Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team.
GitHub repo: https://github.com/google/trax
The basic units in Trax are tensors, using numpy interface.
In Trax numpy operations are accelerated using GPU or TPU. The gradients of functions on tensors are automatically computed. This is done in the trax.fastmath
package which supports two backends:
Installing JAX: https://github.com/google/jax#pip-installation
Imports
import trax
from trax import layers as tl
from trax import shapes
from trax.fastmath import numpy as jnp
from trax.supervised import training
import numpy as np
import tensorflow as tf
%matplotlib inline
import matplotlib.pyplot as plt
Disable GPU usage by TensorFlow:
tf.config.set_visible_devices([], 'GPU')
print("Current trax backend:", trax.fastmath.backend_name())
print("Number of CPU or TPU devices:", trax.fastmath.device_count())
Configuration
NUM_CLASSES = 10
BATCH_SIZE = 32
NUM_STEPS = 20000
STEPS_PER_CHECKPOINT = 1000
EVAL_BATCHES = 300
WARMUP_STEPS = 500
MAX_LR = 1e-3
OUTPUT_DIR = 'output'
!rm -rf {OUTPUT_DIR}
def augment_image(img):
img = tf.image.resize_with_crop_or_pad(img, 40, 40)
img = tf.image.random_crop(img, [32, 32, 3])
img = tf.image.random_flip_left_right(img)
img = tf.image.random_brightness(img, max_delta=0.2)
img = tf.image.random_contrast(img, 0.8, 1.2)
img = tf.image.random_saturation(img, 0.8, 1.2)
return img
def Augment(generator):
for imgs, tgts in generator:
for i in range(len(imgs)):
imgs[i] = augment_image(imgs[i])
yield (imgs, tgts)
def ToFloat(generator):
for img, tgt in generator:
img = img.astype(np.float32) / 255.0
yield (img, tgt)
train_stream = trax.data.Serial(
trax.data.TFDS('cifar10', data_dir='data', keys=('image', 'label'), train=True),
trax.data.Shuffle(),
trax.data.Batch(BATCH_SIZE),
Augment,
ToFloat,
trax.data.AddLossWeights() # needed for tl.CrossEntropyLoss
)
eval_stream = trax.data.Serial(
trax.data.TFDS('cifar10', data_dir='data', keys=('image', 'label'), train=False),
trax.data.Batch(BATCH_SIZE),
ToFloat,
trax.data.AddLossWeights()
)
# Test generator:
batch = next(train_stream())
img = batch[0][0]
plt.imshow(img);
def ConvBlock(filters, kernel_size=3, strides=1, act=True, mode='train'):
layers = [
tl.Conv(filters, (kernel_size, kernel_size), strides=(strides, strides), padding='SAME',
kernel_initializer=tl.initializers.KaimingNormalInitializer()),
tl.BatchNorm(mode=mode)
]
if act: layers.append(tl.Relu())
return layers
def BasicResidual(res_filters, strides=1, mode='train'):
return [
ConvBlock(res_filters, strides=strides, mode=mode),
ConvBlock(res_filters, act=False, mode=mode)
]
def Shortcut(prev_filters, res_filters, strides=1, mode='train'):
layers = []
if strides > 1:
layers.append(tl.AvgPool((strides, strides), (strides, strides)))
if prev_filters != res_filters:
layers += ConvBlock(res_filters, kernel_size=1, act=False, mode=mode)
if len(layers) == 0: layers = None
return layers
def ZerosInitializer(shape, rng):
return jnp.zeros(shape, jnp.float32)
def ResidualBlock(prev_filters, res_filters, strides=1, mode='train'):
shortcut = Shortcut(prev_filters, res_filters, strides, mode=mode)
residual = [
BasicResidual(res_filters, strides, mode=mode),
tl.Weights(ZerosInitializer, shape=(1,)),
tl.Multiply()
]
return [
tl.Residual(residual, shortcut=shortcut),
tl.Relu()
]
def ResidualBody(filters, repetitions, strides, mode='train'):
layers = []
res_filters = filters
for rep, stride in zip(repetitions, strides):
for _ in range(rep):
layers.append(ResidualBlock(filters, res_filters, stride, mode=mode))
filters = res_filters
stride = 1
res_filters *= 2
return layers
def Stem(filter_list, stride=1, mode='train'):
layers = []
for filters in filter_list:
layers.append(ConvBlock(filters, strides=stride, mode=mode))
stride = 1
return layers
def GlobalAvgPool():
def pool(x):
pool_size = tuple(x.shape[1:3]) # NHWC
return trax.fastmath.avg_pool(x, pool_size=pool_size, strides=None, padding='VALID')
return tl.Fn("GlobalAvgPool", pool)
def Head(classes, p_drop=0., mode='train'):
layers = [
GlobalAvgPool(),
tl.Flatten()
]
if p_drop > 0: layers.append(tl.Dropout(p_drop, mode=mode))
layers += [
tl.Dense(classes),
tl.LogSoftmax()
]
return layers
def ResNet(repetitions, classes, strides=None, p_drop=0., mode='train'):
if not strides: strides = [2] * (len(repetitions) + 1)
return tl.Serial(
Stem([32, 32, 64], stride=strides[0], mode=mode),
ResidualBody(64, repetitions, strides[1:], mode=mode),
Head(classes, p_drop=p_drop, mode=mode)
)
def MyModel(mode='train'):
return ResNet([2, 2, 2, 2], NUM_CLASSES, strides=[1, 1, 2, 2, 2], p_drop=0.3, mode=mode)
model = MyModel()
eval_model = MyModel(mode='eval')
model
train_task = training.TrainTask(
labeled_data=train_stream(),
loss_layer=tl.CrossEntropyLoss(),
optimizer=trax.optimizers.Adam(),
lr_schedule=trax.supervised.lr_schedules.warmup_and_rsqrt_decay(WARMUP_STEPS, MAX_LR),
n_steps_per_checkpoint=STEPS_PER_CHECKPOINT
)
eval_task = training.EvalTask(
labeled_data=eval_stream(),
metrics=[tl.CrossEntropyLoss(), tl.Accuracy()],
n_eval_batches=EVAL_BATCHES
)
training_loop = training.Loop(
model,
train_task,
eval_model=eval_model,
eval_tasks=eval_task,
output_dir=OUTPUT_DIR
)
training_loop.run(NUM_STEPS)