AsoBozorg commited on
Commit
430f7e7
·
verified ·
1 Parent(s): 5411677

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -194
app.py CHANGED
@@ -1,206 +1,199 @@
1
- import gradio as gr
2
- import re
3
- import torch
4
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
5
-
6
- # ---------------------------
7
- # Load models
8
- # ---------------------------
9
- print("Loading models...")
10
-
11
- # Sentence similarity model
12
- retriever = pipeline("feature-extraction", model="sentence-transformers/all-MiniLM-L6-v2")
13
 
14
- # NLI model (for faithfulness scoring)
15
- nli = pipeline("text-classification", model="facebook/bart-large-mnli", top_k=None)
 
16
 
17
- # Toxicity classifier
18
- toxicity = pipeline("text-classification", model="unitary/toxic-bert", top_k=None)
19
-
20
- # Summarization/fallback model
21
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
22
-
23
- # Generation model (for synthetic answers)
24
- M = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def ensure_gen(use_tiny=True):
27
- global M
28
- if M is None:
29
- if use_tiny:
30
- model_name = "google/flan-t5-small"
31
- M = pipeline("text2text-generation", model=model_name)
32
- else:
33
- model_name = "google/flan-t5-base"
34
- M = pipeline("text2text-generation", model=model_name)
35
- return M
36
-
37
- # ---------------------------
38
- # Demo Index (sources)
39
- # ---------------------------
40
- INDEX = {
41
- "titles": [
42
- "IPCC on Climate Change",
43
- "Elections Security Myths",
44
- "WHO on Vaccines & Safety",
45
- ],
46
- "urls": [
47
- "https://example.org/ipcc",
48
- "https://example.org/election-security",
49
- "https://example.org/who-vaccines",
50
- ],
51
- "texts": [
52
- "The IPCC states with high confidence that human activities are the dominant cause of global warming since the mid-20th century.",
53
- "Studies show that widespread voter fraud in modern elections is extremely rare and not supported by credible evidence.",
54
- "The World Health Organization confirms vaccines are safe and effective, with benefits vastly outweighing risks.",
55
- ],
56
- }
57
-
58
- # ---------------------------
59
- # Helpers
60
- # ---------------------------
61
- def _clean(txt: str) -> str:
62
- return re.sub(r"\s+", " ", txt).strip()
63
-
64
- def _extractive_fallback(question, idxs):
65
- """If generation fails, fallback to extractive summarizer."""
66
- ctx = " ".join([INDEX["texts"][i] for i in idxs])
67
- summary = summarizer(ctx, max_length=80, min_length=30, do_sample=False)
68
- return summary[0]["summary_text"], [INDEX["titles"][i] for i in idxs]
69
-
70
- # ---------------------------
71
- # Faithfulness scoring
72
- # ---------------------------
73
- def faithfulness_scores(answer, idxs):
74
- scores, per_source = [], []
75
- for i, idx in enumerate(idxs):
76
- premise = INDEX["texts"][idx]
77
- result = nli({"premise": premise, "hypothesis": answer})[0]
78
-
79
- entail_score = 0.0
80
- for item in result:
81
- if item["label"].upper().startswith("ENTAIL"):
82
- entail_score = item["score"]
83
- scores.append(entail_score)
84
- per_source.append((INDEX["titles"][idx], entail_score))
85
-
86
- mean_score = sum(scores) / len(scores) if scores else 0.0
87
- return mean_score, per_source
88
-
89
- # ---------------------------
90
- # Toxicity scoring
91
- # ---------------------------
92
- def toxicity_risk(answer):
93
- result = toxicity(answer)[0]
94
- toxic_score = 0.0
95
- for item in result:
96
- if "toxic" in item["label"].lower():
97
- toxic_score = item["score"]
98
- return toxic_score
99
-
100
- # ---------------------------
101
- # Answer generation
102
- # ---------------------------
103
- def generate_answer(question, idxs, use_tiny=True, max_new=220):
104
  ensure_gen(use_tiny)
105
-
106
  ctx, cites = [], []
107
  for i, idx in enumerate(idxs):
108
  ctx.append(f"[{i+1}] {INDEX['texts'][idx]}")
109
  cites.append(f"[{i+1}] {INDEX['titles'][idx]} – {INDEX['urls'][idx]}")
