HyperAIHyperAI

Command Palette

Search for a command to run...

3 months ago

Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization

Shiori Sagawa Pang Wei Koh Tatsunori B. Hashimoto Percy Liang

Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization

Abstract

Overparameterized neural networks can be highly accurate on average on an i.i.d. test set yet consistently fail on atypical groups of the data (e.g., by learning spurious correlations that hold on average but not in such groups). Distributionally robust optimization (DRO) allows us to learn models that instead minimize the worst-case training loss over a set of pre-defined groups. However, we find that naively applying group DRO to overparameterized neural networks fails: these models can perfectly fit the training data, and any model with vanishing average training loss also already has vanishing worst-case training loss. Instead, the poor worst-case performance arises from poor generalization on some groups. By coupling group DRO models with increased regularization---a stronger-than-typical L2 penalty or early stopping---we achieve substantially higher worst-group accuracies, with 10-40 percentage point improvements on a natural language inference task and two image tasks, while maintaining high average accuracies. Our results suggest that regularization is important for worst-group generalization in the overparameterized regime, even if it is not needed for average generalization. Finally, we introduce a stochastic optimization algorithm, with convergence guarantees, to efficiently train group DRO models.

Code Repositories

ssagawa/overparam_spur_corr
pytorch
Mentioned in GitHub
facebookresearch/DomainBed
pytorch
Mentioned in GitHub
haoxiang-wang/isr
pytorch
Mentioned in GitHub
orparask/VS-Loss
pytorch
Mentioned in GitHub
kohpangwei/group_DRO
Official
pytorch
Mentioned in GitHub
yangarbiter/dp-dg
pytorch
Mentioned in GitHub

Benchmarks

BenchmarkMethodologyMetrics
domain-generalization-on-nico-animalDRO (Resnet-18)
Accuracy: 77.61
domain-generalization-on-nico-vehicleDRO (Resnet-18)
Accuracy: 77.61
domain-generalization-on-pacs-2GroupDRO (Resnet-50, DomainBed)
Average Accuracy: 84.4

Build AI with AI

From idea to launch — accelerate your AI development with free AI co-coding, out-of-the-box environment and best price of GPUs.

AI Co-coding
Ready-to-use GPUs
Best Pricing
Get Started

Hyper Newsletters

Subscribe to our latest updates
We will deliver the latest updates of the week to your inbox at nine o'clock every Monday morning
Powered by MailChimp