Simple detector¶

Dataset: Embrapa Wine Grape Instance Segmentation Dataset

Download:

git clone https://github.com/thsant/wgisd.git

Configuration¶

Imports

In [1]:
from pathlib import Path
from collections import defaultdict
import math
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from tqdm import tqdm, trange

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
from torchvision import ops, tv_tensors, io
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as TF

Configuration

In [2]:
WGISD_ROOT = Path('./wgisd/')
DATA_DIR = WGISD_ROOT / 'data'
TRAIN_LIST = WGISD_ROOT / 'train.txt'
TEST_LIST = WGISD_ROOT / 'test.txt'

MODELS_DIR = Path('./models')

VAL_RATIO = 0.1

IMAGE_SIZE = 512

NUM_WORKERS = 8
BATCH_SIZE = 16
EPOCHS = 1000

LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-2
In [3]:
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda

Data¶

Reading bounding box information¶

In [4]:
def get_file_list(list_name):
    with open(list_name) as f:
        file_list = f.readlines()
    file_list = [file_name.strip() for file_name in file_list]
    return file_list
In [5]:
def read_annotation(annotation_path):
    bboxes = []
    try:
        with open(annotation_path) as f:
            for line in f:
                label, x_center, y_center, width, height = [float(s) for s in line.split()]
                x1 = max(x_center - 0.5 * width, 0.)
                y1 = max(y_center - 0.5 * height, 0.)
                x2 = min(x_center + 0.5 * width, 1.)
                y2 = min(y_center + 0.5 * height, 1.)
                if x1 < x2 and y1 < y2:
                    bbox = [x1, y1, x2, y2]
                    bboxes.append(bbox)
                else:
                    print("Invalid bounding box", annotation_path)
    
    except FileNotFoundError:
        print("Annotation missing:", annotation_path)

    return bboxes
In [6]:
def denormalize_bounding_boxes(bboxes, shape):
    w, h = shape
    new_bboxes = []
    for bbox in bboxes:
        x1, y1, x2, y2 = bbox
        new_bbox = [x1 * w, y1 * h, x2 * w, y2 * h]
        new_bboxes.append(new_bbox)
    return new_bboxes
In [7]:
def read_bounding_boxes(file_list):
    images_data = []
    for name in file_list:
        image_name = name + '.jpg'
        image_path = DATA_DIR / image_name
        image = Image.open(image_path)
        shape = image.size

        annotation_name = name + '.txt'
        annotation_path = DATA_DIR / annotation_name
        bboxes = read_annotation(annotation_path)
        bboxes = denormalize_bounding_boxes(bboxes, shape)
        
        data = {'file_name': image_path, 'bboxes': bboxes, 'shape': shape}
        images_data.append(data)
            
    return images_data
In [8]:
train_val_list = get_file_list(TRAIN_LIST)
In [9]:
test_list = get_file_list(TEST_LIST)
In [10]:
train_list, val_list = train_test_split(train_val_list, test_size=VAL_RATIO, random_state=0,
                                        shuffle=True)
In [11]:
print("Number of train images:", len(train_list))
print("Number of validation images:", len(val_list))
print("Number of train+validation images:", len(train_list) + len(val_list))
print("Number of test images:", len(test_list))
print("Total number of images:", len(train_list) + len(val_list) + len(test_list))
Number of train images: 217
Number of validation images: 25
Number of train+validation images: 242
Number of test images: 58
Total number of images: 300
In [12]:
train_data = read_bounding_boxes(train_list)
val_data = read_bounding_boxes(val_list)
test_data = read_bounding_boxes(test_list)

Utilities¶

Plotting

In [13]:
def plot_image(image):
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)
    ax.set_axis_off()
    ax.imshow(image)
    plt.show()
In [14]:
def plot_image_bounding_boxes(image, bboxes, bboxes_gt=None):
    image = TF.convert_image_dtype(image, dtype=torch.uint8)
    image_bboxes = torchvision.utils.draw_bounding_boxes(image=image, boxes=bboxes, colors='red', width=2)
    if bboxes_gt is not None:
        image_bboxes = torchvision.utils.draw_bounding_boxes(image=image_bboxes, boxes=bboxes_gt, colors='blue', width=2)
    image_bboxes = TF.to_pil_image(image_bboxes)
    plot_image(image_bboxes)

Dataset¶

In [15]:
train_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), antialias=True),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomAffine(
        degrees=10.,
        translate=(0.1, 0.1),
        scale=(0.9, 1.1),
        shear=8.
    ),
    transforms.RandomPhotometricDistort(p=1),
    transforms.ClampBoundingBoxes(),
    transforms.SanitizeBoundingBoxes(labels_getter=None),
    transforms.ToDtype(torch.float32, scale=True),
])
In [16]:
val_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), antialias=True),
    transforms.ToDtype(torch.float32, scale=True),
])
In [17]:
class ImagesDataset(torch.utils.data.Dataset):
    def __init__(self, images_data, transform):
        self.images_data = images_data
        self.transform = transform
    
    def __len__(self):
        return len(self.images_data)
    
    def __getitem__(self, idx):
        img_data = self.images_data[idx]
        file_path = img_data['file_name']
        image = tv_tensors.Image(io.read_image(str(file_path)))

        bboxes = img_data['bboxes']
        if len(bboxes) == 0: bboxes = torch.zeros((0, 4), dtype=torch.float32)

        bboxes = tv_tensors.BoundingBoxes(bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=image.shape[1:])
        image, bboxes = self.transform(image, bboxes)
       
        return image, bboxes
    
    def show_image(self, idx):
        image, bboxes = self[idx]
        plot_image_bounding_boxes(image, bboxes)
In [18]:
train_dset = ImagesDataset(train_data, train_transform)
In [19]:
val_dset = ImagesDataset(val_data, val_transform)
In [20]:
test_dset = ImagesDataset(test_data, val_transform)
In [21]:
val_dset.show_image(0)

DataLoader¶

Normalize bounding boxes

In [22]:
def normalize_bboxes(bboxes):
    h, w = bboxes.canvas_size
    wh = torch.tensor([w, h])
    bboxes = torch.cat((bboxes[:, :2] / wh, bboxes[:, 2:] / wh), axis=1)
    return bboxes

Pad ground truth bounding boxes to allow formation of a batch tensor.

In [23]:
def extend_tensor(t, max_len):
    l = len(t)
    if l < max_len:
        z = torch.zeros((max_len - l,) + t.shape[1:])
        return torch.cat((t, z), dim=0)
    else:
        return t
In [24]:
def collate_fn(batch):
    images, bboxes = zip(*batch)
    
    images_tens = torch.stack(images, dim=0)

    max_bboxes = max(len(bbs) for bbs in bboxes)
    bboxes_tens = [extend_tensor(normalize_bboxes(bbs), max_bboxes) for bbs in bboxes]
    bboxes_tens = torch.stack(bboxes_tens, dim=0)

    return images_tens, bboxes_tens
In [25]:
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
                                           collate_fn=collate_fn,
                                           num_workers=NUM_WORKERS, pin_memory=True)

val_loader = torch.utils.data.DataLoader(val_dset, batch_size=BATCH_SIZE, shuffle=False,
                                         collate_fn=collate_fn,
                                         num_workers=NUM_WORKERS, pin_memory=True)
In [26]:
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
                                          collate_fn=collate_fn,
                                          num_workers=NUM_WORKERS, pin_memory=True)

Model¶

For each cell in the output model proposes a bounding box with the center in that cell and a score.

Encoder¶

In [27]:
class ConvBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, act=True):
        padding = (kernel_size - 1) // 2
        super().__init__(
          nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
          nn.BatchNorm2d(out_channels)  
        )
        if act: self.append(nn.ReLU(inplace=True))
In [28]:
class DownBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super().__init__(
            nn.MaxPool2d(2),
            ConvBlock(in_channels, out_channels, 3),
            ConvBlock(out_channels, out_channels, 3)
        )
In [29]:
class Encoder(nn.Module):
    def __init__(self, in_channels, channels, num_downsamplings):
        super().__init__()
        self.stem = ConvBlock(in_channels, channels, 3)
        
        self.blocks = nn.ModuleList()
        in_channels = channels
        for _ in range(num_downsamplings):
            out_channels = in_channels * 2
            self.blocks.append(DownBlock(in_channels, out_channels))
            in_channels = out_channels
    
    def forward(self, x):
        x = self.stem(x)
        
        xs = []
        for block in self.blocks:
            xs.append(x)
            x = block(x)
        return x, xs

Decoder¶

The decoder is inspired by Feature Pyramid Network (FPN), arXiv:1612.03144 [cs.CV]

In [30]:
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, p_drop=0.):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv = ConvBlock(in_channels, out_channels, 1, act=False)
        self.gamma = nn.Parameter(torch.tensor(0.))
        self.p_drop = p_drop

    def forward(self, x, skip):
        if self.training and self.p_drop > 0.:
            keep_prob = 1. - self.p_drop
            add_branch = torch.bernoulli(torch.tensor(keep_prob))
        else:
            keep_prob = 1.
            add_branch = True

        out = self.up(x)
        if add_branch:
            out = out + self.gamma / keep_prob * self.conv(skip)
        return out
In [31]:
class Head(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super().__init__(
            nn.ReLU(inplace=True),
            ConvBlock(in_channels, out_channels, 3)
        )
In [32]:
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels, num_upsamplings, p_drop_path=0.):
        super().__init__()
        self.conv = ConvBlock(in_channels, out_channels, 1, act=False)
        self.up = nn.ModuleList()
        for _ in range(num_upsamplings):
            in_channels = in_channels // 2
            self.up.append(UpBlock(in_channels, out_channels, p_drop_path))
        
        self.head = Head(out_channels, out_channels)

    def forward(self, x, xs):
        x = self.conv(x)
        for up_layer, skip in zip(self.up, reversed(xs)):
            x = up_layer(x, skip)
        x = self.head(x)
        return x

Full model¶

In [33]:
class DetectionHead(nn.Module):
    def __init__(self, channels, p_drop=0.):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Dropout(p_drop),
            nn.Conv2d(channels, 5, kernel_size=1)
        )

    def forward(self, x):
        out = self.proj(x)
        boxes = self.to_boxes(out)
        return boxes

    @staticmethod
    def to_boxes(out):
        h, w = out.shape[2:]

        coords = torch.sigmoid(out[:, :4])

        grid_x = torch.arange(w, device=out.device).unsqueeze(0)
        grid_y = torch.arange(h, device=out.device).unsqueeze(1)
        cx = (coords[:, 0] + grid_x) / w
        cy = (coords[:, 1] + grid_y) / h
        pred_w = coords[:, 2]
        pred_h = coords[:, 3]
        
        x1 = cx - 0.5 * pred_w
        y1 = cy - 0.5 * pred_h
        x2 = cx + 0.5 * pred_w
        y2 = cy + 0.5 * pred_h
        
        scores = out[:, 4]
        boxes = torch.stack((x1, y1, x2, y2, scores), dim=3)
        return boxes
In [34]:
class Net(nn.Module):
    def __init__(self, num_downsamplings, num_upsamplings, channels, in_channels, out_channels,
                 decoder_p_drop_path=0., head_p_drop=0.):
        super().__init__()
        encoder_out_channels = channels * 2**num_downsamplings
        self.encoder = Encoder(in_channels, channels, num_downsamplings)
        self.decoder = Decoder(encoder_out_channels, out_channels, num_upsamplings, decoder_p_drop_path)
        self.head = DetectionHead(out_channels, head_p_drop)

    def forward(self, x):
        x, xs = self.encoder(x)
        out = self.decoder(x, xs)
        boxes = self.head(out)
        return boxes

Model creation¶

In [35]:
def init_model(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.kaiming_normal_(m.weight)
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1.)
            nn.init.zeros_(m.bias)
        elif isinstance(m, UpBlock):
            nn.init.zeros_(m.gamma)
In [36]:
model = Net(num_downsamplings=5, num_upsamplings=3,
            channels=24, in_channels=3, out_channels=128,
            decoder_p_drop_path=0.1, head_p_drop=0.)
In [37]:
init_model(model)
In [38]:
model.to(DEVICE);
In [39]:
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 10,946,816

Loss¶

In [40]:
def batch_box_area(boxes):
    return (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])
