saikamal1108 commited on
Commit
9149ed1
·
verified ·
1 Parent(s): b109215

Create train_pipeline.py

Browse files
Files changed (1) hide show
  1. train_pipeline.py +104 -0
train_pipeline.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_pipeline.py
2
+ import re
3
+ import numpy as np
4
+ from jiwer import wer
5
+ from transformers import (
6
+ Wav2Vec2ForCTC,
7
+ Wav2Vec2CTCTokenizer,
8
+ Wav2Vec2FeatureExtractor,
9
+ Wav2Vec2Processor,
10
+ TrainingArguments,
11
+ Trainer
12
+ )
13
+ from datasets import Audio
14
+ import torch
15
+ from dataclasses import dataclass
16
+ from preprocess import load_telugu_dataset, normalize_text
17
+ from vocab import build_vocab
18
+
19
+ def prepare_dataset(batch, processor):
20
+ speech = batch["audio"]["array"]
21
+ batch["input_values"] = processor(speech, sampling_rate=16000).input_values[0]
22
+ batch["labels"] = processor.tokenizer(normalize_text(batch["text"])).input_ids
23
+ return batch
24
+
25
+ @dataclass
26
+ class DataCollatorCTC:
27
+ processor: Wav2Vec2Processor
28
+ padding: bool = True
29
+
30
+ def __call__(self, features):
31
+ inputs = [{"input_values": f["input_values"]} for f in features]
32
+ labels = [{"input_ids": f["labels"]} for f in features]
33
+
34
+ batch = self.processor.pad(inputs, return_tensors="pt")
35
+ with self.processor.as_target_processor():
36
+ labels_batch = self.processor.pad(labels, return_tensors="pt")
37
+
38
+ labels = labels_batch["input_ids"]
39
+ labels[labels == self.processor.tokenizer.pad_token_id] = -100
40
+ batch["labels"] = labels
41
+ return batch
42
+
43
+ def train_model():
44
+ # 1. Load dataset
45
+ ds = load_telugu_dataset()
46
+ ds = ds.train_test_split(test_size=0.1)
47
+ train = ds["train"]
48
+ test = ds["test"]
49
+
50
+ # 2. Build vocab
51
+ build_vocab(train, text_col="text")
52
+
53
+ # 3. Processor
54
+ tokenizer = Wav2Vec2CTCTokenizer("vocab.json", pad_token="[PAD]", unk_token="[UNK]")
55
+ extractor = Wav2Vec2FeatureExtractor(sampling_rate=16000, do_normalize=True)
56
+ processor = Wav2Vec2Processor(extractor, tokenizer)
57
+
58
+ # 4. Prepare
59
+ train = train.map(lambda x: prepare_dataset(x, processor))
60
+ test = test.map(lambda x: prepare_dataset(x, processor))
61
+
62
+ # 5. Load XLS-R model
63
+ model = Wav2Vec2ForCTC.from_pretrained(
64
+ "facebook/wav2vec2-xls-r-300m",
65
+ vocab_size=len(tokenizer),
66
+ pad_token_id=tokenizer.pad_token_id,
67
+ ctc_loss_reduction="mean"
68
+ )
69
+ model.freeze_feature_extractor()
70
+
71
+ data_collator = DataCollatorCTC(processor)
72
+
73
+ def compute_metrics(pred):
74
+ pred_ids = np.argmax(pred.predictions, axis=-1)
75
+ pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
76
+
77
+ preds = processor.batch_decode(pred_ids)
78
+ refs = processor.batch_decode(pred.label_ids)
79
+ return {"wer": wer(refs, preds)}
80
+
81
+ args = TrainingArguments(
82
+ output_dir="./model",
83
+ per_device_train_batch_size=4,
84
+ per_device_eval_batch_size=4,
85
+ fp16=True,
86
+ evaluation_strategy="epoch",
87
+ save_strategy="epoch",
88
+ num_train_epochs=5,
89
+ push_to_hub=True,
90
+ hub_model_id="your-username/telugu-asr-xlsr"
91
+ )
92
+
93
+ trainer = Trainer(
94
+ model=model,
95
+ args=args,
96
+ train_dataset=train,
97
+ eval_dataset=test,
98
+ tokenizer=processor.feature_extractor,
99
+ data_collator=data_collator,
100
+ compute_metrics=compute_metrics,
101
+ )
102
+
103
+ trainer.train()
104
+ trainer.push_to_hub()