versae commited on
Commit
13308d6
·
1 Parent(s): e2da4ba

Update eval.py

Browse files
Files changed (1) hide show
  1. eval.py +31 -15
eval.py CHANGED
@@ -6,8 +6,8 @@ from typing import Dict
6
  import torch
7
  from datasets import Audio, Dataset, load_dataset, load_metric
8
 
9
- from transformers import AutoFeatureExtractor, pipeline, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2FeatureExtractor
10
- from pyctcdecode import BeamSearchDecoderCTC
11
 
12
 
13
  def log_results(result: Dataset, args: Dict[str, str]):
@@ -15,8 +15,8 @@ def log_results(result: Dataset, args: Dict[str, str]):
15
 
16
  log_outputs = args.log_outputs
17
  lm = "withLM" if args.use_lm else "noLM"
18
- model_id = args.model_id.replace("/", "_")
19
- dataset_id = "_".join(args.dataset.split("/") + [model_id, args.config, args.split, lm])
20
 
21
  # load metric
22
  wer = load_metric("wer")
@@ -27,7 +27,7 @@ def log_results(result: Dataset, args: Dict[str, str]):
27
  cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
28
 
29
  # print & log results
30
- result_str = f"WER: {wer_result}\n" f"CER: {cer_result}"
31
  print(result_str)
32
 
33
  with open(f"{dataset_id}_eval_results.txt", "w") as f:
@@ -57,7 +57,7 @@ def normalize_text(text: str, dataset: str) -> str:
57
 
58
  if dataset.lower().endswith("nst"):
59
  text = text.lower()
60
- text = text.replace("(...Vær stille under dette opptaket...)", "")
61
  text = re.sub('[áàâ]', 'a', text)
62
  text = re.sub('[ä]', 'æ', text)
63
  text = re.sub('[éèëê]', 'e', text)
@@ -78,10 +78,10 @@ def normalize_text(text: str, dataset: str) -> str:
78
  text = re.sub('[ç]', 'c', text)
79
  text = re.sub('[úùüû]', 'u', text)
80
  text = re.sub('\s', ' ', text)
81
- text = re.sub('<ee>', 'eee', text)
82
- text = re.sub('<qq>', 'qqq', text)
83
- text = re.sub('<mm>', 'mmm', text)
84
- text = re.sub('<inaudible>', 'xxx', text)
85
 
86
  # # In addition, we can normalize the target text, e.g. removing new lines characters etc...
87
  # # note that order is important here!
@@ -112,11 +112,27 @@ def main(args):
112
  args.device = 0 if torch.cuda.is_available() else -1
113
  # asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
114
 
115
- feature_extractor_dict, _ = Wav2Vec2FeatureExtractor.get_feature_extractor_dict(args.model_id)
116
- feature_extractor_dict["processor_class"] = "Wav2Vec2Processor" if not args.use_lm else "Wav2Vec2ProcessorWithLM"
117
- feature_extractor = Wav2Vec2FeatureExtractor.from_dict(feature_extractor_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- asr = pipeline("automatic-speech-recognition", model=args.model_id, feature_extractor=feature_extractor, device=args.device, decoder=BeamSearchDecoderCTC.load_from_dir("./"))
120
 
121
  # map function to decode audio
122
  def map_to_pred(batch):
@@ -172,4 +188,4 @@ if __name__ == "__main__":
172
  )
173
  args = parser.parse_args()
174
 
175
- main(args)
 
6
  import torch
7
  from datasets import Audio, Dataset, load_dataset, load_metric
8
 
9
+ from transformers import AutoFeatureExtractor, AutoModelForCTC, pipeline, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2FeatureExtractor
10
+ # from pyctcdecode import BeamSearchDecoderCTC
11
 
12
 
13
  def log_results(result: Dataset, args: Dict[str, str]):
 
15
 
16
  log_outputs = args.log_outputs
17
  lm = "withLM" if args.use_lm else "noLM"
18
+ model_id = args.model_id.replace("/", "_").replace(".", "")
19
+ dataset_id = "_".join([model_id] + args.dataset.split("/") + [args.config, args.split, lm])
20
 
21
  # load metric
22
  wer = load_metric("wer")
 
27
  cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
28
 
29
  # print & log results
30
+ result_str = f"{dataset_id}\nWER: {wer_result}\nCER: {cer_result}"
31
  print(result_str)
32
 
33
  with open(f"{dataset_id}_eval_results.txt", "w") as f:
 
57
 
58
  if dataset.lower().endswith("nst"):
59
  text = text.lower()
60
+ text = text.replace("(...vær stille under dette opptaket...)", "")
61
  text = re.sub('[áàâ]', 'a', text)
62
  text = re.sub('[ä]', 'æ', text)
63
  text = re.sub('[éèëê]', 'e', text)
 
78
  text = re.sub('[ç]', 'c', text)
79
  text = re.sub('[úùüû]', 'u', text)
80
  text = re.sub('\s', ' ', text)
81
+ text = re.sub("<ee(eh)?>", "e", text)
82
+ text = re.sub("<mmm?>", "m", text)
83
+ text = re.sub("<qq>", "q", text)
84
+ text = re.sub("<inaudible>", "i", text)
85
 
86
  # # In addition, we can normalize the target text, e.g. removing new lines characters etc...
87
  # # note that order is important here!
 
112
  args.device = 0 if torch.cuda.is_available() else -1
113
  # asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
114
 
115
+ model_instance = AutoModelForCTC.from_pretrained(args.model_id)
116
+ if args.use_lm:
117
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained(args.model_id)
118
+ decoder = processor.decoder
119
+ else:
120
+ processor = Wav2Vec2Processor.from_pretrained(args.model_id)
121
+ decoder = None
122
+ asr = pipeline(
123
+ "automatic-speech-recognition",
124
+ model=model_instance,
125
+ tokenizer=processor.tokenizer,
126
+ feature_extractor=processor.feature_extractor,
127
+ decoder=decoder,
128
+ device=args.device
129
+ )
130
+
131
+ # feature_extractor_dict, _ = Wav2Vec2FeatureExtractor.get_feature_extractor_dict(args.model_id)
132
+ # feature_extractor_dict["processor_class"] = "Wav2Vec2Processor" if not args.use_lm else "Wav2Vec2ProcessorWithLM"
133
+ # feature_extractor = Wav2Vec2FeatureExtractor.from_dict(feature_extractor_dict)
134
 
135
+ # asr = pipeline("automatic-speech-recognition", model=args.model_id, feature_extractor=feature_extractor, device=args.device, decoder=BeamSearchDecoderCTC.load_from_dir("./"))
136
 
137
  # map function to decode audio
138
  def map_to_pred(batch):
 
188
  )
189
  args = parser.parse_args()
190
 
191
+ main(args)