AbstractPhil commited on
Commit
2e01525
·
verified ·
1 Parent(s): 7d7e3f5

Create model_trainer.py

Browse files
Files changed (1) hide show
  1. model_trainer.py +1897 -0
model_trainer.py ADDED
@@ -0,0 +1,1897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ David Training Pipeline
3
+ ========================
4
+ Training pipeline for David multi-scale crystal classifier.
5
+
6
+ Should be placed at: geovocab2/train/model/core/david_trainer.py
7
+ Or run from: scripts/train_david.py
8
+
9
+ Features:
10
+ - Pure fp32 training (no mixed precision for geometric stability)
11
+ - Adaptive training controller (freeze/unfreeze scales)
12
+ - Gradient analysis and scaling
13
+ - SafeTensors checkpoint support
14
+ - Enhanced loss component tracking
15
+ - Proper weight organization: weights/model_name/timestamp/
16
+ - Accuracy in filenames and comprehensive tracking
17
+ - Master models index (MODELS_INDEX.json)
18
+ """
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch.utils.data import Dataset, DataLoader
23
+ from torch.utils.tensorboard import SummaryWriter
24
+ from datasets import load_dataset
25
+ from huggingface_hub import HfApi, create_repo, upload_folder, upload_file
26
+ import numpy as np
27
+ import os
28
+ import json
29
+ import time
30
+ import tempfile
31
+ from datetime import datetime
32
+ from tqdm.auto import tqdm
33
+ from pathlib import Path
34
+ from typing import Dict, List, Optional, Tuple, Union
35
+ from dataclasses import dataclass, field, asdict
36
+
37
+ # Import David components
38
+ from geovocab2.train.config.david_config import (
39
+ DavidArchitectureConfig,
40
+ DavidPresets,
41
+ SharingMode,
42
+ FusionMode
43
+ )
44
+
45
+ from geovocab2.train.model.core.david import (
46
+ David,
47
+ MultiScaleCrystalLoss,
48
+ )
49
+
50
+ # Import SimplexFactory
51
+ from geovocab2.shapes.factory import SimplexFactory
52
+
53
+
54
+ # ============================================================================
55
+ # TRAINING CONFIGURATION
56
+ # ============================================================================
57
+
58
+ @dataclass
59
+ class DavidTrainingConfig:
60
+ """
61
+ Complete training configuration for David.
62
+ Separate from model architecture config.
63
+ """
64
+
65
+ # Metadata
66
+ name: str = "david_training"
67
+ run_id: str = "" # Auto-generated timestamp
68
+
69
+ # Dataset
70
+ dataset_name: str = "AbstractPhil/imagenet-clip-features-orderly"
71
+ model_variant: Union[str, List[str]] = "clip_vit_b16" # Single or list for multi-encoder
72
+ num_classes: int = 1000
73
+
74
+ # Model architecture (references to david_config)
75
+ preset: Optional[str] = "balanced" # Or None to use custom config
76
+ custom_config_path: Optional[str] = None # Path to custom david_config.json
77
+
78
+ # Architecture overrides (applied to preset or custom config)
79
+ num_classes_override: Optional[int] = None
80
+ use_belly_override: Optional[bool] = None
81
+ belly_expand_override: Optional[float] = None
82
+ progressive_training_override: Optional[bool] = True # Override progressive training
83
+ scale_warmup_epochs_override: Optional[Dict[int, int]] = None # Custom warmup schedule
84
+
85
+ # Training hyperparameters
86
+ num_epochs: int = 50
87
+ batch_size: int = 512
88
+ learning_rate: float = 5e-3
89
+ weight_decay: float = 1e-5
90
+ warmup_epochs: int = 3
91
+
92
+ # Loss weights
93
+ use_rose_loss: bool = True
94
+ rose_initial_weight: float = 0.01
95
+ rose_max_weight: float = 0.1
96
+ rose_weight_schedule: str = "adaptive"
97
+ use_cayley_loss: bool = False
98
+ cayley_weight: float = 0.001
99
+ scale_loss_balance: Optional[Dict[int, float]] = None
100
+
101
+ # Optimization
102
+ use_mixed_precision: bool = False # Keep False for stability
103
+ gradient_clip: float = 5.0
104
+ scheduler_type: str = "cosine_restarts"
105
+ min_lr: float = 1e-6
106
+
107
+ # Adaptive training (safer defaults)
108
+ freeze_strategy: str = "never" # "performance" or "never"
109
+ freeze_threshold: float = 90.0 # Only freeze when scale hits 90% accuracy
110
+ unfreeze_on_plateau: bool = True
111
+ patience: int = 10
112
+
113
+ # Gradient monitoring
114
+ track_gradients: bool = True
115
+ gradient_scale_threshold: float = 1e-5
116
+ gradient_scale_multiplier: float = 10.0
117
+
118
+ # Logging
119
+ log_interval: int = 50
120
+ val_interval: int = 1
121
+ save_interval: int = 5
122
+ log_fusion_weights: bool = True
123
+ log_loss_components: bool = True
124
+
125
+ # Checkpointing
126
+ save_format: str = "both" # "pytorch", "safetensors", or "both"
127
+
128
+ # HuggingFace Hub (optional)
129
+ hf_repo: Optional[str] = "" #"AbstractPhil/gated-david" # Your HF repo
130
+ upload_to_hub: bool = False
131
+
132
+ # Local paths
133
+ base_dir: str = "./david_training"
134
+
135
+ # Hardware
136
+ num_workers: int = 10
137
+ pin_memory: bool = True
138
+ prefetch_factor: int = 4
139
+ persistent_workers: bool = True
140
+
141
+ def __post_init__(self):
142
+ """Generate run_id if not provided."""
143
+ if not self.run_id:
144
+ self.run_id = datetime.now().strftime('%Y%m%d_%H%M%S')
145
+
146
+ def to_dict(self) -> dict:
147
+ """Convert to dictionary."""
148
+ return asdict(self)
149
+
150
+ @classmethod
151
+ def from_dict(cls, data: dict) -> 'DavidTrainingConfig':
152
+ """Create from dictionary."""
153
+ return cls(**data)
154
+
155
+ def to_json(self, path: str):
156
+ """Save to JSON."""
157
+ data = self.to_dict()
158
+ # Convert any nested dicts with int keys to str keys
159
+ if data.get('scale_loss_balance'):
160
+ data['scale_loss_balance'] = {
161
+ str(k): v for k, v in data['scale_loss_balance'].items()
162
+ }
163
+ if data.get('scale_warmup_epochs_override'):
164
+ data['scale_warmup_epochs_override'] = {
165
+ str(k): v for k, v in data['scale_warmup_epochs_override'].items()
166
+ }
167
+ with open(path, 'w') as f:
168
+ json.dump(data, f, indent=2)
169
+
170
+ @classmethod
171
+ def from_json(cls, path: str) -> 'DavidTrainingConfig':
172
+ """Load from JSON."""
173
+ with open(path, 'r') as f:
174
+ data = json.load(f)
175
+ # Convert str keys back to int for scale_loss_balance
176
+ if 'scale_loss_balance' in data and data['scale_loss_balance']:
177
+ data['scale_loss_balance'] = {
178
+ int(k): v for k, v in data['scale_loss_balance'].items()
179
+ }
180
+ # Convert str keys back to int for scale_warmup_epochs_override
181
+ if 'scale_warmup_epochs_override' in data and data['scale_warmup_epochs_override']:
182
+ data['scale_warmup_epochs_override'] = {
183
+ int(k): v for k, v in data['scale_warmup_epochs_override'].items()
184
+ }
185
+ return cls(**data)
186
+
187
+
188
+ # ============================================================================
189
+ # ADAPTIVE TRAINING CONTROLLER
190
+ # ============================================================================
191
+
192
+ class AdaptiveTrainingController:
193
+ """Manages adaptive training strategies for multi-scale model."""
194
+
195
+ def __init__(self, model: David, config: DavidTrainingConfig):
196
+ self.model = model
197
+ self.config = config
198
+
199
+ scales = model.scales
200
+ self.scale_history = {scale: [] for scale in scales}
201
+ self.best_scale_acc = {scale: 0.0 for scale in scales}
202
+ self.scales_frozen = {scale: False for scale in scales}
203
+
204
+ self.overall_history = []
205
+ self.plateau_counter = 0
206
+ self.best_overall = 0.0
207
+
208
+ def update_metrics(self, scale_accuracies: Dict[int, float], overall_accuracy: float):
209
+ """Update metrics and best scores."""
210
+ for scale, acc in scale_accuracies.items():
211
+ self.scale_history[scale].append(acc)
212
+ if acc > self.best_scale_acc[scale]:
213
+ self.best_scale_acc[scale] = acc
214
+
215
+ self.overall_history.append(overall_accuracy)
216
+
217
+ if overall_accuracy > self.best_overall:
218
+ self.best_overall = overall_accuracy
219
+ self.plateau_counter = 0
220
+ else:
221
+ self.plateau_counter += 1
222
+
223
+ def should_freeze_scale(self, scale: int, current_acc: float) -> bool:
224
+ """Determine if a scale should be frozen."""
225
+ if self.config.freeze_strategy == "never":
226
+ return False
227
+
228
+ if self.scales_frozen[scale]:
229
+ return False
230
+
231
+ if self.config.freeze_strategy == "performance":
232
+ return current_acc >= self.config.freeze_threshold
233
+
234
+ return False
235
+
236
+ def should_unfreeze_scales(self) -> bool:
237
+ """Check if scales should be unfrozen due to plateau."""
238
+ if not self.config.unfreeze_on_plateau:
239
+ return False
240
+ return self.plateau_counter >= 5
241
+
242
+ def apply_adaptive_strategies(self, scale_accuracies: Dict[int, float], epoch: int):
243
+ """Apply freeze/unfreeze based on performance."""
244
+ active_scales = self.model.get_active_scales()
245
+
246
+ # Don't freeze scales if it would leave no trainable parameters
247
+ for scale, acc in scale_accuracies.items():
248
+ if self.should_freeze_scale(scale, acc):
249
+ # Count how many active scales would remain unfrozen
250
+ active_unfrozen = [s for s in active_scales if not self.scales_frozen.get(s, False)]
251
+
252
+ if len(active_unfrozen) <= 1:
253
+ print(f"[⚠️] Skipping freeze of scale {scale} (would leave no active trainable scales)")
254
+ continue
255
+
256
+ self.model.freeze_scale(scale)
257
+ self.scales_frozen[scale] = True
258
+ print(f"[❄️] Froze scale {scale} (acc={acc:.2f}%)")
259
+
260
+ if self.should_unfreeze_scales() and any(self.scales_frozen.values()):
261
+ for scale in self.model.scales:
262
+ if self.scales_frozen[scale]:
263
+ self.model.unfreeze_scale(scale)
264
+ self.scales_frozen[scale] = False
265
+ self.plateau_counter = 0
266
+ print(f"[🔥] Unfroze all scales due to plateau")
267
+
268
+
269
+ # ============================================================================
270
+ # OPTIMIZER & SCHEDULER CREATION
271
+ # ============================================================================
272
+
273
+ def create_optimizer(david: David, config: DavidTrainingConfig) -> torch.optim.Optimizer:
274
+ """Create optimizer with parameter groups."""
275
+
276
+ param_groups = []
277
+
278
+ # Shared parameters (if exists)
279
+ if hasattr(david, 'shared_extractor'):
280
+ param_groups.append({
281
+ 'params': david.shared_extractor.parameters(),
282
+ 'lr': config.learning_rate,
283
+ 'name': 'shared'
284
+ })
285
+ elif hasattr(david, 'shared_base'):
286
+ param_groups.append({
287
+ 'params': david.shared_base.parameters(),
288
+ 'lr': config.learning_rate,
289
+ 'name': 'shared'
290
+ })
291
+
292
+ # Scale-specific parameters
293
+ for scale in david.scales:
294
+ scale_params = []
295
+ if david.sharing_mode == SharingMode.HIERARCHICAL:
296
+ head = getattr(david, f'head_{scale}', None)
297
+ if head:
298
+ scale_params.extend(head.parameters())
299
+ refine = getattr(david, f'refine_{scale}', None)
300
+ if refine:
301
+ scale_params.extend(refine.parameters())
302
+ else:
303
+ scale_params.extend(david.heads[str(scale)].parameters())
304
+
305
+ if scale_params:
306
+ param_groups.append({
307
+ 'params': scale_params,
308
+ 'lr': config.learning_rate,
309
+ 'name': f'scale_{scale}'
310
+ })
311
+
312
+ # Fusion parameters
313
+ if hasattr(david, 'fusion'):
314
+ param_groups.append({
315
+ 'params': david.fusion.parameters(),
316
+ 'lr': config.learning_rate * 0.5,
317
+ 'name': 'fusion'
318
+ })
319
+ elif hasattr(david, 'fusion_weights'):
320
+ param_groups.append({
321
+ 'params': [david.fusion_weights],
322
+ 'lr': config.learning_rate * 0.5,
323
+ 'name': 'fusion'
324
+ })
325
+
326
+ return torch.optim.AdamW(param_groups, weight_decay=config.weight_decay)
327
+
328
+
329
+ def create_scheduler(optimizer: torch.optim.Optimizer,
330
+ config: DavidTrainingConfig) -> torch.optim.lr_scheduler._LRScheduler:
331
+ """Create learning rate scheduler."""
332
+
333
+ if config.scheduler_type == "cosine_restarts":
334
+ return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
335
+ optimizer, T_0=10, T_mult=2, eta_min=config.min_lr
336
+ )
337
+ elif config.scheduler_type == "cosine":
338
+ return torch.optim.lr_scheduler.CosineAnnealingLR(
339
+ optimizer, T_max=config.num_epochs, eta_min=config.min_lr
340
+ )
341
+ else:
342
+ return None
343
+
344
+
345
+ # ============================================================================
346
+ # GRADIENT ANALYSIS
347
+ # ============================================================================
348
+
349
+ def analyze_gradients(model: David, config: DavidTrainingConfig) -> Dict[str, float]:
350
+ """Analyze gradient magnitudes for debugging."""
351
+ grad_stats = {
352
+ 'mean': 0.0,
353
+ 'max': 0.0,
354
+ 'min': float('inf'),
355
+ 'num_zero': 0,
356
+ 'num_small': 0,
357
+ 'total': 0
358
+ }
359
+
360
+ for name, param in model.named_parameters():
361
+ if param.grad is not None:
362
+ grad_norm = param.grad.norm().item()
363
+ grad_stats['mean'] += grad_norm
364
+ grad_stats['max'] = max(grad_stats['max'], grad_norm)
365
+ grad_stats['min'] = min(grad_stats['min'], grad_norm)
366
+ grad_stats['total'] += 1
367
+
368
+ if grad_norm < 1e-10:
369
+ grad_stats['num_zero'] += 1
370
+ elif grad_norm < config.gradient_scale_threshold:
371
+ grad_stats['num_small'] += 1
372
+
373
+ if grad_stats['total'] > 0:
374
+ grad_stats['mean'] /= grad_stats['total']
375
+
376
+ return grad_stats
377
+
378
+
379
+ def scale_small_gradients(model: David, config: DavidTrainingConfig):
380
+ """Scale up very small gradients to prevent vanishing."""
381
+ if not config.track_gradients:
382
+ return
383
+
384
+ for param in model.parameters():
385
+ if param.grad is not None:
386
+ grad_norm = param.grad.norm()
387
+ if grad_norm < config.gradient_scale_threshold and grad_norm > 0:
388
+ param.grad.mul_(config.gradient_scale_multiplier)
389
+
390
+
391
+ # ============================================================================
392
+ # HUGGINGFACE HUB UTILITIES
393
+ # ============================================================================
394
+
395
+ def generate_model_readme(
396
+ config: DavidTrainingConfig,
397
+ david_config: DavidArchitectureConfig,
398
+ best_metrics: Dict,
399
+ run_id: str
400
+ ) -> str:
401
+ """Generate README.md for model card."""
402
+
403
+ readme = f"""---
404
+ language: en
405
+ license: mit
406
+ tags:
407
+ - image-classification
408
+ - imagenet
409
+ - multi-scale
410
+ - feature-geometry
411
+ - david
412
+ datasets:
413
+ - imagenet-1k
414
+ metrics:
415
+ - accuracy
416
+ model-index:
417
+ - name: David-{david_config.sharing_mode}-{david_config.fusion_mode}
418
+ results:
419
+ - task:
420
+ type: image-classification
421
+ dataset:
422
+ name: ImageNet-1K
423
+ type: imagenet-1k
424
+ metrics:
425
+ - type: accuracy
426
+ value: {best_metrics.get('best_val_acc', 0.0):.2f}
427
+ ---
428
+
429
+ # David: Multi-Scale Feature Classifier
430
+
431
+ **David** is a multi-scale deep learning classifier that uses feature geometry (pentachora/4-simplexes)
432
+ as class prototypes with role-weighted similarity computation (Rose Loss).
433
+
434
+ This version is using multiple variations of clip-vit inputs simultaneously into shared space.
435
+ The experiment will determine if entirely deviant variations such as clip-vit-b-patch32 and patch16 can
436
+ exist simultaneously in the same shared space with the correct checks and spacings applied.
437
+
438
+ ## Model Details
439
+
440
+ ### Architecture
441
+ - **Preset**: {config.preset}
442
+ - **Sharing Mode**: {david_config.sharing_mode}
443
+ - **Fusion Mode**: {david_config.fusion_mode}
444
+ - **Scales**: {david_config.scales}
445
+ - **Feature Dim**: {david_config.feature_dim}
446
+ - **Parameters**: {best_metrics.get('parameters', 0):,}
447
+
448
+ ### Training Configuration
449
+ - **Dataset**: {config.dataset_name}
450
+ - **Model Variant**: {config.model_variant}
451
+ - **Epochs**: {config.num_epochs}
452
+ - **Batch Size**: {config.batch_size}
453
+ - **Learning Rate**: {config.learning_rate}
454
+ - **Rose Loss Weight**: {config.rose_initial_weight} → {config.rose_max_weight}
455
+ - **Cayley Loss**: {config.use_cayley_loss}
456
+
457
+ ## Performance
458
+
459
+ ### Best Results
460
+ - **Validation Accuracy**: {best_metrics.get('best_val_acc', 0.0):.2f}%
461
+ - **Best Epoch**: {best_metrics.get('best_epoch', 0)}
462
+ - **Final Train Accuracy**: {best_metrics.get('final_train_acc', 0.0):.2f}%
463
+
464
+ ### Per-Scale Performance
465
+ """
466
+
467
+ if 'scale_accuracies' in best_metrics:
468
+ for scale, acc in best_metrics['scale_accuracies'].items():
469
+ readme += f"- **Scale {scale}**: {acc:.2f}%\n"
470
+
471
+ readme += f"""
472
+
473
+ ## Usage
474
+
475
+ ### Quick Model Lookup
476
+
477
+ **Check `MODELS_INDEX.json` in the repo root** - it lists all trained models sorted by accuracy with links to weights and configs.
478
+
479
+ ### Repository Structure
480
+
481
+ ```
482
+ {config.hf_repo if config.hf_repo else 'AbstractPhil/david'}/
483
+ ├── MODELS_INDEX.json # 📊 Master index of all models (sorted by accuracy)
484
+ ├── README.md # This file
485
+ ├── best_model.json # Latest best model info
486
+ ├── weights/
487
+ │ └── {david_config.name}/
488
+ │ └── {run_id}/
489
+ │ ├── MODEL_SUMMARY.txt # 🎯 Human-readable performance summary
490
+ │ ├── training_history.json # 📈 Epoch-by-epoch training curve
491
+ │ ├── best_model_acc{best_metrics.get('best_val_acc', 0.0):.2f}.safetensors # ⭐ Accuracy in filename!
492
+ │ ├── best_model_acc{best_metrics.get('best_val_acc', 0.0):.2f}_metadata.json
493
+ │ ├── final_model.safetensors
494
+ │ ├── checkpoint_epoch_X_accYY.YY.safetensors
495
+ │ ├── david_config.json
496
+ │ └── train_config.json
497
+ └── runs/
498
+ └── {david_config.name}/
499
+ └── {run_id}/
500
+ └── events.out.tfevents.* # TensorBoard logs
501
+ ```
502
+
503
+ ### Loading the Model
504
+
505
+ ```python
506
+ from geovocab2.train.model.core.david import David, DavidArchitectureConfig
507
+ from huggingface_hub import hf_hub_download
508
+
509
+ # Browse available models in MODELS_INDEX.json first!
510
+
511
+ # Specify model variant and run
512
+ model_name = "{david_config.name}"
513
+ run_id = "{run_id}"
514
+ accuracy = "{best_metrics.get('best_val_acc', 0.0):.2f}" # From MODELS_INDEX.json
515
+
516
+ # Download config
517
+ config_path = hf_hub_download(
518
+ repo_id="{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}",
519
+ filename=f"weights/{{model_name}}/{{run_id}}/david_config.json"
520
+ )
521
+ config = DavidArchitectureConfig.from_json(config_path)
522
+
523
+ # Download weights (accuracy in filename!)
524
+ weights_path = hf_hub_download(
525
+ repo_id="{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}",
526
+ filename=f"weights/{{model_name}}/{{run_id}}/best_model_acc{{accuracy}}.safetensors"
527
+ )
528
+
529
+ # Download training history (optional - see full training curve)
530
+ history_path = hf_hub_download(
531
+ repo_id="{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}",
532
+ filename=f"weights/{{model_name}}/{{run_id}}/training_history.json"
533
+ )
534
+
535
+ # Load model
536
+ from safetensors.torch import load_file
537
+ david = David.from_config(config)
538
+ david.load_state_dict(load_file(weights_path))
539
+ david.eval()
540
+ ```
541
+
542
+ ### Inference
543
+
544
+ ```python
545
+ import torch
546
+ import torch.nn.functional as F
547
+
548
+ # Assuming you have CLIP features (512-dim for ViT-B/16)
549
+ features = get_clip_features(image) # [1, 512]
550
+
551
+ # Load anchors
552
+ anchors_dict = torch.load("anchors.pth")
553
+
554
+ # Forward pass
555
+ with torch.no_grad():
556
+ logits, _ = david(features, anchors_dict)
557
+ predictions = logits.argmax(dim=-1)
558
+ ```
559
+
560
+ ## Architecture Overview
561
+
562
+ ### Multi-Scale Processing
563
+ David processes inputs at multiple scales ({', '.join(map(str, david_config.scales))}),
564
+ allowing it to capture both coarse and fine-grained features.
565
+
566
+ ### Shared Representation Space
567
+ This variation shares multiple versions of clip-vit models in the same representation space.
568
+
569
+ ### Feature Geometry
570
+ Each class is represented by a pentachoron (4-simplex) in embedding space with 5 vertices:
571
+ - **Anchor**: Primary class representative
572
+ - **Need**: Complementary direction
573
+ - **Relation**: Contextual alignment
574
+ - **Purpose**: Functional direction
575
+ - **Observer**: Meta-perspective
576
+
577
+ ### Rose Loss
578
+ Similarity computation uses role-weighted cosine similarities:
579
+ ```
580
+ score = w_anchor * sim(z, anchor) + w_need * sim(z, need) + ...
581
+ ```
582
+
583
+ ### Fusion Strategy
584
+ **{david_config.fusion_mode}**: Intelligently combines predictions from multiple scales.
585
+
586
+ ## Training Details
587
+
588
+ ### Loss Components
589
+ - **Cross-Entropy**: Standard classification loss
590
+ - **Rose Loss**: Pentachora role-weighted margin loss (weight: {config.rose_initial_weight}→{config.rose_max_weight})
591
+ - **Cayley Loss**: Geometric regularization ({'enabled' if config.use_cayley_loss else 'disabled'})
592
+
593
+ ### Optimization
594
+ - **Optimizer**: AdamW
595
+ - **Weight Decay**: {config.weight_decay}
596
+ - **Scheduler**: {config.scheduler_type}
597
+ - **Gradient Clip**: {config.gradient_clip}
598
+ - **Mixed Precision**: {config.use_mixed_precision}
599
+
600
+ ## Citation
601
+
602
+ ```bibtex
603
+ @software{{david_classifier_2025,
604
+ title = {{David: Multi-Scale Feature Classifier}},
605
+ author = {{AbstractPhil}},
606
+ year = {{2025}},
607
+ url = {{https://huggingface.co/{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}}},
608
+ note = {{Run ID: {run_id}}}
609
+ }}
610
+ ```
611
+
612
+ ## License
613
+
614
+ MIT License
615
+
616
+ ## Acknowledgments
617
+
618
+ Built with feature lattice geometry and multi-scale deep learning.
619
+ Special thanks to Claude (Anthropic) for debugging assistance.
620
+
621
+ ---
622
+
623
+ *Generated on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*
624
+ """
625
+
626
+ return readme
627
+
628
+
629
+ def save_best_model_json(
630
+ filepath: str,
631
+ metrics: Dict,
632
+ config: DavidTrainingConfig,
633
+ david_config: DavidArchitectureConfig
634
+ ):
635
+ """Save best_model.json with comprehensive metrics."""
636
+
637
+ model_name = f"David-{david_config.sharing_mode}-{david_config.fusion_mode}"
638
+
639
+ best_model_info = {
640
+ "model_name": model_name,
641
+ "run_id": config.run_id,
642
+ "timestamp": datetime.now().isoformat(),
643
+
644
+ # Best metrics
645
+ "best_val_acc": metrics.get('best_val_acc', 0.0),
646
+ "best_epoch": metrics.get('best_epoch', 0),
647
+ "final_train_acc": metrics.get('final_train_acc', 0.0),
648
+ "final_train_loss": metrics.get('final_train_loss', 0.0),
649
+
650
+ # Per-scale performance
651
+ "scale_accuracies": metrics.get('scale_accuracies', {}),
652
+
653
+ # Architecture
654
+ "architecture": {
655
+ "preset": config.preset,
656
+ "sharing_mode": david_config.sharing_mode,
657
+ "fusion_mode": david_config.fusion_mode,
658
+ "scales": david_config.scales,
659
+ "feature_dim": david_config.feature_dim,
660
+ "num_classes": david_config.num_classes,
661
+ "use_belly": david_config.use_belly,
662
+ "belly_expand": david_config.belly_expand,
663
+ },
664
+
665
+ # Training config
666
+ "training": {
667
+ "dataset": config.dataset_name,
668
+ "model_variant": config.model_variant,
669
+ "num_epochs": config.num_epochs,
670
+ "batch_size": config.batch_size,
671
+ "learning_rate": config.learning_rate,
672
+ "rose_weight": f"{config.rose_initial_weight}→{config.rose_max_weight}",
673
+ "cayley_loss": config.use_cayley_loss,
674
+ "optimizer": "AdamW",
675
+ "scheduler": config.scheduler_type,
676
+ },
677
+
678
+ # Files (organized by model/run)
679
+ "files": {
680
+ "weights_safetensors": f"weights/{model_name}/{config.run_id}/best_model_acc{metrics.get('best_val_acc', 0.0):.2f}.safetensors",
681
+ "weights_pytorch": f"weights/{model_name}/{config.run_id}/best_model.pth",
682
+ "config": f"weights/{model_name}/{config.run_id}/david_config.json",
683
+ "training_config": f"weights/{model_name}/{config.run_id}/train_config.json",
684
+ "tensorboard": f"runs/{model_name}/{config.run_id}/"
685
+ }
686
+ }
687
+
688
+ with open(filepath, 'w') as f:
689
+ json.dump(best_model_info, f, indent=2)
690
+
691
+ print(f"[📄] Saved best_model.json: {filepath}")
692
+
693
+
694
+ def create_model_summary(
695
+ weights_dir: str,
696
+ config: DavidTrainingConfig,
697
+ david_config: DavidArchitectureConfig,
698
+ best_metrics: Dict,
699
+ model_name: str
700
+ ):
701
+ """Create prominent model summary with accuracy front and center."""
702
+
703
+ summary_path = os.path.join(weights_dir, 'MODEL_SUMMARY.txt')
704
+
705
+ best_acc = best_metrics.get('best_val_acc', 0.0)
706
+ training_history = best_metrics.get('training_history', {})
707
+
708
+ summary = f"""
709
+ ╔══════════════════════════════════════════════════════════════╗
710
+ ║ DAVID MODEL SUMMARY ║
711
+ ╠══════════════════════════════════════════════════════════════╣
712
+ ║ ║
713
+ ║ 🎯 VALIDATION ACCURACY: {best_acc:.2f}% ║
714
+ ║ ║
715
+ ╚════════���═════════════════════════════════════════════════════╝
716
+
717
+ MODEL: {model_name}
718
+ RUN ID: {config.run_id}
719
+ BEST EPOCH: {best_metrics.get('best_epoch', 0) + 1}/{config.num_epochs}
720
+
721
+ ═══════════════════════════════════════════════════════════════
722
+
723
+ 📊 PERFORMANCE BREAKDOWN
724
+
725
+ Final Training Accuracy: {best_metrics.get('final_train_acc', 0.0):.2f}%
726
+ Best Validation Accuracy: {best_acc:.2f}%
727
+
728
+ Per-Scale Accuracies:
729
+ """
730
+
731
+ scale_accs = best_metrics.get('scale_accuracies', {})
732
+ for scale in sorted(scale_accs.keys()):
733
+ acc = scale_accs[scale]
734
+ summary += f" • Scale {scale:4d}: {acc:.2f}%\n"
735
+
736
+ summary += f"""
737
+ ═══════════════════════════════════════════════════════════════
738
+
739
+ 🏗️ ARCHITECTURE
740
+
741
+ Preset: {config.preset}
742
+ Sharing Mode: {david_config.sharing_mode}
743
+ Fusion Mode: {david_config.fusion_mode}
744
+ Scales: {len(david_config.scales)} scales - {david_config.scales}
745
+ Feature Dim: {david_config.feature_dim}
746
+ Parameters: {best_metrics.get('parameters', 0):,}
747
+
748
+ ═══════════════════════════════════════════════════════════════
749
+
750
+ 📈 TRAINING CURVE
751
+
752
+ """
753
+
754
+ if training_history and 'val_acc' in training_history:
755
+ summary += "Epoch | Train Acc | Val Acc | Learning Rate\n"
756
+ summary += "------|-----------|----------|--------------\n"
757
+
758
+ for i, epoch in enumerate(training_history.get('epochs', [])):
759
+ train_acc = training_history['train_acc'][i] if i < len(training_history['train_acc']) else 0
760
+ val_acc = training_history['val_acc'][i] if i < len(training_history['val_acc']) else 0
761
+ lr = training_history['lr'][i] if i < len(training_history['lr']) else 0
762
+
763
+ marker = " 👑" if val_acc == best_acc else ""
764
+ summary += f"{epoch:5d} | {train_acc:8.2f}% | {val_acc:7.2f}%{marker} | {lr:.2e}\n"
765
+
766
+ summary += f"""
767
+ ═══════════════════════════════════════════════════════════════
768
+
769
+ 📁 FILES
770
+
771
+ Best Model: best_model_acc{best_acc:.2f}.safetensors
772
+ Config: david_config.json
773
+ Training Cfg: train_config.json
774
+ History: training_history.json
775
+
776
+ ═══════════════════════════════════════════════════════════════
777
+
778
+ Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
779
+ """
780
+
781
+ with open(summary_path, 'w') as f:
782
+ f.write(summary)
783
+
784
+ print(f"[📄] Created MODEL_SUMMARY.txt")
785
+ return summary_path
786
+
787
+
788
+ def update_models_index(
789
+ config: DavidTrainingConfig,
790
+ david_config: DavidArchitectureConfig,
791
+ best_metrics: Dict,
792
+ model_name: str
793
+ ):
794
+ """Update master models index file tracking all trained models."""
795
+
796
+ if not config.upload_to_hub or not config.hf_repo:
797
+ return
798
+
799
+ try:
800
+ from huggingface_hub import hf_hub_download
801
+ api = HfApi()
802
+
803
+ # Try to download existing index
804
+ try:
805
+ index_path = hf_hub_download(
806
+ repo_id=config.hf_repo,
807
+ filename="MODELS_INDEX.json",
808
+ repo_type="model"
809
+ )
810
+ with open(index_path, 'r') as f:
811
+ models_index = json.load(f)
812
+ except:
813
+ # Create new index if doesn't exist
814
+ models_index = {
815
+ "repository": config.hf_repo,
816
+ "updated": datetime.now().isoformat(),
817
+ "models": []
818
+ }
819
+
820
+ # Add current model entry
821
+ model_entry = {
822
+ "model_name": model_name,
823
+ "run_id": config.run_id,
824
+ "timestamp": datetime.now().isoformat(),
825
+ "best_val_acc": best_metrics.get('best_val_acc', 0.0),
826
+ "best_epoch": best_metrics.get('best_epoch', 0),
827
+ "num_scales": len(david_config.scales),
828
+ "scales": david_config.scales,
829
+ "parameters": best_metrics.get('parameters', 0),
830
+ "sharing_mode": david_config.sharing_mode,
831
+ "fusion_mode": david_config.fusion_mode,
832
+ "preset": config.preset,
833
+ "weights_path": f"weights/{model_name}/{config.run_id}/best_model_acc{best_metrics.get('best_val_acc', 0.0):.2f}.safetensors",
834
+ "config_path": f"weights/{model_name}/{config.run_id}/david_config.json",
835
+ "history_path": f"weights/{model_name}/{config.run_id}/training_history.json"
836
+ }
837
+
838
+ # Remove old entry for same run_id if exists (update)
839
+ models_index["models"] = [m for m in models_index["models"] if m.get("run_id") != config.run_id]
840
+ models_index["models"].append(model_entry)
841
+
842
+ # Sort by accuracy (descending)
843
+ models_index["models"].sort(key=lambda x: x.get("best_val_acc", 0), reverse=True)
844
+ models_index["updated"] = datetime.now().isoformat()
845
+ models_index["total_models"] = len(models_index["models"])
846
+
847
+ # Save locally
848
+ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as f:
849
+ json.dump(models_index, f, indent=2)
850
+ temp_path = f.name
851
+
852
+ # Upload to hub root
853
+ api.upload_file(
854
+ path_or_fileobj=temp_path,
855
+ path_in_repo="MODELS_INDEX.json",
856
+ repo_id=config.hf_repo,
857
+ commit_message=f"Update models index - {model_name} @ {best_metrics.get('best_val_acc', 0.0):.2f}%"
858
+ )
859
+
860
+ os.unlink(temp_path)
861
+ print(f"[📊] Updated MODELS_INDEX.json - {len(models_index['models'])} models tracked")
862
+
863
+ except Exception as e:
864
+ print(f"[⚠️] Failed to update models index: {e}")
865
+
866
+
867
+ def upload_to_huggingface(
868
+ local_dir: str,
869
+ repo_id: str,
870
+ commit_message: str,
871
+ path_in_repo: Optional[str] = None,
872
+ patterns: Optional[List[str]] = None
873
+ ):
874
+ """Upload directory to HuggingFace Hub."""
875
+
876
+ try:
877
+ api = HfApi()
878
+
879
+ # Create repo if it doesn't exist
880
+ try:
881
+ create_repo(repo_id, exist_ok=True, repo_type="model")
882
+ print(f"[🤗] Repo ready: {repo_id}")
883
+ except Exception as e:
884
+ print(f"[⚠️] Repo exists or creation failed: {e}")
885
+
886
+ # Upload folder
887
+ if patterns:
888
+ # Upload specific patterns
889
+ for pattern in patterns:
890
+ matching_files = list(Path(local_dir).rglob(pattern))
891
+ for file_path in matching_files:
892
+ rel_path = file_path.relative_to(local_dir)
893
+ if path_in_repo:
894
+ repo_path = f"{path_in_repo}/{rel_path}"
895
+ else:
896
+ repo_path = str(rel_path)
897
+
898
+ api.upload_file(
899
+ path_or_fileobj=str(file_path),
900
+ path_in_repo=repo_path,
901
+ repo_id=repo_id,
902
+ commit_message=commit_message
903
+ )
904
+ else:
905
+ # Upload entire folder
906
+ api.upload_folder(
907
+ folder_path=local_dir,
908
+ repo_id=repo_id,
909
+ path_in_repo=path_in_repo,
910
+ commit_message=commit_message
911
+ )
912
+
913
+ print(f"[✅] Uploaded to Hub: https://huggingface.co/{repo_id}")
914
+
915
+ except Exception as e:
916
+ print(f"[❌] Hub upload failed: {e}")
917
+ print(f" Continuing training (files saved locally)")
918
+
919
+
920
+ def prepare_hub_upload(
921
+ weights_dir: str,
922
+ runs_dir: str,
923
+ config: DavidTrainingConfig,
924
+ david_config: DavidArchitectureConfig,
925
+ best_metrics: Dict,
926
+ model_name: str
927
+ ):
928
+ """Prepare and upload all artifacts to HuggingFace Hub."""
929
+
930
+ if not config.upload_to_hub or not config.hf_repo:
931
+ return
932
+
933
+ print("\n[🤗] Preparing HuggingFace Hub upload...")
934
+
935
+ # Create model summary file
936
+ summary_path = create_model_summary(weights_dir, config, david_config, best_metrics, model_name)
937
+
938
+ # Update master models index
939
+ update_models_index(config, david_config, best_metrics, model_name)
940
+
941
+ api = HfApi()
942
+ try:
943
+ create_repo(config.hf_repo, exist_ok=True, repo_type="model")
944
+ except:
945
+ pass
946
+
947
+ # Create temporary directory for root files
948
+ with tempfile.TemporaryDirectory() as temp_dir:
949
+ # Generate README at root
950
+ readme_path = os.path.join(temp_dir, "README.md")
951
+ readme_content = generate_model_readme(config, david_config, best_metrics, config.run_id)
952
+ with open(readme_path, 'w') as f:
953
+ f.write(readme_content)
954
+ print(f"[📝] Generated README.md")
955
+
956
+ # Save best_model.json at root
957
+ best_json_path = os.path.join(temp_dir, "best_model.json")
958
+ save_best_model_json(best_json_path, best_metrics, config, david_config)
959
+
960
+ # Upload root files (README.md, best_model.json)
961
+ print(f"[📤] Uploading root files...")
962
+
963
+ api.upload_file(
964
+ path_or_fileobj=readme_path,
965
+ path_in_repo="README.md",
966
+ repo_id=config.hf_repo,
967
+ commit_message=f"Update README - Run {config.run_id}"
968
+ )
969
+
970
+ api.upload_file(
971
+ path_or_fileobj=best_json_path,
972
+ path_in_repo="best_model.json",
973
+ repo_id=config.hf_repo,
974
+ commit_message=f"Update metrics - Run {config.run_id}"
975
+ )
976
+
977
+ # Upload ONLY essential weight files (not entire directory!)
978
+ weights_repo_path = f"weights/{model_name}/{config.run_id}"
979
+ best_acc = best_metrics.get('best_val_acc', 0.0)
980
+
981
+ print(f"[📤] Uploading essential files to {weights_repo_path}...")
982
+
983
+ # List of specific files to upload (not entire directory)
984
+ files_to_upload = [
985
+ ('MODEL_SUMMARY.txt', 'MODEL_SUMMARY.txt'),
986
+ ('training_history.json', 'training_history.json'),
987
+ ('david_config.json', 'david_config.json'),
988
+ ('train_config.json', 'train_config.json'),
989
+ (f'best_model_acc{best_acc:.2f}.safetensors', f'best_model_acc{best_acc:.2f}.safetensors'),
990
+ (f'best_model_acc{best_acc:.2f}_metadata.json', f'best_model_acc{best_acc:.2f}_metadata.json'),
991
+ ]
992
+
993
+ for local_filename, repo_filename in files_to_upload:
994
+ local_path = os.path.join(weights_dir, local_filename)
995
+ if os.path.exists(local_path):
996
+ try:
997
+ api.upload_file(
998
+ path_or_fileobj=local_path,
999
+ path_in_repo=f"{weights_repo_path}/{repo_filename}",
1000
+ repo_id=config.hf_repo,
1001
+ commit_message=f"Update {repo_filename} - Run {config.run_id}"
1002
+ )
1003
+ except Exception as e:
1004
+ print(f"[⚠️] Failed to upload {repo_filename}: {e}")
1005
+
1006
+ print(f"[✅] Uploaded to Hub: https://huggingface.co/{config.hf_repo}")
1007
+
1008
+ # Upload tensorboard logs (only if they exist and it's final upload)
1009
+ # Skip TensorBoard during training to avoid huge uploads every epoch
1010
+ # if os.path.exists(runs_dir):
1011
+ # runs_repo_path = f"runs/{model_name}/{config.run_id}"
1012
+ # print(f"[📤] Uploading TensorBoard logs to {runs_repo_path}...")
1013
+ # upload_to_huggingface(
1014
+ # local_dir=runs_dir,
1015
+ # repo_id=config.hf_repo,
1016
+ # commit_message=f"Upload TensorBoard logs - {model_name} - Run {config.run_id}",
1017
+ # path_in_repo=runs_repo_path
1018
+ # )
1019
+
1020
+
1021
+ # ============================================================================
1022
+ # CHECKPOINT UTILITIES
1023
+ # ============================================================================
1024
+
1025
+ def save_checkpoint(
1026
+ filepath: str,
1027
+ david: David,
1028
+ optimizer: torch.optim.Optimizer,
1029
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
1030
+ epoch: int,
1031
+ metrics: Dict,
1032
+ train_config: DavidTrainingConfig
1033
+ ):
1034
+ """Save checkpoint in PyTorch and/or SafeTensors format."""
1035
+
1036
+ checkpoint = {
1037
+ 'epoch': epoch,
1038
+ 'model_state_dict': david.state_dict(),
1039
+ 'optimizer_state_dict': optimizer.state_dict(),
1040
+ 'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
1041
+ 'metrics': metrics,
1042
+ 'train_config': train_config.to_dict(),
1043
+ }
1044
+
1045
+ # Add accuracy to filename if available
1046
+ val_acc = metrics.get('best_val_acc') or metrics.get('val_acc')
1047
+ if val_acc:
1048
+ acc_suffix = f"_acc{val_acc:.2f}"
1049
+ filepath = filepath + acc_suffix
1050
+
1051
+ if train_config.save_format in ['pytorch', 'both']:
1052
+ torch.save(checkpoint, filepath + '.pth')
1053
+ print(f"[💾] Saved PyTorch: {filepath}.pth")
1054
+
1055
+ if train_config.save_format in ['safetensors', 'both']:
1056
+ try:
1057
+ from safetensors.torch import save_file
1058
+
1059
+ # Save model state
1060
+ model_state = {k: v.contiguous() for k, v in david.state_dict().items()}
1061
+ save_file(model_state, filepath + '.safetensors')
1062
+
1063
+ # Save metadata separately (now includes full training history)
1064
+ metadata = {k: v for k, v in checkpoint.items()
1065
+ if k not in ['model_state_dict']}
1066
+ with open(filepath + '_metadata.json', 'w') as f:
1067
+ json.dump(metadata, f, indent=2, default=str)
1068
+
1069
+ print(f"[💾] Saved SafeTensors: {filepath}.safetensors")
1070
+ except ImportError:
1071
+ print(f"[⚠️] SafeTensors not available, skipping")
1072
+
1073
+
1074
+ def load_checkpoint(
1075
+ checkpoint_path: str,
1076
+ david: David,
1077
+ optimizer: Optional[torch.optim.Optimizer] = None,
1078
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
1079
+ device: str = "cuda"
1080
+ ) -> Tuple[int, Dict]:
1081
+ """Load checkpoint and return epoch and metrics."""
1082
+
1083
+ if checkpoint_path.endswith('.safetensors'):
1084
+ # Load SafeTensors format
1085
+ try:
1086
+ from safetensors.torch import load_file
1087
+
1088
+ model_state = load_file(checkpoint_path, device=device)
1089
+ david.load_state_dict(model_state)
1090
+
1091
+ # Load metadata
1092
+ metadata_path = checkpoint_path.replace('.safetensors', '_metadata.json')
1093
+ with open(metadata_path, 'r') as f:
1094
+ metadata = json.load(f)
1095
+
1096
+ epoch = metadata.get('epoch', 0)
1097
+ metrics = metadata.get('metrics', {})
1098
+
1099
+ if optimizer and 'optimizer_state_dict' in metadata:
1100
+ optimizer.load_state_dict(metadata['optimizer_state_dict'])
1101
+
1102
+ if scheduler and 'scheduler_state_dict' in metadata and metadata['scheduler_state_dict']:
1103
+ scheduler.load_state_dict(metadata['scheduler_state_dict'])
1104
+
1105
+ print(f"[✅] Loaded from SafeTensors: {checkpoint_path}")
1106
+ return epoch, metrics
1107
+
1108
+ except ImportError:
1109
+ raise ImportError("safetensors not installed")
1110
+
1111
+ else:
1112
+ # Load PyTorch format
1113
+ checkpoint = torch.load(checkpoint_path, map_location=device)
1114
+
1115
+ david.load_state_dict(checkpoint['model_state_dict'])
1116
+
1117
+ if optimizer and 'optimizer_state_dict' in checkpoint:
1118
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
1119
+
1120
+ if scheduler and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']:
1121
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
1122
+
1123
+ print(f"[✅] Loaded from PyTorch: {checkpoint_path}")
1124
+ return checkpoint['epoch'], checkpoint.get('metrics', {})
1125
+
1126
+
1127
+ # ============================================================================
1128
+ # DATASET
1129
+ # ============================================================================
1130
+
1131
+ class ImageNetHFDataset(Dataset):
1132
+ """PyTorch Dataset wrapper for HuggingFace ImageNet features."""
1133
+
1134
+ def __init__(self, dataset_name: str, model_variant: str, split: str = "train"):
1135
+ # Load only the specific split to avoid downloading all data
1136
+ print(f"[📥] Loading {split} split for {model_variant}...")
1137
+ self.dataset = load_dataset(
1138
+ dataset_name,
1139
+ name=model_variant, # Dataset configuration/variant name
1140
+ split=split # Only load this specific split
1141
+ )
1142
+ self.length = len(self.dataset)
1143
+ print(f"[✅] Loaded {self.length:,} samples from {split} split")
1144
+
1145
+ def __len__(self):
1146
+ return self.length
1147
+
1148
+ def __getitem__(self, idx):
1149
+ item = self.dataset[idx]
1150
+ features = torch.tensor(item['clip_features'], dtype=torch.float32)
1151
+ label = torch.tensor(item['label'], dtype=torch.long)
1152
+ return features, label
1153
+
1154
+
1155
+ class MergedImageNetDataset(Dataset):
1156
+ """
1157
+ Merge multiple CLIP variants into a single dataset.
1158
+ Perfect for testing if David can unify different encoder spaces!
1159
+ """
1160
+
1161
+ def __init__(
1162
+ self,
1163
+ dataset_name: str,
1164
+ model_variants: List[str], # e.g., ['clip_vit_b16', 'clip_vit_laion_b16']
1165
+ split: str = "train",
1166
+ shuffle_seed: int = 42
1167
+ ):
1168
+ print(f"[🔀] Creating merged dataset from {len(model_variants)} variants...")
1169
+
1170
+ self.datasets = []
1171
+ self.cumulative_lengths = [0]
1172
+
1173
+ # Load each variant
1174
+ for variant in model_variants:
1175
+ print(f"[📥] Loading {split} split for {variant}...")
1176
+ ds = load_dataset(
1177
+ dataset_name,
1178
+ name=variant,
1179
+ split=split
1180
+ )
1181
+ self.datasets.append(ds)
1182
+ self.cumulative_lengths.append(self.cumulative_lengths[-1] + len(ds))
1183
+ print(f"[✅] Loaded {len(ds):,} samples from {variant}")
1184
+
1185
+ self.total_length = self.cumulative_lengths[-1]
1186
+
1187
+ # Create shuffled indices for fair mixing
1188
+ print(f"[🎲] Shuffling {self.total_length:,} samples (seed={shuffle_seed})...")
1189
+ rng = np.random.RandomState(shuffle_seed)
1190
+ self.shuffle_indices = rng.permutation(self.total_length)
1191
+
1192
+ print(f"[✅] Merged dataset ready: {self.total_length:,} samples from {len(model_variants)} encoders")
1193
+
1194
+ def __len__(self):
1195
+ return self.total_length
1196
+
1197
+ def __getitem__(self, idx):
1198
+ # Map shuffled index to original dataset
1199
+ actual_idx = int(self.shuffle_indices[idx])
1200
+
1201
+ # Find which dataset this index belongs to
1202
+ dataset_idx = 0
1203
+ for i, cumsum in enumerate(self.cumulative_lengths[1:]):
1204
+ if actual_idx < cumsum:
1205
+ dataset_idx = i
1206
+ break
1207
+
1208
+ # Get item from the correct dataset
1209
+ local_idx = actual_idx - self.cumulative_lengths[dataset_idx]
1210
+ item = self.datasets[dataset_idx][local_idx]
1211
+
1212
+ features = torch.tensor(item['clip_features'], dtype=torch.float32)
1213
+ label = torch.tensor(item['label'], dtype=torch.long)
1214
+
1215
+ return features, label
1216
+
1217
+
1218
+ def create_dataloaders(config: DavidTrainingConfig):
1219
+ """Create train and validation dataloaders."""
1220
+
1221
+ # Check if model_variant is a list (multi-encoder experiment)
1222
+ if isinstance(config.model_variant, list):
1223
+ print(f"[🧪] MULTI-ENCODER EXPERIMENT: Merging {len(config.model_variant)} variants")
1224
+ train_dataset = MergedImageNetDataset(
1225
+ config.dataset_name,
1226
+ config.model_variant, # List of variants
1227
+ "train"
1228
+ )
1229
+ val_dataset = MergedImageNetDataset(
1230
+ config.dataset_name,
1231
+ config.model_variant,
1232
+ "validation"
1233
+ )
1234
+ else:
1235
+ # Single encoder (normal mode)
1236
+ train_dataset = ImageNetHFDataset(
1237
+ config.dataset_name, config.model_variant, "train"
1238
+ )
1239
+ val_dataset = ImageNetHFDataset(
1240
+ config.dataset_name, config.model_variant, "validation"
1241
+ )
1242
+
1243
+ train_loader = DataLoader(
1244
+ train_dataset,
1245
+ batch_size=config.batch_size,
1246
+ shuffle=True,
1247
+ num_workers=config.num_workers,
1248
+ pin_memory=config.pin_memory,
1249
+ prefetch_factor=config.prefetch_factor,
1250
+ persistent_workers=config.persistent_workers
1251
+ )
1252
+
1253
+ val_loader = DataLoader(
1254
+ val_dataset,
1255
+ batch_size=config.batch_size * 2,
1256
+ shuffle=False,
1257
+ num_workers=config.num_workers,
1258
+ pin_memory=config.pin_memory,
1259
+ prefetch_factor=config.prefetch_factor,
1260
+ persistent_workers=config.persistent_workers
1261
+ )
1262
+
1263
+ return train_loader, val_loader
1264
+
1265
+
1266
+ # ============================================================================
1267
+ # CRYSTAL GENERATOR
1268
+ # ============================================================================
1269
+
1270
+ class CrystalGenerator:
1271
+ """Generate crystals for all scales."""
1272
+
1273
+ def __init__(self, num_classes: int, scales: List[int], device: str = "cuda"):
1274
+ self.num_classes = num_classes
1275
+ self.scales = scales
1276
+ self.device = device
1277
+ self.factories = {
1278
+ scale: SimplexFactory(k=4, embed_dim=scale, method="random")
1279
+ for scale in scales
1280
+ }
1281
+
1282
+ def generate(self, seed: int = 42) -> Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor]]:
1283
+ """Generate anchors and crystals for all scales."""
1284
+
1285
+ anchors_dict = {}
1286
+ crystals_dict = {}
1287
+
1288
+ for scale in tqdm(self.scales, desc="Generating crystals"):
1289
+ factory = self.factories[scale]
1290
+ batch_crystals = []
1291
+
1292
+ for class_idx in range(self.num_classes):
1293
+ crystal = factory.build(
1294
+ backend="torch",
1295
+ device=self.device,
1296
+ dtype=torch.float32,
1297
+ seed=seed + class_idx,
1298
+ validate=True
1299
+ )
1300
+ batch_crystals.append(crystal)
1301
+
1302
+ crystals = torch.stack(batch_crystals)
1303
+ anchors = F.normalize(crystals[:, 0, :], dim=-1)
1304
+
1305
+ # Verify anchor diversity
1306
+ anchor_sims = anchors @ anchors.T
1307
+ off_diag = anchor_sims[~torch.eye(self.num_classes, dtype=bool, device=anchors.device)]
1308
+ max_sim = off_diag.max().item()
1309
+ mean_sim = off_diag.mean().item()
1310
+
1311
+ print(f" Scale {scale}: max_sim={max_sim:.4f}, mean_sim={mean_sim:.4f}")
1312
+
1313
+ if max_sim > 0.99:
1314
+ print(f" ⚠️ WARNING: Anchors too similar at scale {scale}!")
1315
+
1316
+ anchors_dict[scale] = anchors
1317
+ crystals_dict[scale] = crystals
1318
+
1319
+ return anchors_dict, crystals_dict
1320
+
1321
+
1322
+ # ============================================================================
1323
+ # TRAINING LOOP
1324
+ # ============================================================================
1325
+
1326
+ def train_epoch(
1327
+ david: David,
1328
+ train_loader: DataLoader,
1329
+ optimizer: torch.optim.Optimizer,
1330
+ criterion: MultiScaleCrystalLoss,
1331
+ anchors_dict: Dict[int, torch.Tensor],
1332
+ crystals_dict: Dict[int, torch.Tensor],
1333
+ epoch: int,
1334
+ config: DavidTrainingConfig,
1335
+ writer: Optional[SummaryWriter],
1336
+ global_step: int
1337
+ ) -> Tuple[float, float, int, Dict]:
1338
+ """Train for one epoch - Pure FP32."""
1339
+
1340
+ david.train()
1341
+ david.update_epoch(epoch)
1342
+
1343
+ total_loss = 0
1344
+ correct = 0
1345
+ total = 0
1346
+ loss_components_sum = {}
1347
+
1348
+ active_scales = david.get_active_scales()
1349
+
1350
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
1351
+
1352
+ for batch_idx, (features, labels) in enumerate(pbar):
1353
+ features = features.cuda(non_blocking=True)
1354
+ labels = labels.cuda(non_blocking=True)
1355
+
1356
+ # Zero gradients
1357
+ optimizer.zero_grad()
1358
+
1359
+ # Forward pass - Pure FP32, no autocast
1360
+ combined, logits_list, features_list, fusion_weights = david(
1361
+ features, anchors_dict, return_all_scales=True
1362
+ )
1363
+
1364
+ # Compute loss
1365
+ losses = criterion(
1366
+ combined, logits_list, features_list,
1367
+ labels, crystals_dict, epoch
1368
+ )
1369
+
1370
+ # Backward
1371
+ losses['total'].backward()
1372
+
1373
+ # Gradient analysis
1374
+ if config.track_gradients and batch_idx % config.log_interval == 0:
1375
+ grad_stats = analyze_gradients(david, config)
1376
+ if writer:
1377
+ step = global_step + batch_idx
1378
+ writer.add_scalar('train/grad_mean', grad_stats['mean'], step)
1379
+ writer.add_scalar('train/grad_max', grad_stats['max'], step)
1380
+ writer.add_scalar('train/grad_num_small', grad_stats['num_small'], step)
1381
+
1382
+ # Scale small gradients
1383
+ scale_small_gradients(david, config)
1384
+
1385
+ # Gradient clipping
1386
+ torch.nn.utils.clip_grad_norm_(david.parameters(), config.gradient_clip)
1387
+
1388
+ # Optimizer step
1389
+ optimizer.step()
1390
+
1391
+ # Metrics
1392
+ total_loss += losses['total'].item()
1393
+ _, predicted = torch.max(combined, 1)
1394
+ total += labels.size(0)
1395
+ correct += (predicted == labels).sum().item()
1396
+
1397
+ # Accumulate loss components
1398
+ for key, value in losses.items():
1399
+ if key not in loss_components_sum:
1400
+ loss_components_sum[key] = 0.0
1401
+ loss_components_sum[key] += value.item()
1402
+
1403
+ # Logging
1404
+ if writer and batch_idx % config.log_interval == 0:
1405
+ step = global_step + batch_idx
1406
+ writer.add_scalar('train/loss_batch', losses['total'].item(), step)
1407
+ writer.add_scalar('train/acc_batch', 100 * correct / total, step)
1408
+
1409
+ if config.log_loss_components:
1410
+ for key, value in losses.items():
1411
+ if key != 'total':
1412
+ writer.add_scalar(f'train/loss_{key}', value.item(), step)
1413
+
1414
+ if config.log_fusion_weights and fusion_weights is not None:
1415
+ if fusion_weights.dim() == 2:
1416
+ mean_weights = fusion_weights.mean(dim=0)
1417
+ for i, w in enumerate(mean_weights):
1418
+ if i < len(active_scales):
1419
+ writer.add_scalar(
1420
+ f'train/fusion_weight_{active_scales[i]}',
1421
+ w.item(), step
1422
+ )
1423
+
1424
+ writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], step)
1425
+
1426
+ pbar.set_postfix({
1427
+ 'loss': f'{total_loss / (batch_idx + 1):.4f}',
1428
+ 'acc': f'{100 * correct / total:.2f}%'
1429
+ })
1430
+
1431
+ global_step += 1
1432
+
1433
+ # Average loss components
1434
+ avg_components = {k: v / len(train_loader) for k, v in loss_components_sum.items()}
1435
+
1436
+ return (
1437
+ total_loss / len(train_loader),
1438
+ 100 * correct / total,
1439
+ global_step,
1440
+ avg_components
1441
+ )
1442
+
1443
+
1444
+ @torch.no_grad()
1445
+ def validate(
1446
+ david: David,
1447
+ val_loader: DataLoader,
1448
+ anchors_dict: Dict[int, torch.Tensor],
1449
+ config: DavidTrainingConfig
1450
+ ) -> Tuple[float, Dict[int, float]]:
1451
+ """Validate model - Pure FP32."""
1452
+
1453
+ david.eval()
1454
+
1455
+ correct = 0
1456
+ total = 0
1457
+ active_scales = david.get_active_scales()
1458
+ scale_correct = {scale: 0 for scale in active_scales}
1459
+
1460
+ for features, labels in tqdm(val_loader, desc="Validation", leave=False):
1461
+ features = features.cuda(non_blocking=True)
1462
+ labels = labels.cuda(non_blocking=True)
1463
+
1464
+ # Forward pass - no autocast
1465
+ combined, logits_list, _, _ = david(
1466
+ features, anchors_dict, return_all_scales=True
1467
+ )
1468
+
1469
+ _, predicted = torch.max(combined, 1)
1470
+ total += labels.size(0)
1471
+ correct += (predicted == labels).sum().item()
1472
+
1473
+ for i, scale in enumerate(active_scales):
1474
+ if i < len(logits_list):
1475
+ _, scale_pred = torch.max(logits_list[i], 1)
1476
+ scale_correct[scale] += (scale_pred == labels).sum().item()
1477
+
1478
+ accuracy = 100 * correct / total
1479
+ scale_accs = {s: 100 * scale_correct[s] / total for s in scale_correct}
1480
+
1481
+ return accuracy, scale_accs
1482
+
1483
+
1484
+ # ============================================================================
1485
+ # MAIN TRAINING FUNCTION
1486
+ # ============================================================================
1487
+
1488
+ def train_david(config: DavidTrainingConfig):
1489
+ """Main training pipeline."""
1490
+
1491
+ # Enable TensorFloat32 for better performance on Ampere+ GPUs
1492
+ torch.set_float32_matmul_precision('high')
1493
+
1494
+ print("="*80)
1495
+ print("🌟 DAVID TRAINING PIPELINE")
1496
+ print("="*80)
1497
+ print(f"Run ID: {config.run_id}")
1498
+ print(f"Preset: {config.preset}")
1499
+ print(f"Batch Size: {config.batch_size}")
1500
+ print(f"Learning Rate: {config.learning_rate}")
1501
+ print(f"Mixed Precision: {config.use_mixed_precision}")
1502
+ print(f"TensorFloat32: Enabled (high precision)")
1503
+ print("="*80)
1504
+
1505
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1506
+
1507
+ # Load or create David config FIRST (needed for model_name)
1508
+ if config.custom_config_path:
1509
+ david_config = DavidArchitectureConfig.from_json(config.custom_config_path)
1510
+ print(f"[📁] Loaded custom config: {config.custom_config_path}")
1511
+ elif config.preset:
1512
+ david_config = DavidPresets.get_preset(config.preset)
1513
+ print(f"[⚙️] Using preset: {config.preset}")
1514
+ else:
1515
+ raise ValueError("Must specify either preset or custom_config_path")
1516
+
1517
+ # Create model name from architecture
1518
+ model_name = f"David-{david_config.sharing_mode}-{david_config.fusion_mode}"
1519
+ print(f"[🏷️] Model: {model_name}")
1520
+
1521
+ # Setup directories with proper hierarchy: weights/model_name/timestamp/
1522
+ weights_dir = os.path.join(config.base_dir, "weights", model_name, config.run_id)
1523
+ runs_dir = os.path.join(config.base_dir, "runs", model_name, config.run_id)
1524
+ os.makedirs(weights_dir, exist_ok=True)
1525
+ os.makedirs(runs_dir, exist_ok=True)
1526
+
1527
+ print(f"[📁] Weights: {weights_dir}")
1528
+ print(f"[📁] Logs: {runs_dir}")
1529
+
1530
+ writer = SummaryWriter(runs_dir)
1531
+
1532
+ # Apply overrides
1533
+ if config.num_classes_override:
1534
+ david_config.num_classes = config.num_classes_override
1535
+ if config.use_belly_override is not None:
1536
+ david_config.use_belly = config.use_belly_override
1537
+ if config.belly_expand_override is not None:
1538
+ david_config.belly_expand = config.belly_expand_override
1539
+ if config.progressive_training_override is not None:
1540
+ david_config.progressive_training = config.progressive_training_override
1541
+ if not david_config.progressive_training:
1542
+ # Disable warmup if progressive training disabled
1543
+ david_config.scale_warmup_epochs = {s: 0 for s in david_config.scales}
1544
+
1545
+ # Override scale warmup schedule if provided
1546
+ if config.scale_warmup_epochs_override is not None:
1547
+ david_config.scale_warmup_epochs = config.scale_warmup_epochs_override
1548
+ # Enable progressive training if custom schedule provided
1549
+ if not david_config.progressive_training:
1550
+ print(f"[⚙️] Enabling progressive training (custom warmup schedule provided)")
1551
+ david_config.progressive_training = True
1552
+
1553
+ print(f"[⚙️] Progressive training: {david_config.progressive_training}")
1554
+ if david_config.progressive_training:
1555
+ print(f" Scale warmup schedule: {david_config.scale_warmup_epochs}")
1556
+
1557
+ # Save configs
1558
+ david_config_path = os.path.join(weights_dir, "david_config.json")
1559
+ david_config.to_json(david_config_path)
1560
+ print(f"[💾] Saved David config: {david_config_path}")
1561
+
1562
+ train_config_path = os.path.join(weights_dir, "train_config.json")
1563
+ config.to_json(train_config_path)
1564
+ print(f"[💾] Saved training config: {train_config_path}")
1565
+
1566
+ # Initialize David
1567
+ david = David.from_config(david_config).cuda()
1568
+ print(f"\n{david}\n")
1569
+
1570
+ # Count parameters
1571
+ total_params = sum(p.numel() for p in david.parameters())
1572
+ trainable_params = sum(p.numel() for p in david.parameters() if p.requires_grad)
1573
+ print(f"[📊] Total Parameters: {total_params:,}")
1574
+ print(f"[📊] Trainable Parameters: {trainable_params:,}")
1575
+
1576
+ # Load data
1577
+ train_loader, val_loader = create_dataloaders(config)
1578
+
1579
+ # Generate crystals
1580
+ crystal_gen = CrystalGenerator(
1581
+ david_config.num_classes,
1582
+ david_config.scales,
1583
+ str(device)
1584
+ )
1585
+ anchors_dict, crystals_dict = crystal_gen.generate()
1586
+
1587
+ # Setup training
1588
+ criterion = MultiScaleCrystalLoss(
1589
+ scales=david_config.scales,
1590
+ num_classes=david_config.num_classes,
1591
+ use_rose_loss=config.use_rose_loss,
1592
+ use_cayley_loss=config.use_cayley_loss,
1593
+ rose_initial_weight=config.rose_initial_weight,
1594
+ rose_max_weight=config.rose_max_weight,
1595
+ cayley_weight=config.cayley_weight,
1596
+ scale_loss_balance=config.scale_loss_balance
1597
+ ).cuda()
1598
+
1599
+ optimizer = create_optimizer(david, config)
1600
+ scheduler = create_scheduler(optimizer, config)
1601
+
1602
+ controller = AdaptiveTrainingController(david, config)
1603
+
1604
+ # Tracking
1605
+ best_val_acc = 0.0
1606
+ best_epoch = 0
1607
+ best_scale_accs = {}
1608
+ global_step = 0
1609
+ final_train_acc = 0.0
1610
+ final_train_loss = 0.0
1611
+
1612
+ # Training history for epoch-by-epoch tracking
1613
+ training_history = {
1614
+ 'epochs': [],
1615
+ 'train_loss': [],
1616
+ 'train_acc': [],
1617
+ 'val_acc': [],
1618
+ 'scale_accs': {},
1619
+ 'lr': []
1620
+ }
1621
+
1622
+ # DIAGNOSTIC: Test one forward/backward pass before training
1623
+ print("\n[🔍] Running diagnostic forward/backward pass...")
1624
+ david.train()
1625
+
1626
+ # Get a small batch
1627
+ for features_test, labels_test in train_loader:
1628
+ features_test = features_test.cuda(non_blocking=True)[:8] # Just 8 samples
1629
+ labels_test = labels_test.cuda(non_blocking=True)[:8]
1630
+
1631
+ # Forward
1632
+ combined_test, logits_test, features_test_out, _ = david(
1633
+ features_test, anchors_dict, return_all_scales=True
1634
+ )
1635
+
1636
+ # Loss
1637
+ losses_test = criterion(
1638
+ combined_test, logits_test, features_test_out,
1639
+ labels_test, crystals_dict, epoch=0
1640
+ )
1641
+
1642
+ print(f" Initial loss: {losses_test['total'].item():.6f}")
1643
+ print(f" Loss components:")
1644
+ for key, value in losses_test.items():
1645
+ if key != 'total':
1646
+ print(f" {key}: {value.item():.6f}")
1647
+
1648
+ # Backward
1649
+ optimizer.zero_grad()
1650
+ losses_test['total'].backward()
1651
+
1652
+ # Check gradients
1653
+ grad_count = sum(1 for p in david.parameters() if p.grad is not None and p.grad.norm() > 0)
1654
+ total_grad_params = sum(1 for p in david.parameters() if p.requires_grad)
1655
+ print(f" Parameters with non-zero gradients: {grad_count}/{total_grad_params}")
1656
+
1657
+ if grad_count == 0:
1658
+ print(f" ❌ ERROR: No gradients! Training will not work.")
1659
+ return None, 0.0
1660
+ elif grad_count < total_grad_params * 0.5:
1661
+ print(f" ⚠️ WARNING: Less than 50% of parameters have gradients")
1662
+ else:
1663
+ print(f" ✅ Gradients look good")
1664
+
1665
+ break # Only test one batch
1666
+
1667
+ print("\n[🚀] Starting training...\n")
1668
+
1669
+ for epoch in range(config.num_epochs):
1670
+ epoch_start = time.time()
1671
+
1672
+ # Train
1673
+ train_loss, train_acc, global_step, loss_components = train_epoch(
1674
+ david, train_loader, optimizer, criterion,
1675
+ anchors_dict, crystals_dict, epoch, config,
1676
+ writer, global_step
1677
+ )
1678
+
1679
+ # Validate
1680
+ val_acc, scale_accs = validate(david, val_loader, anchors_dict, config)
1681
+
1682
+ # Update controller
1683
+ controller.update_metrics(scale_accs, val_acc)
1684
+ controller.apply_adaptive_strategies(scale_accs, epoch)
1685
+
1686
+ # Step scheduler
1687
+ if scheduler:
1688
+ scheduler.step()
1689
+
1690
+ epoch_time = time.time() - epoch_start
1691
+
1692
+ # Print
1693
+ print(f"\n📊 Epoch {epoch+1}/{config.num_epochs} ({epoch_time:.1f}s)")
1694
+ print(f" Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
1695
+ print(f" Val: Acc={val_acc:.2f}% (Best: {best_val_acc:.2f}%)")
1696
+ print(f" Active scales: {david.get_active_scales()}")
1697
+ print(f" LR: {optimizer.param_groups[0]['lr']:.2e}")
1698
+
1699
+ if config.log_loss_components and loss_components:
1700
+ print(f" Loss breakdown:")
1701
+ for key, value in sorted(loss_components.items()):
1702
+ if key != 'total':
1703
+ print(f" {key:20s}: {value:.6f}")
1704
+
1705
+ for scale, acc in scale_accs.items():
1706
+ frozen = "❄️" if controller.scales_frozen.get(scale, False) else "🔥"
1707
+ print(f" {frozen} Scale {scale}: {acc:.2f}%")
1708
+
1709
+ # Update tracking
1710
+ final_train_acc = train_acc
1711
+ final_train_loss = train_loss
1712
+
1713
+ # Record training history
1714
+ training_history['epochs'].append(epoch + 1)
1715
+ training_history['train_loss'].append(train_loss)
1716
+ training_history['train_acc'].append(train_acc)
1717
+ training_history['val_acc'].append(val_acc)
1718
+ training_history['lr'].append(optimizer.param_groups[0]['lr'])
1719
+
1720
+ # Record per-scale accuracies
1721
+ for scale, acc in scale_accs.items():
1722
+ if scale not in training_history['scale_accs']:
1723
+ training_history['scale_accs'][scale] = []
1724
+ training_history['scale_accs'][scale].append(acc)
1725
+
1726
+ # TensorBoard
1727
+ writer.add_scalar('train/loss', train_loss, epoch)
1728
+ writer.add_scalar('train/acc', train_acc, epoch)
1729
+ writer.add_scalar('val/acc', val_acc, epoch)
1730
+
1731
+ for scale, acc in scale_accs.items():
1732
+ writer.add_scalar(f'val/acc_scale_{scale}', acc, epoch)
1733
+
1734
+ # Save best
1735
+ if val_acc > best_val_acc:
1736
+ best_val_acc = val_acc
1737
+ best_epoch = epoch
1738
+ best_scale_accs = scale_accs.copy()
1739
+
1740
+ # Save training history alongside best model
1741
+ history_path = os.path.join(weights_dir, 'training_history.json')
1742
+ with open(history_path, 'w') as f:
1743
+ json.dump(training_history, f, indent=2)
1744
+
1745
+ save_checkpoint(
1746
+ os.path.join(weights_dir, 'best_model'),
1747
+ david, optimizer, scheduler, epoch,
1748
+ {
1749
+ 'best_val_acc': best_val_acc,
1750
+ 'best_epoch': best_epoch,
1751
+ 'scale_accuracies': best_scale_accs,
1752
+ 'training_history': training_history
1753
+ },
1754
+ config
1755
+ )
1756
+
1757
+ # Upload to hub when best model improves
1758
+ if config.upload_to_hub:
1759
+ best_metrics = {
1760
+ 'best_val_acc': best_val_acc,
1761
+ 'best_epoch': best_epoch,
1762
+ 'scale_accuracies': best_scale_accs,
1763
+ 'final_train_acc': train_acc,
1764
+ 'final_train_loss': train_loss,
1765
+ 'training_history': training_history,
1766
+ 'parameters': total_params
1767
+ }
1768
+ prepare_hub_upload(weights_dir, runs_dir, config, david_config, best_metrics, model_name)
1769
+
1770
+ # Periodic save
1771
+ if (epoch + 1) % config.save_interval == 0:
1772
+ save_checkpoint(
1773
+ os.path.join(weights_dir, f'checkpoint_epoch_{epoch+1}'),
1774
+ david, optimizer, scheduler, epoch,
1775
+ {'val_acc': val_acc},
1776
+ config
1777
+ )
1778
+
1779
+ # Final save
1780
+ save_checkpoint(
1781
+ os.path.join(weights_dir, 'final_model'),
1782
+ david, optimizer, scheduler, config.num_epochs - 1,
1783
+ {'final_val_acc': val_acc},
1784
+ config
1785
+ )
1786
+
1787
+ writer.close()
1788
+
1789
+ # Final hub upload with all artifacts
1790
+ if config.upload_to_hub:
1791
+ print("\n[🤗] Performing final HuggingFace Hub upload...")
1792
+ final_metrics = {
1793
+ 'best_val_acc': best_val_acc,
1794
+ 'best_epoch': best_epoch,
1795
+ 'scale_accuracies': best_scale_accs,
1796
+ 'final_train_acc': final_train_acc,
1797
+ 'final_train_loss': final_train_loss,
1798
+ 'training_history': training_history,
1799
+ 'parameters': total_params
1800
+ }
1801
+ prepare_hub_upload(weights_dir, runs_dir, config, david_config, final_metrics, model_name)
1802
+
1803
+ # Upload TensorBoard logs at the end
1804
+ if os.path.exists(runs_dir):
1805
+ runs_repo_path = f"runs/{model_name}/{config.run_id}"
1806
+ print(f"[📤] Uploading TensorBoard logs to {runs_repo_path}...")
1807
+ upload_to_huggingface(
1808
+ local_dir=runs_dir,
1809
+ repo_id=config.hf_repo,
1810
+ commit_message=f"Upload TensorBoard logs - {model_name} - Run {config.run_id}",
1811
+ path_in_repo=runs_repo_path
1812
+ )
1813
+
1814
+ print("\n" + "="*80)
1815
+ print(f"🎉 Training Complete!")
1816
+ print(f" Best Val Acc: {best_val_acc:.2f}% (Epoch {best_epoch+1})")
1817
+ print(f" Final Train Acc: {final_train_acc:.2f}%")
1818
+ print(f" Weights: {weights_dir}")
1819
+ if config.upload_to_hub:
1820
+ print(f" Hub: https://huggingface.co/{config.hf_repo}")
1821
+ print("="*80)
1822
+
1823
+ return david, best_val_acc
1824
+
1825
+
1826
+ # ============================================================================
1827
+ # USAGE EXAMPLE
1828
+ # ============================================================================
1829
+
1830
+ if __name__ == "__main__":
1831
+ # ============================================================================
1832
+ # EXPERIMENT 1: Single Encoder (Standard Training)
1833
+ # ============================================================================
1834
+
1835
+ # config = DavidTrainingConfig(
1836
+ # preset="balanced",
1837
+ # model_variant="clip_vit_b16", # Single encoder
1838
+ #
1839
+ # num_epochs=10,
1840
+ # batch_size=1024,
1841
+ # learning_rate=1e-2,
1842
+ #
1843
+ # use_rose_loss=True,
1844
+ # rose_initial_weight=0.1,
1845
+ # rose_max_weight=0.5,
1846
+ #
1847
+ # upload_to_hub=True,
1848
+ # hf_repo="AbstractPhil/gated-david",
1849
+ # )
1850
+
1851
+ # ============================================================================
1852
+ # EXPERIMENT 2: Multi-Encoder Unified Space (THE TEST!)
1853
+ # ============================================================================
1854
+
1855
+ config = DavidTrainingConfig(
1856
+ preset="balanced", # 4 scales: [256, 512, 768, 1024]
1857
+
1858
+ # 🧪 MULTI-ENCODER: OpenAI CLIP-B/32 vs LAION CLIP-B/32
1859
+ model_variant=["clip_vit_b16", "clip_vit_laion_b32"], # Both B/32!
1860
+
1861
+ num_epochs=10,
1862
+ batch_size=1024,
1863
+ learning_rate=1e-2,
1864
+
1865
+ # Custom warmup for 4 scales
1866
+ scale_warmup_epochs_override={
1867
+ 256: 0,
1868
+ 512: 2,
1869
+ 768: 5,
1870
+ 1024: 8
1871
+ },
1872
+
1873
+ use_rose_loss=True,
1874
+ rose_initial_weight=0.2, # Higher for diversity
1875
+ rose_max_weight=0.8,
1876
+
1877
+ use_cayley_loss=True, # Extra geometric regularization
1878
+ cayley_weight=0.01,
1879
+
1880
+ freeze_strategy="never",
1881
+ gradient_clip=10.0,
1882
+
1883
+ save_format="safetensors",
1884
+ upload_to_hub=False,
1885
+ hf_repo="YourName/YourRepoHere"#"AbstractPhil/david-shared-space",
1886
+ )
1887
+
1888
+ print("="*80)
1889
+ print("🧪 UNIFIED SPACE EXPERIMENT")
1890
+ print("="*80)
1891
+ print(f"Testing if David can unify:")
1892
+ if isinstance(config.model_variant, list):
1893
+ for variant in config.model_variant:
1894
+ print(f" • {variant}")
1895
+ print("="*80)
1896
+
1897
+ david, best_acc = train_david(config)