| """ | |
| Hugging Face compatible implementation of Open-MAGVIt2 | |
| Code reference: https://github.com/TencentARC/Open-MAGVIT2 | |
| """ | |
| from transformers import PretrainedConfig | |
| class EncoderDecoderConfig(PretrainedConfig): | |
| model_type = "resnet_encoder_decoder" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.ch = kwargs.get("ch", 128) | |
| self.in_channels = kwargs.get("in_channels", 3) | |
| self.out_ch = kwargs.get("out_ch", 3) | |
| self.z_channels = kwargs.get("z_channels", 18) | |
| self.num_res_blocks = kwargs.get("num_res_blocks", 2) | |
| self.ch_mult = kwargs.get("ch_mult", [1, 1, 2, 2, 4]) | |
| class QuantizerConfig(PretrainedConfig): | |
| model_type = "lfq_quantizer" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.dim = kwargs.get("dim", 18) | |
| self.codebook_size = kwargs.get("codebook_size", 262144) | |
| self.batch_maximization_weight = kwargs.get("batch_maximization_weight", 1.0) | |
| self.sample_minimization_weight = kwargs.get("sample_minimization_weight", 1.0) | |
| class LFQTokenizerConfig(PretrainedConfig): | |
| r""" | |
| This is the configuration class to store the configuration of a :class:`~transform | |
| """ | |
| model_type = "lfq_tokenizer" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.encoder_decoder_config = kwargs.get("encoder_decoder_config", EncoderDecoderConfig()) | |
| self.quantizer_config = kwargs.get("quantizer_config", QuantizerConfig()) | |