Jasleen05 commited on
Commit
8f998ed
Β·
verified Β·
1 Parent(s): 9e93605

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -62
app.py CHANGED
@@ -12,25 +12,33 @@ load_dotenv()
12
 
13
  # βœ… Configure Flask app
14
  app = Flask(__name__, static_folder='static', template_folder='templates')
15
- app.secret_key = os.urandom(24) # for sessions
16
  CORS(app)
17
 
18
- # βœ… Automatically find the latest trained model folder in chatbot_model/
19
  def get_latest_model_dir(base_dir="chatbot_model"):
20
- subdirs = [
21
- d for d in os.listdir(base_dir)
22
- if d.startswith("trained_model_") and os.path.isdir(os.path.join(base_dir, d))
23
- ]
24
- if not subdirs:
25
- raise Exception("❌ No trained model found in chatbot_model/ directory.")
26
- latest = sorted(subdirs)[-1]
27
- return os.path.abspath(os.path.join(base_dir, latest))
 
 
 
 
28
 
29
  # βœ… Load model from latest folder
30
- model_path = get_latest_model_dir()
31
- print(f"πŸ“¦ Loading model from: {model_path}")
32
- tokenizer = AutoTokenizer.from_pretrained(model_path)
33
- model = AutoModelForCausalLM.from_pretrained(model_path)
 
 
 
 
34
 
35
  # βœ… Setup device
36
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -40,41 +48,94 @@ model.eval()
40
  # βœ… SQLite DB setup
41
  DB_FILE = "chat_history.db"
42
 
43
- def init_db():
44
- with sqlite3.connect(DB_FILE) as conn:
45
- c = conn.cursor()
46
- c.execute('''
47
- CREATE TABLE IF NOT EXISTS chat (
48
- id INTEGER PRIMARY KEY AUTOINCREMENT,
49
- session_id TEXT,
50
- timestamp TEXT,
51
- user TEXT,
52
- bot TEXT
53
- )
54
- ''')
55
- conn.commit()
56
-
57
  def insert_chat(session_id, user_msg, bot_msg):
58
- with sqlite3.connect(DB_FILE) as conn:
59
- c = conn.cursor()
60
- c.execute("INSERT INTO chat (session_id, timestamp, user, bot) VALUES (?, ?, ?, ?)",
61
- (session_id, datetime.datetime.now().isoformat(), user_msg, bot_msg))
62
- conn.commit()
63
-
64
- def fetch_history(session_id, limit=50):
65
- with sqlite3.connect(DB_FILE) as conn:
66
- c = conn.cursor()
67
- c.execute("SELECT timestamp, user, bot FROM chat WHERE session_id=? ORDER BY id DESC LIMIT ?",
68
- (session_id, limit))
69
- history = c.fetchall()
70
- return history[::-1] # latest at bottom
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # βœ… Route for home page
73
  @app.route('/')
74
  def home():
75
- # Start new chat session with unique session ID
76
- session['chat_id'] = datetime.datetime.now().strftime("session_%Y%m%d_%H%M%S")
77
- return render_template('index.html')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  # βœ… POST endpoint for chatbot
80
  @app.route('/chat', methods=['POST'])
@@ -86,13 +147,15 @@ def chat():
86
  if not user_input:
87
  return jsonify({"error": "Missing or empty 'message' in request"}), 400
88
 
89
- session_id = session.get("chat_id", "default_session")
 
 
 
 
90
 
91
- # πŸ” Format prompt (OpenAssistant style)
92
  prompt = f"<|prompter|> {user_input} <|endoftext|><|assistant|>"
93
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
94
 
95
- # 🧠 Generate response
96
  with torch.no_grad():
97
  output = model.generate(
98
  input_ids=inputs["input_ids"],
@@ -107,35 +170,88 @@ def chat():
107
  eos_token_id=tokenizer.eos_token_id
108
  )
109
 
110
- # 🧹 Decode and clean output
111
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
112
  reply = decoded.split("<|assistant|>")[-1].split("<|prompter|>")[0].strip()
113
 
114
- print("🧠 Prompt:", prompt)
115
- print("πŸ’¬ Bot Reply:", reply or "[EMPTY]")
116
-
117
  if not reply:
118
  reply = "⚠️ Sorry, I couldn't generate a response. Please try again."
119
 
120
- # πŸ’Ύ Save to DB
121
  insert_chat(session_id, user_input, reply)
122
 
123
  return jsonify({"response": reply})
124
 
125
  except Exception as e:
 
126
  return jsonify({"error": str(e)}), 500
127
 
128
- # βœ… GET endpoint for chat history of current session
129
  @app.route('/history', methods=['GET'])
130
  def history():
131
- session_id = session.get("chat_id", "default_session")
132
- history = fetch_history(session_id)
133
- return jsonify([
134
- {"timestamp": t, "user": u, "bot": b}
135
- for t, u, b in history
136
- ])
137
-
138
- # βœ… Run the Flask server
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  if __name__ == '__main__':
140
- init_db()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  app.run(debug=True, port=5005)
 
