Update eval.py
Browse files
eval.py
CHANGED
|
@@ -6,8 +6,8 @@ from typing import Dict
|
|
| 6 |
import torch
|
| 7 |
from datasets import Audio, Dataset, load_dataset, load_metric
|
| 8 |
|
| 9 |
-
from transformers import AutoFeatureExtractor, pipeline, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2FeatureExtractor
|
| 10 |
-
from pyctcdecode import BeamSearchDecoderCTC
|
| 11 |
|
| 12 |
|
| 13 |
def log_results(result: Dataset, args: Dict[str, str]):
|
|
@@ -15,8 +15,8 @@ def log_results(result: Dataset, args: Dict[str, str]):
|
|
| 15 |
|
| 16 |
log_outputs = args.log_outputs
|
| 17 |
lm = "withLM" if args.use_lm else "noLM"
|
| 18 |
-
model_id = args.model_id.replace("/", "_")
|
| 19 |
-
dataset_id = "_".join(args.dataset.split("/") + [
|
| 20 |
|
| 21 |
# load metric
|
| 22 |
wer = load_metric("wer")
|
|
@@ -27,7 +27,7 @@ def log_results(result: Dataset, args: Dict[str, str]):
|
|
| 27 |
cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
|
| 28 |
|
| 29 |
# print & log results
|
| 30 |
-
result_str = f"
|
| 31 |
print(result_str)
|
| 32 |
|
| 33 |
with open(f"{dataset_id}_eval_results.txt", "w") as f:
|
|
@@ -57,7 +57,7 @@ def normalize_text(text: str, dataset: str) -> str:
|
|
| 57 |
|
| 58 |
if dataset.lower().endswith("nst"):
|
| 59 |
text = text.lower()
|
| 60 |
-
text = text.replace("(...
|
| 61 |
text = re.sub('[áàâ]', 'a', text)
|
| 62 |
text = re.sub('[ä]', 'æ', text)
|
| 63 |
text = re.sub('[éèëê]', 'e', text)
|
|
@@ -78,10 +78,10 @@ def normalize_text(text: str, dataset: str) -> str:
|
|
| 78 |
text = re.sub('[ç]', 'c', text)
|
| 79 |
text = re.sub('[úùüû]', 'u', text)
|
| 80 |
text = re.sub('\s', ' ', text)
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
|
| 86 |
# # In addition, we can normalize the target text, e.g. removing new lines characters etc...
|
| 87 |
# # note that order is important here!
|
|
@@ -112,11 +112,27 @@ def main(args):
|
|
| 112 |
args.device = 0 if torch.cuda.is_available() else -1
|
| 113 |
# asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
asr = pipeline("automatic-speech-recognition", model=args.model_id, feature_extractor=feature_extractor, device=args.device, decoder=BeamSearchDecoderCTC.load_from_dir("./"))
|
| 120 |
|
| 121 |
# map function to decode audio
|
| 122 |
def map_to_pred(batch):
|
|
@@ -172,4 +188,4 @@ if __name__ == "__main__":
|
|
| 172 |
)
|
| 173 |
args = parser.parse_args()
|
| 174 |
|
| 175 |
-
main(args)
|
|
|
|
| 6 |
import torch
|
| 7 |
from datasets import Audio, Dataset, load_dataset, load_metric
|
| 8 |
|
| 9 |
+
from transformers import AutoFeatureExtractor, AutoModelForCTC, pipeline, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, Wav2Vec2FeatureExtractor
|
| 10 |
+
# from pyctcdecode import BeamSearchDecoderCTC
|
| 11 |
|
| 12 |
|
| 13 |
def log_results(result: Dataset, args: Dict[str, str]):
|
|
|
|
| 15 |
|
| 16 |
log_outputs = args.log_outputs
|
| 17 |
lm = "withLM" if args.use_lm else "noLM"
|
| 18 |
+
model_id = args.model_id.replace("/", "_").replace(".", "")
|
| 19 |
+
dataset_id = "_".join([model_id] + args.dataset.split("/") + [args.config, args.split, lm])
|
| 20 |
|
| 21 |
# load metric
|
| 22 |
wer = load_metric("wer")
|
|
|
|
| 27 |
cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
|
| 28 |
|
| 29 |
# print & log results
|
| 30 |
+
result_str = f"{dataset_id}\nWER: {wer_result}\nCER: {cer_result}"
|
| 31 |
print(result_str)
|
| 32 |
|
| 33 |
with open(f"{dataset_id}_eval_results.txt", "w") as f:
|
|
|
|
| 57 |
|
| 58 |
if dataset.lower().endswith("nst"):
|
| 59 |
text = text.lower()
|
| 60 |
+
text = text.replace("(...vær stille under dette opptaket...)", "")
|
| 61 |
text = re.sub('[áàâ]', 'a', text)
|
| 62 |
text = re.sub('[ä]', 'æ', text)
|
| 63 |
text = re.sub('[éèëê]', 'e', text)
|
|
|
|
| 78 |
text = re.sub('[ç]', 'c', text)
|
| 79 |
text = re.sub('[úùüû]', 'u', text)
|
| 80 |
text = re.sub('\s', ' ', text)
|
| 81 |
+
text = re.sub("<ee(eh)?>", "e", text)
|
| 82 |
+
text = re.sub("<mmm?>", "m", text)
|
| 83 |
+
text = re.sub("<qq>", "q", text)
|
| 84 |
+
text = re.sub("<inaudible>", "i", text)
|
| 85 |
|
| 86 |
# # In addition, we can normalize the target text, e.g. removing new lines characters etc...
|
| 87 |
# # note that order is important here!
|
|
|
|
| 112 |
args.device = 0 if torch.cuda.is_available() else -1
|
| 113 |
# asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
|
| 114 |
|
| 115 |
+
model_instance = AutoModelForCTC.from_pretrained(args.model_id)
|
| 116 |
+
if args.use_lm:
|
| 117 |
+
processor = Wav2Vec2ProcessorWithLM.from_pretrained(args.model_id)
|
| 118 |
+
decoder = processor.decoder
|
| 119 |
+
else:
|
| 120 |
+
processor = Wav2Vec2Processor.from_pretrained(args.model_id)
|
| 121 |
+
decoder = None
|
| 122 |
+
asr = pipeline(
|
| 123 |
+
"automatic-speech-recognition",
|
| 124 |
+
model=model_instance,
|
| 125 |
+
tokenizer=processor.tokenizer,
|
| 126 |
+
feature_extractor=processor.feature_extractor,
|
| 127 |
+
decoder=decoder,
|
| 128 |
+
device=args.device
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# feature_extractor_dict, _ = Wav2Vec2FeatureExtractor.get_feature_extractor_dict(args.model_id)
|
| 132 |
+
# feature_extractor_dict["processor_class"] = "Wav2Vec2Processor" if not args.use_lm else "Wav2Vec2ProcessorWithLM"
|
| 133 |
+
# feature_extractor = Wav2Vec2FeatureExtractor.from_dict(feature_extractor_dict)
|
| 134 |
|
| 135 |
+
# asr = pipeline("automatic-speech-recognition", model=args.model_id, feature_extractor=feature_extractor, device=args.device, decoder=BeamSearchDecoderCTC.load_from_dir("./"))
|
| 136 |
|
| 137 |
# map function to decode audio
|
| 138 |
def map_to_pred(batch):
|
|
|
|
| 188 |
)
|
| 189 |
args = parser.parse_args()
|
| 190 |
|
| 191 |
+
main(args)
|