import argparse import time import mlx.core as mx from transformers import AutoTokenizer from model import load_model from pathlib import Path def generate_text( prompt: str, model_path: str, max_tokens: int = 100, temperature: float = 0.1, top_p: float = 0.9, system: str | None = None, final_only: bool = False, stop_at_boxed: bool = False, extract_boxed: bool = False, disable_chat_template: bool = False, repetition_penalty: float = 1.0, frequency_penalty: float = 0.0, ): """Generates text using the loaded MLX model with better sampling.""" print("Loading model and tokenizer...") model = load_model(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) # Check if we have the chat template chat_template_path = Path(model_path) / "chat_template.jinja" use_chat_format = chat_template_path.exists() and not disable_chat_template print(f"Chat template found: {use_chat_format}") print("Starting generation...") print(f"Prompt: {prompt}") # Format the prompt if using chat template if use_chat_format: messages = [] if system is None and final_only: system = ( "You are a helpful assistant. Do not reveal your reasoning. " "Respond with only the final answer enclosed in \\boxed{...}." ) if system is not None: messages.append({"role": "system", "content": system}) messages.append({"role": "user", "content": prompt}) formatted_prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) print(f"Formatted prompt: {formatted_prompt}") else: # No chat template: prepend BOS if available in tokenizer bos = tokenizer.bos_token or "" formatted_prompt = f"{bos}{prompt}" # Tokenize the prompt prompt_tokens = tokenizer.encode(formatted_prompt, add_special_tokens=False) prompt_tokens = mx.array([prompt_tokens]) print(f"Prompt tokens shape: {prompt_tokens.shape}") print( f"First few token IDs: {prompt_tokens[0, : min(10, prompt_tokens.shape[1])].tolist()}" ) # Generation loop with better sampling start_time = time.time() generated_tokens = [] freq_counts = {} running_text = "" seen_box_start = False for i in range(max_tokens): # Get logits from model logits = model(prompt_tokens) # Focus on next-token logits next_token_logits = logits[0, -1, :] # Apply repetition and frequency penalties before sampling/argmax if repetition_penalty and repetition_penalty != 1.0 and generated_tokens: # Apply a simple repetition penalty to previously generated tokens # Using HF-like rule: if logit > 0 divide by penalty else multiply by penalty logits_list = next_token_logits.tolist() seen = set(generated_tokens) for tid in seen: val = logits_list[tid] if val > 0: logits_list[tid] = val / repetition_penalty else: logits_list[tid] = val * repetition_penalty next_token_logits = mx.array(logits_list) if frequency_penalty and frequency_penalty > 0 and generated_tokens: # Subtract a multiple of token frequency from logits counts = {} for t in generated_tokens: counts[t] = counts.get(t, 0) + 1 # Build a dense penalty vector once per step vocab_size = next_token_logits.shape[-1] pen = [0.0] * vocab_size for tid, c in counts.items(): pen[tid] = frequency_penalty * float(c) next_token_logits = next_token_logits - mx.array(pen) # Apply temperature (temperature==0 -> greedy) if temperature == 0: # Greedy decode next_token = int(mx.argmax(next_token_logits).item()) else: # Sampling path: scale logits, apply top-p mask in logits space scaled_logits = next_token_logits / temperature if 0.0 < top_p < 1.0: probs = mx.softmax(scaled_logits, axis=-1) sorted_probs = mx.sort(probs)[::-1] cumulative_probs = mx.cumsum(sorted_probs, axis=-1) cutoff_index = mx.sum(cumulative_probs < top_p) cutoff_prob = sorted_probs[cutoff_index.item()] mask = probs >= cutoff_prob scaled_logits = mx.where(mask, scaled_logits, float("-inf")) # Sample from logits (MLX categorical expects logits) next_token = mx.random.categorical(scaled_logits, num_samples=1).item() # Safer stop condition: support multiple EOS ids eos_ids = tokenizer.eos_token_id if isinstance(eos_ids, (list, tuple)): stop_ids = set(int(i) for i in eos_ids) else: stop_ids = {int(eos_ids)} if next_token in stop_ids: print(f"Stopping generation at EOS token: {next_token}") break generated_tokens.append(next_token) # Update frequency counts freq_counts[next_token] = freq_counts.get(next_token, 0) + 1 # Append the new token for the next iteration prompt_tokens = mx.concatenate( [prompt_tokens, mx.array([[next_token]])], axis=1 ) # Print token as we generate for debugging if i < 10: # Only print first 10 tokens to avoid spam token_text = tokenizer.decode([next_token]) print(f"Token {i}: {next_token} -> '{token_text}'") # Optional boxed stopping condition if stop_at_boxed: token_text_full = tokenizer.decode([next_token], skip_special_tokens=False) running_text += token_text_full if not seen_box_start and "\\boxed{" in running_text: seen_box_start = True if seen_box_start and "}" in running_text: print("Stopping generation at boxed answer.") break end_time = time.time() # Decode and print the result if generated_tokens: response = tokenizer.decode(generated_tokens, skip_special_tokens=True) print("\n--- Response ---") print(response) else: print("\n--- No tokens generated ---") print("------------------") generation_speed = ( len(generated_tokens) / (end_time - start_time) if generated_tokens else 0 ) print(f"Generated {len(generated_tokens)} tokens") print(f"Generation speed: {generation_speed:.2f} tokens/sec") # Also print the full generated sequence including special tokens for debugging if generated_tokens: full_response = tokenizer.decode(generated_tokens, skip_special_tokens=False) print(f"\nFull response (with special tokens): '{full_response}'") if extract_boxed and generated_tokens: import re m = None # Get the last occurrence of \\boxed{...} for m in re.finditer(r"\\\\boxed\{([^}]*)\}", full_response): pass if m: print(f"\nExtracted boxed answer: {m.group(1).strip()}") else: print("\nNo \\boxed{...} segment found to extract.") def main(): parser = argparse.ArgumentParser(description="Run inference with the MLX model.") parser.add_argument( "--model-path", type=str, default=".", help="Path to the model directory." ) parser.add_argument( "--prompt", type=str, default="What is the capital of France?", help="The prompt to start generation from.", ) parser.add_argument( "--max-tokens", type=int, default=100, help="The maximum number of tokens to generate.", ) parser.add_argument( "--temperature", type=float, default=0.1, help="Sampling temperature." ) parser.add_argument( "--top-p", type=float, default=0.9, help="Top-p (nucleus) sampling parameter." ) parser.add_argument( "--system", type=str, default=None, help="Optional system message for chat template." ) parser.add_argument( "--final-only", action="store_true", help="Instruct the model to output only the final answer inside \\boxed{...}.", ) parser.add_argument( "--stop-at-boxed", action="store_true", help="Stop generation once a closing '}' appears after \\boxed{.", ) parser.add_argument( "--extract-boxed", action="store_true", help="Extract and print the content inside the last \\boxed{...} in the response.", ) parser.add_argument( "--disable-chat-template", action="store_true", help="Ignore chat_template.jinja and feed the raw prompt (prepended with BOS).", ) parser.add_argument( "--repetition-penalty", type=float, default=1.0, help="Penalty (>1.0) to discourage previously generated tokens.", ) parser.add_argument( "--frequency-penalty", type=float, default=0.0, help="Subtract alpha * count(token) from logits before sampling.", ) args = parser.parse_args() generate_text( args.prompt, args.model_path, args.max_tokens, args.temperature, args.top_p, args.system, args.final_only, args.stop_at_boxed, args.extract_boxed, args.disable_chat_template, args.repetition_penalty, args.frequency_penalty, ) if __name__ == "__main__": main()