WrinkleBrane / test_optimizations.py
WCNegentropy's picture
πŸ“š Updated with scientifically rigorous documentation
dc2b9f3 verified
#!/usr/bin/env python3
"""
Test WrinkleBrane Optimizations
Validate performance and fidelity improvements from optimizations.
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent / "src"))
import torch
import numpy as np
import time
from wrinklebrane.membrane_bank import MembraneBank
from wrinklebrane.codes import hadamard_codes
from wrinklebrane.slicer import make_slicer
from wrinklebrane.write_ops import store_pairs
from wrinklebrane.metrics import psnr, ssim
from wrinklebrane.optimizations import (
compute_adaptive_alphas,
generate_extended_codes,
HierarchicalMembraneBank,
optimized_store_pairs
)
def test_adaptive_alphas():
"""Test adaptive alpha scaling vs uniform alphas."""
print("πŸ§ͺ Testing Adaptive Alpha Scaling...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
B, L, H, W, K = 1, 32, 16, 16, 8
# Create test setup
bank_uniform = MembraneBank(L, H, W, device=device)
bank_adaptive = MembraneBank(L, H, W, device=device)
bank_uniform.allocate(B)
bank_adaptive.allocate(B)
C = hadamard_codes(L, K).to(device)
slicer = make_slicer(C)
# Create test patterns with varying energies
patterns = []
for i in range(K):
pattern = torch.zeros(H, W, device=device)
# Create patterns with different energy levels
energy_scale = 0.1 + i * 0.3 # Varying from 0.1 to 2.2
if i % 3 == 0: # High energy circles
for y in range(H):
for x in range(W):
if (x - H//2)**2 + (y - W//2)**2 <= (3 + i//3)**2:
pattern[y, x] = energy_scale
elif i % 3 == 1: # Medium energy squares
size = 4 + i//3
start = (H - size) // 2
pattern[start:start+size, start:start+size] = energy_scale * 0.5
else: # Low energy lines
for d in range(min(H, W)):
if d + i//3 < H and d + i//3 < W:
pattern[d + i//3, d] = energy_scale * 0.1
patterns.append(pattern)
patterns = torch.stack(patterns)
keys = torch.arange(K, device=device)
# Test uniform alphas
uniform_alphas = torch.ones(K, device=device)
M_uniform = store_pairs(bank_uniform.read(), C, keys, patterns, uniform_alphas)
bank_uniform.write(M_uniform - bank_uniform.read())
uniform_readouts = slicer(bank_uniform.read()).squeeze(0)
# Test adaptive alphas
adaptive_alphas = compute_adaptive_alphas(patterns, C, keys)
M_adaptive = store_pairs(bank_adaptive.read(), C, keys, patterns, adaptive_alphas)
bank_adaptive.write(M_adaptive - bank_adaptive.read())
adaptive_readouts = slicer(bank_adaptive.read()).squeeze(0)
# Compare fidelity
uniform_psnr = []
adaptive_psnr = []
print(" Pattern-by-pattern comparison:")
for i in range(K):
u_psnr = psnr(patterns[i].cpu().numpy(), uniform_readouts[i].cpu().numpy())
a_psnr = psnr(patterns[i].cpu().numpy(), adaptive_readouts[i].cpu().numpy())
uniform_psnr.append(u_psnr)
adaptive_psnr.append(a_psnr)
energy = torch.norm(patterns[i]).item()
print(f" Pattern {i}: Energy={energy:.3f}, Alpha={adaptive_alphas[i]:.3f}")
print(f" Uniform PSNR: {u_psnr:.1f}dB, Adaptive PSNR: {a_psnr:.1f}dB")
avg_uniform = np.mean(uniform_psnr)
avg_adaptive = np.mean(adaptive_psnr)
improvement = avg_adaptive - avg_uniform
print(f"\n Results Summary:")
print(f" Uniform alphas: {avg_uniform:.1f}dB average PSNR")
print(f" Adaptive alphas: {avg_adaptive:.1f}dB average PSNR")
print(f" Improvement: {improvement:.1f}dB ({improvement/avg_uniform*100:.1f}%)")
return improvement > 0
def test_extended_codes():
"""Test extended code generation for K > L scenarios."""
print("\nπŸ§ͺ Testing Extended Code Generation...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
L = 32 # Small number of layers
test_Ks = [16, 32, 64, 128] # Including K > L cases
results = {}
for K in test_Ks:
print(f" Testing L={L}, K={K} (capacity: {K/L:.1f}x)")
# Generate extended codes
C = generate_extended_codes(L, K, method="auto", device=device)
# Test orthogonality (only for the orthogonal part when K > L)
if K <= L:
G = C.T @ C
I_approx = torch.eye(K, device=device, dtype=C.dtype)
orthogonality_error = torch.norm(G - I_approx).item()
else:
# For overcomplete case, measure orthogonality of first L vectors
C_ortho = C[:, :L]
G = C_ortho.T @ C_ortho
I_approx = torch.eye(L, device=device, dtype=C.dtype)
orthogonality_error = torch.norm(G - I_approx).item()
# Test in actual storage scenario
B, H, W = 1, 8, 8
bank = MembraneBank(L, H, W, device=device)
bank.allocate(B)
slicer = make_slicer(C)
# Create test patterns (but limit keys to available codes)
# For K > C.shape[1] case, we test with fewer actual patterns
actual_K = min(K, C.shape[1])
patterns = torch.rand(actual_K, H, W, device=device)
keys = torch.arange(actual_K, device=device)
alphas = torch.ones(actual_K, device=device)
# Store and retrieve
M = store_pairs(bank.read(), C, keys, patterns, alphas)
bank.write(M - bank.read())
readouts = slicer(bank.read()).squeeze(0)
# Calculate average fidelity
psnr_values = []
for i in range(actual_K):
psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy())
psnr_values.append(psnr_val)
avg_psnr = np.mean(psnr_values)
min_psnr = np.min(psnr_values)
std_psnr = np.std(psnr_values)
results[K] = {
"orthogonality_error": orthogonality_error,
"avg_psnr": avg_psnr,
"min_psnr": min_psnr,
"std_psnr": std_psnr
}
print(f" Orthogonality error: {orthogonality_error:.6f}")
print(f" PSNR: {avg_psnr:.1f}Β±{std_psnr:.1f}dB (min: {min_psnr:.1f}dB)")
return results
def test_hierarchical_memory():
"""Test hierarchical memory bank organization."""
print("\nπŸ§ͺ Testing Hierarchical Memory Bank...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
L, H, W = 64, 32, 32
K = 32
# Create hierarchical bank
hierarchical_bank = HierarchicalMembraneBank(L, H, W, levels=3, device=device)
hierarchical_bank.allocate(1)
# Create regular bank for comparison
regular_bank = MembraneBank(L, H, W, device=device)
regular_bank.allocate(1)
# Create test patterns with different complexity levels
patterns = []
for i in range(K):
if i < K // 3: # High complexity patterns
pattern = torch.rand(H, W, device=device)
elif i < 2 * K // 3: # Medium complexity patterns
pattern = torch.zeros(H, W, device=device)
pattern[H//4:3*H//4, W//4:3*W//4] = torch.rand(H//2, W//2, device=device)
else: # Low complexity patterns
pattern = torch.zeros(H, W, device=device)
pattern[H//2-2:H//2+2, W//2-2:W//2+2] = torch.ones(4, 4, device=device)
patterns.append(pattern)
patterns = torch.stack(patterns)
keys = torch.arange(K, device=device)
# Test regular storage
C_regular = hadamard_codes(L, K).to(device)
slicer_regular = make_slicer(C_regular)
alphas_regular = torch.ones(K, device=device)
start_time = time.time()
M_regular = store_pairs(regular_bank.read(), C_regular, keys, patterns, alphas_regular)
regular_bank.write(M_regular - regular_bank.read())
regular_readouts = slicer_regular(regular_bank.read()).squeeze(0)
regular_time = time.time() - start_time
# Test hierarchical storage
start_time = time.time()
hierarchical_bank.store_hierarchical(patterns, keys)
hierarchical_time = time.time() - start_time
# Calculate memory usage
regular_memory = L * H * W * 4 # Single bank
hierarchical_memory = sum(bank.L * H * W * 4 for bank in hierarchical_bank.banks)
memory_savings = (regular_memory - hierarchical_memory) / regular_memory * 100
# Calculate regular fidelity
regular_psnr = []
for i in range(K):
psnr_val = psnr(patterns[i].cpu().numpy(), regular_readouts[i].cpu().numpy())
regular_psnr.append(psnr_val)
avg_regular_psnr = np.mean(regular_psnr)
print(f" Regular Bank:")
print(f" Storage time: {regular_time*1000:.2f}ms")
print(f" Memory usage: {regular_memory/1e6:.2f}MB")
print(f" Average PSNR: {avg_regular_psnr:.1f}dB")
print(f" Hierarchical Bank:")
print(f" Storage time: {hierarchical_time*1000:.2f}ms")
print(f" Memory usage: {hierarchical_memory/1e6:.2f}MB")
print(f" Memory savings: {memory_savings:.1f}%")
print(f" Levels: {hierarchical_bank.levels}")
for i, bank in enumerate(hierarchical_bank.banks):
level_fraction = bank.L / hierarchical_bank.total_L
print(f" Level {i}: L={bank.L} ({level_fraction:.1%})")
return memory_savings > 0
def test_optimized_storage():
"""Test the complete optimized storage pipeline."""
print("\nπŸ§ͺ Testing Optimized Storage Pipeline...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
B, L, H, W, K = 1, 64, 32, 32, 48
# Create test banks
bank_original = MembraneBank(L, H, W, device=device)
bank_optimized = MembraneBank(L, H, W, device=device)
bank_original.allocate(B)
bank_optimized.allocate(B)
# Generate extended codes to handle K < L limit
C = generate_extended_codes(L, K, method="auto", device=device)
slicer = make_slicer(C)
# Create mixed complexity test patterns
patterns = []
for i in range(K):
if i % 4 == 0: # High energy patterns
pattern = torch.rand(H, W, device=device) * 2.0
elif i % 4 == 1: # Medium energy patterns
pattern = torch.rand(H, W, device=device) * 1.0
elif i % 4 == 2: # Low energy patterns
pattern = torch.rand(H, W, device=device) * 0.5
else: # Very sparse patterns
pattern = torch.zeros(H, W, device=device)
pattern[torch.rand(H, W, device=device) > 0.95] = torch.rand((torch.rand(H, W, device=device) > 0.95).sum(), device=device)
patterns.append(pattern)
patterns = torch.stack(patterns)
keys = torch.arange(K, device=device)
# Original storage
start_time = time.time()
alphas_original = torch.ones(K, device=device)
M_original = store_pairs(bank_original.read(), C, keys, patterns, alphas_original)
bank_original.write(M_original - bank_original.read())
original_readouts = slicer(bank_original.read()).squeeze(0)
original_time = time.time() - start_time
# Optimized storage
start_time = time.time()
M_optimized = optimized_store_pairs(
bank_optimized.read(), C, keys, patterns,
adaptive_alphas=True, sparsity_threshold=0.01
)
bank_optimized.write(M_optimized - bank_optimized.read())
optimized_readouts = slicer(bank_optimized.read()).squeeze(0)
optimized_time = time.time() - start_time
# Compare results
original_psnr = []
optimized_psnr = []
for i in range(K):
o_psnr = psnr(patterns[i].cpu().numpy(), original_readouts[i].cpu().numpy())
opt_psnr = psnr(patterns[i].cpu().numpy(), optimized_readouts[i].cpu().numpy())
original_psnr.append(o_psnr)
optimized_psnr.append(opt_psnr)
avg_original = np.mean(original_psnr)
avg_optimized = np.mean(optimized_psnr)
fidelity_improvement = avg_optimized - avg_original
speed_improvement = (original_time - optimized_time) / original_time * 100
print(f" Original Pipeline:")
print(f" Time: {original_time*1000:.2f}ms")
print(f" Average PSNR: {avg_original:.1f}dB")
print(f" Optimized Pipeline:")
print(f" Time: {optimized_time*1000:.2f}ms")
print(f" Average PSNR: {avg_optimized:.1f}dB")
print(f" Improvements:")
print(f" Fidelity: +{fidelity_improvement:.1f}dB ({fidelity_improvement/avg_original*100:.1f}%)")
print(f" Speed: {speed_improvement:.1f}% {'faster' if speed_improvement > 0 else 'slower'}")
return fidelity_improvement > 0
def main():
"""Run complete optimization test suite."""
print("πŸš€ WrinkleBrane Optimization Test Suite")
print("="*50)
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
success_count = 0
total_tests = 4
try:
# Test adaptive alphas
if test_adaptive_alphas():
print("βœ… Adaptive alpha scaling: IMPROVED PERFORMANCE")
success_count += 1
else:
print("⚠️ Adaptive alpha scaling: NO IMPROVEMENT")
# Test extended codes
extended_results = test_extended_codes()
if all(r['avg_psnr'] > 50 for r in extended_results.values()): # Reasonable quality threshold
print("βœ… Extended code generation: WORKING")
success_count += 1
else:
print("⚠️ Extended code generation: QUALITY ISSUES")
# Test hierarchical memory
if test_hierarchical_memory():
print("βœ… Hierarchical memory: MEMORY SAVINGS")
success_count += 1
else:
print("⚠️ Hierarchical memory: NO SAVINGS")
# Test optimized storage
if test_optimized_storage():
print("βœ… Optimized storage pipeline: IMPROVED FIDELITY")
success_count += 1
else:
print("⚠️ Optimized storage pipeline: NO IMPROVEMENT")
print("\n" + "="*50)
print(f"🎯 Optimization Results: {success_count}/{total_tests} improvements successful")
if success_count == total_tests:
print("πŸ† ALL OPTIMIZATIONS WORKING PERFECTLY!")
elif success_count > total_tests // 2:
print("βœ… MAJORITY OF OPTIMIZATIONS SUCCESSFUL")
else:
print("⚠️ Mixed results - some optimizations need work")
except Exception as e:
print(f"\n❌ Optimization tests failed with error: {e}")
import traceback
traceback.print_exc()
return False
return success_count > 0
if __name__ == "__main__":
main()