issue while running inference
i am repeatedly stumbling to an issue when i use medium length prompts may i know how do i fix it
the error is as below
"
IndexError Traceback (most recent call last)
/tmp/ipykernel_1453/3969631446.py in <cell line: 0>()
26 model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0") # Input goes to first GPU
27
---> 28 generated_ids = model.generate(
29 **model_inputs,
30 max_new_tokens=1024,
/usr/local/lib/python3.11/dist-packages/fla/models/rwkv7/modeling_rwkv7.py in generate(self, *args, **kwargs)
477 def generate(self, *args, **kwargs):
478 try:
--> 479 return super().generate(*args, **kwargs)
480 except AttributeError as exception:
481 if 'past_key_values' in str(exception):
/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
117
118 return decorate_context
/usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, custom_generate, **kwargs)
2595
2596 # 12. run sample (it degenerates to greedy search when generation_config.do_sample=False)
-> 2597 result = self._sample(
2598 input_ids,
2599 logits_processor=prepared_logits_processor,
/usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py in _sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
3555
3556 if is_prefill:
-> 3557 outputs = self(**model_inputs, return_dict=True)
3558 is_prefill = False
3559 else:
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1751
1752 result = None
/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
173 output = module._old_forward(*args, **kwargs)
174 else:
--> 175 output = module._old_forward(*args, **kwargs)
176 return module._hf_hook.post_forward(module, output)
177
/usr/local/lib/python3.11/dist-packages/transformers/utils/deprecation.py in wrapped_func(*args, **kwargs)
170 warnings.warn(message, FutureWarning, stacklevel=2)
171
--> 172 return func(*args, **kwargs)
173
174 return wrapped_func
/usr/local/lib/python3.11/dist-packages/fla/models/rwkv7/modeling_rwkv7.py in forward(self, input_ids, attention_mask, inputs_embeds, past_key_values, labels, shift_labels, use_cache, output_attentions, output_hidden_states, return_dict, logits_to_keep, **kwargs)
546 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
547
--> 548 outputs = self.model(
549 input_ids=input_ids,
550 attention_mask=attention_mask,
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1751
1752 result = None
/usr/local/lib/python3.11/dist-packages/fla/models/rwkv7/modeling_rwkv7.py in forward(self, input_ids, attention_mask, inputs_embeds, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict, cu_seqlens, **kwargs)
413 )
414 else:
--> 415 hidden_states, attentions, past_key_values, v_first = layer(
416 hidden_states,
417 attention_mask=attention_mask,
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1751
1752 result = None
/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
173 output = module._old_forward(*args, **kwargs)
174 else:
--> 175 output = module._old_forward(*args, **kwargs)
176 return module._hf_hook.post_forward(module, output)
177
/usr/local/lib/python3.11/dist-packages/fla/models/rwkv7/modeling_rwkv7.py in forward(self, hidden_states, attention_mask, past_key_values, use_cache, output_attentions, v_first, cu_seqlens, **kwargs)
188 residual = self.pre_norm(hidden_states) if hasattr(self, 'pre_norm') else hidden_states
189 hidden_states = self.attn_norm(residual)
--> 190 hidden_states, attentions, past_key_values, v_first = self.attn(
191 hidden_states=hidden_states,
192 attention_mask=attention_mask,
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
1740
1741 # torchrec tests the code consistency with the following code
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1751
1752 result = None
/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py in new_forward(module, *args, **kwargs)
173 output = module._old_forward(*args, **kwargs)
174 else:
--> 175 output = module._old_forward(*args, **kwargs)
176 return module._hf_hook.post_forward(module, output)
177
/usr/local/lib/python3.11/dist-packages/fla/layers/rwkv7.py in forward(self, hidden_states, attention_mask, past_key_values, use_cache, output_attentions, v_first, cu_seqlens, **kwargs)
290 # if training, use chunk mode no matter how short the sequence is
291 # launching the triton kernel for just one token will actually be slower
--> 292 o, recurrent_state = chunk_rwkv7(
293 r=r,
294 log_w=w,
/usr/local/lib/python3.11/dist-packages/fla/ops/rwkv7/chunk.py in chunk_rwkv7(r, k, v, a, b, w, log_w, scale, initial_state, output_final_state, cu_seqlens, head_first)
75 assert log_w is not None, "Either w or log_w must be provided!"
76
---> 77 return chunk_dplr_delta_rule(
78 q=r,
79 k=k,
/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py in _fn(*args, **kwargs)
743 )
744 try:
--> 745 return fn(*args, **kwargs)
746 finally:
747 _maybe_set_eval_frame(prior)
/usr/local/lib/python3.11/dist-packages/fla/ops/generalized_delta_rule/dplr/chunk.py in chunk_dplr_delta_rule(q, k, v, a, b, gk, scale, initial_state, output_final_state, cu_seqlens, head_first)
348 )
349 scale = k.shape[-1] ** -0.5 if scale is None else scale
--> 350 o, final_state = ChunkDPLRDeltaRuleFunction.apply(
351 q,
352 k,
/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py in apply(cls, *args, **kwargs)
573 # See NOTE: [functorch vjp and autograd interaction]
574 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575 return super().apply(*args, **kwargs) # type: ignore[misc]
576
577 if not is_setup_ctx_defined:
/usr/local/lib/python3.11/dist-packages/fla/utils.py in wrapper(*args, **kwargs)
158
159 with ctx:
--> 160 return fn(*contiguous_args, **contiguous_kwargs)
161
162 return wrapper
/usr/local/lib/python3.11/dist-packages/torch/amp/autocast_mode.py in decorate_fwd(*args, **kwargs)
501 if cast_inputs is None:
502 args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
--> 503 return fwd(*args, **kwargs)
504 else:
505 autocast_context = torch.is_autocast_enabled(device_type)
/usr/local/lib/python3.11/dist-packages/fla/ops/generalized_delta_rule/dplr/chunk.py in forward(ctx, q, k, v, a, b, gk, scale, initial_state, output_final_state, cu_seqlens)
109 ):
110 chunk_size = 16
--> 111 o, final_state = chunk_dplr_fwd(
112 q=q,
113 k=k,
/usr/local/lib/python3.11/dist-packages/fla/ops/generalized_delta_rule/dplr/chunk.py in chunk_dplr_fwd(q, k, v, a, b, gk, scale, initial_state, output_final_state, cu_seqlens, chunk_size)
35 T = q.shape[1]
36 BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
---> 37 gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, cu_seqlens=cu_seqlens)
38
39 A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
/usr/local/lib/python3.11/dist-packages/fla/ops/rwkv6/chunk.py in chunk_rwkv6_fwd_cumsum(g, chunk_size, cu_seqlens)
83 def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
84 # keep cummulative normalizer in fp32
---> 85 chunk_rwkv6_fwd_cumsum_kernel[grid](
86 g,
87 gi,
/usr/local/lib/python3.11/dist-packages/triton/runtime/jit.py in (*args, **kwargs)
328 memorizes the grid.
329 """
--> 330 return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
331 # return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
332
/usr/local/lib/python3.11/dist-packages/triton/runtime/autotuner.py in run(self, *args, **kwargs)
383 for v, heur in self.values.items():
384 kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 385 return self.fn.run(*args, **kwargs)
386
387
/usr/local/lib/python3.11/dist-packages/triton/runtime/autotuner.py in run(self, *args, **kwargs)
184 pruned_configs = self.prune_configs(kwargs)
185 bench_start = time.time()
--> 186 timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
187 bench_end = time.time()
188 self.bench_time = bench_end - bench_start
/usr/local/lib/python3.11/dist-packages/triton/runtime/autotuner.py in (.0)
184 pruned_configs = self.prune_configs(kwargs)
185 bench_start = time.time()
--> 186 timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
187 bench_end = time.time()
188 self.bench_time = bench_end - bench_start
/usr/local/lib/python3.11/dist-packages/triton/runtime/autotuner.py in _bench(self, config, *args, **meta)
164
165 try:
--> 166 return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
167 except (OutOfResources, CompileTimeAssertionFailure):
168 return [float("inf"), float("inf"), float("inf")]
/usr/local/lib/python3.11/dist-packages/triton/testing.py in do_bench(fn, warmup, rep, grad_to_none, quantiles, return_mode)
115 di = runtime.driver.active.get_device_interface()
116
--> 117 fn()
118 di.synchronize()
119
/usr/local/lib/python3.11/dist-packages/triton/runtime/autotuner.py in kernel_call()
150 self.pre_hook(full_nargs)
151 try:
--> 152 self.fn.run(
153 *args,
154 **current,
/usr/local/lib/python3.11/dist-packages/triton/runtime/jit.py in run(self, grid, warmup, *args, **kwargs)
621 # compile the kernel
622 src = self.ASTSource(self, signature, constants, configs[0])
--> 623 kernel = self.compile(
624 src,
625 target=target,
/usr/local/lib/python3.11/dist-packages/triton/compiler/compiler.py in compile(src, target, options)
277 use_ir_loc = os.environ.get("USE_IR_LOC", None)
278 for ext, compile_ir in list(stages.items())[first_stage:]:
--> 279 next_module = compile_ir(module, metadata)
280 ir_filename = f"{file_name}.{ext}"
281 if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None):
/usr/local/lib/python3.11/dist-packages/triton/backends/nvidia/compiler.py in (src, metadata)
385 stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
386 stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
--> 387 stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
388 stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability)
389 stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability)
/usr/local/lib/python3.11/dist-packages/triton/backends/nvidia/compiler.py in make_llir(src, metadata, options, capability)
284 if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
285 passes.llvmir.add_di_scope(pm)
--> 286 pm.run(mod)
287 # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
288 llvm.init_targets()
IndexError: map::at"
and the code i am using:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
'fla-hub/rwkv7-7.2B-g0',
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float16
tokenizer = AutoTokenizer.from_pretrained('fla-hub/rwkv7-7.2B-g0', trust_remote_code=True)
prompt = "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"
messages = [
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True
)
model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0")
generated_ids = model.generate(
**model_inputs,
max_new_tokens=1024,
do_sample=True,
temperature=1.0,
top_p=0.3,
repetition_penalty=1.2
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
print(response)
I believe there is something wrong with your card and triton. you could set to use fused_recurrent kernel for always :)