|
|
""" |
|
|
Experimental Dashboard for RAG Pipeline Testing |
|
|
Provides GUI interface for running and visualizing experiments |
|
|
""" |
|
|
|
|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
import plotly.express as px |
|
|
import plotly.graph_objects as go |
|
|
from typing import Dict, List, Any |
|
|
import json |
|
|
import time |
|
|
from datetime import datetime |
|
|
import threading |
|
|
import queue |
|
|
|
|
|
|
|
|
from model.model import RAGModel |
|
|
from rails import input as input_guard |
|
|
from rails.output import OutputGuardrails |
|
|
from helper import Answer |
|
|
import os |
|
|
try: |
|
|
import secrets_local |
|
|
HF_TOKEN = secrets_local.HF |
|
|
except ImportError: |
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
sys.path.append(str(Path(__file__).parent / "experiments")) |
|
|
|
|
|
|
|
|
from app import query_rag_pipeline |
|
|
|
|
|
def render_experiment_dashboard(): |
|
|
"""Main experimental dashboard interface""" |
|
|
|
|
|
st.header("π§ͺ RAG Pipeline Experiments") |
|
|
st.markdown("Run controlled experiments to test and validate RAG pipeline behavior") |
|
|
|
|
|
|
|
|
tab1, tab2, tab3 = st.tabs(["π System Info", "π‘οΈ Input Guards", "π Output Guards"]) |
|
|
|
|
|
with tab1: |
|
|
render_system_info_tab() |
|
|
|
|
|
with tab2: |
|
|
render_input_guardrails_tab() |
|
|
|
|
|
with tab3: |
|
|
render_output_guardrails_tab() |
|
|
|
|
|
def render_system_overview(): |
|
|
"""Render quick system overview at the top""" |
|
|
|
|
|
with st.expander("βΉοΈ About this RAG System", expanded=False): |
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.markdown("**π― Purpose:**") |
|
|
st.write("Test and validate a Retrieval-Augmented Generation (RAG) system for university data queries") |
|
|
|
|
|
st.markdown("**π§ Components:**") |
|
|
st.write("β’ Sentence Transformers embeddings") |
|
|
st.write("β’ ChromaDB vector database") |
|
|
st.write("β’ Hugging Face API for text generation") |
|
|
st.write("β’ Input/Output security guardrails") |
|
|
|
|
|
with col2: |
|
|
st.markdown("**π Sample Queries:**") |
|
|
st.write("β’ 'What courses is Maria taking?'") |
|
|
st.write("β’ 'Who teaches computer science?'") |
|
|
st.write("β’ 'Show me faculty in engineering'") |
|
|
|
|
|
st.markdown("**β οΈ Test Cases:**") |
|
|
st.write("β’ Malicious SQL injection attempts") |
|
|
st.write("β’ Personal data extraction tries") |
|
|
st.write("β’ Parameter optimization tests") |
|
|
|
|
|
def get_database_stats(): |
|
|
"""Get real database statistics""" |
|
|
try: |
|
|
import sqlite3 |
|
|
import os |
|
|
|
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
db_path = os.path.join(current_dir, 'database', 'university.db') |
|
|
|
|
|
if not os.path.exists(db_path): |
|
|
|
|
|
db_path = 'database/university.db' |
|
|
if not os.path.exists(db_path): |
|
|
st.warning(f"Database file not found. Checked: {db_path}") |
|
|
return None |
|
|
|
|
|
conn = sqlite3.connect(db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
|
|
|
student_count = cursor.execute("SELECT COUNT(*) FROM students").fetchone()[0] |
|
|
faculty_count = cursor.execute("SELECT COUNT(*) FROM faculty").fetchone()[0] |
|
|
course_count = cursor.execute("SELECT COUNT(*) FROM courses").fetchone()[0] |
|
|
enrollment_count = cursor.execute("SELECT COUNT(*) FROM enrollments").fetchone()[0] |
|
|
|
|
|
|
|
|
sample_student = cursor.execute("SELECT name FROM students LIMIT 1").fetchone() |
|
|
sample_faculty = cursor.execute("SELECT name, department FROM faculty LIMIT 1").fetchone() |
|
|
|
|
|
sample_course_query = """ |
|
|
SELECT c.name, f.department |
|
|
FROM courses c |
|
|
JOIN faculty f ON c.faculty_id = f.id |
|
|
LIMIT 1 |
|
|
""" |
|
|
sample_course = cursor.execute(sample_course_query).fetchone() |
|
|
|
|
|
conn.close() |
|
|
|
|
|
|
|
|
st.success(f"β
Database connected! Found {student_count} students, {faculty_count} faculty, {course_count} courses") |
|
|
|
|
|
return { |
|
|
'students': student_count, |
|
|
'faculty': faculty_count, |
|
|
'courses': course_count, |
|
|
'enrollments': enrollment_count, |
|
|
'sample_student': sample_student[0] if sample_student else "No data available", |
|
|
'sample_faculty': sample_faculty if sample_faculty else ("No data available", "No department"), |
|
|
'sample_course': sample_course if sample_course else ("No data available", "No department") |
|
|
} |
|
|
except Exception as e: |
|
|
st.error(f"β Error connecting to database: {str(e)}") |
|
|
return None |
|
|
|
|
|
def render_system_info_tab(): |
|
|
"""Render comprehensive system information tab""" |
|
|
|
|
|
st.subheader("π System Information & Database Schema") |
|
|
|
|
|
|
|
|
db_stats = get_database_stats() |
|
|
|
|
|
if db_stats: |
|
|
|
|
|
st.markdown("### π Live Database Statistics") |
|
|
col1, col2, col3, col4 = st.columns(4) |
|
|
|
|
|
with col1: |
|
|
st.metric("π₯ Students", db_stats['students']) |
|
|
with col2: |
|
|
st.metric("π¨βπ« Faculty", db_stats['faculty']) |
|
|
with col3: |
|
|
st.metric("π Courses", db_stats['courses']) |
|
|
with col4: |
|
|
st.metric("π Enrollments", db_stats['enrollments']) |
|
|
|
|
|
|
|
|
st.markdown("### ποΈ Database Schema") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.markdown("**Tables Overview:**") |
|
|
|
|
|
|
|
|
with st.expander("π₯ Students Table", expanded=True): |
|
|
if db_stats: |
|
|
st.markdown(f""" |
|
|
**Columns:** |
|
|
- `id` (Primary Key) |
|
|
- `name` (Student full name) |
|
|
- `email` (Email address - PII) |
|
|
- `svnr` (Social security number - Sensitive PII) |
|
|
|
|
|
**Sample Data:** |
|
|
- {db_stats['sample_student']} ([REDACTED_EMAIL]) |
|
|
- Contains {db_stats['students']} total student records |
|
|
- All emails and SVNR automatically redacted for privacy |
|
|
""") |
|
|
else: |
|
|
st.markdown(""" |
|
|
**Columns:** |
|
|
- `id` (Primary Key) |
|
|
- `name` (Student full name) |
|
|
- `email` (Email address - PII) |
|
|
- `svnr` (Social security number - Sensitive PII) |
|
|
|
|
|
**Sample Data:** |
|
|
- Database connection not available |
|
|
- Contains realistic student records with Faker-generated data |
|
|
- All emails and SVNR automatically redacted for privacy |
|
|
""") |
|
|
|
|
|
|
|
|
with st.expander("π¨βπ« Faculty Table"): |
|
|
if db_stats: |
|
|
faculty_name, faculty_dept = db_stats['sample_faculty'] |
|
|
st.markdown(f""" |
|
|
**Columns:** |
|
|
- `id` (Primary Key) |
|
|
- `name` (Faculty full name) |
|
|
- `email` (Email address - PII) |
|
|
- `department` (Department/specialization) |
|
|
|
|
|
**Sample Data:** |
|
|
- {faculty_name} ({faculty_dept}) |
|
|
- Contains {db_stats['faculty']} total faculty records |
|
|
- Departments include engineering, sciences, humanities |
|
|
""") |
|
|
else: |
|
|
st.markdown(""" |
|
|
**Columns:** |
|
|
- `id` (Primary Key) |
|
|
- `name` (Faculty full name) |
|
|
- `email` (Email address - PII) |
|
|
- `department` (Department/specialization) |
|
|
|
|
|
**Sample Data:** |
|
|
- Database connection not available |
|
|
- Contains faculty across various academic departments |
|
|
- Departments include engineering, sciences, humanities |
|
|
""") |
|
|
|
|
|
with col2: |
|
|
|
|
|
with st.expander("π Courses Table", expanded=True): |
|
|
if db_stats: |
|
|
course_name, course_dept = db_stats['sample_course'] |
|
|
st.markdown(f""" |
|
|
**Columns:** |
|
|
- `id` (Primary Key) |
|
|
- `name` (Course title) |
|
|
- `faculty_id` (Foreign Key β Faculty) |
|
|
- `department` (Course department) |
|
|
|
|
|
**Sample Data:** |
|
|
- "{course_name}" ({course_dept}) |
|
|
- Contains {db_stats['courses']} total course records |
|
|
- Generated with realistic university course patterns |
|
|
""") |
|
|
else: |
|
|
st.markdown(""" |
|
|
**Columns:** |
|
|
- `id` (Primary Key) |
|
|
- `name` (Course title) |
|
|
- `faculty_id` (Foreign Key β Faculty) |
|
|
- `department` (Course department) |
|
|
|
|
|
**Sample Data:** |
|
|
- Database connection not available |
|
|
- Contains realistic university courses across departments |
|
|
- Generated with realistic university course patterns |
|
|
""") |
|
|
|
|
|
|
|
|
with st.expander("π Enrollments Table"): |
|
|
if db_stats: |
|
|
avg_enrollments = db_stats['enrollments'] // db_stats['students'] if db_stats['students'] > 0 else 0 |
|
|
st.markdown(f""" |
|
|
**Columns:** |
|
|
- `id` (Primary Key) |
|
|
- `student_id` (Foreign Key β Students) |
|
|
- `course_id` (Foreign Key β Courses) |
|
|
|
|
|
**Purpose:** |
|
|
Links students to their enrolled courses (Many-to-Many relationship) |
|
|
|
|
|
**Statistics:** |
|
|
- {db_stats['enrollments']} total enrollment records |
|
|
- Average enrollments per student: {avg_enrollments} |
|
|
""") |
|
|
else: |
|
|
st.markdown(""" |
|
|
**Columns:** |
|
|
- `id` (Primary Key) |
|
|
- `student_id` (Foreign Key β Students) |
|
|
- `course_id` (Foreign Key β Courses) |
|
|
|
|
|
**Purpose:** |
|
|
Links students to their enrolled courses (Many-to-Many relationship) |
|
|
|
|
|
**Statistics:** |
|
|
- Database connection not available |
|
|
- Contains realistic enrollment patterns for university students |
|
|
""") |
|
|
|
|
|
|
|
|
st.markdown("### π€ RAG Pipeline Components") |
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
|
|
with col1: |
|
|
st.markdown("**π₯ Input Processing:**") |
|
|
st.write("β’ Language detection") |
|
|
st.write("β’ SQL injection detection") |
|
|
st.write("β’ Toxic content filtering") |
|
|
st.write("β’ Intent classification") |
|
|
|
|
|
with col2: |
|
|
st.markdown("**π Retrieval:**") |
|
|
st.write("β’ Sentence-BERT embeddings") |
|
|
st.write("β’ ChromaDB similarity search") |
|
|
st.write("β’ Context window management") |
|
|
st.write("β’ Relevance scoring") |
|
|
|
|
|
with col3: |
|
|
st.markdown("**π€ Output Generation:**") |
|
|
st.write("β’ Hugging Face API") |
|
|
st.write("β’ PII redaction") |
|
|
st.write("β’ Hallucination detection") |
|
|
st.write("β’ Response validation") |
|
|
|
|
|
|
|
|
st.markdown("### π Security & Privacy Features") |
|
|
|
|
|
with st.expander("π‘οΈ Security Measures", expanded=True): |
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.markdown("**Input Guardrails:**") |
|
|
st.write("β
SQL injection prevention") |
|
|
st.write("β
Command injection blocking") |
|
|
st.write("β
Toxic language filtering") |
|
|
st.write("β
Language validation") |
|
|
|
|
|
with col2: |
|
|
st.markdown("**Output Guardrails:**") |
|
|
st.write("β
Email address redaction") |
|
|
st.write("β
SVNR number protection") |
|
|
st.write("β
Irrelevant response filtering") |
|
|
st.write("β
Data leakage prevention") |
|
|
|
|
|
|
|
|
st.markdown("### π§ͺ Available Experiments") |
|
|
|
|
|
exp_info = [ |
|
|
{ |
|
|
"Experiment": "π‘οΈ Input Guards", |
|
|
"Purpose": "Test security against malicious inputs", |
|
|
"Tests": "SQL injection, toxic content, data extraction attempts", |
|
|
"Goal": "Block harmful queries while allowing legitimate ones" |
|
|
}, |
|
|
{ |
|
|
"Experiment": "π Output Guards", |
|
|
"Purpose": "Validate response safety and quality", |
|
|
"Tests": "PII leakage, SVNR exposure, relevance checking", |
|
|
"Goal": "Prevent sensitive data exposure and ensure relevance" |
|
|
} |
|
|
] |
|
|
|
|
|
df = pd.DataFrame(exp_info) |
|
|
st.dataframe(df, use_container_width=True) |
|
|
|
|
|
def render_input_guardrails_tab(): |
|
|
"""Render input guardrails experiment interface""" |
|
|
|
|
|
st.subheader("π‘οΈ Input Guardrails Testing") |
|
|
|
|
|
|
|
|
with st.expander("βΉοΈ About Input Guardrails", expanded=False): |
|
|
st.markdown(""" |
|
|
**Purpose:** Test the system's ability to detect and block malicious or inappropriate inputs. |
|
|
|
|
|
**What we test:** |
|
|
- π« **SQL Injection**: Attempts to manipulate database queries |
|
|
- π« **Command Injection**: System command execution attempts |
|
|
- π« **Toxic Content**: Inappropriate or offensive language |
|
|
- π« **Data Extraction**: Attempts to access sensitive information (emails, SVNR) |
|
|
- β
**Legitimate Queries**: Normal university-related questions should pass |
|
|
|
|
|
**How it works:** |
|
|
- Language detection to ensure English input |
|
|
- Pattern matching for common attack vectors |
|
|
- Content filtering for inappropriate language |
|
|
- Context analysis for data extraction attempts |
|
|
""") |
|
|
|
|
|
col1, col2 = st.columns([2, 1]) |
|
|
|
|
|
with col1: |
|
|
st.markdown("### π§ͺ Test Custom Input") |
|
|
|
|
|
|
|
|
test_input = st.text_area( |
|
|
"Enter your test input:", |
|
|
placeholder="Try: 'SELECT * FROM students WHERE name=admin--' or 'What courses does Maria take?'", |
|
|
height=100, |
|
|
help="Enter any input to test if our guardrails can detect malicious content" |
|
|
) |
|
|
|
|
|
test_button = st.button("π Test Input", type="primary") |
|
|
|
|
|
if test_button and test_input: |
|
|
test_single_input(test_input) |
|
|
|
|
|
with col2: |
|
|
st.markdown("### π Quick Attack Tests") |
|
|
|
|
|
|
|
|
try: |
|
|
from experiments.experiment_1_input_guardrails import InputGuardrailsExperiment |
|
|
exp = InputGuardrailsExperiment() |
|
|
test_cases = exp._get_test_cases() |
|
|
|
|
|
for i, test_case in enumerate(test_cases): |
|
|
if test_case["expected_blocked"]: |
|
|
with st.container(): |
|
|
|
|
|
icon_map = { |
|
|
"sql_injection": "π", |
|
|
"xss_injection": "π", |
|
|
"toxicity": "π€¬", |
|
|
"command_injection": "π₯" |
|
|
} |
|
|
icon = icon_map.get(test_case["category"], "β οΈ") |
|
|
|
|
|
st.markdown(f"**{icon} {test_case['name']}**") |
|
|
st.caption("Expected: BLOCK") |
|
|
if st.button(f"Test {test_case['name']}", key=f"test_{i}", use_container_width=True): |
|
|
test_single_input(test_case['input']) |
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
legitimate_cases = [tc for tc in test_cases if not tc["expected_blocked"]] |
|
|
if legitimate_cases: |
|
|
test_case = legitimate_cases[0] |
|
|
st.markdown("**β
Legitimate Query**") |
|
|
st.caption("Expected: ALLOW") |
|
|
if st.button("Test Legitimate", key="legitimate_test", use_container_width=True): |
|
|
test_single_input(test_case['input']) |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Could not load test cases: {e}") |
|
|
st.info("Using fallback test cases...") |
|
|
|
|
|
if st.button("Test SQL Injection", key="fallback_test", use_container_width=True): |
|
|
test_single_input("SELECT * FROM students WHERE name='admin'--") |
|
|
|
|
|
|
|
|
if "input_test_results" in st.session_state: |
|
|
display_input_test_results() |
|
|
|
|
|
def render_output_guardrails_tab(): |
|
|
"""Render output guardrails experiment interface""" |
|
|
|
|
|
st.subheader("π Output Guardrails Testing") |
|
|
|
|
|
|
|
|
with st.expander("βΉοΈ About Output Guardrails", expanded=False): |
|
|
st.markdown(""" |
|
|
**Purpose:** Test the system's ability to detect and filter problematic content in generated responses. |
|
|
|
|
|
**What we detect:** |
|
|
- π§ **PII Leakage**: Email addresses that shouldn't be exposed |
|
|
- π’ **SVNR Exposure**: Social security numbers (highly sensitive) |
|
|
- π― **Relevance Issues**: Responses not related to university queries |
|
|
- π« **Data Leakage**: Any sensitive information exposure |
|
|
|
|
|
**How it works:** |
|
|
- Pattern matching for emails and ID numbers |
|
|
- Content analysis for relevance to university context |
|
|
- Automatic redaction of detected sensitive data |
|
|
- Response filtering based on content quality |
|
|
|
|
|
**Test Process:** |
|
|
1. Simulate problematic AI responses |
|
|
2. Run through output guardrails |
|
|
3. Check what gets detected and filtered |
|
|
""") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.markdown("### π§ͺ Test Custom Response") |
|
|
|
|
|
|
|
|
test_response = st.text_area( |
|
|
"Enter AI response to test:", |
|
|
placeholder="Try: 'Student email: [email protected]' or 'SVNR: 1235567890'", |
|
|
height=100, |
|
|
help="Enter any AI response to test if our output guardrails can detect issues" |
|
|
) |
|
|
|
|
|
test_button = st.button("π Test Response", type="primary") |
|
|
|
|
|
if test_button and test_response: |
|
|
test_real_output_filtering(test_response) |
|
|
|
|
|
with col2: |
|
|
st.markdown("### π Quick Response Tests") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
try: |
|
|
import secrets_local |
|
|
api_key = secrets_local.HF |
|
|
except ImportError: |
|
|
import os |
|
|
api_key = os.environ.get("HF_TOKEN", "dummy_key") |
|
|
|
|
|
from experiments.experiment_2_output_guardrails import OutputGuardrailsExperiment |
|
|
exp = OutputGuardrailsExperiment(api_key) |
|
|
test_cases = exp._get_test_cases() |
|
|
|
|
|
|
|
|
demo_responses = [ |
|
|
{ |
|
|
"name": "π§ Email Leakage Test", |
|
|
"response": "Student email: [email protected], [email protected]", |
|
|
"expected": "FILTER", |
|
|
"category": "pii_leakage" |
|
|
}, |
|
|
{ |
|
|
"name": "π’ SVNR Leakage Test", |
|
|
"response": "Student SVNR numbers: 1235567890, 9870543210", |
|
|
"expected": "FILTER", |
|
|
"category": "svnr_exposure" |
|
|
}, |
|
|
{ |
|
|
"name": "π― Irrelevant Response Test", |
|
|
"response": "Today's weather is sunny with 25Β°C temperature", |
|
|
"expected": "FILTER", |
|
|
"category": "irrelevant" |
|
|
} |
|
|
] |
|
|
|
|
|
for i, demo in enumerate(demo_responses): |
|
|
with st.container(): |
|
|
st.markdown(f"**{demo['name']}**") |
|
|
st.caption(f"Expected: {demo['expected']}") |
|
|
if st.button(f"Test Response #{i+1}", key=f"response_test_{i}", use_container_width=True): |
|
|
test_real_output_filtering(demo['response']) |
|
|
st.markdown("---") |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Could not load output test cases: {e}") |
|
|
st.info("Using fallback test...") |
|
|
if st.button("Test Email Detection", key="fallback_output_test", use_container_width=True): |
|
|
test_real_output_filtering("Student email: [email protected]") |
|
|
|
|
|
|
|
|
if "output_test_results" in st.session_state: |
|
|
display_output_test_results() |
|
|
|
|
|
def test_single_input(test_input: str): |
|
|
"""Test a single input through the complete RAG pipeline""" |
|
|
|
|
|
try: |
|
|
|
|
|
model = RAGModel(HF_TOKEN) |
|
|
output_guardrails = OutputGuardrails() |
|
|
input_guardrails = input_guard.InputGuardRails() |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
result = query_rag_pipeline(test_input, model, output_guardrails, input_guardrails) |
|
|
|
|
|
|
|
|
blocked = ("Invalid input" in result.answer or |
|
|
"SQL injection" in result.answer or |
|
|
"inappropriate" in result.answer.lower() or |
|
|
"blocked" in result.answer.lower()) |
|
|
|
|
|
|
|
|
st.session_state.input_test_results = { |
|
|
"input": test_input, |
|
|
"blocked": blocked, |
|
|
"reason": result.answer if blocked else "Input accepted - generated response successfully", |
|
|
"full_answer": result.answer, |
|
|
"sources": result.sources, |
|
|
"processing_time": result.processing_time, |
|
|
"timestamp": datetime.now().strftime('%H:%M:%S') |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error testing input: {e}") |
|
|
|
|
|
def test_real_output_filtering(test_response: str): |
|
|
"""Test output filtering through complete RAG pipeline by generating a response that should contain the test content""" |
|
|
|
|
|
try: |
|
|
|
|
|
model = RAGModel(HF_TOKEN) |
|
|
output_guardrails = OutputGuardrails() |
|
|
input_guardrails = input_guard.InputGuardRails() |
|
|
|
|
|
|
|
|
|
|
|
if "email" in test_response.lower(): |
|
|
test_query = "What are some example student contact details?" |
|
|
elif "svnr" in test_response.lower() or any(char.isdigit() for char in test_response): |
|
|
test_query = "Can you show me student identification numbers?" |
|
|
elif "weather" in test_response.lower() or "temperature" in test_response.lower(): |
|
|
test_query = "What's the current weather like?" |
|
|
elif "programming" in test_response.lower() or "computer science" in test_response.lower(): |
|
|
test_query = "What computer science courses are available?" |
|
|
else: |
|
|
test_query = "Tell me about university information" |
|
|
|
|
|
|
|
|
raw_result = query_rag_pipeline(test_query, model, output_guardrails, input_guardrails, |
|
|
input_guardrails_active=True, output_guardrails_active=False) |
|
|
|
|
|
|
|
|
|
|
|
from rag import retriever |
|
|
|
|
|
|
|
|
try: |
|
|
context = retriever.search(test_query, top_k=3) |
|
|
except: |
|
|
context = [] |
|
|
|
|
|
|
|
|
guardrail_results = output_guardrails.check(test_query, test_response, context) |
|
|
|
|
|
|
|
|
filtered_response = test_response |
|
|
from helper import EMAIL_PATTERN |
|
|
filtered_response = EMAIL_PATTERN.sub('[REDACTED_EMAIL]', filtered_response) |
|
|
filtered_response = output_guardrails.redact_svnrs(filtered_response) |
|
|
|
|
|
|
|
|
issues_detected = [] |
|
|
for check_name, result in guardrail_results.items(): |
|
|
if not result.passed: |
|
|
issue_details = ", ".join(result.issues) if result.issues else "Failed validation" |
|
|
issues_detected.append(f"{check_name}: {issue_details}") |
|
|
|
|
|
blocked = len(issues_detected) > 0 |
|
|
|
|
|
|
|
|
st.session_state.output_test_results = { |
|
|
"original": test_response, |
|
|
"filtered": filtered_response, |
|
|
"blocked": blocked, |
|
|
"issues": issues_detected, |
|
|
"query_used": test_query, |
|
|
"context_docs": len(context), |
|
|
"guardrails_enabled": True, |
|
|
"timestamp": datetime.now().strftime('%H:%M:%S'), |
|
|
"system": "REAL" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error testing output filtering: {e}") |
|
|
import traceback |
|
|
st.error(f"Details: {traceback.format_exc()}") |
|
|
|
|
|
def test_output_filtering(response: str, enable_filtering: bool): |
|
|
"""Test output filtering (legacy/fallback method)""" |
|
|
|
|
|
try: |
|
|
|
|
|
filtered_response = response |
|
|
issues = [] |
|
|
|
|
|
if enable_filtering: |
|
|
if "@" in response: |
|
|
issues.append("Email detected") |
|
|
filtered_response = response.replace("@", "[EMAIL]") |
|
|
if any(char.isdigit() for char in response) and len([c for c in response if c.isdigit()]) > 5: |
|
|
issues.append("Potential SVNR/ID detected") |
|
|
|
|
|
st.session_state.output_test_results = { |
|
|
"original": response, |
|
|
"filtered": filtered_response, |
|
|
"issues": issues, |
|
|
"guardrails_enabled": enable_filtering, |
|
|
"timestamp": datetime.now().strftime('%H:%M:%S'), |
|
|
"system": "SIMULATED" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error testing output: {e}") |
|
|
|
|
|
def display_input_test_results(): |
|
|
"""Display input test results from RAG pipeline""" |
|
|
|
|
|
results = st.session_state.input_test_results |
|
|
|
|
|
st.markdown("### π Input Test Results (Full RAG Pipeline)") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.markdown("**Input Query:**") |
|
|
st.code(results["input"]) |
|
|
|
|
|
if not results["blocked"] and results.get("sources"): |
|
|
st.markdown("**Sources Retrieved:**") |
|
|
with st.expander(f"π {len(results['sources'])} sources found"): |
|
|
for source in results["sources"][:3]: |
|
|
st.write(f"β’ {source['title']}") |
|
|
|
|
|
with col2: |
|
|
if results["blocked"]: |
|
|
st.error(f"π« BLOCKED: {results['reason']}") |
|
|
else: |
|
|
st.success("β
ALLOWED - Generated Response") |
|
|
if results.get("full_answer"): |
|
|
with st.expander("π Generated Response"): |
|
|
st.write(results["full_answer"]) |
|
|
|
|
|
|
|
|
if results.get("processing_time"): |
|
|
st.metric("Processing Time", f"{results['processing_time']:.3f}s") |
|
|
|
|
|
st.caption(f"Tested at {results['timestamp']} | System: Full RAG Pipeline") |
|
|
|
|
|
def display_output_test_results(): |
|
|
"""Display output test results from RAG pipeline integration""" |
|
|
|
|
|
results = st.session_state.output_test_results |
|
|
|
|
|
|
|
|
system_type = results.get("system", "UNKNOWN") |
|
|
st.markdown("### π Output Test Results") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.markdown("**Original Response:**") |
|
|
|
|
|
original_text = results.get("original", results.get("input", "")) |
|
|
st.write(original_text) |
|
|
|
|
|
if results.get("query_used"): |
|
|
st.markdown("**Query Used:**") |
|
|
st.caption(f"π {results['query_used']}") |
|
|
|
|
|
with col2: |
|
|
st.markdown("**Filtered Response:**") |
|
|
st.write(results["filtered"]) |
|
|
|
|
|
if results.get("context_docs"): |
|
|
st.markdown("**Context Retrieved:**") |
|
|
st.caption(f"π {results['context_docs']} documents") |
|
|
|
|
|
|
|
|
if system_type == "REAL": |
|
|
|
|
|
if results.get("blocked", False): |
|
|
st.error("π« Response BLOCKED by output guardrails") |
|
|
if results.get("issues"): |
|
|
st.warning("**Issues detected:**") |
|
|
for issue in results["issues"]: |
|
|
st.write(f"β’ {issue}") |
|
|
else: |
|
|
st.success("β
Response PASSED output guardrails") |
|
|
else: |
|
|
|
|
|
if results.get("issues"): |
|
|
st.warning(f"Issues detected: {', '.join(results['issues'])}") |
|
|
else: |
|
|
st.success("No issues detected") |
|
|
|
|
|
st.caption(f"Tested at {results['timestamp']} | System: {system_type} (RAG Pipeline Integration)") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
render_experiment_dashboard() |