marimo.py
import%20marimo%0A%0A__generated_with%20%3D%20%220.1.78%22%0Aapp%20%3D%20marimo.App()%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20%20%20%20%20%23%20Marimo%3A%20reactive%20notebooks%20for%20Python%0A%0A%20%20%20%20%20%20%20%20Website%3A%20%5Bmarimo.io%5D(https%3A%2F%2Fmarimo.io%2F)%0A%0A%20%20%20%20%20%20%20%20Documentation%3A%20%5Bdocs.marimo.io%5D(https%3A%2F%2Fdocs.marimo.io%2F)%0A%0A%20%20%20%20%20%20%20%20Installation%3A%20%60pip%20install%20marimo%60%0A%0A%20%20%20%20%20%20%20%20Create%20or%20edit%20a%20notebook%3A%20%60marimo%20edit%20your_notebook.py%60%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%23%23%20Configuration%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20import%20numpy%20as%20np%0A%20%20%20%20from%20collections%20import%20defaultdict%0A%20%20%20%20import%20matplotlib.pyplot%20as%20plt%0A%0A%20%20%20%20import%20torch%0A%20%20%20%20import%20torch.nn%20as%20nn%0A%20%20%20%20import%20torch.optim%20as%20optim%0A%20%20%20%20import%20torch.nn.functional%20as%20F%0A%20%20%20%20from%20torchvision%20import%20datasets%2C%20transforms%0A%0A%20%20%20%20from%20ignite.engine%20import%20Events%0A%20%20%20%20from%20ignite.engine%20import%20create_supervised_trainer%0A%20%20%20%20from%20ignite.engine%20import%20create_supervised_evaluator%0A%20%20%20%20import%20ignite.metrics%0A%20%20%20%20import%20ignite.contrib.handlers%0A%0A%20%20%20%20mo.md(%22Imports%22)%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20Events%2C%0A%20%20%20%20%20%20%20%20F%2C%0A%20%20%20%20%20%20%20%20create_supervised_evaluator%2C%0A%20%20%20%20%20%20%20%20create_supervised_trainer%2C%0A%20%20%20%20%20%20%20%20datasets%2C%0A%20%20%20%20%20%20%20%20defaultdict%2C%0A%20%20%20%20%20%20%20%20ignite%2C%0A%20%20%20%20%20%20%20%20nn%2C%0A%20%20%20%20%20%20%20%20np%2C%0A%20%20%20%20%20%20%20%20optim%2C%0A%20%20%20%20%20%20%20%20plt%2C%0A%20%20%20%20%20%20%20%20torch%2C%0A%20%20%20%20%20%20%20%20transforms%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell%0Adef%20_(mo)%3A%0A%20%20%20%20DATA_DIR%3D'.%2Fdata'%0A%0A%20%20%20%20NUM_CLASSES%20%3D%2010%0A%20%20%20%20NUM_WORKERS%20%3D%208%0A%20%20%20%20BATCH_SIZE%20%3D%2032%0A%20%20%20%20EPOCHS%20%3D%2020%0A%20%20%20%20LEARNING_RATE%20%3D%201e-2%0A%20%20%20%20WEIGHT_DECAY%20%3D%201e-2%0A%0A%20%20%20%20mo.md(%22Configuration%22)%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20BATCH_SIZE%2C%0A%20%20%20%20%20%20%20%20DATA_DIR%2C%0A%20%20%20%20%20%20%20%20EPOCHS%2C%0A%20%20%20%20%20%20%20%20LEARNING_RATE%2C%0A%20%20%20%20%20%20%20%20NUM_CLASSES%2C%0A%20%20%20%20%20%20%20%20NUM_WORKERS%2C%0A%20%20%20%20%20%20%20%20WEIGHT_DECAY%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20__(mo)%3A%0A%20%20%20%20download_data%20%3D%20mo.ui.checkbox(False%2C%20label%3D%22Download%20data%22)%0A%20%20%20%20do_training%20%3D%20mo.ui.checkbox(False%2C%20label%3D%22Train%20model%22)%0A%20%20%20%20mo.vstack(%5Bdownload_data%2C%20do_training%5D)%0A%20%20%20%20return%20do_training%2C%20download_data%0A%0A%0A%40app.cell%0Adef%20_(torch)%3A%0A%20%20%20%20if%20torch.cuda.is_available()%3A%0A%20%20%20%20%20%20%20%20DEVICE%20%3D%20torch.device(%22cuda%22)%0A%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20DEVICE%20%3D%20torch.device(%22cpu%22)%0A%20%20%20%20print(%22Device%3A%22%2C%20DEVICE)%0A%20%20%20%20return%20DEVICE%2C%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%23%23%20Data%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(transforms)%3A%0A%20%20%20%20train_transform%20%3D%20transforms.Compose(%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20transforms.RandomHorizontalFlip()%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20transforms.RandomCrop(32%2C%20padding%3D4)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20transforms.ColorJitter(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20brightness%3D0.2%2C%20contrast%3D0.2%2C%20saturation%3D0.2%0A%20%20%20%20%20%20%20%20%20%20%20%20)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20transforms.ToTensor()%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20transforms.RandomErasing(p%3D0.1)%2C%0A%20%20%20%20%5D)%0A%20%20%20%20return%20train_transform%2C%0A%0A%0A%40app.cell%0Adef%20_(DATA_DIR%2C%20datasets%2C%20download_data%2C%20train_transform%2C%20transforms)%3A%0A%20%20%20%20train_dset%20%3D%20datasets.CIFAR10(root%3DDATA_DIR%2C%20train%3DTrue%2C%20%0A%20%20%20%20%20%20%20%20download%3Ddownload_data.value%2C%20transform%3Dtrain_transform)%0A%0A%20%20%20%20test_dset%20%3D%20datasets.CIFAR10(root%3DDATA_DIR%2C%20train%3DFalse%2C%20%0A%20%20%20%20%20%20%20%20download%3Ddownload_data.value%2C%20transform%3Dtransforms.ToTensor())%0A%20%20%20%20return%20test_dset%2C%20train_dset%0A%0A%0A%40app.cell%0Adef%20_(BATCH_SIZE%2C%20NUM_WORKERS%2C%20test_dset%2C%20torch%2C%20train_dset)%3A%0A%20%20%20%20train_loader%20%3D%20torch.utils.data.DataLoader(%0A%20%20%20%20%20%20%20%20train_dset%2C%0A%20%20%20%20%20%20%20%20batch_size%3DBATCH_SIZE%2C%0A%20%20%20%20%20%20%20%20shuffle%3DTrue%2C%0A%20%20%20%20%20%20%20%20num_workers%3DNUM_WORKERS%2C%0A%20%20%20%20%20%20%20%20pin_memory%3DTrue%0A%20%20%20%20)%0A%0A%20%20%20%20test_loader%20%3D%20torch.utils.data.DataLoader(%0A%20%20%20%20%20%20%20%20test_dset%2C%0A%20%20%20%20%20%20%20%20batch_size%3DBATCH_SIZE%2C%0A%20%20%20%20%20%20%20%20shuffle%3DFalse%2C%0A%20%20%20%20%20%20%20%20num_workers%3DNUM_WORKERS%2C%0A%20%20%20%20%20%20%20%20pin_memory%3DTrue%0A%20%20%20%20)%0A%20%20%20%20return%20test_loader%2C%20train_loader%0A%0A%0A%40app.cell%0Adef%20_(np%2C%20plt)%3A%0A%20%20%20%20def%20dataset_show_image(dset%2C%20idx)%3A%0A%20%20%20%20%20%20%20%20X%2C%20Y%20%3D%20dset%5Bidx%5D%0A%20%20%20%20%20%20%20%20title%20%3D%20%22Ground%20truth%3A%20%7B%7D%22.format(dset.classes%5BY%5D)%0A%20%20%20%20%20%20%20%20fig%20%3D%20plt.figure()%0A%20%20%20%20%20%20%20%20ax%20%3D%20fig.add_subplot(111)%0A%20%20%20%20%20%20%20%20ax.set_axis_off()%0A%20%20%20%20%20%20%20%20ax.imshow(np.moveaxis(X.numpy()%2C%200%2C%20-1))%0A%20%20%20%20%20%20%20%20ax.set_title(title)%0A%20%20%20%20%20%20%20%20return%20ax%0A%20%20%20%20return%20dataset_show_image%2C%0A%0A%0A%40app.cell%0Adef%20_(dataset_show_image%2C%20test_dset)%3A%0A%20%20%20%20dataset_show_image(test_dset%2C%201)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%23%23%20Model%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(nn)%3A%0A%20%20%20%20class%20ConvBlock(nn.Sequential)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20in_channels%2C%20out_channels%2C%20kernel_size%3D3%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20stride%3D1%2C%20act%3DTrue)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20padding%20%3D%20(kernel_size%20-%201)%20%2F%2F%202%0A%20%20%20%20%20%20%20%20%20%20%20%20layers%20%3D%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.Conv2d(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20in_channels%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20out_channels%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20kernel_size%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20stride%3Dstride%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20padding%3Dpadding%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20bias%3DFalse%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.BatchNorm2d(out_channels)%0A%20%20%20%20%20%20%20%20%20%20%20%20%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20act%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20layers.append(nn.ReLU(inplace%3DTrue))%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__(*layers)%0A%20%20%20%20return%20ConvBlock%2C%0A%0A%0A%40app.cell%0Adef%20_(ConvBlock%2C%20nn%2C%20torch)%3A%0A%20%20%20%20class%20ResidualBlock(nn.Module)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20channels)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__()%0A%20%20%20%20%20%20%20%20%20%20%20%20self.residual%20%3D%20nn.Sequential(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20ConvBlock(channels%2C%20channels%2C%203)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20ConvBlock(channels%2C%20channels%2C%203%2C%20act%3DFalse)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.act%20%3D%20nn.ReLU(inplace%3DTrue)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.%CE%B3%20%3D%20nn.Parameter(torch.zeros(1))%0A%0A%20%20%20%20%20%20%20%20def%20forward(self%2C%20x)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20out%20%3D%20x%20%2B%20self.%CE%B3%20*%20self.residual(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20out%20%3D%20self.act(out)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20out%0A%20%20%20%20return%20ResidualBlock%2C%0A%0A%0A%40app.cell%0Adef%20_(ConvBlock%2C%20nn)%3A%0A%20%20%20%20class%20DownBlock(nn.Sequential)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20in_channels%2C%20out_channels)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20ConvBlock(in_channels%2C%20out_channels%2C%203)%2C%20nn.MaxPool2d(2)%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20return%20DownBlock%2C%0A%0A%0A%40app.cell%0Adef%20_(DownBlock%2C%20ResidualBlock%2C%20nn)%3A%0A%20%20%20%20class%20ResidualLayer(nn.Sequential)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20in_channels%2C%20out_channels)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20DownBlock(in_channels%2C%20out_channels)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20ResidualBlock(out_channels)%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20return%20ResidualLayer%2C%0A%0A%0A%40app.cell%0Adef%20_(nn%2C%20torch)%3A%0A%20%20%20%20class%20TemperatureScaler(nn.Module)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20scaling_factor%3D0.1)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__()%0A%20%20%20%20%20%20%20%20%20%20%20%20self.scaler%20%3D%20nn.Parameter(torch.tensor(scaling_factor))%0A%0A%20%20%20%20%20%20%20%20def%20forward(self%2C%20x)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20x%20*%20self.scaler%0A%20%20%20%20return%20TemperatureScaler%2C%0A%0A%0A%40app.cell%0Adef%20_(TemperatureScaler%2C%20nn)%3A%0A%20%20%20%20class%20Head(nn.Sequential)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20in_channels%2C%20classes)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.AdaptiveMaxPool2d(1)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.Flatten()%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.Linear(in_channels%2C%20classes)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20TemperatureScaler()%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20return%20Head%2C%0A%0A%0A%40app.cell%0Adef%20_(ConvBlock%2C%20DownBlock%2C%20Head%2C%20ResidualLayer%2C%20nn)%3A%0A%20%20%20%20class%20Net(nn.Sequential)%3A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20classes%2C%20hidden_channels%2C%20in_channels%3D3)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20channels%20%3D%20%5Bhidden_channels%20*%202**num%20for%20num%20in%20range(4)%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20super().__init__(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20ConvBlock(in_channels%2C%20hidden_channels%2C%203)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20ResidualLayer(channels%5B0%5D%2C%20channels%5B1%5D)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20DownBlock(channels%5B1%5D%2C%20channels%5B2%5D)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20ResidualLayer(channels%5B2%5D%2C%20channels%5B3%5D)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20Head(channels%5B3%5D%2C%20classes)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20return%20Net%2C%0A%0A%0A%40app.cell%0Adef%20_(nn)%3A%0A%20%20%20%20def%20init_linear(m)%3A%0A%20%20%20%20%20%20%20%20if%20isinstance(m%2C%20(nn.Conv2d%2C%20nn.Linear))%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20nn.init.kaiming_normal_(m.weight)%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20m.bias%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nn.init.zeros_(m.bias)%0A%20%20%20%20return%20init_linear%2C%0A%0A%0A%40app.cell%0Adef%20_(DEVICE%2C%20NUM_CLASSES%2C%20Net)%3A%0A%20%20%20%20model%20%3D%20Net(NUM_CLASSES%2C%20hidden_channels%3D64).to(DEVICE)%0A%20%20%20%20return%20model%2C%0A%0A%0A%40app.cell%0Adef%20_(init_linear%2C%20model)%3A%0A%20%20%20%20model.apply(init_linear)%0A%20%20%20%20None%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(model)%3A%0A%20%20%20%20print(%22Number%20of%20parameters%3A%20%7B%3A%2C%7D%22.format(sum(%0A%20%20%20%20%20%20%20%20p.numel()%20for%20p%20in%20model.parameters())))%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%23%23%20Training%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(nn)%3A%0A%20%20%20%20loss%20%3D%20nn.CrossEntropyLoss(label_smoothing%3D0.1)%0A%20%20%20%20return%20loss%2C%0A%0A%0A%40app.cell%0Adef%20_(WEIGHT_DECAY%2C%20model%2C%20optim)%3A%0A%20%20%20%20optimizer%20%3D%20optim.AdamW(model.parameters()%2C%20lr%3D1e-6%2C%20weight_decay%3DWEIGHT_DECAY)%0A%20%20%20%20return%20optimizer%2C%0A%0A%0A%40app.cell%0Adef%20_(DEVICE%2C%20create_supervised_trainer%2C%20loss%2C%20model%2C%20optimizer)%3A%0A%20%20%20%20trainer%20%3D%20create_supervised_trainer(model%2C%20optimizer%2C%20loss%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20device%3DDEVICE)%0A%20%20%20%20return%20trainer%2C%0A%0A%0A%40app.cell%0Adef%20_(EPOCHS%2C%20LEARNING_RATE%2C%20optim%2C%20optimizer%2C%20train_loader)%3A%0A%20%20%20%20lr_scheduler%20%3D%20optim.lr_scheduler.OneCycleLR(%0A%20%20%20%20%20%20%20%20optimizer%2C%0A%20%20%20%20%20%20%20%20max_lr%3DLEARNING_RATE%2C%0A%20%20%20%20%20%20%20%20steps_per_epoch%3Dlen(train_loader)%2C%0A%20%20%20%20%20%20%20%20epochs%3DEPOCHS%2C%0A%20%20%20%20)%0A%20%20%20%20return%20lr_scheduler%2C%0A%0A%0A%40app.cell%0Adef%20_(Events%2C%20lr_scheduler%2C%20trainer)%3A%0A%20%20%20%20trainer.add_event_handler(Events.ITERATION_COMPLETED%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20lambda%20engine%3A%20lr_scheduler.step())%0A%20%20%20%20None%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(ignite%2C%20trainer)%3A%0A%20%20%20%20ignite.metrics.RunningAverage(output_transform%3Dlambda%20x%3A%20x).attach(%0A%20%20%20%20%20%20%20%20trainer%2C%20%22loss%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(ignite%2C%20loss)%3A%0A%20%20%20%20val_metrics%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%22accuracy%22%3A%20ignite.metrics.Accuracy()%2C%0A%20%20%20%20%20%20%20%20%22loss%22%3A%20ignite.metrics.Loss(loss)%2C%0A%20%20%20%20%7D%0A%20%20%20%20return%20val_metrics%2C%0A%0A%0A%40app.cell%0Adef%20_(DEVICE%2C%20create_supervised_evaluator%2C%20model%2C%20val_metrics)%3A%0A%20%20%20%20evaluator%20%3D%20create_supervised_evaluator(model%2C%20metrics%3Dval_metrics%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20device%3DDEVICE)%0A%20%20%20%20return%20evaluator%2C%0A%0A%0A%40app.cell%0Adef%20_(defaultdict)%3A%0A%20%20%20%20history%20%3D%20defaultdict(list)%0A%20%20%20%20return%20history%2C%0A%0A%0A%40app.cell%0Adef%20_(Events%2C%20evaluator%2C%20history%2C%20test_loader%2C%20trainer)%3A%0A%20%20%20%20%40trainer.on(Events.EPOCH_COMPLETED)%0A%20%20%20%20def%20log_validation_results(engine)%3A%0A%20%20%20%20%20%20%20%20train_state%20%3D%20engine.state%0A%20%20%20%20%20%20%20%20epoch%20%3D%20train_state.epoch%0A%20%20%20%20%20%20%20%20max_epochs%20%3D%20train_state.max_epochs%0A%20%20%20%20%20%20%20%20train_loss%20%3D%20train_state.metrics%5B%22loss%22%5D%0A%20%20%20%20%20%20%20%20history%5B%22train%20loss%22%5D.append(train_loss)%0A%0A%20%20%20%20%20%20%20%20evaluator.run(test_loader)%0A%20%20%20%20%20%20%20%20val_metrics%20%3D%20evaluator.state.metrics%0A%20%20%20%20%20%20%20%20val_loss%20%3D%20val_metrics%5B%22loss%22%5D%0A%20%20%20%20%20%20%20%20val_acc%20%3D%20val_metrics%5B%22accuracy%22%5D%0A%20%20%20%20%20%20%20%20history%5B%22val%20loss%22%5D.append(val_loss)%0A%20%20%20%20%20%20%20%20history%5B%22val%20acc%22%5D.append(val_acc)%0A%0A%20%20%20%20%20%20%20%20print((%22%7B%7D%2F%7B%7D%20-%20train%3A%20loss%20%7B%3A.3f%7D%3B%20%22%20%2B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22val%3A%20loss%20%7B%3A.3f%7D%20accuracy%20%7B%3A.3f%7D%22).format(epoch%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20max_epochs%2C%20train_loss%2C%20val_loss%2C%20val_acc))%0A%20%20%20%20return%20log_validation_results%2C%0A%0A%0A%40app.cell%0Adef%20_(EPOCHS%2C%20do_training%2C%20train_loader%2C%20trainer)%3A%0A%20%20%20%20if%20do_training.value%3A%0A%20%20%20%20%20%20%20%20trainer.run(train_loader%2C%20max_epochs%3DEPOCHS)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(np%2C%20plt)%3A%0A%20%20%20%20def%20history_plot_train_val(history%2C%20key)%3A%0A%20%20%20%20%20%20%20%20fig%20%3D%20plt.figure()%0A%20%20%20%20%20%20%20%20ax%20%3D%20fig.add_subplot(111)%0A%20%20%20%20%20%20%20%20xs%20%3D%20np.arange(1%2C%20len(history%5B'train%20'%20%2B%20key%5D)%20%2B%201)%0A%20%20%20%20%20%20%20%20ax.plot(xs%2C%20history%5B'train%20'%20%2B%20key%5D%2C%20'.-'%2C%20label%3D'train')%0A%20%20%20%20%20%20%20%20ax.plot(xs%2C%20history%5B'val%20'%20%2B%20key%5D%2C%20'.-'%2C%20label%3D'val')%0A%20%20%20%20%20%20%20%20ax.set_xlabel('epoch')%0A%20%20%20%20%20%20%20%20ax.set_ylabel(key)%0A%20%20%20%20%20%20%20%20ax.legend()%0A%20%20%20%20%20%20%20%20ax.grid()%0A%20%20%20%20%20%20%20%20return%20ax%0A%20%20%20%20return%20history_plot_train_val%2C%0A%0A%0A%40app.cell%0Adef%20_(np%2C%20plt)%3A%0A%20%20%20%20def%20history_plot(history%2C%20key)%3A%0A%20%20%20%20%20%20%20%20fig%20%3D%20plt.figure()%0A%20%20%20%20%20%20%20%20ax%20%3D%20fig.add_subplot(111)%0A%20%20%20%20%20%20%20%20xs%20%3D%20np.arange(1%2C%20len(history%5Bkey%5D)%20%2B%201)%0A%20%20%20%20%20%20%20%20ax.plot(xs%2C%20history%5Bkey%5D%2C%20'.-')%0A%20%20%20%20%20%20%20%20ax.set_xlabel('epoch')%0A%20%20%20%20%20%20%20%20ax.set_ylabel(key)%0A%20%20%20%20%20%20%20%20ax.grid()%0A%20%20%20%20%20%20%20%20return%20ax%0A%20%20%20%20return%20history_plot%2C%0A%0A%0A%40app.cell%0Adef%20_(history%2C%20history_plot_train_val)%3A%0A%20%20%20%20history_plot_train_val(history%2C%20'loss')%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(history%2C%20history_plot)%3A%0A%20%20%20%20history_plot(history%2C%20'val%20acc')%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%20%20%20%20return%20mo%2C%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()