import gradio as gr import torch from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, BitsAndBytesConfig import os import time # Disable wandb os.environ["WANDB_DISABLED"] = "true" # Global variables model = None tokenizer = None training_status = "Not started" def load_model(): global model, tokenizer try: # Configure 4-bit quantization for memory efficiency quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", ) # Load model and tokenizer model_name = "LLM360/K2-Think" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=quantization_config, device_map="auto" ) # Set padding token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return "Model loaded successfully!" except Exception as e: return f"Error loading model: {str(e)}" def prepare_data(): try: # Load a sample dataset (you can replace this with your own) dataset = load_dataset("imdb") # Preprocessing function def preprocess_function(examples): # Format the text for instruction tuning texts = [] for text, label in zip(examples["text"], examples["label"]): sentiment = "positive" if label == 1 else "negative" texts.append(f"Analyze the sentiment of this movie review: {text}\nSentiment: {sentiment}") # Tokenize tokenized = tokenizer(texts, truncation=True, padding=True, max_length=256) # Create labels tokenized["labels"] = tokenized["input_ids"].copy() return tokenized # Apply preprocessing tokenized_dataset = dataset.map( preprocess_function, batched=True, remove_columns=dataset["train"].column_names ) # Use small subset for demo train_dataset = tokenized_dataset["train"].shuffle().select(range(50)) return train_dataset, "Data prepared successfully!" except Exception as e: return None, f"Error preparing data: {str(e)}" def train_model(): global model, tokenizer, training_status try: training_status = "Starting training..." yield training_status # Prepare data train_dataset, status = prepare_data() if train_dataset is None: training_status = status yield training_status return training_status = status yield training_status # Set up training arguments training_args = TrainingArguments( output_dir="./k2-think-finetuned", per_device_train_batch_size=1, gradient_accumulation_steps=4, num_train_epochs=1, learning_rate=2e-5, fp16=True, save_strategy="no", logging_steps=5, ) training_status = "Training configuration set up..." yield training_status # Create trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, ) training_status = "Starting training process..." yield training_status # Start training trainer.train() training_status = "Training completed! Saving model..." yield training_status # Save model model.save_pretrained("./k2-think-finetuned") tokenizer.save_pretrained("./k2-think-finetuned") training_status = "Model saved successfully! Ready for inference." yield training_status except Exception as e: training_status = f"Error during training: {str(e)}" yield training_status def generate_text(prompt): if model is None or tokenizer is None: return "Please load the model first." try: inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate( inputs.input_ids, max_length=200, num_return_sequences=1, temperature=0.7, ) return tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: return f"Error generating text: {str(e)}" # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# K2-Think Model Training") with gr.Tab("Training"): gr.Markdown("## Fine-tune K2-Think Model") with gr.Row(): load_btn = gr.Button("Load Model") train_btn = gr.Button("Start Training") status_output = gr.Textbox(label="Training Status", value=training_status) load_btn.click(load_model, outputs=status_output) train_btn.click(train_model, outputs=status_output) with gr.Tab("Inference"): gr.Markdown("## Test Your Fine-tuned Model") with gr.Row(): prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Analyze the sentiment of this movie review: This movie was amazing!") generate_btn = gr.Button("Generate") output_text = gr.Textbox(label="Generated Text") generate_btn.click(generate_text, inputs=prompt_input, outputs=output_text) demo.launch()