Skip to content

Fully Convolutional Networks ​

As discussed in :numref:sec_semantic_segmentation, semantic segmentation classifies images in pixel level. A fully convolutional network (FCN) uses a convolutional neural network to transform image pixels to pixel classes [85]. Unlike the CNNs that we encountered earlier for image classification or object detection, a fully convolutional network transforms the height and width of intermediate feature maps back to those of the input image: this is achieved by the transposed convolutional layer introduced in :numref:sec_transposed_conv. As a result, the classification output and the input image have a one-to-one correspondence in pixel level: the channel dimension at any output pixel holds the classification results for the input pixel at the same spatial position.

julia
using Pkg;
Pkg.activate("../../d2lai")
using d2lai, Images, DataAugmentation
using Flux, Metalhead, CUDA, cuDNN
using Serialization, Statistics
  Activating project at `~/d2l-julia/d2lai`
julia
using Logging

struct SimpleLogger <: AbstractLogger
    min_level::LogLevel
end

Logging.min_enabled_level(logger::SimpleLogger) = logger.min_level

function Logging.shouldlog(logger::SimpleLogger, level, _module, group, id)
    return level >= logger.min_level
end

function Logging.handle_message(logger::SimpleLogger, level, msg, _module, group, id, file, line, kwargs...)
    println("[$(level)] $msg")
end

global_logger(SimpleLogger(Logging.Info))
SimpleLogger(Info)

The Model ​

Here we describe the basic design of the fully convolutional network model. As shown in Figure, this model first uses a CNN to extract image features, then transforms the number of channels into the number of classes via a 1×1 convolutional layer, and finally transforms the height and width of the feature maps to those of the input image via the transposed convolution introduced in :numref:sec_transposed_conv. As a result, the model output has the same height and width as the input image, where the output channel contains the predicted classes for the input pixel at the same spatial position.

Fully convolutional network.

Below, we use a ResNet-18 model pretrained on the ImageNet dataset to extract image features and denote the model instance as pretrained_net. The last few layers of this model include a global average pooling layer and a fully connected layer: they are not needed in the fully convolutional network.

julia
net = Serialization.deserialize("../../resnet18.jls")
ResNet(
  Chain(
    Chain(
      Chain(
        Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false),  # 9_408 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        MaxPool((3, 3), pad=1, stride=2),
      ),
      Chain(
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          identity,
          Chain(
            Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
            BatchNorm(64),              # 128 parameters, plus 128
            NNlib.relu,
            Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
            BatchNorm(64),              # 128 parameters, plus 128
          ),
        ),
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          identity,
          Chain(
            Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
            BatchNorm(64),              # 128 parameters, plus 128
            NNlib.relu,
            Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
            BatchNorm(64),              # 128 parameters, plus 128
          ),
        ),
      ),
      Chain(
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          Chain(
            Conv((1, 1), 64 => 128, stride=2, bias=false),  # 8_192 parameters
            BatchNorm(128),             # 256 parameters, plus 256
          ),
          Chain(
            Conv((3, 3), 64 => 128, pad=1, stride=2, bias=false),  # 73_728 parameters
            BatchNorm(128),             # 256 parameters, plus 256
            NNlib.relu,
            Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
            BatchNorm(128),             # 256 parameters, plus 256
          ),
        ),
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          identity,
          Chain(
            Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
            BatchNorm(128),             # 256 parameters, plus 256
            NNlib.relu,
            Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
            BatchNorm(128),             # 256 parameters, plus 256
          ),
        ),
      ),
      Chain(
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          Chain(
            Conv((1, 1), 128 => 256, stride=2, bias=false),  # 32_768 parameters
            BatchNorm(256),             # 512 parameters, plus 512
          ),
          Chain(
            Conv((3, 3), 128 => 256, pad=1, stride=2, bias=false),  # 294_912 parameters
            BatchNorm(256),             # 512 parameters, plus 512
            NNlib.relu,
            Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
            BatchNorm(256),             # 512 parameters, plus 512
          ),
        ),
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          identity,
          Chain(
            Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
            BatchNorm(256),             # 512 parameters, plus 512
            NNlib.relu,
            Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
            BatchNorm(256),             # 512 parameters, plus 512
          ),
        ),
      ),
      Chain(
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          Chain(
            Conv((1, 1), 256 => 512, stride=2, bias=false),  # 131_072 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
          ),
          Chain(
            Conv((3, 3), 256 => 512, pad=1, stride=2, bias=false),  # 1_179_648 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
            NNlib.relu,
            Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
          ),
        ),
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          identity,
          Chain(
            Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
            NNlib.relu,
            Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
          ),
        ),
      ),
    ),
    Chain(
      AdaptiveMeanPool((1, 1)),
      MLUtils.flatten,
      Dense(512 => 1000),               # 513_000 parameters
    ),
  ),
)         # Total: 62 trainable arrays, 11_689_512 parameters,
          # plus 40 non-trainable, 9_600 parameters, summarysize 44.636 MiB.

