Fixing restore checkpoint step
Browse files- mc4/mc4.py +10 -7
- run_mlm_flax_stream.py +63 -6
mc4/mc4.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
-
"""mC4 dataset based on Common Crawl."""
|
| 2 |
|
| 3 |
|
| 4 |
import gzip
|
| 5 |
import json
|
| 6 |
|
| 7 |
import datasets
|
| 8 |
-
import kenlm
|
| 9 |
import numpy as np
|
| 10 |
from numpy.random import default_rng
|
| 11 |
|
|
@@ -289,6 +289,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
| 289 |
self.sampling_factor = kwargs.pop("sampling_factor", None)
|
| 290 |
self.boundaries = kwargs.pop("boundaries", None)
|
| 291 |
self.seed = kwargs.pop("seed", None)
|
|
|
|
| 292 |
if self.sampling_method:
|
| 293 |
if self.seed is not None:
|
| 294 |
self.rng = default_rng(self.seed)
|
|
@@ -316,7 +317,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
| 316 |
doc_length += length
|
| 317 |
return 10.0 ** (-doc_log_score / doc_length)
|
| 318 |
|
| 319 |
-
def _should_keep_doc_step(self, doc, factor=1.5e5, boundaries=None):
|
| 320 |
perplexity = self.get_perplexity(doc)
|
| 321 |
if boundaries is None:
|
| 322 |
boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
|
|
@@ -331,17 +332,18 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
| 331 |
probability = factor / quartile_range
|
| 332 |
return self.rng.uniform() < probability
|
| 333 |
|
| 334 |
-
def _should_keep_doc_gaussian(self, doc, factor=0.78, boundaries=None):
|
|
|
|
| 335 |
perplexity = self.get_perplexity(doc)
|
| 336 |
if boundaries is not None:
|
| 337 |
m = boundaries[1]
|
| 338 |
else:
|
| 339 |
m = 662247.50212365
|
| 340 |
-
exponential = np.exp(-
|
| 341 |
weighted_perplexity = factor * exponential
|
| 342 |
return self.rng.uniform() < weighted_perplexity
|
| 343 |
|
| 344 |
-
def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
|
| 345 |
if factor is None:
|
| 346 |
factor = 0.5
|
| 347 |
return self.rng.uniform() <= factor
|
|
@@ -415,7 +417,8 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
| 415 |
if self.should_keep_doc(
|
| 416 |
example["text"],
|
| 417 |
factor=self.sampling_factor,
|
| 418 |
-
boundaries=self.boundaries
|
|
|
|
| 419 |
yield id_, example
|
| 420 |
id_ += 1
|
| 421 |
else:
|
|
|
|
| 1 |
+
"""Perplexity Sampled mC4 dataset based on Common Crawl."""
|
| 2 |
|
| 3 |
|
| 4 |
import gzip
|
| 5 |
import json
|
| 6 |
|
| 7 |
import datasets
|
| 8 |
+
import kenlm # pip install https://github.com/kpu/kenlm/archive/master.zip
|
| 9 |
import numpy as np
|
| 10 |
from numpy.random import default_rng
|
| 11 |
|
|
|
|
| 289 |
self.sampling_factor = kwargs.pop("sampling_factor", None)
|
| 290 |
self.boundaries = kwargs.pop("boundaries", None)
|
| 291 |
self.seed = kwargs.pop("seed", None)
|
| 292 |
+
self.kwargs = kwargs
|
| 293 |
if self.sampling_method:
|
| 294 |
if self.seed is not None:
|
| 295 |
self.rng = default_rng(self.seed)
|
|
|
|
| 317 |
doc_length += length
|
| 318 |
return 10.0 ** (-doc_log_score / doc_length)
|
| 319 |
|
| 320 |
+
def _should_keep_doc_step(self, doc, factor=1.5e5, boundaries=None, **kwargs):
|
| 321 |
perplexity = self.get_perplexity(doc)
|
| 322 |
if boundaries is None:
|
| 323 |
boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
|
|
|
|
| 332 |
probability = factor / quartile_range
|
| 333 |
return self.rng.uniform() < probability
|
| 334 |
|
| 335 |
+
def _should_keep_doc_gaussian(self, doc, factor=0.78, boundaries=None, **kwargs):
|
| 336 |
+
width = kwargs.get("width", 9 / 2) # width (spread) of the exponential curve
|
| 337 |
perplexity = self.get_perplexity(doc)
|
| 338 |
if boundaries is not None:
|
| 339 |
m = boundaries[1]
|
| 340 |
else:
|
| 341 |
m = 662247.50212365
|
| 342 |
+
exponential = np.exp((-1 / width) * ((perplexity - m) / m) ** 2)
|
| 343 |
weighted_perplexity = factor * exponential
|
| 344 |
return self.rng.uniform() < weighted_perplexity
|
| 345 |
|
| 346 |
+
def _should_keep_doc_random(self, doc, factor=None, boundaries=None, **kwargs):
|
| 347 |
if factor is None:
|
| 348 |
factor = 0.5
|
| 349 |
return self.rng.uniform() <= factor
|
|
|
|
| 417 |
if self.should_keep_doc(
|
| 418 |
example["text"],
|
| 419 |
factor=self.sampling_factor,
|
| 420 |
+
boundaries=self.boundaries
|
| 421 |
+
**self.kwargs):
|
| 422 |
yield id_, example
|
| 423 |
id_ += 1
|
| 424 |
else:
|
run_mlm_flax_stream.py
CHANGED
|
@@ -348,6 +348,24 @@ def save_checkpoint_files(state, data_collator, training_args, save_dir):
|
|
| 348 |
json.dump({"step": unreplicated_state.step.item()}, f)
|
| 349 |
|
| 350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
def rotate_checkpoints(path, max_checkpoints=5):
|
| 352 |
paths = sorted(Path(path).iterdir(), key=os.path.getmtime)[::-1]
|
| 353 |
if len(paths) > max_checkpoints:
|
|
@@ -484,8 +502,6 @@ if __name__ == "__main__":
|
|
| 484 |
has_tensorboard = is_tensorboard_available()
|
| 485 |
if has_tensorboard and jax.process_index() == 0:
|
| 486 |
try:
|
| 487 |
-
from flax.metrics.tensorboard import SummaryWriter
|
| 488 |
-
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
| 489 |
# Enable Weight&Biases
|
| 490 |
import wandb
|
| 491 |
wandb.init(
|
|
@@ -496,6 +512,8 @@ if __name__ == "__main__":
|
|
| 496 |
wandb.config.update(training_args)
|
| 497 |
wandb.config.update(model_args)
|
| 498 |
wandb.config.update(data_args)
|
|
|
|
|
|
|
| 499 |
except ImportError as ie:
|
| 500 |
has_tensorboard = False
|
| 501 |
logger.warning(
|
|
@@ -569,6 +587,42 @@ if __name__ == "__main__":
|
|
| 569 |
|
| 570 |
# Setup train state
|
| 571 |
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
|
| 573 |
# Define gradient update step fn
|
| 574 |
def train_step(state, batch, dropout_rng):
|
|
@@ -637,7 +691,10 @@ if __name__ == "__main__":
|
|
| 637 |
eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
|
| 638 |
|
| 639 |
steps = tqdm(range(num_train_steps), desc="Training...", position=0)
|
| 640 |
-
for step in range(num_train_steps):
|
|
|
|
|
|
|
|
|
|
| 641 |
# ======================== Training ================================
|
| 642 |
try:
|
| 643 |
samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
|
|
@@ -700,7 +757,7 @@ if __name__ == "__main__":
|
|
| 700 |
|
| 701 |
# save checkpoint after eval_steps
|
| 702 |
if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
|
| 703 |
-
logger.info(f"Saving checkpoint at {step
|
| 704 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 705 |
model.save_pretrained(
|
| 706 |
training_args.output_dir,
|
|
@@ -709,9 +766,9 @@ if __name__ == "__main__":
|
|
| 709 |
commit_message=f"Saving weights and logs of step {step + 1}",
|
| 710 |
)
|
| 711 |
save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
|
| 712 |
-
checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step
|
| 713 |
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
| 714 |
-
model.save_pretrained(checkpoints_dir, params=params
|
| 715 |
save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
|
| 716 |
rotate_checkpoints(
|
| 717 |
Path(training_args.output_dir) / "checkpoints",
|
|
|
|
| 348 |
json.dump({"step": unreplicated_state.step.item()}, f)
|
| 349 |
|
| 350 |
|
| 351 |
+
def restore_checkpoint(save_dir, state):
|
| 352 |
+
logger.info(f"Restoring checkpoint from {save_dir}")
|
| 353 |
+
with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
|
| 354 |
+
params = from_bytes(state.params, f.read())
|
| 355 |
+
|
| 356 |
+
with open(os.path.join(save_dir, "optimizer_state.msgpack"), "rb") as f:
|
| 357 |
+
opt_state = from_bytes(state.opt_state, f.read())
|
| 358 |
+
|
| 359 |
+
args = joblib.load(os.path.join(save_dir, "training_args.joblib"))
|
| 360 |
+
data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))
|
| 361 |
+
|
| 362 |
+
with open(os.path.join(save_dir, "training_state.json"), "r") as f:
|
| 363 |
+
training_state = json.load(f)
|
| 364 |
+
step = training_state["step"]
|
| 365 |
+
|
| 366 |
+
return params, opt_state, step, args, data_collator
|
| 367 |
+
|
| 368 |
+
|
| 369 |
def rotate_checkpoints(path, max_checkpoints=5):
|
| 370 |
paths = sorted(Path(path).iterdir(), key=os.path.getmtime)[::-1]
|
| 371 |
if len(paths) > max_checkpoints:
|
|
|
|
| 502 |
has_tensorboard = is_tensorboard_available()
|
| 503 |
if has_tensorboard and jax.process_index() == 0:
|
| 504 |
try:
|
|
|
|
|
|
|
| 505 |
# Enable Weight&Biases
|
| 506 |
import wandb
|
| 507 |
wandb.init(
|
|
|
|
| 512 |
wandb.config.update(training_args)
|
| 513 |
wandb.config.update(model_args)
|
| 514 |
wandb.config.update(data_args)
|
| 515 |
+
from flax.metrics.tensorboard import SummaryWriter
|
| 516 |
+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
| 517 |
except ImportError as ie:
|
| 518 |
has_tensorboard = False
|
| 519 |
logger.warning(
|
|
|
|
| 587 |
|
| 588 |
# Setup train state
|
| 589 |
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
| 590 |
+
saved_step = 0
|
| 591 |
+
if "checkpoint" in model_args.model_name_or_path:
|
| 592 |
+
params, opt_state, saved_step, args, data_collator = restore_checkpoint(model_args.model_name_or_path, state)
|
| 593 |
+
# Create learning rate schedule
|
| 594 |
+
warmup_fn = optax.linear_schedule(
|
| 595 |
+
init_value=0.0, end_value=args.learning_rate, transition_steps=args.warmup_steps
|
| 596 |
+
)
|
| 597 |
+
decay_fn = optax.linear_schedule(
|
| 598 |
+
init_value=args.learning_rate,
|
| 599 |
+
end_value=0,
|
| 600 |
+
transition_steps=data_args.num_train_steps - args.warmup_steps,
|
| 601 |
+
)
|
| 602 |
+
linear_decay_lr_schedule_fn = optax.join_schedules(
|
| 603 |
+
schedules=[warmup_fn, decay_fn], boundaries=[args.warmup_steps]
|
| 604 |
+
)
|
| 605 |
+
# create adam optimizer
|
| 606 |
+
adamw = optax.adamw(
|
| 607 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
| 608 |
+
b1=training_args.adam_beta1,
|
| 609 |
+
b2=training_args.adam_beta2,
|
| 610 |
+
eps=training_args.adam_epsilon,
|
| 611 |
+
weight_decay=args.weight_decay,
|
| 612 |
+
mask=decay_mask_fn,
|
| 613 |
+
)
|
| 614 |
+
state = train_state.TrainState(
|
| 615 |
+
step=saved_step,
|
| 616 |
+
apply_fn=model.__call__,
|
| 617 |
+
params=params,
|
| 618 |
+
tx=adamw,
|
| 619 |
+
opt_state=opt_state,
|
| 620 |
+
)
|
| 621 |
+
# self.args = args
|
| 622 |
+
# data_collator = data_collator
|
| 623 |
+
# scheduler_fn = args.learning_rate
|
| 624 |
+
model.params = params
|
| 625 |
+
|
| 626 |
|
| 627 |
# Define gradient update step fn
|
| 628 |
def train_step(state, batch, dropout_rng):
|
|
|
|
| 691 |
eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
|
| 692 |
|
| 693 |
steps = tqdm(range(num_train_steps), desc="Training...", position=0)
|
| 694 |
+
for step in range(saved_step, num_train_steps):
|
| 695 |
+
if step < saved_step:
|
| 696 |
+
steps.update(1)
|
| 697 |
+
continue
|
| 698 |
# ======================== Training ================================
|
| 699 |
try:
|
| 700 |
samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
|
|
|
|
| 757 |
|
| 758 |
# save checkpoint after eval_steps
|
| 759 |
if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
|
| 760 |
+
logger.info(f"Saving checkpoint at {step} steps")
|
| 761 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
| 762 |
model.save_pretrained(
|
| 763 |
training_args.output_dir,
|
|
|
|
| 766 |
commit_message=f"Saving weights and logs of step {step + 1}",
|
| 767 |
)
|
| 768 |
save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
|
| 769 |
+
checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
|
| 770 |
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
| 771 |
+
model.save_pretrained(checkpoints_dir, params=params)
|
| 772 |
save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
|
| 773 |
rotate_checkpoints(
|
| 774 |
Path(training_args.output_dir) / "checkpoints",
|