import torch import os from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import logging # Disable transformers warnings for cleaner output logging.getLogger("transformers").setLevel(logging.ERROR) class EndpointHandler: """ Optimized handler for merged Llama-3.1-8B model with fallback strategies. Tries best configuration first, then falls back to more compatible options. """ def __init__(self, path=""): # Get token from environment like local implementation self.ACCESS_TOKEN = os.getenv("HF_TOKEN") # Load tokenizer to match local implementation self.tokenizer = AutoTokenizer.from_pretrained(path, token=self.ACCESS_TOKEN) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Try loading model with progressive fallback strategies self.model = self._load_model_with_fallback(path) # Pre-compute system message tokens for efficiency self.system_message = "Extract the json format in from the data present in . Write nothing else." # Warm up the model with a dummy input self._warmup() def _load_model_with_fallback(self, path): """Load model with progressive fallback strategies""" # Strategy 1: Best performance - Flash Attention + bfloat16 try: print("Attempting to load model with Flash Attention and bfloat16...") model = AutoModelForCausalLM.from_pretrained( path, token=self.ACCESS_TOKEN, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, attn_implementation="flash_attention_2", low_cpu_mem_usage=True, use_cache=True, ).eval() # Try to compile for additional speed try: model = torch.compile(model, mode="reduce-overhead") print("✓ Model loaded with Flash Attention, bfloat16, and compilation") except Exception: print( "✓ Model loaded with Flash Attention and bfloat16 (compilation failed)" ) return model except Exception as e: print(f"Flash Attention loading failed: {e}") # Strategy 2: Standard loading without Flash Attention but with bfloat16 try: print("Attempting to load model without Flash Attention...") model = AutoModelForCausalLM.from_pretrained( path, token=self.ACCESS_TOKEN, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True, use_cache=True, ).eval() # Try to compile for additional speed try: model = torch.compile(model, mode="reduce-overhead") print("✓ Model loaded with bfloat16 and compilation") except Exception: print("✓ Model loaded with bfloat16 (compilation failed)") return model except Exception as e: print(f"Standard bfloat16 loading failed: {e}") # Strategy 3: Try with float16 instead of bfloat16 try: print("Attempting to load model with float16...") model = AutoModelForCausalLM.from_pretrained( path, token=self.ACCESS_TOKEN, torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True, use_cache=True, ).eval() print("✓ Model loaded with float16") return model except Exception as e: print(f"Float16 loading failed: {e}") # Strategy 4: 4-bit quantization matching local implementation exactly try: print("Attempting to load model with 4-bit quantization...") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) model = AutoModelForCausalLM.from_pretrained( path, token=self.ACCESS_TOKEN, # Add token like local implementation quantization_config=quantization_config, device_map="cuda", # Use "cuda" like local, not "auto" torch_dtype=torch.float16, ).eval() print("✓ Model loaded with 4-bit quantization") return model except Exception as e: print(f"4-bit quantization loading failed: {e}") # Strategy 5: CPU fallback (very slow but should work) try: print("Attempting to load model on CPU (this will be slow)...") model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True, low_cpu_mem_usage=True, ).eval() print("⚠ Model loaded on CPU - inference will be slow") return model except Exception as e: print(f"CPU loading failed: {e}") raise RuntimeError("Failed to load model with any strategy") def _warmup(self): """Warm up the model to avoid first-call latency""" try: dummy_messages = [ {"role": "system", "content": self.system_message}, {"role": "user", "content": "test"}, ] input_ids = self.tokenizer.apply_chat_template( dummy_messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", ).to(self.model.device) with torch.no_grad(): self.model.generate( input_ids=input_ids, max_new_tokens=1, do_sample=False, pad_token_id=self.tokenizer.eos_token_id, ) except Exception: pass # Ignore warmup errors def __call__(self, data: dict): """ Optimized inference call. Expected input: { "inputs": "", "parameters": { "max_new_tokens": 120, "temperature": 0.6, "top_p": 0.9, ... (any valid generate kwargs) } } Returns: generated string (no wrapper dict) to match InferenceClient.text_generation(). """ prompt = data.get("inputs") if not prompt: raise ValueError("Missing 'inputs' field in request data.") params = data.get("parameters", {}).copy() # Extract generation parameters to match local implementation exactly max_new_tokens = params.pop("max_new_tokens", 1024) temperature = params.pop("temperature", 0.6) top_p = params.pop("top_p", 0.9) # Don't set do_sample explicitly to match local behavior # Build messages efficiently messages = [ {"role": "system", "content": self.system_message}, {"role": "user", "content": prompt}, ] # Tokenize exactly like the local implementation input_ids = self.tokenizer.apply_chat_template( conversation=messages, # Use 'conversation' parameter like local tokenize=True, add_generation_prompt=True, return_tensors="pt", padding=True, # Use padding like local ) # Generate attention mask exactly like local implementation attention_mask = input_ids.ne(self.tokenizer.pad_token_id) # Generate with parameters matching local implementation exactly with torch.no_grad(): # Disable gradients for inference output_ids = self.model.generate( input_ids=input_ids.to(self.model.device), attention_mask=attention_mask.to(self.model.device), pad_token_id=self.tokenizer.eos_token_id, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, ) # Decode only the generated tokens exactly like local implementation response = self.tokenizer.decode( output_ids[0][input_ids.shape[1] :], skip_special_tokens=True ) return response.strip()