Scripts for perplexity sampling and fixes
Browse files- config.py +3 -0
- convert.py +10 -6
- run_mlm_flax.py +60 -60
- run_mlm_flax_stream.py +719 -0
- run_stream.sh +27 -0
- test_script.py +0 -45
- tokens.py +2 -2
config.py
CHANGED
|
@@ -2,3 +2,6 @@
|
|
| 2 |
from transformers import RobertaConfig
|
| 3 |
config = RobertaConfig.from_pretrained("roberta-large")
|
| 4 |
config.save_pretrained("./")
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from transformers import RobertaConfig
|
| 3 |
config = RobertaConfig.from_pretrained("roberta-large")
|
| 4 |
config.save_pretrained("./")
|
| 5 |
+
|
| 6 |
+
config = RobertaConfig.from_pretrained("roberta-base")
|
| 7 |
+
config.save_pretrained("./config-base.json")
|
convert.py
CHANGED
|
@@ -1,8 +1,12 @@
|
|
| 1 |
-
from
|
| 2 |
-
from transformers import
|
| 3 |
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
from jax import numpy as jnp
|
| 2 |
+
from transformers import FlaxRobertaForMaskedLM, RobertaForMaskedLM
|
| 3 |
|
| 4 |
+
def to_f32(t):
|
| 5 |
+
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
|
| 6 |
|
| 7 |
+
flax_model = FlaxRobertaForMaskedLM.from_pretrained("./")
|
| 8 |
+
flax_model.params = to_f32(flax_model.params)
|
| 9 |
+
flax_model.save_pretrained("./")
|
| 10 |
+
|
| 11 |
+
model = RobertaForMaskedLM.from_pretrained("./", from_flax=True)
|
| 12 |
+
model.save_pretrained("./", save_config=False)
|
run_mlm_flax.py
CHANGED
|
@@ -110,9 +110,6 @@ class DataTrainingArguments:
|
|
| 110 |
dataset_config_name: Optional[str] = field(
|
| 111 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 112 |
)
|
| 113 |
-
dataset_streaming: bool = field(
|
| 114 |
-
default=False, metadata={"help": "Whether dataset_name should be retrieved using streaming if available."}
|
| 115 |
-
)
|
| 116 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
| 117 |
validation_file: Optional[str] = field(
|
| 118 |
default=None,
|
|
@@ -322,7 +319,7 @@ if __name__ == "__main__":
|
|
| 322 |
# download the dataset.
|
| 323 |
if data_args.dataset_name is not None:
|
| 324 |
# Downloading and loading a dataset from the hub.
|
| 325 |
-
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
|
| 326 |
|
| 327 |
if "validation" not in datasets.keys():
|
| 328 |
datasets["validation"] = load_dataset(
|
|
@@ -330,14 +327,12 @@ if __name__ == "__main__":
|
|
| 330 |
data_args.dataset_config_name,
|
| 331 |
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 332 |
cache_dir=model_args.cache_dir,
|
| 333 |
-
streaming=data_args.dataset_streaming,
|
| 334 |
)
|
| 335 |
datasets["train"] = load_dataset(
|
| 336 |
data_args.dataset_name,
|
| 337 |
data_args.dataset_config_name,
|
| 338 |
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 339 |
cache_dir=model_args.cache_dir,
|
| 340 |
-
streaming=data_args.dataset_streaming,
|
| 341 |
)
|
| 342 |
else:
|
| 343 |
data_files = {}
|
|
@@ -456,6 +451,7 @@ if __name__ == "__main__":
|
|
| 456 |
num_proc=data_args.preprocessing_num_workers,
|
| 457 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 458 |
)
|
|
|
|
| 459 |
# Enable tensorboard only on the master node
|
| 460 |
has_tensorboard = is_tensorboard_available()
|
| 461 |
if has_tensorboard and jax.process_index() == 0:
|
|
@@ -483,6 +479,7 @@ if __name__ == "__main__":
|
|
| 483 |
"Please run pip install tensorboard to enable."
|
| 484 |
)
|
| 485 |
|
|
|
|
| 486 |
# Data collator
|
| 487 |
# This one will take care of randomly masking the tokens.
|
| 488 |
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
|
@@ -491,7 +488,14 @@ if __name__ == "__main__":
|
|
| 491 |
rng = jax.random.PRNGKey(training_args.seed)
|
| 492 |
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
| 493 |
|
| 494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
|
| 496 |
# Store some constant
|
| 497 |
num_epochs = int(training_args.num_train_epochs)
|
|
@@ -526,17 +530,24 @@ if __name__ == "__main__":
|
|
| 526 |
return traverse_util.unflatten_dict(flat_mask)
|
| 527 |
|
| 528 |
# create adam optimizer
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
|
| 538 |
# Setup train state
|
| 539 |
-
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=
|
| 540 |
|
| 541 |
# Define gradient update step fn
|
| 542 |
def train_step(state, batch, dropout_rng):
|
|
@@ -634,54 +645,43 @@ if __name__ == "__main__":
|
|
| 634 |
|
| 635 |
train_metrics = []
|
| 636 |
|
| 637 |
-
if training_args.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
if jax.process_index() == 0:
|
| 639 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 640 |
model.save_pretrained(
|
| 641 |
-
|
| 642 |
params=params,
|
| 643 |
push_to_hub=training_args.push_to_hub,
|
| 644 |
-
temp_dir=True,
|
| 645 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
| 646 |
)
|
| 647 |
-
|
| 648 |
-
# ======================== Evaluating ==============================
|
| 649 |
-
num_eval_samples = len(tokenized_datasets["validation"])
|
| 650 |
-
eval_samples_idx = jnp.arange(num_eval_samples)
|
| 651 |
-
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
| 652 |
-
|
| 653 |
-
eval_metrics = []
|
| 654 |
-
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
| 655 |
-
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
| 656 |
-
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
| 657 |
-
|
| 658 |
-
# Model forward
|
| 659 |
-
model_inputs = shard(model_inputs.data)
|
| 660 |
-
metrics = p_eval_step(state.params, model_inputs)
|
| 661 |
-
eval_metrics.append(metrics)
|
| 662 |
-
|
| 663 |
-
# normalize eval metrics
|
| 664 |
-
eval_metrics = get_metrics(eval_metrics)
|
| 665 |
-
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
| 666 |
-
eval_normalizer = eval_metrics.pop("normalizer")
|
| 667 |
-
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
| 668 |
-
|
| 669 |
-
# Update progress bar
|
| 670 |
-
epochs.desc = (
|
| 671 |
-
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
| 672 |
-
)
|
| 673 |
-
|
| 674 |
-
# Save metrics
|
| 675 |
-
if has_tensorboard and jax.process_index() == 0:
|
| 676 |
-
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
| 677 |
-
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
| 678 |
-
|
| 679 |
-
# save checkpoint after each epoch and push checkpoint to the hub
|
| 680 |
-
if jax.process_index() == 0:
|
| 681 |
-
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 682 |
-
model.save_pretrained(
|
| 683 |
-
training_args.output_dir,
|
| 684 |
-
params=params,
|
| 685 |
-
push_to_hub=training_args.push_to_hub,
|
| 686 |
-
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
| 687 |
-
)
|
|
|
|
| 110 |
dataset_config_name: Optional[str] = field(
|
| 111 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 112 |
)
|
|
|
|
|
|
|
|
|
|
| 113 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
| 114 |
validation_file: Optional[str] = field(
|
| 115 |
default=None,
|
|
|
|
| 319 |
# download the dataset.
|
| 320 |
if data_args.dataset_name is not None:
|
| 321 |
# Downloading and loading a dataset from the hub.
|
| 322 |
+
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
| 323 |
|
| 324 |
if "validation" not in datasets.keys():
|
| 325 |
datasets["validation"] = load_dataset(
|
|
|
|
| 327 |
data_args.dataset_config_name,
|
| 328 |
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 329 |
cache_dir=model_args.cache_dir,
|
|
|
|
| 330 |
)
|
| 331 |
datasets["train"] = load_dataset(
|
| 332 |
data_args.dataset_name,
|
| 333 |
data_args.dataset_config_name,
|
| 334 |
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 335 |
cache_dir=model_args.cache_dir,
|
|
|
|
| 336 |
)
|
| 337 |
else:
|
| 338 |
data_files = {}
|
|
|
|
| 451 |
num_proc=data_args.preprocessing_num_workers,
|
| 452 |
load_from_cache_file=not data_args.overwrite_cache,
|
| 453 |
)
|
| 454 |
+
|
| 455 |
# Enable tensorboard only on the master node
|
| 456 |
has_tensorboard = is_tensorboard_available()
|
| 457 |
if has_tensorboard and jax.process_index() == 0:
|
|
|
|
| 479 |
"Please run pip install tensorboard to enable."
|
| 480 |
)
|
| 481 |
|
| 482 |
+
|
| 483 |
# Data collator
|
| 484 |
# This one will take care of randomly masking the tokens.
|
| 485 |
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
|
|
|
| 488 |
rng = jax.random.PRNGKey(training_args.seed)
|
| 489 |
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
| 490 |
|
| 491 |
+
if model_args.model_name_or_path:
|
| 492 |
+
model = FlaxAutoModelForMaskedLM.from_pretrained(
|
| 493 |
+
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 494 |
+
)
|
| 495 |
+
else:
|
| 496 |
+
model = FlaxAutoModelForMaskedLM.from_config(
|
| 497 |
+
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 498 |
+
)
|
| 499 |
|
| 500 |
# Store some constant
|
| 501 |
num_epochs = int(training_args.num_train_epochs)
|
|
|
|
| 530 |
return traverse_util.unflatten_dict(flat_mask)
|
| 531 |
|
| 532 |
# create adam optimizer
|
| 533 |
+
if training_args.adafactor:
|
| 534 |
+
# We use the default parameters here to initialize adafactor,
|
| 535 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
| 536 |
+
optimizer = optax.adafactor(
|
| 537 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
| 538 |
+
)
|
| 539 |
+
else:
|
| 540 |
+
optimizer = optax.adamw(
|
| 541 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
| 542 |
+
b1=training_args.adam_beta1,
|
| 543 |
+
b2=training_args.adam_beta2,
|
| 544 |
+
eps=training_args.adam_epsilon,
|
| 545 |
+
weight_decay=training_args.weight_decay,
|
| 546 |
+
mask=decay_mask_fn,
|
| 547 |
+
)
|
| 548 |
|
| 549 |
# Setup train state
|
| 550 |
+
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
| 551 |
|
| 552 |
# Define gradient update step fn
|
| 553 |
def train_step(state, batch, dropout_rng):
|
|
|
|
| 645 |
|
| 646 |
train_metrics = []
|
| 647 |
|
| 648 |
+
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
| 649 |
+
# ======================== Evaluating ==============================
|
| 650 |
+
num_eval_samples = len(tokenized_datasets["validation"])
|
| 651 |
+
eval_samples_idx = jnp.arange(num_eval_samples)
|
| 652 |
+
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
| 653 |
+
|
| 654 |
+
eval_metrics = []
|
| 655 |
+
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
| 656 |
+
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
| 657 |
+
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
| 658 |
+
|
| 659 |
+
# Model forward
|
| 660 |
+
model_inputs = shard(model_inputs.data)
|
| 661 |
+
metrics = p_eval_step(state.params, model_inputs)
|
| 662 |
+
eval_metrics.append(metrics)
|
| 663 |
+
|
| 664 |
+
# normalize eval metrics
|
| 665 |
+
eval_metrics = get_metrics(eval_metrics)
|
| 666 |
+
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
| 667 |
+
eval_normalizer = eval_metrics.pop("normalizer")
|
| 668 |
+
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
| 669 |
+
|
| 670 |
+
# Update progress bar
|
| 671 |
+
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
| 672 |
+
|
| 673 |
+
# Save metrics
|
| 674 |
+
if has_tensorboard and jax.process_index() == 0:
|
| 675 |
+
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
| 676 |
+
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
| 677 |
+
|
| 678 |
+
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
| 679 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
| 680 |
if jax.process_index() == 0:
|
| 681 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 682 |
model.save_pretrained(
|
| 683 |
+
training_args.output_dir,
|
| 684 |
params=params,
|
| 685 |
push_to_hub=training_args.push_to_hub,
|
|
|
|
| 686 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
| 687 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_mlm_flax_stream.py
ADDED
|
@@ -0,0 +1,719 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2021 The HuggingFace Team All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""
|
| 17 |
+
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
|
| 18 |
+
text file or a dataset.
|
| 19 |
+
|
| 20 |
+
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
| 21 |
+
https://huggingface.co/models?filter=masked-lm
|
| 22 |
+
"""
|
| 23 |
+
import logging
|
| 24 |
+
import os
|
| 25 |
+
import sys
|
| 26 |
+
import time
|
| 27 |
+
from collections import defaultdict
|
| 28 |
+
from dataclasses import dataclass, field
|
| 29 |
+
|
| 30 |
+
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Dict, List, Optional, Tuple
|
| 33 |
+
|
| 34 |
+
import datasets
|
| 35 |
+
import numpy as np
|
| 36 |
+
from datasets import load_dataset
|
| 37 |
+
from tqdm import tqdm
|
| 38 |
+
|
| 39 |
+
import flax
|
| 40 |
+
import jax
|
| 41 |
+
import jax.numpy as jnp
|
| 42 |
+
import kenlm
|
| 43 |
+
import optax
|
| 44 |
+
from flax import jax_utils, traverse_util
|
| 45 |
+
from flax.training import train_state
|
| 46 |
+
from flax.training.common_utils import get_metrics, onehot, shard
|
| 47 |
+
from transformers import (
|
| 48 |
+
CONFIG_MAPPING,
|
| 49 |
+
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
| 50 |
+
AutoConfig,
|
| 51 |
+
AutoTokenizer,
|
| 52 |
+
FlaxAutoModelForMaskedLM,
|
| 53 |
+
HfArgumentParser,
|
| 54 |
+
PreTrainedTokenizerBase,
|
| 55 |
+
TensorType,
|
| 56 |
+
TrainingArguments,
|
| 57 |
+
is_tensorboard_available,
|
| 58 |
+
set_seed,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if datasets.__version__ <= "1.8.0":
|
| 63 |
+
raise ValueError("Make sure to upgrade `datasets` to a version >= 1.9.0 to use dataset streaming")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
| 67 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dataclass
|
| 71 |
+
class ModelArguments:
|
| 72 |
+
"""
|
| 73 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
model_name_or_path: Optional[str] = field(
|
| 77 |
+
default=None,
|
| 78 |
+
metadata={
|
| 79 |
+
"help": "The model checkpoint for weights initialization."
|
| 80 |
+
"Don't set if you want to train a model from scratch."
|
| 81 |
+
},
|
| 82 |
+
)
|
| 83 |
+
model_type: Optional[str] = field(
|
| 84 |
+
default=None,
|
| 85 |
+
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
| 86 |
+
)
|
| 87 |
+
config_name: Optional[str] = field(
|
| 88 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
| 89 |
+
)
|
| 90 |
+
tokenizer_name: Optional[str] = field(
|
| 91 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
| 92 |
+
)
|
| 93 |
+
cache_dir: Optional[str] = field(
|
| 94 |
+
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
| 95 |
+
)
|
| 96 |
+
use_fast_tokenizer: bool = field(
|
| 97 |
+
default=True,
|
| 98 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
| 99 |
+
)
|
| 100 |
+
dtype: Optional[str] = field(
|
| 101 |
+
default="float32",
|
| 102 |
+
metadata={
|
| 103 |
+
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
| 104 |
+
},
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
@dataclass
|
| 108 |
+
class DataTrainingArguments:
|
| 109 |
+
"""
|
| 110 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
dataset_name: Optional[str] = field(
|
| 114 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
| 115 |
+
)
|
| 116 |
+
dataset_config_name: Optional[str] = field(
|
| 117 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 118 |
+
)
|
| 119 |
+
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
| 120 |
+
validation_file: Optional[str] = field(
|
| 121 |
+
default=None,
|
| 122 |
+
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
| 123 |
+
)
|
| 124 |
+
train_ref_file: Optional[str] = field(
|
| 125 |
+
default=None,
|
| 126 |
+
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
|
| 127 |
+
)
|
| 128 |
+
validation_ref_file: Optional[str] = field(
|
| 129 |
+
default=None,
|
| 130 |
+
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
|
| 131 |
+
)
|
| 132 |
+
overwrite_cache: bool = field(
|
| 133 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 134 |
+
)
|
| 135 |
+
validation_split_percentage: Optional[int] = field(
|
| 136 |
+
default=5,
|
| 137 |
+
metadata={
|
| 138 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
| 139 |
+
},
|
| 140 |
+
)
|
| 141 |
+
max_seq_length: Optional[int] = field(
|
| 142 |
+
default=None,
|
| 143 |
+
metadata={
|
| 144 |
+
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
| 145 |
+
"than this will be truncated. Default to the max input length of the model."
|
| 146 |
+
},
|
| 147 |
+
)
|
| 148 |
+
preprocessing_num_workers: Optional[int] = field(
|
| 149 |
+
default=None,
|
| 150 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
| 151 |
+
)
|
| 152 |
+
mlm_probability: float = field(
|
| 153 |
+
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
|
| 154 |
+
)
|
| 155 |
+
pad_to_max_length: bool = field(
|
| 156 |
+
default=False,
|
| 157 |
+
metadata={
|
| 158 |
+
"help": "Whether to pad all samples to `max_seq_length`. "
|
| 159 |
+
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
| 160 |
+
},
|
| 161 |
+
)
|
| 162 |
+
line_by_line: bool = field(
|
| 163 |
+
default=False,
|
| 164 |
+
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
|
| 165 |
+
)
|
| 166 |
+
text_column_name: str = field(
|
| 167 |
+
default="text", metadata={"help": "The name of the column to retrieve the training text."}
|
| 168 |
+
)
|
| 169 |
+
shuffle_buffer_size: int = field(
|
| 170 |
+
default=10000, metadata={"help": "The number of examples to pre-load for shuffling."}
|
| 171 |
+
)
|
| 172 |
+
num_train_steps: int = field(default=50000, metadata={"help": "The number of training steps."})
|
| 173 |
+
num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
|
| 174 |
+
|
| 175 |
+
def __post_init__(self):
|
| 176 |
+
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
| 177 |
+
raise ValueError("Need either a dataset name or a training/validation file.")
|
| 178 |
+
else:
|
| 179 |
+
if self.train_file is not None:
|
| 180 |
+
extension = self.train_file.split(".")[-1]
|
| 181 |
+
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
| 182 |
+
if self.validation_file is not None:
|
| 183 |
+
extension = self.validation_file.split(".")[-1]
|
| 184 |
+
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@flax.struct.dataclass
|
| 188 |
+
class FlaxDataCollatorForLanguageModeling:
|
| 189 |
+
"""
|
| 190 |
+
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
| 191 |
+
are not all of the same length.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
| 195 |
+
The tokenizer used for encoding the data.
|
| 196 |
+
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
| 197 |
+
The probability with which to (randomly) mask tokens in the input.
|
| 198 |
+
|
| 199 |
+
.. note::
|
| 200 |
+
|
| 201 |
+
For best performance, this data collator should be used with a dataset having items that are dictionaries or
|
| 202 |
+
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
|
| 203 |
+
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
|
| 204 |
+
argument :obj:`return_special_tokens_mask=True`.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
tokenizer: PreTrainedTokenizerBase
|
| 208 |
+
mlm_probability: float = 0.15
|
| 209 |
+
|
| 210 |
+
def __post_init__(self):
|
| 211 |
+
if self.tokenizer.mask_token is None:
|
| 212 |
+
raise ValueError(
|
| 213 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
| 214 |
+
"You should pass `mlm=False` to train on causal language modeling instead."
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
|
| 218 |
+
# Handle dict or lists with proper padding and conversion to tensor.
|
| 219 |
+
batch = self.tokenizer.pad(examples, return_tensors=TensorType.NUMPY)
|
| 220 |
+
|
| 221 |
+
# If special token mask has been preprocessed, pop it from the dict.
|
| 222 |
+
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
| 223 |
+
|
| 224 |
+
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
| 225 |
+
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
| 226 |
+
)
|
| 227 |
+
return batch
|
| 228 |
+
|
| 229 |
+
def mask_tokens(
|
| 230 |
+
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
|
| 231 |
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
| 232 |
+
"""
|
| 233 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
| 234 |
+
"""
|
| 235 |
+
labels = inputs.copy()
|
| 236 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
| 237 |
+
probability_matrix = np.full(labels.shape, self.mlm_probability)
|
| 238 |
+
special_tokens_mask = special_tokens_mask.astype("bool")
|
| 239 |
+
|
| 240 |
+
probability_matrix[special_tokens_mask] = 0.0
|
| 241 |
+
masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
|
| 242 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 243 |
+
|
| 244 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 245 |
+
indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
|
| 246 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
| 247 |
+
|
| 248 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 249 |
+
indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
|
| 250 |
+
indices_random &= masked_indices & ~indices_replaced
|
| 251 |
+
|
| 252 |
+
random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
|
| 253 |
+
inputs[indices_random] = random_words[indices_random]
|
| 254 |
+
|
| 255 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 256 |
+
return inputs, labels
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
@dataclass
|
| 261 |
+
class SamplingArguments:
|
| 262 |
+
"""
|
| 263 |
+
Arguments pertaining to how to perform sampling of the dataset.
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
perplexity_model: Optional[str] = field(
|
| 267 |
+
default="es.arpa.bin", metadata={"help": "kenlm model to use to get perplexity values."}
|
| 268 |
+
)
|
| 269 |
+
sampling_method: Optional[str] = field(
|
| 270 |
+
default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document."}
|
| 271 |
+
)
|
| 272 |
+
sampling_factor: Optional[int] = field(
|
| 273 |
+
default=1, metadata={"help": "Sampling factor. Integers for step function, decimals for gaussian."}
|
| 274 |
+
)
|
| 275 |
+
quartiles: Optional[str] = field(
|
| 276 |
+
default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
def __post_init__(self):
|
| 280 |
+
self.quartiles = [float(q) for q in self.quartiles.split(",")]
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
| 284 |
+
num_samples = len(samples_idx)
|
| 285 |
+
samples_to_remove = num_samples % batch_size
|
| 286 |
+
|
| 287 |
+
if samples_to_remove != 0:
|
| 288 |
+
samples_idx = samples_idx[:-samples_to_remove]
|
| 289 |
+
sections_split = num_samples // batch_size
|
| 290 |
+
batch_idx = np.split(samples_idx, sections_split)
|
| 291 |
+
return batch_idx
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
|
| 295 |
+
"""
|
| 296 |
+
The training iterator is advanced so that after groupifying the samples,
|
| 297 |
+
`num_samples` of length `max_seq_length` are returned.
|
| 298 |
+
"""
|
| 299 |
+
num_total_tokens = max_seq_length * num_samples
|
| 300 |
+
samples = defaultdict(list)
|
| 301 |
+
|
| 302 |
+
i = 0
|
| 303 |
+
while i < num_total_tokens:
|
| 304 |
+
tokenized_samples = next(train_iterator)
|
| 305 |
+
i += len(tokenized_samples["input_ids"])
|
| 306 |
+
|
| 307 |
+
# concatenate tokenized samples to list
|
| 308 |
+
samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
|
| 309 |
+
|
| 310 |
+
# Concatenated tokens are split to lists of length `max_seq_length`.
|
| 311 |
+
# Note that remainedr of % max_seq_length are thrown away.
|
| 312 |
+
def group_texts(examples):
|
| 313 |
+
result = {
|
| 314 |
+
k: [t[i : i + max_seq_length] for i in range(0, num_total_tokens, max_seq_length)]
|
| 315 |
+
for k, t in examples.items()
|
| 316 |
+
}
|
| 317 |
+
return result
|
| 318 |
+
|
| 319 |
+
grouped_samples = group_texts(samples)
|
| 320 |
+
return grouped_samples
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
| 324 |
+
summary_writer.scalar("train_time", train_time, step)
|
| 325 |
+
|
| 326 |
+
train_metrics = get_metrics(train_metrics)
|
| 327 |
+
for key, vals in train_metrics.items():
|
| 328 |
+
tag = f"train_{key}"
|
| 329 |
+
for i, val in enumerate(vals):
|
| 330 |
+
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def write_eval_metric(summary_writer, eval_metrics, step):
|
| 334 |
+
for metric_name, value in eval_metrics.items():
|
| 335 |
+
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
if __name__ == "__main__":
|
| 339 |
+
# See all possible arguments in src/transformers/training_args.py
|
| 340 |
+
# or by passing the --help flag to this script.
|
| 341 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
| 342 |
+
|
| 343 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, SamplingArguments))
|
| 344 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 345 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
| 346 |
+
# let's parse it to get our arguments.
|
| 347 |
+
model_args, data_args, training_args, sampling_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
| 348 |
+
else:
|
| 349 |
+
model_args, data_args, training_args, sampling_args = parser.parse_args_into_dataclasses()
|
| 350 |
+
|
| 351 |
+
if (
|
| 352 |
+
os.path.exists(training_args.output_dir)
|
| 353 |
+
and os.listdir(training_args.output_dir)
|
| 354 |
+
and training_args.do_train
|
| 355 |
+
and not training_args.overwrite_output_dir
|
| 356 |
+
):
|
| 357 |
+
raise ValueError(
|
| 358 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
| 359 |
+
"Use --overwrite_output_dir to overcome."
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# Setup logging
|
| 363 |
+
logging.basicConfig(
|
| 364 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 365 |
+
level="INFO",
|
| 366 |
+
datefmt="[%X]",
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Log on each process the small summary:
|
| 370 |
+
logger = logging.getLogger(__name__)
|
| 371 |
+
logger.warning(
|
| 372 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
| 373 |
+
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# Set the verbosity to info of the Transformers logger (on main process only):
|
| 377 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
| 378 |
+
|
| 379 |
+
# Set seed before initializing model.
|
| 380 |
+
set_seed(training_args.seed)
|
| 381 |
+
|
| 382 |
+
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
| 383 |
+
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
| 384 |
+
# (the dataset will be downloaded automatically from the datasets Hub).
|
| 385 |
+
#
|
| 386 |
+
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
| 387 |
+
# 'text' is found. You can easily tweak this behavior (see below).
|
| 388 |
+
if data_args.dataset_name is not None:
|
| 389 |
+
# Downloading and loading a dataset from the hub.
|
| 390 |
+
dataset = load_dataset(
|
| 391 |
+
data_args.dataset_name,
|
| 392 |
+
data_args.dataset_config_name,
|
| 393 |
+
cache_dir=model_args.cache_dir,
|
| 394 |
+
streaming=True,
|
| 395 |
+
split="train",
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
if model_args.config_name:
|
| 399 |
+
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
| 400 |
+
elif model_args.model_name_or_path:
|
| 401 |
+
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
| 402 |
+
else:
|
| 403 |
+
config = CONFIG_MAPPING[model_args.model_type]()
|
| 404 |
+
logger.warning("You are instantiating a new config instance from scratch.")
|
| 405 |
+
|
| 406 |
+
if model_args.tokenizer_name:
|
| 407 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 408 |
+
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 409 |
+
)
|
| 410 |
+
elif model_args.model_name_or_path:
|
| 411 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 412 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 413 |
+
)
|
| 414 |
+
else:
|
| 415 |
+
raise ValueError(
|
| 416 |
+
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
| 417 |
+
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# Loading 5-gram model
|
| 421 |
+
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
|
| 422 |
+
if sampling_args.sampling_method:
|
| 423 |
+
pp_model = kenlm.Model(sampling_args.perplexity_model)
|
| 424 |
+
|
| 425 |
+
def get_perplexity(doc):
|
| 426 |
+
doc_log_score, doc_length = 0, 0
|
| 427 |
+
for line in doc.split("\n"):
|
| 428 |
+
log_score = pp_model.score(line)
|
| 429 |
+
length = len(line.split()) + 1
|
| 430 |
+
doc_log_score += log_score
|
| 431 |
+
doc_length += length
|
| 432 |
+
return 10.0 ** (-doc_log_score / doc_length)
|
| 433 |
+
|
| 434 |
+
def should_keep_doc_step(doc, factor=1, boundaires=None):
|
| 435 |
+
perplexity = get_perplexity(doc)
|
| 436 |
+
if boundaires is None:
|
| 437 |
+
boundaires = [536394.99320948, 662247.50212365, 919250.87225178]
|
| 438 |
+
if perplexity <= boundaires[0]:
|
| 439 |
+
quartile_range = boundaires[0]
|
| 440 |
+
elif boundaires[0] < perplexity < boundaires[1]:
|
| 441 |
+
quartile_range = boundaires[1] - boundaires[0]
|
| 442 |
+
elif boundaires[1] < perplexity < boundaires[2]:
|
| 443 |
+
quartile_range = boundaires[2] - boundaires[1]
|
| 444 |
+
elif perplexity >= boundaires[2]:
|
| 445 |
+
quartile_range = 100 * boundaires[2]
|
| 446 |
+
probability = factor / quartile_range
|
| 447 |
+
return np.random() < probability
|
| 448 |
+
|
| 449 |
+
def should_keep_doc_gaussian(doc, factor=0.4, boundaires=None):
|
| 450 |
+
perplexity = get_perplexity(doc)
|
| 451 |
+
if boundaires is not None:
|
| 452 |
+
m = boundaires[1]
|
| 453 |
+
else:
|
| 454 |
+
m = 662247.50212365
|
| 455 |
+
weighted_perplexity = factor*np.exp(-9/2*((perplexity-m)/m)**2)
|
| 456 |
+
return np.random.uniform() < weighted_perplexity
|
| 457 |
+
|
| 458 |
+
if sampling_args.sampling_method == "gaussian":
|
| 459 |
+
should_keep_doc = should_keep_doc_gaussian
|
| 460 |
+
else:
|
| 461 |
+
should_keep_doc = should_keep_doc_gaussian
|
| 462 |
+
|
| 463 |
+
def tokenize_function(examples):
|
| 464 |
+
return tokenizer([
|
| 465 |
+
example for example in examples[data_args.text_column_name]
|
| 466 |
+
if should_keep_doc(
|
| 467 |
+
example,
|
| 468 |
+
factor=sampling_args.sampling_factor,
|
| 469 |
+
boundaries=sampling_args.boundaries
|
| 470 |
+
)
|
| 471 |
+
], return_special_tokens_mask=True)
|
| 472 |
+
else:
|
| 473 |
+
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
| 474 |
+
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
|
| 475 |
+
# efficient when it receives the `special_tokens_mask`.
|
| 476 |
+
def tokenize_function(examples):
|
| 477 |
+
return tokenizer(
|
| 478 |
+
examples[data_args.text_column_name],
|
| 479 |
+
return_special_tokens_mask=True
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
tokenized_datasets = dataset.map(
|
| 483 |
+
tokenize_function,
|
| 484 |
+
batched=True,
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
shuffle_seed = training_args.seed
|
| 488 |
+
tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
|
| 489 |
+
|
| 490 |
+
# Enable tensorboard only on the master node
|
| 491 |
+
has_tensorboard = is_tensorboard_available()
|
| 492 |
+
if has_tensorboard and jax.process_index() == 0:
|
| 493 |
+
try:
|
| 494 |
+
from flax.metrics.tensorboard import SummaryWriter
|
| 495 |
+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
| 496 |
+
# Enable Weight&Biases
|
| 497 |
+
import wandb
|
| 498 |
+
wandb.init(
|
| 499 |
+
entity='wandb',
|
| 500 |
+
project='hf-flax-bertin-roberta-es',
|
| 501 |
+
sync_tensorboard=True,
|
| 502 |
+
)
|
| 503 |
+
wandb.config.update(training_args)
|
| 504 |
+
wandb.config.update(model_args)
|
| 505 |
+
wandb.config.update(data_args)
|
| 506 |
+
except ImportError as ie:
|
| 507 |
+
has_tensorboard = False
|
| 508 |
+
logger.warning(
|
| 509 |
+
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
| 510 |
+
)
|
| 511 |
+
else:
|
| 512 |
+
logger.warning(
|
| 513 |
+
"Unable to display metrics through TensorBoard because the package is not installed: "
|
| 514 |
+
"Please run pip install tensorboard to enable."
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# Data collator
|
| 518 |
+
# This one will take care of randomly masking the tokens.
|
| 519 |
+
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
| 520 |
+
|
| 521 |
+
# Initialize our training
|
| 522 |
+
rng = jax.random.PRNGKey(training_args.seed)
|
| 523 |
+
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
| 524 |
+
|
| 525 |
+
if model_args.model_name_or_path:
|
| 526 |
+
model = FlaxAutoModelForMaskedLM.from_pretrained(
|
| 527 |
+
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 528 |
+
)
|
| 529 |
+
else:
|
| 530 |
+
model = FlaxAutoModelForMaskedLM.from_config(
|
| 531 |
+
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# Store some constant
|
| 535 |
+
num_epochs = int(training_args.num_train_epochs)
|
| 536 |
+
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
| 537 |
+
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
| 538 |
+
|
| 539 |
+
# define number steps per stream epoch
|
| 540 |
+
num_train_steps = data_args.num_train_steps
|
| 541 |
+
|
| 542 |
+
# Create learning rate schedule
|
| 543 |
+
warmup_fn = optax.linear_schedule(
|
| 544 |
+
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
| 545 |
+
)
|
| 546 |
+
decay_fn = optax.linear_schedule(
|
| 547 |
+
init_value=training_args.learning_rate,
|
| 548 |
+
end_value=0,
|
| 549 |
+
transition_steps=num_train_steps - training_args.warmup_steps,
|
| 550 |
+
)
|
| 551 |
+
linear_decay_lr_schedule_fn = optax.join_schedules(
|
| 552 |
+
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
# We use Optax's "masking" functionality to not apply weight decay
|
| 556 |
+
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
| 557 |
+
# mask boolean with the same structure as the parameters.
|
| 558 |
+
# The mask is True for parameters that should be decayed.
|
| 559 |
+
# Note that this mask is specifically adapted for FlaxBERT-like models.
|
| 560 |
+
# For other models, one should correct the layer norm parameter naming
|
| 561 |
+
# accordingly.
|
| 562 |
+
def decay_mask_fn(params):
|
| 563 |
+
flat_params = traverse_util.flatten_dict(params)
|
| 564 |
+
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
| 565 |
+
return traverse_util.unflatten_dict(flat_mask)
|
| 566 |
+
|
| 567 |
+
# create adam optimizer
|
| 568 |
+
adamw = optax.adamw(
|
| 569 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
| 570 |
+
b1=training_args.adam_beta1,
|
| 571 |
+
b2=training_args.adam_beta2,
|
| 572 |
+
eps=training_args.adam_epsilon,
|
| 573 |
+
weight_decay=training_args.weight_decay,
|
| 574 |
+
mask=decay_mask_fn,
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
# Setup train state
|
| 578 |
+
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
| 579 |
+
|
| 580 |
+
# Define gradient update step fn
|
| 581 |
+
def train_step(state, batch, dropout_rng):
|
| 582 |
+
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
| 583 |
+
|
| 584 |
+
def loss_fn(params):
|
| 585 |
+
labels = batch.pop("labels")
|
| 586 |
+
|
| 587 |
+
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
| 588 |
+
|
| 589 |
+
# compute loss, ignore padded input tokens
|
| 590 |
+
label_mask = jnp.where(labels > 0, 1.0, 0.0)
|
| 591 |
+
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
| 592 |
+
|
| 593 |
+
# take average
|
| 594 |
+
loss = loss.sum() / label_mask.sum()
|
| 595 |
+
|
| 596 |
+
return loss
|
| 597 |
+
|
| 598 |
+
grad_fn = jax.value_and_grad(loss_fn)
|
| 599 |
+
loss, grad = grad_fn(state.params)
|
| 600 |
+
grad = jax.lax.pmean(grad, "batch")
|
| 601 |
+
new_state = state.apply_gradients(grads=grad)
|
| 602 |
+
|
| 603 |
+
metrics = jax.lax.pmean(
|
| 604 |
+
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
return new_state, metrics, new_dropout_rng
|
| 608 |
+
|
| 609 |
+
# Create parallel version of the train step
|
| 610 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
| 611 |
+
|
| 612 |
+
# Define eval fn
|
| 613 |
+
def eval_step(params, batch):
|
| 614 |
+
labels = batch.pop("labels")
|
| 615 |
+
|
| 616 |
+
logits = model(**batch, params=params, train=False)[0]
|
| 617 |
+
|
| 618 |
+
# compute loss, ignore padded input tokens
|
| 619 |
+
label_mask = jnp.where(labels > 0, 1.0, 0.0)
|
| 620 |
+
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
| 621 |
+
|
| 622 |
+
# compute accuracy
|
| 623 |
+
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
|
| 624 |
+
|
| 625 |
+
# summarize metrics
|
| 626 |
+
metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
|
| 627 |
+
metrics = jax.lax.psum(metrics, axis_name="batch")
|
| 628 |
+
|
| 629 |
+
return metrics
|
| 630 |
+
|
| 631 |
+
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
|
| 632 |
+
|
| 633 |
+
# Replicate the train state on each device
|
| 634 |
+
state = jax_utils.replicate(state)
|
| 635 |
+
|
| 636 |
+
train_time = 0
|
| 637 |
+
train_start = time.time()
|
| 638 |
+
train_metrics = []
|
| 639 |
+
eval_metrics = []
|
| 640 |
+
|
| 641 |
+
training_iter = iter(tokenized_datasets)
|
| 642 |
+
|
| 643 |
+
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
| 644 |
+
eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
|
| 645 |
+
|
| 646 |
+
steps = tqdm(range(num_train_steps), desc="Training...", position=0)
|
| 647 |
+
for step in range(num_train_steps):
|
| 648 |
+
# ======================== Training ================================
|
| 649 |
+
try:
|
| 650 |
+
samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
|
| 651 |
+
except StopIteration:
|
| 652 |
+
# Once the end of the dataset stream is reached, the training iterator
|
| 653 |
+
# is reinitialized and reshuffled and a new eval dataset is randomely chosen.
|
| 654 |
+
shuffle_seed += 1
|
| 655 |
+
tokenized_datasets.set_epoch(shuffle_seed)
|
| 656 |
+
|
| 657 |
+
training_iter = iter(tokenized_datasets)
|
| 658 |
+
|
| 659 |
+
eval_dataset = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
|
| 660 |
+
samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
|
| 661 |
+
|
| 662 |
+
# process input samples
|
| 663 |
+
model_inputs = data_collator(samples)
|
| 664 |
+
|
| 665 |
+
# Model forward
|
| 666 |
+
model_inputs = shard(model_inputs.data)
|
| 667 |
+
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
| 668 |
+
|
| 669 |
+
train_metrics.append(train_metric)
|
| 670 |
+
|
| 671 |
+
if step % training_args.logging_steps == 0 and step > 0:
|
| 672 |
+
steps.write(
|
| 673 |
+
f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
| 674 |
+
)
|
| 675 |
+
train_time += time.time() - train_start
|
| 676 |
+
if has_tensorboard and jax.process_index() == 0:
|
| 677 |
+
write_train_metric(summary_writer, train_metrics, train_time, step)
|
| 678 |
+
train_metrics = []
|
| 679 |
+
|
| 680 |
+
# ======================== Evaluating ==============================
|
| 681 |
+
if step % training_args.eval_steps == 0 and step > 0:
|
| 682 |
+
eval_samples_idx = jnp.arange(data_args.num_eval_samples)
|
| 683 |
+
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
| 684 |
+
|
| 685 |
+
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
|
| 686 |
+
# process input samples
|
| 687 |
+
batch_eval_samples = {k: [v[idx] for idx in batch_idx] for k, v in eval_samples.items()}
|
| 688 |
+
model_inputs = data_collator(batch_eval_samples)
|
| 689 |
+
|
| 690 |
+
# Model forward
|
| 691 |
+
model_inputs = shard(model_inputs.data)
|
| 692 |
+
metrics = p_eval_step(state.params, model_inputs)
|
| 693 |
+
eval_metrics.append(metrics)
|
| 694 |
+
|
| 695 |
+
# normalize eval metrics
|
| 696 |
+
eval_metrics = get_metrics(eval_metrics)
|
| 697 |
+
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
| 698 |
+
eval_normalizer = eval_metrics.pop("normalizer")
|
| 699 |
+
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
| 700 |
+
|
| 701 |
+
# Update progress bar
|
| 702 |
+
steps.desc = f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
| 703 |
+
|
| 704 |
+
if has_tensorboard and jax.process_index() == 0:
|
| 705 |
+
write_eval_metric(summary_writer, eval_metrics, step)
|
| 706 |
+
eval_metrics = []
|
| 707 |
+
|
| 708 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
| 709 |
+
if jax.process_index() == 0:
|
| 710 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 711 |
+
model.save_pretrained(
|
| 712 |
+
training_args.output_dir,
|
| 713 |
+
params=params,
|
| 714 |
+
push_to_hub=training_args.push_to_hub,
|
| 715 |
+
commit_message=f"Saving weights and logs of step {step+1}",
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
# update tqdm bar
|
| 719 |
+
steps.update(1)
|
run_stream.sh
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# From https://arxiv.org/pdf/1907.11692.pdf for base model
|
| 2 |
+
python -c "import jax; print('TPUs', jax.device_count())"
|
| 3 |
+
./run_mlm_flax_stream.py \
|
| 4 |
+
--output_dir="./" \
|
| 5 |
+
--model_type="roberta" \
|
| 6 |
+
--config_name="./config-base.json" \
|
| 7 |
+
--tokenizer_name="./" \
|
| 8 |
+
--dataset_name="mc4" \
|
| 9 |
+
--dataset_config_name="es" \
|
| 10 |
+
--max_seq_length="128" \
|
| 11 |
+
--pad_to_max_length \
|
| 12 |
+
--per_device_train_batch_size="256" \
|
| 13 |
+
--per_device_eval_batch_size="256" \
|
| 14 |
+
--adam_beta1="0.9" \
|
| 15 |
+
--adam_beta2="0.98" \
|
| 16 |
+
--adam_epsilon="1e-6" \
|
| 17 |
+
--learning_rate="6e-4" \
|
| 18 |
+
--weight_decay="0.01" \
|
| 19 |
+
--save_strategy="steps" \
|
| 20 |
+
--save_steps="1000" \
|
| 21 |
+
--save_total_limit="5" \
|
| 22 |
+
--warmup_steps="24000" \
|
| 23 |
+
--overwrite_output_dir \
|
| 24 |
+
--num_train_steps="500000" \
|
| 25 |
+
--eval_steps="1000" \
|
| 26 |
+
--dtype="bfloat16" \
|
| 27 |
+
--logging_steps="500" 2>&1 | tee run_stream.log
|
test_script.py
DELETED
|
@@ -1,45 +0,0 @@
|
|
| 1 |
-
"""CONFIG"""
|
| 2 |
-
#!/usr/bin/env python3
|
| 3 |
-
from transformers import RobertaConfig
|
| 4 |
-
config = RobertaConfig.from_pretrained("roberta-large")
|
| 5 |
-
config.save_pretrained("./")
|
| 6 |
-
|
| 7 |
-
"""TOKENIZER"""
|
| 8 |
-
#!/usr/bin/env python3
|
| 9 |
-
from datasets import load_dataset
|
| 10 |
-
from tokenizers import ByteLevelBPETokenizer
|
| 11 |
-
# load dataset
|
| 12 |
-
dataset = load_dataset("large_spanish_corpus")
|
| 13 |
-
# Instantiate tokenizer
|
| 14 |
-
tokenizer = ByteLevelBPETokenizer()
|
| 15 |
-
def batch_iterator(batch_size=1000):
|
| 16 |
-
for i in range(0, len(dataset), batch_size):
|
| 17 |
-
yield dataset[i: i + batch_size]["text"]
|
| 18 |
-
# Customized training
|
| 19 |
-
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
|
| 20 |
-
"<s>",
|
| 21 |
-
"<pad>",
|
| 22 |
-
"</s>",
|
| 23 |
-
"<unk>",
|
| 24 |
-
"<mask>",
|
| 25 |
-
])
|
| 26 |
-
# Save files to disk
|
| 27 |
-
tokenizer.save("./tokenizer.json")
|
| 28 |
-
|
| 29 |
-
"""TOKENIZER"""
|
| 30 |
-
#!/usr/bin/env bash
|
| 31 |
-
./run_mlm_flax.py \
|
| 32 |
-
--output_dir="./" \
|
| 33 |
-
--model_type="roberta" \
|
| 34 |
-
--config_name="./" \
|
| 35 |
-
--tokenizer_name="./" \
|
| 36 |
-
--dataset_name="large_spanish_corpus" \
|
| 37 |
-
--dataset_config_name \ # I think this would be empty
|
| 38 |
-
--max_seq_length="128" \
|
| 39 |
-
--per_device_train_batch_size="4" \
|
| 40 |
-
--per_device_eval_batch_size="4" \
|
| 41 |
-
--learning_rate="3e-4" \
|
| 42 |
-
--warmup_steps="1000" \
|
| 43 |
-
--overwrite_output_dir \
|
| 44 |
-
--num_train_epochs="8" \
|
| 45 |
-
--push_to_hub
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokens.py
CHANGED
|
@@ -3,11 +3,11 @@ from datasets import load_dataset
|
|
| 3 |
from tokenizers import ByteLevelBPETokenizer
|
| 4 |
|
| 5 |
# Load dataset
|
| 6 |
-
dataset = load_dataset("oscar", "unshuffled_deduplicated_es")
|
| 7 |
|
| 8 |
# Instantiate tokenizer
|
| 9 |
tokenizer = ByteLevelBPETokenizer()
|
| 10 |
-
def batch_iterator(batch_size=
|
| 11 |
for i in range(0, len(dataset), batch_size):
|
| 12 |
yield dataset["text"][i: i + batch_size]
|
| 13 |
|
|
|
|
| 3 |
from tokenizers import ByteLevelBPETokenizer
|
| 4 |
|
| 5 |
# Load dataset
|
| 6 |
+
dataset = load_dataset("oscar", "unshuffled_deduplicated_es", split="train[:5000000]")
|
| 7 |
|
| 8 |
# Instantiate tokenizer
|
| 9 |
tokenizer = ByteLevelBPETokenizer()
|
| 10 |
+
def batch_iterator(batch_size=100_000):
|
| 11 |
for i in range(0, len(dataset), batch_size):
|
| 12 |
yield dataset["text"][i: i + batch_size]
|
| 13 |
|