sanatan_ai / server.py
vikramvasudevan's picture
Upload folder using huggingface_hub
0a716e1 verified
raw
history blame
10.6 kB
# server.py
import json
import random
import traceback
from typing import Optional
import uuid
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
import pycountry
from pydantic import BaseModel
from chat_utils import chat
from config import SanatanConfig
from db import SanatanDatabase
from metadata import MetadataWhereClause
from modules.audio.model import AudioRequest
from modules.audio.service import svc_get_audio_urls
from modules.quiz.answer_validator import validate_answer
from modules.quiz.models import Question
from modules.quiz.quiz_helper import generate_question
import logging
from modules.video.model import VideoRequest
from modules.video.service import svc_get_video_urls
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
router = APIRouter()
# In-memory mapping from session_id -> thread_id
# For production, you may want Redis or a DB for persistence
thread_map = {}
class Message(BaseModel):
language: str
text: str
session_id: str | None = None # Optional session ID from client
class QuizGeneratePayload(BaseModel):
language: Optional[str] = "English"
scripture: Optional[str] = None
complexity: Optional[str] = None
mode: Optional[str] = None
session_id: Optional[str] = None # Optional session ID from client
class QuizEvalPayload(BaseModel):
language: Optional[str] = "English"
q: Question
answer: str
session_id: Optional[str] = None # Optional session ID from client
LANG_NATIVE_NAMES = {
"en": "English",
"fr": "Français",
"es": "Español",
"hi": "हिन्दी",
"bn": "বাংলা",
"te": "తెలుగు",
"mr": "मराठी",
"ta": "தமிழ்",
"ur": "اردو",
"gu": "ગુજરાતી",
"kn": "ಕನ್ನಡ",
"ml": "മലയാളം",
"pa": "ਪੰਜਾਬੀ",
"as": "অসমীয়া",
"mai": "मैथिली",
"sd": "سنڌي",
"sat": "ᱥᱟᱱᱛᱟᱲᱤ",
}
@router.get("/languages")
async def handle_fetch_languages():
supported_lang_codes = [
"en",
"fr",
"es",
"hi",
"bn",
"te",
"mr",
"ta",
"ur",
"gu",
"kn",
"ml",
"pa",
"as",
"mai",
"sd",
"sat",
]
languages = []
for code in supported_lang_codes:
lang = pycountry.languages.get(alpha_2=code) or pycountry.languages.get(
alpha_3=code
)
if lang is None:
continue # skip unknown codes
english_name = lang.name
native_name = LANG_NATIVE_NAMES.get(code, english_name)
languages.append(
{
"code": code,
"name": english_name,
"native_name": native_name,
}
)
languages.sort(key=lambda x: x["name"])
return languages
@router.post("/greet")
async def handle_greet(msg: Message):
markdown = "Namaskaram 🙏 I am **bhashyam.ai** and I can help you explore the following scriptures:\n---\n"
for scripture in sorted(SanatanConfig().scriptures, key=lambda doc: doc["title"]):
num_units = SanatanDatabase().count(
collection_name=scripture["collection_name"]
)
markdown += f"- {scripture['title']} : `{num_units}` {scripture["unit"]}s\n"
session_id = msg.session_id
if not session_id:
session_id = str(uuid.uuid4())
return {"reply": markdown, "session_id": session_id}
@router.post("/chat")
async def handle_chat(msg: Message, request: Request):
try:
# Use existing session_id if provided, else generate new
session_id = msg.session_id
if not session_id:
session_id = str(uuid.uuid4())
print(session_id, ": user sent message : ", msg.text)
# Get or create a persistent thread_id for this session
if session_id not in thread_map:
thread_map[session_id] = str(uuid.uuid4())
thread_id = thread_map[session_id]
# Call your graph/chat function
reply_text = chat(
debug_mode=False,
message=msg.text,
history=None,
thread_id=thread_id,
preferred_language=msg.language or "English",
)
# Return both reply and session_id to the client
return {"reply": reply_text, "session_id": session_id}
except Exception as e:
traceback.print_exc()
return JSONResponse(status_code=500, content={"reply": f"Error: {e}"})
@router.post("/quiz/generate")
async def handle_quiz_generate(payload: QuizGeneratePayload, request: Request):
q = generate_question(
collection=payload.scripture
or random.choice(
[
s["collection_name"]
for s in SanatanConfig.scriptures
if s["collection_name"] != "yt_metadata"
]
),
complexity=payload.complexity
or random.choice(["beginner", "intermediate", "advanced"]),
mode=payload.mode or random.choice(["mcq", "open"]),
preferred_lamguage=payload.language or "English",
)
print(q.model_dump_json(indent=1))
return q.model_dump()
@router.post("/quiz/eval")
async def handle_quiz_eval(payload: QuizEvalPayload, request: Request):
result = validate_answer(
payload.q, payload.answer, preferred_language=payload.language or "English"
)
print(result.model_dump_json(indent=1))
return result
@router.get("/scriptures")
async def handle_get_scriptures():
return_values = {}
for scripture in SanatanConfig().scriptures:
if scripture["collection_name"] != "yt_metadata":
return_values[scripture["collection_name"]] = scripture["title"]
return return_values
class ScriptureRequest(BaseModel):
scripture_name: str
unit_index: int
@router.post("/scripture")
async def get_scripture(req: ScriptureRequest):
"""
Return a scripture unit (page or verse, based on config),
including all metadata fields separately.
used for page view to fetch by global index.
"""
logger.info("get_scripture: received request to fetch scripture: %s", req)
# find config entry for the scripture
config = next(
(s for s in SanatanConfig().scriptures if s["name"] == req.scripture_name), None
)
if not config:
return {"error": f"Scripture '{req.scripture_name}' not found"}
# fetch the raw document from DB
raw_doc = SanatanDatabase().fetch_document_by_index(
collection_name=config["collection_name"],
index=req.unit_index,
# unit_name=config.get("unit_field", config.get("unit")),
)
if not raw_doc or isinstance(raw_doc, str) or "error" in raw_doc:
return {"error": f"No data available for unit {req.unit_index}"}
# canonicalize it
canonical_doc = SanatanConfig().canonicalize_document(
scripture_name=req.scripture_name,
document_text=raw_doc.get("document", ""),
metadata_doc=raw_doc,
)
# add unit index & total units (so Flutter can paginate)
canonical_doc["total"] = SanatanDatabase().count(config["collection_name"])
print("canonical_doc = ", canonical_doc)
return canonical_doc
@router.get("/scripture_configs")
async def get_scripture_configs():
scriptures = []
for s in SanatanConfig().scriptures:
num_units = SanatanDatabase().count(collection_name=s["collection_name"])
# Deep copy metadata_fields so we don’t mutate the original config
metadata_fields = []
for f in s.get("metadata_fields", []):
f_copy = dict(f)
lov = f_copy.get("lov")
if callable(lov): # evaluate the function
try:
f_copy["lov"] = lov()
except Exception as e:
f_copy["lov"] = []
metadata_fields.append(f_copy)
scriptures.append(
{
"name": s["name"], # e.g. "bhagavad_gita"
"title": s["title"], # e.g. "Bhagavad Gita"
"unit": s["unit"], # e.g. "verse" or "page"
"unit_field": s.get("unit_field", s.get("unit")),
"total": num_units,
"enabled": "field_mapping" in s,
"source": s.get("source", ""),
"credits": s.get("credits", f"{s.get('source','')}"),
"metadata_fields": metadata_fields,
}
)
return {"scriptures": sorted(scriptures, key=lambda s: s["title"])}
@router.post("/scripture/{scripture_name}/search")
async def search_scripture(
scripture_name: str,
filter_obj: Optional[MetadataWhereClause] = None,
):
"""
Search scripture collection with optional filters.
- `scripture_name`: Name of the collection
- `filter_obj`: MetadataWhereClause (filters, groups, operator)
- `n_results`: number of random results to return
"""
try:
db = SanatanDatabase()
config = next(
(s for s in SanatanConfig().scriptures if s["name"] == scripture_name), None
)
results = db.fetch_first_match(
collection_name=config["collection_name"],
metadata_where_clause=filter_obj,
)
print("results = ", results)
# Flatten + canonicalize results
formatted_results = []
for i in range(len(results["metadatas"])):
id = results["ids"][i]
metadata_doc = results["metadatas"][i]
metadata_doc["id"] = id
print("metadata_doc = ", metadata_doc)
document_text = (
results["documents"][i] if results.get("documents") else None
)
canonical_doc = SanatanConfig().canonicalize_document(
scripture_name, document_text, metadata_doc
)
formatted_results.append(canonical_doc)
print("formatted_results = ", formatted_results)
return {"results": formatted_results}
except Exception as e:
logger.error("Error while searching %s", e, exc_info=True)
return {"error": str(e)}
@router.post("/audio")
async def generate_audio_urls(req: AudioRequest):
logger.info("generate_audio_urls: %s", req)
audio_urls = await svc_get_audio_urls(req)
return audio_urls
@router.post("/video")
async def generate_audio_urls(req: VideoRequest):
logger.info("generate_audio_urls: %s", req)
video_urls = await svc_get_video_urls(req)
return video_urls