Spaces:
Running
Running
File size: 14,217 Bytes
bd0b2d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 |
# backend/main.py
from urllib.parse import unquote
from typing import List, Optional
import uuid
import auth
# --- Standard Imports ---
from fastapi import FastAPI, Depends, HTTPException, UploadFile, File, Form
from pydantic import BaseModel
from database import engine, get_db, SessionLocal
from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy.exc import OperationalError
from sqlalchemy import text
import os
import shutil # Important for file operations
from pipeline import highlight_text
from ml_qna import qna as generate_ml_answer
# from email_automation import download_attached_file
# import imaplib
from contextlib import asynccontextmanager
from pipeline import pipeline_process_pdf, load_all_models
from fastapi import BackgroundTasks
# --- Middleware Import ---
from fastapi.middleware.cors import CORSMiddleware
# --- Local Module Imports ---
import crud
import models
import schemas
from database import engine, get_db
from supabase_utils import upload_file_to_supabase
# This creates/updates the database tables in your Neon database
# based on your models.py file.
models.Base.metadata.create_all(bind=engine)
# --- (3) SETUP FOR LOADING MODELS ON STARTUP ---
# This dictionary will hold our loaded models so we don't reload them on every request
ml_models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# This code runs ONCE when the server starts up
print("[INFO] Server starting up...")
# --- ADD THIS ENTIRE BLOCK ---
print("[INFO] Ensuring system 'automation_user' exists...")
db = SessionLocal()
try:
# Check if the user already exists
automation_user = crud.get_user(db, user_id="automation_user")
if not automation_user:
# If not, create it
print("[INFO] 'automation_user' not found. Creating it now...")
user_data = schemas.UserCreate(
id="automation_user",
name="Automation Service",
department="System",
role="system",
password="automation_pass" # A placeholder password
)
crud.create_user(db, user_data)
print("[INFO] 'automation_user' created successfully.")
else:
print("[INFO] 'automation_user' already exists.")
finally:
db.close() # Always close the database session
# --- END OF BLOCK ---
print("[INFO] Loading ML models...")
tokenizer, model, nlp_model = load_all_models()
ml_models["tokenizer"] = tokenizer
ml_models["model"] = model
ml_models["nlp_model"] = nlp_model
print("[INFO] ML models loaded successfully and are ready.")
yield
ml_models.clear()
print("[INFO] Server shutting down.")
# @asynccontextmanager
# async def lifespan(app: FastAPI):
# # This code runs ONCE when the server starts up
# print("[INFO] Server starting up. Loading ML models...")
# tokenizer, model, nlp_model = load_all_models()
# ml_models["tokenizer"] = tokenizer
# ml_models["model"] = model
# ml_models["nlp_model"] = nlp_model
# print("[INFO] ML models loaded successfully and are ready.")
# yield
# # This code runs when the server shuts down
# ml_models.clear()
# print("[INFO] Server shutting down.")
app = FastAPI(lifespan=lifespan)
# This list now includes the new port your frontend is using
origins = [
"http://localhost:3000",
"http://127.0.0.1:3000",
"https://kochi-metro-document.vercel.app",
"http://localhost:3003", # <-- ADD THIS LINE
"http://127.0.0.1:3003", # <-- And this one for good measure
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # Use the updated list
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- LOCAL UPLOAD DIRECTORY for temporary storage ---
UPLOAD_DIRECTORY = "uploads"
os.makedirs(UPLOAD_DIRECTORY, exist_ok=True)
# --- Diagnostic Endpoints ---
@app.get("/")
def read_root():
return {"status": "ok", "service": "kmrl-backend-service"}
@app.get("/ping-db")
def ping_db(db: Session = Depends(get_db)):
try:
db.execute(text("SELECT 1"))
return {"status": "ok", "message": "Database connection successful."}
except OperationalError as e:
raise HTTPException(
status_code=500,
detail=f"Database connection failed: {str(e)}"
)
# --- User Management Endpoints ---
@app.post("/users/", response_model=schemas.User)
def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)):
db_user = crud.get_user(db, user_id=user.id)
if db_user:
raise HTTPException(status_code=400, detail="User ID already registered")
return crud.create_user(db=db, user=user)
@app.get("/users/{user_id}", response_model=schemas.User)
def read_user(user_id: str, db: Session = Depends(get_db)):
db_user = crud.get_user(db, user_id=user_id)
if db_user is None:
raise HTTPException(status_code=404, detail="User not found")
return db_user
# --- Document Management Endpoints ---
@app.post("/documents/upload")
def upload_document(
# Optional fields for email automation, but required for frontend
title: Optional[str] = Form(None),
department: Optional[str] = Form(None),
user_id: Optional[str] = Form(None),
# The file is always required
file: UploadFile = File(...),
db: Session = Depends(get_db)
):
# --- 1. Set Default Values & Validate User ---
# If a title wasn't provided (from email), create a default one.
final_title = title or f"Email Attachment - {file.filename}"
# If a user_id wasn't provided, it MUST be the automation user.
final_user_id = user_id or "automation_user"
# If a department wasn't provided, set it to be auto-detected by the pipeline.
final_department = department or "auto-detected"
# Now, use these final variables to validate the user
user = crud.get_user(db, user_id=final_user_id)
if not user:
raise HTTPException(status_code=404, detail=f"Uploader '{final_user_id}' not found")
# --- 2. Upload Original File to Cloud ---
print("Uploading original file to cloud storage...")
public_url = upload_file_to_supabase(file.file, file.filename)
if not public_url:
raise HTTPException(status_code=500, detail="Could not upload file to cloud storage.")
print("File uploaded successfully. Public URL:", public_url)
file.file.seek(0) # Rewind file for local processing
# --- 3. Create Initial Database Record ---
# This now correctly matches the function in crud.py (which should not take highlighted_file_path)
document_data = schemas.DocumentCreate(title=final_title, department=final_department)
db_document = crud.create_document(db=db, document=document_data, file_path=public_url, user_id=final_user_id)
print(f"Initial document record created in DB with ID: {db_document.id}")
# --- 4. Save Local Copy & Run ML Pipeline ---
local_file_path = os.path.join(UPLOAD_DIRECTORY, file.filename)
with open(local_file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
print("Starting ML pipeline processing...")
ml_results = pipeline_process_pdf(
pdf_path=local_file_path,
clf_tokenizer=ml_models["tokenizer"],
clf_model=ml_models["model"],
nlp_model=ml_models["nlp_model"]
)
print("ML pipeline processing complete.")
# --- 5. Upload Highlighted PDF (if created) ---
highlighted_pdf_path = ml_results.get("highlighted_pdf")
highlighted_public_url = None
if highlighted_pdf_path and os.path.exists(highlighted_pdf_path):
print("Uploading highlighted file to cloud storage...")
with open(highlighted_pdf_path, "rb") as f:
highlighted_filename = os.path.basename(highlighted_pdf_path)
highlighted_public_url = upload_file_to_supabase(f, highlighted_filename)
print("Highlighted PDF uploaded successfully.")
# ... (after the ML pipeline runs) ...
# --- (6) UPDATE THE DATABASE RECORD WITH ML RESULTS ---
print("Updating database record with ML results...")
final_document = crud.update_document_with_ml_results(
db,
document_id=db_document.id,
ml_results=ml_results,
highlighted_file_path=highlighted_public_url
)
print("Database record updated successfully.")
# --- (7) CREATE NOTIFICATION FOR THE DEPARTMENT ---
# The ML results contain the department the document was routed to.
routed_department = final_document.department
if routed_department and routed_department != "Unknown":
notification_message = f"New document '{final_document.title}' has been assigned to your department."
crud.create_notification(
db=db,
document_id=final_document.id,
department=routed_department,
message=notification_message
)
print(f"Notification created for department: {routed_department}")
# --- 8. Cleanup Local Files ---
try:
if os.path.exists(local_file_path):
os.remove(local_file_path)
if highlighted_pdf_path and os.path.exists(highlighted_pdf_path):
os.remove(highlighted_pdf_path)
except OSError as e:
print(f"Error during file cleanup: {e}")
# --- 9. Return Final Response ---
return {
"message": "Document processed and all data saved successfully.",
"document_info": schemas.Document.model_validate(final_document),
"highlighted_pdf_url": highlighted_public_url
}
# --- Read Endpoints ---
@app.get("/documents/", response_model=list[schemas.Document])
def read_all_documents(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
documents = crud.get_all_documents(db, skip=skip, limit=limit)
return documents
@app.get("/documents/{department}", response_model=list[schemas.Document])
def read_documents_for_department(department: str, skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
documents = crud.get_documents_by_department(db, department=department, skip=skip, limit=limit)
return documents
# --- ADD THESE NEW ENDPOINTS FOR Q&A ---
def run_ml_qna_in_background(question_id: uuid.UUID, pinecone_pdf_id: str, question_text: str):
print(f"[BACKGROUND TASK] Starting ML RAG pipeline for question ID: {question_id}")
# Call the ML function to get the answer
answer_text = generate_ml_answer(
pdf_id=pinecone_pdf_id, # <--- This now uses the correct filename ID
query=question_text
)
print(f"[BACKGROUND TASK] Answer generated: {answer_text[:100]}...")
# Use the CRUD function to save the answer to the database
db = SessionLocal()
try:
# Use the new 'db' session to update the database
crud.update_question_with_answer(
db=db,
question_id=question_id,
answer_text=answer_text
)
print(f"[BACKGROUND TASK] Answer saved to database for question ID: {question_id}")
finally:
db.close()
@app.post("/documents/{document_id}/questions", response_model=schemas.Question)
def ask_question_on_document(
document_id: uuid.UUID,
question: schemas.QuestionCreate,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db)
):
"""
Endpoint for the frontend to submit a new question.
It saves the question, calls the ML RAG pipeline to generate a real answer,
and saves the answer to the database.
"""
# First, fetch the document to get its uploader's ID
document = crud.get_document_by_id(db, document_id=document_id)
if not document:
raise HTTPException(status_code=404, detail="Document not found")
user_id_who_asked = document.uploader_id
# Create the question in the database with a NULL answer first
db_question = crud.create_question(
db=db,
document_id=document_id,
user_id=user_id_who_asked,
question=question
)
print(f"New question saved with ID: {db_question.id}. Triggering background ML task.")
pinecone_pdf_id = os.path.splitext(os.path.basename(unquote(document.file_path)))[0]
background_tasks.add_task(
run_ml_qna_in_background,
db_question.id,
pinecone_pdf_id, # <--- Pass the correct filename ID
question.question_text,
)
# --- END OF KEY CHANGE ---
# Return the new question object to the frontend immediately.
# The frontend will see that `answer_text` is still null.
return db_question
@app.get("/documents/{document_id}/questions", response_model=List[schemas.Question])
def get_document_questions(
document_id: uuid.UUID,
db: Session = Depends(get_db) # Ensure there are no typos like 'get_d b'
):
"""
Endpoint for the frontend to retrieve the full conversation history
(all questions and their answers) for a document.
"""
return crud.get_questions_for_document(db=db, document_id=document_id)
# --- ADD THIS NEW ENDPOINT FOR EMAIL AUTOMATION ---
@app.patch("/questions/{question_id}/answer")
def submit_answer(
question_id: uuid.UUID,
answer: schemas.Answer,
db: Session = Depends(get_db)
):
"""
INTERNAL ENDPOINT for the ML service to submit its generated answer
for a question that has already been created.
"""
updated_question = crud.update_question_with_answer(
db=db,
question_id=question_id,
answer_text=answer.answer_text
)
if not updated_question:
raise HTTPException(status_code=404, detail="Question not found")
print(f"Answer submitted for question ID: {question_id}")
return {"status": "success", "question": updated_question}
# --- NEW ENDPOINT FOR NOTIFICATIONS ---
@app.get("/notifications/{department}", response_model=List[schemas.Notification])
def read_notifications(department: str, db: Session = Depends(get_db)):
"""Fetches unread notifications for a given department."""
notifications = crud.get_notifications_for_department(db, department=department)
return notifications
|