Next, we create the fully convolutional network instance net. It copies all the pretrained layers in the ResNet-18 except for the final global average pooling layer and the fully connected layer that are closest to the output.

julia
backbone(net)
Chain(
  Chain(
    Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false),  # 9_408 parameters
    BatchNorm(64, relu),                # 128 parameters, plus 128
    MaxPool((3, 3), pad=1, stride=2),
  ),
  Chain(
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      identity,
      Chain(
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64),                  # 128 parameters, plus 128
        NNlib.relu,
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64),                  # 128 parameters, plus 128
      ),
    ),
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      identity,
      Chain(
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64),                  # 128 parameters, plus 128
        NNlib.relu,
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64),                  # 128 parameters, plus 128
      ),
    ),
  ),
  Chain(
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      Chain(
        Conv((1, 1), 64 => 128, stride=2, bias=false),  # 8_192 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
      ),
      Chain(
        Conv((3, 3), 64 => 128, pad=1, stride=2, bias=false),  # 73_728 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
        NNlib.relu,
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
      ),
    ),
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      identity,
      Chain(
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
        NNlib.relu,
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
      ),
    ),
  ),
  Chain(
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      Chain(
        Conv((1, 1), 128 => 256, stride=2, bias=false),  # 32_768 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
      Chain(
        Conv((3, 3), 128 => 256, pad=1, stride=2, bias=false),  # 294_912 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
        NNlib.relu,
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
    ),
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      identity,
      Chain(
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
        NNlib.relu,
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
    ),
  ),
  Chain(
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      Chain(
        Conv((1, 1), 256 => 512, stride=2, bias=false),  # 131_072 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
      Chain(
        Conv((3, 3), 256 => 512, pad=1, stride=2, bias=false),  # 1_179_648 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
        NNlib.relu,
        Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
    ),
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      identity,
      Chain(
        Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
        NNlib.relu,
        Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
    ),
  ),
)         # Total: 60 trainable arrays, 11_176_512 parameters,
          # plus 40 non-trainable, 9_600 parameters, summarysize 42.679 MiB.
julia
X = rand(320, 480, 3, 1)
backbone(net)(X) |> size
MethodError: no method matching handle_message(::SimpleLogger, ::Base.CoreLogging.LogLevel, ::String, ::Module, ::Symbol, ::Symbol, ::String, ::Int64; layer::Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, summary(x)::String, maxlog::Int64)


This method may not support any kwargs.





Closest candidates are:


  handle_message(::SimpleLogger, ::Any, ::Any, ::Any, ::Any, ::Any, ::Any, ::Any, Any...) got unsupported keyword arguments "layer", "summary(x)", "maxlog"


   @ Main ~/d2l-julia/Julia_Notebooks/CH11.Computer_Vision/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X25sdnNjb2RlLXJlbW90ZQ==.jl:13


  handle_message(!Matched::LoggingExtras.FormatLogger, ::Any...; kwargs...)


   @ LoggingExtras ~/.julia/packages/LoggingExtras/cFgEq/src/Sinks/formatlogger.jl:51


  handle_message(!Matched::Test.TestLogger, ::Any, ::Any, ::Any, ::Any, ::Any, ::Any, ::Any; kwargs...)


   @ Test ~/.julia/juliaup/julia-1.11.4+0.x64.linux.gnu/share/julia/stdlib/v1.11/Test/src/logging.jl:102


  ...








Stacktrace:


  [1] #invokelatest#2


    @ ./essentials.jl:1057 [inlined]


  [2] invokelatest


    @ ./essentials.jl:1052 [inlined]


  [3] macro expansion


    @ ./logging/logging.jl:388 [inlined]


  [4] _match_eltype(layer::Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, ::Type{Float32}, x::Array{Float64, 4})


    @ Flux ~/.julia/packages/Flux/9PibT/src/layers/stateless.jl:60


  [5] _match_eltype


    @ ~/.julia/packages/Flux/9PibT/src/layers/stateless.jl:85 [inlined]


  [6] (::Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool})(x::Array{Float64, 4})


    @ Flux ~/.julia/packages/Flux/9PibT/src/layers/conv.jl:200


  [7] macro expansion


    @ ~/.julia/packages/Flux/9PibT/src/layers/basic.jl:68 [inlined]


  [8] _applychain(layers::Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, MaxPool{2, 4}}, x::Array{Float64, 4})


    @ Flux ~/.julia/packages/Flux/9PibT/src/layers/basic.jl:68


  [9] Chain


    @ ~/.julia/packages/Flux/9PibT/src/layers/basic.jl:65 [inlined]


 [10] macro expansion


    @ ~/.julia/packages/Flux/9PibT/src/layers/basic.jl:68 [inlined]


 [11] _applychain


    @ ~/.julia/packages/Flux/9PibT/src/layers/basic.jl:68 [inlined]


 [12] (::Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, MaxPool{2, 4}}}, Chain{Tuple{Parallel{PartialFunctions.PartialFunction{nothing, nothing, typeof(Metalhead.addact), Tuple{typeof(relu)}, @NamedTuple{}}, Tuple{typeof(identity), Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, typeof(relu), Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{PartialFunctions.PartialFunction{nothing, nothing, typeof(Metalhead.addact), Tuple{typeof(relu)}, @NamedTuple{}}, Tuple{typeof(identity), Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, typeof(relu), Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Parallel{PartialFunctions.PartialFunction{nothing, nothing, typeof(Metalhead.addact), Tuple{typeof(relu)}, @NamedTuple{}}, Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, typeof(relu), Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{PartialFunctions.PartialFunction{nothing, nothing, typeof(Metalhead.addact), Tuple{typeof(relu)}, @NamedTuple{}}, Tuple{typeof(identity), Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, typeof(relu), Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Parallel{PartialFunctions.PartialFunction{nothing, nothing, typeof(Metalhead.addact), Tuple{typeof(relu)}, @NamedTuple{}}, Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, typeof(relu), Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{PartialFunctions.PartialFunction{nothing, nothing, typeof(Metalhead.addact), Tuple{typeof(relu)}, @NamedTuple{}}, Tuple{typeof(identity), Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, typeof(relu), Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}}}, Chain{Tuple{Parallel{PartialFunctions.PartialFunction{nothing, nothing, typeof(Metalhead.addact), Tuple{typeof(relu)}, @NamedTuple{}}, Tuple{Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}, Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, typeof(relu), Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Parallel{PartialFunctions.PartialFunction{nothing, nothing, typeof(Metalhead.addact), Tuple{typeof(relu)}, @NamedTuple{}}, Tuple{typeof(identity), Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, typeof(relu), Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}}}}})(x::Array{Float64, 4})


    @ Flux ~/.julia/packages/Flux/9PibT/src/layers/basic.jl:65


 [13] top-level scope


    @ ~/d2l-julia/Julia_Notebooks/CH11.Computer_Vision/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_W6sdnNjb2RlLXJlbW90ZQ==.jl:2

Next, we [use a 1×1 convolutional layer to transform the number of output channels into the number of classes (21) of the Pascal VOC2012 dataset.] Finally, we need to (increase the height and width of the feature maps by 32 times) to change them back to the height and width of the input image. Recall how to calculate the output shape of a convolutional layer in :numref:sec_padding. Since (320−64+16×2+32)/32=10 and (480−64+16×2+32)/32=15, we construct a transposed convolutional layer with stride of 32, setting the height and width of the kernel to 64, the padding to 16. In general, we can see that for stride s, padding s/2 (assuming s/2 is an integer), and the height and width of the kernel 2s, the transposed convolution will increase the height and width of the input by s times.

julia
Base.@kwdef struct ResNetImageSegModel{B,H} <: d2lai.AbstractClassifier 
    backbone::B
    head::H
end

Flux.@layer ResNetImageSegModel trainable=(backbone, head)

(rnet::ResNetImageSegModel)(x) = rnet.head(rnet.backbone(x))



function ResNetImageSegModel(num_classes::Int64)
    # net = ResNet(18, pretrain = true, inchannels = inchannels)
    # Does not work currently 
    net = Serialization.deserialize("../../resnet18.jls")
    backbone_ = backbone(net)
    fc = Chain(
        Conv((1,1), 512 => num_classes),
        ConvTranspose(bilinear_kernel(num_classes, num_classes, 64); pad = 16, stride = 32)
    )
    ResNetImageSegModel(backbone_, fc)
end
ResNetImageSegModel

Initializing Transposed Convolutional Layers ​

We already know that transposed convolutional layers can increase the height and width of feature maps. In image processing, we may need to scale up an image, i.e., upsampling. Bilinear interpolation is one of the commonly used upsampling techniques. It is also often used for initializing transposed convolutional layers.

To explain bilinear interpolation, say that given an input image we want to calculate each pixel of the upsampled output image. In order to calculate the pixel of the output image at coordinate (x,y), first map (x,y) to coordinate (x′,y′) on the input image, for example, according to the ratio of the input size to the output size. Note that the mapped x′ and y′ are real numbers. Then, find the four pixels closest to coordinate (x′,y′) on the input image. Finally, the pixel of the output image at coordinate (x,y) is calculated based on these four closest pixels on the input image and their relative distance from (x′,y′).

Upsampling of bilinear interpolation can be implemented by the transposed convolutional layer with the kernel constructed by the following bilinear_kernel function. Due to space limitations, we only provide the implementation of the bilinear_kernel function below without discussions on its algorithm design.

julia
function bilinear_kernel(in_channels, out_channels, kernel_size)
    factor = (kernel_size + 1) ÷ 2
    center = isodd(kernel_size) ? factor - 1 : factor - 0.5 
    og = (reshape(0:kernel_size-1, :, 1), reshape(0:kernel_size-1, 1, :))
    filt = (1 .- (abs.(og[1] .- center)./factor)).*(1 .- (abs.(og[2] .- center)./factor))
    weights = zeros(kernel_size, kernel_size, in_channels, out_channels)
    for i in 1:min(in_channels, out_channels)
        weights[:, :, i, i] .= filt
    end
    weights
end
bilinear_kernel (generic function with 1 method)

Reading the Dataset ​

julia
data = d2lai.VOCSegDataSet((320, 480); batchsize = 32);
train_iter, test_iter = d2lai.load_data_voc(data)
(DataLoader(::Tuple{Array{Float32, 4}, Array{Int64, 3}}, shuffle=true, batchsize=32), DataLoader(::Tuple{Array{Float32, 4}, Array{Int64, 3}}, batchsize=32))

Training ​

Now we can train our constructed fully convolutional network. The loss function and accuracy calculation here are not essentially different from those in image classification of earlier chapters. Because we use the output channel of the transposed convolutional layer to predict the class for each pixel, the channel dimension is specified in the loss calculation. In addition, the accuracy is calculated based on correctness of the predicted class for all the pixels.

julia
model = ResNetImageSegModel(21) |> f64 |> gpu
# @info model(randn(320, 480, 3, 1)) |> size

function d2lai.loss(::ResNetImageSegModel, y_pred::AbstractArray, y::AbstractArray)
    y_pred = permutedims(y_pred, [3, 1, 2, 4])
    loss = Flux.Losses.logitcrossentropy(y_pred, Flux.onehotbatch(y, 0:20); agg = identity)
    loss = mean(loss; dims = 2)
    loss = mean(loss; dims = 3) 
    loss = mean(loss)
end

function d2lai.accuracy(::ResNetImageSegModel, y_pred::AbstractArray, y::AbstractArray)
    sum(dropdims(getindex.(argmax(y_pred; dims = 3), 3) .- 1; dims = 3) .== y) / length(y)
end



trainer = Trainer(model, nothing, Optimisers.Adam(0.0001))
train_iter = data.train |> gpu 
test_iter = data.val |> gpu
model = d2lai.train_ch13(model, train_iter, test_iter, trainer;  num_epochs = 10, batchsize = 32, verbose = true)
[Info] Epoch: 1 Training Loss: 2.5827725 Val Loss: 1.111204 Val Acc: 0.8008898876822569 Train Acc: 0.5856158348893946
[Info] Epoch: 2 Training Loss: 0.65694624 Val Loss: 0.7186479 Val Acc: 0.8312956349309966 Train Acc: 0.8589283290765225
[Info] Epoch: 3 Training Loss: 0.41647875 Val Loss: 0.6748914 Val Acc: 0.837553752824372 Train Acc: 0.8945024136833221
[Info] Epoch: 4 Training Loss: 0.32431132 Val Loss: 0.6380204 Val Acc: 0.8418911329735197 Train Acc: 0.9125829653338199
[Info] Epoch: 5 Training Loss: 0.26777342 Val Loss: 0.6390684 Val Acc: 0.8420368361855575 Train Acc: 0.9244594735505381
[Info] Epoch: 6 Training Loss: 0.22320928 Val Loss: 0.6339928 Val Acc: 0.8452510133359098 Train Acc: 0.9339833662431319
[Info] Epoch: 7 Training Loss: 0.19791853 Val Loss: 0.6219831 Val Acc: 0.8471264066373178 Train Acc: 0.9398801148802369
[Info] Epoch: 8 Training Loss: 0.17984653 Val Loss: 0.5850241 Val Acc: 0.8473081885897115 Train Acc: 0.9435384861313817
[Info] Epoch: 9 Training Loss: 0.15954861 Val Loss: 0.6019194 Val Acc: 0.8480440468201661 Train Acc: 0.9478737976297792
[Info] Epoch: 10 Training Loss: 0.1479829 Val Loss: 0.6065308 Val Acc: 0.8487400483320102 Train Acc: 0.950754144577753
ResNetImageSegModel(
  Chain(
    Chain(
      Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false),  # 9_408 parameters
      BatchNorm(64, relu),              # 128 parameters, plus 128
      MaxPool((3, 3), pad=1, stride=2),
    ),
    Chain(
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        identity,
        Chain(
          Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
          BatchNorm(64),                # 128 parameters, plus 128
          NNlib.relu,
          Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
          BatchNorm(64),                # 128 parameters, plus 128
        ),
      ),
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        identity,
        Chain(
          Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
          BatchNorm(64),                # 128 parameters, plus 128
          NNlib.relu,
          Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
          BatchNorm(64),                # 128 parameters, plus 128
        ),
      ),
    ),
    Chain(
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        Chain(
          Conv((1, 1), 64 => 128, stride=2, bias=false),  # 8_192 parameters
          BatchNorm(128),               # 256 parameters, plus 256
        ),
        Chain(
          Conv((3, 3), 64 => 128, pad=1, stride=2, bias=false),  # 73_728 parameters
          BatchNorm(128),               # 256 parameters, plus 256
          NNlib.relu,
          Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
          BatchNorm(128),               # 256 parameters, plus 256
        ),
      ),
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        identity,
        Chain(
          Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
          BatchNorm(128),               # 256 parameters, plus 256
          NNlib.relu,
          Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
          BatchNorm(128),               # 256 parameters, plus 256
        ),
      ),
    ),
    Chain(
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        Chain(
          Conv((1, 1), 128 => 256, stride=2, bias=false),  # 32_768 parameters
          BatchNorm(256),               # 512 parameters, plus 512
        ),
        Chain(
          Conv((3, 3), 128 => 256, pad=1, stride=2, bias=false),  # 294_912 parameters
          BatchNorm(256),               # 512 parameters, plus 512
          NNlib.relu,
          Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
          BatchNorm(256),               # 512 parameters, plus 512
        ),
      ),
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        identity,
        Chain(
          Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
          BatchNorm(256),               # 512 parameters, plus 512
          NNlib.relu,
          Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
          BatchNorm(256),               # 512 parameters, plus 512
        ),
      ),
    ),
    Chain(
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        Chain(
          Conv((1, 1), 256 => 512, stride=2, bias=false),  # 131_072 parameters
          BatchNorm(512),               # 1_024 parameters, plus 1_024
        ),
        Chain(
          Conv((3, 3), 256 => 512, pad=1, stride=2, bias=false),  # 1_179_648 parameters
          BatchNorm(512),               # 1_024 parameters, plus 1_024
          NNlib.relu,
          Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
          BatchNorm(512),               # 1_024 parameters, plus 1_024
        ),
      ),
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        identity,
        Chain(
          Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
          BatchNorm(512),               # 1_024 parameters, plus 1_024
          NNlib.relu,
          Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
          BatchNorm(512),               # 1_024 parameters, plus 1_024
        ),
      ),
    ),
  ),
  Chain(
    Conv((1, 1), 512 => 21),            # 10_773 parameters
    ConvTranspose((64, 64), 21 => 21, pad=16, stride=32),  # 1_806_357 parameters
  ),
)         # Total: 64 trainable arrays, 12_993_642 parameters,
          # plus 40 non-trainable, 9_600 parameters, summarysize 18.469 KiB.

