This commit is contained in:
Aleksei 2025-10-22 11:41:54 +03:00
parent b39af02f9d
commit 1d36f728a0
5 changed files with 414 additions and 6 deletions

2
.env.example Normal file
View File

@ -0,0 +1,2 @@
OPENROUTER_API_KEY=sk-00000000000000000000000000000000
OPENAI_API_KEY=sk-00000000000000000000000000000000

165
.gitignore vendored
View File

@ -1,9 +1,7 @@
# ---> Python
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
# C extensions # C extensions
*.so *.so
@ -11,7 +9,7 @@ __pycache__/
.Python .Python
build/ build/
develop-eggs/ develop-eggs/
dist/ #dist/
downloads/ downloads/
eggs/ eggs/
.eggs/ .eggs/
@ -57,9 +55,9 @@ cover/
*.pot *.pot
# Django stuff: # Django stuff:
*.log
local_settings.py local_settings.py
db.sqlite3 *.sqlite
*.sqlite3
db.sqlite3-journal db.sqlite3-journal
# Flask stuff: # Flask stuff:
@ -158,5 +156,160 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ .idea/
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
lerna-debug.log*
.pnpm-debug.log*
# Diagnostic reports (https://nodejs.org/api/report.html)
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
# Runtime data
pids
*.pid
*.seed
*.pid.lock
# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov
# Coverage directory used by tools like istanbul
coverage
*.lcov
# nyc test coverage
.nyc_output
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
.grunt
# Bower dependency directory (https://bower.io/)
bower_components
# node-waf configuration
.lock-wscript
# Compiled binary addons (https://nodejs.org/api/addons.html)
build/Release
# Dependency directories
node_modules/
jspm_packages/
# Snowpack dependency directory (https://snowpack.dev/)
web_modules/
# TypeScript cache
*.tsbuildinfo
# Optional npm cache directory
.npm
# Optional eslint cache
.eslintcache
# Optional stylelint cache
.stylelintcache
# Microbundle cache
.rpt2_cache/
.rts2_cache_cjs/
.rts2_cache_es/
.rts2_cache_umd/
# Optional REPL history
.node_repl_history
# Output of 'npm pack'
*.tgz
# Yarn Integrity file
.yarn-integrity
# dotenv environment variable files
.env.development.local
.env.test.local
.env.production.local
.env.local
# parcel-bundler cache (https://parceljs.org/)
.parcel-cache
# Next.js build output
.next
out
# Nuxt.js build / generate output
.nuxt
dist
# Gatsby files
.cache/
# Comment in the public line in if your project uses Gatsby and not Next.js
# https://nextjs.org/blog/next-9-1#public-directory-support
# public
# vuepress build output
.vuepress/dist
# vuepress v2.x temp and cache directory
.temp
# vitepress build output
**/.vitepress/dist
# vitepress cache directory
**/.vitepress/cache
# Docusaurus cache and generated files
.docusaurus
# Serverless directories
.serverless/
# FuseBox cache
.fusebox/
# DynamoDB Local files
.dynamodb/
# TernJS port file
.tern-port
# Stores VSCode versions used for testing VSCode extensions
.vscode-test
# yarn v2
.yarn/cache
.yarn/unplugged
.yarn/build-state.yml
.yarn/install-state.gz
.pnp.*
*.pyc
# Packages
#/dist/*
# Unit test / coverage reports
.pytest_cache
.DS_Store
.idea/*
.python-version
.vscode/*
/docs/site/*
.mypy_cache
/poetry.toml
*.js.map

5
requirements.txt Normal file
View File

@ -0,0 +1,5 @@
Pillow
tqdm
openai
python-dotenv
pydantic

77
src/etl.py Normal file
View File

@ -0,0 +1,77 @@
import shutil
from pathlib import Path
from PIL import Image
from tqdm import tqdm
def process_image(source_path: Path, dest_dir: Path, max_width: int):
"""
Обрабатывает одно изображение: изменяет его размер, если ширина превышает max_width,
иначе просто копирует. Сохраняет результат в dest_dir.
"""
try:
with Image.open(source_path) as img:
dest_path = dest_dir / source_path.name
if img.width > max_width:
# Рассчитываем новую высоту, сохраняя пропорции
ratio = max_width / float(img.width)
new_height = int(float(img.height) * ratio)
# Изменяем размер и сохраняем
resized_img = img.resize((max_width, new_height), Image.Resampling.LANCZOS)
resized_img.save(dest_path)
else:
# Если изображение уже подходит, просто копируем его
shutil.copy2(source_path, dest_path)
except Exception as e:
tqdm.write(f"Ошибка при обработке файла {source_path.name}: {e}")
def main():
"""
Основная функция для создания обработанной версии датасета PlantWild.
"""
# --- НАСТРОЙКИ ---
dataset_plantwild: Path = Path(r'E:\Project\OLIMP\PlantWild\plantwild')
dataset_processed: Path = Path(r'E:\Project\OLIMP\Benchmark_LLM_PlantWild\dataset')
num_images_per_class: int = 4
max_width: int = 640
# --- ЛОГИКА СКРИПТА ---
source_images_dir = dataset_plantwild / 'images'
if not source_images_dir.is_dir():
print(f"Ошибка: Исходная директория не найдена по пути '{source_images_dir}'")
return
dataset_processed.mkdir(parents=True, exist_ok=True)
print(f"Целевая директория: '{dataset_processed}'")
class_dirs = [d for d in source_images_dir.iterdir() if d.is_dir()]
print("Начинаем обработку классов...")
for class_dir in tqdm(class_dirs, desc="Обработка классов"):
class_name = class_dir.name
new_class_dir = dataset_processed / class_name
new_class_dir.mkdir(exist_ok=True)
image_paths = list(class_dir.glob('*.[jJ][pP][gG]')) + \
list(class_dir.glob('*.[jJ][pP][eE][gG]')) + \
list(class_dir.glob('*.[pP][nN][gG]'))
if not image_paths:
tqdm.write(f"Предупреждение: Изображения для класса '{class_name}' не найдены.")
continue
selected_images = image_paths[:num_images_per_class]
for image_path in selected_images:
process_image(source_path=image_path, dest_dir=new_class_dir, max_width=max_width)
print("\nОбработка датасета завершена.")
if __name__ == "__main__":
main()

171
src/main.py Normal file
View File

@ -0,0 +1,171 @@
import os
import csv
import json
import base64
from pathlib import Path
from abc import ABC, abstractmethod
from typing import Tuple
from openai import OpenAI, APIError
from pydantic import BaseModel, Field, ValidationError
from dotenv import load_dotenv
from tqdm import tqdm
# --- Pydantic-модель для структурированного ответа ---
class PlantDiseasePrediction(BaseModel):
"""Определяет структуру ответа от LLM."""
predicted_class: str = Field(..., description="The predicted disease class for the plant image.")
# --- АБСТРАКЦИЯ ПРОВАЙДЕРОВ ---
class ImageClassifierProvider(ABC):
@abstractmethod
def get_prediction(self, image_path: Path, prompt: str, classes: list[str]) -> Tuple[str, str]:
pass
class BaseProvider(ImageClassifierProvider):
"""Общий класс для провайдеров, использующих JSON mode."""
def __init__(self, model_name: str, api_key: str, base_url: str):
self.model_name = model_name
self.client = OpenAI(base_url=base_url, api_key=api_key)
def _encode_image_to_base64(self, image_path: Path) -> str:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def get_prediction(self, image_path: Path, prompt: str, classes: list[str]) -> Tuple[str, str]:
base64_image = self._encode_image_to_base64(image_path)
raw_response_str = ""
try:
response = self.client.chat.completions.create(
model=self.model_name,
response_format={"type": "json_object"},
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": [
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
]}
]
)
raw_response_str = response.model_dump_json(indent=2)
message_content = response.choices[0].message.content
prediction_data = json.loads(message_content)
prediction_obj = PlantDiseasePrediction(**prediction_data)
if prediction_obj.predicted_class in classes:
return prediction_obj.predicted_class, raw_response_str
else:
# Модель сгенерировала класс, которого нет в списке
return "INVALID_CLASS_ERROR", raw_response_str
except (json.JSONDecodeError, ValidationError) as e:
error_msg = f"Ошибка парсинга или валидации JSON: {e}"
tqdm.write(f"{error_msg} для {image_path.name}")
return "PARSING_ERROR", raw_response_str or error_msg
except APIError as e:
error_msg = f"Ошибка API: {e}"
tqdm.write(f"{error_msg} для {image_path.name}")
return "API_ERROR", raw_response_str or str(e)
except Exception as e:
error_msg = f"Неизвестная ошибка: {e}"
tqdm.write(f"{error_msg} для {image_path.name}")
return "UNKNOWN_ERROR", raw_response_str or str(e)
class OpenAIProvider(BaseProvider):
def __init__(self, model_name: str, api_key: str):
super().__init__(model_name, api_key, base_url="https://api.proxyapi.ru/openai/v1")
class OpenRouterProvider(BaseProvider):
def __init__(self, model_name: str, api_key: str):
super().__init__(model_name, api_key, base_url="https://openai.api.proxyapi.ru/v1")
def get_provider(provider_name: str, model_name: str, api_key: str) -> ImageClassifierProvider:
if provider_name.lower() == 'openrouter':
return OpenRouterProvider(model_name=model_name, api_key=api_key)
elif provider_name.lower() == 'openai':
return OpenAIProvider(model_name=model_name, api_key=api_key)
else:
raise ValueError(f"Провайдер '{provider_name}' не поддерживается.")
def main():
load_dotenv()
PROVIDER = "openai"
MODEL_NAME = "gpt-5-chat-latest"
processed_dataset_path = Path(r'E:\Project\OLIMP\Benchmark_LLM_PlantWild\dataset')
api_key_env_var = "OPENAI_API_KEY" if PROVIDER.lower() == 'openai' else "OPENROUTER_API_KEY"
api_key = os.getenv(api_key_env_var)
if not api_key:
print(f"Ошибка: API ключ {api_key_env_var} не найден в .env файле.")
return
provider = get_provider(PROVIDER, MODEL_NAME, api_key)
classes_file = processed_dataset_path / 'classes.txt'
if not classes_file.exists():
print(f"Ошибка: Файл с классами не найден: {classes_file}")
return
with open(classes_file, 'r', encoding='utf-8') as f:
class_names = [line.strip().split(' ', 1)[1] for line in f if line.strip()]
# Промпт адаптирован для JSON mode
schema_json = json.dumps(PlantDiseasePrediction.model_json_schema(), indent=2)
prompt = (
f"You are an expert system for plant disease recognition. "
f"Analyze the user's image and respond ONLY with a valid JSON object that adheres to the following JSON Schema:\n"
f"```json\n{schema_json}\n```\n"
f"The 'predicted_class' must be one of the following valid classes: {', '.join(class_names)}. "
f"Do not include any other text, explanations, or markdown formatting."
)
safe_model_name = MODEL_NAME.replace('/', '_')
results_filename = f"results_{PROVIDER}_{safe_model_name}.csv"
error_log_filename = f"error_log_{PROVIDER}_{safe_model_name}.txt"
if os.path.exists(error_log_filename):
os.remove(error_log_filename)
with open(results_filename, 'w', newline='', encoding='utf-8') as csvfile:
csv_writer = csv.writer(csvfile, delimiter=';')
csv_writer.writerow(["image_name", "correct_class", "predicted_class"])
image_files = sorted(list(processed_dataset_path.rglob('*.[jJ][pP]*[gG]')))
print(f"Найдено {len(image_files)} изображений. Начинаем тестирование...")
error_count = 0
# ----------------------------------------
print("УБРАТЬ! СДЕЛАНО ТОЛЬКО ДЛЯ ОТЛАДКИ!")
image_files = image_files[0:4]
# ----------------------------------------
for image_path in tqdm(image_files, desc="Тестирование"):
correct_class = image_path.parent.name
predicted_class, raw_response = provider.get_prediction(image_path, prompt, class_names)
csv_writer.writerow([image_path.name, correct_class, predicted_class])
error_keywords = ["ERROR", "REFUSAL"]
if any(keyword in predicted_class for keyword in error_keywords):
error_count += 1
with open(error_log_filename, 'a', encoding='utf-8') as log_file:
log_file.write(f"--- ERROR FOR IMAGE: {image_path.name} ---\n")
log_file.write(f"PREDICTED_STATUS: {predicted_class}\nRAW RESPONSE:\n{raw_response}\n\n")
print(f"\nТестирование завершено.")
print(f"✅ Результаты сохранены в файле: {results_filename}")
if error_count > 0:
print(f"⚠️ Записано {error_count} ошибок. Подробности в файле: {error_log_filename}")
if __name__ == "__main__":
main()