| from typing import Dict, List, Any | |
| from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| import os | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| self.path = path | |
| logger.info(f"Initializing handler with path: {path}") | |
| logger.info(f"Directory contents: {os.listdir(path) if os.path.exists(path) else 'Path does not exist'}") | |
| try: | |
| self.task = self._determine_task() | |
| except Exception as e: | |
| logger.error(f"Failed to determine task: {str(e)}") | |
| raise | |
| logger.info(f"Initializing model for task: {self.task}") | |
| if self.task == "text-generation": | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| path, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(path) | |
| self.pipeline = pipeline( | |
| "text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| elif self.task == "text-classification": | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| path, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(path) | |
| self.pipeline = pipeline( | |
| "text-classification", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| elif self.task == "sentence-embedding": | |
| self.model = SentenceTransformer(path) | |
| else: | |
| raise ValueError(f"Unsupported task: {self.task} for model at {path}") | |
| def _determine_task(self): | |
| config_path = os.path.join(self.path, "config.json") | |
| logger.info(f"Checking for config.json at: {config_path}") | |
| if not os.path.exists(config_path): | |
| logger.error(f"config.json not found in {self.path}") | |
| raise ValueError(f"config.json not found in {self.path}") | |
| try: | |
| config = AutoConfig.from_pretrained(self.path) | |
| model_type = config.model_type if hasattr(config, "model_type") else None | |
| except Exception as e: | |
| logger.error(f"Failed to load config: {str(e)}") | |
| raise ValueError(f"Invalid config.json in {self.path}: {str(e)}") | |
| text_generation_types = ["gpt2"] | |
| text_classification_types = ["bert", "distilbert", "roberta"] | |
| embedding_types = ["bert"] | |
| model_name = self.path.split("/")[-1].lower() | |
| logger.info(f"Model name: {model_name}, Model type: {model_type}") | |
| if model_type in text_generation_types or model_name in ["fine_tuned_gpt2", "merged_distilgpt2"]: | |
| return "text-generation" | |
| elif model_type in text_classification_types or model_name in ["emotion_classifier", "emotion_model", "intent_classifier", "intent_fallback"]: | |
| return "text-classification" | |
| elif model_name in ["intent_encoder", "sentence_transformer"] or "sentence_bert_config.json" in os.listdir(self.path): | |
| return "sentence-embedding" | |
| raise ValueError(f"Could not determine task for model_type: {model_type}, model_name: {model_name}") | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| inputs = data.get("inputs", "") | |
| parameters = data.get("parameters", None) | |
| if not inputs: | |
| logger.warning("No inputs provided") | |
| return [{"error": "No inputs provided"}] | |
| try: | |
| logger.info(f"Processing inputs for task: {self.task}") | |
| if self.task == "text-generation": | |
| result = self.pipeline(inputs, max_length=50, num_return_sequences=1, **(parameters or {})) | |
| return [{"generated_text": item["generated_text"]} for item in result] | |
| elif self.task == "text-classification": | |
| result = self.pipeline(inputs, return_all_scores=True, **(parameters or {})) | |
| return [{"label": item["label"], "score": item["score"]} for sublist in result for item in sublist] | |
| elif self.task == "sentence-embedding": | |
| embeddings = self.model.encode(inputs) | |
| return [{"embeddings": embeddings.tolist()}] | |
| return [{"error": f"Unsupported task: {self.task}"}] | |
| except Exception as e: | |
| logger.error(f"Inference failed: {str(e)}") | |
| return [{"error": f"Inference failed: {str(e)}"}] |