Skip to content

Fine-Tuning ​

In earlier chapters, we discussed how to train models on the Fashion-MNIST training dataset with only 60000 images. We also described ImageNet, the most widely used large-scale image dataset in academia, which has more than 10 million images and 1000 objects. However, the size of the dataset that we usually encounter is between those of the two datasets.

Suppose that we want to recognize different types of chairs from images, and then recommend purchase links to users. One possible method is to first identify 100 common chairs, take 1000 images of different angles for each chair, and then train a classification model on the collected image dataset. Although this chair dataset may be larger than the Fashion-MNIST dataset, the number of examples is still less than one-tenth of that in ImageNet. This may lead to overfitting of complicated models that are suitable for ImageNet on this chair dataset. Besides, due to the limited amount of training examples, the accuracy of the trained model may not meet practical requirements.

In order to address the above problems, an obvious solution is to collect more data. However, collecting and labeling data can take a lot of time and money. For example, in order to collect the ImageNet dataset, researchers have spent millions of dollars from research funding. Although the current data collection cost has been significantly reduced, this cost still cannot be ignored.

Another solution is to apply transfer learning to transfer the knowledge learned from the source dataset to the target dataset. For example, although most of the images in the ImageNet dataset have nothing to do with chairs, the model trained on this dataset may extract more general image features, which can help identify edges, textures, shapes, and object composition. These similar features may also be effective for recognizing chairs.

Steps ​

In this section, we will introduce a common technique in transfer learning: fine-tuning. As shown in Figure, fine-tuning consists of the following four steps:

  1. Pretrain a neural network model, i.e., the source model, on a source dataset (e.g., the ImageNet dataset).

  2. Create a new neural network model, i.e., the target model. This copies all model designs and their parameters on the source model except the output layer. We assume that these model parameters contain the knowledge learned from the source dataset and this knowledge will also be applicable to the target dataset. We also assume that the output layer of the source model is closely related to the labels of the source dataset; thus it is not used in the target model.

  3. Add an output layer to the target model, whose number of outputs is the number of categories in the target dataset. Then randomly initialize the model parameters of this layer.

  4. Train the target model on the target dataset, such as a chair dataset. The output layer will be trained from scratch, while the parameters of all the other layers are fine-tuned based on the parameters of the source model.

Fine tuning.

When target datasets are much smaller than source datasets, fine-tuning helps to improve models' generalization ability.

Hot Dog Recognition ​

Let's demonstrate fine-tuning via a concrete case: hot dog recognition. We will fine-tune a ResNet model on a small dataset, which was pretrained on the ImageNet dataset. This small dataset consists of thousands of images with and without hot dogs. We will use the fine-tuned model to recognize hot dogs from images.

julia
using Pkg;
Pkg.activate("../../d2lai")
using d2lai, Flux, Images, CUDA, cuDNN, DataAugmentation, Serialization
using Metalhead
using Flux.Zygote
using MPI, NCCL, Statistics, Plots
  Activating project at `~/d2l-julia/d2lai`
julia
using Logging

struct SimpleLogger <: AbstractLogger
    min_level::LogLevel
end

Logging.min_enabled_level(logger::SimpleLogger) = logger.min_level

function Logging.shouldlog(logger::SimpleLogger, level, _module, group, id)
    return level >= logger.min_level
end

function Logging.handle_message(logger::SimpleLogger, level, msg, _module, group, id, file, line, kwargs...)
    println("[$(level)] $msg")
end

global_logger(SimpleLogger(Logging.Info))
SimpleLogger(Info)

Reading the Dataset ​

The hot dog dataset we use was taken from online images. This dataset consists of 1400 positive-class images containing hot dogs, and as many negative-class images containing other foods. 1000 images of both classes are used for training and the rest are for testing.

After unzipping the downloaded dataset, we obtain two folders hotdog/train and hotdog/test. Both folders have hotdog and not-hotdog subfolders, either of which contains images of the corresponding class.