110
-
111
- instr = (
112
- "Write a clear paragraph (3–6 sentences) that answers the user's claim "
113
- "STRICTLY using the sources below. Include citations like [1], [2]. "
114
- "Do not reply with only citation markers; write complete sentences."
115
- )
116
-
117
- # ✅ build ctx_block outside the f-string
118
- ctx_block = "\n".join(ctx)
119
-
120
  prompt = (
121
- f"{instr}\n\nSources:\n{ctx_block}\n\n"
122
- f"Claim: {question}\nAnswer:"
123
- )
124
-
125
- toks = M.tokenizer(prompt, return_tensors="pt", truncation=True)
126
- out = M.model.generate(
127
- **toks,
128
- max_new_tokens=max_new,
129
- min_new_tokens=80,
130
- do_sample=True,
131
- temperature=0.8,
132
- top_p=0.92,
133
- repetition_penalty=1.15,
134
- no_repeat_ngram_size=3,
135
- early_stopping=True,
136
  )
137
- text = _clean(M.tokenizer.decode(out[0], skip_special_tokens=True))
138
-
139
- if len(text) < 60 or re.fullmatch(r"\[+\d+\]+\.?", text):
140
- text, cites = _extractive_fallback(question, idxs)
141
-
142
- return text, cites
143
-
144
- # ---------------------------
145
- # Pipeline
146
- # ---------------------------
147
- def run_pipeline(claim, src1, src2, src3, use_tiny=True):
148
- # Gather candidate sources
149
- candidates = list(range(len(INDEX["texts"])))
150
- if src1:
151
- INDEX["texts"].append(src1)
152
- INDEX["titles"].append("Custom Source 1")
153
- INDEX["urls"].append("user://source1")
154
- candidates.append(len(INDEX["texts"]) - 1)
155
- if src2:
156
- INDEX["texts"].append(src2)
157
- INDEX["titles"].append("Custom Source 2")
158
- INDEX["urls"].append("user://source2")
159
- candidates.append(len(INDEX["texts"]) - 1)
160
- if src3:
161
- INDEX["texts"].append(src3)
162
- INDEX["titles"].append("Custom Source 3")
163
- INDEX["urls"].append("user://source3")
164
- candidates.append(len(INDEX["texts"]) - 1)
165
-
166
- # Pick top-3 sources (simplified: first 3 candidates)
167
- idxs = candidates[:3]
168
-
169
- # Generate answer
170
- answer, cites = generate_answer(claim, idxs, use_tiny=use_tiny)
171
-
172
- # Faithfulness + toxicity
173
- faith_total, per_src = faithfulness_scores(answer, idxs)
174
- tox = toxicity_risk(answer)
175
-
176
- # PII redaction
177
- redacted = re.sub(r"\b[A-Z][a-z]+ [A-Z][a-z]+\b", "[REDACTED]", answer)
178
-
179
- return (
180
- f"Faithfulness (mean entailment): {faith_total:.2f} | Toxicity risk: {tox:.2f}\n\n{answer}\n\n"
181
- + "\n".join(cites),
182
- per_src,
183
- redacted,
184
- )
185
-
186
- # ---------------------------
187
- # Gradio UI
188
- # ---------------------------
189
- with gr.Blocks() as demo:
190
- gr.Markdown("## 🎯 TruthLens – Misinformation-Aware RAG\nType a claim or question and see fact-checked answers with citations.")
191
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  with gr.Row():
193
- claim = gr.Textbox(label="Claim or question", placeholder="e.g., Did humans cause climate change?")
194
- src1 = gr.Textbox(label="Optional source 1", lines=3)
195
- src2 = gr.Textbox(label="Optional source 2", lines=3)
196
- src3 = gr.Textbox(label="Optional source 3", lines=3)
197
-
198
- run_btn = gr.Button("Run TruthLens", variant="primary")
199
-
200
- out_answer = gr.Textbox(label="Fact-checked answer", lines=8)
201
- out_table = gr.Dataframe(headers=["Source", "Faithfulness"], label="Per-source faithfulness", wrap=True)
202
- out_redact = gr.Textbox(label="PII-redacted answer", lines=6)
203
-
204
- run_btn.click(fn=run_pipeline, inputs=[claim, src1, src2, src3], outputs=[out_answer, out_table, out_redact])
205
-
206
- demo.launch()
 
 
 
 
1
+ # ─────────────────────────────────────────────────────────────────────────────
2
+ # TruthLens – Misinformation-Aware RAG (Lite/Full modes)
3
+ # ─────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
4
 
5
+ import numpy as np
6
+ import pandas as pd
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
 
