| import copy | |
| from typing import Any, Dict | |
| from transformers import CLIPVisionConfig, PretrainedConfig | |
| from .configuration_llama import CustomLlamaConfig | |
| class POINTSChatConfig(PretrainedConfig): | |
| model_type = "points_chat" | |
| is_composition = True | |
| """Configuration class for `POINTS`.""" | |
| def __init__(self, | |
| **kwargs) -> None: | |
| super().__init__(**kwargs) | |
| vision_config = kwargs.pop("vision_config", None) | |
| llm_config = kwargs.pop("llm_config", None) | |
| if isinstance(vision_config, dict): | |
| self.vision_config = CLIPVisionConfig(**vision_config) | |
| else: | |
| self.vision_config = vision_config | |
| if isinstance(llm_config, dict): | |
| self.llm_config = CustomLlamaConfig(**llm_config) | |
| else: | |
| self.llm_config = llm_config | |
| def to_dict(self) -> Dict[str, Any]: | |
| output = copy.deepcopy(self.__dict__) | |
| output["vision_config"] = self.vision_config.to_dict() | |
| output["llm_config"] = self.llm_config.to_dict() | |
| return output | |