Merge branch 'main' of https://huggingface.co/flax-community/bertin-roberta-large-spanish into main
Browse files- mc4/mc4.py +52 -31
- run_mlm_flax_stream.py +8 -2
    	
        mc4/mc4.py
    CHANGED
    
    | @@ -283,27 +283,28 @@ class Mc4(datasets.GeneratorBasedBuilder): | |
| 283 | 
             
                BUILDER_CONFIG_CLASS = Mc4Config
         | 
| 284 |  | 
| 285 | 
             
                def __init__(self, *args, writer_batch_size=None, **kwargs):
         | 
|  | |
| 286 | 
             
                    self.sampling_method = kwargs.pop("sampling_method", None)
         | 
|  | |
|  | |
|  | |
|  | |
| 287 | 
             
                    if self.sampling_method:
         | 
| 288 | 
            -
                         | 
| 289 | 
            -
             | 
| 290 | 
            -
                            self.rng = default_rng(seed)
         | 
| 291 | 
             
                        else:
         | 
| 292 | 
             
                            self.rng = default_rng()
         | 
| 293 | 
            -
                        self. | 
| 294 | 
            -
                        self.sampling_factor = kwargs.pop("sampling_factor", None)
         | 
| 295 | 
            -
                        self.boundaries = kwargs.pop("boundaries", None)
         | 
| 296 | 
            -
                        # Loading 5-gram model
         | 
| 297 | 
            -
                        # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
         | 
| 298 | 
            -
                        logger.info("loading model = %s", self.perplexity_model)
         | 
| 299 | 
            -
                        self.pp_model = kenlm.Model(self.perplexity_model)
         | 
| 300 | 
            -
                        if self.sampling_method == "gaussian":
         | 
| 301 | 
            -
                            self.should_keep_doc = self._should_keep_doc_gaussian
         | 
| 302 | 
            -
                        elif self.sampling_method == "random":
         | 
| 303 | 
             
                            self.should_keep_doc = self._should_keep_doc_random
         | 
| 304 | 
             
                        else:
         | 
| 305 | 
            -
                             | 
| 306 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 307 | 
             
                    super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
         | 
| 308 |  | 
| 309 | 
             
                def get_perplexity(self, doc):
         | 
| @@ -341,7 +342,9 @@ class Mc4(datasets.GeneratorBasedBuilder): | |
| 341 | 
             
                    return self.rng.uniform() < weighted_perplexity
         | 
| 342 |  | 
| 343 | 
             
                def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
         | 
| 344 | 
            -
                     | 
|  | |
|  | |
| 345 |  | 
| 346 | 
             
                def _info(self):
         | 
| 347 | 
             
                    return datasets.DatasetInfo(
         | 
| @@ -371,8 +374,18 @@ class Mc4(datasets.GeneratorBasedBuilder): | |
| 371 | 
             
                            for lang in self.config.languages
         | 
| 372 | 
             
                            for index in range(_N_SHARDS_PER_SPLIT[lang][split])
         | 
| 373 | 
             
                        ]
         | 
| 374 | 
            -
                     | 
| 375 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 376 | 
             
                    return [
         | 
| 377 | 
             
                        datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_downloaded_files}),
         | 
