Densely Connected Networks (DenseNet) ​
ResNet significantly changed the view of how to parametrize the functions in deep networks. DenseNet (dense convolutional network) is to some extent the logical extension of this [94]. DenseNet is characterized by both the connectivity pattern where each layer connects to all the preceding layers and the concatenation operation (rather than the addition operator in ResNet) to preserve and reuse features from earlier layers. To understand how to arrive at it, let's take a small detour to mathematics.
using Pkg; Pkg.activate("../../d2lai")
using d2lai
using Flux
using CUDA, cuDNN Activating project at `/workspace/d2l-julia/d2lai`From ResNet to DenseNet ​
Recall the Taylor expansion for functions. At the point
The key point is that it decomposes a function into terms of increasingly higher order. In a similar vein, ResNet decomposes functions into
That is, ResNet decomposes
The main difference between ResNet (left) and DenseNet (right) in cross-layer connections: use of addition and use of concatenation.
As shown in Figure, the key difference between ResNet and DenseNet is that in the latter case outputs are concatenated (denoted by
In the end, all these functions are combined in MLP to reduce the number of features again. In terms of implementation this is quite simple: rather than adding terms, we concatenate them. The name DenseNet arises from the fact that the dependency graph between variables becomes quite dense. The final layer of such a chain is densely connected to all previous layers. The dense connections are shown in Figure.
Dense connections in DenseNet. Note how the dimensionality increases with depth.
The main components that comprise a DenseNet are dense blocks and transition layers. The former define how the inputs and outputs are concatenated, while the latter control the number of channels so that it is not too large, since the expansion
Dense Blocks ​
DenseNet uses the modified "batch normalization, activation, and convolution" structure of ResNet (see the exercise in :numref:sec_resnet). First, we implement this convolution block structure.
struct DenseNetConvBlock{N} <: AbstractModel
net::N
end
Flux.@layer DenseNetConvBlock
(d::DenseNetConvBlock)(x) = d.net(x)
function DenseNetConvBlock(channel_in::Int, num_channels::Int)
net = Chain(
Conv((3,3), channel_in => num_channels, pad = 1),
BatchNorm(num_channels, relu)
)
endDenseNetConvBlockA dense block consists of multiple convolution blocks, each using the same number of output channels. In the forward propagation, however, we concatenate the input and output of each convolution block on the channel dimension. Lazy evaluation allows us to adjust the dimensionality automatically.
struct DenseBlock{N} <: AbstractModel
net::N
end
Flux.@layer DenseBlock
function DenseBlock(channel_in::Int, num_convs, num_channels; return_output_channels = false)
prev_channels = channel_in
conv_layers = map(1:num_convs) do i
block = DenseNetConvBlock(prev_channels, num_channels)
prev_channels += num_channels
return block
end
net = DenseBlock(conv_layers)
if return_output_channels
return net, prev_channels
else
return net
end
end
function (d::DenseBlock)(x)
for block in d.net
y = block(x)
x = cat(x, y; dims = 3)
end
return x
endIn the following example, we define a DenseBlock instance with two convolution blocks of 10 output channels. When using an input with three channels, we will get an output with
block = DenseBlock(3, 2, 10)
block(rand(8,8,3,16)) |> size(8, 8, 23, 16)Transition Layers ​
Since each dense block will increase the number of channels, adding too many of them will lead to an excessively complex model. A transition layer is used to control the complexity of the model. It reduces the number of channels by using a
struct DenseNetTransitionBlock{N} <: AbstractModel
net::N
end
Flux.@layer DenseNetTransitionBlock
function DenseNetTransitionBlock(channel_in, num_channels)
net = Chain(
BatchNorm(channel_in, relu),
Conv((1,1), channel_in => num_channels),
MeanPool((2,2), pad = 1, stride = 2)
)
end
(dtb::DenseNetTransitionBlock)(x) = dtb.net(x)Apply a transition layer with 10 channels to the output of the dense block in the previous example. This reduces the number of output channels to 10, and halves the height and width.
block = DenseNetTransitionBlock(23, 10)
block(rand(8, 8, 23, 16)) |> size(5, 5, 10, 16)DenseNet Model ​
Next, we will construct a DenseNet model. DenseNet first uses the same single convolutional layer and max-pooling layer as in ResNet.
struct DenseNetB1{N} <: AbstractModel
net::N
end
Flux.@layer DenseNetB1
(b1::DenseNetB1)(x) = b1.net(x)
function DenseNetB1(channel_in::Int = 1)
net = Chain(
Conv((7,7), channel_in => 64, pad = 3, stride =2),
BatchNorm(64, relu),
MaxPool((3,3), stride = 2, pad = 1)
)
DenseNetB1(net)
endDenseNetB1Then, similar to the four modules made up of residual blocks that ResNet uses, DenseNet uses four dense blocks. As with ResNet, we can set the number of convolutional layers used in each dense block. Here, we set it to 4, consistent with the ResNet-18 model in :numref:sec_resnet. Furthermore, we set the number of channels (i.e., growth rate) for the convolutional layers in the dense block to 32, so 128 channels will be added to each dense block.
In ResNet, the height and width are reduced between each module by a residual block with a stride of 2. Here, we use the transition layer to halve the height and width and halve the number of channels. Similar to ResNet, a global pooling layer and a fully connected layer are connected at the end to produce the output.
struct DenseNet{N} <:AbstractClassifier
net::N
end
Flux.@layer DenseNet
(dn::DenseNet)(x) = dn.net(x)
function DenseNet(channel_in::Int = 1; growth_rate = 32, arch = (4,4,4,4), num_classes = 10)
prev_channels = 64
layers = []
for i in 1:length(arch)
block, prev_channels = DenseBlock(prev_channels, arch[i], growth_rate; return_output_channels = true)
push!(layers, block)
if i != length(arch)
transition_layer = DenseNetTransitionBlock(prev_channels, prev_channels ÷ 2)
prev_channels = prev_channels ÷ 2
push!(layers, transition_layer)
end
end
net = Flux.@autosize (96, 96, 1, 1) Chain(
DenseNetB1(channel_in),
layers...,
GlobalMeanPool(),
Flux.flatten,
Dense(_ => num_classes),
softmax
)
DenseNet(net)
endDenseNetmodel = DenseNet()DenseNet(
Chain(
DenseNetB1(
Chain(
Conv((7, 7), 1 => 64, pad=3, stride=2), # 3_200 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
MaxPool((3, 3), pad=1, stride=2),
),
),
DenseBlock(
[
Chain(
Conv((3, 3), 64 => 32, pad=1), # 18_464 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 96 => 32, pad=1), # 27_680 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 128 => 32, pad=1), # 36_896 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 160 => 32, pad=1), # 46_112 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
],
),
Chain(
BatchNorm(192, relu), # 384 parameters, plus 384
Conv((1, 1), 192 => 96), # 18_528 parameters
MeanPool((2, 2), pad=1),
),
DenseBlock(
[
Chain(
Conv((3, 3), 96 => 32, pad=1), # 27_680 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 128 => 32, pad=1), # 36_896 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 160 => 32, pad=1), # 46_112 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 192 => 32, pad=1), # 55_328 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
],
),
Chain(
BatchNorm(224, relu), # 448 parameters, plus 448
Conv((1, 1), 224 => 112), # 25_200 parameters
MeanPool((2, 2), pad=1),
),
DenseBlock(
[
Chain(
Conv((3, 3), 112 => 32, pad=1), # 32_288 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 144 => 32, pad=1), # 41_504 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 176 => 32, pad=1), # 50_720 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 208 => 32, pad=1), # 59_936 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
],
),
Chain(
BatchNorm(240, relu), # 480 parameters, plus 480
Conv((1, 1), 240 => 120), # 28_920 parameters
MeanPool((2, 2), pad=1),
),
DenseBlock(
[
Chain(
Conv((3, 3), 120 => 32, pad=1), # 34_592 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 152 => 32, pad=1), # 43_808 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 184 => 32, pad=1), # 53_024 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 216 => 32, pad=1), # 62_240 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
],
),
GlobalMeanPool(),
Flux.flatten,
Dense(248 => 10), # 2_490 parameters
NNlib.softmax,
),
) # Total: 82 trainable arrays, 754_082 parameters,
# plus 40 non-trainable, 2_464 parameters, summarysize 2.894 MiB.Training ​
Since we are using a deeper network here, in this section, we will reduce the input height and width from 224 to 96 to simplify the computation.
model = DenseNet()
data = d2lai.FashionMNISTData(batchsize = 128, resize = (96, 96))
opt = Descent(0.01)
trainer = Trainer(model, data, opt; max_epochs = 10, gpu = true, board_yscale = :identity)
d2lai.fit(trainer); [ Info: Train Loss: 0.63423675, Val Loss: 0.44235954, Val Acc: 0.875
[ Info: Train Loss: 0.26092514, Val Loss: 0.35349703, Val Acc: 0.9375
[ Info: Train Loss: 0.23870032, Val Loss: 0.54305226, Val Acc: 0.8125
[ Info: Train Loss: 0.36317107, Val Loss: 0.43822515, Val Acc: 0.875
[ Info: Train Loss: 0.3757471, Val Loss: 0.3553314, Val Acc: 0.875
[ Info: Train Loss: 0.16658472, Val Loss: 0.48415565, Val Acc: 0.875
[ Info: Train Loss: 0.16947578, Val Loss: 0.3061001, Val Acc: 0.875
[ Info: Train Loss: 0.25147822, Val Loss: 0.4952318, Val Acc: 0.8125
[ Info: Train Loss: 0.14965303, Val Loss: 0.32668728, Val Acc: 0.875
[ Info: Train Loss: 0.14355798, Val Loss: 0.32147193, Val Acc: 0.875Summary and Discussion ​
The main components that comprise DenseNet are dense blocks and transition layers. For the latter, we need to keep the dimensionality under control when composing the network by adding transition layers that shrink the number of channels again. In terms of cross-layer connections, in contrast to ResNet, where inputs and outputs are added together, DenseNet concatenates inputs and outputs on the channel dimension. Although these concatenation operations reuse features to achieve computational efficiency, unfortunately they lead to heavy GPU memory consumption. As a result, applying DenseNet may require more memory-efficient implementations that may increase training time [147].
Exercises ​
Why do we use average pooling rather than max-pooling in the transition layer?
One of the advantages mentioned in the DenseNet paper is that its model parameters are smaller than those of ResNet. Why is this the case?
One problem for which DenseNet has been criticized is its high memory consumption.
Is this really the case? Try to change the input shape to
to compare the actual GPU memory consumption empirically. Can you think of an alternative means of reducing the memory consumption? How would you need to change the framework?
Implement the various DenseNet versions presented in Table 1 of the DenseNet paper [94].
Design an MLP-based model by applying the DenseNet idea. Apply it to the housing price prediction task in :numref:
sec_kaggle_house.
1 . ​
Since the goal is to take outputs from previous layers upto the next layer, we use mean pooling. It doesnot discard the rest of the activations in the convolutional layer, and instead takes the mean.
2 ​
Due to the transition layers, we effectively manage the model complexity and by extension the number of parameters.
3. ​
A. 96x96 takes 2169MiB of GPU memoy. 224 / 96 = 2.33. The GPU memory is affected by the order of N^3. Therefore it will be 8 times more B. Sparse Connectivity: Instead of connecting all the layers to all the layers, randomly pick some layers and connect them.
4. ​
densenet121 = DenseNet(; arch = (6, 12, 24, 16))
densenet169 = DenseNet(; arch = (6, 12, 32, 32))
densenet201 = DenseNet(; arch = (6, 12, 48, 32))
densenet264 = DenseNet(; arch = (6, 12, 64, 48))DenseNet(
Chain(
DenseNetB1(
Chain(
Conv((7, 7), 1 => 64, pad=3, stride=2), # 3_200 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
MaxPool((3, 3), pad=1, stride=2),
),
),
DenseBlock(
[
Chain(
Conv((3, 3), 64 => 32, pad=1), # 18_464 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 96 => 32, pad=1), # 27_680 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 128 => 32, pad=1), # 36_896 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 160 => 32, pad=1), # 46_112 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 192 => 32, pad=1), # 55_328 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 224 => 32, pad=1), # 64_544 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
],
),
Chain(
BatchNorm(256, relu), # 512 parameters, plus 512
Conv((1, 1), 256 => 128), # 32_896 parameters
MeanPool((2, 2), pad=1),
),
DenseBlock(
[
Chain(
Conv((3, 3), 128 => 32, pad=1), # 36_896 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 160 => 32, pad=1), # 46_112 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 192 => 32, pad=1), # 55_328 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 224 => 32, pad=1), # 64_544 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 256 => 32, pad=1), # 73_760 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 288 => 32, pad=1), # 82_976 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 320 => 32, pad=1), # 92_192 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 352 => 32, pad=1), # 101_408 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 384 => 32, pad=1), # 110_624 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 416 => 32, pad=1), # 119_840 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 448 => 32, pad=1), # 129_056 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 480 => 32, pad=1), # 138_272 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
],
),
Chain(
BatchNorm(512, relu), # 1_024 parameters, plus 1_024
Conv((1, 1), 512 => 256), # 131_328 parameters
MeanPool((2, 2), pad=1),
),
DenseBlock(
[
Chain(
Conv((3, 3), 256 => 32, pad=1), # 73_760 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 288 => 32, pad=1), # 82_976 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 320 => 32, pad=1), # 92_192 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 352 => 32, pad=1), # 101_408 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 384 => 32, pad=1), # 110_624 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 416 => 32, pad=1), # 119_840 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 448 => 32, pad=1), # 129_056 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 480 => 32, pad=1), # 138_272 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 512 => 32, pad=1), # 147_488 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 544 => 32, pad=1), # 156_704 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 576 => 32, pad=1), # 165_920 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 608 => 32, pad=1), # 175_136 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 640 => 32, pad=1), # 184_352 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 672 => 32, pad=1), # 193_568 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 704 => 32, pad=1), # 202_784 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 736 => 32, pad=1), # 212_000 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 768 => 32, pad=1), # 221_216 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 800 => 32, pad=1), # 230_432 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 832 => 32, pad=1), # 239_648 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 864 => 32, pad=1), # 248_864 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 896 => 32, pad=1), # 258_080 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 928 => 32, pad=1), # 267_296 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 960 => 32, pad=1), # 276_512 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 992 => 32, pad=1), # 285_728 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1024 => 32, pad=1), # 294_944 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1056 => 32, pad=1), # 304_160 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1088 => 32, pad=1), # 313_376 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1120 => 32, pad=1), # 322_592 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1152 => 32, pad=1), # 331_808 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1184 => 32, pad=1), # 341_024 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1216 => 32, pad=1), # 350_240 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1248 => 32, pad=1), # 359_456 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1280 => 32, pad=1), # 368_672 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1312 => 32, pad=1), # 377_888 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1344 => 32, pad=1), # 387_104 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1376 => 32, pad=1), # 396_320 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1408 => 32, pad=1), # 405_536 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1440 => 32, pad=1), # 414_752 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1472 => 32, pad=1), # 423_968 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1504 => 32, pad=1), # 433_184 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1536 => 32, pad=1), # 442_400 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1568 => 32, pad=1), # 451_616 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1600 => 32, pad=1), # 460_832 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1632 => 32, pad=1), # 470_048 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1664 => 32, pad=1), # 479_264 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1696 => 32, pad=1), # 488_480 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1728 => 32, pad=1), # 497_696 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1760 => 32, pad=1), # 506_912 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1792 => 32, pad=1), # 516_128 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1824 => 32, pad=1), # 525_344 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1856 => 32, pad=1), # 534_560 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1888 => 32, pad=1), # 543_776 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1920 => 32, pad=1), # 552_992 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1952 => 32, pad=1), # 562_208 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1984 => 32, pad=1), # 571_424 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2016 => 32, pad=1), # 580_640 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2048 => 32, pad=1), # 589_856 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2080 => 32, pad=1), # 599_072 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2112 => 32, pad=1), # 608_288 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2144 => 32, pad=1), # 617_504 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2176 => 32, pad=1), # 626_720 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2208 => 32, pad=1), # 635_936 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2240 => 32, pad=1), # 645_152 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2272 => 32, pad=1), # 654_368 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
],
),
Chain(
BatchNorm(2304, relu), # 4_608 parameters, plus 4_608
Conv((1, 1), 2304 => 1152), # 2_655_360 parameters
MeanPool((2, 2), pad=1),
),
DenseBlock(
[
Chain(
Conv((3, 3), 1152 => 32, pad=1), # 331_808 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1184 => 32, pad=1), # 341_024 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1216 => 32, pad=1), # 350_240 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1248 => 32, pad=1), # 359_456 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1280 => 32, pad=1), # 368_672 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1312 => 32, pad=1), # 377_888 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1344 => 32, pad=1), # 387_104 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1376 => 32, pad=1), # 396_320 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1408 => 32, pad=1), # 405_536 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1440 => 32, pad=1), # 414_752 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1472 => 32, pad=1), # 423_968 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1504 => 32, pad=1), # 433_184 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1536 => 32, pad=1), # 442_400 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1568 => 32, pad=1), # 451_616 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1600 => 32, pad=1), # 460_832 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1632 => 32, pad=1), # 470_048 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1664 => 32, pad=1), # 479_264 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1696 => 32, pad=1), # 488_480 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1728 => 32, pad=1), # 497_696 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1760 => 32, pad=1), # 506_912 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1792 => 32, pad=1), # 516_128 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1824 => 32, pad=1), # 525_344 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1856 => 32, pad=1), # 534_560 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1888 => 32, pad=1), # 543_776 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1920 => 32, pad=1), # 552_992 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1952 => 32, pad=1), # 562_208 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 1984 => 32, pad=1), # 571_424 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2016 => 32, pad=1), # 580_640 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2048 => 32, pad=1), # 589_856 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2080 => 32, pad=1), # 599_072 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2112 => 32, pad=1), # 608_288 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2144 => 32, pad=1), # 617_504 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2176 => 32, pad=1), # 626_720 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2208 => 32, pad=1), # 635_936 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2240 => 32, pad=1), # 645_152 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2272 => 32, pad=1), # 654_368 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2304 => 32, pad=1), # 663_584 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2336 => 32, pad=1), # 672_800 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2368 => 32, pad=1), # 682_016 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2400 => 32, pad=1), # 691_232 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2432 => 32, pad=1), # 700_448 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2464 => 32, pad=1), # 709_664 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2496 => 32, pad=1), # 718_880 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2528 => 32, pad=1), # 728_096 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2560 => 32, pad=1), # 737_312 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2592 => 32, pad=1), # 746_528 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2624 => 32, pad=1), # 755_744 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
Chain(
Conv((3, 3), 2656 => 32, pad=1), # 764_960 parameters
BatchNorm(32, relu), # 64 parameters, plus 64
),
],
),
GlobalMeanPool(),
Flux.flatten,
Dense(2688 => 10), # 26_890 parameters
NNlib.softmax,
),
) # Total: 538 trainable arrays, 53_786_826 parameters,
# plus 268 non-trainable, 14_592 parameters, summarysize 205.290 MiB.5. ​
## 5.
struct DenseNetMLPBlock{N} <: AbstractModel
net::N
end
function DenseNetMLPBlock(features_in, num_features, num_dense, return_output_features = false)
prev_features = features_in
blocks = []
for i in 1:num_dense
push!(blocks, Dense(prev_features, num_features))
prev_features += num_features
push!(blocks, Dropout(0.4),
end
block = DenseNetMLPBlock(blocks)
if return_output_features
return block, prev_features
else
return block
end
end
function (d::DenseNetMLPBlock)(x)
for block in d.net
y = block(x)
x = vcat(x,y)
end
end
struct DenseNetMLP{N} <: AbstractModel
net::N
endfunction DenseNetMLP(num_features; num_classes = 10, arch = (4,4,4,4))
prev_features = 64
layers = map(arch) do num_dense
block, prev_features = DenseNetMLPBlock(prev_features, 64, num_dense; return_output_featuers = true)
return block
end
Chain(
Dense(num_features => 64, relu),
Dropout(0.2),
layers...,
Dense(_ => num_classes),
softmax
)
end