Adding reading streaming files from local disk
Browse files- mc4/mc4.py +36 -14
- run_mlm_flax_stream.py +8 -2
mc4/mc4.py
CHANGED
|
@@ -283,6 +283,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
| 283 |
BUILDER_CONFIG_CLASS = Mc4Config
|
| 284 |
|
| 285 |
def __init__(self, *args, writer_batch_size=None, **kwargs):
|
|
|
|
| 286 |
self.sampling_method = kwargs.pop("sampling_method", None)
|
| 287 |
if self.sampling_method:
|
| 288 |
seed = kwargs.pop("seed", None)
|
|
@@ -290,19 +291,20 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
| 290 |
self.rng = default_rng(seed)
|
| 291 |
else:
|
| 292 |
self.rng = default_rng()
|
| 293 |
-
self.
|
| 294 |
-
self.sampling_factor = kwargs.pop("sampling_factor", None)
|
| 295 |
-
self.boundaries = kwargs.pop("boundaries", None)
|
| 296 |
-
# Loading 5-gram model
|
| 297 |
-
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
|
| 298 |
-
logger.info("loading model = %s", self.perplexity_model)
|
| 299 |
-
self.pp_model = kenlm.Model(self.perplexity_model)
|
| 300 |
-
if self.sampling_method == "gaussian":
|
| 301 |
-
self.should_keep_doc = self._should_keep_doc_gaussian
|
| 302 |
-
elif self.sampling_method == "random":
|
| 303 |
self.should_keep_doc = self._should_keep_doc_random
|
| 304 |
else:
|
| 305 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
|
| 308 |
|
|
@@ -341,7 +343,9 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
| 341 |
return self.rng.uniform() < weighted_perplexity
|
| 342 |
|
| 343 |
def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
|
| 344 |
-
|
|
|
|
|
|
|
| 345 |
|
| 346 |
def _info(self):
|
| 347 |
return datasets.DatasetInfo(
|
|
@@ -371,8 +375,18 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
| 371 |
for lang in self.config.languages
|
| 372 |
for index in range(_N_SHARDS_PER_SPLIT[lang][split])
|
| 373 |
]
|
| 374 |
-
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
return [
|
| 377 |
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_downloaded_files}),
|
| 378 |
datasets.SplitGenerator(
|
|
@@ -385,6 +399,14 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
| 385 |
id_ = 0
|
| 386 |
for filepath in filepaths:
|
| 387 |
logger.info("generating examples from = %s", filepath)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
with gzip.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
|
| 389 |
if self.sampling_method:
|
| 390 |
logger.info("sampling method = %s", self.sampling_method)
|
|
|
|
| 283 |
BUILDER_CONFIG_CLASS = Mc4Config
|
| 284 |
|
| 285 |
def __init__(self, *args, writer_batch_size=None, **kwargs):
|
| 286 |
+
self.filepaths = kwargs.pop(filepaths, {})
|
| 287 |
self.sampling_method = kwargs.pop("sampling_method", None)
|
| 288 |
if self.sampling_method:
|
| 289 |
seed = kwargs.pop("seed", None)
|
|
|
|
| 291 |
self.rng = default_rng(seed)
|
| 292 |
else:
|
| 293 |
self.rng = default_rng()
|
| 294 |
+
if self.sampling_method == "random":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
self.should_keep_doc = self._should_keep_doc_random
|
| 296 |
else:
|
| 297 |
+
self.perplexity_model = kwargs.pop("perplexity_model", None)
|
| 298 |
+
self.sampling_factor = kwargs.pop("sampling_factor", None)
|
| 299 |
+
self.boundaries = kwargs.pop("boundaries", None)
|
| 300 |
+
# Loading 5-gram model
|
| 301 |
+
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
|
| 302 |
+
logger.info("loading model = %s", self.perplexity_model)
|
| 303 |
+
self.pp_model = kenlm.Model(self.perplexity_model)
|
| 304 |
+
if self.sampling_method == "gaussian":
|
| 305 |
+
self.should_keep_doc = self._should_keep_doc_gaussian
|
| 306 |
+
else:
|
| 307 |
+
self.should_keep_doc = self._should_keep_doc_step
|
| 308 |
|
| 309 |
super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
|
| 310 |
|
|
|
|
| 343 |
return self.rng.uniform() < weighted_perplexity
|
| 344 |
|
| 345 |
def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
|
| 346 |
+
if factor is None:
|
| 347 |
+
factor = 0.5
|
| 348 |
+
return self.rng.uniform() <= factor
|
| 349 |
|
| 350 |
def _info(self):
|
| 351 |
return datasets.DatasetInfo(
|
|
|
|
| 375 |
for lang in self.config.languages
|
| 376 |
for index in range(_N_SHARDS_PER_SPLIT[lang][split])
|
| 377 |
]
|
| 378 |
+
if "train" in self.filepaths:
|
| 379 |
+
train_downloaded_files = self.filepaths["train"]
|
| 380 |
+
if not isinstance(train_downloaded_files, (tuple, list)):
|
| 381 |
+
train_downloaded_files = [train_downloaded_files]
|
| 382 |
+
else:
|
| 383 |
+
train_downloaded_files = dl_manager.download(data_urls["train"])
|
| 384 |
+
if "validation" in self.filepaths:
|
| 385 |
+
validation_downloaded_files = self.filepaths["validation"]
|
| 386 |
+
if not isinstance(validation_downloaded_files, (tuple, list)):
|
| 387 |
+
validation_downloaded_files = [validation_downloaded_files]
|
| 388 |
+
else:
|
| 389 |
+
validation_downloaded_files = dl_manager.download(data_urls["validation"])
|
| 390 |
return [
|
| 391 |
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_downloaded_files}),
|
| 392 |
datasets.SplitGenerator(
|
|
|
|
| 399 |
id_ = 0
|
| 400 |
for filepath in filepaths:
|
| 401 |
logger.info("generating examples from = %s", filepath)
|
| 402 |
+
if filepath.endswith("json") or filepath.endswith("jsonl"):
|
| 403 |
+
with open(filepath, "r", encoding="utf-8") as f:
|
| 404 |
+
for line in f:
|
| 405 |
+
if line:
|
| 406 |
+
example = json.loads(line)
|
| 407 |
+
yield id_, example
|
| 408 |
+
id_ += 1
|
| 409 |
+
else:
|
| 410 |
with gzip.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
|
| 411 |
if self.sampling_method:
|
| 412 |
logger.info("sampling method = %s", self.sampling_method)
|
run_mlm_flax_stream.py
CHANGED
|
@@ -178,10 +178,10 @@ class DataTrainingArguments:
|
|
| 178 |
else:
|
| 179 |
if self.train_file is not None:
|
| 180 |
extension = self.train_file.split(".")[-1]
|
| 181 |
-
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
| 182 |
if self.validation_file is not None:
|
| 183 |
extension = self.validation_file.split(".")[-1]
|
| 184 |
-
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
| 185 |
|
| 186 |
|
| 187 |
@flax.struct.dataclass
|
|
@@ -386,6 +386,11 @@ if __name__ == "__main__":
|
|
| 386 |
# 'text' is found. You can easily tweak this behavior (see below).
|
| 387 |
if data_args.dataset_name is not None:
|
| 388 |
# Downloading and loading a dataset from the hub.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
dataset = load_dataset(
|
| 390 |
data_args.dataset_name,
|
| 391 |
data_args.dataset_config_name,
|
|
@@ -397,6 +402,7 @@ if __name__ == "__main__":
|
|
| 397 |
boundaries=sampling_args.boundaries,
|
| 398 |
perplexity_model=sampling_args.perplexity_model,
|
| 399 |
seed=training_args.seed,
|
|
|
|
| 400 |
)
|
| 401 |
|
| 402 |
if model_args.config_name:
|
|
|
|
| 178 |
else:
|
| 179 |
if self.train_file is not None:
|
| 180 |
extension = self.train_file.split(".")[-1]
|
| 181 |
+
assert extension in ["csv", "json", "txt", "gz"], "`train_file` should be a csv, a json or a txt file."
|
| 182 |
if self.validation_file is not None:
|
| 183 |
extension = self.validation_file.split(".")[-1]
|
| 184 |
+
assert extension in ["csv", "json", "txt", "gz"], "`validation_file` should be a csv, a json or a txt file."
|
| 185 |
|
| 186 |
|
| 187 |
@flax.struct.dataclass
|
|
|
|
| 386 |
# 'text' is found. You can easily tweak this behavior (see below).
|
| 387 |
if data_args.dataset_name is not None:
|
| 388 |
# Downloading and loading a dataset from the hub.
|
| 389 |
+
filepaths = {}
|
| 390 |
+
if data_args.train_file:
|
| 391 |
+
filepaths["train"] = data_args.train_file
|
| 392 |
+
if data_args.validation_file:
|
| 393 |
+
filepaths["validation"] = data_args.validation_file
|
| 394 |
dataset = load_dataset(
|
| 395 |
data_args.dataset_name,
|
| 396 |
data_args.dataset_config_name,
|
|
|
|
| 402 |
boundaries=sampling_args.boundaries,
|
| 403 |
perplexity_model=sampling_args.perplexity_model,
|
| 404 |
seed=training_args.seed,
|
| 405 |
+
filepaths=filepaths,
|
| 406 |
)
|
| 407 |
|
| 408 |
if model_args.config_name:
|