| import gradio as gr |
| import torch |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification |
| import time |
| from typing import Dict, Tuple |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
| |
| MAX_LENGTH = 1000 |
| MODELS = { |
| "cointegrated/rubert-tiny2": "Лёгкая модель (быстрая)", |
| "s-nlp/rubert-tiny-cased-rured": "Специализированная для классификации", |
| "ai-forever/ruBert-base": "Точная модель (медленнее)" |
| } |
| LABELS = { |
| 0: "Политика", |
| 1: "Экономика", |
| 2: "Наука и технологии", |
| 3: "Культура и искусство", |
| 4: "Спорт", |
| 5: "Здоровье и медицина", |
| 6: "Образование", |
| 7: "Разное" |
| } |
|
|
| class TopicClassifier: |
| def __init__(self): |
| self.models: Dict = {} |
| self.tokenizers: Dict = {} |
| |
| def load_model(self, model_name: str): |
| """Загрузка модели по требованию""" |
| if model_name not in self.models: |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForSequenceClassification.from_pretrained( |
| model_name, |
| num_labels=len(LABELS) |
| ) |
| |
| |
| model.eval() |
| |
| self.models[model_name] = model |
| self.tokenizers[model_name] = tokenizer |
| |
| print(f"Модель {model_name} загружена успешно") |
| except Exception as e: |
| raise Exception(f"Ошибка загрузки модели: {str(e)}") |
| |
| def predict(self, text: str, model_name: str) -> Tuple[Dict, float]: |
| """Предсказание темы текста""" |
| if not text.strip(): |
| raise ValueError("Текст не может быть пустым") |
| |
| if len(text) > MAX_LENGTH: |
| text = text[:MAX_LENGTH] |
| gr.Warning(f"Текст обрезан до {MAX_LENGTH} символов") |
| |
| self.load_model(model_name) |
| |
| start_time = time.time() |
| |
| try: |
| tokenizer = self.tokenizers[model_name] |
| model = self.models[model_name] |
| |
| inputs = tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=512, |
| padding=True |
| ) |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| |
| scores = predictions[0].tolist() |
| results = {LABELS[i]: round(score * 100, 2) for i, score in enumerate(scores)} |
| |
| |
| sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) |
| |
| latency = round((time.time() - start_time) * 1000, 2) |
| |
| return sorted_results, latency |
| |
| except Exception as e: |
| raise Exception(f"Ошибка при обработке: {str(e)}") |
|
|
| |
| classifier = TopicClassifier() |
|
|
| def process_text(text: str, model_choice: str) -> Tuple[str, str, str]: |
| """Обработка текста с выбранной моделью""" |
| if not text.strip(): |
| return "⚠️ Введите текст для анализа", "", "0" |
| |
| try: |
| predictions, latency = classifier.predict(text, model_choice) |
| |
| |
| top_topic = list(predictions.keys())[0] |
| top_score = predictions[top_topic] |
| |
| result_text = f"🎯 **Основная тема:** {top_topic} ({top_score}%)\n\n" |
| result_text += "📊 **Распределение тем:**\n" |
| |
| for topic, score in predictions.items(): |
| result_text += f"• {topic}: {score}%\n" |
| |
| |
| json_output = "{\n" |
| for topic, score in predictions.items(): |
| json_output += f' "{topic}": {score},\n' |
| json_output = json_output.rstrip(",\n") + "\n}" |
| |
| return result_text, json_output, str(latency) |
| |
| except ValueError as e: |
| return f"❌ {str(e)}", "", "0" |
| except Exception as e: |
| return f"⚠️ Ошибка: {str(e)}", "", "0" |
|
|
| |
| examples = [ |
| ["Российская экономика показала рост в третьем квартале благодаря увеличению экспорта нефти и газа."], |
| ["Ученые создали новый материал для солнечных батарей с эффективностью 45%."], |
| ["На чемпионате мира по футболу сборная Бразилии одержала победу со счетом 3:1."], |
| ["В музее открылась выставка современных художников, посвященная проблемам экологии."], |
| ["Минздрав рекомендовал новые правила вакцинации для населения старше 60 лет."] |
| ] |
|
|
| |
| with gr.Blocks(title="Классификатор тем текста", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 🎯 Классификатор тематики текста") |
| gr.Markdown("Определите основную тему вашего текста с помощью ИИ-моделей") |
| |
| with gr.Row(): |
| with gr.Column(scale=2): |
| model_selector = gr.Dropdown( |
| choices=list(MODELS.keys()), |
| value=list(MODELS.keys())[0], |
| label="📋 Выберите модель", |
| info="Каждая модель имеет разный баланс скорости и точности" |
| ) |
| |
| text_input = gr.Textbox( |
| label="📝 Введите текст для анализа", |
| placeholder="Введите текст на русском языке...", |
| lines=5, |
| max_lines=10 |
| ) |
| |
| process_btn = gr.Button("🔍 Анализировать текст", variant="primary") |
| |
| gr.Markdown("### 📋 Примеры текстов") |
| gr.Examples( |
| examples=examples, |
| inputs=text_input, |
| label="Нажмите на пример для быстрой загрузки" |
| ) |
| |
| with gr.Column(scale=3): |
| with gr.Row(): |
| latency_display = gr.Textbox( |
| label="⏱️ Время обработки", |
| value="0", |
| interactive=False |
| ) |
| latency_display.info = "мсек" |
| |
| output_text = gr.Markdown( |
| label="📊 Результаты классификации" |
| ) |
| |
| json_output = gr.Code( |
| label="📄 JSON-формат результатов", |
| language="json", |
| interactive=False |
| ) |
| |
| |
| process_btn.click( |
| fn=process_text, |
| inputs=[text_input, model_selector], |
| outputs=[output_text, json_output, latency_display] |
| ) |
| |
| |
| with gr.Accordion("ℹ️ Информация о моделях", open=False): |
| gr.Markdown(""" |
| **Доступные модели:** |
| |
| 1. **cointegrated/rubert-tiny2** - Быстрая и легкая модель, идеально подходит для CPU |
| 2. **s-nlp/rubert-tiny-cased-rured** - Специализирована для тематической классификации |
| 3. **ai-forever/ruBert-base** - Самая точная, но требует больше времени |
| |
| **Ограничения:** |
| - Максимальная длина текста: 1000 символов |
| - Только русский язык |
| - Автоматическое определение 8 основных тем |
| """) |
| |
| gr.Markdown("---") |
| gr.Markdown("### 📌 Инструкция") |
| gr.Markdown(""" |
| 1. Выберите модель из списка |
| 2. Введите или вставьте текст для анализа |
| 3. Нажмите кнопку "Анализировать текст" |
| 4. Получите результаты классификации и время обработки |
| """) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |