zeroranker commited on
Commit
f1b46ec
·
verified ·
1 Parent(s): f0d44da

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +179 -0
app.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from datasets import load_dataset
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, BitsAndBytesConfig
5
+ import os
6
+ import time
7
+
8
+ # Disable wandb
9
+ os.environ["WANDB_DISABLED"] = "true"
10
+
11
+ # Global variables
12
+ model = None
13
+ tokenizer = None
14
+ training_status = "Not started"
15
+
16
+ def load_model():
17
+ global model, tokenizer
18
+ try:
19
+ # Configure 4-bit quantization for memory efficiency
20
+ quantization_config = BitsAndBytesConfig(
21
+ load_in_4bit=True,
22
+ bnb_4bit_compute_dtype=torch.float16,
23
+ bnb_4bit_quant_type="nf4",
24
+ )
25
+
26
+ # Load model and tokenizer
27
+ model_name = "LLM360/K2-Think"
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ model_name,
31
+ quantization_config=quantization_config,
32
+ device_map="auto"
33
+ )
34
+
35
+ # Set padding token
36
+ if tokenizer.pad_token is None:
37
+ tokenizer.pad_token = tokenizer.eos_token
38
+
39
+ return "Model loaded successfully!"
40
+ except Exception as e:
41
+ return f"Error loading model: {str(e)}"
42
+
43
+ def prepare_data():
44
+ try:
45
+ # Load a sample dataset (you can replace this with your own)
46
+ dataset = load_dataset("imdb")
47
+
48
+ # Preprocessing function
49
+ def preprocess_function(examples):
50
+ # Format the text for instruction tuning
51
+ texts = []
52
+ for text, label in zip(examples["text"], examples["label"]):
53
+ sentiment = "positive" if label == 1 else "negative"
54
+ texts.append(f"Analyze the sentiment of this movie review: {text}\nSentiment: {sentiment}")
55
+
56
+ # Tokenize
57
+ tokenized = tokenizer(texts, truncation=True, padding=True, max_length=256)
58
+
59
+ # Create labels
60
+ tokenized["labels"] = tokenized["input_ids"].copy()
61
+
62
+ return tokenized
63
+
64
+ # Apply preprocessing
65
+ tokenized_dataset = dataset.map(
66
+ preprocess_function,
67
+ batched=True,
68
+ remove_columns=dataset["train"].column_names
69
+ )
70
+
71
+ # Use small subset for demo
72
+ train_dataset = tokenized_dataset["train"].shuffle().select(range(50))
73
+
74
+ return train_dataset, "Data prepared successfully!"
75
+ except Exception as e:
76
+ return None, f"Error preparing data: {str(e)}"
77
+
78
+ def train_model():
79
+ global model, tokenizer, training_status
80
+ try:
81
+ training_status = "Starting training..."
82
+ yield training_status
83
+
84
+ # Prepare data
85
+ train_dataset, status = prepare_data()
86
+ if train_dataset is None:
87
+ training_status = status
88
+ yield training_status
89
+ return
90
+
91
+ training_status = status
92
+ yield training_status
93
+
94
+ # Set up training arguments
95
+ training_args = TrainingArguments(
96
+ output_dir="./k2-think-finetuned",
97
+ per_device_train_batch_size=1,
98
+ gradient_accumulation_steps=4,
99
+ num_train_epochs=1,
100
+ learning_rate=2e-5,
101
+ fp16=True,
102
+ save_strategy="no",
103
+ logging_steps=5,
104
+ )
105
+
106
+ training_status = "Training configuration set up..."
107
+ yield training_status
108
+
109
+ # Create trainer
110
+ trainer = Trainer(
111
+ model=model,
112
+ args=training_args,
113
+ train_dataset=train_dataset,
114
+ )
115
+
116
+ training_status = "Starting training process..."
117
+ yield training_status
118
+
119
+ # Start training
120
+ trainer.train()
121
+
122
+ training_status = "Training completed! Saving model..."
123
+ yield training_status
124
+
125
+ # Save model
126
+ model.save_pretrained("./k2-think-finetuned")
127
+ tokenizer.save_pretrained("./k2-think-finetuned")
128
+
129
+ training_status = "Model saved successfully! Ready for inference."
130
+ yield training_status
131
+
132
+ except Exception as e:
133
+ training_status = f"Error during training: {str(e)}"
134
+ yield training_status
135
+
136
+ def generate_text(prompt):
137
+ if model is None or tokenizer is None:
138
+ return "Please load the model first."
139
+
140
+ try:
141
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
142
+ outputs = model.generate(
143
+ inputs.input_ids,
144
+ max_length=200,
145
+ num_return_sequences=1,
146
+ temperature=0.7,
147
+ )
148
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
149
+ except Exception as e:
150
+ return f"Error generating text: {str(e)}"
151
+
152
+ # Create the Gradio interface
153
+ with gr.Blocks() as demo:
154
+ gr.Markdown("# K2-Think Model Training")
155
+
156
+ with gr.Tab("Training"):
157
+ gr.Markdown("## Fine-tune K2-Think Model")
158
+
159
+ with gr.Row():
160
+ load_btn = gr.Button("Load Model")
161
+ train_btn = gr.Button("Start Training")
162
+
163
+ status_output = gr.Textbox(label="Training Status", value=training_status)
164
+
165
+ load_btn.click(load_model, outputs=status_output)
166
+ train_btn.click(train_model, outputs=status_output)
167
+
168
+ with gr.Tab("Inference"):
169
+ gr.Markdown("## Test Your Fine-tuned Model")
170
+
171
+ with gr.Row():
172
+ prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Analyze the sentiment of this movie review: This movie was amazing!")
173
+ generate_btn = gr.Button("Generate")
174
+
175
+ output_text = gr.Textbox(label="Generated Text")
176
+
177
+ generate_btn.click(generate_text, inputs=prompt_input, outputs=output_text)
178
+
179
+ demo.launch()