CIFAR10 classification using Flax

Flax is a neural network library and ecosystem for JAX that is designed for flexibility.

GitHub repo: https://github.com/google/flax

Optax is a gradient processing and optimization library for JAX.

Configuration

Imports

Disable GPU usage by TensorFlow:

Configuration

Data

Model

Residual blocks

ResNet

Training

Loss

Metrics

Optimizer

Train state

Train functions

Start training