Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import io | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline | |
| from datetime import datetime | |
| from PIL import Image | |
| import os | |
| from datetime import datetime | |
| from openai import OpenAI | |
| from ai71 import AI71 | |
| if torch.cuda.is_available(): | |
| model = model.to('cuda') | |
| dials_embeddings = pd.read_pickle('https://huggingface.co/datasets/vsrinivas/CBT_dialogue_embed_ds/resolve/main/kaggle_therapy_embeddings.pkl') | |
| with open ('emotion_group_labels.txt') as file: | |
| emotion_group_labels = file.read().splitlines() | |
| embed_model = SentenceTransformer('paraphrase-MiniLM-L6-v2') | |
| classifier = pipeline("zero-shot-classification", model ='facebook/bart-large-mnli') | |
| AI71_API_KEY = os.getenv('AI71_API_KEY') | |
| # Detect emotions from patient dialogues | |
| def detect_emotions(text): | |
| emotion = classifier(text, candidate_labels=emotion_group_labels, batch_size=16) | |
| top_5_scores = [i/sum(emotion['scores'][:5]) for i in emotion['scores'][:5]] | |
| top_5_emotions = emotion['labels'][:5] | |
| emotion_set = {l: "{:.2%}".format(s) for l, s in zip(top_5_emotions, top_5_scores)} | |
| return emotion_set | |
| # Measure cosine similarity between a pair of vectors | |
| def cosine_distance(vec1,vec2): | |
| cosine = (np.dot(vec1, vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2))) | |
| return cosine | |
| # Generate an image of trigger emotions | |
| def generate_triggers_img(items): | |
| labels = list(items.keys()) | |
| values = [float(v.strip('%')) for v in items.values()] # Convert to float for plotting | |
| new_items = {k:v for k, v in zip(labels, values)} | |
| new_items = dict(sorted(new_items.items(), key=lambda item: item[1])) | |
| labels = list(new_items.keys()) | |
| values = list(new_items.values()) | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| colors = plt.cm.viridis(np.linspace(0, 1, len(labels))) | |
| bars = ax.barh(labels, values, color=colors) | |
| for spine in ax.spines.values(): | |
| spine.set_visible(False) | |
| ax.tick_params(axis='y', labelsize=18) | |
| ax.xaxis.set_visible(False) | |
| ax.yaxis.set_ticks_position('none') | |
| for bar in bars: | |
| width = bar.get_width() | |
| ax.text(width, bar.get_y() + bar.get_height()/2, f'{width:.2f}%', | |
| ha='left', va='center', fontweight='bold', fontsize=18) | |
| plt.tight_layout() | |
| plt.savefig('triggeres.png') | |
| triggers_img = Image.open('triggeres.png') | |
| return triggers_img | |
| def get_doc_response_emotions(user_message, therapy_session_conversation): | |
| user_messages = [] | |
| user_messages.append(user_message) | |
| emotion_set = detect_emotions(user_message) | |
| print(emotion_set) | |
| emotions_msg = generate_triggers_img(emotion_set) | |
| user_embedding = embed_model.encode(user_message, device='cuda' if torch.cuda.is_available() else 'cpu') | |
| similarities =[] | |
| for v in dials_embeddings['embeddings']: | |
| similarities.append(cosine_distance(user_embedding,v)) | |
| top_match_index = similarities.index(max(similarities)) | |
| doc_response = dials_embeddings.iloc[top_match_index]['Doctor'] | |
| therapy_session_conversation.append(["User: "+user_message, "Therapist: "+doc_response]) | |
| print(f"User's message: {user_message}") | |
| print(f"RAG Matching message: {dials_embeddings.iloc[top_match_index]['Patient']}") | |
| print(f"Therapist's response: {dials_embeddings.iloc[top_match_index]['Doctor']}\n\n") | |
| return '', therapy_session_conversation, emotions_msg | |
| def summarize_and_recommend(therapy_session_conversation): | |
| session_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) | |
| session_conversation = [item[0] for item in therapy_session_conversation] | |
| session_conversation = [x for x in session_conversation if x is not None] | |
| session_conversation.insert(0, "Session_time: "+session_time) | |
| session_conversation_processed ='\n'.join(session_conversation) | |
| print("session_conversation_processed:", session_conversation_processed) | |
| full_summary = "" | |
| for chunk in AI71(AI71_API_KEY).chat.completions.create( | |
| model="tiiuae/falcon-180b-chat", | |
| messages=[ | |
| {"role": "system", "content": """You are an Expert Cognitive Behavioural Therapist and Precis writer. | |
| Summarize 'STRICTLY' the below user content <<<session_conversation_processed>>> 'ONLY' into useful, ethical, relevant and realistic phrases with a format | |
| Session Time: | |
| Summary of the patient messages: #in two to four sentences | |
| Summary of therapist messages: #in two to three sentences: | |
| Summary of the whole session: # in two to three sentences. Ensure the entire session summary strictly does not exceed 100 tokens."""}, | |
| {"role": "user", "content": session_conversation_processed}, | |
| ], | |
| stream=True, | |
| ): | |
| if chunk.choices[0].delta.content: | |
| summary = chunk.choices[0].delta.content | |
| full_summary += summary | |
| full_summary = full_summary.replace('User:', '').strip() | |
| print("\n") | |
| print("Full summary:", full_summary) | |
| full_recommendations = "" | |
| for chunk in AI71(AI71_API_KEY).chat.completions.create( | |
| model="tiiuae/falcon-180b-chat", | |
| messages=[ | |
| {"role": "system", "content": """You are an expert Cognitive Behavioural Therapist. | |
| Based on 'STRICTLY' the full summary <<<full_summary>>> 'ONLY' provide clinically valid, useful, appropriate action plan for the Patient as a bullted list. | |
| The list shall contain both medical and non medical prescriptions, dos and donts. The format of response shall be in passive voice with proper tense. | |
| - The patient is referred to........ #in one sentence | |
| - The patient is advised to ........ #in one sentence | |
| - The patient is refrained from........ #in one sentence | |
| - It is suggested that tha patient ........ #in one sentence | |
| - Scheduled a follow-up session with the patient........#in one sentence | |
| *Ensure the list contains NOT MORE THAN 7 points"""}, | |
| {"role": "user", "content": full_summary}, | |
| ], | |
| stream=True, | |
| ): | |
| if chunk.choices[0].delta.content: | |
| rec = chunk.choices[0].delta.content | |
| full_recommendations += rec | |
| full_recommendations = full_recommendations.replace('User:', '').strip() | |
| print("\n") | |
| print("Full recommendations:", full_recommendations) | |
| chatbox=[] | |
| return full_summary, full_recommendations, chatbox |