Update modeling_ovis.py
Browse files- modeling_ovis.py +4 -4
    	
        modeling_ovis.py
    CHANGED
    
    | @@ -288,10 +288,10 @@ class Ovis(OvisPreTrainedModel): | |
| 288 | 
             
                    super().__init__(config, *inputs, **kwargs)
         | 
| 289 | 
             
                    attn_kwargs = dict()
         | 
| 290 | 
             
                    if self.config.llm_attn_implementation:
         | 
| 291 | 
            -
                        if self.config.llm_attn_implementation == "flash_attention_2":
         | 
| 292 | 
            -
             | 
| 293 | 
            -
             | 
| 294 | 
            -
             | 
| 295 | 
             
                        attn_kwargs["attn_implementation"] = self.config.llm_attn_implementation
         | 
| 296 | 
             
                    self.llm = AutoModelForCausalLM.from_config(self.config.llm_config, **attn_kwargs)
         | 
| 297 | 
             
                    assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
         | 
|  | |
| 288 | 
             
                    super().__init__(config, *inputs, **kwargs)
         | 
| 289 | 
             
                    attn_kwargs = dict()
         | 
| 290 | 
             
                    if self.config.llm_attn_implementation:
         | 
| 291 | 
            +
                        # if self.config.llm_attn_implementation == "flash_attention_2":
         | 
| 292 | 
            +
                        #     assert (is_flash_attn_2_available() and
         | 
| 293 | 
            +
                        #             version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.6.3")), \
         | 
| 294 | 
            +
                        #         "Using `flash_attention_2` requires having `flash_attn>=2.6.3` installed."
         | 
| 295 | 
             
                        attn_kwargs["attn_implementation"] = self.config.llm_attn_implementation
         | 
| 296 | 
             
                    self.llm = AutoModelForCausalLM.from_config(self.config.llm_config, **attn_kwargs)
         | 
| 297 | 
             
                    assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
         | 
