jameszhou-gl commited on
Commit
4e2ac81
·
0 Parent(s):

Initial commit of Coogee demo

Browse files
Files changed (3) hide show
  1. app.py +214 -0
  2. model/model.py +491 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import gradio as gr
5
+ from safetensors.torch import load_file
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers import AutoTokenizer
8
+
9
+ # -----------------------------
10
+ # Config: where to load things
11
+ # -----------------------------
12
+ # Tokenizer repo (already published)
13
+ TOKENIZER_REPO = os.environ.get("COOGEE_TOKENIZER_REPO", "jameszhou-gl/ehr-gpt")
14
+
15
+ # Where to get model weights/config:
16
+ # - Option A (default): from the same HF repo as tokenizer (set names below)
17
+ # - Option B: from local files (set LOCAL_* paths)
18
+ WEIGHT_REPO = os.environ.get("COOGEE_WEIGHT_REPO", TOKENIZER_REPO)
19
+ WEIGHT_FILENAME = os.environ.get("COOGEE_WEIGHT_FILENAME", "model.safetensors")
20
+ CONFIG_FILENAME = os.environ.get("COOGEE_CONFIG_FILENAME", "config.json")
21
+
22
+ LOCAL_WEIGHT = os.environ.get("COOGEE_LOCAL_WEIGHT", "") # e.g., "hf_upload_tmp/model.safetensors"
23
+ LOCAL_CONFIG = os.environ.get("COOGEE_LOCAL_CONFIG", "") # e.g., "hf_upload_tmp/config.json"
24
+
25
+ # Optional: HF token if the repo is private in a Space
26
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
27
+
28
+ # -----------------------------
29
+ # 1) Import your local model code
30
+ # -----------------------------
31
+ from model.model import Transformer, ModelArgs # <- YOUR local classes
32
+
33
+ # -----------------------------
34
+ # 2) Load tokenizer
35
+ # -----------------------------
36
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_REPO, use_fast=False)
37
+
38
+ # EOS handling: prefer tokenizer.eos_token_id, or fall back to END_RECORD
39
+ def get_eos_id():
40
+ if tok.eos_token_id is not None:
41
+ return tok.eos_token_id
42
+ try:
43
+ return tok.convert_tokens_to_ids("END_RECORD")
44
+ except Exception:
45
+ raise ValueError("No eos_token_id and 'END_RECORD' not found; set eos_token in tokenizer_config.json.")
46
+
47
+ EOS_ID = get_eos_id()
48
+
49
+ # -----------------------------
50
+ # 3) Load config & weights
51
+ # -----------------------------
52
+ if LOCAL_CONFIG and os.path.isfile(LOCAL_CONFIG):
53
+ cfg_path = LOCAL_CONFIG
54
+ else:
55
+ cfg_path = hf_hub_download(WEIGHT_REPO, filename=CONFIG_FILENAME, token=HF_TOKEN)
56
+
57
+ with open(cfg_path, "r") as f:
58
+ cfg = json.load(f)
59
+
60
+ args = ModelArgs(**cfg)
61
+ model = Transformer(args)
62
+
63
+ if LOCAL_WEIGHT and os.path.isfile(LOCAL_WEIGHT):
64
+ weight_path = LOCAL_WEIGHT
65
+ else:
66
+ weight_path = hf_hub_download(WEIGHT_REPO, filename=WEIGHT_FILENAME, token=HF_TOKEN)
67
+
68
+ state = load_file(weight_path)
69
+ missing, unexpected = model.load_state_dict(state, strict=False)
70
+ if missing or unexpected:
71
+ print("[load_state_dict] missing:", missing)
72
+ print("[load_state_dict] unexpected:", unexpected)
73
+
74
+ device = "cuda" if torch.cuda.is_available() else "cpu"
75
+ model.to(device).eval()
76
+
77
+ # -----------------------------
78
+ # 4) Generation wrapper
79
+ # -----------------------------
80
+ def _has_generate(m):
81
+ return callable(getattr(m, "generate", None))
82
+
83
+ def _fallback_sample(input_ids, max_new_tokens=256, temperature=1.0, top_p=0.95, top_k=50):
84
+ """
85
+ Minimal sampling loop if your Transformer doesn't implement .generate().
86
+ Assumes model(input_ids) -> logits [B, T, V].
87
+ Stops on EOS_ID.
88
+ """
89
+ model.eval()
90
+ ids = input_ids.clone()
91
+ with torch.no_grad():
92
+ for _ in range(int(max_new_tokens)):
93
+ logits = model(ids) # your forward may return logits directly; if it returns dict, adapt here
94
+ if isinstance(logits, dict):
95
+ logits = logits["logits"]
96
+ next_logits = logits[:, -1, :] / max(temperature, 1e-6) # [B, V]
97
+
98
+ # top-k / nucleus
99
+ if top_k and top_k > 0:
100
+ topk_vals, topk_idx = torch.topk(next_logits, k=min(top_k, next_logits.size(-1)), dim=-1)
101
+ filt = torch.full_like(next_logits, float("-inf"))
102
+ filt.scatter_(1, topk_idx, topk_vals)
103
+ next_logits = filt
104
+
105
+ if 0.0 < top_p < 1.0:
106
+ sorted_logits, sorted_idx = torch.sort(next_logits, descending=True)
107
+ probs = torch.softmax(sorted_logits, dim=-1)
108
+ cum = torch.cumsum(probs, dim=-1)
109
+ mask = cum > top_p
110
+ mask[..., 1:] = mask[..., :-1].clone()
111
+ mask[..., 0] = False
112
+ sorted_logits = sorted_logits.masked_fill(mask, float("-inf"))
113
+ next_logits = torch.zeros_like(next_logits).scatter(1, sorted_idx, sorted_logits)
114
+
115
+ probs = torch.softmax(next_logits, dim=-1)
116
+ next_id = torch.multinomial(probs, num_samples=1) # [B, 1]
117
+
118
+ ids = torch.cat([ids, next_id], dim=1) # [B, T+1]
119
+ if int(next_id.item()) == EOS_ID:
120
+ break
121
+ return ids
122
+
123
+ def generate_timeline(age, sex, race, marital, year,
124
+ max_new_tokens, temperature, top_p, top_k, seed):
125
+ # Build prompt tokens
126
+ prompt_tokens = [
127
+ "START_RECORD",
128
+ age, sex, race, marital, year,
129
+ ]
130
+ # to ids
131
+ try:
132
+ ids = [[tok.convert_tokens_to_ids(t) for t in prompt_tokens]]
133
+ except Exception as e:
134
+ return f"Tokenization error: {e}", ""
135
+
136
+ input_ids = torch.tensor(ids, dtype=torch.long, device=device)
137
+
138
+ # Seed
139
+ if seed is not None and str(seed).strip() != "":
140
+ try:
141
+ s = int(seed)
142
+ torch.manual_seed(s)
143
+ if torch.cuda.is_available():
144
+ torch.cuda.manual_seed_all(s)
145
+ except Exception:
146
+ pass
147
+
148
+ # Generate
149
+ if _has_generate(model):
150
+ out = model.generate(
151
+ input_ids=input_ids,
152
+ max_length=min(args.n_ctx, input_ids.size(1) + int(max_new_tokens)),
153
+ temperature=float(temperature),
154
+ top_p=float(top_p),
155
+ top_k=int(top_k),
156
+ do_sample=True,
157
+ eos_token_id=EOS_ID,
158
+ )
159
+ else:
160
+ out = _fallback_sample(
161
+ input_ids, max_new_tokens=int(max_new_tokens),
162
+ temperature=float(temperature), top_p=float(top_p), top_k=int(top_k)
163
+ )
164
+
165
+ gen_ids = out[0, input_ids.size(1):].tolist()
166
+ if EOS_ID in gen_ids:
167
+ end_idx = gen_ids.index(EOS_ID)
168
+ gen_ids = gen_ids[: end_idx + 1]
169
+
170
+ gen_tokens = tok.convert_ids_to_tokens(gen_ids)
171
+ timeline = prompt_tokens + gen_tokens
172
+ text_view = " ".join(timeline)
173
+ table_view = [[i, t] for i, t in enumerate(timeline)]
174
+ return text_view, table_view
175
+
176
+ # -----------------------------
177
+ # 5) Gradio UI
178
+ # -----------------------------
179
+ AGE_OPTS = [f"AGE_{a}_{a+5}_years" for a in range(15, 100, 5)]
180
+ SEX_OPTS = ["SEX_M", "SEX_F"]
181
+ RACE_OPTS = ["RACE_ASIAN", "RACE_BLACK", "RACE_HISPANIC", "RACE_OTHER", "RACE_UNKNOWN", "RACE_WHITE"]
182
+ MARITAL_OPTS = ["MARITAL_STATUS_DIVORCED", "MARITAL_STATUS_MARRIED", "MARITAL_STATUS_SINGLE",
183
+ "MARITAL_STATUS_UNKNOWN", "MARITAL_STATUS_WIDOWED"]
184
+ YEAR_OPTS = [f"YEAR_{y}" for y in range(2005, 2021)]
185
+
186
+ with gr.Blocks(title="Coogee (local model) — Synthetic EHR Generator") as demo:
187
+ gr.Markdown("## Coogee — Generate synthetic EHR timelines (local model class)")
188
+
189
+ with gr.Row():
190
+ age = gr.Dropdown(AGE_OPTS, value="AGE_85_90_years", label="Age")
191
+ sex = gr.Dropdown(SEX_OPTS, value="SEX_M", label="Sex")
192
+ race = gr.Dropdown(RACE_OPTS, value="RACE_UNKNOWN", label="Ethnicity")
193
+ marital = gr.Dropdown(MARITAL_OPTS, value="MARITAL_STATUS_WIDOWED", label="Marital")
194
+ year = gr.Dropdown(YEAR_OPTS, value="YEAR_2017", label="Year")
195
+
196
+ with gr.Row():
197
+ max_new_tokens = gr.Slider(16, 1024, value=256, step=1, label="Max new tokens")
198
+ temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Temperature")
199
+ top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p")
200
+ top_k = gr.Slider(0, 200, value=50, step=1, label="Top-k")
201
+ seed = gr.Textbox(value="", label="Seed (optional)")
202
+
203
+ btn = gr.Button("Generate")
204
+ out_text = gr.Textbox(lines=6, label="Generated timeline")
205
+ out_table = gr.Dataframe(headers=["Idx", "Token"], label="Token table", interactive=False)
206
+
207
+ btn.click(
208
+ fn=generate_timeline,
209
+ inputs=[age, sex, race, marital, year, max_new_tokens, temperature, top_p, top_k, seed],
210
+ outputs=[out_text, out_table],
211
+ api_name="generate",
212
+ )
213
+
214
+ demo.queue(max_size=20).launch()
model/model.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ code by Guanglin Zhou ([email protected])
3
+ Reference: https://github.com/openai/gpt-oss/blob/main/gpt_oss/torch/model.py
4
+ 17 Sep 2025: Add KV-cache for efficient inference;
5
+ '''
6
+ import math, json, os
7
+ import numpy as np
8
+ from dataclasses import dataclass, asdict
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ def load_auxiliary_embeddings(embd_path: str, vocab_size: int, embd_dim: int, embd_type: str) -> torch.nn.Parameter:
16
+ """Load auxiliary embeddings (hierarchy, or semantic) from npz file.
17
+
18
+ Args:
19
+ embd_path: Path to the npz file containing embeddings
20
+ vocab_size: Size of the vocabulary
21
+ embd_dim: Dimension of the embeddings
22
+ embd_type: Type of embeddings (for logging purposes)
23
+
24
+ Returns:
25
+ torch.nn.Parameter: Frozen embedding matrix of shape (vocab_size, embd_dim)
26
+ """
27
+ embeddings = np.zeros((vocab_size, embd_dim))
28
+
29
+ with np.load(embd_path) as data:
30
+ for token_id_str in data.files:
31
+ token_id = int(token_id_str)
32
+ embeddings[token_id] = data[token_id_str]
33
+
34
+ print(f"Loaded {embd_type} embeddings for {len(data.files)} tokens")
35
+
36
+ return nn.Parameter(
37
+ torch.FloatTensor(embeddings),
38
+ requires_grad=False
39
+ )
40
+
41
+
42
+ @dataclass
43
+ class ModelArgs:
44
+ LOG_DIR: str # Directory containing model checkpoints and embeddings
45
+ vocab_size: int = -1 # later loaded from tokenizer
46
+ n_embd: int = 576
47
+ hidden_dim: int = 768
48
+ n_layers: int = 6
49
+ n_ctx: int = 2048
50
+ num_attention_heads: int = 9
51
+ num_key_value_heads: int = 3
52
+ drop_out: float = 0.1
53
+ hidden_act: str = "silu"
54
+ initializer_range: float = 0.041666666666666664
55
+ rms_norm_eps: float = 1e-5
56
+ use_cache: bool = True
57
+ pad_token_id: Optional[int] = None
58
+ bos_token_id: int = 0
59
+ eos_token_id: int = 0
60
+ tie_word_embeddings: bool = True
61
+ rope_theta: float = 10000.0
62
+ use_hierarchy_embd: bool = False
63
+ hierarchy_dim: Optional[int] = None
64
+ use_semantic_embd: bool = False
65
+ semantic_dim: Optional[int] = None
66
+
67
+ @classmethod
68
+ def from_dict(cls, d: dict):
69
+ return cls(**{k: v for k, v in d.items() if k in cls.__annotations__})
70
+
71
+ def to_json(self, path: str):
72
+ with open(path, "w") as f:
73
+ json.dump(asdict(self), f, indent=2)
74
+
75
+ @classmethod
76
+ def from_json(cls, path: str):
77
+ with open(path, "r") as f:
78
+ d = json.load(f)
79
+ return cls.from_dict(d)
80
+
81
+ class RMSNorm(nn.Module):
82
+ def __init__(self, n_embd, eps=1e-5):
83
+ super().__init__()
84
+ self.weight = nn.Parameter(torch.ones(n_embd))
85
+ self.eps = eps
86
+
87
+ def forward(self, x):
88
+ variance = x.pow(2).mean(-1, keepdim=True)
89
+ x = x * torch.rsqrt(variance + self.eps)
90
+ return self.weight * x
91
+
92
+ def precompute_rope_frequencies(n_embd: int, n_ctx: int, theta: float = 10000.0):
93
+ position = torch.arange(n_ctx).unsqueeze(1) # [seq_len, 1]
94
+ div_term = theta ** (torch.arange(0, n_embd, 2).float() / n_embd) # [n_embd/2]
95
+ freqs = position / div_term # [seq_len, n_embd/2]
96
+ return freqs
97
+
98
+ def apply_rotary_embeddings(x: torch.Tensor, freqs: torch.Tensor):
99
+ # x shape: [batch, seq_len, heads, head_dim]
100
+ # freqs shape: [seq_len, head_dim/2]
101
+ x_rot = x.float()
102
+
103
+ # Reshape freqs to match x's dimensions
104
+ freqs = freqs.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, n_embd/2]
105
+
106
+ # Split channels for rotation
107
+ x1, x2 = x_rot[..., :x_rot.shape[-1]//2], x_rot[..., x_rot.shape[-1]//2:]
108
+
109
+ # Apply rotary embeddings
110
+ cos = torch.cos(freqs).to(x.device)
111
+ sin = torch.sin(freqs).to(x.device)
112
+
113
+ # Ensure broadcasting dimensions match
114
+ cos = cos.expand_as(x1)
115
+ sin = sin.expand_as(x1)
116
+
117
+ # Rotate x1 and x2
118
+ x1_rot = x1 * cos - x2 * sin
119
+ x2_rot = x2 * cos + x1 * sin
120
+
121
+ # Concatenate back
122
+ return torch.cat([x1_rot, x2_rot], dim=-1).to(x.dtype)
123
+
124
+ def apply_rope_with_pos_ids(x: torch.Tensor, freqs: torch.Tensor, position_ids: torch.Tensor):
125
+ """
126
+ x: [B, T, H, Dh] (queries or keys)
127
+ freqs: [max_seq_len, Dh/2] (precomputed table)
128
+ position_ids: [B, T] absolute positions for these tokens
129
+ """
130
+ B, T, H, Dh = x.shape
131
+ x = x.float()
132
+
133
+ # gather the cos/sin rows for each position in the batch
134
+ cos = torch.cos(freqs[position_ids]) # [B, T, Dh/2]
135
+ sin = torch.sin(freqs[position_ids]) # [B, T, Dh/2]
136
+
137
+ # expand to heads
138
+ cos = cos.unsqueeze(2).expand(B, T, H, Dh // 2) # [B, T, H, Dh/2]
139
+ sin = sin.unsqueeze(2).expand(B, T, H, Dh // 2)
140
+
141
+ x1, x2 = x[..., :Dh//2], x[..., Dh//2:]
142
+ x_rot1 = x1 * cos - x2 * sin
143
+ x_rot2 = x2 * cos + x1 * sin
144
+ out = torch.cat([x_rot1, x_rot2], dim=-1)
145
+ return out.to(dtype=x.dtype)
146
+
147
+ class SelfAttention(nn.Module):
148
+ def __init__(self, args: ModelArgs):
149
+ super().__init__()
150
+ self.n_embd = args.n_embd
151
+ self.num_heads = args.num_attention_heads
152
+ self.num_kv_heads = args.num_key_value_heads
153
+ self.head_dim = args.n_embd // args.num_attention_heads
154
+
155
+ # Adjust projections to match head dimensions
156
+ self.q_proj = nn.Linear(args.n_embd, self.num_heads * self.head_dim, bias=False)
157
+ self.k_proj = nn.Linear(args.n_embd, self.num_kv_heads * self.head_dim, bias=False)
158
+ self.v_proj = nn.Linear(args.n_embd, self.num_kv_heads * self.head_dim, bias=False)
159
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, args.n_embd, bias=False)
160
+
161
+ # Initialize rotary embeddings
162
+ self.register_buffer(
163
+ "rope_freqs",
164
+ precompute_rope_frequencies(
165
+ self.head_dim, # Use full head_dim for frequencies
166
+ args.n_ctx,
167
+ args.rope_theta
168
+ ),
169
+ persistent=False
170
+ )
171
+ self.attn_drop = nn.Dropout(args.drop_out)
172
+ self.residual_drop = nn.Dropout(args.drop_out)
173
+
174
+ def forward(self, hidden_states, attention_mask=None, past_key_value: Optional[tuple] = None, use_cache: bool = False, position_offset: Optional[int] = None):
175
+ """
176
+ hidden_states: [B, Tq, D]
177
+ past_key_value: Optional[(past_k, past_v)] each [B, H_kv, Tp, Dh] already RoPE-rotated
178
+ use_cache: if True, return current (k,v) concatenated with past for caching upstream
179
+ position_offset: if provided, absolute position of the first token in `hidden_states`
180
+ (i.e., past length). If None, defaults to 0.
181
+ """
182
+ B, Tq, _ = hidden_states.size()
183
+ Hq = self.num_heads
184
+ Hkv = self.num_kv_heads
185
+ Dh = self.head_dim
186
+
187
+ # Projections
188
+ q = self.q_proj(hidden_states).view(B, Tq, Hq, Dh)
189
+ k_new = self.k_proj(hidden_states).view(B, Tq, Hkv, Dh)
190
+ v_new = self.v_proj(hidden_states).view(B, Tq, Hkv, Dh)
191
+
192
+ # Absolute positions for the NEW tokens
193
+ past_len = 0 if past_key_value is None else past_key_value[0].size(2)
194
+ if position_offset is not None:
195
+ position_offset = past_len
196
+ pos_ids_new = (torch.arange(Tq, device=hidden_states.device) + position_offset).view(1, Tq).expand(B, Tq)
197
+
198
+ # Apply RoPE to new q and k
199
+ q = apply_rope_with_pos_ids(q, self.rope_freqs, pos_ids_new)
200
+ k_new = apply_rope_with_pos_ids(k_new, self.rope_freqs, pos_ids_new)
201
+
202
+ # Prepare full K/V in KV-head space, concatenate with past if any
203
+ if past_key_value is not None:
204
+ past_k, past_v = past_key_value
205
+ k_cat = torch.cat([past_k, k_new.transpose(1, 2)], dim=2)
206
+ v_cat = torch.cat([past_v, v_new.transpose(1, 2)], dim=2)
207
+ else:
208
+ k_cat = k_new.transpose(1, 2)
209
+ v_cat = v_new.transpose(1, 2)
210
+ Tk = k_cat.size(2)
211
+
212
+ # Expand KV to query-heads if using GQA
213
+ if Hkv < Hq:
214
+ repeat = Hq // Hkv
215
+ k_full = k_cat.repeat_interleave(repeat, dim=1)
216
+ v_full = v_cat.repeat_interleave(repeat, dim=1)
217
+ else:
218
+ k_full = k_cat
219
+ v_full = v_cat
220
+
221
+ # Scaled dot-product attention
222
+ q = q.transpose(1, 2) # (B, Hq, Tq, Dh)
223
+
224
+ attn_scores = torch.matmul(q, k_full.transpose(-2, -1)) / math.sqrt(Dh)
225
+
226
+ i_abs = (position_offset + torch.arange(Tq, device=hidden_states.device).unsqueeze(1))
227
+ j_abs = torch.arange(Tk, device=hidden_states.device).unsqueeze(0)
228
+ causal = (j_abs <= i_abs).unsqueeze(0).unsqueeze(0)
229
+ attn_scores = attn_scores.masked_fill(~causal, float('-inf'))
230
+
231
+ # Optional extra mask (e.g., padding mask shaped/broadcastable to [B,1,Tq,Tk])
232
+ if attention_mask is not None:
233
+ attn_scores = attn_scores + attention_mask
234
+
235
+ attn_probs = F.softmax(attn_scores, dim=-1)
236
+ attn_probs = self.attn_drop(attn_probs)
237
+ context = torch.matmul(attn_probs, v_full)
238
+
239
+ context = context.transpose(1, 2).contiguous().view(B, Tq, Hq*Dh)
240
+ out = self.o_proj(context)
241
+ out = self.residual_drop(out)
242
+ if use_cache:
243
+ return out, (k_cat, v_cat)
244
+ else:
245
+ return out, None
246
+
247
+ class FeedForward(nn.Module):
248
+ def __init__(self, args: ModelArgs):
249
+ super().__init__()
250
+ self.gate_proj = nn.Linear(args.n_embd, args.hidden_dim, bias=False)
251
+ self.up_proj = nn.Linear(args.n_embd, args.hidden_dim, bias=False)
252
+ self.down_proj = nn.Linear(args.hidden_dim, args.n_embd, bias=False)
253
+ self.act_fn = nn.SiLU()
254
+ self.drop_out = nn.Dropout(args.drop_out)
255
+ def forward(self, x):
256
+ gate = self.act_fn(self.gate_proj(x))
257
+ up = self.up_proj(x)
258
+ out = self.down_proj(gate * up)
259
+ return self.drop_out(out)
260
+
261
+ class DecoderBlock(nn.Module):
262
+ def __init__(self, args: ModelArgs):
263
+ super().__init__()
264
+ self.self_attn = SelfAttention(args)
265
+ self.ffn = FeedForward(args)
266
+ self.input_layernorm = RMSNorm(args.n_embd, args.rms_norm_eps)
267
+ self.post_attention_layernorm = RMSNorm(args.n_embd, args.rms_norm_eps)
268
+
269
+ def forward(self, hidden_states, attention_mask=None, past_key_value: Optional[tuple] = None, use_cache: bool = False, position_offset: Optional[int] = None):
270
+ residual = hidden_states
271
+ x = self.input_layernorm(hidden_states)
272
+ attn_out, present_kv = self.self_attn(x, attention_mask, past_key_value, use_cache, position_offset)
273
+ hidden_states = residual + attn_out
274
+
275
+ residual = hidden_states
276
+ x = self.post_attention_layernorm(hidden_states)
277
+ ffn_out = self.ffn(x)
278
+ hidden_states = residual + ffn_out
279
+
280
+ return hidden_states, present_kv
281
+
282
+ class Transformer(nn.Module):
283
+ def __init__(self, args: ModelArgs):
284
+ super().__init__()
285
+ self.args = args
286
+
287
+ self.tok_embeddings = nn.Embedding(args.vocab_size, args.n_embd)
288
+ self.emb_drop = nn.Dropout(args.drop_out)
289
+ # Optional auxiliary embeddings
290
+ if args.use_hierarchy_embd:
291
+ hierarchy_embd_path = os.path.join(args.LOG_DIR, "knowledge_embd", "hierarchy_embd.npz")
292
+ self.hierarchy_embeddings = load_auxiliary_embeddings(
293
+ hierarchy_embd_path,
294
+ args.vocab_size,
295
+ args.hierarchy_dim,
296
+ "hierarchy"
297
+ )
298
+ self.hierarchy_proj = nn.Linear(args.hierarchy_dim, args.n_embd, bias=False) # project to n_embd
299
+ self.alpha_h = nn.Parameter(torch.tensor(1.0))
300
+ self.h_rms = RMSNorm(args.n_embd, args.rms_norm_eps)
301
+ if args.use_semantic_embd:
302
+ semantic_embd_path = os.path.join(args.LOG_DIR, "knowledge_embd", "semantic_embd.npz")
303
+ self.semantic_embeddings = load_auxiliary_embeddings(
304
+ semantic_embd_path,
305
+ args.vocab_size,
306
+ args.semantic_dim,
307
+ "semantic"
308
+ )
309
+ self.semantic_proj = nn.Linear(args.semantic_dim, args.n_embd, bias=False) # project to n_embd
310
+ self.alpha_s = nn.Parameter(torch.tensor(1.0))
311
+ self.s_rms = RMSNorm(args.n_embd, args.rms_norm_eps)
312
+
313
+ self.layers = nn.ModuleList()
314
+ for _ in range(args.n_layers):
315
+ self.layers.append(DecoderBlock(args))
316
+ self.final_norm = RMSNorm(args.n_embd, args.rms_norm_eps)
317
+ # Add output before weight tying
318
+ self.output = nn.Linear(args.n_embd, args.vocab_size, bias=False)
319
+ # Initialize weights
320
+ self.apply(self._init_weights)
321
+
322
+ # Tie weights if configured
323
+ if args.tie_word_embeddings:
324
+ self.output.weight = self.tok_embeddings.weight
325
+
326
+ def _init_weights(self, module):
327
+ if isinstance(module, nn.Linear):
328
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.args.initializer_range)
329
+ if module.bias is not None:
330
+ torch.nn.init.zeros_(module.bias)
331
+ elif isinstance(module, nn.Embedding):
332
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.args.initializer_range)
333
+
334
+ def print_model_params(self, print_details: bool = True):
335
+ """Print a detailed breakdown of model parameters."""
336
+ total_params = 0
337
+ if print_details:
338
+ print("\nModel Parameter Details:")
339
+ print("-" * 100)
340
+ print(f"{'Layer':<40} {'Shape':>20} {'Parameters':>15} {'Status':>15}")
341
+ print("-" * 100)
342
+
343
+ for name, param in self.named_parameters():
344
+ param_count = param.numel()
345
+ total_params += param_count
346
+ status = "Trainable" if param.requires_grad else "Frozen"
347
+ print(f"{name:<40} {str(list(param.shape)):>20} {param_count:>15,} {status:>15}")
348
+
349
+ print("-" * 80)
350
+ print(f"{'Total Parameters':<40} {' ':>20} {total_params:>15,}")
351
+ print("\nParameter count by component:")
352
+
353
+ # Count parameters by major components
354
+ def count_params(pattern):
355
+ all_params = sum(p.numel() for name, p in self.named_parameters() if pattern in name)
356
+ trainable = sum(p.numel() for name, p in self.named_parameters() if pattern in name and p.requires_grad)
357
+ frozen = sum(p.numel() for name, p in self.named_parameters() if pattern in name and not p.requires_grad)
358
+ return all_params, trainable, frozen
359
+
360
+ components = {
361
+ 'Token Embeddings': 'tok_embeddings',
362
+ 'Hierarchy Embeddings': 'hierarchy_embeddings',
363
+ 'Hierarchy Projection': 'hierarchy_proj',
364
+ 'Semantic Embeddings': 'semantic_embeddings',
365
+ 'Semantic Projection': 'semantic_proj',
366
+ 'DecoderBlock': 'layers',
367
+ 'Final Norm': 'final_norm',
368
+ 'Output Layer': 'output'
369
+ }
370
+
371
+ print(f"{'Component':<20} {'Total':>15} {'Trainable':>15} {'Frozen':>15}")
372
+ print("-" * 70)
373
+
374
+ for component, pattern in components.items():
375
+ total, trainable, frozen = count_params(pattern)
376
+ print(f"{component:<20} {total:>15,} {trainable:>15,} {frozen:>15,}")
377
+
378
+ # Count trainable vs non-trainable parameters
379
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
380
+ frozen_params = sum(p.numel() for p in self.parameters() if not p.requires_grad)
381
+ print(f"\nTrainable parameters: {trainable_params:>15,}")
382
+ print(f"Frozen parameters: {frozen_params:>15,}")
383
+ print(f"Total parameters: {trainable_params + frozen_params:>15,}")
384
+
385
+ def forward(self, input_ids, attention_mask=None, past_key_values: Optional[list] = None, use_cache: bool = False):
386
+ B, Tq = input_ids.shape
387
+ device = input_ids.device
388
+
389
+ hidden_states = self.tok_embeddings(input_ids) # [B, Tq, D]
390
+
391
+ # Optional auxiliary embeddings (unchanged)
392
+ if self.args.use_hierarchy_embd:
393
+ h = self.hierarchy_proj(self.hierarchy_embeddings[input_ids])
394
+ hidden_states = hidden_states + self.alpha_h * self.h_rms(h)
395
+ if self.args.use_semantic_embd:
396
+ s = self.semantic_proj(self.semantic_embeddings[input_ids])
397
+ hidden_states = hidden_states + self.alpha_s * self.s_rms(s)
398
+ hidden_states = self.emb_drop(hidden_states)
399
+ # We build causal mask per layer inside attention using absolute positions. If need padding masks, use the attention_mask argument.
400
+
401
+ if past_key_values is None:
402
+ past_key_values = [None] * len(self.layers)
403
+
404
+ # Absolute starting position for FIRST query token in this call
405
+ past_len = 0 if past_key_values[0] is None else past_key_values[0][0].size(2)
406
+
407
+ presents = [] if use_cache else None
408
+ for layer, past_kv in zip(self.layers, past_key_values):
409
+ hidden_states, present_kv = layer(hidden_states, attention_mask, past_kv, use_cache, position_offset=past_len)
410
+ if use_cache:
411
+ presents.append(present_kv)
412
+ hidden_states = self.final_norm(hidden_states)
413
+ logits = self.output(hidden_states)
414
+ if use_cache:
415
+ return logits, presents
416
+ else:
417
+ return logits
418
+
419
+ @staticmethod
420
+ def _sample_top_p(logits: torch.Tensor, top_p: float = 0.9, temperature: float = 1.0) -> torch.Tensor:
421
+ """
422
+ logits: [B, V]
423
+ returns: next_token ids [B]
424
+ """
425
+ if temperature <= 0:
426
+ # greedy fallback
427
+ return torch.argmax(logits, dim=-1)
428
+
429
+ logits = logits / temperature
430
+ probs = F.softmax(logits, dim=-1) # [B, V]
431
+
432
+ # sort by prob desc
433
+ sorted_probs, sorted_idx = torch.sort(probs, dim=-1, descending=True) # [B, V], [B, V]
434
+ cumsum = torch.cumsum(sorted_probs, dim=-1) # [B, V]
435
+
436
+ # mask everything past the nucleus (keep the first token that crosses top_p)
437
+ cutoff = cumsum > top_p # [B, V] boolean
438
+ # shift mask right so we keep at least one token per row
439
+ cutoff[..., 1:] = cutoff[..., :-1].clone()
440
+ cutoff[..., 0] = False
441
+
442
+ sorted_probs = sorted_probs.masked_fill(cutoff, 0.0)
443
+ # re-normalize
444
+ sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
445
+
446
+ # sample in the sorted space then map back to original vocab ids
447
+ next_sorted_idx = torch.multinomial(sorted_probs, num_samples=1) # [B, 1]
448
+ next_token = torch.gather(sorted_idx, -1, next_sorted_idx).squeeze(-1) # [B]
449
+ return next_token
450
+
451
+ def print_alpha_values(self):
452
+ alpha_h, alpha_s = None, None
453
+ if hasattr(self, 'alpha_h'):
454
+ print(f"Hierarchy (alpha_h): {self.alpha_h.item():.4f}")
455
+ alpha_h = self.alpha_h.item()
456
+ if hasattr(self, 'alpha_s'):
457
+ print(f"Semantic (alpha_s): {self.alpha_s.item():.4f}")
458
+ alpha_s = self.alpha_s.item()
459
+ return alpha_h, alpha_s
460
+
461
+ def generate(self, input_ids, max_length=None, temperature=1.0, top_p=0.9, end_token_id=None):
462
+ """
463
+ input_ids: [B, T] (B can be 1)
464
+ returns: [B, T_out]
465
+ """
466
+ self.eval()
467
+ max_length = self.args.n_ctx if max_length is None else min(max_length, self.args.n_ctx)
468
+ end_token_id = self.args.eos_token_id if end_token_id is None else end_token_id # default is eos_token_id, might be different for different tasks
469
+
470
+ device = input_ids.device
471
+ cur = input_ids
472
+ B = cur.size(0)
473
+ finished = torch.zeros(B, dtype=torch.bool, device=device)
474
+
475
+ with torch.no_grad():
476
+ logits, past = self(cur, use_cache=True)
477
+
478
+ while cur.size(1) < max_length:
479
+ next_logits = logits[:, -1, :] # [B, V]
480
+ next_token = Transformer._sample_top_p(next_logits, top_p, temperature) # [B]
481
+ next_token = torch.where(finished, torch.full_like(next_token, end_token_id), next_token)
482
+
483
+ cur = torch.cat([cur, next_token.unsqueeze(1)], dim=1) # [B, T+1]
484
+ finished = finished | (next_token == end_token_id)
485
+ if torch.all(finished):
486
+ break
487
+
488
+ last = next_token.view(B, 1)
489
+ logits, past = self(last, use_cache=True, past_key_values=past)
490
+
491
+ return cur
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers>=4.43
3
+ huggingface_hub>=0.23
4
+ safetensors
5
+ gradio>=4.0
6
+ numpy