Skip to content

Weight Decay

Now that we have characterized the problem of overfitting, we can introduce our first regularization technique. Recall that we can always mitigate overfitting by collecting more training data. However, that can be costly, time consuming, or entirely out of our control, making it impossible in the short run. For now, we can assume that we already have as much high-quality data as our resources permit and focus the tools at our disposal when the dataset is taken as a given.

Recall that in our polynomial regression example (:numref:subsec_polynomial-curve-fitting) we could limit our model's capacity by tweaking the degree of the fitted polynomial. Indeed, limiting the number of features is a popular technique for mitigating overfitting. However, simply tossing aside features can be too blunt an instrument. Sticking with the polynomial regression example, consider what might happen with high-dimensional input. The natural extensions of polynomials to multivariate data are called monomials, which are simply products of powers of variables. The degree of a monomial is the sum of the powers. For example, x12x2, and x3x52 are both monomials of degree 3.

Note that the number of terms with degree d blows up rapidly as d grows larger. Given k variables, the number of monomials of degree d is (k1+dk1). Even small changes in degree, say from 2 to 3, dramatically increase the complexity of our model. Thus we often need a more fine-grained tool for adjusting function complexity.

julia
using Pkg;
Pkg.activate("../../d2lai")
using d2lai, Flux
  Activating project at `/workspace/workspace/d2l-julia/d2lai`

Norms and Weight Decay

(Rather than directly manipulating the number of parameters, weight decay, operates by restricting the values that the parameters can take.) More commonly called 2 regularization outside of deep learning circles when optimized by minibatch stochastic gradient descent, weight decay might be the most widely used technique for regularizing parametric machine learning models. The technique is motivated by the basic intuition that among all functions f, the function f=0 (assigning the value 0 to all inputs) is in some sense the simplest, and that we can measure the complexity of a function by the distance of its parameters from zero. But how precisely should we measure the distance between a function and zero? There is no single right answer. In fact, entire branches of mathematics, including parts of functional analysis and the theory of Banach spaces, are devoted to addressing such issues.

One simple interpretation might be to measure the complexity of a linear function f(x)=wx by some norm of its weight vector, e.g., w2. Recall that we introduced the 2 norm and 1 norm, which are special cases of the more general p norm, in :numref:subsec_lin-algebra-norms. The most common method for ensuring a small weight vector is to add its norm as a penalty term to the problem of minimizing the loss. Thus we replace our original objective, minimizing the prediction loss on the training labels, with new objective, minimizing the sum of the prediction loss and the penalty term. Now, if our weight vector grows too large, our learning algorithm might focus on minimizing the weight norm w2 rather than minimizing the training error. That is exactly what we want. To illustrate things in code, we revive our previous example from :numref:sec_linear_regression for linear regression. There, our loss was given by

L(w,b)=1ni=1n12(wx(i)+by(i))2.

Recall that x(i) are the features, y(i) is the label for any data example i, and (w,b) are the weight and bias parameters, respectively. To penalize the size of the weight vector, we must somehow add w2 to the loss function, but how should the model trade off the standard loss for this new additive penalty? In practice, we characterize this trade-off via the regularization constant λ, a nonnegative hyperparameter that we fit using validation data:

L(w,b)+λ2w2.

For λ=0, we recover our original loss function. For λ>0, we restrict the size of w. We divide by 2 by convention: when we take the derivative of a quadratic function, the 2 and 1/2 cancel out, ensuring that the expression for the update looks nice and simple. The astute reader might wonder why we work with the squared norm and not the standard norm (i.e., the Euclidean distance). We do this for computational convenience. By squaring the 2 norm, we remove the square root, leaving the sum of squares of each component of the weight vector. This makes the derivative of the penalty easy to compute: the sum of derivatives equals the derivative of the sum.

Moreover, you might ask why we work with the 2 norm in the first place and not, say, the 1 norm. In fact, other choices are valid and popular throughout statistics. While 2-regularized linear models constitute the classic ridge regression algorithm, 1-regularized linear regression is a similarly fundamental method in statistics, popularly known as lasso regression. One reason to work with the 2 norm is that it places an outsize penalty on large components of the weight vector. This biases our learning algorithm towards models that distribute weight evenly across a larger number of features. In practice, this might make them more robust to measurement error in a single variable. By contrast, 1 penalties lead to models that concentrate weights on a small set of features by clearing the other weights to zero. This gives us an effective method for feature selection, which may be desirable for other reasons. For example, if our model only relies on a few features, then we may not need to collect, store, or transmit data for the other (dropped) features.

