#!/usr/bin/env python3 """ Test WrinkleBrane Optimizations Validate performance and fidelity improvements from optimizations. """ import sys from pathlib import Path sys.path.append(str(Path(__file__).resolve().parent / "src")) import torch import numpy as np import time from wrinklebrane.membrane_bank import MembraneBank from wrinklebrane.codes import hadamard_codes from wrinklebrane.slicer import make_slicer from wrinklebrane.write_ops import store_pairs from wrinklebrane.metrics import psnr, ssim from wrinklebrane.optimizations import ( compute_adaptive_alphas, generate_extended_codes, HierarchicalMembraneBank, optimized_store_pairs ) def test_adaptive_alphas(): """Test adaptive alpha scaling vs uniform alphas.""" print("🧪 Testing Adaptive Alpha Scaling...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") B, L, H, W, K = 1, 32, 16, 16, 8 # Create test setup bank_uniform = MembraneBank(L, H, W, device=device) bank_adaptive = MembraneBank(L, H, W, device=device) bank_uniform.allocate(B) bank_adaptive.allocate(B) C = hadamard_codes(L, K).to(device) slicer = make_slicer(C) # Create test patterns with varying energies patterns = [] for i in range(K): pattern = torch.zeros(H, W, device=device) # Create patterns with different energy levels energy_scale = 0.1 + i * 0.3 # Varying from 0.1 to 2.2 if i % 3 == 0: # High energy circles for y in range(H): for x in range(W): if (x - H//2)**2 + (y - W//2)**2 <= (3 + i//3)**2: pattern[y, x] = energy_scale elif i % 3 == 1: # Medium energy squares size = 4 + i//3 start = (H - size) // 2 pattern[start:start+size, start:start+size] = energy_scale * 0.5 else: # Low energy lines for d in range(min(H, W)): if d + i//3 < H and d + i//3 < W: pattern[d + i//3, d] = energy_scale * 0.1 patterns.append(pattern) patterns = torch.stack(patterns) keys = torch.arange(K, device=device) # Test uniform alphas uniform_alphas = torch.ones(K, device=device) M_uniform = store_pairs(bank_uniform.read(), C, keys, patterns, uniform_alphas) bank_uniform.write(M_uniform - bank_uniform.read()) uniform_readouts = slicer(bank_uniform.read()).squeeze(0) # Test adaptive alphas adaptive_alphas = compute_adaptive_alphas(patterns, C, keys) M_adaptive = store_pairs(bank_adaptive.read(), C, keys, patterns, adaptive_alphas) bank_adaptive.write(M_adaptive - bank_adaptive.read()) adaptive_readouts = slicer(bank_adaptive.read()).squeeze(0) # Compare fidelity uniform_psnr = [] adaptive_psnr = [] print(" Pattern-by-pattern comparison:") for i in range(K): u_psnr = psnr(patterns[i].cpu().numpy(), uniform_readouts[i].cpu().numpy()) a_psnr = psnr(patterns[i].cpu().numpy(), adaptive_readouts[i].cpu().numpy()) uniform_psnr.append(u_psnr) adaptive_psnr.append(a_psnr) energy = torch.norm(patterns[i]).item() print(f" Pattern {i}: Energy={energy:.3f}, Alpha={adaptive_alphas[i]:.3f}") print(f" Uniform PSNR: {u_psnr:.1f}dB, Adaptive PSNR: {a_psnr:.1f}dB") avg_uniform = np.mean(uniform_psnr) avg_adaptive = np.mean(adaptive_psnr) improvement = avg_adaptive - avg_uniform print(f"\n Results Summary:") print(f" Uniform alphas: {avg_uniform:.1f}dB average PSNR") print(f" Adaptive alphas: {avg_adaptive:.1f}dB average PSNR") print(f" Improvement: {improvement:.1f}dB ({improvement/avg_uniform*100:.1f}%)") return improvement > 0 def test_extended_codes(): """Test extended code generation for K > L scenarios.""" print("\n🧪 Testing Extended Code Generation...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") L = 32 # Small number of layers test_Ks = [16, 32, 64, 128] # Including K > L cases results = {} for K in test_Ks: print(f" Testing L={L}, K={K} (capacity: {K/L:.1f}x)") # Generate extended codes C = generate_extended_codes(L, K, method="auto", device=device) # Test orthogonality (only for the orthogonal part when K > L) if K <= L: G = C.T @ C I_approx = torch.eye(K, device=device, dtype=C.dtype) orthogonality_error = torch.norm(G - I_approx).item() else: # For overcomplete case, measure orthogonality of first L vectors C_ortho = C[:, :L] G = C_ortho.T @ C_ortho I_approx = torch.eye(L, device=device, dtype=C.dtype) orthogonality_error = torch.norm(G - I_approx).item() # Test in actual storage scenario B, H, W = 1, 8, 8 bank = MembraneBank(L, H, W, device=device) bank.allocate(B) slicer = make_slicer(C) # Create test patterns (but limit keys to available codes) # For K > C.shape[1] case, we test with fewer actual patterns actual_K = min(K, C.shape[1]) patterns = torch.rand(actual_K, H, W, device=device) keys = torch.arange(actual_K, device=device) alphas = torch.ones(actual_K, device=device) # Store and retrieve M = store_pairs(bank.read(), C, keys, patterns, alphas) bank.write(M - bank.read()) readouts = slicer(bank.read()).squeeze(0) # Calculate average fidelity psnr_values = [] for i in range(actual_K): psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy()) psnr_values.append(psnr_val) avg_psnr = np.mean(psnr_values) min_psnr = np.min(psnr_values) std_psnr = np.std(psnr_values) results[K] = { "orthogonality_error": orthogonality_error, "avg_psnr": avg_psnr, "min_psnr": min_psnr, "std_psnr": std_psnr } print(f" Orthogonality error: {orthogonality_error:.6f}") print(f" PSNR: {avg_psnr:.1f}±{std_psnr:.1f}dB (min: {min_psnr:.1f}dB)") return results def test_hierarchical_memory(): """Test hierarchical memory bank organization.""" print("\n🧪 Testing Hierarchical Memory Bank...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") L, H, W = 64, 32, 32 K = 32 # Create hierarchical bank hierarchical_bank = HierarchicalMembraneBank(L, H, W, levels=3, device=device) hierarchical_bank.allocate(1) # Create regular bank for comparison regular_bank = MembraneBank(L, H, W, device=device) regular_bank.allocate(1) # Create test patterns with different complexity levels patterns = [] for i in range(K): if i < K // 3: # High complexity patterns pattern = torch.rand(H, W, device=device) elif i < 2 * K // 3: # Medium complexity patterns pattern = torch.zeros(H, W, device=device) pattern[H//4:3*H//4, W//4:3*W//4] = torch.rand(H//2, W//2, device=device) else: # Low complexity patterns pattern = torch.zeros(H, W, device=device) pattern[H//2-2:H//2+2, W//2-2:W//2+2] = torch.ones(4, 4, device=device) patterns.append(pattern) patterns = torch.stack(patterns) keys = torch.arange(K, device=device) # Test regular storage C_regular = hadamard_codes(L, K).to(device) slicer_regular = make_slicer(C_regular) alphas_regular = torch.ones(K, device=device) start_time = time.time() M_regular = store_pairs(regular_bank.read(), C_regular, keys, patterns, alphas_regular) regular_bank.write(M_regular - regular_bank.read()) regular_readouts = slicer_regular(regular_bank.read()).squeeze(0) regular_time = time.time() - start_time # Test hierarchical storage start_time = time.time() hierarchical_bank.store_hierarchical(patterns, keys) hierarchical_time = time.time() - start_time # Calculate memory usage regular_memory = L * H * W * 4 # Single bank hierarchical_memory = sum(bank.L * H * W * 4 for bank in hierarchical_bank.banks) memory_savings = (regular_memory - hierarchical_memory) / regular_memory * 100 # Calculate regular fidelity regular_psnr = [] for i in range(K): psnr_val = psnr(patterns[i].cpu().numpy(), regular_readouts[i].cpu().numpy()) regular_psnr.append(psnr_val) avg_regular_psnr = np.mean(regular_psnr) print(f" Regular Bank:") print(f" Storage time: {regular_time*1000:.2f}ms") print(f" Memory usage: {regular_memory/1e6:.2f}MB") print(f" Average PSNR: {avg_regular_psnr:.1f}dB") print(f" Hierarchical Bank:") print(f" Storage time: {hierarchical_time*1000:.2f}ms") print(f" Memory usage: {hierarchical_memory/1e6:.2f}MB") print(f" Memory savings: {memory_savings:.1f}%") print(f" Levels: {hierarchical_bank.levels}") for i, bank in enumerate(hierarchical_bank.banks): level_fraction = bank.L / hierarchical_bank.total_L print(f" Level {i}: L={bank.L} ({level_fraction:.1%})") return memory_savings > 0 def test_optimized_storage(): """Test the complete optimized storage pipeline.""" print("\n🧪 Testing Optimized Storage Pipeline...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") B, L, H, W, K = 1, 64, 32, 32, 48 # Create test banks bank_original = MembraneBank(L, H, W, device=device) bank_optimized = MembraneBank(L, H, W, device=device) bank_original.allocate(B) bank_optimized.allocate(B) # Generate extended codes to handle K < L limit C = generate_extended_codes(L, K, method="auto", device=device) slicer = make_slicer(C) # Create mixed complexity test patterns patterns = [] for i in range(K): if i % 4 == 0: # High energy patterns pattern = torch.rand(H, W, device=device) * 2.0 elif i % 4 == 1: # Medium energy patterns pattern = torch.rand(H, W, device=device) * 1.0 elif i % 4 == 2: # Low energy patterns pattern = torch.rand(H, W, device=device) * 0.5 else: # Very sparse patterns pattern = torch.zeros(H, W, device=device) pattern[torch.rand(H, W, device=device) > 0.95] = torch.rand((torch.rand(H, W, device=device) > 0.95).sum(), device=device) patterns.append(pattern) patterns = torch.stack(patterns) keys = torch.arange(K, device=device) # Original storage start_time = time.time() alphas_original = torch.ones(K, device=device) M_original = store_pairs(bank_original.read(), C, keys, patterns, alphas_original) bank_original.write(M_original - bank_original.read()) original_readouts = slicer(bank_original.read()).squeeze(0) original_time = time.time() - start_time # Optimized storage start_time = time.time() M_optimized = optimized_store_pairs( bank_optimized.read(), C, keys, patterns, adaptive_alphas=True, sparsity_threshold=0.01 ) bank_optimized.write(M_optimized - bank_optimized.read()) optimized_readouts = slicer(bank_optimized.read()).squeeze(0) optimized_time = time.time() - start_time # Compare results original_psnr = [] optimized_psnr = [] for i in range(K): o_psnr = psnr(patterns[i].cpu().numpy(), original_readouts[i].cpu().numpy()) opt_psnr = psnr(patterns[i].cpu().numpy(), optimized_readouts[i].cpu().numpy()) original_psnr.append(o_psnr) optimized_psnr.append(opt_psnr) avg_original = np.mean(original_psnr) avg_optimized = np.mean(optimized_psnr) fidelity_improvement = avg_optimized - avg_original speed_improvement = (original_time - optimized_time) / original_time * 100 print(f" Original Pipeline:") print(f" Time: {original_time*1000:.2f}ms") print(f" Average PSNR: {avg_original:.1f}dB") print(f" Optimized Pipeline:") print(f" Time: {optimized_time*1000:.2f}ms") print(f" Average PSNR: {avg_optimized:.1f}dB") print(f" Improvements:") print(f" Fidelity: +{fidelity_improvement:.1f}dB ({fidelity_improvement/avg_original*100:.1f}%)") print(f" Speed: {speed_improvement:.1f}% {'faster' if speed_improvement > 0 else 'slower'}") return fidelity_improvement > 0 def main(): """Run complete optimization test suite.""" print("🚀 WrinkleBrane Optimization Test Suite") print("="*50) # Set random seeds for reproducibility torch.manual_seed(42) np.random.seed(42) success_count = 0 total_tests = 4 try: # Test adaptive alphas if test_adaptive_alphas(): print("✅ Adaptive alpha scaling: IMPROVED PERFORMANCE") success_count += 1 else: print("⚠️ Adaptive alpha scaling: NO IMPROVEMENT") # Test extended codes extended_results = test_extended_codes() if all(r['avg_psnr'] > 50 for r in extended_results.values()): # Reasonable quality threshold print("✅ Extended code generation: WORKING") success_count += 1 else: print("⚠️ Extended code generation: QUALITY ISSUES") # Test hierarchical memory if test_hierarchical_memory(): print("✅ Hierarchical memory: MEMORY SAVINGS") success_count += 1 else: print("⚠️ Hierarchical memory: NO SAVINGS") # Test optimized storage if test_optimized_storage(): print("✅ Optimized storage pipeline: IMPROVED FIDELITY") success_count += 1 else: print("⚠️ Optimized storage pipeline: NO IMPROVEMENT") print("\n" + "="*50) print(f"🎯 Optimization Results: {success_count}/{total_tests} improvements successful") if success_count == total_tests: print("🏆 ALL OPTIMIZATIONS WORKING PERFECTLY!") elif success_count > total_tests // 2: print("✅ MAJORITY OF OPTIMIZATIONS SUCCESSFUL") else: print("⚠️ Mixed results - some optimizations need work") except Exception as e: print(f"\n❌ Optimization tests failed with error: {e}") import traceback traceback.print_exc() return False return success_count > 0 if __name__ == "__main__": main()