Dataset: Embrapa Wine Grape Instance Segmentation Dataset
Download:
git clone https://github.com/thsant/wgisd.git
Imports
from pathlib import Path
from collections import defaultdict
import math
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import ops, tv_tensors, io
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as TF
Configuration
WGISD_ROOT = Path('./wgisd/')
DATA_DIR = WGISD_ROOT / 'data'
TRAIN_LIST = WGISD_ROOT / 'train.txt'
TEST_LIST = WGISD_ROOT / 'test.txt'
MODELS_DIR = Path('./models')
VAL_RATIO = 0.1
IMAGE_SIZE = 512
NUM_WORKERS = 8
BATCH_SIZE = 16
EPOCHS = 1000
LEARNING_RATE = 1e-2
WEIGHT_DECAY = 1e-2
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", DEVICE)
device: cuda
def get_file_list(list_name):
with open(list_name) as f:
file_list = f.readlines()
file_list = [file_name.strip() for file_name in file_list]
return file_list
def read_annotation(annotation_path):
bboxes = []
try:
with open(annotation_path) as f:
for line in f:
label, x_center, y_center, width, height = [float(s) for s in line.split()]
x1 = max(x_center - 0.5 * width, 0.)
y1 = max(y_center - 0.5 * height, 0.)
x2 = min(x_center + 0.5 * width, 1.)
y2 = min(y_center + 0.5 * height, 1.)
if x1 < x2 and y1 < y2:
bbox = [x1, y1, x2, y2]
bboxes.append(bbox)
else:
print("Invalid bounding box", annotation_path)
except FileNotFoundError:
print("Annotation missing:", annotation_path)
return bboxes
def denormalize_bounding_boxes(bboxes, shape):
w, h = shape
new_bboxes = []
for bbox in bboxes:
x1, y1, x2, y2 = bbox
new_bbox = [x1 * w, y1 * h, x2 * w, y2 * h]
new_bboxes.append(new_bbox)
return new_bboxes
def read_bounding_boxes(file_list):
images_data = []
for name in file_list:
image_name = name + '.jpg'
image_path = DATA_DIR / image_name
image = Image.open(image_path)
shape = image.size
annotation_name = name + '.txt'
annotation_path = DATA_DIR / annotation_name
bboxes = read_annotation(annotation_path)
bboxes = denormalize_bounding_boxes(bboxes, shape)
data = {'file_name': image_path, 'bboxes': bboxes, 'shape': shape}
images_data.append(data)
return images_data
train_val_list = get_file_list(TRAIN_LIST)
test_list = get_file_list(TEST_LIST)
train_list, val_list = train_test_split(train_val_list, test_size=VAL_RATIO, random_state=0,
shuffle=True)
print("Number of train images:", len(train_list))
print("Number of validation images:", len(val_list))
print("Number of train+validation images:", len(train_list) + len(val_list))
print("Number of test images:", len(test_list))
print("Total number of images:", len(train_list) + len(val_list) + len(test_list))
Number of train images: 217 Number of validation images: 25 Number of train+validation images: 242 Number of test images: 58 Total number of images: 300
train_data = read_bounding_boxes(train_list)
val_data = read_bounding_boxes(val_list)
test_data = read_bounding_boxes(test_list)
Plotting
def plot_image(image):
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)
ax.set_axis_off()
ax.imshow(image)
plt.show()
def plot_image_bounding_boxes(image, bboxes, bboxes_gt=None):
image = TF.convert_image_dtype(image, dtype=torch.uint8)
image_bboxes = torchvision.utils.draw_bounding_boxes(image=image, boxes=bboxes, colors='red', width=2)
if bboxes_gt is not None:
image_bboxes = torchvision.utils.draw_bounding_boxes(image=image_bboxes, boxes=bboxes_gt, colors='blue', width=2)
image_bboxes = TF.to_pil_image(image_bboxes)
plot_image(image_bboxes)
train_transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), antialias=True),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomAffine(
degrees=10.,
translate=(0.1, 0.1),
scale=(0.9, 1.1),
shear=8.
),
transforms.RandomPhotometricDistort(p=1),
transforms.ClampBoundingBoxes(),
transforms.SanitizeBoundingBoxes(labels_getter=None),
transforms.ToDtype(torch.float32, scale=True),
])
val_transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), antialias=True),
transforms.ToDtype(torch.float32, scale=True),
])
class ImagesDataset(torch.utils.data.Dataset):
def __init__(self, images_data, transform):
self.images_data = images_data
self.transform = transform
def __len__(self):
return len(self.images_data)
def __getitem__(self, idx):
img_data = self.images_data[idx]
file_path = img_data['file_name']
image = tv_tensors.Image(io.read_image(str(file_path)))
bboxes = img_data['bboxes']
if len(bboxes) == 0: bboxes = torch.zeros((0, 4), dtype=torch.float32)
bboxes = tv_tensors.BoundingBoxes(bboxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=image.shape[1:])
image, bboxes = self.transform(image, bboxes)
return image, bboxes
def show_image(self, idx):
image, bboxes = self[idx]
plot_image_bounding_boxes(image, bboxes)
train_dset = ImagesDataset(train_data, train_transform)
val_dset = ImagesDataset(val_data, val_transform)
test_dset = ImagesDataset(test_data, val_transform)
val_dset.show_image(0)
Normalize bounding boxes
def normalize_bboxes(bboxes):
h, w = bboxes.canvas_size
wh = torch.tensor([w, h])
bboxes = torch.cat((bboxes[:, :2] / wh, bboxes[:, 2:] / wh), axis=1)
return bboxes
Pad ground truth bounding boxes to allow formation of a batch tensor.
def extend_tensor(t, max_len):
l = len(t)
if l < max_len:
z = torch.zeros((max_len - l,) + t.shape[1:])
return torch.cat((t, z), dim=0)
else:
return t
def collate_fn(batch):
images, bboxes = zip(*batch)
images_tens = torch.stack(images, dim=0)
max_bboxes = max(len(bbs) for bbs in bboxes)
bboxes_tens = [extend_tensor(normalize_bboxes(bbs), max_bboxes) for bbs in bboxes]
bboxes_tens = torch.stack(bboxes_tens, dim=0)
return images_tens, bboxes_tens
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=BATCH_SIZE, shuffle=True,
collate_fn=collate_fn,
num_workers=NUM_WORKERS, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=BATCH_SIZE, shuffle=False,
collate_fn=collate_fn,
num_workers=NUM_WORKERS, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=BATCH_SIZE, shuffle=False,
collate_fn=collate_fn,
num_workers=NUM_WORKERS, pin_memory=True)
For each cell in the output model proposes a bounding box with the center in that cell and a score.
class ConvBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, act=True):
padding = (kernel_size - 1) // 2
super().__init__(
nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(out_channels)
)
if act: self.append(nn.ReLU(inplace=True))
class DownBlock(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.MaxPool2d(2),
ConvBlock(in_channels, out_channels, 3),
ConvBlock(out_channels, out_channels, 3)
)
class Encoder(nn.Module):
def __init__(self, in_channels, channels, num_downsamplings):
super().__init__()
self.stem = ConvBlock(in_channels, channels, 3)
self.blocks = nn.ModuleList()
in_channels = channels
for _ in range(num_downsamplings):
out_channels = in_channels * 2
self.blocks.append(DownBlock(in_channels, out_channels))
in_channels = out_channels
def forward(self, x):
x = self.stem(x)
xs = []
for block in self.blocks:
xs.append(x)
x = block(x)
return x, xs
The decoder is inspired by Feature Pyramid Network (FPN), arXiv:1612.03144 [cs.CV]
class UpBlock(nn.Module):
def __init__(self, in_channels, out_channels, p_drop=0.):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv = ConvBlock(in_channels, out_channels, 1, act=False)
self.gamma = nn.Parameter(torch.tensor(0.))
self.p_drop = p_drop
def forward(self, x, skip):
if self.training and self.p_drop > 0.:
keep_prob = 1. - self.p_drop
add_branch = torch.bernoulli(torch.tensor(keep_prob))
else:
keep_prob = 1.
add_branch = True
out = self.up(x)
if add_branch:
out = out + self.gamma / keep_prob * self.conv(skip)
return out
class Head(nn.Sequential):
def __init__(self, in_channels, out_channels):
super().__init__(
nn.ReLU(inplace=True),
ConvBlock(in_channels, out_channels, 3)
)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_upsamplings, p_drop_path=0.):
super().__init__()
self.conv = ConvBlock(in_channels, out_channels, 1, act=False)
self.up = nn.ModuleList()
for _ in range(num_upsamplings):
in_channels = in_channels // 2
self.up.append(UpBlock(in_channels, out_channels, p_drop_path))
self.head = Head(out_channels, out_channels)
def forward(self, x, xs):
x = self.conv(x)
for up_layer, skip in zip(self.up, reversed(xs)):
x = up_layer(x, skip)
x = self.head(x)
return x
class DetectionHead(nn.Module):
def __init__(self, channels, p_drop=0.):
super().__init__()
self.proj = nn.Sequential(
nn.Dropout(p_drop),
nn.Conv2d(channels, 5, kernel_size=1)
)
def forward(self, x):
out = self.proj(x)
boxes = self.to_boxes(out)
return boxes
@staticmethod
def to_boxes(out):
h, w = out.shape[2:]
coords = torch.sigmoid(out[:, :4])
grid_x = torch.arange(w, device=out.device).unsqueeze(0)
grid_y = torch.arange(h, device=out.device).unsqueeze(1)
cx = (coords[:, 0] + grid_x) / w
cy = (coords[:, 1] + grid_y) / h
pred_w = coords[:, 2]
pred_h = coords[:, 3]
x1 = cx - 0.5 * pred_w
y1 = cy - 0.5 * pred_h
x2 = cx + 0.5 * pred_w
y2 = cy + 0.5 * pred_h
scores = out[:, 4]
boxes = torch.stack((x1, y1, x2, y2, scores), dim=3)
return boxes
class Net(nn.Module):
def __init__(self, num_downsamplings, num_upsamplings, channels, in_channels, out_channels,
decoder_p_drop_path=0., head_p_drop=0.):
super().__init__()
encoder_out_channels = channels * 2**num_downsamplings
self.encoder = Encoder(in_channels, channels, num_downsamplings)
self.decoder = Decoder(encoder_out_channels, out_channels, num_upsamplings, decoder_p_drop_path)
self.head = DetectionHead(out_channels, head_p_drop)
def forward(self, x):
x, xs = self.encoder(x)
out = self.decoder(x, xs)
boxes = self.head(out)
return boxes
def init_model(model):
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.)
nn.init.zeros_(m.bias)
elif isinstance(m, UpBlock):
nn.init.zeros_(m.gamma)
model = Net(num_downsamplings=5, num_upsamplings=3,
channels=24, in_channels=3, out_channels=128,
decoder_p_drop_path=0.1, head_p_drop=0.)
init_model(model)
model.to(DEVICE);
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))
Number of parameters: 10,946,816
def batch_box_area(boxes):
return (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])
def batch_box_scores(boxes_gt, # shape: (b, n_gt, 4)
boxes_pred, # shape: (b, n_pred, 4)
eps=1e-7):
lt = torch.max(boxes_gt[..., None, :2], boxes_pred[..., None, :, :2])
rb = torch.min(boxes_gt[..., None, 2:], boxes_pred[..., None, :, 2:])
wh = (rb - lt).clamp(min=0)
inter = wh[..., 0] * wh[..., 1] # shape: (b, n_gt, n_pred)
area_gt = batch_box_area(boxes_gt)[..., None]
area_pred = batch_box_area(boxes_pred)[..., None, :]
scores = (2 * inter - area_pred) / (area_gt + eps)
return scores
class DetectionLoss(nn.Module):
def __init__(self, λ_conf=1.0, λ_geom=1.0):
super().__init__()
self.λ_conf = λ_conf
self.λ_geom = λ_geom
self.conf_loss = nn.BCEWithLogitsLoss()
def forward(self, output, target):
if target.size(1) == 0:
target = torch.zeros((target.size(0), 1, 4), device=target.device)
output = output.flatten(1, 2)
pred_bboxes = output[..., :4]
pred_scores = output[..., 4]
true_scores = batch_box_scores(target, pred_bboxes)
true_scores, _ = torch.max(true_scores, dim=1)
mask = true_scores.clamp(min=0.).detach()
conf_loss = self.conf_loss(pred_scores, mask)
geom_loss = torch.mean(mask * (1. - true_scores))
loss = self.λ_conf * conf_loss + self.λ_geom * geom_loss
return loss
Output preprocessing
def heatmap_peaks(heat, kernel=3):
pad = (kernel - 1) // 2
hmax = nn.functional.max_pool2d(heat.unsqueeze(1), (kernel, kernel), stride=1, padding=pad).squeeze(1)
peaks = (hmax == heat)
return peaks
def remove_zero_boxes(boxes):
mask = ~ (boxes == 0.).all(dim=-1)
return boxes[mask]
class AverageLoss():
def __init__(self):
self.name = "loss"
self.reset()
def reset(self):
self.num_samples = 0
self.total_loss = 0.
def update(self, data):
batch_size = data['batch_size']
self.num_samples += batch_size
self.total_loss += batch_size * data[self.name]
def compute(self):
avg_loss = self.total_loss / self.num_samples
metrics = {self.name: avg_loss}
return metrics
class DetectionMetrics():
def __init__(self, threshold=0.5, iou_threshold=0.5, nms_threshold=0.5, heatmap_kernel=3):
self.threshold = threshold
self.iou_threshold = iou_threshold
self.nms_threshold = nms_threshold
self.heatmap_kernel = heatmap_kernel
self.reset()
def reset(self):
self._num_true = 0
def update(self, data):
outputs = data['outputs']
true_bboxes = data['targets']
logits = outputs[..., 4]
scores = torch.sigmoid(logits)
pred_bboxes = outputs[..., :4]
peaks = heatmap_peaks(scores, kernel=self.heatmap_kernel)
masks = peaks & (scores > self.threshold)
for true_boxes, pred_boxes, conf, mask in zip(true_bboxes, pred_bboxes, scores, masks):
true_boxes = remove_zero_boxes(true_boxes)
pred_boxes = pred_boxes[mask]
conf = conf[mask]
keep_idx = ops.nms(pred_boxes, conf, self.nms_threshold)
pred_boxes = pred_boxes[keep_idx]
conf = conf[keep_idx]
if len(pred_boxes) > 0 and len(true_boxes) > 0:
ious = ops.box_iou(pred_boxes, true_boxes)
# zero all non_maximum values for a given ground truth box
idx = ious.argmax(dim=0, keepdims=True)
ious = torch.zeros_like(ious).scatter_(0, idx, ious.gather(0, idx))
is_true_positive = ious > self.iou_threshold
is_true_positive = is_true_positive.any(dim=1)
else:
is_true_positive = torch.zeros_like(conf, dtype=torch.bool)
self._num_true += len(true_boxes)
self.process_batch(is_true_positive, conf)
Simple F1 score for evaluation during training
class F1(DetectionMetrics):
def __init__(self, threshold=0.5, iou_threshold=0.5, nms_threshold=0.5, heatmap_kernel=3):
super().__init__(threshold=threshold, iou_threshold=iou_threshold, nms_threshold=nms_threshold,
heatmap_kernel=heatmap_kernel)
self.name = "F1"
self.reset()
def reset(self):
self._num_pred = 0
self._num_tp = 0
super().reset()
def process_batch(self, is_true_positive, conf):
self._num_pred += len(conf)
self._num_tp += is_true_positive.sum().item()
def compute(self):
prec = self._num_tp / (self._num_pred + 1e-7)
rec = self._num_tp / (self._num_true + 1e-7)
f1 = 2 * prec * rec / (prec + rec + 1e-7)
metrics = {self.name: f1}
return metrics
Average precsion AP using all-point interpolation
class AveragePrecision(DetectionMetrics):
def __init__(self, iou_threshold=0.5, nms_threshold=0.5, heatmap_kernel=3):
super().__init__(threshold=0., iou_threshold=iou_threshold, nms_threshold=nms_threshold,
heatmap_kernel=heatmap_kernel)
self.name = "AP"
self.reset()
def reset(self):
self._tp_list = []
self._confidence_list = []
super().reset()
def process_batch(self, is_true_positive, conf):
self._confidence_list.append(conf)
tp = torch.zeros_like(conf)
tp[is_true_positive] = 1.
self._tp_list.append(tp)
def compute(self):
conf = torch.cat(self._confidence_list)
tp = torch.cat(self._tp_list)
# sort true positives according to confidence
idx = conf.argsort(descending=True)
tp = tp[idx]
# cumulative sum of true positives and false positives
tpc = tp.cumsum(0)
fpc = (1. - tp).cumsum(0)
# precision and recall curves
recall = tpc / (self._num_true + 1e-7)
precision = tpc / (tpc + fpc)
# append sentinel values to beginning and end
z = torch.zeros(1, device=recall.device)
o = torch.ones(1, device=recall.device)
recall = torch.cat([z, recall, o])
precision = torch.cat([o, precision, z])
# compute precision envelope
precision = precision.flip(0)
precision, _ = precision.cummax(0)
precision = precision.flip(0)
# integrate area under curve
idx = (recall[1:] != recall[:-1]).nonzero(as_tuple=True)[0] # indexes where recall changes
ap = ((recall[idx + 1] - recall[idx]) * precision[idx + 1]).sum().item() # area under curve
metrics = {self.name: ap}
return metrics
def iterate(step_fn, loader, metrics_list):
for metric in metrics_list:
metric.reset()
for x, y in loader:
x = x.to(DEVICE)
if type(y) == list or type(y) == tuple:
y = [yi.to(DEVICE) for yi in y]
else:
y = y.to(DEVICE)
loss, out = step_fn(x, y)
data = {"loss": loss.item(),
"batch_size": len(x),
"targets": y,
"outputs": out}
for metric in metrics_list:
metric.update(data)
metrics = {}
for metric in metrics_list:
metrics.update(metric.compute())
return metrics
def train(model, loss_fn, optimizer, loader, batch_scheduler, metrics_list):
def train_step(x, y):
out = model(x)
loss = loss_fn(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_scheduler.step()
return loss.detach(), out.detach()
model.train()
metrics = iterate(train_step, loader, metrics_list)
return metrics
def evaluate(model, loss_fn, loader, metrics_list):
def eval_step(x, y):
out = model(x)
loss = loss_fn(out, y)
return loss.detach(), out.detach()
model.eval()
with torch.inference_mode():
metrics = iterate(eval_step, loader, metrics_list)
return metrics
def update_history(history, metrics, name):
for key, val in metrics.items():
history[name + ' ' + key].append(val)
def history_plot_train_val(history, key):
fig = plt.figure()
ax = fig.add_subplot(111)
xs = np.arange(1, len(history['train ' + key]) + 1)
ax.plot(xs, history['train ' + key], '.-', label='train')
ax.plot(xs, history['val ' + key], '.-', label='val')
ax.set_xlabel('epoch')
ax.set_ylabel(key)
ax.legend()
ax.grid()
plt.show()
loss = DetectionLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE,
steps_per_epoch=len(train_loader), epochs=EPOCHS)
history = defaultdict(list)
metrics = [AverageLoss(), F1()]
pbar = trange(EPOCHS, ncols=140)
for epoch in pbar:
train_metrics = train(model, loss, optimizer, train_loader, lr_scheduler, metrics)
update_history(history, train_metrics, "train")
val_metrics = evaluate(model, loss, val_loader, metrics)
update_history(history, val_metrics, "val")
pbar.set_postfix({"train F1": f"{train_metrics['F1']:.3f}", "val F1": f"{val_metrics['F1']:.3f}"})
100%|███████████████████████████████████████████████████████████████████| 1000/1000 [2:07:48<00:00, 7.67s/it, train F1=0.947, val F1=0.783]
torch.save(model.state_dict(), str(MODELS_DIR / 'final_model.pt'))
history_plot_train_val(history, 'loss')
history_plot_train_val(history, 'F1')
model.load_state_dict(torch.load(str(MODELS_DIR / 'final_model.pt')))
<All keys matched successfully>
model.eval();
thresholds = np.linspace(0., 1.0, num=50)
f1_vs_thr = [evaluate(model, loss, val_loader, [F1(threshold=thr)])['F1'] for thr in tqdm(thresholds, ncols=140)]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:00<00:00, 1.21s/it]
plt.plot(thresholds, f1_vs_thr);
plt.grid();
num = np.argmax(f1_vs_thr)
best_thr = thresholds[num]
best_f1 = f1_vs_thr[num]
print(f"Best F1 {best_f1:.3f} for threshold {best_thr:.3f}")
Best F1 0.811 for threshold 0.633
test_f1 = evaluate(model, loss, test_loader, [F1(threshold=best_thr)])['F1']
print(f"Test F1: {test_f1:.3f}")
Test F1: 0.854
test_ap = evaluate(model, loss, test_loader, [AveragePrecision()])['AP']
print(f"Test AP: {test_ap:.3f}")
Test AP: 0.833
Convert from scaled tensors to unscaled bounding boxes
def denormalize_bboxes(bboxes, shape):
wh = torch.tensor([shape[1], shape[0]])
bboxes = torch.cat((bboxes[:, :2] * wh, bboxes[:, 2:] * wh), axis=1)
return bboxes
def output_to_bboxes(output, shape, threshold=0.5, nms_threshold=0.5, heatmap_kernel=3):
logits = output[..., 4]
scores = torch.sigmoid(logits)
pred_boxes = output[..., :4]
peaks = heatmap_peaks(scores.unsqueeze(0), kernel=heatmap_kernel).squeeze(0)
mask = peaks & (scores > threshold)
pred_boxes = pred_boxes[mask]
scores = scores[mask]
keep_idx = ops.nms(pred_boxes, scores, nms_threshold)
pred_boxes = pred_boxes[keep_idx]
scores = scores[keep_idx]
bboxes = denormalize_bboxes(pred_boxes, shape)
return bboxes, scores
image, bboxes_gt = test_dset[0]
with torch.inference_mode():
output = model(image.unsqueeze(0).to(DEVICE))
output = output.detach().cpu().squeeze(0)
pred_bboxes, pred_scores = output_to_bboxes(output, image.shape[1:], threshold=best_thr, nms_threshold=0.5, heatmap_kernel=3)
plot_image_bounding_boxes(image, pred_bboxes, bboxes_gt)
Heatmap of predicted scores
plot_image(TF.to_pil_image(torch.sigmoid(output[..., 4])))