Flax is a neural network library and ecosystem for JAX that is designed for flexibility.
GitHub repo: https://github.com/google/flax
Optax is a gradient processing and optimization library for JAX.
Imports
import math
from functools import partial
from collections import defaultdict
from typing import Any, Sequence
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state, common_utils
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
Disable GPU usage by TensorFlow:
tf.config.set_visible_devices([], 'GPU')
jax.local_devices()
[GpuDevice(id=0, process_index=0)]
Configuration
IMAGE_SIZE = 32
NUM_CLASSES = 10
BATCH_SIZE = 32
EPOCHS = 100
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-2
rng = jax.random.PRNGKey(0)
def augment_image(img):
img = tf.image.resize_with_crop_or_pad(img, 40, 40)
img = tf.image.random_crop(img, [32, 32, 3])
img = tf.image.random_flip_left_right(img)
img = tf.image.random_brightness(img, max_delta=0.2)
img = tf.image.random_contrast(img, 0.8, 1.2)
img = tf.image.random_saturation(img, 0.8, 1.2)
return img
def train_process_sample(x):
image = augment_image(x['image'])
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
return {'image': image, 'label': x['label']}
def val_process_sample(x):
image = tf.image.convert_image_dtype(x['image'], dtype=tf.float32)
return {'image': image, 'label': x['label']}
def prepare_train_dataset(dataset_builder, batch_size):
ds = dataset_builder.as_dataset(split='train')
ds = ds.repeat()
ds = ds.map(train_process_sample, num_parallel_calls=tf.data.AUTOTUNE)
df = ds.shuffle(16 * batch_size, reshuffle_each_iteration=True, seed=0)
ds = ds.batch(batch_size)
ds = ds.prefetch(10)
return ds
def prepare_val_dataset(dataset_builder, batch_size):
ds = dataset_builder.as_dataset(split='test')
ds = ds.map(val_process_sample, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.batch(batch_size)
ds = ds.repeat()
ds = ds.prefetch(10)
return ds
def tf_to_numpy(xs):
return jax.tree_map(lambda x: x._numpy(), xs)
def dataset_to_iterator(ds):
it = map(tf_to_numpy, ds)
return it
dataset_builder = tfds.builder('cifar10')
dataset_builder.download_and_prepare()
train_steps_per_epoch = math.ceil(dataset_builder.info.splits['train'].num_examples / BATCH_SIZE)
val_steps_per_epoch = math.ceil(dataset_builder.info.splits['test'].num_examples / BATCH_SIZE)
train_ds = prepare_train_dataset(dataset_builder, BATCH_SIZE)
val_ds = prepare_val_dataset(dataset_builder, BATCH_SIZE)
train_iter = dataset_to_iterator(train_ds)
val_iter = dataset_to_iterator(val_ds)
ModuleDef = Any
class ConvBlock(nn.Module):
channels: int
kernel_size: int
norm: ModuleDef
stride: int = 1
act: bool = True
@nn.compact
def __call__(self, x):
x = nn.Conv(self.channels, (self.kernel_size, self.kernel_size), strides=self.stride,
padding='SAME', use_bias=False, kernel_init=nn.initializers.kaiming_normal())(x)
x = self.norm()(x)
if self.act:
x = nn.swish(x)
return x
Residual blocks
class ResidualBlock(nn.Module):
channels: int
conv_block: ModuleDef
@nn.compact
def __call__(self, x):
channels = self.channels
conv_block = self.conv_block
shortcut = x
residual = conv_block(channels, 3)(x)
residual = conv_block(channels, 3, act=False)(residual)
if shortcut.shape != residual.shape:
shortcut = conv_block(channels, 1, act=False)(shortcut)
gamma = self.param('gamma', nn.initializers.zeros, 1, jnp.float32)
out = shortcut + gamma * residual
out = nn.swish(out)
return out
class Stage(nn.Module):
channels: int
num_blocks: int
stride: int
block: ModuleDef
@nn.compact
def __call__(self, x):
stride = self.stride
if stride > 1:
x = nn.max_pool(x, (stride, stride), strides=(stride, stride))
for _ in range(self.num_blocks):
x = self.block(self.channels)(x)
return x
class Body(nn.Module):
channel_list: Sequence[int]
num_blocks_list: Sequence[int]
strides: Sequence[int]
stage: ModuleDef
@nn.compact
def __call__(self, x):
for channels, num_blocks, stride in zip(self.channel_list, self.num_blocks_list, self.strides):
x = self.stage(channels, num_blocks, stride)(x)
return x
ResNet
class Stem(nn.Module):
channel_list: Sequence[int]
stride: int
conv_block: ModuleDef
@nn.compact
def __call__(self, x):
stride = self.stride
for channels in self.channel_list:
x = self.conv_block(channels, 3, stride=stride)(x)
stride = 1
return x
class Head(nn.Module):
classes: int
dropout: ModuleDef
@nn.compact
def __call__(self, x):
x = jnp.mean(x, axis=(1, 2))
x = self.dropout()(x)
x = nn.Dense(self.classes)(x)
return x
class ResNet(nn.Module):
classes: int
channel_list: Sequence[int]
num_blocks_list: Sequence[int]
strides: Sequence[int]
head_p_drop: float = 0.
@nn.compact
def __call__(self, x, train=True):
norm = partial(nn.BatchNorm, use_running_average=not train)
dropout = partial(nn.Dropout, rate=self.head_p_drop, deterministic=not train)
conv_block = partial(ConvBlock, norm=norm)
residual_block = partial(ResidualBlock, conv_block=conv_block)
stage = partial(Stage, block=residual_block)
x = Stem([32, 32, 64], self.strides[0], conv_block)(x)
x = Body(self.channel_list, self.num_blocks_list, self.strides[1:], stage)(x)
x = Head(self.classes, dropout)(x)
return x
@jax.jit
def initialize(params_rng):
init_rngs = {'params': params_rng}
input_shape = (1, IMAGE_SIZE, IMAGE_SIZE, 3)
variables = model.init(init_rngs, jnp.ones(input_shape, jnp.float32), train=False)
return variables
model = ResNet(NUM_CLASSES,
channel_list = [64, 128, 256, 512],
num_blocks_list = [2, 2, 2, 2],
strides = [1, 1, 2, 2, 2],
head_p_drop = 0.3)
params_rng, dropout_rng = jax.random.split(rng)
variables = initialize(params_rng)
def cross_entropy_loss(logits, labels):
one_hot_labels = common_utils.onehot(labels, num_classes=NUM_CLASSES)
loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
loss = jnp.mean(loss)
return loss
def compute_metrics(logits, labels):
loss = cross_entropy_loss(logits, labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics
num_train_steps = train_steps_per_epoch * EPOCHS
shedule_fn = optax.cosine_onecycle_schedule(transition_steps=num_train_steps, peak_value=LEARNING_RATE)
tx = optax.adamw(learning_rate=shedule_fn, weight_decay=WEIGHT_DECAY)
class TrainState(train_state.TrainState):
batch_stats: Any
state = TrainState.create(
apply_fn = model.apply,
params = variables['params'],
batch_stats = variables['batch_stats'],
tx = tx)
@jax.jit
def train_step(state, batch, dropout_rng):
dropout_rng = jax.random.fold_in(dropout_rng, state.step)
def loss_fn(params):
variables = {'params': params, 'batch_stats': state.batch_stats}
logits, new_model_state = state.apply_fn(variables, batch['image'], train=True,
rngs={'dropout': dropout_rng}, mutable='batch_stats')
loss = cross_entropy_loss(logits, batch['label'])
return loss, (new_model_state, logits)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
aux, grads = grad_fn(state.params)
new_model_state, logits = aux[1]
metrics = compute_metrics(logits, batch['label'])
new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
return new_state, metrics
@jax.jit
def eval_step(state, batch):
variables = {'params': state.params, 'batch_stats': state.batch_stats}
logits = state.apply_fn(variables, batch['image'], train=False, mutable=False)
metrics = compute_metrics(logits, batch['label'])
return metrics
def metrics_summary(metrics):
metrics = jax.device_get(metrics)
metrics = jax.tree_multimap(lambda *args: np.stack(args), *metrics)
summary = jax.tree_map(lambda x: x.mean(), metrics)
return summary
def log_metrics(history, summary, name):
print(f"{name}: ", end='', flush=True)
for key, val in summary.items():
history[name + ' ' + key].append(val)
print(f"{key} {val:.3f} ", end='')
def train(state, train_iter, val_iter, epochs):
history = defaultdict(list)
for epoch in range(1, epochs + 1):
print(f"{epoch}/{epochs} - ", end='')
train_metrics = []
for step in range(train_steps_per_epoch):
batch = next(train_iter)
state, metrics = train_step(state, batch, dropout_rng)
train_metrics.append(metrics)
summary = metrics_summary(train_metrics)
log_metrics(history, summary, 'train')
print('; ', end='')
val_metrics = []
for step in range(val_steps_per_epoch):
batch = next(val_iter)
metrics = eval_step(state, batch)
val_metrics.append(metrics)
summary = metrics_summary(val_metrics)
log_metrics(history, summary, 'val')
print()
return history
history = train(state, train_iter, val_iter, EPOCHS)
1/100 - train: accuracy 0.519 loss 1.333 ; val: accuracy 0.679 loss 0.929 2/100 - train: accuracy 0.716 loss 0.819 ; val: accuracy 0.737 loss 0.788 3/100 - train: accuracy 0.773 loss 0.664 ; val: accuracy 0.770 loss 0.670 4/100 - train: accuracy 0.799 loss 0.584 ; val: accuracy 0.799 loss 0.596 5/100 - train: accuracy 0.816 loss 0.535 ; val: accuracy 0.800 loss 0.597 6/100 - train: accuracy 0.831 loss 0.493 ; val: accuracy 0.781 loss 0.735 7/100 - train: accuracy 0.840 loss 0.467 ; val: accuracy 0.810 loss 0.566 8/100 - train: accuracy 0.849 loss 0.441 ; val: accuracy 0.806 loss 0.607 9/100 - train: accuracy 0.857 loss 0.418 ; val: accuracy 0.807 loss 0.585 10/100 - train: accuracy 0.863 loss 0.398 ; val: accuracy 0.806 loss 0.625 11/100 - train: accuracy 0.870 loss 0.381 ; val: accuracy 0.826 loss 0.568 12/100 - train: accuracy 0.875 loss 0.363 ; val: accuracy 0.787 loss 0.738 13/100 - train: accuracy 0.877 loss 0.357 ; val: accuracy 0.825 loss 0.585 14/100 - train: accuracy 0.881 loss 0.341 ; val: accuracy 0.832 loss 0.553 15/100 - train: accuracy 0.884 loss 0.337 ; val: accuracy 0.823 loss 0.571 16/100 - train: accuracy 0.887 loss 0.329 ; val: accuracy 0.825 loss 0.566 17/100 - train: accuracy 0.887 loss 0.324 ; val: accuracy 0.848 loss 0.509 18/100 - train: accuracy 0.889 loss 0.322 ; val: accuracy 0.771 loss 0.799 19/100 - train: accuracy 0.890 loss 0.318 ; val: accuracy 0.776 loss 0.820 20/100 - train: accuracy 0.893 loss 0.315 ; val: accuracy 0.767 loss 0.909 21/100 - train: accuracy 0.891 loss 0.312 ; val: accuracy 0.727 loss 1.047 22/100 - train: accuracy 0.891 loss 0.314 ; val: accuracy 0.797 loss 0.711 23/100 - train: accuracy 0.893 loss 0.308 ; val: accuracy 0.821 loss 0.613 24/100 - train: accuracy 0.894 loss 0.308 ; val: accuracy 0.817 loss 0.683 25/100 - train: accuracy 0.892 loss 0.314 ; val: accuracy 0.730 loss 1.165 26/100 - train: accuracy 0.895 loss 0.306 ; val: accuracy 0.772 loss 0.874 27/100 - train: accuracy 0.897 loss 0.304 ; val: accuracy 0.809 loss 0.617 28/100 - train: accuracy 0.896 loss 0.300 ; val: accuracy 0.735 loss 1.021 29/100 - train: accuracy 0.898 loss 0.297 ; val: accuracy 0.826 loss 0.644 30/100 - train: accuracy 0.900 loss 0.291 ; val: accuracy 0.838 loss 0.529 31/100 - train: accuracy 0.901 loss 0.286 ; val: accuracy 0.801 loss 0.818 32/100 - train: accuracy 0.901 loss 0.288 ; val: accuracy 0.714 loss 1.196 33/100 - train: accuracy 0.903 loss 0.282 ; val: accuracy 0.815 loss 0.611 34/100 - train: accuracy 0.906 loss 0.271 ; val: accuracy 0.836 loss 0.565 35/100 - train: accuracy 0.907 loss 0.273 ; val: accuracy 0.854 loss 0.523 36/100 - train: accuracy 0.907 loss 0.272 ; val: accuracy 0.785 loss 0.812 37/100 - train: accuracy 0.908 loss 0.264 ; val: accuracy 0.848 loss 0.484 38/100 - train: accuracy 0.911 loss 0.261 ; val: accuracy 0.805 loss 0.819 39/100 - train: accuracy 0.911 loss 0.259 ; val: accuracy 0.845 loss 0.557 40/100 - train: accuracy 0.912 loss 0.257 ; val: accuracy 0.836 loss 0.638 41/100 - train: accuracy 0.913 loss 0.253 ; val: accuracy 0.858 loss 0.469 42/100 - train: accuracy 0.915 loss 0.249 ; val: accuracy 0.792 loss 0.807 43/100 - train: accuracy 0.915 loss 0.245 ; val: accuracy 0.858 loss 0.495 44/100 - train: accuracy 0.917 loss 0.241 ; val: accuracy 0.805 loss 0.701 45/100 - train: accuracy 0.919 loss 0.237 ; val: accuracy 0.868 loss 0.452 46/100 - train: accuracy 0.919 loss 0.235 ; val: accuracy 0.810 loss 0.706 47/100 - train: accuracy 0.922 loss 0.228 ; val: accuracy 0.865 loss 0.489 48/100 - train: accuracy 0.923 loss 0.221 ; val: accuracy 0.770 loss 0.835 49/100 - train: accuracy 0.925 loss 0.217 ; val: accuracy 0.863 loss 0.530 50/100 - train: accuracy 0.926 loss 0.216 ; val: accuracy 0.867 loss 0.454 51/100 - train: accuracy 0.928 loss 0.209 ; val: accuracy 0.836 loss 0.614 52/100 - train: accuracy 0.929 loss 0.203 ; val: accuracy 0.861 loss 0.506 53/100 - train: accuracy 0.931 loss 0.200 ; val: accuracy 0.861 loss 0.477 54/100 - train: accuracy 0.934 loss 0.194 ; val: accuracy 0.874 loss 0.406 55/100 - train: accuracy 0.934 loss 0.191 ; val: accuracy 0.852 loss 0.553 56/100 - train: accuracy 0.938 loss 0.181 ; val: accuracy 0.879 loss 0.413 57/100 - train: accuracy 0.939 loss 0.176 ; val: accuracy 0.814 loss 0.807 58/100 - train: accuracy 0.940 loss 0.173 ; val: accuracy 0.890 loss 0.380 59/100 - train: accuracy 0.943 loss 0.164 ; val: accuracy 0.864 loss 0.597 60/100 - train: accuracy 0.943 loss 0.162 ; val: accuracy 0.821 loss 0.624 61/100 - train: accuracy 0.948 loss 0.150 ; val: accuracy 0.844 loss 0.602 62/100 - train: accuracy 0.947 loss 0.151 ; val: accuracy 0.867 loss 0.477 63/100 - train: accuracy 0.951 loss 0.143 ; val: accuracy 0.863 loss 0.514 64/100 - train: accuracy 0.953 loss 0.135 ; val: accuracy 0.852 loss 0.538 65/100 - train: accuracy 0.954 loss 0.132 ; val: accuracy 0.862 loss 0.583 66/100 - train: accuracy 0.958 loss 0.122 ; val: accuracy 0.895 loss 0.407 67/100 - train: accuracy 0.960 loss 0.115 ; val: accuracy 0.875 loss 0.524 68/100 - train: accuracy 0.961 loss 0.110 ; val: accuracy 0.880 loss 0.459 69/100 - train: accuracy 0.964 loss 0.103 ; val: accuracy 0.877 loss 0.545 70/100 - train: accuracy 0.966 loss 0.098 ; val: accuracy 0.896 loss 0.433 71/100 - train: accuracy 0.968 loss 0.093 ; val: accuracy 0.891 loss 0.391 72/100 - train: accuracy 0.970 loss 0.087 ; val: accuracy 0.908 loss 0.373 73/100 - train: accuracy 0.973 loss 0.081 ; val: accuracy 0.914 loss 0.366 74/100 - train: accuracy 0.975 loss 0.074 ; val: accuracy 0.898 loss 0.483 75/100 - train: accuracy 0.976 loss 0.071 ; val: accuracy 0.895 loss 0.457 76/100 - train: accuracy 0.977 loss 0.064 ; val: accuracy 0.903 loss 0.390 77/100 - train: accuracy 0.980 loss 0.058 ; val: accuracy 0.893 loss 0.573 78/100 - train: accuracy 0.983 loss 0.051 ; val: accuracy 0.911 loss 0.435 79/100 - train: accuracy 0.984 loss 0.045 ; val: accuracy 0.921 loss 0.387 80/100 - train: accuracy 0.985 loss 0.044 ; val: accuracy 0.911 loss 0.427 81/100 - train: accuracy 0.987 loss 0.037 ; val: accuracy 0.916 loss 0.414 82/100 - train: accuracy 0.988 loss 0.036 ; val: accuracy 0.912 loss 0.451 83/100 - train: accuracy 0.990 loss 0.028 ; val: accuracy 0.932 loss 0.355 84/100 - train: accuracy 0.991 loss 0.026 ; val: accuracy 0.932 loss 0.354 85/100 - train: accuracy 0.992 loss 0.024 ; val: accuracy 0.932 loss 0.350 86/100 - train: accuracy 0.994 loss 0.019 ; val: accuracy 0.932 loss 0.341 87/100 - train: accuracy 0.995 loss 0.015 ; val: accuracy 0.931 loss 0.379 88/100 - train: accuracy 0.995 loss 0.014 ; val: accuracy 0.940 loss 0.355 89/100 - train: accuracy 0.997 loss 0.011 ; val: accuracy 0.939 loss 0.346 90/100 - train: accuracy 0.997 loss 0.009 ; val: accuracy 0.940 loss 0.338 91/100 - train: accuracy 0.998 loss 0.007 ; val: accuracy 0.943 loss 0.342 92/100 - train: accuracy 0.998 loss 0.007 ; val: accuracy 0.945 loss 0.332 93/100 - train: accuracy 0.998 loss 0.006 ; val: accuracy 0.946 loss 0.326 94/100 - train: accuracy 0.999 loss 0.004 ; val: accuracy 0.944 loss 0.332 95/100 - train: accuracy 0.999 loss 0.003 ; val: accuracy 0.948 loss 0.329 96/100 - train: accuracy 0.999 loss 0.004 ; val: accuracy 0.949 loss 0.324 97/100 - train: accuracy 0.999 loss 0.002 ; val: accuracy 0.949 loss 0.322 98/100 - train: accuracy 0.999 loss 0.003 ; val: accuracy 0.949 loss 0.322 99/100 - train: accuracy 0.999 loss 0.002 ; val: accuracy 0.950 loss 0.321 100/100 - train: accuracy 0.999 loss 0.002 ; val: accuracy 0.950 loss 0.321
def plot_history_train_val(history, key):
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train ' + key]) + 1)
ax.plot(xs, history['train ' + key], '.-', label='train')
ax.plot(xs, history['val ' + key], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel(key)
ax.legend()
ax.grid()
plt.show()
plot_history_train_val(history, 'loss')
plot_history_train_val(history, 'accuracy')