Gemma-2B Fine-tuned with SFT (400 Steps)
This model is a fine-tuned version of google/gemma-2b on the won-bae/bpo_preference_hh_data dataset using Supervised Fine-Tuning (SFT).
Model Details
Model Description
- Developed by: MohamadBazzi
- Model type: Conversational Language Model
- Language(s): English
- License: Apache 2.0
- Finetuned from model: google/gemma-2b
Training Details
Training Data
- Dataset: won-bae/bpo_preference_hh_data
- Dataset Size: BPO preference data from Anthropic HH-RLHF
- Format: Conversational (prompt-response pairs)
Training Procedure
Training Hyperparameters
- Training regime: Best-of-n Preference Optimization (BPO) + Supervised Fine-tuning (SFT)
- Training steps: 400 (checkpoint from 750 total steps)
- Learning rate: 1e-5
- Batch size: 4 per device
- Gradient accumulation: 4 steps
- Sequence length: 512 tokens
- Optimizer: AdamW
- LoRA configuration: r=32, alpha=64, dropout=0.05
- Precision: bfloat16 with 4-bit quantization
Usage
Direct Use
This model is intended for conversational AI applications and can generate helpful, harmless responses to user queries.
Inference
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Load the model
model = AutoModelForCausalLM.from_pretrained(
"MohamadBazzi/gemma-bpo-sft",
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("MohamadBazzi/gemma-bpo-sft")
# Format your prompt
prompt = "Human: Hello, can you help me with Python programming?
Assistant: "
# Generate response
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
**inputs,
max_length=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
Chat Format
The model expects input in conversational format similar to Anthropic's HH format:
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("MohamadBazzi/gemma-bpo-sft")
model = AutoModelForCausalLM.from_pretrained("MohamadBazzi/gemma-bpo-sft")
# Example usage
prompt = "Human: What are the key principles of machine learning?
Assistant: "
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=200, temperature=0.7)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
Model Performance
This checkpoint represents an intermediate stage in the training process:
- Checkpoint: 400/750 steps
- Training method: SFT
- Alignment: Optimized for helpful and harmless responses
Training Infrastructure
This model was trained using:
- Compute Canada's Digital Research Alliance (Narval cluster)
- TRL (Transformer Reinforcement Learning) library
- Hugging Face Transformers
- LoRA/PEFT for efficient fine-tuning
Limitations
- This is an intermediate checkpoint and may not represent the fully converged model
- Intended primarily for research purposes
- May exhibit biases present in the training data
Citation
@model{gemma-bpo-sft-2024,
title={Gemma-2B Fine-tuned with BPO and SFT},
author={MohamadBazzi},
year={2024},
url={https://huggingface.co/MohamadBazzi/gemma-bpo-sft}
}
- Downloads last month
- 1,313
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support
Model tree for MohamadBazzi/gemma-bpo-sft
Base model
google/gemma-2b