from typing import Dict, List, Any import torch import os from transformers import AutoModelForCausalLM, AutoTokenizer class EndpointHandler: def __init__(self, path=""): # Get HuggingFace token for gated model access hf_token = os.getenv("HF_TOKEN") # Load model and tokenizer with authentication self.tokenizer = AutoTokenizer.from_pretrained( path, token=hf_token ) self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.float16, device_map="auto", token=hf_token ) # Set pad token if not exists if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Simple handler that mimics local LLM behavior for RemoteLLM """ inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) # Handle different input formats that RemoteLLM sends if isinstance(inputs, dict) and "messages" in inputs: messages = inputs["messages"] elif isinstance(inputs, list): messages = inputs else: # Fallback - treat as direct text messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": str(inputs)} ] # Check if this is a continuation (has assistant message) has_assistant = any(msg.get("role") == "assistant" for msg in messages) # Apply chat template exactly like BrickGPT does locally if has_assistant: prompt = self.tokenizer.apply_chat_template( messages, continue_final_message=True, return_tensors='pt' ) else: prompt = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors='pt' ) # Move to device input_ids = prompt.to(self.model.device) attention_mask = torch.ones_like(input_ids) # Generation parameters - use BrickGPT defaults generation_params = { "max_new_tokens": parameters.get("max_new_tokens", 10), "temperature": parameters.get("temperature", 0.6), "top_k": parameters.get("top_k", 20), "top_p": parameters.get("top_p", 1.0), "pad_token_id": self.tokenizer.pad_token_id, "do_sample": True, "num_return_sequences": 1, "return_dict_in_generate": True, } # Generate with torch.no_grad(): output_dict = self.model.generate( input_ids, attention_mask=attention_mask, **generation_params ) # Extract new tokens and decode EXACTLY like local LLM input_length = input_ids.shape[1] result_ids = output_dict['sequences'][0][input_length:] # CRITICAL: Decode exactly like local LLM (no skip_special_tokens parameter) generated_text = self.tokenizer.decode(result_ids) # Return in format RemoteLLM expects return [{"generated_text": generated_text}]