medmekk HF Staff commited on
Commit
92390cf
·
verified ·
1 Parent(s): 55896b9

Update build-output/torch-universal/triton_llama_mlp/mlp.py

Browse files
build-output/torch-universal/triton_llama_mlp/mlp.py CHANGED
@@ -152,7 +152,7 @@ class TritonLlamaMLP(nn.Module):
152
  down_output += self.down_proj.bias
153
 
154
  # Reshape back to original dimensions: (batch_size*seq_len, hidden_size) -> (*, hidden_size)
155
- return down_output.reshape(original_shape)
156
 
157
 
158
 
 
152
  down_output += self.down_proj.bias
153
 
154
  # Reshape back to original dimensions: (batch_size*seq_len, hidden_size) -> (*, hidden_size)
155
+ return down_output.reshape(original_shape).to(dtype)
156
 
157
 
158