Update model code
Browse files
InternVL2-40B-Pretrain/configuration_internvl_chat.py
CHANGED
|
@@ -61,6 +61,8 @@ class InternVLChatConfig(PretrainedConfig):
|
|
| 61 |
self.ps_version = ps_version # pixel shuffle version
|
| 62 |
self.min_dynamic_patch = min_dynamic_patch
|
| 63 |
self.max_dynamic_patch = max_dynamic_patch
|
|
|
|
|
|
|
| 64 |
|
| 65 |
logger.info(f'vision_select_layer: {self.select_layer}')
|
| 66 |
logger.info(f'ps_version: {self.ps_version}')
|
|
|
|
| 61 |
self.ps_version = ps_version # pixel shuffle version
|
| 62 |
self.min_dynamic_patch = min_dynamic_patch
|
| 63 |
self.max_dynamic_patch = max_dynamic_patch
|
| 64 |
+
# By default, we use tie_word_embeddings=False for models of all sizes.
|
| 65 |
+
self.tie_word_embeddings = self.llm_config.tie_word_embeddings
|
| 66 |
|
| 67 |
logger.info(f'vision_select_layer: {self.select_layer}')
|
| 68 |
logger.info(f'ps_version: {self.ps_version}')
|
InternVL2-40B-Pretrain/modeling_intern_vit.py
CHANGED
|
@@ -364,6 +364,7 @@ class InternVisionEncoder(nn.Module):
|
|
| 364 |
class InternVisionModel(PreTrainedModel):
|
| 365 |
main_input_name = 'pixel_values'
|
| 366 |
_supports_flash_attn_2 = True
|
|
|
|
| 367 |
config_class = InternVisionConfig
|
| 368 |
_no_split_modules = ['InternVisionEncoderLayer']
|
| 369 |
|
|
|
|
| 364 |
class InternVisionModel(PreTrainedModel):
|
| 365 |
main_input_name = 'pixel_values'
|
| 366 |
_supports_flash_attn_2 = True
|
| 367 |
+
supports_gradient_checkpointing = True
|
| 368 |
config_class = InternVisionConfig
|
| 369 |
_no_split_modules = ['InternVisionEncoderLayer']
|
| 370 |
|
InternVL2-40B-Pretrain/modeling_internvl_chat.py
CHANGED
|
@@ -36,12 +36,13 @@ class InternVLChatModel(PreTrainedModel):
|
|
| 36 |
main_input_name = 'pixel_values'
|
| 37 |
base_model_prefix = 'language_model'
|
| 38 |
_supports_flash_attn_2 = True
|
|
|
|
| 39 |
_no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer']
|
| 40 |
|
| 41 |
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
|
| 42 |
super().__init__(config)
|
| 43 |
|
| 44 |
-
assert version_cmp(transformers.__version__, '4.
|
| 45 |
image_size = config.force_image_size or config.vision_config.image_size
|
| 46 |
patch_size = config.vision_config.patch_size
|
| 47 |
self.patch_size = patch_size
|
|
@@ -108,7 +109,7 @@ class InternVLChatModel(PreTrainedModel):
|
|
| 108 |
B, N, C = input_embeds.shape
|
| 109 |
input_embeds = input_embeds.reshape(B * N, C)
|
| 110 |
|
| 111 |
-
if torch.distributed.get_rank() == 0:
|
| 112 |
print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
|
| 113 |
|
| 114 |
input_ids = input_ids.reshape(B * N)
|
|
@@ -343,3 +344,13 @@ class InternVLChatModel(PreTrainedModel):
|
|
| 343 |
)
|
| 344 |
|
| 345 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
main_input_name = 'pixel_values'
|
| 37 |
base_model_prefix = 'language_model'
|
| 38 |
_supports_flash_attn_2 = True
|
| 39 |
+
supports_gradient_checkpointing = True
|
| 40 |
_no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer']
|
| 41 |
|
| 42 |
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
|
| 43 |
super().__init__(config)
|
| 44 |
|
| 45 |
+
assert version_cmp(transformers.__version__, '4.37.0', 'ge')
|
| 46 |
image_size = config.force_image_size or config.vision_config.image_size
|
| 47 |
patch_size = config.vision_config.patch_size
|
| 48 |
self.patch_size = patch_size
|
|
|
|
| 109 |
B, N, C = input_embeds.shape
|
| 110 |
input_embeds = input_embeds.reshape(B * N, C)
|
| 111 |
|
| 112 |
+
if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
|
| 113 |
print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
|
| 114 |
|
| 115 |
input_ids = input_ids.reshape(B * N)
|
|
|
|
| 344 |
)
|
| 345 |
|
| 346 |
return outputs
|
| 347 |
+
|
| 348 |
+
@property
|
| 349 |
+
def lm_head(self):
|
| 350 |
+
return self.language_model.get_output_embeddings()
|
| 351 |
+
|
| 352 |
+
def get_input_embeddings(self):
|
| 353 |
+
return self.language_model.get_input_embeddings()
|
| 354 |
+
|
| 355 |
+
def get_output_embeddings(self):
|
| 356 |
+
return self.language_model.get_output_embeddings()
|