Using the same notation in :eqref:eq_linreg_batch_update, minibatch stochastic gradient descent updates for 2-regularized regression as follows:

w(1ηλ)wη|B|iBx(i)(wx(i)+by(i)).

As before, we update w based on the amount by which our estimate differs from the observation. However, we also shrink the size of w towards zero. That is why the method is sometimes called "weight decay": given the penalty term alone, our optimization algorithm decays the weight at each step of training. In contrast to feature selection, weight decay offers us a mechanism for continuously adjusting the complexity of a function. Smaller values of λ correspond to less constrained w, whereas larger values of λ constrain w more considerably. Whether we include a corresponding bias penalty b2 can vary across implementations, and may vary across layers of a neural network. Often, we do not regularize the bias term. Besides, although 2 regularization may not be equivalent to weight decay for other optimization algorithms, the idea of regularization through shrinking the size of weights still holds true.

High-Dimensional Linear Regression

We can illustrate the benefits of weight decay through a simple synthetic example.

First, we [generate some data as before]:

(y=0.05+i=1d0.01xi+ϵ where ϵN(0,0.012).)

In this synthetic dataset, our label is given by an underlying linear function of our inputs, corrupted by Gaussian noise with zero mean and standard deviation 0.01. For illustrative purposes, we can make the effects of overfitting pronounced, by increasing the dimensionality of our problem to d=200 and working with a small training set with only 20 examples.

julia
struct PolynomialData{XT, YT, A} <: d2lai.AbstractData 
    X::XT 
    y::YT
    args::A 
    function PolynomialData(num_train, num_val, num_inputs, batch_size)
        args = (num_train = num_train, num_val = num_val, num_inputs = num_inputs, batchsize = batch_size)
        n = num_train + num_val 
        X = randn(num_inputs, n)
        b = zeros(1)
        y = 0.01*ones(1, num_inputs)*X .+ b .+ 0.01*randn(1, n)
        new{typeof(X), typeof(y), typeof(args)}(X, y, args)
    end
end
function d2lai.get_dataloader(data::PolynomialData; train = true)
    if train 
        return Flux.DataLoader((data.X[:, 1:data.args.num_train], data.y[:, 1:data.args.num_train]), batchsize = data.args.batchsize, shuffle=true)
    else
        return Flux.DataLoader((data.X[:, data.args.num_train + 1 : end], data.y[:, data.args.num_train + 1 : end]), batchsize = data.args.batchsize)
    end
end

Implementation from Scratch

Now, let's try implementing weight decay from scratch. Since minibatch stochastic gradient descent is our optimizer, we just need to add the squared 2 penalty to the original loss function.

