git init
This commit is contained in:
parent
b39af02f9d
commit
1d36f728a0
2
.env.example
Normal file
2
.env.example
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
OPENROUTER_API_KEY=sk-00000000000000000000000000000000
|
||||||
|
OPENAI_API_KEY=sk-00000000000000000000000000000000
|
||||||
165
.gitignore
vendored
165
.gitignore
vendored
@ -1,9 +1,7 @@
|
|||||||
# ---> Python
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
|
||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
@ -11,7 +9,7 @@ __pycache__/
|
|||||||
.Python
|
.Python
|
||||||
build/
|
build/
|
||||||
develop-eggs/
|
develop-eggs/
|
||||||
dist/
|
#dist/
|
||||||
downloads/
|
downloads/
|
||||||
eggs/
|
eggs/
|
||||||
.eggs/
|
.eggs/
|
||||||
@ -57,9 +55,9 @@ cover/
|
|||||||
*.pot
|
*.pot
|
||||||
|
|
||||||
# Django stuff:
|
# Django stuff:
|
||||||
*.log
|
|
||||||
local_settings.py
|
local_settings.py
|
||||||
db.sqlite3
|
*.sqlite
|
||||||
|
*.sqlite3
|
||||||
db.sqlite3-journal
|
db.sqlite3-journal
|
||||||
|
|
||||||
# Flask stuff:
|
# Flask stuff:
|
||||||
@ -158,5 +156,160 @@ cython_debug/
|
|||||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
.idea/
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
lerna-debug.log*
|
||||||
|
.pnpm-debug.log*
|
||||||
|
|
||||||
|
# Diagnostic reports (https://nodejs.org/api/report.html)
|
||||||
|
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
|
||||||
|
|
||||||
|
# Runtime data
|
||||||
|
pids
|
||||||
|
*.pid
|
||||||
|
*.seed
|
||||||
|
*.pid.lock
|
||||||
|
|
||||||
|
# Directory for instrumented libs generated by jscoverage/JSCover
|
||||||
|
lib-cov
|
||||||
|
|
||||||
|
# Coverage directory used by tools like istanbul
|
||||||
|
coverage
|
||||||
|
*.lcov
|
||||||
|
|
||||||
|
# nyc test coverage
|
||||||
|
.nyc_output
|
||||||
|
|
||||||
|
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
|
||||||
|
.grunt
|
||||||
|
|
||||||
|
# Bower dependency directory (https://bower.io/)
|
||||||
|
bower_components
|
||||||
|
|
||||||
|
# node-waf configuration
|
||||||
|
.lock-wscript
|
||||||
|
|
||||||
|
# Compiled binary addons (https://nodejs.org/api/addons.html)
|
||||||
|
build/Release
|
||||||
|
|
||||||
|
# Dependency directories
|
||||||
|
node_modules/
|
||||||
|
jspm_packages/
|
||||||
|
|
||||||
|
# Snowpack dependency directory (https://snowpack.dev/)
|
||||||
|
web_modules/
|
||||||
|
|
||||||
|
# TypeScript cache
|
||||||
|
*.tsbuildinfo
|
||||||
|
|
||||||
|
# Optional npm cache directory
|
||||||
|
.npm
|
||||||
|
|
||||||
|
# Optional eslint cache
|
||||||
|
.eslintcache
|
||||||
|
|
||||||
|
# Optional stylelint cache
|
||||||
|
.stylelintcache
|
||||||
|
|
||||||
|
# Microbundle cache
|
||||||
|
.rpt2_cache/
|
||||||
|
.rts2_cache_cjs/
|
||||||
|
.rts2_cache_es/
|
||||||
|
.rts2_cache_umd/
|
||||||
|
|
||||||
|
# Optional REPL history
|
||||||
|
.node_repl_history
|
||||||
|
|
||||||
|
# Output of 'npm pack'
|
||||||
|
*.tgz
|
||||||
|
|
||||||
|
# Yarn Integrity file
|
||||||
|
.yarn-integrity
|
||||||
|
|
||||||
|
# dotenv environment variable files
|
||||||
|
.env.development.local
|
||||||
|
.env.test.local
|
||||||
|
.env.production.local
|
||||||
|
.env.local
|
||||||
|
|
||||||
|
# parcel-bundler cache (https://parceljs.org/)
|
||||||
|
.parcel-cache
|
||||||
|
|
||||||
|
# Next.js build output
|
||||||
|
.next
|
||||||
|
out
|
||||||
|
|
||||||
|
# Nuxt.js build / generate output
|
||||||
|
.nuxt
|
||||||
|
dist
|
||||||
|
|
||||||
|
# Gatsby files
|
||||||
|
.cache/
|
||||||
|
# Comment in the public line in if your project uses Gatsby and not Next.js
|
||||||
|
# https://nextjs.org/blog/next-9-1#public-directory-support
|
||||||
|
# public
|
||||||
|
|
||||||
|
# vuepress build output
|
||||||
|
.vuepress/dist
|
||||||
|
|
||||||
|
# vuepress v2.x temp and cache directory
|
||||||
|
.temp
|
||||||
|
|
||||||
|
# vitepress build output
|
||||||
|
**/.vitepress/dist
|
||||||
|
|
||||||
|
# vitepress cache directory
|
||||||
|
**/.vitepress/cache
|
||||||
|
|
||||||
|
# Docusaurus cache and generated files
|
||||||
|
.docusaurus
|
||||||
|
|
||||||
|
# Serverless directories
|
||||||
|
.serverless/
|
||||||
|
|
||||||
|
# FuseBox cache
|
||||||
|
.fusebox/
|
||||||
|
|
||||||
|
# DynamoDB Local files
|
||||||
|
.dynamodb/
|
||||||
|
|
||||||
|
# TernJS port file
|
||||||
|
.tern-port
|
||||||
|
|
||||||
|
# Stores VSCode versions used for testing VSCode extensions
|
||||||
|
.vscode-test
|
||||||
|
|
||||||
|
# yarn v2
|
||||||
|
.yarn/cache
|
||||||
|
.yarn/unplugged
|
||||||
|
.yarn/build-state.yml
|
||||||
|
.yarn/install-state.gz
|
||||||
|
.pnp.*
|
||||||
|
|
||||||
|
|
||||||
|
*.pyc
|
||||||
|
|
||||||
|
# Packages
|
||||||
|
#/dist/*
|
||||||
|
# Unit test / coverage reports
|
||||||
|
|
||||||
|
.pytest_cache
|
||||||
|
|
||||||
|
.DS_Store
|
||||||
|
.idea/*
|
||||||
|
.python-version
|
||||||
|
.vscode/*
|
||||||
|
|
||||||
|
/docs/site/*
|
||||||
|
.mypy_cache
|
||||||
|
|
||||||
|
|
||||||
|
/poetry.toml
|
||||||
|
|
||||||
|
*.js.map
|
||||||
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
Pillow
|
||||||
|
tqdm
|
||||||
|
openai
|
||||||
|
python-dotenv
|
||||||
|
pydantic
|
||||||
77
src/etl.py
Normal file
77
src/etl.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def process_image(source_path: Path, dest_dir: Path, max_width: int):
|
||||||
|
"""
|
||||||
|
Обрабатывает одно изображение: изменяет его размер, если ширина превышает max_width,
|
||||||
|
иначе просто копирует. Сохраняет результат в dest_dir.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Image.open(source_path) as img:
|
||||||
|
dest_path = dest_dir / source_path.name
|
||||||
|
|
||||||
|
if img.width > max_width:
|
||||||
|
# Рассчитываем новую высоту, сохраняя пропорции
|
||||||
|
ratio = max_width / float(img.width)
|
||||||
|
new_height = int(float(img.height) * ratio)
|
||||||
|
|
||||||
|
# Изменяем размер и сохраняем
|
||||||
|
resized_img = img.resize((max_width, new_height), Image.Resampling.LANCZOS)
|
||||||
|
resized_img.save(dest_path)
|
||||||
|
else:
|
||||||
|
# Если изображение уже подходит, просто копируем его
|
||||||
|
shutil.copy2(source_path, dest_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
tqdm.write(f"Ошибка при обработке файла {source_path.name}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
Основная функция для создания обработанной версии датасета PlantWild.
|
||||||
|
"""
|
||||||
|
# --- НАСТРОЙКИ ---
|
||||||
|
dataset_plantwild: Path = Path(r'E:\Project\OLIMP\PlantWild\plantwild')
|
||||||
|
dataset_processed: Path = Path(r'E:\Project\OLIMP\Benchmark_LLM_PlantWild\dataset')
|
||||||
|
num_images_per_class: int = 4
|
||||||
|
max_width: int = 640
|
||||||
|
|
||||||
|
# --- ЛОГИКА СКРИПТА ---
|
||||||
|
source_images_dir = dataset_plantwild / 'images'
|
||||||
|
|
||||||
|
if not source_images_dir.is_dir():
|
||||||
|
print(f"Ошибка: Исходная директория не найдена по пути '{source_images_dir}'")
|
||||||
|
return
|
||||||
|
|
||||||
|
dataset_processed.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"Целевая директория: '{dataset_processed}'")
|
||||||
|
|
||||||
|
class_dirs = [d for d in source_images_dir.iterdir() if d.is_dir()]
|
||||||
|
|
||||||
|
print("Начинаем обработку классов...")
|
||||||
|
for class_dir in tqdm(class_dirs, desc="Обработка классов"):
|
||||||
|
class_name = class_dir.name
|
||||||
|
new_class_dir = dataset_processed / class_name
|
||||||
|
new_class_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
image_paths = list(class_dir.glob('*.[jJ][pP][gG]')) + \
|
||||||
|
list(class_dir.glob('*.[jJ][pP][eE][gG]')) + \
|
||||||
|
list(class_dir.glob('*.[pP][nN][gG]'))
|
||||||
|
|
||||||
|
if not image_paths:
|
||||||
|
tqdm.write(f"Предупреждение: Изображения для класса '{class_name}' не найдены.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
selected_images = image_paths[:num_images_per_class]
|
||||||
|
|
||||||
|
for image_path in selected_images:
|
||||||
|
process_image(source_path=image_path, dest_dir=new_class_dir, max_width=max_width)
|
||||||
|
|
||||||
|
print("\nОбработка датасета завершена.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
171
src/main.py
Normal file
171
src/main.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
import os
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
from pathlib import Path
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from openai import OpenAI, APIError
|
||||||
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
# --- Pydantic-модель для структурированного ответа ---
|
||||||
|
class PlantDiseasePrediction(BaseModel):
|
||||||
|
"""Определяет структуру ответа от LLM."""
|
||||||
|
predicted_class: str = Field(..., description="The predicted disease class for the plant image.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- АБСТРАКЦИЯ ПРОВАЙДЕРОВ ---
|
||||||
|
|
||||||
|
class ImageClassifierProvider(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def get_prediction(self, image_path: Path, prompt: str, classes: list[str]) -> Tuple[str, str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseProvider(ImageClassifierProvider):
|
||||||
|
"""Общий класс для провайдеров, использующих JSON mode."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str, api_key: str, base_url: str):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.client = OpenAI(base_url=base_url, api_key=api_key)
|
||||||
|
|
||||||
|
def _encode_image_to_base64(self, image_path: Path) -> str:
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||||
|
|
||||||
|
def get_prediction(self, image_path: Path, prompt: str, classes: list[str]) -> Tuple[str, str]:
|
||||||
|
base64_image = self._encode_image_to_base64(image_path)
|
||||||
|
raw_response_str = ""
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": prompt},
|
||||||
|
{"role": "user", "content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
|
||||||
|
]}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
raw_response_str = response.model_dump_json(indent=2)
|
||||||
|
message_content = response.choices[0].message.content
|
||||||
|
|
||||||
|
prediction_data = json.loads(message_content)
|
||||||
|
prediction_obj = PlantDiseasePrediction(**prediction_data)
|
||||||
|
|
||||||
|
if prediction_obj.predicted_class in classes:
|
||||||
|
return prediction_obj.predicted_class, raw_response_str
|
||||||
|
else:
|
||||||
|
# Модель сгенерировала класс, которого нет в списке
|
||||||
|
return "INVALID_CLASS_ERROR", raw_response_str
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, ValidationError) as e:
|
||||||
|
error_msg = f"Ошибка парсинга или валидации JSON: {e}"
|
||||||
|
tqdm.write(f"{error_msg} для {image_path.name}")
|
||||||
|
return "PARSING_ERROR", raw_response_str or error_msg
|
||||||
|
except APIError as e:
|
||||||
|
error_msg = f"Ошибка API: {e}"
|
||||||
|
tqdm.write(f"{error_msg} для {image_path.name}")
|
||||||
|
return "API_ERROR", raw_response_str or str(e)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Неизвестная ошибка: {e}"
|
||||||
|
tqdm.write(f"{error_msg} для {image_path.name}")
|
||||||
|
return "UNKNOWN_ERROR", raw_response_str or str(e)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProvider(BaseProvider):
|
||||||
|
def __init__(self, model_name: str, api_key: str):
|
||||||
|
super().__init__(model_name, api_key, base_url="https://api.proxyapi.ru/openai/v1")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterProvider(BaseProvider):
|
||||||
|
def __init__(self, model_name: str, api_key: str):
|
||||||
|
super().__init__(model_name, api_key, base_url="https://openai.api.proxyapi.ru/v1")
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider(provider_name: str, model_name: str, api_key: str) -> ImageClassifierProvider:
|
||||||
|
if provider_name.lower() == 'openrouter':
|
||||||
|
return OpenRouterProvider(model_name=model_name, api_key=api_key)
|
||||||
|
elif provider_name.lower() == 'openai':
|
||||||
|
return OpenAIProvider(model_name=model_name, api_key=api_key)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Провайдер '{provider_name}' не поддерживается.")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
load_dotenv()
|
||||||
|
PROVIDER = "openai"
|
||||||
|
MODEL_NAME = "gpt-5-chat-latest"
|
||||||
|
processed_dataset_path = Path(r'E:\Project\OLIMP\Benchmark_LLM_PlantWild\dataset')
|
||||||
|
|
||||||
|
api_key_env_var = "OPENAI_API_KEY" if PROVIDER.lower() == 'openai' else "OPENROUTER_API_KEY"
|
||||||
|
api_key = os.getenv(api_key_env_var)
|
||||||
|
if not api_key:
|
||||||
|
print(f"Ошибка: API ключ {api_key_env_var} не найден в .env файле.")
|
||||||
|
return
|
||||||
|
|
||||||
|
provider = get_provider(PROVIDER, MODEL_NAME, api_key)
|
||||||
|
|
||||||
|
classes_file = processed_dataset_path / 'classes.txt'
|
||||||
|
if not classes_file.exists():
|
||||||
|
print(f"Ошибка: Файл с классами не найден: {classes_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(classes_file, 'r', encoding='utf-8') as f:
|
||||||
|
class_names = [line.strip().split(' ', 1)[1] for line in f if line.strip()]
|
||||||
|
|
||||||
|
# Промпт адаптирован для JSON mode
|
||||||
|
schema_json = json.dumps(PlantDiseasePrediction.model_json_schema(), indent=2)
|
||||||
|
prompt = (
|
||||||
|
f"You are an expert system for plant disease recognition. "
|
||||||
|
f"Analyze the user's image and respond ONLY with a valid JSON object that adheres to the following JSON Schema:\n"
|
||||||
|
f"```json\n{schema_json}\n```\n"
|
||||||
|
f"The 'predicted_class' must be one of the following valid classes: {', '.join(class_names)}. "
|
||||||
|
f"Do not include any other text, explanations, or markdown formatting."
|
||||||
|
)
|
||||||
|
|
||||||
|
safe_model_name = MODEL_NAME.replace('/', '_')
|
||||||
|
results_filename = f"results_{PROVIDER}_{safe_model_name}.csv"
|
||||||
|
error_log_filename = f"error_log_{PROVIDER}_{safe_model_name}.txt"
|
||||||
|
|
||||||
|
if os.path.exists(error_log_filename):
|
||||||
|
os.remove(error_log_filename)
|
||||||
|
|
||||||
|
with open(results_filename, 'w', newline='', encoding='utf-8') as csvfile:
|
||||||
|
csv_writer = csv.writer(csvfile, delimiter=';')
|
||||||
|
csv_writer.writerow(["image_name", "correct_class", "predicted_class"])
|
||||||
|
image_files = sorted(list(processed_dataset_path.rglob('*.[jJ][pP]*[gG]')))
|
||||||
|
|
||||||
|
print(f"Найдено {len(image_files)} изображений. Начинаем тестирование...")
|
||||||
|
error_count = 0
|
||||||
|
|
||||||
|
# ----------------------------------------
|
||||||
|
print("УБРАТЬ! СДЕЛАНО ТОЛЬКО ДЛЯ ОТЛАДКИ!")
|
||||||
|
image_files = image_files[0:4]
|
||||||
|
# ----------------------------------------
|
||||||
|
|
||||||
|
for image_path in tqdm(image_files, desc="Тестирование"):
|
||||||
|
correct_class = image_path.parent.name
|
||||||
|
predicted_class, raw_response = provider.get_prediction(image_path, prompt, class_names)
|
||||||
|
|
||||||
|
csv_writer.writerow([image_path.name, correct_class, predicted_class])
|
||||||
|
|
||||||
|
error_keywords = ["ERROR", "REFUSAL"]
|
||||||
|
if any(keyword in predicted_class for keyword in error_keywords):
|
||||||
|
error_count += 1
|
||||||
|
with open(error_log_filename, 'a', encoding='utf-8') as log_file:
|
||||||
|
log_file.write(f"--- ERROR FOR IMAGE: {image_path.name} ---\n")
|
||||||
|
log_file.write(f"PREDICTED_STATUS: {predicted_class}\nRAW RESPONSE:\n{raw_response}\n\n")
|
||||||
|
|
||||||
|
print(f"\nТестирование завершено.")
|
||||||
|
print(f"✅ Результаты сохранены в файле: {results_filename}")
|
||||||
|
if error_count > 0:
|
||||||
|
print(f"⚠️ Записано {error_count} ошибок. Подробности в файле: {error_log_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue
Block a user