Acrosoc commited on
Commit
769df9b
·
verified ·
1 Parent(s): 2d2f572

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -35
app.py CHANGED
@@ -1,55 +1,115 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import pipeline
4
  import torch
 
 
 
5
 
6
- # 1. Инициализация FastAPI приложения
 
 
 
 
 
 
 
7
  app = FastAPI(
8
- title="Text Analysis API",
9
- description="API для анализа текста с использованием моделей из Hugging Face",
10
  version="1.0.0"
11
  )
12
 
13
- # 2. Определение модели для Pydantic (для валидации входных данных)
14
- class TextInput(BaseModel):
15
- text: str
 
 
16
 
17
- # 3. Загрузка модели
18
- # Модель загружается один раз при старте приложения, а не при каждом запросе.
19
- # Это ключевой момент для производительности!
20
- # device=0 использует GPU, если доступен, device=-1 - CPU
21
- # Для Spaces с бесплатным CPU используем device=-1
22
  try:
23
- classifier = pipeline(
24
- "sentiment-analysis",
25
- model="distilbert-base-uncased-finetuned-sst-2-english",
26
- device=-1 # Указываем использование CPU
 
 
 
 
27
  )
28
- print("Модель успешно загружена.")
 
 
29
  except Exception as e:
30
- print(f"Ошибка при загрузке модели: {e}")
31
- classifier = None
 
 
 
 
 
 
 
 
 
 
32
 
33
- # 4. Создание эндпоинта (конечной точки) API
34
  @app.get("/")
35
  def read_root():
36
  """Корневой эндпоинт для проверки работоспособности."""
37
- return {"status": "API is running"}
38
 
39
- @app.post("/analyze")
40
- def analyze_text(request: TextInput):
41
  """
42
- Эндпоинт для анализа тональности текста.
43
- Принимает JSON с полем 'text' и возвращает результат анализа.
44
  """
45
- if not classifier:
46
- return {"error": "Модель не была загружена. Проверьте логи Space."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- text_to_analyze = request.text
49
- result = classifier(text_to_analyze)
50
- return {"input_text": text_to_analyze, "sentiment": result}
51
 
52
- # Пример для запуска локально (не используется в Docker, но полезно для отладки)
53
- # if __name__ == "__main__":
54
- # import uvicorn
55
- # uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
 
3
  import torch
4
+ import transformers
5
+ import charactertokenizer # Импортируем новый токенизатор
6
+ import os
7
 
8
+ # --- 1. Настройка приложения и модели ---
9
+
10
+ # Определяем устройство. Для бесплатных HF Spaces это всегда 'cpu'.
11
+ # Использование os.environ.get для гибкости, если вы переключитесь на GPU.
12
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
13
+ MODEL_NAME = 'ai-forever/charllama-2.6B'
14
+
15
+ # Инициализация FastAPI приложения
16
  app = FastAPI(
17
+ title="CharLLaMA 2.6B API",
18
+ description="API для генерации текста с использованием модели ai-forever/charllama-2.6B",
19
  version="1.0.0"
20
  )
21
 
22
+ # --- 2. Загрузка модели и токенизатора ---
23
+
24
+ # Глобальные переменные для модели и токенизатора
25
+ model = None
26
+ tokenizer = None
27
 
28
+ # Обернем загрузку в try-except для отлова ошибок при старте
 
 
 
 
29
  try:
30
+ print(f"Загрузка токенизатора {MODEL_NAME}...")
31
+ tokenizer = charactertokenizer.CharacterTokenizer.from_pretrained(MODEL_NAME)
32
+
33
+ print(f"Загрузка модели {MODEL_NAME} на устройство {DEVICE}...")
34
+ # Для CPU-инстанций используем torch.float32. Если бы была GPU, можно было бы использовать float16
35
+ model = transformers.AutoModelForCausalLM.from_pretrained(
36
+ MODEL_NAME,
37
+ torch_dtype=torch.float32
38
  )
39
+ model.to(DEVICE)
40
+ print("Модель и токенизатор успешно загружены.")
41
+
42
  except Exception as e:
43
+ print(f"Критическая ошибка при загрузке модели: {e}")
44
+ # Если модель не загрузилась, приложение будет возвращать ошибку.
45
+
46
+ # --- 3. Определение моделей данных (Pydantic) ---
47
+
48
+ class GenerationInput(BaseModel):
49
+ prompt: str
50
+ max_length: int = 512 # Даем пользователю возможность управлять параметрами
51
+ temperature: float = 0.8
52
+ top_p: float = 0.6
53
+
54
+ # --- 4. Создание эндпоинтов API ---
55
 
 
56
  @app.get("/")
57
  def read_root():
58
  """Корневой эндпоинт для проверки работоспособности."""
59
+ return {"status": "API is running", "model_loaded": model is not None}
60
 
61
+ @app.post("/generate")
62
+ def generate_text(request: GenerationInput):
63
  """
64
+ Эндпоинт для генерации текста.
65
+ Принимает JSON с полем 'prompt' и опциональными параметрами генерации.
66
  """
67
+ if not model or not tokenizer:
68
+ raise HTTPException(
69
+ status_code=503,
70
+ detail="Модель не была загружена. Проверьте логи Space."
71
+ )
72
+
73
+ prompt = request.prompt
74
+
75
+ # Параметры генерации из запроса и примера
76
+ generation_args = {
77
+ 'max_length': request.max_length,
78
+ 'num_return_sequences': 1,
79
+ 'do_sample': True,
80
+ 'no_repeat_ngram_size': 10,
81
+ 'temperature': request.temperature,
82
+ 'top_p': request.top_p,
83
+ 'top_k': 0,
84
+ }
85
+
86
+ try:
87
+ # 1. Токенизация входного текста
88
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(DEVICE)
89
+ prompt_len = input_ids.shape[1]
90
+
91
+ # 2. Генерация
92
+ print("Начинаю генерацию...")
93
+ output_ids = model.generate(
94
+ input_ids=input_ids,
95
+ eos_token_id=tokenizer.eos_token_id,
96
+ **generation_args
97
+ )
98
+ print("Генерация завершена.")
99
+
100
+ # 3. Декодирование и постобработка
101
+ # Декодируем только сгенерированную часть, исключая исходный промпт
102
+ generated_part = output_ids[0][prompt_len:]
103
+ output_text = tokenizer.decode(generated_part, skip_special_tokens=True)
104
 
105
+ # Убираем все, что идет после токена конца последовательности, если он есть
106
+ if '</s>' in output_text:
107
+ output_text = output_text.split('</s>')[0].strip()
108
 
109
+ return {
110
+ "input_prompt": prompt,
111
+ "generated_text": output_text
112
+ }
113
+ except Exception as e:
114
+ print(f"Ошибка во время генерации: {e}")
115
+ raise HTTPException(status_code=500, detail=str(e))