Spaces:
Sleeping
Sleeping
Joseph Pollack commited on
update for realtime
Browse files- scripts/train.py +14 -1
- scripts/train_lora.py +15 -0
scripts/train.py
CHANGED
|
@@ -29,6 +29,7 @@ from typing import Tuple, Optional
|
|
| 29 |
import torch
|
| 30 |
from datasets import load_dataset, Audio, Dataset
|
| 31 |
from transformers import (
|
|
|
|
| 32 |
VoxtralForConditionalGeneration,
|
| 33 |
VoxtralProcessor,
|
| 34 |
Trainer,
|
|
@@ -254,7 +255,7 @@ def main():
|
|
| 254 |
parser.add_argument("--dataset-config", type=str, default=None, help="HF dataset config/subset")
|
| 255 |
parser.add_argument("--train-count", type=int, default=100, help="Number of training samples to use")
|
| 256 |
parser.add_argument("--eval-count", type=int, default=50, help="Number of eval samples to use")
|
| 257 |
-
parser.add_argument("--model-checkpoint", type=str, default="mistralai/Voxtral-Mini-
|
| 258 |
parser.add_argument("--output-dir", type=str, default="./voxtral-finetuned")
|
| 259 |
parser.add_argument("--batch-size", type=int, default=2)
|
| 260 |
parser.add_argument("--eval-batch-size", type=int, default=4)
|
|
@@ -359,6 +360,18 @@ def main():
|
|
| 359 |
print("⚠️ Training will continue without experiment tracking")
|
| 360 |
|
| 361 |
print("Loading processor and model...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
processor = VoxtralProcessor.from_pretrained(model_checkpoint)
|
| 363 |
model = VoxtralForConditionalGeneration.from_pretrained(
|
| 364 |
model_checkpoint,
|
|
|
|
| 29 |
import torch
|
| 30 |
from datasets import load_dataset, Audio, Dataset
|
| 31 |
from transformers import (
|
| 32 |
+
AutoConfig,
|
| 33 |
VoxtralForConditionalGeneration,
|
| 34 |
VoxtralProcessor,
|
| 35 |
Trainer,
|
|
|
|
| 255 |
parser.add_argument("--dataset-config", type=str, default=None, help="HF dataset config/subset")
|
| 256 |
parser.add_argument("--train-count", type=int, default=100, help="Number of training samples to use")
|
| 257 |
parser.add_argument("--eval-count", type=int, default=50, help="Number of eval samples to use")
|
| 258 |
+
parser.add_argument("--model-checkpoint", type=str, default="mistralai/Voxtral-Mini-3B-2507")
|
| 259 |
parser.add_argument("--output-dir", type=str, default="./voxtral-finetuned")
|
| 260 |
parser.add_argument("--batch-size", type=int, default=2)
|
| 261 |
parser.add_argument("--eval-batch-size", type=int, default=4)
|
|
|
|
| 360 |
print("⚠️ Training will continue without experiment tracking")
|
| 361 |
|
| 362 |
print("Loading processor and model...")
|
| 363 |
+
# Full fine-tuning supports only the non-Realtime Voxtral (VoxtralForConditionalGeneration).
|
| 364 |
+
# Voxtral Realtime uses a different architecture and is not supported by this script yet.
|
| 365 |
+
try:
|
| 366 |
+
config = AutoConfig.from_pretrained(model_checkpoint)
|
| 367 |
+
except Exception:
|
| 368 |
+
config = None
|
| 369 |
+
if getattr(config, "model_type", None) == "voxtral_realtime":
|
| 370 |
+
raise ValueError(
|
| 371 |
+
"Full fine-tuning does not support Voxtral Realtime checkpoints (model_type=voxtral_realtime). "
|
| 372 |
+
"Use the non-Realtime Voxtral model, e.g.:\n"
|
| 373 |
+
" --model-checkpoint mistralai/Voxtral-Mini-3B-2507"
|
| 374 |
+
)
|
| 375 |
processor = VoxtralProcessor.from_pretrained(model_checkpoint)
|
| 376 |
model = VoxtralForConditionalGeneration.from_pretrained(
|
| 377 |
model_checkpoint,
|
scripts/train_lora.py
CHANGED
|
@@ -31,6 +31,7 @@ from typing import Tuple, Optional
|
|
| 31 |
import torch
|
| 32 |
from datasets import load_dataset, Audio, Dataset
|
| 33 |
from transformers import (
|
|
|
|
| 34 |
VoxtralForConditionalGeneration,
|
| 35 |
VoxtralProcessor,
|
| 36 |
Trainer,
|
|
@@ -375,6 +376,20 @@ def main():
|
|
| 375 |
print("⚠️ Training will continue without experiment tracking")
|
| 376 |
|
| 377 |
print("Loading processor and model...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
processor = VoxtralProcessor.from_pretrained(model_checkpoint)
|
| 379 |
lora_cfg = LoraConfig(
|
| 380 |
r=args.lora_r,
|
|
|
|
| 31 |
import torch
|
| 32 |
from datasets import load_dataset, Audio, Dataset
|
| 33 |
from transformers import (
|
| 34 |
+
AutoConfig,
|
| 35 |
VoxtralForConditionalGeneration,
|
| 36 |
VoxtralProcessor,
|
| 37 |
Trainer,
|
|
|
|
| 376 |
print("⚠️ Training will continue without experiment tracking")
|
| 377 |
|
| 378 |
print("Loading processor and model...")
|
| 379 |
+
# LoRA training supports only the non-Realtime Voxtral (e.g. Voxtral-Mini-3B-2507).
|
| 380 |
+
# Voxtral Realtime (e.g. Voxtral-Mini-4B-Realtime-2602) uses a different config and
|
| 381 |
+
# is not compatible with VoxtralForConditionalGeneration.
|
| 382 |
+
try:
|
| 383 |
+
config = AutoConfig.from_pretrained(model_checkpoint)
|
| 384 |
+
except Exception:
|
| 385 |
+
config = None
|
| 386 |
+
if getattr(config, "model_type", None) == "voxtral_realtime":
|
| 387 |
+
raise ValueError(
|
| 388 |
+
"LoRA training does not support Voxtral Realtime checkpoints (model_type=voxtral_realtime). "
|
| 389 |
+
"Use the non-Realtime Voxtral model for LoRA, e.g.:\n"
|
| 390 |
+
" --model-checkpoint mistralai/Voxtral-Mini-3B-2507\n"
|
| 391 |
+
"For full fine-tuning of the Realtime model, use scripts/train.py instead."
|
| 392 |
+
)
|
| 393 |
processor = VoxtralProcessor.from_pretrained(model_checkpoint)
|
| 394 |
lora_cfg = LoraConfig(
|
| 395 |
r=args.lora_r,
|