|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import time |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from typing import Dict |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
from sentence_transformers.util import cos_sim |
|
|
from accelerate import Accelerator |
|
|
from scipy.stats import zscore |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
|
|
|
device = accelerator.device |
|
|
print("Using accelerator device =", device) |
|
|
|
|
|
|
|
|
from sentence_transformers import CrossEncoder |
|
|
model_sf_mxbai = CrossEncoder("mixedbread-ai/mxbai-rerank-large-v1" ,device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def RAG_retrieval_Base(queryText ,passages, min_threshold=0.0, max_num_passages=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
df_filtered = pd.DataFrame() |
|
|
|
|
|
if max_num_passages: |
|
|
result_rerank = model_sf_mxbai.rank(queryText, passages, return_documents=False, top_k=max_num_passages) |
|
|
else: |
|
|
nback =int(0.1 *len(passages)) |
|
|
if nback<=0: |
|
|
nback=1 |
|
|
result_rerank = model_sf_mxbai.rank(queryText, passages, return_documents=False, top_k=nback) |
|
|
|
|
|
if result_rerank: |
|
|
df = pd.DataFrame(result_rerank) |
|
|
|
|
|
if min_threshold >0: |
|
|
df_filtered = df[df['score'] >= min_threshold] |
|
|
else: |
|
|
df_filtered =df.copy() |
|
|
|
|
|
selected_passages = [passages[i] for i in df_filtered['corpus_id']] |
|
|
|
|
|
|
|
|
df_filtered['Passage'] = selected_passages |
|
|
|
|
|
df_filtered = df_filtered.drop_duplicates(subset='Passage', keep='first') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return df_filtered |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"An error occurred: {e}") |
|
|
return pd.DataFrame() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
queryText = 'A man is eating a piece of bread' |
|
|
|
|
|
|
|
|
passages = [ |
|
|
"A man is eating food.", |
|
|
"A man is eating pasta.", |
|
|
"The girl is carrying a baby.", |
|
|
"A man is riding a horse.", |
|
|
] |
|
|
|
|
|
df_retrieved = RAG_retrieval_Base(queryText, passages, min_threshold=0, max_num_passages=3) |
|
|
|
|
|
|
|
|
print(df_retrieved) |
|
|
|
|
|
|
|
|
print("end of computations") |
|
|
|
|
|
|