Data from Kaggle Web Traffic Time Series Forecasting, https://www.kaggle.com/c/web-traffic-time-series-forecasting/
Imports
from pathlib import Path
from datetime import timedelta
from collections import defaultdict
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
Configuration
data_dir = Path("./data")
PRED_STEPS = 60
BATCH_SIZE = 1024
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
csv_file = data_dir / "train_1.csv"
df = pd.read_csv(csv_file)
df.head()
Page | 2015-07-01 | 2015-07-02 | 2015-07-03 | 2015-07-04 | 2015-07-05 | 2015-07-06 | 2015-07-07 | 2015-07-08 | 2015-07-09 | ... | 2016-12-22 | 2016-12-23 | 2016-12-24 | 2016-12-25 | 2016-12-26 | 2016-12-27 | 2016-12-28 | 2016-12-29 | 2016-12-30 | 2016-12-31 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2NE1_zh.wikipedia.org_all-access_spider | 18.0 | 11.0 | 5.0 | 13.0 | 14.0 | 9.0 | 9.0 | 22.0 | 26.0 | ... | 32.0 | 63.0 | 15.0 | 26.0 | 14.0 | 20.0 | 22.0 | 19.0 | 18.0 | 20.0 |
1 | 2PM_zh.wikipedia.org_all-access_spider | 11.0 | 14.0 | 15.0 | 18.0 | 11.0 | 13.0 | 22.0 | 11.0 | 10.0 | ... | 17.0 | 42.0 | 28.0 | 15.0 | 9.0 | 30.0 | 52.0 | 45.0 | 26.0 | 20.0 |
2 | 3C_zh.wikipedia.org_all-access_spider | 1.0 | 0.0 | 1.0 | 1.0 | 0.0 | 4.0 | 0.0 | 3.0 | 4.0 | ... | 3.0 | 1.0 | 1.0 | 7.0 | 4.0 | 4.0 | 6.0 | 3.0 | 4.0 | 17.0 |
3 | 4minute_zh.wikipedia.org_all-access_spider | 35.0 | 13.0 | 10.0 | 94.0 | 4.0 | 26.0 | 14.0 | 9.0 | 11.0 | ... | 32.0 | 10.0 | 26.0 | 27.0 | 16.0 | 11.0 | 17.0 | 19.0 | 10.0 | 11.0 |
4 | 52_Hz_I_Love_You_zh.wikipedia.org_all-access_s... | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | ... | 48.0 | 9.0 | 25.0 | 13.0 | 3.0 | 11.0 | 27.0 | 13.0 | 36.0 | 10.0 |
5 rows × 551 columns
date_to_index = pd.Series(index=pd.Index([pd.to_datetime(c) for c in df.columns[1:]]),
data = list(range(df.shape[-1] - 1)))
series_array = df[df.columns[1:]].values
def get_enc_pred_dates(start_day, end_day, pred_length):
pred_start = end_day - pred_length + timedelta(1)
pred_end = end_day
enc_start = start_day
enc_end = end_day - pred_length
return enc_start, enc_end, pred_start, pred_end
def get_time_series(start_date, end_date, series_array, date_to_index):
inds = date_to_index[start_date:end_date]
return series_array[:,inds]
def transform_time_series(series_array, series_mean=None):
series_array = np.log1p(np.nan_to_num(series_array)) # filling NaN with 0
if series_mean is None:
series_mean = series_array.mean(axis=1, keepdims=True)
series_array = series_array - series_mean
series_array = series_array.reshape((series_array.shape[0], 1, series_array.shape[1]))
return series_array, series_mean
def prepare_time_series(enc_start, enc_end, pred_start, pred_end, series_array, date_to_index):
enc_series = get_time_series(enc_start, enc_end, series_array, date_to_index)
enc_series, enc_mean = transform_time_series(enc_series)
pred_series = get_time_series(pred_start, pred_end, series_array, date_to_index)
pred_series, _ = transform_time_series(pred_series, enc_mean)
# we append a history of the target series to the input data, so that we can train with teacher forcing
enc_series = np.concatenate([enc_series, pred_series[:, :, :-1]], axis=-1)
enc_series = torch.as_tensor(enc_series, dtype=torch.float32)
pred_series = torch.as_tensor(pred_series, dtype=torch.float32)
return enc_series, pred_series
data_start_date = df.columns[1]
data_end_date = df.columns[-1]
print(f'Data ranges from {data_start_date} to {data_end_date}')
Data ranges from 2015-07-01 to 2016-12-31
pred_length = timedelta(PRED_STEPS)
first_day = pd.to_datetime(data_start_date)
last_day = pd.to_datetime(data_end_date)
train_enc_start, train_enc_end, train_pred_start, train_pred_end = get_enc_pred_dates(
first_day, last_day - pred_length, pred_length)
val_enc_start, val_enc_end, val_pred_start, val_pred_end = get_enc_pred_dates(
first_day + pred_length, last_day, pred_length)
print('Train encoding:', train_enc_start.date(), '--', train_enc_end.date())
print('Train prediction:', train_pred_start.date(), '--', train_pred_end.date(), '\n')
print('Val encoding:', val_enc_start.date(), '--', val_enc_end.date())
print('Val prediction:', val_pred_start.date(), '--', val_pred_end.date())
print('\nEncoding interval:', (train_pred_start - first_day).days)
print('Prediction interval:', pred_length.days)
Train encoding: 2015-07-01 -- 2016-09-02 Train prediction: 2016-09-03 -- 2016-11-01 Val encoding: 2015-08-30 -- 2016-11-01 Val prediction: 2016-11-02 -- 2016-12-31 Encoding interval: 430 Prediction interval: 60
train_enc_series, train_pred_series = prepare_time_series(train_enc_start, train_enc_end,
train_pred_start, train_pred_end,
series_array, date_to_index)
val_enc_series, val_pred_series = prepare_time_series(val_enc_start, val_enc_end,
val_pred_start, val_pred_end,
series_array, date_to_index)
train_dset = torch.utils.data.TensorDataset(train_enc_series, train_pred_series)
val_dset = torch.utils.data.TensorDataset(val_enc_series, val_pred_series)
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=BATCH_SIZE, shuffle=False)
Visualization of a stack of dilated causal convolutional layers.
Model architecture.
def causal_conv1d(in_channels, out_channels, kernel_size=2, dilation=1, pad=False):
conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=0, dilation=dilation)
if pad:
padding = (kernel_size - 1) * dilation
return nn.Sequential(nn.ConstantPad1d((padding, 0), 0.), conv)
else:
return conv
class GatedConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=2, dilation=1):
super().__init__()
self.conv = causal_conv1d(in_channels, 2 * out_channels, kernel_size, dilation)
def forward(self, x):
x = self.conv(x)
a, b = torch.chunk(x, 2, dim=1)
return torch.tanh(a) * torch.sigmoid(b)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, gate_channels, out_channels, skip_size=0, kernel_size=2, dilation=1):
super().__init__()
self.skip_size = skip_size
self.conv_gated = GatedConv1d(in_channels, gate_channels, kernel_size, dilation)
self.conv_res = nn.Conv1d(gate_channels, out_channels, 1)
self.conv_skip = nn.Conv1d(gate_channels, out_channels, 1)
def forward(self, x):
gated = self.conv_gated(x)
residual = self.conv_res(gated)
out = x[:, :, -residual.size(-1):] + residual
skip = self.conv_skip(gated[:, :, -self.skip_size:])
return out, skip
class ResidualStack(nn.Module):
def __init__(self, channels, gate_channels, num_blocks, num_layers, skip_size=0, kernel_size=2):
super().__init__()
self.layers = nn.ModuleList()
for b in range(num_blocks):
for i in range(num_layers):
rate = 2**i
self.layers.append(ResidualBlock(channels, gate_channels, channels, skip_size,
kernel_size, rate))
def forward(self, x):
skips = 0
for layer in self.layers:
x, skip = layer(x)
skips += skip
return skips
def head(in_channels, hidden_channels, out_channels, p_drop=0.):
return nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv1d(in_channels, hidden_channels, 1),
nn.ReLU(inplace=True),
nn.Dropout(p_drop),
nn.Conv1d(hidden_channels, out_channels, 1)
)
def stem(in_channels, out_channels):
return nn.Sequential(
nn.Conv1d(in_channels, out_channels, 1),
nn.Tanh()
)
def wavenet(in_channels, residual_channels, gate_channels, hidden_channels, out_channels, skip_size=0,
num_blocks=2, num_layers=10, kernel_size=2, p_drop=0.):
return nn.Sequential(
stem(in_channels, residual_channels),
ResidualStack(residual_channels, gate_channels, num_blocks, num_layers, skip_size, kernel_size),
head(residual_channels, hidden_channels, out_channels, p_drop=p_drop)
)
model = wavenet(1, 16, 32, 128, 1, skip_size=PRED_STEPS, num_blocks=2, num_layers=7, p_drop=0.2)
model.to(DEVICE);
def show_or_save(fig, filename=None):
if filename:
fig.savefig(filename, bbox_inches='tight', pad_inches=0.05)
plt.close(fig)
else:
plt.show()
class History:
def __init__(self):
self.values = defaultdict(list)
def append(self, key, value):
self.values[key].append(value)
def reset(self):
for k in self.values.keys():
self.values[k] = []
def _plot(self, key, line_type='-', label=None):
if not label: label=key
xs = np.arange(1, len(self.values[key])+1)
self.ax.plot(xs, self.values[key], line_type, label=label)
def plot_train_val(self, key, x_is_batch=False, ylog=False, filename=None):
fig = plt.figure()
self.ax = fig.add_subplot(111)
self._plot('train ' + key, '.-', 'train')
self._plot('val ' + key, '.-', 'val')
self.ax.legend()
if ylog: self.ax.set_yscale('log')
self.ax.set_xlabel('batch' if x_is_batch else 'epoch')
self.ax.set_ylabel(key)
show_or_save(fig, filename)
class Learner:
def __init__(self, model, loss, optimizer, train_loader, val_loader, device,
epoch_scheduler=None, batch_scheduler=None):
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.epoch_scheduler = epoch_scheduler
self.batch_scheduler = batch_scheduler
self.history = History()
def iterate(self, loader, backward_pass=False):
total_loss = 0.0
num_samples = 0
for X, Y in loader:
X, Y = X.to(self.device), Y.to(self.device)
Y_pred = self.model(X)
batch_size = X.size(0)
batch_loss = self.loss(Y_pred, Y)
if backward_pass:
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
if self.batch_scheduler is not None:
self.batch_scheduler.step()
total_loss += batch_size * batch_loss.item()
num_samples += batch_size
avg_loss = total_loss / num_samples
return avg_loss
def train(self):
self.model.train()
train_loss = self.iterate(self.train_loader, backward_pass=True)
print(f'train: loss {train_loss:.3f}', end=' ')
self.history.append('train loss', train_loss)
def validate(self):
self.model.eval()
with torch.no_grad():
val_loss = self.iterate(self.val_loader)
print(f'val: loss {val_loss:.3f}')
self.history.append('val loss', val_loss)
def fit(self, epochs):
for i in range(epochs):
print(f'{i+1}/{epochs}', end=' ')
self.train()
self.validate()
if self.epoch_scheduler is not None:
self.epoch_scheduler.step()
loss = nn.L1Loss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
learner = Learner(model, loss, optimizer, train_loader, val_loader, DEVICE)
EPOCHS = 20
learner.batch_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-2,
steps_per_epoch=len(train_loader),
epochs=EPOCHS)
learner.fit(EPOCHS)
1/20 train: loss 0.404 val: loss 0.295 2/20 train: loss 0.297 val: loss 0.291 3/20 train: loss 0.290 val: loss 0.285 4/20 train: loss 0.285 val: loss 0.286 5/20 train: loss 0.283 val: loss 0.273 6/20 train: loss 0.277 val: loss 0.269 7/20 train: loss 0.273 val: loss 0.283 8/20 train: loss 0.270 val: loss 0.269 9/20 train: loss 0.270 val: loss 0.266 10/20 train: loss 0.266 val: loss 0.265 11/20 train: loss 0.266 val: loss 0.265 12/20 train: loss 0.265 val: loss 0.264 13/20 train: loss 0.264 val: loss 0.263 14/20 train: loss 0.264 val: loss 0.263 15/20 train: loss 0.263 val: loss 0.263 16/20 train: loss 0.263 val: loss 0.263 17/20 train: loss 0.263 val: loss 0.262 18/20 train: loss 0.263 val: loss 0.262 19/20 train: loss 0.263 val: loss 0.262 20/20 train: loss 0.262 val: loss 0.262
learner.history.plot_train_val('loss')
@torch.no_grad()
def predict_sequence(model, x, pred_steps, device):
model.eval()
x = x.to(device)
pred_sequence = np.zeros(pred_steps)
for i in range(pred_steps):
last_pred = model(x)[:, :, -1:]
pred_sequence[i] = last_pred.item()
x = torch.cat((x, last_pred), dim=-1) # append to the sequence
return pred_sequence
def predict_and_plot(enc_data, target_data, sample_ind, enc_tail_len=50):
enc_series = enc_data[sample_ind:sample_ind + 1]
pred_series = predict_sequence(model, enc_series, PRED_STEPS, DEVICE)
enc_series_tail = enc_series[0, 0, -enc_tail_len:].numpy()
target_series = target_data[sample_ind, 0].numpy()
pred_series = np.concatenate([enc_series_tail[-1:], pred_series])
target_series = np.concatenate([enc_series_tail[-1:], target_series])
enc_len = len(enc_series_tail)
fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(111)
ax.plot(range(1, enc_len + 1), enc_series_tail)
ax.plot(range(enc_len, enc_len + PRED_STEPS + 1), target_series, color='orange')
ax.plot(range(enc_len, enc_len + PRED_STEPS + 1), pred_series, color='teal', linestyle='--')
ax.set_title(f'Encoder Series Tail of Length {enc_tail_len}, Target Series, and Predictions')
ax.legend(['Encoding Series', 'Target Series', 'Predictions'])
plt.show()
predict_and_plot(val_enc_series[:, :, :-PRED_STEPS+1], val_pred_series, sample_ind=16555, enc_tail_len=100)