manueldeprada HF Staff commited on
Commit
1e37df0
·
verified ·
1 Parent(s): 45cf90c

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +16 -1
custom_generate/generate.py CHANGED
@@ -612,7 +612,22 @@ def generate(model, *args, **kwargs):
612
  penalty_alpha (`float`): The alpha value for the degeneration penalty.
613
  top_k (`int`): The number of candidates to consider at each step.
614
  """
 
 
 
 
 
 
 
 
 
 
 
615
  generation_outputs = GenerationMixin.generate(
616
- model, *args, custom_generate=_contrastive_search, **kwargs
 
 
 
 
617
  )
618
  return generation_outputs
 
612
  penalty_alpha (`float`): The alpha value for the degeneration penalty.
613
  top_k (`int`): The number of candidates to consider at each step.
614
  """
615
+ cache_implementation = kwargs.pop("cache_implementation", "dynamic_full")
616
+ if cache_implementation != "dynamic_full" and (
617
+ "sliding_attention"
618
+ in getattr(model.config.get_text_config(), "layer_types", [])
619
+ or getattr(model.config.get_text_config(), "sliding_window", 0) > 0
620
+ ):
621
+ logger.warning_once(
622
+ "Contrastive search with sliding window attention requires `cache_implementation='dynamic_full'`. "
623
+ "Using other cache types may break rollback and cause incorrect results."
624
+ )
625
+
626
  generation_outputs = GenerationMixin.generate(
627
+ model,
628
+ *args,
629
+ custom_generate=_contrastive_search,
630
+ cache_implementation=cache_implementation,
631
+ **kwargs,
632
  )
633
  return generation_outputs