RAG / experimental_dashboard.py
gauthy08's picture
feat: Integrate query_rag_pipeline into dashboard and cleanup
2e3dc8a
raw
history blame
30.7 kB
"""
Experimental Dashboard for RAG Pipeline Testing
Provides GUI interface for running and visualizing experiments
"""
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from typing import Dict, List, Any
import json
import time
from datetime import datetime
import threading
import queue
# Import for RAG pipeline integration
from model.model import RAGModel
from rails import input as input_guard
from rails.output import OutputGuardrails
from helper import Answer
import os
try:
import secrets_local
HF_TOKEN = secrets_local.HF
except ImportError:
HF_TOKEN = os.environ.get("HF_TOKEN")
# Import experiments - note: experiments are imported within functions to avoid circular imports
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent / "experiments"))
# Import query_rag_pipeline from app.py to avoid code duplication
from app import query_rag_pipeline
def render_experiment_dashboard():
"""Main experimental dashboard interface"""
st.header("πŸ§ͺ RAG Pipeline Experiments")
st.markdown("Run controlled experiments to test and validate RAG pipeline behavior")
# Main content area with tabs
tab1, tab2, tab3 = st.tabs(["πŸ“‹ System Info", "πŸ›‘οΈ Input Guards", "πŸ” Output Guards"])
with tab1:
render_system_info_tab()
with tab2:
render_input_guardrails_tab()
with tab3:
render_output_guardrails_tab()
def render_system_overview():
"""Render quick system overview at the top"""
with st.expander("ℹ️ About this RAG System", expanded=False):
col1, col2 = st.columns(2)
with col1:
st.markdown("**🎯 Purpose:**")
st.write("Test and validate a Retrieval-Augmented Generation (RAG) system for university data queries")
st.markdown("**πŸ”§ Components:**")
st.write("β€’ Sentence Transformers embeddings")
st.write("β€’ ChromaDB vector database")
st.write("β€’ Hugging Face API for text generation")
st.write("β€’ Input/Output security guardrails")
with col2:
st.markdown("**πŸ“Š Sample Queries:**")
st.write("β€’ 'What courses is Maria taking?'")
st.write("β€’ 'Who teaches computer science?'")
st.write("β€’ 'Show me faculty in engineering'")
st.markdown("**⚠️ Test Cases:**")
st.write("β€’ Malicious SQL injection attempts")
st.write("β€’ Personal data extraction tries")
st.write("β€’ Parameter optimization tests")
def get_database_stats():
"""Get real database statistics"""
try:
import sqlite3
import os
# Use absolute path to ensure we find the database
current_dir = os.path.dirname(os.path.abspath(__file__))
db_path = os.path.join(current_dir, 'database', 'university.db')
if not os.path.exists(db_path):
# Try relative path as fallback
db_path = 'database/university.db'
if not os.path.exists(db_path):
st.warning(f"Database file not found. Checked: {db_path}")
return None
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Get counts
student_count = cursor.execute("SELECT COUNT(*) FROM students").fetchone()[0]
faculty_count = cursor.execute("SELECT COUNT(*) FROM faculty").fetchone()[0]
course_count = cursor.execute("SELECT COUNT(*) FROM courses").fetchone()[0]
enrollment_count = cursor.execute("SELECT COUNT(*) FROM enrollments").fetchone()[0]
# Get sample data (using correct column names)
sample_student = cursor.execute("SELECT name FROM students LIMIT 1").fetchone()
sample_faculty = cursor.execute("SELECT name, department FROM faculty LIMIT 1").fetchone()
# Courses table doesn't have department column, get faculty info via join
sample_course_query = """
SELECT c.name, f.department
FROM courses c
JOIN faculty f ON c.faculty_id = f.id
LIMIT 1
"""
sample_course = cursor.execute(sample_course_query).fetchone()
conn.close()
# Success message for debugging
st.success(f"βœ… Database connected! Found {student_count} students, {faculty_count} faculty, {course_count} courses")
return {
'students': student_count,
'faculty': faculty_count,
'courses': course_count,
'enrollments': enrollment_count,
'sample_student': sample_student[0] if sample_student else "No data available",
'sample_faculty': sample_faculty if sample_faculty else ("No data available", "No department"),
'sample_course': sample_course if sample_course else ("No data available", "No department")
}
except Exception as e:
st.error(f"❌ Error connecting to database: {str(e)}")
return None
def render_system_info_tab():
"""Render comprehensive system information tab"""
st.subheader("πŸ“‹ System Information & Database Schema")
# Get real database stats
db_stats = get_database_stats()
if db_stats:
# Live Database Statistics
st.markdown("### πŸ“Š Live Database Statistics")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("πŸ‘₯ Students", db_stats['students'])
with col2:
st.metric("πŸ‘¨β€πŸ« Faculty", db_stats['faculty'])
with col3:
st.metric("πŸ“š Courses", db_stats['courses'])
with col4:
st.metric("πŸ“ Enrollments", db_stats['enrollments'])
# Database Schema
st.markdown("### πŸ—„οΈ Database Schema")
col1, col2 = st.columns(2)
with col1:
st.markdown("**Tables Overview:**")
# Students table
with st.expander("πŸ‘₯ Students Table", expanded=True):
if db_stats:
st.markdown(f"""
**Columns:**
- `id` (Primary Key)
- `name` (Student full name)
- `email` (Email address - PII)
- `svnr` (Social security number - Sensitive PII)
**Sample Data:**
- {db_stats['sample_student']} ([REDACTED_EMAIL])
- Contains {db_stats['students']} total student records
- All emails and SVNR automatically redacted for privacy
""")
else:
st.markdown("""
**Columns:**
- `id` (Primary Key)
- `name` (Student full name)
- `email` (Email address - PII)
- `svnr` (Social security number - Sensitive PII)
**Sample Data:**
- Database connection not available
- Contains realistic student records with Faker-generated data
- All emails and SVNR automatically redacted for privacy
""")
# Faculty table
with st.expander("πŸ‘¨β€πŸ« Faculty Table"):
if db_stats:
faculty_name, faculty_dept = db_stats['sample_faculty']
st.markdown(f"""
**Columns:**
- `id` (Primary Key)
- `name` (Faculty full name)
- `email` (Email address - PII)
- `department` (Department/specialization)
**Sample Data:**
- {faculty_name} ({faculty_dept})
- Contains {db_stats['faculty']} total faculty records
- Departments include engineering, sciences, humanities
""")
else:
st.markdown("""
**Columns:**
- `id` (Primary Key)
- `name` (Faculty full name)
- `email` (Email address - PII)
- `department` (Department/specialization)
**Sample Data:**
- Database connection not available
- Contains faculty across various academic departments
- Departments include engineering, sciences, humanities
""")
with col2:
# Courses table
with st.expander("πŸ“š Courses Table", expanded=True):
if db_stats:
course_name, course_dept = db_stats['sample_course']
st.markdown(f"""
**Columns:**
- `id` (Primary Key)
- `name` (Course title)
- `faculty_id` (Foreign Key β†’ Faculty)
- `department` (Course department)
**Sample Data:**
- "{course_name}" ({course_dept})
- Contains {db_stats['courses']} total course records
- Generated with realistic university course patterns
""")
else:
st.markdown("""
**Columns:**
- `id` (Primary Key)
- `name` (Course title)
- `faculty_id` (Foreign Key β†’ Faculty)
- `department` (Course department)
**Sample Data:**
- Database connection not available
- Contains realistic university courses across departments
- Generated with realistic university course patterns
""")
# Enrollments table
with st.expander("πŸ“ Enrollments Table"):
if db_stats:
avg_enrollments = db_stats['enrollments'] // db_stats['students'] if db_stats['students'] > 0 else 0
st.markdown(f"""
**Columns:**
- `id` (Primary Key)
- `student_id` (Foreign Key β†’ Students)
- `course_id` (Foreign Key β†’ Courses)
**Purpose:**
Links students to their enrolled courses (Many-to-Many relationship)
**Statistics:**
- {db_stats['enrollments']} total enrollment records
- Average enrollments per student: {avg_enrollments}
""")
else:
st.markdown("""
**Columns:**
- `id` (Primary Key)
- `student_id` (Foreign Key β†’ Students)
- `course_id` (Foreign Key β†’ Courses)
**Purpose:**
Links students to their enrolled courses (Many-to-Many relationship)
**Statistics:**
- Database connection not available
- Contains realistic enrollment patterns for university students
""")
# RAG System Details
st.markdown("### πŸ€– RAG Pipeline Components")
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("**πŸ“₯ Input Processing:**")
st.write("β€’ Language detection")
st.write("β€’ SQL injection detection")
st.write("β€’ Toxic content filtering")
st.write("β€’ Intent classification")
with col2:
st.markdown("**πŸ” Retrieval:**")
st.write("β€’ Sentence-BERT embeddings")
st.write("β€’ ChromaDB similarity search")
st.write("β€’ Context window management")
st.write("β€’ Relevance scoring")
with col3:
st.markdown("**πŸ“€ Output Generation:**")
st.write("β€’ Hugging Face API")
st.write("β€’ PII redaction")
st.write("β€’ Hallucination detection")
st.write("β€’ Response validation")
# Security Information
st.markdown("### πŸ”’ Security & Privacy Features")
with st.expander("πŸ›‘οΈ Security Measures", expanded=True):
col1, col2 = st.columns(2)
with col1:
st.markdown("**Input Guardrails:**")
st.write("βœ… SQL injection prevention")
st.write("βœ… Command injection blocking")
st.write("βœ… Toxic language filtering")
st.write("βœ… Language validation")
with col2:
st.markdown("**Output Guardrails:**")
st.write("βœ… Email address redaction")
st.write("βœ… SVNR number protection")
st.write("βœ… Irrelevant response filtering")
st.write("βœ… Data leakage prevention")
# Experiment Information
st.markdown("### πŸ§ͺ Available Experiments")
exp_info = [
{
"Experiment": "πŸ›‘οΈ Input Guards",
"Purpose": "Test security against malicious inputs",
"Tests": "SQL injection, toxic content, data extraction attempts",
"Goal": "Block harmful queries while allowing legitimate ones"
},
{
"Experiment": "πŸ” Output Guards",
"Purpose": "Validate response safety and quality",
"Tests": "PII leakage, SVNR exposure, relevance checking",
"Goal": "Prevent sensitive data exposure and ensure relevance"
}
]
df = pd.DataFrame(exp_info)
st.dataframe(df, use_container_width=True)
def render_input_guardrails_tab():
"""Render input guardrails experiment interface"""
st.subheader("πŸ›‘οΈ Input Guardrails Testing")
# Add explanation
with st.expander("ℹ️ About Input Guardrails", expanded=False):
st.markdown("""
**Purpose:** Test the system's ability to detect and block malicious or inappropriate inputs.
**What we test:**
- 🚫 **SQL Injection**: Attempts to manipulate database queries
- 🚫 **Command Injection**: System command execution attempts
- 🚫 **Toxic Content**: Inappropriate or offensive language
- 🚫 **Data Extraction**: Attempts to access sensitive information (emails, SVNR)
- βœ… **Legitimate Queries**: Normal university-related questions should pass
**How it works:**
- Language detection to ensure English input
- Pattern matching for common attack vectors
- Content filtering for inappropriate language
- Context analysis for data extraction attempts
""")
col1, col2 = st.columns([2, 1])
with col1:
st.markdown("### πŸ§ͺ Test Custom Input")
# Custom input testing
test_input = st.text_area(
"Enter your test input:",
placeholder="Try: 'SELECT * FROM students WHERE name=admin--' or 'What courses does Maria take?'",
height=100,
help="Enter any input to test if our guardrails can detect malicious content"
)
test_button = st.button("πŸ” Test Input", type="primary")
if test_button and test_input:
test_single_input(test_input)
with col2:
st.markdown("### πŸš€ Quick Attack Tests")
# Load test cases directly from experiment file
try:
from experiments.experiment_1_input_guardrails import InputGuardrailsExperiment
exp = InputGuardrailsExperiment()
test_cases = exp._get_test_cases()
for i, test_case in enumerate(test_cases):
if test_case["expected_blocked"]: # Only show attack cases, not legitimate
with st.container():
# Map categories to icons
icon_map = {
"sql_injection": "πŸ’‰",
"xss_injection": "πŸ”“",
"toxicity": "🀬",
"command_injection": "πŸ’₯"
}
icon = icon_map.get(test_case["category"], "⚠️")
st.markdown(f"**{icon} {test_case['name']}**")
st.caption("Expected: BLOCK")
if st.button(f"Test {test_case['name']}", key=f"test_{i}", use_container_width=True):
test_single_input(test_case['input'])
st.markdown("---")
# Add legitimate test from experiment file
legitimate_cases = [tc for tc in test_cases if not tc["expected_blocked"]]
if legitimate_cases:
test_case = legitimate_cases[0] # Use first legitimate case
st.markdown("**βœ… Legitimate Query**")
st.caption("Expected: ALLOW")
if st.button("Test Legitimate", key="legitimate_test", use_container_width=True):
test_single_input(test_case['input'])
except Exception as e:
st.error(f"Could not load test cases: {e}")
st.info("Using fallback test cases...")
# Fallback to simple test
if st.button("Test SQL Injection", key="fallback_test", use_container_width=True):
test_single_input("SELECT * FROM students WHERE name='admin'--")
# Results display
if "input_test_results" in st.session_state:
display_input_test_results()
def render_output_guardrails_tab():
"""Render output guardrails experiment interface"""
st.subheader("πŸ” Output Guardrails Testing")
# Add explanation
with st.expander("ℹ️ About Output Guardrails", expanded=False):
st.markdown("""
**Purpose:** Test the system's ability to detect and filter problematic content in generated responses.
**What we detect:**
- πŸ“§ **PII Leakage**: Email addresses that shouldn't be exposed
- πŸ”’ **SVNR Exposure**: Social security numbers (highly sensitive)
- 🎯 **Relevance Issues**: Responses not related to university queries
- 🚫 **Data Leakage**: Any sensitive information exposure
**How it works:**
- Pattern matching for emails and ID numbers
- Content analysis for relevance to university context
- Automatic redaction of detected sensitive data
- Response filtering based on content quality
**Test Process:**
1. Simulate problematic AI responses
2. Run through output guardrails
3. Check what gets detected and filtered
""")
col1, col2 = st.columns(2)
with col1:
st.markdown("### πŸ§ͺ Test Custom Response")
# Custom response testing (like input guardrails)
test_response = st.text_area(
"Enter AI response to test:",
placeholder="Try: 'Student email: [email protected]' or 'SVNR: 1235567890'",
height=100,
help="Enter any AI response to test if our output guardrails can detect issues"
)
test_button = st.button("πŸ” Test Response", type="primary")
if test_button and test_response:
test_real_output_filtering(test_response)
with col2:
st.markdown("### πŸš€ Quick Response Tests")
# Load test cases directly from experiment file (like input guardrails)
try:
# Get API key (needed for experiment initialization)
try:
import secrets_local
api_key = secrets_local.HF
except ImportError:
import os
api_key = os.environ.get("HF_TOKEN", "dummy_key")
from experiments.experiment_2_output_guardrails import OutputGuardrailsExperiment
exp = OutputGuardrailsExperiment(api_key)
test_cases = exp._get_test_cases()
# Create demo responses for each test category
demo_responses = [
{
"name": "πŸ“§ Email Leakage Test",
"response": "Student email: [email protected], [email protected]",
"expected": "FILTER",
"category": "pii_leakage"
},
{
"name": "πŸ”’ SVNR Leakage Test",
"response": "Student SVNR numbers: 1235567890, 9870543210",
"expected": "FILTER",
"category": "svnr_exposure"
},
{
"name": "🎯 Irrelevant Response Test",
"response": "Today's weather is sunny with 25Β°C temperature",
"expected": "FILTER",
"category": "irrelevant"
}
]
for i, demo in enumerate(demo_responses):
with st.container():
st.markdown(f"**{demo['name']}**")
st.caption(f"Expected: {demo['expected']}")
if st.button(f"Test Response #{i+1}", key=f"response_test_{i}", use_container_width=True):
test_real_output_filtering(demo['response'])
st.markdown("---")
except Exception as e:
st.error(f"Could not load output test cases: {e}")
st.info("Using fallback test...")
if st.button("Test Email Detection", key="fallback_output_test", use_container_width=True):
test_real_output_filtering("Student email: [email protected]")
# Results display
if "output_test_results" in st.session_state:
display_output_test_results()
def test_single_input(test_input: str):
"""Test a single input through the complete RAG pipeline"""
try:
# Initialize RAG components
model = RAGModel(HF_TOKEN)
output_guardrails = OutputGuardrails()
input_guardrails = input_guard.InputGuardRails()
# Run through complete RAG pipeline
start_time = time.time()
result = query_rag_pipeline(test_input, model, output_guardrails, input_guardrails)
# Determine if input was blocked by checking if we got a guardrail rejection
blocked = ("Invalid input" in result.answer or
"SQL injection" in result.answer or
"inappropriate" in result.answer.lower() or
"blocked" in result.answer.lower())
# Store results in compatible format
st.session_state.input_test_results = {
"input": test_input,
"blocked": blocked,
"reason": result.answer if blocked else "Input accepted - generated response successfully",
"full_answer": result.answer,
"sources": result.sources,
"processing_time": result.processing_time,
"timestamp": datetime.now().strftime('%H:%M:%S')
}
except Exception as e:
st.error(f"Error testing input: {e}")
def test_real_output_filtering(test_response: str):
"""Test output filtering through complete RAG pipeline by generating a response that should contain the test content"""
try:
# Initialize RAG components
model = RAGModel(HF_TOKEN)
output_guardrails = OutputGuardrails()
input_guardrails = input_guard.InputGuardRails()
# Create a query that would likely generate the test response content
# This is a bit of a hack, but it allows us to test output filtering in a realistic way
if "email" in test_response.lower():
test_query = "What are some example student contact details?"
elif "svnr" in test_response.lower() or any(char.isdigit() for char in test_response):
test_query = "Can you show me student identification numbers?"
elif "weather" in test_response.lower() or "temperature" in test_response.lower():
test_query = "What's the current weather like?"
elif "programming" in test_response.lower() or "computer science" in test_response.lower():
test_query = "What computer science courses are available?"
else:
test_query = "Tell me about university information"
# Run through complete pipeline with output guardrails disabled first to get raw response
raw_result = query_rag_pipeline(test_query, model, output_guardrails, input_guardrails,
input_guardrails_active=True, output_guardrails_active=False)
# Now test the provided response text through output guardrails manually
# (This simulates what would happen if the LLM generated the test response)
from rag import retriever
# Get context for the test query
try:
context = retriever.search(test_query, top_k=3)
except:
context = []
# Test the provided response against output guardrails
guardrail_results = output_guardrails.check(test_query, test_response, context)
# Apply redaction to the test response
filtered_response = test_response
from helper import EMAIL_PATTERN
filtered_response = EMAIL_PATTERN.sub('[REDACTED_EMAIL]', filtered_response)
filtered_response = output_guardrails.redact_svnrs(filtered_response)
# Process guardrail results
issues_detected = []
for check_name, result in guardrail_results.items():
if not result.passed:
issue_details = ", ".join(result.issues) if result.issues else "Failed validation"
issues_detected.append(f"{check_name}: {issue_details}")
blocked = len(issues_detected) > 0
# Store results in session state
st.session_state.output_test_results = {
"original": test_response,
"filtered": filtered_response,
"blocked": blocked,
"issues": issues_detected,
"query_used": test_query,
"context_docs": len(context),
"guardrails_enabled": True,
"timestamp": datetime.now().strftime('%H:%M:%S'),
"system": "REAL"
}
except Exception as e:
st.error(f"Error testing output filtering: {e}")
import traceback
st.error(f"Details: {traceback.format_exc()}")
def test_output_filtering(response: str, enable_filtering: bool):
"""Test output filtering (legacy/fallback method)"""
try:
# Simple filtering simulation
filtered_response = response
issues = []
if enable_filtering:
if "@" in response:
issues.append("Email detected")
filtered_response = response.replace("@", "[EMAIL]")
if any(char.isdigit() for char in response) and len([c for c in response if c.isdigit()]) > 5:
issues.append("Potential SVNR/ID detected")
st.session_state.output_test_results = {
"original": response,
"filtered": filtered_response,
"issues": issues,
"guardrails_enabled": enable_filtering,
"timestamp": datetime.now().strftime('%H:%M:%S'),
"system": "SIMULATED" # Mark as simulation
}
except Exception as e:
st.error(f"Error testing output: {e}")
def display_input_test_results():
"""Display input test results from RAG pipeline"""
results = st.session_state.input_test_results
st.markdown("### πŸ” Input Test Results (Full RAG Pipeline)")
col1, col2 = st.columns(2)
with col1:
st.markdown("**Input Query:**")
st.code(results["input"])
if not results["blocked"] and results.get("sources"):
st.markdown("**Sources Retrieved:**")
with st.expander(f"πŸ“š {len(results['sources'])} sources found"):
for source in results["sources"][:3]: # Show first 3 sources
st.write(f"β€’ {source['title']}")
with col2:
if results["blocked"]:
st.error(f"🚫 BLOCKED: {results['reason']}")
else:
st.success("βœ… ALLOWED - Generated Response")
if results.get("full_answer"):
with st.expander("πŸ“ Generated Response"):
st.write(results["full_answer"])
# Show performance metrics
if results.get("processing_time"):
st.metric("Processing Time", f"{results['processing_time']:.3f}s")
st.caption(f"Tested at {results['timestamp']} | System: Full RAG Pipeline")
def display_output_test_results():
"""Display output test results from RAG pipeline integration"""
results = st.session_state.output_test_results
# Show system type
system_type = results.get("system", "UNKNOWN")
st.markdown("### πŸ” Output Test Results")
col1, col2 = st.columns(2)
with col1:
st.markdown("**Original Response:**")
# Handle both 'original' and 'input' keys for compatibility
original_text = results.get("original", results.get("input", ""))
st.write(original_text)
if results.get("query_used"):
st.markdown("**Query Used:**")
st.caption(f"πŸ“ {results['query_used']}")
with col2:
st.markdown("**Filtered Response:**")
st.write(results["filtered"])
if results.get("context_docs"):
st.markdown("**Context Retrieved:**")
st.caption(f"πŸ“š {results['context_docs']} documents")
# Handle both old and new result formats
if system_type == "REAL":
# New real system results
if results.get("blocked", False):
st.error("🚫 Response BLOCKED by output guardrails")
if results.get("issues"):
st.warning("**Issues detected:**")
for issue in results["issues"]:
st.write(f"β€’ {issue}")
else:
st.success("βœ… Response PASSED output guardrails")
else:
# Legacy simulated results
if results.get("issues"):
st.warning(f"Issues detected: {', '.join(results['issues'])}")
else:
st.success("No issues detected")
st.caption(f"Tested at {results['timestamp']} | System: {system_type} (RAG Pipeline Integration)")
if __name__ == "__main__":
render_experiment_dashboard()