import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from transformers.trainer_utils import set_seed from packaging import version import transformers from threading import Thread import random import os import gradio as gr # 默认参数 DEFAULT_TOP_P = 0.9 DEFAULT_TOP_K = 80 DEFAULT_TEMPERATURE = 0.3 DEFAULT_MAX_NEW_TOKENS = 512 DEFAULT_SYSTEM_MESSAGE = "" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" cpu_only = not torch.cuda.is_available() DEFAULT_CKPT_PATH = "ystemsrx/Qwen2.5-Sex" def _load_model_tokenizer(checkpoint_path): tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, resume_download=True) device_map = "auto" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model = AutoModelForCausalLM.from_pretrained( checkpoint_path, torch_dtype=torch_dtype, resume_download=True ).eval() model.generation_config.max_new_tokens = DEFAULT_MAX_NEW_TOKENS return model, tokenizer def _chat_stream(model, tokenizer, query, history, system_message, top_p, top_k, temperature, max_new_tokens): conversation = [{'role': 'system', 'content': system_message}] for query_h, response_h in history: conversation.append({'role': 'user', 'content': query_h}) conversation.append({'role': 'assistant', 'content': response_h}) conversation.append({'role': 'user', 'content': query}) if version.parse(transformers.__version__) >= version.parse("4.31"): text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) else: text = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in conversation]) + "\nAssistant:" inputs = tokenizer(text, return_tensors="pt").to(DEVICE) streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=True, timeout=30.0, skip_special_tokens=True) generation_kwargs = dict( input_ids=inputs["input_ids"], max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, temperature=temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id, streamer=streamer ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() thread.join(timeout=45) assistant_reply = "" for new_text in streamer: assistant_reply += new_text yield assistant_reply def initialize_model(checkpoint_path=DEFAULT_CKPT_PATH): set_seed(random.randint(0, 2**32 - 1)) return _load_model_tokenizer(checkpoint_path) model, tokenizer = initialize_model() def chat_interface(user_input, history, system_message, top_p, top_k, temperature, max_new_tokens): if not user_input.strip(): yield history, history, system_message, "" return history.append((user_input, "")) yield history, history, system_message, "" generator = _chat_stream(model, tokenizer, user_input, history[:-1], system_message, top_p, top_k, temperature, max_new_tokens) for assistant_reply in generator: history[-1] = (user_input, assistant_reply) yield history, history, system_message, "" def clear_history(): return [], [], DEFAULT_SYSTEM_MESSAGE, gr.Textbox.update(value="") # Gradio UI demo = gr.Blocks() with demo: gr.Markdown("# Qwen2.5 Chatbot") with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot() user_input = gr.Textbox(show_label=False, placeholder="输入你的问题...") send_btn = gr.Button("发送") with gr.Column(scale=1): clear_btn = gr.Button("清空历史") system_message = gr.Textbox(label="系统消息", value=DEFAULT_SYSTEM_MESSAGE) top_p_slider = gr.Slider(0.1, 1.0, value=DEFAULT_TOP_P, label="Top-p") top_k_slider = gr.Slider(0, 100, value=DEFAULT_TOP_K, label="Top-k") temperature_slider = gr.Slider(0.1, 1.5, value=DEFAULT_TEMPERATURE, label="Temperature") max_new_tokens_slider = gr.Slider(50, 2048, value=DEFAULT_MAX_NEW_TOKENS, label="Max New Tokens") state = gr.State([]) user_input.submit(chat_interface, [user_input, state, system_message, top_p_slider, top_k_slider, temperature_slider, max_new_tokens_slider], [chatbot, state, system_message, user_input], queue=True) send_btn.click(chat_interface, [user_input, state, system_message, top_p_slider, top_k_slider, temperature_slider, max_new_tokens_slider], [chatbot, state, system_message, user_input], queue=True) clear_btn.click(clear_history, None, [chatbot, state, system_message, user_input], queue=True) demo.launch()