julia
function read_hotdog_dataset(extracted_folder, aug = ImageToTensor() ; train = true)
    folder = train ? "train" : "test"
    hotdog_imgs = readdir(joinpath(extracted_folder, "hotdog", folder, "hotdog"); join = true)
    len_hotdog = length(hotdog_imgs)

    nothotdog_imgs = readdir(joinpath(extracted_folder, "hotdog", folder, "not-hotdog"); join = true)
    len_nothotdog = length(nothotdog_imgs)
    img_names = vcat(hotdog_imgs, nothotdog_imgs)
         
    feature_imgs = map(img_names) do img_name
        img = Images.load(img_name)
        img = Image(img)
        img_tensor = apply(aug, img) |> itemdata |> collect
        img_tensor = permutedims(img_tensor, (2,1,3))
    end 

    labels = vcat(fill(1, len_hotdog), fill(0, len_nothotdog))
    return feature_imgs, labels
end
read_hotdog_dataset (generic function with 2 methods)

We create two instances to read all the image files in the training and testing datasets, respectively. The first 8 positive examples and the last 8 negative images are shown below. As you can see, the images vary in size and aspect ratio

julia
file = d2lai._download("hotdog.zip")
extracted_folder = d2lai._extract(file)

train_imgs, _ = read_hotdog_dataset(extracted_folder; train = true)
test_imgs, _ = read_hotdog_dataset(extracted_folder; train = false)

show_imgs = vcat(train_imgs[1:8], train_imgs[end-7:end])
d2lai.show_images(show_imgs, 8, 2)

During training, we first crop a random area of random size and random aspect ratio from the image, and then scale this area to a 224×224 input image. During testing, we scale both the height and width of an image to 256 pixels, and then crop a central 224×224 area as input. In addition, for the three RGB (red, green, and blue) color channels we standardize their values channel by channel. Concretely, the mean value of a channel is subtracted from each value of that channel and then the result is divided by the standard deviation of that channel.

julia
train_aug = DataAugmentation.compose(
    Maybe(FlipX{2}()),
    RandomCrop((224, 224)),
    ImageToTensor()
)

test_aug = DataAugmentation.compose(
    Maybe(FlipX{2}()),
    CenterCrop((224, 224)),
    ImageToTensor()
)

function normalize_batched(img::AbstractArray{T,4}, mean, std) where T
    return (img .- reshape(mean, 1, 1, :, 1)) ./ reshape(std, 1, 1, :, 1)
end


struct HotDogData{T,V,A} <: d2lai.AbstractData
    train::T
    val::V
    args::A
    function HotDogData(; batchsize = 64, flatten = false, aug = nothing)
        file = d2lai._download("hotdog.zip")
        extracted_folder = d2lai._extract(file)
        train_imgs, train_labels = read_hotdog_dataset(extracted_folder, test_aug; train = true)
        test_imgs, test_labels = read_hotdog_dataset(extracted_folder, test_aug; train = false)

        train_imgs = stack(train_imgs; dims = 4)
        test_imgs = stack(test_imgs; dims = 4)

        mean = [0.485f0, 0.456f0, 0.406f0]
        std = [0.229f0, 0.224f0, 0.225f0]

        train_imgs_normalized = normalize_batched(train_imgs, mean, std)
        test_imgs_normalized = normalize_batched(test_imgs, mean, std)

        train = (features = train_imgs_normalized, labels = train_labels)
        test = (features = test_imgs_normalized, labels = test_labels)
        args = (batchsize = batchsize, flatten = flatten)
        new{typeof(train), typeof(test), typeof(args)}(train, test, args)
    end
end

function d2lai.get_dataloader(data::HotDogData; train = true)
    if train 
        return Flux.DataLoader(data.train; shuffle = true, batchsize = data.args.batchsize)
    else
        return Flux.DataLoader(data.val; shuffle = true, batchsize = data.args.batchsize)
    end
end

Defining and Initializing the Model ​

We use ResNet-18, which was pretrained on the ImageNet dataset, as the source model. Here, we specify pretrain=True to automatically download the pretrained model parameters. If this model is used for the first time, Internet connection is required for download.

Note

Since there is a known issue with loading pre-trained ResNet18 with Metalhead v0.9.5, we load a serialized version of pretrained resnet 18.

