sanatan_ai / server.py
vikramvasudevan's picture
Upload folder using huggingface_hub
ceb171d verified
raw
history blame
20.2 kB
# server.py
import random
import traceback
from typing import Optional
import uuid
from fastapi import APIRouter, HTTPException, Request, Query
from fastapi.responses import JSONResponse
import pycountry
from pydantic import BaseModel
from chat_utils import chat
from config import SanatanConfig
from db import SanatanDatabase
from metadata import MetadataWhereClause
from modules.audio.model import AudioRequest, AudioType
from modules.audio.service import svc_get_audio_urls, svc_get_indices_with_audio
from modules.config.categories import get_scripture_categories
from modules.dropbox.discources import get_discourse_by_id, get_discourse_summaries
from modules.languages.get_v2 import handle_fetch_languages_v2
from modules.quiz.answer_validator import validate_answer
from modules.quiz.models import Question
from modules.quiz.quiz_helper import generate_question
import logging
from modules.video.model import VideoRequest
from modules.video.service import svc_get_video_urls
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
router = APIRouter()
# In-memory mapping from session_id -> thread_id
# For production, you may want Redis or a DB for persistence
thread_map = {}
class Message(BaseModel):
language: str
text: str
session_id: str | None = None # Optional session ID from client
class QuizGeneratePayload(BaseModel):
language: Optional[str] = "English"
scripture: Optional[str] = None
complexity: Optional[str] = None
mode: Optional[str] = None
session_id: Optional[str] = None # Optional session ID from client
class QuizEvalPayload(BaseModel):
language: Optional[str] = "English"
q: Question
answer: str
session_id: Optional[str] = None # Optional session ID from client
LANG_NATIVE_NAMES = {
"en": "English",
"fr": "Français",
"es": "Español",
"hi": "हिन्दी",
"bn": "বাংলা",
"te": "తెలుగు",
"mr": "मराठी",
"ta": "தமிழ்",
"ur": "اردو",
"gu": "ગુજરાતી",
"kn": "ಕನ್ನಡ",
"ml": "മലയാളം",
"pa": "ਪੰਜਾਬੀ",
"as": "অসমীয়া",
"mai": "मैथिली",
"sd": "سنڌي",
"sat": "ᱥᱟᱱᱛᱟᱲᱤ",
}
@router.get("/languages")
async def handle_fetch_languages():
supported_lang_codes = [
"en",
"fr",
"es",
"hi",
"bn",
"te",
"mr",
"ta",
"ur",
"gu",
"kn",
"ml",
"pa",
"as",
"mai",
"sd",
"sat",
]
languages = []
for code in supported_lang_codes:
lang = pycountry.languages.get(alpha_2=code) or pycountry.languages.get(
alpha_3=code
)
if lang is None:
continue # skip unknown codes
english_name = lang.name
native_name = LANG_NATIVE_NAMES.get(code, english_name)
languages.append(
{
"code": code,
"name": english_name,
"native_name": native_name,
}
)
languages.sort(key=lambda x: x["name"])
return languages
@router.get("/languages_v2")
async def fn_handle_fetch_languages_v2():
val = await handle_fetch_languages_v2()
return val
@router.post("/greet")
async def handle_greet(msg: Message):
markdown = "Namaskaram 🙏 I am **bhashyam.ai** and I can help you explore the following scriptures:\n---\n"
for scripture in sorted(SanatanConfig().scriptures, key=lambda doc: doc["title"]):
num_units = SanatanDatabase().count(
collection_name=scripture["collection_name"]
)
markdown += f"- {scripture['title']} : `{num_units}` {scripture["unit"]}s\n"
session_id = msg.session_id
if not session_id:
session_id = str(uuid.uuid4())
return {"reply": markdown, "session_id": session_id}
@router.post("/chat")
async def handle_chat(msg: Message, request: Request):
try:
# Use existing session_id if provided, else generate new
session_id = msg.session_id
if not session_id:
session_id = str(uuid.uuid4())
print(session_id, ": user sent message : ", msg.text)
# Get or create a persistent thread_id for this session
if session_id not in thread_map:
thread_map[session_id] = str(uuid.uuid4())
thread_id = thread_map[session_id]
# Call your graph/chat function
reply_text = chat(
debug_mode=False,
message=msg.text,
history=None,
thread_id=thread_id,
preferred_language=msg.language or "English",
)
# Return both reply and session_id to the client
return {"reply": reply_text, "session_id": session_id}
except Exception as e:
traceback.print_exc()
return JSONResponse(status_code=500, content={"reply": f"Error: {e}"})
@router.post("/quiz/generate")
async def handle_quiz_generate(payload: QuizGeneratePayload, request: Request):
q = generate_question(
collection=payload.scripture
or random.choice(
[
s["collection_name"]
for s in SanatanConfig.scriptures
if s["collection_name"] != "yt_metadata"
]
),
complexity=payload.complexity
or random.choice(["beginner", "intermediate", "advanced"]),
mode=payload.mode or random.choice(["mcq", "open"]),
preferred_lamguage=payload.language or "English",
)
print(q.model_dump_json(indent=1))
return q.model_dump()
@router.post("/quiz/eval")
async def handle_quiz_eval(payload: QuizEvalPayload, request: Request):
result = validate_answer(
payload.q, payload.answer, preferred_language=payload.language or "English"
)
print(result.model_dump_json(indent=1))
return result
@router.get("/scriptures")
async def handle_get_scriptures():
return_values = {}
for scripture in SanatanConfig().scriptures:
if scripture["collection_name"] != "yt_metadata":
return_values[scripture["collection_name"]] = scripture["title"]
return return_values
class ScriptureRequest(BaseModel):
scripture_name: str
unit_index: int
@router.post("/scripture")
async def get_scripture(req: ScriptureRequest):
"""
Return a scripture unit (page or verse, based on config),
including all metadata fields separately.
used for page view to fetch by global index.
"""
logger.info("get_scripture: received request to fetch scripture: %s", req)
# find config entry for the scripture
config = next(
(s for s in SanatanConfig().scriptures if s["name"] == req.scripture_name), None
)
if not config:
return {"error": f"Scripture '{req.scripture_name}' not found"}
# fetch the raw document from DB
raw_doc = SanatanDatabase().fetch_document_by_index(
collection_name=config["collection_name"],
index=req.unit_index,
# unit_name=config.get("unit_field", config.get("unit")),
)
if not raw_doc or isinstance(raw_doc, str) or "error" in raw_doc:
return {"error": f"No data available for unit {req.unit_index}"}
# canonicalize it
canonical_doc = SanatanConfig().canonicalize_document(
scripture_name=req.scripture_name,
document_text=raw_doc.get("document", ""),
metadata_doc=raw_doc,
)
# add unit index & total units (so Flutter can paginate)
canonical_doc["total"] = SanatanDatabase().count(config["collection_name"])
# print("canonical_doc = ", canonical_doc)
return canonical_doc
@router.get("/scripture_configs")
async def get_scripture_configs():
scriptures = []
config = SanatanConfig()
for s in config.scriptures:
num_units = SanatanDatabase().count(collection_name=s["collection_name"])
# Deep copy metadata_fields so we don’t mutate the original config
metadata_fields = []
for f in s.get("metadata_fields", []):
f_copy = dict(f)
lov = f_copy.get("lov")
if callable(lov): # evaluate the function
try:
f_copy["lov"] = lov()
except Exception as e:
f_copy["lov"] = []
metadata_fields.append(f_copy)
scriptures.append(
{
"name": s["name"], # e.g. "bhagavad_gita"
"title": s["title"], # e.g. "Bhagavad Gita"
"banner_url": s.get("banner_url",None),
"category": s["category"], # e.g. "Philosophy"
"unit": s["unit"], # e.g. "verse" or "page"
"unit_field": s.get("unit_field", s.get("unit")),
"total": num_units,
"enabled": "field_mapping" in s,
"source": s.get("source", ""),
"credits": s.get(
"credits", {"art": [], "data": [], "audio": [], "video": []}
),
"metadata_fields": metadata_fields,
"field_mapping": config.remove_callables(s.get("field_mapping", {})),
}
)
return {"scriptures": sorted(scriptures, key=lambda s: s["title"])}
class ScriptureFirstSearchRequst(BaseModel):
filter_obj: Optional[MetadataWhereClause] = None
has_audio: Optional[AudioType] = None
@router.post("/scripture/{scripture_name}/search")
async def search_scripture_find_first_match(
scripture_name: str,
req: ScriptureFirstSearchRequst,
):
"""
Search scripture collection and return the first matching result after applying audio filter.
"""
filter_obj = req.filter_obj
has_audio = req.has_audio
try:
logger.info(
"search_scripture_find_first_match: searching for %s with filters=%s | has_audio=%s",
scripture_name,
filter_obj,
has_audio,
)
db = SanatanDatabase()
config = next(
(s for s in SanatanConfig().scriptures if s["name"] == scripture_name),
None,
)
if not config:
return {"error": f"Scripture '{scripture_name}' not found"}
# 1️⃣ Fetch all matches
if has_audio:
results = db.fetch_all_matches(
collection_name=config["collection_name"],
metadata_where_clause=filter_obj,
page=None, # Fetch all to apply audio filter
page_size=None,
)
else:
# optimization. get only first match if no has_audio parameter is provided.
result = db.fetch_first_match(
collection_name=config["collection_name"],
metadata_where_clause=filter_obj,
)
results = {
"ids": list(result["ids"]),
"documents": list(result["documents"]),
"metadatas": list(result["metadatas"]),
"total_matches": 1,
}
formatted_results = []
for i in range(len(results["metadatas"])):
doc_id = results["ids"][i]
metadata_doc = results["metadatas"][i]
metadata_doc["id"] = doc_id
document_text = (
results["documents"][i] if results.get("documents") else None
)
canonical_doc = SanatanConfig().canonicalize_document(
scripture_name, document_text, metadata_doc
)
formatted_results.append(canonical_doc)
# 2️⃣ Apply has_audio filter
if has_audio and formatted_results:
if has_audio == AudioType.none:
all_audio_indices = set()
for atype in [
AudioType.recitation,
AudioType.virutham,
AudioType.upanyasam,
AudioType.santhai,
]:
indices = await svc_get_indices_with_audio(scripture_name, atype)
all_audio_indices.update(indices)
formatted_results = [
r
for r in formatted_results
if r["_global_index"] not in all_audio_indices
]
else:
audio_indices = set()
if has_audio == AudioType.any:
for atype in [
AudioType.recitation,
AudioType.virutham,
AudioType.upanyasam,
AudioType.santhai,
]:
indices = await svc_get_indices_with_audio(
scripture_name, atype
)
audio_indices.update(indices)
else:
audio_indices.update(
await svc_get_indices_with_audio(scripture_name, has_audio)
)
formatted_results = [
r for r in formatted_results if r["_global_index"] in audio_indices
]
# 3️⃣ Sort by global index
formatted_results.sort(key=lambda x: x["_global_index"])
# print(f"formatted_results = {formatted_results}")
# 4️⃣ Return only the first valid result
return {"results": formatted_results[:1] if formatted_results else []}
except Exception as e:
logger.error("Error while searching %s", e, exc_info=True)
return {"error": str(e)}
class ScriptureMultiSearchRequest(BaseModel):
filter_obj: Optional[MetadataWhereClause] = None
page: int = 1
page_size: int = 20
has_audio: Optional[AudioType] = None
@router.post("/scripture/{scripture_name}/search/all")
async def search_scripture_find_all_matches(
scripture_name: str, req: ScriptureMultiSearchRequest
):
"""
Search scripture collection and return all matching results with pagination.
- `scripture_name`: Name of the collection
- `filter_obj`: MetadataWhereClause (filters, groups, operator)
- `page`: 1-based page number
- `page_size`: Number of results per page
- `has_audio` : optional. can take values any|none|recitation|virutham|upanyasam
"""
filter_obj = req.filter_obj
page = req.page
page_size = req.page_size
has_audio = req.has_audio
try:
logger.info(
"search_scripture_find_all_matches: searching for %s with filters %s | page=%s, page_size=%s, has_audio=%s",
scripture_name,
filter_obj,
page,
page_size,
has_audio,
)
db = SanatanDatabase()
config = next(
(s for s in SanatanConfig().scriptures if s["name"] == scripture_name),
None,
)
if not config:
return {"error": f"Scripture '{scripture_name}' not found"}
# 1️⃣ Fetch all matching metadata WITHOUT pagination yet
results = db.fetch_all_matches(
collection_name=config["collection_name"],
metadata_where_clause=filter_obj,
page=None, # Fetch all to apply audio filter
page_size=None,
)
formatted_results = []
all_indices = [] # Keep track of all _global_index
for i in range(len(results["metadatas"])):
doc_id = results["ids"][i]
metadata_doc = results["metadatas"][i]
metadata_doc["id"] = doc_id
document_text = (
results["documents"][i] if results.get("documents") else None
)
canonical_doc = SanatanConfig().canonicalize_document(
scripture_name, document_text, metadata_doc
)
formatted_results.append(canonical_doc)
all_indices.append(canonical_doc["_global_index"])
# 2️⃣ Apply has_audio filter
if has_audio:
if has_audio == AudioType.none:
# Fetch all indices that have any audio type
all_audio_indices = set()
for atype in [
AudioType.recitation,
AudioType.virutham,
AudioType.upanyasam,
AudioType.santhai,
]:
indices = await svc_get_indices_with_audio(scripture_name, atype)
all_audio_indices.update(indices)
# Keep only indices that are NOT in all_audio_indices
formatted_results = [
r
for r in formatted_results
if r["_global_index"] not in all_audio_indices
]
else:
if has_audio == AudioType.any:
# Combine indices for all audio types
audio_indices = set()
for atype in [
AudioType.recitation,
AudioType.virutham,
AudioType.upanyasam,
AudioType.santhai,
]:
indices = await svc_get_indices_with_audio(
scripture_name, atype
)
audio_indices.update(indices)
else:
audio_indices = set(
await svc_get_indices_with_audio(scripture_name, has_audio)
)
# Keep only indices that match
formatted_results = [
r for r in formatted_results if r["_global_index"] in audio_indices
]
# 3️⃣ Apply pagination on filtered results
total_matches = len(formatted_results)
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
paginated_results = formatted_results[start_idx:end_idx]
return {
"results": paginated_results,
"total_matches": total_matches,
"page": page,
"page_size": page_size,
}
except Exception as e:
logger.error("Error while searching %s", e, exc_info=True)
return {"error": str(e)}
@router.post("/audio")
async def generate_audio_urls(req: AudioRequest):
logger.info("generate_audio_urls: %s", req)
audio_urls = await svc_get_audio_urls(req)
return audio_urls
@router.post("/video")
async def generate_audio_urls(req: VideoRequest):
logger.info("generate_audio_urls: %s", req)
video_urls = await svc_get_video_urls(req)
return video_urls
@router.get("/scripture_categories")
def route_get_scripture_categories():
return get_scripture_categories()
@router.get("/donation/products")
def route_get_donation_product_ids(include_tests: bool = False):
products = [
{"id": "donation_unit_0100"},
{"id": "donation_unit_0500"},
{"id": "donation_unit_1000"},
{"id": "donation_unit_2500"},
{"id": "donation_unit_5000"},
]
if include_tests:
products += [
{"id": "android.test.purchased"},
{"id": "android.test.canceled"},
{"id": "android.test.refunded"},
{"id": "android.test.item_unavailable"},
]
return products
@router.get("/discourse/list")
async def get_all_discourses(
page: int = Query(1, ge=1, description="Page number (1-indexed)"),
per_page: int = Query(10, ge=1, le=100, description="Number of items per page")
):
"""
Returns a paginated list of discourse topics.
Each topic includes:
- id
- topic_name
- thumbnail_url
"""
result = await get_discourse_summaries(page=page, per_page=per_page)
return result
@router.get("/discourse/find/{topic_id}")
async def get_discourse_detail(topic_id: int):
"""
Returns the full details of a discourse topic by its unique ID.
"""
topic = await get_discourse_by_id(topic_id)
if not topic:
raise HTTPException(status_code=404, detail="Discourse topic not found")
return topic