ResNet for CIFAR10 classification using Julia

4.9 μs

Configuration

3.9 μs
25.7 ms
3.5 ms
18.5 s
26.4 ms
3.6 s
4.5 s
401 ms
80.2 μs
49.1 μs
369 ms
136 μs
Config
2.5 ms
config
Config
  batchsize: Int64 32
  throttle: Int64 20
  lr: Float32 0.001f0
  epochs: Int64 2
770 μs

Data

4.8 μs

One needs to download data only once

8.3 μs
download_data
false
70.0 ns
70.0 ns
2.8 s
51.7 ms
train_aug
8-step Augmentor.ImmutablePipeline:
 1.) Either: (50%) Flip the X axis. (50%) No operation.
 2.) Either: (50%) ShearX by ϕ ∈ -5:5 degree. (50%) ShearY by ψ ∈ -5:5 degree.
 3.) Rotate by θ ∈ -15:15 degree
 4.) Crop a 32×32 window around the center
 5.) Zoom by I ∈ {0.9×0.9, 1.0×1.0, 1.1×1.1, 1.2×1.2}
 6.) Split colorant into its color channels
 7.) Permute dimension order to (3, 2, 1)
 8.) Convert eltype to Float32
729 ms
collate (generic function with 1 method)
26.3 μs
collate (generic function with 2 methods)
25.2 μs
60.0 ms
4.9 ms
4.5 s

Model

4.0 μs
conv_block (generic function with 1 method)
84.2 μs
basic_residual (generic function with 1 method)
23.0 μs
40.2 ms
residual_block (generic function with 1 method)
13.5 μs
residual_body (generic function with 1 method)
38.1 μs
stem (generic function with 2 methods)
62.5 μs
head (generic function with 2 methods)
23.6 μs
resnet (generic function with 1 method)
51.7 μs
model
Chain(Chain(Chain(Conv((3, 3), 3=>32), BatchNorm(32, λ = relu)), Chain(Conv((3, 3), 32=>32), BatchNorm(32, λ = relu)), Chain(Conv((3, 3), 32=>64), BatchNorm(64, λ = relu))), Chain(SkipConnection(Chain(Chain(Conv((3, 3), 64=>64), BatchNorm(64, λ = relu)), Chain(Conv((3, 3), 64=>64), BatchNorm(64))), AddMerge(Float32[0.0], identity)), SkipConnection(Chain(Chain(Conv((3, 3), 64=>64), BatchNorm(64, λ = relu)), Chain(Conv((3, 3), 64=>64), BatchNorm(64))), AddMerge(Float32[0.0], identity)), SkipConnection(Chain(Chain(Conv((3, 3), 64=>128), BatchNorm(128, λ = relu)), Chain(Conv((3, 3), 128=>128), BatchNorm(128))), AddMerge(Float32[0.0], Chain(Conv((1, 1), 64=>128), BatchNorm(128)))), SkipConnection(Chain(Chain(Conv((3, 3), 128=>128), BatchNorm(128, λ = relu)), Chain(Conv((3, 3), 128=>128), BatchNorm(128))), AddMerge(Float32[0.0], identity)), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), SkipConnection(Chain(Chain(Conv((3, 3), 128=>256), BatchNorm(256, λ = relu)), Chain(Conv((3, 3), 256=>256), BatchNorm(256))), AddMerge(Float32[0.0], Chain(Conv((1, 1), 128=>256), BatchNorm(256)))), SkipConnection(Chain(Chain(Conv((3, 3), 256=>256), BatchNorm(256, λ = relu)), Chain(Conv((3, 3), 256=>256), BatchNorm(256))), AddMerge(Float32[0.0], identity)), MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), SkipConnection(Chain(Chain(Conv((3, 3), 256=>512), BatchNorm(512, λ = relu)), Chain(Conv((3, 3), 512=>512), BatchNorm(512))), AddMerge(Float32[0.0], Chain(Conv((1, 1), 256=>512), BatchNorm(512)))), SkipConnection(Chain(Chain(Conv((3, 3), 512=>512), BatchNorm(512, λ = relu)), Chain(Conv((3, 3), 512=>512), BatchNorm(512))), AddMerge(Float32[0.0], identity))), Chain(GlobalMeanPool(), flatten, Dropout(0.3), Dense(512, 10)))
2.4 s

Training

2.5 μs
loss (generic function with 1 method)
18.4 μs
297 ms
opt
10.0 ms
accuracy (generic function with 1 method)
50.3 μs
evalcb
(::Flux.var"#throttled#42"{Flux.var"#throttled#38#43"{Bool,Bool,Main.workspace3.var"#11#12",Int64}}) (generic function with 1 method)
6.2 ms
do_training
false
70.0 ns
34.4 μs