|  | import torch | 
					
						
						|  | from torch import nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | from torch.nn.utils.rnn import pad_sequence | 
					
						
						|  | from torch.nn import CrossEntropyLoss, MSELoss | 
					
						
						|  | from torch.utils.data import Dataset | 
					
						
						|  | from transformers import XLMRobertaPreTrainedModel, XLMRobertaModel, PretrainedConfig, AutoTokenizer | 
					
						
						|  | from transformers.modeling_outputs import ModelOutput | 
					
						
						|  | from dataclasses import dataclass | 
					
						
						|  | from typing import Optional, Union, Tuple, List | 
					
						
						|  | import warnings | 
					
						
						|  | import numpy as np | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  | import string | 
					
						
						|  |  | 
					
						
						|  | import spacy | 
					
						
						|  | nlp = spacy.load("xx_sent_ud_sm") | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class RankingCompressionOutput(ModelOutput): | 
					
						
						|  |  | 
					
						
						|  | loss: Optional[torch.FloatTensor] = None | 
					
						
						|  | compression_loss: Optional[torch.FloatTensor] = None | 
					
						
						|  | ranking_loss: Optional[torch.FloatTensor] = None | 
					
						
						|  | compression_logits: torch.FloatTensor = None | 
					
						
						|  | ranking_scores: torch.FloatTensor = None | 
					
						
						|  | hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None | 
					
						
						|  | attentions: Optional[Tuple[torch.FloatTensor, ...]] = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class XProvenceConfig(PretrainedConfig): | 
					
						
						|  | model_type = "XProvence" | 
					
						
						|  | def __init__(self, **kwargs): | 
					
						
						|  | super().__init__(**kwargs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class XProvence(XLMRobertaPreTrainedModel): | 
					
						
						|  | config_class = XProvenceConfig | 
					
						
						|  | def __init__(self, config): | 
					
						
						|  | super().__init__(config) | 
					
						
						|  | num_labels = getattr(config, "num_labels", 2) | 
					
						
						|  | self.num_labels = num_labels | 
					
						
						|  | self.roberta = XLMRobertaModel(config) | 
					
						
						|  | output_dim = config.hidden_size | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.classifier = nn.Linear(output_dim, num_labels) | 
					
						
						|  | drop_out = getattr(config, "cls_dropout", None) | 
					
						
						|  | drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out | 
					
						
						|  | self.dropout = nn.Dropout(drop_out) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | token_dropout = drop_out | 
					
						
						|  | self.token_dropout = nn.Dropout(token_dropout) | 
					
						
						|  | self.token_classifier = nn.Linear( | 
					
						
						|  | config.hidden_size, 2 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) | 
					
						
						|  | self.max_len = config.max_position_embeddings - 4 | 
					
						
						|  |  | 
					
						
						|  | self.post_init() | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | input_ids: Optional[torch.LongTensor] = None, | 
					
						
						|  | attention_mask: Optional[torch.FloatTensor] = None, | 
					
						
						|  | token_type_ids: Optional[torch.LongTensor] = None, | 
					
						
						|  | position_ids: Optional[torch.LongTensor] = None, | 
					
						
						|  | head_mask: Optional[torch.FloatTensor] = None, | 
					
						
						|  | inputs_embeds: Optional[torch.FloatTensor] = None, | 
					
						
						|  | labels: Optional[torch.LongTensor] = None, | 
					
						
						|  | ranking_labels: Optional[torch.LongTensor] = None, | 
					
						
						|  | loss_weight: Optional[float] = None, | 
					
						
						|  | output_attentions: Optional[bool] = None, | 
					
						
						|  | output_hidden_states: Optional[bool] = None, | 
					
						
						|  | return_dict: Optional[bool] = None, | 
					
						
						|  | ) -> Union[Tuple[torch.Tensor], RankingCompressionOutput]: | 
					
						
						|  | """simplified forward""" | 
					
						
						|  | outputs = self.roberta( | 
					
						
						|  | input_ids, | 
					
						
						|  | token_type_ids=token_type_ids, | 
					
						
						|  | attention_mask=attention_mask, | 
					
						
						|  | position_ids=position_ids, | 
					
						
						|  | inputs_embeds=inputs_embeds, | 
					
						
						|  | output_attentions=output_attentions, | 
					
						
						|  | output_hidden_states=output_hidden_states, | 
					
						
						|  | return_dict=return_dict, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | encoder_layer = outputs[0] | 
					
						
						|  |  | 
					
						
						|  | pooled_output = outputs['pooler_output'] | 
					
						
						|  | pooled_output = self.dropout(pooled_output) | 
					
						
						|  | ranking_logits = self.classifier(pooled_output) | 
					
						
						|  | compression_logits = self.token_classifier(self.token_dropout(encoder_layer)) | 
					
						
						|  | ranking_scores = ranking_logits[:, 0].squeeze() | 
					
						
						|  |  | 
					
						
						|  | compression_loss = None | 
					
						
						|  | ranking_loss = None | 
					
						
						|  | if labels is not None: | 
					
						
						|  |  | 
					
						
						|  | labels = labels.to(compression_logits.device) | 
					
						
						|  | loss_fct = CrossEntropyLoss() | 
					
						
						|  | compression_loss = loss_fct(compression_logits.view(-1, 2), labels.view(-1)) | 
					
						
						|  | if ranking_labels is not None: | 
					
						
						|  |  | 
					
						
						|  | ranking_labels = ranking_labels.to(ranking_logits.device) | 
					
						
						|  | loss_fct = MSELoss() | 
					
						
						|  | ranking_loss = loss_fct(ranking_scores, ranking_labels.squeeze()) | 
					
						
						|  | loss = None | 
					
						
						|  | if (labels is not None) and (ranking_labels is not None): | 
					
						
						|  | w = loss_weight if loss_weight else 1 | 
					
						
						|  | loss = compression_loss + w * ranking_loss | 
					
						
						|  | elif labels is not None: | 
					
						
						|  | loss = compression_loss | 
					
						
						|  | elif ranking_labels is not None: | 
					
						
						|  | loss = ranking_loss | 
					
						
						|  |  | 
					
						
						|  | return RankingCompressionOutput( | 
					
						
						|  | loss=loss, | 
					
						
						|  | compression_loss=compression_loss, | 
					
						
						|  | ranking_loss=ranking_loss, | 
					
						
						|  | compression_logits=compression_logits, | 
					
						
						|  | ranking_scores=ranking_scores, | 
					
						
						|  | hidden_states=outputs.hidden_states, | 
					
						
						|  | attentions=outputs.attentions, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def process( | 
					
						
						|  | self, | 
					
						
						|  | question: Union[List[str], str], | 
					
						
						|  | context: Union[List[List[str]], str], | 
					
						
						|  | title: Optional[Union[List[List[str]], str]] = "first_sentence", | 
					
						
						|  | batch_size=32, | 
					
						
						|  | threshold=0.3, | 
					
						
						|  | always_select_title=False, | 
					
						
						|  | reorder=False, | 
					
						
						|  | top_k=5, | 
					
						
						|  | enable_warnings=True, | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if type(question) == str: | 
					
						
						|  | queries = [question] | 
					
						
						|  | else: | 
					
						
						|  | queries = question | 
					
						
						|  | if type(context) == str: | 
					
						
						|  | contexts = [[context]] | 
					
						
						|  | else: | 
					
						
						|  | contexts = context | 
					
						
						|  | if type(title) == str and title != "first_sentence": | 
					
						
						|  | titles = [[title]] | 
					
						
						|  | else: | 
					
						
						|  | titles = title | 
					
						
						|  | assert ( | 
					
						
						|  | titles == "first_sentence" | 
					
						
						|  | or titles == None | 
					
						
						|  | or type(titles) == list | 
					
						
						|  | and len(titles) == len(queries) | 
					
						
						|  | ), "Variable 'titles' must be 'first_sentence' or a list of strings of the same length as 'queries'" | 
					
						
						|  | if type(titles) == list: | 
					
						
						|  | assert all( | 
					
						
						|  | [ | 
					
						
						|  | len(titles_item) == len(contexts_item) | 
					
						
						|  | for titles_item, contexts_item in zip(contexts, titles) | 
					
						
						|  | ] | 
					
						
						|  | ), "Each list in 'titles' must have the same length as the corresponding list in 'context'" | 
					
						
						|  | assert len(queries) == len( | 
					
						
						|  | contexts | 
					
						
						|  | ), "Lists 'queries' and 'contexts' must have same lengths" | 
					
						
						|  | dataset = TestDataset( | 
					
						
						|  | queries=queries, | 
					
						
						|  | contexts=contexts, | 
					
						
						|  | titles=titles, | 
					
						
						|  | tokenizer=self.tokenizer, | 
					
						
						|  | max_len=self.max_len, | 
					
						
						|  | enable_warnings=enable_warnings, | 
					
						
						|  | ) | 
					
						
						|  | selected_contexts = [ | 
					
						
						|  | [{0: contexts[i][j]} for j in range(len(contexts[i]))] | 
					
						
						|  | for i in range(len(queries)) | 
					
						
						|  | ] | 
					
						
						|  | reranking_scores = [ | 
					
						
						|  | [None for j in range(len(contexts[i]))] for i in range(len(queries)) | 
					
						
						|  | ] | 
					
						
						|  | compressions = [ | 
					
						
						|  | [0 for j in range(len(contexts[i]))] for i in range(len(queries)) | 
					
						
						|  | ] | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | for batch_start in tqdm( | 
					
						
						|  | range(0, len(dataset), batch_size), desc="Pruning contexts..." | 
					
						
						|  | ): | 
					
						
						|  | qis = dataset.qis[batch_start : batch_start + batch_size] | 
					
						
						|  | cis = dataset.cis[batch_start : batch_start + batch_size] | 
					
						
						|  | sis = dataset.sis[batch_start : batch_start + batch_size] | 
					
						
						|  | sent_coords = dataset.sent_coords[ | 
					
						
						|  | batch_start : batch_start + batch_size | 
					
						
						|  | ] | 
					
						
						|  | ids_list = dataset.ids[batch_start : batch_start + batch_size] | 
					
						
						|  | ids = pad_sequence( | 
					
						
						|  | ids_list, batch_first=True, padding_value=dataset.pad_idx | 
					
						
						|  | ).to(self.device) | 
					
						
						|  | mask = (ids != dataset.pad_idx).to(self.device) | 
					
						
						|  | outputs = self.forward(ids, mask) | 
					
						
						|  | scores = F.softmax(outputs["compression_logits"].cpu(), dim=-1)[:, :, 1] | 
					
						
						|  | token_preds = scores > threshold | 
					
						
						|  | reranking_scrs = ( | 
					
						
						|  | outputs["ranking_scores"].cpu().numpy() | 
					
						
						|  | ) | 
					
						
						|  | if len(reranking_scrs.shape) == 0: | 
					
						
						|  | reranking_scrs = reranking_scrs[None] | 
					
						
						|  | for ( | 
					
						
						|  | ids_list_, | 
					
						
						|  | token_preds_, | 
					
						
						|  | rerank_score, | 
					
						
						|  | qi, | 
					
						
						|  | ci, | 
					
						
						|  | si, | 
					
						
						|  | sent_coords_, | 
					
						
						|  | ) in zip( | 
					
						
						|  | ids_list, token_preds, reranking_scrs, qis, cis, sis, sent_coords | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | selected_mask = sentence_rounding( | 
					
						
						|  | token_preds_.cpu().numpy(), | 
					
						
						|  | np.array(sent_coords_), | 
					
						
						|  | threshold=threshold, | 
					
						
						|  | always_select_title=always_select_title | 
					
						
						|  | and si == 0 | 
					
						
						|  | and titles != None, | 
					
						
						|  | ) | 
					
						
						|  | assert len(selected_mask) == len(token_preds_) | 
					
						
						|  | selected_contexts[qi][ci][si] = ids_list_[ | 
					
						
						|  | selected_mask[: len(ids_list_)] | 
					
						
						|  | ] | 
					
						
						|  | if si == 0: | 
					
						
						|  | reranking_scores[qi][ci] = rerank_score | 
					
						
						|  | for i in range(len(queries)): | 
					
						
						|  | for j in range(len(contexts[i])): | 
					
						
						|  | if type(selected_contexts[i][j][0]) != str: | 
					
						
						|  | toks = torch.cat( | 
					
						
						|  | [ | 
					
						
						|  | ids_ | 
					
						
						|  | for _, ids_ in sorted( | 
					
						
						|  | selected_contexts[i][j].items(), key=lambda x: x[0] | 
					
						
						|  | ) | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | selected_contexts[i][j] = self.tokenizer.decode( | 
					
						
						|  | toks, | 
					
						
						|  | skip_special_tokens=True, | 
					
						
						|  | clean_up_tokenization_spaces=False, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | selected_contexts[i][j] = selected_contexts[i][j][0] | 
					
						
						|  | len_original = len(contexts[i][j]) | 
					
						
						|  | len_compressed = len(selected_contexts[i][j]) | 
					
						
						|  | compressions[i][j] = (len_original-len_compressed)/len_original * 100 | 
					
						
						|  | if reorder: | 
					
						
						|  | idxs = np.argsort(reranking_scores[i])[::-1][:top_k] | 
					
						
						|  | selected_contexts[i] = [selected_contexts[i][j] for j in idxs] | 
					
						
						|  | reranking_scores[i] = [reranking_scores[i][j] for j in idxs] | 
					
						
						|  | compressions[i] = [compressions[i][j] for j in idxs] | 
					
						
						|  |  | 
					
						
						|  | if type(context) == str: | 
					
						
						|  | selected_contexts = selected_contexts[0][0] | 
					
						
						|  | reranking_scores = reranking_scores[0][0] | 
					
						
						|  | compressions = compressions[0][0] | 
					
						
						|  |  | 
					
						
						|  | return { | 
					
						
						|  | "pruned_context": selected_contexts, | 
					
						
						|  | "reranking_score": reranking_scores, | 
					
						
						|  | "compression_rate": compressions, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def sentence_rounding(predictions, chunks, threshold, always_select_title=True): | 
					
						
						|  | """ | 
					
						
						|  | predictions: a binary vector containing 1 for tokens which were selected and 0s otherwise | 
					
						
						|  | chunks: a list of pairs [start, end] of sentence, i.e. sentence is in coordinates predictions[start:end] | 
					
						
						|  | the functions | 
					
						
						|  | """ | 
					
						
						|  | cumulative_sum = np.cumsum(predictions) | 
					
						
						|  | chunk_sums = cumulative_sum[chunks[:, 1] - 1] - np.where( | 
					
						
						|  | chunks[:, 0] > 0, cumulative_sum[chunks[:, 0] - 1], 0 | 
					
						
						|  | ) | 
					
						
						|  | chunk_lengths = chunks[:, 1] - chunks[:, 0] | 
					
						
						|  | chunk_means = chunk_sums / chunk_lengths | 
					
						
						|  | if always_select_title and (chunk_means>threshold).any(): | 
					
						
						|  | chunk_means[0] = 1 | 
					
						
						|  | means = np.hstack((np.zeros(1), chunk_means, np.zeros(1))) | 
					
						
						|  | repeats = np.hstack( | 
					
						
						|  | ([chunks[0][0]], chunk_lengths, [predictions.shape[0] - chunks[-1][1]]) | 
					
						
						|  | ) | 
					
						
						|  | return np.repeat(means, repeats) > threshold | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def normalize(s: str) -> str: | 
					
						
						|  | def white_space_fix(text): | 
					
						
						|  | return " ".join(text.split()) | 
					
						
						|  |  | 
					
						
						|  | def remove_punc(text): | 
					
						
						|  | exclude = set(string.punctuation) | 
					
						
						|  | return "".join(ch for ch in text if ch not in exclude) | 
					
						
						|  |  | 
					
						
						|  | def lower(text): | 
					
						
						|  | return text.lower() | 
					
						
						|  |  | 
					
						
						|  | return white_space_fix(remove_punc(lower(s))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def sent_split_and_tokenize(text, tokenizer, max_len): | 
					
						
						|  |  | 
					
						
						|  | sents_nltk = [sent.text.strip() for sent in nlp(text).sents] | 
					
						
						|  | sents = [] | 
					
						
						|  | for j, sent_nltk in enumerate(sents_nltk): | 
					
						
						|  | tokinput = (" " if j != 0 else "") + sent_nltk | 
					
						
						|  | tok = tokenizer.encode(tokinput, add_special_tokens=False) | 
					
						
						|  | ltok = len(tok) | 
					
						
						|  | if ltok == 0: | 
					
						
						|  | continue | 
					
						
						|  | if ltok <= max_len: | 
					
						
						|  | sents.append(tok) | 
					
						
						|  | else: | 
					
						
						|  | for begin in range(0, ltok, max_len): | 
					
						
						|  | sents.append(tok[begin:begin+max_len]) | 
					
						
						|  | return sents | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TestDataset(Dataset): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | queries, | 
					
						
						|  | contexts, | 
					
						
						|  | tokenizer, | 
					
						
						|  | max_len=6000, | 
					
						
						|  | titles="first_sentence", | 
					
						
						|  | enable_warnings=True, | 
					
						
						|  | ): | 
					
						
						|  | self.tokenizer = tokenizer | 
					
						
						|  | self.max_len = max_len | 
					
						
						|  | self.pad_idx = self.tokenizer.pad_token_id | 
					
						
						|  | self.cls_idx = [self.tokenizer.cls_token_id] | 
					
						
						|  | self.sep_idx = [self.tokenizer.sep_token_id] | 
					
						
						|  | self.eos = [self.tokenizer.eos_token_id] | 
					
						
						|  |  | 
					
						
						|  | self.nb_spe_tok = len(self.cls_idx) + len(self.sep_idx) | 
					
						
						|  | self.enable_warnings = enable_warnings | 
					
						
						|  | self.unusual_query_length = ( | 
					
						
						|  | self.max_len // 2 | 
					
						
						|  | ) | 
					
						
						|  | self.unusual_title_len = self.max_len // 2 | 
					
						
						|  | self.create_dataset(queries, contexts, titles) | 
					
						
						|  | self.len = len(self.cis) | 
					
						
						|  |  | 
					
						
						|  | def create_dataset(self, queries, contexts, titles="first_sentence"): | 
					
						
						|  | self.qis = [] | 
					
						
						|  | self.cis = [] | 
					
						
						|  | self.sis = [] | 
					
						
						|  | self.sent_coords = [] | 
					
						
						|  | self.cntx_coords = [] | 
					
						
						|  | self.ids = [] | 
					
						
						|  | if self.enable_warnings: | 
					
						
						|  | warnings_dict = { | 
					
						
						|  | "zero_len_query": set(), | 
					
						
						|  | "too_long_query": set(), | 
					
						
						|  | "unusually_long_query": set(), | 
					
						
						|  | "unusually_long_title": set(), | 
					
						
						|  | "split_context": set(), | 
					
						
						|  | } | 
					
						
						|  | for i, query in enumerate(queries): | 
					
						
						|  | tokenized_query = self.tokenizer.encode( | 
					
						
						|  | normalize(query), add_special_tokens=False | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | query_len = len(tokenized_query) | 
					
						
						|  | if query_len == 0: | 
					
						
						|  | if self.enable_warnings: | 
					
						
						|  | warnings_dict["zero_len_query"].add(i) | 
					
						
						|  | continue | 
					
						
						|  | elif query_len >= self.max_len - self.nb_spe_tok - 1: | 
					
						
						|  | if self.enable_warnings: | 
					
						
						|  | warnings_dict["too_long_query"].add(i) | 
					
						
						|  | continue | 
					
						
						|  | elif query_len >= self.unusual_query_length: | 
					
						
						|  | if self.enable_warnings: | 
					
						
						|  | warnings_dict["unusually_long_query"].add(i) | 
					
						
						|  | left_0 = len(tokenized_query) + self.nb_spe_tok | 
					
						
						|  | tokenized_seq_0 = self.cls_idx + tokenized_query + self.sep_idx | 
					
						
						|  | max_len = self.max_len - left_0 - 1 | 
					
						
						|  | for j, cntx in enumerate(contexts[i]): | 
					
						
						|  | title = titles[i][j] if type(titles) == list else titles | 
					
						
						|  | tokenized_sents = sent_split_and_tokenize(cntx, self.tokenizer, max_len) | 
					
						
						|  |  | 
					
						
						|  | if title is not None and title != "first_sentence": | 
					
						
						|  | tokenized_title = self.tokenizer.encode( | 
					
						
						|  | title, add_special_tokens=False | 
					
						
						|  | ) | 
					
						
						|  | ltok = len(tokenized_title) | 
					
						
						|  | if ltok == 0: | 
					
						
						|  | pass | 
					
						
						|  | elif ltok <= max_len: | 
					
						
						|  | tokenized_sents = [tokenized_title] + tokenized_sents | 
					
						
						|  | else: | 
					
						
						|  | if self.enable_warnings and ltok >= self.unusual_title_len: | 
					
						
						|  | warnings_dict["unusually_long_title"].add(i) | 
					
						
						|  | tokenized_sents = [ | 
					
						
						|  | tokenized_title[begin : begin + max_len] | 
					
						
						|  | for begin in range(0, ltok, max_len) | 
					
						
						|  | ] + tokenized_sents | 
					
						
						|  | tokenized_seq = tokenized_seq_0 | 
					
						
						|  | left = left_0 | 
					
						
						|  | sent_coords = [] | 
					
						
						|  | block = 0 | 
					
						
						|  | for idx, tokenized_sent in enumerate(tokenized_sents): | 
					
						
						|  | l = len(tokenized_sent) | 
					
						
						|  | if left + l <= self.max_len - 1: | 
					
						
						|  | sent_coords.append([left, left + l]) | 
					
						
						|  | tokenized_seq = tokenized_seq + tokenized_sent | 
					
						
						|  | left += l | 
					
						
						|  | else: | 
					
						
						|  | if self.enable_warnings: | 
					
						
						|  | warnings_dict["split_context"].add(i) | 
					
						
						|  | if len(tokenized_seq) > left_0: | 
					
						
						|  | tokenized_seq = tokenized_seq + self.eos | 
					
						
						|  | self.qis.append(i) | 
					
						
						|  | self.cis.append(j) | 
					
						
						|  | self.sis.append(block) | 
					
						
						|  | self.sent_coords.append(sent_coords) | 
					
						
						|  | self.cntx_coords.append( | 
					
						
						|  | [sent_coords[0][0], sent_coords[-1][1]] | 
					
						
						|  | ) | 
					
						
						|  | self.ids.append(torch.tensor(tokenized_seq)) | 
					
						
						|  | tokenized_seq = tokenized_seq_0 + tokenized_sent | 
					
						
						|  | sent_coords = [[left_0, left_0 + l]] | 
					
						
						|  | left = left_0 + l | 
					
						
						|  | block += 1 | 
					
						
						|  | if len(tokenized_seq) > left_0: | 
					
						
						|  | tokenized_seq = tokenized_seq + self.eos | 
					
						
						|  | self.qis.append(i) | 
					
						
						|  | self.cis.append(j) | 
					
						
						|  | self.sis.append(block) | 
					
						
						|  | self.sent_coords.append(sent_coords) | 
					
						
						|  | self.cntx_coords.append([sent_coords[0][0], sent_coords[-1][1]]) | 
					
						
						|  | self.ids.append(torch.tensor(tokenized_seq)) | 
					
						
						|  | if self.enable_warnings: | 
					
						
						|  | self.print_warnings(warnings_dict, len(queries)) | 
					
						
						|  |  | 
					
						
						|  | def __len__(self): | 
					
						
						|  | return len(self.ids) | 
					
						
						|  |  | 
					
						
						|  | def print_warnings(self, warnings_dict, N): | 
					
						
						|  | n = len(warnings_dict["zero_len_query"]) | 
					
						
						|  | info = " You can suppress Provence warnings by setting enable_warnings=False." | 
					
						
						|  | if n > 0: | 
					
						
						|  | ex = list(warnings_dict["zero_len_query"])[:10] | 
					
						
						|  | warnings.warn( | 
					
						
						|  | f"{n} out of {N} queries have zero length, e.g. at indexes {ex}. " | 
					
						
						|  | "These examples will be skipped in context pruning, " | 
					
						
						|  | "their contexts will be kept as is." + info | 
					
						
						|  | ) | 
					
						
						|  | n = len(warnings_dict["too_long_query"]) | 
					
						
						|  | if n > 0: | 
					
						
						|  | ex = list(warnings_dict["too_long_query"])[:10] | 
					
						
						|  | warnings.warn( | 
					
						
						|  | f"{n} out of {N} queries are too long for context length {self.max_len}, " | 
					
						
						|  | f"e.g. at indexes {ex}. These examples will be skipped in context pruning, " | 
					
						
						|  | "their contexts will be kept as is." + info | 
					
						
						|  | ) | 
					
						
						|  | n = len(warnings_dict["unusually_long_query"]) | 
					
						
						|  | if n > 0: | 
					
						
						|  | ex = list(warnings_dict["unusually_long_query"])[:10] | 
					
						
						|  | warnings.warn( | 
					
						
						|  | f"{n} out of {N} queries are longer than {self.unusual_query_length} tokens, " | 
					
						
						|  | f"e.g. at indexes {ex}. These examples will processed as usual in context pruning, " | 
					
						
						|  | "but the quality of context pruning could be reduced." + info | 
					
						
						|  | ) | 
					
						
						|  | n = len(warnings_dict["unusually_long_title"]) | 
					
						
						|  | if n > 0: | 
					
						
						|  | ex = list(warnings_dict["unusually_long_title"])[:10] | 
					
						
						|  | warnings.warn( | 
					
						
						|  | f"{n} out of {N} titles are longer than {self.unusual_title_length} tokens, " | 
					
						
						|  | f"e.g. at indexes {ex}. These examples will processed as usual in context pruning, " | 
					
						
						|  | "but the quality of context pruning could be reduced." + info | 
					
						
						|  | ) | 
					
						
						|  | n = len(warnings_dict["split_context"]) | 
					
						
						|  | if n > 0: | 
					
						
						|  | ex = list(warnings_dict["split_context"])[:10] | 
					
						
						|  | warnings.warn( | 
					
						
						|  | f"{n} out of {N} contexts were split into several pieces for context pruning, " | 
					
						
						|  | f"due to a limited context length of Provence which is equal to {self.max_len}. " | 
					
						
						|  | "This could potentially reduce the quality of context pruning. " | 
					
						
						|  | "You could consider checking and reducing lengths of contexts, queries, or titles." | 
					
						
						|  | + info | 
					
						
						|  | ) | 
					
						
						|  |  |