9
+ import gradio as gr
10
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
11
+ from sentence_transformers import SentenceTransformer
12
+
13
+ # ===== Config =====
14
+ GEN_TINY = "google/flan-t5-small" # Lite mode
15
+ GEN_FULL = "google/flan-t5-base" # Full mode
16
+ EMB_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
17
+ NLI_MODEL = "cross-encoder/nli-deberta-v3-small" # light NLI (still optional)
18
+ NER_MODEL = "dslim/bert-base-NER"
19
+ TOX_MODEL = "unitary/toxic-bert"
20
+
21
+ SAMPLE_DOCS = [
22
+ {"title": "WHO on Vaccines & Safety",
23
+ "text": "Vaccines undergo rigorous testing and continuous safety monitoring. Severe adverse reactions are rare.",
24
+ "url": "https://example.org/who-vaccines"},
25
+ {"title": "IPCC on Climate Change",
26
+ "text": "It is unequivocal that human influence has warmed the atmosphere, ocean and land.",
27
+ "url": "https://example.org/ipcc"},
28
+ {"title": "Elections Security Myths",
29
+ "text": "Independent audits reduce fraud risk; no credible evidence for nationwide manipulation.",
30
+ "url": "https://example.org/election-security"},
31
+ ]
32
+
33
+ # ===== Lazy model holders =====
34
+ class M:
35
+ emb = None
36
+ tok = None
37
+ gen = None
38
+ nli = None
39
+ ner = None
40
+ tox = None
41
+
42
+ INDEX = {"emb": None, "texts": [], "titles": [], "urls": []}
43
+
44
+ def ensure_emb():
45
+ if M.emb is None:
46
+ M.emb = SentenceTransformer(EMB_MODEL)
47
 
48
  def ensure_gen(use_tiny=True):
49
+ model_id = GEN_TINY if use_tiny else GEN_FULL
50
+ if (M.gen is None) or (getattr(M.gen, "_id", None) != model_id):
51
+ M.tok = AutoTokenizer.from_pretrained(model_id)
52
+ M.gen = AutoModelForSeq2SeqLM.from_pretrained(model_id)
53
+ M.gen._id = model_id # remember which is loaded
54
+
55
+ def ensure_nli():
56
+ if M.nli is None:
57
+ # NOTE: no return_all_scores; we’ll use top_k=None at call time
58
+ M.nli = pipeline("text-classification", model=NLI_MODEL)
59
+
60
+ def ensure_ner():
61
+ if M.ner is None:
62
+ M.ner = pipeline("token-classification", model=NER_MODEL, aggregation_strategy="simple")
63
+
64
+ def ensure_tox():
65
+ if M.tox is None:
66
+ M.tox = pipeline("text-classification", model=TOX_MODEL)
67
+
68
+ # ===== Index =====
69
+ def build_index(extra=None):
70
+ ensure_emb()
71
+ texts = [d["text"] for d in SAMPLE_DOCS]
72
+ titles = [d["title"] for d in SAMPLE_DOCS]
73
+ urls = [d["url"] for d in SAMPLE_DOCS]
74
+ if extra:
75
+ for i, t in enumerate(extra):
76
+ if t and str(t).strip():
77
+ texts.append(str(t).strip()); titles.append(f"UserDoc {i+1}"); urls.append("user://paste")
78
+ INDEX["emb"] = M.emb.encode(texts, normalize_embeddings=True, convert_to_numpy=True)
79
+ INDEX["texts"], INDEX["titles"], INDEX["urls"] = texts, titles, urls
80
+
81
+ # ===== Core steps =====
82
+ def retrieve(q, k=3):
83
+ ensure_emb()
84
+ if INDEX["emb"] is None:
85
+ build_index()
86
+ qv = M.emb.encode([q], normalize_embeddings=True, convert_to_numpy=True)
87
+ sims = cosine_similarity(qv, INDEX["emb"])[0]
88
+ return list(np.argsort(-sims)[:k])
89
+
90
+ def generate_answer(question, idxs, use_tiny=True, max_new=256):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  ensure_gen(use_tiny)
 
92
  ctx, cites = [], []
93
  for i, idx in enumerate(idxs):
94
  ctx.append(f"[{i+1}] {INDEX['texts'][idx]}")
95
  cites.append(f"[{i+1}] {INDEX['titles'][idx]} – {INDEX['urls'][idx]}")
 
 
 
 
 
 
 
 
 
 
96
  prompt = (
97
+ "Answer the user's claim STRICTLY using the sources below. "
98
+ "Use citations like [1], [2]. If unsure, say you are uncertain.\n\n"
99
+ f"Sources:\n{'\n'.join(ctx)}\n\n"
100
+ f"Question: {question}\nAnswer:"
 
 
 
 
 
 
 
 
 
 
 
101
  )
