add wandb integration
Browse files- run_clm_flax.py +18 -0
run_clm_flax.py
CHANGED
|
@@ -53,6 +53,7 @@ from transformers import (
|
|
| 53 |
is_tensorboard_available,
|
| 54 |
)
|
| 55 |
from transformers.testing_utils import CaptureLogger
|
|
|
|
| 56 |
|
| 57 |
|
| 58 |
logger = logging.getLogger(__name__)
|
|
@@ -232,6 +233,13 @@ def main():
|
|
| 232 |
# or by passing the --help flag to this script.
|
| 233 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
| 236 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 237 |
# If we pass only one argument to the script and it's the path to a json file,
|
|
@@ -250,6 +258,13 @@ def main():
|
|
| 250 |
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
| 251 |
"Use --overwrite_output_dir to overcome."
|
| 252 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
# Make one log on every process with the configuration for debugging.
|
| 255 |
logging.basicConfig(
|
|
@@ -591,6 +606,8 @@ def main():
|
|
| 591 |
epochs.write(
|
| 592 |
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
| 593 |
)
|
|
|
|
|
|
|
| 594 |
|
| 595 |
train_metrics = []
|
| 596 |
|
|
@@ -623,6 +640,7 @@ def main():
|
|
| 623 |
if has_tensorboard and jax.process_index() == 0:
|
| 624 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
| 625 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
|
|
|
| 626 |
|
| 627 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
| 628 |
# save checkpoint after each epoch and push checkpoint to the hub
|
|
|
|
| 53 |
is_tensorboard_available,
|
| 54 |
)
|
| 55 |
from transformers.testing_utils import CaptureLogger
|
| 56 |
+
import wandb
|
| 57 |
|
| 58 |
|
| 59 |
logger = logging.getLogger(__name__)
|
|
|
|
| 233 |
# or by passing the --help flag to this script.
|
| 234 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
| 235 |
|
| 236 |
+
if jax.process_index() == 0:
|
| 237 |
+
wandb.init(
|
| 238 |
+
entity = os.getenv("WANDB_ENTITY", "indonesian-nlp"),
|
| 239 |
+
project = os.getenv("WANDB_PROJECT", "huggingface"),
|
| 240 |
+
sync_tensorboard =True
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
| 244 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 245 |
# If we pass only one argument to the script and it's the path to a json file,
|
|
|
|
| 258 |
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
| 259 |
"Use --overwrite_output_dir to overcome."
|
| 260 |
)
|
| 261 |
+
# log your configs with wandb.config, accepts a dict
|
| 262 |
+
if jax.process_index() == 0:
|
| 263 |
+
wandb.config.update(training_args) # optional, log your configs
|
| 264 |
+
wandb.config.update(model_args) # optional, log your configs
|
| 265 |
+
wandb.config.update(data_args) # optional, log your configs
|
| 266 |
+
|
| 267 |
+
wandb.config['test_log'] = 12345 # log additional things
|
| 268 |
|
| 269 |
# Make one log on every process with the configuration for debugging.
|
| 270 |
logging.basicConfig(
|
|
|
|
| 606 |
epochs.write(
|
| 607 |
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
| 608 |
)
|
| 609 |
+
if jax.process_index() == 0:
|
| 610 |
+
wandb.log({'my_metric': train_metrics})
|
| 611 |
|
| 612 |
train_metrics = []
|
| 613 |
|
|
|
|
| 640 |
if has_tensorboard and jax.process_index() == 0:
|
| 641 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
| 642 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
| 643 |
+
wandb.log({'my_metric': eval_metrics})
|
| 644 |
|
| 645 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
| 646 |
# save checkpoint after each epoch and push checkpoint to the hub
|