skip to content
Dominik Schröder

Asymptotics of Learning with Deep Structured (Random) Features

Dominik Schröder, Hugo Cui, Daniil Dmitriev, Bruno Loureiro

ICML 2024(2024)

Summary

We derive an approximative formula for the generalization error of deep neural networks with structured (random) features, confirming a widely believed conjecture. We also show that our results can capture feature maps learned by deep, finite-width neural networks trained under gradient descent.

A widely observed phenomen in deep learning is that the generalization error of a trained model is often well-predicted by the so-called “double descent” curve. This curve is characterized by a peak in the generalization error for a certain model complexity, followed by a decrease in the error as the model complexity increases. The following plot shows the generalization error computed using our asymptotic formula for a deep neural network with structured features. The double descent curve is visible for sufficiently big additive noise.

Asymptotic formula for the generalization error

For nn independent samples xiRpx_i\in\mathbb{R}^p of a zero mean random vector with covariance matrix

Ω:=Exixi\Omega := \mathbf E\, x_i x_i^\top

define the sample covariance matrix, and the Gram matrices

XXpandXXp\frac{XX^\top}{p}\quad\text{and}\quad \frac{X^\top X}{p}

and the corresponding resolvents

G(λ):=(XXp+λ)1,Gˇ(λ):=(XXp+λ)1.G(\lambda):=\Bigl(\frac{XX^\top}{p}+\lambda\Bigr)^{-1}, \qquad \check G(\lambda):=\Bigl(\frac{X^\top X}{p}+\lambda\Bigr)^{-1}.

The deterministic equivalents of these matrices are

G(λ)M(λ),Gˇ(λ)m(λ)I,G(\lambda)\approx M(\lambda),\quad \check G(\lambda)\approx m(\lambda)I,

where m(λ)m(\lambda) is the solution of the self-consistent equation

1m(λ)=λ+λΩM(λ)=λ+Ω(1+npm(λ)Ω)1.\frac{1}{m(\lambda)}=\lambda + \lambda \langle\Omega M(\lambda)\rangle = \lambda + \Bigl\langle \Omega\Bigl(1+\frac{n}{p}m(\lambda)\Omega\Bigr)^{-1}\Bigr\rangle.

and

M(λ):=(λ+npλm(λ)Ω)1.M(\lambda):= \Bigl(\lambda + \frac{n}{p}\lambda m(\lambda)\Omega\Bigr)^{-1}.

The generalization error of ridge regression is given by

Egenrmt(λ):=1kθΨnpmλΦ(M+λM2)Φ1np(λm)2ΩMΩMθ+Σ(λm)2npMΩMΩ1np(λm)2ΩMΩM,\begin{split} \mathcal E_\mathrm{gen}^\mathrm{rmt}(\lambda)&:=\frac{1}{k}\theta_\ast^\top \frac{ \Psi-\frac{n}{p} m\lambda \Phi (M+\lambda M^2)\Phi^\top}{1-\frac{n}{p}(\lambda m)^2\braket{\Omega M\Omega M}}\theta_\ast + \braket{\Sigma} \frac{ (\lambda m)^2\frac{n}{p} \braket{ M\Omega M\Omega }}{1-\frac{n}{p}(\lambda m)^2\braket{\Omega M\Omega M}}, \end{split}

where Ω,Φ,Ψ\Omega,\Phi,\Psi are the covariances of the student features xx and the teacher features zz

Ω:=ExxRp×p,Φ:=ExzRp×k,Ψ:=EzzRk×k,\Omega := \mathbf E\, x x^\top\in\mathbf R^{p\times p},\quad \Phi := \mathbf E\, x z^\top\in\mathbf R^{p\times k},\quad \Psi := \mathbf E\, z z^\top\in\mathbf R^{k\times k},

θRk\theta_\ast\in\R^{k} is the target weight vector, and ΣRk×k\Sigma\in\R^{k\times k} is covariance of the additive noise.


Numerical illustration

Here we consider an MNIST dataset and train a simple 4-layer neural network to recognize whether a digit is even or odd. We then use the readout layer of this network as a feature map and perform ridge regression on the features. The following plot shows the empirical generalization error computed using our asymptotic formula for the ridge regression problem. The double descent curve is visible for sufficiently small additive noise. The deterministic equivalents remain valid throughout the training process.

Linear RegressionNN
λ = 0.1λ = 1λ = 10λ = 100
40m50m100m200m300m400m500m123↑ Generalization error10201002001k2k10k# Samples →
Generalization error of feature regrresion using either the Neural Network (NN) features or linear features. The solid lines represent the deterministic equivalents, while the dots represent the empirical generalization error. The grey line represents the gradient descent loss of the NN.

Abstract

For a large class of feature maps we provide a tight asymptotic characterisation of the test error associated with learning the readout layer, in the high-dimensional limit where the input dimension, hidden layer widths, and number of training samples are proportionally large. This characterization is formulated in terms of the population covariance of the features. Our work is partially motivated by the problem of learning with Gaussian rainbow neural networks, namely deep non-linear fully-connected networks with random but structured weights, whose row-wise covariances are further allowed to depend on the weights of previous layers. For such networks we also derive a closed-form formula for the feature covariance in terms of the weight matrices. We further find that in some cases our results can capture feature maps learned by deep, finite-width neural networks trained under gradient descent.

Paper

2402.13999.pdf