102
+ toks = M.tok(prompt, return_tensors="pt", truncation=True)
103
+ out = M.gen.generate(**toks, max_new_tokens=max_new, do_sample=False)
104
+ return M.tok.decode(out[0], skip_special_tokens=True), cites
105
+
106
+ def nli_faithfulness(answer, idxs):
107
+ try:
108
+ ensure_nli()
109
+ per_src = []
110
+ for idx in idxs:
111
+ prem = INDEX["texts"][idx]
112
+ out = M.nli({"text": prem, "text_pair": answer}, top_k=None)
113
+ # Normalize shapes: out -> list -> list[dict] or dict
114
+ scores_obj = out[0] if isinstance(out, list) and out else out
115
+ scores = [scores_obj] if isinstance(scores_obj, dict) else (scores_obj or [])
116
+ ent = 0.0
117
+ for item in scores:
118
+ if str(item.get("label", "")).upper().startswith("ENTAIL"):
119
+ ent = float(item.get("score", 0.0)); break
120
+ per_src.append((INDEX["titles"][idx], ent))
121
+ mean_ent = float(np.mean([s for _, s in per_src])) if per_src else 0.0
122
+ return mean_ent, per_src, None
123
+ except Exception as e:
124
+ return 0.0, [(INDEX["titles"][i], 0.0) for i in idxs], f"NLI skipped: {e}"
125
+
126
+ def redact_pii(text):
127
+ try:
128
+ ensure_ner()
129
+ ents = M.ner(text)
130
+ ents = sorted(ents, key=lambda e: e.get("end",0)-e.get("start",0), reverse=True)
131
+ out = text
132
+ for e in ents:
133
+ s, e2 = int(e.get("start",0)), int(e.get("end",0))
134
+ span = text[s:e2]
135
+ if span:
136
+ out = out.replace(span, f"<{e.get('entity_group','ENT')}>")
137
+ return out, None
138
+ except Exception as e:
139
+ return text, f"PII redaction skipped: {e}"
140
+
141
+ def tox_score(text):
142
+ try:
143
+ ensure_tox()
144
+ pred = M.tox(text)[0]
145
+ return float(pred.get("score", 0.0)), None
146
+ except Exception as e:
147
+ return 0.0, f"Toxicity check skipped: {e}"
148
+
149
+ # ===== Pipeline (Lite vs Full) =====
150
+ def run_pipeline(claim, s1, s2, s3, lite_mode):
151
+ # Build/refresh index with user sources
152
+ build_index([s1, s2, s3])
153
+
154
+ # 1) Retrieve + Generate (always on)
155
+ idxs = retrieve(claim, k=3)
156
+ answer, cites = generate_answer(claim, idxs, use_tiny=lite_mode)
157
+
158
+ # 2) Optional checks (only in Full mode, but fail-soft)
159
+ notes = []
160
+ if not lite_mode:
161
+ mean_ent, per_src, nli_note = nli_faithfulness(answer, idxs)
162
+ if nli_note: notes.append(nli_note)
163
+ pii, pii_note = redact_pii(answer); redacted = pii
164
+ if pii_note: notes.append(pii_note)
165
+ tox, tox_note = tox_score(answer)
166
+ if tox_note: notes.append(tox_note)
167
+ else:
168
+ mean_ent, per_src = 0.0, [(INDEX["titles"][i], 0.0) for i in idxs]
169
+ redacted, tox = answer, 0.0
170
+ notes.append("Lite mode: NLI/PII/Toxicity disabled for reliability on free CPU.")
171
+
172
+ table = pd.DataFrame({"Source": [s for s,_ in per_src],
173
+ "Faithfulness": [round(float(sc),3) for _, sc in per_src]})
174
+ summary = f"Faithfulness (mean entailment): {mean_ent:.2f} | Toxicity risk: {tox:.2f}"
175
+ if notes:
176
+ summary += " \n" + " \n".join(f"• {n}" for n in notes)
177
+ return summary, answer, "\n".join(cites), table, redacted
178
+
179
+ # ===== UI =====
180
+ with gr.Blocks(title="TruthLens – Misinformation-Aware RAG") as demo:
181
+ gr.Markdown("# 🧭 TruthLens – Misinformation-Aware RAG\nType a claim or question and get a grounded answer with citations.")
182
  with gr.Row():
183
+ with gr.Column():
184
+ claim = gr.Textbox(label="Claim or question", lines=2, placeholder="e.g., Did humans cause climate change?")
185
+ lite = gr.Checkbox(value=True, label="Lite mode (more reliable on free CPU)")
186
+ run_btn = gr.Button("Run TruthLens", variant="primary")
187
+ with gr.Column():
188
+ s1 = gr.Textbox(label="Optional source 1", lines=3)
189
+ s2 = gr.Textbox(label="Optional source 2", lines=3)
190
+ s3 = gr.Textbox(label="Optional source 3", lines=3)
191
+ summary = gr.Markdown()
192
+ answer = gr.Markdown(label="Answer")
193
+ cites = gr.Markdown(label="Citations")
194
+ table = gr.Dataframe(label="Per-source faithfulness")
195
+ redacted = gr.Textbox(label="PII-redacted answer", lines=3)
196
+ run_btn.click(run_pipeline, [claim, s1, s2, s3, lite], [summary, answer, cites, table, redacted])
197
+
198
+ if __name__ == "__main__":
199
+ demo.launch()