Mineral-Nano-1 / train_simple.py
Luke-Bergen's picture
Create train_simple.py
d7f173b verified
#!/usr/bin/env python3
"""
Simple fine-tuning script for Mineral Nano 1
Uses LoRA for efficient training on consumer GPUs
"""
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType
# ============ CONFIGURATION ============
MODEL_NAME = "Luke-Bergen/mineral-nano-1"
OUTPUT_DIR = "./mineral-nano-finetuned"
DATASET_NAME = "timdettmers/openassistant-guanaco" # Example dataset
MAX_LENGTH = 512 # Shorter for faster training
EPOCHS = 1
BATCH_SIZE = 1 # Small for limited VRAM
LEARNING_RATE = 2e-4
# =======================================
def setup_model():
"""Load model with LoRA for efficient training"""
print("="*60)
print("Loading Model with LoRA (Efficient Training)")
print("="*60)
# Load base model
print("\n[1/3] Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
# Load tokenizer
print("[2/3] Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Setup LoRA (trains only 1% of parameters!)
print("[3/3] Adding LoRA adapters...")
lora_config = LoraConfig(
r=8, # LoRA rank
lora_alpha=16,
target_modules=["q_proj", "v_proj"], # Which layers to adapt
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
print("\nβœ… Model ready for training!")
return model, tokenizer
def prepare_dataset(tokenizer):
"""Load and prepare training data"""
print("\n" + "="*60)
print("Preparing Dataset")
print("="*60)
# Load dataset
print("\nLoading dataset...")
dataset = load_dataset(DATASET_NAME, split="train[:1000]") # Use 1000 examples
print(f"βœ… Loaded {len(dataset)} examples")
# Tokenization function
def tokenize_function(examples):
# Format: "User: X\nAssistant: Y"
texts = examples["text"]
return tokenizer(
texts,
truncation=True,
max_length=MAX_LENGTH,
padding="max_length"
)
# Tokenize dataset
print("Tokenizing...")
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset.column_names
)
print("βœ… Dataset ready!")
return tokenized_dataset
def train_model(model, tokenizer, dataset):
"""Train the model"""
print("\n" + "="*60)
print("Starting Training")
print("="*60)
# Training arguments
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=4, # Effective batch size = 4
learning_rate=LEARNING_RATE,
fp16=False,
bf16=torch.cuda.is_available(),
logging_steps=10,
save_steps=100,
save_total_limit=2,
optim="adamw_torch",
warmup_steps=50,
lr_scheduler_type="cosine",
report_to="none", # Change to "wandb" for tracking
)
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator,
)
# Train!
print("\nπŸš€ Training started...")
print("This will take time depending on your hardware:")
print(" - RTX 3090: ~2-3 hours")
print(" - RTX 4090: ~1-2 hours")
print(" - CPU: 10-20 hours (not recommended)")
print("\n" + "="*60 + "\n")
trainer.train()
print("\nβœ… Training complete!")
# Save model
print("\nSaving fine-tuned model...")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"βœ… Model saved to: {OUTPUT_DIR}")
return trainer
def test_model(model, tokenizer):
"""Quick test of the trained model"""
print("\n" + "="*60)
print("Testing Fine-Tuned Model")
print("="*60)
test_prompts = [
"Hello! How are you?",
"What is machine learning?",
"Write a short poem about AI."
]
model.eval()
for prompt in test_prompts:
print(f"\nπŸ’¬ Input: {prompt}")
print("πŸ€– Output: ", end="")
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(prompt):].strip()
print(response)
print("\n" + "="*60)
def main():
"""Main training pipeline"""
print("\n")
print("╔════════════════════════════════════════════════════════╗")
print("β•‘ β•‘")
print("β•‘ MINERAL NANO 1 - SIMPLE TRAINING SCRIPT β•‘")
print("β•‘ LoRA Fine-Tuning (Efficient) β•‘")
print("β•‘ β•‘")
print("β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•")
# Check GPU
if torch.cuda.is_available():
print(f"\nβœ… GPU detected: {torch.cuda.get_device_name(0)}")
print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
print("\n⚠️ No GPU detected - training will be VERY slow")
response = input("Continue anyway? (y/n): ")
if response.lower() != 'y':
return
print("\nπŸ“‹ Training Configuration:")
print(f" Model: {MODEL_NAME}")
print(f" Dataset: {DATASET_NAME}")
print(f" Epochs: {EPOCHS}")
print(f" Batch Size: {BATCH_SIZE}")
print(f" Learning Rate: {LEARNING_RATE}")
print(f" Output: {OUTPUT_DIR}")
input("\nPress ENTER to start training (or Ctrl+C to cancel)...")
# Setup
model, tokenizer = setup_model()
dataset = prepare_dataset(tokenizer)
# Train
trainer = train_model(model, tokenizer, dataset)
# Test
test_model(model, tokenizer)
print("\n" + "="*60)
print("πŸŽ‰ ALL DONE!")
print("="*60)
print(f"\nYour fine-tuned model is saved in: {OUTPUT_DIR}")
print("\nNext steps:")
print("1. Test more with: python test_model.py")
print("2. Upload to HuggingFace: model.push_to_hub('your-username/model-name')")
print("3. Continue training with more data")
print("="*60)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\n⚠️ Training interrupted!")
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()