Joseph Pollack commited on
Commit
617401b
·
unverified ·
1 Parent(s): 0c53915

update for realtime

Browse files
Files changed (2) hide show
  1. scripts/train.py +14 -1
  2. scripts/train_lora.py +15 -0
scripts/train.py CHANGED
@@ -29,6 +29,7 @@ from typing import Tuple, Optional
29
  import torch
30
  from datasets import load_dataset, Audio, Dataset
31
  from transformers import (
 
32
  VoxtralForConditionalGeneration,
33
  VoxtralProcessor,
34
  Trainer,
@@ -254,7 +255,7 @@ def main():
254
  parser.add_argument("--dataset-config", type=str, default=None, help="HF dataset config/subset")
255
  parser.add_argument("--train-count", type=int, default=100, help="Number of training samples to use")
256
  parser.add_argument("--eval-count", type=int, default=50, help="Number of eval samples to use")
257
- parser.add_argument("--model-checkpoint", type=str, default="mistralai/Voxtral-Mini-4B-Realtime-2602")
258
  parser.add_argument("--output-dir", type=str, default="./voxtral-finetuned")
259
  parser.add_argument("--batch-size", type=int, default=2)
260
  parser.add_argument("--eval-batch-size", type=int, default=4)
@@ -359,6 +360,18 @@ def main():
359
  print("⚠️ Training will continue without experiment tracking")
360
 
361
  print("Loading processor and model...")
 
 
 
 
 
 
 
 
 
 
 
 
362
  processor = VoxtralProcessor.from_pretrained(model_checkpoint)
363
  model = VoxtralForConditionalGeneration.from_pretrained(
364
  model_checkpoint,
 
29
  import torch
30
  from datasets import load_dataset, Audio, Dataset
31
  from transformers import (
32
+ AutoConfig,
33
  VoxtralForConditionalGeneration,
34
  VoxtralProcessor,
35
  Trainer,
 
255
  parser.add_argument("--dataset-config", type=str, default=None, help="HF dataset config/subset")
256
  parser.add_argument("--train-count", type=int, default=100, help="Number of training samples to use")
257
  parser.add_argument("--eval-count", type=int, default=50, help="Number of eval samples to use")
258
+ parser.add_argument("--model-checkpoint", type=str, default="mistralai/Voxtral-Mini-3B-2507")
259
  parser.add_argument("--output-dir", type=str, default="./voxtral-finetuned")
260
  parser.add_argument("--batch-size", type=int, default=2)
261
  parser.add_argument("--eval-batch-size", type=int, default=4)
 
360
  print("⚠️ Training will continue without experiment tracking")
361
 
362
  print("Loading processor and model...")
363
+ # Full fine-tuning supports only the non-Realtime Voxtral (VoxtralForConditionalGeneration).
364
+ # Voxtral Realtime uses a different architecture and is not supported by this script yet.
365
+ try:
366
+ config = AutoConfig.from_pretrained(model_checkpoint)
367
+ except Exception:
368
+ config = None
369
+ if getattr(config, "model_type", None) == "voxtral_realtime":
370
+ raise ValueError(
371
+ "Full fine-tuning does not support Voxtral Realtime checkpoints (model_type=voxtral_realtime). "
372
+ "Use the non-Realtime Voxtral model, e.g.:\n"
373
+ " --model-checkpoint mistralai/Voxtral-Mini-3B-2507"
374
+ )
375
  processor = VoxtralProcessor.from_pretrained(model_checkpoint)
376
  model = VoxtralForConditionalGeneration.from_pretrained(
377
  model_checkpoint,
scripts/train_lora.py CHANGED
@@ -31,6 +31,7 @@ from typing import Tuple, Optional
31
  import torch
32
  from datasets import load_dataset, Audio, Dataset
33
  from transformers import (
 
34
  VoxtralForConditionalGeneration,
35
  VoxtralProcessor,
36
  Trainer,
@@ -375,6 +376,20 @@ def main():
375
  print("⚠️ Training will continue without experiment tracking")
376
 
377
  print("Loading processor and model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  processor = VoxtralProcessor.from_pretrained(model_checkpoint)
379
  lora_cfg = LoraConfig(
380
  r=args.lora_r,
 
31
  import torch
32
  from datasets import load_dataset, Audio, Dataset
33
  from transformers import (
34
+ AutoConfig,
35
  VoxtralForConditionalGeneration,
36
  VoxtralProcessor,
37
  Trainer,
 
376
  print("⚠️ Training will continue without experiment tracking")
377
 
378
  print("Loading processor and model...")
379
+ # LoRA training supports only the non-Realtime Voxtral (e.g. Voxtral-Mini-3B-2507).
380
+ # Voxtral Realtime (e.g. Voxtral-Mini-4B-Realtime-2602) uses a different config and
381
+ # is not compatible with VoxtralForConditionalGeneration.
382
+ try:
383
+ config = AutoConfig.from_pretrained(model_checkpoint)
384
+ except Exception:
385
+ config = None
386
+ if getattr(config, "model_type", None) == "voxtral_realtime":
387
+ raise ValueError(
388
+ "LoRA training does not support Voxtral Realtime checkpoints (model_type=voxtral_realtime). "
389
+ "Use the non-Realtime Voxtral model for LoRA, e.g.:\n"
390
+ " --model-checkpoint mistralai/Voxtral-Mini-3B-2507\n"
391
+ "For full fine-tuning of the Realtime model, use scripts/train.py instead."
392
+ )
393
  processor = VoxtralProcessor.from_pretrained(model_checkpoint)
394
  lora_cfg = LoraConfig(
395
  r=args.lora_r,