Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import torch | |
| import gradio as gr | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer | |
| # ----------------------------- | |
| # Config: where to load things | |
| # ----------------------------- | |
| # Tokenizer repo (already published) | |
| TOKENIZER_REPO = os.environ.get("COOGEE_TOKENIZER_REPO", "jameszhou-gl/ehr-gpt") | |
| # Where to get model weights/config: | |
| # - Option A (default): from the same HF repo as tokenizer (set names below) | |
| # - Option B: from local files (set LOCAL_* paths) | |
| WEIGHT_REPO = os.environ.get("COOGEE_WEIGHT_REPO", TOKENIZER_REPO) | |
| WEIGHT_FILENAME = os.environ.get("COOGEE_WEIGHT_FILENAME", "model.safetensors") | |
| CONFIG_FILENAME = os.environ.get("COOGEE_CONFIG_FILENAME", "config.json") | |
| LOCAL_WEIGHT = os.environ.get("COOGEE_LOCAL_WEIGHT", "") # e.g., "hf_upload_tmp/model.safetensors" | |
| LOCAL_CONFIG = os.environ.get("COOGEE_LOCAL_CONFIG", "") # e.g., "hf_upload_tmp/config.json" | |
| # Optional: HF token if the repo is private in a Space | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| # ----------------------------- | |
| # 1) Import your local model code | |
| # ----------------------------- | |
| from model.model import Transformer, ModelArgs # <- YOUR local classes | |
| # ----------------------------- | |
| # 2) Load tokenizer | |
| # ----------------------------- | |
| tok = AutoTokenizer.from_pretrained(TOKENIZER_REPO, use_fast=False) | |
| # EOS handling: prefer tokenizer.eos_token_id, or fall back to END_RECORD | |
| def get_eos_id(): | |
| if tok.eos_token_id is not None: | |
| return tok.eos_token_id | |
| try: | |
| return tok.convert_tokens_to_ids("END_RECORD") | |
| except Exception: | |
| raise ValueError("No eos_token_id and 'END_RECORD' not found; set eos_token in tokenizer_config.json.") | |
| EOS_ID = get_eos_id() | |
| # ----------------------------- | |
| # 3) Load config & weights | |
| # ----------------------------- | |
| if LOCAL_CONFIG and os.path.isfile(LOCAL_CONFIG): | |
| cfg_path = LOCAL_CONFIG | |
| else: | |
| cfg_path = hf_hub_download(WEIGHT_REPO, filename=CONFIG_FILENAME, token=HF_TOKEN) | |
| with open(cfg_path, "r") as f: | |
| cfg = json.load(f) | |
| args = ModelArgs(**cfg) | |
| model = Transformer(args) | |
| if LOCAL_WEIGHT and os.path.isfile(LOCAL_WEIGHT): | |
| weight_path = LOCAL_WEIGHT | |
| else: | |
| weight_path = hf_hub_download(WEIGHT_REPO, filename=WEIGHT_FILENAME, token=HF_TOKEN) | |
| state = load_file(weight_path) | |
| missing, unexpected = model.load_state_dict(state, strict=False) | |
| if missing or unexpected: | |
| print("[load_state_dict] missing:", missing) | |
| print("[load_state_dict] unexpected:", unexpected) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device).eval() | |
| # ----------------------------- | |
| # 4) Generation wrapper | |
| # ----------------------------- | |
| def _has_generate(m): | |
| return callable(getattr(m, "generate", None)) | |
| def _fallback_sample(input_ids, max_new_tokens=256, temperature=1.0, top_p=0.95, top_k=50): | |
| """ | |
| Minimal sampling loop if your Transformer doesn't implement .generate(). | |
| Assumes model(input_ids) -> logits [B, T, V]. | |
| Stops on EOS_ID. | |
| """ | |
| model.eval() | |
| ids = input_ids.clone() | |
| with torch.no_grad(): | |
| for _ in range(int(max_new_tokens)): | |
| logits = model(ids) # your forward may return logits directly; if it returns dict, adapt here | |
| if isinstance(logits, dict): | |
| logits = logits["logits"] | |
| next_logits = logits[:, -1, :] / max(temperature, 1e-6) # [B, V] | |
| # top-k / nucleus | |
| if top_k and top_k > 0: | |
| topk_vals, topk_idx = torch.topk(next_logits, k=min(top_k, next_logits.size(-1)), dim=-1) | |
| filt = torch.full_like(next_logits, float("-inf")) | |
| filt.scatter_(1, topk_idx, topk_vals) | |
| next_logits = filt | |
| if 0.0 < top_p < 1.0: | |
| sorted_logits, sorted_idx = torch.sort(next_logits, descending=True) | |
| probs = torch.softmax(sorted_logits, dim=-1) | |
| cum = torch.cumsum(probs, dim=-1) | |
| mask = cum > top_p | |
| mask[..., 1:] = mask[..., :-1].clone() | |
| mask[..., 0] = False | |
| sorted_logits = sorted_logits.masked_fill(mask, float("-inf")) | |
| next_logits = torch.zeros_like(next_logits).scatter(1, sorted_idx, sorted_logits) | |
| probs = torch.softmax(next_logits, dim=-1) | |
| next_id = torch.multinomial(probs, num_samples=1) # [B, 1] | |
| ids = torch.cat([ids, next_id], dim=1) # [B, T+1] | |
| if int(next_id.item()) == EOS_ID: | |
| break | |
| return ids | |
| def generate_timeline(age, sex, race, marital, year, | |
| max_new_tokens, temperature, top_p, top_k, seed): | |
| # Build prompt tokens | |
| prompt_tokens = [ | |
| "START_RECORD", | |
| age, sex, race, marital, year, | |
| ] | |
| # to ids | |
| try: | |
| ids = [[tok.convert_tokens_to_ids(t) for t in prompt_tokens]] | |
| except Exception as e: | |
| return f"Tokenization error: {e}", "" | |
| input_ids = torch.tensor(ids, dtype=torch.long, device=device) | |
| # Seed | |
| if seed is not None and str(seed).strip() != "": | |
| try: | |
| s = int(seed) | |
| torch.manual_seed(s) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(s) | |
| except Exception: | |
| pass | |
| # Generate | |
| if _has_generate(model): | |
| out = model.generate( | |
| input_ids=input_ids, | |
| max_length=min(args.n_ctx, input_ids.size(1) + int(max_new_tokens)), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| top_k=int(top_k), | |
| do_sample=True, | |
| eos_token_id=EOS_ID, | |
| ) | |
| else: | |
| out = _fallback_sample( | |
| input_ids, max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), top_p=float(top_p), top_k=int(top_k) | |
| ) | |
| gen_ids = out[0, input_ids.size(1):].tolist() | |
| if EOS_ID in gen_ids: | |
| end_idx = gen_ids.index(EOS_ID) | |
| gen_ids = gen_ids[: end_idx + 1] | |
| gen_tokens = tok.convert_ids_to_tokens(gen_ids) | |
| timeline = prompt_tokens + gen_tokens | |
| text_view = " ".join(timeline) | |
| table_view = [[i, t] for i, t in enumerate(timeline)] | |
| return text_view, table_view | |
| # ----------------------------- | |
| # 5) Gradio UI | |
| # ----------------------------- | |
| AGE_OPTS = [f"AGE_{a}_{a+5}_years" for a in range(15, 100, 5)] | |
| SEX_OPTS = ["SEX_M", "SEX_F"] | |
| RACE_OPTS = ["RACE_ASIAN", "RACE_BLACK", "RACE_HISPANIC", "RACE_OTHER", "RACE_UNKNOWN", "RACE_WHITE"] | |
| MARITAL_OPTS = ["MARITAL_STATUS_DIVORCED", "MARITAL_STATUS_MARRIED", "MARITAL_STATUS_SINGLE", | |
| "MARITAL_STATUS_UNKNOWN", "MARITAL_STATUS_WIDOWED"] | |
| YEAR_OPTS = [f"YEAR_{y}" for y in range(2005, 2021)] | |
| with gr.Blocks(title="Coogee (local model) β Synthetic EHR Generator") as demo: | |
| gr.Markdown("## Coogee β Generate synthetic EHR timelines (local model class)") | |
| with gr.Row(): | |
| age = gr.Dropdown(AGE_OPTS, value="AGE_85_90_years", label="Age") | |
| sex = gr.Dropdown(SEX_OPTS, value="SEX_M", label="Sex") | |
| race = gr.Dropdown(RACE_OPTS, value="RACE_UNKNOWN", label="Ethnicity") | |
| marital = gr.Dropdown(MARITAL_OPTS, value="MARITAL_STATUS_WIDOWED", label="Marital") | |
| year = gr.Dropdown(YEAR_OPTS, value="YEAR_2017", label="Year") | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider(16, 1024, value=256, step=1, label="Max new tokens") | |
| temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p") | |
| top_k = gr.Slider(0, 200, value=50, step=1, label="Top-k") | |
| seed = gr.Textbox(value="", label="Seed (optional)") | |
| btn = gr.Button("Generate") | |
| out_text = gr.Textbox(lines=6, label="Generated timeline") | |
| out_table = gr.Dataframe(headers=["Idx", "Token"], label="Token table", interactive=False) | |
| btn.click( | |
| fn=generate_timeline, | |
| inputs=[age, sex, race, marital, year, max_new_tokens, temperature, top_p, top_k, seed], | |
| outputs=[out_text, out_table], | |
| api_name="generate", | |
| ) | |
| demo.queue(max_size=20).launch() |