DraconicDragon's picture
Upload 3 files
6b5de5c verified
import math
import torch
from torch.autograd import Function
from torch.nn import functional as F
class PyTorchSkaFn(Function):
@staticmethod
def forward(ctx, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
# Get kernel size and padding from the weight tensor shape
# w shape is (n, wc, ks*ks, h, w)
ks = int(math.sqrt(w.shape[2]))
pad = (ks - 1) // 2
n, ic, h, width = x.shape
wc = w.shape[1] # wc = weight channels
# 1. Extract patches from the input tensor
# This creates a "view" of the input where each (h*w) column
# contains the flattened data for a ks x ks patch.
# Shape: (n, ic * ks * ks, h * w)
x_unfolded = F.unfold(x, kernel_size=ks, padding=pad)
# 2. Reshape the unfolded input for element-wise multiplication
# Shape: (n, ic, ks * ks, h * w)
x_unfolded = x_unfolded.view(n, ic, ks * ks, h * width)
# 3. Prepare the weights for multiplication
# The original weights have wc channels, which are repeated across the
# input channels 'ic'.
# We need to reshape w to match the unfolded input.
# w original shape: (n, wc, ks*ks, h, w)
# w reshaped: (n, wc, ks*ks, h*w)
w = w.view(n, wc, ks * ks, h * width)
# If the number of input channels is not equal to weight channels,
# it implies the weights are grouped/repeated.
if ic != wc:
# This handles the "ci % wc" logic from the Triton kernel,
# repeating the weight channels to match the input channels.
repeats = ic // wc
w = w.repeat(1, repeats, 1, 1)
# 4. Perform the core operation: element-wise multiplication and sum
# This is the equivalent of the Triton kernel's main loop.
# (x_unfolded * w) -> shape: (n, ic, ks*ks, h*w)
# .sum(dim=2) sums across the kernel dimension (ks*ks).
# output shape: (n, ic, h*w)
output = (x_unfolded * w).sum(dim=2)
# 5. Reshape the output back to the original image format
# Shape: (n, ic, h, w)
output = output.view(n, ic, h, width)
return output
class SKA(torch.nn.Module):
def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
return PyTorchSkaFn.apply(x, w) # type: ignore