msse-ai-engineering / enhanced_app.py
Seth McKnight
Enhance deployment pipeline and modern chat interface (#53)
74e758d
raw
history blame
10.8 kB
"""
Enhanced Flask app with integrated guardrails system.
This module demonstrates how to integrate the guardrails system
with the existing Flask API endpoints.
"""
# ...existing code...
from dotenv import load_dotenv
from flask import Flask, jsonify, render_template, request
# Load environment variables from .env file
load_dotenv()
app = Flask(__name__)
@app.route("/")
def index():
"""
Renders the chat interface.
"""
return render_template("chat.html")
@app.route("/health")
def health():
"""
Health check endpoint.
"""
return jsonify({"status": "ok"}), 200
@app.route("/chat", methods=["POST"])
def chat():
"""
Enhanced endpoint for conversational RAG interactions with guardrails.
Accepts JSON requests with user messages and returns AI-generated
responses with comprehensive validation and safety checks.
"""
try:
# Validate request contains JSON data
if not request.is_json:
return (
jsonify(
{
"status": "error",
"message": "Content-Type must be application/json",
}
),
400,
)
data = request.get_json()
# Validate required message parameter
message = data.get("message")
if message is None:
return (
jsonify(
{"status": "error", "message": "message parameter is required"}
),
400,
)
if not isinstance(message, str) or not message.strip():
return (
jsonify(
{"status": "error", "message": "message must be a non-empty string"}
),
400,
)
# Extract optional parameters
conversation_id = data.get("conversation_id")
include_sources = data.get("include_sources", True)
include_debug = data.get("include_debug", False)
enable_guardrails = data.get("enable_guardrails", True)
# Initialize enhanced RAG pipeline components
try:
from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH
from src.embedding.embedding_service import EmbeddingService
from src.llm.llm_service import LLMService
from src.rag.enhanced_rag_pipeline import EnhancedRAGPipeline
from src.rag.rag_pipeline import RAGPipeline
from src.rag.response_formatter import ResponseFormatter
from src.search.search_service import SearchService
from src.vector_store.vector_db import VectorDatabase
# Initialize services
vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME)
embedding_service = EmbeddingService()
search_service = SearchService(vector_db, embedding_service)
# Initialize LLM service from environment
llm_service = LLMService.from_environment()
# Initialize base RAG pipeline
base_rag_pipeline = RAGPipeline(search_service, llm_service)
# Initialize enhanced pipeline with guardrails if enabled
if enable_guardrails:
# Configure guardrails for production use
guardrails_config = {
"min_confidence_threshold": 0.7,
"strict_mode": False,
"enable_response_enhancement": True,
"log_all_results": True,
}
rag_pipeline = EnhancedRAGPipeline(base_rag_pipeline, guardrails_config)
else:
rag_pipeline = base_rag_pipeline
# Initialize response formatter
formatter = ResponseFormatter()
except ValueError as e:
return (
jsonify(
{
"status": "error",
"message": f"LLM service configuration error: {str(e)}",
"details": (
"Please ensure OPENROUTER_API_KEY or GROQ_API_KEY "
"environment variables are set"
),
}
),
503,
)
except Exception as e:
return (
jsonify(
{
"status": "error",
"message": f"Service initialization failed: {str(e)}",
}
),
500,
)
# Generate RAG response with enhanced validation
rag_response = rag_pipeline.generate_answer(message.strip())
# Format response for API with guardrails information
if include_sources:
formatted_response = formatter.format_api_response(
rag_response, include_debug
)
# Add guardrails information if available
if hasattr(rag_response, "guardrails_approved"):
formatted_response["guardrails"] = {
"approved": rag_response.guardrails_approved,
"confidence": rag_response.guardrails_confidence,
"safety_passed": rag_response.safety_passed,
"quality_score": rag_response.quality_score,
"warnings": getattr(rag_response, "guardrails_warnings", []),
"fallbacks": getattr(rag_response, "guardrails_fallbacks", []),
}
else:
formatted_response = formatter.format_chat_response(
rag_response, conversation_id, include_sources=False
)
return jsonify(formatted_response)
except Exception as e:
return (
jsonify({"status": "error", "message": f"Chat request failed: {str(e)}"}),
500,
)
@app.route("/chat/health", methods=["GET"])
def chat_health():
"""
Health check endpoint for enhanced RAG chat functionality.
Returns the status of all RAG pipeline components including guardrails.
"""
try:
from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH
from src.embedding.embedding_service import EmbeddingService
from src.llm.llm_service import LLMService
from src.rag.enhanced_rag_pipeline import EnhancedRAGPipeline
from src.rag.rag_pipeline import RAGPipeline
from src.search.search_service import SearchService
from src.vector_store.vector_db import VectorDatabase
# Initialize services
vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME)
embedding_service = EmbeddingService()
search_service = SearchService(vector_db, embedding_service)
llm_service = LLMService.from_environment()
# Initialize enhanced pipeline
base_rag_pipeline = RAGPipeline(search_service, llm_service)
enhanced_pipeline = EnhancedRAGPipeline(base_rag_pipeline)
# Get comprehensive health status
health_status = enhanced_pipeline.get_health_status()
return jsonify(
{
"status": "healthy",
"components": health_status,
"timestamp": health_status.get("timestamp", "unknown"),
}
)
except ValueError as e:
# Specific handling for LLM configuration errors
return (
jsonify(
{
"status": "error",
"message": f"LLM configuration error: {str(e)}",
"health": {
"pipeline_status": "unhealthy",
"components": {
"llm_service": {
"status": "unconfigured",
"error": str(e),
}
},
},
}
),
503,
)
except Exception as e:
return (
jsonify(
{
"status": "unhealthy",
"error": str(e),
"components": {"error": "Failed to initialize components"},
}
),
500,
)
@app.route("/guardrails/validate", methods=["POST"])
def validate_response():
"""
Standalone endpoint for validating responses with guardrails.
Allows testing of guardrails validation without full RAG pipeline.
"""
try:
if not request.is_json:
return (
jsonify(
{
"status": "error",
"message": "Content-Type must be application/json",
}
),
400,
)
data = request.get_json()
# Validate required parameters
response_text = data.get("response")
query_text = data.get("query")
sources = data.get("sources", [])
if not response_text or not query_text:
return (
jsonify(
{
"status": "error",
"message": "response and query parameters are required",
}
),
400,
)
# Initialize enhanced pipeline for validation
from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH
from src.embedding.embedding_service import EmbeddingService
from src.llm.llm_service import LLMService
from src.rag.enhanced_rag_pipeline import EnhancedRAGPipeline
from src.rag.rag_pipeline import RAGPipeline
from src.search.search_service import SearchService
from src.vector_store.vector_db import VectorDatabase
# Initialize services
vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME)
embedding_service = EmbeddingService()
search_service = SearchService(vector_db, embedding_service)
llm_service = LLMService.from_environment()
# Initialize enhanced pipeline
base_rag_pipeline = RAGPipeline(search_service, llm_service)
enhanced_pipeline = EnhancedRAGPipeline(base_rag_pipeline)
# Perform validation
validation_result = enhanced_pipeline.validate_response_only(
response_text, query_text, sources
)
return jsonify({"status": "success", "validation": validation_result})
except Exception as e:
return (
jsonify({"status": "error", "message": f"Validation failed: {str(e)}"}),
500,
)
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=8080)