| 378 | 
             
                        datasets.SplitGenerator(
         | 
| @@ -385,21 +398,29 @@ class Mc4(datasets.GeneratorBasedBuilder): | |
| 385 | 
             
                    id_ = 0
         | 
| 386 | 
             
                    for filepath in filepaths:
         | 
| 387 | 
             
                        logger.info("generating examples from = %s", filepath)
         | 
| 388 | 
            -
                         | 
| 389 | 
            -
                             | 
| 390 | 
            -
                                logger.info("sampling method = %s", self.sampling_method)
         | 
| 391 | 
            -
                                for line in f:
         | 
| 392 | 
            -
                                    if line:
         | 
| 393 | 
            -
                                        example = json.loads(line)
         | 
| 394 | 
            -
                                        if self.should_keep_doc(
         | 
| 395 | 
            -
                                            example["text"],
         | 
| 396 | 
            -
                                            factor=self.sampling_factor,
         | 
| 397 | 
            -
                                            boundaries=self.boundaries):
         | 
| 398 | 
            -
                                            yield id_, example
         | 
| 399 | 
            -
                                            id_ += 1
         | 
| 400 | 
            -
                            else:
         | 
| 401 | 
             
                                for line in f:
         | 
| 402 | 
             
                                    if line:
         | 
| 403 | 
             
                                        example = json.loads(line)
         | 
| 404 | 
             
                                        yield id_, example
         | 
| 405 | 
             
                                        id_ += 1
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 283 | 
             
                BUILDER_CONFIG_CLASS = Mc4Config
         | 
| 284 |  | 
| 285 | 
             
                def __init__(self, *args, writer_batch_size=None, **kwargs):
         | 
| 286 | 
            +
                    self.data_files = kwargs.pop("data_files", {})
         | 
| 287 | 
             
                    self.sampling_method = kwargs.pop("sampling_method", None)
         | 
| 288 | 
            +
                    self.perplexity_model = kwargs.pop("perplexity_model", None)
         | 
| 289 | 
            +
                    self.sampling_factor = kwargs.pop("sampling_factor", None)
         | 
| 290 | 
            +
                    self.boundaries = kwargs.pop("boundaries", None)
         | 
| 291 | 
            +
                    self.seed = kwargs.pop("seed", None)
         | 
| 292 | 
             
                    if self.sampling_method:
         | 
| 293 | 
            +
                        if self.seed is not None:
         | 
| 294 | 
            +
                            self.rng = default_rng(self.seed)
         | 
|  | |
| 295 | 
             
                        else:
         | 
| 296 | 
             
                            self.rng = default_rng()
         | 
| 297 | 
            +
                        if self.sampling_method == "random":
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 298 | 
             
                            self.should_keep_doc = self._should_keep_doc_random
         | 
| 299 | 
             
                        else:
         | 
| 300 | 
            +
                            # Loading 5-gram model
         | 
| 301 | 
            +
                            # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
         | 
| 302 | 
            +
                            logger.info("loading model = %s", self.perplexity_model)
         | 
| 303 | 
            +
                            self.pp_model = kenlm.Model(self.perplexity_model)
         | 
| 304 | 
            +
                            if self.sampling_method == "gaussian":
         | 
| 305 | 
            +
                                self.should_keep_doc = self._should_keep_doc_gaussian
         | 
| 306 | 
            +
                            else:
         | 
| 307 | 
            +
                                self.should_keep_doc = self._should_keep_doc_step
         | 
| 308 | 
             
                    super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
         | 
| 309 |  | 
| 310 | 
             
                def get_perplexity(self, doc):
         | 
|  | |
| 342 | 
             
                    return self.rng.uniform() < weighted_perplexity
         | 
| 343 |  | 
| 344 | 
             
                def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
         | 
| 345 | 
            +
                    if factor is None:
         | 
| 346 | 
            +
                        factor = 0.5
         | 
| 347 | 
            +
                    return self.rng.uniform() <= factor
         | 
| 348 |  | 
| 349 | 
             
                def _info(self):
         | 
| 350 | 
             
                    return datasets.DatasetInfo(
         | 
|  | |
| 374 | 
             
                            for lang in self.config.languages
         | 
| 375 | 
             
                            for index in range(_N_SHARDS_PER_SPLIT[lang][split])
         | 
| 376 | 
             
                        ]
         | 
| 377 | 
            +
                    if "train" in self.data_files:
         | 
| 378 | 
            +
                        train_downloaded_files = self.data_files["train"]
         | 
| 379 | 
            +
                        if not isinstance(train_downloaded_files, (tuple, list)):
         | 
| 380 | 
            +
                            train_downloaded_files = [train_downloaded_files]
         | 
| 381 | 
            +
                    else:
         | 
| 382 | 
            +
                        train_downloaded_files = dl_manager.download(data_urls["train"])
         | 
| 383 | 
            +
                    if "validation" in self.data_files:
         | 
| 384 | 
            +
                        validation_downloaded_files = self.data_files["validation"]
         | 
| 385 | 
            +
                        if not isinstance(validation_downloaded_files, (tuple, list)):
         | 
| 386 | 
            +
                            validation_downloaded_files = [validation_downloaded_files]
         | 
| 387 | 
            +
                    else:
         | 
| 388 | 
            +
                        validation_downloaded_files = dl_manager.download(data_urls["validation"])
         | 
| 389 | 
             
                    return [
         | 
| 390 | 
             
                        datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_downloaded_files}),
         | 
