LiamKhoaLe commited on
Commit
46db43c
·
1 Parent(s): 412cffc

Migration to smaller subs

Browse files
Files changed (12) hide show
  1. app.py +12 -1271
  2. copy.py +1278 -0
  3. helpers/__init__.py +2 -0
  4. helpers/models.py +54 -0
  5. helpers/pages.py +26 -0
  6. helpers/setup.py +59 -0
  7. routes/auth.py +53 -0
  8. routes/chats.py +463 -0
  9. routes/files.py +215 -0
  10. routes/health.py +71 -0
  11. routes/projects.py +130 -0
  12. routes/reports.py +139 -0
app.py CHANGED
@@ -1,1278 +1,19 @@
1
  # https://binkhoale1812-edsummariser.hf.space/
2
- import os, io, re, uuid, json, time, logging
3
- from typing import List, Dict, Any, Optional
4
- from datetime import datetime, timezone
5
- from pydantic import BaseModel
6
- import asyncio
7
 
8
- # Load environment variables from .env file
9
- from dotenv import load_dotenv
10
- load_dotenv()
11
 
12
- from fastapi import FastAPI, UploadFile, File, Form, Request, HTTPException, BackgroundTasks
13
- from fastapi.responses import FileResponse, JSONResponse, HTMLResponse
14
- from fastapi.staticfiles import StaticFiles
15
- from fastapi.middleware.cors import CORSMiddleware
16
-
17
- # MongoDB imports
18
- from pymongo.errors import PyMongoError, ConnectionFailure, ServerSelectionTimeoutError
19
-
20
- from utils.api.rotator import APIKeyRotator
21
- from utils.ingestion.parser import parse_pdf_bytes, parse_docx_bytes
22
- from utils.ingestion.caption import BlipCaptioner
23
- from utils.ingestion.chunker import build_cards_from_pages
24
- from utils.rag.embeddings import EmbeddingClient
25
- from utils.rag.rag import RAGStore, ensure_indexes
26
- from utils.api.router import select_model, generate_answer_with_model
27
- from utils.service.summarizer import cheap_summarize
28
- from utils.service.common import trim_text
29
- from utils.logger import get_logger
30
- import re
31
-
32
- # ────────────────────────────── Response Models ──────────────────────────────
33
- class ProjectResponse(BaseModel):
34
- project_id: str
35
- user_id: str
36
- name: str
37
- description: str
38
- created_at: str
39
- updated_at: str
40
-
41
- class ProjectsListResponse(BaseModel):
42
- projects: List[ProjectResponse]
43
-
44
- class ChatMessageResponse(BaseModel):
45
- user_id: str
46
- project_id: str
47
- role: str
48
- content: str
49
- timestamp: float
50
- created_at: str
51
- sources: Optional[List[Dict[str, Any]]] = None
52
-
53
- class ChatHistoryResponse(BaseModel):
54
- messages: List[ChatMessageResponse]
55
-
56
- class MessageResponse(BaseModel):
57
- message: str
58
-
59
- class UploadResponse(BaseModel):
60
- job_id: str
61
- status: str
62
- total_files: Optional[int] = None
63
-
64
- class FileSummaryResponse(BaseModel):
65
- filename: str
66
- summary: str
67
-
68
- class ChatAnswerResponse(BaseModel):
69
- answer: str
70
- sources: List[Dict[str, Any]]
71
- relevant_files: Optional[List[str]] = None
72
-
73
- class HealthResponse(BaseModel):
74
- ok: bool
75
-
76
- class ReportResponse(BaseModel):
77
- filename: str
78
- report_markdown: str
79
- sources: List[Dict[str, Any]]
80
-
81
- # ────────────────────────────── App Setup ──────────────────────────────
82
- logger = get_logger("APP", name="studybuddy")
83
-
84
- app = FastAPI(title="StudyBuddy RAG", version="0.1.0")
85
- app.add_middleware(
86
- CORSMiddleware,
87
- allow_origins=["*"],
88
- allow_credentials=True,
89
- allow_methods=["*"],
90
- allow_headers=["*"],
91
- )
92
-
93
- # Serve static files (index.html, scripts.js, styles.css)
94
- app.mount("/static", StaticFiles(directory="static"), name="static")
95
-
96
- # In-memory job tracker (for progress queries)
97
- app.state.jobs = {}
98
-
99
-
100
- # ────────────────────────────── Global Clients ──────────────────────────────
101
- # API rotators (round robin + auto failover on quota errors)
102
- gemini_rotator = APIKeyRotator(prefix="GEMINI_API_", max_slots=5)
103
- nvidia_rotator = APIKeyRotator(prefix="NVIDIA_API_", max_slots=5)
104
-
105
- # Captioner + Embeddings (lazy init inside classes)
106
- captioner = BlipCaptioner()
107
- embedder = EmbeddingClient(model_name=os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2"))
108
-
109
- # Mongo / RAG store
110
- try:
111
- rag = RAGStore(mongo_uri=os.getenv("MONGO_URI"), db_name=os.getenv("MONGO_DB", "studybuddy"))
112
- # Test the connection
113
- rag.client.admin.command('ping')
114
- logger.info("[APP] MongoDB connection successful")
115
- ensure_indexes(rag)
116
- logger.info("[APP] MongoDB indexes ensured")
117
- except Exception as e:
118
- logger.error(f"[APP] Failed to initialize MongoDB/RAG store: {str(e)}")
119
- logger.error(f"[APP] MONGO_URI: {os.getenv('MONGO_URI', 'Not set')}")
120
- logger.error(f"[APP] MONGO_DB: {os.getenv('MONGO_DB', 'studybuddy')}")
121
- # Create a dummy RAG store for now - this will cause errors but prevents the app from crashing
122
- rag = None
123
-
124
-
125
- # ────────────────────────────── Auth Helpers/Routes ───────────────────────────
126
- import hashlib
127
- import secrets
128
-
129
-
130
- def _hash_password(password: str, salt: Optional[str] = None) -> Dict[str, str]:
131
- salt = salt or secrets.token_hex(16)
132
- dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), bytes.fromhex(salt), 120000)
133
- return {"salt": salt, "hash": dk.hex()}
134
-
135
-
136
- def _verify_password(password: str, salt: str, expected_hex: str) -> bool:
137
- dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), bytes.fromhex(salt), 120000)
138
- return secrets.compare_digest(dk.hex(), expected_hex)
139
-
140
-
141
- @app.post("/auth/signup")
142
- async def signup(email: str = Form(...), password: str = Form(...)):
143
- email = email.strip().lower()
144
- if not email or not password or "@" not in email:
145
- raise HTTPException(400, detail="Invalid email or password")
146
- users = rag.db["users"]
147
- if users.find_one({"email": email}):
148
- raise HTTPException(409, detail="Email already registered")
149
- user_id = str(uuid.uuid4())
150
- hp = _hash_password(password)
151
- users.insert_one({
152
- "email": email,
153
- "user_id": user_id,
154
- "pw_salt": hp["salt"],
155
- "pw_hash": hp["hash"],
156
- "created_at": int(time.time())
157
- })
158
- logger.info(f"[AUTH] Created user {email} -> {user_id}")
159
- return {"email": email, "user_id": user_id}
160
-
161
-
162
- @app.post("/auth/login")
163
- async def login(email: str = Form(...), password: str = Form(...)):
164
- email = email.strip().lower()
165
- users = rag.db["users"]
166
- doc = users.find_one({"email": email})
167
- if not doc:
168
- raise HTTPException(401, detail="Invalid credentials")
169
- if not _verify_password(password, doc.get("pw_salt", ""), doc.get("pw_hash", "")):
170
- raise HTTPException(401, detail="Invalid credentials")
171
- logger.info(f"[AUTH] Login {email}")
172
- return {"email": email, "user_id": doc.get("user_id")}
173
-
174
-
175
- # ────────────────────────────── Project Management ───────────────────────────
176
- @app.post("/projects/create", response_model=ProjectResponse)
177
- async def create_project(user_id: str = Form(...), name: str = Form(...), description: str = Form("")):
178
- """Create a new project for a user"""
179
- try:
180
- if not rag:
181
- raise HTTPException(500, detail="Database connection not available")
182
-
183
- if not name.strip():
184
- raise HTTPException(400, detail="Project name is required")
185
-
186
- if not user_id.strip():
187
- raise HTTPException(400, detail="User ID is required")
188
-
189
- project_id = str(uuid.uuid4())
190
- current_time = datetime.now(timezone.utc)
191
-
192
- project = {
193
- "project_id": project_id,
194
- "user_id": user_id,
195
- "name": name.strip(),
196
- "description": description.strip(),
197
- "created_at": current_time,
198
- "updated_at": current_time
199
- }
200
-
201
- logger.info(f"[PROJECT] Creating project {name} for user {user_id}")
202
-
203
- # Insert the project
204
- try:
205
- result = rag.db["projects"].insert_one(project)
206
- logger.info(f"[PROJECT] Created project {name} with ID {project_id}, MongoDB result: {result.inserted_id}")
207
- except PyMongoError as mongo_error:
208
- logger.error(f"[PROJECT] MongoDB error creating project: {str(mongo_error)}")
209
- raise HTTPException(500, detail=f"Database error: {str(mongo_error)}")
210
- except Exception as db_error:
211
- logger.error(f"[PROJECT] Database error creating project: {str(db_error)}")
212
- raise HTTPException(500, detail=f"Database error: {str(db_error)}")
213
-
214
- # Return a properly formatted response
215
- response = ProjectResponse(
216
- project_id=project_id,
217
- user_id=user_id,
218
- name=name.strip(),
219
- description=description.strip(),
220
- created_at=current_time.isoformat(),
221
- updated_at=current_time.isoformat()
222
- )
223
-
224
- logger.info(f"[PROJECT] Successfully created project {name} for user {user_id}")
225
- return response
226
-
227
- except HTTPException:
228
- # Re-raise HTTP exceptions
229
- raise
230
- except Exception as e:
231
- logger.error(f"[PROJECT] Error creating project: {str(e)}")
232
- logger.error(f"[PROJECT] Error type: {type(e)}")
233
- logger.error(f"[PROJECT] Error details: {e}")
234
- raise HTTPException(500, detail=f"Failed to create project: {str(e)}")
235
-
236
-
237
- @app.get("/projects", response_model=ProjectsListResponse)
238
- async def list_projects(user_id: str):
239
- """List all projects for a user"""
240
- projects_cursor = rag.db["projects"].find(
241
- {"user_id": user_id}
242
- ).sort("updated_at", -1)
243
-
244
- projects = []
245
- for project in projects_cursor:
246
- projects.append(ProjectResponse(
247
- project_id=project["project_id"],
248
- user_id=project["user_id"],
249
- name=project["name"],
250
- description=project.get("description", ""),
251
- created_at=project["created_at"].isoformat() if isinstance(project["created_at"], datetime) else str(project["created_at"]),
252
- updated_at=project["updated_at"].isoformat() if isinstance(project["updated_at"], datetime) else str(project["updated_at"])
253
- ))
254
-
255
- return ProjectsListResponse(projects=projects)
256
-
257
-
258
- @app.get("/projects/{project_id}", response_model=ProjectResponse)
259
- async def get_project(project_id: str, user_id: str):
260
- """Get a specific project (with user ownership check)"""
261
- project = rag.db["projects"].find_one(
262
- {"project_id": project_id, "user_id": user_id}
263
- )
264
- if not project:
265
- raise HTTPException(404, detail="Project not found")
266
-
267
- return ProjectResponse(
268
- project_id=project["project_id"],
269
- user_id=project["user_id"],
270
- name=project["name"],
271
- description=project.get("description", ""),
272
- created_at=project["created_at"].isoformat() if isinstance(project["created_at"], datetime) else str(project["created_at"]),
273
- updated_at=project["updated_at"].isoformat() if isinstance(project["updated_at"], datetime) else str(project["updated_at"])
274
- )
275
-
276
-
277
- @app.delete("/projects/{project_id}", response_model=MessageResponse)
278
- async def delete_project(project_id: str, user_id: str):
279
- """Delete a project and all its associated data"""
280
- # Check ownership
281
- project = rag.db["projects"].find_one({"project_id": project_id, "user_id": user_id})
282
- if not project:
283
- raise HTTPException(404, detail="Project not found")
284
-
285
- # Delete project and all associated data
286
- rag.db["projects"].delete_one({"project_id": project_id})
287
- rag.db["chunks"].delete_many({"project_id": project_id})
288
- rag.db["files"].delete_many({"project_id": project_id})
289
- rag.db["chat_sessions"].delete_many({"project_id": project_id})
290
-
291
- logger.info(f"[PROJECT] Deleted project {project_id} for user {user_id}")
292
- return MessageResponse(message="Project deleted successfully")
293
-
294
-
295
- # ────────────────────────────── Chat Sessions ──────────────────────────────
296
- @app.post("/chat/save", response_model=MessageResponse)
297
- async def save_chat_message(
298
- user_id: str = Form(...),
299
- project_id: str = Form(...),
300
- role: str = Form(...),
301
- content: str = Form(...),
302
- timestamp: Optional[float] = Form(None),
303
- sources: Optional[str] = Form(None)
304
- ):
305
- """Save a chat message to the session"""
306
- if role not in ["user", "assistant"]:
307
- raise HTTPException(400, detail="Invalid role")
308
-
309
- # Parse optional sources JSON
310
- parsed_sources: Optional[List[Dict[str, Any]]] = None
311
- if sources:
312
- try:
313
- parsed = json.loads(sources)
314
- if isinstance(parsed, list):
315
- parsed_sources = parsed
316
- except Exception:
317
- parsed_sources = None
318
-
319
- message = {
320
- "user_id": user_id,
321
- "project_id": project_id,
322
- "role": role,
323
- "content": content,
324
- "timestamp": timestamp or time.time(),
325
- "created_at": datetime.now(timezone.utc),
326
- **({"sources": parsed_sources} if parsed_sources is not None else {})
327
- }
328
-
329
- rag.db["chat_sessions"].insert_one(message)
330
- return MessageResponse(message="Chat message saved")
331
-
332
-
333
- @app.get("/chat/history", response_model=ChatHistoryResponse)
334
- async def get_chat_history(user_id: str, project_id: str, limit: int = 100):
335
- """Get chat history for a project"""
336
- messages_cursor = rag.db["chat_sessions"].find(
337
- {"user_id": user_id, "project_id": project_id}
338
- ).sort("timestamp", 1).limit(limit)
339
-
340
- messages = []
341
- for message in messages_cursor:
342
- messages.append(ChatMessageResponse(
343
- user_id=message["user_id"],
344
- project_id=message["project_id"],
345
- role=message["role"],
346
- content=message["content"],
347
- timestamp=message["timestamp"],
348
- created_at=message["created_at"].isoformat() if isinstance(message["created_at"], datetime) else str(message["created_at"]),
349
- sources=message.get("sources")
350
- ))
351
-
352
- return ChatHistoryResponse(messages=messages)
353
-
354
-
355
- @app.delete("/chat/history", response_model=MessageResponse)
356
- async def delete_chat_history(user_id: str, project_id: str):
357
- try:
358
- rag.db["chat_sessions"].delete_many({"user_id": user_id, "project_id": project_id})
359
- logger.info(f"[CHAT] Cleared history for user {user_id} project {project_id}")
360
- # Also clear in-memory LRU for this user to avoid stale context
361
- try:
362
- from memo.core import get_memory_system
363
- memory = get_memory_system()
364
- memory.clear(user_id)
365
- logger.info(f"[CHAT] Cleared memory for user {user_id}")
366
- except Exception as me:
367
- logger.warning(f"[CHAT] Failed to clear memory for user {user_id}: {me}")
368
- return MessageResponse(message="Chat history cleared")
369
- except Exception as e:
370
- raise HTTPException(500, detail=f"Failed to clear chat history: {str(e)}")
371
-
372
-
373
- # ────────────────────────────── Helpers ──────────────────────────────
374
- def _infer_mime(filename: str) -> str:
375
- lower = filename.lower()
376
- if lower.endswith(".pdf"):
377
- return "application/pdf"
378
- if lower.endswith(".docx"):
379
- return "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
380
- return "application/octet-stream"
381
-
382
-
383
- def _extract_pages(filename: str, file_bytes: bytes) -> List[Dict[str, Any]]:
384
- mime = _infer_mime(filename)
385
- if mime == "application/pdf":
386
- return parse_pdf_bytes(file_bytes)
387
- elif mime == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
388
- return parse_docx_bytes(file_bytes)
389
- else:
390
- raise HTTPException(status_code=400, detail=f"Unsupported file type: {filename}")
391
-
392
-
393
- # ────────────────────────────── Routes ──────────────────────────────
394
- @app.get("/", response_class=HTMLResponse)
395
- def index():
396
- index_path = os.path.join("static", "index.html")
397
- if not os.path.exists(index_path):
398
- return HTMLResponse("<h1>StudyBuddy</h1><p>Static files not found.</p>")
399
- return FileResponse(index_path)
400
-
401
-
402
- @app.post("/upload", response_model=UploadResponse)
403
- async def upload_files(
404
- request: Request,
405
- background_tasks: BackgroundTasks,
406
- user_id: str = Form(...),
407
- project_id: str = Form(...),
408
- files: List[UploadFile] = File(...),
409
- replace_filenames: Optional[str] = Form(None), # JSON array of filenames to replace
410
- rename_map: Optional[str] = Form(None), # JSON object {original: newname}
411
- ):
412
- """
413
- Ingest many files: PDF/DOCX.
414
- Steps:
415
- 1) Extract text & images
416
- 2) Caption images (BLIP base, CPU ok)
417
- 3) Merge captions into page text
418
- 4) Chunk into semantic cards (topic_name, summary, content + metadata)
419
- 5) Embed with all-MiniLM-L6-v2
420
- 6) Store in MongoDB with per-user and per-project metadata
421
- 7) Create a file-level summary
422
- """
423
- job_id = str(uuid.uuid4())
424
-
425
- # Basic upload policy limits
426
- max_files = int(os.getenv("MAX_FILES_PER_UPLOAD", "15"))
427
- max_mb = int(os.getenv("MAX_FILE_MB", "50"))
428
- if len(files) > max_files:
429
- raise HTTPException(400, detail=f"Too many files. Max {max_files} allowed per upload.")
430
-
431
- # Parse replace/rename directives
432
- replace_set = set()
433
- try:
434
- if replace_filenames:
435
- replace_set = set(json.loads(replace_filenames))
436
- except Exception:
437
- pass
438
- rename_dict: Dict[str, str] = {}
439
- try:
440
- if rename_map:
441
- rename_dict = json.loads(rename_map)
442
- except Exception:
443
- pass
444
-
445
- preloaded_files = []
446
- for uf in files:
447
- raw = await uf.read()
448
- if len(raw) > max_mb * 1024 * 1024:
449
- raise HTTPException(400, detail=f"{uf.filename} exceeds {max_mb} MB limit")
450
- # Apply rename if present
451
- eff_name = rename_dict.get(uf.filename, uf.filename)
452
- preloaded_files.append((eff_name, raw))
453
-
454
- # Initialize job status
455
- app.state.jobs[job_id] = {
456
- "created_at": time.time(),
457
- "total": len(preloaded_files),
458
- "completed": 0,
459
- "status": "processing",
460
- "last_error": None,
461
- }
462
-
463
- # Single background task: process files sequentially with isolation
464
- async def _process_all():
465
- for idx, (fname, raw) in enumerate(preloaded_files, start=1):
466
- try:
467
- # If instructed to replace this filename, remove previous data first
468
- if fname in replace_set:
469
- try:
470
- rag.db["chunks"].delete_many({"user_id": user_id, "project_id": project_id, "filename": fname})
471
- rag.db["files"].delete_many({"user_id": user_id, "project_id": project_id, "filename": fname})
472
- logger.info(f"[{job_id}] Replaced prior data for {fname}")
473
- except Exception as de:
474
- logger.warning(f"[{job_id}] Replace delete failed for {fname}: {de}")
475
- logger.info(f"[{job_id}] ({idx}/{len(preloaded_files)}) Parsing {fname} ({len(raw)} bytes)")
476
-
477
- # Extract pages from file
478
- pages = _extract_pages(fname, raw)
479
-
480
- # Caption images per page (if any)
481
- num_imgs = sum(len(p.get("images", [])) for p in pages)
482
- captions = []
483
- if num_imgs > 0:
484
- for p in pages:
485
- caps = []
486
- for im in p.get("images", []):
487
- try:
488
- cap = captioner.caption_image(im)
489
- caps.append(cap)
490
- except Exception as e:
491
- logger.warning(f"[{job_id}] Caption error in {fname}: {e}")
492
- captions.append(caps)
493
- else:
494
- captions = [[] for _ in pages]
495
-
496
- # Merge captions into text
497
- for p, caps in zip(pages, captions):
498
- if caps:
499
- p["text"] = (p.get("text", "") + "\n\n" + "\n".join([f"[Image] {c}" for c in caps])).strip()
500
-
501
- # Build cards
502
- cards = await build_cards_from_pages(pages, filename=fname, user_id=user_id, project_id=project_id)
503
- logger.info(f"[{job_id}] Built {len(cards)} cards for {fname}")
504
-
505
- # Embed & store
506
- embeddings = embedder.embed([c["content"] for c in cards])
507
- for c, vec in zip(cards, embeddings):
508
- c["embedding"] = vec
509
-
510
- rag.store_cards(cards)
511
-
512
- # File-level summary (cheap extractive)
513
- full_text = "\n\n".join(p.get("text", "") for p in pages)
514
- file_summary = await cheap_summarize(full_text, max_sentences=6)
515
- rag.upsert_file_summary(user_id=user_id, project_id=project_id, filename=fname, summary=file_summary)
516
- logger.info(f"[{job_id}] Completed {fname}")
517
- # Update job progress
518
- job = app.state.jobs.get(job_id)
519
- if job:
520
- job["completed"] = idx
521
- job["status"] = "processing" if idx < job.get("total", 0) else "completed"
522
- except Exception as e:
523
- logger.error(f"[{job_id}] Failed processing {fname}: {e}")
524
- job = app.state.jobs.get(job_id)
525
- if job:
526
- job["last_error"] = str(e)
527
- job["completed"] = idx # count as completed attempt
528
- finally:
529
- # Yield control between files to keep loop responsive
530
- await asyncio.sleep(0)
531
-
532
- logger.info(f"[{job_id}] Ingestion complete for {len(preloaded_files)} files")
533
- # Finalize job status
534
- job = app.state.jobs.get(job_id)
535
- if job:
536
- job["status"] = "completed"
537
-
538
- background_tasks.add_task(_process_all)
539
- return UploadResponse(job_id=job_id, status="processing", total_files=len(preloaded_files))
540
-
541
-
542
- @app.get("/upload/status")
543
- async def upload_status(job_id: str):
544
- job = app.state.jobs.get(job_id)
545
- if not job:
546
- raise HTTPException(404, detail="Job not found")
547
- percent = 0
548
- if job.get("total"):
549
- percent = int(round((job.get("completed", 0) / job.get("total", 1)) * 100))
550
- return {
551
- "job_id": job_id,
552
- "status": job.get("status"),
553
- "completed": job.get("completed"),
554
- "total": job.get("total"),
555
- "percent": percent,
556
- "last_error": job.get("last_error"),
557
- "created_at": job.get("created_at"),
558
- }
559
-
560
-
561
- @app.get("/files")
562
- async def list_project_files(user_id: str, project_id: str):
563
- """Return stored filenames and summaries for a project."""
564
- files = rag.list_files(user_id=user_id, project_id=project_id)
565
- # Ensure filenames list
566
- filenames = [f.get("filename") for f in files if f.get("filename")]
567
- return {"files": files, "filenames": filenames}
568
-
569
-
570
- @app.delete("/files", response_model=MessageResponse)
571
- async def delete_file(user_id: str, project_id: str, filename: str):
572
- """Delete a file summary and associated chunks for a project."""
573
- try:
574
- rag.db["files"].delete_many({"user_id": user_id, "project_id": project_id, "filename": filename})
575
- rag.db["chunks"].delete_many({"user_id": user_id, "project_id": project_id, "filename": filename})
576
- logger.info(f"[FILES] Deleted file {filename} for user {user_id} project {project_id}")
577
- return MessageResponse(message="File deleted")
578
- except Exception as e:
579
- raise HTTPException(500, detail=f"Failed to delete file: {str(e)}")
580
-
581
-
582
- @app.get("/cards")
583
- def list_cards(user_id: str, project_id: str, filename: Optional[str] = None, limit: int = 50, skip: int = 0):
584
- """List cards for a project"""
585
- cards = rag.list_cards(user_id=user_id, project_id=project_id, filename=filename, limit=limit, skip=skip)
586
- # Ensure all cards are JSON serializable
587
- serializable_cards = []
588
- for card in cards:
589
- serializable_card = {}
590
- for key, value in card.items():
591
- if key == '_id':
592
- serializable_card[key] = str(value) # Convert ObjectId to string
593
- elif isinstance(value, datetime):
594
- serializable_card[key] = value.isoformat() # Convert datetime to ISO string
595
- else:
596
- serializable_card[key] = value
597
- serializable_cards.append(serializable_card)
598
- # Sort cards by topic_name
599
- return {"cards": serializable_cards}
600
-
601
-
602
- @app.get("/file-summary", response_model=FileSummaryResponse)
603
- def get_file_summary(user_id: str, project_id: str, filename: str):
604
- doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=filename)
605
- if not doc:
606
- raise HTTPException(404, detail="No summary found for that file.")
607
- return FileSummaryResponse(filename=filename, summary=doc.get("summary", ""))
608
-
609
-
610
- @app.post("/report", response_model=ReportResponse)
611
- async def generate_report(
612
- user_id: str = Form(...),
613
- project_id: str = Form(...),
614
- filename: str = Form(...),
615
- outline_words: int = Form(200),
616
- report_words: int = Form(1200),
617
- instructions: str = Form("")
618
- ):
619
- """
620
- Generate a Markdown report for a single document using a lightweight CoT:
621
- 1) Gemini Flash: create a structured outline based on file summary + top chunks
622
- 2) Gemini Pro: expand into a full report with citations
623
- """
624
- logger.info("[REPORT] User Q/report: %s", trim_text(instructions, 15).replace("\n", " "))
625
- # Validate file exists
626
- files_list = rag.list_files(user_id=user_id, project_id=project_id)
627
- filenames_ci = {f.get("filename", "").lower(): f.get("filename") for f in files_list}
628
- eff_name = filenames_ci.get(filename.lower(), filename)
629
- doc_sum = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=eff_name)
630
- if not doc_sum:
631
- raise HTTPException(404, detail="No summary found for that file.")
632
-
633
- # Retrieve top-k chunks for this file using enhanced search
634
- query_text = f"Comprehensive report for {eff_name}"
635
- if instructions.strip():
636
- query_text = f"{instructions} {eff_name}"
637
-
638
- q_vec = embedder.embed([query_text])[0]
639
- hits = rag.vector_search(user_id=user_id, project_id=project_id, query_vector=q_vec, k=8, filenames=[eff_name], search_type="flat")
640
- if not hits:
641
- # Fall back to summary-only report
642
- hits = []
643
-
644
- # Build context
645
- contexts = []
646
- sources_meta = []
647
- for h in hits:
648
- doc = h["doc"]
649
- chunk_id = str(doc.get("_id", ""))
650
- contexts.append(f"[CHUNK_ID: {chunk_id}] [{doc.get('topic_name','Topic')}] {trim_text(doc.get('content',''), 2000)}")
651
- sources_meta.append({
652
- "filename": doc.get("filename"),
653
- "topic_name": doc.get("topic_name"),
654
- "page_span": doc.get("page_span"),
655
- "score": float(h.get("score", 0.0)),
656
- "chunk_id": chunk_id
657
- })
658
- context_text = "\n\n---\n\n".join(contexts) if contexts else ""
659
- file_summary = doc_sum.get("summary", "")
660
-
661
- # Chain-of-thought style two-step with Gemini
662
- from utils.api.router import GEMINI_MED, GEMINI_PRO
663
-
664
- # Step 1: Content filtering and relevance assessment based on user instructions
665
- if instructions.strip():
666
- filter_sys = (
667
- "You are an expert content analyst. Given the user's specific instructions and the document content, "
668
- "identify which sections/chunks are MOST relevant to their request. "
669
- "Each chunk is prefixed with [CHUNK_ID: <id>] - use these exact IDs in your response. "
670
- "Return a JSON object with this structure: {\"relevant_chunks\": [\"<chunk_id_1>\", \"<chunk_id_2>\"], \"focus_areas\": [\"key topic 1\", \"key topic 2\"]}"
671
- )
672
- filter_user = f"USER_INSTRUCTIONS: {instructions}\n\nDOCUMENT_SUMMARY: {file_summary}\n\nAVAILABLE_CHUNKS:\n{context_text}\n\nIdentify only the chunks that directly address the user's specific request."
673
-
674
- try:
675
- selection_filter = {"provider": "gemini", "model": os.getenv("GEMINI_MED", "gemini-2.5-flash")}
676
- filter_response = await generate_answer_with_model(selection_filter, filter_sys, filter_user, gemini_rotator, nvidia_rotator)
677
- logger.info(f"[REPORT] Raw filter response: {filter_response}")
678
- # Try to parse the filter response to get relevant chunks
679
- import json
680
- try:
681
- filter_data = json.loads(filter_response)
682
- relevant_chunk_ids = filter_data.get("relevant_chunks", [])
683
- focus_areas = filter_data.get("focus_areas", [])
684
- logger.info(f"[REPORT] Content filtering identified {len(relevant_chunk_ids)} relevant chunks: {relevant_chunk_ids} and focus areas: {focus_areas}")
685
- # Filter context to only relevant chunks
686
- if relevant_chunk_ids and hits:
687
- filtered_hits = [h for h in hits if str(h["doc"].get("_id", "")) in relevant_chunk_ids]
688
- if filtered_hits:
689
- hits = filtered_hits
690
- logger.info(f"[REPORT] Filtered context from {len(hits)} chunks to {len(filtered_hits)} relevant chunks")
691
- else:
692
- logger.warning(f"[REPORT] No matching chunks found for IDs: {relevant_chunk_ids}")
693
- else:
694
- logger.warning(f"[REPORT] No relevant chunk IDs returned or no hits available")
695
- except json.JSONDecodeError as e:
696
- logger.warning(f"[REPORT] Could not parse filter response, using all chunks. JSON error: {e}. Response: {filter_response}")
697
- except Exception as e:
698
- logger.warning(f"[REPORT] Content filtering failed: {e}")
699
-
700
- # Step 2: Create focused outline based on user instructions
701
- sys_outline = (
702
- "You are an expert technical writer. Create a focused, hierarchical outline for a report based on the user's specific instructions and the MATERIALS. "
703
- "The outline should directly address what the user asked for. Output as Markdown bullet list only. Keep it within about {} words."
704
- ).format(max(100, outline_words))
705
-
706
- instruction_context = f"USER_REQUEST: {instructions}\n\n" if instructions.strip() else ""
707
- user_outline = f"{instruction_context}MATERIALS:\n\n[FILE_SUMMARY from {eff_name}]\n{file_summary}\n\n[DOC_CONTEXT]\n{context_text}"
708
-
709
- try:
710
- # Step 1: Outline with Flash/Med
711
- selection_outline = {"provider": "gemini", "model": os.getenv("GEMINI_MED", "gemini-2.5-flash")}
712
- outline_md = await generate_answer_with_model(selection_outline, sys_outline, user_outline, gemini_rotator, nvidia_rotator)
713
- except Exception as e:
714
- logger.warning(f"Report outline failed: {e}")
715
- outline_md = "# Report Outline\n\n- Introduction\n- Key Topics\n- Conclusion"
716
-
717
- # Step 3: Generate focused report based on user instructions and filtered content
718
- instruction_focus = f"FOCUS ON: {instructions}\n\n" if instructions.strip() else ""
719
- sys_report = (
720
- "You are an expert report writer. Write a focused, comprehensive Markdown report that directly addresses the user's specific request. "
721
- "Using the OUTLINE and MATERIALS:\n"
722
- "- Structure the report to answer exactly what the user asked for\n"
723
- "- Use clear section headings\n"
724
- "- Keep content factual and grounded in the provided materials\n"
725
- f"- Include brief citations like (source: {eff_name}, topic) - use the actual filename provided\n"
726
- "- If the user asked for a specific section/topic, focus heavily on that\n"
727
- f"- Target length ~{max(600, report_words)} words\n"
728
- "- Ensure the report directly fulfills the user's request"
729
- )
730
- user_report = f"{instruction_focus}OUTLINE:\n{outline_md}\n\nMATERIALS:\n[FILE_SUMMARY from {eff_name}]\n{file_summary}\n\n[DOC_CONTEXT]\n{context_text}"
731
-
732
- try:
733
- selection_report = {"provider": "gemini", "model": os.getenv("GEMINI_PRO", "gemini-2.5-pro")}
734
- report_md = await generate_answer_with_model(selection_report, sys_report, user_report, gemini_rotator, nvidia_rotator)
735
- except Exception as e:
736
- logger.error(f"Report generation failed: {e}")
737
- report_md = outline_md + "\n\n" + file_summary
738
-
739
- return ReportResponse(filename=eff_name, report_markdown=report_md, sources=sources_meta)
740
-
741
-
742
- @app.post("/report/pdf")
743
- async def generate_report_pdf(
744
- user_id: str = Form(...),
745
- project_id: str = Form(...),
746
- report_content: str = Form(...)
747
- ):
748
- """
749
- Generate a PDF from report content using the PDF utility module
750
- """
751
- from utils.service.pdf import generate_report_pdf as generate_pdf
752
- from fastapi.responses import Response
753
-
754
- try:
755
- pdf_content = await generate_pdf(report_content, user_id, project_id)
756
-
757
- # Return PDF as response
758
- return Response(
759
- content=pdf_content,
760
- media_type="application/pdf",
761
- headers={"Content-Disposition": f"attachment; filename=report-{datetime.now().strftime('%Y-%m-%d')}.pdf"}
762
- )
763
-
764
- except HTTPException:
765
- # Re-raise HTTP exceptions as-is
766
- raise
767
-
768
-
769
- # ────────────────────────────── Enhanced RAG Helper Functions ──────────────────────────────
770
-
771
- async def _generate_query_variations(question: str, nvidia_rotator) -> List[str]:
772
- """
773
- Generate multiple query variations using Chain of Thought reasoning
774
- """
775
- if not nvidia_rotator:
776
- return [question] # Fallback to original question
777
-
778
- try:
779
- # Use NVIDIA to generate query variations
780
- sys_prompt = """You are an expert at query expansion and reformulation. Given a user question, generate 3-5 different ways to ask the same question that would help retrieve relevant information from a document database.
781
-
782
- Focus on:
783
- 1. Different terminology and synonyms
784
- 2. More specific technical terms
785
- 3. Broader conceptual queries
786
- 4. Question reformulations
787
-
788
- Return only the variations, one per line, no numbering or extra text."""
789
-
790
- user_prompt = f"Original question: {question}\n\nGenerate query variations:"
791
-
792
- from utils.api.router import generate_answer_with_model
793
- selection = {"provider": "nvidia", "model": "meta/llama-3.1-8b-instruct"}
794
- response = await generate_answer_with_model(selection, sys_prompt, user_prompt, None, nvidia_rotator)
795
-
796
- # Parse variations
797
- variations = [line.strip() for line in response.split('\n') if line.strip()]
798
- variations = [v for v in variations if len(v) > 10] # Filter out too short variations
799
-
800
- # Always include original question
801
- if question not in variations:
802
- variations.insert(0, question)
803
-
804
- return variations[:5] # Limit to 5 variations
805
-
806
- except Exception as e:
807
- logger.warning(f"Query variation generation failed: {e}")
808
- return [question]
809
-
810
-
811
- def _deduplicate_and_rank_hits(all_hits: List[Dict], original_question: str) -> List[Dict]:
812
- """
813
- Deduplicate hits by chunk ID and rank by relevance to original question
814
- """
815
- if not all_hits:
816
- return []
817
-
818
- # Deduplicate by chunk ID
819
- seen_ids = set()
820
- unique_hits = []
821
-
822
- for hit in all_hits:
823
- chunk_id = str(hit.get("doc", {}).get("_id", ""))
824
- if chunk_id not in seen_ids:
825
- seen_ids.add(chunk_id)
826
- unique_hits.append(hit)
827
-
828
- # Simple ranking: boost scores for hits that contain question keywords
829
- question_words = set(original_question.lower().split())
830
-
831
- for hit in unique_hits:
832
- content = hit.get("doc", {}).get("content", "").lower()
833
- topic = hit.get("doc", {}).get("topic_name", "").lower()
834
-
835
- # Count keyword matches
836
- content_matches = sum(1 for word in question_words if word in content)
837
- topic_matches = sum(1 for word in question_words if word in topic)
838
-
839
- # Boost score based on keyword matches
840
- keyword_boost = 1.0 + (content_matches * 0.1) + (topic_matches * 0.2)
841
- hit["score"] = hit.get("score", 0.0) * keyword_boost
842
-
843
- # Sort by boosted score
844
- unique_hits.sort(key=lambda x: x.get("score", 0.0), reverse=True)
845
-
846
- return unique_hits
847
-
848
-
849
- @app.post("/chat", response_model=ChatAnswerResponse)
850
- async def chat(
851
- user_id: str = Form(...),
852
- project_id: str = Form(...),
853
- question: str = Form(...),
854
- k: int = Form(6)
855
- ):
856
- # Add timeout protection to prevent hanging
857
- import asyncio
858
- try:
859
- return await asyncio.wait_for(_chat_impl(user_id, project_id, question, k), timeout=120.0)
860
- except asyncio.TimeoutError:
861
- logger.error("[CHAT] Chat request timed out after 120 seconds")
862
- return ChatAnswerResponse(
863
- answer="Sorry, the request took too long to process. Please try again with a simpler question.",
864
- sources=[],
865
- relevant_files=[]
866
- )
867
-
868
- async def _chat_impl(
869
- user_id: str,
870
- project_id: str,
871
- question: str,
872
- k: int
873
- ):
874
- """
875
- RAG chat that answers ONLY from uploaded materials.
876
- - Preload all filenames + summaries; use NVIDIA to classify file relevance to question (true/false)
877
- - Restrict vector search to relevant files (fall back to all if none)
878
- - Bring in recent chat memory: last 3 via NVIDIA relevance; remaining 17 via semantic search
879
- - After answering, summarize (q,a) via NVIDIA and store into LRU (last 20)
880
- """
881
- import sys
882
- from memo.core import get_memory_system
883
- from utils.api.router import NVIDIA_SMALL # reuse default name
884
- memory = get_memory_system()
885
- logger.info("[CHAT] User Q/chat: %s", trim_text(question, 15).replace("\n", " "))
886
-
887
- # 0) Detect any filenames mentioned in the question (e.g., JADE.pdf)
888
- # Supports .pdf, .docx, and .doc for detection purposes
889
- # Only capture contiguous tokens ending with extension (no spaces) to avoid swallowing prompt text
890
- mentioned = set([m.group(0).strip() for m in re.finditer(r"\b[^\s/\\]+?\.(?:pdf|docx|doc)\b", question, re.IGNORECASE)])
891
- if mentioned:
892
- logger.info(f"[CHAT] Detected mentioned filenames in question: {list(mentioned)}")
893
-
894
- # 0a) If the question explicitly asks for a summary/about of a single mentioned file, return its summary directly
895
- if mentioned and (re.search(r"\b(summary|summarize|about|overview)\b", question, re.IGNORECASE)):
896
- # Prefer direct summary when exactly one file is referenced
897
- if len(mentioned) == 1:
898
- fn = next(iter(mentioned))
899
- doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=fn)
900
- if doc:
901
- return ChatAnswerResponse(
902
- answer=doc.get("summary", ""),
903
- sources=[{"filename": fn, "file_summary": True}]
904
- )
905
- # If not found with the same casing, try case-insensitive match against stored filenames
906
- files_ci = rag.list_files(user_id=user_id, project_id=project_id)
907
- match = next((f["filename"] for f in files_ci if f.get("filename", "").lower() == fn.lower()), None)
908
- if match:
909
- doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=match)
910
- if doc:
911
- return ChatAnswerResponse(
912
- answer=doc.get("summary", ""),
913
- sources=[{"filename": match, "file_summary": True}]
914
- )
915
- # If multiple files are referenced with summary intent, proceed to relevance flow below
916
-
917
- # 1) Preload file list + summaries
918
- files_list = rag.list_files(user_id=user_id, project_id=project_id) # [{filename, summary}]
919
-
920
- # 1a) Normalize mentioned filenames against the user's library (case-insensitive)
921
- filenames_ci_map = {f.get("filename", "").lower(): f.get("filename") for f in files_list if f.get("filename")}
922
- mentioned_normalized = []
923
- for mfn in mentioned:
924
- key = mfn.lower()
925
- if key in filenames_ci_map:
926
- mentioned_normalized.append(filenames_ci_map[key])
927
- if mentioned and not mentioned_normalized and files_list:
928
- # Try looser match: contained filenames ignoring spaces
929
- norm = {f.get("filename", "").lower().replace(" ", ""): f.get("filename") for f in files_list if f.get("filename")}
930
- for mfn in mentioned:
931
- key2 = mfn.lower().replace(" ", "")
932
- if key2 in norm:
933
- mentioned_normalized.append(norm[key2])
934
- if mentioned_normalized:
935
- logger.info(f"[CHAT] Normalized mentions to stored filenames: {mentioned_normalized}")
936
-
937
- # 1b) Ask NVIDIA to mark relevance per file
938
- try:
939
- from memo.history import get_history_manager
940
- history_manager = get_history_manager(memory)
941
- relevant_map = await history_manager.files_relevance(question, files_list, nvidia_rotator)
942
- relevant_files = [fn for fn, ok in relevant_map.items() if ok]
943
- logger.info(f"[CHAT] NVIDIA relevant files: {relevant_files}")
944
- except Exception as e:
945
- logger.warning(f"[CHAT] NVIDIA relevance failed, defaulting to all files: {e}")
946
- relevant_files = [f.get("filename") for f in files_list if f.get("filename")]
947
-
948
- # 1c) Ensure any explicitly mentioned files in the question are included
949
- # This safeguards against model misclassification
950
- if mentioned_normalized:
951
- extra = [fn for fn in mentioned_normalized if fn not in relevant_files]
952
- relevant_files.extend(extra)
953
- if extra:
954
- logger.info(f"[CHAT] Forced-include mentioned files into relevance: {extra}")
955
-
956
- # 2) Memory context: recent 3 via NVIDIA, remaining 17 via semantic
957
- # Use enhanced context retrieval if available, otherwise fallback to original method
958
- try:
959
- from memo.history import get_history_manager
960
- history_manager = get_history_manager(memory)
961
- recent_related, semantic_related = await history_manager.related_recent_and_semantic_context(
962
- user_id, question, embedder
963
- )
964
- except Exception as e:
965
- logger.warning(f"[CHAT] Enhanced context retrieval failed, using fallback: {e}")
966
- # Fallback to original method
967
- recent3 = memory.recent(user_id, 3)
968
- if recent3:
969
- sys = "Pick only items that directly relate to the new question. Output the selected items verbatim, no commentary. If none, output nothing."
970
- numbered = [{"id": i+1, "text": s} for i, s in enumerate(recent3)]
971
- user = f"Question: {question}\nCandidates:\n{json.dumps(numbered, ensure_ascii=False)}\nSelect any related items and output ONLY their 'text' values concatenated."
972
- try:
973
- from utils.api.rotator import robust_post_json
974
- key = nvidia_rotator.get_key()
975
- url = "https://integrate.api.nvidia.com/v1/chat/completions"
976
- payload = {
977
- "model": os.getenv("NVIDIA_SMALL", "meta/llama-3.1-8b-instruct"),
978
- "temperature": 0.0,
979
- "messages": [
980
- {"role": "system", "content": sys},
981
- {"role": "user", "content": user},
982
- ]
983
- }
984
- headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key or ''}"}
985
- data = await robust_post_json(url, headers, payload, nvidia_rotator)
986
- recent_related = data["choices"][0]["message"]["content"].strip()
987
- except Exception as e:
988
- logger.warning(f"Recent-related NVIDIA error: {e}")
989
- recent_related = ""
990
- else:
991
- recent_related = ""
992
-
993
- # Get semantic context from remaining memories
994
- rest17 = memory.rest(user_id, 3)
995
- if rest17:
996
- import numpy as np
997
- def _cosine(a: np.ndarray, b: np.ndarray) -> float:
998
- denom = (np.linalg.norm(a) * np.linalg.norm(b)) or 1.0
999
- return float(np.dot(a, b) / denom)
1000
-
1001
- qv = np.array(embedder.embed([question])[0], dtype="float32")
1002
- mats = embedder.embed([s.strip() for s in rest17])
1003
- sims = [(_cosine(qv, np.array(v, dtype="float32")), s) for v, s in zip(mats, rest17)]
1004
- sims.sort(key=lambda x: x[0], reverse=True)
1005
- top = [s for (sc, s) in sims[:3] if sc > 0.15]
1006
- semantic_related = "\n\n".join(top) if top else ""
1007
-
1008
- # 3) Enhanced query reasoning and RAG vector search
1009
- logger.info(f"[CHAT] Starting enhanced vector search with relevant_files={relevant_files}")
1010
-
1011
- # Chain of Thought query breakdown for better retrieval
1012
- enhanced_queries = await _generate_query_variations(question, nvidia_rotator)
1013
- logger.info(f"[CHAT] Generated {len(enhanced_queries)} query variations")
1014
-
1015
- # Try multiple search strategies
1016
- all_hits = []
1017
- search_strategies = ["flat", "hybrid", "local"] # Try most accurate first
1018
-
1019
- for strategy in search_strategies:
1020
- for query_variant in enhanced_queries:
1021
- q_vec = embedder.embed([query_variant])[0]
1022
- hits = rag.vector_search(
1023
- user_id=user_id,
1024
- project_id=project_id,
1025
- query_vector=q_vec,
1026
- k=k,
1027
- filenames=relevant_files if relevant_files else None,
1028
- search_type=strategy
1029
- )
1030
- if hits:
1031
- all_hits.extend(hits)
1032
- logger.info(f"[CHAT] {strategy} search with '{query_variant[:50]}...' returned {len(hits)} hits")
1033
- break # If we found hits with this strategy, move to next query
1034
- if all_hits:
1035
- break # If we found hits, don't try other strategies
1036
-
1037
- # Deduplicate and rank results
1038
- hits = _deduplicate_and_rank_hits(all_hits, question)
1039
- logger.info(f"[CHAT] Final vector search returned {len(hits) if hits else 0} hits")
1040
- if not hits:
1041
- logger.info(f"[CHAT] No hits with relevance filter. relevant_files={relevant_files}")
1042
- # Fallback 1: Try with original question and flat search
1043
- q_vec_original = embedder.embed([question])[0]
1044
- hits = rag.vector_search(
1045
- user_id=user_id,
1046
- project_id=project_id,
1047
- query_vector=q_vec_original,
1048
- k=k,
1049
- filenames=relevant_files if relevant_files else None,
1050
- search_type="flat"
1051
- )
1052
- logger.info(f"[CHAT] Fallback flat search → hits={len(hits) if hits else 0}")
1053
-
1054
- # Fallback 2: if we have explicit mentions, try restricting only to them
1055
- if not hits and mentioned_normalized:
1056
- hits = rag.vector_search(
1057
- user_id=user_id,
1058
- project_id=project_id,
1059
- query_vector=q_vec_original,
1060
- k=k,
1061
- filenames=mentioned_normalized,
1062
- search_type="flat"
1063
- )
1064
- logger.info(f"[CHAT] Fallback with mentioned files only → hits={len(hits) if hits else 0}")
1065
-
1066
- # Fallback 3: if still empty, try without any filename restriction
1067
- if not hits:
1068
- hits = rag.vector_search(
1069
- user_id=user_id,
1070
- project_id=project_id,
1071
- query_vector=q_vec_original,
1072
- k=k,
1073
- filenames=None,
1074
- search_type="flat"
1075
- )
1076
- logger.info(f"[CHAT] Fallback with all files → hits={len(hits) if hits else 0}")
1077
- # If still no hits, and we have mentioned files, try returning their summaries if present
1078
- if not hits and mentioned_normalized:
1079
- fsum_map = {f["filename"]: f.get("summary", "") for f in files_list}
1080
- summaries = [fsum_map.get(fn, "") for fn in mentioned_normalized]
1081
- summaries = [s for s in summaries if s]
1082
- if summaries:
1083
- answer = ("\n\n---\n\n").join(summaries)
1084
- return ChatAnswerResponse(
1085
- answer=answer,
1086
- sources=[{"filename": fn, "file_summary": True} for fn in mentioned_normalized],
1087
- relevant_files=mentioned_normalized
1088
- )
1089
- if not hits:
1090
- # Last resort: use summaries from relevant files if we didn't have explicit mentions normalized
1091
- candidates = mentioned_normalized or relevant_files or []
1092
- if candidates:
1093
- fsum_map = {f["filename"]: f.get("summary", "") for f in files_list}
1094
- summaries = [fsum_map.get(fn, "") for fn in candidates]
1095
- summaries = [s for s in summaries if s]
1096
- if summaries:
1097
- answer = ("\n\n---\n\n").join(summaries)
1098
- logger.info(f"[CHAT] Falling back to file-level summaries for: {candidates}")
1099
- return ChatAnswerResponse(
1100
- answer=answer,
1101
- sources=[{"filename": fn, "file_summary": True} for fn in candidates],
1102
- relevant_files=candidates
1103
- )
1104
- return ChatAnswerResponse(
1105
- answer="I don't know based on your uploaded materials. Try uploading more sources or rephrasing the question.",
1106
- sources=[],
1107
- relevant_files=relevant_files or mentioned_normalized
1108
- )
1109
- # If we get here, we have hits, so continue with normal flow
1110
- # Compose context
1111
- contexts = []
1112
- sources_meta = []
1113
- for h in hits:
1114
- doc = h["doc"]
1115
- score = h["score"]
1116
- contexts.append(f"[{doc.get('topic_name','Topic')}] {trim_text(doc.get('content',''), 2000)}")
1117
- sources_meta.append({
1118
- "filename": doc.get("filename"),
1119
- "topic_name": doc.get("topic_name"),
1120
- "page_span": doc.get("page_span"),
1121
- "score": float(score),
1122
- "chunk_id": str(doc.get("_id", "")) # Convert ObjectId to string
1123
- })
1124
- context_text = "\n\n---\n\n".join(contexts)
1125
-
1126
- # Add file-level summaries for relevant files
1127
- file_summary_block = ""
1128
- if relevant_files:
1129
- fsum_map = {f["filename"]: f.get("summary","") for f in files_list}
1130
- lines = [f"[{fn}] {fsum_map.get(fn, '')}" for fn in relevant_files]
1131
- file_summary_block = "\n".join(lines)
1132
-
1133
- # Guardrail instruction to avoid hallucination
1134
- system_prompt = (
1135
- "You are a careful study assistant. Answer strictly using the given CONTEXT.\n"
1136
- "If the answer isn't in the context, say 'I don't know based on the provided materials.'\n"
1137
- "Write concise, clear explanations with citations like (source: actual_filename, topic).\n"
1138
- "Use the exact filename as provided in the context, not placeholders.\n"
1139
- )
1140
-
1141
- # Add recent chat context and historical similarity context
1142
- history_block = ""
1143
- if recent_related or semantic_related:
1144
- history_block = "RECENT_CHAT_CONTEXT:\n" + (recent_related or "") + ("\n\nHISTORICAL_SIMILARITY_CONTEXT:\n" + semantic_related if semantic_related else "")
1145
- composed_context = ""
1146
- if history_block:
1147
- composed_context += history_block + "\n\n"
1148
- if file_summary_block:
1149
- composed_context += "FILE_SUMMARIES:\n" + file_summary_block + "\n\n"
1150
- composed_context += "DOC_CONTEXT:\n" + context_text
1151
-
1152
- # Compose user prompt
1153
- user_prompt = f"QUESTION:\n{question}\n\nCONTEXT:\n{composed_context}"
1154
- # Choose model (cost-aware)
1155
- selection = select_model(question=question, context=composed_context)
1156
- logger.info(f"Model selection: {selection}")
1157
- # Generate answer with model
1158
- logger.info(f"[CHAT] Generating answer with {selection['provider']} {selection['model']}")
1159
- try:
1160
- answer = await generate_answer_with_model(
1161
- selection=selection,
1162
- system_prompt=system_prompt,
1163
- user_prompt=user_prompt,
1164
- gemini_rotator=gemini_rotator,
1165
- nvidia_rotator=nvidia_rotator
1166
- )
1167
- logger.info(f"[CHAT] Answer generated successfully, length: {len(answer)}")
1168
- except Exception as e:
1169
- logger.error(f"LLM error: {e}")
1170
- answer = "I had trouble contacting the language model provider just now. Please try again."
1171
- # After answering: summarize QA and store in memory (LRU, last 20)
1172
- try:
1173
- from memo.history import get_history_manager
1174
- history_manager = get_history_manager(memory)
1175
- qa_sum = await history_manager.summarize_qa_with_nvidia(question, answer, nvidia_rotator)
1176
- memory.add(user_id, qa_sum)
1177
-
1178
- # Also store enhanced conversation memory if available
1179
- if memory.is_enhanced_available():
1180
- await memory.add_conversation_memory(
1181
- user_id=user_id,
1182
- question=question,
1183
- answer=answer,
1184
- project_id=project_id,
1185
- context={
1186
- "relevant_files": relevant_files,
1187
- "sources_count": len(sources_meta),
1188
- "timestamp": time.time()
1189
- }
1190
- )
1191
- except Exception as e:
1192
- logger.warning(f"QA summarize/store failed: {e}")
1193
- # Trim for logging
1194
- logger.info("LLM answer (trimmed): %s", trim_text(answer, 200).replace("\n", " "))
1195
- return ChatAnswerResponse(answer=answer, sources=sources_meta, relevant_files=relevant_files)
1196
-
1197
-
1198
- @app.get("/healthz", response_model=HealthResponse)
1199
- def health():
1200
- return HealthResponse(ok=True)
1201
-
1202
-
1203
- @app.get("/test-db")
1204
- async def test_database():
1205
- """Test database connection and basic operations"""
1206
- try:
1207
- if not rag:
1208
- return {
1209
- "status": "error",
1210
- "message": "RAG store not initialized",
1211
- "error_type": "RAGStoreNotInitialized"
1212
- }
1213
-
1214
- # Test basic connection
1215
- rag.client.admin.command('ping')
1216
-
1217
- # Test basic insert/query
1218
- test_collection = rag.db["test_collection"]
1219
- test_doc = {"test": True, "timestamp": datetime.now(timezone.utc)}
1220
- result = test_collection.insert_one(test_doc)
1221
-
1222
- # Test query
1223
- found = test_collection.find_one({"_id": result.inserted_id})
1224
-
1225
- # Clean up
1226
- test_collection.delete_one({"_id": result.inserted_id})
1227
-
1228
- return {
1229
- "status": "success",
1230
- "message": "Database connection and operations working correctly",
1231
- "test_id": str(result.inserted_id),
1232
- "found_doc": str(found["_id"]) if found else None
1233
- }
1234
-
1235
- except Exception as e:
1236
- logger.error(f"[TEST-DB] Database test failed: {str(e)}")
1237
- return {
1238
- "status": "error",
1239
- "message": f"Database test failed: {str(e)}",
1240
- "error_type": str(type(e))
1241
- }
1242
-
1243
-
1244
- @app.get("/rag-status")
1245
- async def rag_status():
1246
- """Check the status of the RAG store"""
1247
- if not rag:
1248
- return {
1249
- "status": "error",
1250
- "message": "RAG store not initialized",
1251
- "rag_available": False
1252
- }
1253
-
1254
- try:
1255
- # Test connection
1256
- rag.client.admin.command('ping')
1257
- return {
1258
- "status": "success",
1259
- "message": "RAG store is available and connected",
1260
- "rag_available": True,
1261
- "database": rag.db.name,
1262
- "collections": {
1263
- "chunks": rag.chunks.name,
1264
- "files": rag.files.name
1265
- }
1266
- }
1267
- except Exception as e:
1268
- return {
1269
- "status": "error",
1270
- "message": f"RAG store connection failed: {str(e)}",
1271
- "rag_available": False,
1272
- "error": str(e)
1273
- }
1274
 
