Yuto2007 commited on
Commit
8517f78
·
verified ·
1 Parent(s): e9d18f4

Fix: Use single modeling.py file

Browse files
Files changed (2) hide show
  1. config.json +3 -3
  2. modeling.py +579 -0
config.json CHANGED
@@ -52,9 +52,9 @@
52
  },
53
  "chemical_encoder_config": null,
54
  "auto_map": {
55
- "AutoConfig": "configuration_tx.TXConfig",
56
- "AutoModel": "modeling_tx_standalone.TXModelForHF",
57
- "AutoModelForCausalLM": "modeling_tx_standalone.TXModelForHF"
58
  },
59
  "transformers_version": "4.35.0"
60
  }
 
52
  },
53
  "chemical_encoder_config": null,
54
  "auto_map": {
55
+ "AutoConfig": "modeling.TXConfig",
56
+ "AutoModel": "modeling.TXModelForHF",
57
+ "AutoModelForCausalLM": "modeling.TXModelForHF"
58
  },
59
  "transformers_version": "4.35.0"
60
  }
modeling.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) Tahoe Therapeutics 2025. All rights reserved.
2
+ """
3
+ TXModel - Complete Standalone Implementation for HuggingFace
4
+ All code in one file - requires ONLY: transformers, torch, safetensors
5
+ """
6
+
7
+ import math
8
+ from typing import Optional, Dict, Any, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import Tensor, nn
13
+ from transformers import PreTrainedModel, PretrainedConfig
14
+ from transformers.modeling_outputs import BaseModelOutput
15
+
16
+
17
+ # =============================================================================
18
+ # CONFIGURATION
19
+ # =============================================================================
20
+
21
+ class TXConfig(PretrainedConfig):
22
+ """Configuration for TXModel"""
23
+
24
+ model_type = "tx_model"
25
+
26
+ def __init__(
27
+ self,
28
+ vocab_size: int = 30000,
29
+ d_model: int = 512,
30
+ n_layers: int = 12,
31
+ n_heads: int = 8,
32
+ expansion_ratio: int = 4,
33
+ norm_scheme: str = "pre",
34
+ transformer_activation: str = "gelu",
35
+ cell_emb_style: str = "cls",
36
+ pad_token_id: int = 0,
37
+ pad_value: float = 0.0,
38
+ num_bins: int = 51,
39
+ use_chem_token: bool = False,
40
+ attn_config: Optional[Dict] = None,
41
+ norm_config: Optional[Dict] = None,
42
+ gene_encoder_config: Optional[Dict] = None,
43
+ expression_encoder_config: Optional[Dict] = None,
44
+ expression_decoder_config: Optional[Dict] = None,
45
+ mvc_config: Optional[Dict] = None,
46
+ chemical_encoder_config: Optional[Dict] = None,
47
+ use_glu: bool = False,
48
+ return_gene_embeddings: bool = False,
49
+ standard_scale_outputs: bool = False,
50
+ keep_first_n_tokens: int = 1,
51
+ **kwargs
52
+ ):
53
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
54
+
55
+ self.vocab_size = vocab_size
56
+ self.d_model = d_model
57
+ self.n_layers = n_layers
58
+ self.n_heads = n_heads
59
+ self.expansion_ratio = expansion_ratio
60
+ self.norm_scheme = norm_scheme
61
+ self.transformer_activation = transformer_activation
62
+ self.cell_emb_style = cell_emb_style
63
+ self.pad_value = pad_value
64
+ self.num_bins = num_bins
65
+ self.use_chem_token = use_chem_token
66
+ self.keep_first_n_tokens = keep_first_n_tokens
67
+ self.return_gene_embeddings = return_gene_embeddings
68
+ self.standard_scale_outputs = standard_scale_outputs
69
+ self.use_glu = use_glu
70
+
71
+ self.attn_config = attn_config or {}
72
+ self.norm_config = norm_config or {}
73
+ self.gene_encoder_config = gene_encoder_config or {}
74
+ self.expression_encoder_config = expression_encoder_config or {}
75
+ self.expression_decoder_config = expression_decoder_config or {}
76
+ self.mvc_config = mvc_config
77
+ self.chemical_encoder_config = chemical_encoder_config
78
+
79
+
80
+ # =============================================================================
81
+ # MODEL BLOCKS
82
+ # =============================================================================
83
+
84
+ class MultiheadAttention(nn.Module):
85
+ """Multi-head attention with grouped query support"""
86
+
87
+ def __init__(
88
+ self,
89
+ d_model: int,
90
+ n_heads: int,
91
+ kv_n_heads: Optional[int] = None,
92
+ dropout: float = 0.0,
93
+ device: Optional[str] = None,
94
+ ):
95
+ super().__init__()
96
+ self.d_model = d_model
97
+ self.n_heads = n_heads
98
+ self.kv_n_heads = kv_n_heads if kv_n_heads is not None else n_heads
99
+ self.head_dim = d_model // n_heads
100
+ self.dropout = dropout
101
+ self.n_rep = n_heads // self.kv_n_heads
102
+
103
+ self.q_proj = nn.Linear(d_model, d_model, device=device)
104
+ self.k_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, device=device)
105
+ self.v_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, device=device)
106
+ self.out_proj = nn.Linear(d_model, d_model, device=device)
107
+ self.attn_dropout = nn.Dropout(dropout)
108
+
109
+ def forward(
110
+ self,
111
+ x: Tensor,
112
+ key_padding_mask: Optional[Tensor] = None,
113
+ **kwargs
114
+ ) -> Tuple[Tensor, None, None]:
115
+ batch_size, seq_len, _ = x.shape
116
+
117
+ q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
118
+ k = self.k_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim).transpose(1, 2)
119
+ v = self.v_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim).transpose(1, 2)
120
+
121
+ if self.n_rep > 1:
122
+ k = k.repeat_interleave(self.n_rep, dim=1)
123
+ v = v.repeat_interleave(self.n_rep, dim=1)
124
+
125
+ scale = 1.0 / math.sqrt(self.head_dim)
126
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
127
+
128
+ if key_padding_mask is not None:
129
+ mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
130
+ attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
131
+
132
+ attn_weights = F.softmax(attn_scores, dim=-1)
133
+ attn_weights = self.attn_dropout(attn_weights)
134
+
135
+ output = torch.matmul(attn_weights, v)
136
+ output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
137
+ output = self.out_proj(output)
138
+
139
+ return output, None, None
140
+
141
+
142
+ class TXBlock(nn.Module):
143
+ """Transformer encoder block"""
144
+
145
+ def __init__(
146
+ self,
147
+ d_model: int,
148
+ n_heads: int,
149
+ expansion_ratio: int,
150
+ attn_config: Optional[Dict] = None,
151
+ norm_config: Optional[Dict] = None,
152
+ dropout: float = 0.0,
153
+ activation: str = "gelu",
154
+ device: Optional[str] = None,
155
+ norm_scheme: str = "pre",
156
+ use_glu: bool = False,
157
+ **kwargs
158
+ ):
159
+ super().__init__()
160
+
161
+ attn_config = attn_config or {}
162
+ norm_config = norm_config or {}
163
+
164
+ self.d_model = d_model
165
+ self.n_heads = n_heads
166
+ self.norm_scheme = norm_scheme
167
+ self.use_glu = use_glu
168
+
169
+ kv_n_heads = attn_config.get("kv_n_heads", n_heads)
170
+ self.self_attn = MultiheadAttention(
171
+ d_model=d_model,
172
+ n_heads=n_heads,
173
+ kv_n_heads=kv_n_heads,
174
+ dropout=attn_config.get("attn_pdrop", 0.0),
175
+ device=device,
176
+ )
177
+
178
+ dim_feedforward = d_model * expansion_ratio
179
+ self.up_proj = nn.Linear(d_model, dim_feedforward, device=device)
180
+ self.down_proj = nn.Linear(dim_feedforward, d_model, device=device)
181
+
182
+ if use_glu:
183
+ self.gate_proj = nn.Linear(d_model, dim_feedforward, device=device)
184
+
185
+ eps = norm_config.get("eps", 1e-5)
186
+ self.norm1 = nn.LayerNorm(d_model, eps=eps, device=device)
187
+ self.norm2 = nn.LayerNorm(d_model, eps=eps, device=device)
188
+
189
+ self.post_sa_dropout = nn.Dropout(dropout)
190
+ self.post_ffn_dropout = nn.Dropout(dropout)
191
+
192
+ self.activation = {
193
+ "gelu": nn.GELU(),
194
+ "relu": nn.ReLU(),
195
+ "silu": nn.SiLU(),
196
+ "leaky_relu": nn.LeakyReLU(),
197
+ }.get(activation, nn.GELU())
198
+
199
+ def forward(
200
+ self,
201
+ x: Tensor,
202
+ key_padding_mask: Optional[Tensor] = None,
203
+ **kwargs
204
+ ) -> Tensor:
205
+ if self.norm_scheme == "pre":
206
+ x = x + self._sa_block(self.norm1(x), key_padding_mask)
207
+ x = x + self._ff_block(self.norm2(x))
208
+ else:
209
+ x = self.norm1(x + self._sa_block(x, key_padding_mask))
210
+ x = self.norm2(x + self._ff_block(x))
211
+ return x
212
+
213
+ def _sa_block(self, x: Tensor, key_padding_mask: Optional[Tensor] = None) -> Tensor:
214
+ x, _, _ = self.self_attn(x, key_padding_mask=key_padding_mask)
215
+ return self.post_sa_dropout(x)
216
+
217
+ def _ff_block(self, x: Tensor) -> Tensor:
218
+ if self.use_glu:
219
+ x = self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
220
+ else:
221
+ x = self.down_proj(self.activation(self.up_proj(x)))
222
+ return self.post_ffn_dropout(x)
223
+
224
+
225
+ class TXEncoder(nn.Module):
226
+ """Stack of transformer encoder layers"""
227
+
228
+ def __init__(
229
+ self,
230
+ encoder_layer: TXBlock,
231
+ num_layers: int,
232
+ use_norm: bool = False,
233
+ norm_config: Optional[Dict] = None,
234
+ **kwargs
235
+ ):
236
+ super().__init__()
237
+
238
+ norm_config = norm_config or {}
239
+
240
+ self.layers = nn.ModuleList([
241
+ TXBlock(
242
+ d_model=encoder_layer.d_model,
243
+ n_heads=encoder_layer.n_heads,
244
+ expansion_ratio=encoder_layer.up_proj.out_features // encoder_layer.d_model,
245
+ norm_scheme=encoder_layer.norm_scheme,
246
+ use_glu=encoder_layer.use_glu,
247
+ )
248
+ for _ in range(num_layers)
249
+ ])
250
+
251
+ self.use_norm = use_norm
252
+ if use_norm:
253
+ eps = norm_config.get("eps", 1e-5)
254
+ self.norm = nn.LayerNorm(encoder_layer.d_model, eps=eps)
255
+
256
+ def forward(
257
+ self,
258
+ total_embs: Tensor,
259
+ key_padding_mask: Optional[Tensor] = None,
260
+ output_hidden_states: bool = False,
261
+ ) -> Tuple[Tensor, Optional[list]]:
262
+ x = total_embs
263
+ hidden_states = [] if output_hidden_states else None
264
+
265
+ for layer in self.layers:
266
+ x = layer(x, key_padding_mask=key_padding_mask)
267
+ if output_hidden_states:
268
+ hidden_states.append(x)
269
+
270
+ if self.use_norm:
271
+ x = self.norm(x)
272
+
273
+ return x, hidden_states
274
+
275
+
276
+ class GeneEncoder(nn.Module):
277
+ """Gene embedding encoder"""
278
+
279
+ def __init__(
280
+ self,
281
+ num_embeddings: int,
282
+ embedding_dim: int,
283
+ padding_idx: int = 0,
284
+ use_norm: bool = False,
285
+ **kwargs
286
+ ):
287
+ super().__init__()
288
+ self.use_norm = use_norm
289
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
290
+ self.project = nn.Identity()
291
+
292
+ if self.use_norm:
293
+ self.enc_norm = nn.LayerNorm(embedding_dim)
294
+
295
+ def forward(self, x: Tensor) -> Tensor:
296
+ x = self.embedding(x)
297
+ x = self.project(x)
298
+ if self.use_norm:
299
+ x = self.enc_norm(x)
300
+ return x
301
+
302
+
303
+ class ContinuousValueEncoder(nn.Module):
304
+ """Encode continuous expression values"""
305
+
306
+ def __init__(
307
+ self,
308
+ d_model: int,
309
+ dropout: float = 0.1,
310
+ max_value: int = 512,
311
+ activation: str = "relu",
312
+ use_norm: bool = False,
313
+ ):
314
+ super().__init__()
315
+ self.dropout = nn.Dropout(p=dropout)
316
+ self.linear1 = nn.Linear(1, d_model)
317
+ self.activation = {"relu": nn.ReLU(), "gelu": nn.GELU(), "leaky_relu": nn.LeakyReLU()}.get(activation, nn.ReLU())
318
+ self.linear2 = nn.Linear(d_model, d_model)
319
+ self.use_norm = use_norm
320
+ if use_norm:
321
+ self.norm = nn.LayerNorm(d_model)
322
+ self.max_value = max_value
323
+
324
+ def forward(self, x: Tensor) -> Tensor:
325
+ x = x.unsqueeze(-1)
326
+ x = torch.clamp(x, max=self.max_value)
327
+ x = self.activation(self.linear1(x))
328
+ x = self.linear2(x)
329
+ if self.use_norm:
330
+ x = self.norm(x)
331
+ return self.dropout(x)
332
+
333
+
334
+ class ExprDecoder(nn.Module):
335
+ """Expression value decoder"""
336
+
337
+ def __init__(
338
+ self,
339
+ d_model: int,
340
+ n_outputs: int = 1,
341
+ n_layers: int = 2,
342
+ activation: str = "leaky_relu",
343
+ ):
344
+ super().__init__()
345
+ self.activation = {"leaky_relu": nn.LeakyReLU(), "relu": nn.ReLU(), "gelu": nn.GELU()}.get(activation, nn.LeakyReLU())
346
+ self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(n_layers)])
347
+ self.out_proj = nn.Linear(d_model, n_outputs)
348
+
349
+ def forward(self, x: Tensor) -> Dict[str, Tensor]:
350
+ for layer in self.linear_layers:
351
+ x = self.activation(layer(x))
352
+ pred_value = self.out_proj(x)
353
+ if pred_value.shape[-1] == 1:
354
+ pred_value = pred_value.squeeze(-1)
355
+ return {"pred": pred_value}
356
+
357
+
358
+ class MVCDecoder(nn.Module):
359
+ """Masked value prediction decoder"""
360
+
361
+ def __init__(
362
+ self,
363
+ d_model: int,
364
+ arch_style: str = "inner product",
365
+ query_activation: str = "sigmoid",
366
+ scaled_dot_product: bool = False,
367
+ ):
368
+ super().__init__()
369
+ self.scaled_dot_product = scaled_dot_product
370
+ self.gene2query = nn.Linear(d_model, d_model)
371
+ self.query_activation = {"sigmoid": nn.Sigmoid(), "relu": nn.ReLU(), "tanh": nn.Tanh()}.get(query_activation, nn.Sigmoid())
372
+ self.W = nn.Linear(d_model, d_model, bias=False)
373
+ self.arch_style = arch_style
374
+
375
+ def forward(self, cell_emb: Tensor, gene_embs: Tensor) -> Dict[str, Tensor]:
376
+ query_vecs = self.query_activation(self.gene2query(gene_embs))
377
+ cell_emb = cell_emb.unsqueeze(2)
378
+ pred_value = torch.bmm(self.W(query_vecs), cell_emb).squeeze(2)
379
+
380
+ if self.scaled_dot_product:
381
+ pred_value = pred_value / torch.sqrt(torch.tensor(query_vecs.shape[-1], dtype=pred_value.dtype))
382
+
383
+ return {"pred": pred_value}
384
+
385
+
386
+ # =============================================================================
387
+ # MAIN MODEL
388
+ # =============================================================================
389
+
390
+ class TXModel(nn.Module):
391
+ """Transformer model for genomic data"""
392
+
393
+ def __init__(self, config: TXConfig):
394
+ super().__init__()
395
+
396
+ self.config = config
397
+ self.gene_encoder = GeneEncoder(
398
+ config.vocab_size,
399
+ config.d_model,
400
+ padding_idx=config.pad_token_id,
401
+ use_norm=config.gene_encoder_config.get("use_norm", False),
402
+ )
403
+
404
+ self.flag_encoder = nn.Embedding(2, config.d_model)
405
+
406
+ self.expression_encoder = ContinuousValueEncoder(
407
+ d_model=config.d_model,
408
+ dropout=config.expression_encoder_config.get("dropout", 0.1),
409
+ max_value=config.expression_encoder_config.get("max_value", 512),
410
+ activation=config.expression_encoder_config.get("activation", "relu"),
411
+ use_norm=config.expression_encoder_config.get("use_norm", False),
412
+ )
413
+
414
+ encoder_layer = TXBlock(
415
+ d_model=config.d_model,
416
+ n_heads=config.n_heads,
417
+ expansion_ratio=config.expansion_ratio,
418
+ attn_config=config.attn_config,
419
+ norm_config=config.norm_config,
420
+ activation=config.transformer_activation,
421
+ norm_scheme=config.norm_scheme,
422
+ use_glu=config.use_glu,
423
+ )
424
+
425
+ self.transformer_encoder = TXEncoder(
426
+ encoder_layer,
427
+ config.n_layers,
428
+ use_norm=config.norm_scheme == "pre",
429
+ norm_config=config.norm_config,
430
+ )
431
+
432
+ self.expression_decoder = ExprDecoder(
433
+ d_model=config.d_model,
434
+ n_outputs=config.expression_decoder_config.get("n_outputs", 1),
435
+ n_layers=config.expression_decoder_config.get("n_layers", 2),
436
+ activation=config.expression_decoder_config.get("activation", "leaky_relu"),
437
+ )
438
+
439
+ if config.mvc_config is not None:
440
+ self.mvc_decoder = MVCDecoder(
441
+ d_model=config.d_model,
442
+ arch_style=config.mvc_config.get("arch_style", "inner product"),
443
+ query_activation=config.mvc_config.get("query_activation", "sigmoid"),
444
+ scaled_dot_product=config.mvc_config.get("scaled_dot_product", False),
445
+ )
446
+ else:
447
+ self.mvc_decoder = None
448
+
449
+ def forward(
450
+ self,
451
+ genes: Tensor,
452
+ values: Tensor,
453
+ gen_masks: Tensor,
454
+ key_padding_mask: Tensor,
455
+ skip_decoders: bool = False,
456
+ output_hidden_states: bool = False,
457
+ ) -> dict:
458
+ # Encode
459
+ token_embs = self.gene_encoder(genes)
460
+ token_values = self.expression_encoder(values)
461
+ token_values = token_values.masked_fill(gen_masks.unsqueeze(-1), 0.0)
462
+
463
+ flag = self.flag_encoder(torch.tensor(1, device=token_embs.device)).reshape(1, 1, -1)
464
+ flag_embs = gen_masks.unsqueeze(-1).to(token_embs.dtype) * flag
465
+
466
+ total_embs = token_embs + token_values + flag_embs
467
+
468
+ self.cur_gene_token_embs = token_embs
469
+
470
+ # Transform
471
+ transformer_output, hidden_states = self.transformer_encoder(
472
+ total_embs=total_embs,
473
+ key_padding_mask=key_padding_mask,
474
+ output_hidden_states=output_hidden_states,
475
+ )
476
+
477
+ # Cell embedding
478
+ cell_emb = transformer_output[:, 0, :]
479
+
480
+ output = {
481
+ "transformer_output": transformer_output,
482
+ "cell_emb": cell_emb,
483
+ }
484
+
485
+ if output_hidden_states:
486
+ output["hidden_states"] = hidden_states
487
+
488
+ if skip_decoders:
489
+ return output
490
+
491
+ # Decode
492
+ expr_output = self.expression_decoder(transformer_output)
493
+ output["expr_preds"] = expr_output["pred"]
494
+
495
+ if self.mvc_decoder is not None:
496
+ mvc_output = self.mvc_decoder(cell_emb, self.cur_gene_token_embs)
497
+ output["mvc_output"] = mvc_output["pred"]
498
+
499
+ return output
500
+
501
+
502
+ # =============================================================================
503
+ # HUGGINGFACE WRAPPER
504
+ # =============================================================================
505
+
506
+ class TXPreTrainedModel(PreTrainedModel):
507
+ """Base class for TXModel"""
508
+ config_class = TXConfig
509
+ base_model_prefix = "tx_model"
510
+ supports_gradient_checkpointing = False
511
+
512
+ def _init_weights(self, module):
513
+ if isinstance(module, nn.Linear):
514
+ module.weight.data.normal_(mean=0.0, std=0.02)
515
+ if module.bias is not None:
516
+ module.bias.data.zero_()
517
+ elif isinstance(module, nn.Embedding):
518
+ module.weight.data.normal_(mean=0.0, std=0.02)
519
+ if module.padding_idx is not None:
520
+ module.weight.data[module.padding_idx].zero_()
521
+ elif isinstance(module, nn.LayerNorm):
522
+ module.bias.data.zero_()
523
+ module.weight.data.fill_(1.0)
524
+
525
+
526
+ class TXModelForHF(TXPreTrainedModel):
527
+ """
528
+ HuggingFace-compatible TXModel
529
+
530
+ Requires ONLY: transformers, torch, safetensors
531
+ """
532
+
533
+ def __init__(self, config: TXConfig):
534
+ super().__init__(config)
535
+ self.tx_model = TXModel(config)
536
+ self.post_init()
537
+
538
+ def forward(
539
+ self,
540
+ genes: torch.Tensor,
541
+ values: torch.Tensor,
542
+ gen_masks: torch.Tensor,
543
+ key_padding_mask: Optional[torch.Tensor] = None,
544
+ skip_decoders: bool = False,
545
+ output_hidden_states: bool = False,
546
+ return_dict: bool = True,
547
+ **kwargs
548
+ ) -> Union[Tuple, BaseModelOutput]:
549
+
550
+ if key_padding_mask is None:
551
+ key_padding_mask = ~genes.eq(self.config.pad_token_id)
552
+
553
+ outputs = self.tx_model(
554
+ genes=genes,
555
+ values=values,
556
+ gen_masks=gen_masks,
557
+ key_padding_mask=key_padding_mask,
558
+ skip_decoders=skip_decoders,
559
+ output_hidden_states=output_hidden_states,
560
+ )
561
+
562
+ if not return_dict:
563
+ return tuple(v for v in outputs.values())
564
+
565
+ return BaseModelOutput(
566
+ last_hidden_state=outputs.get("cell_emb"),
567
+ hidden_states=outputs.get("hidden_states") if output_hidden_states else None,
568
+ )
569
+
570
+ def get_input_embeddings(self):
571
+ return self.tx_model.gene_encoder.embedding
572
+
573
+ def set_input_embeddings(self, value):
574
+ self.tx_model.gene_encoder.embedding = value
575
+
576
+
577
+ # Aliases
578
+ TXForCausalLM = TXModelForHF
579
+ AutoModelForCausalLM = TXModelForHF