In [41]:
def batch_box_scores(boxes_gt,   # shape: (b, n_gt, 4)
                     boxes_pred, # shape: (b, n_pred, 4)
                     eps=1e-7):
    
    lt = torch.max(boxes_gt[..., None, :2], boxes_pred[..., None, :, :2])
    rb = torch.min(boxes_gt[..., None, 2:], boxes_pred[..., None, :, 2:])

    wh = (rb - lt).clamp(min=0) 
    inter = wh[..., 0] * wh[..., 1] # shape: (b, n_gt, n_pred)
    
    area_gt = batch_box_area(boxes_gt)[..., None]
    area_pred = batch_box_area(boxes_pred)[..., None, :]
    
    scores = (2 * inter - area_pred) / (area_gt + eps)
    return scores
In [42]:
class DetectionLoss(nn.Module):
    def __init__(self, λ_conf=1.0, λ_geom=1.0):
        super().__init__()
        self.λ_conf = λ_conf
        self.λ_geom = λ_geom
        self.conf_loss = nn.BCEWithLogitsLoss()
    
    def forward(self, output, target):
        if target.size(1) == 0:
            target = torch.zeros((target.size(0), 1, 4), device=target.device)

        output = output.flatten(1, 2)
        pred_bboxes = output[..., :4]
        pred_scores = output[..., 4]

        true_scores = batch_box_scores(target, pred_bboxes)
        true_scores, _ = torch.max(true_scores, dim=1)
        mask = true_scores.clamp(min=0.).detach()

        conf_loss = self.conf_loss(pred_scores, mask)
        geom_loss = torch.mean(mask * (1. - true_scores))
        loss = self.λ_conf * conf_loss + self.λ_geom * geom_loss
        return loss

Metrics¶

Output preprocessing

from https://github.com/xingyizhou/CenterNet

In [43]:
def heatmap_peaks(heat, kernel=3):
    pad = (kernel - 1) // 2
    hmax = nn.functional.max_pool2d(heat.unsqueeze(1), (kernel, kernel), stride=1, padding=pad).squeeze(1)
    peaks = (hmax == heat)
    return peaks
In [44]:
def remove_zero_boxes(boxes):
    mask = ~ (boxes == 0.).all(dim=-1)
    return boxes[mask]
In [45]:
class AverageLoss():
    def __init__(self):
        self.name = "loss"
        self.reset()

    def reset(self):
        self.num_samples = 0
        self.total_loss = 0.

    def update(self, data):
        batch_size = data['batch_size']
        self.num_samples += batch_size
        self.total_loss += batch_size * data[self.name]

    def compute(self):
        avg_loss = self.total_loss / self.num_samples
        metrics = {self.name: avg_loss}
        return metrics
In [46]:
class DetectionMetrics():
    def __init__(self, threshold=0.5, iou_threshold=0.5, nms_threshold=0.5, heatmap_kernel=3):
        self.threshold = threshold
        self.iou_threshold = iou_threshold
        self.nms_threshold = nms_threshold
        self.heatmap_kernel = heatmap_kernel
        self.reset()

    def reset(self):
        self._num_true = 0

    def update(self, data):
        outputs = data['outputs']
        true_bboxes = data['targets']

        logits = outputs[..., 4]
        scores = torch.sigmoid(logits)
        pred_bboxes = outputs[..., :4]

        peaks = heatmap_peaks(scores, kernel=self.heatmap_kernel)
        masks = peaks & (scores > self.threshold)

        for true_boxes, pred_boxes, conf, mask in zip(true_bboxes, pred_bboxes, scores, masks):
            true_boxes = remove_zero_boxes(true_boxes)
            
            pred_boxes = pred_boxes[mask]
            conf = conf[mask]

            keep_idx = ops.nms(pred_boxes, conf, self.nms_threshold)
            pred_boxes = pred_boxes[keep_idx]
            conf = conf[keep_idx]

            if len(pred_boxes) > 0 and len(true_boxes) > 0:
                ious = ops.box_iou(pred_boxes, true_boxes)

                # zero all non_maximum values for a given ground truth box
                idx = ious.argmax(dim=0, keepdims=True)
                ious = torch.zeros_like(ious).scatter_(0, idx, ious.gather(0, idx))

                is_true_positive = ious > self.iou_threshold
                is_true_positive = is_true_positive.any(dim=1)
            else:
                is_true_positive = torch.zeros_like(conf, dtype=torch.bool)

            self._num_true += len(true_boxes)
            self.process_batch(is_true_positive, conf)

Simple F1 score for evaluation during training

