tx-model-standalone / modeling_tx_standalone.py
Yuto2007's picture
Upload folder using huggingface_hub
e9d18f4 verified
# Copyright (C) Tahoe Therapeutics 2025. All rights reserved.
"""
HuggingFace-compatible wrapper for TXModel (Standalone version)
Only requires: transformers, torch, safetensors
"""
from typing import Optional, Union, Tuple
import torch
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from configuration_tx import TXConfig
from model_standalone import TXModel
class TXPreTrainedModel(PreTrainedModel):
"""
Base class for TXModel with HuggingFace integration
"""
config_class = TXConfig
base_model_prefix = "tx_model"
supports_gradient_checkpointing = False
_no_split_modules = ["TXBlock"]
def _init_weights(self, module):
"""Initialize weights"""
if isinstance(module, torch.nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, torch.nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, torch.nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class TXModelForHF(TXPreTrainedModel):
"""
HuggingFace-compatible TXModel
This model can be used directly with HuggingFace's from_pretrained()
and requires only: transformers, torch, safetensors
No dependencies on llmfoundry, composer, or other external libraries.
"""
def __init__(self, config: TXConfig):
super().__init__(config)
# Initialize standalone model
self.tx_model = TXModel(
vocab_size=config.vocab_size,
d_model=config.d_model,
n_layers=config.n_layers,
n_heads=config.n_heads,
expansion_ratio=config.expansion_ratio,
pad_token_id=config.pad_token_id,
pad_value=config.pad_value,
num_bins=config.num_bins,
norm_scheme=config.norm_scheme,
transformer_activation=config.transformer_activation,
cell_emb_style=config.cell_emb_style,
use_chem_token=config.use_chem_token,
attn_config=config.attn_config,
norm_config=config.norm_config,
gene_encoder_config=config.gene_encoder_config,
expression_encoder_config=config.expression_encoder_config,
expression_decoder_config=config.expression_decoder_config,
mvc_config=config.mvc_config,
chemical_encoder_config=config.chemical_encoder_config,
use_glu=config.use_glu,
return_gene_embeddings=config.return_gene_embeddings,
keep_first_n_tokens=config.keep_first_n_tokens,
)
# Post init
self.post_init()
def forward(
self,
genes: torch.Tensor,
values: torch.Tensor,
gen_masks: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
drug_ids: Optional[torch.Tensor] = None,
skip_decoders: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[Tuple, BaseModelOutput]:
"""
Forward pass through the model.
Args:
genes: Gene token IDs [batch_size, seq_len]
values: Expression values [batch_size, seq_len]
gen_masks: Generation masks [batch_size, seq_len]
key_padding_mask: Padding mask [batch_size, seq_len]
drug_ids: Drug IDs [batch_size] (optional)
skip_decoders: Whether to skip decoder computation
output_hidden_states: Whether to return hidden states
return_dict: Whether to return a dict or tuple
Returns:
Model outputs
"""
if key_padding_mask is None:
key_padding_mask = ~genes.eq(self.config.pad_token_id)
outputs = self.tx_model(
genes=genes,
values=values,
gen_masks=gen_masks,
key_padding_mask=key_padding_mask,
drug_ids=drug_ids,
skip_decoders=skip_decoders,
output_hidden_states=output_hidden_states,
)
if not return_dict:
return tuple(v for v in outputs.values())
# Convert to HuggingFace output format
return BaseModelOutput(
last_hidden_state=outputs.get("cell_emb"),
hidden_states=outputs.get("hidden_states") if output_hidden_states else None,
)
def get_input_embeddings(self):
"""Get input embeddings"""
return self.tx_model.gene_encoder.embedding
def set_input_embeddings(self, value):
"""Set input embeddings"""
self.tx_model.gene_encoder.embedding = value
def get_output_embeddings(self):
"""Get output embeddings (not applicable)"""
return None
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""
Load model from pretrained weights.
Works with both local paths and HuggingFace Hub.
Requires only: transformers, torch, safetensors
"""
# Let parent class handle config and weight loading
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
# Alias for easier importing
TXForCausalLM = TXModelForHF