coogee-demo / app.py
jameszhou-gl's picture
Initial commit of Coogee demo
4e2ac81
raw
history blame
8.24 kB
import os
import json
import torch
import gradio as gr
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
# -----------------------------
# Config: where to load things
# -----------------------------
# Tokenizer repo (already published)
TOKENIZER_REPO = os.environ.get("COOGEE_TOKENIZER_REPO", "jameszhou-gl/ehr-gpt")
# Where to get model weights/config:
# - Option A (default): from the same HF repo as tokenizer (set names below)
# - Option B: from local files (set LOCAL_* paths)
WEIGHT_REPO = os.environ.get("COOGEE_WEIGHT_REPO", TOKENIZER_REPO)
WEIGHT_FILENAME = os.environ.get("COOGEE_WEIGHT_FILENAME", "model.safetensors")
CONFIG_FILENAME = os.environ.get("COOGEE_CONFIG_FILENAME", "config.json")
LOCAL_WEIGHT = os.environ.get("COOGEE_LOCAL_WEIGHT", "") # e.g., "hf_upload_tmp/model.safetensors"
LOCAL_CONFIG = os.environ.get("COOGEE_LOCAL_CONFIG", "") # e.g., "hf_upload_tmp/config.json"
# Optional: HF token if the repo is private in a Space
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# -----------------------------
# 1) Import your local model code
# -----------------------------
from model.model import Transformer, ModelArgs # <- YOUR local classes
# -----------------------------
# 2) Load tokenizer
# -----------------------------
tok = AutoTokenizer.from_pretrained(TOKENIZER_REPO, use_fast=False)
# EOS handling: prefer tokenizer.eos_token_id, or fall back to END_RECORD
def get_eos_id():
if tok.eos_token_id is not None:
return tok.eos_token_id
try:
return tok.convert_tokens_to_ids("END_RECORD")
except Exception:
raise ValueError("No eos_token_id and 'END_RECORD' not found; set eos_token in tokenizer_config.json.")
EOS_ID = get_eos_id()
# -----------------------------
# 3) Load config & weights
# -----------------------------
if LOCAL_CONFIG and os.path.isfile(LOCAL_CONFIG):
cfg_path = LOCAL_CONFIG
else:
cfg_path = hf_hub_download(WEIGHT_REPO, filename=CONFIG_FILENAME, token=HF_TOKEN)
with open(cfg_path, "r") as f:
cfg = json.load(f)
args = ModelArgs(**cfg)
model = Transformer(args)
if LOCAL_WEIGHT and os.path.isfile(LOCAL_WEIGHT):
weight_path = LOCAL_WEIGHT
else:
weight_path = hf_hub_download(WEIGHT_REPO, filename=WEIGHT_FILENAME, token=HF_TOKEN)
state = load_file(weight_path)
missing, unexpected = model.load_state_dict(state, strict=False)
if missing or unexpected:
print("[load_state_dict] missing:", missing)
print("[load_state_dict] unexpected:", unexpected)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()
# -----------------------------
# 4) Generation wrapper
# -----------------------------
def _has_generate(m):
return callable(getattr(m, "generate", None))
def _fallback_sample(input_ids, max_new_tokens=256, temperature=1.0, top_p=0.95, top_k=50):
"""
Minimal sampling loop if your Transformer doesn't implement .generate().
Assumes model(input_ids) -> logits [B, T, V].
Stops on EOS_ID.
"""
model.eval()
ids = input_ids.clone()
with torch.no_grad():
for _ in range(int(max_new_tokens)):
logits = model(ids) # your forward may return logits directly; if it returns dict, adapt here
if isinstance(logits, dict):
logits = logits["logits"]
next_logits = logits[:, -1, :] / max(temperature, 1e-6) # [B, V]
# top-k / nucleus
if top_k and top_k > 0:
topk_vals, topk_idx = torch.topk(next_logits, k=min(top_k, next_logits.size(-1)), dim=-1)
filt = torch.full_like(next_logits, float("-inf"))
filt.scatter_(1, topk_idx, topk_vals)
next_logits = filt
if 0.0 < top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(next_logits, descending=True)
probs = torch.softmax(sorted_logits, dim=-1)
cum = torch.cumsum(probs, dim=-1)
mask = cum > top_p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = False
sorted_logits = sorted_logits.masked_fill(mask, float("-inf"))
next_logits = torch.zeros_like(next_logits).scatter(1, sorted_idx, sorted_logits)
probs = torch.softmax(next_logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1) # [B, 1]
ids = torch.cat([ids, next_id], dim=1) # [B, T+1]
if int(next_id.item()) == EOS_ID:
break
return ids
def generate_timeline(age, sex, race, marital, year,
max_new_tokens, temperature, top_p, top_k, seed):
# Build prompt tokens
prompt_tokens = [
"START_RECORD",
age, sex, race, marital, year,
]
# to ids
try:
ids = [[tok.convert_tokens_to_ids(t) for t in prompt_tokens]]
except Exception as e:
return f"Tokenization error: {e}", ""
input_ids = torch.tensor(ids, dtype=torch.long, device=device)
# Seed
if seed is not None and str(seed).strip() != "":
try:
s = int(seed)
torch.manual_seed(s)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(s)
except Exception:
pass
# Generate
if _has_generate(model):
out = model.generate(
input_ids=input_ids,
max_length=min(args.n_ctx, input_ids.size(1) + int(max_new_tokens)),
temperature=float(temperature),
top_p=float(top_p),
top_k=int(top_k),
do_sample=True,
eos_token_id=EOS_ID,
)
else:
out = _fallback_sample(
input_ids, max_new_tokens=int(max_new_tokens),
temperature=float(temperature), top_p=float(top_p), top_k=int(top_k)
)
gen_ids = out[0, input_ids.size(1):].tolist()
if EOS_ID in gen_ids:
end_idx = gen_ids.index(EOS_ID)
gen_ids = gen_ids[: end_idx + 1]
gen_tokens = tok.convert_ids_to_tokens(gen_ids)
timeline = prompt_tokens + gen_tokens
text_view = " ".join(timeline)
table_view = [[i, t] for i, t in enumerate(timeline)]
return text_view, table_view
# -----------------------------
# 5) Gradio UI
# -----------------------------
AGE_OPTS = [f"AGE_{a}_{a+5}_years" for a in range(15, 100, 5)]
SEX_OPTS = ["SEX_M", "SEX_F"]
RACE_OPTS = ["RACE_ASIAN", "RACE_BLACK", "RACE_HISPANIC", "RACE_OTHER", "RACE_UNKNOWN", "RACE_WHITE"]
MARITAL_OPTS = ["MARITAL_STATUS_DIVORCED", "MARITAL_STATUS_MARRIED", "MARITAL_STATUS_SINGLE",
"MARITAL_STATUS_UNKNOWN", "MARITAL_STATUS_WIDOWED"]
YEAR_OPTS = [f"YEAR_{y}" for y in range(2005, 2021)]
with gr.Blocks(title="Coogee (local model) β€” Synthetic EHR Generator") as demo:
gr.Markdown("## Coogee β€” Generate synthetic EHR timelines (local model class)")
with gr.Row():
age = gr.Dropdown(AGE_OPTS, value="AGE_85_90_years", label="Age")
sex = gr.Dropdown(SEX_OPTS, value="SEX_M", label="Sex")
race = gr.Dropdown(RACE_OPTS, value="RACE_UNKNOWN", label="Ethnicity")
marital = gr.Dropdown(MARITAL_OPTS, value="MARITAL_STATUS_WIDOWED", label="Marital")
year = gr.Dropdown(YEAR_OPTS, value="YEAR_2017", label="Year")
with gr.Row():
max_new_tokens = gr.Slider(16, 1024, value=256, step=1, label="Max new tokens")
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p")
top_k = gr.Slider(0, 200, value=50, step=1, label="Top-k")
seed = gr.Textbox(value="", label="Seed (optional)")
btn = gr.Button("Generate")
out_text = gr.Textbox(lines=6, label="Generated timeline")
out_table = gr.Dataframe(headers=["Idx", "Token"], label="Token table", interactive=False)
btn.click(
fn=generate_timeline,
inputs=[age, sex, race, marital, year, max_new_tokens, temperature, top_p, top_k, seed],
outputs=[out_text, out_table],
api_name="generate",
)
demo.queue(max_size=20).launch()