--- library_name: transformers tags: - image-classification - computer-vision - vit - vision-transformer - orthogonal-residual-updates - imagenet license: cc-by-sa-4.0 pipeline_tag: image-classification results: - task: type: image-classification dataset: name: ImageNet-1k type: ImageNet-1k metrics: - name: Validation Accuracy Top@1 type: Validation Accuracy Top@1 value: 74.62 --- # Model Card for OrthoViT-B ImageNet-1k This model is a Vision Transformer (ViT-B) trained on [ImageNet-1k](https://huggingface.co/datasets/timm/imagenet-1k-wds), incorporating _Orthogonal Residual Updates_ as proposed in the paper [Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks](https://arxiv.org/abs/2505.11881). The core idea is to decompose a module's output relative to the input stream and add only the component orthogonal to this stream, aiming for richer feature learning and more efficient training. This specific checkpoint was trained for approximately 90,000 steps (roughly 270 epochs out of a planned 300). ## Model Details ### Evaluation _**Note:** Validation accuracy below is measured on checkpoint at step 90k (not the final model); results may differ slightly from those reported in the paper._ | Steps | Connection | Top-1 Accuracy (%) | Top-5 Accuracy (%) | Link | |-------|-------------|--------------------|---------------------|------| | 90k | Orthogonal | **74.62** | **92.26** | [here](https://huggingface.co/BootsofLagrangian/ortho-vit-b-imagenet1k-hf) | | 90k | Linear | 71.23 | 90.29 | [link](https://huggingface.co/BootsofLagrangian/linear-vit-b-imagenet1k-hf) | ### Abstract Residual connections are pivotal for deep neural networks, enabling greater depth by mitigating vanishing gradients. However, in standard residual updates, the module's output is directly added to the input stream. This can lead to updates that predominantly reinforce or modulate the existing stream direction, potentially underutilizing the module's capacity for learning entirely novel features. In this work, we introduce _Orthogonal Residual Update_: we decompose the module's output relative to the input stream and add only the component orthogonal to this stream. This design aims to guide modules to contribute primarily new representational directions, fostering richer feature learning while promoting more efficient training. We demonstrate that our orthogonal update strategy improves generalization accuracy and training stability across diverse architectures (ResNetV2, Vision Transformers) and datasets (CIFARs, TinyImageNet, ImageNet-1k), achieving, for instance, a +4.3\%p top-1 accuracy gain for ViT-B on ImageNet-1k. ### Method Overview Our core idea is to modify the standard residual update $x_{n+1} = x_n + f(\sigma(x_n))$ by projecting out the component of $f(\sigma(x_n))$ that is parallel to $x_n$. The update then becomes $x_{n+1} = x_n + f_{\perp}(x_n)$, where $f_{\perp}(x_n)$ is the component of $f(\sigma(x_n))$ orthogonal to $x_n$. ![Figure 1: Intuition behind Orthogonal Residual Update](img/figure1.jpg) *Figure 1: (Left) Standard residual update. (Right) Our Orthogonal Residual Update, which discards the parallel component $f_{||}$ and adds only the orthogonal component $f_{\perp}$.* This approach aims to ensure that each module primarily contributes new information to the residual stream, enhancing representational diversity and mitigating potential interference from updates that merely rescale or oppose the existing stream. ### Key Results: Stable and Efficient Learning Our Orthogonal Residual Update strategy leads to more stable training dynamics and improved learning efficiency. For example, models trained with our method often exhibit faster convergence to better generalization performance, as illustrated by comparative training curves. ![Figure 2: Training Dynamics and Efficiency Comparison](img/figure2.jpg) *Figure 2: Example comparison (e.g., ViT-B on ImageNet-1k) showing Orthogonal Residual Update (blue) achieving lower training loss and higher validation accuracy in less wall-clock time compared to linear residual updates (red).* ### Model Sources - **Repository (Original Implementation):** [https://github.com/BootsofLagrangian/ortho-residual](https://github.com/BootsofLagrangian/ortho-residual) - **Paper:** [Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks (arXiv:2505.11881)](https://arxiv.org/abs/2505.11881) ## Evaluation ```python import torch import torchvision.transforms as transforms from datasets import load_dataset from torch.utils.data import DataLoader from transformers import AutoModelForImageClassification from tqdm import tqdm import argparse from typing import Tuple, List def accuracy_counts( logits: torch.Tensor, target: torch.Tensor, topk: Tuple[int, ...] = (1, 5), ) -> List[int]: """ Given model outputs and targets, return a list of correct-counts for each k in topk. """ maxk = max(topk) _, pred = logits.topk(maxk, dim=1, largest=True, sorted=True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.item()) return res def evaluate_model(): device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu") print(f"Using device: {device}") model = AutoModelForImageClassification.from_pretrained( "BootsofLagrangian/ortho-vit-b-imagenet1k-hf", trust_remote_code=True ) model.to(device) model.eval() img_size = 224 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] transform_eval = transforms.Compose([ transforms.Lambda(lambda img: img.convert("RGB")), transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize(mean, std), ]) val_dataset = load_dataset("timm/imagenet-1k-wds", split="validation") def collate_fn(batch): images = torch.stack([transform_eval(item['jpg']) for item in batch]) labels = torch.tensor([item['cls'] for item in batch]) return images, labels val_loader = DataLoader( val_dataset, batch_size=32, shuffle=False, num_workers=4, collate_fn=collate_fn, pin_memory=True ) total_samples, correct_top1, correct_top5 = 0, 0, 0 with torch.no_grad(): for images, labels in tqdm(val_loader, desc="Evaluating"): images = images.to(device) labels = labels.to(device) outputs = model(pixel_values=images) logits = outputs.logits counts = accuracy_counts(logits, labels, topk=(1, 5)) correct_top1 += counts[0] correct_top5 += counts[1] total_samples += images.size(0) top1_accuracy = (correct_top1 / total_samples) * 100 top5_accuracy = (correct_top5 / total_samples) * 100 print("\n--- Evaluation Results ---") print(f"Total samples evaluated: {total_samples}") print(f"Top-1 Accuracy: {top1_accuracy:.2f}%") print(f"Top-5 Accuracy: {top5_accuracy:.2f}%") ``` ## Citation ```bib @article{oh2025revisitingresidualconnectionsorthogonal, title={Revisiting Residual Connections: Orthogonal Updates for Stable and Efficient Deep Networks}, author={Giyeong Oh and Woohyun Cho and Siyeol Kim and Suhwan Choi and Younjae Yu}, year={2025}, journal={arXiv preprint arXiv:2505.11881}, eprint={2505.11881}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2505.11881} } ```