julia
model = Serialization.deserialize("../../resnet18.jls")
ResNet(
  Chain(
    Chain(
      Chain(
        Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false),  # 9_408 parameters
        BatchNorm(64, relu),            # 128 parameters, plus 128
        MaxPool((3, 3), pad=1, stride=2),
      ),
      Chain(
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          identity,
          Chain(
            Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
            BatchNorm(64),              # 128 parameters, plus 128
            NNlib.relu,
            Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
            BatchNorm(64),              # 128 parameters, plus 128
          ),
        ),
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          identity,
          Chain(
            Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
            BatchNorm(64),              # 128 parameters, plus 128
            NNlib.relu,
            Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
            BatchNorm(64),              # 128 parameters, plus 128
          ),
        ),
      ),
      Chain(
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          Chain(
            Conv((1, 1), 64 => 128, stride=2, bias=false),  # 8_192 parameters
            BatchNorm(128),             # 256 parameters, plus 256
          ),
          Chain(
            Conv((3, 3), 64 => 128, pad=1, stride=2, bias=false),  # 73_728 parameters
            BatchNorm(128),             # 256 parameters, plus 256
            NNlib.relu,
            Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
            BatchNorm(128),             # 256 parameters, plus 256
          ),
        ),
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          identity,
          Chain(
            Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
            BatchNorm(128),             # 256 parameters, plus 256
            NNlib.relu,
            Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
            BatchNorm(128),             # 256 parameters, plus 256
          ),
        ),
      ),
      Chain(
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          Chain(
            Conv((1, 1), 128 => 256, stride=2, bias=false),  # 32_768 parameters
            BatchNorm(256),             # 512 parameters, plus 512
          ),
          Chain(
            Conv((3, 3), 128 => 256, pad=1, stride=2, bias=false),  # 294_912 parameters
            BatchNorm(256),             # 512 parameters, plus 512
            NNlib.relu,
            Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
            BatchNorm(256),             # 512 parameters, plus 512
          ),
        ),
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          identity,
          Chain(
            Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
            BatchNorm(256),             # 512 parameters, plus 512
            NNlib.relu,
            Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
            BatchNorm(256),             # 512 parameters, plus 512
          ),
        ),
      ),
      Chain(
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          Chain(
            Conv((1, 1), 256 => 512, stride=2, bias=false),  # 131_072 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
          ),
          Chain(
            Conv((3, 3), 256 => 512, pad=1, stride=2, bias=false),  # 1_179_648 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
            NNlib.relu,
            Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
          ),
        ),
        Parallel(
          PartialFunction(
            "",
            Metalhead.addact,
            (NNlib.relu,),
            NamedTuple(),
          ),
          identity,
          Chain(
            Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
            NNlib.relu,
            Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
            BatchNorm(512),             # 1_024 parameters, plus 1_024
          ),
        ),
      ),
    ),
    Chain(
      AdaptiveMeanPool((1, 1)),
      MLUtils.flatten,
      Dense(512 => 1000),               # 513_000 parameters
    ),
  ),
)         # Total: 62 trainable arrays, 11_689_512 parameters,
          # plus 40 non-trainable, 9_600 parameters, summarysize 44.636 MiB.

In order to access the backbone of the model, we can do:

julia
backbone(model)
Chain(
  Chain(
    Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false),  # 9_408 parameters
    BatchNorm(64, relu),                # 128 parameters, plus 128
    MaxPool((3, 3), pad=1, stride=2),
  ),
  Chain(
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      identity,
      Chain(
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64),                  # 128 parameters, plus 128
        NNlib.relu,
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64),                  # 128 parameters, plus 128
      ),
    ),
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      identity,
      Chain(
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64),                  # 128 parameters, plus 128
        NNlib.relu,
        Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
        BatchNorm(64),                  # 128 parameters, plus 128
      ),
    ),
  ),
  Chain(
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      Chain(
        Conv((1, 1), 64 => 128, stride=2, bias=false),  # 8_192 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
      ),
      Chain(
        Conv((3, 3), 64 => 128, pad=1, stride=2, bias=false),  # 73_728 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
        NNlib.relu,
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
      ),
    ),
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      identity,
      Chain(
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
        NNlib.relu,
        Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
        BatchNorm(128),                 # 256 parameters, plus 256
      ),
    ),
  ),
  Chain(
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      Chain(
        Conv((1, 1), 128 => 256, stride=2, bias=false),  # 32_768 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
      Chain(
        Conv((3, 3), 128 => 256, pad=1, stride=2, bias=false),  # 294_912 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
        NNlib.relu,
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
    ),
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      identity,
      Chain(
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
        NNlib.relu,
        Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
        BatchNorm(256),                 # 512 parameters, plus 512
      ),
    ),
  ),
  Chain(
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      Chain(
        Conv((1, 1), 256 => 512, stride=2, bias=false),  # 131_072 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
      Chain(
        Conv((3, 3), 256 => 512, pad=1, stride=2, bias=false),  # 1_179_648 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
        NNlib.relu,
        Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
    ),
    Parallel(
      PartialFunction(
        "",
        Metalhead.addact,
        (NNlib.relu,),
        NamedTuple(),
      ),
      identity,
      Chain(
        Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
        NNlib.relu,
        Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
        BatchNorm(512),                 # 1_024 parameters, plus 1_024
      ),
    ),
  ),
)         # Total: 60 trainable arrays, 11_176_512 parameters,
          # plus 40 non-trainable, 9_600 parameters, summarysize 42.679 MiB.

