|
|
import torch, gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
BASE_ID = "unsloth/Qwen2.5-0.5B" |
|
|
LORA_ID = "piyushthepandey/qwen05b_pubhealth_lora" |
|
|
|
|
|
|
|
|
tok = AutoTokenizer.from_pretrained( |
|
|
BASE_ID, |
|
|
trust_remote_code=True, |
|
|
use_fast=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
base = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_ID, |
|
|
torch_dtype=torch.float32, |
|
|
low_cpu_mem_usage=True, |
|
|
device_map={"": "cpu"}, |
|
|
) |
|
|
base.eval() |
|
|
|
|
|
|
|
|
lora = PeftModel.from_pretrained(base, LORA_ID, device_map={"": "cpu"}) |
|
|
lora.eval() |
|
|
|
|
|
|
|
|
TPL = """### Instruction: |
|
|
Translate the user request into SQL given the schema. |
|
|
|
|
|
### Schema: |
|
|
{schema} |
|
|
|
|
|
### Request: |
|
|
{request} |
|
|
|
|
|
### SQL explanation: |
|
|
{explanation} |
|
|
|
|
|
### SQL: |
|
|
""" |
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate(schema, request, explanation, which): |
|
|
model = lora if which == "LoRA" else base |
|
|
prompt = TPL.format(schema=schema, request=request, explanation=explanation) |
|
|
ids = tok(prompt, return_tensors="pt").to("cpu") |
|
|
out = model.generate(**ids, |
|
|
max_new_tokens=128, |
|
|
do_sample=False, |
|
|
eos_token_id=tok.eos_token_id) |
|
|
return tok.decode(out[0], skip_special_tokens=True)\ |
|
|
.split("### SQL:")[-1].strip() |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Qwen-0.5B Public-Health Text-to-SQL") as demo: |
|
|
gr.Markdown("## π₯ Qwen-2.5-0.5B β Public-Health Text-to-SQL \n" |
|
|
"**Compare base vs fine-tuned LoRA on CPU**") |
|
|
schema = gr.Textbox(lines=5, label="DDL schema") |
|
|
req = gr.Textbox(lines=2, label="Natural-language request") |
|
|
expl = gr.Textbox(lines=2, label="(optional) SQL explanation") |
|
|
which = gr.Radio(["LoRA", "Base"], value="LoRA", label="Model") |
|
|
out = gr.Code(label="Generated SQL") |
|
|
gr.Button("Generate").click(generate, |
|
|
inputs=[schema, req, expl, which], |
|
|
outputs=out) |
|
|
|
|
|
demo.launch() |