From 1d36f728a02de67543c2e9e0c5abcccd9f2387d3 Mon Sep 17 00:00:00 2001 From: Aleksei <867865@gmail.com> Date: Wed, 22 Oct 2025 11:41:54 +0300 Subject: [PATCH] git init --- .env.example | 2 + .gitignore | 165 +++++++++++++++++++++++++++++++++++++++++++-- requirements.txt | 5 ++ src/etl.py | 77 +++++++++++++++++++++ src/main.py | 171 +++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 414 insertions(+), 6 deletions(-) create mode 100644 .env.example create mode 100644 requirements.txt create mode 100644 src/etl.py create mode 100644 src/main.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..447d327 --- /dev/null +++ b/.env.example @@ -0,0 +1,2 @@ +OPENROUTER_API_KEY=sk-00000000000000000000000000000000 +OPENAI_API_KEY=sk-00000000000000000000000000000000 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 5d381cc..0ebdfd8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,7 @@ -# ---> Python # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class - # C extensions *.so @@ -11,7 +9,7 @@ __pycache__/ .Python build/ develop-eggs/ -dist/ +#dist/ downloads/ eggs/ .eggs/ @@ -57,9 +55,9 @@ cover/ *.pot # Django stuff: -*.log local_settings.py -db.sqlite3 +*.sqlite +*.sqlite3 db.sqlite3-journal # Flask stuff: @@ -158,5 +156,160 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +lerna-debug.log* +.pnpm-debug.log* + +# Diagnostic reports (https://nodejs.org/api/report.html) +report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage +*.lcov + +# nyc test coverage +.nyc_output + +# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) +.grunt + +# Bower dependency directory (https://bower.io/) +bower_components + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (https://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +node_modules/ +jspm_packages/ + +# Snowpack dependency directory (https://snowpack.dev/) +web_modules/ + +# TypeScript cache +*.tsbuildinfo + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Optional stylelint cache +.stylelintcache + +# Microbundle cache +.rpt2_cache/ +.rts2_cache_cjs/ +.rts2_cache_es/ +.rts2_cache_umd/ + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variable files +.env.development.local +.env.test.local +.env.production.local +.env.local + +# parcel-bundler cache (https://parceljs.org/) +.parcel-cache + +# Next.js build output +.next +out + +# Nuxt.js build / generate output +.nuxt +dist + +# Gatsby files +.cache/ +# Comment in the public line in if your project uses Gatsby and not Next.js +# https://nextjs.org/blog/next-9-1#public-directory-support +# public + +# vuepress build output +.vuepress/dist + +# vuepress v2.x temp and cache directory +.temp + +# vitepress build output +**/.vitepress/dist + +# vitepress cache directory +**/.vitepress/cache + +# Docusaurus cache and generated files +.docusaurus + +# Serverless directories +.serverless/ + +# FuseBox cache +.fusebox/ + +# DynamoDB Local files +.dynamodb/ + +# TernJS port file +.tern-port + +# Stores VSCode versions used for testing VSCode extensions +.vscode-test + +# yarn v2 +.yarn/cache +.yarn/unplugged +.yarn/build-state.yml +.yarn/install-state.gz +.pnp.* + + +*.pyc + +# Packages +#/dist/* +# Unit test / coverage reports + +.pytest_cache + +.DS_Store +.idea/* +.python-version +.vscode/* + +/docs/site/* +.mypy_cache + + +/poetry.toml + +*.js.map \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..700d473 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +Pillow +tqdm +openai +python-dotenv +pydantic \ No newline at end of file diff --git a/src/etl.py b/src/etl.py new file mode 100644 index 0000000..43f64bb --- /dev/null +++ b/src/etl.py @@ -0,0 +1,77 @@ +import shutil +from pathlib import Path +from PIL import Image +from tqdm import tqdm + + +def process_image(source_path: Path, dest_dir: Path, max_width: int): + """ + Обрабатывает одно изображение: изменяет его размер, если ширина превышает max_width, + иначе просто копирует. Сохраняет результат в dest_dir. + """ + try: + with Image.open(source_path) as img: + dest_path = dest_dir / source_path.name + + if img.width > max_width: + # Рассчитываем новую высоту, сохраняя пропорции + ratio = max_width / float(img.width) + new_height = int(float(img.height) * ratio) + + # Изменяем размер и сохраняем + resized_img = img.resize((max_width, new_height), Image.Resampling.LANCZOS) + resized_img.save(dest_path) + else: + # Если изображение уже подходит, просто копируем его + shutil.copy2(source_path, dest_path) + + except Exception as e: + tqdm.write(f"Ошибка при обработке файла {source_path.name}: {e}") + + +def main(): + """ + Основная функция для создания обработанной версии датасета PlantWild. + """ + # --- НАСТРОЙКИ --- + dataset_plantwild: Path = Path(r'E:\Project\OLIMP\PlantWild\plantwild') + dataset_processed: Path = Path(r'E:\Project\OLIMP\Benchmark_LLM_PlantWild\dataset') + num_images_per_class: int = 4 + max_width: int = 640 + + # --- ЛОГИКА СКРИПТА --- + source_images_dir = dataset_plantwild / 'images' + + if not source_images_dir.is_dir(): + print(f"Ошибка: Исходная директория не найдена по пути '{source_images_dir}'") + return + + dataset_processed.mkdir(parents=True, exist_ok=True) + print(f"Целевая директория: '{dataset_processed}'") + + class_dirs = [d for d in source_images_dir.iterdir() if d.is_dir()] + + print("Начинаем обработку классов...") + for class_dir in tqdm(class_dirs, desc="Обработка классов"): + class_name = class_dir.name + new_class_dir = dataset_processed / class_name + new_class_dir.mkdir(exist_ok=True) + + image_paths = list(class_dir.glob('*.[jJ][pP][gG]')) + \ + list(class_dir.glob('*.[jJ][pP][eE][gG]')) + \ + list(class_dir.glob('*.[pP][nN][gG]')) + + if not image_paths: + tqdm.write(f"Предупреждение: Изображения для класса '{class_name}' не найдены.") + continue + + selected_images = image_paths[:num_images_per_class] + + for image_path in selected_images: + process_image(source_path=image_path, dest_dir=new_class_dir, max_width=max_width) + + print("\nОбработка датасета завершена.") + + +if __name__ == "__main__": + main() diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..c429bab --- /dev/null +++ b/src/main.py @@ -0,0 +1,171 @@ +import os +import csv +import json +import base64 +from pathlib import Path +from abc import ABC, abstractmethod +from typing import Tuple + +from openai import OpenAI, APIError +from pydantic import BaseModel, Field, ValidationError +from dotenv import load_dotenv +from tqdm import tqdm + + +# --- Pydantic-модель для структурированного ответа --- +class PlantDiseasePrediction(BaseModel): + """Определяет структуру ответа от LLM.""" + predicted_class: str = Field(..., description="The predicted disease class for the plant image.") + + +# --- АБСТРАКЦИЯ ПРОВАЙДЕРОВ --- + +class ImageClassifierProvider(ABC): + @abstractmethod + def get_prediction(self, image_path: Path, prompt: str, classes: list[str]) -> Tuple[str, str]: + pass + + +class BaseProvider(ImageClassifierProvider): + """Общий класс для провайдеров, использующих JSON mode.""" + + def __init__(self, model_name: str, api_key: str, base_url: str): + self.model_name = model_name + self.client = OpenAI(base_url=base_url, api_key=api_key) + + def _encode_image_to_base64(self, image_path: Path) -> str: + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + + def get_prediction(self, image_path: Path, prompt: str, classes: list[str]) -> Tuple[str, str]: + base64_image = self._encode_image_to_base64(image_path) + raw_response_str = "" + try: + response = self.client.chat.completions.create( + model=self.model_name, + response_format={"type": "json_object"}, + messages=[ + {"role": "system", "content": prompt}, + {"role": "user", "content": [ + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} + ]} + ] + ) + raw_response_str = response.model_dump_json(indent=2) + message_content = response.choices[0].message.content + + prediction_data = json.loads(message_content) + prediction_obj = PlantDiseasePrediction(**prediction_data) + + if prediction_obj.predicted_class in classes: + return prediction_obj.predicted_class, raw_response_str + else: + # Модель сгенерировала класс, которого нет в списке + return "INVALID_CLASS_ERROR", raw_response_str + + except (json.JSONDecodeError, ValidationError) as e: + error_msg = f"Ошибка парсинга или валидации JSON: {e}" + tqdm.write(f"{error_msg} для {image_path.name}") + return "PARSING_ERROR", raw_response_str or error_msg + except APIError as e: + error_msg = f"Ошибка API: {e}" + tqdm.write(f"{error_msg} для {image_path.name}") + return "API_ERROR", raw_response_str or str(e) + except Exception as e: + error_msg = f"Неизвестная ошибка: {e}" + tqdm.write(f"{error_msg} для {image_path.name}") + return "UNKNOWN_ERROR", raw_response_str or str(e) + + +class OpenAIProvider(BaseProvider): + def __init__(self, model_name: str, api_key: str): + super().__init__(model_name, api_key, base_url="https://api.proxyapi.ru/openai/v1") + + +class OpenRouterProvider(BaseProvider): + def __init__(self, model_name: str, api_key: str): + super().__init__(model_name, api_key, base_url="https://openai.api.proxyapi.ru/v1") + + +def get_provider(provider_name: str, model_name: str, api_key: str) -> ImageClassifierProvider: + if provider_name.lower() == 'openrouter': + return OpenRouterProvider(model_name=model_name, api_key=api_key) + elif provider_name.lower() == 'openai': + return OpenAIProvider(model_name=model_name, api_key=api_key) + else: + raise ValueError(f"Провайдер '{provider_name}' не поддерживается.") + + +def main(): + load_dotenv() + PROVIDER = "openai" + MODEL_NAME = "gpt-5-chat-latest" + processed_dataset_path = Path(r'E:\Project\OLIMP\Benchmark_LLM_PlantWild\dataset') + + api_key_env_var = "OPENAI_API_KEY" if PROVIDER.lower() == 'openai' else "OPENROUTER_API_KEY" + api_key = os.getenv(api_key_env_var) + if not api_key: + print(f"Ошибка: API ключ {api_key_env_var} не найден в .env файле.") + return + + provider = get_provider(PROVIDER, MODEL_NAME, api_key) + + classes_file = processed_dataset_path / 'classes.txt' + if not classes_file.exists(): + print(f"Ошибка: Файл с классами не найден: {classes_file}") + return + + with open(classes_file, 'r', encoding='utf-8') as f: + class_names = [line.strip().split(' ', 1)[1] for line in f if line.strip()] + + # Промпт адаптирован для JSON mode + schema_json = json.dumps(PlantDiseasePrediction.model_json_schema(), indent=2) + prompt = ( + f"You are an expert system for plant disease recognition. " + f"Analyze the user's image and respond ONLY with a valid JSON object that adheres to the following JSON Schema:\n" + f"```json\n{schema_json}\n```\n" + f"The 'predicted_class' must be one of the following valid classes: {', '.join(class_names)}. " + f"Do not include any other text, explanations, or markdown formatting." + ) + + safe_model_name = MODEL_NAME.replace('/', '_') + results_filename = f"results_{PROVIDER}_{safe_model_name}.csv" + error_log_filename = f"error_log_{PROVIDER}_{safe_model_name}.txt" + + if os.path.exists(error_log_filename): + os.remove(error_log_filename) + + with open(results_filename, 'w', newline='', encoding='utf-8') as csvfile: + csv_writer = csv.writer(csvfile, delimiter=';') + csv_writer.writerow(["image_name", "correct_class", "predicted_class"]) + image_files = sorted(list(processed_dataset_path.rglob('*.[jJ][pP]*[gG]'))) + + print(f"Найдено {len(image_files)} изображений. Начинаем тестирование...") + error_count = 0 + + # ---------------------------------------- + print("УБРАТЬ! СДЕЛАНО ТОЛЬКО ДЛЯ ОТЛАДКИ!") + image_files = image_files[0:4] + # ---------------------------------------- + + for image_path in tqdm(image_files, desc="Тестирование"): + correct_class = image_path.parent.name + predicted_class, raw_response = provider.get_prediction(image_path, prompt, class_names) + + csv_writer.writerow([image_path.name, correct_class, predicted_class]) + + error_keywords = ["ERROR", "REFUSAL"] + if any(keyword in predicted_class for keyword in error_keywords): + error_count += 1 + with open(error_log_filename, 'a', encoding='utf-8') as log_file: + log_file.write(f"--- ERROR FOR IMAGE: {image_path.name} ---\n") + log_file.write(f"PREDICTED_STATUS: {predicted_class}\nRAW RESPONSE:\n{raw_response}\n\n") + + print(f"\nТестирование завершено.") + print(f"✅ Результаты сохранены в файле: {results_filename}") + if error_count > 0: + print(f"⚠️ Записано {error_count} ошибок. Подробности в файле: {error_log_filename}") + + +if __name__ == "__main__": + main() \ No newline at end of file