As a fully connected layer, it transforms ResNet's final global average pooling outputs into 1000 class outputs of the ImageNet dataset. We then construct a new neural network as the target model. It is defined in the same way as the pretrained source model except that its number of outputs in the final layer is set to the number of classes in the target dataset (rather than 1000).

In the code below, the model parameters before the output layer of the target model instance ResNet18PreTrained are initialized to model parameters of the corresponding layers from the source model. Since these model parameters were obtained via pretraining on ImageNet, they are effective. Therefore, we can only use a small learning rate to fine-tune such pretrained parameters. In contrast, model parameters in the output layer are randomly initialized and generally require a larger learning rate to be learned from scratch. Letting the base learning rate be η, a learning rate of 10η will be used to iterate the model parameters in the output layer.

julia
Base.@kwdef struct ResNet18PreTrained{B,F} <: d2lai.AbstractClassifier 
    backbone::B
    fc::F
end

Flux.@layer ResNet18PreTrained

(rnet::ResNet18PreTrained)(x) = rnet.fc(rnet.backbone(x))



function ResNet18PreTrained(num_classes::Int64)
    # net = ResNet(18, pretrain = true, inchannels = inchannels)
    # Does not work currently 
    net = Serialization.deserialize("../../resnet18.jls")
    backbone_ = backbone(net)
    fc = Chain(
        AdaptiveMeanPool((1, 1)),
        Flux.flatten, 
        Dense(512 => num_classes)
    )
    ResNet18PreTrained(backbone_, fc)
end

net = ResNet18PreTrained(2) |> gpu
ResNet18PreTrained(
  Chain(
    Chain(
      Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false),  # 9_408 parameters
      BatchNorm(64, relu),              # 128 parameters, plus 128
      MaxPool((3, 3), pad=1, stride=2),
    ),
    Chain(
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        identity,
        Chain(
          Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
          BatchNorm(64),                # 128 parameters, plus 128
          NNlib.relu,
          Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
          BatchNorm(64),                # 128 parameters, plus 128
        ),
      ),
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        identity,
        Chain(
          Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
          BatchNorm(64),                # 128 parameters, plus 128
          NNlib.relu,
          Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
          BatchNorm(64),                # 128 parameters, plus 128
        ),
      ),
    ),
    Chain(
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        Chain(
          Conv((1, 1), 64 => 128, stride=2, bias=false),  # 8_192 parameters
          BatchNorm(128),               # 256 parameters, plus 256
        ),
        Chain(
          Conv((3, 3), 64 => 128, pad=1, stride=2, bias=false),  # 73_728 parameters
          BatchNorm(128),               # 256 parameters, plus 256
          NNlib.relu,
          Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
          BatchNorm(128),               # 256 parameters, plus 256
        ),
      ),
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        identity,
        Chain(
          Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
          BatchNorm(128),               # 256 parameters, plus 256
          NNlib.relu,
          Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
          BatchNorm(128),               # 256 parameters, plus 256
        ),
      ),
    ),
    Chain(
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        Chain(
          Conv((1, 1), 128 => 256, stride=2, bias=false),  # 32_768 parameters
          BatchNorm(256),               # 512 parameters, plus 512
        ),
        Chain(
          Conv((3, 3), 128 => 256, pad=1, stride=2, bias=false),  # 294_912 parameters
          BatchNorm(256),               # 512 parameters, plus 512
          NNlib.relu,
          Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
          BatchNorm(256),               # 512 parameters, plus 512
        ),
      ),
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        identity,
        Chain(
          Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
          BatchNorm(256),               # 512 parameters, plus 512
          NNlib.relu,
          Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
          BatchNorm(256),               # 512 parameters, plus 512
        ),
      ),
    ),
    Chain(
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        Chain(
          Conv((1, 1), 256 => 512, stride=2, bias=false),  # 131_072 parameters
          BatchNorm(512),               # 1_024 parameters, plus 1_024
        ),
        Chain(
          Conv((3, 3), 256 => 512, pad=1, stride=2, bias=false),  # 1_179_648 parameters
          BatchNorm(512),               # 1_024 parameters, plus 1_024
          NNlib.relu,
          Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
          BatchNorm(512),               # 1_024 parameters, plus 1_024
        ),
      ),
      Parallel(
        PartialFunction(
          "",
          Metalhead.addact,
          (NNlib.relu,),
          NamedTuple(),
        ),
        identity,
        Chain(
          Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
          BatchNorm(512),               # 1_024 parameters, plus 1_024
          NNlib.relu,
          Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
          BatchNorm(512),               # 1_024 parameters, plus 1_024
        ),
      ),
    ),
  ),
  Chain(
    AdaptiveMeanPool((1, 1)),
    Flux.flatten,
    Dense(512 => 2),                    # 1_026 parameters
  ),
)         # Total: 62 trainable arrays, 11_177_538 parameters,
          # plus 40 non-trainable, 9_600 parameters, summarysize 17.992 KiB.

