from typing import Literal import numpy as np from sentence_transformers import SentenceTransformer from openai import OpenAI from dotenv import load_dotenv import tiktoken load_dotenv() # Local HuggingFace model hf_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") # OpenAI client client = OpenAI() # Choose tokenizer for embeddings model tokenizer = tiktoken.encoding_for_model("text-embedding-3-large") # ------------------------------- # Helpers # ------------------------------- def _get_hf_embedding(texts: list[str]) -> list[list[float]]: """Get embeddings using HuggingFace SentenceTransformer.""" return hf_model.encode(texts).tolist() def chunk_text(text: str, max_tokens: int = 1000) -> list[str]: tokens = tokenizer.encode(text) return [tokenizer.decode(tokens[i:i+max_tokens]) for i in range(0, len(tokens), max_tokens)] import numpy as np EMBED_DIM = 3072 # dimension of text-embedding-3-large def _get_openai_embedding(texts: list[str]) -> list[list[float]]: """Get embeddings for a list of texts. If a text is too long, chunk + average.""" final_embeddings = [] for text in texts: if not text or not isinstance(text, str) or not text.strip(): # fallback: skip or append zero vector final_embeddings.append([0.0] * EMBED_DIM) continue # Split into chunks if too long if len(tokenizer.encode(text)) > 8192: chunks = chunk_text(text) else: chunks = [text] # Clean chunks clean_chunks = [c.strip() for c in chunks if isinstance(c, str) and c.strip()] if not clean_chunks: final_embeddings.append([0.0] * EMBED_DIM) continue try: response = client.embeddings.create( model="text-embedding-3-large", input=clean_chunks ) chunk_embeddings = [np.array(d.embedding) for d in response.data] avg_embedding = np.mean(chunk_embeddings, axis=0) final_embeddings.append(avg_embedding.tolist()) except Exception as e: print(f"Embedding failed for text[:100]={text[:100]!r}, error={e}") final_embeddings.append([0.0] * EMBED_DIM) # fallback return final_embeddings embedding_cache = {} def get_embedding(texts: list[str], backend: Literal["hf","openai"] = "hf") -> list[list[float]]: key = (backend, tuple(texts)) # tuple is hashable if key in embedding_cache: return embedding_cache[key] if backend == "hf": embedding_cache[key] = _get_hf_embedding(texts) else: embedding_cache[key] = _get_openai_embedding(texts) return embedding_cache[key] # ------------------------------- # Example # ------------------------------- if __name__ == "__main__": texts = [ "short text example", "very long text " * 2000 # will get chunked ] embs = get_embedding(texts, backend="openai") print(len(embs), "embeddings returned")