Instructions to use lijiang/Omni-Diffusion with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use lijiang/Omni-Diffusion with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("lijiang/Omni-Diffusion", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # coding=utf-8 | |
| # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import warnings | |
| import copy | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, Optional, Tuple, Union | |
| import torch | |
| import torch.distributions as dists | |
| from torch.nn import functional as F | |
| from transformers import __version__ | |
| from transformers.generation.configuration_utils import ( | |
| GenerationConfig | |
| ) | |
| from transformers.utils import ( | |
| ModelOutput, | |
| is_torchdynamo_compiling, | |
| logging, | |
| ) | |
| logger = logging.get_logger(__name__) | |
| from tqdm import tqdm | |
| def top_p_logits(logits, top_p=None): | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| # Shift the indices to the right to keep the first token above the threshold | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) | |
| mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) | |
| logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) | |
| return logits | |
| def top_k_logits(logits, top_k=None): | |
| top_k = min(top_k, logits.size(-1)) # Safety check | |
| # Remove all tokens with a probability less than the last token of the top-k | |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
| logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) | |
| return logits | |
| def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False, | |
| repeat_penalty=1.0, max_position_penalty=1.0, past_x=None, mask_id=None,): | |
| if temperature > 0: | |
| logits = logits / temperature | |
| if top_p is not None and top_p < 1: | |
| logits = top_p_logits(logits, top_p) | |
| if top_k is not None: | |
| logits = top_k_logits(logits, top_k) | |
| if repeat_penalty != 1.0: | |
| select_mask = torch.logical_and((past_x != 0), (past_x != mask_id)) | |
| generated_tokens = set(past_x[select_mask].tolist()) | |
| for token in set(generated_tokens): | |
| logits[:, token][logits[:, token] < 0] *= repeat_penalty | |
| logits[:, token][logits[:, token] >= 0] /= repeat_penalty | |
| if max_position_penalty != 1.0: | |
| token_length = logits.shape[-2] | |
| if token_length > 100: | |
| penalty_map = [i / (token_length - 100) * (max_position_penalty - 1.0) + 1.0 | |
| for i in range(token_length - 100)] | |
| penalty_map = torch.tensor(penalty_map).unsqueeze(-1).to(logits.device).to(logits.dtype) | |
| penalty_map = torch.cat([torch.ones_like(logits[:100, :1]), penalty_map], dim=0) | |
| penalty_map = penalty_map.repeat(1, logits.shape[-1]) | |
| logits[logits < 0] *= penalty_map[logits < 0] | |
| logits[logits >= 0] /= penalty_map[logits >= 0] | |
| probs = torch.softmax(logits, dim=-1) | |
| if temperature > 0: | |
| try: | |
| x0 = dists.Categorical(probs=probs).sample() | |
| confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) | |
| except: | |
| confidence, x0 = probs.max(dim=-1) | |
| else: | |
| confidence, x0 = probs.max(dim=-1) | |
| if margin_confidence: | |
| sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) | |
| top1_probs = sorted_probs[:, 0] | |
| top2_probs = sorted_probs[:, 1] | |
| confidence = top1_probs - top2_probs | |
| if neg_entropy: | |
| epsilon = 1e-10 | |
| log_probs = torch.log(probs + epsilon) | |
| confidence = torch.sum(probs * log_probs, dim=-1) | |
| return confidence, x0 | |
| class DreamModelOutput(ModelOutput): | |
| sequences: torch.LongTensor = None | |
| history: Optional[Tuple[torch.FloatTensor]] = None | |
| class DreamGenerationConfig(GenerationConfig): | |
| def __init__(self, **kwargs): | |
| self.temperature: float = kwargs.pop("temperature", 0.0) | |
| self.top_p: Optional[float] = kwargs.pop("top_p", None) | |
| self.top_k: Optional[int] = kwargs.pop("top_k", None) | |
| self.max_length = kwargs.pop("max_length", 20) | |
| self.max_new_tokens = kwargs.pop("max_new_tokens", None) | |
| # diffusion specific params | |
| self.eps: float = kwargs.pop("eps", 1e-3) | |
| self.steps: int = kwargs.pop("steps", 512) | |
| self.alg: str = kwargs.pop("alg", 'origin') | |
| self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None) | |
| # Parameters that define the output variables of `generate` | |
| self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1) | |
| self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False) | |
| self.output_history: bool = kwargs.pop("output_history", False) | |
| # Special tokens that can be used at generation time | |
| self.mask_token_id = kwargs.pop("mask_token_id", None) | |
| self.pad_token_id = kwargs.pop("pad_token_id", None) | |
| self.bos_token_id = kwargs.pop("bos_token_id", None) | |
| self.eos_token_id = kwargs.pop("eos_token_id", None) | |
| # Wild card | |
| self.generation_kwargs = kwargs.pop("generation_kwargs", {}) | |
| # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub | |
| # interface. | |
| self._from_model_config = kwargs.pop("_from_model_config", False) | |
| self._commit_hash = kwargs.pop("_commit_hash", None) | |
| self.transformers_version = kwargs.pop("transformers_version", __version__) | |
| # Additional attributes without default values | |
| if not self._from_model_config: | |
| # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a | |
| # model's default configuration file | |
| for key, value in kwargs.items(): | |
| try: | |
| setattr(self, key, value) | |
| except AttributeError as err: | |
| logger.error(f"Can't set {key} with value {value} for {self}") | |
| raise err | |
| # Validate the values of the attributes | |
| self.validate(is_init=True) | |
| def validate(self, is_init=False): | |
| pass | |
| class DreamGenerationMixin: | |
| def _expand_inputs_for_generation( | |
| expand_size: int = 1, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.LongTensor] = None | |
| ) -> Tuple[torch.LongTensor, Dict[str, Any]]: | |
| """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" | |
| # Do not call torch.repeat_interleave if expand_size is 1 because it clones | |
| # the input tensor and thus requires more memory although no change is applied | |
| if expand_size == 1: | |
| return input_ids, attention_mask | |
| if input_ids is not None: | |
| input_ids = input_ids.repeat_interleave(expand_size, dim=0) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.repeat_interleave(expand_size, dim=0) | |
| return input_ids, attention_mask | |
| def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): | |
| """Performs validation related to the resulting generated length""" | |
| # Can't throw warnings/exceptions during compilation | |
| if is_torchdynamo_compiling(): | |
| return | |
| # 1. Max length warnings related to poor parameterization | |
| if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: | |
| # 20 is the default max_length of the generation config | |
| warnings.warn( | |
| f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " | |
| "generation length. We recommend setting `max_new_tokens` to control the maximum length of the " | |
| "generation.", | |
| UserWarning, | |
| ) | |
| if input_ids_length >= generation_config.max_length: | |
| input_ids_string = "input_ids" | |
| raise ValueError( | |
| f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" | |
| f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" | |
| " increasing `max_length` or, better yet, setting `max_new_tokens`." | |
| ) | |
| def _prepare_generated_length( | |
| self, | |
| generation_config, | |
| has_default_max_length, | |
| input_ids_length, | |
| ): | |
| """Prepared max and min length in generation configs to avoid clashes between similar attributes""" | |
| if generation_config.max_new_tokens is not None: | |
| if not has_default_max_length and generation_config.max_length is not None: | |
| logger.warning( | |
| f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" | |
| f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " | |
| "Please refer to the documentation for more information. " | |
| "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" | |
| ) | |
| generation_config.max_length = generation_config.max_new_tokens + input_ids_length | |
| elif has_default_max_length: | |
| if generation_config.max_length == DreamGenerationConfig().max_length: | |
| generation_config.max_length = generation_config.max_length + input_ids_length | |
| max_position_embeddings = getattr(self.config, "max_position_embeddings", None) | |
| if max_position_embeddings is not None: | |
| generation_config.max_length = min(generation_config.max_length, max_position_embeddings) | |
| return generation_config | |
| def _prepare_generation_config( | |
| self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict | |
| ) -> DreamGenerationConfig: | |
| """ | |
| Prepares the base generation config, then applies any generation configuration options from kwargs. This | |
| function handles retrocompatibility with respect to configuration files. | |
| """ | |
| # priority: `generation_config` argument > `model.generation_config` (the default generation config) | |
| using_model_generation_config = False | |
| if generation_config is None: | |
| generation_config = DreamGenerationConfig.from_model_config(self.config) | |
| using_model_generation_config = True | |
| # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` | |
| # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an | |
| # exception will be raised in `_validate_model_kwargs` | |
| if not is_torchdynamo_compiling(): | |
| generation_config = copy.deepcopy(generation_config) | |
| _kwargs = generation_config.update(**kwargs) | |
| # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model | |
| if not using_model_generation_config: | |
| if generation_config.bos_token_id is None: | |
| generation_config.bos_token_id = self.generation_config.bos_token_id | |
| if generation_config.eos_token_id is None: | |
| generation_config.eos_token_id = self.generation_config.eos_token_id | |
| if generation_config.pad_token_id is None: | |
| generation_config.pad_token_id = self.generation_config.pad_token_id | |
| if generation_config.mask_token_id is None: | |
| generation_config.mask_token_id = self.generation_config.mask_token_id | |
| return generation_config | |
| def _prepare_special_tokens( | |
| self, | |
| generation_config: DreamGenerationConfig, | |
| device: Optional[Union[torch.device, str]] = None, | |
| ): | |
| """ | |
| Prepares the special tokens for generation, overwriting the generation config with their processed versions | |
| converted to tensor. | |
| Note that `generation_config` is changed in place and stops being serializable after this method is called. | |
| That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the | |
| function). However, if called outside `generate`, consider creating a copy of `generation_config` first. | |
| """ | |
| # Convert special tokens to tensors | |
| def _tensor_or_none(token, device=None): | |
| if token is None: | |
| return token | |
| device = device if device is not None else self.device | |
| if isinstance(token, torch.Tensor): | |
| return token.to(device) | |
| return torch.tensor(token, device=device, dtype=torch.long) | |
| bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) | |
| eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) | |
| pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) | |
| mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device) | |
| # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). | |
| if eos_token_tensor is not None and eos_token_tensor.ndim == 0: | |
| eos_token_tensor = eos_token_tensor.unsqueeze(0) | |
| # Set pad token if unset (and there are conditions to do so) | |
| if pad_token_tensor is None and eos_token_tensor is not None: | |
| pad_token_tensor = eos_token_tensor[0] | |
| logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") | |
| # Update generation config with the updated special tokens tensors | |
| # NOTE: this must be written into a different attribute name than the one holding the original special tokens | |
| # (in their non-tensor form), in order to enable end-to-end compilation. See | |
| # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations | |
| generation_config._bos_token_tensor = bos_token_tensor | |
| generation_config._eos_token_tensor = eos_token_tensor | |
| generation_config._pad_token_tensor = pad_token_tensor | |
| generation_config._mask_token_tensor = mask_token_tensor | |
| def diffusion_generate( | |
| self, | |
| inputs: Optional[torch.Tensor] = None, | |
| generation_config: Optional[DreamGenerationConfig] = None, | |
| inputs_embeds=None, | |
| prefix_lm=False, | |
| alg=None, | |
| block_size=-1, | |
| cfg=0.0, | |
| add_boa_token=False, | |
| **kwargs, | |
| ) -> Union[DreamModelOutput, torch.LongTensor]: | |
| # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call | |
| generation_config = self._prepare_generation_config(generation_config, **kwargs) | |
| generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x) | |
| generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits) | |
| # breakpoint() | |
| # 2. Define model inputs | |
| if inputs is not None: | |
| input_ids = inputs | |
| device = input_ids.device | |
| input_ids_length = input_ids.shape[-1] | |
| else: | |
| input_ids = None | |
| device = inputs_embeds.device | |
| input_ids_length = inputs_embeds.shape[1] | |
| attention_mask = kwargs.pop("attention_mask", None) | |
| self._prepare_special_tokens(generation_config, device=device) | |
| # 3. Prepare `max_length`. | |
| has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None | |
| generation_config = self._prepare_generated_length( | |
| generation_config=generation_config, | |
| has_default_max_length=has_default_max_length, | |
| input_ids_length=input_ids_length, | |
| ) | |
| self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) | |
| # import pdb;pdb.set_trace() | |
| # 4. Check input_ids | |
| #if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type: | |
| if not is_torchdynamo_compiling() and self.device.type != device.type: | |
| warnings.warn( | |
| "You are calling .generate() with the `input_ids` being on a device type different" | |
| f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" | |
| f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." | |
| " Please make sure that you have put `input_ids` to the" | |
| f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" | |
| " running `.generate()`.", | |
| UserWarning, | |
| ) | |
| # breakpoint() | |
| if ( | |
| hasattr(generation_config, "pad_token_id") and | |
| input_ids is not None and | |
| torch.any(input_ids == generation_config.pad_token_id) and | |
| attention_mask is None | |
| ): | |
| warnings.warn( | |
| "Padding was detected but no attention mask is passed here. For correct " | |
| "generation results, please set `attention_mask` when batch-padding inputs.", | |
| UserWarning, | |
| ) | |
| assert generation_config.num_return_sequences == 1, \ | |
| "Currently, we only support num_return_sequences = 1 for diffusion generation." | |
| input_ids, attention_mask = self._expand_inputs_for_generation( | |
| expand_size=generation_config.num_return_sequences, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ) | |
| result = self._sample( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| generation_config=generation_config, | |
| generation_tokens_hook_func=generation_tokens_hook_func, | |
| generation_logits_hook_func=generation_logits_hook_func, | |
| inputs_embeds=inputs_embeds, | |
| device=device, | |
| prefix_lm=prefix_lm, | |
| alg=alg, | |
| block_size=block_size, | |
| cfg=cfg, | |
| add_boa_token=add_boa_token, | |
| **kwargs, | |
| ) | |
| return result | |
| def _sample( | |
| self, | |
| input_ids: torch.LongTensor, | |
| attention_mask: Optional[torch.LongTensor], | |
| generation_config: DreamGenerationConfig, | |
| generation_tokens_hook_func, | |
| generation_logits_hook_func, | |
| inputs_embeds=None, | |
| prefix_lm=False, | |
| device=None, | |
| step_ratio=None, | |
| penalty=1.2, | |
| alg=None, | |
| block_size=None, | |
| add_boa_token=False, | |
| max_position_penalty=1.0, | |
| repeat_penalty=1.0, | |
| cfg=0.0, | |
| **kwargs, | |
| ) -> Union[DreamModelOutput, torch.LongTensor]: | |
| output_history = True | |
| return_dict_in_generate = generation_config.return_dict_in_generate | |
| max_length = generation_config.max_length | |
| mask_token_id = generation_config.mask_token_id | |
| max_new_tokens = generation_config.max_new_tokens | |
| steps = min(generation_config.steps, max_new_tokens) | |
| eps = generation_config.eps | |
| alg = generation_config.alg if alg is None else alg | |
| print("denoise algorithm: " + alg) | |
| alg_temp = generation_config.alg_temp | |
| temperature = generation_config.temperature | |
| top_p = generation_config.top_p | |
| top_k = generation_config.top_k | |
| histories = [] if (return_dict_in_generate and output_history) else None | |
| all_logit = [] | |
| generated_tokens = [] | |
| block_size = max_new_tokens if block_size < 0 else block_size | |
| if input_ids is None: | |
| assert device is not None | |
| assert inputs_embeds is not None | |
| bsz, seq_len = inputs_embeds.shape[:2] | |
| max_length = seq_len + max_new_tokens | |
| input_ids = torch.full((bsz, seq_len), 0, dtype=torch.long).to(device) | |
| tok_idx = None | |
| past_key_values = None | |
| x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id) | |
| timesteps = torch.linspace(1, eps, steps + 1, device=x.device) | |
| x = generation_tokens_hook_func(None, x, None) | |
| if step_ratio is not None: | |
| steps = int(max_new_tokens * step_ratio) | |
| if add_boa_token: | |
| bos_index = int((x.shape[1] - (x == mask_token_id).sum()) + (x == mask_token_id).sum() * 0.2) | |
| x[:, bos_index] = 151684 # <|begin_of_audio|> | |
| input_x = x.clone() | |
| total_steps = steps | |
| block_num = (x == mask_token_id).sum() // block_size | |
| if block_num * block_size < (x == mask_token_id).sum(): block_num += 1 | |
| input_length = input_ids.shape[-1] | |
| task = None | |
| if "task" in kwargs: task = kwargs['task'] | |
| if cfg > 0: | |
| import random | |
| empty_prompt = "" | |
| if task == "S2I": | |
| empty_prompt = "<|im_start|>system\nPlease generate an image based on the input audio.<|im_end|>\n" | |
| empty_prompt += "<|im_start|>user\n<|im_end|>\n<|im_start|>assistant\n" | |
| un_x = kwargs['tokenizer'].encode(empty_prompt) | |
| elif task == "T2I": | |
| empty_prompt = "<|im_start|>user\nGenerate an image based on the provided text description.\n" | |
| empty_prompt += "<|im_end|>\n<|im_start|>assistant\n" | |
| first_audio_token = kwargs['tokenizer'].encode("<|begin_of_audio|>")[0] | |
| un_x_text = random.sample([_ for _ in range(first_audio_token)], | |
| input_ids.shape[1] - len(kwargs['tokenizer'].encode(empty_prompt))) | |
| un_x = kwargs['tokenizer'].encode("<|im_start|>user\nGenerate an image based on the provided \ | |
| text description.\n") | |
| un_x = un_x + un_x_text + kwargs['tokenizer'].encode("<|im_end|>\n<|im_start|>assistant\n") | |
| for block_idx in range(block_num): | |
| block_mask = torch.zeros([x.shape[-1]]).to(torch.bool).to(x.device) | |
| block_mask[input_length + block_idx * block_size: input_length + (block_idx + 1) * block_size] = True | |
| steps = int(block_mask.sum() / (x.shape[-1] - input_length) * total_steps) | |
| timesteps = torch.linspace(1, eps, steps + 1, device=x.device) | |
| for i in tqdm(range(steps)): | |
| mask_index = (x == mask_token_id) | |
| if mask_index.sum() == 0: break | |
| inputs_embeds_curr = self.model.embed_tokens(x) | |
| if inputs_embeds is not None: | |
| inputs_embeds_curr[:, :inputs_embeds.shape[1]] = inputs_embeds | |
| if cfg > 0: | |
| input_un_x = torch.tensor(un_x).unsqueeze(0).to(x.dtype).to(x.device) | |
| input_un_x = torch.cat([input_un_x, x[:, input_ids.shape[1]:]], dim=1) | |
| un_inpus_embeds = self.model.embed_tokens(input_un_x) | |
| attention_mask_cond = torch.ones([1, inputs_embeds_curr.shape[1], inputs_embeds_curr.shape[1]]) | |
| attention_mask_cond = attention_mask_cond.to(torch.bool).to(inputs_embeds_curr.device) | |
| attention_mask_uncond = torch.zeros([1, inputs_embeds_curr.shape[1], inputs_embeds_curr.shape[1]]) | |
| attention_mask_uncond[:, :un_inpus_embeds.shape[1], :un_inpus_embeds.shape[1]] = 1 | |
| attention_mask_uncond = attention_mask_uncond.to(torch.bool).to(inputs_embeds.device) | |
| attention_mask = torch.cat([attention_mask_cond, attention_mask_uncond]) | |
| attention_mask = attention_mask.unsqueeze(1) | |
| if inputs_embeds_curr.shape[1] != un_inpus_embeds.shape[1]: | |
| un_inpus_embeds = torch.cat([un_inpus_embeds, | |
| torch.zeros_like(inputs_embeds_curr[:, :inputs_embeds_curr.shape[1] - | |
| un_inpus_embeds.shape[1], :])], dim=1) | |
| input_inputs_embeds_curr = torch.cat([inputs_embeds_curr, un_inpus_embeds]) | |
| model_logits = self.forward_dream(None, attention_mask, tok_idx, | |
| inputs_embeds=input_inputs_embeds_curr).logits | |
| logits = model_logits[:1]; un_logits = model_logits[1:] | |
| logits = un_logits + (cfg + 1) * (logits - un_logits) | |
| logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) | |
| else: | |
| logits = self.forward_dream(None, attention_mask, tok_idx, | |
| inputs_embeds=inputs_embeds_curr).logits | |
| logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) | |
| logits = generation_logits_hook_func(i, x, logits) | |
| mask_logits = logits[mask_index] | |
| if i == 0: | |
| input_index = torch.where(mask_index[0]==True)[0][0] | |
| t = timesteps[i] | |
| s = timesteps[i + 1] | |
| if alg == 'origin': | |
| p_transfer = 1 - s / t if i < steps - 1 else 1 | |
| x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id | |
| transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer | |
| _, x0[transfer_index_t_s] = sample_tokens( | |
| mask_logits[transfer_index_t_s], | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| max_position_penalty=max_position_penalty, | |
| ) | |
| x[mask_index] = x0.clone() | |
| else: | |
| if alg == 'maskgit_plus': | |
| confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, | |
| max_position_penalty=max_position_penalty) | |
| elif alg == 'topk_margin': | |
| confidence, x0 = sample_tokens( | |
| mask_logits, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| margin_confidence=True, | |
| max_position_penalty=max_position_penalty, | |
| ) | |
| elif alg == 'entropy': | |
| confidence, x0 = sample_tokens( | |
| mask_logits, | |
| temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| neg_entropy=True, | |
| max_position_penalty=max_position_penalty, | |
| ) | |
| elif alg == "entropy-penalty": | |
| confidence, x0 = sample_tokens( | |
| mask_logits, | |
| temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| neg_entropy=True, | |
| repeat_penalty=repeat_penalty if len(histories) != 0 else 1.0, | |
| past_x=histories[-1] if len(histories) != 0 else [], | |
| mask_id=mask_token_id, | |
| max_position_penalty=max_position_penalty, | |
| ) | |
| else: | |
| raise RuntimeError(f"Unknown alg: {alg}") | |
| block_mask_1 = block_mask[mask_index[0]] | |
| confidence = confidence + torch.where(block_mask_1, 0, -torch.inf).to(confidence.device) | |
| num_mask_token = mask_index.sum() | |
| num_mask_token = (x[:, block_mask] == mask_token_id).sum() | |
| number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token | |
| number_transfer_tokens = max(number_transfer_tokens, 1) | |
| if number_transfer_tokens > 0: | |
| if alg_temp is None or alg_temp == 0: | |
| _, transfer_index = torch.topk(confidence, number_transfer_tokens) | |
| else: | |
| confidence = confidence / alg_temp | |
| confidence = F.softmax(confidence, dim=-1) | |
| transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens) | |
| x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id | |
| x0_[transfer_index] = x0[transfer_index].clone() | |
| x[mask_index] = x0_ | |
| logit,indic = torch.max(torch.softmax(logits.clone(),dim=-1),-1) | |
| logit = logit[0][x[0]!=0] | |
| indic = indic[0][x[0]!=0] | |
| temp_X = x[0][x[0]!=0] | |
| x = generation_tokens_hook_func(i, x, logits) | |
| if histories is not None: | |
| histories.append(x.clone()) | |
| all_logit.append(torch.max(logits.clone(),-1)[-1]) | |
| return (x, histories) | |