sp-embraceable commited on
Commit
89dbe87
·
verified ·
1 Parent(s): f74147e

Upload Provence

Browse files
Files changed (4) hide show
  1. README.md +199 -0
  2. config.json +44 -0
  3. model.safetensors +3 -0
  4. modeling_provence.py +472 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Provence"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_provence.ProvenceConfig",
8
+ "AutoModel": "modeling_provence.Provence"
9
+ },
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 1024,
13
+ "id2label": {
14
+ "0": "LABEL_0"
15
+ },
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 4096,
18
+ "label2id": {
19
+ "LABEL_0": 0
20
+ },
21
+ "layer_norm_eps": 1e-07,
22
+ "max_position_embeddings": 512,
23
+ "max_relative_positions": -1,
24
+ "model_type": "Provence",
25
+ "norm_rel_ebd": "layer_norm",
26
+ "num_attention_heads": 16,
27
+ "num_hidden_layers": 24,
28
+ "pad_token_id": 0,
29
+ "pooler_dropout": 0,
30
+ "pooler_hidden_act": "gelu",
31
+ "pooler_hidden_size": 1024,
32
+ "pos_att_type": [
33
+ "p2c",
34
+ "c2p"
35
+ ],
36
+ "position_biased_input": false,
37
+ "position_buckets": 256,
38
+ "relative_attention": true,
39
+ "share_att_key": true,
40
+ "torch_dtype": "float32",
41
+ "transformers_version": "4.53.2",
42
+ "type_vocab_size": 0,
43
+ "vocab_size": 128100
44
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e838da129cf72caa2ce36dfacca1b7a748ff2e7cb2c6682eed9a5839ebf90aca
3
+ size 1740308732
modeling_provence.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+ from typing import Optional, Union, Tuple, List
3
+ from dataclasses import dataclass
4
+ from tqdm import tqdm
5
+ import warnings
6
+ import nltk
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import Dataset
12
+ from torch.nn.utils.rnn import pad_sequence
13
+ from transformers import AutoTokenizer
14
+ from transformers import DebertaV2PreTrainedModel, DebertaV2Model, PretrainedConfig
15
+ try:
16
+ from transformers.models.deberta_v2.modeling_deberta_v2 import (
17
+ StableDropout,
18
+ ContextPooler,
19
+ )
20
+ except ImportError:
21
+ from transformers.models.deberta_v2.modeling_deberta_v2 import ContextPooler
22
+ StableDropout = nn.Dropout
23
+ from transformers.modeling_outputs import ModelOutput
24
+
25
+
26
+ @dataclass
27
+ class RankingCompressionOutput(ModelOutput):
28
+
29
+ compression_logits: torch.FloatTensor = None
30
+ ranking_scores: torch.FloatTensor = None
31
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
32
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
33
+
34
+
35
+ """adapted from https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L1357
36
+ """
37
+
38
+
39
+ class ProvenceConfig(PretrainedConfig):
40
+
41
+ model_type = "Provence"
42
+
43
+ def __init__(self, **kwargs):
44
+ super().__init__(**kwargs)
45
+
46
+
47
+ class Provence(DebertaV2PreTrainedModel):
48
+
49
+ config_class = ProvenceConfig
50
+
51
+ def __init__(self, config):
52
+ super().__init__(config)
53
+ num_labels = getattr(config, "num_labels", 2)
54
+ self.num_labels = num_labels
55
+ self.deberta = DebertaV2Model(config)
56
+ self.pooler = ContextPooler(config)
57
+ output_dim = self.pooler.output_dim
58
+
59
+ ### RANKING LAYER
60
+ self.classifier = nn.Linear(output_dim, num_labels)
61
+ drop_out = getattr(config, "cls_dropout", None)
62
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
63
+ self.dropout = StableDropout(drop_out)
64
+
65
+ ### COMPRESSION LAYER: another head
66
+ token_dropout = drop_out
67
+ self.token_dropout = nn.Dropout(token_dropout)
68
+ self.token_classifier = nn.Linear(
69
+ config.hidden_size, 2
70
+ ) # => hard coded number of labels
71
+ self.name = "Provence"
72
+ self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
73
+ self.max_len = config.max_position_embeddings
74
+
75
+ # Initialize weights and apply final processing
76
+ self.post_init()
77
+
78
+ def forward(
79
+ self,
80
+ input_ids: Optional[torch.LongTensor] = None,
81
+ attention_mask: Optional[torch.FloatTensor] = None,
82
+ ) -> RankingCompressionOutput:
83
+ outputs = self.deberta(
84
+ input_ids,
85
+ attention_mask=attention_mask,
86
+ )
87
+
88
+ encoder_layer = outputs[0]
89
+ pooled_output = self.pooler(encoder_layer)
90
+ pooled_output = self.dropout(pooled_output)
91
+ ranking_logits = self.classifier(pooled_output)
92
+ compression_logits = self.token_classifier(self.token_dropout(encoder_layer))
93
+ ranking_scores = ranking_logits[
94
+ :, 0
95
+ ].squeeze() # select first dim of logits for ranking scores
96
+
97
+ return RankingCompressionOutput(
98
+ compression_logits=compression_logits,
99
+ ranking_scores=ranking_scores,
100
+ hidden_states=outputs.hidden_states,
101
+ attentions=outputs.attentions,
102
+ )
103
+
104
+ def process(
105
+ self,
106
+ question: Union[List[str], str],
107
+ context: Union[List[List[str]], str],
108
+ title: Optional[Union[List[List[str]], str]] = "first_sentence",
109
+ batch_size=32,
110
+ threshold=0.1,
111
+ always_select_title=False,
112
+ reorder=False,
113
+ top_k=5,
114
+ enable_warnings=True,
115
+ ):
116
+
117
+ # convert input format into queries of type List[str] and contexts/titles of type List[List[str]]
118
+ if type(question) == str:
119
+ queries = [question]
120
+ else: # list of strs
121
+ queries = question
122
+ if type(context) == str:
123
+ contexts = [[context]]
124
+ else:
125
+ contexts = context
126
+ if type(title) == str and title != "first_sentence":
127
+ titles = [[title]]
128
+ else:
129
+ titles = title
130
+ assert (
131
+ titles == "first_sentence"
132
+ or titles == None
133
+ or type(titles) == list
134
+ and len(titles) == len(queries)
135
+ ), "Variable 'titles' must be 'first_sentence' or a list of strings of the same length as 'queries'"
136
+ if type(titles) == list:
137
+ assert all(
138
+ [
139
+ len(titles_item) == len(contexts_item)
140
+ for titles_item, contexts_item in zip(contexts, titles)
141
+ ]
142
+ ), "Each list in 'titles' must have the same length as the corresponding list in 'context'"
143
+ assert len(queries) == len(
144
+ contexts
145
+ ), "Lists 'queries' and 'contexts' must have same lengths"
146
+ dataset = TestDataset(
147
+ queries=queries,
148
+ contexts=contexts,
149
+ titles=titles,
150
+ tokenizer=self.tokenizer,
151
+ max_len=self.max_len,
152
+ enable_warnings=enable_warnings,
153
+ )
154
+ selected_contexts = [
155
+ [{0: contexts[i][j]} for j in range(len(contexts[i]))]
156
+ for i in range(len(queries))
157
+ ]
158
+ reranking_scores = [
159
+ [None for j in range(len(contexts[i]))] for i in range(len(queries))
160
+ ]
161
+ compressions = [
162
+ [0 for j in range(len(contexts[i]))] for i in range(len(queries))
163
+ ]
164
+ with torch.no_grad():
165
+ for batch_start in tqdm(
166
+ range(0, len(dataset), batch_size), desc="Pruning contexts..."
167
+ ):
168
+ qis = dataset.qis[batch_start : batch_start + batch_size]
169
+ cis = dataset.cis[batch_start : batch_start + batch_size]
170
+ sis = dataset.sis[batch_start : batch_start + batch_size]
171
+ sent_coords = dataset.sent_coords[
172
+ batch_start : batch_start + batch_size
173
+ ]
174
+ ids_list = dataset.ids[batch_start : batch_start + batch_size]
175
+ ids = pad_sequence(
176
+ ids_list, batch_first=True, padding_value=dataset.pad_idx
177
+ ).to(self.device)
178
+ mask = (ids != dataset.pad_idx).to(self.device)
179
+ outputs = self.forward(ids, mask)
180
+ scores = F.softmax(outputs["compression_logits"].cpu(), dim=-1)[:, :, 1]
181
+ token_preds = scores > threshold
182
+ reranking_scrs = (
183
+ outputs["ranking_scores"].cpu().numpy()
184
+ ) # get first score
185
+ if len(reranking_scrs.shape) == 0:
186
+ reranking_scrs = reranking_scrs[None]
187
+ for (
188
+ ids_list_,
189
+ token_preds_,
190
+ rerank_score,
191
+ qi,
192
+ ci,
193
+ si,
194
+ sent_coords_,
195
+ ) in zip(
196
+ ids_list, token_preds, reranking_scrs, qis, cis, sis, sent_coords
197
+ ):
198
+
199
+ selected_mask = sentence_rounding(
200
+ token_preds_.cpu().numpy(),
201
+ np.array(sent_coords_),
202
+ threshold=threshold,
203
+ always_select_title=always_select_title
204
+ and si == 0
205
+ and titles != None,
206
+ )
207
+ assert len(selected_mask) == len(token_preds_)
208
+ selected_contexts[qi][ci][si] = ids_list_[
209
+ selected_mask[: len(ids_list_)]
210
+ ]
211
+ if si == 0:
212
+ reranking_scores[qi][ci] = rerank_score
213
+ for i in range(len(queries)):
214
+ for j in range(len(contexts[i])):
215
+ if type(selected_contexts[i][j][0]) != str:
216
+ toks = torch.cat(
217
+ [
218
+ ids_
219
+ for _, ids_ in sorted(
220
+ selected_contexts[i][j].items(), key=lambda x: x[0]
221
+ )
222
+ ]
223
+ )
224
+ selected_contexts[i][j] = self.tokenizer.decode(
225
+ toks,
226
+ skip_special_tokens=True,
227
+ clean_up_tokenization_spaces=False,
228
+ )
229
+ else:
230
+ selected_contexts[i][j] = selected_contexts[i][j][0]
231
+ len_original = len(contexts[i][j])
232
+ len_compressed = len(selected_contexts[i][j])
233
+ compressions[i][j] = (len_original-len_compressed)/len_original * 100
234
+ if reorder:
235
+ idxs = np.argsort(reranking_scores[i])[::-1][:top_k]
236
+ selected_contexts[i] = [selected_contexts[i][j] for j in idxs]
237
+ reranking_scores[i] = [reranking_scores[i][j] for j in idxs]
238
+ compressions[i] = [compressions[i][j] for j in idxs]
239
+
240
+ if type(context) == str:
241
+ selected_contexts = selected_contexts[0][0]
242
+ reranking_scores = reranking_scores[0][0]
243
+ compressions = compressions[0][0]
244
+
245
+ return {
246
+ "pruned_context": selected_contexts,
247
+ "reranking_score": reranking_scores,
248
+ "compression_rate": compressions,
249
+ }
250
+
251
+
252
+ # Some utils functions
253
+
254
+
255
+ def sentence_rounding(predictions, chunks, threshold, always_select_title=True):
256
+ """
257
+ predictions: a binary vector containing 1 for tokens which were selected and 0s otherwise
258
+ chunks: a list of pairs [start, end] of sentence, i.e. sentence is in coordinates predictions[start:end]
259
+ the functions
260
+ """
261
+ cumulative_sum = np.cumsum(predictions)
262
+ chunk_sums = cumulative_sum[chunks[:, 1] - 1] - np.where(
263
+ chunks[:, 0] > 0, cumulative_sum[chunks[:, 0] - 1], 0
264
+ )
265
+ chunk_lengths = chunks[:, 1] - chunks[:, 0]
266
+ chunk_means = chunk_sums / chunk_lengths
267
+ if always_select_title and (chunk_means>threshold).any():
268
+ chunk_means[0] = 1
269
+ means = np.hstack((np.zeros(1), chunk_means, np.zeros(1)))
270
+ repeats = np.hstack(
271
+ ([chunks[0][0]], chunk_lengths, [predictions.shape[0] - chunks[-1][1]])
272
+ )
273
+ return np.repeat(means, repeats) > threshold
274
+
275
+
276
+ def normalize(s: str) -> str:
277
+ def white_space_fix(text):
278
+ return " ".join(text.split())
279
+
280
+ def remove_punc(text):
281
+ exclude = set(string.punctuation)
282
+ return "".join(ch for ch in text if ch not in exclude)
283
+
284
+ def lower(text):
285
+ return text.lower()
286
+
287
+ return white_space_fix(remove_punc(lower(s)))
288
+
289
+
290
+ def sent_split_and_tokenize(text, tokenizer, max_len):
291
+ sents_nltk = nltk.sent_tokenize(text)
292
+ sents = []
293
+ for j, sent_nltk in enumerate(sents_nltk):
294
+ tokinput = (" " if j != 0 else "") + sent_nltk
295
+ tok = tokenizer.encode(tokinput, add_special_tokens=False)
296
+ ltok = len(tok)
297
+ if ltok == 0:
298
+ continue
299
+ if ltok <= max_len:
300
+ sents.append(tok)
301
+ else:
302
+ for begin in range(0, ltok, max_len):
303
+ sents.append(tok[begin : begin + max_len])
304
+ return sents
305
+
306
+
307
+ class TestDataset(Dataset):
308
+ def __init__(
309
+ self,
310
+ queries,
311
+ contexts,
312
+ tokenizer,
313
+ max_len=512,
314
+ titles="first_sentence",
315
+ enable_warnings=True,
316
+ ):
317
+ self.tokenizer = tokenizer
318
+ self.max_len = max_len
319
+ self.pad_idx = 0
320
+ self.cls_idx = [1]
321
+ self.sep_idx = [2]
322
+ self.eos = [2]
323
+ # hardcoded deberta-specific indexes
324
+ self.nb_spe_tok = len(self.cls_idx) + len(self.sep_idx)
325
+ self.enable_warnings = enable_warnings
326
+ self.unusual_query_length = (
327
+ self.max_len // 2
328
+ ) # TODO: change to data-driven value
329
+ self.unusual_title_len = self.max_len // 2 # TODO: change to data-driven value
330
+ self.create_dataset(contexts, queries, titles)
331
+ self.len = len(self.cis)
332
+
333
+ def create_dataset(self, contexts, queries, titles="first_sentence"):
334
+ self.qis = []
335
+ self.cis = []
336
+ self.sis = []
337
+ self.sent_coords = []
338
+ self.cntx_coords = []
339
+ self.ids = []
340
+ if self.enable_warnings:
341
+ warnings_dict = {
342
+ "zero_len_query": set(),
343
+ "too_long_query": set(),
344
+ "unusually_long_query": set(),
345
+ "unusually_long_title": set(),
346
+ "split_context": set(),
347
+ }
348
+ for i, query in enumerate(queries):
349
+ tokenized_query = self.tokenizer.encode(
350
+ normalize(query), add_special_tokens=False
351
+ )
352
+ # normalize query because all training data has normalized queries
353
+ query_len = len(tokenized_query)
354
+ if query_len == 0:
355
+ if self.enable_warnings:
356
+ warnings_dict["zero_len_query"].add(i)
357
+ continue
358
+ elif query_len >= self.max_len - self.nb_spe_tok - 1: # -1 for eos
359
+ if self.enable_warnings:
360
+ warnings_dict["too_long_query"].add(i)
361
+ continue
362
+ elif query_len >= self.unusual_query_length:
363
+ if self.enable_warnings:
364
+ warnings_dict["unusually_long_query"].add(i)
365
+ left_0 = len(tokenized_query) + self.nb_spe_tok
366
+ tokenized_seq_0 = self.cls_idx + tokenized_query + self.sep_idx
367
+ max_len = self.max_len - left_0 - 1
368
+ for j, cntx in enumerate(contexts[i]):
369
+ title = titles[i][j] if type(titles) == list else titles
370
+ tokenized_sents = sent_split_and_tokenize(cntx, self.tokenizer, max_len)
371
+ # each (sent + query + special tokens) <= max_len
372
+ if title is not None and title != "first_sentence":
373
+ tokenized_title = self.tokenizer.encode(
374
+ title, add_special_tokens=False
375
+ )
376
+ ltok = len(tokenized_title)
377
+ if ltok == 0:
378
+ pass
379
+ elif ltok <= max_len:
380
+ tokenized_sents = [tokenized_title] + tokenized_sents
381
+ else:
382
+ if self.enable_warnings and ltok >= self.unusual_title_len:
383
+ warnings_dict["unusually_long_title"].add(i)
384
+ tokenized_sents = [
385
+ tokenized_title[begin : begin + max_len]
386
+ for begin in range(0, ltok, max_len)
387
+ ] + tokenized_sents
388
+ tokenized_seq = tokenized_seq_0
389
+ left = left_0
390
+ sent_coords = []
391
+ block = 0
392
+ for idx, tokenized_sent in enumerate(tokenized_sents):
393
+ l = len(tokenized_sent)
394
+ if left + l <= self.max_len - 1:
395
+ sent_coords.append([left, left + l])
396
+ tokenized_seq = tokenized_seq + tokenized_sent
397
+ left += l
398
+ else:
399
+ if self.enable_warnings:
400
+ warnings_dict["split_context"].add(i)
401
+ if len(tokenized_seq) > left_0:
402
+ tokenized_seq = tokenized_seq + self.eos
403
+ self.qis.append(i)
404
+ self.cis.append(j)
405
+ self.sis.append(block)
406
+ self.sent_coords.append(sent_coords)
407
+ self.cntx_coords.append(
408
+ [sent_coords[0][0], sent_coords[-1][1]]
409
+ )
410
+ self.ids.append(torch.tensor(tokenized_seq))
411
+ tokenized_seq = tokenized_seq_0 + tokenized_sent
412
+ sent_coords = [[left_0, left_0 + l]]
413
+ left = left_0 + l
414
+ block += 1
415
+ if len(tokenized_seq) > left_0:
416
+ tokenized_seq = tokenized_seq + self.eos
417
+ self.qis.append(i)
418
+ self.cis.append(j)
419
+ self.sis.append(block)
420
+ self.sent_coords.append(sent_coords)
421
+ self.cntx_coords.append([sent_coords[0][0], sent_coords[-1][1]])
422
+ self.ids.append(torch.tensor(tokenized_seq))
423
+ if self.enable_warnings:
424
+ self.print_warnings(warnings_dict, len(queries))
425
+
426
+ def __len__(self):
427
+ return len(self.ids)
428
+
429
+ def print_warnings(self, warnings_dict, N):
430
+ n = len(warnings_dict["zero_len_query"])
431
+ info = " You can suppress Provence warnings by setting enable_warnings=False."
432
+ if n > 0:
433
+ ex = list(warnings_dict["zero_len_query"])[:10]
434
+ warnings.warn(
435
+ f"{n} out of {N} queries have zero length, e.g. at indexes {ex}. "
436
+ "These examples will be skipped in context pruning, "
437
+ "their contexts will be kept as is." + info
438
+ )
439
+ n = len(warnings_dict["too_long_query"])
440
+ if n > 0:
441
+ ex = list(warnings_dict["too_long_query"])[:10]
442
+ warnings.warn(
443
+ f"{n} out of {N} queries are too long for context length {self.max_len}, "
444
+ f"e.g. at indexes {ex}. These examples will be skipped in context pruning, "
445
+ "their contexts will be kept as is." + info
446
+ )
447
+ n = len(warnings_dict["unusually_long_query"])
448
+ if n > 0:
449
+ ex = list(warnings_dict["unusually_long_query"])[:10]
450
+ warnings.warn(
451
+ f"{n} out of {N} queries are longer than {self.unusual_query_length} tokens, "
452
+ f"e.g. at indexes {ex}. These examples will processed as usual in context pruning, "
453
+ "but the quality of context pruning could be reduced." + info
454
+ )
455
+ n = len(warnings_dict["unusually_long_title"])
456
+ if n > 0:
457
+ ex = list(warnings_dict["unusually_long_title"])[:10]
458
+ warnings.warn(
459
+ f"{n} out of {N} titles are longer than {self.unusual_title_length} tokens, "
460
+ f"e.g. at indexes {ex}. These examples will processed as usual in context pruning, "
461
+ "but the quality of context pruning could be reduced." + info
462
+ )
463
+ n = len(warnings_dict["split_context"])
464
+ if n > 0:
465
+ ex = list(warnings_dict["split_context"])[:10]
466
+ warnings.warn(
467
+ f"{n} out of {N} contexts were split into several pieces for context pruning, "
468
+ f"due to a limited context length of Provence which is equal to {self.max_len}. "
469
+ "This could potentially reduce the quality of context pruning. "
470
+ "You could consider checking and reducing lengths of contexts, queries, or titles."
471
+ + info
472
+ )