# Copyright (C) Tahoe Therapeutics 2025. All rights reserved. """ HuggingFace-compatible wrapper for TXModel (Standalone version) Only requires: transformers, torch, safetensors """ from typing import Optional, Union, Tuple import torch from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput from configuration_tx import TXConfig from model_standalone import TXModel class TXPreTrainedModel(PreTrainedModel): """ Base class for TXModel with HuggingFace integration """ config_class = TXConfig base_model_prefix = "tx_model" supports_gradient_checkpointing = False _no_split_modules = ["TXBlock"] def _init_weights(self, module): """Initialize weights""" if isinstance(module, torch.nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, torch.nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, torch.nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class TXModelForHF(TXPreTrainedModel): """ HuggingFace-compatible TXModel This model can be used directly with HuggingFace's from_pretrained() and requires only: transformers, torch, safetensors No dependencies on llmfoundry, composer, or other external libraries. """ def __init__(self, config: TXConfig): super().__init__(config) # Initialize standalone model self.tx_model = TXModel( vocab_size=config.vocab_size, d_model=config.d_model, n_layers=config.n_layers, n_heads=config.n_heads, expansion_ratio=config.expansion_ratio, pad_token_id=config.pad_token_id, pad_value=config.pad_value, num_bins=config.num_bins, norm_scheme=config.norm_scheme, transformer_activation=config.transformer_activation, cell_emb_style=config.cell_emb_style, use_chem_token=config.use_chem_token, attn_config=config.attn_config, norm_config=config.norm_config, gene_encoder_config=config.gene_encoder_config, expression_encoder_config=config.expression_encoder_config, expression_decoder_config=config.expression_decoder_config, mvc_config=config.mvc_config, chemical_encoder_config=config.chemical_encoder_config, use_glu=config.use_glu, return_gene_embeddings=config.return_gene_embeddings, keep_first_n_tokens=config.keep_first_n_tokens, ) # Post init self.post_init() def forward( self, genes: torch.Tensor, values: torch.Tensor, gen_masks: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None, drug_ids: Optional[torch.Tensor] = None, skip_decoders: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ) -> Union[Tuple, BaseModelOutput]: """ Forward pass through the model. Args: genes: Gene token IDs [batch_size, seq_len] values: Expression values [batch_size, seq_len] gen_masks: Generation masks [batch_size, seq_len] key_padding_mask: Padding mask [batch_size, seq_len] drug_ids: Drug IDs [batch_size] (optional) skip_decoders: Whether to skip decoder computation output_hidden_states: Whether to return hidden states return_dict: Whether to return a dict or tuple Returns: Model outputs """ if key_padding_mask is None: key_padding_mask = ~genes.eq(self.config.pad_token_id) outputs = self.tx_model( genes=genes, values=values, gen_masks=gen_masks, key_padding_mask=key_padding_mask, drug_ids=drug_ids, skip_decoders=skip_decoders, output_hidden_states=output_hidden_states, ) if not return_dict: return tuple(v for v in outputs.values()) # Convert to HuggingFace output format return BaseModelOutput( last_hidden_state=outputs.get("cell_emb"), hidden_states=outputs.get("hidden_states") if output_hidden_states else None, ) def get_input_embeddings(self): """Get input embeddings""" return self.tx_model.gene_encoder.embedding def set_input_embeddings(self, value): """Set input embeddings""" self.tx_model.gene_encoder.embedding = value def get_output_embeddings(self): """Get output embeddings (not applicable)""" return None @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """ Load model from pretrained weights. Works with both local paths and HuggingFace Hub. Requires only: transformers, torch, safetensors """ # Let parent class handle config and weight loading return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) # Alias for easier importing TXForCausalLM = TXModelForHF