12
 
13
  # βœ… Configure Flask app
14
  app = Flask(__name__, static_folder='static', template_folder='templates')
15
+ app.secret_key = os.getenv('SECRET_KEY', os.urandom(24)) # Use .env for secret key
16
  CORS(app)
17
 
18
+ # βœ… Automatically find the latest trained model folder
19
  def get_latest_model_dir(base_dir="chatbot_model"):
20
+ try:
21
+ subdirs = [
22
+ d for d in os.listdir(base_dir)
23
+ if d.startswith("trained_model_") and os.path.isdir(os.path.join(base_dir, d))
24
+ ]
25
+ if not subdirs:
26
+ raise Exception("❌ No trained model found in chatbot_model/ directory.")
27
+ latest = sorted(subdirs)[-1]
28
+ return os.path.abspath(os.path.join(base_dir, latest))
29
+ except Exception as e:
30
+ print(f"❌ Error finding model directory: {e}")
31
+ raise
32
 
33
  # βœ… Load model from latest folder
34
+ try:
35
+ model_path = get_latest_model_dir()
36
+ print(f"πŸ“¦ Loading model from: {model_path}")
37
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
38
+ model = AutoModelForCausalLM.from_pretrained(model_path)
39
+ except Exception as e:
40
+ print(f"❌ Failed to load model: {e}")
41
+ raise
42
 
43
  # βœ… Setup device
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
48
  # βœ… SQLite DB setup
49
  DB_FILE = "chat_history.db"
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def insert_chat(session_id, user_msg, bot_msg):
52
+ try:
53
+ with sqlite3.connect(DB_FILE) as conn:
54
+ c = conn.cursor()
55
+ c.execute("INSERT INTO chat (session_id, timestamp, user, bot) VALUES (?, ?, ?, ?)",
56
+ (session_id, datetime.datetime.now().isoformat(), user_msg, bot_msg))
57
+ conn.commit()
58
+ print(f"βœ… Inserted chat for session_id={session_id}")
59
+ except Exception as e:
60
+ print(f"❌ Error inserting chat: {e}")
61
+
62
+ def fetch_history(session_id, limit=100):
63
+ try:
64
+ with sqlite3.connect(DB_FILE) as conn:
65
+ c = conn.cursor()
66
+ c.execute("SELECT timestamp, user, bot FROM chat WHERE session_id=? ORDER BY id ASC LIMIT ?",
67
+ (session_id, limit))
68
+ rows = c.fetchall()
69
+ print(f"πŸ“œ Fetched {len(rows)} messages for session_id={session_id}")
70
+ return rows
71
+ except Exception as e:
72
+ print(f"❌ Error fetching history for session_id={session_id}: {e}")
73
+ return []
74
+
75
+ def get_all_sessions():
76
+ try:
77
+ with sqlite3.connect(DB_FILE) as conn:
78
+ c = conn.cursor()
79
+ c.execute("SELECT session_id, created_at FROM sessions ORDER BY created_at DESC")
80
+ sessions = c.fetchall()
81
+ print(f"πŸ“‹ Fetched {len(sessions)} sessions")
82
+ return sessions
83
+ except Exception as e:
84
+ print(f"❌ Error fetching sessions: {e}")
85
+ return []
86
+
87
+ def create_new_session():
88
+ try:
89
+ session_id = datetime.datetime.now().strftime("session_%Y%m%d_%H%M%S")
90
+ with sqlite3.connect(DB_FILE) as conn:
91
+ c = conn.cursor()
92
+ c.execute("INSERT INTO sessions (session_id, created_at) VALUES (?, ?)",
93
+ (session_id, datetime.datetime.now().isoformat()))
94
+ conn.commit()
95
+ print(f"βœ… Created new session: {session_id}")
96
+ return session_id
97
+ except Exception as e:
98
+ print(f"❌ Error creating new session: {e}")
99
+ raise
100
+
101
+ def delete_session(session_id):
102
+ try:
103
+ with sqlite3.connect(DB_FILE) as conn:
104
+ c = conn.cursor()
105
+ c.execute("DELETE FROM chat WHERE session_id=?", (session_id,))
106
+ c.execute("DELETE FROM sessions WHERE session_id=?", (session_id,))
107
+ conn.commit()
108
+ print(f"πŸ—‘οΈ Deleted session_id={session_id}")
109
+ # Reset session if deleted session is active
110
+ if session.get('chat_id') == session_id:
111
+ session.pop('chat_id', None)
112
+ print(f"πŸ”„ Reset active session_id={session_id}")
113
+ except Exception as e:
114
+ print(f"❌ Error deleting session_id={session_id}: {e}")
115
 
116
  # βœ… Route for home page
117
  @app.route('/')
118
  def home():
