Spaces:
Paused
Paused
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -237,15 +237,20 @@ def convert_to_markdown(text):
|
|
| 237 |
return markdown_text
|
| 238 |
|
| 239 |
|
| 240 |
-
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
-
def
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
-
def
|
| 247 |
-
self.
|
| 248 |
-
shared_state = State()
|
| 249 |
|
| 250 |
|
| 251 |
#######################################################
|
|
|
|
| 237 |
return markdown_text
|
| 238 |
|
| 239 |
|
| 240 |
+
#Datasets encodieren - in train und val Sets
|
| 241 |
+
class Dataset(torch.utils.data.Dataset):
|
| 242 |
+
def __init__(self, encodings, labels=None):
|
| 243 |
+
self.encodings = encodings
|
| 244 |
+
self.labels = labels
|
| 245 |
|
| 246 |
+
def __getitem__(self, idx):
|
| 247 |
+
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
| 248 |
+
if self.labels:
|
| 249 |
+
item["labels"] = torch.tensor(self.labels[idx])
|
| 250 |
+
return item
|
| 251 |
|
| 252 |
+
def __len__(self):
|
| 253 |
+
return len(self.encodings["input_ids"])
|
|
|
|
| 254 |
|
| 255 |
|
| 256 |
#######################################################
|