|
|
""" |
|
|
David Training Pipeline |
|
|
======================== |
|
|
Training pipeline for David multi-scale crystal classifier. |
|
|
|
|
|
Should be placed at: geovocab2/train/model/core/david_trainer.py |
|
|
Or run from: scripts/train_david.py |
|
|
|
|
|
Features: |
|
|
- Pure fp32 training (no mixed precision for geometric stability) |
|
|
- Adaptive training controller (freeze/unfreeze scales) |
|
|
- Gradient analysis and scaling |
|
|
- SafeTensors checkpoint support |
|
|
- Enhanced loss component tracking |
|
|
- Proper weight organization: weights/model_name/timestamp/ |
|
|
- Accuracy in filenames and comprehensive tracking |
|
|
- Master models index (MODELS_INDEX.json) |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from datasets import load_dataset |
|
|
from huggingface_hub import HfApi, create_repo, upload_folder, upload_file |
|
|
import numpy as np |
|
|
import os |
|
|
import json |
|
|
import time |
|
|
import tempfile |
|
|
from datetime import datetime |
|
|
from tqdm.auto import tqdm |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
from dataclasses import dataclass, field, asdict |
|
|
|
|
|
|
|
|
from geovocab2.train.config.david_config import ( |
|
|
DavidArchitectureConfig, |
|
|
DavidPresets, |
|
|
SharingMode, |
|
|
FusionMode |
|
|
) |
|
|
|
|
|
from geovocab2.train.model.core.david import ( |
|
|
David, |
|
|
MultiScaleCrystalLoss, |
|
|
) |
|
|
|
|
|
|
|
|
from geovocab2.shapes.factory import SimplexFactory |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DavidTrainingConfig: |
|
|
""" |
|
|
Complete training configuration for David. |
|
|
Separate from model architecture config. |
|
|
""" |
|
|
|
|
|
|
|
|
name: str = "david_training" |
|
|
run_id: str = "" |
|
|
|
|
|
|
|
|
dataset_name: str = "AbstractPhil/imagenet-clip-features-orderly" |
|
|
model_variant: Union[str, List[str]] = "clip_vit_b16" |
|
|
num_classes: int = 1000 |
|
|
|
|
|
|
|
|
preset: Optional[str] = "balanced" |
|
|
custom_config_path: Optional[str] = None |
|
|
|
|
|
|
|
|
num_classes_override: Optional[int] = None |
|
|
use_belly_override: Optional[bool] = None |
|
|
belly_expand_override: Optional[float] = None |
|
|
progressive_training_override: Optional[bool] = True |
|
|
scale_warmup_epochs_override: Optional[Dict[int, int]] = None |
|
|
|
|
|
|
|
|
num_epochs: int = 50 |
|
|
batch_size: int = 512 |
|
|
learning_rate: float = 5e-3 |
|
|
weight_decay: float = 1e-5 |
|
|
warmup_epochs: int = 3 |
|
|
|
|
|
|
|
|
use_rose_loss: bool = True |
|
|
rose_initial_weight: float = 0.01 |
|
|
rose_max_weight: float = 0.1 |
|
|
rose_weight_schedule: str = "adaptive" |
|
|
use_cayley_loss: bool = False |
|
|
cayley_weight: float = 0.001 |
|
|
scale_loss_balance: Optional[Dict[int, float]] = None |
|
|
|
|
|
|
|
|
use_mixed_precision: bool = False |
|
|
gradient_clip: float = 5.0 |
|
|
scheduler_type: str = "cosine_restarts" |
|
|
min_lr: float = 1e-6 |
|
|
|
|
|
|
|
|
freeze_strategy: str = "never" |
|
|
freeze_threshold: float = 90.0 |
|
|
unfreeze_on_plateau: bool = True |
|
|
patience: int = 10 |
|
|
|
|
|
|
|
|
track_gradients: bool = True |
|
|
gradient_scale_threshold: float = 1e-5 |
|
|
gradient_scale_multiplier: float = 10.0 |
|
|
|
|
|
|
|
|
log_interval: int = 50 |
|
|
val_interval: int = 1 |
|
|
save_interval: int = 5 |
|
|
log_fusion_weights: bool = True |
|
|
log_loss_components: bool = True |
|
|
|
|
|
|
|
|
save_format: str = "both" |
|
|
|
|
|
|
|
|
hf_repo: Optional[str] = "" |
|
|
upload_to_hub: bool = False |
|
|
|
|
|
|
|
|
base_dir: str = "./david_training" |
|
|
|
|
|
|
|
|
num_workers: int = 10 |
|
|
pin_memory: bool = True |
|
|
prefetch_factor: int = 4 |
|
|
persistent_workers: bool = True |
|
|
|
|
|
def __post_init__(self): |
|
|
"""Generate run_id if not provided.""" |
|
|
if not self.run_id: |
|
|
self.run_id = datetime.now().strftime('%Y%m%d_%H%M%S') |
|
|
|
|
|
def to_dict(self) -> dict: |
|
|
"""Convert to dictionary.""" |
|
|
return asdict(self) |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: dict) -> 'DavidTrainingConfig': |
|
|
"""Create from dictionary.""" |
|
|
return cls(**data) |
|
|
|
|
|
def to_json(self, path: str): |
|
|
"""Save to JSON.""" |
|
|
data = self.to_dict() |
|
|
|
|
|
if data.get('scale_loss_balance'): |
|
|
data['scale_loss_balance'] = { |
|
|
str(k): v for k, v in data['scale_loss_balance'].items() |
|
|
} |
|
|
if data.get('scale_warmup_epochs_override'): |
|
|
data['scale_warmup_epochs_override'] = { |
|
|
str(k): v for k, v in data['scale_warmup_epochs_override'].items() |
|
|
} |
|
|
with open(path, 'w') as f: |
|
|
json.dump(data, f, indent=2) |
|
|
|
|
|
@classmethod |
|
|
def from_json(cls, path: str) -> 'DavidTrainingConfig': |
|
|
"""Load from JSON.""" |
|
|
with open(path, 'r') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
if 'scale_loss_balance' in data and data['scale_loss_balance']: |
|
|
data['scale_loss_balance'] = { |
|
|
int(k): v for k, v in data['scale_loss_balance'].items() |
|
|
} |
|
|
|
|
|
if 'scale_warmup_epochs_override' in data and data['scale_warmup_epochs_override']: |
|
|
data['scale_warmup_epochs_override'] = { |
|
|
int(k): v for k, v in data['scale_warmup_epochs_override'].items() |
|
|
} |
|
|
return cls(**data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdaptiveTrainingController: |
|
|
"""Manages adaptive training strategies for multi-scale model.""" |
|
|
|
|
|
def __init__(self, model: David, config: DavidTrainingConfig): |
|
|
self.model = model |
|
|
self.config = config |
|
|
|
|
|
scales = model.scales |
|
|
self.scale_history = {scale: [] for scale in scales} |
|
|
self.best_scale_acc = {scale: 0.0 for scale in scales} |
|
|
self.scales_frozen = {scale: False for scale in scales} |
|
|
|
|
|
self.overall_history = [] |
|
|
self.plateau_counter = 0 |
|
|
self.best_overall = 0.0 |
|
|
|
|
|
def update_metrics(self, scale_accuracies: Dict[int, float], overall_accuracy: float): |
|
|
"""Update metrics and best scores.""" |
|
|
for scale, acc in scale_accuracies.items(): |
|
|
self.scale_history[scale].append(acc) |
|
|
if acc > self.best_scale_acc[scale]: |
|
|
self.best_scale_acc[scale] = acc |
|
|
|
|
|
self.overall_history.append(overall_accuracy) |
|
|
|
|
|
if overall_accuracy > self.best_overall: |
|
|
self.best_overall = overall_accuracy |
|
|
self.plateau_counter = 0 |
|
|
else: |
|
|
self.plateau_counter += 1 |
|
|
|
|
|
def should_freeze_scale(self, scale: int, current_acc: float) -> bool: |
|
|
"""Determine if a scale should be frozen.""" |
|
|
if self.config.freeze_strategy == "never": |
|
|
return False |
|
|
|
|
|
if self.scales_frozen[scale]: |
|
|
return False |
|
|
|
|
|
if self.config.freeze_strategy == "performance": |
|
|
return current_acc >= self.config.freeze_threshold |
|
|
|
|
|
return False |
|
|
|
|
|
def should_unfreeze_scales(self) -> bool: |
|
|
"""Check if scales should be unfrozen due to plateau.""" |
|
|
if not self.config.unfreeze_on_plateau: |
|
|
return False |
|
|
return self.plateau_counter >= 5 |
|
|
|
|
|
def apply_adaptive_strategies(self, scale_accuracies: Dict[int, float], epoch: int): |
|
|
"""Apply freeze/unfreeze based on performance.""" |
|
|
active_scales = self.model.get_active_scales() |
|
|
|
|
|
|
|
|
for scale, acc in scale_accuracies.items(): |
|
|
if self.should_freeze_scale(scale, acc): |
|
|
|
|
|
active_unfrozen = [s for s in active_scales if not self.scales_frozen.get(s, False)] |
|
|
|
|
|
if len(active_unfrozen) <= 1: |
|
|
print(f"[⚠️] Skipping freeze of scale {scale} (would leave no active trainable scales)") |
|
|
continue |
|
|
|
|
|
self.model.freeze_scale(scale) |
|
|
self.scales_frozen[scale] = True |
|
|
print(f"[❄️] Froze scale {scale} (acc={acc:.2f}%)") |
|
|
|
|
|
if self.should_unfreeze_scales() and any(self.scales_frozen.values()): |
|
|
for scale in self.model.scales: |
|
|
if self.scales_frozen[scale]: |
|
|
self.model.unfreeze_scale(scale) |
|
|
self.scales_frozen[scale] = False |
|
|
self.plateau_counter = 0 |
|
|
print(f"[🔥] Unfroze all scales due to plateau") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_optimizer(david: David, config: DavidTrainingConfig) -> torch.optim.Optimizer: |
|
|
"""Create optimizer with parameter groups.""" |
|
|
|
|
|
param_groups = [] |
|
|
|
|
|
|
|
|
if hasattr(david, 'shared_extractor'): |
|
|
param_groups.append({ |
|
|
'params': david.shared_extractor.parameters(), |
|
|
'lr': config.learning_rate, |
|
|
'name': 'shared' |
|
|
}) |
|
|
elif hasattr(david, 'shared_base'): |
|
|
param_groups.append({ |
|
|
'params': david.shared_base.parameters(), |
|
|
'lr': config.learning_rate, |
|
|
'name': 'shared' |
|
|
}) |
|
|
|
|
|
|
|
|
for scale in david.scales: |
|
|
scale_params = [] |
|
|
if david.sharing_mode == SharingMode.HIERARCHICAL: |
|
|
head = getattr(david, f'head_{scale}', None) |
|
|
if head: |
|
|
scale_params.extend(head.parameters()) |
|
|
refine = getattr(david, f'refine_{scale}', None) |
|
|
if refine: |
|
|
scale_params.extend(refine.parameters()) |
|
|
else: |
|
|
scale_params.extend(david.heads[str(scale)].parameters()) |
|
|
|
|
|
if scale_params: |
|
|
param_groups.append({ |
|
|
'params': scale_params, |
|
|
'lr': config.learning_rate, |
|
|
'name': f'scale_{scale}' |
|
|
}) |
|
|
|
|
|
|
|
|
if hasattr(david, 'fusion'): |
|
|
param_groups.append({ |
|
|
'params': david.fusion.parameters(), |
|
|
'lr': config.learning_rate * 0.5, |
|
|
'name': 'fusion' |
|
|
}) |
|
|
elif hasattr(david, 'fusion_weights'): |
|
|
param_groups.append({ |
|
|
'params': [david.fusion_weights], |
|
|
'lr': config.learning_rate * 0.5, |
|
|
'name': 'fusion' |
|
|
}) |
|
|
|
|
|
return torch.optim.AdamW(param_groups, weight_decay=config.weight_decay) |
|
|
|
|
|
|
|
|
def create_scheduler(optimizer: torch.optim.Optimizer, |
|
|
config: DavidTrainingConfig) -> torch.optim.lr_scheduler._LRScheduler: |
|
|
"""Create learning rate scheduler.""" |
|
|
|
|
|
if config.scheduler_type == "cosine_restarts": |
|
|
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( |
|
|
optimizer, T_0=10, T_mult=2, eta_min=config.min_lr |
|
|
) |
|
|
elif config.scheduler_type == "cosine": |
|
|
return torch.optim.lr_scheduler.CosineAnnealingLR( |
|
|
optimizer, T_max=config.num_epochs, eta_min=config.min_lr |
|
|
) |
|
|
else: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def analyze_gradients(model: David, config: DavidTrainingConfig) -> Dict[str, float]: |
|
|
"""Analyze gradient magnitudes for debugging.""" |
|
|
grad_stats = { |
|
|
'mean': 0.0, |
|
|
'max': 0.0, |
|
|
'min': float('inf'), |
|
|
'num_zero': 0, |
|
|
'num_small': 0, |
|
|
'total': 0 |
|
|
} |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if param.grad is not None: |
|
|
grad_norm = param.grad.norm().item() |
|
|
grad_stats['mean'] += grad_norm |
|
|
grad_stats['max'] = max(grad_stats['max'], grad_norm) |
|
|
grad_stats['min'] = min(grad_stats['min'], grad_norm) |
|
|
grad_stats['total'] += 1 |
|
|
|
|
|
if grad_norm < 1e-10: |
|
|
grad_stats['num_zero'] += 1 |
|
|
elif grad_norm < config.gradient_scale_threshold: |
|
|
grad_stats['num_small'] += 1 |
|
|
|
|
|
if grad_stats['total'] > 0: |
|
|
grad_stats['mean'] /= grad_stats['total'] |
|
|
|
|
|
return grad_stats |
|
|
|
|
|
|
|
|
def scale_small_gradients(model: David, config: DavidTrainingConfig): |
|
|
"""Scale up very small gradients to prevent vanishing.""" |
|
|
if not config.track_gradients: |
|
|
return |
|
|
|
|
|
for param in model.parameters(): |
|
|
if param.grad is not None: |
|
|
grad_norm = param.grad.norm() |
|
|
if grad_norm < config.gradient_scale_threshold and grad_norm > 0: |
|
|
param.grad.mul_(config.gradient_scale_multiplier) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_model_readme( |
|
|
config: DavidTrainingConfig, |
|
|
david_config: DavidArchitectureConfig, |
|
|
best_metrics: Dict, |
|
|
run_id: str |
|
|
) -> str: |
|
|
"""Generate README.md for model card.""" |
|
|
|
|
|
readme = f"""--- |
|
|
language: en |
|
|
license: mit |
|
|
tags: |
|
|
- image-classification |
|
|
- imagenet |
|
|
- multi-scale |
|
|
- feature-geometry |
|
|
- david |
|
|
datasets: |
|
|
- imagenet-1k |
|
|
metrics: |
|
|
- accuracy |
|
|
model-index: |
|
|
- name: David-{david_config.sharing_mode}-{david_config.fusion_mode} |
|
|
results: |
|
|
- task: |
|
|
type: image-classification |
|
|
dataset: |
|
|
name: ImageNet-1K |
|
|
type: imagenet-1k |
|
|
metrics: |
|
|
- type: accuracy |
|
|
value: {best_metrics.get('best_val_acc', 0.0):.2f} |
|
|
--- |
|
|
|
|
|
# David: Multi-Scale Feature Classifier |
|
|
|
|
|
**David** is a multi-scale deep learning classifier that uses feature geometry (pentachora/4-simplexes) |
|
|
as class prototypes with role-weighted similarity computation (Rose Loss). |
|
|
|
|
|
This version is using multiple variations of clip-vit inputs simultaneously into shared space. |
|
|
The experiment will determine if entirely deviant variations such as clip-vit-b-patch32 and patch16 can |
|
|
exist simultaneously in the same shared space with the correct checks and spacings applied. |
|
|
|
|
|
## Model Details |
|
|
|
|
|
### Architecture |
|
|
- **Preset**: {config.preset} |
|
|
- **Sharing Mode**: {david_config.sharing_mode} |
|
|
- **Fusion Mode**: {david_config.fusion_mode} |
|
|
- **Scales**: {david_config.scales} |
|
|
- **Feature Dim**: {david_config.feature_dim} |
|
|
- **Parameters**: {best_metrics.get('parameters', 0):,} |
|
|
|
|
|
### Training Configuration |
|
|
- **Dataset**: {config.dataset_name} |
|
|
- **Model Variant**: {config.model_variant} |
|
|
- **Epochs**: {config.num_epochs} |
|
|
- **Batch Size**: {config.batch_size} |
|
|
- **Learning Rate**: {config.learning_rate} |
|
|
- **Rose Loss Weight**: {config.rose_initial_weight} → {config.rose_max_weight} |
|
|
- **Cayley Loss**: {config.use_cayley_loss} |
|
|
|
|
|
## Performance |
|
|
|
|
|
### Best Results |
|
|
- **Validation Accuracy**: {best_metrics.get('best_val_acc', 0.0):.2f}% |
|
|
- **Best Epoch**: {best_metrics.get('best_epoch', 0)} |
|
|
- **Final Train Accuracy**: {best_metrics.get('final_train_acc', 0.0):.2f}% |
|
|
|
|
|
### Per-Scale Performance |
|
|
""" |
|
|
|
|
|
if 'scale_accuracies' in best_metrics: |
|
|
for scale, acc in best_metrics['scale_accuracies'].items(): |
|
|
readme += f"- **Scale {scale}**: {acc:.2f}%\n" |
|
|
|
|
|
readme += f""" |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Quick Model Lookup |
|
|
|
|
|
**Check `MODELS_INDEX.json` in the repo root** - it lists all trained models sorted by accuracy with links to weights and configs. |
|
|
|
|
|
### Repository Structure |
|
|
|
|
|
``` |
|
|
{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}/ |
|
|
├── MODELS_INDEX.json # 📊 Master index of all models (sorted by accuracy) |
|
|
├── README.md # This file |
|
|
├── best_model.json # Latest best model info |
|
|
├── weights/ |
|
|
│ └── {david_config.name}/ |
|
|
│ └── {run_id}/ |
|
|
│ ├── MODEL_SUMMARY.txt # 🎯 Human-readable performance summary |
|
|
│ ├── training_history.json # 📈 Epoch-by-epoch training curve |
|
|
│ ├── best_model_acc{best_metrics.get('best_val_acc', 0.0):.2f}.safetensors # ⭐ Accuracy in filename! |
|
|
│ ├── best_model_acc{best_metrics.get('best_val_acc', 0.0):.2f}_metadata.json |
|
|
│ ├── final_model.safetensors |
|
|
│ ├── checkpoint_epoch_X_accYY.YY.safetensors |
|
|
│ ├── david_config.json |
|
|
│ └── train_config.json |
|
|
└── runs/ |
|
|
└── {david_config.name}/ |
|
|
└── {run_id}/ |
|
|
└── events.out.tfevents.* # TensorBoard logs |
|
|
``` |
|
|
|
|
|
### Loading the Model |
|
|
|
|
|
```python |
|
|
from geovocab2.train.model.core.david import David, DavidArchitectureConfig |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
# Browse available models in MODELS_INDEX.json first! |
|
|
|
|
|
# Specify model variant and run |
|
|
model_name = "{david_config.name}" |
|
|
run_id = "{run_id}" |
|
|
accuracy = "{best_metrics.get('best_val_acc', 0.0):.2f}" # From MODELS_INDEX.json |
|
|
|
|
|
# Download config |
|
|
config_path = hf_hub_download( |
|
|
repo_id="{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}", |
|
|
filename=f"weights/{{model_name}}/{{run_id}}/david_config.json" |
|
|
) |
|
|
config = DavidArchitectureConfig.from_json(config_path) |
|
|
|
|
|
# Download weights (accuracy in filename!) |
|
|
weights_path = hf_hub_download( |
|
|
repo_id="{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}", |
|
|
filename=f"weights/{{model_name}}/{{run_id}}/best_model_acc{{accuracy}}.safetensors" |
|
|
) |
|
|
|
|
|
# Download training history (optional - see full training curve) |
|
|
history_path = hf_hub_download( |
|
|
repo_id="{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}", |
|
|
filename=f"weights/{{model_name}}/{{run_id}}/training_history.json" |
|
|
) |
|
|
|
|
|
# Load model |
|
|
from safetensors.torch import load_file |
|
|
david = David.from_config(config) |
|
|
david.load_state_dict(load_file(weights_path)) |
|
|
david.eval() |
|
|
``` |
|
|
|
|
|
### Inference |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
# Assuming you have CLIP features (512-dim for ViT-B/16) |
|
|
features = get_clip_features(image) # [1, 512] |
|
|
|
|
|
# Load anchors |
|
|
anchors_dict = torch.load("anchors.pth") |
|
|
|
|
|
# Forward pass |
|
|
with torch.no_grad(): |
|
|
logits, _ = david(features, anchors_dict) |
|
|
predictions = logits.argmax(dim=-1) |
|
|
``` |
|
|
|
|
|
## Architecture Overview |
|
|
|
|
|
### Multi-Scale Processing |
|
|
David processes inputs at multiple scales ({', '.join(map(str, david_config.scales))}), |
|
|
allowing it to capture both coarse and fine-grained features. |
|
|
|
|
|
### Shared Representation Space |
|
|
This variation shares multiple versions of clip-vit models in the same representation space. |
|
|
|
|
|
### Feature Geometry |
|
|
Each class is represented by a pentachoron (4-simplex) in embedding space with 5 vertices: |
|
|
- **Anchor**: Primary class representative |
|
|
- **Need**: Complementary direction |
|
|
- **Relation**: Contextual alignment |
|
|
- **Purpose**: Functional direction |
|
|
- **Observer**: Meta-perspective |
|
|
|
|
|
### Rose Loss |
|
|
Similarity computation uses role-weighted cosine similarities: |
|
|
``` |
|
|
score = w_anchor * sim(z, anchor) + w_need * sim(z, need) + ... |
|
|
``` |
|
|
|
|
|
### Fusion Strategy |
|
|
**{david_config.fusion_mode}**: Intelligently combines predictions from multiple scales. |
|
|
|
|
|
## Training Details |
|
|
|
|
|
### Loss Components |
|
|
- **Cross-Entropy**: Standard classification loss |
|
|
- **Rose Loss**: Pentachora role-weighted margin loss (weight: {config.rose_initial_weight}→{config.rose_max_weight}) |
|
|
- **Cayley Loss**: Geometric regularization ({'enabled' if config.use_cayley_loss else 'disabled'}) |
|
|
|
|
|
### Optimization |
|
|
- **Optimizer**: AdamW |
|
|
- **Weight Decay**: {config.weight_decay} |
|
|
- **Scheduler**: {config.scheduler_type} |
|
|
- **Gradient Clip**: {config.gradient_clip} |
|
|
- **Mixed Precision**: {config.use_mixed_precision} |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@software{{david_classifier_2025, |
|
|
title = {{David: Multi-Scale Feature Classifier}}, |
|
|
author = {{AbstractPhil}}, |
|
|
year = {{2025}}, |
|
|
url = {{https://huggingface.co/{config.hf_repo if config.hf_repo else 'AbstractPhil/david'}}}, |
|
|
note = {{Run ID: {run_id}}} |
|
|
}} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
MIT License |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
Built with feature lattice geometry and multi-scale deep learning. |
|
|
Special thanks to Claude (Anthropic) for debugging assistance. |
|
|
|
|
|
--- |
|
|
|
|
|
*Generated on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}* |
|
|
""" |
|
|
|
|
|
return readme |
|
|
|
|
|
|
|
|
def save_best_model_json( |
|
|
filepath: str, |
|
|
metrics: Dict, |
|
|
config: DavidTrainingConfig, |
|
|
david_config: DavidArchitectureConfig |
|
|
): |
|
|
"""Save best_model.json with comprehensive metrics.""" |
|
|
|
|
|
model_name = f"David-{david_config.sharing_mode}-{david_config.fusion_mode}" |
|
|
|
|
|
best_model_info = { |
|
|
"model_name": model_name, |
|
|
"run_id": config.run_id, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
|
|
|
|
|
|
"best_val_acc": metrics.get('best_val_acc', 0.0), |
|
|
"best_epoch": metrics.get('best_epoch', 0), |
|
|
"final_train_acc": metrics.get('final_train_acc', 0.0), |
|
|
"final_train_loss": metrics.get('final_train_loss', 0.0), |
|
|
|
|
|
|
|
|
"scale_accuracies": metrics.get('scale_accuracies', {}), |
|
|
|
|
|
|
|
|
"architecture": { |
|
|
"preset": config.preset, |
|
|
"sharing_mode": david_config.sharing_mode, |
|
|
"fusion_mode": david_config.fusion_mode, |
|
|
"scales": david_config.scales, |
|
|
"feature_dim": david_config.feature_dim, |
|
|
"num_classes": david_config.num_classes, |
|
|
"use_belly": david_config.use_belly, |
|
|
"belly_expand": david_config.belly_expand, |
|
|
}, |
|
|
|
|
|
|
|
|
"training": { |
|
|
"dataset": config.dataset_name, |
|
|
"model_variant": config.model_variant, |
|
|
"num_epochs": config.num_epochs, |
|
|
"batch_size": config.batch_size, |
|
|
"learning_rate": config.learning_rate, |
|
|
"rose_weight": f"{config.rose_initial_weight}→{config.rose_max_weight}", |
|
|
"cayley_loss": config.use_cayley_loss, |
|
|
"optimizer": "AdamW", |
|
|
"scheduler": config.scheduler_type, |
|
|
}, |
|
|
|
|
|
|
|
|
"files": { |
|
|
"weights_safetensors": f"weights/{model_name}/{config.run_id}/best_model_acc{metrics.get('best_val_acc', 0.0):.2f}.safetensors", |
|
|
"weights_pytorch": f"weights/{model_name}/{config.run_id}/best_model.pth", |
|
|
"config": f"weights/{model_name}/{config.run_id}/david_config.json", |
|
|
"training_config": f"weights/{model_name}/{config.run_id}/train_config.json", |
|
|
"tensorboard": f"runs/{model_name}/{config.run_id}/" |
|
|
} |
|
|
} |
|
|
|
|
|
with open(filepath, 'w') as f: |
|
|
json.dump(best_model_info, f, indent=2) |
|
|
|
|
|
print(f"[📄] Saved best_model.json: {filepath}") |
|
|
|
|
|
|
|
|
def create_model_summary( |
|
|
weights_dir: str, |
|
|
config: DavidTrainingConfig, |
|
|
david_config: DavidArchitectureConfig, |
|
|
best_metrics: Dict, |
|
|
model_name: str |
|
|
): |
|
|
"""Create prominent model summary with accuracy front and center.""" |
|
|
|
|
|
summary_path = os.path.join(weights_dir, 'MODEL_SUMMARY.txt') |
|
|
|
|
|
best_acc = best_metrics.get('best_val_acc', 0.0) |
|
|
training_history = best_metrics.get('training_history', {}) |
|
|
|
|
|
summary = f""" |
|
|
╔══════════════════════════════════════════════════════════════╗ |
|
|
║ DAVID MODEL SUMMARY ║ |
|
|
╠══════════════════════════════════════════════════════════════╣ |
|
|
║ ║ |
|
|
║ 🎯 VALIDATION ACCURACY: {best_acc:.2f}% ║ |
|
|
║ ║ |
|
|
╚══════════════════════════════════════════════════════════════╝ |
|
|
|
|
|
MODEL: {model_name} |
|
|
RUN ID: {config.run_id} |
|
|
BEST EPOCH: {best_metrics.get('best_epoch', 0) + 1}/{config.num_epochs} |
|
|
|
|
|
═══════════════════════════════════════════════════════════════ |
|
|
|
|
|
📊 PERFORMANCE BREAKDOWN |
|
|
|
|
|
Final Training Accuracy: {best_metrics.get('final_train_acc', 0.0):.2f}% |
|
|
Best Validation Accuracy: {best_acc:.2f}% |
|
|
|
|
|
Per-Scale Accuracies: |
|
|
""" |
|
|
|
|
|
scale_accs = best_metrics.get('scale_accuracies', {}) |
|
|
for scale in sorted(scale_accs.keys()): |
|
|
acc = scale_accs[scale] |
|
|
summary += f" • Scale {scale:4d}: {acc:.2f}%\n" |
|
|
|
|
|
summary += f""" |
|
|
═══════════════════════════════════════════════════════════════ |
|
|
|
|
|
🏗️ ARCHITECTURE |
|
|
|
|
|
Preset: {config.preset} |
|
|
Sharing Mode: {david_config.sharing_mode} |
|
|
Fusion Mode: {david_config.fusion_mode} |
|
|
Scales: {len(david_config.scales)} scales - {david_config.scales} |
|
|
Feature Dim: {david_config.feature_dim} |
|
|
Parameters: {best_metrics.get('parameters', 0):,} |
|
|
|
|
|
═══════════════════════════════════════════════════════════════ |
|
|
|
|
|
📈 TRAINING CURVE |
|
|
|
|
|
""" |
|
|
|
|
|
if training_history and 'val_acc' in training_history: |
|
|
summary += "Epoch | Train Acc | Val Acc | Learning Rate\n" |
|
|
summary += "------|-----------|----------|--------------\n" |
|
|
|
|
|
for i, epoch in enumerate(training_history.get('epochs', [])): |
|
|
train_acc = training_history['train_acc'][i] if i < len(training_history['train_acc']) else 0 |
|
|
val_acc = training_history['val_acc'][i] if i < len(training_history['val_acc']) else 0 |
|
|
lr = training_history['lr'][i] if i < len(training_history['lr']) else 0 |
|
|
|
|
|
marker = " 👑" if val_acc == best_acc else "" |
|
|
summary += f"{epoch:5d} | {train_acc:8.2f}% | {val_acc:7.2f}%{marker} | {lr:.2e}\n" |
|
|
|
|
|
summary += f""" |
|
|
═══════════════════════════════════════════════════════════════ |
|
|
|
|
|
📁 FILES |
|
|
|
|
|
Best Model: best_model_acc{best_acc:.2f}.safetensors |
|
|
Config: david_config.json |
|
|
Training Cfg: train_config.json |
|
|
History: training_history.json |
|
|
|
|
|
═══════════════════════════════════════════════════════════════ |
|
|
|
|
|
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} |
|
|
""" |
|
|
|
|
|
with open(summary_path, 'w') as f: |
|
|
f.write(summary) |
|
|
|
|
|
print(f"[📄] Created MODEL_SUMMARY.txt") |
|
|
return summary_path |
|
|
|
|
|
|
|
|
def update_models_index( |
|
|
config: DavidTrainingConfig, |
|
|
david_config: DavidArchitectureConfig, |
|
|
best_metrics: Dict, |
|
|
model_name: str |
|
|
): |
|
|
"""Update master models index file tracking all trained models.""" |
|
|
|
|
|
if not config.upload_to_hub or not config.hf_repo: |
|
|
return |
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
try: |
|
|
index_path = hf_hub_download( |
|
|
repo_id=config.hf_repo, |
|
|
filename="MODELS_INDEX.json", |
|
|
repo_type="model" |
|
|
) |
|
|
with open(index_path, 'r') as f: |
|
|
models_index = json.load(f) |
|
|
except: |
|
|
|
|
|
models_index = { |
|
|
"repository": config.hf_repo, |
|
|
"updated": datetime.now().isoformat(), |
|
|
"models": [] |
|
|
} |
|
|
|
|
|
|
|
|
model_entry = { |
|
|
"model_name": model_name, |
|
|
"run_id": config.run_id, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"best_val_acc": best_metrics.get('best_val_acc', 0.0), |
|
|
"best_epoch": best_metrics.get('best_epoch', 0), |
|
|
"num_scales": len(david_config.scales), |
|
|
"scales": david_config.scales, |
|
|
"parameters": best_metrics.get('parameters', 0), |
|
|
"sharing_mode": david_config.sharing_mode, |
|
|
"fusion_mode": david_config.fusion_mode, |
|
|
"preset": config.preset, |
|
|
"weights_path": f"weights/{model_name}/{config.run_id}/best_model_acc{best_metrics.get('best_val_acc', 0.0):.2f}.safetensors", |
|
|
"config_path": f"weights/{model_name}/{config.run_id}/david_config.json", |
|
|
"history_path": f"weights/{model_name}/{config.run_id}/training_history.json" |
|
|
} |
|
|
|
|
|
|
|
|
models_index["models"] = [m for m in models_index["models"] if m.get("run_id") != config.run_id] |
|
|
models_index["models"].append(model_entry) |
|
|
|
|
|
|
|
|
models_index["models"].sort(key=lambda x: x.get("best_val_acc", 0), reverse=True) |
|
|
models_index["updated"] = datetime.now().isoformat() |
|
|
models_index["total_models"] = len(models_index["models"]) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as f: |
|
|
json.dump(models_index, f, indent=2) |
|
|
temp_path = f.name |
|
|
|
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=temp_path, |
|
|
path_in_repo="MODELS_INDEX.json", |
|
|
repo_id=config.hf_repo, |
|
|
commit_message=f"Update models index - {model_name} @ {best_metrics.get('best_val_acc', 0.0):.2f}%" |
|
|
) |
|
|
|
|
|
os.unlink(temp_path) |
|
|
print(f"[📊] Updated MODELS_INDEX.json - {len(models_index['models'])} models tracked") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[⚠️] Failed to update models index: {e}") |
|
|
|
|
|
|
|
|
def upload_to_huggingface( |
|
|
local_dir: str, |
|
|
repo_id: str, |
|
|
commit_message: str, |
|
|
path_in_repo: Optional[str] = None, |
|
|
patterns: Optional[List[str]] = None |
|
|
): |
|
|
"""Upload directory to HuggingFace Hub.""" |
|
|
|
|
|
try: |
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
try: |
|
|
create_repo(repo_id, exist_ok=True, repo_type="model") |
|
|
print(f"[🤗] Repo ready: {repo_id}") |
|
|
except Exception as e: |
|
|
print(f"[⚠️] Repo exists or creation failed: {e}") |
|
|
|
|
|
|
|
|
if patterns: |
|
|
|
|
|
for pattern in patterns: |
|
|
matching_files = list(Path(local_dir).rglob(pattern)) |
|
|
for file_path in matching_files: |
|
|
rel_path = file_path.relative_to(local_dir) |
|
|
if path_in_repo: |
|
|
repo_path = f"{path_in_repo}/{rel_path}" |
|
|
else: |
|
|
repo_path = str(rel_path) |
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=str(file_path), |
|
|
path_in_repo=repo_path, |
|
|
repo_id=repo_id, |
|
|
commit_message=commit_message |
|
|
) |
|
|
else: |
|
|
|
|
|
api.upload_folder( |
|
|
folder_path=local_dir, |
|
|
repo_id=repo_id, |
|
|
path_in_repo=path_in_repo, |
|
|
commit_message=commit_message |
|
|
) |
|
|
|
|
|
print(f"[✅] Uploaded to Hub: https://huggingface.co/{repo_id}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[❌] Hub upload failed: {e}") |
|
|
print(f" Continuing training (files saved locally)") |
|
|
|
|
|
|
|
|
def prepare_hub_upload( |
|
|
weights_dir: str, |
|
|
runs_dir: str, |
|
|
config: DavidTrainingConfig, |
|
|
david_config: DavidArchitectureConfig, |
|
|
best_metrics: Dict, |
|
|
model_name: str |
|
|
): |
|
|
"""Prepare and upload all artifacts to HuggingFace Hub.""" |
|
|
|
|
|
if not config.upload_to_hub or not config.hf_repo: |
|
|
return |
|
|
|
|
|
print("\n[🤗] Preparing HuggingFace Hub upload...") |
|
|
|
|
|
|
|
|
summary_path = create_model_summary(weights_dir, config, david_config, best_metrics, model_name) |
|
|
|
|
|
|
|
|
update_models_index(config, david_config, best_metrics, model_name) |
|
|
|
|
|
api = HfApi() |
|
|
try: |
|
|
create_repo(config.hf_repo, exist_ok=True, repo_type="model") |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
|
|
readme_path = os.path.join(temp_dir, "README.md") |
|
|
readme_content = generate_model_readme(config, david_config, best_metrics, config.run_id) |
|
|
with open(readme_path, 'w') as f: |
|
|
f.write(readme_content) |
|
|
print(f"[📝] Generated README.md") |
|
|
|
|
|
|
|
|
best_json_path = os.path.join(temp_dir, "best_model.json") |
|
|
save_best_model_json(best_json_path, best_metrics, config, david_config) |
|
|
|
|
|
|
|
|
print(f"[📤] Uploading root files...") |
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=readme_path, |
|
|
path_in_repo="README.md", |
|
|
repo_id=config.hf_repo, |
|
|
commit_message=f"Update README - Run {config.run_id}" |
|
|
) |
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=best_json_path, |
|
|
path_in_repo="best_model.json", |
|
|
repo_id=config.hf_repo, |
|
|
commit_message=f"Update metrics - Run {config.run_id}" |
|
|
) |
|
|
|
|
|
|
|
|
weights_repo_path = f"weights/{model_name}/{config.run_id}" |
|
|
best_acc = best_metrics.get('best_val_acc', 0.0) |
|
|
|
|
|
print(f"[📤] Uploading essential files to {weights_repo_path}...") |
|
|
|
|
|
|
|
|
files_to_upload = [ |
|
|
('MODEL_SUMMARY.txt', 'MODEL_SUMMARY.txt'), |
|
|
('training_history.json', 'training_history.json'), |
|
|
('david_config.json', 'david_config.json'), |
|
|
('train_config.json', 'train_config.json'), |
|
|
(f'best_model_acc{best_acc:.2f}.safetensors', f'best_model_acc{best_acc:.2f}.safetensors'), |
|
|
(f'best_model_acc{best_acc:.2f}_metadata.json', f'best_model_acc{best_acc:.2f}_metadata.json'), |
|
|
] |
|
|
|
|
|
for local_filename, repo_filename in files_to_upload: |
|
|
local_path = os.path.join(weights_dir, local_filename) |
|
|
if os.path.exists(local_path): |
|
|
try: |
|
|
api.upload_file( |
|
|
path_or_fileobj=local_path, |
|
|
path_in_repo=f"{weights_repo_path}/{repo_filename}", |
|
|
repo_id=config.hf_repo, |
|
|
commit_message=f"Update {repo_filename} - Run {config.run_id}" |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"[⚠️] Failed to upload {repo_filename}: {e}") |
|
|
|
|
|
print(f"[✅] Uploaded to Hub: https://huggingface.co/{config.hf_repo}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint( |
|
|
filepath: str, |
|
|
david: David, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], |
|
|
epoch: int, |
|
|
metrics: Dict, |
|
|
train_config: DavidTrainingConfig |
|
|
): |
|
|
"""Save checkpoint in PyTorch and/or SafeTensors format.""" |
|
|
|
|
|
checkpoint = { |
|
|
'epoch': epoch, |
|
|
'model_state_dict': david.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'scheduler_state_dict': scheduler.state_dict() if scheduler else None, |
|
|
'metrics': metrics, |
|
|
'train_config': train_config.to_dict(), |
|
|
} |
|
|
|
|
|
|
|
|
val_acc = metrics.get('best_val_acc') or metrics.get('val_acc') |
|
|
if val_acc: |
|
|
acc_suffix = f"_acc{val_acc:.2f}" |
|
|
filepath = filepath + acc_suffix |
|
|
|
|
|
if train_config.save_format in ['pytorch', 'both']: |
|
|
torch.save(checkpoint, filepath + '.pth') |
|
|
print(f"[💾] Saved PyTorch: {filepath}.pth") |
|
|
|
|
|
if train_config.save_format in ['safetensors', 'both']: |
|
|
try: |
|
|
from safetensors.torch import save_file |
|
|
|
|
|
|
|
|
model_state = {k: v.contiguous() for k, v in david.state_dict().items()} |
|
|
save_file(model_state, filepath + '.safetensors') |
|
|
|
|
|
|
|
|
metadata = {k: v for k, v in checkpoint.items() |
|
|
if k not in ['model_state_dict']} |
|
|
with open(filepath + '_metadata.json', 'w') as f: |
|
|
json.dump(metadata, f, indent=2, default=str) |
|
|
|
|
|
print(f"[💾] Saved SafeTensors: {filepath}.safetensors") |
|
|
except ImportError: |
|
|
print(f"[⚠️] SafeTensors not available, skipping") |
|
|
|
|
|
|
|
|
def load_checkpoint( |
|
|
checkpoint_path: str, |
|
|
david: David, |
|
|
optimizer: Optional[torch.optim.Optimizer] = None, |
|
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, |
|
|
device: str = "cuda" |
|
|
) -> Tuple[int, Dict]: |
|
|
"""Load checkpoint and return epoch and metrics.""" |
|
|
|
|
|
if checkpoint_path.endswith('.safetensors'): |
|
|
|
|
|
try: |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
model_state = load_file(checkpoint_path, device=device) |
|
|
david.load_state_dict(model_state) |
|
|
|
|
|
|
|
|
metadata_path = checkpoint_path.replace('.safetensors', '_metadata.json') |
|
|
with open(metadata_path, 'r') as f: |
|
|
metadata = json.load(f) |
|
|
|
|
|
epoch = metadata.get('epoch', 0) |
|
|
metrics = metadata.get('metrics', {}) |
|
|
|
|
|
if optimizer and 'optimizer_state_dict' in metadata: |
|
|
optimizer.load_state_dict(metadata['optimizer_state_dict']) |
|
|
|
|
|
if scheduler and 'scheduler_state_dict' in metadata and metadata['scheduler_state_dict']: |
|
|
scheduler.load_state_dict(metadata['scheduler_state_dict']) |
|
|
|
|
|
print(f"[✅] Loaded from SafeTensors: {checkpoint_path}") |
|
|
return epoch, metrics |
|
|
|
|
|
except ImportError: |
|
|
raise ImportError("safetensors not installed") |
|
|
|
|
|
else: |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
david.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
if optimizer and 'optimizer_state_dict' in checkpoint: |
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
|
|
|
if scheduler and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']: |
|
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
|
|
|
print(f"[✅] Loaded from PyTorch: {checkpoint_path}") |
|
|
return checkpoint['epoch'], checkpoint.get('metrics', {}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImageNetHFDataset(Dataset): |
|
|
"""PyTorch Dataset wrapper for HuggingFace ImageNet features.""" |
|
|
|
|
|
def __init__(self, dataset_name: str, model_variant: str, split: str = "train"): |
|
|
|
|
|
print(f"[📥] Loading {split} split for {model_variant}...") |
|
|
self.dataset = load_dataset( |
|
|
dataset_name, |
|
|
name=model_variant, |
|
|
split=split |
|
|
) |
|
|
self.length = len(self.dataset) |
|
|
print(f"[✅] Loaded {self.length:,} samples from {split} split") |
|
|
|
|
|
def __len__(self): |
|
|
return self.length |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
item = self.dataset[idx] |
|
|
features = torch.tensor(item['clip_features'], dtype=torch.float32) |
|
|
label = torch.tensor(item['label'], dtype=torch.long) |
|
|
return features, label |
|
|
|
|
|
|
|
|
class MergedImageNetDataset(Dataset): |
|
|
""" |
|
|
Merge multiple CLIP variants into a single dataset. |
|
|
Perfect for testing if David can unify different encoder spaces! |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset_name: str, |
|
|
model_variants: List[str], |
|
|
split: str = "train", |
|
|
shuffle_seed: int = 42 |
|
|
): |
|
|
print(f"[🔀] Creating merged dataset from {len(model_variants)} variants...") |
|
|
|
|
|
self.datasets = [] |
|
|
self.cumulative_lengths = [0] |
|
|
|
|
|
|
|
|
for variant in model_variants: |
|
|
print(f"[📥] Loading {split} split for {variant}...") |
|
|
ds = load_dataset( |
|
|
dataset_name, |
|
|
name=variant, |
|
|
split=split |
|
|
) |
|
|
self.datasets.append(ds) |
|
|
self.cumulative_lengths.append(self.cumulative_lengths[-1] + len(ds)) |
|
|
print(f"[✅] Loaded {len(ds):,} samples from {variant}") |
|
|
|
|
|
self.total_length = self.cumulative_lengths[-1] |
|
|
|
|
|
|
|
|
print(f"[🎲] Shuffling {self.total_length:,} samples (seed={shuffle_seed})...") |
|
|
rng = np.random.RandomState(shuffle_seed) |
|
|
self.shuffle_indices = rng.permutation(self.total_length) |
|
|
|
|
|
print(f"[✅] Merged dataset ready: {self.total_length:,} samples from {len(model_variants)} encoders") |
|
|
|
|
|
def __len__(self): |
|
|
return self.total_length |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
actual_idx = int(self.shuffle_indices[idx]) |
|
|
|
|
|
|
|
|
dataset_idx = 0 |
|
|
for i, cumsum in enumerate(self.cumulative_lengths[1:]): |
|
|
if actual_idx < cumsum: |
|
|
dataset_idx = i |
|
|
break |
|
|
|
|
|
|
|
|
local_idx = actual_idx - self.cumulative_lengths[dataset_idx] |
|
|
item = self.datasets[dataset_idx][local_idx] |
|
|
|
|
|
features = torch.tensor(item['clip_features'], dtype=torch.float32) |
|
|
label = torch.tensor(item['label'], dtype=torch.long) |
|
|
|
|
|
return features, label |
|
|
|
|
|
|
|
|
def create_dataloaders(config: DavidTrainingConfig): |
|
|
"""Create train and validation dataloaders.""" |
|
|
|
|
|
|
|
|
if isinstance(config.model_variant, list): |
|
|
print(f"[🧪] MULTI-ENCODER EXPERIMENT: Merging {len(config.model_variant)} variants") |
|
|
train_dataset = MergedImageNetDataset( |
|
|
config.dataset_name, |
|
|
config.model_variant, |
|
|
"train" |
|
|
) |
|
|
val_dataset = MergedImageNetDataset( |
|
|
config.dataset_name, |
|
|
config.model_variant, |
|
|
"validation" |
|
|
) |
|
|
else: |
|
|
|
|
|
train_dataset = ImageNetHFDataset( |
|
|
config.dataset_name, config.model_variant, "train" |
|
|
) |
|
|
val_dataset = ImageNetHFDataset( |
|
|
config.dataset_name, config.model_variant, "validation" |
|
|
) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=config.pin_memory, |
|
|
prefetch_factor=config.prefetch_factor, |
|
|
persistent_workers=config.persistent_workers |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=config.batch_size * 2, |
|
|
shuffle=False, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=config.pin_memory, |
|
|
prefetch_factor=config.prefetch_factor, |
|
|
persistent_workers=config.persistent_workers |
|
|
) |
|
|
|
|
|
return train_loader, val_loader |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrystalGenerator: |
|
|
"""Generate crystals for all scales.""" |
|
|
|
|
|
def __init__(self, num_classes: int, scales: List[int], device: str = "cuda"): |
|
|
self.num_classes = num_classes |
|
|
self.scales = scales |
|
|
self.device = device |
|
|
self.factories = { |
|
|
scale: SimplexFactory(k=4, embed_dim=scale, method="random") |
|
|
for scale in scales |
|
|
} |
|
|
|
|
|
def generate(self, seed: int = 42) -> Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor]]: |
|
|
"""Generate anchors and crystals for all scales.""" |
|
|
|
|
|
anchors_dict = {} |
|
|
crystals_dict = {} |
|
|
|
|
|
for scale in tqdm(self.scales, desc="Generating crystals"): |
|
|
factory = self.factories[scale] |
|
|
batch_crystals = [] |
|
|
|
|
|
for class_idx in range(self.num_classes): |
|
|
crystal = factory.build( |
|
|
backend="torch", |
|
|
device=self.device, |
|
|
dtype=torch.float32, |
|
|
seed=seed + class_idx, |
|
|
validate=True |
|
|
) |
|
|
batch_crystals.append(crystal) |
|
|
|
|
|
crystals = torch.stack(batch_crystals) |
|
|
anchors = F.normalize(crystals[:, 0, :], dim=-1) |
|
|
|
|
|
|
|
|
anchor_sims = anchors @ anchors.T |
|
|
off_diag = anchor_sims[~torch.eye(self.num_classes, dtype=bool, device=anchors.device)] |
|
|
max_sim = off_diag.max().item() |
|
|
mean_sim = off_diag.mean().item() |
|
|
|
|
|
print(f" Scale {scale}: max_sim={max_sim:.4f}, mean_sim={mean_sim:.4f}") |
|
|
|
|
|
if max_sim > 0.99: |
|
|
print(f" ⚠️ WARNING: Anchors too similar at scale {scale}!") |
|
|
|
|
|
anchors_dict[scale] = anchors |
|
|
crystals_dict[scale] = crystals |
|
|
|
|
|
return anchors_dict, crystals_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_epoch( |
|
|
david: David, |
|
|
train_loader: DataLoader, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
criterion: MultiScaleCrystalLoss, |
|
|
anchors_dict: Dict[int, torch.Tensor], |
|
|
crystals_dict: Dict[int, torch.Tensor], |
|
|
epoch: int, |
|
|
config: DavidTrainingConfig, |
|
|
writer: Optional[SummaryWriter], |
|
|
global_step: int |
|
|
) -> Tuple[float, float, int, Dict]: |
|
|
"""Train for one epoch - Pure FP32.""" |
|
|
|
|
|
david.train() |
|
|
david.update_epoch(epoch) |
|
|
|
|
|
total_loss = 0 |
|
|
correct = 0 |
|
|
total = 0 |
|
|
loss_components_sum = {} |
|
|
|
|
|
active_scales = david.get_active_scales() |
|
|
|
|
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}") |
|
|
|
|
|
for batch_idx, (features, labels) in enumerate(pbar): |
|
|
features = features.cuda(non_blocking=True) |
|
|
labels = labels.cuda(non_blocking=True) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
combined, logits_list, features_list, fusion_weights = david( |
|
|
features, anchors_dict, return_all_scales=True |
|
|
) |
|
|
|
|
|
|
|
|
losses = criterion( |
|
|
combined, logits_list, features_list, |
|
|
labels, crystals_dict, epoch |
|
|
) |
|
|
|
|
|
|
|
|
losses['total'].backward() |
|
|
|
|
|
|
|
|
if config.track_gradients and batch_idx % config.log_interval == 0: |
|
|
grad_stats = analyze_gradients(david, config) |
|
|
if writer: |
|
|
step = global_step + batch_idx |
|
|
writer.add_scalar('train/grad_mean', grad_stats['mean'], step) |
|
|
writer.add_scalar('train/grad_max', grad_stats['max'], step) |
|
|
writer.add_scalar('train/grad_num_small', grad_stats['num_small'], step) |
|
|
|
|
|
|
|
|
scale_small_gradients(david, config) |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(david.parameters(), config.gradient_clip) |
|
|
|
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
total_loss += losses['total'].item() |
|
|
_, predicted = torch.max(combined, 1) |
|
|
total += labels.size(0) |
|
|
correct += (predicted == labels).sum().item() |
|
|
|
|
|
|
|
|
for key, value in losses.items(): |
|
|
if key not in loss_components_sum: |
|
|
loss_components_sum[key] = 0.0 |
|
|
loss_components_sum[key] += value.item() |
|
|
|
|
|
|
|
|
if writer and batch_idx % config.log_interval == 0: |
|
|
step = global_step + batch_idx |
|
|
writer.add_scalar('train/loss_batch', losses['total'].item(), step) |
|
|
writer.add_scalar('train/acc_batch', 100 * correct / total, step) |
|
|
|
|
|
if config.log_loss_components: |
|
|
for key, value in losses.items(): |
|
|
if key != 'total': |
|
|
writer.add_scalar(f'train/loss_{key}', value.item(), step) |
|
|
|
|
|
if config.log_fusion_weights and fusion_weights is not None: |
|
|
if fusion_weights.dim() == 2: |
|
|
mean_weights = fusion_weights.mean(dim=0) |
|
|
for i, w in enumerate(mean_weights): |
|
|
if i < len(active_scales): |
|
|
writer.add_scalar( |
|
|
f'train/fusion_weight_{active_scales[i]}', |
|
|
w.item(), step |
|
|
) |
|
|
|
|
|
writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], step) |
|
|
|
|
|
pbar.set_postfix({ |
|
|
'loss': f'{total_loss / (batch_idx + 1):.4f}', |
|
|
'acc': f'{100 * correct / total:.2f}%' |
|
|
}) |
|
|
|
|
|
global_step += 1 |
|
|
|
|
|
|
|
|
avg_components = {k: v / len(train_loader) for k, v in loss_components_sum.items()} |
|
|
|
|
|
return ( |
|
|
total_loss / len(train_loader), |
|
|
100 * correct / total, |
|
|
global_step, |
|
|
avg_components |
|
|
) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def validate( |
|
|
david: David, |
|
|
val_loader: DataLoader, |
|
|
anchors_dict: Dict[int, torch.Tensor], |
|
|
config: DavidTrainingConfig |
|
|
) -> Tuple[float, Dict[int, float]]: |
|
|
"""Validate model - Pure FP32.""" |
|
|
|
|
|
david.eval() |
|
|
|
|
|
correct = 0 |
|
|
total = 0 |
|
|
active_scales = david.get_active_scales() |
|
|
scale_correct = {scale: 0 for scale in active_scales} |
|
|
|
|
|
for features, labels in tqdm(val_loader, desc="Validation", leave=False): |
|
|
features = features.cuda(non_blocking=True) |
|
|
labels = labels.cuda(non_blocking=True) |
|
|
|
|
|
|
|
|
combined, logits_list, _, _ = david( |
|
|
features, anchors_dict, return_all_scales=True |
|
|
) |
|
|
|
|
|
_, predicted = torch.max(combined, 1) |
|
|
total += labels.size(0) |
|
|
correct += (predicted == labels).sum().item() |
|
|
|
|
|
for i, scale in enumerate(active_scales): |
|
|
if i < len(logits_list): |
|
|
_, scale_pred = torch.max(logits_list[i], 1) |
|
|
scale_correct[scale] += (scale_pred == labels).sum().item() |
|
|
|
|
|
accuracy = 100 * correct / total |
|
|
scale_accs = {s: 100 * scale_correct[s] / total for s in scale_correct} |
|
|
|
|
|
return accuracy, scale_accs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_david(config: DavidTrainingConfig): |
|
|
"""Main training pipeline.""" |
|
|
|
|
|
|
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
print("="*80) |
|
|
print("🌟 DAVID TRAINING PIPELINE") |
|
|
print("="*80) |
|
|
print(f"Run ID: {config.run_id}") |
|
|
print(f"Preset: {config.preset}") |
|
|
print(f"Batch Size: {config.batch_size}") |
|
|
print(f"Learning Rate: {config.learning_rate}") |
|
|
print(f"Mixed Precision: {config.use_mixed_precision}") |
|
|
print(f"TensorFloat32: Enabled (high precision)") |
|
|
print("="*80) |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
if config.custom_config_path: |
|
|
david_config = DavidArchitectureConfig.from_json(config.custom_config_path) |
|
|
print(f"[📁] Loaded custom config: {config.custom_config_path}") |
|
|
elif config.preset: |
|
|
david_config = DavidPresets.get_preset(config.preset) |
|
|
print(f"[⚙️] Using preset: {config.preset}") |
|
|
else: |
|
|
raise ValueError("Must specify either preset or custom_config_path") |
|
|
|
|
|
|
|
|
model_name = f"David-{david_config.sharing_mode}-{david_config.fusion_mode}" |
|
|
print(f"[🏷️] Model: {model_name}") |
|
|
|
|
|
|
|
|
weights_dir = os.path.join(config.base_dir, "weights", model_name, config.run_id) |
|
|
runs_dir = os.path.join(config.base_dir, "runs", model_name, config.run_id) |
|
|
os.makedirs(weights_dir, exist_ok=True) |
|
|
os.makedirs(runs_dir, exist_ok=True) |
|
|
|
|
|
print(f"[📁] Weights: {weights_dir}") |
|
|
print(f"[📁] Logs: {runs_dir}") |
|
|
|
|
|
writer = SummaryWriter(runs_dir) |
|
|
|
|
|
|
|
|
if config.num_classes_override: |
|
|
david_config.num_classes = config.num_classes_override |
|
|
if config.use_belly_override is not None: |
|
|
david_config.use_belly = config.use_belly_override |
|
|
if config.belly_expand_override is not None: |
|
|
david_config.belly_expand = config.belly_expand_override |
|
|
if config.progressive_training_override is not None: |
|
|
david_config.progressive_training = config.progressive_training_override |
|
|
if not david_config.progressive_training: |
|
|
|
|
|
david_config.scale_warmup_epochs = {s: 0 for s in david_config.scales} |
|
|
|
|
|
|
|
|
if config.scale_warmup_epochs_override is not None: |
|
|
david_config.scale_warmup_epochs = config.scale_warmup_epochs_override |
|
|
|
|
|
if not david_config.progressive_training: |
|
|
print(f"[⚙️] Enabling progressive training (custom warmup schedule provided)") |
|
|
david_config.progressive_training = True |
|
|
|
|
|
print(f"[⚙️] Progressive training: {david_config.progressive_training}") |
|
|
if david_config.progressive_training: |
|
|
print(f" Scale warmup schedule: {david_config.scale_warmup_epochs}") |
|
|
|
|
|
|
|
|
david_config_path = os.path.join(weights_dir, "david_config.json") |
|
|
david_config.to_json(david_config_path) |
|
|
print(f"[💾] Saved David config: {david_config_path}") |
|
|
|
|
|
train_config_path = os.path.join(weights_dir, "train_config.json") |
|
|
config.to_json(train_config_path) |
|
|
print(f"[💾] Saved training config: {train_config_path}") |
|
|
|
|
|
|
|
|
david = David.from_config(david_config).cuda() |
|
|
print(f"\n{david}\n") |
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in david.parameters()) |
|
|
trainable_params = sum(p.numel() for p in david.parameters() if p.requires_grad) |
|
|
print(f"[📊] Total Parameters: {total_params:,}") |
|
|
print(f"[📊] Trainable Parameters: {trainable_params:,}") |
|
|
|
|
|
|
|
|
train_loader, val_loader = create_dataloaders(config) |
|
|
|
|
|
|
|
|
crystal_gen = CrystalGenerator( |
|
|
david_config.num_classes, |
|
|
david_config.scales, |
|
|
str(device) |
|
|
) |
|
|
anchors_dict, crystals_dict = crystal_gen.generate() |
|
|
|
|
|
|
|
|
criterion = MultiScaleCrystalLoss( |
|
|
scales=david_config.scales, |
|
|
num_classes=david_config.num_classes, |
|
|
use_rose_loss=config.use_rose_loss, |
|
|
use_cayley_loss=config.use_cayley_loss, |
|
|
rose_initial_weight=config.rose_initial_weight, |
|
|
rose_max_weight=config.rose_max_weight, |
|
|
cayley_weight=config.cayley_weight, |
|
|
scale_loss_balance=config.scale_loss_balance |
|
|
).cuda() |
|
|
|
|
|
optimizer = create_optimizer(david, config) |
|
|
scheduler = create_scheduler(optimizer, config) |
|
|
|
|
|
controller = AdaptiveTrainingController(david, config) |
|
|
|
|
|
|
|
|
best_val_acc = 0.0 |
|
|
best_epoch = 0 |
|
|
best_scale_accs = {} |
|
|
global_step = 0 |
|
|
final_train_acc = 0.0 |
|
|
final_train_loss = 0.0 |
|
|
|
|
|
|
|
|
training_history = { |
|
|
'epochs': [], |
|
|
'train_loss': [], |
|
|
'train_acc': [], |
|
|
'val_acc': [], |
|
|
'scale_accs': {}, |
|
|
'lr': [] |
|
|
} |
|
|
|
|
|
|
|
|
print("\n[🔍] Running diagnostic forward/backward pass...") |
|
|
david.train() |
|
|
|
|
|
|
|
|
for features_test, labels_test in train_loader: |
|
|
features_test = features_test.cuda(non_blocking=True)[:8] |
|
|
labels_test = labels_test.cuda(non_blocking=True)[:8] |
|
|
|
|
|
|
|
|
combined_test, logits_test, features_test_out, _ = david( |
|
|
features_test, anchors_dict, return_all_scales=True |
|
|
) |
|
|
|
|
|
|
|
|
losses_test = criterion( |
|
|
combined_test, logits_test, features_test_out, |
|
|
labels_test, crystals_dict, epoch=0 |
|
|
) |
|
|
|
|
|
print(f" Initial loss: {losses_test['total'].item():.6f}") |
|
|
print(f" Loss components:") |
|
|
for key, value in losses_test.items(): |
|
|
if key != 'total': |
|
|
print(f" {key}: {value.item():.6f}") |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
losses_test['total'].backward() |
|
|
|
|
|
|
|
|
grad_count = sum(1 for p in david.parameters() if p.grad is not None and p.grad.norm() > 0) |
|
|
total_grad_params = sum(1 for p in david.parameters() if p.requires_grad) |
|
|
print(f" Parameters with non-zero gradients: {grad_count}/{total_grad_params}") |
|
|
|
|
|
if grad_count == 0: |
|
|
print(f" ❌ ERROR: No gradients! Training will not work.") |
|
|
return None, 0.0 |
|
|
elif grad_count < total_grad_params * 0.5: |
|
|
print(f" ⚠️ WARNING: Less than 50% of parameters have gradients") |
|
|
else: |
|
|
print(f" ✅ Gradients look good") |
|
|
|
|
|
break |
|
|
|
|
|
print("\n[🚀] Starting training...\n") |
|
|
|
|
|
for epoch in range(config.num_epochs): |
|
|
epoch_start = time.time() |
|
|
|
|
|
|
|
|
train_loss, train_acc, global_step, loss_components = train_epoch( |
|
|
david, train_loader, optimizer, criterion, |
|
|
anchors_dict, crystals_dict, epoch, config, |
|
|
writer, global_step |
|
|
) |
|
|
|
|
|
|
|
|
val_acc, scale_accs = validate(david, val_loader, anchors_dict, config) |
|
|
|
|
|
|
|
|
controller.update_metrics(scale_accs, val_acc) |
|
|
controller.apply_adaptive_strategies(scale_accs, epoch) |
|
|
|
|
|
|
|
|
if scheduler: |
|
|
scheduler.step() |
|
|
|
|
|
epoch_time = time.time() - epoch_start |
|
|
|
|
|
|
|
|
print(f"\n📊 Epoch {epoch+1}/{config.num_epochs} ({epoch_time:.1f}s)") |
|
|
print(f" Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%") |
|
|
print(f" Val: Acc={val_acc:.2f}% (Best: {best_val_acc:.2f}%)") |
|
|
print(f" Active scales: {david.get_active_scales()}") |
|
|
print(f" LR: {optimizer.param_groups[0]['lr']:.2e}") |
|
|
|
|
|
if config.log_loss_components and loss_components: |
|
|
print(f" Loss breakdown:") |
|
|
for key, value in sorted(loss_components.items()): |
|
|
if key != 'total': |
|
|
print(f" {key:20s}: {value:.6f}") |
|
|
|
|
|
for scale, acc in scale_accs.items(): |
|
|
frozen = "❄️" if controller.scales_frozen.get(scale, False) else "🔥" |
|
|
print(f" {frozen} Scale {scale}: {acc:.2f}%") |
|
|
|
|
|
|
|
|
final_train_acc = train_acc |
|
|
final_train_loss = train_loss |
|
|
|
|
|
|
|
|
training_history['epochs'].append(epoch + 1) |
|
|
training_history['train_loss'].append(train_loss) |
|
|
training_history['train_acc'].append(train_acc) |
|
|
training_history['val_acc'].append(val_acc) |
|
|
training_history['lr'].append(optimizer.param_groups[0]['lr']) |
|
|
|
|
|
|
|
|
for scale, acc in scale_accs.items(): |
|
|
if scale not in training_history['scale_accs']: |
|
|
training_history['scale_accs'][scale] = [] |
|
|
training_history['scale_accs'][scale].append(acc) |
|
|
|
|
|
|
|
|
writer.add_scalar('train/loss', train_loss, epoch) |
|
|
writer.add_scalar('train/acc', train_acc, epoch) |
|
|
writer.add_scalar('val/acc', val_acc, epoch) |
|
|
|
|
|
for scale, acc in scale_accs.items(): |
|
|
writer.add_scalar(f'val/acc_scale_{scale}', acc, epoch) |
|
|
|
|
|
|
|
|
if val_acc > best_val_acc: |
|
|
best_val_acc = val_acc |
|
|
best_epoch = epoch |
|
|
best_scale_accs = scale_accs.copy() |
|
|
|
|
|
|
|
|
history_path = os.path.join(weights_dir, 'training_history.json') |
|
|
with open(history_path, 'w') as f: |
|
|
json.dump(training_history, f, indent=2) |
|
|
|
|
|
save_checkpoint( |
|
|
os.path.join(weights_dir, 'best_model'), |
|
|
david, optimizer, scheduler, epoch, |
|
|
{ |
|
|
'best_val_acc': best_val_acc, |
|
|
'best_epoch': best_epoch, |
|
|
'scale_accuracies': best_scale_accs, |
|
|
'training_history': training_history |
|
|
}, |
|
|
config |
|
|
) |
|
|
|
|
|
|
|
|
if config.upload_to_hub: |
|
|
best_metrics = { |
|
|
'best_val_acc': best_val_acc, |
|
|
'best_epoch': best_epoch, |
|
|
'scale_accuracies': best_scale_accs, |
|
|
'final_train_acc': train_acc, |
|
|
'final_train_loss': train_loss, |
|
|
'training_history': training_history, |
|
|
'parameters': total_params |
|
|
} |
|
|
prepare_hub_upload(weights_dir, runs_dir, config, david_config, best_metrics, model_name) |
|
|
|
|
|
|
|
|
if (epoch + 1) % config.save_interval == 0: |
|
|
save_checkpoint( |
|
|
os.path.join(weights_dir, f'checkpoint_epoch_{epoch+1}'), |
|
|
david, optimizer, scheduler, epoch, |
|
|
{'val_acc': val_acc}, |
|
|
config |
|
|
) |
|
|
|
|
|
|
|
|
save_checkpoint( |
|
|
os.path.join(weights_dir, 'final_model'), |
|
|
david, optimizer, scheduler, config.num_epochs - 1, |
|
|
{'final_val_acc': val_acc}, |
|
|
config |
|
|
) |
|
|
|
|
|
writer.close() |
|
|
|
|
|
|
|
|
if config.upload_to_hub: |
|
|
print("\n[🤗] Performing final HuggingFace Hub upload...") |
|
|
final_metrics = { |
|
|
'best_val_acc': best_val_acc, |
|
|
'best_epoch': best_epoch, |
|
|
'scale_accuracies': best_scale_accs, |
|
|
'final_train_acc': final_train_acc, |
|
|
'final_train_loss': final_train_loss, |
|
|
'training_history': training_history, |
|
|
'parameters': total_params |
|
|
} |
|
|
prepare_hub_upload(weights_dir, runs_dir, config, david_config, final_metrics, model_name) |
|
|
|
|
|
|
|
|
if os.path.exists(runs_dir): |
|
|
runs_repo_path = f"runs/{model_name}/{config.run_id}" |
|
|
print(f"[📤] Uploading TensorBoard logs to {runs_repo_path}...") |
|
|
upload_to_huggingface( |
|
|
local_dir=runs_dir, |
|
|
repo_id=config.hf_repo, |
|
|
commit_message=f"Upload TensorBoard logs - {model_name} - Run {config.run_id}", |
|
|
path_in_repo=runs_repo_path |
|
|
) |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print(f"🎉 Training Complete!") |
|
|
print(f" Best Val Acc: {best_val_acc:.2f}% (Epoch {best_epoch+1})") |
|
|
print(f" Final Train Acc: {final_train_acc:.2f}%") |
|
|
print(f" Weights: {weights_dir}") |
|
|
if config.upload_to_hub: |
|
|
print(f" Hub: https://huggingface.co/{config.hf_repo}") |
|
|
print("="*80) |
|
|
|
|
|
return david, best_val_acc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = DavidTrainingConfig( |
|
|
preset="balanced", |
|
|
|
|
|
|
|
|
model_variant=["clip_vit_b16", "clip_vit_laion_b32"], |
|
|
|
|
|
num_epochs=10, |
|
|
batch_size=1024, |
|
|
learning_rate=1e-2, |
|
|
|
|
|
|
|
|
scale_warmup_epochs_override={ |
|
|
256: 0, |
|
|
512: 2, |
|
|
768: 5, |
|
|
1024: 8 |
|
|
}, |
|
|
|
|
|
use_rose_loss=True, |
|
|
rose_initial_weight=0.2, |
|
|
rose_max_weight=0.8, |
|
|
|
|
|
use_cayley_loss=True, |
|
|
cayley_weight=0.01, |
|
|
|
|
|
freeze_strategy="never", |
|
|
gradient_clip=10.0, |
|
|
|
|
|
save_format="safetensors", |
|
|
upload_to_hub=False, |
|
|
hf_repo="YourName/YourRepoHere" |
|
|
) |
|
|
|
|
|
print("="*80) |
|
|
print("🧪 UNIFIED SPACE EXPERIMENT") |
|
|
print("="*80) |
|
|
print(f"Testing if David can unify:") |
|
|
if isinstance(config.model_variant, list): |
|
|
for variant in config.model_variant: |
|
|
print(f" • {variant}") |
|
|
print("="*80) |
|
|
|
|
|
david, best_acc = train_david(config) |