udpated the model and script to load local data
Browse files- pytorch_model.bin +1 -1
- run_clm_flax.py +6 -1
- run_pretraining.sh +3 -2
pytorch_model.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1444576537
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f67e392707d4b269ea616f717bb7451e79d8cc0235449e990209b12bb74aad45
|
| 3 |
size 1444576537
|
run_clm_flax.py
CHANGED
|
@@ -112,6 +112,9 @@ class DataTrainingArguments:
|
|
| 112 |
dataset_config_name: Optional[str] = field(
|
| 113 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 114 |
)
|
|
|
|
|
|
|
|
|
|
| 115 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
| 116 |
validation_file: Optional[str] = field(
|
| 117 |
default=None,
|
|
@@ -296,19 +299,21 @@ def main():
|
|
| 296 |
if data_args.dataset_name is not None:
|
| 297 |
# Downloading and loading a dataset from the hub.
|
| 298 |
dataset = load_dataset(
|
| 299 |
-
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
|
| 300 |
)
|
| 301 |
|
| 302 |
if "validation" not in dataset.keys():
|
| 303 |
dataset["validation"] = load_dataset(
|
| 304 |
data_args.dataset_name,
|
| 305 |
data_args.dataset_config_name,
|
|
|
|
| 306 |
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 307 |
cache_dir=model_args.cache_dir,
|
| 308 |
)
|
| 309 |
dataset["train"] = load_dataset(
|
| 310 |
data_args.dataset_name,
|
| 311 |
data_args.dataset_config_name,
|
|
|
|
| 312 |
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 313 |
cache_dir=model_args.cache_dir,
|
| 314 |
)
|
|
|
|
| 112 |
dataset_config_name: Optional[str] = field(
|
| 113 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 114 |
)
|
| 115 |
+
dataset_data_dir: Optional[str] = field(
|
| 116 |
+
default=None, metadata={"help": "The name of the data directory."}
|
| 117 |
+
)
|
| 118 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
| 119 |
validation_file: Optional[str] = field(
|
| 120 |
default=None,
|
|
|
|
| 299 |
if data_args.dataset_name is not None:
|
| 300 |
# Downloading and loading a dataset from the hub.
|
| 301 |
dataset = load_dataset(
|
| 302 |
+
data_args.dataset_name, data_args.dataset_config_name, data_dir=data_args.dataset_data_dir, cache_dir=model_args.cache_dir, keep_in_memory=False
|
| 303 |
)
|
| 304 |
|
| 305 |
if "validation" not in dataset.keys():
|
| 306 |
dataset["validation"] = load_dataset(
|
| 307 |
data_args.dataset_name,
|
| 308 |
data_args.dataset_config_name,
|
| 309 |
+
data_dir=data_args.dataset_data_dir,
|
| 310 |
split=f"train[:{data_args.validation_split_percentage}%]",
|
| 311 |
cache_dir=model_args.cache_dir,
|
| 312 |
)
|
| 313 |
dataset["train"] = load_dataset(
|
| 314 |
data_args.dataset_name,
|
| 315 |
data_args.dataset_config_name,
|
| 316 |
+
data_dir=data_args.dataset_data_dir,
|
| 317 |
split=f"train[{data_args.validation_split_percentage}%:]",
|
| 318 |
cache_dir=model_args.cache_dir,
|
| 319 |
)
|
run_pretraining.sh
CHANGED
|
@@ -9,8 +9,9 @@ export WANDB_LOG_MODEL="true"
|
|
| 9 |
--model_type="gpt2" \
|
| 10 |
--config_name="${MODEL_DIR}" \
|
| 11 |
--tokenizer_name="${MODEL_DIR}" \
|
| 12 |
-
--dataset_name="
|
| 13 |
-
--dataset_config_name="
|
|
|
|
| 14 |
--do_train --do_eval \
|
| 15 |
--block_size="512" \
|
| 16 |
--per_device_train_batch_size="24" \
|
|
|
|
| 9 |
--model_type="gpt2" \
|
| 10 |
--config_name="${MODEL_DIR}" \
|
| 11 |
--tokenizer_name="${MODEL_DIR}" \
|
| 12 |
+
--dataset_name="./datasets/id_collection" \
|
| 13 |
+
--dataset_config_name="id_collection" \
|
| 14 |
+
--dataset_data_dir="/data/collection" \
|
| 15 |
--do_train --do_eval \
|
| 16 |
--block_size="512" \
|
| 17 |
--per_device_train_batch_size="24" \
|