Residual Networks (ResNet) and ResNeXt
As we design ever deeper networks it becomes imperative to understand how adding layers can increase the complexity and expressiveness of the network. Even more important is the ability to design networks where adding layers makes networks strictly more expressive rather than just different. To make some progress we need a bit of mathematics.
using Pkg; Pkg.activate("../../d2lai")
using d2lai
using Flux
using CUDA, cuDNN
using Statistics, Flux.Zygote Activating project at `/workspace/d2l-julia/d2lai`Function Classes
Consider
We know that regularization [138], [139] may control complexity of
For non-nested function classes, a larger (indicated by area) function class does not guarantee we will get closer to the "truth" function (\mathit{f}^$). This does not happen in nested function classes.*
Thus, only if larger function classes contain the smaller ones are we guaranteed that increasing them strictly increases the expressive power of the network. For deep neural networks, if we can train the newly-added layer into an identity function
This is the question that He et al. [93] considered when working on very deep computer vision models. At the heart of their proposed residual network (ResNet) is the idea that every additional layer should more easily contain the identity function as one of its elements. These considerations are rather profound but they led to a surprisingly simple solution, a residual block. With it, ResNet won the ImageNet Large Scale Visual Recognition Challenge in 2015. The design had a profound influence on how to build deep neural networks. For instance, residual blocks have been added to recurrent networks [140], [141]. Likewise, Transformers [142] use them to stack many layers of networks efficiently. It is also used in graph neural networks [143] and, as a basic concept, it has been used extensively in computer vision [86], [144]. Note that residual networks are predated by highway networks [145] that share some of the motivation, albeit without the elegant parametrization around the identity function.
Residual Blocks
Let's focus on a local part of a neural network, as depicted in Figure. Denote the input by
In a regular block (left), the portion within the dotted-line box must directly learn the mapping
ResNet has VGG's full
struct Residual{N} <: AbstractModel
net::N
end
function Residual(channels_in::Int; num_channels::Int = channels_in, use_1x1conv = !isequal(channels_in, num_channels), stride = 1)
conv_chain = Chain(
Conv((3,3) , channels_in=>num_channels, pad = 1, stride = stride),
BatchNorm(num_channels, relu),
Conv((3,3) , num_channels=>num_channels, pad = 1),
BatchNorm(num_channels),
)
net = use_1x1conv ? Parallel(+, conv_chain, Conv((1,1), channels_in=>num_channels, stride = stride)) : Parallel(+, conv_chain, Flux.identity)
Residual(net)
end
(r::Residual)(x) = relu.(r.net(x))This code generates two types of networks: one where we add the input to the output before applying the ReLU nonlinearity whenever use_1x1conv=False; and one where we adjust channels and resolution by means of a
ResNet block with and without
Now let's look at [a situation where the input and output are of the same shape], where
r = Residual(3)
r(rand(4, 5, 3, 32)) |> size┌ Warning: Layer with Float32 parameters got Float64 input.
│ The input will be converted, but any earlier layers may be very slow.
│ layer = Conv((3, 3), 3 => 3, pad=1) # 84 parameters
│ summary(x) = "4×5×3×32 Array{Float64, 4}"
└ @ Flux ~/.julia/packages/Flux/3711C/src/layers/stateless.jl:60
(4, 5, 3, 32)We also have the option to halve the output height and width while increasing the number of output channels. In this case we use use_1x1conv=True. This comes in handy at the beginning of each ResNet block to reduce the spatial dimensionality via strides=2.
r = Residual(3; num_channels = 6)
r(rand(4, 5, 3, 32)) |> size(4, 5, 6, 32)ResNet Model
The first two layers of ResNet are the same as those of the GoogLeNet we described before: the
struct ResNetB1{N} <: AbstractModel
net::N
end
function ResNetB1()
net = Chain(
Conv((7,7), 1 => 64, pad = 3 , stride = 2),
BatchNorm(64, relu),
MaxPool((3,3), pad = 1, stride = 2)
)
ResNetB1(net)
end
(r::ResNetB1)(x) = r.net(x)
Flux.@layer ResNetB1GoogLeNet uses four modules made up of Inception blocks. However, ResNet uses four modules made up of residual blocks, each of which uses several residual blocks with the same number of output channels. The number of channels in the first module is the same as the number of input channels. Since a max-pooling layer with a stride of 2 has already been used, it is not necessary to reduce the height and width. In the first residual block for each of the subsequent modules, the number of channels is doubled compared with that of the previous module, and the height and width are halved.
struct ResNetBlock{N} <: AbstractModel
net::N
end
function ResNetBlock(channel_in, num_residuals, num_channels; first_block = false)
block = if first_block
blocks = map(1:num_residuals) do i
Residual(channel_in)
end |> Chain
else
blocks = map(1:num_residuals) do i
if i == 1
return Residual(channel_in; num_channels, stride = 2)
else
return Residual(num_channels)
end
end |> Chain
end
ResNetBlock(block)
end
Flux.@layer ResNetBlock
(r::ResNetBlock)(x) = r.net(x)Then, we add all the modules to ResNet. Here, two residual blocks are used for each module. Lastly, just like GoogLeNet, we add a global average pooling layer, followed by the fully connected layer output.
struct ResNet{N} <: AbstractClassifier
net::N
end
Flux.@layer ResNet
function ResNet(arch::Tuple, num_classes::Int = 10)
channel_ins = last.(arch[1:end-1])
net = Flux.@autosize (96, 96, 1, 1) Chain(
ResNetB1(),
ResNetBlock(64, arch[1]..., first_block = true),
map(arch[2:end], channel_ins) do (num_residuals, num_channels), channel_in
ResNetBlock(channel_in, num_residuals, num_channels)
end |> Chain,
GlobalMeanPool(),
Flux.flatten,
Dense(_ => num_classes),
softmax
)
ResNet(net)
end
(r::ResNet)(x) = r.net(x)There are four convolutional layers in each module (excluding the
The ResNet-18 architecture.
Before training ResNet, let's [observe how the input shape changes across different modules in ResNet]. As in all the previous architectures, the resolution decreases while the number of channels increases up until the point where a global average pooling layer aggregates all features.
arch = ((2, 64), (2, 128), (2, 256), (2, 512))
model = ResNet(arch)ResNet(
Chain(
ResNetB1(
Chain(
Conv((7, 7), 1 => 64, pad=3, stride=2), # 3_200 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
MaxPool((3, 3), pad=1, stride=2),
),
),
ResNetBlock(
Chain(
[
Residual(
Parallel(
+,
Chain(
Conv((3, 3), 64 => 64, pad=1), # 36_928 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((3, 3), 64 => 64, pad=1), # 36_928 parameters
BatchNorm(64), # 128 parameters, plus 128
),
identity,
),
),
Residual(
Parallel(
+,
Chain(
Conv((3, 3), 64 => 64, pad=1), # 36_928 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((3, 3), 64 => 64, pad=1), # 36_928 parameters
BatchNorm(64), # 128 parameters, plus 128
),
identity,
),
),
],
),
),
Chain(
ResNetBlock(
Chain(
[
Residual(
Parallel(
+,
Chain(
Conv((3, 3), 64 => 128, pad=1, stride=2), # 73_856 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((3, 3), 128 => 128, pad=1), # 147_584 parameters
BatchNorm(128), # 256 parameters, plus 256
),
Conv((1, 1), 64 => 128, stride=2), # 8_320 parameters
),
),
Residual(
Parallel(
+,
Chain(
Conv((3, 3), 128 => 128, pad=1), # 147_584 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((3, 3), 128 => 128, pad=1), # 147_584 parameters
BatchNorm(128), # 256 parameters, plus 256
),
identity,
),
),
],
),
),
ResNetBlock(
Chain(
[
Residual(
Parallel(
+,
Chain(
Conv((3, 3), 128 => 256, pad=1, stride=2), # 295_168 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1), # 590_080 parameters
BatchNorm(256), # 512 parameters, plus 512
),
Conv((1, 1), 128 => 256, stride=2), # 33_024 parameters
),
),
Residual(
Parallel(
+,
Chain(
Conv((3, 3), 256 => 256, pad=1), # 590_080 parameters
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((3, 3), 256 => 256, pad=1), # 590_080 parameters
BatchNorm(256), # 512 parameters, plus 512
),
identity,
),
),
],
),
),
ResNetBlock(
Chain(
[
Residual(
Parallel(
+,
Chain(
Conv((3, 3), 256 => 512, pad=1, stride=2), # 1_180_160 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((3, 3), 512 => 512, pad=1), # 2_359_808 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
),
Conv((1, 1), 256 => 512, stride=2), # 131_584 parameters
),
),
Residual(
Parallel(
+,
Chain(
Conv((3, 3), 512 => 512, pad=1), # 2_359_808 parameters
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((3, 3), 512 => 512, pad=1), # 2_359_808 parameters
BatchNorm(512), # 1_024 parameters, plus 1_024
),
identity,
),
),
],
),
),
),
GlobalMeanPool(),
Flux.flatten,
Dense(512 => 10), # 5_130 parameters
NNlib.softmax,
),
) # Total: 76 trainable arrays, 11_178_378 parameters,
# plus 34 non-trainable, 7_808 parameters, summarysize 42.680 MiB.Training
We train ResNet on the Fashion-MNIST dataset, just like before. ResNet is quite a powerful and flexible architecture. The plot capturing training and validation loss illustrates a significant gap between both graphs, with the training loss being considerably lower. For a network of this flexibility, more training data would offer distinct benefit in closing the gap and improving accuracy.
data = d2lai.FashionMNISTData(batchsize = 128, resize = (96,96));
opt = Descent(0.01)
trainer = Trainer(model, data, opt; max_epochs = 10, gpu = true, board_yscale = :identity)
d2lai.fit(trainer); [ Info: Train Loss: 0.19277047, Val Loss: 0.19109984, Val Acc: 0.9375
[ Info: Train Loss: 0.12554948, Val Loss: 0.14957266, Val Acc: 0.9375
[ Info: Train Loss: 0.22092374, Val Loss: 0.14626466, Val Acc: 1.0
[ Info: Train Loss: 0.04200567, Val Loss: 0.088903934, Val Acc: 1.0
[ Info: Train Loss: 0.19900467, Val Loss: 0.1520779, Val Acc: 0.9375
[ Info: Train Loss: 0.032146264, Val Loss: 0.08688806, Val Acc: 1.0
[ Info: Train Loss: 0.06265808, Val Loss: 0.07228865, Val Acc: 1.0
[ Info: Train Loss: 0.008486914, Val Loss: 0.22081932, Val Acc: 0.9375
[ Info: Train Loss: 0.0028302989, Val Loss: 0.14812489, Val Acc: 0.9375
[ Info: Train Loss: 0.008302619, Val Loss: 0.14797957, Val Acc: 0.9375ResNeXt
One of the challenges one encounters in the design of ResNet is the trade-off between nonlinearity and dimensionality within a given block. That is, we could add more nonlinearity by increasing the number of layers, or by increasing the width of the convolutions. An alternative strategy is to increase the number of channels that can carry information between blocks. Unfortunately, the latter comes with a quadratic penalty since the computational cost of ingesting sec_channels).
We can take some inspiration from the Inception block of Figure which has information flowing through the block in separate groups. Applying the idea of multiple independent groups to the ResNet block of Figure led to the design of ResNeXt [80]. Different from the smorgasbord of transformations in Inception, ResNeXt adopts the same transformation in all branches, thus minimizing the need for manual tuning of each branch.
The ResNeXt block. The use of grouped convolution with
Breaking up a convolution from
The only challenge in this design is that no information is exchanged between the subsec_residual-blks, the residual connection is replaced (thus generalized) by a
The right-hand figure in Figure provides a much more concise summary of the resulting network block. It will also play a major role in the design of generic modern CNNs in :numref:sec_cnn-design. Note that the idea of grouped convolutions dates back to the implementation of AlexNet [90]. When distributing the network across two GPUs with limited memory, the implementation treated each GPU as its own channel with no ill effects.
The following implementation of the ResNeXtBlock class takes as argument groups (bot_channels (use_1x1conv=True, strides=2.
struct ResNeXtBlock{N} <: AbstractClassifier
net::N
end
function ResNeXtBlock(channel_in::Int, groups::Int, bot_mul; num_channels = channel_in,
use_1x1conv = !isequal(channel_in, num_channels),
stride = 1)
bot_channels = Int(round(num_channels*bot_mul))
bottleneck_net = Chain(
Conv((1,1), channel_in => bot_channels),
BatchNorm(bot_channels, relu),
Conv((3,3), bot_channels => bot_channels, pad = 1, stride = stride, groups = groups),
BatchNorm(bot_channels, relu),
Conv((1,1), bot_channels => num_channels, stride = 1),
BatchNorm(num_channels, relu)
)
net = if !use_1x1conv
Parallel(+, Flux.identity, bottleneck_net)
else
sidenet = Chain(
Conv((1,1), channel_in => num_channels, stride = stride),
BatchNorm(num_channels, relu)
)
Parallel(+, sidenet, bottleneck_net)
end
ResNeXtBlock(net)
end
Flux.@layer ResNeXtBlock
(r::ResNeXtBlock)(x) = r.net(x)Its use is entirely analogous to that of the ResNetBlock discussed previously. For instance, when using (use_1x1conv=False, strides=1), the input and output are of the same shape. Alternatively, setting use_1x1conv=True, strides=2 halves the output height and width.
blk = ResNeXtBlock(32, 16, 1)ResNeXtBlock(
Parallel(
+,
identity,
Chain(
Conv((1, 1), 32 => 32), # 1_056 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
Conv((3, 3), 32 => 32, pad=1, groups=16), # 608 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
Conv((1, 1), 32 => 32), # 1_056 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
),
) # Total: 12 trainable arrays, 2_912 parameters,
# plus 6 non-trainable, 192 parameters, summarysize 13.344 KiB.blk = ResNeXtBlock(32, 16, 2; num_channels = 64)ResNeXtBlock(
Parallel(
+,
Chain(
Conv((1, 1), 32 => 64), # 2_112 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
),
Chain(
Conv((1, 1), 32 => 128), # 4_224 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((3, 3), 128 => 128, pad=1, groups=16), # 9_344 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((1, 1), 128 => 64), # 8_256 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
),
),
) # Total: 16 trainable arrays, 24_704 parameters,
# plus 8 non-trainable, 768 parameters, summarysize 101.125 KiB.Summary and Discussion
Nested function classes are desirable since they allow us to obtain strictly more powerful rather than also subtly different function classes when adding capacity. One way of accomplishing this is by letting additional layers to simply pass through the input to the output. Residual connections allow for this. As a consequence, this changes the inductive bias from simple functions being of the form
The residual mapping can learn the identity function more easily, such as pushing parameters in the weight layer to zero. We can train an effective deep neural network by having residual blocks. Inputs can forward propagate faster through the residual connections across layers. As a consequence, we can thus train much deeper networks. For instance, the original ResNet paper [93] allowed for up to 152 layers. Another benefit of residual networks is that it allows us to add layers, initialized as the identity function, during the training process. After all, the default behavior of a layer is to let the data pass through unchanged. This can accelerate the training of very large networks in some cases.
Prior to residual connections, bypassing paths with gating units were introduced to effectively train highway networks with over 100 layers [145]. Using identity functions as bypassing paths, ResNet performed remarkably well on multiple computer vision tasks. Residual connections had a major influence on the design of subsequent deep neural networks, of either convolutional or sequential nature. As we will introduce later, the Transformer architecture [142] adopts residual connections (together with other design choices) and is pervasive in areas as diverse as language, vision, speech, and reinforcement learning.
ResNeXt is an example for how the design of convolutional neural networks has evolved over time: by being more frugal with computation and trading it off against the size of the activations (number of channels), it allows for faster and more accurate networks at lower cost. An alternative way of viewing grouped convolutions is to think of a block-diagonal matrix for the convolutional weights. Note that there are quite a few such "tricks" that lead to more efficient networks. For instance, ShiftNet [95] mimicks the effects of a
A common feature of the designs we have discussed so far is that the network design is fairly manual, primarily relying on the ingenuity of the designer to find the "right" network hyperparameters. While clearly feasible, it is also very costly in terms of human time and there is no guarantee that the outcome is optimal in any sense. In :numref:sec_cnn-design we will discuss a number of strategies for obtaining high quality networks in a more automated fashion. In particular, we will review the notion of network design spaces that led to the RegNetX/Y models [97].
Exercises
What are the major differences between the Inception block in Figure and the residual block? How do they compare in terms of computation, accuracy, and the classes of functions they can describe?
Refer to Table 1 in the ResNet paper [93] to implement different variants of the network.
For deeper networks, ResNet introduces a "bottleneck" architecture to reduce model complexity. Try to implement it.
In subsequent versions of ResNet, the authors changed the "convolution, batch normalization, and activation" structure to the "batch normalization, activation, and convolution" structure. Make this improvement yourself. See Figure 1 in He et al. [146] for details.
Why can't we just increase the complexity of functions without bound, even if the function classes are nested?
## 3.
struct ResNetBottleNeck{N} <: AbstractModel
net::N
end
Flux.@layer ResNetBottleNeck
(rbn::ResNetBottleNeck)(x) = r.net(x)
function ResNetBottleNeck(channel_in, bot_mul)
bottled_channels = Int(round(channel_in*bot_mul))
main_net = Chain(
Conv((1,1), channel_in => bottled_channels),
BatchNorm(bottled_channels, relu),
Conv((3,3), bottled_channels => bottled_channels, pad = 1, stride = 1),
BatchNorm(bottled_channels, relu),
Conv((1,1), bottled_channels => channel_in),
BatchNorm(channel_in, relu),
)
net = Parallel(+, Flux.identity, main_net)
ResNetBottleNeck(net)
end
b_neck = ResNetBottleNeck(64, 0.25)ResNetBottleNeck(
Parallel(
+,
identity,
Chain(
Conv((1, 1), 64 => 16), # 1_040 parameters
BatchNorm(16, relu), # 32 parameters, plus 32
Conv((3, 3), 16 => 16, pad=1), # 2_320 parameters
BatchNorm(16, relu), # 32 parameters, plus 32
Conv((1, 1), 16 => 64), # 1_088 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
),
),
) # Total: 12 trainable arrays, 4_640 parameters,
# plus 6 non-trainable, 192 parameters, summarysize 20.094 KiB.Flux.outputsize(b_neck.net, (96, 96, 64, 1))(96, 96, 64, 1)## 4.
struct ResidualBlockNew{N} <: AbstractModel
net::N
end
Flux.@layer ResidualBlockNew
(rbn::ResidualBlockNew)(x) = rbn.net(x)
function ResidualBlockNew(channel_in::Int; num_channels = channel_in, use_1x1conv = !isequal(channel_in, num_channels), stride = 1)
main_net = Chain(
BatchNorm(channel_in, relu),
Conv((3,3), channel_in => num_channels, stride = stride, pad = 1),
BatchNorm(num_channels, relu),
Conv((3,3), num_channels => num_channels, stride = 1, pad = 1),
)
shortcut_connection = if use_1x1conv
Conv((1,1), channel_in => num_channels, stride = stride)
else
Flux.identity
end
net = Parallel(+, shortcut_connection, main_net)
ResidualBlockNew(net)
end
rb_new = ResidualBlockNew(64; num_channels = 128, stride = 2)ResidualBlockNew(
Parallel(
+,
Conv((1, 1), 64 => 128, stride=2), # 8_320 parameters
Chain(
BatchNorm(64, relu), # 128 parameters, plus 128
Conv((3, 3), 64 => 128, pad=1, stride=2), # 73_856 parameters
BatchNorm(128, relu), # 256 parameters, plus 256
Conv((3, 3), 128 => 128, pad=1), # 147_584 parameters
),
),
) # Total: 10 trainable arrays, 230_144 parameters,
# plus 4 non-trainable, 384 parameters, summarysize 901.500 KiB.Flux.outputsize(rb_new.net, (96, 96, 64, 1))(48, 48, 128, 1)