Spaces:
Running
on
Zero
Running
on
Zero
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from torchvision.transforms import Resize | |
| from transformers import AutoConfig, AutoModel, Siglip2VisionConfig, Siglip2VisionModel | |
| from . import models | |
| from .utils import ScalingLayer | |
| class TextAlignedTokenizer(nn.Module): | |
| def __init__( | |
| self, | |
| bottleneck, | |
| bottleneck_token_num=256, | |
| input_size=384, | |
| teacher='google/siglip2-so400m-patch14-384', | |
| input_type='quant', # choose from ['quant', 'rec', 'indices'] | |
| pool_scale=1, # choose from [1, 2, 3] | |
| decoder_depth=3, | |
| select_layer_id=-2, | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.input_size = input_size | |
| self.bottleneck_token_num = bottleneck_token_num | |
| self.teacher = teacher | |
| self.input_type = input_type | |
| self.pool_scale = pool_scale | |
| self.decoder_depth = decoder_depth | |
| self.select_layer_id = select_layer_id | |
| self.bottleneck_dim = bottleneck['args']['bottleneck_dim'] | |
| self.encoder_config = AutoConfig.from_pretrained(teacher) | |
| self.encoder = AutoModel.from_config(self.encoder_config).vision_model | |
| self.encoder_hidden_dim = self.encoder.config.hidden_size | |
| self.decoder_config = Siglip2VisionConfig() | |
| self.decoder_config.update({ | |
| 'patch_size': 1, | |
| 'num_hidden_layers': self.decoder_depth, | |
| 'num_channels': self.bottleneck_dim, | |
| 'hidden_size': self.encoder_hidden_dim, | |
| }) | |
| self.decoder = Siglip2VisionModel(self.decoder_config) | |
| self.encode_task_layer = nn.Sequential( | |
| nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim), | |
| nn.Tanh()) | |
| self.decode_task_layer = nn.Sequential( | |
| nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim), | |
| nn.Tanh(), | |
| nn.Linear(self.encoder_hidden_dim, self.encoder_hidden_dim)) | |
| bottleneck_args = { | |
| 'token_nums': self.bottleneck_token_num, | |
| 'input_dim': self.encoder_hidden_dim, | |
| 'output_dim': self.bottleneck_dim} | |
| self.bottleneck = models.make(bottleneck, args=bottleneck_args) | |
| self.scale_layer = ScalingLayer(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| self.image_resize = Resize((self.input_size, self.input_size)) | |
| def set_vq_eval_deterministic(self, deterministic=True): | |
| self.bottleneck.regularizer.set_eval_deterministic(deterministic) | |
| def device(self): | |
| return next(self.parameters()).device | |
| def dtype(self): | |
| return next(self.parameters()).dtype | |
| def from_checkpoint(cls, ckpt, load_teacher=True, **kwargs): | |
| ckpt = torch.load(ckpt, map_location='cpu') | |
| ckpt_kwargs = ckpt["model"]["args"] | |
| model = cls(**kwargs, **ckpt_kwargs) | |
| sd = ckpt["model"]["sd"] | |
| if not load_teacher: | |
| sd = {k: v for k, v in sd.items() if not k.startswith('teacher')} | |
| model.load_state_dict(sd, strict=True) | |
| return model | |
| def encode(self, x, **kwargs): | |
| if x.ndim == 5: | |
| x = rearrange(x, 'b c t h w -> (b t) c h w') | |
| x = self.scale_layer(x) | |
| if tuple(x.shape[-2:]) != (self.input_size, self.input_size): | |
| x = self.image_resize(x) | |
| vq_feats = self.encoder(x, output_hidden_states=True).hidden_states[self.select_layer_id] | |
| pool_scale = self.pool_scale | |
| pool_scale = kwargs.get("pool_scale", pool_scale) | |
| if pool_scale != 1: | |
| vq_feats = self.avg_pool(vq_feats, pool_scale) | |
| vq_feats = self.encode_task_layer(vq_feats.to(x)) | |
| bottleneck_out = self.bottleneck(vq_feats) | |
| z = bottleneck_out.pop('output') | |
| return {'encoded': z, 'pool_scale': pool_scale, 'vq_feats': vq_feats, **bottleneck_out} | |
| def avg_pool(self, z, pool_scale=1): | |
| if z.ndim == 3: | |
| b, n, c = z.shape | |
| p = int(n ** 0.5) | |
| z = rearrange(z, 'b (p1 p2) c -> b c p1 p2', p1=p, p2=p) | |
| else: | |
| b, c, p, _ = z.shape | |
| p_s = int(p // pool_scale) | |
| z = F.avg_pool2d( | |
| z, | |
| kernel_size=(pool_scale, pool_scale), | |
| stride=(pool_scale, pool_scale) | |
| ).contiguous() | |
| z = rearrange(z, 'b c p1 p2 -> b (p1 p2) c') | |
| return z | |
| def decode(self, z): | |
| if z.ndim == 4: | |
| z = rearrange(z, 'b c p1 p2 -> b (p1 p2) c') | |
| attention_mask = torch.ones(z.shape[:2], dtype=torch.int, device=z.device) | |
| p = int(z.shape[1]**0.5) | |
| spatial_shape = torch.tensor([[p, p]]*z.shape[0], device=self.device) | |
| z = self.decoder(z, attention_mask, spatial_shape, output_hidden_states=True).last_hidden_state | |
| z = self.decode_task_layer(z) | |
| return z | |
| def decode_from_bottleneck(self, bottleneck_rep): | |
| z = self.bottleneck.decode(bottleneck_rep) # (b, n, c) | |
| p = int(z.shape[1]**0.5) | |
| z = rearrange(z, 'b (p1 p2) c -> b c p1 p2', p1=p, p2=p) | |
| return self.decode(z) | |
| def forward(self, data, **kwargs): | |
| # data: video in shape (b, c, t, h, w) | |
| encode_output = self.encode(data, **kwargs) | |
| vq_feats = encode_output['encoded'] | |
| p = int(vq_feats.shape[1] ** 0.5) | |
| vq_feats = rearrange(vq_feats, 'b (h w) c -> b c h w', h=p, w=p) | |
| pred_feats = self.decode(vq_feats) | |
| if self.input_type == 'quant': | |
| z = encode_output["regularized_z"] # [b, n, c] | |
| elif self.input_type == 'indices': | |
| z = encode_output["bottleneck_rep"] # [b, n] | |
| elif self.input_type == 'rec': | |
| z = pred_feats # [b, n, c] | |
| encode_output['encoded'] = z | |
| return encode_output |