from flask import Flask, request, jsonify, render_template, session from transformers import AutoTokenizer, AutoModelForCausalLM import torch from flask_cors import CORS import os import sqlite3 import datetime from dotenv import load_dotenv # ✅ Load environment variables load_dotenv() # ✅ Configure Flask app app = Flask(__name__, static_folder='static', template_folder='templates') app.secret_key = os.getenv('SECRET_KEY', os.urandom(24)) # Use .env for secret key CORS(app) # ✅ Automatically find the latest trained model folder def get_latest_model_dir(base_dir="chatbot_model"): try: subdirs = [ d for d in os.listdir(base_dir) if d.startswith("trained_model_") and os.path.isdir(os.path.join(base_dir, d)) ] if not subdirs: raise Exception("❌ No trained model found in chatbot_model/ directory.") latest = sorted(subdirs)[-1] return os.path.abspath(os.path.join(base_dir, latest)) except Exception as e: print(f"❌ Error finding model directory: {e}") raise # ✅ Load model from latest folder try: model_path = get_latest_model_dir() print(f"📦 Loading model from: {model_path}") tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(model_path) except Exception as e: print(f"❌ Failed to load model: {e}") raise # ✅ Setup device device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() # ✅ SQLite DB setup DB_FILE = "chat_history.db" def insert_chat(session_id, user_msg, bot_msg): try: with sqlite3.connect(DB_FILE) as conn: c = conn.cursor() c.execute("INSERT INTO chat (session_id, timestamp, user, bot) VALUES (?, ?, ?, ?)", (session_id, datetime.datetime.now().isoformat(), user_msg, bot_msg)) conn.commit() print(f"✅ Inserted chat for session_id={session_id}") except Exception as e: print(f"❌ Error inserting chat: {e}") def fetch_history(session_id, limit=100): try: with sqlite3.connect(DB_FILE) as conn: c = conn.cursor() c.execute("SELECT timestamp, user, bot FROM chat WHERE session_id=? ORDER BY id ASC LIMIT ?", (session_id, limit)) rows = c.fetchall() print(f"📜 Fetched {len(rows)} messages for session_id={session_id}") return rows except Exception as e: print(f"❌ Error fetching history for session_id={session_id}: {e}") return [] def get_all_sessions(): try: with sqlite3.connect(DB_FILE) as conn: c = conn.cursor() c.execute("SELECT session_id, created_at FROM sessions ORDER BY created_at DESC") sessions = c.fetchall() print(f"📋 Fetched {len(sessions)} sessions") return sessions except Exception as e: print(f"❌ Error fetching sessions: {e}") return [] def create_new_session(): try: session_id = datetime.datetime.now().strftime("session_%Y%m%d_%H%M%S") with sqlite3.connect(DB_FILE) as conn: c = conn.cursor() c.execute("INSERT INTO sessions (session_id, created_at) VALUES (?, ?)", (session_id, datetime.datetime.now().isoformat())) conn.commit() print(f"✅ Created new session: {session_id}") return session_id except Exception as e: print(f"❌ Error creating new session: {e}") raise def delete_session(session_id): try: with sqlite3.connect(DB_FILE) as conn: c = conn.cursor() c.execute("DELETE FROM chat WHERE session_id=?", (session_id,)) c.execute("DELETE FROM sessions WHERE session_id=?", (session_id,)) conn.commit() print(f"🗑️ Deleted session_id={session_id}") # Reset session if deleted session is active if session.get('chat_id') == session_id: session.pop('chat_id', None) print(f"🔄 Reset active session_id={session_id}") except Exception as e: print(f"❌ Error deleting session_id={session_id}: {e}") # ✅ Route for home page @app.route('/') def home(): try: if 'chat_id' not in session: session['chat_id'] = create_new_session() print(f"🏠 Initialized session_id={session['chat_id']} for new user") return render_template('index.html') except Exception as e: print(f"❌ Error in /: {e}") return jsonify({"error": "Failed to initialize session"}), 500 # ✅ Create new session and return it @app.route('/new_session', methods=['GET']) def new_session(): try: new_sid = create_new_session() session['chat_id'] = new_sid print(f"➕ New session created: {new_sid}") return jsonify({"session_id": new_sid}) except Exception as e: print(f"❌ Error in /new_session: {e}") return jsonify({"error": str(e)}), 500 # ✅ POST endpoint for chatbot @app.route('/chat', methods=['POST']) def chat(): try: data = request.get_json() user_input = data.get("message", "").strip() if not user_input: return jsonify({"error": "Missing or empty 'message' in request"}), 400 session_id = session.get("chat_id") if not session_id: session_id = create_new_session() session['chat_id'] = session_id print(f"🔄 Created new session_id={session_id} for chat") prompt = f"<|prompter|> {user_input} <|endoftext|><|assistant|>" inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): output = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=800, do_sample=True, top_k=50, top_p=0.95, temperature=0.7, repetition_penalty=1.15, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id ) decoded = tokenizer.decode(output[0], skip_special_tokens=True) reply = decoded.split("<|assistant|>")[-1].split("<|prompter|>")[0].strip() if not reply: reply = "⚠️ Sorry, I couldn't generate a response. Please try again." insert_chat(session_id, user_input, reply) return jsonify({"response": reply}) except Exception as e: print(f"❌ Error in /chat: {e}") return jsonify({"error": str(e)}), 500 # ✅ GET current session chat @app.route('/history', methods=['GET']) def history(): try: session_id = session.get("chat_id") if not session_id: print("⚠️ No session_id found in session") return jsonify([]) rows = fetch_history(session_id) return jsonify([{"timestamp": t, "user": u, "bot": b} for t, u, b in rows]) except Exception as e: print(f"❌ Error in /history: {e}") return jsonify({"error": str(e)}), 500 # ✅ GET chat for specific session @app.route('/history/', methods=['GET']) def session_history(session_id): try: rows = fetch_history(session_id) return jsonify([{"timestamp": t, "user": u, "bot": b} for t, u, b in rows]) except Exception as e: print(f"❌ Error in /history/{session_id}: {e}") return jsonify({"error": str(e)}), 500 # ✅ GET all sessions @app.route('/sessions', methods=['GET']) def list_sessions(): try: sessions = get_all_sessions() return jsonify([{"session_id": sid, "created_at": ts} for sid, ts in sessions]) except Exception as e: print(f"❌ Error in /sessions: {e}") return jsonify({"error": str(e)}), 500 # ✅ DELETE session @app.route('/sessions/', methods=['DELETE']) def delete_session_route(session_id): try: delete_session(session_id) return jsonify({"status": "deleted"}) except Exception as e: print(f"❌ Error in /sessions/{session_id}: {e}") return jsonify({"error": str(e)}), 500 # ✅ Initialize database and run server if __name__ == '__main__': # Ensure database is initialized without merging init_db.py try: with sqlite3.connect(DB_FILE) as conn: c = conn.cursor() c.execute(''' CREATE TABLE IF NOT EXISTS chat ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT NOT NULL, timestamp TEXT NOT NULL, user TEXT, bot TEXT ) ''') c.execute(''' CREATE TABLE IF NOT EXISTS sessions ( session_id TEXT PRIMARY KEY, created_at TEXT NOT NULL ) ''') conn.commit() print("✅ Database initialized successfully") except Exception as e: print(f"❌ Error initializing database: {e}") raise app.run(debug=True, port=5005)