import torch import torch.nn.functional as F from ._ops import add_op_namespace_prefix @torch.library.custom_op(add_op_namespace_prefix("silu_and_mul"), mutates_args=()) def _silu_and_mul(x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] def backward(ctx, grad_output): x = ctx.saved_tensors[0] d = x.shape[-1] // 2 x1, x2 = x[..., :d], x[..., d:] sigmoid_x1 = torch.sigmoid(x1) silu_x1 = F.silu(x1) dsilu_dx1 = sigmoid_x1 + silu_x1 * (1 - sigmoid_x1) dx1 = grad_output * x2 * dsilu_dx1 dx2 = grad_output * silu_x1 return torch.cat([dx1, dx2], dim=-1) def setup_context(ctx, inputs, output): (x,) = inputs ctx.save_for_backward(x) _silu_and_mul.register_autograd(backward, setup_context=setup_context) @_silu_and_mul.register_fake def _(x: torch.Tensor) -> torch.Tensor: return x.new_empty(x.shape[0], x.shape[1] // 2)