Yuto2007 commited on
Commit
e093a4b
·
verified ·
1 Parent(s): 77e53b3

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language: en
4
+ tags:
5
+ - biology
6
+ - genomics
7
+ - single-cell
8
+ library_name: transformers
9
+ ---
10
+
11
+ # TXModel - Standalone Version
12
+
13
+ **Zero external dependencies!** This model requires only:
14
+ - `transformers`
15
+ - `torch`
16
+ - `safetensors`
17
+
18
+ No llmfoundry, composer, or other libraries needed!
19
+
20
+ ## 🚀 Quick Start
21
+
22
+ ```python
23
+ from transformers import AutoModel
24
+ import torch
25
+
26
+ # Load model (downloads automatically from Hub)
27
+ model = AutoModel.from_pretrained(
28
+ "your-username/tx-model-standalone",
29
+ trust_remote_code=True
30
+ )
31
+
32
+ # Prepare inputs
33
+ genes = torch.randint(0, 100, (2, 10))
34
+ values = torch.rand(2, 10)
35
+ masks = torch.ones(2, 10).bool()
36
+
37
+ # Inference
38
+ model.eval()
39
+ with torch.no_grad():
40
+ output = model(genes=genes, values=values, gen_masks=masks)
41
+
42
+ print(output.last_hidden_state.shape) # [2, 10, d_model]
43
+ ```
44
+
45
+ ## 📦 Installation
46
+
47
+ ```bash
48
+ pip install transformers torch safetensors
49
+ ```
50
+
51
+ That's it! No other dependencies required.
52
+
53
+ ## 🎯 Usage
54
+
55
+ The model works exactly like any other HuggingFace model:
56
+
57
+ ```python
58
+ from transformers import AutoModel
59
+
60
+ # Load from Hub
61
+ model = AutoModel.from_pretrained(
62
+ "your-username/tx-model-standalone",
63
+ trust_remote_code=True
64
+ )
65
+
66
+ # Or load locally
67
+ model = AutoModel.from_pretrained(
68
+ "./path/to/model",
69
+ trust_remote_code=True
70
+ )
71
+
72
+ # Move to GPU
73
+ device = "cuda" if torch.cuda.is_available() else "cpu"
74
+ model = model.to(device)
75
+ model.eval()
76
+
77
+ # Your inference code here
78
+ ```
79
+
80
+ ## ⚡ Features
81
+
82
+ - ✅ **Zero external dependencies** (only transformers + torch)
83
+ - ✅ **Works with AutoModel** out of the box
84
+ - ✅ **Hub-ready** - upload and share easily
85
+ - ✅ **Same architecture** as original model
86
+ - ✅ **Full compatibility** with existing weights
87
+
88
+ ## 📊 Model Details
89
+
90
+ | Property | Value |
91
+ |----------|-------|
92
+ | Parameters | ~70M |
93
+ | Architecture | Transformer Encoder |
94
+ | Hidden Size | 512 |
95
+ | Layers | 12 |
96
+ | Attention Heads | 8 |
97
+
98
+ ## 🔧 Advanced Usage
99
+
100
+ ### Accessing Model Internals
101
+
102
+ ```python
103
+ # Access the TXModel directly
104
+ tx_model = model.tx_model
105
+
106
+ # Get cell embeddings
107
+ output = model(genes, values, masks)
108
+ cell_emb = output.last_hidden_state[:, 0, :] # CLS token
109
+
110
+ # Get gene embeddings
111
+ tx_output = tx_model(genes, values, masks, key_padding_mask=~genes.eq(0))
112
+ gene_embs = tx_output["gene_embeddings"] # If return_gene_embeddings=True
113
+ ```
114
+
115
+ ### Batch Processing
116
+
117
+ ```python
118
+ from torch.utils.data import DataLoader
119
+
120
+ # Your dataloader
121
+ dataloader = DataLoader(dataset, batch_size=32)
122
+
123
+ results = []
124
+ for batch in dataloader:
125
+ with torch.no_grad():
126
+ output = model(
127
+ genes=batch['genes'],
128
+ values=batch['values'],
129
+ gen_masks=batch['masks']
130
+ )
131
+ results.append(output.last_hidden_state)
132
+ ```
133
+
134
+ ## 🆚 vs Original Version
135
+
136
+ This standalone version:
137
+ - ✅ Removes dependencies on llmfoundry and composer
138
+ - ✅ Uses only PyTorch and Transformers components
139
+ - ✅ Works with standard HuggingFace tools
140
+ - ✅ Maintains same model architecture and weights
141
+ - ✅ Easier to install and deploy
142
+
143
+ ## 📝 Citation
144
+
145
+ If you use this model, please cite the original work:
146
+
147
+ ```bibtex
148
+ @article{tahoe2024,
149
+ title={Tahoe-x1: Foundation Model for Genomics},
150
+ author={...},
151
+ year={2024}
152
+ }
153
+ ```
154
+
155
+ ## 📄 License
156
+
157
+ Apache 2.0
blocks_standalone.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) Tahoe Therapeutics 2025. All rights reserved.
2
+ """
3
+ Standalone implementation of TXModel blocks without external dependencies.
4
+ Only requires: torch, transformers
5
+ """
6
+
7
+ import math
8
+ from typing import Optional, Dict, Any, Tuple
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import Tensor, nn
13
+
14
+
15
+ class MultiheadAttention(nn.Module):
16
+ """Standard multi-head attention implementation"""
17
+
18
+ def __init__(
19
+ self,
20
+ d_model: int,
21
+ n_heads: int,
22
+ kv_n_heads: Optional[int] = None,
23
+ dropout: float = 0.0,
24
+ bias: bool = True,
25
+ device: Optional[str] = None,
26
+ ):
27
+ super().__init__()
28
+ self.d_model = d_model
29
+ self.n_heads = n_heads
30
+ self.kv_n_heads = kv_n_heads if kv_n_heads is not None else n_heads
31
+ self.head_dim = d_model // n_heads
32
+ self.dropout = dropout
33
+
34
+ # Grouped Query Attention support
35
+ self.n_rep = n_heads // self.kv_n_heads
36
+
37
+ self.q_proj = nn.Linear(d_model, d_model, bias=bias, device=device)
38
+ self.k_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, bias=bias, device=device)
39
+ self.v_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, bias=bias, device=device)
40
+ self.out_proj = nn.Linear(d_model, d_model, bias=bias, device=device)
41
+
42
+ self.attn_dropout = nn.Dropout(dropout)
43
+
44
+ def forward(
45
+ self,
46
+ x: Tensor,
47
+ attn_bias: Optional[Tensor] = None,
48
+ key_padding_mask: Optional[Tensor] = None,
49
+ is_causal: bool = False,
50
+ **kwargs
51
+ ) -> Tuple[Tensor, None, None]:
52
+ batch_size, seq_len, _ = x.shape
53
+
54
+ # Project queries, keys, values
55
+ q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim)
56
+ k = self.k_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim)
57
+ v = self.v_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim)
58
+
59
+ # Transpose for attention: (batch, heads, seq, head_dim)
60
+ q = q.transpose(1, 2)
61
+ k = k.transpose(1, 2)
62
+ v = v.transpose(1, 2)
63
+
64
+ # Repeat k/v for grouped query attention
65
+ if self.n_rep > 1:
66
+ k = k.repeat_interleave(self.n_rep, dim=1)
67
+ v = v.repeat_interleave(self.n_rep, dim=1)
68
+
69
+ # Scaled dot-product attention
70
+ scale = 1.0 / math.sqrt(self.head_dim)
71
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
72
+
73
+ # Apply attention bias if provided
74
+ if attn_bias is not None:
75
+ attn_scores = attn_scores + attn_bias
76
+
77
+ # Apply key padding mask
78
+ if key_padding_mask is not None:
79
+ # key_padding_mask: (batch, seq_len) with True for valid positions
80
+ # Convert to attention mask: (batch, 1, 1, seq_len)
81
+ mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
82
+ attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
83
+
84
+ # Apply causal mask if needed
85
+ if is_causal:
86
+ causal_mask = torch.triu(
87
+ torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool),
88
+ diagonal=1
89
+ )
90
+ attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
91
+
92
+ # Softmax and dropout
93
+ attn_weights = F.softmax(attn_scores, dim=-1)
94
+ attn_weights = self.attn_dropout(attn_weights)
95
+
96
+ # Apply attention to values
97
+ output = torch.matmul(attn_weights, v)
98
+
99
+ # Reshape and project output
100
+ output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
101
+ output = self.out_proj(output)
102
+
103
+ return output, None, None
104
+
105
+
106
+ class TXBlock(nn.Module):
107
+ """Transformer encoder block with pre/post normalization support"""
108
+
109
+ def __init__(
110
+ self,
111
+ d_model: int,
112
+ n_heads: int,
113
+ expansion_ratio: int,
114
+ attn_config: Optional[Dict] = None,
115
+ norm_config: Optional[Dict] = None,
116
+ dropout: Optional[float] = 0.0,
117
+ activation: Optional[str] = "gelu",
118
+ device: Optional[str] = None,
119
+ norm_scheme: str = "pre",
120
+ use_glu: bool = False,
121
+ **kwargs: Any,
122
+ ) -> None:
123
+ super().__init__()
124
+
125
+ if attn_config is None:
126
+ attn_config = {}
127
+ if norm_config is None:
128
+ norm_config = {}
129
+
130
+ self.d_model = d_model
131
+ self.n_heads = n_heads
132
+ self.device = device
133
+ self.norm_scheme = norm_scheme
134
+ self.use_glu = use_glu
135
+
136
+ # Attention
137
+ kv_n_heads = attn_config.get("kv_n_heads", n_heads)
138
+ self.self_attn = MultiheadAttention(
139
+ d_model=d_model,
140
+ n_heads=n_heads,
141
+ kv_n_heads=kv_n_heads,
142
+ dropout=attn_config.get("attn_pdrop", 0.0),
143
+ device=device,
144
+ )
145
+
146
+ # FFN
147
+ dim_feedforward = d_model * expansion_ratio
148
+ self.up_proj = nn.Linear(d_model, dim_feedforward, device=device)
149
+ self.down_proj = nn.Linear(dim_feedforward, d_model, device=device)
150
+
151
+ if use_glu:
152
+ self.gate_proj = nn.Linear(d_model, dim_feedforward, device=device)
153
+
154
+ # Normalization
155
+ eps = norm_config.get("eps", 1e-5)
156
+ self.norm1 = nn.LayerNorm(d_model, eps=eps, device=device)
157
+ self.norm2 = nn.LayerNorm(d_model, eps=eps, device=device)
158
+
159
+ # Dropout
160
+ self.post_sa_dropout = nn.Dropout(dropout)
161
+ self.post_ffn_dropout = nn.Dropout(dropout)
162
+
163
+ # Activation
164
+ self.activation = self._get_activation_fn(activation)
165
+
166
+ if norm_scheme not in ["pre", "post"]:
167
+ raise ValueError("norm_scheme must be either pre or post")
168
+
169
+ @staticmethod
170
+ def _get_activation_fn(activation: str):
171
+ if activation == "gelu":
172
+ return nn.GELU()
173
+ elif activation == "relu":
174
+ return nn.ReLU()
175
+ elif activation == "silu" or activation == "swish":
176
+ return nn.SiLU()
177
+ elif activation == "leaky_relu":
178
+ return nn.LeakyReLU()
179
+ else:
180
+ raise ValueError(f"Unknown activation: {activation}")
181
+
182
+ def forward(
183
+ self,
184
+ x: Tensor,
185
+ attn_bias: Optional[Tensor] = None,
186
+ key_padding_mask: Optional[Tensor] = None,
187
+ **kwargs
188
+ ) -> Tensor:
189
+
190
+ if self.norm_scheme == "pre":
191
+ # Pre-norm: norm -> attention -> add
192
+ x = x + self._sa_block(
193
+ self.norm1(x),
194
+ attn_bias=attn_bias,
195
+ key_padding_mask=key_padding_mask,
196
+ )
197
+ x = x + self._ff_block(self.norm2(x))
198
+ else:
199
+ # Post-norm: attention -> add -> norm
200
+ x = self.norm1(
201
+ x + self._sa_block(
202
+ x,
203
+ attn_bias=attn_bias,
204
+ key_padding_mask=key_padding_mask,
205
+ )
206
+ )
207
+ x = self.norm2(x + self._ff_block(x))
208
+
209
+ return x
210
+
211
+ def _sa_block(
212
+ self,
213
+ x: Tensor,
214
+ attn_bias: Optional[Tensor] = None,
215
+ key_padding_mask: Optional[Tensor] = None,
216
+ ) -> Tensor:
217
+ x, _, _ = self.self_attn(
218
+ x,
219
+ attn_bias=attn_bias,
220
+ key_padding_mask=key_padding_mask,
221
+ is_causal=False,
222
+ )
223
+ return self.post_sa_dropout(x)
224
+
225
+ def _ff_block(self, x: Tensor) -> Tensor:
226
+ if self.use_glu:
227
+ # GLU variant: (gate * activation(x)) * up(x)
228
+ x = self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
229
+ else:
230
+ # Standard FFN
231
+ x = self.down_proj(self.activation(self.up_proj(x)))
232
+ return self.post_ffn_dropout(x)
233
+
234
+
235
+ class TXEncoder(nn.Module):
236
+ """Stack of transformer encoder layers"""
237
+
238
+ def __init__(
239
+ self,
240
+ encoder_layer: TXBlock,
241
+ num_layers: int,
242
+ use_norm: bool = False,
243
+ norm_config: Optional[Dict] = None,
244
+ attn_config: Optional[Dict] = None,
245
+ ):
246
+ super().__init__()
247
+
248
+ if norm_config is None:
249
+ norm_config = {}
250
+
251
+ # Clone the layer
252
+ self.layers = nn.ModuleList([
253
+ TXBlock(
254
+ d_model=encoder_layer.d_model,
255
+ n_heads=encoder_layer.n_heads,
256
+ expansion_ratio=encoder_layer.up_proj.out_features // encoder_layer.d_model,
257
+ attn_config=attn_config,
258
+ norm_config=norm_config,
259
+ activation="gelu",
260
+ device=encoder_layer.device,
261
+ norm_scheme=encoder_layer.norm_scheme,
262
+ use_glu=encoder_layer.use_glu,
263
+ )
264
+ for _ in range(num_layers)
265
+ ])
266
+
267
+ self.use_norm = use_norm
268
+ if use_norm:
269
+ eps = norm_config.get("eps", 1e-5)
270
+ self.norm = nn.LayerNorm(encoder_layer.d_model, eps=eps)
271
+
272
+ def forward(
273
+ self,
274
+ total_embs: Tensor,
275
+ key_padding_mask: Optional[Tensor] = None,
276
+ output_hidden_states: bool = False,
277
+ ) -> Tuple[Tensor, Optional[list]]:
278
+
279
+ x = total_embs
280
+ hidden_states = [] if output_hidden_states else None
281
+
282
+ for layer in self.layers:
283
+ x = layer(
284
+ x,
285
+ attn_bias=None,
286
+ key_padding_mask=key_padding_mask,
287
+ )
288
+
289
+ if output_hidden_states:
290
+ hidden_states.append(x)
291
+
292
+ if self.use_norm:
293
+ x = self.norm(x)
294
+
295
+ return x, hidden_states
296
+
297
+
298
+ class GeneEncoder(nn.Module):
299
+ """Gene embedding with optional extra embeddings"""
300
+
301
+ def __init__(
302
+ self,
303
+ num_embeddings: int,
304
+ embedding_dim: int,
305
+ padding_idx: int = 0,
306
+ use_norm: bool = False,
307
+ gene_encoder_cfg: Optional[Dict] = None,
308
+ ):
309
+ super().__init__()
310
+
311
+ if gene_encoder_cfg is None:
312
+ gene_encoder_cfg = {}
313
+
314
+ self.use_norm = use_norm
315
+ self.embedding = nn.Embedding(
316
+ num_embeddings,
317
+ embedding_dim,
318
+ padding_idx=padding_idx,
319
+ )
320
+
321
+ # For now, no extra embeddings in standalone version
322
+ self.project = nn.Identity()
323
+
324
+ if self.use_norm:
325
+ self.enc_norm = nn.LayerNorm(embedding_dim)
326
+
327
+ def forward(self, x: Tensor) -> Tensor:
328
+ x = self.embedding(x)
329
+ x = self.project(x)
330
+ if self.use_norm:
331
+ x = self.enc_norm(x)
332
+ return x
333
+
334
+
335
+ class ChemEncoder(nn.Module):
336
+ """Chemical compound encoder"""
337
+
338
+ def __init__(
339
+ self,
340
+ d_out: int,
341
+ padding_idx: int = 0,
342
+ activation: str = "leaky_relu",
343
+ use_norm: bool = True,
344
+ freeze: bool = False,
345
+ num_drugs: int = 1000,
346
+ fp_dim: int = 2048,
347
+ ):
348
+ super().__init__()
349
+
350
+ # Initialize with zeros (user should load pretrained weights)
351
+ drug_fps = torch.zeros((num_drugs, fp_dim), dtype=torch.float32)
352
+
353
+ self.embedding = nn.Embedding.from_pretrained(
354
+ drug_fps,
355
+ padding_idx=padding_idx,
356
+ freeze=freeze,
357
+ )
358
+
359
+ self.fc = nn.Linear(fp_dim, d_out)
360
+
361
+ if activation == "leaky_relu":
362
+ self.activation = nn.LeakyReLU()
363
+ elif activation == "relu":
364
+ self.activation = nn.ReLU()
365
+ elif activation == "gelu":
366
+ self.activation = nn.GELU()
367
+ else:
368
+ self.activation = nn.ReLU()
369
+
370
+ self.proj = nn.Linear(d_out, d_out)
371
+
372
+ self.use_norm = use_norm
373
+ if self.use_norm:
374
+ self.norm = nn.LayerNorm(d_out)
375
+
376
+ def forward(self, x: Tensor) -> Tensor:
377
+ x = self.embedding(x)
378
+ x = self.activation(self.fc(x))
379
+ x = self.proj(x)
380
+
381
+ if self.use_norm:
382
+ x = self.norm(x)
383
+ return x
384
+
385
+
386
+ class ContinuousValueEncoder(nn.Module):
387
+ """Encode continuous values to embeddings"""
388
+
389
+ def __init__(
390
+ self,
391
+ d_model: int,
392
+ dropout: float = 0.1,
393
+ max_value: int = 512,
394
+ activation: str = "relu",
395
+ use_norm: bool = False,
396
+ ):
397
+ super().__init__()
398
+
399
+ self.dropout = nn.Dropout(p=dropout)
400
+ self.linear1 = nn.Linear(1, d_model)
401
+
402
+ if activation == "relu":
403
+ self.activation = nn.ReLU()
404
+ elif activation == "gelu":
405
+ self.activation = nn.GELU()
406
+ elif activation == "leaky_relu":
407
+ self.activation = nn.LeakyReLU()
408
+ else:
409
+ self.activation = nn.ReLU()
410
+
411
+ self.linear2 = nn.Linear(d_model, d_model)
412
+
413
+ self.use_norm = use_norm
414
+ if self.use_norm:
415
+ self.norm = nn.LayerNorm(d_model)
416
+
417
+ self.max_value = max_value
418
+
419
+ def forward(self, x: Tensor) -> Tensor:
420
+ # Expand last dimension
421
+ x = x.unsqueeze(-1)
422
+ # Clip to max value
423
+ x = torch.clamp(x, max=self.max_value)
424
+ # Project
425
+ x = self.activation(self.linear1(x))
426
+ x = self.linear2(x)
427
+ if self.use_norm:
428
+ x = self.norm(x)
429
+ return self.dropout(x)
430
+
431
+
432
+ class ExprDecoder(nn.Module):
433
+ """Expression value decoder"""
434
+
435
+ def __init__(
436
+ self,
437
+ d_model: int,
438
+ n_outputs: int = 1,
439
+ n_layers: int = 2,
440
+ activation: str = "leaky_relu",
441
+ ):
442
+ super().__init__()
443
+
444
+ if activation == "leaky_relu":
445
+ self.activation = nn.LeakyReLU()
446
+ elif activation == "relu":
447
+ self.activation = nn.ReLU()
448
+ elif activation == "gelu":
449
+ self.activation = nn.GELU()
450
+ else:
451
+ self.activation = nn.LeakyReLU()
452
+
453
+ self.linear_layers = nn.ModuleList(
454
+ [nn.Linear(d_model, d_model) for _ in range(n_layers)]
455
+ )
456
+ self.out_proj = nn.Linear(d_model, n_outputs)
457
+
458
+ def forward(self, x: Tensor) -> Dict[str, Tensor]:
459
+ for layer in self.linear_layers:
460
+ x = self.activation(layer(x))
461
+ pred_value = self.out_proj(x)
462
+ if pred_value.shape[-1] == 1:
463
+ pred_value = pred_value.squeeze(-1)
464
+ return {"pred": pred_value}
465
+
466
+
467
+ class MVCDecoder(nn.Module):
468
+ """Masked value prediction decoder"""
469
+
470
+ def __init__(
471
+ self,
472
+ d_model: int,
473
+ arch_style: str = "inner product",
474
+ query_activation: str = "sigmoid",
475
+ scaled_dot_product: bool = False,
476
+ ) -> None:
477
+ super().__init__()
478
+
479
+ self.scaled_dot_product = scaled_dot_product
480
+
481
+ if arch_style == "inner product":
482
+ self.gene2query = nn.Linear(d_model, d_model)
483
+
484
+ if query_activation == "sigmoid":
485
+ self.query_activation = nn.Sigmoid()
486
+ elif query_activation == "relu":
487
+ self.query_activation = nn.ReLU()
488
+ elif query_activation == "tanh":
489
+ self.query_activation = nn.Tanh()
490
+ else:
491
+ self.query_activation = nn.Sigmoid()
492
+
493
+ self.W = nn.Linear(d_model, d_model, bias=False)
494
+ else:
495
+ raise ValueError(f"Unknown arch_style: {arch_style}")
496
+
497
+ self.arch_style = arch_style
498
+
499
+ def forward(
500
+ self,
501
+ cell_emb: Tensor,
502
+ gene_embs: Tensor,
503
+ ) -> Dict[str, Tensor]:
504
+
505
+ if self.arch_style == "inner product":
506
+ query_vecs = self.query_activation(
507
+ self.gene2query(gene_embs)
508
+ )
509
+ inner_product_dimension = query_vecs.shape[-1]
510
+ cell_emb = cell_emb.unsqueeze(2)
511
+ pred_value = torch.bmm(self.W(query_vecs), cell_emb).squeeze(2)
512
+
513
+ if self.scaled_dot_product:
514
+ pred_value = pred_value / torch.sqrt(
515
+ torch.tensor(inner_product_dimension, dtype=pred_value.dtype)
516
+ )
517
+
518
+ return {"pred": pred_value}
519
+ else:
520
+ raise ValueError(f"Unknown arch_style: {self.arch_style}")
config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "tx_model",
3
+ "architectures": [
4
+ "TXModelForHF"
5
+ ],
6
+ "vocab_size": 62720,
7
+ "d_model": 512,
8
+ "n_layers": 12,
9
+ "n_heads": 8,
10
+ "expansion_ratio": 4,
11
+ "norm_scheme": "pre",
12
+ "transformer_activation": "relu",
13
+ "use_glu": false,
14
+ "pad_token_id": 0,
15
+ "pad_value": -2,
16
+ "num_bins": 51,
17
+ "use_chem_token": false,
18
+ "keep_first_n_tokens": 1,
19
+ "cell_emb_style": "cls",
20
+ "return_gene_embeddings": false,
21
+ "standard_scale_outputs": false,
22
+ "attn_config": {
23
+ "attn_impl": "flash",
24
+ "use_attn_mask": false,
25
+ "attn_type": "grouped_query_attention",
26
+ "kv_nheads": 8,
27
+ "attn_pdrop": 0
28
+ },
29
+ "norm_config": {
30
+ "eps": 1e-05,
31
+ "norm_type": "layernorm"
32
+ },
33
+ "gene_encoder_config": {
34
+ "use_norm": true
35
+ },
36
+ "expression_encoder_config": {
37
+ "dropout": 0.1,
38
+ "use_norm": true,
39
+ "max_value": 512,
40
+ "activation": "relu",
41
+ "input_emb_style": "continuous"
42
+ },
43
+ "expression_decoder_config": {
44
+ "n_layers": 1,
45
+ "n_outputs": 1,
46
+ "activation": "leaky_relu"
47
+ },
48
+ "mvc_config": {
49
+ "arch_style": "inner product",
50
+ "query_activation": "sigmoid",
51
+ "scaled_dot_product": true
52
+ },
53
+ "chemical_encoder_config": null,
54
+ "auto_map": {
55
+ "AutoConfig": "configuration_tx.TXConfig",
56
+ "AutoModel": "modeling_tx_standalone.TXModelForHF",
57
+ "AutoModelForCausalLM": "modeling_tx_standalone.TXModelForHF"
58
+ },
59
+ "transformers_version": "4.35.0"
60
+ }
configuration_tx.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) Tahoe Therapeutics 2025. All rights reserved.
2
+ """
3
+ Configuration class for TXModel compatible with HuggingFace Transformers
4
+ """
5
+
6
+ from transformers import PretrainedConfig
7
+ from typing import Optional, Dict, Any
8
+
9
+
10
+ class TXConfig(PretrainedConfig):
11
+ """
12
+ Configuration class for TXModel.
13
+
14
+ This class stores the configuration of a TXModel, which is a Transformer-based model
15
+ for genomic/biological sequence analysis.
16
+
17
+ Args:
18
+ vocab_size (int): Size of the vocabulary
19
+ d_model (int): Dimensionality of the model embeddings
20
+ n_layers (int): Number of transformer layers
21
+ n_heads (int): Number of attention heads
22
+ expansion_ratio (int): Expansion ratio for FFN
23
+ norm_scheme (str): Normalization scheme ('pre' or 'post')
24
+ transformer_activation (str): Activation function for transformer
25
+ cell_emb_style (str): Cell embedding style ('cls', 'avg-pool', 'w-pool')
26
+ pad_token_id (int): ID of the padding token
27
+ pad_value (float): Value for padding
28
+ num_bins (int): Number of bins for expression values
29
+ use_chem_token (bool): Whether to use chemical token encoder
30
+ attn_config (Dict): Attention configuration
31
+ norm_config (Dict): Normalization configuration
32
+ init_config (Dict): Initialization configuration
33
+ gene_encoder_config (Dict): Gene encoder configuration
34
+ expression_encoder_config (Dict): Expression encoder configuration
35
+ expression_decoder_config (Dict): Expression decoder configuration
36
+ mvc_config (Optional[Dict]): MVC decoder configuration
37
+ chemical_encoder_config (Optional[Dict]): Chemical encoder configuration
38
+ use_glu (bool): Whether to use GLU in FFN
39
+ return_gene_embeddings (bool): Whether to return gene embeddings
40
+ standard_scale_outputs (bool): Whether to scale outputs
41
+ """
42
+
43
+ model_type = "tx_model"
44
+
45
+ def __init__(
46
+ self,
47
+ vocab_size: int = 30000,
48
+ d_model: int = 512,
49
+ n_layers: int = 12,
50
+ n_heads: int = 8,
51
+ expansion_ratio: int = 4,
52
+ norm_scheme: str = "pre",
53
+ transformer_activation: str = "gelu",
54
+ cell_emb_style: str = "cls",
55
+ pad_token_id: int = 0,
56
+ pad_value: float = 0.0,
57
+ num_bins: int = 51,
58
+ use_chem_token: bool = False,
59
+ attn_config: Optional[Dict] = None,
60
+ norm_config: Optional[Dict] = None,
61
+ init_config: Optional[Dict] = None,
62
+ gene_encoder_config: Optional[Dict] = None,
63
+ expression_encoder_config: Optional[Dict] = None,
64
+ expression_decoder_config: Optional[Dict] = None,
65
+ mvc_config: Optional[Dict] = None,
66
+ chemical_encoder_config: Optional[Dict] = None,
67
+ use_glu: bool = False,
68
+ return_gene_embeddings: bool = False,
69
+ standard_scale_outputs: bool = False,
70
+ keep_first_n_tokens: int = 1,
71
+ **kwargs
72
+ ):
73
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
74
+
75
+ self.vocab_size = vocab_size
76
+ self.d_model = d_model
77
+ self.n_layers = n_layers
78
+ self.n_heads = n_heads
79
+ self.expansion_ratio = expansion_ratio
80
+ self.norm_scheme = norm_scheme
81
+ self.transformer_activation = transformer_activation
82
+ self.cell_emb_style = cell_emb_style
83
+ self.pad_value = pad_value
84
+ self.num_bins = num_bins
85
+ self.use_chem_token = use_chem_token
86
+ self.keep_first_n_tokens = keep_first_n_tokens
87
+ self.return_gene_embeddings = return_gene_embeddings
88
+ self.standard_scale_outputs = standard_scale_outputs
89
+ self.use_glu = use_glu
90
+
91
+ # Sub-configurations
92
+ self.attn_config = attn_config or {
93
+ "attn_type": "grouped_query_attention",
94
+ "attn_pdrop": 0.0,
95
+ "attn_impl": "flash",
96
+ "use_attn_mask": False,
97
+ "qk_ln": False,
98
+ "qk_gn": False,
99
+ "clip_qkv": None,
100
+ "softmax_scale": None,
101
+ }
102
+
103
+ self.norm_config = norm_config or {
104
+ "norm_type": "low_precision_layernorm",
105
+ "eps": 1e-5,
106
+ }
107
+
108
+ self.init_config = init_config or {
109
+ "name": "kaiming_normal_",
110
+ "fan_mode": "fan_in",
111
+ "init_nonlinearity": "relu",
112
+ "init_div_is_residual": True,
113
+ "emb_init_std": None,
114
+ "emb_init_uniform_lim": None,
115
+ "init_std": None,
116
+ "init_gain": 0.0,
117
+ }
118
+
119
+ self.gene_encoder_config = gene_encoder_config or {
120
+ "use_norm": False,
121
+ }
122
+
123
+ self.expression_encoder_config = expression_encoder_config or {
124
+ "input_emb_style": "continuous",
125
+ "dropout": 0.1,
126
+ "max_value": 512,
127
+ "activation": "relu",
128
+ "use_norm": False,
129
+ }
130
+
131
+ self.expression_decoder_config = expression_decoder_config or {
132
+ "n_outputs": 1,
133
+ "n_layers": 2,
134
+ "activation": "leaky_relu",
135
+ }
136
+
137
+ self.mvc_config = mvc_config
138
+ self.chemical_encoder_config = chemical_encoder_config
139
+
140
+ @classmethod
141
+ def from_yaml_configs(cls, model_config_dict: Dict, collator_config_dict: Dict) -> "TXConfig":
142
+ """
143
+ Create TXConfig from model_config.yml and collator_config.yml dictionaries
144
+
145
+ Args:
146
+ model_config_dict: Dictionary from model_config.yml
147
+ collator_config_dict: Dictionary from collator_config.yml
148
+
149
+ Returns:
150
+ TXConfig instance
151
+ """
152
+ return cls(
153
+ vocab_size=model_config_dict.get("vocab_size"),
154
+ d_model=model_config_dict.get("d_model"),
155
+ n_layers=model_config_dict.get("n_layers"),
156
+ n_heads=model_config_dict.get("n_heads"),
157
+ expansion_ratio=model_config_dict.get("expansion_ratio"),
158
+ norm_scheme=model_config_dict.get("norm_scheme", "pre"),
159
+ transformer_activation=model_config_dict.get("transformer_activation", "gelu"),
160
+ cell_emb_style=model_config_dict.get("cell_emb_style", "cls"),
161
+ pad_token_id=collator_config_dict.get("pad_token_id", 0),
162
+ pad_value=collator_config_dict.get("pad_value", 0.0),
163
+ num_bins=collator_config_dict.get("num_bins", 51),
164
+ use_chem_token=collator_config_dict.get("use_chem_token", False),
165
+ attn_config=model_config_dict.get("attn_config"),
166
+ norm_config=model_config_dict.get("norm_config"),
167
+ init_config=model_config_dict.get("init_config"),
168
+ gene_encoder_config=model_config_dict.get("gene_encoder"),
169
+ expression_encoder_config=model_config_dict.get("expression_encoder"),
170
+ expression_decoder_config=model_config_dict.get("expression_decoder"),
171
+ mvc_config=model_config_dict.get("mvc"),
172
+ chemical_encoder_config=model_config_dict.get("chemical_encoder"),
173
+ use_glu=model_config_dict.get("use_glu", False),
174
+ return_gene_embeddings=model_config_dict.get("return_gene_embeddings", False),
175
+ standard_scale_outputs=model_config_dict.get("standard_scale_outputs", False),
176
+ keep_first_n_tokens=collator_config_dict.get("keep_first_n_tokens", 1),
177
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:217637af5a4d12f3fe2d2648fb9d4d1404b53eea587336c62cfcfbfb26088efd
3
+ size 284008108
model_standalone.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) Tahoe Therapeutics 2025. All rights reserved.
2
+ """
3
+ Standalone implementation of TXModel without external dependencies.
4
+ Only requires: torch, transformers, safetensors
5
+ """
6
+
7
+ from typing import Optional, Union, Tuple
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import Tensor, nn
11
+
12
+ from blocks_standalone import (
13
+ ChemEncoder,
14
+ ContinuousValueEncoder,
15
+ ExprDecoder,
16
+ GeneEncoder,
17
+ MVCDecoder,
18
+ TXBlock,
19
+ TXEncoder,
20
+ )
21
+
22
+
23
+ class TXModel(nn.Module):
24
+ """Standalone Transformer model for genomic data"""
25
+
26
+ def __init__(
27
+ self,
28
+ vocab_size: int,
29
+ d_model: int,
30
+ n_layers: int,
31
+ n_heads: int,
32
+ expansion_ratio: int,
33
+ pad_token_id: int,
34
+ pad_value: float,
35
+ num_bins: int,
36
+ norm_scheme: str = "pre",
37
+ transformer_activation: str = "gelu",
38
+ cell_emb_style: str = "cls",
39
+ use_chem_token: bool = False,
40
+ attn_config: Optional[dict] = None,
41
+ norm_config: Optional[dict] = None,
42
+ gene_encoder_config: Optional[dict] = None,
43
+ expression_encoder_config: Optional[dict] = None,
44
+ expression_decoder_config: Optional[dict] = None,
45
+ mvc_config: Optional[dict] = None,
46
+ chemical_encoder_config: Optional[dict] = None,
47
+ use_glu: bool = False,
48
+ return_gene_embeddings: bool = False,
49
+ keep_first_n_tokens: int = 1,
50
+ device: Optional[str] = None,
51
+ ):
52
+ super().__init__()
53
+
54
+ self.model_type = "Transformer"
55
+ self.device = device
56
+ self.vocab_size = vocab_size
57
+ self.n_layers = n_layers
58
+ self.n_heads = n_heads
59
+ self.d_model = d_model
60
+ self.expansion_ratio = expansion_ratio
61
+ self.norm_scheme = norm_scheme
62
+ self.transformer_activation = transformer_activation
63
+ self.use_chem_token = use_chem_token
64
+ self.cell_emb_style = cell_emb_style
65
+ self.pad_token_id = pad_token_id
66
+ self.pad_value = pad_value
67
+ self.n_input_bins = num_bins
68
+ self.keep_first_n_tokens = keep_first_n_tokens
69
+ self.return_gene_embeddings = return_gene_embeddings
70
+
71
+ if attn_config is None:
72
+ attn_config = {}
73
+ if norm_config is None:
74
+ norm_config = {}
75
+ if gene_encoder_config is None:
76
+ gene_encoder_config = {"use_norm": False}
77
+ if expression_encoder_config is None:
78
+ expression_encoder_config = {}
79
+ if expression_decoder_config is None:
80
+ expression_decoder_config = {}
81
+
82
+ # Gene encoder
83
+ self.gene_encoder = GeneEncoder(
84
+ self.vocab_size,
85
+ self.d_model,
86
+ padding_idx=self.pad_token_id,
87
+ use_norm=gene_encoder_config.get("use_norm", False),
88
+ gene_encoder_cfg=gene_encoder_config,
89
+ )
90
+
91
+ # Flag encoder
92
+ self.flag_encoder = nn.Embedding(2, self.d_model)
93
+
94
+ # Expression encoder
95
+ self.expression_encoder = ContinuousValueEncoder(
96
+ d_model=self.d_model,
97
+ dropout=expression_encoder_config.get("dropout", 0.1),
98
+ max_value=expression_encoder_config.get("max_value", 512),
99
+ activation=expression_encoder_config.get("activation", "relu"),
100
+ use_norm=expression_encoder_config.get("use_norm", False),
101
+ )
102
+
103
+ # Chemical encoder (if needed)
104
+ if self.use_chem_token:
105
+ if chemical_encoder_config is None:
106
+ chemical_encoder_config = {}
107
+ self.chem_encoder = ChemEncoder(
108
+ d_out=self.d_model,
109
+ padding_idx=chemical_encoder_config.get("padding_idx", 0),
110
+ activation=chemical_encoder_config.get("activation", "leaky_relu"),
111
+ freeze=chemical_encoder_config.get("freeze", False),
112
+ num_drugs=chemical_encoder_config.get("num_drugs", 1000),
113
+ fp_dim=chemical_encoder_config.get("fp_dim", 2048),
114
+ )
115
+
116
+ # Transformer encoder
117
+ encoder_layer = TXBlock(
118
+ d_model=self.d_model,
119
+ n_heads=self.n_heads,
120
+ expansion_ratio=self.expansion_ratio,
121
+ attn_config=attn_config,
122
+ norm_config=norm_config,
123
+ activation=self.transformer_activation,
124
+ device=self.device,
125
+ norm_scheme=self.norm_scheme,
126
+ use_glu=use_glu,
127
+ )
128
+
129
+ self.transformer_encoder = TXEncoder(
130
+ encoder_layer,
131
+ self.n_layers,
132
+ use_norm=self.norm_scheme == "pre",
133
+ norm_config=norm_config,
134
+ attn_config=attn_config,
135
+ )
136
+
137
+ # Expression decoder
138
+ self.expression_decoder = ExprDecoder(
139
+ d_model=self.d_model,
140
+ n_outputs=expression_decoder_config.get("n_outputs", 1),
141
+ n_layers=expression_decoder_config.get("n_layers", 2),
142
+ activation=expression_decoder_config.get("activation", "leaky_relu"),
143
+ )
144
+
145
+ # MVC decoder (if configured)
146
+ if mvc_config is not None:
147
+ self.mvc_decoder = MVCDecoder(
148
+ d_model=self.d_model,
149
+ arch_style=mvc_config.get("arch_style", "inner product"),
150
+ query_activation=mvc_config.get("query_activation", "sigmoid"),
151
+ scaled_dot_product=mvc_config.get("scaled_dot_product", False),
152
+ )
153
+ else:
154
+ self.mvc_decoder = None
155
+
156
+ def transformer_generate(
157
+ self,
158
+ genes: Tensor,
159
+ values: Tensor,
160
+ gen_masks: Tensor,
161
+ key_padding_mask: Tensor,
162
+ drug_ids: Optional[Tensor] = None,
163
+ output_hidden_states: bool = False,
164
+ ) -> Union[Tensor, Tuple[Tensor, list]]:
165
+
166
+ # Encode genes
167
+ token_embs = self.gene_encoder(genes)
168
+
169
+ # Encode expression values
170
+ token_values = self.expression_encoder(values)
171
+ token_values = token_values.masked_fill(gen_masks.unsqueeze(-1), 0.0)
172
+
173
+ # Flag embeddings
174
+ flag = self.flag_encoder(
175
+ torch.tensor(1, device=token_embs.device)
176
+ ).reshape(1, 1, -1)
177
+ flag_embs = gen_masks.unsqueeze(-1).to(token_embs.dtype) * flag
178
+
179
+ # Combine embeddings
180
+ total_embs = token_embs + token_values + flag_embs
181
+
182
+ # Add chemical embedding if used
183
+ if self.use_chem_token and drug_ids is not None:
184
+ drug_embs = self.chem_encoder(drug_ids)
185
+ total_embs[:, 1, :] = drug_embs
186
+
187
+ # Store gene embeddings for MVC
188
+ self.cur_gene_token_embs = token_embs
189
+
190
+ # Pass through transformer
191
+ output, hidden_states = self.transformer_encoder(
192
+ total_embs=total_embs,
193
+ key_padding_mask=key_padding_mask,
194
+ output_hidden_states=output_hidden_states,
195
+ )
196
+
197
+ return output, hidden_states
198
+
199
+ def forward(
200
+ self,
201
+ genes: Tensor,
202
+ values: Tensor,
203
+ gen_masks: Tensor,
204
+ key_padding_mask: Tensor,
205
+ drug_ids: Optional[Tensor] = None,
206
+ skip_decoders: bool = False,
207
+ output_hidden_states: bool = False,
208
+ ) -> dict:
209
+
210
+ # Generate transformer output
211
+ transformer_output, hidden_states = self.transformer_generate(
212
+ genes, values, gen_masks, key_padding_mask,
213
+ drug_ids, output_hidden_states
214
+ )
215
+
216
+ # Prepare output dict
217
+ output = {
218
+ "transformer_output": transformer_output,
219
+ }
220
+
221
+ if output_hidden_states:
222
+ output["hidden_states"] = hidden_states
223
+
224
+ # Cell embedding (CLS token or pooling)
225
+ if self.cell_emb_style == "cls":
226
+ cell_emb = transformer_output[:, 0, :]
227
+ elif self.cell_emb_style == "avg-pool":
228
+ # Average over non-padding tokens
229
+ mask = key_padding_mask.unsqueeze(-1).float()
230
+ cell_emb = (transformer_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
231
+ elif self.cell_emb_style == "w-pool":
232
+ # Weighted pooling (not implemented, use avg)
233
+ mask = key_padding_mask.unsqueeze(-1).float()
234
+ cell_emb = (transformer_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
235
+ else:
236
+ cell_emb = transformer_output[:, 0, :]
237
+
238
+ output["cell_emb"] = cell_emb
239
+
240
+ # Return gene embeddings if requested
241
+ if self.return_gene_embeddings:
242
+ output["gene_embeddings"] = transformer_output
243
+
244
+ # Skip decoders if requested
245
+ if skip_decoders:
246
+ return output
247
+
248
+ # Expression decoder
249
+ expr_output = self.expression_decoder(transformer_output)
250
+ output["expr_preds"] = expr_output["pred"]
251
+
252
+ # MVC decoder (if available)
253
+ if self.mvc_decoder is not None:
254
+ mvc_output = self.mvc_decoder(
255
+ cell_emb,
256
+ self.cur_gene_token_embs,
257
+ )
258
+ output["mvc_output"] = mvc_output["pred"]
259
+
260
+ return output
261
+
262
+ @classmethod
263
+ def from_pretrained(cls, model_path: str, **kwargs):
264
+ """Load model from pretrained weights"""
265
+ from safetensors.torch import load_file
266
+ import json
267
+ from pathlib import Path
268
+
269
+ model_path = Path(model_path)
270
+
271
+ # Load config
272
+ with open(model_path / "config.json", "r") as f:
273
+ config = json.load(f)
274
+
275
+ # Create model
276
+ model = cls(
277
+ vocab_size=config["vocab_size"],
278
+ d_model=config["d_model"],
279
+ n_layers=config["n_layers"],
280
+ n_heads=config["n_heads"],
281
+ expansion_ratio=config["expansion_ratio"],
282
+ pad_token_id=config["pad_token_id"],
283
+ pad_value=config["pad_value"],
284
+ num_bins=config["num_bins"],
285
+ norm_scheme=config.get("norm_scheme", "pre"),
286
+ transformer_activation=config.get("transformer_activation", "gelu"),
287
+ cell_emb_style=config.get("cell_emb_style", "cls"),
288
+ use_chem_token=config.get("use_chem_token", False),
289
+ attn_config=config.get("attn_config"),
290
+ norm_config=config.get("norm_config"),
291
+ gene_encoder_config=config.get("gene_encoder_config"),
292
+ expression_encoder_config=config.get("expression_encoder_config"),
293
+ expression_decoder_config=config.get("expression_decoder_config"),
294
+ mvc_config=config.get("mvc_config"),
295
+ chemical_encoder_config=config.get("chemical_encoder_config"),
296
+ use_glu=config.get("use_glu", False),
297
+ return_gene_embeddings=config.get("return_gene_embeddings", False),
298
+ keep_first_n_tokens=config.get("keep_first_n_tokens", 1),
299
+ )
300
+
301
+ # Load weights
302
+ state_dict = load_file(model_path / "model.safetensors")
303
+
304
+ # Remove 'model.tx_model.' or 'tx_model.' prefix if present
305
+ new_state_dict = {}
306
+ for k, v in state_dict.items():
307
+ new_key = k
308
+ if k.startswith('model.tx_model.'):
309
+ new_key = k[14:] # Remove 'model.tx_model.'
310
+ elif k.startswith('tx_model.'):
311
+ new_key = k[9:] # Remove 'tx_model.'
312
+ elif k.startswith('model.'):
313
+ new_key = k[6:] # Remove 'model.'
314
+ new_state_dict[new_key] = v
315
+
316
+ model.load_state_dict(new_state_dict, strict=False)
317
+
318
+ return model
modeling_tx_standalone.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) Tahoe Therapeutics 2025. All rights reserved.
2
+ """
3
+ HuggingFace-compatible wrapper for TXModel (Standalone version)
4
+ Only requires: transformers, torch, safetensors
5
+ """
6
+
7
+ from typing import Optional, Union, Tuple
8
+ import torch
9
+ from transformers import PreTrainedModel
10
+ from transformers.modeling_outputs import BaseModelOutput
11
+
12
+ from configuration_tx import TXConfig
13
+ from model_standalone import TXModel
14
+
15
+
16
+ class TXPreTrainedModel(PreTrainedModel):
17
+ """
18
+ Base class for TXModel with HuggingFace integration
19
+ """
20
+ config_class = TXConfig
21
+ base_model_prefix = "tx_model"
22
+ supports_gradient_checkpointing = False
23
+ _no_split_modules = ["TXBlock"]
24
+
25
+ def _init_weights(self, module):
26
+ """Initialize weights"""
27
+ if isinstance(module, torch.nn.Linear):
28
+ module.weight.data.normal_(mean=0.0, std=0.02)
29
+ if module.bias is not None:
30
+ module.bias.data.zero_()
31
+ elif isinstance(module, torch.nn.Embedding):
32
+ module.weight.data.normal_(mean=0.0, std=0.02)
33
+ if module.padding_idx is not None:
34
+ module.weight.data[module.padding_idx].zero_()
35
+ elif isinstance(module, torch.nn.LayerNorm):
36
+ module.bias.data.zero_()
37
+ module.weight.data.fill_(1.0)
38
+
39
+
40
+ class TXModelForHF(TXPreTrainedModel):
41
+ """
42
+ HuggingFace-compatible TXModel
43
+
44
+ This model can be used directly with HuggingFace's from_pretrained()
45
+ and requires only: transformers, torch, safetensors
46
+
47
+ No dependencies on llmfoundry, composer, or other external libraries.
48
+ """
49
+
50
+ def __init__(self, config: TXConfig):
51
+ super().__init__(config)
52
+
53
+ # Initialize standalone model
54
+ self.tx_model = TXModel(
55
+ vocab_size=config.vocab_size,
56
+ d_model=config.d_model,
57
+ n_layers=config.n_layers,
58
+ n_heads=config.n_heads,
59
+ expansion_ratio=config.expansion_ratio,
60
+ pad_token_id=config.pad_token_id,
61
+ pad_value=config.pad_value,
62
+ num_bins=config.num_bins,
63
+ norm_scheme=config.norm_scheme,
64
+ transformer_activation=config.transformer_activation,
65
+ cell_emb_style=config.cell_emb_style,
66
+ use_chem_token=config.use_chem_token,
67
+ attn_config=config.attn_config,
68
+ norm_config=config.norm_config,
69
+ gene_encoder_config=config.gene_encoder_config,
70
+ expression_encoder_config=config.expression_encoder_config,
71
+ expression_decoder_config=config.expression_decoder_config,
72
+ mvc_config=config.mvc_config,
73
+ chemical_encoder_config=config.chemical_encoder_config,
74
+ use_glu=config.use_glu,
75
+ return_gene_embeddings=config.return_gene_embeddings,
76
+ keep_first_n_tokens=config.keep_first_n_tokens,
77
+ )
78
+
79
+ # Post init
80
+ self.post_init()
81
+
82
+ def forward(
83
+ self,
84
+ genes: torch.Tensor,
85
+ values: torch.Tensor,
86
+ gen_masks: torch.Tensor,
87
+ key_padding_mask: Optional[torch.Tensor] = None,
88
+ drug_ids: Optional[torch.Tensor] = None,
89
+ skip_decoders: bool = False,
90
+ output_hidden_states: bool = False,
91
+ return_dict: bool = True,
92
+ ) -> Union[Tuple, BaseModelOutput]:
93
+ """
94
+ Forward pass through the model.
95
+
96
+ Args:
97
+ genes: Gene token IDs [batch_size, seq_len]
98
+ values: Expression values [batch_size, seq_len]
99
+ gen_masks: Generation masks [batch_size, seq_len]
100
+ key_padding_mask: Padding mask [batch_size, seq_len]
101
+ drug_ids: Drug IDs [batch_size] (optional)
102
+ skip_decoders: Whether to skip decoder computation
103
+ output_hidden_states: Whether to return hidden states
104
+ return_dict: Whether to return a dict or tuple
105
+
106
+ Returns:
107
+ Model outputs
108
+ """
109
+
110
+ if key_padding_mask is None:
111
+ key_padding_mask = ~genes.eq(self.config.pad_token_id)
112
+
113
+ outputs = self.tx_model(
114
+ genes=genes,
115
+ values=values,
116
+ gen_masks=gen_masks,
117
+ key_padding_mask=key_padding_mask,
118
+ drug_ids=drug_ids,
119
+ skip_decoders=skip_decoders,
120
+ output_hidden_states=output_hidden_states,
121
+ )
122
+
123
+ if not return_dict:
124
+ return tuple(v for v in outputs.values())
125
+
126
+ # Convert to HuggingFace output format
127
+ return BaseModelOutput(
128
+ last_hidden_state=outputs.get("cell_emb"),
129
+ hidden_states=outputs.get("hidden_states") if output_hidden_states else None,
130
+ )
131
+
132
+ def get_input_embeddings(self):
133
+ """Get input embeddings"""
134
+ return self.tx_model.gene_encoder.embedding
135
+
136
+ def set_input_embeddings(self, value):
137
+ """Set input embeddings"""
138
+ self.tx_model.gene_encoder.embedding = value
139
+
140
+ def get_output_embeddings(self):
141
+ """Get output embeddings (not applicable)"""
142
+ return None
143
+
144
+ @classmethod
145
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
146
+ """
147
+ Load model from pretrained weights.
148
+
149
+ Works with both local paths and HuggingFace Hub.
150
+ Requires only: transformers, torch, safetensors
151
+ """
152
+ # Let parent class handle config and weight loading
153
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
154
+
155
+
156
+ # Alias for easier importing
157
+ TXForCausalLM = TXModelForHF
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Standalone version - ONLY these dependencies required!
2
+ transformers>=4.35.0
3
+ torch>=2.0.0
4
+ safetensors>=0.4.0
5
+
6
+ # Optional: for converting from original format
7
+ # omegaconf>=2.3.0 # Only needed for conversion, not for using the model
tokenizer_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "PreTrainedTokenizerFast",
3
+ "model_max_length": 1000000000000000019884624838656,
4
+ "vocab_size": 62720
5
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff