| import gradio as gr |
| import random |
| import time, os |
| import copy |
| import re |
|
|
| import torch |
| from rich.console import Console |
| from rich.table import Table |
| from datetime import datetime |
|
|
| from threading import Thread |
| from typing import Optional |
| from transformers import TextIteratorStreamer |
|
|
| from utils.special_tok_llama2 import ( |
| B_CODE, |
| E_CODE, |
| B_RESULT, |
| E_RESULT, |
| B_INST, |
| E_INST, |
| B_SYS, |
| E_SYS, |
| DEFAULT_PAD_TOKEN, |
| DEFAULT_BOS_TOKEN, |
| DEFAULT_EOS_TOKEN, |
| DEFAULT_UNK_TOKEN, |
| IGNORE_INDEX, |
| ) |
|
|
| from finetuning.conversation_template import ( |
| json_to_code_result_tok_temp, |
| msg_to_code_result_tok_temp, |
| ) |
|
|
| import warnings |
|
|
| warnings.filterwarnings("ignore", category=UserWarning, module="transformers") |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" |
|
|
|
|
| from code_interpreter.LlamaCodeInterpreter import LlamaCodeInterpreter |
|
|
|
|
| class StreamingLlamaCodeInterpreter(LlamaCodeInterpreter): |
| streamer: Optional[TextIteratorStreamer] = None |
|
|
| |
| @torch.inference_mode() |
| def generate( |
| self, |
| prompt: str = "[INST]\n###User : hi\n###Assistant :", |
| max_new_tokens=512, |
| do_sample: bool = True, |
| use_cache: bool = True, |
| top_p: float = 0.95, |
| temperature: float = 0.1, |
| top_k: int = 50, |
| repetition_penalty: float = 1.0, |
| ) -> str: |
| |
|
|
| self.streamer = TextIteratorStreamer( |
| self.tokenizer, skip_prompt=True, Timeout=5 |
| ) |
|
|
| input_prompt = copy.deepcopy(prompt) |
| inputs = self.tokenizer([prompt], return_tensors="pt") |
| input_tokens_shape = inputs["input_ids"].shape[-1] |
|
|
| eos_token_id = self.tokenizer.convert_tokens_to_ids(DEFAULT_EOS_TOKEN) |
| e_code_token_id = self.tokenizer.convert_tokens_to_ids(E_CODE) |
|
|
| kwargs = dict( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=do_sample, |
| top_p=top_p, |
| temperature=temperature, |
| use_cache=use_cache, |
| top_k=top_k, |
| repetition_penalty=repetition_penalty, |
| eos_token_id=[ |
| eos_token_id, |
| e_code_token_id, |
| ], |
| streamer=self.streamer, |
| ) |
|
|
| thread = Thread(target=self.model.generate, kwargs=kwargs) |
| thread.start() |
|
|
| return "" |
|
|
|
|
| def change_markdown_image(text: str): |
| modified_text = re.sub(r"!\[(.*?)\]\(\'(.*?)\'\)", r"", text) |
| return modified_text |
|
|
|
|
| def gradio_launch(model_path: str, load_in_4bit: bool = True, MAX_TRY: int = 5): |
| with gr.Blocks(theme=gr.themes.Monochrome()) as demo: |
| chatbot = gr.Chatbot(height=820, avatar_images="./assets/logo2.png") |
| msg = gr.Textbox() |
| clear = gr.Button("Clear") |
|
|
| interpreter = StreamingLlamaCodeInterpreter( |
| model_path=model_path, load_in_4bit=load_in_4bit |
| ) |
|
|
| def bot(history): |
| user_message = history[-1][0] |
|
|
| interpreter.dialog.append({"role": "user", "content": user_message}) |
|
|
| print(f"###User : [bold]{user_message}[bold]") |
| |
|
|
| |
| HAS_CODE = False |
| INST_END_TOK_FLAG = False |
| full_generated_text = "" |
| prompt = interpreter.dialog_to_prompt(dialog=interpreter.dialog) |
| start_prompt = copy.deepcopy(prompt) |
| prompt = f"{prompt} {E_INST}" |
|
|
| _ = interpreter.generate(prompt) |
| history[-1][1] = "" |
| generated_text = "" |
| for character in interpreter.streamer: |
| history[-1][1] += character |
| generated_text += character |
| yield history |
|
|
| full_generated_text += generated_text |
| HAS_CODE, generated_code_block = interpreter.extract_code_blocks( |
| generated_text |
| ) |
|
|
| attempt = 1 |
| while HAS_CODE: |
| if attempt > MAX_TRY: |
| break |
| |
|
|
| |
| history[-1][1] = ( |
| history[-1][1] |
| .replace(f"{B_CODE}", "\n```python\n") |
| .replace(f"{E_CODE}", "\n```\n") |
| ) |
| history[-1][1] = change_markdown_image(history[-1][1]) |
| yield history |
|
|
| |
| generated_code_block = generated_code_block.replace( |
| "<unk>_", "" |
| ).replace("<unk>", "") |
|
|
| ( |
| code_block_output, |
| error_flag, |
| ) = interpreter.execute_code_and_return_output( |
| f"{generated_code_block}" |
| ) |
| code_block_output = interpreter.clean_code_output(code_block_output) |
| generated_text = ( |
| f"{generated_text}\n{B_RESULT}\n{code_block_output}\n{E_RESULT}\n" |
| ) |
| full_generated_text += ( |
| f"\n{B_RESULT}\n{code_block_output}\n{E_RESULT}\n" |
| ) |
|
|
| |
| history[-1][1] += f"\n```RESULT\n{code_block_output}\n```\n" |
| history[-1][1] = change_markdown_image(history[-1][1]) |
| yield history |
|
|
| prompt = f"{prompt} {generated_text}" |
|
|
| _ = interpreter.generate(prompt) |
| for character in interpreter.streamer: |
| history[-1][1] += character |
| generated_text += character |
| history[-1][1] = change_markdown_image(history[-1][1]) |
| yield history |
|
|
| HAS_CODE, generated_code_block = interpreter.extract_code_blocks( |
| generated_text |
| ) |
|
|
| if generated_text.endswith("</s>"): |
| break |
|
|
| attempt += 1 |
|
|
| interpreter.dialog.append( |
| { |
| "role": "assistant", |
| "content": generated_text.replace("<unk>_", "") |
| .replace("<unk>", "") |
| .replace("</s>", ""), |
| } |
| ) |
|
|
| print("----------\n" * 2) |
| print(interpreter.dialog) |
| print("----------\n" * 2) |
|
|
| return history[-1][1] |
|
|
| def user(user_message, history): |
| return "", history + [[user_message, None]] |
|
|
| msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( |
| bot, chatbot, chatbot |
| ) |
| clear.click(lambda: None, None, chatbot, queue=False) |
|
|
| demo.queue() |
| demo.launch() |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="Process path for LLAMA2_FINETUNEED.") |
| parser.add_argument( |
| "--path", |
| type=str, |
| required=True, |
| help="Path to the finetuned LLAMA2 model.", |
| default="./output/llama-2-7b-codellama-ci", |
| ) |
| args = parser.parse_args() |
|
|
| gradio_launch(model_path=args.path, load_in_4bit=True) |
|
|