Upload files with `vila-upload`.
Browse filesUpload media_encoder.py
Upload media.py
Upload modeling_vila.py
Upload configuration_vila.py
Upload builder.py
Upload mm_utils.py
Upload tokenizer_utils.py
Upload siglip_encoder.py
- builder.py +14 -4
- configuration_vila.py +16 -8
- media.py +4 -0
- media_encoder.py +3 -2
- mm_utils.py +1 -1
- modeling_vila.py +131 -35
- siglip_encoder.py +2 -3
- tokenizer_utils.py +2 -2
builder.py
CHANGED
|
@@ -22,9 +22,9 @@ from dataclasses import asdict
|
|
| 22 |
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 23 |
|
| 24 |
import torch
|
|
|
|
| 25 |
from huggingface_hub import file_exists, repo_exists
|
| 26 |
from huggingface_hub.utils import HFValidationError
|
| 27 |
-
import transformers
|
| 28 |
from transformers import (
|
| 29 |
AutoConfig,
|
| 30 |
AutoModelForCausalLM,
|
|
@@ -33,8 +33,9 @@ from transformers import (
|
|
| 33 |
PreTrainedModel,
|
| 34 |
PreTrainedTokenizer,
|
| 35 |
)
|
|
|
|
| 36 |
# from .conversation import *
|
| 37 |
-
from .conversation import
|
| 38 |
|
| 39 |
SENTINEL_TOKEN = "<vila/sentinel>"
|
| 40 |
MEDIA_TOKENS = {
|
|
@@ -51,9 +52,11 @@ DUMMY_CONVERSATION = [
|
|
| 51 |
{"from": "gpt", "value": "answer"},
|
| 52 |
] * 10
|
| 53 |
|
|
|
|
| 54 |
def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
|
| 55 |
return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
|
| 56 |
-
|
|
|
|
| 57 |
def has_tokenizer(repo_id_or_path: str) -> bool:
|
| 58 |
# Check if the tokenizer is in a local directory
|
| 59 |
if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
|
|
@@ -65,12 +68,14 @@ def has_tokenizer(repo_id_or_path: str) -> bool:
|
|
| 65 |
except HFValidationError:
|
| 66 |
return False
|
| 67 |
|
|
|
|
| 68 |
def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
|
| 69 |
if not hasattr(tokenizer, "sentinel_token"):
|
| 70 |
tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
|
| 71 |
tokenizer.sentinel_token = SENTINEL_TOKEN
|
| 72 |
tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
|
| 73 |
|
|
|
|
| 74 |
def tokenize_conversation_legacy(
|
| 75 |
messages: Sequence[Dict[str, str]],
|
| 76 |
tokenizer: transformers.PreTrainedTokenizer,
|
|
@@ -103,6 +108,7 @@ def tokenize_conversation_legacy(
|
|
| 103 |
|
| 104 |
return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
|
| 105 |
|
|
|
|
| 106 |
def tokenize_conversation(
|
| 107 |
messages: Sequence[Dict[str, str]],
|
| 108 |
tokenizer: transformers.PreTrainedTokenizer,
|
|
@@ -148,6 +154,7 @@ def tokenize_conversation(
|
|
| 148 |
)
|
| 149 |
return tokenizer_image_token(text, tokenizer, return_tensors="pt")
|
| 150 |
|
|
|
|
| 151 |
def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
|
| 152 |
_maybe_add_sentinel_token(tokenizer)
|
| 153 |
template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
|
|
@@ -159,6 +166,7 @@ def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
|
|
| 159 |
stop_tokens.add(stop_token)
|
| 160 |
return list(stop_tokens)
|
| 161 |
|
|
|
|
| 162 |
def context_length_extension(config):
|
| 163 |
orig_ctx_len = getattr(config, "max_position_embeddings", None)
|
| 164 |
model_max_length = getattr(config, "model_max_length", None)
|
|
@@ -186,7 +194,7 @@ def build_llm_and_tokenizer(
|
|
| 186 |
|
| 187 |
# Quantization related
|
| 188 |
quantization_restore_from_checkpoint = False
|
| 189 |
-
|
| 190 |
if quantization_restore_from_checkpoint:
|
| 191 |
fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
|
| 192 |
|
|
@@ -215,6 +223,8 @@ def build_llm_and_tokenizer(
|
|
| 215 |
if getattr(config, "chat_template", None) is not None:
|
| 216 |
print(f"Using chat template: {config.chat_template}")
|
| 217 |
fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
|
|
|
|
|
|
|
| 218 |
with open(fpath) as fd:
|
| 219 |
chat_template = fd.read()
|
| 220 |
tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
|
|
|
|
| 22 |
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 23 |
|
| 24 |
import torch
|
| 25 |
+
import transformers
|
| 26 |
from huggingface_hub import file_exists, repo_exists
|
| 27 |
from huggingface_hub.utils import HFValidationError
|
|
|
|
| 28 |
from transformers import (
|
| 29 |
AutoConfig,
|
| 30 |
AutoModelForCausalLM,
|
|
|
|
| 33 |
PreTrainedModel,
|
| 34 |
PreTrainedTokenizer,
|
| 35 |
)
|
| 36 |
+
|
| 37 |
# from .conversation import *
|
| 38 |
+
from .conversation import SeparatorStyle, default_conversation
|
| 39 |
|
| 40 |
SENTINEL_TOKEN = "<vila/sentinel>"
|
| 41 |
MEDIA_TOKENS = {
|
|
|
|
| 52 |
{"from": "gpt", "value": "answer"},
|
| 53 |
] * 10
|
| 54 |
|
| 55 |
+
|
| 56 |
def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
|
| 57 |
return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
def has_tokenizer(repo_id_or_path: str) -> bool:
|
| 61 |
# Check if the tokenizer is in a local directory
|
| 62 |
if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
|
|
|
|
| 68 |
except HFValidationError:
|
| 69 |
return False
|
| 70 |
|
| 71 |
+
|
| 72 |
def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
|
| 73 |
if not hasattr(tokenizer, "sentinel_token"):
|
| 74 |
tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
|
| 75 |
tokenizer.sentinel_token = SENTINEL_TOKEN
|
| 76 |
tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
|
| 77 |
|
| 78 |
+
|
| 79 |
def tokenize_conversation_legacy(
|
| 80 |
messages: Sequence[Dict[str, str]],
|
| 81 |
tokenizer: transformers.PreTrainedTokenizer,
|
|
|
|
| 108 |
|
| 109 |
return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
|
| 110 |
|
| 111 |
+
|
| 112 |
def tokenize_conversation(
|
| 113 |
messages: Sequence[Dict[str, str]],
|
| 114 |
tokenizer: transformers.PreTrainedTokenizer,
|
|
|
|
| 154 |
)
|
| 155 |
return tokenizer_image_token(text, tokenizer, return_tensors="pt")
|
| 156 |
|
| 157 |
+
|
| 158 |
def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
|
| 159 |
_maybe_add_sentinel_token(tokenizer)
|
| 160 |
template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
|
|
|
|
| 166 |
stop_tokens.add(stop_token)
|
| 167 |
return list(stop_tokens)
|
| 168 |
|
| 169 |
+
|
| 170 |
def context_length_extension(config):
|
| 171 |
orig_ctx_len = getattr(config, "max_position_embeddings", None)
|
| 172 |
model_max_length = getattr(config, "model_max_length", None)
|
|
|
|
| 194 |
|
| 195 |
# Quantization related
|
| 196 |
quantization_restore_from_checkpoint = False
|
| 197 |
+
|
| 198 |
if quantization_restore_from_checkpoint:
|
| 199 |
fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
|
| 200 |
|
|
|
|
| 223 |
if getattr(config, "chat_template", None) is not None:
|
| 224 |
print(f"Using chat template: {config.chat_template}")
|
| 225 |
fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
|
| 226 |
+
if not os.path.exists(fpath):
|
| 227 |
+
fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja")
|
| 228 |
with open(fpath) as fd:
|
| 229 |
chat_template = fd.read()
|
| 230 |
tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
|
configuration_vila.py
CHANGED
|
@@ -1,15 +1,24 @@
|
|
|
|
|
| 1 |
import math
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from typing import List, Optional
|
| 3 |
-
|
| 4 |
import torch
|
| 5 |
import torchvision
|
| 6 |
-
import os, os.path as osp
|
| 7 |
-
|
| 8 |
-
from threading import Thread
|
| 9 |
-
from copy import deepcopy
|
| 10 |
from PIL import Image
|
| 11 |
-
from transformers import
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
class VILAConfig(PretrainedConfig):
|
| 15 |
model_type = "vila"
|
|
@@ -82,4 +91,3 @@ class VILAConfig(PretrainedConfig):
|
|
| 82 |
self.video_encoder = video_encoder
|
| 83 |
|
| 84 |
super().__init__(**kwargs)
|
| 85 |
-
|
|
|
|
| 1 |
+
import json
|
| 2 |
import math
|
| 3 |
+
import os
|
| 4 |
+
import os.path as osp
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from threading import Thread
|
| 7 |
from typing import List, Optional
|
| 8 |
+
|
| 9 |
import torch
|
| 10 |
import torchvision
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from PIL import Image
|
| 12 |
+
from transformers import (
|
| 13 |
+
AutoProcessor,
|
| 14 |
+
PretrainedConfig,
|
| 15 |
+
PreTrainedModel,
|
| 16 |
+
Qwen2Config,
|
| 17 |
+
Qwen2ForCausalLM,
|
| 18 |
+
Qwen2PreTrainedModel,
|
| 19 |
+
TextIteratorStreamer,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
|
| 23 |
class VILAConfig(PretrainedConfig):
|
| 24 |
model_type = "vila"
|
|
|
|
| 91 |
self.video_encoder = video_encoder
|
| 92 |
|
| 93 |
super().__init__(**kwargs)
|
|
|
media.py
CHANGED
|
@@ -20,13 +20,16 @@ MEDIA_TOKENS = {
|
|
| 20 |
"video": "<vila/video>",
|
| 21 |
}
|
| 22 |
|
|
|
|
| 23 |
class Media:
|
| 24 |
pass
|
| 25 |
|
|
|
|
| 26 |
class File(Media):
|
| 27 |
def __init__(self, path: str) -> None:
|
| 28 |
self.path = path
|
| 29 |
|
|
|
|
| 30 |
class Image(File):
|
| 31 |
pass
|
| 32 |
|
|
@@ -34,6 +37,7 @@ class Image(File):
|
|
| 34 |
class Video(File):
|
| 35 |
pass
|
| 36 |
|
|
|
|
| 37 |
def make_list(obj: Any) -> List:
|
| 38 |
return obj if isinstance(obj, list) else [obj]
|
| 39 |
|
|
|
|
| 20 |
"video": "<vila/video>",
|
| 21 |
}
|
| 22 |
|
| 23 |
+
|
| 24 |
class Media:
|
| 25 |
pass
|
| 26 |
|
| 27 |
+
|
| 28 |
class File(Media):
|
| 29 |
def __init__(self, path: str) -> None:
|
| 30 |
self.path = path
|
| 31 |
|
| 32 |
+
|
| 33 |
class Image(File):
|
| 34 |
pass
|
| 35 |
|
|
|
|
| 37 |
class Video(File):
|
| 38 |
pass
|
| 39 |
|
| 40 |
+
|
| 41 |
def make_list(obj: Any) -> List:
|
| 42 |
return obj if isinstance(obj, list) else [obj]
|
| 43 |
|
media_encoder.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch import nn
|
| 3 |
from functools import partial
|
| 4 |
from typing import Any, Dict, List, Optional
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
class BaseEncoder(nn.Module):
|
| 8 |
def __init__(self, parent: nn.Module) -> None:
|
|
|
|
|
|
|
|
|
|
| 1 |
from functools import partial
|
| 2 |
from typing import Any, Dict, List, Optional
|
| 3 |
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
|
| 8 |
class BaseEncoder(nn.Module):
|
| 9 |
def __init__(self, parent: nn.Module) -> None:
|
mm_utils.py
CHANGED
|
@@ -26,7 +26,7 @@ import torch
|
|
| 26 |
from PIL import Image
|
| 27 |
from transformers import StoppingCriteria
|
| 28 |
|
| 29 |
-
from
|
| 30 |
|
| 31 |
|
| 32 |
def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
|
|
|
|
| 26 |
from PIL import Image
|
| 27 |
from transformers import StoppingCriteria
|
| 28 |
|
| 29 |
+
from .constants import DEFAULT_IMAGE_TOKEN
|
| 30 |
|
| 31 |
|
| 32 |
def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
|
modeling_vila.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import shutil
|
| 2 |
import copy
|
| 3 |
import json
|
| 4 |
import logging
|
|
@@ -6,6 +5,7 @@ import math
|
|
| 6 |
import os
|
| 7 |
import os.path
|
| 8 |
import os.path as osp
|
|
|
|
| 9 |
import warnings
|
| 10 |
from abc import ABC
|
| 11 |
from collections import OrderedDict, defaultdict, deque
|
|
@@ -15,13 +15,12 @@ from threading import Thread
|
|
| 15 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 16 |
|
| 17 |
import torch
|
| 18 |
-
import torch.nn as nn
|
| 19 |
import torch.distributed as dist
|
|
|
|
| 20 |
import torch.nn.functional as F
|
| 21 |
import torchvision
|
| 22 |
from einops import rearrange
|
| 23 |
from PIL import Image
|
| 24 |
-
|
| 25 |
from transformers import (
|
| 26 |
AutoConfig,
|
| 27 |
AutoModel,
|
|
@@ -34,28 +33,30 @@ from transformers import (
|
|
| 34 |
Qwen2Config,
|
| 35 |
Qwen2ForCausalLM,
|
| 36 |
Qwen2PreTrainedModel,
|
| 37 |
-
TextIteratorStreamer
|
| 38 |
)
|
| 39 |
-
from transformers.modeling_utils import ContextManagers, no_init_weights
|
| 40 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
|
| 41 |
|
| 42 |
from .base_projector import MultimodalProjector, MultimodalProjectorConfig
|
| 43 |
from .builder import build_llm_and_tokenizer
|
| 44 |
from .configuration_vila import VILAConfig
|
| 45 |
-
from .
|
| 46 |
-
from .
|
| 47 |
-
from .utils import get_model_config
|
| 48 |
from .media import extract_media
|
|
|
|
| 49 |
from .mm_utils import process_image, process_images
|
|
|
|
| 50 |
from .tokenizer_utils import tokenize_conversation
|
| 51 |
-
from .
|
| 52 |
-
|
| 53 |
|
| 54 |
# from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
|
| 55 |
# quick hack for remote code
|
| 56 |
def get_pg_manager():
|
| 57 |
return None
|
| 58 |
|
|
|
|
| 59 |
def get_model_weights_dtype(model: nn.Module):
|
| 60 |
pass
|
| 61 |
|
|
@@ -72,7 +73,77 @@ def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> Pre
|
|
| 72 |
mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
|
| 73 |
mm_projector = MultimodalProjector(mm_projector_cfg, config)
|
| 74 |
return mm_projector
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
|
| 78 |
## skip vision tower instantiation
|
|
@@ -110,7 +181,7 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
| 110 |
main_input_name = "input_embeds"
|
| 111 |
supports_gradient_checkpointing = True
|
| 112 |
_supports_flash_attn_2 = True
|
| 113 |
-
|
| 114 |
def __init__(self, config: VILAConfig, *args, **kwargs):
|
| 115 |
super().__init__(config)
|
| 116 |
self.config = config
|
|
@@ -119,22 +190,19 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
| 119 |
llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
|
| 120 |
else:
|
| 121 |
raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
|
| 122 |
-
|
| 123 |
# loading on cpu by default
|
| 124 |
device_map = kwargs.get("device_map", "cpu")
|
| 125 |
self.mm_projector = build_mm_projector(mm_projector_cfg, config)
|
| 126 |
self.vision_tower = build_vision_tower(vision_tower_cfg, config)
|
| 127 |
if "auto" in device_map or "cuda" in device_map:
|
| 128 |
self.mm_projector = self.mm_projector.cuda()
|
| 129 |
-
self.vision_tower = self.vision_tower.cuda()
|
| 130 |
# set device_map auto can autoamtically shard llm to different devices
|
| 131 |
self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
|
| 132 |
-
|
| 133 |
-
self.encoders = {
|
| 134 |
-
|
| 135 |
-
"video": BasicVideoEncoder(self)
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
self.post_config()
|
| 139 |
self.is_loaded = True
|
| 140 |
|
|
@@ -143,37 +211,65 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
| 143 |
), "At least one of the components must be instantiated."
|
| 144 |
|
| 145 |
@classmethod
|
| 146 |
-
def convert_vila_dev_ckpt_to_remote(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
# assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
|
| 148 |
from huggingface_hub import HfApi, snapshot_download
|
| 149 |
|
| 150 |
if os.path.isdir(model_path):
|
| 151 |
model_path = model_path
|
| 152 |
api = HfApi()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
if api.repo_exists(model_path):
|
| 154 |
model_path = snapshot_download(model_path, local_dir=output_dir)
|
| 155 |
print("downloading HF model to", model_path)
|
| 156 |
-
|
| 157 |
cfg_path = os.path.join(model_path, "config.json")
|
| 158 |
config = json.load(open(cfg_path))
|
| 159 |
-
config["version"] = "2.0"
|
| 160 |
config["architectures"] = ["VILAForCasualLM"]
|
| 161 |
config["auto_map"] = {
|
| 162 |
"AutoConfig": "modeling_vila.VILAConfig",
|
| 163 |
"AutoModel": "modeling_vila.VILAForCasualLM",
|
| 164 |
-
"AutoModelForCausalLM": "modeling_vila.VILAForCasualLM"
|
| 165 |
}
|
| 166 |
config["model_type"] = "vila"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
json.dump(config, open(cfg_path, "w"), indent=2)
|
| 168 |
self.copy_remote_py_files(model_path)
|
| 169 |
-
|
| 170 |
@classmethod
|
| 171 |
def copy_remote_py_files(cls, output_dir):
|
| 172 |
## copy .py and REAMDE for next loading remote code
|
| 173 |
current_file_path = os.path.abspath(__file__)
|
| 174 |
current_folder = os.path.dirname(current_file_path)
|
| 175 |
for file_name in os.listdir(current_folder):
|
| 176 |
-
if file_name.endswith(".py"):
|
| 177 |
full_file_name = os.path.join(current_folder, file_name)
|
| 178 |
if os.path.isfile(full_file_name):
|
| 179 |
shutil.copy(full_file_name, output_dir)
|
|
@@ -222,17 +318,15 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
| 222 |
state_dict=mm_projector_state_dict,
|
| 223 |
)
|
| 224 |
self.config.mm_projector_cfg = self.mm_projector.config
|
| 225 |
-
|
| 226 |
## update and save top-level config
|
| 227 |
self.config._name_or_path = output_dir
|
| 228 |
self.config.architectures = [self.__class__.__name__]
|
| 229 |
self.config.save_pretrained(output_dir)
|
| 230 |
-
|
| 231 |
## copy .py and REAMDE for next loading remote code
|
| 232 |
self.copy_remote_py_files(output_dir)
|
| 233 |
|
| 234 |
-
|
| 235 |
-
|
| 236 |
@classmethod
|
| 237 |
def from_pretrained(
|
| 238 |
cls,
|
|
@@ -258,7 +352,7 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
| 258 |
# variables for XGrammar
|
| 259 |
# print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
|
| 260 |
NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
|
| 261 |
-
|
| 262 |
# TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
|
| 263 |
self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
|
| 264 |
# XGrammar tokenizer and grammar compiler
|
|
@@ -318,11 +412,12 @@ class VILAPretrainedModel(PreTrainedModel):
|
|
| 318 |
self.get_vision_tower().eval()
|
| 319 |
if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
|
| 320 |
self.get_mm_projector().eval()
|
| 321 |
-
|
|
|
|
| 322 |
class VILAForCasualLM(VILAPretrainedModel):
|
| 323 |
def __init__(self, config: VILAConfig, *args, **kwargs):
|
| 324 |
super().__init__(config, *args, **kwargs)
|
| 325 |
-
|
| 326 |
def merge_features_for_dynamic_s2(self, image_features, block_sizes):
|
| 327 |
scales = self.get_vision_tower().scales
|
| 328 |
resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
|
|
@@ -395,7 +490,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
| 395 |
if getattr(self.config, "dynamic_s2", False):
|
| 396 |
image_features = self.get_vision_tower()(images)
|
| 397 |
image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
|
| 398 |
-
|
| 399 |
image_features = [
|
| 400 |
self.split_chessboard(x, block_size[0], block_size[1])
|
| 401 |
for x, block_size in zip(image_features, new_block_sizes)
|
|
@@ -881,6 +976,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
| 881 |
return outputs.logits, labels
|
| 882 |
|
| 883 |
return outputs
|
|
|
|
| 884 |
@torch.inference_mode()
|
| 885 |
def generate(
|
| 886 |
self,
|
|
@@ -898,7 +994,7 @@ class VILAForCasualLM(VILAPretrainedModel):
|
|
| 898 |
self,
|
| 899 |
prompt: Union[str, List],
|
| 900 |
generation_config: Optional[GenerationConfig] = None,
|
| 901 |
-
response_format
|
| 902 |
) -> str:
|
| 903 |
# TODO(zhijianl): Support directly taking conversation as input
|
| 904 |
conversation = [{"from": "human", "value": prompt}]
|
|
|
|
|
|
|
| 1 |
import copy
|
| 2 |
import json
|
| 3 |
import logging
|
|
|
|
| 5 |
import os
|
| 6 |
import os.path
|
| 7 |
import os.path as osp
|
| 8 |
+
import shutil
|
| 9 |
import warnings
|
| 10 |
from abc import ABC
|
| 11 |
from collections import OrderedDict, defaultdict, deque
|
|
|
|
| 15 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 16 |
|
| 17 |
import torch
|
|
|
|
| 18 |
import torch.distributed as dist
|
| 19 |
+
import torch.nn as nn
|
| 20 |
import torch.nn.functional as F
|
| 21 |
import torchvision
|
| 22 |
from einops import rearrange
|
| 23 |
from PIL import Image
|
|
|
|
| 24 |
from transformers import (
|
| 25 |
AutoConfig,
|
| 26 |
AutoModel,
|
|
|
|
| 33 |
Qwen2Config,
|
| 34 |
Qwen2ForCausalLM,
|
| 35 |
Qwen2PreTrainedModel,
|
| 36 |
+
TextIteratorStreamer,
|
| 37 |
)
|
|
|
|
| 38 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 39 |
+
from transformers.modeling_utils import ContextManagers, no_init_weights
|
| 40 |
|
| 41 |
from .base_projector import MultimodalProjector, MultimodalProjectorConfig
|
| 42 |
from .builder import build_llm_and_tokenizer
|
| 43 |
from .configuration_vila import VILAConfig
|
| 44 |
+
from .constants import *
|
| 45 |
+
from .conversation import SeparatorStyle, default_conversation
|
|
|
|
| 46 |
from .media import extract_media
|
| 47 |
+
from .media_encoder import BasicImageEncoder, BasicVideoEncoder
|
| 48 |
from .mm_utils import process_image, process_images
|
| 49 |
+
from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2
|
| 50 |
from .tokenizer_utils import tokenize_conversation
|
| 51 |
+
from .utils import get_model_config
|
| 52 |
+
|
| 53 |
|
| 54 |
# from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
|
| 55 |
# quick hack for remote code
|
| 56 |
def get_pg_manager():
|
| 57 |
return None
|
| 58 |
|
| 59 |
+
|
| 60 |
def get_model_weights_dtype(model: nn.Module):
|
| 61 |
pass
|
| 62 |
|
|
|
|
| 73 |
mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
|
| 74 |
mm_projector = MultimodalProjector(mm_projector_cfg, config)
|
| 75 |
return mm_projector
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def check_dot_in_model_path(model_path: str):
|
| 79 |
+
"""Check if the model path contains dot, which will affect the remote code loading."""
|
| 80 |
+
if osp.isdir(model_path): # local model
|
| 81 |
+
if "." in osp.abspath(model_path):
|
| 82 |
+
return True
|
| 83 |
+
else: # remote model
|
| 84 |
+
if "." in model_path:
|
| 85 |
+
return True
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_vila_version(model_path: str) -> str:
|
| 90 |
+
VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"]
|
| 91 |
+
for version in VERSIONS:
|
| 92 |
+
if version in model_path.lower():
|
| 93 |
+
return version
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def generate_jinja_template(conv_mode: str) -> str:
|
| 98 |
+
if conv_mode == "vicuna_v1":
|
| 99 |
+
return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." %}
|
| 100 |
+
{% set roles = ["USER", "ASSISTANT"] %}
|
| 101 |
+
{% set sep = " " %}
|
| 102 |
+
{% set sep2 = "</s>" %}
|
| 103 |
+
|
| 104 |
+
{{ system_prompt }}
|
| 105 |
+
|
| 106 |
+
{% for message in messages %}
|
| 107 |
+
{% if message['role'] == roles[0] %}
|
| 108 |
+
{{ roles[0] }}{{ sep }}{{ message['content'] }}{{ sep2 }}
|
| 109 |
+
{% else %}
|
| 110 |
+
{{ roles[1] }}{{ sep }}{{ message['content'] }}{{ sep2 }}
|
| 111 |
+
{% endif %}
|
| 112 |
+
{% endfor %}"""
|
| 113 |
+
elif conv_mode == "llama_3":
|
| 114 |
+
return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." %}
|
| 115 |
+
{% set roles = ["<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"] %}
|
| 116 |
+
{% set sep = "<|eot_id|>" %}
|
| 117 |
+
{% set sep2 = "<|end_of_text|>" %}
|
| 118 |
+
|
| 119 |
+
{{ system_prompt }}
|
| 120 |
+
|
| 121 |
+
{% for message in messages %}
|
| 122 |
+
{% if message['role'] == 'user' %}
|
| 123 |
+
{{ roles[0] }}{{ message['content'] }}{{ sep }}
|
| 124 |
+
{% else %}
|
| 125 |
+
{{ roles[1] }}{{ message['content'] }}{{ sep }}
|
| 126 |
+
{% endif %}
|
| 127 |
+
{% endfor %}
|
| 128 |
+
|
| 129 |
+
{{ sep2 }}"""
|
| 130 |
+
elif conv_mode == "hermes_2":
|
| 131 |
+
return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %}
|
| 132 |
+
{% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %}
|
| 133 |
+
{% set sep = "<|im_end|>" %}
|
| 134 |
+
|
| 135 |
+
{{ system_prompt }}{{ sep }}
|
| 136 |
+
|
| 137 |
+
{% for message in messages %}
|
| 138 |
+
{% if message['role'] == 'user' %}
|
| 139 |
+
{{ roles[0] }}{{ message['content'] }}{{ sep }}
|
| 140 |
+
{% else %}
|
| 141 |
+
{{ roles[1] }}{{ message['content'] }}{{ sep }}
|
| 142 |
+
{% endif %}
|
| 143 |
+
{% endfor %}"""
|
| 144 |
+
else:
|
| 145 |
+
raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.")
|
| 146 |
+
|
| 147 |
|
| 148 |
def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
|
| 149 |
## skip vision tower instantiation
|
|
|
|
| 181 |
main_input_name = "input_embeds"
|
| 182 |
supports_gradient_checkpointing = True
|
| 183 |
_supports_flash_attn_2 = True
|
| 184 |
+
|
| 185 |
def __init__(self, config: VILAConfig, *args, **kwargs):
|
| 186 |
super().__init__(config)
|
| 187 |
self.config = config
|
|
|
|
| 190 |
llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
|
| 191 |
else:
|
| 192 |
raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
|
| 193 |
+
|
| 194 |
# loading on cpu by default
|
| 195 |
device_map = kwargs.get("device_map", "cpu")
|
| 196 |
self.mm_projector = build_mm_projector(mm_projector_cfg, config)
|
| 197 |
self.vision_tower = build_vision_tower(vision_tower_cfg, config)
|
| 198 |
if "auto" in device_map or "cuda" in device_map:
|
| 199 |
self.mm_projector = self.mm_projector.cuda()
|
| 200 |
+
self.vision_tower = self.vision_tower.cuda()
|
| 201 |
# set device_map auto can autoamtically shard llm to different devices
|
| 202 |
self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
|
| 203 |
+
|
| 204 |
+
self.encoders = {"image": BasicImageEncoder(self), "video": BasicVideoEncoder(self)}
|
| 205 |
+
|
|
|
|
|
|
|
|
|
|
| 206 |
self.post_config()
|
| 207 |
self.is_loaded = True
|
| 208 |
|
|
|
|
| 211 |
), "At least one of the components must be instantiated."
|
| 212 |
|
| 213 |
@classmethod
|
| 214 |
+
def convert_vila_dev_ckpt_to_remote(
|
| 215 |
+
self,
|
| 216 |
+
model_path: str,
|
| 217 |
+
output_dir: str = None,
|
| 218 |
+
vila_version: str | None = None,
|
| 219 |
+
conv_mode: str | None = None,
|
| 220 |
+
*model_args,
|
| 221 |
+
**kwargs,
|
| 222 |
+
):
|
| 223 |
# assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
|
| 224 |
from huggingface_hub import HfApi, snapshot_download
|
| 225 |
|
| 226 |
if os.path.isdir(model_path):
|
| 227 |
model_path = model_path
|
| 228 |
api = HfApi()
|
| 229 |
+
|
| 230 |
+
if check_dot_in_model_path(model_path) and output_dir is None:
|
| 231 |
+
raise ValueError(
|
| 232 |
+
f"Model path {model_path} contains a dot, which will affect the remote code loading. Please specify the output directory without dot in the path to fix this issue."
|
| 233 |
+
)
|
| 234 |
+
if output_dir is not None and "." in output_dir:
|
| 235 |
+
raise ValueError(
|
| 236 |
+
f"Output directory {output_dir} contains a dot, which will affect the remote code loading. Please specify a valid output directory without dots."
|
| 237 |
+
)
|
| 238 |
+
if vila_version is None:
|
| 239 |
+
vila_version = get_vila_version(model_path)
|
| 240 |
+
|
| 241 |
if api.repo_exists(model_path):
|
| 242 |
model_path = snapshot_download(model_path, local_dir=output_dir)
|
| 243 |
print("downloading HF model to", model_path)
|
| 244 |
+
|
| 245 |
cfg_path = os.path.join(model_path, "config.json")
|
| 246 |
config = json.load(open(cfg_path))
|
| 247 |
+
config["version"] = "2.0" # nvila tag
|
| 248 |
config["architectures"] = ["VILAForCasualLM"]
|
| 249 |
config["auto_map"] = {
|
| 250 |
"AutoConfig": "modeling_vila.VILAConfig",
|
| 251 |
"AutoModel": "modeling_vila.VILAForCasualLM",
|
| 252 |
+
"AutoModelForCausalLM": "modeling_vila.VILAForCasualLM",
|
| 253 |
}
|
| 254 |
config["model_type"] = "vila"
|
| 255 |
+
if vila_version in ["vila1.5", "vila-m3"]:
|
| 256 |
+
if conv_mode is None:
|
| 257 |
+
raise ValueError(f"Please specify the conversation mode for {model_path}.")
|
| 258 |
+
config["chat_template"] = conv_mode
|
| 259 |
+
jinja_template = generate_jinja_template(conv_mode)
|
| 260 |
+
jinja_path = os.path.join(model_path, f"{conv_mode}.jinja")
|
| 261 |
+
with open(jinja_path, "w") as f:
|
| 262 |
+
f.write(jinja_template)
|
| 263 |
json.dump(config, open(cfg_path, "w"), indent=2)
|
| 264 |
self.copy_remote_py_files(model_path)
|
| 265 |
+
|
| 266 |
@classmethod
|
| 267 |
def copy_remote_py_files(cls, output_dir):
|
| 268 |
## copy .py and REAMDE for next loading remote code
|
| 269 |
current_file_path = os.path.abspath(__file__)
|
| 270 |
current_folder = os.path.dirname(current_file_path)
|
| 271 |
for file_name in os.listdir(current_folder):
|
| 272 |
+
if file_name.endswith(".py") or file_name.endswith(".jinja"):
|
| 273 |
full_file_name = os.path.join(current_folder, file_name)
|
| 274 |
if os.path.isfile(full_file_name):
|
| 275 |
shutil.copy(full_file_name, output_dir)
|
|
|
|
| 318 |
state_dict=mm_projector_state_dict,
|
| 319 |
)
|
| 320 |
self.config.mm_projector_cfg = self.mm_projector.config
|
| 321 |
+
|
| 322 |
## update and save top-level config
|
| 323 |
self.config._name_or_path = output_dir
|
| 324 |
self.config.architectures = [self.__class__.__name__]
|
| 325 |
self.config.save_pretrained(output_dir)
|
| 326 |
+
|
| 327 |
## copy .py and REAMDE for next loading remote code
|
| 328 |
self.copy_remote_py_files(output_dir)
|
| 329 |
|
|
|
|
|
|
|
| 330 |
@classmethod
|
| 331 |
def from_pretrained(
|
| 332 |
cls,
|
|
|
|
| 352 |
# variables for XGrammar
|
| 353 |
# print("DEBUG", len(self.tokenizer.added_tokens_encoder.keys()), self.tokenizer.added_tokens_encoder.keys())
|
| 354 |
NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
|
| 355 |
+
|
| 356 |
# TODO: SENTINEL_TOKEN is not added, need to check with Zhijian
|
| 357 |
self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
|
| 358 |
# XGrammar tokenizer and grammar compiler
|
|
|
|
| 412 |
self.get_vision_tower().eval()
|
| 413 |
if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
|
| 414 |
self.get_mm_projector().eval()
|
| 415 |
+
|
| 416 |
+
|
| 417 |
class VILAForCasualLM(VILAPretrainedModel):
|
| 418 |
def __init__(self, config: VILAConfig, *args, **kwargs):
|
| 419 |
super().__init__(config, *args, **kwargs)
|
| 420 |
+
|
| 421 |
def merge_features_for_dynamic_s2(self, image_features, block_sizes):
|
| 422 |
scales = self.get_vision_tower().scales
|
| 423 |
resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
|
|
|
|
| 490 |
if getattr(self.config, "dynamic_s2", False):
|
| 491 |
image_features = self.get_vision_tower()(images)
|
| 492 |
image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
|
| 493 |
+
|
| 494 |
image_features = [
|
| 495 |
self.split_chessboard(x, block_size[0], block_size[1])
|
| 496 |
for x, block_size in zip(image_features, new_block_sizes)
|
|
|
|
| 976 |
return outputs.logits, labels
|
| 977 |
|
| 978 |
return outputs
|
| 979 |
+
|
| 980 |
@torch.inference_mode()
|
| 981 |
def generate(
|
| 982 |
self,
|
|
|
|
| 994 |
self,
|
| 995 |
prompt: Union[str, List],
|
| 996 |
generation_config: Optional[GenerationConfig] = None,
|
| 997 |
+
response_format=None,
|
| 998 |
) -> str:
|
| 999 |
# TODO(zhijianl): Support directly taking conversation as input
|
| 1000 |
conversation = [{"from": "human", "value": prompt}]
|
siglip_encoder.py
CHANGED
|
@@ -20,11 +20,11 @@ import torch.nn.functional as F
|
|
| 20 |
from accelerate.hooks import add_hook_to_module
|
| 21 |
from einops import rearrange
|
| 22 |
from s2wrapper import forward as multiscale_forward
|
| 23 |
-
from transformers import AutoConfig, PreTrainedModel
|
| 24 |
from transformers.image_processing_utils import BaseImageProcessor
|
| 25 |
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 26 |
from transformers.models.siglip import SiglipVisionModel
|
| 27 |
-
|
| 28 |
|
| 29 |
class VisionTower(nn.Module):
|
| 30 |
def __init__(self, vision_tower, args, delay_load=False):
|
|
@@ -146,7 +146,6 @@ class VisionTower(nn.Module):
|
|
| 146 |
|
| 147 |
return image_features
|
| 148 |
|
| 149 |
-
|
| 150 |
@property
|
| 151 |
def dummy_feature(self):
|
| 152 |
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
|
|
|
| 20 |
from accelerate.hooks import add_hook_to_module
|
| 21 |
from einops import rearrange
|
| 22 |
from s2wrapper import forward as multiscale_forward
|
| 23 |
+
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, SiglipImageProcessor
|
| 24 |
from transformers.image_processing_utils import BaseImageProcessor
|
| 25 |
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 26 |
from transformers.models.siglip import SiglipVisionModel
|
| 27 |
+
|
| 28 |
|
| 29 |
class VisionTower(nn.Module):
|
| 30 |
def __init__(self, vision_tower, args, delay_load=False):
|
|
|
|
| 146 |
|
| 147 |
return image_features
|
| 148 |
|
|
|
|
| 149 |
@property
|
| 150 |
def dummy_feature(self):
|
| 151 |
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
tokenizer_utils.py
CHANGED
|
@@ -19,9 +19,9 @@ from typing import Any, Dict, List, Optional, Sequence
|
|
| 19 |
import torch
|
| 20 |
import transformers
|
| 21 |
|
| 22 |
-
from .conversation import default_conversation, SeparatorStyle
|
| 23 |
-
from .mm_utils import tokenizer_image_token
|
| 24 |
from .constants import IGNORE_INDEX, SENTINEL_TOKEN
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# __all__ = [
|
| 27 |
# "tokenize_conversation",
|
|
|
|
| 19 |
import torch
|
| 20 |
import transformers
|
| 21 |
|
|
|
|
|
|
|
| 22 |
from .constants import IGNORE_INDEX, SENTINEL_TOKEN
|
| 23 |
+
from .conversation import SeparatorStyle, default_conversation
|
| 24 |
+
from .mm_utils import tokenizer_image_token
|
| 25 |
|
| 26 |
# __all__ = [
|
| 27 |
# "tokenize_conversation",
|