|
|
| from base_class import Embedding_Model |
| import pickle |
| from sentence_transformers import SentenceTransformer |
|
|
| from openai.embeddings_utils import ( |
| get_embedding, |
| ) |
|
|
|
|
| class HuggingfaceSentenceTransformerModel(Embedding_Model): |
| EMBEDDING_MODEL = "distiluse-base-multilingual-cased-v2" |
|
|
| def __init__(self, model_name=EMBEDDING_MODEL) -> None: |
| super().__init__(model_name) |
| |
| self.model = SentenceTransformer(model_name, cache_folder="/app/ckpt/") |
|
|
| def __call__(self, text) -> None: |
| return self.model.encode(text) |
|
|
|
|
| class OpenAIEmbeddingModel(Embedding_Model): |
| |
| EMBEDDING_MODEL = "text-embedding-ada-002" |
| |
| |
|
|
| def __init__(self, model_name=EMBEDDING_MODEL) -> None: |
| super().__init__(model_name) |
| self.model_name = model_name |
|
|
| |
| def embedding_from_string(self, |
| string: str, |
| ) -> list: |
| """Return embedding of given string, using a cache to avoid recomputing.""" |
| model = self.model_name |
| if (string, model) not in self.embedding_cache.keys(): |
| self.embedding_cache[(string, model)] = get_embedding( |
| string, model) |
| with open(self.embedding_cache_path, "wb") as embedding_cache_file: |
| pickle.dump(self.embedding_cache, embedding_cache_file) |
| return self.embedding_cache[(string, model)] |
|
|
| def __call__(self, text) -> None: |
| return self.embedding_from_string(text) |
|
|