Spaces:
Sleeping
Sleeping
| import gc | |
| import time | |
| import torch | |
| import numpy as np | |
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering | |
| st.set_page_config(page_title="ViBidLawQA - Hệ thống hỏi đáp trực tuyến luật Việt Nam", page_icon="./app/static/ai.png", layout="centered", initial_sidebar_state="expanded") | |
| with open("./static/styles.css") as f: | |
| st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
| if 'messages' not in st.session_state: | |
| st.session_state.messages = [] | |
| st.markdown(f""" | |
| <div class=logo_area> | |
| <img src="./app/static/ai.png"/> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.markdown("<h2 style='text-align: center;'>ViBidLawQA_v2</h2>", unsafe_allow_html=True) | |
| answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0) | |
| context = st.sidebar.text_area(label='Nội dung văn bản pháp luật Việt Nam:', placeholder='Vui lòng nhập nội dung văn bản pháp luật Việt Nam tại đây...', height=500) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| if answering_method == 'Generation' and 'aqa_model' not in st.session_state: | |
| if 'eqa_model' and 'eqa_tokenizer' in st.session_state: | |
| del st.session_state.eqa_model | |
| del st.session_state.eqa_tokenizer | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| print('Switching to generative model...') | |
| print('Loading generative model...') | |
| st.session_state.aqa_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path='./models/AQA_model').to(device) | |
| st.session_state.aqa_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='./models/AQA_model') | |
| if answering_method == 'Extraction' and 'eqa_model' not in st.session_state: | |
| if 'aqa_model' and 'aqa_tokenizer' in st.session_state: | |
| del st.session_state.aqa_model | |
| del st.session_state.aqa_tokenizer | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| print('Switching to extraction model...') | |
| print('Loading extraction model...') | |
| st.session_state.eqa_model = AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path='./models/EQA_model').to(device) | |
| st.session_state.eqa_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='./models/EQA_model') | |
| def get_abstractive_answer(context, question, max_length=1024, max_target_length=512): | |
| inputs = st.session_state.aqa_tokenizer(question, | |
| context, | |
| max_length=max_length, | |
| truncation='only_second', | |
| padding='max_length', | |
| return_tensors='pt') | |
| outputs = st.session_state.aqa_model.generate(inputs=inputs['input_ids'].to(device), | |
| attention_mask=inputs['attention_mask'].to(device), | |
| max_length=max_target_length) | |
| answer = st.session_state.aqa_tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_space=True) | |
| if not answer.endswith('.'): | |
| answer += '.' | |
| return answer | |
| def generate_text_effect(answer): | |
| words = answer.split() | |
| for i in range(len(words)): | |
| time.sleep(0.05) | |
| yield " ".join(words[:i+1]) | |
| def get_extractive_answer(context, question, stride=20, max_length=256, n_best=50, max_answer_length=512): | |
| inputs = st.session_state.eqa_tokenizer(question, | |
| context, | |
| max_length=max_length, | |
| truncation='only_second', | |
| stride=stride, | |
| return_overflowing_tokens=True, | |
| return_offsets_mapping=True, | |
| padding='max_length') | |
| for i in range(len(inputs['input_ids'])): | |
| sequence_ids = inputs.sequence_ids(i) | |
| offset = inputs['offset_mapping'][i] | |
| inputs['offset_mapping'][i] = [ | |
| o if sequence_ids[k] == 1 else None for k, o in enumerate(offset) | |
| ] | |
| input_ids = torch.tensor(inputs["input_ids"]).to(device) | |
| attention_mask = torch.tensor(inputs["attention_mask"]).to(device) | |
| with torch.no_grad(): | |
| outputs = st.session_state.eqa_model(input_ids=input_ids, attention_mask=attention_mask) | |
| start_logits = outputs.start_logits.cpu().numpy() | |
| end_logits = outputs.end_logits.cpu().numpy() | |
| answers = [] | |
| for i in range(len(inputs["input_ids"])): | |
| start_logit = start_logits[i] | |
| end_logit = end_logits[i] | |
| offsets = inputs["offset_mapping"][i] | |
| start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist() | |
| end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist() | |
| for start_index in start_indexes: | |
| for end_index in end_indexes: | |
| if offsets[start_index] is None or offsets[end_index] is None: | |
| continue | |
| if end_index < start_index or end_index - start_index + 1 > max_answer_length: | |
| continue | |
| answer = { | |
| "text": context[offsets[start_index][0] : offsets[end_index][1]], | |
| "logit_score": start_logit[start_index] + end_logit[end_index], | |
| } | |
| answers.append(answer) | |
| if len(answers) > 0: | |
| best_answer = max(answers, key=lambda x: x["logit_score"]) | |
| return best_answer["text"] | |
| else: | |
| return "" | |
| for message in st.session_state.messages: | |
| if message['role'] == 'assistant': | |
| avatar_class = "assistant-avatar" | |
| message_class = "assistant-message" | |
| avatar = './app/static/ai.png' | |
| else: | |
| avatar_class = "user-avatar" | |
| message_class = "user-message" | |
| avatar = './app/static/human.png' | |
| st.markdown(f""" | |
| <div class="{message_class}"> | |
| <img src="{avatar}" class="{avatar_class}" /> | |
| <div class="stMarkdown">{message['content']}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'): | |
| st.markdown(f""" | |
| <div class="user-message"> | |
| <img src="./app/static/human.png" class="user-avatar" /> | |
| <div class="stMarkdown">{prompt}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.session_state.messages.append({'role': 'user', 'content': prompt}) | |
| message_placeholder = st.empty() | |
| for _ in range(2): | |
| for dots in ["●", "●●", "●●●"]: | |
| time.sleep(0.2) | |
| message_placeholder.markdown(f""" | |
| <div class="assistant-message"> | |
| <img src="./app/static/ai.png" class="assistant-avatar" /> | |
| <div class="stMarkdown">{dots}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| full_response = "" | |
| if answering_method == 'Generation': | |
| abs_answer = get_abstractive_answer(context=context, question=prompt) | |
| for word in generate_text_effect(abs_answer): | |
| full_response = word | |
| message_placeholder.markdown(f""" | |
| <div class="assistant-message"> | |
| <img src="./app/static/ai.png" class="assistant-avatar" /> | |
| <div class="stMarkdown">{full_response}●</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| ext_answer = get_extractive_answer(context=context, question=prompt) | |
| for word in generate_text_effect(ext_answer): | |
| full_response = word | |
| message_placeholder.markdown(f""" | |
| <div class="assistant-message"> | |
| <img src="./app/static/ai.png" class="assistant-avatar" /> | |
| <div class="stMarkdown">{full_response}●</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| message_placeholder.markdown(f""" | |
| <div class="assistant-message"> | |
| <img src="./app/static/ai.png" class="assistant-avatar" /> | |
| <div class="stMarkdown">{full_response}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.session_state.messages.append({'role': 'assistant', 'content': full_response}) |