119
+ try:
120
+ if 'chat_id' not in session:
121
+ session['chat_id'] = create_new_session()
122
+ print(f"🏠 Initialized session_id={session['chat_id']} for new user")
123
+ return render_template('index.html')
124
+ except Exception as e:
125
+ print(f"❌ Error in /: {e}")
126
+ return jsonify({"error": "Failed to initialize session"}), 500
127
+
128
+ # βœ… Create new session and return it
129
+ @app.route('/new_session', methods=['GET'])
130
+ def new_session():
131
+ try:
132
+ new_sid = create_new_session()
133
+ session['chat_id'] = new_sid
134
+ print(f"βž• New session created: {new_sid}")
135
+ return jsonify({"session_id": new_sid})
136
+ except Exception as e:
137
+ print(f"❌ Error in /new_session: {e}")
138
+ return jsonify({"error": str(e)}), 500
139
 
140
  # βœ… POST endpoint for chatbot
141
  @app.route('/chat', methods=['POST'])
 
147
  if not user_input:
148
  return jsonify({"error": "Missing or empty 'message' in request"}), 400
149
 
150
+ session_id = session.get("chat_id")
151
+ if not session_id:
152
+ session_id = create_new_session()
153
+ session['chat_id'] = session_id
154
+ print(f"πŸ”„ Created new session_id={session_id} for chat")
155
 
 
156
  prompt = f"<|prompter|> {user_input} <|endoftext|><|assistant|>"
157
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
158
 
 
159
  with torch.no_grad():
160
  output = model.generate(
161
  input_ids=inputs["input_ids"],
 
170
  eos_token_id=tokenizer.eos_token_id
171
  )
172
 
 
173
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
174
  reply = decoded.split("<|assistant|>")[-1].split("<|prompter|>")[0].strip()
175
 
 
 
 
176
  if not reply:
177
  reply = "⚠️ Sorry, I couldn't generate a response. Please try again."
178
 
 
179
  insert_chat(session_id, user_input, reply)
180
 
181
  return jsonify({"response": reply})
182
 
183
  except Exception as e:
184
+ print(f"❌ Error in /chat: {e}")
185
  return jsonify({"error": str(e)}), 500
186
 
187
+ # βœ… GET current session chat
188
  @app.route('/history', methods=['GET'])
189
  def history():
190
+ try:
191
+ session_id = session.get("chat_id")
192
+ if not session_id:
193
+ print("⚠️ No session_id found in session")
194
+ return jsonify([])
195
+ rows = fetch_history(session_id)
196
+ return jsonify([{"timestamp": t, "user": u, "bot": b} for t, u, b in rows])
197
+ except Exception as e:
198
+ print(f"❌ Error in /history: {e}")
199
+ return jsonify({"error": str(e)}), 500
200
+
201
+ # βœ… GET chat for specific session
202
+ @app.route('/history/<session_id>', methods=['GET'])
203
+ def session_history(session_id):
204
+ try:
205
+ rows = fetch_history(session_id)
206
+ return jsonify([{"timestamp": t, "user": u, "bot": b} for t, u, b in rows])
207
+ except Exception as e:
208
+ print(f"❌ Error in /history/{session_id}: {e}")
209
+ return jsonify({"error": str(e)}), 500
210
+
211
+ # βœ… GET all sessions
212
+ @app.route('/sessions', methods=['GET'])
213
+ def list_sessions():
214
+ try:
215
+ sessions = get_all_sessions()
216
+ return jsonify([{"session_id": sid, "created_at": ts} for sid, ts in sessions])
217
+ except Exception as e:
218
+ print(f"❌ Error in /sessions: {e}")
219
+ return jsonify({"error": str(e)}), 500
220
+
221
+ # βœ… DELETE session
222
+ @app.route('/sessions/<session_id>', methods=['DELETE'])
223
+ def delete_session_route(session_id):
224
+ try:
225
+ delete_session(session_id)
226
+ return jsonify({"status": "deleted"})
227
+ except Exception as e:
228
+ print(f"❌ Error in /sessions/{session_id}: {e}")
229
+ return jsonify({"error": str(e)}), 500
230
+
231
+ # βœ… Initialize database and run server
232
  if __name__ == '__main__':
233
+ # Ensure database is initialized without merging init_db.py
234
+ try:
235
+ with sqlite3.connect(DB_FILE) as conn:
236
+ c = conn.cursor()
237
+ c.execute('''
238
+ CREATE TABLE IF NOT EXISTS chat (
239
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
240
+ session_id TEXT NOT NULL,
241
+ timestamp TEXT NOT NULL,
242
+ user TEXT,
243
+ bot TEXT
244
+ )
245
+ ''')
246
+ c.execute('''
247
+ CREATE TABLE IF NOT EXISTS sessions (
248
+ session_id TEXT PRIMARY KEY,
249
+ created_at TEXT NOT NULL
250
+ )
251
+ ''')
252
+ conn.commit()
253
+ print("βœ… Database initialized successfully")
254
+ except Exception as e:
255
+ print(f"❌ Error initializing database: {e}")
256
+ raise
257
  app.run(debug=True, port=5005)