Weight Decay
Now that we have characterized the problem of overfitting, we can introduce our first regularization technique. Recall that we can always mitigate overfitting by collecting more training data. However, that can be costly, time consuming, or entirely out of our control, making it impossible in the short run. For now, we can assume that we already have as much high-quality data as our resources permit and focus the tools at our disposal when the dataset is taken as a given.
Recall that in our polynomial regression example (:numref:subsec_polynomial-curve-fitting) we could limit our model's capacity by tweaking the degree of the fitted polynomial. Indeed, limiting the number of features is a popular technique for mitigating overfitting. However, simply tossing aside features can be too blunt an instrument. Sticking with the polynomial regression example, consider what might happen with high-dimensional input. The natural extensions of polynomials to multivariate data are called monomials, which are simply products of powers of variables. The degree of a monomial is the sum of the powers. For example,
Note that the number of terms with degree
using Pkg;
Pkg.activate("../../d2lai")
using d2lai, Flux Activating project at `/workspace/workspace/d2l-julia/d2lai`Norms and Weight Decay
(Rather than directly manipulating the number of parameters, weight decay, operates by restricting the values that the parameters can take.) More commonly called
One simple interpretation might be to measure the complexity of a linear function subsec_lin-algebra-norms. The most common method for ensuring a small weight vector is to add its norm as a penalty term to the problem of minimizing the loss. Thus we replace our original objective, minimizing the prediction loss on the training labels, with new objective, minimizing the sum of the prediction loss and the penalty term. Now, if our weight vector grows too large, our learning algorithm might focus on minimizing the weight norm sec_linear_regression for linear regression. There, our loss was given by
Recall that
For
Moreover, you might ask why we work with the
Using the same notation in :eqref:eq_linreg_batch_update, minibatch stochastic gradient descent updates for
As before, we update
High-Dimensional Linear Regression
We can illustrate the benefits of weight decay through a simple synthetic example.
First, we [generate some data as before]:
(
In this synthetic dataset, our label is given by an underlying linear function of our inputs, corrupted by Gaussian noise with zero mean and standard deviation 0.01. For illustrative purposes, we can make the effects of overfitting pronounced, by increasing the dimensionality of our problem to
struct PolynomialData{XT, YT, A} <: d2lai.AbstractData
X::XT
y::YT
args::A
function PolynomialData(num_train, num_val, num_inputs, batch_size)
args = (num_train = num_train, num_val = num_val, num_inputs = num_inputs, batchsize = batch_size)
n = num_train + num_val
X = randn(num_inputs, n)
b = zeros(1)
y = 0.01*ones(1, num_inputs)*X .+ b .+ 0.01*randn(1, n)
new{typeof(X), typeof(y), typeof(args)}(X, y, args)
end
end
function d2lai.get_dataloader(data::PolynomialData; train = true)
if train
return Flux.DataLoader((data.X[:, 1:data.args.num_train], data.y[:, 1:data.args.num_train]), batchsize = data.args.batchsize, shuffle=true)
else
return Flux.DataLoader((data.X[:, data.args.num_train + 1 : end], data.y[:, data.args.num_train + 1 : end]), batchsize = data.args.batchsize)
end
endImplementation from Scratch
Now, let's try implementing weight decay from scratch. Since minibatch stochastic gradient descent is our optimizer, we just need to add the squared
Defining Norm Penalty {#Defining-\ell_2-Norm-Penalty}
Perhaps the most convenient way of implementing this penalty is to square all terms in place and sum them.
function l2_penalty(w)
sum(w.^2)
endl2_penalty (generic function with 1 method)Defining the Model
struct WeightDecayScratch{N, A} <: AbstractModel
net::N
args::A
end
function WeightDecayScratch(net, lambda::Real = 0.01)
args = (lambda = lambda, )
return WeightDecayScratch(net, args)
end
d2lai.forward(m::WeightDecayScratch, x) = m.net(x)
function d2lai.loss(m::WeightDecayScratch, y_pred, y)
mse_loss = Flux.Losses.mse(y_pred, y)
reg_loss = m.args.lambda*l2_penalty(m.net.weight)
return mse_loss + reg_loss
endThe following code fits our model on the training set with 20 examples and evaluates it on the validation set with 100 examples.
function train_scratch(lambda)
model = WeightDecayScratch(Dense(200 => 1), lambda)
opt = Descent(0.01)
data = PolynomialData(20, 100, 200, 5)
trainer = Trainer(model, data, opt; max_epochs = 10)
d2lai.fit(trainer)
endtrain_scratch (generic function with 1 method)Training without Regularization
We now run this code with lambda = 0, disabling weight decay. Note that we overfit badly, decreasing the training error but not the validation error–-a textbook case of overfitting.
train_scratch(0.)┌ Warning: Layer with Float32 parameters got Float64 input.
│ The input will be converted, but any earlier layers may be very slow.
│ layer = Dense(200 => 1) # 201 parameters
│ summary(x) = "200×5 Matrix{Float64}"
└ @ Flux ~/.julia/packages/Flux/3711C/src/layers/stateless.jl:60 [ Info: Train Loss: 0.17010443647907264, Val Loss: 0.3104697962097079
[ Info: Train Loss: 0.00812942767681148, Val Loss: 0.3761586611835473
[ Info: Train Loss: 0.00300404867619369, Val Loss: 0.36316384810156355
[ Info: Train Loss: 9.086800392004014e-5, Val Loss: 0.36251333818878384
[ Info: Train Loss: 3.768173429835639e-5, Val Loss: 0.3636249171162744
[ Info: Train Loss: 9.808228607727102e-6, Val Loss: 0.36464670373184327
[ Info: Train Loss: 6.473687428255859e-6, Val Loss: 0.3645929430537423
[ Info: Train Loss: 1.6601482946387506e-7, Val Loss: 0.3646625322339013
[ Info: Train Loss: 7.307524895225433e-7, Val Loss: 0.36467572724532893
[ Info: Train Loss: 5.9002328313041295e-8, Val Loss: 0.3647579105026305(WeightDecayScratch{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, @NamedTuple{lambda::Float64}}(Dense(200 => 1), (lambda = 0.0,)), (val_loss = [0.6315150462155444, 1.7667208566284978, 1.0124669697653603, 0.34469795893243227, 1.0563052358000395, 2.566221664070826, 0.4486341039828149, 1.606172932980381, 2.1008154580934293, 2.67316597039848, 1.156265822363412, 1.0914530668393172, 4.342967840175924, 2.177160883786672, 1.0240883719609932, 0.30057182957817985, 6.080012696794702, 3.1974150170736544, 1.5931458666944185, 0.3647579105026305], val_acc = nothing))Using Weight Decay
Below, we run with substantial weight decay. Note that the training error increases but the validation error decreases. This is precisely the effect we expect from regularization.
train_scratch(3.) [ Info: Train Loss: 3.2797984202712356, Val Loss: 4.062983477700901
[ Info: Train Loss: 1.9804117886097512, Val Loss: 2.4610657281266337
[ Info: Train Loss: 1.2077013343062588, Val Loss: 1.500083301680783
[ Info: Train Loss: 0.7378329506568404, Val Loss: 0.9140557421194878
[ Info: Train Loss: 0.4516074772083898, Val Loss: 0.5599806966984647
[ Info: Train Loss: 0.2769862280470018, Val Loss: 0.3442206324581537
[ Info: Train Loss: 0.17134064441566937, Val Loss: 0.2129407598175414
[ Info: Train Loss: 0.10625916231477397, Val Loss: 0.13353039201811642
[ Info: Train Loss: 0.06703511114350612, Val Loss: 0.08539619997682976
[ Info: Train Loss: 0.0427132283461683, Val Loss: 0.05658541332204206(WeightDecayScratch{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, @NamedTuple{lambda::Float64}}(Dense(200 => 1), (lambda = 3.0,)), (val_loss = [0.0720421476114861, 0.10422826972139035, 0.05659159672156999, 0.06927200854633188, 0.08334887807514543, 0.08471663080022433, 0.056548645586230295, 0.050229097932397465, 0.0728615335716399, 0.0886988647475379, 0.06967753364200492, 0.08053099017313697, 0.1301141913732636, 0.08189592795502351, 0.047899721839636644, 0.04776282464020387, 0.07244801286783671, 0.06684342421053806, 0.05908849218872609, 0.05658541332204206], val_acc = nothing))Summary
Regularization is a common method for dealing with overfitting. Classical regularization techniques add a penalty term to the loss function (when training) to reduce the complexity of the learned model. One particular choice for keeping the model simple is using an
Exercises
Experiment with the value of
in the estimation problem in this section. Plot training and validation accuracy as a function of . What do you observe? Use a validation set to find the optimal value of
. Is it really the optimal value? Does this matter? What would the update equations look like if instead of
we used as our penalty of choice ( regularization)? We know that
. Can you find a similar equation for matrices (see the Frobenius norm in :numref: subsec_lin-algebra-norms)?Review the relationship between training error and generalization error. In addition to weight decay, increased training, and the use of a model of suitable complexity, what other ways might help us deal with overfitting?
In Bayesian statistics we use the product of prior and likelihood to arrive at a posterior via
. How can you identify with regularization?