ResNet for CIFAR10 classification using Julia
Configuration
xxxxxxxxxx1
1
using PlutoUITable of Contents
xxxxxxxxxx1
1
TableOfContents()xxxxxxxxxx1
1
using Fluxxxxxxxxxxx1
1
using CUDAxxxxxxxxxx1
1
using MLDatasetsxxxxxxxxxx1
1
using Imagesxxxxxxxxxx1
1
using Augmentorxxxxxxxxxx1
1
using Parametersxxxxxxxxxx1
1
using IterToolsxxxxxxxxxx1
1
using OnlineStatsxxxxxxxxxx1
1
using PrintfConfigxxxxxxxxxx6
1
struct Config2
batchsize::Int = 323
throttle::Int = 204
lr::Float32 = 1f-35
epochs::Int = 26
endConfig
batchsize: Int64 32
throttle: Int64 20
lr: Float32 0.001f0
epochs: Int64 2
xxxxxxxxxx1
1
config = Config()Data
One needs to download data only once
falsexxxxxxxxxx1
1
download_data = falsexxxxxxxxxx3
1
if download_data2
CIFAR10.download(i_accept_the_terms_of_use=true)3
endxxxxxxxxxx1
1
train_data = CIFAR10.traindata(Float32);xxxxxxxxxx1
1
test_data = CIFAR10.testdata(Float32);8-step Augmentor.ImmutablePipeline:
1.) Either: (50%) Flip the X axis. (50%) No operation.
2.) Either: (50%) ShearX by ϕ ∈ -5:5 degree. (50%) ShearY by ψ ∈ -5:5 degree.
3.) Rotate by θ ∈ -15:15 degree
4.) Crop a 32×32 window around the center
5.) Zoom by I ∈ {0.9×0.9, 1.0×1.0, 1.1×1.1, 1.2×1.2}
6.) Split colorant into its color channels
7.) Permute dimension order to (3, 2, 1)
8.) Convert eltype to Float32xxxxxxxxxx3
1
train_aug = FlipX(0.5) |> ShearX(-5:5) * ShearY(-5:5) |> Rotate(-15:15) |>2
CropSize(32,32) |> Zoom(0.9:0.1:1.2) |>3
SplitChannels() |> PermuteDims(3, 2, 1) |> ConvertEltype(Float32)collate (generic function with 1 method)xxxxxxxxxx5
1
function collate((imgs, labels))2
imgs = imgs |> gpu3
labels = Flux.onehotbatch(labels .+ 1, 1:10) |> gpu4
imgs, labels5
endcollate (generic function with 2 methods)xxxxxxxxxx5
1
function collate((imgs, labels), aug)2
imgs_aug = Array{Float32}(undef, size(imgs))3
augmentbatch!(imgs_aug, CIFAR10.convert2image(imgs), aug)4
collate((imgs_aug, labels))5
endxxxxxxxxxx2
1
train_loader = imap(d -> collate(d, train_aug),2
Flux.Data.DataLoader(train_data, batchsize=config.batchsize, shuffle=true));xxxxxxxxxx2
1
test_loader = imap(collate,2
Flux.Data.DataLoader(test_data, batchsize=config.batchsize, shuffle=false));"deer"
xxxxxxxxxx5
1
begin2
batch = iterate(train_loader)[1]3
CIFAR10.classnames()[Flux.onecold(cpu(batch[2])[:, 1], 1:10)],4
CIFAR10.convert2image(cpu(batch[1])[:,:,:, 1])5
endModel
conv_block (generic function with 1 method)xxxxxxxxxx5
1
function conv_block(ch::Pair; kernel_size=3, stride=1, activation=relu)2
Chain(Conv((kernel_size, kernel_size), ch, pad=SamePad(), stride=stride,3
init=Flux.kaiming_normal),4
BatchNorm(ch.second, activation))5
endbasic_residual (generic function with 1 method)xxxxxxxxxx4
1
function basic_residual(ch::Pair)2
Chain(conv_block(ch),3
conv_block(ch.second => ch.second, activation=identity))4
endxxxxxxxxxx19
1
begin2
struct AddMerge3
gamma4
expand5
end6
7
Flux. AddMerge8
9
function AddMerge(ch::Pair)10
if ch.first == ch.second11
expand = identity12
else13
expand = conv_block(ch, kernel_size=1, activation=identity)14
end15
AddMerge([0.f0], expand)16
end17
18
(m::AddMerge)(x1, x2) = relu.(m.gamma .* x1 .+ m.expand(x2))19
endresidual_block (generic function with 1 method)xxxxxxxxxx4
1
function residual_block(ch::Pair)2
residual = basic_residual(ch)3
SkipConnection(residual, AddMerge(ch))4
endresidual_body (generic function with 1 method)xxxxxxxxxx15
1
function residual_body(in_channels, repetitions, downsamplings)2
layers = []3
res_channels = in_channels4
for (rep, stride) in zip(repetitions, downsamplings)5
if stride > 16
push!(layers, MaxPool((stride, stride)))7
end8
for i = 1:rep9
push!(layers, residual_block(in_channels => res_channels))10
in_channels = res_channels11
end12
res_channels *= 213
end14
Chain(layers...)15
endstem (generic function with 2 methods)xxxxxxxxxx9
1
function stem(in_channels=3; channel_list = [32, 32, 64], stride=1)2
layers = []3
for channels in channel_list4
push!(layers, conv_block(in_channels => channels, stride=stride))5
in_channels = channels6
stride=17
end8
Chain(layers...)9
endhead (generic function with 2 methods)xxxxxxxxxx6
1
function head(in_channels, classes, p_drop=0.)2
Chain(GlobalMeanPool(),3
flatten,4
Dropout(p_drop),5
Dense(in_channels, classes))6
endresnet (generic function with 1 method)xxxxxxxxxx5
1
function resnet(classes, repetitions, downsamplings; in_channels=3, p_drop=0.)2
Chain(stem(in_channels, stride=downsamplings[1]),3
residual_body(64, repetitions, downsamplings[1:end]),4
head(64 * 2^(length(repetitions)-1), classes, p_drop))5
endChain(Chain(Chain(Conv((3, 3), 3=>32), BatchNorm(32, λ = relu)), Chain(Conv((3, 3), 32=>32), BatchNorm(32, λ = relu)), Chain(Conv((3, 3), 32=>64), BatchNorm(64, λ = relu))), Chain(SkipConnection(Chain(Chain(Conv((3, 3), 64=>64), BatchNorm(64, λ = relu)), Chain(Conv((3, 3), 64=>64), BatchNorm(64))), AddMerge(Float32[0.0], identity)), SkipConnection(Chain(Chain(Conv((3, 3), 64=>64), BatchNorm(64, λ = relu)), Chain(Conv((3, 3), 64=>64), BatchNorm(64))), AddMerge(Float32[0.0], identity)), SkipConnection(Chain(Chain(Conv((3, 3), 64=>128), BatchNorm(128, λ = relu)), Chain(Conv((3, 3), 128=>128), BatchNorm(128))), AddMerge(Float32[0.0], Chain(Conv((1, 1), 64=>128), BatchNorm(128)))), SkipConnection(Chain(Chain(Conv((3, 3), 128=>128), BatchNorm(128, λ = relu)), Chain(Conv((3, 3), 128=>128), BatchNorm(128))), AddMerge(Float32[0.0], identity)), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), SkipConnection(Chain(Chain(Conv((3, 3), 128=>256), BatchNorm(256, λ = relu)), Chain(Conv((3, 3), 256=>256), BatchNorm(256))), AddMerge(Float32[0.0], Chain(Conv((1, 1), 128=>256), BatchNorm(256)))), SkipConnection(Chain(Chain(Conv((3, 3), 256=>256), BatchNorm(256, λ = relu)), Chain(Conv((3, 3), 256=>256), BatchNorm(256))), AddMerge(Float32[0.0], identity)), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), SkipConnection(Chain(Chain(Conv((3, 3), 256=>512), BatchNorm(512, λ = relu)), Chain(Conv((3, 3), 512=>512), BatchNorm(512))), AddMerge(Float32[0.0], Chain(Conv((1, 1), 256=>512), BatchNorm(512)))), SkipConnection(Chain(Chain(Conv((3, 3), 512=>512), BatchNorm(512, λ = relu)), Chain(Conv((3, 3), 512=>512), BatchNorm(512))), AddMerge(Float32[0.0], identity))), Chain(GlobalMeanPool(), flatten, Dropout(0.3), Dense(512, 10)))xxxxxxxxxx1
1
model = resnet(10, [2, 2, 2, 2], [1, 1, 2, 2, 2], p_drop=0.3) |> gpuTraining
loss (generic function with 1 method)xxxxxxxxxx1
1
loss(x, y) = Flux.logitcrossentropy(model(x), y)xxxxxxxxxx1
1
ps = params(model);0.001
0.001
0.9
0.999
0.0001
xxxxxxxxxx1
1
opt = Flux.Optimiser(InvDecay(0.001), ADAMW(config.lr, (0.9, 0.999), 1f-4))accuracy (generic function with 1 method)xxxxxxxxxx7
1
function accuracy(model, data)2
m = Mean()3
for (x, y) in data4
fit!(m, Flux.onecold(cpu(model(x)), 1:10) .== Flux.onecold(cpu(y), 1:10))5
end6
value(m)7
end(::Flux.var"#throttled#42"{Flux.var"#throttled#38#43"{Bool,Bool,Main.workspace3.var"#11#12",Int64}}) (generic function with 1 method)xxxxxxxxxx3
1
evalcb = Flux.throttle(config.throttle) do2
"Val accuracy: %.3f\n" accuracy(model, test_loader)3
endfalsexxxxxxxxxx1
1
do_training = falsexxxxxxxxxx3
1
if do_training2
Flux. config.epochs Flux.train!(loss, ps, train_loader, opt, cb=evalcb)3
end