We define the train_ch13 function again with certain changes. This implementation, expects trainer.opt to be a named tuple where we provide opt for each sub field of the model. As you can see then, we map over the keys of this NamedTuple to define the optimiser state and then re construct the model. The reconstruction is the reason why we define our ResNet18PreTrained with Base.@kwdef

julia
function train_ch13(model, train_iter, test_iter, trainer, opt::NamedTuple = trainer.opt;  num_epochs = 100, batchsize = 256, verbose = true)
  DistributedUtils.initialize(NCCLBackend)
  backend = DistributedUtils.get_distributed_backend(NCCLBackend)
  rank = DistributedUtils.local_rank(backend)
  
  config = map(trainer.opt, keys(trainer.opt)) do opt, key
      submodel = getproperty(model, key)
      submodel = DistributedUtils.synchronize!!(backend, DistributedUtils.FluxDistributedModel(submodel); root=0) 
      
      opt = DistributedUtils.DistributedOptimizer(backend, opt)
    
      st_opt = Optimisers.setup(opt, submodel)
      st_opt = DistributedUtils.synchronize!!(backend, st_opt; root=0) 
      (; submodel, opt, st_opt)
  end

  train_data = DistributedUtils.DistributedDataContainer(
            backend, train_iter
        )

  train_loader = Flux.DataLoader(train_data, batchsize = batchsize, shuffle = true)

  val_data = DistributedUtils.DistributedDataContainer(
            backend, test_iter
        )

  val_loader = Flux.DataLoader(val_data, batchsize = batchsize)

  for i in 1:num_epochs 
    losses = (train_losses = [], val_losses = [], val_acc = [])
    for batch in train_loader 
      l, back = Zygote.pullback(d2lai.training_step, model, batch)
      g = back(one(l))[1]
      config = map(keys(trainer.opt), config) do key, cnf 
        updated_state, updated_submodel = Optimisers.update(cnf.st_opt, cnf.submodel, getproperty(g, key))
        (;  st_opt = updated_state, submodel = updated_submodel, opt = cnf.opt)
      end
      re = collect(keys(trainer.opt)) .=> getproperty.(config, :submodel)
      model = typeof(model)(; re...)
      push!(losses.train_losses, d2lai.training_step(model, batch))
    end
    for batch in val_loader 
      val_loss, val_acc = d2lai.validation_step(model, batch)
      push!(losses.val_losses, val_loss)
      push!(losses.val_acc, val_acc)
    end
    verbose &&@info "Epoch: $i Training Loss: $(mean(losses.train_losses)) Val Loss: $(mean(losses.val_losses)) Val Acc: $(mean(losses.val_acc))" 

    d2lai.draw_metrics(model, i, trainer, losses)
  end
  verbose && Plots.display(trainer.board.plt)
end
train_ch13 (generic function with 2 methods)
julia

data = HotDogData(; batchsize = 128)


function d2lai.loss(model::ResNet18PreTrained, y_pred, y)
    Flux.logitcrossentropy(y_pred, Flux.onehotbatch(y, 0:1))
end

opt = (
    backbone = Flux.Optimisers.Adam(5e-5),
    fc = Flux.Optimisers.Adam(5e-4)
)
trainer = Trainer(net, nothing, opt)

train_iter = data.train |> gpu
test_iter = data.val |> gpu

