Spaces:
Running
Running
anakin87
commited on
Commit
·
4c41de2
1
Parent(s):
08d96b7
refactor EntailmentChecker: only relevant documents are used
Browse files- Rock_fact_checker.py +3 -4
- app_utils/backend_utils.py +6 -28
- app_utils/entailment_checker.py +42 -8
Rock_fact_checker.py
CHANGED
|
@@ -97,8 +97,8 @@ def main():
|
|
| 97 |
|
| 98 |
# Display results
|
| 99 |
if st.session_state.results:
|
| 100 |
-
|
| 101 |
-
|
| 102 |
|
| 103 |
# show different messages depending on entailment results
|
| 104 |
max_key = max(agg_entailment_info, key=agg_entailment_info.get)
|
|
@@ -107,12 +107,11 @@ def main():
|
|
| 107 |
|
| 108 |
st.markdown(f"###### Aggregate entailment information:")
|
| 109 |
col1, col2 = st.columns([2, 1])
|
| 110 |
-
agg_entailment_info = results["agg_entailment_info"]
|
| 111 |
fig = create_ternary_plot(agg_entailment_info)
|
| 112 |
with col1:
|
| 113 |
st.plotly_chart(fig, use_container_width=True)
|
| 114 |
with col2:
|
| 115 |
-
st.write(
|
| 116 |
|
| 117 |
st.markdown(f"###### Most Relevant snippets:")
|
| 118 |
df, urls = create_df_for_relevant_snippets(docs)
|
|
|
|
| 97 |
|
| 98 |
# Display results
|
| 99 |
if st.session_state.results:
|
| 100 |
+
docs = st.session_state.results["documents"]
|
| 101 |
+
agg_entailment_info = st.session_state.results["aggregate_entailment_info"]
|
| 102 |
|
| 103 |
# show different messages depending on entailment results
|
| 104 |
max_key = max(agg_entailment_info, key=agg_entailment_info.get)
|
|
|
|
| 107 |
|
| 108 |
st.markdown(f"###### Aggregate entailment information:")
|
| 109 |
col1, col2 = st.columns([2, 1])
|
|
|
|
| 110 |
fig = create_ternary_plot(agg_entailment_info)
|
| 111 |
with col1:
|
| 112 |
st.plotly_chart(fig, use_container_width=True)
|
| 113 |
with col2:
|
| 114 |
+
st.write(agg_entailment_info)
|
| 115 |
|
| 116 |
st.markdown(f"###### Most Relevant snippets:")
|
| 117 |
df, urls = create_df_for_relevant_snippets(docs)
|
app_utils/backend_utils.py
CHANGED
|
@@ -44,7 +44,11 @@ def start_haystack():
|
|
| 44 |
embedding_model=RETRIEVER_MODEL,
|
| 45 |
model_format=RETRIEVER_MODEL_FORMAT,
|
| 46 |
)
|
| 47 |
-
entailment_checker = EntailmentChecker(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
pipe = Pipeline()
|
| 50 |
pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
|
|
@@ -60,30 +64,4 @@ pipe = start_haystack()
|
|
| 60 |
def query(statement: str, retriever_top_k: int = 5):
|
| 61 |
"""Run query and verify statement"""
|
| 62 |
params = {"retriever": {"top_k": retriever_top_k}}
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
|
| 66 |
-
for i, doc in enumerate(results["documents"]):
|
| 67 |
-
scores += doc.score
|
| 68 |
-
ent_info = doc.meta["entailment_info"]
|
| 69 |
-
con, neu, ent = (
|
| 70 |
-
ent_info["contradiction"],
|
| 71 |
-
ent_info["neutral"],
|
| 72 |
-
ent_info["entailment"],
|
| 73 |
-
)
|
| 74 |
-
agg_con += con * doc.score
|
| 75 |
-
agg_neu += neu * doc.score
|
| 76 |
-
agg_ent += ent * doc.score
|
| 77 |
-
|
| 78 |
-
# if in the first documents there is a strong evidence of entailment/contradiction,
|
| 79 |
-
# there is no need to consider less relevant documents
|
| 80 |
-
if max(agg_con, agg_ent) / scores > 0.5:
|
| 81 |
-
results["documents"] = results["documents"][: i + 1]
|
| 82 |
-
break
|
| 83 |
-
|
| 84 |
-
results["agg_entailment_info"] = {
|
| 85 |
-
"contradiction": round(agg_con / scores, 2),
|
| 86 |
-
"neutral": round(agg_neu / scores, 2),
|
| 87 |
-
"entailment": round(agg_ent / scores, 2),
|
| 88 |
-
}
|
| 89 |
-
return results
|
|
|
|
| 44 |
embedding_model=RETRIEVER_MODEL,
|
| 45 |
model_format=RETRIEVER_MODEL_FORMAT,
|
| 46 |
)
|
| 47 |
+
entailment_checker = EntailmentChecker(
|
| 48 |
+
model_name_or_path=NLI_MODEL,
|
| 49 |
+
use_gpu=False,
|
| 50 |
+
entailment_contradiction_threshold=0.5,
|
| 51 |
+
)
|
| 52 |
|
| 53 |
pipe = Pipeline()
|
| 54 |
pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
|
|
|
|
| 64 |
def query(statement: str, retriever_top_k: int = 5):
|
| 65 |
"""Run query and verify statement"""
|
| 66 |
params = {"retriever": {"top_k": retriever_top_k}}
|
| 67 |
+
return pipe.run(statement, params=params)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_utils/entailment_checker.py
CHANGED
|
@@ -4,13 +4,14 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, Auto
|
|
| 4 |
import torch
|
| 5 |
from haystack.nodes.base import BaseComponent
|
| 6 |
from haystack.modeling.utils import initialize_device_settings
|
| 7 |
-
from haystack.schema import Document
|
| 8 |
|
| 9 |
|
| 10 |
class EntailmentChecker(BaseComponent):
|
| 11 |
"""
|
| 12 |
This node checks the entailment between every document content and the query.
|
| 13 |
-
It enrichs the documents metadata with
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
outgoing_edges = 1
|
|
@@ -22,6 +23,7 @@ class EntailmentChecker(BaseComponent):
|
|
| 22 |
tokenizer: Optional[str] = None,
|
| 23 |
use_gpu: bool = True,
|
| 24 |
batch_size: int = 16,
|
|
|
|
| 25 |
):
|
| 26 |
"""
|
| 27 |
Load a Natural Language Inference model from Transformers.
|
|
@@ -31,7 +33,9 @@ class EntailmentChecker(BaseComponent):
|
|
| 31 |
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
| 32 |
:param tokenizer: Name of the tokenizer (usually the same as model)
|
| 33 |
:param use_gpu: Whether to use GPU (if available).
|
| 34 |
-
|
|
|
|
|
|
|
| 35 |
"""
|
| 36 |
super().__init__()
|
| 37 |
|
|
@@ -43,6 +47,7 @@ class EntailmentChecker(BaseComponent):
|
|
| 43 |
pretrained_model_name_or_path=model_name_or_path, revision=model_version
|
| 44 |
)
|
| 45 |
self.batch_size = batch_size
|
|
|
|
| 46 |
self.model.to(str(self.devices[0]))
|
| 47 |
|
| 48 |
id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
|
|
@@ -53,12 +58,41 @@ class EntailmentChecker(BaseComponent):
|
|
| 53 |
)
|
| 54 |
|
| 55 |
def run(self, query: str, documents: List[Document]):
|
| 56 |
-
for doc in documents:
|
| 57 |
-
entailment_dict = self.get_entailment(premise=doc.content, hypotesis=query)
|
| 58 |
-
doc.meta["entailment_info"] = entailment_dict
|
| 59 |
-
return {"documents": documents}, "output_1"
|
| 60 |
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
pass
|
| 63 |
|
| 64 |
def get_entailment(self, premise, hypotesis):
|
|
|
|
| 4 |
import torch
|
| 5 |
from haystack.nodes.base import BaseComponent
|
| 6 |
from haystack.modeling.utils import initialize_device_settings
|
| 7 |
+
from haystack.schema import Document
|
| 8 |
|
| 9 |
|
| 10 |
class EntailmentChecker(BaseComponent):
|
| 11 |
"""
|
| 12 |
This node checks the entailment between every document content and the query.
|
| 13 |
+
It enrichs the documents metadata with entailment informations.
|
| 14 |
+
It also returns aggregate entailment information.
|
| 15 |
"""
|
| 16 |
|
| 17 |
outgoing_edges = 1
|
|
|
|
| 23 |
tokenizer: Optional[str] = None,
|
| 24 |
use_gpu: bool = True,
|
| 25 |
batch_size: int = 16,
|
| 26 |
+
entailment_contradiction_threshold: float = 0.5,
|
| 27 |
):
|
| 28 |
"""
|
| 29 |
Load a Natural Language Inference model from Transformers.
|
|
|
|
| 33 |
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
| 34 |
:param tokenizer: Name of the tokenizer (usually the same as model)
|
| 35 |
:param use_gpu: Whether to use GPU (if available).
|
| 36 |
+
:param batch_size: Number of Documents to be processed at a time.
|
| 37 |
+
:param entailment_contradiction_threshold: if in the first N documents there is a strong evidence of entailment/contradiction
|
| 38 |
+
(aggregate entailment or contradiction are greater than the threshold), the less relevant documents are not taken into account
|
| 39 |
"""
|
| 40 |
super().__init__()
|
| 41 |
|
|
|
|
| 47 |
pretrained_model_name_or_path=model_name_or_path, revision=model_version
|
| 48 |
)
|
| 49 |
self.batch_size = batch_size
|
| 50 |
+
self.entailment_contradiction_threshold = entailment_contradiction_threshold
|
| 51 |
self.model.to(str(self.devices[0]))
|
| 52 |
|
| 53 |
id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
|
|
|
|
| 58 |
)
|
| 59 |
|
| 60 |
def run(self, query: str, documents: List[Document]):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
|
| 63 |
+
for i, doc in enumerate(documents):
|
| 64 |
+
entailment_info = self.get_entailment(premise=doc.content, hypotesis=query)
|
| 65 |
+
doc.meta["entailment_info"] = entailment_info
|
| 66 |
+
|
| 67 |
+
scores += doc.score
|
| 68 |
+
con, neu, ent = (
|
| 69 |
+
entailment_info["contradiction"],
|
| 70 |
+
entailment_info["neutral"],
|
| 71 |
+
entailment_info["entailment"],
|
| 72 |
+
)
|
| 73 |
+
agg_con += con * doc.score
|
| 74 |
+
agg_neu += neu * doc.score
|
| 75 |
+
agg_ent += ent * doc.score
|
| 76 |
+
|
| 77 |
+
# if in the first documents there is a strong evidence of entailment/contradiction,
|
| 78 |
+
# there is no need to consider less relevant documents
|
| 79 |
+
if max(agg_con, agg_ent) / scores > self.entailment_contradiction_threshold:
|
| 80 |
+
break
|
| 81 |
+
|
| 82 |
+
aggregate_entailment_info = {
|
| 83 |
+
"contradiction": round(agg_con / scores, 2),
|
| 84 |
+
"neutral": round(agg_neu / scores, 2),
|
| 85 |
+
"entailment": round(agg_ent / scores, 2),
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
entailment_checker_result = {
|
| 89 |
+
"documents": documents[: i + 1],
|
| 90 |
+
"aggregate_entailment_info": aggregate_entailment_info,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
return entailment_checker_result, "output_1"
|
| 94 |
+
|
| 95 |
+
def run_batch(self, queries: List[str], documents: List[Document]):
|
| 96 |
pass
|
| 97 |
|
| 98 |
def get_entailment(self, premise, hypotesis):
|