Fix train script for NPSC
Browse files
run_speech_recognition_ctc.py
CHANGED
|
@@ -391,6 +391,23 @@ def main():
|
|
| 391 |
# Set seed before initializing model.
|
| 392 |
set_seed(training_args.seed)
|
| 393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
# 1. First, let's load the dataset
|
| 395 |
raw_datasets = DatasetDict()
|
| 396 |
|
|
@@ -401,6 +418,8 @@ def main():
|
|
| 401 |
split=data_args.train_split_name,
|
| 402 |
use_auth_token=data_args.use_auth_token,
|
| 403 |
)
|
|
|
|
|
|
|
| 404 |
|
| 405 |
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
| 406 |
raise ValueError(
|
|
@@ -426,6 +445,8 @@ def main():
|
|
| 426 |
split=data_args.eval_split_name,
|
| 427 |
use_auth_token=data_args.use_auth_token,
|
| 428 |
)
|
|
|
|
|
|
|
| 429 |
|
| 430 |
if data_args.max_eval_samples is not None:
|
| 431 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
|
|
|
| 391 |
# Set seed before initializing model.
|
| 392 |
set_seed(training_args.seed)
|
| 393 |
|
| 394 |
+
# Pre-processing dataset
|
| 395 |
+
def preprocess_dataset(entry):
|
| 396 |
+
return (
|
| 397 |
+
"<INAUDIBLE>" not in entry["text"]
|
| 398 |
+
and entry["sentence_language_code"].lower() == "nb-no"
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
def map_dataset(entry):
|
| 402 |
+
return {"text": (entry["text"]
|
| 403 |
+
.lower()
|
| 404 |
+
.replace("<ee>", "eee")
|
| 405 |
+
.replace("<mm>", "mmm")
|
| 406 |
+
.replace("<qq>", "qqq")
|
| 407 |
+
.replace("ó", "o")
|
| 408 |
+
.replace("é", "e")
|
| 409 |
+
)}
|
| 410 |
+
|
| 411 |
# 1. First, let's load the dataset
|
| 412 |
raw_datasets = DatasetDict()
|
| 413 |
|
|
|
|
| 418 |
split=data_args.train_split_name,
|
| 419 |
use_auth_token=data_args.use_auth_token,
|
| 420 |
)
|
| 421 |
+
raw_datasets["train"] = raw_datasets["train"].filter(preprocess_dataset)
|
| 422 |
+
raw_datasets["train"] = raw_datasets["train"].map(map_dataset)
|
| 423 |
|
| 424 |
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
| 425 |
raise ValueError(
|
|
|
|
| 445 |
split=data_args.eval_split_name,
|
| 446 |
use_auth_token=data_args.use_auth_token,
|
| 447 |
)
|
| 448 |
+
raw_datasets["eval"] = raw_datasets["eval"].filter(preprocess_dataset)
|
| 449 |
+
raw_datasets["eval"] = raw_datasets["eval"].map(map_dataset)
|
| 450 |
|
| 451 |
if data_args.max_eval_samples is not None:
|
| 452 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|