david-shared-space / model_trainer.py
AbstractPhil's picture
Create model_trainer.py
2e01525 verified
raw
history blame
69.3 kB
"""
David Training Pipeline
========================
Training pipeline for David multi-scale crystal classifier.
Should be placed at: geovocab2/train/model/core/david_trainer.py
Or run from: scripts/train_david.py
Features:
- Pure fp32 training (no mixed precision for geometric stability)
- Adaptive training controller (freeze/unfreeze scales)
- Gradient analysis and scaling
- SafeTensors checkpoint support
- Enhanced loss component tracking
- Proper weight organization: weights/model_name/timestamp/
- Accuracy in filenames and comprehensive tracking
- Master models index (MODELS_INDEX.json)
"""
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from datasets import load_dataset
from huggingface_hub import HfApi, create_repo, upload_folder, upload_file
import numpy as np
import os
import json
import time
import tempfile
from datetime import datetime
from tqdm.auto import tqdm
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from dataclasses import dataclass, field, asdict
# Import David components
from geovocab2.train.config.david_config import (
DavidArchitectureConfig,
DavidPresets,
SharingMode,
FusionMode
)
from geovocab2.train.model.core.david import (
David,
MultiScaleCrystalLoss,
)
# Import SimplexFactory
from geovocab2.shapes.factory import SimplexFactory
# ============================================================================
# TRAINING CONFIGURATION
# ============================================================================
@dataclass
class DavidTrainingConfig:
"""
Complete training configuration for David.
Separate from model architecture config.
"""
# Metadata
name: str = "david_training"
run_id: str = "" # Auto-generated timestamp
# Dataset
dataset_name: str = "AbstractPhil/imagenet-clip-features-orderly"
model_variant: Union[str, List[str]] = "clip_vit_b16" # Single or list for multi-encoder
num_classes: int = 1000
# Model architecture (references to david_config)
preset: Optional[str] = "balanced" # Or None to use custom config
custom_config_path: Optional[str] = None # Path to custom david_config.json
# Architecture overrides (applied to preset or custom config)
num_classes_override: Optional[int] = None
use_belly_override: Optional[bool] = None
belly_expand_override: Optional[float] = None
progressive_training_override: Optional[bool] = True # Override progressive training
scale_warmup_epochs_override: Optional[Dict[int, int]] = None # Custom warmup schedule
# Training hyperparameters
num_epochs: int = 50
batch_size: int = 512
learning_rate: float = 5e-3
weight_decay: float = 1e-5
warmup_epochs: int = 3
# Loss weights
use_rose_loss: bool = True
rose_initial_weight: float = 0.01
rose_max_weight: float = 0.1
rose_weight_schedule: str = "adaptive"
use_cayley_loss: bool = False
cayley_weight: float = 0.001
scale_loss_balance: Optional[Dict[int, float]] = None
# Optimization
use_mixed_precision: bool = False # Keep False for stability
gradient_clip: float = 5.0
scheduler_type: str = "cosine_restarts"
min_lr: float = 1e-6
# Adaptive training (safer defaults)
freeze_strategy: str = "never" # "performance" or "never"
freeze_threshold: float = 90.0 # Only freeze when scale hits 90% accuracy
unfreeze_on_plateau: bool = True
patience: int = 10
# Gradient monitoring
track_gradients: bool = True
gradient_scale_threshold: float = 1e-5
gradient_scale_multiplier: float = 10.0
# Logging
log_interval: int = 50
val_interval: int = 1
save_interval: int = 5
log_fusion_weights: bool = True
log_loss_components: bool = True
# Checkpointing
save_format: str = "both" # "pytorch", "safetensors", or "both"
# HuggingFace Hub (optional)
hf_repo: Optional[str] = "" #"AbstractPhil/gated-david" # Your HF repo
upload_to_hub: bool = False
# Local paths
base_dir: str = "./david_training"
# Hardware
num_workers: int = 10
pin_memory: bool = True
prefetch_factor: int = 4
persistent_workers: bool = True
def __post_init__(self):
"""Generate run_id if not provided."""
if not self.run_id:
self.run_id = datetime.now().strftime('%Y%m%d_%H%M%S')
def to_dict(self) -> dict:
"""Convert to dictionary."""
return asdict(self)
@classmethod
def from_dict(cls, data: dict) -> 'DavidTrainingConfig':
"""Create from dictionary."""
return cls(**data)
def to_json(self, path: str):
"""Save to JSON."""
data = self.to_dict()
# Convert any nested dicts with int keys to str keys
if data.get('scale_loss_balance'):
data['scale_loss_balance'] = {
str(k): v for k, v in data['scale_loss_balance'].items()
}
if data.get('scale_warmup_epochs_override'):
data['scale_warmup_epochs_override'] = {
str(k): v for k, v in data['scale_warmup_epochs_override'].items()
}
with open(path, 'w') as f:
json.dump(data, f, indent=2)
@classmethod
def from_json(cls, path: str) -> 'DavidTrainingConfig':
"""Load from JSON."""
with open(path, 'r') as f:
data = json.load(f)
# Convert str keys back to int for scale_loss_balance
if 'scale_loss_balance' in data and data['scale_loss_balance']:
data['scale_loss_balance'] = {
int(k): v for k, v in data['scale_loss_balance'].items()
}
# Convert str keys back to int for scale_warmup_epochs_override
if 'scale_warmup_epochs_override' in data and data['scale_warmup_epochs_override']:
data['scale_warmup_epochs_override'] = {
int(k): v for k, v in data['scale_warmup_epochs_override'].items()
}
return cls(**data)
# ============================================================================
# ADAPTIVE TRAINING CONTROLLER
# ============================================================================
class AdaptiveTrainingController:
"""Manages adaptive training strategies for multi-scale model."""
def __init__(self, model: David, config: DavidTrainingConfig):
self.model = model
self.config = config
scales = model.scales
self.scale_history = {scale: [] for scale in scales}
self.best_scale_acc = {scale: 0.0 for scale in scales}
self.scales_frozen = {scale: False for scale in scales}
self.overall_history = []
self.plateau_counter = 0
self.best_overall = 0.0
def update_metrics(self, scale_accuracies: Dict[int, float], overall_accuracy: float):
"""Update metrics and best scores."""
for scale, acc in scale_accuracies.items():
self.scale_history[scale].append(acc)
if acc > self.best_scale_acc[scale]:
self.best_scale_acc[scale] = acc
self.overall_history.append(overall_accuracy)
if overall_accuracy > self.best_overall:
self.best_overall = overall_accuracy
self.plateau_counter = 0
else:
self.plateau_counter += 1
def should_freeze_scale(self, scale: int, current_acc: float) -> bool:
"""Determine if a scale should be frozen."""
if self.config.freeze_strategy == "never":
return False
if self.scales_frozen[scale]:
return False
if self.config.freeze_strategy == "performance":
return current_acc >= self.config.freeze_threshold
return False
def should_unfreeze_scales(self) -> bool:
"""Check if scales should be unfrozen due to plateau."""
if not self.config.unfreeze_on_plateau:
return False
return self.plateau_counter >= 5
def apply_adaptive_strategies(self, scale_accuracies: Dict[int, float], epoch: int):
"""Apply freeze/unfreeze based on performance."""
active_scales = self.model.get_active_scales()
# Don't freeze scales if it would leave no trainable parameters
for scale, acc in scale_accuracies.items():
if self.should_freeze_scale(scale, acc):
# Count how many active scales would remain unfrozen
active_unfrozen = [s for s in active_scales if not self.scales_frozen.get(s, False)]
if len(active_unfrozen) <= 1:
print(f"[⚠️] Skipping freeze of scale {scale} (would leave no active trainable scales)")
continue
self.model.freeze_scale(scale)
self.scales_frozen[scale] = True
print(f"[❄️] Froze scale {scale} (acc={acc:.2f}%)")
if self.should_unfreeze_scales() and any(self.scales_frozen.values()):
for scale in self.model.scales:
if self.scales_frozen[scale]:
self.model.unfreeze_scale(scale)
self.scales_frozen[scale] = False
self.plateau_counter = 0
print(f"[🔥] Unfroze all scales due to plateau")
# ============================================================================
# OPTIMIZER & SCHEDULER CREATION
# ============================================================================
def create_optimizer(david: David, config: DavidTrainingConfig) -> torch.optim.Optimizer:
"""Create optimizer with parameter groups."""
param_groups = []
# Shared parameters (if exists)
if hasattr(david, 'shared_extractor'):
param_groups.append({
'params': david.shared_extractor.parameters(),
'lr': config.learning_rate,
'name': 'shared'
})
elif hasattr(david, 'shared_base'):
param_groups.append({
'params': david.shared_base.parameters(),
'lr': config.learning_rate,
'name': 'shared'
})
# Scale-specific parameters
for scale in david.scales:
scale_params = []
if david.sharing_mode == SharingMode.HIERARCHICAL:
head = getattr(david, f'head_{scale}', None)
if head:
scale_params.extend(head.parameters())
refine = getattr(david, f'refine_{scale}', None)
if refine:
scale_params.extend(refine.parameters())
else:
scale_params.extend(david.heads[str(scale)].parameters())
if scale_params:
param_groups.append({
'params': scale_params,
'lr': config.learning_rate,
'name': f'scale_{scale}'
})
# Fusion parameters
if hasattr(david, 'fusion'):
param_groups.append({
'params': david.fusion.parameters(),
'lr': config.learning_rate * 0.5,
'name': 'fusion'
})
elif hasattr(david, 'fusion_weights'):
param_groups.append({
'params': [david.fusion_weights],
'lr': config.learning_rate * 0.5,
'name': 'fusion'
})
return torch.optim.AdamW(param_groups, weight_decay=config.weight_decay)
def create_scheduler(optimizer: torch.optim.Optimizer,
config: DavidTrainingConfig) -> torch.optim.lr_scheduler._LRScheduler:
"""Create learning rate scheduler."""
if config.scheduler_type == "cosine_restarts":
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=10, T_mult=2, eta_min=config.min_lr
)
elif config.scheduler_type == "cosine":
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=config.num_epochs, eta_min=config.min_lr
)
else:
return None
# ============================================================================
# GRADIENT ANALYSIS
# ============================================================================
def analyze_gradients(model: David, config: DavidTrainingConfig) -> Dict[str, float]:
"""Analyze gradient magnitudes for debugging."""
grad_stats = {
'mean': 0.0,
'max': 0.0,
'min': float('inf'),
'num_zero': 0,
'num_small': 0,
'total': 0
}
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
grad_stats['mean'] += grad_norm
grad_stats['max'] = max(grad_stats['max'], grad_norm)
grad_stats['min'] = min(grad_stats['min'], grad_norm)
grad_stats['total'] += 1
if grad_norm < 1e-10:
grad_stats['num_zero'] += 1
elif grad_norm < config.gradient_scale_threshold:
grad_stats['num_small'] += 1
if grad_stats['total'] > 0:
grad_stats['mean'] /= grad_stats['total']
return grad_stats
def scale_small_gradients(model: David, config: DavidTrainingConfig):
"""Scale up very small gradients to prevent vanishing."""
if not config.track_gradients:
return
for param in model.parameters():
if param.grad is not None:
grad_norm = param.grad.norm()
if grad_norm < config.gradient_scale_threshold and grad_norm > 0:
param.grad.mul_(config.gradient_scale_multiplier)
# ============================================================================
# HUGGINGFACE HUB UTILITIES
# ============================================================================
def generate_model_readme(
config: DavidTrainingConfig,
david_config: DavidArchitectureConfig,
best_metrics: Dict,
run_id: str
) -> str:
"""Generate README.md for model card."""
readme = f"""---
language: en
license: mit
tags:
- image-classification
- imagenet
- multi-scale
- feature-geometry
- david
datasets:
- imagenet-1k
metrics:
- accuracy
model-index:
- name: David-{david_config.sharing_mode}-{david_config.fusion_mode}
results:
- task:
type: image-classification
dataset:
name: ImageNet-1K
type: imagenet-1k
metrics:
- type: accuracy
value: {best_metrics.get('best_val_acc', 0.0):.2f}
---
# David: Multi-Scale Feature Classifier
**David** is a multi-scale deep learning classifier that uses feature geometry (pentachora/4-simplexes)
as class prototypes with role-weighted similarity computation (Rose Loss).
This version is using multiple variations of clip-vit inputs simultaneously into shared space.
The experiment will determine if entirely deviant variations such as clip-vit-b-patch32 and patch16 can
exist simultaneously in the same shared space with the correct checks and spacings applied.
## Model Details
### Architecture
- **Preset**: {config.preset}
- **Sharing Mode**: {david_config.sharing_mode}
- **Fusion Mode**: {david_config.fusion_mode}
- **Scales**: {david_config.scales}
- **Feature Dim**: {david_config.feature_dim}
- **Parameters**: {best_metrics.get('parameters', 0):,}
### Training Configuration
- **Dataset**: {config.dataset_name}
- **Model Variant**: {config.model_variant}
- **Epochs**: {config.num_epochs}
- **Batch Size**: {config.batch_size}
- **Learning Rate**: {config.learning_rate}
- **Rose Loss Weight**: {config.rose_initial_weight}{config.rose_max_weight}
- **Cayley Loss**: {config.use_cayley_loss}
## Performance
### Best Results
- **Validation Accuracy**: {best_metrics.get('best_val_acc', 0.0):.2f}%
- **Best Epoch**: {best_metrics.get('best_epoch', 0)}
- **Final Train Accuracy**: {best_metrics.get('final_train_acc', 0.0):.2f}%
### Per-Scale Performance
"""
if 'scale_accuracies' in best_metrics:
for scale, acc in best_metrics['scale_accuracies'].items():
readme += f"- **Scale {scale}**: {acc:.2f}%\n"
readme += f"""
## Usage
### Quick Model Lookup
**Check `MODELS_INDEX.json` in the repo root** - it lists all trained models sorted by accuracy with links to weights and configs.
### Repository Structure
```
{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}/
├── MODELS_INDEX.json # 📊 Master index of all models (sorted by accuracy)
├── README.md # This file
├── best_model.json # Latest best model info
├── weights/
│ └── {david_config.name}/
│ └── {run_id}/
│ ├── MODEL_SUMMARY.txt # 🎯 Human-readable performance summary
│ ├── training_history.json # 📈 Epoch-by-epoch training curve
│ ├── best_model_acc{best_metrics.get('best_val_acc', 0.0):.2f}.safetensors # ⭐ Accuracy in filename!
│ ├── best_model_acc{best_metrics.get('best_val_acc', 0.0):.2f}_metadata.json
│ ├── final_model.safetensors
│ ├── checkpoint_epoch_X_accYY.YY.safetensors
│ ├── david_config.json
│ └── train_config.json
└── runs/
└── {david_config.name}/
└── {run_id}/
└── events.out.tfevents.* # TensorBoard logs
```
### Loading the Model
```python
from geovocab2.train.model.core.david import David, DavidArchitectureConfig
from huggingface_hub import hf_hub_download
# Browse available models in MODELS_INDEX.json first!
# Specify model variant and run
model_name = "{david_config.name}"
run_id = "{run_id}"
accuracy = "{best_metrics.get('best_val_acc', 0.0):.2f}" # From MODELS_INDEX.json
# Download config
config_path = hf_hub_download(
repo_id="{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}",
filename=f"weights/{{model_name}}/{{run_id}}/david_config.json"
)
config = DavidArchitectureConfig.from_json(config_path)
# Download weights (accuracy in filename!)
weights_path = hf_hub_download(
repo_id="{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}",
filename=f"weights/{{model_name}}/{{run_id}}/best_model_acc{{accuracy}}.safetensors"
)
# Download training history (optional - see full training curve)
history_path = hf_hub_download(
repo_id="{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}",
filename=f"weights/{{model_name}}/{{run_id}}/training_history.json"
)
# Load model
from safetensors.torch import load_file
david = David.from_config(config)
david.load_state_dict(load_file(weights_path))
david.eval()
```
### Inference
```python
import torch
import torch.nn.functional as F
# Assuming you have CLIP features (512-dim for ViT-B/16)
features = get_clip_features(image) # [1, 512]
# Load anchors
anchors_dict = torch.load("anchors.pth")
# Forward pass
with torch.no_grad():
logits, _ = david(features, anchors_dict)
predictions = logits.argmax(dim=-1)
```
## Architecture Overview
### Multi-Scale Processing
David processes inputs at multiple scales ({', '.join(map(str, david_config.scales))}),
allowing it to capture both coarse and fine-grained features.
### Shared Representation Space
This variation shares multiple versions of clip-vit models in the same representation space.
### Feature Geometry
Each class is represented by a pentachoron (4-simplex) in embedding space with 5 vertices:
- **Anchor**: Primary class representative
- **Need**: Complementary direction
- **Relation**: Contextual alignment
- **Purpose**: Functional direction
- **Observer**: Meta-perspective
### Rose Loss
Similarity computation uses role-weighted cosine similarities:
```
score = w_anchor * sim(z, anchor) + w_need * sim(z, need) + ...
```
### Fusion Strategy
**{david_config.fusion_mode}**: Intelligently combines predictions from multiple scales.
## Training Details
### Loss Components
- **Cross-Entropy**: Standard classification loss
- **Rose Loss**: Pentachora role-weighted margin loss (weight: {config.rose_initial_weight}{config.rose_max_weight})
- **Cayley Loss**: Geometric regularization ({'enabled' if config.use_cayley_loss else 'disabled'})
### Optimization
- **Optimizer**: AdamW
- **Weight Decay**: {config.weight_decay}
- **Scheduler**: {config.scheduler_type}
- **Gradient Clip**: {config.gradient_clip}
- **Mixed Precision**: {config.use_mixed_precision}
## Citation
```bibtex
@software{{david_classifier_2025,
title = {{David: Multi-Scale Feature Classifier}},
author = {{AbstractPhil}},
year = {{2025}},
url = {{https://huggingface.co/{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}}},
note = {{Run ID: {run_id}}}
}}
```
## License
MIT License
## Acknowledgments
Built with feature lattice geometry and multi-scale deep learning.
Special thanks to Claude (Anthropic) for debugging assistance.
---
*Generated on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*
"""
return readme
def save_best_model_json(
filepath: str,
metrics: Dict,
config: DavidTrainingConfig,
david_config: DavidArchitectureConfig
):
"""Save best_model.json with comprehensive metrics."""
model_name = f"David-{david_config.sharing_mode}-{david_config.fusion_mode}"
best_model_info = {
"model_name": model_name,
"run_id": config.run_id,
"timestamp": datetime.now().isoformat(),
# Best metrics
"best_val_acc": metrics.get('best_val_acc', 0.0),
"best_epoch": metrics.get('best_epoch', 0),
"final_train_acc": metrics.get('final_train_acc', 0.0),
"final_train_loss": metrics.get('final_train_loss', 0.0),
# Per-scale performance
"scale_accuracies": metrics.get('scale_accuracies', {}),
# Architecture
"architecture": {
"preset": config.preset,
"sharing_mode": david_config.sharing_mode,
"fusion_mode": david_config.fusion_mode,
"scales": david_config.scales,
"feature_dim": david_config.feature_dim,
"num_classes": david_config.num_classes,
"use_belly": david_config.use_belly,
"belly_expand": david_config.belly_expand,
},
# Training config
"training": {
"dataset": config.dataset_name,
"model_variant": config.model_variant,
"num_epochs": config.num_epochs,
"batch_size": config.batch_size,
"learning_rate": config.learning_rate,
"rose_weight": f"{config.rose_initial_weight}{config.rose_max_weight}",
"cayley_loss": config.use_cayley_loss,
"optimizer": "AdamW",
"scheduler": config.scheduler_type,
},
# Files (organized by model/run)
"files": {
"weights_safetensors": f"weights/{model_name}/{config.run_id}/best_model_acc{metrics.get('best_val_acc', 0.0):.2f}.safetensors",
"weights_pytorch": f"weights/{model_name}/{config.run_id}/best_model.pth",
"config": f"weights/{model_name}/{config.run_id}/david_config.json",
"training_config": f"weights/{model_name}/{config.run_id}/train_config.json",
"tensorboard": f"runs/{model_name}/{config.run_id}/"
}
}
with open(filepath, 'w') as f:
json.dump(best_model_info, f, indent=2)
print(f"[📄] Saved best_model.json: {filepath}")
def create_model_summary(
weights_dir: str,
config: DavidTrainingConfig,
david_config: DavidArchitectureConfig,
best_metrics: Dict,
model_name: str
):
"""Create prominent model summary with accuracy front and center."""
summary_path = os.path.join(weights_dir, 'MODEL_SUMMARY.txt')
best_acc = best_metrics.get('best_val_acc', 0.0)
training_history = best_metrics.get('training_history', {})
summary = f"""
╔══════════════════════════════════════════════════════════════╗
║ DAVID MODEL SUMMARY ║
╠══════════════════════════════════════════════════════════════╣
║ ║
║ 🎯 VALIDATION ACCURACY: {best_acc:.2f}% ║
║ ║
╚══════════════════════════════════════════════════════════════╝
MODEL: {model_name}
RUN ID: {config.run_id}
BEST EPOCH: {best_metrics.get('best_epoch', 0) + 1}/{config.num_epochs}
═══════════════════════════════════════════════════════════════
📊 PERFORMANCE BREAKDOWN
Final Training Accuracy: {best_metrics.get('final_train_acc', 0.0):.2f}%
Best Validation Accuracy: {best_acc:.2f}%
Per-Scale Accuracies:
"""
scale_accs = best_metrics.get('scale_accuracies', {})
for scale in sorted(scale_accs.keys()):
acc = scale_accs[scale]
summary += f" • Scale {scale:4d}: {acc:.2f}%\n"
summary += f"""
═══════════════════════════════════════════════════════════════
🏗️ ARCHITECTURE
Preset: {config.preset}
Sharing Mode: {david_config.sharing_mode}
Fusion Mode: {david_config.fusion_mode}
Scales: {len(david_config.scales)} scales - {david_config.scales}
Feature Dim: {david_config.feature_dim}
Parameters: {best_metrics.get('parameters', 0):,}
═══════════════════════════════════════════════════════════════
📈 TRAINING CURVE
"""
if training_history and 'val_acc' in training_history:
summary += "Epoch | Train Acc | Val Acc | Learning Rate\n"
summary += "------|-----------|----------|--------------\n"
for i, epoch in enumerate(training_history.get('epochs', [])):
train_acc = training_history['train_acc'][i] if i < len(training_history['train_acc']) else 0
val_acc = training_history['val_acc'][i] if i < len(training_history['val_acc']) else 0
lr = training_history['lr'][i] if i < len(training_history['lr']) else 0
marker = " 👑" if val_acc == best_acc else ""
summary += f"{epoch:5d} | {train_acc:8.2f}% | {val_acc:7.2f}%{marker} | {lr:.2e}\n"
summary += f"""
═══════════════════════════════════════════════════════════════
📁 FILES
Best Model: best_model_acc{best_acc:.2f}.safetensors
Config: david_config.json
Training Cfg: train_config.json
History: training_history.json
═══════════════════════════════════════════════════════════════
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
"""
with open(summary_path, 'w') as f:
f.write(summary)
print(f"[📄] Created MODEL_SUMMARY.txt")
return summary_path
def update_models_index(
config: DavidTrainingConfig,
david_config: DavidArchitectureConfig,
best_metrics: Dict,
model_name: str
):
"""Update master models index file tracking all trained models."""
if not config.upload_to_hub or not config.hf_repo:
return
try:
from huggingface_hub import hf_hub_download
api = HfApi()
# Try to download existing index
try:
index_path = hf_hub_download(
repo_id=config.hf_repo,
filename="MODELS_INDEX.json",
repo_type="model"
)
with open(index_path, 'r') as f:
models_index = json.load(f)
except:
# Create new index if doesn't exist
models_index = {
"repository": config.hf_repo,
"updated": datetime.now().isoformat(),
"models": []
}
# Add current model entry
model_entry = {
"model_name": model_name,
"run_id": config.run_id,
"timestamp": datetime.now().isoformat(),
"best_val_acc": best_metrics.get('best_val_acc', 0.0),
"best_epoch": best_metrics.get('best_epoch', 0),
"num_scales": len(david_config.scales),
"scales": david_config.scales,
"parameters": best_metrics.get('parameters', 0),
"sharing_mode": david_config.sharing_mode,
"fusion_mode": david_config.fusion_mode,
"preset": config.preset,
"weights_path": f"weights/{model_name}/{config.run_id}/best_model_acc{best_metrics.get('best_val_acc', 0.0):.2f}.safetensors",
"config_path": f"weights/{model_name}/{config.run_id}/david_config.json",
"history_path": f"weights/{model_name}/{config.run_id}/training_history.json"
}
# Remove old entry for same run_id if exists (update)
models_index["models"] = [m for m in models_index["models"] if m.get("run_id") != config.run_id]
models_index["models"].append(model_entry)
# Sort by accuracy (descending)
models_index["models"].sort(key=lambda x: x.get("best_val_acc", 0), reverse=True)
models_index["updated"] = datetime.now().isoformat()
models_index["total_models"] = len(models_index["models"])
# Save locally
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as f:
json.dump(models_index, f, indent=2)
temp_path = f.name
# Upload to hub root
api.upload_file(
path_or_fileobj=temp_path,
path_in_repo="MODELS_INDEX.json",
repo_id=config.hf_repo,
commit_message=f"Update models index - {model_name} @ {best_metrics.get('best_val_acc', 0.0):.2f}%"
)
os.unlink(temp_path)
print(f"[📊] Updated MODELS_INDEX.json - {len(models_index['models'])} models tracked")
except Exception as e:
print(f"[⚠️] Failed to update models index: {e}")
def upload_to_huggingface(
local_dir: str,
repo_id: str,
commit_message: str,
path_in_repo: Optional[str] = None,
patterns: Optional[List[str]] = None
):
"""Upload directory to HuggingFace Hub."""
try:
api = HfApi()
# Create repo if it doesn't exist
try:
create_repo(repo_id, exist_ok=True, repo_type="model")
print(f"[🤗] Repo ready: {repo_id}")
except Exception as e:
print(f"[⚠️] Repo exists or creation failed: {e}")
# Upload folder
if patterns:
# Upload specific patterns
for pattern in patterns:
matching_files = list(Path(local_dir).rglob(pattern))
for file_path in matching_files:
rel_path = file_path.relative_to(local_dir)
if path_in_repo:
repo_path = f"{path_in_repo}/{rel_path}"
else:
repo_path = str(rel_path)
api.upload_file(
path_or_fileobj=str(file_path),
path_in_repo=repo_path,
repo_id=repo_id,
commit_message=commit_message
)
else:
# Upload entire folder
api.upload_folder(
folder_path=local_dir,
repo_id=repo_id,
path_in_repo=path_in_repo,
commit_message=commit_message
)
print(f"[✅] Uploaded to Hub: https://huggingface.co/{repo_id}")
except Exception as e:
print(f"[❌] Hub upload failed: {e}")
print(f" Continuing training (files saved locally)")
def prepare_hub_upload(
weights_dir: str,
runs_dir: str,
config: DavidTrainingConfig,
david_config: DavidArchitectureConfig,
best_metrics: Dict,
model_name: str
):
"""Prepare and upload all artifacts to HuggingFace Hub."""
if not config.upload_to_hub or not config.hf_repo:
return
print("\n[🤗] Preparing HuggingFace Hub upload...")
# Create model summary file
summary_path = create_model_summary(weights_dir, config, david_config, best_metrics, model_name)
# Update master models index
update_models_index(config, david_config, best_metrics, model_name)
api = HfApi()
try:
create_repo(config.hf_repo, exist_ok=True, repo_type="model")
except:
pass
# Create temporary directory for root files
with tempfile.TemporaryDirectory() as temp_dir:
# Generate README at root
readme_path = os.path.join(temp_dir, "README.md")
readme_content = generate_model_readme(config, david_config, best_metrics, config.run_id)
with open(readme_path, 'w') as f:
f.write(readme_content)
print(f"[📝] Generated README.md")
# Save best_model.json at root
best_json_path = os.path.join(temp_dir, "best_model.json")
save_best_model_json(best_json_path, best_metrics, config, david_config)
# Upload root files (README.md, best_model.json)
print(f"[📤] Uploading root files...")
api.upload_file(
path_or_fileobj=readme_path,
path_in_repo="README.md",
repo_id=config.hf_repo,
commit_message=f"Update README - Run {config.run_id}"
)
api.upload_file(
path_or_fileobj=best_json_path,
path_in_repo="best_model.json",
repo_id=config.hf_repo,
commit_message=f"Update metrics - Run {config.run_id}"
)
# Upload ONLY essential weight files (not entire directory!)
weights_repo_path = f"weights/{model_name}/{config.run_id}"
best_acc = best_metrics.get('best_val_acc', 0.0)
print(f"[📤] Uploading essential files to {weights_repo_path}...")
# List of specific files to upload (not entire directory)
files_to_upload = [
('MODEL_SUMMARY.txt', 'MODEL_SUMMARY.txt'),
('training_history.json', 'training_history.json'),
('david_config.json', 'david_config.json'),
('train_config.json', 'train_config.json'),
(f'best_model_acc{best_acc:.2f}.safetensors', f'best_model_acc{best_acc:.2f}.safetensors'),
(f'best_model_acc{best_acc:.2f}_metadata.json', f'best_model_acc{best_acc:.2f}_metadata.json'),
]
for local_filename, repo_filename in files_to_upload:
local_path = os.path.join(weights_dir, local_filename)
if os.path.exists(local_path):
try:
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=f"{weights_repo_path}/{repo_filename}",
repo_id=config.hf_repo,
commit_message=f"Update {repo_filename} - Run {config.run_id}"
)
except Exception as e:
print(f"[⚠️] Failed to upload {repo_filename}: {e}")
print(f"[✅] Uploaded to Hub: https://huggingface.co/{config.hf_repo}")
# Upload tensorboard logs (only if they exist and it's final upload)
# Skip TensorBoard during training to avoid huge uploads every epoch
# if os.path.exists(runs_dir):
# runs_repo_path = f"runs/{model_name}/{config.run_id}"
# print(f"[📤] Uploading TensorBoard logs to {runs_repo_path}...")
# upload_to_huggingface(
# local_dir=runs_dir,
# repo_id=config.hf_repo,
# commit_message=f"Upload TensorBoard logs - {model_name} - Run {config.run_id}",
# path_in_repo=runs_repo_path
# )
# ============================================================================
# CHECKPOINT UTILITIES
# ============================================================================
def save_checkpoint(
filepath: str,
david: David,
optimizer: torch.optim.Optimizer,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
epoch: int,
metrics: Dict,
train_config: DavidTrainingConfig
):
"""Save checkpoint in PyTorch and/or SafeTensors format."""
checkpoint = {
'epoch': epoch,
'model_state_dict': david.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
'metrics': metrics,
'train_config': train_config.to_dict(),
}
# Add accuracy to filename if available
val_acc = metrics.get('best_val_acc') or metrics.get('val_acc')
if val_acc:
acc_suffix = f"_acc{val_acc:.2f}"
filepath = filepath + acc_suffix
if train_config.save_format in ['pytorch', 'both']:
torch.save(checkpoint, filepath + '.pth')
print(f"[💾] Saved PyTorch: {filepath}.pth")
if train_config.save_format in ['safetensors', 'both']:
try:
from safetensors.torch import save_file
# Save model state
model_state = {k: v.contiguous() for k, v in david.state_dict().items()}
save_file(model_state, filepath + '.safetensors')
# Save metadata separately (now includes full training history)
metadata = {k: v for k, v in checkpoint.items()
if k not in ['model_state_dict']}
with open(filepath + '_metadata.json', 'w') as f:
json.dump(metadata, f, indent=2, default=str)
print(f"[💾] Saved SafeTensors: {filepath}.safetensors")
except ImportError:
print(f"[⚠️] SafeTensors not available, skipping")
def load_checkpoint(
checkpoint_path: str,
david: David,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
device: str = "cuda"
) -> Tuple[int, Dict]:
"""Load checkpoint and return epoch and metrics."""
if checkpoint_path.endswith('.safetensors'):
# Load SafeTensors format
try:
from safetensors.torch import load_file
model_state = load_file(checkpoint_path, device=device)
david.load_state_dict(model_state)
# Load metadata
metadata_path = checkpoint_path.replace('.safetensors', '_metadata.json')
with open(metadata_path, 'r') as f:
metadata = json.load(f)
epoch = metadata.get('epoch', 0)
metrics = metadata.get('metrics', {})
if optimizer and 'optimizer_state_dict' in metadata:
optimizer.load_state_dict(metadata['optimizer_state_dict'])
if scheduler and 'scheduler_state_dict' in metadata and metadata['scheduler_state_dict']:
scheduler.load_state_dict(metadata['scheduler_state_dict'])
print(f"[✅] Loaded from SafeTensors: {checkpoint_path}")
return epoch, metrics
except ImportError:
raise ImportError("safetensors not installed")
else:
# Load PyTorch format
checkpoint = torch.load(checkpoint_path, map_location=device)
david.load_state_dict(checkpoint['model_state_dict'])
if optimizer and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if scheduler and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
print(f"[✅] Loaded from PyTorch: {checkpoint_path}")
return checkpoint['epoch'], checkpoint.get('metrics', {})
# ============================================================================
# DATASET
# ============================================================================
class ImageNetHFDataset(Dataset):
"""PyTorch Dataset wrapper for HuggingFace ImageNet features."""
def __init__(self, dataset_name: str, model_variant: str, split: str = "train"):
# Load only the specific split to avoid downloading all data
print(f"[📥] Loading {split} split for {model_variant}...")
self.dataset = load_dataset(
dataset_name,
name=model_variant, # Dataset configuration/variant name
split=split # Only load this specific split
)
self.length = len(self.dataset)
print(f"[✅] Loaded {self.length:,} samples from {split} split")
def __len__(self):
return self.length
def __getitem__(self, idx):
item = self.dataset[idx]
features = torch.tensor(item['clip_features'], dtype=torch.float32)
label = torch.tensor(item['label'], dtype=torch.long)
return features, label
class MergedImageNetDataset(Dataset):
"""
Merge multiple CLIP variants into a single dataset.
Perfect for testing if David can unify different encoder spaces!
"""
def __init__(
self,
dataset_name: str,
model_variants: List[str], # e.g., ['clip_vit_b16', 'clip_vit_laion_b16']
split: str = "train",
shuffle_seed: int = 42
):
print(f"[🔀] Creating merged dataset from {len(model_variants)} variants...")
self.datasets = []
self.cumulative_lengths = [0]
# Load each variant
for variant in model_variants:
print(f"[📥] Loading {split} split for {variant}...")
ds = load_dataset(
dataset_name,
name=variant,
split=split
)
self.datasets.append(ds)
self.cumulative_lengths.append(self.cumulative_lengths[-1] + len(ds))
print(f"[✅] Loaded {len(ds):,} samples from {variant}")
self.total_length = self.cumulative_lengths[-1]
# Create shuffled indices for fair mixing
print(f"[🎲] Shuffling {self.total_length:,} samples (seed={shuffle_seed})...")
rng = np.random.RandomState(shuffle_seed)
self.shuffle_indices = rng.permutation(self.total_length)
print(f"[✅] Merged dataset ready: {self.total_length:,} samples from {len(model_variants)} encoders")
def __len__(self):
return self.total_length
def __getitem__(self, idx):
# Map shuffled index to original dataset
actual_idx = int(self.shuffle_indices[idx])
# Find which dataset this index belongs to
dataset_idx = 0
for i, cumsum in enumerate(self.cumulative_lengths[1:]):
if actual_idx < cumsum:
dataset_idx = i
break
# Get item from the correct dataset
local_idx = actual_idx - self.cumulative_lengths[dataset_idx]
item = self.datasets[dataset_idx][local_idx]
features = torch.tensor(item['clip_features'], dtype=torch.float32)
label = torch.tensor(item['label'], dtype=torch.long)
return features, label
def create_dataloaders(config: DavidTrainingConfig):
"""Create train and validation dataloaders."""
# Check if model_variant is a list (multi-encoder experiment)
if isinstance(config.model_variant, list):
print(f"[🧪] MULTI-ENCODER EXPERIMENT: Merging {len(config.model_variant)} variants")
train_dataset = MergedImageNetDataset(
config.dataset_name,
config.model_variant, # List of variants
"train"
)
val_dataset = MergedImageNetDataset(
config.dataset_name,
config.model_variant,
"validation"
)
else:
# Single encoder (normal mode)
train_dataset = ImageNetHFDataset(
config.dataset_name, config.model_variant, "train"
)
val_dataset = ImageNetHFDataset(
config.dataset_name, config.model_variant, "validation"
)
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor,
persistent_workers=config.persistent_workers
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size * 2,
shuffle=False,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor,
persistent_workers=config.persistent_workers
)
return train_loader, val_loader
# ============================================================================
# CRYSTAL GENERATOR
# ============================================================================
class CrystalGenerator:
"""Generate crystals for all scales."""
def __init__(self, num_classes: int, scales: List[int], device: str = "cuda"):
self.num_classes = num_classes
self.scales = scales
self.device = device
self.factories = {
scale: SimplexFactory(k=4, embed_dim=scale, method="random")
for scale in scales
}
def generate(self, seed: int = 42) -> Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor]]:
"""Generate anchors and crystals for all scales."""
anchors_dict = {}
crystals_dict = {}
for scale in tqdm(self.scales, desc="Generating crystals"):
factory = self.factories[scale]
batch_crystals = []
for class_idx in range(self.num_classes):
crystal = factory.build(
backend="torch",
device=self.device,
dtype=torch.float32,
seed=seed + class_idx,
validate=True
)
batch_crystals.append(crystal)
crystals = torch.stack(batch_crystals)
anchors = F.normalize(crystals[:, 0, :], dim=-1)
# Verify anchor diversity
anchor_sims = anchors @ anchors.T
off_diag = anchor_sims[~torch.eye(self.num_classes, dtype=bool, device=anchors.device)]
max_sim = off_diag.max().item()
mean_sim = off_diag.mean().item()
print(f" Scale {scale}: max_sim={max_sim:.4f}, mean_sim={mean_sim:.4f}")
if max_sim > 0.99:
print(f" ⚠️ WARNING: Anchors too similar at scale {scale}!")
anchors_dict[scale] = anchors
crystals_dict[scale] = crystals
return anchors_dict, crystals_dict
# ============================================================================
# TRAINING LOOP
# ============================================================================
def train_epoch(
david: David,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
criterion: MultiScaleCrystalLoss,
anchors_dict: Dict[int, torch.Tensor],
crystals_dict: Dict[int, torch.Tensor],
epoch: int,
config: DavidTrainingConfig,
writer: Optional[SummaryWriter],
global_step: int
) -> Tuple[float, float, int, Dict]:
"""Train for one epoch - Pure FP32."""
david.train()
david.update_epoch(epoch)
total_loss = 0
correct = 0
total = 0
loss_components_sum = {}
active_scales = david.get_active_scales()
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
for batch_idx, (features, labels) in enumerate(pbar):
features = features.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# Zero gradients
optimizer.zero_grad()
# Forward pass - Pure FP32, no autocast
combined, logits_list, features_list, fusion_weights = david(
features, anchors_dict, return_all_scales=True
)
# Compute loss
losses = criterion(
combined, logits_list, features_list,
labels, crystals_dict, epoch
)
# Backward
losses['total'].backward()
# Gradient analysis
if config.track_gradients and batch_idx % config.log_interval == 0:
grad_stats = analyze_gradients(david, config)
if writer:
step = global_step + batch_idx
writer.add_scalar('train/grad_mean', grad_stats['mean'], step)
writer.add_scalar('train/grad_max', grad_stats['max'], step)
writer.add_scalar('train/grad_num_small', grad_stats['num_small'], step)
# Scale small gradients
scale_small_gradients(david, config)
# Gradient clipping
torch.nn.utils.clip_grad_norm_(david.parameters(), config.gradient_clip)
# Optimizer step
optimizer.step()
# Metrics
total_loss += losses['total'].item()
_, predicted = torch.max(combined, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# Accumulate loss components
for key, value in losses.items():
if key not in loss_components_sum:
loss_components_sum[key] = 0.0
loss_components_sum[key] += value.item()
# Logging
if writer and batch_idx % config.log_interval == 0:
step = global_step + batch_idx
writer.add_scalar('train/loss_batch', losses['total'].item(), step)
writer.add_scalar('train/acc_batch', 100 * correct / total, step)
if config.log_loss_components:
for key, value in losses.items():
if key != 'total':
writer.add_scalar(f'train/loss_{key}', value.item(), step)
if config.log_fusion_weights and fusion_weights is not None:
if fusion_weights.dim() == 2:
mean_weights = fusion_weights.mean(dim=0)
for i, w in enumerate(mean_weights):
if i < len(active_scales):
writer.add_scalar(
f'train/fusion_weight_{active_scales[i]}',
w.item(), step
)
writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], step)
pbar.set_postfix({
'loss': f'{total_loss / (batch_idx + 1):.4f}',
'acc': f'{100 * correct / total:.2f}%'
})
global_step += 1
# Average loss components
avg_components = {k: v / len(train_loader) for k, v in loss_components_sum.items()}
return (
total_loss / len(train_loader),
100 * correct / total,
global_step,
avg_components
)
@torch.no_grad()
def validate(
david: David,
val_loader: DataLoader,
anchors_dict: Dict[int, torch.Tensor],
config: DavidTrainingConfig
) -> Tuple[float, Dict[int, float]]:
"""Validate model - Pure FP32."""
david.eval()
correct = 0
total = 0
active_scales = david.get_active_scales()
scale_correct = {scale: 0 for scale in active_scales}
for features, labels in tqdm(val_loader, desc="Validation", leave=False):
features = features.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# Forward pass - no autocast
combined, logits_list, _, _ = david(
features, anchors_dict, return_all_scales=True
)
_, predicted = torch.max(combined, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
for i, scale in enumerate(active_scales):
if i < len(logits_list):
_, scale_pred = torch.max(logits_list[i], 1)
scale_correct[scale] += (scale_pred == labels).sum().item()
accuracy = 100 * correct / total
scale_accs = {s: 100 * scale_correct[s] / total for s in scale_correct}
return accuracy, scale_accs
# ============================================================================
# MAIN TRAINING FUNCTION
# ============================================================================
def train_david(config: DavidTrainingConfig):
"""Main training pipeline."""
# Enable TensorFloat32 for better performance on Ampere+ GPUs
torch.set_float32_matmul_precision('high')
print("="*80)
print("🌟 DAVID TRAINING PIPELINE")
print("="*80)
print(f"Run ID: {config.run_id}")
print(f"Preset: {config.preset}")
print(f"Batch Size: {config.batch_size}")
print(f"Learning Rate: {config.learning_rate}")
print(f"Mixed Precision: {config.use_mixed_precision}")
print(f"TensorFloat32: Enabled (high precision)")
print("="*80)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load or create David config FIRST (needed for model_name)
if config.custom_config_path:
david_config = DavidArchitectureConfig.from_json(config.custom_config_path)
print(f"[📁] Loaded custom config: {config.custom_config_path}")
elif config.preset:
david_config = DavidPresets.get_preset(config.preset)
print(f"[⚙️] Using preset: {config.preset}")
else:
raise ValueError("Must specify either preset or custom_config_path")
# Create model name from architecture
model_name = f"David-{david_config.sharing_mode}-{david_config.fusion_mode}"
print(f"[🏷️] Model: {model_name}")
# Setup directories with proper hierarchy: weights/model_name/timestamp/
weights_dir = os.path.join(config.base_dir, "weights", model_name, config.run_id)
runs_dir = os.path.join(config.base_dir, "runs", model_name, config.run_id)
os.makedirs(weights_dir, exist_ok=True)
os.makedirs(runs_dir, exist_ok=True)
print(f"[📁] Weights: {weights_dir}")
print(f"[📁] Logs: {runs_dir}")
writer = SummaryWriter(runs_dir)
# Apply overrides
if config.num_classes_override:
david_config.num_classes = config.num_classes_override
if config.use_belly_override is not None:
david_config.use_belly = config.use_belly_override
if config.belly_expand_override is not None:
david_config.belly_expand = config.belly_expand_override
if config.progressive_training_override is not None:
david_config.progressive_training = config.progressive_training_override
if not david_config.progressive_training:
# Disable warmup if progressive training disabled
david_config.scale_warmup_epochs = {s: 0 for s in david_config.scales}
# Override scale warmup schedule if provided
if config.scale_warmup_epochs_override is not None:
david_config.scale_warmup_epochs = config.scale_warmup_epochs_override
# Enable progressive training if custom schedule provided
if not david_config.progressive_training:
print(f"[⚙️] Enabling progressive training (custom warmup schedule provided)")
david_config.progressive_training = True
print(f"[⚙️] Progressive training: {david_config.progressive_training}")
if david_config.progressive_training:
print(f" Scale warmup schedule: {david_config.scale_warmup_epochs}")
# Save configs
david_config_path = os.path.join(weights_dir, "david_config.json")
david_config.to_json(david_config_path)
print(f"[💾] Saved David config: {david_config_path}")
train_config_path = os.path.join(weights_dir, "train_config.json")
config.to_json(train_config_path)
print(f"[💾] Saved training config: {train_config_path}")
# Initialize David
david = David.from_config(david_config).cuda()
print(f"\n{david}\n")
# Count parameters
total_params = sum(p.numel() for p in david.parameters())
trainable_params = sum(p.numel() for p in david.parameters() if p.requires_grad)
print(f"[📊] Total Parameters: {total_params:,}")
print(f"[📊] Trainable Parameters: {trainable_params:,}")
# Load data
train_loader, val_loader = create_dataloaders(config)
# Generate crystals
crystal_gen = CrystalGenerator(
david_config.num_classes,
david_config.scales,
str(device)
)
anchors_dict, crystals_dict = crystal_gen.generate()
# Setup training
criterion = MultiScaleCrystalLoss(
scales=david_config.scales,
num_classes=david_config.num_classes,
use_rose_loss=config.use_rose_loss,
use_cayley_loss=config.use_cayley_loss,
rose_initial_weight=config.rose_initial_weight,
rose_max_weight=config.rose_max_weight,
cayley_weight=config.cayley_weight,
scale_loss_balance=config.scale_loss_balance
).cuda()
optimizer = create_optimizer(david, config)
scheduler = create_scheduler(optimizer, config)
controller = AdaptiveTrainingController(david, config)
# Tracking
best_val_acc = 0.0
best_epoch = 0
best_scale_accs = {}
global_step = 0
final_train_acc = 0.0
final_train_loss = 0.0
# Training history for epoch-by-epoch tracking
training_history = {
'epochs': [],
'train_loss': [],
'train_acc': [],
'val_acc': [],
'scale_accs': {},
'lr': []
}
# DIAGNOSTIC: Test one forward/backward pass before training
print("\n[🔍] Running diagnostic forward/backward pass...")
david.train()
# Get a small batch
for features_test, labels_test in train_loader:
features_test = features_test.cuda(non_blocking=True)[:8] # Just 8 samples
labels_test = labels_test.cuda(non_blocking=True)[:8]
# Forward
combined_test, logits_test, features_test_out, _ = david(
features_test, anchors_dict, return_all_scales=True
)
# Loss
losses_test = criterion(
combined_test, logits_test, features_test_out,
labels_test, crystals_dict, epoch=0
)
print(f" Initial loss: {losses_test['total'].item():.6f}")
print(f" Loss components:")
for key, value in losses_test.items():
if key != 'total':
print(f" {key}: {value.item():.6f}")
# Backward
optimizer.zero_grad()
losses_test['total'].backward()
# Check gradients
grad_count = sum(1 for p in david.parameters() if p.grad is not None and p.grad.norm() > 0)
total_grad_params = sum(1 for p in david.parameters() if p.requires_grad)
print(f" Parameters with non-zero gradients: {grad_count}/{total_grad_params}")
if grad_count == 0:
print(f" ❌ ERROR: No gradients! Training will not work.")
return None, 0.0
elif grad_count < total_grad_params * 0.5:
print(f" ⚠️ WARNING: Less than 50% of parameters have gradients")
else:
print(f" ✅ Gradients look good")
break # Only test one batch
print("\n[🚀] Starting training...\n")
for epoch in range(config.num_epochs):
epoch_start = time.time()
# Train
train_loss, train_acc, global_step, loss_components = train_epoch(
david, train_loader, optimizer, criterion,
anchors_dict, crystals_dict, epoch, config,
writer, global_step
)
# Validate
val_acc, scale_accs = validate(david, val_loader, anchors_dict, config)
# Update controller
controller.update_metrics(scale_accs, val_acc)
controller.apply_adaptive_strategies(scale_accs, epoch)
# Step scheduler
if scheduler:
scheduler.step()
epoch_time = time.time() - epoch_start
# Print
print(f"\n📊 Epoch {epoch+1}/{config.num_epochs} ({epoch_time:.1f}s)")
print(f" Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
print(f" Val: Acc={val_acc:.2f}% (Best: {best_val_acc:.2f}%)")
print(f" Active scales: {david.get_active_scales()}")
print(f" LR: {optimizer.param_groups[0]['lr']:.2e}")
if config.log_loss_components and loss_components:
print(f" Loss breakdown:")
for key, value in sorted(loss_components.items()):
if key != 'total':
print(f" {key:20s}: {value:.6f}")
for scale, acc in scale_accs.items():
frozen = "❄️" if controller.scales_frozen.get(scale, False) else "🔥"
print(f" {frozen} Scale {scale}: {acc:.2f}%")
# Update tracking
final_train_acc = train_acc
final_train_loss = train_loss
# Record training history
training_history['epochs'].append(epoch + 1)
training_history['train_loss'].append(train_loss)
training_history['train_acc'].append(train_acc)
training_history['val_acc'].append(val_acc)
training_history['lr'].append(optimizer.param_groups[0]['lr'])
# Record per-scale accuracies
for scale, acc in scale_accs.items():
if scale not in training_history['scale_accs']:
training_history['scale_accs'][scale] = []
training_history['scale_accs'][scale].append(acc)
# TensorBoard
writer.add_scalar('train/loss', train_loss, epoch)
writer.add_scalar('train/acc', train_acc, epoch)
writer.add_scalar('val/acc', val_acc, epoch)
for scale, acc in scale_accs.items():
writer.add_scalar(f'val/acc_scale_{scale}', acc, epoch)
# Save best
if val_acc > best_val_acc:
best_val_acc = val_acc
best_epoch = epoch
best_scale_accs = scale_accs.copy()
# Save training history alongside best model
history_path = os.path.join(weights_dir, 'training_history.json')
with open(history_path, 'w') as f:
json.dump(training_history, f, indent=2)
save_checkpoint(
os.path.join(weights_dir, 'best_model'),
david, optimizer, scheduler, epoch,
{
'best_val_acc': best_val_acc,
'best_epoch': best_epoch,
'scale_accuracies': best_scale_accs,
'training_history': training_history
},
config
)
# Upload to hub when best model improves
if config.upload_to_hub:
best_metrics = {
'best_val_acc': best_val_acc,
'best_epoch': best_epoch,
'scale_accuracies': best_scale_accs,
'final_train_acc': train_acc,
'final_train_loss': train_loss,
'training_history': training_history,
'parameters': total_params
}
prepare_hub_upload(weights_dir, runs_dir, config, david_config, best_metrics, model_name)
# Periodic save
if (epoch + 1) % config.save_interval == 0:
save_checkpoint(
os.path.join(weights_dir, f'checkpoint_epoch_{epoch+1}'),
david, optimizer, scheduler, epoch,
{'val_acc': val_acc},
config
)
# Final save
save_checkpoint(
os.path.join(weights_dir, 'final_model'),
david, optimizer, scheduler, config.num_epochs - 1,
{'final_val_acc': val_acc},
config
)
writer.close()
# Final hub upload with all artifacts
if config.upload_to_hub:
print("\n[🤗] Performing final HuggingFace Hub upload...")
final_metrics = {
'best_val_acc': best_val_acc,
'best_epoch': best_epoch,
'scale_accuracies': best_scale_accs,
'final_train_acc': final_train_acc,
'final_train_loss': final_train_loss,
'training_history': training_history,
'parameters': total_params
}
prepare_hub_upload(weights_dir, runs_dir, config, david_config, final_metrics, model_name)
# Upload TensorBoard logs at the end
if os.path.exists(runs_dir):
runs_repo_path = f"runs/{model_name}/{config.run_id}"
print(f"[📤] Uploading TensorBoard logs to {runs_repo_path}...")
upload_to_huggingface(
local_dir=runs_dir,
repo_id=config.hf_repo,
commit_message=f"Upload TensorBoard logs - {model_name} - Run {config.run_id}",
path_in_repo=runs_repo_path
)
print("\n" + "="*80)
print(f"🎉 Training Complete!")
print(f" Best Val Acc: {best_val_acc:.2f}% (Epoch {best_epoch+1})")
print(f" Final Train Acc: {final_train_acc:.2f}%")
print(f" Weights: {weights_dir}")
if config.upload_to_hub:
print(f" Hub: https://huggingface.co/{config.hf_repo}")
print("="*80)
return david, best_val_acc
# ============================================================================
# USAGE EXAMPLE
# ============================================================================
if __name__ == "__main__":
# ============================================================================
# EXPERIMENT 1: Single Encoder (Standard Training)
# ============================================================================
# config = DavidTrainingConfig(
# preset="balanced",
# model_variant="clip_vit_b16", # Single encoder
#
# num_epochs=10,
# batch_size=1024,
# learning_rate=1e-2,
#
# use_rose_loss=True,
# rose_initial_weight=0.1,
# rose_max_weight=0.5,
#
# upload_to_hub=True,
# hf_repo="AbstractPhil/gated-david",
# )
# ============================================================================
# EXPERIMENT 2: Multi-Encoder Unified Space (THE TEST!)
# ============================================================================
config = DavidTrainingConfig(
preset="balanced", # 4 scales: [256, 512, 768, 1024]
# 🧪 MULTI-ENCODER: OpenAI CLIP-B/32 vs LAION CLIP-B/32
model_variant=["clip_vit_b16", "clip_vit_laion_b32"], # Both B/32!
num_epochs=10,
batch_size=1024,
learning_rate=1e-2,
# Custom warmup for 4 scales
scale_warmup_epochs_override={
256: 0,
512: 2,
768: 5,
1024: 8
},
use_rose_loss=True,
rose_initial_weight=0.2, # Higher for diversity
rose_max_weight=0.8,
use_cayley_loss=True, # Extra geometric regularization
cayley_weight=0.01,
freeze_strategy="never",
gradient_clip=10.0,
save_format="safetensors",
upload_to_hub=False,
hf_repo="YourName/YourRepoHere"#"AbstractPhil/david-shared-space",
)
print("="*80)
print("🧪 UNIFIED SPACE EXPERIMENT")
print("="*80)
print(f"Testing if David can unify:")
if isinstance(config.model_variant, list):
for variant in config.model_variant:
print(f" • {variant}")
print("="*80)
david, best_acc = train_david(config)