from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import List, Optional import numpy as np from sentence_transformers import SentenceTransformer app = FastAPI(title="FRIDA Embedding API", version="1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) MODEL_NAME = "ai-forever/FRIDA" model = SentenceTransformer(MODEL_NAME) EMBED_DIM = model.get_sentence_embedding_dimension() SUPPORTED_PROMPTS = [ "search_query", "search_document", "paraphrase", "categorize", "categorize_sentiment", "categorize_topic", "categorize_entailment", ] class EmbedRequest(BaseModel): texts: List[str] = Field(..., description="Список текстов") prompt_name: Optional[str] = Field("search_document", description="FRIDA prompt_name") class EmbedResponse(BaseModel): embeddings: List[List[float]] dim: int @app.get("/health") def health(): return {"status": "ok"} @app.get("/metadata") def metadata(): return { "model": MODEL_NAME, "embedding_dim": EMBED_DIM, "pooling": "cls", "prompts_supported": SUPPORTED_PROMPTS, } @app.post("/embed", response_model=EmbedResponse) def embed(req: EmbedRequest): if not req.texts: raise HTTPException(status_code=400, detail="texts must be non-empty") prompt = req.prompt_name or "search_document" if prompt not in SUPPORTED_PROMPTS: raise HTTPException(status_code=400, detail=f"Unsupported prompt_name: {prompt}") vectors = model.encode( req.texts, convert_to_numpy=True, prompt_name=prompt, normalize_embeddings=True, batch_size=min(16, max(1, len(req.texts))), show_progress_bar=False, ).astype(np.float32) return {"embeddings": vectors.tolist(), "dim": int(vectors.shape[1])}