| 391 | 
             
                        datasets.SplitGenerator(
         | 
|  | |
| 398 | 
             
                    id_ = 0
         | 
| 399 | 
             
                    for filepath in filepaths:
         | 
| 400 | 
             
                        logger.info("generating examples from = %s", filepath)
         | 
| 401 | 
            +
                        if filepath.endswith("jsonl"):
         | 
| 402 | 
            +
                            with open(filepath, "r", encoding="utf-8") as f:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 403 | 
             
                                for line in f:
         | 
| 404 | 
             
                                    if line:
         | 
| 405 | 
             
                                        example = json.loads(line)
         | 
| 406 | 
             
                                        yield id_, example
         | 
| 407 | 
             
                                        id_ += 1
         | 
| 408 | 
            +
                        else:
         | 
| 409 | 
            +
                            with gzip.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
         | 
| 410 | 
            +
                                if self.sampling_method:
         | 
| 411 | 
            +
                                    logger.info("sampling method = %s", self.sampling_method)
         | 
| 412 | 
            +
                                    for line in f:
         | 
| 413 | 
            +
                                        if line:
         | 
| 414 | 
            +
                                            example = json.loads(line)
         | 
| 415 | 
            +
                                            if self.should_keep_doc(
         | 
| 416 | 
            +
                                                example["text"],
         | 
| 417 | 
            +
                                                factor=self.sampling_factor,
         | 
| 418 | 
            +
                                                boundaries=self.boundaries):
         | 
| 419 | 
            +
                                                yield id_, example
         | 
| 420 | 
            +
                                                id_ += 1
         | 
| 421 | 
            +
                                else:
         | 
| 422 | 
            +
                                    for line in f:
         | 
| 423 | 
            +
                                        if line:
         | 
| 424 | 
            +
                                            example = json.loads(line)
         | 
| 425 | 
            +
                                            yield id_, example
         | 
| 426 | 
            +
                                            id_ += 1
         | 
    	
        run_mlm_flax_stream.py
    CHANGED
    
    | @@ -178,10 +178,10 @@ class DataTrainingArguments: | |
| 178 | 
             
                    else:
         | 
| 179 | 
             
                        if self.train_file is not None:
         | 
| 180 | 
             
                            extension = self.train_file.split(".")[-1]
         | 
| 181 | 
            -
                            assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
         | 
| 182 | 
             
                        if self.validation_file is not None:
         | 
| 183 | 
             
                            extension = self.validation_file.split(".")[-1]
         | 
| 184 | 
            -
                            assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
         | 
| 185 |  | 
| 186 |  | 
| 187 | 
             
            @flax.struct.dataclass
         | 
| @@ -386,6 +386,11 @@ if __name__ == "__main__": | |
| 386 | 
             
                # 'text' is found. You can easily tweak this behavior (see below).
         | 
| 387 | 
             
                if data_args.dataset_name is not None:
         | 
| 388 | 
             
                    # Downloading and loading a dataset from the hub.
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 389 | 
             
                    dataset = load_dataset(
         | 
| 390 | 
             
                        data_args.dataset_name,
         | 
| 391 | 
             
                        data_args.dataset_config_name,
         | 
| @@ -397,6 +402,7 @@ if __name__ == "__main__": | |
| 397 | 
             
                        boundaries=sampling_args.boundaries,
         | 
| 398 | 
             
                        perplexity_model=sampling_args.perplexity_model,
         | 
| 399 | 
             
                        seed=training_args.seed,
         | 
|  | |
| 400 | 
             
                    )
         | 
| 401 |  | 
| 402 | 
             
                if model_args.config_name:
         | 
|  | |
| 178 | 
             
                    else:
         | 
| 179 | 
             
                        if self.train_file is not None:
         | 
| 180 | 
             
                            extension = self.train_file.split(".")[-1]
         | 
| 181 | 
            +
                            assert extension in ["csv", "json", "jsonl", "txt", "gz"], "`train_file` should be a csv, a json (lines) or a txt file."
         | 
| 182 | 
             
                        if self.validation_file is not None:
         | 
| 183 | 
             
                            extension = self.validation_file.split(".")[-1]
         | 
| 184 | 
            +
                            assert extension in ["csv", "json", "jsonl", "txt", "gz"], "`validation_file` should be a csv, a json (lines) or a txt file."
         | 
| 185 |  | 
| 186 |  | 
| 187 | 
             
            @flax.struct.dataclass
         | 
|  | |
| 386 | 
             
                # 'text' is found. You can easily tweak this behavior (see below).
         | 
| 387 | 
             
                if data_args.dataset_name is not None:
         | 
| 388 | 
             
                    # Downloading and loading a dataset from the hub.
         | 
| 389 | 
            +
                    filepaths = {}
         | 
| 390 | 
            +
                    if data_args.train_file:
         | 
| 391 | 
            +
                        filepaths["train"] = data_args.train_file
         | 
| 392 | 
            +
                    if data_args.validation_file:
         | 
| 393 | 
            +
                        filepaths["validation"] = data_args.validation_file
         | 
| 394 | 
             
                    dataset = load_dataset(
         | 
| 395 | 
             
                        data_args.dataset_name,
         | 
| 396 | 
             
                        data_args.dataset_config_name,
         | 
|  | |
| 402 | 
             
                        boundaries=sampling_args.boundaries,
         | 
| 403 | 
             
                        perplexity_model=sampling_args.perplexity_model,
         | 
| 404 | 
             
                        seed=training_args.seed,
         | 
| 405 | 
            +
                        data_files=filepaths,
         | 
| 406 | 
             
                    )
         | 
| 407 |  | 
| 408 | 
             
                if model_args.config_name:
         | 

