ResNet for CIFAR10 classification using Julia
Configuration
xxxxxxxxxx
1
1
using PlutoUI
Table of Contents
xxxxxxxxxx
1
1
TableOfContents()
xxxxxxxxxx
1
1
using Flux
xxxxxxxxxx
1
1
using CUDA
xxxxxxxxxx
1
1
using MLDatasets
xxxxxxxxxx
1
1
using Images
xxxxxxxxxx
1
1
using Augmentor
xxxxxxxxxx
1
1
using Parameters
xxxxxxxxxx
1
1
using IterTools
xxxxxxxxxx
1
1
using OnlineStats
xxxxxxxxxx
1
1
using Printf
Config
xxxxxxxxxx
6
1
struct Config
2
batchsize::Int = 32
3
throttle::Int = 20
4
lr::Float32 = 1f-3
5
epochs::Int = 2
6
end
Config
batchsize: Int64 32
throttle: Int64 20
lr: Float32 0.001f0
epochs: Int64 2
xxxxxxxxxx
1
1
config = Config()
Data
One needs to download data only once
false
xxxxxxxxxx
1
1
download_data = false
xxxxxxxxxx
3
1
if download_data
2
CIFAR10.download(i_accept_the_terms_of_use=true)
3
end
xxxxxxxxxx
1
1
train_data = CIFAR10.traindata(Float32);
xxxxxxxxxx
1
1
test_data = CIFAR10.testdata(Float32);
8-step Augmentor.ImmutablePipeline:
1.) Either: (50%) Flip the X axis. (50%) No operation.
2.) Either: (50%) ShearX by ϕ ∈ -5:5 degree. (50%) ShearY by ψ ∈ -5:5 degree.
3.) Rotate by θ ∈ -15:15 degree
4.) Crop a 32×32 window around the center
5.) Zoom by I ∈ {0.9×0.9, 1.0×1.0, 1.1×1.1, 1.2×1.2}
6.) Split colorant into its color channels
7.) Permute dimension order to (3, 2, 1)
8.) Convert eltype to Float32
xxxxxxxxxx
3
1
train_aug = FlipX(0.5) |> ShearX(-5:5) * ShearY(-5:5) |> Rotate(-15:15) |>
2
CropSize(32,32) |> Zoom(0.9:0.1:1.2) |>
3
SplitChannels() |> PermuteDims(3, 2, 1) |> ConvertEltype(Float32)
collate (generic function with 1 method)
xxxxxxxxxx
5
1
function collate((imgs, labels))
2
imgs = imgs |> gpu
3
labels = Flux.onehotbatch(labels .+ 1, 1:10) |> gpu
4
imgs, labels
5
end
collate (generic function with 2 methods)
xxxxxxxxxx
5
1
function collate((imgs, labels), aug)
2
imgs_aug = Array{Float32}(undef, size(imgs))
3
augmentbatch!(imgs_aug, CIFAR10.convert2image(imgs), aug)
4
collate((imgs_aug, labels))
5
end
xxxxxxxxxx
2
1
train_loader = imap(d -> collate(d, train_aug),
2
Flux.Data.DataLoader(train_data, batchsize=config.batchsize, shuffle=true));
xxxxxxxxxx
2
1
test_loader = imap(collate,
2
Flux.Data.DataLoader(test_data, batchsize=config.batchsize, shuffle=false));
"deer"
xxxxxxxxxx
5
1
begin
2
batch = iterate(train_loader)[1]
3
CIFAR10.classnames()[Flux.onecold(cpu(batch[2])[:, 1], 1:10)],
4
CIFAR10.convert2image(cpu(batch[1])[:,:,:, 1])
5
end
Model
conv_block (generic function with 1 method)
xxxxxxxxxx
5
1
function conv_block(ch::Pair; kernel_size=3, stride=1, activation=relu)
2
Chain(Conv((kernel_size, kernel_size), ch, pad=SamePad(), stride=stride,
3
init=Flux.kaiming_normal),
4
BatchNorm(ch.second, activation))
5
end
basic_residual (generic function with 1 method)
xxxxxxxxxx
4
1
function basic_residual(ch::Pair)
2
Chain(conv_block(ch),
3
conv_block(ch.second => ch.second, activation=identity))
4
end
xxxxxxxxxx
19
1
begin
2
struct AddMerge
3
gamma
4
expand
5
end
6
7
Flux. AddMerge
8
9
function AddMerge(ch::Pair)
10
if ch.first == ch.second
11
expand = identity
12
else
13
expand = conv_block(ch, kernel_size=1, activation=identity)
14
end
15
AddMerge([0.f0], expand)
16
end
17
18
(m::AddMerge)(x1, x2) = relu.(m.gamma .* x1 .+ m.expand(x2))
19
end
residual_block (generic function with 1 method)
xxxxxxxxxx
4
1
function residual_block(ch::Pair)
2
residual = basic_residual(ch)
3
SkipConnection(residual, AddMerge(ch))
4
end
residual_body (generic function with 1 method)
xxxxxxxxxx
15
1
function residual_body(in_channels, repetitions, downsamplings)
2
layers = []
3
res_channels = in_channels
4
for (rep, stride) in zip(repetitions, downsamplings)
5
if stride > 1
6
push!(layers, MaxPool((stride, stride)))
7
end
8
for i = 1:rep
9
push!(layers, residual_block(in_channels => res_channels))
10
in_channels = res_channels
11
end
12
res_channels *= 2
13
end
14
Chain(layers...)
15
end
stem (generic function with 2 methods)
xxxxxxxxxx
9
1
function stem(in_channels=3; channel_list = [32, 32, 64], stride=1)
2
layers = []
3
for channels in channel_list
4
push!(layers, conv_block(in_channels => channels, stride=stride))
5
in_channels = channels
6
stride=1
7
end
8
Chain(layers...)
9
end
head (generic function with 2 methods)
xxxxxxxxxx
6
1
function head(in_channels, classes, p_drop=0.)
2
Chain(GlobalMeanPool(),
3
flatten,
4
Dropout(p_drop),
5
Dense(in_channels, classes))
6
end
resnet (generic function with 1 method)
xxxxxxxxxx
5
1
function resnet(classes, repetitions, downsamplings; in_channels=3, p_drop=0.)
2
Chain(stem(in_channels, stride=downsamplings[1]),
3
residual_body(64, repetitions, downsamplings[1:end]),
4
head(64 * 2^(length(repetitions)-1), classes, p_drop))
5
end
Chain(Chain(Chain(Conv((3, 3), 3=>32), BatchNorm(32, λ = relu)), Chain(Conv((3, 3), 32=>32), BatchNorm(32, λ = relu)), Chain(Conv((3, 3), 32=>64), BatchNorm(64, λ = relu))), Chain(SkipConnection(Chain(Chain(Conv((3, 3), 64=>64), BatchNorm(64, λ = relu)), Chain(Conv((3, 3), 64=>64), BatchNorm(64))), AddMerge(Float32[0.0], identity)), SkipConnection(Chain(Chain(Conv((3, 3), 64=>64), BatchNorm(64, λ = relu)), Chain(Conv((3, 3), 64=>64), BatchNorm(64))), AddMerge(Float32[0.0], identity)), SkipConnection(Chain(Chain(Conv((3, 3), 64=>128), BatchNorm(128, λ = relu)), Chain(Conv((3, 3), 128=>128), BatchNorm(128))), AddMerge(Float32[0.0], Chain(Conv((1, 1), 64=>128), BatchNorm(128)))), SkipConnection(Chain(Chain(Conv((3, 3), 128=>128), BatchNorm(128, λ = relu)), Chain(Conv((3, 3), 128=>128), BatchNorm(128))), AddMerge(Float32[0.0], identity)), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), SkipConnection(Chain(Chain(Conv((3, 3), 128=>256), BatchNorm(256, λ = relu)), Chain(Conv((3, 3), 256=>256), BatchNorm(256))), AddMerge(Float32[0.0], Chain(Conv((1, 1), 128=>256), BatchNorm(256)))), SkipConnection(Chain(Chain(Conv((3, 3), 256=>256), BatchNorm(256, λ = relu)), Chain(Conv((3, 3), 256=>256), BatchNorm(256))), AddMerge(Float32[0.0], identity)), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), SkipConnection(Chain(Chain(Conv((3, 3), 256=>512), BatchNorm(512, λ = relu)), Chain(Conv((3, 3), 512=>512), BatchNorm(512))), AddMerge(Float32[0.0], Chain(Conv((1, 1), 256=>512), BatchNorm(512)))), SkipConnection(Chain(Chain(Conv((3, 3), 512=>512), BatchNorm(512, λ = relu)), Chain(Conv((3, 3), 512=>512), BatchNorm(512))), AddMerge(Float32[0.0], identity))), Chain(GlobalMeanPool(), flatten, Dropout(0.3), Dense(512, 10)))
xxxxxxxxxx
1
1
model = resnet(10, [2, 2, 2, 2], [1, 1, 2, 2, 2], p_drop=0.3) |> gpu
Training
loss (generic function with 1 method)
xxxxxxxxxx
1
1
loss(x, y) = Flux.logitcrossentropy(model(x), y)
xxxxxxxxxx
1
1
ps = params(model);
0.001
0.001
0.9
0.999
0.0001
xxxxxxxxxx
1
1
opt = Flux.Optimiser(InvDecay(0.001), ADAMW(config.lr, (0.9, 0.999), 1f-4))
accuracy (generic function with 1 method)
xxxxxxxxxx
7
1
function accuracy(model, data)
2
m = Mean()
3
for (x, y) in data
4
fit!(m, Flux.onecold(cpu(model(x)), 1:10) .== Flux.onecold(cpu(y), 1:10))
5
end
6
value(m)
7
end
(::Flux.var"#throttled#42"{Flux.var"#throttled#38#43"{Bool,Bool,Main.workspace3.var"#11#12",Int64}}) (generic function with 1 method)
xxxxxxxxxx
3
1
evalcb = Flux.throttle(config.throttle) do
2
"Val accuracy: %.3f\n" accuracy(model, test_loader)
3
end
false
xxxxxxxxxx
1
1
do_training = false
xxxxxxxxxx
3
1
if do_training
2
Flux. config.epochs Flux.train!(loss, ps, train_loader, opt, cb=evalcb)
3
end