Haiku is a library built on top of JAX designed to provide simple, composable abstractions for machine learning research.
GitHub repo: https://github.com/deepmind/dm-haiku
Imports
import math
from collections import namedtuple, defaultdict
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import haiku as hk
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
Disable GPU usage by TensorFlow:
tf.config.set_visible_devices([], 'GPU')
jax.local_devices()
[GpuDevice(id=0, process_index=0)]
Configuration
IMAGE_SIZE = 32
NUM_CLASSES = 10
BATCH_SIZE = 32
EPOCHS = 100
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-2
def augment_image(img):
img = tf.image.resize_with_crop_or_pad(img, 40, 40)
img = tf.image.random_crop(img, [32, 32, 3])
img = tf.image.random_flip_left_right(img)
img = tf.image.random_brightness(img, max_delta=0.2)
img = tf.image.random_contrast(img, 0.8, 1.2)
img = tf.image.random_saturation(img, 0.8, 1.2)
return img
def train_process_sample(x):
image = augment_image(x['image'])
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
return {'images': image, 'labels': x['label']}
def val_process_sample(x):
image = tf.image.convert_image_dtype(x['image'], dtype=tf.float32)
return {'images': image, 'labels': x['label']}
def prepare_train_dataset(dataset_builder, batch_size):
ds = dataset_builder.as_dataset(split='train')
ds = ds.repeat()
ds = ds.map(train_process_sample, num_parallel_calls=tf.data.AUTOTUNE)
df = ds.shuffle(16 * batch_size, reshuffle_each_iteration=True, seed=0)
ds = ds.batch(batch_size)
ds = ds.prefetch(tf.data.AUTOTUNE)
ds = tfds.as_numpy(ds)
return ds
def prepare_val_dataset(dataset_builder, batch_size):
ds = dataset_builder.as_dataset(split='test')
ds = ds.map(val_process_sample, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.batch(batch_size)
ds = ds.repeat()
ds = ds.prefetch(tf.data.AUTOTUNE)
ds = tfds.as_numpy(ds)
return ds
dataset_builder = tfds.builder('cifar10')
dataset_builder.download_and_prepare()
train_steps_per_epoch = math.ceil(dataset_builder.info.splits['train'].num_examples / BATCH_SIZE)
val_steps_per_epoch = math.ceil(dataset_builder.info.splits['test'].num_examples / BATCH_SIZE)
train_ds = prepare_train_dataset(dataset_builder, BATCH_SIZE)
val_ds = prepare_val_dataset(dataset_builder, BATCH_SIZE)
train_iter = iter(train_ds)
val_iter = iter(val_ds)
Utilities
class Sequential(hk.Module):
def __init__(self, layers):
super().__init__()
self.layers = tuple(layers)
def __call__(self, inputs, *args):
out = inputs
for layer in self.layers:
out = layer(out, *args)
return out
class ConvBlock(hk.Module):
def __init__(self, channels, kernel_size, stride = 1, act = True):
super().__init__()
self.act = act
self.conv = hk.Conv2D(channels, kernel_size, stride=stride, with_bias=False, padding="SAME",
w_init=hk.initializers.VarianceScaling(2., "fan_in", "truncated_normal"))
self.bn = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9)
def __call__(self, inputs, is_training):
out = self.conv(inputs)
out = self.bn(out, is_training)
if self.act:
out = jax.nn.swish(out)
return out
Residual blocks
class ResidualBlock(hk.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.residual_blocks = Sequential([ConvBlock(channels, 3),
ConvBlock(channels, 3, act=False)])
def __call__(self, inputs, is_training):
shortcut = inputs
residual = self.residual_blocks(inputs, is_training)
if shortcut.shape != residual.shape:
shortcut = ConvBlock(self.channels, 1, act=False)(shortcut, is_training)
gamma = hk.get_parameter("gamma", [], inputs.dtype, init=jnp.zeros)
out = shortcut + gamma * residual
out = jax.nn.swish(out)
return out
class Stage(hk.Module):
def __init__(self, channels, num_blocks, stride):
super().__init__()
self.stride = stride
if stride > 1:
self.pool = hk.MaxPool(stride, strides=stride, padding="VALID")
self.blocks = Sequential([ResidualBlock(channels) for _ in range(num_blocks)])
def __call__(self, inputs, is_training):
out = inputs
if self.stride > 1:
out = self.pool(out)
out = self.blocks(out, is_training)
return out
class Body(hk.Module):
def __init__(self, channel_list, num_blocks_list, strides):
super().__init__()
self.stages = Sequential([Stage(channels, num_blocks, stride)
for channels, num_blocks, stride in zip(channel_list, num_blocks_list, strides)])
def __call__(self, inputs, is_training):
out = self.stages(inputs, is_training)
return out
ResNet
class Stem(hk.Module):
def __init__(self, channel_list, stride):
super().__init__()
blocks = []
for channels in channel_list:
blocks.append(ConvBlock(channels, 3, stride=stride))
stride = 1
self.blocks = Sequential(blocks)
def __call__(self, inputs, is_training):
out = self.blocks(inputs, is_training)
return out
class Head(hk.Module):
def __init__(self, classes, p_drop=0.):
super().__init__()
self.p_drop = p_drop
self.linear = hk.Linear(classes)
def __call__(self, inputs, is_training):
out = jnp.mean(inputs, axis=(1, 2))
if is_training and self.p_drop > 0:
out = hk.dropout(hk.next_rng_key(), self.p_drop, out)
out = self.linear(out)
return out
class ResNet(hk.Module):
def __init__(self, classes, channel_list, num_blocks_list, strides, head_p_drop = 0.):
super().__init__()
self.stem = Stem([32, 32, 64], strides[0])
self.body = Body(channel_list, num_blocks_list, strides[1:])
self.head = Head(classes, head_p_drop)
def __call__(self, inputs, is_training):
out = self.stem(inputs, is_training)
out = self.body(out, is_training)
out = self.head(out, is_training)
return out
def _forward(batch, is_training):
images = batch['images']
model = ResNet(NUM_CLASSES,
channel_list = [64, 128, 256, 512],
num_blocks_list = [2, 2, 2, 2],
strides = [1, 1, 2, 2, 2],
head_p_drop = 0.3)
return model(images, is_training)
forward = hk.transform_with_state(_forward)
rng = jax.random.PRNGKey(0)
init_rng, rng = jax.random.split(rng)
batch = next(train_iter)
params, state = forward.init(init_rng, batch, is_training=True)
def summary_fn(params, state, rng, batch):
logits, new_state = forward.apply(params, state, rng, batch, is_training=True)
return logits
def format_module_name(name):
parts = name.split('~')
return " " * (len(parts) - 1) + parts[-1]
tabulate_fn = hk.experimental.eval_summary(summary_fn)
summary = tabulate_fn(params, state, rng, batch)
print("{:32} | {:17} | {:17} | {}".format("Module","Input","Output", "Param count"))
print("=" * 86)
for i in summary:
print("{:32} | {:17} | {:17} | {:,}".format(
format_module_name(i.module_details.module.module_name),
str(i.args_spec[0]),
str(i.output_spec),
hk.data_structures.tree_size(i.module_details.params)))
Module | Input | Output | Param count ====================================================================================== res_net | f32[32,32,32,3] | f32[32,10] | 11,200,882 /stem | f32[32,32,32,3] | f32[32,32,32,64] | 28,768 /sequential | f32[32,32,32,3] | f32[32,32,32,64] | 28,768 /conv_block | f32[32,32,32,3] | f32[32,32,32,32] | 928 /conv2_d | f32[32,32,32,3] | f32[32,32,32,32] | 864 /batch_norm | f32[32,32,32,32] | f32[32,32,32,32] | 64 /mean_ema | f32[1,1,1,32] | f32[1,1,1,32] | 0 /var_ema | f32[1,1,1,32] | f32[1,1,1,32] | 0 /conv_block_1 | f32[32,32,32,32] | f32[32,32,32,32] | 9,280 /conv2_d | f32[32,32,32,32] | f32[32,32,32,32] | 9,216 /batch_norm | f32[32,32,32,32] | f32[32,32,32,32] | 64 /mean_ema | f32[1,1,1,32] | f32[1,1,1,32] | 0 /var_ema | f32[1,1,1,32] | f32[1,1,1,32] | 0 /conv_block_2 | f32[32,32,32,32] | f32[32,32,32,64] | 18,560 /conv2_d | f32[32,32,32,32] | f32[32,32,32,64] | 18,432 /batch_norm | f32[32,32,32,64] | f32[32,32,32,64] | 128 /mean_ema | f32[1,1,1,64] | f32[1,1,1,64] | 0 /var_ema | f32[1,1,1,64] | f32[1,1,1,64] | 0 /body | f32[32,32,32,64] | f32[32,4,4,512] | 11,166,984 /sequential | f32[32,32,32,64] | f32[32,4,4,512] | 11,166,984 /stage | f32[32,32,32,64] | f32[32,32,32,64] | 147,970 /sequential | f32[32,32,32,64] | f32[32,32,32,64] | 147,970 /residual_block | f32[32,32,32,64] | f32[32,32,32,64] | 73,985 /sequential | f32[32,32,32,64] | f32[32,32,32,64] | 73,985 /conv_block | f32[32,32,32,64] | f32[32,32,32,64] | 36,993 /conv2_d | f32[32,32,32,64] | f32[32,32,32,64] | 36,864 /batch_norm | f32[32,32,32,64] | f32[32,32,32,64] | 129 /mean_ema | f32[1,1,1,64] | f32[1,1,1,64] | 0 /var_ema | f32[1,1,1,64] | f32[1,1,1,64] | 0 /conv_block_1 | f32[32,32,32,64] | f32[32,32,32,64] | 36,993 /conv2_d | f32[32,32,32,64] | f32[32,32,32,64] | 36,864 /batch_norm | f32[32,32,32,64] | f32[32,32,32,64] | 129 /mean_ema | f32[1,1,1,64] | f32[1,1,1,64] | 0 /var_ema | f32[1,1,1,64] | f32[1,1,1,64] | 0 /residual_block_1 | f32[32,32,32,64] | f32[32,32,32,64] | 73,985 /sequential | f32[32,32,32,64] | f32[32,32,32,64] | 73,985 /conv_block | f32[32,32,32,64] | f32[32,32,32,64] | 36,993 /conv2_d | f32[32,32,32,64] | f32[32,32,32,64] | 36,864 /batch_norm | f32[32,32,32,64] | f32[32,32,32,64] | 129 /mean_ema | f32[1,1,1,64] | f32[1,1,1,64] | 0 /var_ema | f32[1,1,1,64] | f32[1,1,1,64] | 0 /conv_block_1 | f32[32,32,32,64] | f32[32,32,32,64] | 36,993 /conv2_d | f32[32,32,32,64] | f32[32,32,32,64] | 36,864 /batch_norm | f32[32,32,32,64] | f32[32,32,32,64] | 129 /mean_ema | f32[1,1,1,64] | f32[1,1,1,64] | 0 /var_ema | f32[1,1,1,64] | f32[1,1,1,64] | 0 /stage_1 | f32[32,32,32,64] | f32[32,16,16,128] | 525,570 /max_pool | f32[32,32,32,64] | f32[32,16,16,64] | 0 /sequential | f32[32,16,16,64] | f32[32,16,16,128] | 525,570 /residual_block | f32[32,16,16,64] | f32[32,16,16,128] | 230,145 /sequential | f32[32,16,16,64] | f32[32,16,16,128] | 221,697 /conv_block | f32[32,16,16,64] | f32[32,16,16,128] | 73,985 /conv2_d | f32[32,16,16,64] | f32[32,16,16,128] | 73,728 /batch_norm | f32[32,16,16,128] | f32[32,16,16,128] | 257 /mean_ema | f32[1,1,1,128] | f32[1,1,1,128] | 0 /var_ema | f32[1,1,1,128] | f32[1,1,1,128] | 0 /conv_block_1 | f32[32,16,16,128] | f32[32,16,16,128] | 147,713 /conv2_d | f32[32,16,16,128] | f32[32,16,16,128] | 147,456 /batch_norm | f32[32,16,16,128] | f32[32,16,16,128] | 257 /mean_ema | f32[1,1,1,128] | f32[1,1,1,128] | 0 /var_ema | f32[1,1,1,128] | f32[1,1,1,128] | 0 /residual_block/conv_block | f32[32,16,16,64] | f32[32,16,16,128] | 8,449 /conv2_d | f32[32,16,16,64] | f32[32,16,16,128] | 8,192 /batch_norm | f32[32,16,16,128] | f32[32,16,16,128] | 257 /mean_ema | f32[1,1,1,128] | f32[1,1,1,128] | 0 /var_ema | f32[1,1,1,128] | f32[1,1,1,128] | 0 /residual_block_1 | f32[32,16,16,128] | f32[32,16,16,128] | 295,425 /sequential | f32[32,16,16,128] | f32[32,16,16,128] | 295,425 /conv_block | f32[32,16,16,128] | f32[32,16,16,128] | 147,713 /conv2_d | f32[32,16,16,128] | f32[32,16,16,128] | 147,456 /batch_norm | f32[32,16,16,128] | f32[32,16,16,128] | 257 /mean_ema | f32[1,1,1,128] | f32[1,1,1,128] | 0 /var_ema | f32[1,1,1,128] | f32[1,1,1,128] | 0 /conv_block_1 | f32[32,16,16,128] | f32[32,16,16,128] | 147,713 /conv2_d | f32[32,16,16,128] | f32[32,16,16,128] | 147,456 /batch_norm | f32[32,16,16,128] | f32[32,16,16,128] | 257 /mean_ema | f32[1,1,1,128] | f32[1,1,1,128] | 0 /var_ema | f32[1,1,1,128] | f32[1,1,1,128] | 0 /stage_2 | f32[32,16,16,128] | f32[32,8,8,256] | 2,099,714 /max_pool | f32[32,16,16,128] | f32[32,8,8,128] | 0 /sequential | f32[32,8,8,128] | f32[32,8,8,256] | 2,099,714 /residual_block | f32[32,8,8,128] | f32[32,8,8,256] | 919,041 /sequential | f32[32,8,8,128] | f32[32,8,8,256] | 885,761 /conv_block | f32[32,8,8,128] | f32[32,8,8,256] | 295,425 /conv2_d | f32[32,8,8,128] | f32[32,8,8,256] | 294,912 /batch_norm | f32[32,8,8,256] | f32[32,8,8,256] | 513 /mean_ema | f32[1,1,1,256] | f32[1,1,1,256] | 0 /var_ema | f32[1,1,1,256] | f32[1,1,1,256] | 0 /conv_block_1 | f32[32,8,8,256] | f32[32,8,8,256] | 590,337 /conv2_d | f32[32,8,8,256] | f32[32,8,8,256] | 589,824 /batch_norm | f32[32,8,8,256] | f32[32,8,8,256] | 513 /mean_ema | f32[1,1,1,256] | f32[1,1,1,256] | 0 /var_ema | f32[1,1,1,256] | f32[1,1,1,256] | 0 /residual_block/conv_block | f32[32,8,8,128] | f32[32,8,8,256] | 33,281 /conv2_d | f32[32,8,8,128] | f32[32,8,8,256] | 32,768 /batch_norm | f32[32,8,8,256] | f32[32,8,8,256] | 513 /mean_ema | f32[1,1,1,256] | f32[1,1,1,256] | 0 /var_ema | f32[1,1,1,256] | f32[1,1,1,256] | 0 /residual_block_1 | f32[32,8,8,256] | f32[32,8,8,256] | 1,180,673 /sequential | f32[32,8,8,256] | f32[32,8,8,256] | 1,180,673 /conv_block | f32[32,8,8,256] | f32[32,8,8,256] | 590,337 /conv2_d | f32[32,8,8,256] | f32[32,8,8,256] | 589,824 /batch_norm | f32[32,8,8,256] | f32[32,8,8,256] | 513 /mean_ema | f32[1,1,1,256] | f32[1,1,1,256] | 0 /var_ema | f32[1,1,1,256] | f32[1,1,1,256] | 0 /conv_block_1 | f32[32,8,8,256] | f32[32,8,8,256] | 590,337 /conv2_d | f32[32,8,8,256] | f32[32,8,8,256] | 589,824 /batch_norm | f32[32,8,8,256] | f32[32,8,8,256] | 513 /mean_ema | f32[1,1,1,256] | f32[1,1,1,256] | 0 /var_ema | f32[1,1,1,256] | f32[1,1,1,256] | 0 /stage_3 | f32[32,8,8,256] | f32[32,4,4,512] | 8,393,730 /max_pool | f32[32,8,8,256] | f32[32,4,4,256] | 0 /sequential | f32[32,4,4,256] | f32[32,4,4,512] | 8,393,730 /residual_block | f32[32,4,4,256] | f32[32,4,4,512] | 3,673,089 /sequential | f32[32,4,4,256] | f32[32,4,4,512] | 3,540,993 /conv_block | f32[32,4,4,256] | f32[32,4,4,512] | 1,180,673 /conv2_d | f32[32,4,4,256] | f32[32,4,4,512] | 1,179,648 /batch_norm | f32[32,4,4,512] | f32[32,4,4,512] | 1,025 /mean_ema | f32[1,1,1,512] | f32[1,1,1,512] | 0 /var_ema | f32[1,1,1,512] | f32[1,1,1,512] | 0 /conv_block_1 | f32[32,4,4,512] | f32[32,4,4,512] | 2,360,321 /conv2_d | f32[32,4,4,512] | f32[32,4,4,512] | 2,359,296 /batch_norm | f32[32,4,4,512] | f32[32,4,4,512] | 1,025 /mean_ema | f32[1,1,1,512] | f32[1,1,1,512] | 0 /var_ema | f32[1,1,1,512] | f32[1,1,1,512] | 0 /residual_block/conv_block | f32[32,4,4,256] | f32[32,4,4,512] | 132,097 /conv2_d | f32[32,4,4,256] | f32[32,4,4,512] | 131,072 /batch_norm | f32[32,4,4,512] | f32[32,4,4,512] | 1,025 /mean_ema | f32[1,1,1,512] | f32[1,1,1,512] | 0 /var_ema | f32[1,1,1,512] | f32[1,1,1,512] | 0 /residual_block_1 | f32[32,4,4,512] | f32[32,4,4,512] | 4,720,641 /sequential | f32[32,4,4,512] | f32[32,4,4,512] | 4,720,641 /conv_block | f32[32,4,4,512] | f32[32,4,4,512] | 2,360,321 /conv2_d | f32[32,4,4,512] | f32[32,4,4,512] | 2,359,296 /batch_norm | f32[32,4,4,512] | f32[32,4,4,512] | 1,025 /mean_ema | f32[1,1,1,512] | f32[1,1,1,512] | 0 /var_ema | f32[1,1,1,512] | f32[1,1,1,512] | 0 /conv_block_1 | f32[32,4,4,512] | f32[32,4,4,512] | 2,360,321 /conv2_d | f32[32,4,4,512] | f32[32,4,4,512] | 2,359,296 /batch_norm | f32[32,4,4,512] | f32[32,4,4,512] | 1,025 /mean_ema | f32[1,1,1,512] | f32[1,1,1,512] | 0 /var_ema | f32[1,1,1,512] | f32[1,1,1,512] | 0 /head | f32[32,4,4,512] | f32[32,10] | 5,130 /linear | f32[32,512] | f32[32,10] | 5,130
def cross_entropy_loss(logits, labels):
one_hot_labels = jax.nn.one_hot(labels, NUM_CLASSES)
loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels).mean()
return loss
@jax.jit
def loss_fn(params, state, rng, batch):
logits, new_state = forward.apply(params, state, rng, batch, is_training=True)
loss = cross_entropy_loss(logits, batch['labels'])
return loss, (new_state, logits)
grad_fn = jax.grad(loss_fn, has_aux=True)
def compute_metrics(logits, labels):
loss = cross_entropy_loss(logits, labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy,
}
return metrics
num_train_steps = train_steps_per_epoch * EPOCHS
shedule_fn = optax.cosine_onecycle_schedule(transition_steps=num_train_steps, peak_value=LEARNING_RATE)
optim = optax.adamw(learning_rate=shedule_fn, weight_decay=WEIGHT_DECAY)
opt_state = optim.init(params)
TrainState = namedtuple("TrainState", ["params", "state", "opt_state", "rng"])
train_state = TrainState(params, state, opt_state, rng)
@jax.jit
def train_step(train_state, batch):
rng, new_rng = jax.random.split(train_state.rng)
grads, (new_state, logits) = grad_fn(train_state.params, train_state.state, rng, batch)
updates, new_opt_state = optim.update(grads, train_state.opt_state, train_state.params)
new_params = optax.apply_updates(train_state.params, updates)
train_state = TrainState(new_params, new_state, new_opt_state, new_rng)
metrics = compute_metrics(logits, batch['labels'])
return train_state, metrics
@jax.jit
def eval_step(train_state, batch):
logits, _ = forward.apply(train_state.params, train_state.state, None, batch, is_training=False)
metrics = compute_metrics(logits, batch['labels'])
return metrics
def metrics_summary(metrics):
metrics = jax.device_get(metrics)
metrics = jax.tree_multimap(lambda *args: np.stack(args), *metrics)
summary = jax.tree_map(lambda x: x.mean(), metrics)
return summary
def log_metrics(history, summary, name):
print(f"{name}: ", end='', flush=True)
for key, val in summary.items():
history[name + ' ' + key].append(val)
print(f"{key} {val:.3f} ", end='')
def train(train_state, train_iter, val_iter, epochs):
history = defaultdict(list)
for epoch in range(1, epochs + 1):
print(f"{epoch}/{epochs} - ", end='')
train_metrics = []
for step in range(train_steps_per_epoch):
batch = next(train_iter)
train_state, metrics = train_step(train_state, batch)
train_metrics.append(metrics)
summary = metrics_summary(train_metrics)
log_metrics(history, summary, 'train')
print('; ', end='')
val_metrics = []
for step in range(val_steps_per_epoch):
batch = next(val_iter)
metrics = eval_step(train_state, batch)
val_metrics.append(metrics)
summary = metrics_summary(val_metrics)
log_metrics(history, summary, 'val')
print()
return history
history = train(train_state, train_iter, val_iter, EPOCHS)
1/100 - train: accuracy 0.510 loss 1.355 ; val: accuracy 0.669 loss 0.951 2/100 - train: accuracy 0.711 loss 0.826 ; val: accuracy 0.753 loss 0.717 3/100 - train: accuracy 0.769 loss 0.667 ; val: accuracy 0.768 loss 0.678 4/100 - train: accuracy 0.802 loss 0.581 ; val: accuracy 0.781 loss 0.654 5/100 - train: accuracy 0.817 loss 0.529 ; val: accuracy 0.817 loss 0.550 6/100 - train: accuracy 0.830 loss 0.492 ; val: accuracy 0.802 loss 0.612 7/100 - train: accuracy 0.841 loss 0.461 ; val: accuracy 0.833 loss 0.507 8/100 - train: accuracy 0.850 loss 0.437 ; val: accuracy 0.833 loss 0.511 9/100 - train: accuracy 0.857 loss 0.412 ; val: accuracy 0.847 loss 0.468 10/100 - train: accuracy 0.864 loss 0.394 ; val: accuracy 0.854 loss 0.445 11/100 - train: accuracy 0.872 loss 0.375 ; val: accuracy 0.846 loss 0.463 12/100 - train: accuracy 0.876 loss 0.361 ; val: accuracy 0.847 loss 0.475 13/100 - train: accuracy 0.879 loss 0.353 ; val: accuracy 0.847 loss 0.477 14/100 - train: accuracy 0.882 loss 0.340 ; val: accuracy 0.846 loss 0.503 15/100 - train: accuracy 0.886 loss 0.332 ; val: accuracy 0.859 loss 0.435 16/100 - train: accuracy 0.890 loss 0.324 ; val: accuracy 0.874 loss 0.412 17/100 - train: accuracy 0.890 loss 0.322 ; val: accuracy 0.874 loss 0.402 18/100 - train: accuracy 0.892 loss 0.316 ; val: accuracy 0.864 loss 0.453 19/100 - train: accuracy 0.892 loss 0.314 ; val: accuracy 0.873 loss 0.403 20/100 - train: accuracy 0.893 loss 0.313 ; val: accuracy 0.851 loss 0.499 21/100 - train: accuracy 0.891 loss 0.314 ; val: accuracy 0.860 loss 0.465 22/100 - train: accuracy 0.891 loss 0.315 ; val: accuracy 0.870 loss 0.442 23/100 - train: accuracy 0.895 loss 0.306 ; val: accuracy 0.872 loss 0.444 24/100 - train: accuracy 0.894 loss 0.308 ; val: accuracy 0.871 loss 0.436 25/100 - train: accuracy 0.895 loss 0.308 ; val: accuracy 0.854 loss 0.509 26/100 - train: accuracy 0.895 loss 0.306 ; val: accuracy 0.862 loss 0.500 27/100 - train: accuracy 0.896 loss 0.302 ; val: accuracy 0.881 loss 0.394 28/100 - train: accuracy 0.898 loss 0.297 ; val: accuracy 0.865 loss 0.464 29/100 - train: accuracy 0.901 loss 0.295 ; val: accuracy 0.864 loss 0.464 30/100 - train: accuracy 0.900 loss 0.293 ; val: accuracy 0.849 loss 0.533 31/100 - train: accuracy 0.902 loss 0.287 ; val: accuracy 0.869 loss 0.434 32/100 - train: accuracy 0.904 loss 0.282 ; val: accuracy 0.881 loss 0.418 33/100 - train: accuracy 0.902 loss 0.282 ; val: accuracy 0.887 loss 0.363 34/100 - train: accuracy 0.906 loss 0.272 ; val: accuracy 0.852 loss 0.520 35/100 - train: accuracy 0.906 loss 0.274 ; val: accuracy 0.870 loss 0.425 36/100 - train: accuracy 0.907 loss 0.269 ; val: accuracy 0.881 loss 0.393 37/100 - train: accuracy 0.910 loss 0.264 ; val: accuracy 0.880 loss 0.406 38/100 - train: accuracy 0.910 loss 0.262 ; val: accuracy 0.886 loss 0.368 39/100 - train: accuracy 0.912 loss 0.257 ; val: accuracy 0.886 loss 0.380 40/100 - train: accuracy 0.914 loss 0.252 ; val: accuracy 0.885 loss 0.415 41/100 - train: accuracy 0.914 loss 0.250 ; val: accuracy 0.885 loss 0.381 42/100 - train: accuracy 0.915 loss 0.246 ; val: accuracy 0.880 loss 0.424 43/100 - train: accuracy 0.916 loss 0.247 ; val: accuracy 0.881 loss 0.396 44/100 - train: accuracy 0.917 loss 0.243 ; val: accuracy 0.895 loss 0.387 45/100 - train: accuracy 0.918 loss 0.241 ; val: accuracy 0.883 loss 0.403 46/100 - train: accuracy 0.920 loss 0.234 ; val: accuracy 0.881 loss 0.399 47/100 - train: accuracy 0.922 loss 0.228 ; val: accuracy 0.903 loss 0.326 48/100 - train: accuracy 0.924 loss 0.223 ; val: accuracy 0.900 loss 0.347 49/100 - train: accuracy 0.925 loss 0.220 ; val: accuracy 0.886 loss 0.397 50/100 - train: accuracy 0.928 loss 0.212 ; val: accuracy 0.889 loss 0.355 51/100 - train: accuracy 0.927 loss 0.210 ; val: accuracy 0.883 loss 0.388 52/100 - train: accuracy 0.931 loss 0.202 ; val: accuracy 0.896 loss 0.358 53/100 - train: accuracy 0.933 loss 0.197 ; val: accuracy 0.900 loss 0.327 54/100 - train: accuracy 0.933 loss 0.196 ; val: accuracy 0.890 loss 0.378 55/100 - train: accuracy 0.936 loss 0.187 ; val: accuracy 0.904 loss 0.330 56/100 - train: accuracy 0.937 loss 0.183 ; val: accuracy 0.895 loss 0.386 57/100 - train: accuracy 0.938 loss 0.178 ; val: accuracy 0.893 loss 0.372 58/100 - train: accuracy 0.942 loss 0.170 ; val: accuracy 0.886 loss 0.414 59/100 - train: accuracy 0.943 loss 0.166 ; val: accuracy 0.907 loss 0.309 60/100 - train: accuracy 0.943 loss 0.163 ; val: accuracy 0.908 loss 0.315 61/100 - train: accuracy 0.946 loss 0.153 ; val: accuracy 0.901 loss 0.338 62/100 - train: accuracy 0.950 loss 0.144 ; val: accuracy 0.912 loss 0.325 63/100 - train: accuracy 0.952 loss 0.141 ; val: accuracy 0.914 loss 0.309 64/100 - train: accuracy 0.953 loss 0.136 ; val: accuracy 0.901 loss 0.345 65/100 - train: accuracy 0.955 loss 0.129 ; val: accuracy 0.907 loss 0.347 66/100 - train: accuracy 0.958 loss 0.121 ; val: accuracy 0.897 loss 0.381 67/100 - train: accuracy 0.960 loss 0.117 ; val: accuracy 0.912 loss 0.308 68/100 - train: accuracy 0.963 loss 0.109 ; val: accuracy 0.908 loss 0.337 69/100 - train: accuracy 0.964 loss 0.105 ; val: accuracy 0.920 loss 0.299 70/100 - train: accuracy 0.966 loss 0.097 ; val: accuracy 0.916 loss 0.351 71/100 - train: accuracy 0.968 loss 0.092 ; val: accuracy 0.919 loss 0.324 72/100 - train: accuracy 0.971 loss 0.086 ; val: accuracy 0.920 loss 0.326 73/100 - train: accuracy 0.972 loss 0.080 ; val: accuracy 0.921 loss 0.328 74/100 - train: accuracy 0.974 loss 0.074 ; val: accuracy 0.923 loss 0.317 75/100 - train: accuracy 0.976 loss 0.068 ; val: accuracy 0.925 loss 0.324 76/100 - train: accuracy 0.978 loss 0.064 ; val: accuracy 0.928 loss 0.301 77/100 - train: accuracy 0.982 loss 0.056 ; val: accuracy 0.925 loss 0.313 78/100 - train: accuracy 0.982 loss 0.052 ; val: accuracy 0.921 loss 0.357 79/100 - train: accuracy 0.983 loss 0.048 ; val: accuracy 0.931 loss 0.315 80/100 - train: accuracy 0.986 loss 0.042 ; val: accuracy 0.932 loss 0.318 81/100 - train: accuracy 0.987 loss 0.037 ; val: accuracy 0.932 loss 0.323 82/100 - train: accuracy 0.989 loss 0.033 ; val: accuracy 0.929 loss 0.338 83/100 - train: accuracy 0.990 loss 0.029 ; val: accuracy 0.937 loss 0.316 84/100 - train: accuracy 0.992 loss 0.024 ; val: accuracy 0.934 loss 0.321 85/100 - train: accuracy 0.993 loss 0.020 ; val: accuracy 0.933 loss 0.358 86/100 - train: accuracy 0.993 loss 0.021 ; val: accuracy 0.937 loss 0.346 87/100 - train: accuracy 0.995 loss 0.015 ; val: accuracy 0.940 loss 0.332 88/100 - train: accuracy 0.996 loss 0.014 ; val: accuracy 0.940 loss 0.322 89/100 - train: accuracy 0.996 loss 0.011 ; val: accuracy 0.945 loss 0.321 90/100 - train: accuracy 0.997 loss 0.010 ; val: accuracy 0.941 loss 0.334 91/100 - train: accuracy 0.998 loss 0.007 ; val: accuracy 0.941 loss 0.342 92/100 - train: accuracy 0.998 loss 0.006 ; val: accuracy 0.945 loss 0.327 93/100 - train: accuracy 0.998 loss 0.005 ; val: accuracy 0.948 loss 0.341 94/100 - train: accuracy 0.998 loss 0.004 ; val: accuracy 0.946 loss 0.335 95/100 - train: accuracy 0.999 loss 0.003 ; val: accuracy 0.947 loss 0.328 96/100 - train: accuracy 0.999 loss 0.003 ; val: accuracy 0.948 loss 0.328 97/100 - train: accuracy 0.999 loss 0.002 ; val: accuracy 0.949 loss 0.334 98/100 - train: accuracy 0.999 loss 0.003 ; val: accuracy 0.949 loss 0.330 99/100 - train: accuracy 0.999 loss 0.002 ; val: accuracy 0.949 loss 0.339 100/100 - train: accuracy 0.999 loss 0.002 ; val: accuracy 0.949 loss 0.335
def plot_history_train_val(history, key):
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train ' + key]) + 1)
ax.plot(xs, history['train ' + key], '.-', label='train')
ax.plot(xs, history['val ' + key], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel(key)
ax.legend()
ax.grid()
plt.show()
plot_history_train_val(history, 'loss')
plot_history_train_val(history, 'accuracy')