import os import torch from threading import Thread from transformers import ( AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, ) MODEL_ID = os.environ.get("MODEL_ID", "swiss-ai/Apertus-8B-Instruct-2509") # ---- Load model & tokenizer once at startup tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) # dtype: prefer bfloat16 on GPU (A100/T4 support), else float32 for CPU if torch.cuda.is_available(): torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 else: torch_dtype = torch.float32 model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", # accelerate will shard across available devices torch_dtype=torch_dtype, trust_remote_code=True, ) # Ensure we have an EOS if needed eos_token_id = tokenizer.eos_token_id def _apply_chat_template_with_fallback(messages): """ Apply the tokenizer's chat template if present; otherwise, fall back to a simple format. Returns a string prompt (not tokenized). """ try: return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) except Exception: # Fallback formatting parts = [] for m in messages: role = m.get("role", "user") content = m.get("content", "") parts.append(f"<|{role}|>\n{content}\n") parts.append("<|assistant|>\n") return "\n".join(parts) def chat_with_model(message, history_messages, perspective): """ Streaming generator for Gradio (Chatbot type='messages'). Inputs: - message: str - history_messages: list[{'role': 'user'|'assistant', 'content': str}] - perspective: str (system message, optional) Yields: - (updated_messages_for_chatbot, updated_messages_for_state) """ # Compose chat messages for this turn chat_msgs = [] if perspective and perspective.strip(): chat_msgs.append({"role": "system", "content": perspective.strip()}) # Append prior turns from UI state (already in messages format) for m in history_messages: if "role" in m and "content" in m: chat_msgs.append({"role": m["role"], "content": m["content"]}) # Add the new user message chat_msgs.append({"role": "user", "content": message}) # Build the prompt with the model's chat template prompt_text = _apply_chat_template_with_fallback(chat_msgs) inputs = tokenizer(prompt_text, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} # Set up streamer for token-wise output streamer = TextIteratorStreamer( tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True, ) gen_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=512, do_sample=True, temperature=0.7, top_p=0.9, eos_token_id=eos_token_id, ) # Launch generation in a background thread thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() # Start building the assistant reply incrementally reply = "" base = history_messages + [{"role": "user", "content": message}] for token_text in streamer: reply += token_text updated = base + [{"role": "assistant", "content": reply}] yield updated, updated