1275
  # Local dev
1276
  # if __name__ == "__main__":
1277
  # import uvicorn
1278
- # uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
1
  # https://binkhoale1812-edsummariser.hf.space/
 
 
 
 
 
2
 
3
+ # Minimal orchestrator that exposes the FastAPI app and registers routes
4
+ from helpers import app # FastAPI instance
 
5
 
6
+ # Import route modules for side-effect registration
7
+ import routes.auth as _routes_auth
8
+ import routes.projects as _routes_projects
9
+ import routes.files as _routes_files
10
+ import routes.reports as _routes_report
11
+ import routes.chats as _routes_chat
12
+ import routes.health as _routes_health
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Local dev
15
  # if __name__ == "__main__":
16
  # import uvicorn
17
+ # uvicorn.run(app, host="0.0.0.0", port=8000)
18
+
19
+
copy.py ADDED
@@ -0,0 +1,1278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://binkhoale1812-edsummariser.hf.space/
2
+ import os, io, re, uuid, json, time, logging
3
+ from typing import List, Dict, Any, Optional
4
+ from datetime import datetime, timezone
5
+ from pydantic import BaseModel
6
+ import asyncio
7
+
8
+ # Load environment variables from .env file
9
+ from dotenv import load_dotenv
10
+ load_dotenv()
11
+
12
+ from fastapi import FastAPI, UploadFile, File, Form, Request, HTTPException, BackgroundTasks
13
+ from fastapi.responses import FileResponse, JSONResponse, HTMLResponse
14
+ from fastapi.staticfiles import StaticFiles
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+
17
+ # MongoDB imports
18
+ from pymongo.errors import PyMongoError, ConnectionFailure, ServerSelectionTimeoutError
19
+
20
+ from utils.api.rotator import APIKeyRotator
21
+ from utils.ingestion.parser import parse_pdf_bytes, parse_docx_bytes
22
+ from utils.ingestion.caption import BlipCaptioner
23
+ from utils.ingestion.chunker import build_cards_from_pages
24
+ from utils.rag.embeddings import EmbeddingClient
25
+ from utils.rag.rag import RAGStore, ensure_indexes
26
+ from utils.api.router import select_model, generate_answer_with_model
27
+ from utils.service.summarizer import cheap_summarize
28
+ from utils.service.common import trim_text
29
+ from utils.logger import get_logger
30
+ import re
31
+
32
+ # ────────────────────────────── Response Models ──────────────────────────────
33
+ class ProjectResponse(BaseModel):
34
+ project_id: str
35
+ user_id: str
36
+ name: str
37
+ description: str
38
+ created_at: str
39
+ updated_at: str
40
+
41
+ class ProjectsListResponse(BaseModel):
42
+ projects: List[ProjectResponse]
43
+
44
+ class ChatMessageResponse(BaseModel):
45
+ user_id: str
46
+ project_id: str
47
+ role: str
48
+ content: str
49
+ timestamp: float
50
+ created_at: str
51
+ sources: Optional[List[Dict[str, Any]]] = None
52
+
53
+ class ChatHistoryResponse(BaseModel):
54
+ messages: List[ChatMessageResponse]
55
+
56
+ class MessageResponse(BaseModel):
57
+ message: str
58
+
59
+ class UploadResponse(BaseModel):
60
+ job_id: str
61
+ status: str
62
+ total_files: Optional[int] = None
63
+
64
+ class FileSummaryResponse(BaseModel):
65
+ filename: str
66
+ summary: str
67
+
68
+ class ChatAnswerResponse(BaseModel):
69
+ answer: str
70
+ sources: List[Dict[str, Any]]
71
+ relevant_files: Optional[List[str]] = None
72
+
73
+ class HealthResponse(BaseModel):
74
+ ok: bool
75
+
76
+ class ReportResponse(BaseModel):
77
+ filename: str
78
+ report_markdown: str
79
+ sources: List[Dict[str, Any]]
80
+
81
+ # ────────────────────────────── App Setup ──────────────────────────────
82
+ logger = get_logger("APP", name="studybuddy")
83
+
84
+ app = FastAPI(title="StudyBuddy RAG", version="0.1.0")
85
+ app.add_middleware(
86
+ CORSMiddleware,
87
+ allow_origins=["*"],
88
+ allow_credentials=True,
89
+ allow_methods=["*"],
90
+ allow_headers=["*"],
91
+ )
92
+
93
+ # Serve static files (index.html, scripts.js, styles.css)
94
+ app.mount("/static", StaticFiles(directory="static"), name="static")
95
+
96
+ # In-memory job tracker (for progress queries)
97
+ app.state.jobs = {}
98
+
99
+
100
+ # ────────────────────────────── Global Clients ──────────────────────────────
101
+ # API rotators (round robin + auto failover on quota errors)
102
+ gemini_rotator = APIKeyRotator(prefix="GEMINI_API_", max_slots=5)
103
+ nvidia_rotator = APIKeyRotator(prefix="NVIDIA_API_", max_slots=5)
104
+
105
+ # Captioner + Embeddings (lazy init inside classes)
106
+ captioner = BlipCaptioner()
107
+ embedder = EmbeddingClient(model_name=os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2"))
108
+
109
+ # Mongo / RAG store
110
+ try:
111
+ rag = RAGStore(mongo_uri=os.getenv("MONGO_URI"), db_name=os.getenv("MONGO_DB", "studybuddy"))
112
+ # Test the connection
113
+ rag.client.admin.command('ping')
114
+ logger.info("[APP] MongoDB connection successful")
115
+ ensure_indexes(rag)
116
+ logger.info("[APP] MongoDB indexes ensured")
117
+ except Exception as e:
118
+ logger.error(f"[APP] Failed to initialize MongoDB/RAG store: {str(e)}")
119
+ logger.error(f"[APP] MONGO_URI: {os.getenv('MONGO_URI', 'Not set')}")
120
+ logger.error(f"[APP] MONGO_DB: {os.getenv('MONGO_DB', 'studybuddy')}")
121
+ # Create a dummy RAG store for now - this will cause errors but prevents the app from crashing
122
+ rag = None
123
+
124
+
125
+ # ────────────────────────────── Auth Helpers/Routes ───────────────────────────
126
+ import hashlib
127
+ import secrets
128
+
129
+
130
+ def _hash_password(password: str, salt: Optional[str] = None) -> Dict[str, str]:
131
+ salt = salt or secrets.token_hex(16)
132
+ dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), bytes.fromhex(salt), 120000)
133
+ return {"salt": salt, "hash": dk.hex()}
134
+
135
+
136
+ def _verify_password(password: str, salt: str, expected_hex: str) -> bool:
137
+ dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), bytes.fromhex(salt), 120000)
138
+ return secrets.compare_digest(dk.hex(), expected_hex)
139
+
140
+
141
+ @app.post("/auth/signup")
142
+ async def signup(email: str = Form(...), password: str = Form(...)):
143
+ email = email.strip().lower()
144
+ if not email or not password or "@" not in email:
145
+ raise HTTPException(400, detail="Invalid email or password")
146
+ users = rag.db["users"]
147
+ if users.find_one({"email": email}):
148
+ raise HTTPException(409, detail="Email already registered")
149
+ user_id = str(uuid.uuid4())
150
+ hp = _hash_password(password)
151
+ users.insert_one({
152
+ "email": email,
153
+ "user_id": user_id,
154
+ "pw_salt": hp["salt"],
155
+ "pw_hash": hp["hash"],
156
+ "created_at": int(time.time())
157
+ })
158
+ logger.info(f"[AUTH] Created user {email} -> {user_id}")
159
+ return {"email": email, "user_id": user_id}
160
+
161
+
162
+ @app.post("/auth/login")
163
+ async def login(email: str = Form(...), password: str = Form(...)):
164
+ email = email.strip().lower()
165
+ users = rag.db["users"]
166
+ doc = users.find_one({"email": email})
167
+ if not doc:
168
+ raise HTTPException(401, detail="Invalid credentials")
169
+ if not _verify_password(password, doc.get("pw_salt", ""), doc.get("pw_hash", "")):
170
+ raise HTTPException(401, detail="Invalid credentials")
171
+ logger.info(f"[AUTH] Login {email}")
172
+ return {"email": email, "user_id": doc.get("user_id")}
173
+
174
+
175
+ # ────────────────────────────── Project Management ───────────────────────────
176
+ @app.post("/projects/create", response_model=ProjectResponse)
177
+ async def create_project(user_id: str = Form(...), name: str = Form(...), description: str = Form("")):
178
+ """Create a new project for a user"""
179
+ try:
180
+ if not rag:
181
+ raise HTTPException(500, detail="Database connection not available")
182
+
183
+ if not name.strip():
184
+ raise HTTPException(400, detail="Project name is required")
185
+
186
+ if not user_id.strip():
187
+ raise HTTPException(400, detail="User ID is required")
188
+
189
+ project_id = str(uuid.uuid4())
190
+ current_time = datetime.now(timezone.utc)
191
+
192
+ project = {
193
+ "project_id": project_id,
194
+ "user_id": user_id,
195
+ "name": name.strip(),
196
+ "description": description.strip(),
197
+ "created_at": current_time,
198
+ "updated_at": current_time
199
+ }
200
+
201
+ logger.info(f"[PROJECT] Creating project {name} for user {user_id}")
202
+
203
+ # Insert the project
204
+ try:
205
+ result = rag.db["projects"].insert_one(project)
206
+ logger.info(f"[PROJECT] Created project {name} with ID {project_id}, MongoDB result: {result.inserted_id}")
207
+ except PyMongoError as mongo_error:
208
+ logger.error(f"[PROJECT] MongoDB error creating project: {str(mongo_error)}")
209
+ raise HTTPException(500, detail=f"Database error: {str(mongo_error)}")
210
+ except Exception as db_error:
211
+ logger.error(f"[PROJECT] Database error creating project: {str(db_error)}")
212
+ raise HTTPException(500, detail=f"Database error: {str(db_error)}")
213
+
214
+ # Return a properly formatted response
215
+ response = ProjectResponse(
216
+ project_id=project_id,
217
+ user_id=user_id,
218
+ name=name.strip(),
219
+ description=description.strip(),
220
+ created_at=current_time.isoformat(),
221
+ updated_at=current_time.isoformat()
222
+ )
223
+
224
+ logger.info(f"[PROJECT] Successfully created project {name} for user {user_id}")
225
+ return response
226
+
227
+ except HTTPException:
228
+ # Re-raise HTTP exceptions
229
+ raise
230
+ except Exception as e:
231
+ logger.error(f"[PROJECT] Error creating project: {str(e)}")
232
+ logger.error(f"[PROJECT] Error type: {type(e)}")
233
+ logger.error(f"[PROJECT] Error details: {e}")
234
+ raise HTTPException(500, detail=f"Failed to create project: {str(e)}")
235
+
236
+
237
+ @app.get("/projects", response_model=ProjectsListResponse)
238
+ async def list_projects(user_id: str):
239
+ """List all projects for a user"""
240
+ projects_cursor = rag.db["projects"].find(
241
+ {"user_id": user_id}
242
+ ).sort("updated_at", -1)
243
+
244
+ projects = []
245
+ for project in projects_cursor:
246
+ projects.append(ProjectResponse(
247
+ project_id=project["project_id"],
248
+ user_id=project["user_id"],
249
+ name=project["name"],
250
+ description=project.get("description", ""),
251
+ created_at=project["created_at"].isoformat() if isinstance(project["created_at"], datetime) else str(project["created_at"]),
252
+ updated_at=project["updated_at"].isoformat() if isinstance(project["updated_at"], datetime) else str(project["updated_at"])
253
+ ))
254
+
255
+ return ProjectsListResponse(projects=projects)
256
+
257
+
258
+ @app.get("/projects/{project_id}", response_model=ProjectResponse)
259
+ async def get_project(project_id: str, user_id: str):
260
+ """Get a specific project (with user ownership check)"""
261
+ project = rag.db["projects"].find_one(
262
+ {"project_id": project_id, "user_id": user_id}
263
+ )
264
+ if not project:
265
+ raise HTTPException(404, detail="Project not found")
266
+
267
+ return ProjectResponse(
268
+ project_id=project["project_id"],
269
+ user_id=project["user_id"],
270
+ name=project["name"],
271
+ description=project.get("description", ""),
272
+ created_at=project["created_at"].isoformat() if isinstance(project["created_at"], datetime) else str(project["created_at"]),
273
+ updated_at=project["updated_at"].isoformat() if isinstance(project["updated_at"], datetime) else str(project["updated_at"])
274
+ )
275
+
276
+
277
+ @app.delete("/projects/{project_id}", response_model=MessageResponse)
278
+ async def delete_project(project_id: str, user_id: str):
279
+ """Delete a project and all its associated data"""
280
+ # Check ownership
281
+ project = rag.db["projects"].find_one({"project_id": project_id, "user_id": user_id})
282
+ if not project:
283
+ raise HTTPException(404, detail="Project not found")
284
+
285
+ # Delete project and all associated data
286
+ rag.db["projects"].delete_one({"project_id": project_id})
287
+ rag.db["chunks"].delete_many({"project_id": project_id})
288
+ rag.db["files"].delete_many({"project_id": project_id})
289
+ rag.db["chat_sessions"].delete_many({"project_id": project_id})
290
+
291
+ logger.info(f"[PROJECT] Deleted project {project_id} for user {user_id}")
292
+ return MessageResponse(message="Project deleted successfully")
293
+
294
+
295
+ # ────────────────────────────── Chat Sessions ──────────────────────────────
296
+ @app.post("/chat/save", response_model=MessageResponse)
297
+ async def save_chat_message(
298
+ user_id: str = Form(...),
299
+ project_id: str = Form(...),
300
+ role: str = Form(...),
301
+ content: str = Form(...),
302
+ timestamp: Optional[float] = Form(None),
303
+ sources: Optional[str] = Form(None)
304
+ ):
305
+ """Save a chat message to the session"""
306
+ if role not in ["user", "assistant"]:
307
+ raise HTTPException(400, detail="Invalid role")
308
+
309
+ # Parse optional sources JSON
310
+ parsed_sources: Optional[List[Dict[str, Any]]] = None
311
+ if sources:
312
+ try:
313
+ parsed = json.loads(sources)
314
+ if isinstance(parsed, list):
315
+ parsed_sources = parsed
316
+ except Exception:
317
+ parsed_sources = None
318
+
319
+ message = {
320
+ "user_id": user_id,
321
+ "project_id": project_id,
322
+ "role": role,
323
+ "content": content,
324
+ "timestamp": timestamp or time.time(),
325
+ "created_at": datetime.now(timezone.utc),
326
+ **({"sources": parsed_sources} if parsed_sources is not None else {})
327
+ }
328
+
329
+ rag.db["chat_sessions"].insert_one(message)
330
+ return MessageResponse(message="Chat message saved")
331
+
332
+
333
+ @app.get("/chat/history", response_model=ChatHistoryResponse)
334
+ async def get_chat_history(user_id: str, project_id: str, limit: int = 100):
335
+ """Get chat history for a project"""
336
+ messages_cursor = rag.db["chat_sessions"].find(
337
+ {"user_id": user_id, "project_id": project_id}
338
+ ).sort("timestamp", 1).limit(limit)
339
+
340
+ messages = []
341
+ for message in messages_cursor:
342
+ messages.append(ChatMessageResponse(
343
+ user_id=message["user_id"],
344
+ project_id=message["project_id"],
345
+ role=message["role"],
346
+ content=message["content"],
347
+ timestamp=message["timestamp"],
348
+ created_at=message["created_at"].isoformat() if isinstance(message["created_at"], datetime) else str(message["created_at"]),
349
+ sources=message.get("sources")
350
+ ))
351
+
352
+ return ChatHistoryResponse(messages=messages)
353
+
354
+
355
+ @app.delete("/chat/history", response_model=MessageResponse)
356
+ async def delete_chat_history(user_id: str, project_id: str):
357
+ try:
358
+ rag.db["chat_sessions"].delete_many({"user_id": user_id, "project_id": project_id})
359
+ logger.info(f"[CHAT] Cleared history for user {user_id} project {project_id}")
360
+ # Also clear in-memory LRU for this user to avoid stale context
361
+ try:
362
+ from memo.core import get_memory_system
363
+ memory = get_memory_system()
364
+ memory.clear(user_id)
365
+ logger.info(f"[CHAT] Cleared memory for user {user_id}")
366
+ except Exception as me:
367
+ logger.warning(f"[CHAT] Failed to clear memory for user {user_id}: {me}")
368
+ return MessageResponse(message="Chat history cleared")
369
+ except Exception as e:
370
+ raise HTTPException(500, detail=f"Failed to clear chat history: {str(e)}")
371
+
372
+
373
+ # ────────────────────────────── Helpers ──────────────────────────────
374
+ def _infer_mime(filename: str) -> str:
375
+ lower = filename.lower()
376
+ if lower.endswith(".pdf"):
377
+ return "application/pdf"
378
+ if lower.endswith(".docx"):
379
+ return "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
380
+ return "application/octet-stream"
381
+
382
+
383
+ def _extract_pages(filename: str, file_bytes: bytes) -> List[Dict[str, Any]]:
384
+ mime = _infer_mime(filename)
385
+ if mime == "application/pdf":
386
+ return parse_pdf_bytes(file_bytes)
387
+ elif mime == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
388
+ return parse_docx_bytes(file_bytes)
389
+ else:
390
+ raise HTTPException(status_code=400, detail=f"Unsupported file type: {filename}")
391
+
392
+
393
+ # ────────────────────────────── Routes ──────────────────────────────
394
+ @app.get("/", response_class=HTMLResponse)
395
+ def index():
396
+ index_path = os.path.join("static", "index.html")
397
+ if not os.path.exists(index_path):
398
+ return HTMLResponse("<h1>StudyBuddy</h1><p>Static files not found.</p>")
399
+ return FileResponse(index_path)
400
+
401
+
402
+ @app.post("/upload", response_model=UploadResponse)
403
+ async def upload_files(
404
+ request: Request,
405
+ background_tasks: BackgroundTasks,
406
+ user_id: str = Form(...),
407
+ project_id: str = Form(...),
408
+ files: List[UploadFile] = File(...),
409
+ replace_filenames: Optional[str] = Form(None), # JSON array of filenames to replace
410
+ rename_map: Optional[str] = Form(None), # JSON object {original: newname}
411
+ ):
412
+ """
413
+ Ingest many files: PDF/DOCX.
414
+ Steps:
415
+ 1) Extract text & images
416
+ 2) Caption images (BLIP base, CPU ok)
417
+ 3) Merge captions into page text
418
+ 4) Chunk into semantic cards (topic_name, summary, content + metadata)
419
+ 5) Embed with all-MiniLM-L6-v2
420
+ 6) Store in MongoDB with per-user and per-project metadata
421
+ 7) Create a file-level summary
422
+ """
423
+ job_id = str(uuid.uuid4())
424
+
425
+ # Basic upload policy limits
426
+ max_files = int(os.getenv("MAX_FILES_PER_UPLOAD", "15"))
427
+ max_mb = int(os.getenv("MAX_FILE_MB", "50"))
428
+ if len(files) > max_files:
429
+ raise HTTPException(400, detail=f"Too many files. Max {max_files} allowed per upload.")
430
+
431
+ # Parse replace/rename directives
432
+ replace_set = set()
433
+ try:
434
+ if replace_filenames:
435
+ replace_set = set(json.loads(replace_filenames))
436
+ except Exception:
437
+ pass
438
+ rename_dict: Dict[str, str] = {}
439
+ try:
440
+ if rename_map:
441
+ rename_dict = json.loads(rename_map)
442
+ except Exception:
443
+ pass
444
+
445
+ preloaded_files = []
446
+ for uf in files:
447
+ raw = await uf.read()
448
+ if len(raw) > max_mb * 1024 * 1024:
449
+ raise HTTPException(400, detail=f"{uf.filename} exceeds {max_mb} MB limit")
450
+ # Apply rename if present
451
+ eff_name = rename_dict.get(uf.filename, uf.filename)
452
+ preloaded_files.append((eff_name, raw))
453
+
454
+ # Initialize job status
455
+ app.state.jobs[job_id] = {
456
+ "created_at": time.time(),
457
+ "total": len(preloaded_files),
458
+ "completed": 0,
459
+ "status": "processing",
460
+ "last_error": None,
461
+ }
462
+
463
+ # Single background task: process files sequentially with isolation
464
+ async def _process_all():
465
+ for idx, (fname, raw) in enumerate(preloaded_files, start=1):
466
+ try:
467
+ # If instructed to replace this filename, remove previous data first
468
+ if fname in replace_set:
469
+ try:
470
+ rag.db["chunks"].delete_many({"user_id": user_id, "project_id": project_id, "filename": fname})
471
+ rag.db["files"].delete_many({"user_id": user_id, "project_id": project_id, "filename": fname})
472
+ logger.info(f"[{job_id}] Replaced prior data for {fname}")
473
+ except Exception as de:
474
+ logger.warning(f"[{job_id}] Replace delete failed for {fname}: {de}")
475
+ logger.info(f"[{job_id}] ({idx}/{len(preloaded_files)}) Parsing {fname} ({len(raw)} bytes)")
476
+
477
+ # Extract pages from file
478
+ pages = _extract_pages(fname, raw)
479
+
480
+ # Caption images per page (if any)
481
+ num_imgs = sum(len(p.get("images", [])) for p in pages)
482
+ captions = []
483
+ if num_imgs > 0:
484
+ for p in pages:
485
+ caps = []
486
+ for im in p.get("images", []):
487
+ try:
488
+ cap = captioner.caption_image(im)
489
+ caps.append(cap)
490
+ except Exception as e:
491
+ logger.warning(f"[{job_id}] Caption error in {fname}: {e}")
492
+ captions.append(caps)
493
+ else:
494
+ captions = [[] for _ in pages]
495
+
496
+ # Merge captions into text
497
+ for p, caps in zip(pages, captions):
498
+ if caps:
499
+ p["text"] = (p.get("text", "") + "\n\n" + "\n".join([f"[Image] {c}" for c in caps])).strip()
500
+
501
+ # Build cards
502
+ cards = await build_cards_from_pages(pages, filename=fname, user_id=user_id, project_id=project_id)
503
+ logger.info(f"[{job_id}] Built {len(cards)} cards for {fname}")
504
+
505
+ # Embed & store
506
+ embeddings = embedder.embed([c["content"] for c in cards])
507
+ for c, vec in zip(cards, embeddings):
508
+ c["embedding"] = vec
509
+
510
+ rag.store_cards(cards)
511
+
512
+ # File-level summary (cheap extractive)
513
+ full_text = "\n\n".join(p.get("text", "") for p in pages)
514
+ file_summary = await cheap_summarize(full_text, max_sentences=6)
515
+ rag.upsert_file_summary(user_id=user_id, project_id=project_id, filename=fname, summary=file_summary)
516
+ logger.info(f"[{job_id}] Completed {fname}")
517
+ # Update job progress
518
+ job = app.state.jobs.get(job_id)
519
+ if job:
520
+ job["completed"] = idx
521
+ job["status"] = "processing" if idx < job.get("total", 0) else "completed"
522
+ except Exception as e:
523
+ logger.error(f"[{job_id}] Failed processing {fname}: {e}")
524
+ job = app.state.jobs.get(job_id)
525
+ if job:
526
+ job["last_error"] = str(e)
527
+ job["completed"] = idx # count as completed attempt
528
+ finally:
529
+ # Yield control between files to keep loop responsive
530
+ await asyncio.sleep(0)
531
+
532
+ logger.info(f"[{job_id}] Ingestion complete for {len(preloaded_files)} files")
533
+ # Finalize job status
534
+ job = app.state.jobs.get(job_id)
535
+ if job:
536
+ job["status"] = "completed"
537
+
538
+ background_tasks.add_task(_process_all)
539
+ return UploadResponse(job_id=job_id, status="processing", total_files=len(preloaded_files))
540
+
541
+
542
+ @app.get("/upload/status")
543
+ async def upload_status(job_id: str):
544
+ job = app.state.jobs.get(job_id)
545
+ if not job:
546
+ raise HTTPException(404, detail="Job not found")
547
+ percent = 0
548
+ if job.get("total"):
549
+ percent = int(round((job.get("completed", 0) / job.get("total", 1)) * 100))
550
+ return {
551
+ "job_id": job_id,
552
+ "status": job.get("status"),
553
+ "completed": job.get("completed"),
554
+ "total": job.get("total"),
555
+ "percent": percent,
556
+ "last_error": job.get("last_error"),
557
+ "created_at": job.get("created_at"),
558
+ }
559
+
560
+
561
+ @app.get("/files")
562
+ async def list_project_files(user_id: str, project_id: str):
563
+ """Return stored filenames and summaries for a project."""
564
+ files = rag.list_files(user_id=user_id, project_id=project_id)
565
+ # Ensure filenames list
566
+ filenames = [f.get("filename") for f in files if f.get("filename")]
567
+ return {"files": files, "filenames": filenames}
568
+
569
+
570
+ @app.delete("/files", response_model=MessageResponse)
571
+ async def delete_file(user_id: str, project_id: str, filename: str):
572
+ """Delete a file summary and associated chunks for a project."""
573
+ try:
574
+ rag.db["files"].delete_many({"user_id": user_id, "project_id": project_id, "filename": filename})
575
+ rag.db["chunks"].delete_many({"user_id": user_id, "project_id": project_id, "filename": filename})
576
+ logger.info(f"[FILES] Deleted file {filename} for user {user_id} project {project_id}")
577
+ return MessageResponse(message="File deleted")
578
+ except Exception as e:
579
+ raise HTTPException(500, detail=f"Failed to delete file: {str(e)}")
580
+
581
+
582
+ @app.get("/cards")
583
+ def list_cards(user_id: str, project_id: str, filename: Optional[str] = None, limit: int = 50, skip: int = 0):
584
+ """List cards for a project"""
585
+ cards = rag.list_cards(user_id=user_id, project_id=project_id, filename=filename, limit=limit, skip=skip)
586
+ # Ensure all cards are JSON serializable
587
+ serializable_cards = []
588
+ for card in cards:
589
+ serializable_card = {}
590
+ for key, value in card.items():
591
+ if key == '_id':
592
+ serializable_card[key] = str(value) # Convert ObjectId to string
593
+ elif isinstance(value, datetime):
594
+ serializable_card[key] = value.isoformat() # Convert datetime to ISO string
595
+ else:
596
+ serializable_card[key] = value
597
+ serializable_cards.append(serializable_card)
598
+ # Sort cards by topic_name
599
+ return {"cards": serializable_cards}
600
+
601
+
602
+ @app.get("/file-summary", response_model=FileSummaryResponse)
603
+ def get_file_summary(user_id: str, project_id: str, filename: str):
604
+ doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=filename)
605
+ if not doc:
606
+ raise HTTPException(404, detail="No summary found for that file.")
607
+ return FileSummaryResponse(filename=filename, summary=doc.get("summary", ""))
608
+
609
+
610
+ @app.post("/report", response_model=ReportResponse)
611
+ async def generate_report(
612
+ user_id: str = Form(...),
613
+ project_id: str = Form(...),
614
+ filename: str = Form(...),
615
+ outline_words: int = Form(200),
616
+ report_words: int = Form(1200),
617
+ instructions: str = Form("")
618
+ ):
619
+ """
620
+ Generate a Markdown report for a single document using a lightweight CoT:
621
+ 1) Gemini Flash: create a structured outline based on file summary + top chunks
622
+ 2) Gemini Pro: expand into a full report with citations
623
+ """
624
+ logger.info("[REPORT] User Q/report: %s", trim_text(instructions, 15).replace("\n", " "))
625
+ # Validate file exists
626
+ files_list = rag.list_files(user_id=user_id, project_id=project_id)
627
+ filenames_ci = {f.get("filename", "").lower(): f.get("filename") for f in files_list}
628
+ eff_name = filenames_ci.get(filename.lower(), filename)
629
+ doc_sum = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=eff_name)
630
+ if not doc_sum:
631
+ raise HTTPException(404, detail="No summary found for that file.")
632
+
633
+ # Retrieve top-k chunks for this file using enhanced search
634
+ query_text = f"Comprehensive report for {eff_name}"
635
+ if instructions.strip():
636
+ query_text = f"{instructions} {eff_name}"
637
+
638
+ q_vec = embedder.embed([query_text])[0]
639
+ hits = rag.vector_search(user_id=user_id, project_id=project_id, query_vector=q_vec, k=8, filenames=[eff_name], search_type="flat")
640
+ if not hits:
641
+ # Fall back to summary-only report
642
+ hits = []
643
+
644
+ # Build context
645
+ contexts = []
646
+ sources_meta = []
647
+ for h in hits:
648
+ doc = h["doc"]
649
+ chunk_id = str(doc.get("_id", ""))
650
+ contexts.append(f"[CHUNK_ID: {chunk_id}] [{doc.get('topic_name','Topic')}] {trim_text(doc.get('content',''), 2000)}")
651
+ sources_meta.append({
652
+ "filename": doc.get("filename"),
653
+ "topic_name": doc.get("topic_name"),
654
+ "page_span": doc.get("page_span"),
655
+ "score": float(h.get("score", 0.0)),
656
+ "chunk_id": chunk_id
657
+ })
658
+ context_text = "\n\n---\n\n".join(contexts) if contexts else ""
659
+ file_summary = doc_sum.get("summary", "")
660
+
661
+ # Chain-of-thought style two-step with Gemini
662
+ from utils.api.router import GEMINI_MED, GEMINI_PRO
663
+
664
+ # Step 1: Content filtering and relevance assessment based on user instructions
665
+ if instructions.strip():
666
+ filter_sys = (
667
+ "You are an expert content analyst. Given the user's specific instructions and the document content, "
668
+ "identify which sections/chunks are MOST relevant to their request. "
669
+ "Each chunk is prefixed with [CHUNK_ID: <id>] - use these exact IDs in your response. "
670
+ "Return a JSON object with this structure: {\"relevant_chunks\": [\"<chunk_id_1>\", \"<chunk_id_2>\"], \"focus_areas\": [\"key topic 1\", \"key topic 2\"]}"
671
+ )
672
+ filter_user = f"USER_INSTRUCTIONS: {instructions}\n\nDOCUMENT_SUMMARY: {file_summary}\n\nAVAILABLE_CHUNKS:\n{context_text}\n\nIdentify only the chunks that directly address the user's specific request."
673
+
674
+ try:
675
+ selection_filter = {"provider": "gemini", "model": os.getenv("GEMINI_MED", "gemini-2.5-flash")}
676
+ filter_response = await generate_answer_with_model(selection_filter, filter_sys, filter_user, gemini_rotator, nvidia_rotator)
677
+ logger.info(f"[REPORT] Raw filter response: {filter_response}")
678
+ # Try to parse the filter response to get relevant chunks
679
+ import json
680
+ try:
681
+ filter_data = json.loads(filter_response)
682
+ relevant_chunk_ids = filter_data.get("relevant_chunks", [])
683
+ focus_areas = filter_data.get("focus_areas", [])
684
+ logger.info(f"[REPORT] Content filtering identified {len(relevant_chunk_ids)} relevant chunks: {relevant_chunk_ids} and focus areas: {focus_areas}")
685
+ # Filter context to only relevant chunks
686
+ if relevant_chunk_ids and hits:
687
+ filtered_hits = [h for h in hits if str(h["doc"].get("_id", "")) in relevant_chunk_ids]
688
+ if filtered_hits:
689
+ hits = filtered_hits
690
+ logger.info(f"[REPORT] Filtered context from {len(hits)} chunks to {len(filtered_hits)} relevant chunks")
691
+ else:
692
+ logger.warning(f"[REPORT] No matching chunks found for IDs: {relevant_chunk_ids}")
693
+ else:
694
+ logger.warning(f"[REPORT] No relevant chunk IDs returned or no hits available")
695
+ except json.JSONDecodeError as e:
696
+ logger.warning(f"[REPORT] Could not parse filter response, using all chunks. JSON error: {e}. Response: {filter_response}")
697
+ except Exception as e:
698
+ logger.warning(f"[REPORT] Content filtering failed: {e}")
699
+
700
+ # Step 2: Create focused outline based on user instructions
701
+ sys_outline = (
702
+ "You are an expert technical writer. Create a focused, hierarchical outline for a report based on the user's specific instructions and the MATERIALS. "
703
+ "The outline should directly address what the user asked for. Output as Markdown bullet list only. Keep it within about {} words."
704
+ ).format(max(100, outline_words))
705
+
706
+ instruction_context = f"USER_REQUEST: {instructions}\n\n" if instructions.strip() else ""
707
+ user_outline = f"{instruction_context}MATERIALS:\n\n[FILE_SUMMARY from {eff_name}]\n{file_summary}\n\n[DOC_CONTEXT]\n{context_text}"
708
+
709
+ try:
710
+ # Step 1: Outline with Flash/Med
711
+ selection_outline = {"provider": "gemini", "model": os.getenv("GEMINI_MED", "gemini-2.5-flash")}
712
+ outline_md = await generate_answer_with_model(selection_outline, sys_outline, user_outline, gemini_rotator, nvidia_rotator)
713
+ except Exception as e:
714
+ logger.warning(f"Report outline failed: {e}")
715
+ outline_md = "# Report Outline\n\n- Introduction\n- Key Topics\n- Conclusion"
716
+
717
+ # Step 3: Generate focused report based on user instructions and filtered content
718
+ instruction_focus = f"FOCUS ON: {instructions}\n\n" if instructions.strip() else ""
719
+ sys_report = (
720
+ "You are an expert report writer. Write a focused, comprehensive Markdown report that directly addresses the user's specific request. "
721
+ "Using the OUTLINE and MATERIALS:\n"
722
+ "- Structure the report to answer exactly what the user asked for\n"
723
+ "- Use clear section headings\n"
724
+ "- Keep content factual and grounded in the provided materials\n"
725
+ f"- Include brief citations like (source: {eff_name}, topic) - use the actual filename provided\n"
726
+ "- If the user asked for a specific section/topic, focus heavily on that\n"
727
+ f"- Target length ~{max(600, report_words)} words\n"
728
+ "- Ensure the report directly fulfills the user's request"
729
+ )
730
+ user_report = f"{instruction_focus}OUTLINE:\n{outline_md}\n\nMATERIALS:\n[FILE_SUMMARY from {eff_name}]\n{file_summary}\n\n[DOC_CONTEXT]\n{context_text}"
731
+
732
+ try:
733
+ selection_report = {"provider": "gemini", "model": os.getenv("GEMINI_PRO", "gemini-2.5-pro")}
734
+ report_md = await generate_answer_with_model(selection_report, sys_report, user_report, gemini_rotator, nvidia_rotator)
735
+ except Exception as e:
736
+ logger.error(f"Report generation failed: {e}")
737
+ report_md = outline_md + "\n\n" + file_summary
738
+
739
+ return ReportResponse(filename=eff_name, report_markdown=report_md, sources=sources_meta)
740
+
741
+
742
+ @app.post("/report/pdf")
743
+ async def generate_report_pdf(
744
+ user_id: str = Form(...),
745
+ project_id: str = Form(...),
746
+ report_content: str = Form(...)
747
+ ):
748
+ """
749
+ Generate a PDF from report content using the PDF utility module
750
+ """
751
+ from utils.service.pdf import generate_report_pdf as generate_pdf
752
+ from fastapi.responses import Response
753
+
754
+ try:
755
+ pdf_content = await generate_pdf(report_content, user_id, project_id)
756
+
757
+ # Return PDF as response
758
+ return Response(
759
+ content=pdf_content,
760
+ media_type="application/pdf",
761
+ headers={"Content-Disposition": f"attachment; filename=report-{datetime.now().strftime('%Y-%m-%d')}.pdf"}
762
+ )
763
+
764
+ except HTTPException:
765
+ # Re-raise HTTP exceptions as-is
766
+ raise
767
+
768
+
769
+ # ────────────────────────────── Enhanced RAG Helper Functions ──────────────────────────────
770
+
771
+ async def _generate_query_variations(question: str, nvidia_rotator) -> List[str]:
772
+ """
773
+ Generate multiple query variations using Chain of Thought reasoning
774
+ """
775
+ if not nvidia_rotator:
776
+ return [question] # Fallback to original question
777
+
778
+ try:
779
+ # Use NVIDIA to generate query variations
780
+ sys_prompt = """You are an expert at query expansion and reformulation. Given a user question, generate 3-5 different ways to ask the same question that would help retrieve relevant information from a document database.
781
+
782
+ Focus on:
783
+ 1. Different terminology and synonyms
784
+ 2. More specific technical terms
785
+ 3. Broader conceptual queries
786
+ 4. Question reformulations
787
+
788
+ Return only the variations, one per line, no numbering or extra text."""
789
+
790
+ user_prompt = f"Original question: {question}\n\nGenerate query variations:"
791
+
792
+ from utils.api.router import generate_answer_with_model
793
+ selection = {"provider": "nvidia", "model": "meta/llama-3.1-8b-instruct"}
794
+ response = await generate_answer_with_model(selection, sys_prompt, user_prompt, None, nvidia_rotator)
795
+
796
+ # Parse variations
797
+ variations = [line.strip() for line in response.split('\n') if line.strip()]
798
+ variations = [v for v in variations if len(v) > 10] # Filter out too short variations
799
+
800
+ # Always include original question
801
+ if question not in variations:
802
+ variations.insert(0, question)
803
+
804
+ return variations[:5] # Limit to 5 variations
805
+
806
+ except Exception as e:
807
+ logger.warning(f"Query variation generation failed: {e}")
808
+ return [question]
809
+
810
+
811
+ def _deduplicate_and_rank_hits(all_hits: List[Dict], original_question: str) -> List[Dict]:
812
+ """
813
+ Deduplicate hits by chunk ID and rank by relevance to original question
814
+ """
815
+ if not all_hits:
816
+ return []
817
+
818
+ # Deduplicate by chunk ID
819
+ seen_ids = set()
820
+ unique_hits = []
821
+
822
+ for hit in all_hits:
823
+ chunk_id = str(hit.get("doc", {}).get("_id", ""))
824
+ if chunk_id not in seen_ids:
825
+ seen_ids.add(chunk_id)
826
+ unique_hits.append(hit)
827
+
828
+ # Simple ranking: boost scores for hits that contain question keywords
829
+ question_words = set(original_question.lower().split())
830
+
831
+ for hit in unique_hits:
832
+ content = hit.get("doc", {}).get("content", "").lower()
833
+ topic = hit.get("doc", {}).get("topic_name", "").lower()
834
+
835
+ # Count keyword matches
836
+ content_matches = sum(1 for word in question_words if word in content)
837
+ topic_matches = sum(1 for word in question_words if word in topic)
838
+
839
+ # Boost score based on keyword matches
840
+ keyword_boost = 1.0 + (content_matches * 0.1) + (topic_matches * 0.2)
841
+ hit["score"] = hit.get("score", 0.0) * keyword_boost
842
+
843
+ # Sort by boosted score
844
+ unique_hits.sort(key=lambda x: x.get("score", 0.0), reverse=True)
845
+
846
+ return unique_hits
847
+
848
+
849
+ @app.post("/chat", response_model=ChatAnswerResponse)
850
+ async def chat(
851
+ user_id: str = Form(...),
852
+ project_id: str = Form(...),
853
+ question: str = Form(...),
854
+ k: int = Form(6)
855
+ ):
856
+ # Add timeout protection to prevent hanging
857
+ import asyncio
858
+ try:
859
+ return await asyncio.wait_for(_chat_impl(user_id, project_id, question, k), timeout=120.0)
860
+ except asyncio.TimeoutError:
861
+ logger.error("[CHAT] Chat request timed out after 120 seconds")
862
+ return ChatAnswerResponse(
863
+ answer="Sorry, the request took too long to process. Please try again with a simpler question.",
864
+ sources=[],
865
+ relevant_files=[]
866
+ )
867
+
868
+ async def _chat_impl(
869
+ user_id: str,
870
+ project_id: str,
871
+ question: str,
872
+ k: int
873
+ ):
874
+ """
875
+ RAG chat that answers ONLY from uploaded materials.
876
+ - Preload all filenames + summaries; use NVIDIA to classify file relevance to question (true/false)
877
+ - Restrict vector search to relevant files (fall back to all if none)
878
+ - Bring in recent chat memory: last 3 via NVIDIA relevance; remaining 17 via semantic search
879
+ - After answering, summarize (q,a) via NVIDIA and store into LRU (last 20)
880
+ """
881
+ import sys
882
+ from memo.core import get_memory_system
883
+ from utils.api.router import NVIDIA_SMALL # reuse default name
884
+ memory = get_memory_system()
885
+ logger.info("[CHAT] User Q/chat: %s", trim_text(question, 15).replace("\n", " "))
886
+
887
+ # 0) Detect any filenames mentioned in the question (e.g., JADE.pdf)
888
+ # Supports .pdf, .docx, and .doc for detection purposes
889
+ # Only capture contiguous tokens ending with extension (no spaces) to avoid swallowing prompt text
890
+ mentioned = set([m.group(0).strip() for m in re.finditer(r"\b[^\s/\\]+?\.(?:pdf|docx|doc)\b", question, re.IGNORECASE)])
891
+ if mentioned:
892
+ logger.info(f"[CHAT] Detected mentioned filenames in question: {list(mentioned)}")
893
+
894
+ # 0a) If the question explicitly asks for a summary/about of a single mentioned file, return its summary directly
895
+ if mentioned and (re.search(r"\b(summary|summarize|about|overview)\b", question, re.IGNORECASE)):
896
+ # Prefer direct summary when exactly one file is referenced
897
+ if len(mentioned) == 1:
898
+ fn = next(iter(mentioned))
899
+ doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=fn)
900
+ if doc:
901
+ return ChatAnswerResponse(
902
+ answer=doc.get("summary", ""),
903
+ sources=[{"filename": fn, "file_summary": True}]
904
+ )
905
+ # If not found with the same casing, try case-insensitive match against stored filenames
906
+ files_ci = rag.list_files(user_id=user_id, project_id=project_id)
907
+ match = next((f["filename"] for f in files_ci if f.get("filename", "").lower() == fn.lower()), None)
908
+ if match:
909
+ doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=match)
910
+ if doc:
911
+ return ChatAnswerResponse(
912
+ answer=doc.get("summary", ""),
913
+ sources=[{"filename": match, "file_summary": True}]
914
+ )
915
+ # If multiple files are referenced with summary intent, proceed to relevance flow below
916
+
917
+ # 1) Preload file list + summaries
918
+ files_list = rag.list_files(user_id=user_id, project_id=project_id) # [{filename, summary}]
919
+
920
+ # 1a) Normalize mentioned filenames against the user's library (case-insensitive)
921
+ filenames_ci_map = {f.get("filename", "").lower(): f.get("filename") for f in files_list if f.get("filename")}
922
+ mentioned_normalized = []
923
+ for mfn in mentioned:
924
+ key = mfn.lower()
925
+ if key in filenames_ci_map:
926
+ mentioned_normalized.append(filenames_ci_map[key])
927
+ if mentioned and not mentioned_normalized and files_list:
928
+ # Try looser match: contained filenames ignoring spaces
929
+ norm = {f.get("filename", "").lower().replace(" ", ""): f.get("filename") for f in files_list if f.get("filename")}
930
+ for mfn in mentioned:
931
+ key2 = mfn.lower().replace(" ", "")
932
+ if key2 in norm:
933
+ mentioned_normalized.append(norm[key2])
934
+ if mentioned_normalized:
935
+ logger.info(f"[CHAT] Normalized mentions to stored filenames: {mentioned_normalized}")
936
+
937
+ # 1b) Ask NVIDIA to mark relevance per file
938
+ try:
939
+ from memo.history import get_history_manager
940
+ history_manager = get_history_manager(memory)
941
+ relevant_map = await history_manager.files_relevance(question, files_list, nvidia_rotator)
942
+ relevant_files = [fn for fn, ok in relevant_map.items() if ok]
943
+ logger.info(f"[CHAT] NVIDIA relevant files: {relevant_files}")
944
+ except Exception as e:
945
+ logger.warning(f"[CHAT] NVIDIA relevance failed, defaulting to all files: {e}")
946
+ relevant_files = [f.get("filename") for f in files_list if f.get("filename")]
947
+
948
+ # 1c) Ensure any explicitly mentioned files in the question are included
949
+ # This safeguards against model misclassification
950
+ if mentioned_normalized:
951
+ extra = [fn for fn in mentioned_normalized if fn not in relevant_files]
952
+ relevant_files.extend(extra)
953
+ if extra:
954
+ logger.info(f"[CHAT] Forced-include mentioned files into relevance: {extra}")
955
+
956
+ # 2) Memory context: recent 3 via NVIDIA, remaining 17 via semantic
957
+ # Use enhanced context retrieval if available, otherwise fallback to original method
958
+ try:
959
+ from memo.history import get_history_manager
960
+ history_manager = get_history_manager(memory)
961
+ recent_related, semantic_related = await history_manager.related_recent_and_semantic_context(
962
+ user_id, question, embedder
963
+ )
964
+ except Exception as e:
965
+ logger.warning(f"[CHAT] Enhanced context retrieval failed, using fallback: {e}")
966
+ # Fallback to original method
967
+ recent3 = memory.recent(user_id, 3)
968
+ if recent3:
969
+ sys = "Pick only items that directly relate to the new question. Output the selected items verbatim, no commentary. If none, output nothing."
970
+ numbered = [{"id": i+1, "text": s} for i, s in enumerate(recent3)]
971
+ user = f"Question: {question}\nCandidates:\n{json.dumps(numbered, ensure_ascii=False)}\nSelect any related items and output ONLY their 'text' values concatenated."
972
+ try:
973
+ from utils.api.rotator import robust_post_json
974
+ key = nvidia_rotator.get_key()
975
+ url = "https://integrate.api.nvidia.com/v1/chat/completions"
976
+ payload = {
977
+ "model": os.getenv("NVIDIA_SMALL", "meta/llama-3.1-8b-instruct"),
978
+ "temperature": 0.0,
979
+ "messages": [
980
+ {"role": "system", "content": sys},
981
+ {"role": "user", "content": user},
982
+ ]
983
+ }
984
+ headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key or ''}"}
985
+ data = await robust_post_json(url, headers, payload, nvidia_rotator)
986
+ recent_related = data["choices"][0]["message"]["content"].strip()
987
+ except Exception as e:
988
+ logger.warning(f"Recent-related NVIDIA error: {e}")
989
+ recent_related = ""
990
+ else:
991
+ recent_related = ""
992
+
993
+ # Get semantic context from remaining memories
994
+ rest17 = memory.rest(user_id, 3)
995
+ if rest17:
996
+ import numpy as np
997
+ def _cosine(a: np.ndarray, b: np.ndarray) -> float:
998
+ denom = (np.linalg.norm(a) * np.linalg.norm(b)) or 1.0
999
+ return float(np.dot(a, b) / denom)
1000
+
1001
+ qv = np.array(embedder.embed([question])[0], dtype="float32")
1002
+ mats = embedder.embed([s.strip() for s in rest17])
1003
+ sims = [(_cosine(qv, np.array(v, dtype="float32")), s) for v, s in zip(mats, rest17)]
1004
+ sims.sort(key=lambda x: x[0], reverse=True)
1005
+ top = [s for (sc, s) in sims[:3] if sc > 0.15]
1006
+ semantic_related = "\n\n".join(top) if top else ""
1007
+
1008
+ # 3) Enhanced query reasoning and RAG vector search
1009
+ logger.info(f"[CHAT] Starting enhanced vector search with relevant_files={relevant_files}")
1010
+
1011
+ # Chain of Thought query breakdown for better retrieval
1012
+ enhanced_queries = await _generate_query_variations(question, nvidia_rotator)
1013
+ logger.info(f"[CHAT] Generated {len(enhanced_queries)} query variations")
1014
+
1015
+ # Try multiple search strategies
1016
+ all_hits = []
1017
+ search_strategies = ["flat", "hybrid", "local"] # Try most accurate first
1018
+
1019
+ for strategy in search_strategies:
1020
+ for query_variant in enhanced_queries:
1021
+ q_vec = embedder.embed([query_variant])[0]
1022
+ hits = rag.vector_search(
1023
+ user_id=user_id,
1024
+ project_id=project_id,
1025
+ query_vector=q_vec,
1026
+ k=k,
1027
+ filenames=relevant_files if relevant_files else None,
1028
+ search_type=strategy
1029
+ )
1030
+ if hits:
1031
+ all_hits.extend(hits)
1032
+ logger.info(f"[CHAT] {strategy} search with '{query_variant[:50]}...' returned {len(hits)} hits")
1033
+ break # If we found hits with this strategy, move to next query
1034
+ if all_hits:
1035
+ break # If we found hits, don't try other strategies
1036
+
1037
+ # Deduplicate and rank results
1038
+ hits = _deduplicate_and_rank_hits(all_hits, question)
1039
+ logger.info(f"[CHAT] Final vector search returned {len(hits) if hits else 0} hits")
1040
+ if not hits:
1041
+ logger.info(f"[CHAT] No hits with relevance filter. relevant_files={relevant_files}")
1042
+ # Fallback 1: Try with original question and flat search
1043
+ q_vec_original = embedder.embed([question])[0]
1044
+ hits = rag.vector_search(
1045
+ user_id=user_id,
1046
+ project_id=project_id,
1047
+ query_vector=q_vec_original,
1048
+ k=k,
1049
+ filenames=relevant_files if relevant_files else None,
1050
+ search_type="flat"
1051
+ )
1052
+ logger.info(f"[CHAT] Fallback flat search → hits={len(hits) if hits else 0}")
1053
+
1054
+ # Fallback 2: if we have explicit mentions, try restricting only to them
1055
+ if not hits and mentioned_normalized:
1056
+ hits = rag.vector_search(
1057
+ user_id=user_id,
1058
+ project_id=project_id,
1059
+ query_vector=q_vec_original,
1060
+ k=k,
1061
+ filenames=mentioned_normalized,
1062
+ search_type="flat"
1063
+ )
1064
+ logger.info(f"[CHAT] Fallback with mentioned files only → hits={len(hits) if hits else 0}")
1065
+
1066
+ # Fallback 3: if still empty, try without any filename restriction
1067
+ if not hits:
1068
+ hits = rag.vector_search(
1069
+ user_id=user_id,
1070
+ project_id=project_id,
1071
+ query_vector=q_vec_original,
1072
+ k=k,
1073
+ filenames=None,
1074
+ search_type="flat"
1075
+ )
1076
+ logger.info(f"[CHAT] Fallback with all files → hits={len(hits) if hits else 0}")
1077
+ # If still no hits, and we have mentioned files, try returning their summaries if present
1078
+ if not hits and mentioned_normalized:
1079
+ fsum_map = {f["filename"]: f.get("summary", "") for f in files_list}
1080
+ summaries = [fsum_map.get(fn, "") for fn in mentioned_normalized]
1081
+ summaries = [s for s in summaries if s]
1082
+ if summaries:
1083
+ answer = ("\n\n---\n\n").join(summaries)
1084
+ return ChatAnswerResponse(
1085
+ answer=answer,
1086
+ sources=[{"filename": fn, "file_summary": True} for fn in mentioned_normalized],
1087
+ relevant_files=mentioned_normalized
1088
+ )
1089
+ if not hits:
1090
+ # Last resort: use summaries from relevant files if we didn't have explicit mentions normalized
1091
+ candidates = mentioned_normalized or relevant_files or []
1092
+ if candidates:
1093
+ fsum_map = {f["filename"]: f.get("summary", "") for f in files_list}
1094
+ summaries = [fsum_map.get(fn, "") for fn in candidates]
1095
+ summaries = [s for s in summaries if s]
1096
+ if summaries:
1097
+ answer = ("\n\n---\n\n").join(summaries)
1098
+ logger.info(f"[CHAT] Falling back to file-level summaries for: {candidates}")
1099
+ return ChatAnswerResponse(
1100
+ answer=answer,
1101
+ sources=[{"filename": fn, "file_summary": True} for fn in candidates],
1102
+ relevant_files=candidates
1103
+ )
1104
+ return ChatAnswerResponse(
1105
+ answer="I don't know based on your uploaded materials. Try uploading more sources or rephrasing the question.",
1106
+ sources=[],
1107
+ relevant_files=relevant_files or mentioned_normalized
1108
+ )
1109
+ # If we get here, we have hits, so continue with normal flow
1110
+ # Compose context
1111
+ contexts = []
1112
+ sources_meta = []
1113
+ for h in hits:
1114
+ doc = h["doc"]
1115
+ score = h["score"]
1116
+ contexts.append(f"[{doc.get('topic_name','Topic')}] {trim_text(doc.get('content',''), 2000)}")
1117
+ sources_meta.append({
1118
+ "filename": doc.get("filename"),
1119
+ "topic_name": doc.get("topic_name"),
1120
+ "page_span": doc.get("page_span"),
1121
+ "score": float(score),
1122
+ "chunk_id": str(doc.get("_id", "")) # Convert ObjectId to string
1123
+ })
1124
+ context_text = "\n\n---\n\n".join(contexts)
1125
+
1126
+ # Add file-level summaries for relevant files
1127
+ file_summary_block = ""
1128
+ if relevant_files:
1129
+ fsum_map = {f["filename"]: f.get("summary","") for f in files_list}
1130
+ lines = [f"[{fn}] {fsum_map.get(fn, '')}" for fn in relevant_files]
1131
+ file_summary_block = "\n".join(lines)
1132
+
1133
+ # Guardrail instruction to avoid hallucination
1134
+ system_prompt = (
1135
+ "You are a careful study assistant. Answer strictly using the given CONTEXT.\n"
1136
+ "If the answer isn't in the context, say 'I don't know based on the provided materials.'\n"
1137
+ "Write concise, clear explanations with citations like (source: actual_filename, topic).\n"
1138
+ "Use the exact filename as provided in the context, not placeholders.\n"
1139
+ )
1140
+
1141
+ # Add recent chat context and historical similarity context
1142
+ history_block = ""
1143
+ if recent_related or semantic_related:
1144
+ history_block = "RECENT_CHAT_CONTEXT:\n" + (recent_related or "") + ("\n\nHISTORICAL_SIMILARITY_CONTEXT:\n" + semantic_related if semantic_related else "")
1145
+ composed_context = ""
1146
+ if history_block:
1147
+ composed_context += history_block + "\n\n"
1148
+ if file_summary_block:
1149
+ composed_context += "FILE_SUMMARIES:\n" + file_summary_block + "\n\n"
1150
+ composed_context += "DOC_CONTEXT:\n" + context_text
1151
+
1152
+ # Compose user prompt
1153
+ user_prompt = f"QUESTION:\n{question}\n\nCONTEXT:\n{composed_context}"
1154
+ # Choose model (cost-aware)
1155
+ selection = select_model(question=question, context=composed_context)
1156
+ logger.info(f"Model selection: {selection}")
1157
+ # Generate answer with model
1158
+ logger.info(f"[CHAT] Generating answer with {selection['provider']} {selection['model']}")
1159
+ try:
1160
+ answer = await generate_answer_with_model(
1161
+ selection=selection,
1162
+ system_prompt=system_prompt,
1163
+ user_prompt=user_prompt,
1164
+ gemini_rotator=gemini_rotator,
1165
+ nvidia_rotator=nvidia_rotator
1166
+ )
1167
+ logger.info(f"[CHAT] Answer generated successfully, length: {len(answer)}")
1168
+ except Exception as e:
1169
+ logger.error(f"LLM error: {e}")
1170
+ answer = "I had trouble contacting the language model provider just now. Please try again."
1171
+ # After answering: summarize QA and store in memory (LRU, last 20)
1172
+ try:
1173
+ from memo.history import get_history_manager
1174
+ history_manager = get_history_manager(memory)
1175
+ qa_sum = await history_manager.summarize_qa_with_nvidia(question, answer, nvidia_rotator)
1176
+ memory.add(user_id, qa_sum)
1177
+
1178
+ # Also store enhanced conversation memory if available
1179
+ if memory.is_enhanced_available():
1180
+ await memory.add_conversation_memory(
1181
+ user_id=user_id,
1182
+ question=question,
1183
+ answer=answer,
1184
+ project_id=project_id,
1185
+ context={
1186
+ "relevant_files": relevant_files,
1187
+ "sources_count": len(sources_meta),
1188
+ "timestamp": time.time()
1189
+ }
1190
+ )
1191
+ except Exception as e:
1192
+ logger.warning(f"QA summarize/store failed: {e}")
1193
+ # Trim for logging
1194
+ logger.info("LLM answer (trimmed): %s", trim_text(answer, 200).replace("\n", " "))
1195
+ return ChatAnswerResponse(answer=answer, sources=sources_meta, relevant_files=relevant_files)
1196
+
1197
+
1198
+ @app.get("/healthz", response_model=HealthResponse)
1199
+ def health():
1200
+ return HealthResponse(ok=True)
1201
+
1202
+
1203
+ @app.get("/test-db")
1204
+ async def test_database():
1205
+ """Test database connection and basic operations"""
1206
+ try:
1207
+ if not rag:
1208
+ return {
1209
+ "status": "error",
1210
+ "message": "RAG store not initialized",
1211
+ "error_type": "RAGStoreNotInitialized"
1212
+ }
1213
+
1214
+ # Test basic connection
1215
+ rag.client.admin.command('ping')
1216
+
1217
+ # Test basic insert/query
1218
+ test_collection = rag.db["test_collection"]
1219
+ test_doc = {"test": True, "timestamp": datetime.now(timezone.utc)}
1220
+ result = test_collection.insert_one(test_doc)
1221
+
1222
+ # Test query
1223
+ found = test_collection.find_one({"_id": result.inserted_id})
1224
+
1225
+ # Clean up
1226
+ test_collection.delete_one({"_id": result.inserted_id})
1227
+
1228
+ return {
1229
+ "status": "success",
1230
+ "message": "Database connection and operations working correctly",
1231
+ "test_id": str(result.inserted_id),
1232
+ "found_doc": str(found["_id"]) if found else None
1233
+ }
1234
+
1235
+ except Exception as e:
1236
+ logger.error(f"[TEST-DB] Database test failed: {str(e)}")
1237
+ return {
1238
+ "status": "error",
1239
+ "message": f"Database test failed: {str(e)}",
1240
+ "error_type": str(type(e))
1241
+ }
1242
+
1243
+
1244
+ @app.get("/rag-status")
1245
+ async def rag_status():
1246
+ """Check the status of the RAG store"""
1247
+ if not rag:
1248
+ return {
1249
+ "status": "error",
1250
+ "message": "RAG store not initialized",
1251
+ "rag_available": False
1252
+ }
1253
+
1254
+ try:
1255
+ # Test connection
1256
+ rag.client.admin.command('ping')
1257
+ return {
1258
+ "status": "success",
1259
+ "message": "RAG store is available and connected",
1260
+ "rag_available": True,
1261
+ "database": rag.db.name,
1262
+ "collections": {
1263
+ "chunks": rag.chunks.name,
1264
+ "files": rag.files.name
1265
+ }
1266
+ }
1267
+ except Exception as e:
1268
+ return {
1269
+ "status": "error",
1270
+ "message": f"RAG store connection failed: {str(e)}",
1271
+ "rag_available": False,
1272
+ "error": str(e)
1273
+ }
1274
+
1275
+ # Local dev
1276
+ # if __name__ == "__main__":
1277
+ # import uvicorn
1278
+ # uvicorn.run(app, host="0.0.0.0", port=8000)
helpers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Package init for helpers. Exposes FastAPI app for external import.
2
+ from .setup import app, logger # re-export for convenience
helpers/models.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Optional
2
+ from pydantic import BaseModel
3
+
4
+
5
+ # ────────────────────────────── Response Models ──────────────────────────────
6
+ class ProjectResponse(BaseModel):
7
+ project_id: str
8
+ user_id: str
9
+ name: str
10
+ description: str
11
+ created_at: str
12
+ updated_at: str
13
+
14
+ class ProjectsListResponse(BaseModel):
15
+ projects: List[ProjectResponse]
16
+
17
+ class ChatMessageResponse(BaseModel):
18
+ user_id: str
19
+ project_id: str
20
+ role: str
21
+ content: str
22
+ timestamp: float
23
+ created_at: str
24
+ sources: Optional[List[Dict[str, Any]]] = None
25
+
26
+ class ChatHistoryResponse(BaseModel):
27
+ messages: List[ChatMessageResponse]
28
+
29
+ class MessageResponse(BaseModel):
30
+ message: str
31
+
32
+ class UploadResponse(BaseModel):
33
+ job_id: str
34
+ status: str
35
+ total_files: Optional[int] = None
36
+
37
+ class FileSummaryResponse(BaseModel):
38
+ filename: str
39
+ summary: str
40
+
41
+ class ChatAnswerResponse(BaseModel):
42
+ answer: str
43
+ sources: List[Dict[str, Any]]
44
+ relevant_files: Optional[List[str]] = None
45
+
46
+ class HealthResponse(BaseModel):
47
+ ok: bool
48
+
49
+ class ReportResponse(BaseModel):
50
+ filename: str
51
+ report_markdown: str
52
+ sources: List[Dict[str, Any]]
53
+
54
+
helpers/pages.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Any
3
+ from fastapi import HTTPException
4
+ from utils.ingestion.parser import parse_pdf_bytes, parse_docx_bytes
5
+
6
+
7
+ # ────────────────────────────── Helpers ──────────────────────────────
8
+ def _infer_mime(filename: str) -> str:
9
+ lower = filename.lower()
10
+ if lower.endswith(".pdf"):
11
+ return "application/pdf"
12
+ if lower.endswith(".docx"):
13
+ return "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
14
+ return "application/octet-stream"
15
+
16
+
17
+ def _extract_pages(filename: str, file_bytes: bytes) -> List[Dict[str, Any]]:
18
+ mime = _infer_mime(filename)
19
+ if mime == "application/pdf":
20
+ return parse_pdf_bytes(file_bytes)
21
+ elif mime == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
22
+ return parse_docx_bytes(file_bytes)
23
+ else:
24
+ raise HTTPException(status_code=400, detail=f"Unsupported file type: {filename}")
25
+
26
+
helpers/setup.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, logging
2
+ from dotenv import load_dotenv
3
+ load_dotenv()
4
+
5
+ from fastapi import FastAPI
6
+ from fastapi.staticfiles import StaticFiles
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+
9
+ from utils.logger import get_logger
10
+ from utils.api.rotator import APIKeyRotator
11
+ from utils.ingestion.caption import BlipCaptioner
12
+ from utils.rag.embeddings import EmbeddingClient
13
+ from utils.rag.rag import RAGStore, ensure_indexes
14
+
15
+
16
+ # ────────────────────────────── App Setup ──────────────────────────────
17
+ logger = get_logger("APP", name="studybuddy")
18
+
19
+ app = FastAPI(title="StudyBuddy RAG", version="0.1.0")
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"],
23
+ allow_credentials=True,
24
+ allow_methods=["*"],
25
+ allow_headers=["*"],
26
+ )
27
+
28
+ # Serve static files (index.html, scripts.js, styles.css)
29
+ app.mount("/static", StaticFiles(directory="static"), name="static")
30
+
31
+ # In-memory job tracker (for progress queries)
32
+ app.state.jobs = {}
33
+
34
+
35
+ # ────────────────────────────── Global Clients ──────────────────────────────
36
+ # API rotators (round robin + auto failover on quota errors)
37
+ gemini_rotator = APIKeyRotator(prefix="GEMINI_API_", max_slots=5)
38
+ nvidia_rotator = APIKeyRotator(prefix="NVIDIA_API_", max_slots=5)
39
+
40
+ # Captioner + Embeddings (lazy init inside classes)
41
+ captioner = BlipCaptioner()
42
+ embedder = EmbeddingClient(model_name=os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2"))
43
+
44
+ # Mongo / RAG store
45
+ try:
46
+ rag = RAGStore(mongo_uri=os.getenv("MONGO_URI"), db_name=os.getenv("MONGO_DB", "studybuddy"))
47
+ # Test the connection
48
+ rag.client.admin.command('ping')
49
+ logger.info("[APP] MongoDB connection successful")
50
+ ensure_indexes(rag)
51
+ logger.info("[APP] MongoDB indexes ensured")
52
+ except Exception as e:
53
+ logger.error(f"[APP] Failed to initialize MongoDB/RAG store: {str(e)}")
54
+ logger.error(f"[APP] MONGO_URI: {os.getenv('MONGO_URI', 'Not set')}")
55
+ logger.error(f"[APP] MONGO_DB: {os.getenv('MONGO_DB', 'studybuddy')}")
56
+ # Create a dummy RAG store for now - this will cause errors but prevents the app from crashing
57
+ rag = None
58
+
59
+
routes/auth.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid, time, hashlib, secrets
2
+ from typing import Optional
3
+ from fastapi import Form, HTTPException
4
+
5
+ from helpers.setup import app, rag, logger
6
+
7
+
8
+ # ────────────────────────────── Auth Helpers/Routes ───────────────────────────
9
+ def _hash_password(password: str, salt: Optional[str] = None):
10
+ salt = salt or secrets.token_hex(16)
11
+ dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), bytes.fromhex(salt), 120000)
12
+ return {"salt": salt, "hash": dk.hex()}
13
+
14
+
15
+ def _verify_password(password: str, salt: str, expected_hex: str) -> bool:
16
+ dk = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), bytes.fromhex(salt), 120000)
17
+ return secrets.compare_digest(dk.hex(), expected_hex)
18
+
19
+
20
+ @app.post("/auth/signup")
21
+ async def signup(email: str = Form(...), password: str = Form(...)):
22
+ email = email.strip().lower()
23
+ if not email or not password or "@" not in email:
24
+ raise HTTPException(400, detail="Invalid email or password")
25
+ users = rag.db["users"]
26
+ if users.find_one({"email": email}):
27
+ raise HTTPException(409, detail="Email already registered")
28
+ user_id = str(uuid.uuid4())
29
+ hp = _hash_password(password)
30
+ users.insert_one({
31
+ "email": email,
32
+ "user_id": user_id,
33
+ "pw_salt": hp["salt"],
34
+ "pw_hash": hp["hash"],
35
+ "created_at": int(time.time())
36
+ })
37
+ logger.info(f"[AUTH] Created user {email} -> {user_id}")
38
+ return {"email": email, "user_id": user_id}
39
+
40
+
41
+ @app.post("/auth/login")
42
+ async def login(email: str = Form(...), password: str = Form(...)):
43
+ email = email.strip().lower()
44
+ users = rag.db["users"]
45
+ doc = users.find_one({"email": email})
46
+ if not doc:
47
+ raise HTTPException(401, detail="Invalid credentials")
48
+ if not _verify_password(password, doc.get("pw_salt", ""), doc.get("pw_hash", "")):
49
+ raise HTTPException(401, detail="Invalid credentials")
50
+ logger.info(f"[AUTH] Login {email}")
51
+ return {"email": email, "user_id": doc.get("user_id")}
52
+
53
+
routes/chats.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, time, re, uuid, asyncio, os
2
+ from datetime import datetime, timezone
3
+ from typing import Any, Dict, List, Optional
4
+ from fastapi import Form, HTTPException
5
+
6
+ from helpers.setup import app, rag, logger, embedder, captioner, gemini_rotator, nvidia_rotator
7
+ from helpers.models import ChatMessageResponse, ChatHistoryResponse, MessageResponse, ChatAnswerResponse
8
+ from utils.service.common import trim_text
9
+ from utils.api.router import select_model, generate_answer_with_model
10
+
11
+
12
+ @app.post("/chat/save", response_model=MessageResponse)
13
+ async def save_chat_message(
14
+ user_id: str = Form(...),
15
+ project_id: str = Form(...),
16
+ role: str = Form(...),
17
+ content: str = Form(...),
18
+ timestamp: Optional[float] = Form(None),
19
+ sources: Optional[str] = Form(None)
20
+ ):
21
+ """Save a chat message to the session"""
22
+ if role not in ["user", "assistant"]:
23
+ raise HTTPException(400, detail="Invalid role")
24
+
25
+ # Parse optional sources JSON
26
+ parsed_sources: Optional[List[Dict[str, Any]]] = None
27
+ if sources:
28
+ try:
29
+ parsed = json.loads(sources)
30
+ if isinstance(parsed, list):
31
+ parsed_sources = parsed
32
+ except Exception:
33
+ parsed_sources = None
34
+
35
+ message = {
36
+ "user_id": user_id,
37
+ "project_id": project_id,
38
+ "role": role,
39
+ "content": content,
40
+ "timestamp": timestamp or time.time(),
41
+ "created_at": datetime.now(timezone.utc),
42
+ **({"sources": parsed_sources} if parsed_sources is not None else {})
43
+ }
44
+
45
+ rag.db["chat_sessions"].insert_one(message)
46
+ return MessageResponse(message="Chat message saved")
47
+
48
+
49
+ @app.get("/chat/history", response_model=ChatHistoryResponse)
50
+ async def get_chat_history(user_id: str, project_id: str, limit: int = 100):
51
+ """Get chat history for a project"""
52
+ messages_cursor = rag.db["chat_sessions"].find(
53
+ {"user_id": user_id, "project_id": project_id}
54
+ ).sort("timestamp", 1).limit(limit)
55
+
56
+ messages = []
57
+ for message in messages_cursor:
58
+ messages.append(ChatMessageResponse(
59
+ user_id=message["user_id"],
60
+ project_id=message["project_id"],
61
+ role=message["role"],
62
+ content=message["content"],
63
+ timestamp=message["timestamp"],
64
+ created_at=message["created_at"].isoformat() if isinstance(message["created_at"], datetime) else str(message["created_at"]),
65
+ sources=message.get("sources")
66
+ ))
67
+
68
+ return ChatHistoryResponse(messages=messages)
69
+
70
+
71
+ @app.delete("/chat/history", response_model=MessageResponse)
72
+ async def delete_chat_history(user_id: str, project_id: str):
73
+ try:
74
+ rag.db["chat_sessions"].delete_many({"user_id": user_id, "project_id": project_id})
75
+ logger.info(f"[CHAT] Cleared history for user {user_id} project {project_id}")
76
+ # Also clear in-memory LRU for this user to avoid stale context
77
+ try:
78
+ from memo.core import get_memory_system
79
+ memory = get_memory_system()
80
+ memory.clear(user_id)
81
+ logger.info(f"[CHAT] Cleared memory for user {user_id}")
82
+ except Exception as me:
83
+ logger.warning(f"[CHAT] Failed to clear memory for user {user_id}: {me}")
84
+ return MessageResponse(message="Chat history cleared")
85
+ except Exception as e:
86
+ raise HTTPException(500, detail=f"Failed to clear chat history: {str(e)}")
87
+
88
+
89
+ # ────────────────────────────── RAG Chat and Helpers ──────────────────────────────
90
+ async def _generate_query_variations(question: str, nvidia_rotator) -> List[str]:
91
+ """
92
+ Generate multiple query variations using Chain of Thought reasoning
93
+ """
94
+ if not nvidia_rotator:
95
+ return [question] # Fallback to original question
96
+
97
+ try:
98
+ # Use NVIDIA to generate query variations
99
+ sys_prompt = """You are an expert at query expansion and reformulation. Given a user question, generate 3-5 different ways to ask the same question that would help retrieve relevant information from a document database.
100
+
101
+ Focus on:
102
+ 1. Different terminology and synonyms
103
+ 2. More specific technical terms
104
+ 3. Broader conceptual queries
105
+ 4. Question reformulations
106
+
107
+ Return only the variations, one per line, no numbering or extra text."""
108
+
109
+ user_prompt = f"Original question: {question}\n\nGenerate query variations:"
110
+
111
+ from utils.api.router import generate_answer_with_model
112
+ selection = {"provider": "nvidia", "model": "meta/llama-3.1-8b-instruct"}
113
+ response = await generate_answer_with_model(selection, sys_prompt, user_prompt, None, nvidia_rotator)
114
+
115
+ # Parse variations
116
+ variations = [line.strip() for line in response.split('\n') if line.strip()]
117
+ variations = [v for v in variations if len(v) > 10] # Filter out too short variations
118
+
119
+ # Always include original question
120
+ if question not in variations:
121
+ variations.insert(0, question)
122
+
123
+ return variations[:5] # Limit to 5 variations
124
+
125
+ except Exception as e:
126
+ logger.warning(f"Query variation generation failed: {e}")
127
+ return [question]
128
+
129
+
130
+ def _deduplicate_and_rank_hits(all_hits: List[Dict], original_question: str) -> List[Dict]:
131
+ """
132
+ Deduplicate hits by chunk ID and rank by relevance to original question
133
+ """
134
+ if not all_hits:
135
+ return []
136
+
137
+ # Deduplicate by chunk ID
138
+ seen_ids = set()
139
+ unique_hits = []
140
+
141
+ for hit in all_hits:
142
+ chunk_id = str(hit.get("doc", {}).get("_id", ""))
143
+ if chunk_id not in seen_ids:
144
+ seen_ids.add(chunk_id)
145
+ unique_hits.append(hit)
146
+
147
+ # Simple ranking: boost scores for hits that contain question keywords
148
+ question_words = set(original_question.lower().split())
149
+
150
+ for hit in unique_hits:
151
+ content = hit.get("doc", {}).get("content", "").lower()
152
+ topic = hit.get("doc", {}).get("topic_name", "").lower()
153
+
154
+ # Count keyword matches
155
+ content_matches = sum(1 for word in question_words if word in content)
156
+ topic_matches = sum(1 for word in question_words if word in topic)
157
+
158
+ # Boost score based on keyword matches
159
+ keyword_boost = 1.0 + (content_matches * 0.1) + (topic_matches * 0.2)
160
+ hit["score"] = hit.get("score", 0.0) * keyword_boost
161
+
162
+ # Sort by boosted score
163
+ unique_hits.sort(key=lambda x: x.get("score", 0.0), reverse=True)
164
+
165
+ return unique_hits
166
+
167
+
168
+ @app.post("/chat", response_model=ChatAnswerResponse)
169
+ async def chat(
170
+ user_id: str = Form(...),
171
+ project_id: str = Form(...),
172
+ question: str = Form(...),
173
+ k: int = Form(6)
174
+ ):
175
+ import asyncio
176
+ try:
177
+ return await asyncio.wait_for(_chat_impl(user_id, project_id, question, k), timeout=120.0)
178
+ except asyncio.TimeoutError:
179
+ logger.error("[CHAT] Chat request timed out after 120 seconds")
180
+ return ChatAnswerResponse(
181
+ answer="Sorry, the request took too long to process. Please try again with a simpler question.",
182
+ sources=[],
183
+ relevant_files=[]
184
+ )
185
+
186
+
187
+ async def _chat_impl(
188
+ user_id: str,
189
+ project_id: str,
190
+ question: str,
191
+ k: int
192
+ ):
193
+ import sys
194
+ from memo.core import get_memory_system
195
+ from utils.api.router import NVIDIA_SMALL # reuse default name
196
+ memory = get_memory_system()
197
+ logger.info("[CHAT] User Q/chat: %s", trim_text(question, 15).replace("\n", " "))
198
+
199
+ mentioned = set([m.group(0).strip() for m in re.finditer(r"\b[^\s/\\]+?\.(?:pdf|docx|doc)\b", question, re.IGNORECASE)])
200
+ if mentioned:
201
+ logger.info(f"[CHAT] Detected mentioned filenames in question: {list(mentioned)}")
202
+
203
+ if mentioned and (re.search(r"\b(summary|summarize|about|overview)\b", question, re.IGNORECASE)):
204
+ if len(mentioned) == 1:
205
+ fn = next(iter(mentioned))
206
+ doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=fn)
207
+ if doc:
208
+ return ChatAnswerResponse(
209
+ answer=doc.get("summary", ""),
210
+ sources=[{"filename": fn, "file_summary": True}]
211
+ )
212
+ files_ci = rag.list_files(user_id=user_id, project_id=project_id)
213
+ match = next((f["filename"] for f in files_ci if f.get("filename", "").lower() == fn.lower()), None)
214
+ if match:
215
+ doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=match)
216
+ if doc:
217
+ return ChatAnswerResponse(
218
+ answer=doc.get("summary", ""),
219
+ sources=[{"filename": match, "file_summary": True}]
220
+ )
221
+
222
+ files_list = rag.list_files(user_id=user_id, project_id=project_id)
223
+
224
+ filenames_ci_map = {f.get("filename", "").lower(): f.get("filename") for f in files_list if f.get("filename")}
225
+ mentioned_normalized = []
226
+ for mfn in mentioned:
227
+ key = mfn.lower()
228
+ if key in filenames_ci_map:
229
+ mentioned_normalized.append(filenames_ci_map[key])
230
+ if mentioned and not mentioned_normalized and files_list:
231
+ norm = {f.get("filename", "").lower().replace(" ", ""): f.get("filename") for f in files_list if f.get("filename")}
232
+ for mfn in mentioned:
233
+ key2 = mfn.lower().replace(" ", "")
234
+ if key2 in norm:
235
+ mentioned_normalized.append(norm[key2])
236
+ if mentioned_normalized:
237
+ logger.info(f"[CHAT] Normalized mentions to stored filenames: {mentioned_normalized}")
238
+
239
+ try:
240
+ from memo.history import get_history_manager
241
+ history_manager = get_history_manager(memory)
242
+ relevant_map = await history_manager.files_relevance(question, files_list, nvidia_rotator)
243
+ relevant_files = [fn for fn, ok in relevant_map.items() if ok]
244
+ logger.info(f"[CHAT] NVIDIA relevant files: {relevant_files}")
245
+ except Exception as e:
246
+ logger.warning(f"[CHAT] NVIDIA relevance failed, defaulting to all files: {e}")
247
+ relevant_files = [f.get("filename") for f in files_list if f.get("filename")]
248
+
249
+ if mentioned_normalized:
250
+ extra = [fn for fn in mentioned_normalized if fn not in relevant_files]
251
+ relevant_files.extend(extra)
252
+ if extra:
253
+ logger.info(f"[CHAT] Forced-include mentioned files into relevance: {extra}")
254
+
255
+ try:
256
+ from memo.history import get_history_manager
257
+ history_manager = get_history_manager(memory)
258
+ recent_related, semantic_related = await history_manager.related_recent_and_semantic_context(
259
+ user_id, question, embedder
260
+ )
261
+ except Exception as e:
262
+ logger.warning(f"[CHAT] Enhanced context retrieval failed, using fallback: {e}")
263
+ recent3 = memory.recent(user_id, 3)
264
+ if recent3:
265
+ sys = "Pick only items that directly relate to the new question. Output the selected items verbatim, no commentary. If none, output nothing."
266
+ numbered = [{"id": i+1, "text": s} for i, s in enumerate(recent3)]
267
+ user = f"Question: {question}\nCandidates:\n{json.dumps(numbered, ensure_ascii=False)}\nSelect any related items and output ONLY their 'text' values concatenated."
268
+ try:
269
+ from utils.api.rotator import robust_post_json
270
+ key = nvidia_rotator.get_key()
271
+ url = "https://integrate.api.nvidia.com/v1/chat/completions"
272
+ payload = {
273
+ "model": os.getenv("NVIDIA_SMALL", "meta/llama-3.1-8b-instruct"),
274
+ "temperature": 0.0,
275
+ "messages": [
276
+ {"role": "system", "content": sys},
277
+ {"role": "user", "content": user},
278
+ ]
279
+ }
280
+ headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key or ''}"}
281
+ data = await robust_post_json(url, headers, payload, nvidia_rotator)
282
+ recent_related = data["choices"][0]["message"]["content"].strip()
283
+ except Exception as e:
284
+ logger.warning(f"Recent-related NVIDIA error: {e}")
285
+ recent_related = ""
286
+ else:
287
+ recent_related = ""
288
+ rest17 = memory.rest(user_id, 3)
289
+ if rest17:
290
+ import numpy as np
291
+ def _cosine(a: np.ndarray, b: np.ndarray) -> float:
292
+ denom = (np.linalg.norm(a) * np.linalg.norm(b)) or 1.0
293
+ return float(np.dot(a, b) / denom)
294
+ qv = np.array(embedder.embed([question])[0], dtype="float32")
295
+ mats = embedder.embed([s.strip() for s in rest17])
296
+ sims = [(_cosine(qv, np.array(v, dtype="float32")), s) for v, s in zip(mats, rest17)]
297
+ sims.sort(key=lambda x: x[0], reverse=True)
298
+ top = [s for (sc, s) in sims[:3] if sc > 0.15]
299
+ semantic_related = "\n\n".join(top) if top else ""
300
+
301
+ logger.info(f"[CHAT] Starting enhanced vector search with relevant_files={relevant_files}")
302
+ enhanced_queries = await _generate_query_variations(question, nvidia_rotator)
303
+ logger.info(f"[CHAT] Generated {len(enhanced_queries)} query variations")
304
+ all_hits = []
305
+ search_strategies = ["flat", "hybrid", "local"]
306
+ for strategy in search_strategies:
307
+ for query_variant in enhanced_queries:
308
+ q_vec = embedder.embed([query_variant])[0]
309
+ hits = rag.vector_search(
310
+ user_id=user_id,
311
+ project_id=project_id,
312
+ query_vector=q_vec,
313
+ k=k,
314
+ filenames=relevant_files if relevant_files else None,
315
+ search_type=strategy
316
+ )
317
+ if hits:
318
+ all_hits.extend(hits)
319
+ logger.info(f"[CHAT] {strategy} search with '{query_variant[:50]}...' returned {len(hits)} hits")
320
+ break
321
+ if all_hits:
322
+ break
323
+ hits = _deduplicate_and_rank_hits(all_hits, question)
324
+ logger.info(f"[CHAT] Final vector search returned {len(hits) if hits else 0} hits")
325
+ if not hits:
326
+ logger.info(f"[CHAT] No hits with relevance filter. relevant_files={relevant_files}")
327
+ q_vec_original = embedder.embed([question])[0]
328
+ hits = rag.vector_search(
329
+ user_id=user_id,
330
+ project_id=project_id,
331
+ query_vector=q_vec_original,
332
+ k=k,
333
+ filenames=relevant_files if relevant_files else None,
334
+ search_type="flat"
335
+ )
336
+ logger.info(f"[CHAT] Fallback flat search → hits={len(hits) if hits else 0}")
337
+ if not hits and mentioned_normalized:
338
+ hits = rag.vector_search(
339
+ user_id=user_id,
340
+ project_id=project_id,
341
+ query_vector=q_vec_original,
342
+ k=k,
343
+ filenames=mentioned_normalized,
344
+ search_type="flat"
345
+ )
346
+ logger.info(f"[CHAT] Fallback with mentioned files only → hits={len(hits) if hits else 0}")
347
+ if not hits:
348
+ hits = rag.vector_search(
349
+ user_id=user_id,
350
+ project_id=project_id,
351
+ query_vector=q_vec_original,
352
+ k=k,
353
+ filenames=None,
354
+ search_type="flat"
355
+ )
356
+ logger.info(f"[CHAT] Fallback with all files → hits={len(hits) if hits else 0}")
357
+ if not hits and mentioned_normalized:
358
+ fsum_map = {f["filename"]: f.get("summary", "") for f in files_list}
359
+ summaries = [fsum_map.get(fn, "") for fn in mentioned_normalized]
360
+ summaries = [s for s in summaries if s]
361
+ if summaries:
362
+ answer = ("\n\n---\n\n").join(summaries)
363
+ return ChatAnswerResponse(
364
+ answer=answer,
365
+ sources=[{"filename": fn, "file_summary": True} for fn in mentioned_normalized],
366
+ relevant_files=mentioned_normalized
367
+ )
368
+ if not hits:
369
+ candidates = mentioned_normalized or relevant_files or []
370
+ if candidates:
371
+ fsum_map = {f["filename"]: f.get("summary", "") for f in files_list}
372
+ summaries = [fsum_map.get(fn, "") for fn in candidates]
373
+ summaries = [s for s in summaries if s]
374
+ if summaries:
375
+ answer = ("\n\n---\n\n").join(summaries)
376
+ logger.info(f"[CHAT] Falling back to file-level summaries for: {candidates}")
377
+ return ChatAnswerResponse(
378
+ answer=answer,
379
+ sources=[{"filename": fn, "file_summary": True} for fn in candidates],
380
+ relevant_files=candidates
381
+ )
382
+ return ChatAnswerResponse(
383
+ answer="I don't know based on your uploaded materials. Try uploading more sources or rephrasing the question.",
384
+ sources=[],
385
+ relevant_files=relevant_files or mentioned_normalized
386
+ )
387
+ contexts = []
388
+ sources_meta = []
389
+ for h in hits:
390
+ doc = h["doc"]
391
+ score = h["score"]
392
+ contexts.append(f"[{doc.get('topic_name','Topic')}] {trim_text(doc.get('content',''), 2000)}")
393
+ sources_meta.append({
394
+ "filename": doc.get("filename"),
395
+ "topic_name": doc.get("topic_name"),
396
+ "page_span": doc.get("page_span"),
397
+ "score": float(score),
398
+ "chunk_id": str(doc.get("_id", ""))
399
+ })
400
+ context_text = "\n\n---\n\n".join(contexts)
401
+
402
+ file_summary_block = ""
403
+ if relevant_files:
404
+ fsum_map = {f["filename"]: f.get("summary","") for f in files_list}
405
+ lines = [f"[{fn}] {fsum_map.get(fn, '')}" for fn in relevant_files]
406
+ file_summary_block = "\n".join(lines)
407
+
408
+ system_prompt = (
409
+ "You are a careful study assistant. Answer strictly using the given CONTEXT.\n"
410
+ "If the answer isn't in the context, say 'I don't know based on the provided materials.'\n"
411
+ "Write concise, clear explanations with citations like (source: actual_filename, topic).\n"
412
+ "Use the exact filename as provided in the context, not placeholders.\n"
413
+ )
414
+
415
+ history_block = ""
416
+ if recent_related or semantic_related:
417
+ history_block = "RECENT_CHAT_CONTEXT:\n" + (recent_related or "") + ("\n\nHISTORICAL_SIMILARITY_CONTEXT:\n" + semantic_related if semantic_related else "")
418
+ composed_context = ""
419
+ if history_block:
420
+ composed_context += history_block + "\n\n"
421
+ if file_summary_block:
422
+ composed_context += "FILE_SUMMARIES:\n" + file_summary_block + "\n\n"
423
+ composed_context += "DOC_CONTEXT:\n" + context_text
424
+
425
+ user_prompt = f"QUESTION:\n{question}\n\nCONTEXT:\n{composed_context}"
426
+ selection = select_model(question=question, context=composed_context)
427
+ logger.info(f"Model selection: {selection}")
428
+ logger.info(f"[CHAT] Generating answer with {selection['provider']} {selection['model']}")
429
+ try:
430
+ answer = await generate_answer_with_model(
431
+ selection=selection,
432
+ system_prompt=system_prompt,
433
+ user_prompt=user_prompt,
434
+ gemini_rotator=gemini_rotator,
435
+ nvidia_rotator=nvidia_rotator
436
+ )
437
+ logger.info(f"[CHAT] Answer generated successfully, length: {len(answer)}")
438
+ except Exception as e:
439
+ logger.error(f"LLM error: {e}")
440
+ answer = "I had trouble contacting the language model provider just now. Please try again."
441
+ try:
442
+ from memo.history import get_history_manager
443
+ history_manager = get_history_manager(memory)
444
+ qa_sum = await history_manager.summarize_qa_with_nvidia(question, answer, nvidia_rotator)
445
+ memory.add(user_id, qa_sum)
446
+ if memory.is_enhanced_available():
447
+ await memory.add_conversation_memory(
448
+ user_id=user_id,
449
+ question=question,
450
+ answer=answer,
451
+ project_id=project_id,
452
+ context={
453
+ "relevant_files": relevant_files,
454
+ "sources_count": len(sources_meta),
455
+ "timestamp": time.time()
456
+ }
457
+ )
458
+ except Exception as e:
459
+ logger.warning(f"QA summarize/store failed: {e}")
460
+ logger.info("LLM answer (trimmed): %s", trim_text(answer, 200).replace("\n", " "))
461
+ return ChatAnswerResponse(answer=answer, sources=sources_meta, relevant_files=relevant_files)
462
+
463
+
routes/files.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, json, uuid, time, asyncio
2
+ from typing import List, Dict, Any, Optional
3
+ from datetime import datetime
4
+ from fastapi import UploadFile, File, Form, Request, HTTPException, BackgroundTasks
5
+ from fastapi.responses import FileResponse, HTMLResponse
6
+
7
+ from helpers.setup import app, rag, logger, embedder, captioner
8
+ from helpers.models import UploadResponse, FileSummaryResponse, MessageResponse
9
+ from helpers.pages import _extract_pages
10
+
11
+ from utils.service.summarizer import cheap_summarize
12
+ from utils.ingestion.chunker import build_cards_from_pages
13
+ from utils.service.common import trim_text
14
+
15
+
16
+ @app.get("/", response_class=HTMLResponse)
17
+ def index():
18
+ index_path = os.path.join("static", "index.html")
19
+ if not os.path.exists(index_path):
20
+ return HTMLResponse("<h1>StudyBuddy</h1><p>Static files not found.</p>")
21
+ return FileResponse(index_path)
22
+
23
+
24
+ @app.post("/upload", response_model=UploadResponse)
25
+ async def upload_files(
26
+ request: Request,
27
+ background_tasks: BackgroundTasks,
28
+ user_id: str = Form(...),
29
+ project_id: str = Form(...),
30
+ files: List[UploadFile] = File(...),
31
+ replace_filenames: Optional[str] = Form(None),
32
+ rename_map: Optional[str] = Form(None),
33
+ ):
34
+ """
35
+ Ingest many files: PDF/DOCX.
36
+ Steps:
37
+ 1) Extract text & images
38
+ 2) Caption images (BLIP base, CPU ok)
39
+ 3) Merge captions into page text
40
+ 4) Chunk into semantic cards (topic_name, summary, content + metadata)
41
+ 5) Embed with all-MiniLM-L6-v2
42
+ 6) Store in MongoDB with per-user and per-project metadata
43
+ 7) Create a file-level summary
44
+ """
45
+ job_id = str(uuid.uuid4())
46
+
47
+ max_files = int(os.getenv("MAX_FILES_PER_UPLOAD", "15"))
48
+ max_mb = int(os.getenv("MAX_FILE_MB", "50"))
49
+ if len(files) > max_files:
50
+ raise HTTPException(400, detail=f"Too many files. Max {max_files} allowed per upload.")
51
+
52
+ replace_set = set()
53
+ try:
54
+ if replace_filenames:
55
+ replace_set = set(json.loads(replace_filenames))
56
+ except Exception:
57
+ pass
58
+ rename_dict: Dict[str, str] = {}
59
+ try:
60
+ if rename_map:
61
+ rename_dict = json.loads(rename_map)
62
+ except Exception:
63
+ pass
64
+
65
+ preloaded_files = []
66
+ for uf in files:
67
+ raw = await uf.read()
68
+ if len(raw) > max_mb * 1024 * 1024:
69
+ raise HTTPException(400, detail=f"{uf.filename} exceeds {max_mb} MB limit")
70
+ eff_name = rename_dict.get(uf.filename, uf.filename)
71
+ preloaded_files.append((eff_name, raw))
72
+
73
+ app.state.jobs[job_id] = {
74
+ "created_at": time.time(),
75
+ "total": len(preloaded_files),
76
+ "completed": 0,
77
+ "status": "processing",
78
+ "last_error": None,
79
+ }
80
+
81
+ async def _process_all():
82
+ for idx, (fname, raw) in enumerate(preloaded_files, start=1):
83
+ try:
84
+ if fname in replace_set:
85
+ try:
86
+ rag.db["chunks"].delete_many({"user_id": user_id, "project_id": project_id, "filename": fname})
87
+ rag.db["files"].delete_many({"user_id": user_id, "project_id": project_id, "filename": fname})
88
+ logger.info(f"[{job_id}] Replaced prior data for {fname}")
89
+ except Exception as de:
90
+ logger.warning(f"[{job_id}] Replace delete failed for {fname}: {de}")
91
+ logger.info(f"[{job_id}] ({idx}/{len(preloaded_files)}) Parsing {fname} ({len(raw)} bytes)")
92
+
93
+ pages = _extract_pages(fname, raw)
94
+
95
+ num_imgs = sum(len(p.get("images", [])) for p in pages)
96
+ captions = []
97
+ if num_imgs > 0:
98
+ for p in pages:
99
+ caps = []
100
+ for im in p.get("images", []):
101
+ try:
102
+ cap = captioner.caption_image(im)
103
+ caps.append(cap)
104
+ except Exception as e:
105
+ logger.warning(f"[{job_id}] Caption error in {fname}: {e}")
106
+ captions.append(caps)
107
+ else:
108
+ captions = [[] for _ in pages]
109
+
110
+ for p, caps in zip(pages, captions):
111
+ if caps:
112
+ p["text"] = (p.get("text", "") + "\n\n" + "\n".join([f"[Image] {c}" for c in caps])).strip()
113
+
114
+ cards = await build_cards_from_pages(pages, filename=fname, user_id=user_id, project_id=project_id)
115
+ logger.info(f"[{job_id}] Built {len(cards)} cards for {fname}")
116
+
117
+ embeddings = embedder.embed([c["content"] for c in cards])
118
+ for c, vec in zip(cards, embeddings):
119
+ c["embedding"] = vec
120
+
121
+ rag.store_cards(cards)
122
+
123
+ full_text = "\n\n".join(p.get("text", "") for p in pages)
124
+ file_summary = await cheap_summarize(full_text, max_sentences=6)
125
+ rag.upsert_file_summary(user_id=user_id, project_id=project_id, filename=fname, summary=file_summary)
126
+ logger.info(f"[{job_id}] Completed {fname}")
127
+ job = app.state.jobs.get(job_id)
128
+ if job:
129
+ job["completed"] = idx
130
+ job["status"] = "processing" if idx < job.get("total", 0) else "completed"
131
+ except Exception as e:
132
+ logger.error(f"[{job_id}] Failed processing {fname}: {e}")
133
+ job = app.state.jobs.get(job_id)
134
+ if job:
135
+ job["last_error"] = str(e)
136
+ job["completed"] = idx
137
+ finally:
138
+ await asyncio.sleep(0)
139
+
140
+ logger.info(f"[{job_id}] Ingestion complete for {len(preloaded_files)} files")
141
+ job = app.state.jobs.get(job_id)
142
+ if job:
143
+ job["status"] = "completed"
144
+
145
+ background_tasks.add_task(_process_all)
146
+ return UploadResponse(job_id=job_id, status="processing", total_files=len(preloaded_files))
147
+
148
+
149
+ @app.get("/upload/status")
150
+ async def upload_status(job_id: str):
151
+ job = app.state.jobs.get(job_id)
152
+ if not job:
153
+ raise HTTPException(404, detail="Job not found")
154
+ percent = 0
155
+ if job.get("total"):
156
+ percent = int(round((job.get("completed", 0) / job.get("total", 1)) * 100))
157
+ return {
158
+ "job_id": job_id,
159
+ "status": job.get("status"),
160
+ "completed": job.get("completed"),
161
+ "total": job.get("total"),
162
+ "percent": percent,
163
+ "last_error": job.get("last_error"),
164
+ "created_at": job.get("created_at"),
165
+ }
166
+
167
+
168
+ @app.get("/files")
169
+ async def list_project_files(user_id: str, project_id: str):
170
+ """Return stored filenames and summaries for a project."""
171
+ files = rag.list_files(user_id=user_id, project_id=project_id)
172
+ filenames = [f.get("filename") for f in files if f.get("filename")]
173
+ return {"files": files, "filenames": filenames}
174
+
175
+
176
+ @app.delete("/files", response_model=MessageResponse)
177
+ async def delete_file(user_id: str, project_id: str, filename: str):
178
+ """Delete a file summary and associated chunks for a project."""
179
+ try:
180
+ rag.db["files"].delete_many({"user_id": user_id, "project_id": project_id, "filename": filename})
181
+ rag.db["chunks"].delete_many({"user_id": user_id, "project_id": project_id, "filename": filename})
182
+ logger.info(f"[FILES] Deleted file {filename} for user {user_id} project {project_id}")
183
+ return MessageResponse(message="File deleted")
184
+ except Exception as e:
185
+ raise HTTPException(500, detail=f"Failed to delete file: {str(e)}")
186
+
187
+
188
+ @app.get("/cards")
189
+ def list_cards(user_id: str, project_id: str, filename: Optional[str] = None, limit: int = 50, skip: int = 0):
190
+ """List cards for a project"""
191
+ cards = rag.list_cards(user_id=user_id, project_id=project_id, filename=filename, limit=limit, skip=skip)
192
+ # Ensure all cards are JSON serializable
193
+ serializable_cards = []
194
+ for card in cards:
195
+ serializable_card = {}
196
+ for key, value in card.items():
197
+ if key == '_id':
198
+ serializable_card[key] = str(value) # Convert ObjectId to string
199
+ elif isinstance(value, datetime):
200
+ serializable_card[key] = value.isoformat() # Convert datetime to ISO string
201
+ else:
202
+ serializable_card[key] = value
203
+ serializable_cards.append(serializable_card)
204
+ # Sort cards by topic_name
205
+ return {"cards": serializable_cards}
206
+
207
+
208
+ @app.get("/file-summary", response_model=FileSummaryResponse)
209
+ def get_file_summary(user_id: str, project_id: str, filename: str):
210
+ doc = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=filename)
211
+ if not doc:
212
+ raise HTTPException(404, detail="No summary found for that file.")
213
+ return FileSummaryResponse(filename=filename, summary=doc.get("summary", ""))
214
+
215
+
routes/health.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from helpers.setup import app, rag, logger
2
+ from helpers.models import HealthResponse
3
+
4
+
5
+ @app.get("/healthz", response_model=HealthResponse)
6
+ def health():
7
+ return HealthResponse(ok=True)
8
+
9
+
10
+ @app.get("/test-db")
11
+ async def test_database():
12
+ """Test database connection and basic operations"""
13
+ from datetime import datetime, timezone
14
+ try:
15
+ if not rag:
16
+ return {
17
+ "status": "error",
18
+ "message": "RAG store not initialized",
19
+ "error_type": "RAGStoreNotInitialized"
20
+ }
21
+ rag.client.admin.command('ping')
22
+ test_collection = rag.db["test_collection"]
23
+ test_doc = {"test": True, "timestamp": datetime.now(timezone.utc)}
24
+ result = test_collection.insert_one(test_doc)
25
+ found = test_collection.find_one({"_id": result.inserted_id})
26
+ test_collection.delete_one({"_id": result.inserted_id})
27
+ return {
28
+ "status": "success",
29
+ "message": "Database connection and operations working correctly",
30
+ "test_id": str(result.inserted_id),
31
+ "found_doc": str(found["_id"]) if found else None
32
+ }
33
+ except Exception as e:
34
+ logger.error(f"[TEST-DB] Database test failed: {str(e)}")
35
+ return {
36
+ "status": "error",
37
+ "message": f"Database test failed: {str(e)}",
38
+ "error_type": str(type(e))
39
+ }
40
+
41
+
42
+ @app.get("/rag-status")
43
+ async def rag_status():
44
+ """Check the status of the RAG store"""
45
+ if not rag:
46
+ return {
47
+ "status": "error",
48
+ "message": "RAG store not initialized",
49
+ "rag_available": False
50
+ }
51
+ try:
52
+ rag.client.admin.command('ping')
53
+ return {
54
+ "status": "success",
55
+ "message": "RAG store is available and connected",
56
+ "rag_available": True,
57
+ "database": rag.db.name,
58
+ "collections": {
59
+ "chunks": rag.chunks.name,
60
+ "files": rag.files.name
61
+ }
62
+ }
63
+ except Exception as e:
64
+ return {
65
+ "status": "error",
66
+ "message": f"RAG store connection failed: {str(e)}",
67
+ "rag_available": False,
68
+ "error": str(e)
69
+ }
70
+
71
+
routes/projects.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import time
3
+ from datetime import datetime, timezone
4
+ from fastapi import Form, HTTPException
5
+ from pymongo.errors import PyMongoError
6
+
7
+ from helpers.setup import app, rag, logger
8
+ from helpers.models import ProjectResponse, ProjectsListResponse, MessageResponse
9
+
10
+
11
+ # ────────────────────────────── Project Management ───────────────────────────
12
+ @app.post("/projects/create", response_model=ProjectResponse)
13
+ async def create_project(user_id: str = Form(...), name: str = Form(...), description: str = Form("")):
14
+ """Create a new project for a user"""
15
+ try:
16
+ if not rag:
17
+ raise HTTPException(500, detail="Database connection not available")
18
+
19
+ if not name.strip():
20
+ raise HTTPException(400, detail="Project name is required")
21
+
22
+ if not user_id.strip():
23
+ raise HTTPException(400, detail="User ID is required")
24
+
25
+ project_id = str(uuid.uuid4())
26
+ current_time = datetime.now(timezone.utc)
27
+
28
+ project = {
29
+ "project_id": project_id,
30
+ "user_id": user_id,
31
+ "name": name.strip(),
32
+ "description": description.strip(),
33
+ "created_at": current_time,
34
+ "updated_at": current_time
35
+ }
36
+
37
+ logger.info(f"[PROJECT] Creating project {name} for user {user_id}")
38
+
39
+ # Insert the project
40
+ try:
41
+ result = rag.db["projects"].insert_one(project)
42
+ logger.info(f"[PROJECT] Created project {name} with ID {project_id}, MongoDB result: {result.inserted_id}")
43
+ except PyMongoError as mongo_error:
44
+ logger.error(f"[PROJECT] MongoDB error creating project: {str(mongo_error)}")
45
+ raise HTTPException(500, detail=f"Database error: {str(mongo_error)}")
46
+ except Exception as db_error:
47
+ logger.error(f"[PROJECT] Database error creating project: {str(db_error)}")
48
+ raise HTTPException(500, detail=f"Database error: {str(db_error)}")
49
+
50
+ # Return a properly formatted response
51
+ response = ProjectResponse(
52
+ project_id=project_id,
53
+ user_id=user_id,
54
+ name=name.strip(),
55
+ description=description.strip(),
56
+ created_at=current_time.isoformat(),
57
+ updated_at=current_time.isoformat()
58
+ )
59
+
60
+ logger.info(f"[PROJECT] Successfully created project {name} for user {user_id}")
61
+ return response
62
+
63
+ except HTTPException:
64
+ # Re-raise HTTP exceptions
65
+ raise
66
+ except Exception as e:
67
+ logger.error(f"[PROJECT] Error creating project: {str(e)}")
68
+ logger.error(f"[PROJECT] Error type: {type(e)}")
69
+ logger.error(f"[PROJECT] Error details: {e}")
70
+ raise HTTPException(500, detail=f"Failed to create project: {str(e)}")
71
+
72
+
73
+ @app.get("/projects", response_model=ProjectsListResponse)
74
+ async def list_projects(user_id: str):
75
+ """List all projects for a user"""
76
+ projects_cursor = rag.db["projects"].find(
77
+ {"user_id": user_id}
78
+ ).sort("updated_at", -1)
79
+
80
+ projects = []
81
+ for project in projects_cursor:
82
+ projects.append(ProjectResponse(
83
+ project_id=project["project_id"],
84
+ user_id=project["user_id"],
85
+ name=project["name"],
86
+ description=project.get("description", ""),
87
+ created_at=project["created_at"].isoformat() if isinstance(project["created_at"], datetime) else str(project["created_at"]),
88
+ updated_at=project["updated_at"].isoformat() if isinstance(project["updated_at"], datetime) else str(project["updated_at"])
89
+ ))
90
+
91
+ return ProjectsListResponse(projects=projects)
92
+
93
+
94
+ @app.get("/projects/{project_id}", response_model=ProjectResponse)
95
+ async def get_project(project_id: str, user_id: str):
96
+ """Get a specific project (with user ownership check)"""
97
+ project = rag.db["projects"].find_one(
98
+ {"project_id": project_id, "user_id": user_id}
99
+ )
100
+ if not project:
101
+ raise HTTPException(404, detail="Project not found")
102
+
103
+ return ProjectResponse(
104
+ project_id=project["project_id"],
105
+ user_id=project["user_id"],
106
+ name=project["name"],
107
+ description=project.get("description", ""),
108
+ created_at=project["created_at"].isoformat() if isinstance(project["created_at"], datetime) else str(project["created_at"]),
109
+ updated_at=project["updated_at"].isoformat() if isinstance(project["updated_at"], datetime) else str(project["updated_at"])
110
+ )
111
+
112
+
113
+ @app.delete("/projects/{project_id}", response_model=MessageResponse)
114
+ async def delete_project(project_id: str, user_id: str):
115
+ """Delete a project and all its associated data"""
116
+ # Check ownership
117
+ project = rag.db["projects"].find_one({"project_id": project_id, "user_id": user_id})
118
+ if not project:
119
+ raise HTTPException(404, detail="Project not found")
120
+
121
+ # Delete project and all associated data
122
+ rag.db["projects"].delete_one({"project_id": project_id})
123
+ rag.db["chunks"].delete_many({"project_id": project_id})
124
+ rag.db["files"].delete_many({"project_id": project_id})
125
+ rag.db["chat_sessions"].delete_many({"project_id": project_id})
126
+
127
+ logger.info(f"[PROJECT] Deleted project {project_id} for user {user_id}")
128
+ return MessageResponse(message="Project deleted successfully")
129
+
130
+
routes/reports.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+ from typing import List, Dict
4
+ from fastapi import Form, HTTPException
5
+
6
+ from helpers.setup import app, rag, logger, embedder, gemini_rotator, nvidia_rotator
7
+ from helpers.models import ReportResponse
8
+ from utils.service.common import trim_text
9
+ from utils.api.router import select_model, generate_answer_with_model
10
+
11
+
12
+ @app.post("/report", response_model=ReportResponse)
13
+ async def generate_report(
14
+ user_id: str = Form(...),
15
+ project_id: str = Form(...),
16
+ filename: str = Form(...),
17
+ outline_words: int = Form(200),
18
+ report_words: int = Form(1200),
19
+ instructions: str = Form("")
20
+ ):
21
+ logger.info("[REPORT] User Q/report: %s", trim_text(instructions, 15).replace("\n", " "))
22
+ files_list = rag.list_files(user_id=user_id, project_id=project_id)
23
+ filenames_ci = {f.get("filename", "").lower(): f.get("filename") for f in files_list}
24
+ eff_name = filenames_ci.get(filename.lower(), filename)
25
+ doc_sum = rag.get_file_summary(user_id=user_id, project_id=project_id, filename=eff_name)
26
+ if not doc_sum:
27
+ raise HTTPException(404, detail="No summary found for that file.")
28
+
29
+ query_text = f"Comprehensive report for {eff_name}"
30
+ if instructions.strip():
31
+ query_text = f"{instructions} {eff_name}"
32
+ q_vec = embedder.embed([query_text])[0]
33
+ hits = rag.vector_search(user_id=user_id, project_id=project_id, query_vector=q_vec, k=8, filenames=[eff_name], search_type="flat")
34
+ if not hits:
35
+ hits = []
36
+
37
+ contexts: List[str] = []
38
+ sources_meta: List[Dict] = []
39
+ for h in hits:
40
+ doc = h["doc"]
41
+ chunk_id = str(doc.get("_id", ""))
42
+ contexts.append(f"[CHUNK_ID: {chunk_id}] [{doc.get('topic_name','Topic')}] {trim_text(doc.get('content',''), 2000)}")
43
+ sources_meta.append({
44
+ "filename": doc.get("filename"),
45
+ "topic_name": doc.get("topic_name"),
46
+ "page_span": doc.get("page_span"),
47
+ "score": float(h.get("score", 0.0)),
48
+ "chunk_id": chunk_id
49
+ })
50
+ context_text = "\n\n---\n\n".join(contexts) if contexts else ""
51
+ file_summary = doc_sum.get("summary", "")
52
+
53
+ from utils.api.router import GEMINI_MED, GEMINI_PRO
54
+ if instructions.strip():
55
+ filter_sys = (
56
+ "You are an expert content analyst. Given the user's specific instructions and the document content, "
57
+ "identify which sections/chunks are MOST relevant to their request. "
58
+ "Each chunk is prefixed with [CHUNK_ID: <id>] - use these exact IDs in your response. "
59
+ "Return a JSON object with this structure: {\"relevant_chunks\": [\"<chunk_id_1>\", \"<chunk_id_2>\"], \"focus_areas\": [\"key topic 1\", \"key topic 2\"]}"
60
+ )
61
+ filter_user = f"USER_INSTRUCTIONS: {instructions}\n\nDOCUMENT_SUMMARY: {file_summary}\n\nAVAILABLE_CHUNKS:\n{context_text}\n\nIdentify only the chunks that directly address the user's specific request."
62
+ try:
63
+ selection_filter = {"provider": "gemini", "model": os.getenv("GEMINI_MED", "gemini-2.5-flash")}
64
+ filter_response = await generate_answer_with_model(selection_filter, filter_sys, filter_user, gemini_rotator, nvidia_rotator)
65
+ logger.info(f"[REPORT] Raw filter response: {filter_response}")
66
+ import json as _json
67
+ try:
68
+ filter_data = _json.loads(filter_response)
69
+ relevant_chunk_ids = filter_data.get("relevant_chunks", [])
70
+ focus_areas = filter_data.get("focus_areas", [])
71
+ logger.info(f"[REPORT] Content filtering identified {len(relevant_chunk_ids)} relevant chunks: {relevant_chunk_ids} and focus areas: {focus_areas}")
72
+ if relevant_chunk_ids and hits:
73
+ filtered_hits = [h for h in hits if str(h["doc"].get("_id", "")) in relevant_chunk_ids]
74
+ if filtered_hits:
75
+ hits = filtered_hits
76
+ logger.info(f"[REPORT] Filtered context from {len(hits)} chunks to {len(filtered_hits)} relevant chunks")
77
+ else:
78
+ logger.warning(f"[REPORT] No matching chunks found for IDs: {relevant_chunk_ids}")
79
+ else:
80
+ logger.warning(f"[REPORT] No relevant chunk IDs returned or no hits available")
81
+ except _json.JSONDecodeError as e:
82
+ logger.warning(f"[REPORT] Could not parse filter response, using all chunks. JSON error: {e}. Response: {filter_response}")
83
+ except Exception as e:
84
+ logger.warning(f"[REPORT] Content filtering failed: {e}")
85
+
86
+ sys_outline = (
87
+ "You are an expert technical writer. Create a focused, hierarchical outline for a report based on the user's specific instructions and the MATERIALS. "
88
+ "The outline should directly address what the user asked for. Output as Markdown bullet list only. Keep it within about {} words."
89
+ ).format(max(100, outline_words))
90
+ instruction_context = f"USER_REQUEST: {instructions}\n\n" if instructions.strip() else ""
91
+ user_outline = f"{instruction_context}MATERIALS:\n\n[FILE_SUMMARY from {eff_name}]\n{file_summary}\n\n[DOC_CONTEXT]\n{context_text}"
92
+ try:
93
+ selection_outline = {"provider": "gemini", "model": os.getenv("GEMINI_MED", "gemini-2.5-flash")}
94
+ outline_md = await generate_answer_with_model(selection_outline, sys_outline, user_outline, gemini_rotator, nvidia_rotator)
95
+ except Exception as e:
96
+ logger.warning(f"Report outline failed: {e}")
97
+ outline_md = "# Report Outline\n\n- Introduction\n- Key Topics\n- Conclusion"
98
+
99
+ instruction_focus = f"FOCUS ON: {instructions}\n\n" if instructions.strip() else ""
100
+ sys_report = (
101
+ "You are an expert report writer. Write a focused, comprehensive Markdown report that directly addresses the user's specific request. "
102
+ "Using the OUTLINE and MATERIALS:\n"
103
+ "- Structure the report to answer exactly what the user asked for\n"
104
+ "- Use clear section headings\n"
105
+ "- Keep content factual and grounded in the provided materials\n"
106
+ f"- Include brief citations like (source: {eff_name}, topic) - use the actual filename provided\n"
107
+ "- If the user asked for a specific section/topic, focus heavily on that\n"
108
+ f"- Target length ~{max(600, report_words)} words\n"
109
+ "- Ensure the report directly fulfills the user's request"
110
+ )
111
+ user_report = f"{instruction_focus}OUTLINE:\n{outline_md}\n\nMATERIALS:\n[FILE_SUMMARY from {eff_name}]\n{file_summary}\n\n[DOC_CONTEXT]\n{context_text}"
112
+ try:
113
+ selection_report = {"provider": "gemini", "model": os.getenv("GEMINI_PRO", "gemini-2.5-pro")}
114
+ report_md = await generate_answer_with_model(selection_report, sys_report, user_report, gemini_rotator, nvidia_rotator)
115
+ except Exception as e:
116
+ logger.error(f"Report generation failed: {e}")
117
+ report_md = outline_md + "\n\n" + file_summary
118
+ return ReportResponse(filename=eff_name, report_markdown=report_md, sources=sources_meta)
119
+
120
+
121
+ @app.post("/report/pdf")
122
+ async def generate_report_pdf(
123
+ user_id: str = Form(...),
124
+ project_id: str = Form(...),
125
+ report_content: str = Form(...)
126
+ ):
127
+ from utils.service.pdf import generate_report_pdf as generate_pdf
128
+ from fastapi.responses import Response
129
+ try:
130
+ pdf_content = await generate_pdf(report_content, user_id, project_id)
131
+ return Response(
132
+ content=pdf_content,
133
+ media_type="application/pdf",
134
+ headers={"Content-Disposition": f"attachment; filename=report-{datetime.now().strftime('%Y-%m-%d')}.pdf"}
135
+ )
136
+ except HTTPException:
137
+ raise
138
+
139
+