File size: 596 Bytes
dc2b9f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import sys
from pathlib import Path

import torch

sys.path.append(str(Path(__file__).resolve().parents[1] / "src"))
from wrinklebrane.codes import dct_codes, gram_matrix, hadamard_codes


def _assert_orthonormal(C: torch.Tensor, atol: float = 1e-5) -> None:
    G = gram_matrix(C)
    K = C.shape[1]
    I = torch.eye(K, dtype=C.dtype)
    assert torch.allclose(G, I, atol=atol)


def test_hadamard_codes_orthogonality() -> None:
    C = hadamard_codes(L=16, K=8)
    _assert_orthonormal(C)


def test_dct_codes_orthogonality() -> None:
    C = dct_codes(L=16, K=16)
    _assert_orthonormal(C)