Asymptotics of Learning with Deep Structured (Random) Features
Dominik Schröder, Hugo Cui, Daniil Dmitriev, Bruno Loureiro
ICML 2024(2024)
Summary
We derive an approximative formula for the generalization error of deep neural networks with structured (random) features, confirming a widely believed conjecture. We also show that our results can capture feature maps learned by deep, finite-width neural networks trained under gradient descent.A widely observed phenomen in deep learning is that the generalization error of a trained model is often well-predicted by the so-called “double descent” curve. This curve is characterized by a peak in the generalization error for a certain model complexity, followed by a decrease in the error as the model complexity increases. The following plot shows the generalization error computed using our asymptotic formula for a deep neural network with structured features. The double descent curve is visible for sufficiently big additive noise.
Asymptotic formula for the generalization error
For independent samples of a zero mean random vector with covariance matrix
define the sample covariance matrix, and the Gram matrices
and the corresponding resolvents
The deterministic equivalents of these matrices are
where is the solution of the self-consistent equation
and
The generalization error of ridge regression is given by
where are the covariances of the student features and the teacher features
is the target weight vector, and is covariance of the additive noise.
Numerical illustration
Here we consider an MNIST dataset and train a simple 4-layer neural network to recognize whether a digit is even or odd. We then use the readout layer of this network as a feature map and perform ridge regression on the features. The following plot shows the empirical generalization error computed using our asymptotic formula for the ridge regression problem. The double descent curve is visible for sufficiently small additive noise. The deterministic equivalents remain valid throughout the training process.
Abstract
For a large class of feature maps we provide a tight asymptotic characterisation of the test error associated with learning the readout layer, in the high-dimensional limit where the input dimension, hidden layer widths, and number of training samples are proportionally large. This characterization is formulated in terms of the population covariance of the features. Our work is partially motivated by the problem of learning with Gaussian rainbow neural networks, namely deep non-linear fully-connected networks with random but structured weights, whose row-wise covariances are further allowed to depend on the weights of previous layers. For such networks we also derive a closed-form formula for the feature covariance in terms of the weight matrices. We further find that in some cases our results can capture feature maps learned by deep, finite-width neural networks trained under gradient descent.