chripto commited on
Commit
8f315c1
·
verified ·
1 Parent(s): 9e6f606

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +68 -3
main.py CHANGED
@@ -1,8 +1,11 @@
1
  import os
 
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel, Field
4
  from huggingface_hub import InferenceClient
5
  import uvicorn
 
 
6
 
7
 
8
  app = FastAPI()
@@ -49,6 +52,56 @@ def format_prompt_for_text_generation(system_prompt: str, history: list, prompt:
49
  return out
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  class Item(BaseModel):
53
  """Payload per /generate/. Per Ministral-3: temperature < 0.1 in produzione (raccomandato)."""
54
  prompt: str
@@ -102,7 +155,19 @@ def generate(item: Item) -> str:
102
  except Exception as e2:
103
  last_error = e2
104
 
105
- # 3) Ultima risorsa: text_generation (solo per modelli che lo supportano su hf-inference)
 
 
 
 
 
 
 
 
 
 
 
 
106
  try:
107
  formatted = format_prompt_for_text_generation(
108
  item.system_prompt, item.history or [], item.prompt
@@ -123,8 +188,8 @@ def generate(item: Item) -> str:
123
  for r in stream
124
  )
125
  return str(stream)
126
- except Exception as e3:
127
- last_error = e3
128
 
129
  raise HTTPException(status_code=502, detail=f"Inference fallita: {str(last_error)}")
130
 
 
1
  import os
2
+ import json
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel, Field
5
  from huggingface_hub import InferenceClient
6
  import uvicorn
7
+ import urllib.request
8
+ import urllib.error
9
 
10
 
11
  app = FastAPI()
 
52
  return out
53
 
54
 
55
+ def chat_completion_via_http(messages: list, max_tokens: int, temperature: float, top_p: float) -> str | None:
56
+ """
57
+ Chiamata diretta all'endpoint HF chat completions (v1).
58
+ Usata quando il SDK fallisce perché il modello non dichiara il task (es. Ministral-3).
59
+ """
60
+ if not HF_TOKEN:
61
+ return None
62
+ base = INFERENCE_ENDPOINT_URL.rstrip("/") if INFERENCE_ENDPOINT_URL else f"https://api-inference.huggingface.co/models/{MODEL_ID}"
63
+ url = f"{base}/v1/chat/completions"
64
+ body = {
65
+ "messages": messages,
66
+ "max_tokens": max_tokens,
67
+ "temperature": temperature,
68
+ "top_p": top_p,
69
+ }
70
+ data = json.dumps(body).encode("utf-8")
71
+ req = urllib.request.Request(
72
+ url,
73
+ data=data,
74
+ headers={
75
+ "Authorization": f"Bearer {HF_TOKEN}",
76
+ "Content-Type": "application/json",
77
+ },
78
+ method="POST",
79
+ )
80
+ try:
81
+ with urllib.request.urlopen(req, timeout=120) as resp:
82
+ out = json.loads(resp.read().decode())
83
+ except urllib.error.HTTPError as e:
84
+ if e.code == 503:
85
+ err_body = e.read().decode() if e.fp else ""
86
+ try:
87
+ err_json = json.loads(err_body)
88
+ if "estimated_time" in err_json:
89
+ raise HTTPException(
90
+ status_code=503,
91
+ detail=f"Modello in caricamento. Riprova tra {err_json.get('estimated_time', 0):.0f}s.",
92
+ )
93
+ except (ValueError, TypeError):
94
+ pass
95
+ return None
96
+ except Exception:
97
+ return None
98
+ choices = out.get("choices") or []
99
+ if not choices:
100
+ return None
101
+ msg = choices[0].get("message") or {}
102
+ return (msg.get("content") or "").strip()
103
+
104
+
105
  class Item(BaseModel):
106
  """Payload per /generate/. Per Ministral-3: temperature < 0.1 in produzione (raccomandato)."""
107
  prompt: str
 
155
  except Exception as e2:
156
  last_error = e2
157
 
158
+ # 3) Chat completions via HTTP (endpoint v1) – funziona per modelli che non dichiarano il task (es. Ministral-3)
159
+ try:
160
+ content = chat_completion_via_http(
161
+ messages, item.max_new_tokens, temperature, top_p
162
+ )
163
+ if content is not None and content != "":
164
+ return content
165
+ except HTTPException:
166
+ raise
167
+ except Exception as e3:
168
+ last_error = e3
169
+
170
+ # 4) Ultima risorsa: text_generation (solo per modelli che lo supportano su hf-inference)
171
  try:
172
  formatted = format_prompt_for_text_generation(
173
  item.system_prompt, item.history or [], item.prompt
 
188
  for r in stream
189
  )
190
  return str(stream)
191
+ except Exception as e4:
192
+ last_error = e4
193
 
194
  raise HTTPException(status_code=502, detail=f"Inference fallita: {str(last_error)}")
195