Christina Theodoris
commited on
Commit
·
ead0550
1
Parent(s):
471eefc
update tokenizer to include eos token
Browse files- geneformer/tokenizer.py +4 -4
geneformer/tokenizer.py
CHANGED
|
@@ -102,7 +102,7 @@ class TranscriptomeTokenizer:
|
|
| 102 |
model_input_size : int = 2048
|
| 103 |
| Max input size of model to truncate input to.
|
| 104 |
special_token : bool = False
|
| 105 |
-
| Adds CLS token before and
|
| 106 |
gene_median_file : Path
|
| 107 |
| Path to pickle file containing dictionary of non-zero median
|
| 108 |
| gene expression values across Genecorpus-30M.
|
|
@@ -122,7 +122,7 @@ class TranscriptomeTokenizer:
|
|
| 122 |
# input size for tokenization
|
| 123 |
self.model_input_size = model_input_size
|
| 124 |
|
| 125 |
-
# add CLS and
|
| 126 |
self.special_token = special_token
|
| 127 |
|
| 128 |
# load dictionary of gene normalization factors
|
|
@@ -377,14 +377,14 @@ class TranscriptomeTokenizer:
|
|
| 377 |
if self.special_token:
|
| 378 |
example["input_ids"] = example["input_ids"][
|
| 379 |
0 : self.model_input_size - 2
|
| 380 |
-
] # truncate to leave space for CLS and
|
| 381 |
example["input_ids"] = np.insert(
|
| 382 |
example["input_ids"], 0, self.gene_token_dict.get("<cls>")
|
| 383 |
)
|
| 384 |
example["input_ids"] = np.insert(
|
| 385 |
example["input_ids"],
|
| 386 |
len(example["input_ids"]),
|
| 387 |
-
self.gene_token_dict.get("<
|
| 388 |
)
|
| 389 |
else:
|
| 390 |
# Truncate/Crop input_ids to input size
|
|
|
|
| 102 |
model_input_size : int = 2048
|
| 103 |
| Max input size of model to truncate input to.
|
| 104 |
special_token : bool = False
|
| 105 |
+
| Adds CLS token before and EOS token after rank value encoding.
|
| 106 |
gene_median_file : Path
|
| 107 |
| Path to pickle file containing dictionary of non-zero median
|
| 108 |
| gene expression values across Genecorpus-30M.
|
|
|
|
| 122 |
# input size for tokenization
|
| 123 |
self.model_input_size = model_input_size
|
| 124 |
|
| 125 |
+
# add CLS and EOS tokens
|
| 126 |
self.special_token = special_token
|
| 127 |
|
| 128 |
# load dictionary of gene normalization factors
|
|
|
|
| 377 |
if self.special_token:
|
| 378 |
example["input_ids"] = example["input_ids"][
|
| 379 |
0 : self.model_input_size - 2
|
| 380 |
+
] # truncate to leave space for CLS and EOS token
|
| 381 |
example["input_ids"] = np.insert(
|
| 382 |
example["input_ids"], 0, self.gene_token_dict.get("<cls>")
|
| 383 |
)
|
| 384 |
example["input_ids"] = np.insert(
|
| 385 |
example["input_ids"],
|
| 386 |
len(example["input_ids"]),
|
| 387 |
+
self.gene_token_dict.get("<eos>"),
|
| 388 |
)
|
| 389 |
else:
|
| 390 |
# Truncate/Crop input_ids to input size
|