danieldk's picture
danieldk HF Staff
Add Python invalid dependency test kernel
bb1e912
import torch
import torch.nn.functional as F
from ._ops import add_op_namespace_prefix
@torch.library.custom_op(add_op_namespace_prefix("silu_and_mul"), mutates_args=())
def _silu_and_mul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
d = x.shape[-1] // 2
x1, x2 = x[..., :d], x[..., d:]
sigmoid_x1 = torch.sigmoid(x1)
silu_x1 = F.silu(x1)
dsilu_dx1 = sigmoid_x1 + silu_x1 * (1 - sigmoid_x1)
dx1 = grad_output * x2 * dsilu_dx1
dx2 = grad_output * silu_x1
return torch.cat([dx1, dx2], dim=-1)
def setup_context(ctx, inputs, output):
(x,) = inputs
ctx.save_for_backward(x)
_silu_and_mul.register_autograd(backward, setup_context=setup_context)
@_silu_and_mul.register_fake
def _(x: torch.Tensor) -> torch.Tensor:
return x.new_empty(x.shape[0], x.shape[1] // 2)