In [47]:
class F1(DetectionMetrics):
    def __init__(self, threshold=0.5, iou_threshold=0.5, nms_threshold=0.5, heatmap_kernel=3):
        super().__init__(threshold=threshold, iou_threshold=iou_threshold, nms_threshold=nms_threshold,
                         heatmap_kernel=heatmap_kernel)
        self.name = "F1"
        self.reset()

    def reset(self):
        self._num_pred = 0
        self._num_tp = 0
        super().reset()

    def process_batch(self, is_true_positive, conf):
        self._num_pred += len(conf)
        self._num_tp += is_true_positive.sum().item()

    def compute(self):
        prec = self._num_tp / (self._num_pred + 1e-7)
        rec = self._num_tp / (self._num_true + 1e-7)
        f1 = 2 * prec * rec / (prec + rec + 1e-7)
        metrics = {self.name: f1}
        return metrics

Average precsion AP using all-point interpolation

In [48]:
class AveragePrecision(DetectionMetrics):
    def __init__(self, iou_threshold=0.5, nms_threshold=0.5, heatmap_kernel=3):
        super().__init__(threshold=0., iou_threshold=iou_threshold, nms_threshold=nms_threshold,
                         heatmap_kernel=heatmap_kernel)
        self.name = "AP"
        self.reset()

    def reset(self):
        self._tp_list = []
        self._confidence_list = []
        super().reset()

    def process_batch(self, is_true_positive, conf):
        self._confidence_list.append(conf)
        tp = torch.zeros_like(conf)
        tp[is_true_positive] = 1.
        self._tp_list.append(tp)

    def compute(self):
        conf = torch.cat(self._confidence_list)
        tp = torch.cat(self._tp_list)

        # sort true positives according to confidence
        idx = conf.argsort(descending=True)
        tp = tp[idx]

        # cumulative sum of true positives and false positives
        tpc = tp.cumsum(0)
        fpc = (1. - tp).cumsum(0)

        # precision and recall curves
        recall = tpc / (self._num_true + 1e-7)
        precision =  tpc / (tpc + fpc)

        # append sentinel values to beginning and end
        z = torch.zeros(1, device=recall.device)
        o = torch.ones(1, device=recall.device)
        recall = torch.cat([z, recall, o])
        precision = torch.cat([o, precision, z])

        # compute precision envelope
        precision = precision.flip(0)
        precision, _ = precision.cummax(0)
        precision = precision.flip(0)

        # integrate area under curve
        idx = (recall[1:] != recall[:-1]).nonzero(as_tuple=True)[0] # indexes where recall changes
        ap = ((recall[idx + 1] - recall[idx]) * precision[idx + 1]).sum().item() # area under curve

        metrics = {self.name: ap}
        return metrics

Training¶

Training functions¶

In [49]:
def iterate(step_fn, loader, metrics_list):
    for metric in metrics_list:
        metric.reset()
    
    for x, y in loader:
        x = x.to(DEVICE)
        if type(y) == list or type(y) == tuple:
            y = [yi.to(DEVICE) for yi in y]
        else:
            y = y.to(DEVICE)
        
        loss, out = step_fn(x, y)

        data = {"loss": loss.item(),
                "batch_size": len(x),
                "targets": y,
                "outputs": out}

        for metric in metrics_list:
            metric.update(data)

    metrics = {}
    for metric in metrics_list:
        metrics.update(metric.compute())
    return metrics
In [50]:
def train(model, loss_fn, optimizer, loader, batch_scheduler, metrics_list):
    def train_step(x, y):
        out = model(x)
        loss = loss_fn(out, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_scheduler.step()
        return loss.detach(), out.detach()

    model.train()
    metrics = iterate(train_step, loader, metrics_list)
    return metrics
In [51]:
def evaluate(model, loss_fn, loader, metrics_list):
    def eval_step(x, y):
        out = model(x)
        loss = loss_fn(out, y)
        return loss.detach(), out.detach()

    model.eval()
    with torch.inference_mode():
        metrics = iterate(eval_step, loader, metrics_list)
    return metrics
In [52]:
def update_history(history, metrics, name):
    for key, val in metrics.items():
        history[name + ' ' + key].append(val)
In [53]:
def history_plot_train_val(history, key):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    xs = np.arange(1, len(history['train ' + key]) + 1)
    ax.plot(xs, history['train ' + key], '.-', label='train')
    ax.plot(xs, history['val ' + key], '.-', label='val')
    ax.set_xlabel('epoch')
    ax.set_ylabel(key)
    ax.legend()
    ax.grid()
    plt.show()

Start training¶

In [54]:
loss = DetectionLoss()
In [55]:
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
In [56]:
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
                                             steps_per_epoch=len(train_loader), epochs=EPOCHS)
