from typing import Literal import numpy as np from sentence_transformers import SentenceTransformer from openai import OpenAI from dotenv import load_dotenv import tiktoken load_dotenv() # Local HuggingFace model hf_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") # OpenAI client client = OpenAI() # Choose tokenizer for embeddings model tokenizer = tiktoken.encoding_for_model("text-embedding-3-large") # ------------------------------- # Helpers # ------------------------------- def _get_hf_embedding(texts: list[str]) -> list[list[float]]: """Get embeddings using HuggingFace SentenceTransformer.""" return hf_model.encode(texts).tolist() def chunk_text(text: str, max_tokens: int = 1000) -> list[str]: tokens = tokenizer.encode(text) return [tokenizer.decode(tokens[i:i+max_tokens]) for i in range(0, len(tokens), max_tokens)] def _get_openai_embedding(texts: list[str]) -> list[list[float]]: """Get embeddings for a list of texts. If a text is too long, chunk + average.""" final_embeddings = [] for text in texts: # Split into chunks if too long if len(tokenizer.encode(text)) > 8192: chunks = chunk_text(text) else: chunks = [text] # Call API on all chunks at once response = client.embeddings.create( model="text-embedding-3-large", input=chunks ) chunk_embeddings = [np.array(d.embedding) for d in response.data] # Average embeddings if multiple chunks avg_embedding = np.mean(chunk_embeddings, axis=0) final_embeddings.append(avg_embedding.tolist()) return final_embeddings def get_embedding(texts: list[str], backend: Literal["hf","openai"] = "hf") -> list[list[float]]: """ Get embeddings for a list of texts. backend = "openai" or "hf" """ if backend == "hf": return _get_hf_embedding(texts) return _get_openai_embedding(texts) # ------------------------------- # Example # ------------------------------- if __name__ == "__main__": texts = [ "short text example", "very long text " * 2000 # will get chunked ] embs = get_embedding(texts, backend="openai") print(len(embs), "embeddings returned")