|
|
import os |
|
|
import torch |
|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
BASE_MODEL = "meta-llama/Llama-3.2-1B-Instruct" |
|
|
ADAPTER_ID = "stevenArtificial/Babaru-Llama-3.2-1B-Instruct" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_ID) |
|
|
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_MODEL, |
|
|
device_map="auto", |
|
|
torch_dtype=(torch.float16 if torch.backends.mps.is_available() else torch.float32) |
|
|
) |
|
|
model = PeftModel.from_pretrained(base_model, ADAPTER_ID) |
|
|
|
|
|
|
|
|
device = ( |
|
|
"cuda" if torch.cuda.is_available() else |
|
|
"mps" if torch.backends.mps.is_available() else |
|
|
"cpu" |
|
|
) |
|
|
model.to(device) |
|
|
|
|
|
def generate( |
|
|
prompt: str, |
|
|
max_tokens: int, |
|
|
top_p: float, |
|
|
temperature: float |
|
|
) -> str: |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
do_sample=True, |
|
|
top_p=top_p, |
|
|
temperature=temperature, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
return text[len(prompt):].strip() |
|
|
|
|
|
|
|
|
EXAMPLES = [ |
|
|
[ |
|
|
"User: How would you describe the theatrical style of Babaru?\nAssistant:", |
|
|
128, 0.9, 0.8 |
|
|
], |
|
|
[ |
|
|
"User: Explain existentialism in a snarky tone.\nAssistant:", |
|
|
128, 0.9, 0.7 |
|
|
], |
|
|
[ |
|
|
"User: Summarize the significance of LoRA.\nAssistant:", |
|
|
128, 0.8, 0.9 |
|
|
] |
|
|
] |
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=generate, |
|
|
inputs=[ |
|
|
gr.Textbox( |
|
|
lines=4, |
|
|
placeholder="Enter your prompt here...", |
|
|
label="Prompt" |
|
|
), |
|
|
gr.Slider(32, 512, value=128, step=32, label="Max new tokens"), |
|
|
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"), |
|
|
gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature") |
|
|
], |
|
|
outputs=[ |
|
|
gr.Textbox(label="Generated Reply") |
|
|
], |
|
|
examples=EXAMPLES, |
|
|
title="🎭 Babaru Llama-3.2-1B-Instruct", |
|
|
description="Snarky, theatrical AI assistant. Enter a prompt to see Babaru in action!", |
|
|
allow_flagging="never" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
interface.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=int(os.environ.get("PORT", 7860)) |
|
|
) |