In [57]:
history = defaultdict(list)
In [58]:
metrics = [AverageLoss(), F1()]
In [59]:
pbar = trange(EPOCHS, ncols=140)
for epoch in pbar:
    train_metrics = train(model, loss, optimizer, train_loader, lr_scheduler, metrics)
    update_history(history, train_metrics, "train")
    
    val_metrics = evaluate(model, loss, val_loader, metrics)
    update_history(history, val_metrics, "val")
    pbar.set_postfix({"train F1": f"{train_metrics['F1']:.3f}", "val F1": f"{val_metrics['F1']:.3f}"})
100%|███████████████████████████████████████████████████████████████████| 1000/1000 [2:07:48<00:00,  7.67s/it, train F1=0.947, val F1=0.783]
In [60]:
torch.save(model.state_dict(), str(MODELS_DIR / 'final_model.pt'))

Plotting¶

In [61]:
history_plot_train_val(history, 'loss')
In [62]:
history_plot_train_val(history, 'F1')

Testing¶

In [63]:
model.load_state_dict(torch.load(str(MODELS_DIR / 'final_model.pt')))
Out[63]:
<All keys matched successfully>
In [64]:
model.eval();
In [65]:
thresholds = np.linspace(0., 1.0, num=50)
f1_vs_thr = [evaluate(model, loss, val_loader, [F1(threshold=thr)])['F1'] for thr in tqdm(thresholds, ncols=140)]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:00<00:00,  1.21s/it]
In [66]:
plt.plot(thresholds, f1_vs_thr);
plt.grid();
In [67]:
num = np.argmax(f1_vs_thr)
best_thr = thresholds[num]
best_f1 = f1_vs_thr[num]
print(f"Best F1 {best_f1:.3f} for threshold {best_thr:.3f}")
Best F1 0.811 for threshold 0.633
In [68]:
test_f1 = evaluate(model, loss, test_loader, [F1(threshold=best_thr)])['F1']
print(f"Test F1: {test_f1:.3f}")
Test F1: 0.854
In [69]:
test_ap = evaluate(model, loss, test_loader, [AveragePrecision()])['AP']
print(f"Test AP: {test_ap:.3f}")
Test AP: 0.833

Example¶

Convert from scaled tensors to unscaled bounding boxes

In [70]:
def denormalize_bboxes(bboxes, shape):
    wh = torch.tensor([shape[1], shape[0]])
    bboxes = torch.cat((bboxes[:, :2] * wh, bboxes[:, 2:] * wh), axis=1)
    return bboxes
In [71]:
def output_to_bboxes(output, shape, threshold=0.5, nms_threshold=0.5, heatmap_kernel=3):
    logits = output[..., 4]
    scores = torch.sigmoid(logits)
    pred_boxes = output[..., :4]
    
    peaks = heatmap_peaks(scores.unsqueeze(0), kernel=heatmap_kernel).squeeze(0)
    mask = peaks & (scores > threshold)
    
    pred_boxes = pred_boxes[mask]
    scores = scores[mask]
    
    keep_idx = ops.nms(pred_boxes, scores, nms_threshold)
    pred_boxes = pred_boxes[keep_idx]
    scores = scores[keep_idx]
    
    bboxes = denormalize_bboxes(pred_boxes, shape)
    
    return bboxes, scores
In [72]:
image, bboxes_gt = test_dset[0]
In [73]:
with torch.inference_mode():
    output = model(image.unsqueeze(0).to(DEVICE))
output = output.detach().cpu().squeeze(0)
pred_bboxes, pred_scores = output_to_bboxes(output, image.shape[1:], threshold=best_thr, nms_threshold=0.5, heatmap_kernel=3)
plot_image_bounding_boxes(image, pred_bboxes, bboxes_gt)

Heatmap of predicted scores

In [74]:
plot_image(TF.to_pil_image(torch.sigmoid(output[..., 4])))