Defining 2 Norm Penalty {#Defining-\ell_2-Norm-Penalty}

Perhaps the most convenient way of implementing this penalty is to square all terms in place and sum them.

julia
function l2_penalty(w)
    sum(w.^2)
end
l2_penalty (generic function with 1 method)

Defining the Model

julia
struct WeightDecayScratch{N, A} <: AbstractModel 
    net::N
    args::A
end

function WeightDecayScratch(net, lambda::Real = 0.01)
    args = (lambda = lambda, )
    return WeightDecayScratch(net, args)
end
d2lai.forward(m::WeightDecayScratch, x) = m.net(x)

function d2lai.loss(m::WeightDecayScratch, y_pred, y)
    mse_loss = Flux.Losses.mse(y_pred, y)
    reg_loss = m.args.lambda*l2_penalty(m.net.weight)
    return mse_loss + reg_loss
end

The following code fits our model on the training set with 20 examples and evaluates it on the validation set with 100 examples.

julia
function train_scratch(lambda)
    model = WeightDecayScratch(Dense(200 => 1), lambda)
    opt = Descent(0.01)
    data = PolynomialData(20, 100, 200, 5)
    trainer = Trainer(model, data, opt; max_epochs = 10)
    d2lai.fit(trainer)
end
train_scratch (generic function with 1 method)

Training without Regularization

We now run this code with lambda = 0, disabling weight decay. Note that we overfit badly, decreasing the training error but not the validation error–-a textbook case of overfitting.

julia
train_scratch(0.)
┌ Warning: Layer with Float32 parameters got Float64 input.
│   The input will be converted, but any earlier layers may be very slow.
│   layer = Dense(200 => 1)     # 201 parameters
│   summary(x) = "200×5 Matrix{Float64}"
└ @ Flux ~/.julia/packages/Flux/3711C/src/layers/stateless.jl:60
    [ Info: Train Loss: 0.17010443647907264, Val Loss: 0.3104697962097079
    [ Info: Train Loss: 0.00812942767681148, Val Loss: 0.3761586611835473
    [ Info: Train Loss: 0.00300404867619369, Val Loss: 0.36316384810156355
    [ Info: Train Loss: 9.086800392004014e-5, Val Loss: 0.36251333818878384
    [ Info: Train Loss: 3.768173429835639e-5, Val Loss: 0.3636249171162744
    [ Info: Train Loss: 9.808228607727102e-6, Val Loss: 0.36464670373184327
    [ Info: Train Loss: 6.473687428255859e-6, Val Loss: 0.3645929430537423
    [ Info: Train Loss: 1.6601482946387506e-7, Val Loss: 0.3646625322339013
    [ Info: Train Loss: 7.307524895225433e-7, Val Loss: 0.36467572724532893
    [ Info: Train Loss: 5.9002328313041295e-8, Val Loss: 0.3647579105026305
(WeightDecayScratch{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, @NamedTuple{lambda::Float64}}(Dense(200 => 1), (lambda = 0.0,)), (val_loss = [0.6315150462155444, 1.7667208566284978, 1.0124669697653603, 0.34469795893243227, 1.0563052358000395, 2.566221664070826, 0.4486341039828149, 1.606172932980381, 2.1008154580934293, 2.67316597039848, 1.156265822363412, 1.0914530668393172, 4.342967840175924, 2.177160883786672, 1.0240883719609932, 0.30057182957817985, 6.080012696794702, 3.1974150170736544, 1.5931458666944185, 0.3647579105026305], val_acc = nothing))

Using Weight Decay

Below, we run with substantial weight decay. Note that the training error increases but the validation error decreases. This is precisely the effect we expect from regularization.

julia
train_scratch(3.)
    [ Info: Train Loss: 3.2797984202712356, Val Loss: 4.062983477700901
    [ Info: Train Loss: 1.9804117886097512, Val Loss: 2.4610657281266337
    [ Info: Train Loss: 1.2077013343062588, Val Loss: 1.500083301680783
    [ Info: Train Loss: 0.7378329506568404, Val Loss: 0.9140557421194878
    [ Info: Train Loss: 0.4516074772083898, Val Loss: 0.5599806966984647
    [ Info: Train Loss: 0.2769862280470018, Val Loss: 0.3442206324581537
    [ Info: Train Loss: 0.17134064441566937, Val Loss: 0.2129407598175414
    [ Info: Train Loss: 0.10625916231477397, Val Loss: 0.13353039201811642
    [ Info: Train Loss: 0.06703511114350612, Val Loss: 0.08539619997682976
    [ Info: Train Loss: 0.0427132283461683, Val Loss: 0.05658541332204206
(WeightDecayScratch{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, @NamedTuple{lambda::Float64}}(Dense(200 => 1), (lambda = 3.0,)), (val_loss = [0.0720421476114861, 0.10422826972139035, 0.05659159672156999, 0.06927200854633188, 0.08334887807514543, 0.08471663080022433, 0.056548645586230295, 0.050229097932397465, 0.0728615335716399, 0.0886988647475379, 0.06967753364200492, 0.08053099017313697, 0.1301141913732636, 0.08189592795502351, 0.047899721839636644, 0.04776282464020387, 0.07244801286783671, 0.06684342421053806, 0.05908849218872609, 0.05658541332204206], val_acc = nothing))

Summary

Regularization is a common method for dealing with overfitting. Classical regularization techniques add a penalty term to the loss function (when training) to reduce the complexity of the learned model. One particular choice for keeping the model simple is using an 2 penalty. This leads to weight decay in the update steps of the minibatch stochastic gradient descent algorithm. In practice, the weight decay functionality is provided in optimizers from deep learning frameworks. Different sets of parameters can have different update behaviors within the same training loop.

Exercises

  1. Experiment with the value of λ in the estimation problem in this section. Plot training and validation accuracy as a function of λ. What do you observe?

  2. Use a validation set to find the optimal value of λ. Is it really the optimal value? Does this matter?

  3. What would the update equations look like if instead of w2 we used i|wi| as our penalty of choice (1 regularization)?

  4. We know that w2=ww. Can you find a similar equation for matrices (see the Frobenius norm in :numref:subsec_lin-algebra-norms)?

  5. Review the relationship between training error and generalization error. In addition to weight decay, increased training, and the use of a model of suitable complexity, what other ways might help us deal with overfitting?

  6. In Bayesian statistics we use the product of prior and likelihood to arrive at a posterior via P(wx)P(xw)P(w). How can you identify P(w) with regularization?

julia