Prediction ​

When predicting, we need to standardize the input image in each channel and transform the image into the four-dimensional input format required by the CNN.

julia
function predict(model, img)
    preds = model(img)
    pred_classes = getindex.(argmax(preds; dims = 3), 3) .- 1
    dropdims(pred_classes; dims = 3)
end
predict (generic function with 1 method)
julia
function label2image(preds)
    imgs = map(eachslice(preds; dims = 3)) do x
        labeled_img = getindex.(Ref(d2lai.VOC_COLORMAP), cpu(x) .+ 1)
        labeled_img = cat(
            getindex.(labeled_img, 1),
            getindex.(labeled_img, 2),
            getindex.(labeled_img, 3);
            dims = 3
        )
        colorview(RGB, permutedims(Float64.(labeled_img), (3, 2, 1)))
    end
end
label2image (generic function with 1 method)

Images in the test dataset vary in size and shape. Since the model uses a transposed convolutional layer with stride of 32, when the height or width of an input image is indivisible by 32, the output height or width of the transposed convolutional layer will deviate from the shape of the input image. In order to address this issue, we can crop multiple rectangular areas with height and width that are integer multiples of 32 in the image, and perform forward propagation on the pixels in these areas separately. Note that the union of these rectangular areas needs to completely cover the input image. When a pixel is covered by multiple rectangular areas, the average of the transposed convolution outputs in separate areas for this same pixel can be input to the softmax operation to predict the class.

