industry

Intervening on early readouts for mitigating spurious features and simplicity bias (blog.research.google)

research.google · 2 years ago · write a board post referencing this
Posted by Rishabh Tiwari, Pre-doctoral Researcher, and Pradeep Shenoy, Research Scientist, Google Research Machine learning models in the real world are often trained on limited data that may contain unintended statistical biases . For example, in the CELEBA celebrity image dataset, a disproportionate number of female celebrities have blond hair, leading to classifiers incorrectly predicting “blond” as the hair color for most female faces — here, gender is a spurious feature for predicting hair color. Such unfair biases could have significant consequences in critical applications such as medical diagnosis . Surprisingly, recent work has also discovered an inherent tendency of deep networks to amplify such statistical biases , through the so-called simplicity bias of deep learning. This bias is the tendency of deep networks to identify weakly predictive features early in the training, and continue to anchor on these features, failing to identify more complex and potentially more accurate features. With the above in mind, we propose simple and effective fixes to this dual challenge of spurious features and simplicity bias by applying early readouts and feature forgetting . First, in “ Using Early Readouts to Mediate Featural Bias in Distillation ”, we show that making predictions from early layers of a deep network (referred to as “early readouts”) can automatically signal issues with the quality of the learned representations. In particular, these predictions are more often wrong, and more confidently wrong, when the network is relying on spurious features. We use this erroneous confidence to improve outcomes in model distillation , a setting where a larger “teacher” model guides the training of a smaller “student” model. Then in “ Overcoming Simplicity Bias in Deep Networks using a Feature Sieve ”, we intervene directly on these indicator signals by making the network “forget” the problematic features and consequently look for better, more predictive features. This substantially improves the model’s ability to generalize to unseen domains compared to previous approaches. Our AI Principles and our Responsible AI practices guide how we research and develop these advanced applications and help us address the challenges posed by statistical biases. Animation comparing hypothetical responses from two models trained with and without the feature sieve. Early readouts for debiasing distillation We first illustrate the diagnostic value of early readouts and their application in debiased distillation, i.e., making sure that the student model inherits the teacher model’s resilience to feature bias through distillation. We start with a standard distillation framework where the student is trained with a mixture of label matching (minimizing the cross-entropy loss between student outputs and the ground-truth labels) and teacher matching (minimizing the KL divergence loss between student and teacher outputs for any given input). Suppose one trains a linear deco

login to comment.