Update build-output/torch-universal/triton_llama_mlp/mlp.py
Browse files
build-output/torch-universal/triton_llama_mlp/mlp.py
CHANGED
|
@@ -152,7 +152,7 @@ class TritonLlamaMLP(nn.Module):
|
|
| 152 |
down_output += self.down_proj.bias
|
| 153 |
|
| 154 |
# Reshape back to original dimensions: (batch_size*seq_len, hidden_size) -> (*, hidden_size)
|
| 155 |
-
return down_output.reshape(original_shape)
|
| 156 |
|
| 157 |
|
| 158 |
|
|
|
|
| 152 |
down_output += self.down_proj.bias
|
| 153 |
|
| 154 |
# Reshape back to original dimensions: (batch_size*seq_len, hidden_size) -> (*, hidden_size)
|
| 155 |
+
return down_output.reshape(original_shape).to(dtype)
|
| 156 |
|
| 157 |
|
| 158 |
|