For simplicity, we only read a few larger test images, and crop a 320×480 area for prediction starting from the upper-left corner of an image. For these test images, we print their cropped areas, prediction results, and ground-truth row by row.

julia
imgs_test = data.val[1][:, :, :, 1:4] |> gpu
preds = predict(model, imgs_test)
segmented_imgs = label2image(preds)

colored_imgs = map(eachslice(data.val[1][:, :, :, 1:4]; dims = 4)) do x
            colorview(RGB, permutedims(Float64.(x), (3, 2, 1)))
end

d2lai.show_images(vcat(colored_imgs, segmented_imgs), 2, 4)

Summary ​

  • The fully convolutional network first uses a CNN to extract image features, then transforms the number of channels into the number of classes via a 1×1 convolutional layer, and finally transforms the height and width of the feature maps to those of the input image via the transposed convolution.

  • In a fully convolutional network, we can use upsampling of bilinear interpolation to initialize the transposed convolutional layer.

Exercises ​

  1. If we use Xavier initialization for the transposed convolutional layer in the experiment, how does the result change?

  2. Can you further improve the accuracy of the model by tuning the hyperparameters?

  3. Predict the classes of all pixels in test images.

  4. The original fully convolutional network paper also uses outputs of some intermediate CNN layers [85]. Try to implement this idea.