{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# WaveNet" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[arXiv:1609.03499 [cs.SD]](https://arxiv.org/abs/1609.03499)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Data from Kaggle Web Traffic Time Series Forecasting, https://www.kaggle.com/c/web-traffic-time-series-forecasting/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Configuration" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "from datetime import timedelta\n", "from collections import defaultdict\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Configuration" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "data_dir = Path(\"./data\")\n", "PRED_STEPS = 60\n", "\n", "BATCH_SIZE = 1024" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "device: cuda\n" ] } ], "source": [ "DEVICE = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "print(\"device:\", DEVICE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "csv_file = data_dir / \"train_1.csv\"" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv(csv_file)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Page | \n", "2015-07-01 | \n", "2015-07-02 | \n", "2015-07-03 | \n", "2015-07-04 | \n", "2015-07-05 | \n", "2015-07-06 | \n", "2015-07-07 | \n", "2015-07-08 | \n", "2015-07-09 | \n", "... | \n", "2016-12-22 | \n", "2016-12-23 | \n", "2016-12-24 | \n", "2016-12-25 | \n", "2016-12-26 | \n", "2016-12-27 | \n", "2016-12-28 | \n", "2016-12-29 | \n", "2016-12-30 | \n", "2016-12-31 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "2NE1_zh.wikipedia.org_all-access_spider | \n", "18.0 | \n", "11.0 | \n", "5.0 | \n", "13.0 | \n", "14.0 | \n", "9.0 | \n", "9.0 | \n", "22.0 | \n", "26.0 | \n", "... | \n", "32.0 | \n", "63.0 | \n", "15.0 | \n", "26.0 | \n", "14.0 | \n", "20.0 | \n", "22.0 | \n", "19.0 | \n", "18.0 | \n", "20.0 | \n", "
1 | \n", "2PM_zh.wikipedia.org_all-access_spider | \n", "11.0 | \n", "14.0 | \n", "15.0 | \n", "18.0 | \n", "11.0 | \n", "13.0 | \n", "22.0 | \n", "11.0 | \n", "10.0 | \n", "... | \n", "17.0 | \n", "42.0 | \n", "28.0 | \n", "15.0 | \n", "9.0 | \n", "30.0 | \n", "52.0 | \n", "45.0 | \n", "26.0 | \n", "20.0 | \n", "
2 | \n", "3C_zh.wikipedia.org_all-access_spider | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "4.0 | \n", "0.0 | \n", "3.0 | \n", "4.0 | \n", "... | \n", "3.0 | \n", "1.0 | \n", "1.0 | \n", "7.0 | \n", "4.0 | \n", "4.0 | \n", "6.0 | \n", "3.0 | \n", "4.0 | \n", "17.0 | \n", "
3 | \n", "4minute_zh.wikipedia.org_all-access_spider | \n", "35.0 | \n", "13.0 | \n", "10.0 | \n", "94.0 | \n", "4.0 | \n", "26.0 | \n", "14.0 | \n", "9.0 | \n", "11.0 | \n", "... | \n", "32.0 | \n", "10.0 | \n", "26.0 | \n", "27.0 | \n", "16.0 | \n", "11.0 | \n", "17.0 | \n", "19.0 | \n", "10.0 | \n", "11.0 | \n", "
4 | \n", "52_Hz_I_Love_You_zh.wikipedia.org_all-access_s... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "... | \n", "48.0 | \n", "9.0 | \n", "25.0 | \n", "13.0 | \n", "3.0 | \n", "11.0 | \n", "27.0 | \n", "13.0 | \n", "36.0 | \n", "10.0 | \n", "
5 rows × 551 columns
\n", "