train_ch13(net, train_iter, test_iter, trainer; num_epochs= 10)
[Info] Epoch: 1 Training Loss: 0.33564004 Val Loss: 0.20766422 Val Acc: 0.90625
[Info] Epoch: 2 Training Loss: 0.09363069 Val Loss: 0.11151349 Val Acc: 0.9599609375
[Info] Epoch: 3 Training Loss: 0.029383207 Val Loss: 0.10629809 Val Acc: 0.9658203125
[Info] Epoch: 4 Training Loss: 0.00954765 Val Loss: 0.10614552 Val Acc: 0.9599609375
[Info] Epoch: 5 Training Loss: 0.004738503 Val Loss: 0.107036114 Val Acc: 0.966796875
[Info] Epoch: 6 Training Loss: 0.0028371501 Val Loss: 0.1050554 Val Acc: 0.96875
[Info] Epoch: 7 Training Loss: 0.0020083808 Val Loss: 0.10353405 Val Acc: 0.9697265625
[Info] Epoch: 8 Training Loss: 0.0015513377 Val Loss: 0.10319546 Val Acc: 0.96875
[Info] Epoch: 9 Training Loss: 0.0012680623 Val Loss: 0.10395229 Val Acc: 0.9609375
[Info] Epoch: 10 Training Loss: 0.0010735149 Val Loss: 0.105356134 Val Acc: 0.9609375

For comparison, we define an identical model, but initialize all of its model parameters to random values. Since the entire model needs to be trained from scratch, we can use a larger learning rate.

julia
scratch_model = Metalhead.ResNet(18; pretrain = false)
scratch_backbone = backbone(scratch_model)
scratch_fc = Chain(
        AdaptiveMeanPool((1, 1)),
        Flux.flatten, 
        Dense(512 => 2)
    )
scratch_model = ResNet18PreTrained(scratch_backbone, scratch_fc) |> gpu 

opt = (
    backbone = Flux.Optimisers.Adam(5e-4),
    fc = Flux.Optimisers.Adam(5e-3)
)
trainer = Trainer(scratch_model, nothing, opt)
train_ch13(scratch_model, train_iter, test_iter, trainer; num_epochs= 10)
[Info] Epoch: 1 Training Loss: 1.0970988 Val Loss: 0.67144907 Val Acc: 0.650390625
[Info] Epoch: 2 Training Loss: 0.9108904 Val Loss: 1.2049842 Val Acc: 0.466796875
[Info] Epoch: 3 Training Loss: 0.49674577 Val Loss: 0.47337252 Val Acc: 0.8388671875
[Info] Epoch: 4 Training Loss: 0.3468305 Val Loss: 0.37830657 Val Acc: 0.8603515625
[Info] Epoch: 5 Training Loss: 0.3518307 Val Loss: 0.43837076 Val Acc: 0.822265625
[Info] Epoch: 6 Training Loss: 0.34420398 Val Loss: 0.33772746 Val Acc: 0.875
[Info] Epoch: 7 Training Loss: 0.34272197 Val Loss: 0.34560323 Val Acc: 0.8564453125
[Info] Epoch: 8 Training Loss: 0.29734886 Val Loss: 0.37139577 Val Acc: 0.841796875
[Info] Epoch: 9 Training Loss: 0.25808263 Val Loss: 0.39100593 Val Acc: 0.830078125
[Info] Epoch: 10 Training Loss: 0.2289773 Val Loss: 0.31808153 Val Acc: 0.8759765625

As we can see, the fine-tuned model tends to perform better for the same epoch because its initial parameter values are more effective.

Summary ​

  • Transfer learning transfers knowledge learned from the source dataset to the target dataset. Fine-tuning is a common technique for transfer learning.

  • The target model copies all model designs with their parameters from the source model except the output layer, and fine-tunes these parameters based on the target dataset. In contrast, the output layer of the target model needs to be trained from scratch.

  • Generally, fine-tuning parameters uses a smaller learning rate, while training the output layer from scratch can use a larger learning rate.

Exercises ​

  1. Keep increasing the learning rate of finetune_net. How does the accuracy of the model change?

  2. Further adjust hyperparameters of finetune_net and scratch_net in the comparative experiment. Do they still differ in accuracy?

  3. Set the parameters before the output layer of finetune_net to those of the source model and do not update them during training. How does the accuracy of the model change? You can use the following code.