Fixes and defaults
Browse files- mc4/mc4.py +18 -13
- run_mlm_flax_stream.py +8 -8
mc4/mc4.py
CHANGED
|
@@ -8,7 +8,6 @@ import datasets
|
|
| 8 |
import kenlm
|
| 9 |
import numpy as np
|
| 10 |
from numpy.random import default_rng
|
| 11 |
-
rng = default_rng()
|
| 12 |
|
| 13 |
|
| 14 |
logger = datasets.logging.get_logger(__name__)
|
|
@@ -284,11 +283,16 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
| 284 |
BUILDER_CONFIG_CLASS = Mc4Config
|
| 285 |
|
| 286 |
def __init__(self, *args, writer_batch_size=None, **kwargs):
|
| 287 |
-
self.sampling_method = kwargs.pop("sampling_method")
|
| 288 |
if self.sampling_method:
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
# Loading 5-gram model
|
| 293 |
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
|
| 294 |
logger.info("loading model = %s", self.perplexity_model)
|
|
@@ -298,7 +302,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
| 298 |
elif self.sampling_method == "random":
|
| 299 |
self.should_keep_doc = self._should_keep_doc_random
|
| 300 |
else:
|
| 301 |
-
self.should_keep_doc = self.
|
| 302 |
|
| 303 |
super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
|
| 304 |
|
|
@@ -311,7 +315,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
| 311 |
doc_length += length
|
| 312 |
return 10.0 ** (-doc_log_score / doc_length)
|
| 313 |
|
| 314 |
-
def _should_keep_doc_step(self, doc, factor=1, boundaries=None):
|
| 315 |
perplexity = self.get_perplexity(doc)
|
| 316 |
if boundaries is None:
|
| 317 |
boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
|
|
@@ -322,21 +326,22 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
| 322 |
elif boundaries[1] < perplexity < boundaries[2]:
|
| 323 |
quartile_range = boundaries[2] - boundaries[1]
|
| 324 |
elif perplexity >= boundaries[2]:
|
| 325 |
-
quartile_range =
|
| 326 |
probability = factor / quartile_range
|
| 327 |
-
return rng.uniform() < probability
|
| 328 |
|
| 329 |
-
def _should_keep_doc_gaussian(self, doc, factor=0.
|
| 330 |
perplexity = self.get_perplexity(doc)
|
| 331 |
if boundaries is not None:
|
| 332 |
m = boundaries[1]
|
| 333 |
else:
|
| 334 |
m = 662247.50212365
|
| 335 |
-
|
| 336 |
-
|
|
|
|
| 337 |
|
| 338 |
def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
|
| 339 |
-
return rng.uniform() <= 0.5
|
| 340 |
|
| 341 |
def _info(self):
|
| 342 |
return datasets.DatasetInfo(
|
|
|
|
| 8 |
import kenlm
|
| 9 |
import numpy as np
|
| 10 |
from numpy.random import default_rng
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
logger = datasets.logging.get_logger(__name__)
|
|
|
|
| 283 |
BUILDER_CONFIG_CLASS = Mc4Config
|
| 284 |
|
| 285 |
def __init__(self, *args, writer_batch_size=None, **kwargs):
|
| 286 |
+
self.sampling_method = kwargs.pop("sampling_method", None)
|
| 287 |
if self.sampling_method:
|
| 288 |
+
seed = kwargs.pop("seed", None)
|
| 289 |
+
if seed is not None:
|
| 290 |
+
self.rng = default_rng(seed)
|
| 291 |
+
else:
|
| 292 |
+
self.rng = default_rng()
|
| 293 |
+
self.perplexity_model = kwargs.pop("perplexity_model", None)
|
| 294 |
+
self.sampling_factor = kwargs.pop("sampling_factor", None)
|
| 295 |
+
self.boundaries = kwargs.pop("boundaries", None)
|
| 296 |
# Loading 5-gram model
|
| 297 |
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
|
| 298 |
logger.info("loading model = %s", self.perplexity_model)
|
|
|
|
| 302 |
elif self.sampling_method == "random":
|
| 303 |
self.should_keep_doc = self._should_keep_doc_random
|
| 304 |
else:
|
| 305 |
+
self.should_keep_doc = self._should_keep_doc_step
|
| 306 |
|
| 307 |
super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
|
| 308 |
|
|
|
|
| 315 |
doc_length += length
|
| 316 |
return 10.0 ** (-doc_log_score / doc_length)
|
| 317 |
|
| 318 |
+
def _should_keep_doc_step(self, doc, factor=1.5e5, boundaries=None):
|
| 319 |
perplexity = self.get_perplexity(doc)
|
| 320 |
if boundaries is None:
|
| 321 |
boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
|
|
|
|
| 326 |
elif boundaries[1] < perplexity < boundaries[2]:
|
| 327 |
quartile_range = boundaries[2] - boundaries[1]
|
| 328 |
elif perplexity >= boundaries[2]:
|
| 329 |
+
quartile_range = 10 * boundaries[2]
|
| 330 |
probability = factor / quartile_range
|
| 331 |
+
return self.rng.uniform() < probability
|
| 332 |
|
| 333 |
+
def _should_keep_doc_gaussian(self, doc, factor=0.78, boundaries=None):
|
| 334 |
perplexity = self.get_perplexity(doc)
|
| 335 |
if boundaries is not None:
|
| 336 |
m = boundaries[1]
|
| 337 |
else:
|
| 338 |
m = 662247.50212365
|
| 339 |
+
exponential = np.exp(-9/2 * ((perplexity - m) / m) ** 2)
|
| 340 |
+
weighted_perplexity = factor * exponential
|
| 341 |
+
return self.rng.uniform() < weighted_perplexity
|
| 342 |
|
| 343 |
def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
|
| 344 |
+
return self.rng.uniform() <= 0.5
|
| 345 |
|
| 346 |
def _info(self):
|
| 347 |
return datasets.DatasetInfo(
|
run_mlm_flax_stream.py
CHANGED
|
@@ -256,28 +256,27 @@ class FlaxDataCollatorForLanguageModeling:
|
|
| 256 |
return inputs, labels
|
| 257 |
|
| 258 |
|
| 259 |
-
|
| 260 |
@dataclass
|
| 261 |
class SamplingArguments:
|
| 262 |
"""
|
| 263 |
Arguments pertaining to how to perform sampling of the dataset.
|
| 264 |
"""
|
| 265 |
|
| 266 |
-
perplexity_model: Optional[str]
|
| 267 |
-
default="es.arpa.bin", metadata={"help": "
|
| 268 |
)
|
| 269 |
-
sampling_method: Optional[str]
|
| 270 |
-
default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document."}
|
| 271 |
)
|
| 272 |
-
sampling_factor: Optional[
|
| 273 |
-
default=
|
| 274 |
)
|
| 275 |
boundaries: Optional[str] = field(
|
| 276 |
default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
|
| 277 |
)
|
| 278 |
|
| 279 |
def __post_init__(self):
|
| 280 |
-
self.boundaries = [float(q) for q in self.boundaries.split(",")]
|
| 281 |
|
| 282 |
|
| 283 |
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
|
@@ -397,6 +396,7 @@ if __name__ == "__main__":
|
|
| 397 |
sampling_factor=sampling_args.sampling_factor,
|
| 398 |
boundaries=sampling_args.boundaries,
|
| 399 |
perplexity_model=sampling_args.perplexity_model,
|
|
|
|
| 400 |
)
|
| 401 |
|
| 402 |
if model_args.config_name:
|
|
|
|
| 256 |
return inputs, labels
|
| 257 |
|
| 258 |
|
|
|
|
| 259 |
@dataclass
|
| 260 |
class SamplingArguments:
|
| 261 |
"""
|
| 262 |
Arguments pertaining to how to perform sampling of the dataset.
|
| 263 |
"""
|
| 264 |
|
| 265 |
+
perplexity_model: Optional[str] = field(
|
| 266 |
+
default="./es.arpa.bin", metadata={"help": "Path to KenLM model to use to get perplexity values."}
|
| 267 |
)
|
| 268 |
+
sampling_method: Optional[str] = field(
|
| 269 |
+
default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document, or 'random'."}
|
| 270 |
)
|
| 271 |
+
sampling_factor: Optional[float] = field(
|
| 272 |
+
default=None, metadata={"help": "Sampling factor. Integers for step function, decimals for gaussian."}
|
| 273 |
)
|
| 274 |
boundaries: Optional[str] = field(
|
| 275 |
default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
|
| 276 |
)
|
| 277 |
|
| 278 |
def __post_init__(self):
|
| 279 |
+
self.boundaries = [float(q.strip()) for q in self.boundaries.split(",")]
|
| 280 |
|
| 281 |
|
| 282 |
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
|
|
|
| 396 |
sampling_factor=sampling_args.sampling_factor,
|
| 397 |
boundaries=sampling_args.boundaries,
|
| 398 |
perplexity_model=sampling_args.perplexity_model,
|
| 399 |
+
seed=training_args.seed,
|
| 400 |
)
|
| 401 |
|
| 402 |
if model_args.config_name:
|