HyperAIHyperAI

Command Palette

Search for a command to run...

3 months ago

Ensemble of Averages: Improving Model Selection and Boosting Performance in Domain Generalization

Devansh Arpit Huan Wang Yingbo Zhou Caiming Xiong

Ensemble of Averages: Improving Model Selection and Boosting Performance in Domain Generalization

Abstract

In Domain Generalization (DG) settings, models trained independently on a given set of training domains have notoriously chaotic performance on distribution shifted test domains, and stochasticity in optimization (e.g. seed) plays a big role. This makes deep learning models unreliable in real world settings. We first show that this chaotic behavior exists even along the training optimization trajectory of a single model, and propose a simple model averaging protocol that both significantly boosts domain generalization and diminishes the impact of stochasticity by improving the rank correlation between the in-domain validation accuracy and out-domain test accuracy, which is crucial for reliable early stopping. Taking advantage of our observation, we show that instead of ensembling unaveraged models (that is typical in practice), ensembling moving average models (EoA) from independent runs further boosts performance. We theoretically explain the boost in performance of ensembling and model averaging by adapting the well known Bias-Variance trade-off to the domain generalization setting. On the DomainBed benchmark, when using a pre-trained ResNet-50, this ensemble of averages achieves an average of $68.0\%$, beating vanilla ERM (w/o averaging/ensembling) by $\sim 4\%$, and when using a pre-trained RegNetY-16GF, achieves an average of $76.6\%$, beating vanilla ERM by $6\%$. Our code is available at https://github.com/salesforce/ensemble-of-averages.

Code Repositories

salesforce/ensemble-of-averages
Official
pytorch
Mentioned in GitHub

Benchmarks

BenchmarkMethodologyMetrics
domain-generalization-on-domainnetEnsemble of Averages (RegNetY-16GF)
Average Accuracy: 60.9
domain-generalization-on-domainnetEnsemble of Averages (ResNet-50)
Average Accuracy: 47.4
domain-generalization-on-domainnetEnsemble of Averages (ResNeXt-50 32x4d)
Average Accuracy: 54.6
domain-generalization-on-office-homeEnsemble of Averages (RegNetY-16GF)
Average Accuracy: 83.9
domain-generalization-on-office-homeEnsemble of Averages (ResNeXt-50 32x4d)
Average Accuracy: 80.2
domain-generalization-on-office-homeEnsemble of Averages (ResNet-50)
Average Accuracy: 72.5
domain-generalization-on-pacs-2Ensemble of Averages (RegNetY-16GF)
Average Accuracy: 95.8
domain-generalization-on-pacs-2Ensemble of Averages (ResNet-50)
Average Accuracy: 88.6
domain-generalization-on-pacs-2Ensemble of Averages (ResNeXt-50 32x4d)
Average Accuracy: 93.2
domain-generalization-on-terraincognitaEnsemble of Averages (ResNet-50)
Average Accuracy: 52.3
domain-generalization-on-terraincognitaEnsemble of Averages (RegNetY-16GF)
Average Accuracy: 61.1
domain-generalization-on-terraincognitaEnsemble of Averages (ResNeXt-50 32x4d)
Average Accuracy: 55.2
domain-generalization-on-vlcsEnsemble of Averages (RegNetY-16GF)
Average Accuracy: 81.1
domain-generalization-on-vlcsEnsemble of Averages (ResNeXt-50 32x4d)
Average Accuracy: 80.4
domain-generalization-on-vlcsEnsemble of Averages (ResNet-50)
Average Accuracy: 79.1

Build AI with AI

From idea to launch — accelerate your AI development with free AI co-coding, out-of-the-box environment and best price of GPUs.

AI Co-coding
Ready-to-use GPUs
Best Pricing
Get Started

Hyper Newsletters

Subscribe to our latest updates
We will deliver the latest updates of the week to your inbox at nine o'clock every Monday morning
Powered by MailChimp