File size: 14,995 Bytes
dc2b9f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
#!/usr/bin/env python3
"""
Test WrinkleBrane Optimizations
Validate performance and fidelity improvements from optimizations.
"""

import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent / "src"))

import torch
import numpy as np
import time
from wrinklebrane.membrane_bank import MembraneBank  
from wrinklebrane.codes import hadamard_codes
from wrinklebrane.slicer import make_slicer
from wrinklebrane.write_ops import store_pairs
from wrinklebrane.metrics import psnr, ssim
from wrinklebrane.optimizations import (
    compute_adaptive_alphas,
    generate_extended_codes, 
    HierarchicalMembraneBank,
    optimized_store_pairs
)

def test_adaptive_alphas():
    """Test adaptive alpha scaling vs uniform alphas."""
    print("🧪 Testing Adaptive Alpha Scaling...")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    B, L, H, W, K = 1, 32, 16, 16, 8
    
    # Create test setup
    bank_uniform = MembraneBank(L, H, W, device=device)
    bank_adaptive = MembraneBank(L, H, W, device=device)
    bank_uniform.allocate(B)
    bank_adaptive.allocate(B)
    
    C = hadamard_codes(L, K).to(device)
    slicer = make_slicer(C)
    
    # Create test patterns with varying energies
    patterns = []
    for i in range(K):
        pattern = torch.zeros(H, W, device=device)
        # Create patterns with different energy levels
        energy_scale = 0.1 + i * 0.3  # Varying from 0.1 to 2.2
        
        if i % 3 == 0:  # High energy circles
            for y in range(H):
                for x in range(W):
                    if (x - H//2)**2 + (y - W//2)**2 <= (3 + i//3)**2:
                        pattern[y, x] = energy_scale
        elif i % 3 == 1:  # Medium energy squares  
            size = 4 + i//3
            start = (H - size) // 2
            pattern[start:start+size, start:start+size] = energy_scale * 0.5
        else:  # Low energy lines
            for d in range(min(H, W)):
                if d + i//3 < H and d + i//3 < W:
                    pattern[d + i//3, d] = energy_scale * 0.1
                    
        patterns.append(pattern)
    
    patterns = torch.stack(patterns)
    keys = torch.arange(K, device=device)
    
    # Test uniform alphas
    uniform_alphas = torch.ones(K, device=device)
    M_uniform = store_pairs(bank_uniform.read(), C, keys, patterns, uniform_alphas)
    bank_uniform.write(M_uniform - bank_uniform.read())
    uniform_readouts = slicer(bank_uniform.read()).squeeze(0)
    
    # Test adaptive alphas
    adaptive_alphas = compute_adaptive_alphas(patterns, C, keys)
    M_adaptive = store_pairs(bank_adaptive.read(), C, keys, patterns, adaptive_alphas)
    bank_adaptive.write(M_adaptive - bank_adaptive.read())
    adaptive_readouts = slicer(bank_adaptive.read()).squeeze(0)
    
    # Compare fidelity
    uniform_psnr = []
    adaptive_psnr = []
    
    print("   Pattern-by-pattern comparison:")
    for i in range(K):
        u_psnr = psnr(patterns[i].cpu().numpy(), uniform_readouts[i].cpu().numpy()) 
        a_psnr = psnr(patterns[i].cpu().numpy(), adaptive_readouts[i].cpu().numpy())
        
        uniform_psnr.append(u_psnr)
        adaptive_psnr.append(a_psnr)
        
        energy = torch.norm(patterns[i]).item()
        print(f"     Pattern {i}: Energy={energy:.3f}, Alpha={adaptive_alphas[i]:.3f}")
        print(f"       Uniform PSNR: {u_psnr:.1f}dB, Adaptive PSNR: {a_psnr:.1f}dB")
    
    avg_uniform = np.mean(uniform_psnr)
    avg_adaptive = np.mean(adaptive_psnr)
    improvement = avg_adaptive - avg_uniform
    
    print(f"\n   Results Summary:")
    print(f"     Uniform alphas:   {avg_uniform:.1f}dB average PSNR")
    print(f"     Adaptive alphas:  {avg_adaptive:.1f}dB average PSNR")
    print(f"     Improvement:      {improvement:.1f}dB ({improvement/avg_uniform*100:.1f}%)")
    
    return improvement > 0


def test_extended_codes():
    """Test extended code generation for K > L scenarios."""
    print("\n🧪 Testing Extended Code Generation...")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    L = 32  # Small number of layers
    test_Ks = [16, 32, 64, 128]  # Including K > L cases
    
    results = {}
    
    for K in test_Ks:
        print(f"   Testing L={L}, K={K} (capacity: {K/L:.1f}x)")
        
        # Generate extended codes
        C = generate_extended_codes(L, K, method="auto", device=device)
        
        # Test orthogonality (only for the orthogonal part when K > L)
        if K <= L:
            G = C.T @ C
            I_approx = torch.eye(K, device=device, dtype=C.dtype)
            orthogonality_error = torch.norm(G - I_approx).item()
        else:
            # For overcomplete case, measure orthogonality of first L vectors
            C_ortho = C[:, :L]
            G = C_ortho.T @ C_ortho
            I_approx = torch.eye(L, device=device, dtype=C.dtype)
            orthogonality_error = torch.norm(G - I_approx).item()
        
        # Test in actual storage scenario
        B, H, W = 1, 8, 8
        bank = MembraneBank(L, H, W, device=device)
        bank.allocate(B)
        
        slicer = make_slicer(C)
        
        # Create test patterns (but limit keys to available codes)
        # For K > C.shape[1] case, we test with fewer actual patterns
        actual_K = min(K, C.shape[1])
        patterns = torch.rand(actual_K, H, W, device=device)
        keys = torch.arange(actual_K, device=device)
        alphas = torch.ones(actual_K, device=device) 
        
        # Store and retrieve
        M = store_pairs(bank.read(), C, keys, patterns, alphas)
        bank.write(M - bank.read())
        readouts = slicer(bank.read()).squeeze(0)
        
        # Calculate average fidelity  
        psnr_values = []
        for i in range(actual_K):
            psnr_val = psnr(patterns[i].cpu().numpy(), readouts[i].cpu().numpy())
            psnr_values.append(psnr_val)
        
        avg_psnr = np.mean(psnr_values)
        min_psnr = np.min(psnr_values)
        std_psnr = np.std(psnr_values)
        
        results[K] = {
            "orthogonality_error": orthogonality_error,
            "avg_psnr": avg_psnr,
            "min_psnr": min_psnr,
            "std_psnr": std_psnr
        }
        
        print(f"     Orthogonality error: {orthogonality_error:.6f}")
        print(f"     PSNR: {avg_psnr:.1f}±{std_psnr:.1f}dB (min: {min_psnr:.1f}dB)")
    
    return results


def test_hierarchical_memory():
    """Test hierarchical memory bank organization."""
    print("\n🧪 Testing Hierarchical Memory Bank...")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    L, H, W = 64, 32, 32
    K = 32
    
    # Create hierarchical bank
    hierarchical_bank = HierarchicalMembraneBank(L, H, W, levels=3, device=device)
    hierarchical_bank.allocate(1)
    
    # Create regular bank for comparison
    regular_bank = MembraneBank(L, H, W, device=device)
    regular_bank.allocate(1)
    
    # Create test patterns with different complexity levels
    patterns = []
    for i in range(K):
        if i < K // 3:  # High complexity patterns
            pattern = torch.rand(H, W, device=device)
        elif i < 2 * K // 3:  # Medium complexity patterns
            pattern = torch.zeros(H, W, device=device)
            pattern[H//4:3*H//4, W//4:3*W//4] = torch.rand(H//2, W//2, device=device)
        else:  # Low complexity patterns
            pattern = torch.zeros(H, W, device=device)
            pattern[H//2-2:H//2+2, W//2-2:W//2+2] = torch.ones(4, 4, device=device)
        patterns.append(pattern)
    
    patterns = torch.stack(patterns)
    keys = torch.arange(K, device=device)
    
    # Test regular storage
    C_regular = hadamard_codes(L, K).to(device)
    slicer_regular = make_slicer(C_regular)
    alphas_regular = torch.ones(K, device=device)
    
    start_time = time.time()
    M_regular = store_pairs(regular_bank.read(), C_regular, keys, patterns, alphas_regular)
    regular_bank.write(M_regular - regular_bank.read())
    regular_readouts = slicer_regular(regular_bank.read()).squeeze(0)
    regular_time = time.time() - start_time
    
    # Test hierarchical storage
    start_time = time.time()
    hierarchical_bank.store_hierarchical(patterns, keys)
    hierarchical_time = time.time() - start_time
    
    # Calculate memory usage
    regular_memory = L * H * W * 4  # Single bank
    hierarchical_memory = sum(bank.L * H * W * 4 for bank in hierarchical_bank.banks)
    memory_savings = (regular_memory - hierarchical_memory) / regular_memory * 100
    
    # Calculate regular fidelity
    regular_psnr = []
    for i in range(K):
        psnr_val = psnr(patterns[i].cpu().numpy(), regular_readouts[i].cpu().numpy())
        regular_psnr.append(psnr_val)
    
    avg_regular_psnr = np.mean(regular_psnr)
    
    print(f"   Regular Bank:")
    print(f"     Storage time: {regular_time*1000:.2f}ms")
    print(f"     Memory usage: {regular_memory/1e6:.2f}MB")
    print(f"     Average PSNR: {avg_regular_psnr:.1f}dB")
    
    print(f"   Hierarchical Bank:")
    print(f"     Storage time: {hierarchical_time*1000:.2f}ms") 
    print(f"     Memory usage: {hierarchical_memory/1e6:.2f}MB")
    print(f"     Memory savings: {memory_savings:.1f}%")
    print(f"     Levels: {hierarchical_bank.levels}")
    
    for i, bank in enumerate(hierarchical_bank.banks):
        level_fraction = bank.L / hierarchical_bank.total_L
        print(f"       Level {i}: L={bank.L} ({level_fraction:.1%})")
    
    return memory_savings > 0


def test_optimized_storage():
    """Test the complete optimized storage pipeline."""
    print("\n🧪 Testing Optimized Storage Pipeline...")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    B, L, H, W, K = 1, 64, 32, 32, 48
    
    # Create test banks
    bank_original = MembraneBank(L, H, W, device=device)
    bank_optimized = MembraneBank(L, H, W, device=device)
    bank_original.allocate(B)
    bank_optimized.allocate(B)
    
    # Generate extended codes to handle K < L limit
    C = generate_extended_codes(L, K, method="auto", device=device)
    slicer = make_slicer(C)
    
    # Create mixed complexity test patterns
    patterns = []
    for i in range(K):
        if i % 4 == 0:  # High energy patterns
            pattern = torch.rand(H, W, device=device) * 2.0
        elif i % 4 == 1:  # Medium energy patterns
            pattern = torch.rand(H, W, device=device) * 1.0
        elif i % 4 == 2:  # Low energy patterns
            pattern = torch.rand(H, W, device=device) * 0.5
        else:  # Very sparse patterns
            pattern = torch.zeros(H, W, device=device)
            pattern[torch.rand(H, W, device=device) > 0.95] = torch.rand((torch.rand(H, W, device=device) > 0.95).sum(), device=device)
        patterns.append(pattern)
    
    patterns = torch.stack(patterns)
    keys = torch.arange(K, device=device)
    
    # Original storage
    start_time = time.time()
    alphas_original = torch.ones(K, device=device)
    M_original = store_pairs(bank_original.read(), C, keys, patterns, alphas_original)
    bank_original.write(M_original - bank_original.read())
    original_readouts = slicer(bank_original.read()).squeeze(0)
    original_time = time.time() - start_time
    
    # Optimized storage
    start_time = time.time()
    M_optimized = optimized_store_pairs(
        bank_optimized.read(), C, keys, patterns, 
        adaptive_alphas=True, sparsity_threshold=0.01
    )
    bank_optimized.write(M_optimized - bank_optimized.read())
    optimized_readouts = slicer(bank_optimized.read()).squeeze(0)
    optimized_time = time.time() - start_time
    
    # Compare results
    original_psnr = []
    optimized_psnr = []
    
    for i in range(K):
        o_psnr = psnr(patterns[i].cpu().numpy(), original_readouts[i].cpu().numpy())
        opt_psnr = psnr(patterns[i].cpu().numpy(), optimized_readouts[i].cpu().numpy())
        
        original_psnr.append(o_psnr)
        optimized_psnr.append(opt_psnr)
    
    avg_original = np.mean(original_psnr)
    avg_optimized = np.mean(optimized_psnr)
    fidelity_improvement = avg_optimized - avg_original
    speed_improvement = (original_time - optimized_time) / original_time * 100
    
    print(f"   Original Pipeline:")
    print(f"     Time: {original_time*1000:.2f}ms")
    print(f"     Average PSNR: {avg_original:.1f}dB")
    
    print(f"   Optimized Pipeline:")
    print(f"     Time: {optimized_time*1000:.2f}ms")
    print(f"     Average PSNR: {avg_optimized:.1f}dB")
    
    print(f"   Improvements:")
    print(f"     Fidelity: +{fidelity_improvement:.1f}dB ({fidelity_improvement/avg_original*100:.1f}%)")
    print(f"     Speed: {speed_improvement:.1f}% {'faster' if speed_improvement > 0 else 'slower'}")
    
    return fidelity_improvement > 0


def main():
    """Run complete optimization test suite."""
    print("🚀 WrinkleBrane Optimization Test Suite")
    print("="*50)
    
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    success_count = 0
    total_tests = 4
    
    try:
        # Test adaptive alphas
        if test_adaptive_alphas():
            print("✅ Adaptive alpha scaling: IMPROVED PERFORMANCE")
            success_count += 1
        else:
            print("⚠️  Adaptive alpha scaling: NO IMPROVEMENT")
        
        # Test extended codes
        extended_results = test_extended_codes()
        if all(r['avg_psnr'] > 50 for r in extended_results.values()):  # Reasonable quality threshold
            print("✅ Extended code generation: WORKING") 
            success_count += 1
        else:
            print("⚠️  Extended code generation: QUALITY ISSUES")
        
        # Test hierarchical memory
        if test_hierarchical_memory():
            print("✅ Hierarchical memory: MEMORY SAVINGS")
            success_count += 1
        else:
            print("⚠️  Hierarchical memory: NO SAVINGS")
        
        # Test optimized storage
        if test_optimized_storage():
            print("✅ Optimized storage pipeline: IMPROVED FIDELITY")
            success_count += 1
        else:
            print("⚠️  Optimized storage pipeline: NO IMPROVEMENT")
        
        print("\n" + "="*50)
        print(f"🎯 Optimization Results: {success_count}/{total_tests} improvements successful")
        
        if success_count == total_tests:
            print("🏆 ALL OPTIMIZATIONS WORKING PERFECTLY!")
        elif success_count > total_tests // 2:
            print("✅ MAJORITY OF OPTIMIZATIONS SUCCESSFUL")
        else:
            print("⚠️  Mixed results - some optimizations need work")
    
    except Exception as e:
        print(f"\n❌ Optimization tests failed with error: {e}")
        import traceback
        traceback.print_exc()
        return False
    
    return success_count > 0


if __name__ == "__main__":
    main()