Wine grape detection using simple YOLO-like dtector¶

Simple YOLO-like model for object detection.

YOLO detectors:

  • YOLOv1 arXiv:1506.02640 [cs.CV]
  • YOLOv2 arXiv:1612.08242 [cs.CV]
  • YOLOv3 arXiv:1804.02767 [cs.CV]
  • YOLOv4 arXiv:2004.10934 [cs.CV]
  • YOLOv5 https://github.com/ultralytics/yolov5

Dataset: Embrapa Wine Grape Instance Segmentation Dataset

Download:

git clone https://github.com/thsant/wgisd.git

Configuration¶

Imports

In [1]:
import warnings
In [2]:
warnings.simplefilter('ignore')
In [3]:
from pathlib import Path
from functools import partial
from collections import defaultdict
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt

import tqdm
import tqdm.autonotebook
tqdm.autonotebook.tqdm = tqdm.tqdm # hack to force ASCII output everywhere
from tqdm import tqdm

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
from torchvision import ops
import torchvision.transforms.functional as TF

import albumentations as A

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
import ignite.metrics
import ignite.contrib.handlers

Configuration

In [4]:
WGISD_ROOT = Path('./wgisd/')
DATA_DIR = WGISD_ROOT / 'data'
TRAIN_LIST = WGISD_ROOT / 'train.txt'
TEST_LIST = WGISD_ROOT / 'test.txt'

MODELS_DIR = Path('./models')

VAL_RATIO = 0.1

IMAGE_SIZE = 512

NUM_WORKERS = 8
BATCH_SIZE = 32
EPOCHS = 1000

LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-2
In [5]:
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda

Loss¶

In [6]:
class DetectionLoss(nn.Module):
    def __init__(self, λ_conf=1.0, λ_geom=1.0):
        super().__init__()
        self.λ_conf = λ_conf
        self.λ_geom = λ_geom
        self.conf_loss = nn.BCELoss()
        self.geom_loss = nn.MSELoss()
    
    def forward(self, output, target):
        pred_bboxes = output[..., :4]
        scores = output[..., 4]
        
        gt_mask, gt_bboxes = self.create_gt(scores, target)
        conf_loss = self.conf_loss(scores, gt_mask)
        
        pred_bboxes = pred_bboxes[gt_mask > 0]    
        geom_loss = self.geom_loss(pred_bboxes, gt_bboxes)
        
        loss = self.λ_conf * conf_loss + self.λ_geom * geom_loss
        return loss
    
    @staticmethod
    def create_gt(scores, target):
        gt_mask = torch.zeros_like(scores)
        wh = torch.tensor([scores.shape[-1], scores.shape[-2]], device=scores.device)
        
        sorted_boxes = []
        for true_boxes, mask in zip(target, gt_mask):
            # set the cells of the mask containing the centers of the bounding boxes to 1
            cxy = (0.5 * (true_boxes[:, 2:] + true_boxes[:, :2]) * wh).to(torch.long)
            mask[cxy[:, 1], cxy[:, 0]] = 1.
            
            # sort ground truth bounding boxes according to their center coordinates
            num = cxy[:, 1] * wh[1] + cxy[:, 0]
            num_sorted, idx = num.sort()
            # eliminate possible duplicate values due to the presence of several bounding boxes with the same center
            idx_mask = torch.empty_like(idx, dtype=torch.bool)
            idx_mask[0] = True
            idx_mask[1:] = num_sorted[1:] != num_sorted[:-1]
            idx = idx[idx_mask]
            
            sorted_boxes.append(true_boxes[idx])
        
        gt_bboxes = torch.cat(sorted_boxes, dim=0)
        
        return gt_mask, gt_bboxes

Model¶

Backbone¶

In [7]:
class ConvBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, act=True):
        padding = (kernel_size - 1) // 2
        layers = [
          nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
          nn.BatchNorm2d(out_channels)
        ]
        if act: layers.append(nn.ReLU(inplace=True))
        super().__init__(*layers)
In [8]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, p_drop=0.):
        super().__init__()
        self.residual = nn.Sequential(
            ConvBlock(in_channels, out_channels, stride=stride),
            ConvBlock(out_channels, out_channels, act=False),
            nn.Dropout(p_drop)
        )
        self.shortcut = self.get_shortcut(in_channels, out_channels, stride)
        self.act = nn.ReLU(inplace=True)
        self.gamma = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        out = self.shortcut(x) + self.gamma * self.residual(x)
        return self.act(out)
    
    def get_shortcut(self, in_channels, out_channels, stride):
        if in_channels != out_channels:
            shortcut = ConvBlock(in_channels, out_channels, 1, act=False)
            if stride > 1:
                shortcut = nn.Sequential(nn.AvgPool2d(stride), shortcut)
        elif stride > 1:
            shortcut = nn.AvgPool2d(stride)
        else:
            shortcut = nn.Identity()
        return shortcut
In [9]:
class Body(nn.Sequential):
    def __init__(self, in_channels, channel_list, num_blocks_list, strides, res_p_drop=0.):
        layers = []
        for out_channels, num_blocks, stride in zip(channel_list, num_blocks_list, strides):
            for _ in range(num_blocks):
                layers.append(ResidualBlock(in_channels, out_channels, stride, res_p_drop))
                in_channels = out_channels
                stride = 1
        
        super().__init__(*layers)
In [10]:
class Stem(nn.Sequential):
    def __init__(self, in_channels, channel_list, stride):
        layers = []
        for out_channels in channel_list:
            layers.append(ConvBlock(in_channels, out_channels, 3, stride=stride))
            in_channels = out_channels
            stride = 1
        super().__init__(*layers)
In [11]:
class ResNet(nn.Sequential):
    def __init__(self, num_blocks_list, channel_list, strides, in_channels=3, res_p_drop=0.):
        super().__init__(
            Stem(in_channels, [32, 32, 64], strides[0]),
            Body(64, channel_list, num_blocks_list, strides[1:], res_p_drop),
        )

Full model¶

In [12]:
class Net(nn.Module):
    def __init__(self, num_blocks_list, channel_list, strides, in_channels=3):
        super().__init__()
        self.backbone = ResNet(num_blocks_list, channel_list, strides, in_channels)
        self.head = nn.Conv2d(channel_list[-1], 5, kernel_size=1)

    def forward(self, x):
        out = self.backbone(x)
        out = self.head(out)
        out = torch.sigmoid(out)
        boxes = self.to_boxes(out)
        return boxes
        
    def to_boxes(self, out):
        h, w = out.shape[2:]

        grid_x = torch.arange(w, device=out.device).unsqueeze(0)
        grid_y = torch.arange(h, device=out.device).unsqueeze(1)
        cx = (out[:, 0] + grid_x) / w
        cy = (out[:, 1] + grid_y) / h
        pred_w = out[:, 2]
        pred_h = out[:, 3]
        
        x1 = cx - 0.5 * pred_w
        y1 = cy - 0.5 * pred_h
        x2 = cx + 0.5 * pred_w
        y2 = cy + 0.5 * pred_h
        
        scores = out[:, 4]
        boxes = torch.stack((x1, y1, x2, y2, scores), dim=3)
        return boxes
In [13]:
def init_model(model):
    for m in model.modules():
        if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.kaiming_normal_(m.weight)
            if m.bias is not None: nn.init.zeros_(m.bias)
In [14]:
model = Net(num_blocks_list=[2, 2, 2, 2], channel_list=[64, 128, 256, 512], strides=[2, 2, 2, 2, 2]).to(DEVICE)
In [15]:
init_model(model)
In [16]:
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 11,198,317

Data¶

Reading of bounding box information¶

In [19]:
def get_file_list(list_name):
    with open(list_name) as f:
        file_list = f.readlines()
    file_list = [file_name.strip() for file_name in file_list]
    return file_list
In [20]:
def read_annotation(annotation_path):
    bboxes = []
    try:
        with open(annotation_path) as f:
            for line in f:
                label, x_center, y_center, width, height = [float(s) for s in line.split()]
                x1 = max(x_center - 0.5 * width, 0.)
                y1 = max(y_center - 0.5 * height, 0.)
                x2 = min(x_center + 0.5 * width, 1.)
                y2 = min(y_center + 0.5 * height, 1.)
                if x1 < x2 and y1 < y2:
                    bbox = [x1, y1, x2, y2]
                    bboxes.append(bbox)
                else:
                    print("Invalid bounding box", annotation_path)
    
    except FileNotFoundError:
        print("Annotation missing:", annotation_path)

    return bboxes
In [21]:
def read_bounding_boxes(file_list):
    images_data = []
    for name in file_list:
        image_name = name + '.jpg'
        annotation_name = name + '.txt'
        image_path = DATA_DIR / image_name
        annotation_path = DATA_DIR / annotation_name
        bboxes = read_annotation(annotation_path)
        data = {'file_name': image_path, 'bboxes': bboxes}
        images_data.append(data)
            
    return images_data
In [22]:
train_val_list = get_file_list(TRAIN_LIST)
In [23]:
test_list = get_file_list(TEST_LIST)
In [24]:
train_list, val_list = train_test_split(train_val_list, test_size=VAL_RATIO, random_state=0,
                                        shuffle=True)
In [25]:
print("Number of train images:", len(train_list))
print("Number of validation images:", len(val_list))
print("Number of train+validation images:", len(train_list) + len(val_list))
print("Number of test images:", len(test_list))
print("Total number of images:", len(train_list) + len(val_list) + len(test_list))
Number of train images: 217
Number of validation images: 25
Number of train+validation images: 242
Number of test images: 58
Total number of images: 300
In [26]:
train_data = read_bounding_boxes(train_list)
val_data = read_bounding_boxes(val_list)
test_data = read_bounding_boxes(test_list)

Utilities¶

Convert from scaled tensors to unscaled bounding boxes

In [27]:
def denormalize_bboxes(bboxes, shape):
    wh = torch.tensor([shape[1], shape[0]])
    bboxes = torch.cat((bboxes[:, :2] * wh, bboxes[:, 2:] * wh), axis=1)
    return bboxes

Plotting

In [28]:
def plot_image(image):
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)
    ax.set_axis_off()
    ax.imshow(image)
    plt.show()
In [29]:
def plot_image_bounding_boxes(image, bboxes, bboxes_gt=None):
    image = TF.convert_image_dtype(image, dtype=torch.uint8)
    image_bboxes = torchvision.utils.draw_bounding_boxes(image=image, boxes=bboxes, colors='red', width=2)
    if bboxes_gt is not None:
        image_bboxes = torchvision.utils.draw_bounding_boxes(image=image_bboxes, boxes=bboxes_gt, colors='blue', width=2)
    image_bboxes = TF.to_pil_image(image_bboxes)
    plot_image(image_bboxes)

Dataset¶

In [30]:
class ImagesDataset(torch.utils.data.Dataset):
    def __init__(self, images_data, transform=None):
        self.images_data = images_data
        self.transform = transform
    
    def __len__(self):
        return len(self.images_data)
    
    def __getitem__(self, idx):
        img_data = self.images_data[idx]
        file_path = img_data['file_name']
        image = Image.open(file_path)
        image = image.convert('RGB')
        image = np.array(image)
        
        bboxes = img_data['bboxes']
        labels = [0] * len(bboxes)
        
        if self.transform is not None:
            transformed = self.transform(image=image, bboxes=bboxes, labels=labels)
            image, bboxes = transformed['image'], transformed['bboxes']
        
        if len(bboxes) > 0:
            bboxes = torch.tensor(bboxes, dtype=torch.float32)
        else:
            bboxes = torch.zeros((0, 4), dtype=torch.float32)
        
        image = TF.to_tensor(image)
        
        return image, bboxes
    
    def show_image(self, idx):
        image, bboxes = self[idx]
        bboxes = denormalize_bboxes(bboxes, image.shape[1:])
        plot_image_bounding_boxes(image, bboxes)
In [31]:
train_transform = A.Compose([
    A.Resize(width=IMAGE_SIZE, height=IMAGE_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Affine(
        scale=(0.9, 1.1),
        translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)},
        rotate=(-10, 10),
        shear=(-8, 8),
        mode=cv2.BORDER_REFLECT,
        p=1
    ),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1)
], bbox_params=A.BboxParams(format='albumentations', min_visibility=0.1, label_fields=['labels']))
In [32]:
val_transform = A.Compose([
    A.Resize(width=IMAGE_SIZE, height=IMAGE_SIZE)
], bbox_params=A.BboxParams(format='albumentations', label_fields=['labels']))
In [33]:
train_dset = ImagesDataset(train_data, train_transform)
In [34]:
val_dset = ImagesDataset(val_data, val_transform)
In [35]:
test_dset = ImagesDataset(test_data, val_transform)
In [36]:
val_dset.show_image(0)

DataLoader¶

In [37]:
def collate_fn(batch):
    images, bboxes = zip(*batch)
    images_tens = torch.stack(images, dim=0)
    return images_tens, bboxes
In [38]:
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
                                           collate_fn=collate_fn,
                                           num_workers=NUM_WORKERS, pin_memory=True)

val_loader = torch.utils.data.DataLoader(val_dset, batch_size=BATCH_SIZE, shuffle=False,
                                         collate_fn=collate_fn,
                                         num_workers=NUM_WORKERS, pin_memory=True)
In [39]:
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
                                          collate_fn=collate_fn,
                                          num_workers=NUM_WORKERS, pin_memory=True)

Metrics¶

In [40]:
class DetectionMetrics(ignite.metrics.Metric):
    def __init__(self, threshold=0.5, iou_threshold=0.5, nms_threshold=0.5,
                 output_transform=lambda x: x, device="cpu"):
        self.threshold = threshold
        self.iou_threshold = iou_threshold
        self.nms_threshold = nms_threshold
        super().__init__(output_transform=output_transform, device=device)
    
    def reset(self):
        self._num_true = 0
        super().reset()
    
    def update(self, data):
        outputs, targets = data[0].detach(), data[1]
        
        scores = outputs[..., 4]
        pred_bboxes = outputs[..., :4]
        
        masks = (scores > self.threshold)
        
        for true_boxes, pred_boxes, conf, mask in zip(targets, pred_bboxes, scores, masks):
            pred_boxes = pred_boxes[mask]
            conf = conf[mask]
            
            keep_idx = ops.nms(pred_boxes, conf, self.nms_threshold)
            pred_boxes = pred_boxes[keep_idx]
            conf = conf[keep_idx]
            
            if len(pred_boxes) > 0 and len(true_boxes) > 0:
                ious = ops.box_iou(pred_boxes, true_boxes)
                
                # zero all non_maximum values for a given ground truth box
                idx = ious.argmax(dim=0, keepdims=True)
                ious = torch.zeros_like(ious).scatter_(0, idx, ious.gather(0, idx))
                
                is_true_positive = ious > self.iou_threshold
                is_true_positive = is_true_positive.any(dim=1)
            else:
                is_true_positive = torch.zeros_like(conf, dtype=torch.bool)
            
            self._num_true += len(true_boxes)
            self.process_batch(is_true_positive, conf)

Simple F1 score for evaluation during training

In [41]:
class F1(DetectionMetrics):
    def reset(self):
        self._num_pred = 0
        self._num_tp = 0
        super().reset()
    
    def process_batch(self, is_true_positive, conf):
        self._num_pred += len(conf)
        self._num_tp += is_true_positive.sum().item()
    
    def compute(self):
        prec = self._num_tp / (self._num_pred + 1e-7)
        rec = self._num_tp / (self._num_true + 1e-7)
        f1 = 2 * prec * rec / (prec + rec + 1e-7)
        return f1

Average precsion AP using all-point interpolation

In [42]:
class AveragePrecision(DetectionMetrics):
    def __init__(self, iou_threshold=0.5, nms_threshold=0.5,
                 output_transform=lambda x: x, device="cpu"):
        super().__init__(threshold=0., iou_threshold=iou_threshold, nms_threshold=nms_threshold,
                         output_transform=output_transform, device=device)
    
    def reset(self):
        self._tp_list = []
        self._confidence_list = []
        super().reset()
    
    def process_batch(self, is_true_positive, conf):
        self._confidence_list.append(conf)
        tp = torch.zeros_like(conf)
        tp[is_true_positive] = 1.
        self._tp_list.append(tp)
    
    def compute(self):
        conf = torch.cat(self._confidence_list)
        tp = torch.cat(self._tp_list)
        
        # sort true positives according to confidence
        idx = conf.argsort(descending=True)
        tp = tp[idx]
        
        # cumulative sum of true positives and false positives
        tpc = tp.cumsum(0)
        fpc = (1. - tp).cumsum(0)
        
        # precision and recall curves
        recall = tpc / (self._num_true + 1e-7)
        precision =  tpc / (tpc + fpc)
        
        # append sentinel values to beginning and end
        z = torch.zeros(1, device=self._device)
        o = torch.ones(1, device=self._device)
        recall = torch.cat([z, recall, o])
        precision = torch.cat([o, precision, z])
        
        # compute precision envelope
        precision = precision.flip(0)
        precision, _ = precision.cummax(0)
        precision = precision.flip(0)
        
        # integrate area under curve
        idx = (recall[1:] != recall[:-1]).nonzero(as_tuple=True)[0] # indexes where recall changes
        ap = ((recall[idx + 1] - recall[idx]) * precision[idx + 1]).sum().item() # area under curve
        
        return ap

Training¶

Setup Trainer¶

In [43]:
params = [p for p in model.parameters() if p.requires_grad]
In [44]:
optimizer = optim.AdamW(params, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
In [45]:
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
                                             steps_per_epoch=len(train_loader), epochs=EPOCHS)
In [46]:
loss = DetectionLoss()

Trainer

In [47]:
trainer = create_supervised_trainer(model, optimizer, loss, device=DEVICE)
In [48]:
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step());
In [49]:
ignite.metrics.RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
In [50]:
pbar = ignite.contrib.handlers.ProgressBar(persist=True, ncols=100)
In [51]:
pbar.attach(trainer, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED)

Evaluator

In [52]:
evaluator = create_supervised_evaluator(model,
                metrics={"F1": F1(device=DEVICE), "loss": ignite.metrics.Loss(loss)},
                device=DEVICE)
In [53]:
best_handler = ignite.handlers.Checkpoint({'model': model},
    ignite.handlers.DiskSaver(MODELS_DIR, create_dir=False, require_empty=False),
    n_saved=1, filename_prefix='best',
    score_function=ignite.handlers.Checkpoint.get_default_score_fn("F1"),
    score_name="val_F1",
    global_step_transform=ignite.handlers.global_step_from_engine(trainer)
)
In [54]:
evaluator.add_event_handler(Events.COMPLETED, best_handler);
In [55]:
history = defaultdict(list)
In [56]:
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    train_state = engine.state
    epoch = train_state.epoch
    max_epochs = train_state.max_epochs
    train_loss = train_state.metrics["loss"]
    history['train loss'].append(train_loss)
    
    evaluator.run(val_loader)
    val_metrics = evaluator.state.metrics
    val_loss = val_metrics["loss"]
    val_f1 = val_metrics["F1"]
    history['val loss'].append(val_loss)
    history['val F1'].append(val_f1)
    
    pbar.pbar.set_postfix({"train loss": f"{train_loss:.3f}",
                           "val loss": f"{val_loss:.3f}", "val F1": f"{val_f1:.3f}"})

Start training¶

In [57]:
trainer.run(train_loader, max_epochs=EPOCHS);
Epoch: [1000/1000] 100%|████████████, train loss=0.020, val loss=0.126, val F1=0.648 [1:24:16<00:00]
In [58]:
torch.save(model.state_dict(), str(MODELS_DIR / 'final_model.pt'))

Plotting¶

In [59]:
def show_or_save(fig, filename=None):
    if filename:
        fig.savefig(filename, bbox_inches='tight', pad_inches=0.05)
        plt.close(fig)
    else:
        plt.show()
In [60]:
def history_plot_train_val(history, key, filename=None):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    xs = np.arange(1, len(history['train ' + key]) + 1)
    ax.plot(xs, history['train ' + key], '.-', label='train')
    ax.plot(xs, history['val ' + key], '.-', label='val')
    ax.set_xlabel('epoch')
    ax.set_ylabel(key)
    ax.legend()
    ax.grid()
    show_or_save(fig, filename)
In [61]:
def history_plot(history, key, filename=None):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    xs = np.arange(1, len(history[key]) + 1)
    ax.plot(xs, history[key], '.-')
    ax.set_xlabel('epoch')
    ax.set_ylabel(key)
    ax.grid()
    show_or_save(fig, filename)
In [62]:
history_plot_train_val(history, 'loss')
In [63]:
history_plot(history, 'val F1')

Testing¶

In [64]:
def evaluate_metric(model, metric, loader, device, **kwargs):
    name = metric.__name__
    evaluator = create_supervised_evaluator(model,
                metrics={name: metric(device=device, **kwargs)},
                device=device)
    evaluator.run(loader)
    val_metric = evaluator.state.metrics[name]
    return val_metric
In [65]:
#model.load_state_dict(torch.load(str(MODELS_DIR / best_handler.last_checkpoint)))
model.load_state_dict(torch.load(str(MODELS_DIR / 'final_model.pt')))
Out[65]:
<All keys matched successfully>
In [66]:
model.eval();
In [67]:
thresholds = np.linspace(0., 1.0, num=50)
f1_vs_thr = [evaluate_metric(model, F1, val_loader, DEVICE, threshold=thr) for thr in tqdm(thresholds)]
100%|██████████| 50/50 [01:00<00:00,  1.20s/it]
In [68]:
plt.plot(thresholds, f1_vs_thr);
plt.grid();
In [69]:
num = np.argmax(f1_vs_thr)
best_thr = thresholds[num]
best_f1 = f1_vs_thr[num]
print(f"Best F1 {best_f1:.3f} for threshold {best_thr:.3f}")
Best F1 0.661 for threshold 0.163
In [70]:
test_f1 = evaluate_metric(model, F1, test_loader, DEVICE, threshold=best_thr)
print(f"Test F1: {test_f1:.3f}")
Test F1: 0.709
In [71]:
test_ap = evaluate_metric(model, AveragePrecision, test_loader, DEVICE)
print(f"Test AP: {test_ap:.3f}")
Test AP: 0.656

Example¶

In [72]:
def output_to_bboxes(output, shape, threshold=0.5, nms_threshold=0.5):
    scores = output[..., 4]
    pred_boxes = output[..., :4]
    
    mask = (scores > threshold)
    
    pred_boxes = pred_boxes[mask]
    scores = scores[mask]
    
    keep_idx = ops.nms(pred_boxes, scores, nms_threshold)
    pred_boxes = pred_boxes[keep_idx]
    scores = scores[keep_idx]
    
    bboxes = denormalize_bboxes(pred_boxes, shape)
    
    return bboxes, scores
In [73]:
image, bboxes_gt = test_dset[0]
In [77]:
with torch.no_grad():
    output = model(image.unsqueeze(0).to(DEVICE))
output = output.detach().cpu().squeeze(0)
pred_bboxes, pred_scores = output_to_bboxes(output, image.shape[1:], threshold=best_thr, nms_threshold=0.5)
true_bboxes = denormalize_bboxes(bboxes_gt, image.shape[1:])
plot_image_bounding_boxes(image, pred_bboxes, true_bboxes)

Heatmap of predicted scores

In [75]:
plot_image(TF.to_pil_image(output[..., 4]))