| import math | |
| import torch | |
| from torch.autograd import Function | |
| from torch.nn import functional as F | |
| class PyTorchSkaFn(Function): | |
| def forward(ctx, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: | |
| # Get kernel size and padding from the weight tensor shape | |
| # w shape is (n, wc, ks*ks, h, w) | |
| ks = int(math.sqrt(w.shape[2])) | |
| pad = (ks - 1) // 2 | |
| n, ic, h, width = x.shape | |
| wc = w.shape[1] # wc = weight channels | |
| # 1. Extract patches from the input tensor | |
| # This creates a "view" of the input where each (h*w) column | |
| # contains the flattened data for a ks x ks patch. | |
| # Shape: (n, ic * ks * ks, h * w) | |
| x_unfolded = F.unfold(x, kernel_size=ks, padding=pad) | |
| # 2. Reshape the unfolded input for element-wise multiplication | |
| # Shape: (n, ic, ks * ks, h * w) | |
| x_unfolded = x_unfolded.view(n, ic, ks * ks, h * width) | |
| # 3. Prepare the weights for multiplication | |
| # The original weights have wc channels, which are repeated across the | |
| # input channels 'ic'. | |
| # We need to reshape w to match the unfolded input. | |
| # w original shape: (n, wc, ks*ks, h, w) | |
| # w reshaped: (n, wc, ks*ks, h*w) | |
| w = w.view(n, wc, ks * ks, h * width) | |
| # If the number of input channels is not equal to weight channels, | |
| # it implies the weights are grouped/repeated. | |
| if ic != wc: | |
| # This handles the "ci % wc" logic from the Triton kernel, | |
| # repeating the weight channels to match the input channels. | |
| repeats = ic // wc | |
| w = w.repeat(1, repeats, 1, 1) | |
| # 4. Perform the core operation: element-wise multiplication and sum | |
| # This is the equivalent of the Triton kernel's main loop. | |
| # (x_unfolded * w) -> shape: (n, ic, ks*ks, h*w) | |
| # .sum(dim=2) sums across the kernel dimension (ks*ks). | |
| # output shape: (n, ic, h*w) | |
| output = (x_unfolded * w).sum(dim=2) | |
| # 5. Reshape the output back to the original image format | |
| # Shape: (n, ic, h, w) | |
| output = output.view(n, ic, h, width) | |
| return output | |
| class SKA(torch.nn.Module): | |
| def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: | |
| return PyTorchSkaFn.apply(x, w) # type: ignore | |