piyushthepandey's picture
Upload app.py
d58cd6c verified
import torch, gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# ── IDs ──────────────────────────────────────────────────────────────────────
BASE_ID = "unsloth/Qwen2.5-0.5B" # backbone you trained on
LORA_ID = "piyushthepandey/qwen05b_pubhealth_lora" # your LoRA repo
# ── 1 Tokenizer ─────────────────────────────────────────────────────────────
tok = AutoTokenizer.from_pretrained(
BASE_ID,
trust_remote_code=True, # needed for Qwen-style tokenizers
use_fast=True,
)
# ── 2 Load the base model on ***CPU*** only -------------------------------β€”
# float32 is safest on CPU (half-precision on CPU is very slow / fragile)
base = AutoModelForCausalLM.from_pretrained(
BASE_ID,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
device_map={"": "cpu"},
)
base.eval()
# ── 3 Attach LoRA adapters (still CPU) ──────────────────────────────────────
lora = PeftModel.from_pretrained(base, LORA_ID, device_map={"": "cpu"})
lora.eval()
# ── 4 Prompt template --------------------------------------------------------
TPL = """### Instruction:
Translate the user request into SQL given the schema.
### Schema:
{schema}
### Request:
{request}
### SQL explanation:
{explanation}
### SQL:
"""
@torch.inference_mode()
def generate(schema, request, explanation, which):
model = lora if which == "LoRA" else base
prompt = TPL.format(schema=schema, request=request, explanation=explanation)
ids = tok(prompt, return_tensors="pt").to("cpu")
out = model.generate(**ids,
max_new_tokens=128,
do_sample=False,
eos_token_id=tok.eos_token_id)
return tok.decode(out[0], skip_special_tokens=True)\
.split("### SQL:")[-1].strip()
# ── 5 Gradio UI -------------------------------------------------------------
with gr.Blocks(title="Qwen-0.5B Public-Health Text-to-SQL") as demo:
gr.Markdown("## πŸ₯ Qwen-2.5-0.5B – Public-Health Text-to-SQL \n"
"**Compare base vs fine-tuned LoRA on CPU**")
schema = gr.Textbox(lines=5, label="DDL schema")
req = gr.Textbox(lines=2, label="Natural-language request")
expl = gr.Textbox(lines=2, label="(optional) SQL explanation")
which = gr.Radio(["LoRA", "Base"], value="LoRA", label="Model")
out = gr.Code(label="Generated SQL")
gr.Button("Generate").click(generate,
inputs=[schema, req, expl, which],
outputs=out)
demo.launch()