from inference.inference import ( force_CPU, generate_text_stream, list_checkpoints, load_model, ) import argparse import torch from inference.model import ByteTokenizer import os import sys def main(): parser = argparse.ArgumentParser( description="Text generation with DiffAttention LLM", formatter_class=argparse.RawTextHelpFormatter, ) # Generation mode arguments parser.add_argument( "--prompt", type=str, default="", help="Run in single-shot mode with the given prompt.", ) parser.add_argument( "-c", "--chat", action="store_true", help="Run in interactive chat mode." ) # Chat mode arguments parser.add_argument( "--system", type=str, default="You are a helpful chatbot.", help="System prompt for chat mode.", ) parser.add_argument( "--user_role", type=str, default="user", help="Role name for the user in chat mode.", ) parser.add_argument( "--assistant_role", type=str, default="assistant", help="Role name for the assistant in chat mode.", ) # Common arguments parser.add_argument( "--checkpoint", type=str, default="model.pt", help="Path to the checkpoint file.", ) parser.add_argument( "--stop", nargs="+", default=[], help='One or more stop sequences. e.g. --stop "world" """', ) parser.add_argument( "--max_tokens", type=int, default=512, help="Maximum number of new tokens to generate.", ) parser.add_argument( "--temperature", type=float, default=0.35, help="Sampling temperature." ) parser.add_argument( "--top_k", type=int, default=7, help="Top-k sampling parameter (0 to disable).", ) parser.add_argument( "--repetition_penalty", type=float, default=1.35, help="Repetition penalty (1.0 for no penalty).", ) parser.add_argument( "--list_checkpoints", action="store_true", help="List available checkpoints and exit.", ) args = parser.parse_args() if not args.prompt and not args.chat and not args.list_checkpoints: parser.print_help() sys.exit( "\nError: Either --prompt, --chat, or --list_checkpoints must be specified." ) # List checkpoints if requested if args.list_checkpoints: print("Available checkpoints:") checkpoints = list_checkpoints() if not checkpoints: print("No checkpoints found.") for i, ckpt in enumerate(checkpoints): print(f"{i+1}. {ckpt}") return checkpoint_path = args.checkpoint if not os.path.exists(checkpoint_path): print(f"Checkpoint file not found: {checkpoint_path}") print("Searching for latest checkpoint in 'checkpoints/' directory...") checkpoints = list_checkpoints() if not checkpoints: sys.exit( "No checkpoints found. Please train a model or specify a valid path." ) end_checkpoints = [ckpt for ckpt in checkpoints if "end.pt" in ckpt] if end_checkpoints: latest_checkpoint = max(end_checkpoints) else: latest_checkpoint = max(checkpoints) checkpoint_path = os.path.join("checkpoints", latest_checkpoint) print(f"Using latest checkpoint: {checkpoint_path}") # Set device if torch.backends.mps.is_available() and not force_CPU: device = torch.device("mps") else: device = torch.device( "cuda" if torch.cuda.is_available() and not force_CPU else "cpu" ) print(f"Using device: {device}") tokenizer = ByteTokenizer() # Load model model = load_model(checkpoint_path, device) # --- Mode Handling --- if args.chat: stop_sequences = args.stop + ["<|im_end|>"] history = f"<|im_start|>system\n{args.system}<|im_end|>\n" print("\n--- Interactive Chat ---") print(f"System Prompt: {args.system}") print("Type 'exit' or 'quit' to end the session.") print("-" * 26) while True: try: user_prompt_display = f"<|im_start|>{args.user_role}\n" user_input = input(user_prompt_display) if user_input.lower() in ["exit", "quit"]: break prompt = ( history + f"<|im_start|>{args.user_role}\n{user_input}<|im_end|>\n" + f"<|im_start|>{args.assistant_role}\n" ) print(f"<|im_start|>{args.assistant_role}") sys.stdout.flush() generated_text_parts = [] for chunk in generate_text_stream( model=model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, repetition_penalty=args.repetition_penalty, device=device, stop_sequences=stop_sequences, ): print(chunk, end="", flush=True) generated_text_parts.append(chunk) generated_text = "".join(generated_text_parts) history += ( f"<|im_start|>{args.user_role}\n{user_input}<|im_end|>\n" + f"<|im_start|>{args.assistant_role}\n{generated_text}<|im_end|>\n" ) print() # Newline after assistant output except (KeyboardInterrupt, EOFError): print("\nExiting chat.") break else: print(f"\nGenerating text with prompt: '{args.prompt}'") print( f"Parameters: temp={args.temperature}, top_k={args.top_k}, repetition_penalty={args.repetition_penalty}" ) print("\n--- Generation Start ---") generated_text_parts = [] for chunk in generate_text_stream( model=model, tokenizer=tokenizer, prompt=args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, repetition_penalty=args.repetition_penalty, device=device, stop_sequences=args.stop, ): print(chunk, end="", flush=True) generated_text_parts.append(chunk) print("\n--- Generation End ---") generated_text = "".join(generated_text_parts) full_text = args.prompt + generated_text print("\n\nFull generated text (for reference):") print("-" * 40) print(full_text) print("-" * 40) if __name__ == "__main__": main()