|
|
|
|
|
""" |
|
|
Test WrinkleBrane Optimizations |
|
|
Validate performance and fidelity improvements from optimizations. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
sys.path.append(str(Path(__file__).resolve().parent / "src")) |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import time |
|
|
from wrinklebrane.membrane_bank import MembraneBank |
|
|
from wrinklebrane.codes import hadamard_codes |
|
|
from wrinklebrane.slicer import make_slicer |
|
|
from wrinklebrane.write_ops import store_pairs |
|
|
from wrinklebrane.metrics import psnr, ssim |
|
|
from wrinklebrane.optimizations import ( |
|
|
compute_adaptive_alphas, |
|
|
generate_extended_codes, |
|
|
HierarchicalMembraneBank, |
|
|
optimized_store_pairs |
|
|
) |
|
|
|
|
|
def test_adaptive_alphas(): |
|
|
"""Test adaptive alpha scaling vs uniform alphas.""" |
|
|
print("π§ͺ Testing Adaptive Alpha Scaling...") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
B, L, H, W, K = 1, 32, 16, 16, 8 |
|
|
|
|
|
|
|
|
bank_uniform = MembraneBank(L, H, W, device=device) |
|
|
bank_adaptive = MembraneBank(L, H, W, device=device) |
|
|
bank_uniform.allocate(B) |
|
|
bank_adaptive.allocate(B) |
|
|
|
|
|
C = hadamard_codes(L, K).to(device) |
|
|
slicer = make_slicer(C) |
|
|
|
|
|
|
|
|
patterns = [] |
|
|
for i in range(K): |
|
|
pattern = torch.zeros(H, W, device=device) |
|
|
|
|
|
energy_scale = 0.1 + i * 0.3 |
|
|
|
|
|
if i % 3 == 0: |
|
|
for y in range(H): |
|
|
for x in range(W): |
|
|
if (x - H//2)**2 + (y - W//2)**2 <= (3 + i//3)**2: |
|
|
pattern[y, x] = energy_scale |
|
|
elif i % 3 == 1: |
|
|
size = 4 + i//3 |
|
|
start = (H - size) // 2 |
|
|
pattern[start:start+size, start:start+size] = energy_scale * 0.5 |
|
|
else: |
|
|
for d in range(min(H, W)): |
|
|
if d + i//3 < H and d + i//3 < W: |
|
|
pattern[d + i//3, d] = energy_scale * 0.1 |
|
|
|
|
|
patterns.append(pattern) |
|
|
|
|
|
patterns = torch.stack(patterns) |
|
|
keys = torch.arange(K, device=device) |
|
|
|
|
|
|
|
|
uniform_alphas = torch.ones(K, device=device) |
|
|
M_uniform = store_pairs(bank_uniform.read(), C, keys, patterns, uniform_alphas) |
|
|
bank_uniform.write(M_uniform - bank_uniform.read()) |
|
|
uniform_readouts = slicer(bank_uniform.read()).squeeze(0) |
|
|
|
|
|
|
|
|
adaptive_alphas = compute_adaptive_alphas(patterns, C, keys) |
|
|
M_adaptive = store_pairs(bank_adaptive.read(), C, keys, patterns, adaptive_alphas) |
|
|
bank_adaptive.write(M_adaptive - bank_adaptive.read()) |
|
|
adaptive_readouts = slicer(bank_adaptive.read()).squeeze(0) |
|
|
|
|
|
|
|
|
uniform_psnr = [] |
|
|
adaptive_psnr = [] |
|
|
|
|
|
print(" Pattern-by-pattern comparison:") |
|
|
for i in range(K): |
|
|
u_psnr = psnr(patterns[i].cpu().numpy(), uniform_readouts[i].cpu().numpy()) |
|
|
a_psnr = psnr(patterns[i].cpu().numpy(), adaptive_readouts[i].cpu().numpy()) |
|
|
|
|
|
uniform_psnr.append(u_psnr) |
|
|
adaptive_psnr.append(a_psnr) |
|
|
|
|
|
energy = torch.norm(patterns[i]).item() |
|
|
print(f" Pattern {i}: Energy={energy:.3f}, Alpha={adaptive_alphas[i]:.3f}") |
|
|
print(f" Uniform PSNR: {u_psnr:.1f}dB, Adaptive PSNR: {a_psnr:.1f}dB") |
|
|
|
|
|
avg_uniform = np.mean(uniform_psnr) |
|
|
avg_adaptive = np.mean(adaptive_psnr) |
|
|
improvement = avg_adaptive - avg_uniform |
|
|
|
|
|
print(f"\n Results Summary:") |
|
|
print(f" Uniform alphas: {avg_uniform:.1f}dB average PSNR") |
|
|
print(f" Adaptive alphas: {avg_adaptive:.1f}dB average PSNR") |
|
|
print(f" Improvement: {improvement:.1f}dB ({improvement/avg_uniform*100:.1f}%)") |
|
|
|
|
|
return improvement > 0 |
|
|
|
|
|
|
|
|
def test_extended_codes(): |
|
|
"""Test extended code generation for K > L scenarios.""" |
|
|
print("\nπ§ͺ Testing Extended Code Generation...") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
L = 32 |
|
|
test_Ks = [16, 32, 64, 128] |
|
|
|
|
|
results = {} |
|
|
|
|
|
for K in test_Ks: |
|
|
print(f" Testing L={L}, K={K} (capacity: {K/L:.1f}x)") |
|
|
|
|
|
|
|
|
C = generate_extended_codes(L, K, method="auto", device=device) |
|
|
|
|
|
|
|
|
if K <= L: |
|
|
G = C.T @ C |
|
|
I_approx = torch.eye(K, device=device, dtype=C.dtype) |
|
|
orthogonality_error = torch.norm(G - I_approx).item() |
|
|
else: |
|
|
|
|
|
C_ortho = C[:, :L] |
|
|
G = C_ortho.T @ C_ortho |
|
|
I_approx = torch.eye(L, device=device, dtype=C.dtype) |
|
|
orthogonality_error = torch.norm(G - I_approx).item() |
|
|
|
|
|
|
|
|
B, H, W = 1, 8, 8 |
|
|
bank = MembraneBank(L, H, W, device=device) |
|
|
bank.allocate(B) |
|
|
|
|
|
slicer = make_slicer(C) |
|
|
|
|
|
|
|
|
|
|
|
actual_K = min(K, C.shape[1]) |
|
|
patterns = torch.rand(actual_K, H, W, device=device) |
|
|
keys = torch.arange(actual_K, device=device) |
|
|
alphas = torch.ones(actual_K, device=device) |
|
|
|
|
|
|
|
|
M = store_pairs(bank.read(), C, keys, patterns, alphas) |
|
|
bank.write(M - bank.read()) |
|
|
readouts = slicer(bank.read()).squeeze(0) |
|
|
|
|
|
|
|
|
psnr_values = [] |
|
|
for i in range(actual_K): |
|
|
psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy()) |
|
|
psnr_values.append(psnr_val) |
|
|
|
|
|
avg_psnr = np.mean(psnr_values) |
|
|
min_psnr = np.min(psnr_values) |
|
|
std_psnr = np.std(psnr_values) |
|
|
|
|
|
results[K] = { |
|
|
"orthogonality_error": orthogonality_error, |
|
|
"avg_psnr": avg_psnr, |
|
|
"min_psnr": min_psnr, |
|
|
"std_psnr": std_psnr |
|
|
} |
|
|
|
|
|
print(f" Orthogonality error: {orthogonality_error:.6f}") |
|
|
print(f" PSNR: {avg_psnr:.1f}Β±{std_psnr:.1f}dB (min: {min_psnr:.1f}dB)") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def test_hierarchical_memory(): |
|
|
"""Test hierarchical memory bank organization.""" |
|
|
print("\nπ§ͺ Testing Hierarchical Memory Bank...") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
L, H, W = 64, 32, 32 |
|
|
K = 32 |
|
|
|
|
|
|
|
|
hierarchical_bank = HierarchicalMembraneBank(L, H, W, levels=3, device=device) |
|
|
hierarchical_bank.allocate(1) |
|
|
|
|
|
|
|
|
regular_bank = MembraneBank(L, H, W, device=device) |
|
|
regular_bank.allocate(1) |
|
|
|
|
|
|
|
|
patterns = [] |
|
|
for i in range(K): |
|
|
if i < K // 3: |
|
|
pattern = torch.rand(H, W, device=device) |
|
|
elif i < 2 * K // 3: |
|
|
pattern = torch.zeros(H, W, device=device) |
|
|
pattern[H//4:3*H//4, W//4:3*W//4] = torch.rand(H//2, W//2, device=device) |
|
|
else: |
|
|
pattern = torch.zeros(H, W, device=device) |
|
|
pattern[H//2-2:H//2+2, W//2-2:W//2+2] = torch.ones(4, 4, device=device) |
|
|
patterns.append(pattern) |
|
|
|
|
|
patterns = torch.stack(patterns) |
|
|
keys = torch.arange(K, device=device) |
|
|
|
|
|
|
|
|
C_regular = hadamard_codes(L, K).to(device) |
|
|
slicer_regular = make_slicer(C_regular) |
|
|
alphas_regular = torch.ones(K, device=device) |
|
|
|
|
|
start_time = time.time() |
|
|
M_regular = store_pairs(regular_bank.read(), C_regular, keys, patterns, alphas_regular) |
|
|
regular_bank.write(M_regular - regular_bank.read()) |
|
|
regular_readouts = slicer_regular(regular_bank.read()).squeeze(0) |
|
|
regular_time = time.time() - start_time |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
hierarchical_bank.store_hierarchical(patterns, keys) |
|
|
hierarchical_time = time.time() - start_time |
|
|
|
|
|
|
|
|
regular_memory = L * H * W * 4 |
|
|
hierarchical_memory = sum(bank.L * H * W * 4 for bank in hierarchical_bank.banks) |
|
|
memory_savings = (regular_memory - hierarchical_memory) / regular_memory * 100 |
|
|
|
|
|
|
|
|
regular_psnr = [] |
|
|
for i in range(K): |
|
|
psnr_val = psnr(patterns[i].cpu().numpy(), regular_readouts[i].cpu().numpy()) |
|
|
regular_psnr.append(psnr_val) |
|
|
|
|
|
avg_regular_psnr = np.mean(regular_psnr) |
|
|
|
|
|
print(f" Regular Bank:") |
|
|
print(f" Storage time: {regular_time*1000:.2f}ms") |
|
|
print(f" Memory usage: {regular_memory/1e6:.2f}MB") |
|
|
print(f" Average PSNR: {avg_regular_psnr:.1f}dB") |
|
|
|
|
|
print(f" Hierarchical Bank:") |
|
|
print(f" Storage time: {hierarchical_time*1000:.2f}ms") |
|
|
print(f" Memory usage: {hierarchical_memory/1e6:.2f}MB") |
|
|
print(f" Memory savings: {memory_savings:.1f}%") |
|
|
print(f" Levels: {hierarchical_bank.levels}") |
|
|
|
|
|
for i, bank in enumerate(hierarchical_bank.banks): |
|
|
level_fraction = bank.L / hierarchical_bank.total_L |
|
|
print(f" Level {i}: L={bank.L} ({level_fraction:.1%})") |
|
|
|
|
|
return memory_savings > 0 |
|
|
|
|
|
|
|
|
def test_optimized_storage(): |
|
|
"""Test the complete optimized storage pipeline.""" |
|
|
print("\nπ§ͺ Testing Optimized Storage Pipeline...") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
B, L, H, W, K = 1, 64, 32, 32, 48 |
|
|
|
|
|
|
|
|
bank_original = MembraneBank(L, H, W, device=device) |
|
|
bank_optimized = MembraneBank(L, H, W, device=device) |
|
|
bank_original.allocate(B) |
|
|
bank_optimized.allocate(B) |
|
|
|
|
|
|
|
|
C = generate_extended_codes(L, K, method="auto", device=device) |
|
|
slicer = make_slicer(C) |
|
|
|
|
|
|
|
|
patterns = [] |
|
|
for i in range(K): |
|
|
if i % 4 == 0: |
|
|
pattern = torch.rand(H, W, device=device) * 2.0 |
|
|
elif i % 4 == 1: |
|
|
pattern = torch.rand(H, W, device=device) * 1.0 |
|
|
elif i % 4 == 2: |
|
|
pattern = torch.rand(H, W, device=device) * 0.5 |
|
|
else: |
|
|
pattern = torch.zeros(H, W, device=device) |
|
|
pattern[torch.rand(H, W, device=device) > 0.95] = torch.rand((torch.rand(H, W, device=device) > 0.95).sum(), device=device) |
|
|
patterns.append(pattern) |
|
|
|
|
|
patterns = torch.stack(patterns) |
|
|
keys = torch.arange(K, device=device) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
alphas_original = torch.ones(K, device=device) |
|
|
M_original = store_pairs(bank_original.read(), C, keys, patterns, alphas_original) |
|
|
bank_original.write(M_original - bank_original.read()) |
|
|
original_readouts = slicer(bank_original.read()).squeeze(0) |
|
|
original_time = time.time() - start_time |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
M_optimized = optimized_store_pairs( |
|
|
bank_optimized.read(), C, keys, patterns, |
|
|
adaptive_alphas=True, sparsity_threshold=0.01 |
|
|
) |
|
|
bank_optimized.write(M_optimized - bank_optimized.read()) |
|
|
optimized_readouts = slicer(bank_optimized.read()).squeeze(0) |
|
|
optimized_time = time.time() - start_time |
|
|
|
|
|
|
|
|
original_psnr = [] |
|
|
optimized_psnr = [] |
|
|
|
|
|
for i in range(K): |
|
|
o_psnr = psnr(patterns[i].cpu().numpy(), original_readouts[i].cpu().numpy()) |
|
|
opt_psnr = psnr(patterns[i].cpu().numpy(), optimized_readouts[i].cpu().numpy()) |
|
|
|
|
|
original_psnr.append(o_psnr) |
|
|
optimized_psnr.append(opt_psnr) |
|
|
|
|
|
avg_original = np.mean(original_psnr) |
|
|
avg_optimized = np.mean(optimized_psnr) |
|
|
fidelity_improvement = avg_optimized - avg_original |
|
|
speed_improvement = (original_time - optimized_time) / original_time * 100 |
|
|
|
|
|
print(f" Original Pipeline:") |
|
|
print(f" Time: {original_time*1000:.2f}ms") |
|
|
print(f" Average PSNR: {avg_original:.1f}dB") |
|
|
|
|
|
print(f" Optimized Pipeline:") |
|
|
print(f" Time: {optimized_time*1000:.2f}ms") |
|
|
print(f" Average PSNR: {avg_optimized:.1f}dB") |
|
|
|
|
|
print(f" Improvements:") |
|
|
print(f" Fidelity: +{fidelity_improvement:.1f}dB ({fidelity_improvement/avg_original*100:.1f}%)") |
|
|
print(f" Speed: {speed_improvement:.1f}% {'faster' if speed_improvement > 0 else 'slower'}") |
|
|
|
|
|
return fidelity_improvement > 0 |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Run complete optimization test suite.""" |
|
|
print("π WrinkleBrane Optimization Test Suite") |
|
|
print("="*50) |
|
|
|
|
|
|
|
|
torch.manual_seed(42) |
|
|
np.random.seed(42) |
|
|
|
|
|
success_count = 0 |
|
|
total_tests = 4 |
|
|
|
|
|
try: |
|
|
|
|
|
if test_adaptive_alphas(): |
|
|
print("β
Adaptive alpha scaling: IMPROVED PERFORMANCE") |
|
|
success_count += 1 |
|
|
else: |
|
|
print("β οΈ Adaptive alpha scaling: NO IMPROVEMENT") |
|
|
|
|
|
|
|
|
extended_results = test_extended_codes() |
|
|
if all(r['avg_psnr'] > 50 for r in extended_results.values()): |
|
|
print("β
Extended code generation: WORKING") |
|
|
success_count += 1 |
|
|
else: |
|
|
print("β οΈ Extended code generation: QUALITY ISSUES") |
|
|
|
|
|
|
|
|
if test_hierarchical_memory(): |
|
|
print("β
Hierarchical memory: MEMORY SAVINGS") |
|
|
success_count += 1 |
|
|
else: |
|
|
print("β οΈ Hierarchical memory: NO SAVINGS") |
|
|
|
|
|
|
|
|
if test_optimized_storage(): |
|
|
print("β
Optimized storage pipeline: IMPROVED FIDELITY") |
|
|
success_count += 1 |
|
|
else: |
|
|
print("β οΈ Optimized storage pipeline: NO IMPROVEMENT") |
|
|
|
|
|
print("\n" + "="*50) |
|
|
print(f"π― Optimization Results: {success_count}/{total_tests} improvements successful") |
|
|
|
|
|
if success_count == total_tests: |
|
|
print("π ALL OPTIMIZATIONS WORKING PERFECTLY!") |
|
|
elif success_count > total_tests // 2: |
|
|
print("β
MAJORITY OF OPTIMIZATIONS SUCCESSFUL") |
|
|
else: |
|
|
print("β οΈ Mixed results - some optimizations need work") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ Optimization tests failed with error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
return success_count > 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |