Spaces:
Runtime error
Runtime error
| import os | |
| import multiprocessing | |
| import concurrent.futures | |
| from langchain.document_loaders import TextLoader, DirectoryLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.vectorstores import FAISS | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import torch | |
| import numpy as np | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig | |
| from datetime import datetime | |
| import json | |
| import gradio as gr | |
| import re | |
| from huggingface_hub import InferenceClient | |
| # from unsloth import FastLanguageModel | |
| import transformers | |
| from transformers import BloomForCausalLM | |
| from transformers import BloomForTokenClassification | |
| from transformers import BloomForTokenClassification | |
| from transformers import BloomTokenizerFast | |
| import torch | |
| class DocumentRetrievalAndGeneration: | |
| def __init__(self, embedding_model_name, lm_model_id, data_folder): | |
| # hf_token = os.getenv('HF_TOKEN') | |
| hf="hf_VuNNBwnFqlcKzV" | |
| token="vCfLXEBxyAOftxvlWpwf" | |
| self.hf_token=hf+token | |
| # print(HF_TOKEN,hf_token) | |
| self.all_splits = self.load_documents(data_folder) | |
| self.embeddings = SentenceTransformer(embedding_model_name) | |
| self.cpu_index = self.create_faiss_index() | |
| self.llm = self.initialize_llm2(lm_model_id) | |
| def load_documents(self, folder_path): | |
| loader = DirectoryLoader(folder_path, loader_cls=TextLoader) | |
| documents = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250) | |
| all_splits = text_splitter.split_documents(documents) | |
| print('Length of documents:', len(documents)) | |
| print("LEN of all_splits", len(all_splits)) | |
| return all_splits | |
| def create_faiss_index(self): | |
| all_texts = [split.page_content for split in self.all_splits] | |
| embeddings = self.embeddings.encode(all_texts, convert_to_tensor=True).cpu().numpy() | |
| index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| index.add(embeddings) | |
| return index | |
| def initialize_llm(self, model_id): | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config,token=self.hf_token) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| generate_text = pipeline( | |
| model=model, | |
| tokenizer=tokenizer, | |
| return_full_text=True, | |
| task='text-generation', | |
| temperature=0.6, | |
| max_new_tokens=256, | |
| ) | |
| return generate_text | |
| def initialize_llm2(self,model_id): | |
| self.client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
| # except: | |
| # try: | |
| # pipe = pipeline("text-generation", model="microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True) | |
| # except: | |
| # pipe = pipeline("text-generation", model="microsoft/Phi-3-mini-4k-instruct") | |
| # pipe = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2") | |
| # model_name = "mistralai/Mistral-7B-Instruct-v0.2" | |
| # pipeline = transformers.pipeline( | |
| # "text-generation", | |
| # model=model_name, | |
| # model_kwargs={"torch_dtype": torch.bfloat16}, | |
| # device="cpu", | |
| # ) | |
| # return generate_text | |
| def generate_response_with_timeout(self, model_inputs): | |
| try: | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| future = executor.submit(self.llm.model.generate, model_inputs, max_new_tokens=1000, do_sample=True) | |
| generated_ids = future.result(timeout=800) # Timeout set to 60 seconds | |
| return generated_ids | |
| except concurrent.futures.TimeoutError: | |
| return "Text generation process timed out" | |
| raise TimeoutError("Text generation process timed out") | |
| def query_and_generate_response(self, query): | |
| query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy() | |
| distances, indices = self.cpu_index.search(np.array([query_embedding]), k=5) | |
| content = "" | |
| # for idx in indices[0]: | |
| # content += "-" * 50 + "\n" | |
| # content += self.all_splits[idx].page_content + "\n" | |
| # distance=distances[0][idx] | |
| # print("CHUNK", idx) | |
| # print("Distance :",distance) | |
| # print(self.all_splits[idx].page_content) | |
| # print("############################") | |
| for idx in indices[0]: | |
| if idx < len(self.all_splits) and idx < len(distances[0]): | |
| content += "-" * 50 + "\n" | |
| content += self.all_splits[idx].page_content + "\n" | |
| distance = distances[0][idx] | |
| print("CHUNK", idx) | |
| print("Distance :", distance) | |
| print(self.all_splits[idx].page_content) | |
| print("############################") | |
| else: | |
| print(f"Index {idx} is out of bounds. Skipping.") | |
| # {query} | |
| prompt = f"""<s> | |
| You are a knowledgeable assistant with access to a comprehensive database. | |
| I need you to answer my question and provide related information in a specific format. | |
| I have provided five relatable json files {content}, choose the most suitable chunks for answering the query | |
| Here's what I need: | |
| Include a final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point. | |
| content | |
| Here's my question: | |
| Query: | |
| Solution==> | |
| RETURN ONLY SOLUTION . IF THEIR IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS , RETURN " NO SOLUTION AVAILABLE" | |
| IF THE QUERY AND THE RETRIEVED CHUNKS DO NOT CORRELATE MEANINGFULLY, OR IF THE QUERY IS NOT RELEVANT TO TDA2 OR RELATED TOPICS, THEN "NO SOLUTION AVAILABLE." | |
| Example1 | |
| Query: "How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM", | |
| Solution: "To use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM, you need to modify the configuration file of the NDK application. Specifically, change the processor reference from 'A15_0' to 'IPU1_0'.", | |
| Example2 | |
| Query: "Can BQ25896 support I2C interface?", | |
| Solution: "Yes, the BQ25896 charger supports the I2C interface for communication." | |
| Example3 | |
| Query: "Who is the fastest runner in the world", | |
| Solution:"NO SOLUTION AVAILABLE" | |
| Example4 | |
| Query:"What is the price of latest apple MACBOOK " | |
| Solution:"NO SOLUTION AVAILABLE" | |
| </s> | |
| """ | |
| messages = [{"role": "system", "content": prompt}] | |
| messages.append({"role": "user", "content": query}) | |
| response = "" | |
| for message in self.client.chat_completion(messages,max_tokens=2048,stream=True,temperature=0.7): | |
| token = message.choices[0].delta.content | |
| response += token | |
| # yield response | |
| generated_response=response | |
| # messages = [{"role": "user", "content": prompt}] | |
| # encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt") | |
| # model_inputs = encodeds.to(self.llm.device) | |
| # start_time = datetime.now() | |
| # generated_ids = self.generate_response_with_timeout(model_inputs) | |
| # elapsed_time = datetime.now() - start_time | |
| # decoded = self.llm.tokenizer.batch_decode(generated_ids) | |
| # generated_response = decoded[0] | |
| ######################################################### | |
| # messages = [] | |
| # # Check if history is None or empty and handle accordingly | |
| # if history: | |
| # for user_msg, assistant_msg in history: | |
| # messages.append({"role": "user", "content": user_msg}) | |
| # messages.append({"role": "assistant", "content": assistant_msg}) | |
| # # Always add the current user message | |
| # messages.append({"role": "user", "content": message}) | |
| # # Construct the prompt using the pipeline's tokenizer | |
| # prompt = pipeline.tokenizer.apply_chat_template( | |
| # messages, | |
| # tokenize=False, | |
| # add_generation_prompt=True | |
| # ) | |
| # # Generate the response | |
| # terminators = [ | |
| # pipeline.tokenizer.eos_token_id, | |
| # pipeline.tokenizer.convert_tokens_to_ids("") | |
| # ] | |
| # # Adjust the temperature slightly above given to ensure variety | |
| # adjusted_temp = temperature + 0.1 | |
| # # Generate outputs with adjusted parameters | |
| # outputs = pipeline( | |
| # prompt, | |
| # max_new_tokens=max_new_tokens, | |
| # do_sample=True, | |
| # temperature=adjusted_temp, | |
| # top_p=0.9 | |
| # ) | |
| # # Extract the generated text, skipping the length of the prompt | |
| # generated_text = outputs[0]["generated_text"] | |
| # generated_response = generated_text[len(prompt):] | |
| match1 = re.search(r'\[/INST\](.*?)</s>', generated_response, re.DOTALL) | |
| match2 = re.search(r'Solution:(.*?)</s>', generated_response, re.DOTALL | re.IGNORECASE) | |
| if match1: | |
| solution_text = match1.group(1).strip() | |
| if "Solution:" in solution_text: | |
| solution_text = solution_text.split("Solution:", 1)[1].strip() | |
| elif match2: | |
| solution_text = match2.group(1).strip() | |
| else: | |
| solution_text=generated_response | |
| # print("Generated response:", generated_response) | |
| # print("Time elapsed:", elapsed_time) | |
| # print("Device in use:", self.llm.device) | |
| return solution_text, content | |
| def qa_infer_gradio(self, query): | |
| response = self.query_and_generate_response(query) | |
| return response | |
| if __name__ == "__main__": | |
| print("starting...") | |
| embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12' | |
| # lm_model_id = "mistralai/Mistral-7B-Instruct-v0.2" | |
| lm_model_id= "unsloth/Phi-3-mini-4k-instruct-bnb-4bit" | |
| data_folder = 'text_files' | |
| doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder) | |
| def launch_interface(): | |
| css_code = """ | |
| .gradio-container { | |
| background-color: #ffffff; | |
| } | |
| /* Button styling for all buttons */ | |
| button { | |
| background-color: #999999; /* Default color for all other buttons */ | |
| color: black; | |
| border: 1px solid black; | |
| padding: 10px; | |
| margin-right: 10px; | |
| font-size: 16px; /* Increase font size */ | |
| font-weight: bold; /* Make text bold */ | |
| } | |
| """ | |
| EXAMPLES = ["What are the main types of blood cancer, and how do they differ in terms of symptoms, progression, and treatment options? ", | |
| "What are the latest advancements in the treatment of blood cancer, and how do they improve patient outcomes compared to traditional therapies?", | |
| "How do genetic factors and environmental exposures contribute to the risk of developing blood cancer, and what preventive measures can be taken?"] | |
| interface = gr.Interface( | |
| fn=doc_retrieval_gen.qa_infer_gradio, | |
| inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")], | |
| allow_flagging='never', | |
| examples=EXAMPLES, | |
| cache_examples=False, | |
| outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES")], | |
| css=css_code | |
| ) | |
| interface.launch(debug=True) | |
| launch_interface() | |