File size: 2,351 Bytes
6b5de5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import math

import torch
from torch.autograd import Function
from torch.nn import functional as F


class PyTorchSkaFn(Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        # Get kernel size and padding from the weight tensor shape
        # w shape is (n, wc, ks*ks, h, w)
        ks = int(math.sqrt(w.shape[2]))
        pad = (ks - 1) // 2

        n, ic, h, width = x.shape
        wc = w.shape[1]  # wc = weight channels

        # 1. Extract patches from the input tensor
        # This creates a "view" of the input where each (h*w) column
        # contains the flattened data for a ks x ks patch.
        # Shape: (n, ic * ks * ks, h * w)
        x_unfolded = F.unfold(x, kernel_size=ks, padding=pad)

        # 2. Reshape the unfolded input for element-wise multiplication
        # Shape: (n, ic, ks * ks, h * w)
        x_unfolded = x_unfolded.view(n, ic, ks * ks, h * width)

        # 3. Prepare the weights for multiplication
        # The original weights have wc channels, which are repeated across the
        # input channels 'ic'.
        # We need to reshape w to match the unfolded input.
        # w original shape: (n, wc, ks*ks, h, w)
        # w reshaped:     (n, wc, ks*ks, h*w)
        w = w.view(n, wc, ks * ks, h * width)

        # If the number of input channels is not equal to weight channels,
        # it implies the weights are grouped/repeated.
        if ic != wc:
            # This handles the "ci % wc" logic from the Triton kernel,
            # repeating the weight channels to match the input channels.
            repeats = ic // wc
            w = w.repeat(1, repeats, 1, 1)

        # 4. Perform the core operation: element-wise multiplication and sum
        # This is the equivalent of the Triton kernel's main loop.
        # (x_unfolded * w) -> shape: (n, ic, ks*ks, h*w)
        # .sum(dim=2) sums across the kernel dimension (ks*ks).
        # output shape: (n, ic, h*w)
        output = (x_unfolded * w).sum(dim=2)

        # 5. Reshape the output back to the original image format
        # Shape: (n, ic, h, w)
        output = output.view(n, ic, h, width)

        return output


class SKA(torch.nn.Module):
    def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        return PyTorchSkaFn.apply(x, w)  # type: ignore