语义分割是一项计算机视觉任务,涉及对图像中的每一个像素进行分类和标记。与识别并围绕物体放置边框的对象检测不同,语义分割提供了对图像更细粒度的理解,能够在像素级别勾勒出对象边界。
我将使用U-Net结合MobileNet作为基准架构。
U-Net:
代码
导入库
import os
from glob import glob
import shutil
from pathlib import Path, PurePath
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import cv2
from sklearn.model_selection import train_test_split
import torchvision
from torchvision.transforms import transforms as T
from torchvision.utils import draw_bounding_boxes
from tqdm import tqdm
import albumentations as A
import time
%matplotlib inline
torch.manual_seed(42)
调用CUDA
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
device
加载数据集
IMAGES = '/kaggle/input/car-segmentation/car-segmentation/images/'
MASKS = '/kaggle/input/car-segmentation/car-segmentation/masks/'
classes = '/kaggle/input/car-segmentation/car-segmentation/classes.txt'
masks = '/kaggle/input/car-segmentation/car-segmentation/masks.json'
数据可视化
rows, cols = 3, 3
plt.figure(figsize=(15, 12))
images = [IMAGES + i for i in os.listdir(IMAGES)]
shapes = []
for num, x in enumerate(images):
if num == 9:
break
img = Image.open(x)
shapes.append(img.size)
plt.subplot(rows, cols, num+1)
plt.title(Path(x).name)
plt.axis('off')
plt.imshow(img)
img.close()
rows, cols = 3, 3
plt.figure(figsize=(15, 12))
mask_shapes = []
masks = [MASKS + i for i in os.listdir(MASKS)]
for num, x in enumerate(masks):
if num == 9:
break
img = Image.open(x)
mask_shapes.append(img.size)
plt.subplot(rows, cols, num+1)
plt.title(Path(x).name)
plt.axis('off')
plt.imshow(img)
img.close()
数据准备
class CarDataset(Dataset):
def __init__(self, image_path, mask_path, x, mean, std, transform=None, patch=False):
self.img_path = image_path
self.mask_path = mask_path
self.x = x
self.mean = mean
self.std = std
self.transform = transform
self.patch = patch
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
img = cv2.imread(self.img_path + self.x[idx])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = cv2.imread(self.mask_path + self.x[idx])
if self.transform is not None:
aug = self.transform(image=img, mask=mask)
img = Image.fromarray(aug['image'])
mask = aug['mask']
if self.transform is None:
img = Image.fromarray(img)
t = T.Compose([T.ToTensor(), T.Normalize(self.mean, self.std)])
img = t(img)
mask = torch.from_numpy(mask).to(torch.int64)
if self.patch:
img, mask = self.tiles(img, mask)
return img, mask
def tiles(self, img, mask):
img_patches = img.unfold(1, size, size).unfold(2, size, size)
img_patches = img_patches.contiguous().view(3, -1, size, size)
img_patches = img_patches.permute(1, 0, 2, 3)
mask_patches = mask.unfold(0, size, size).unfold(1, size, size)
mask_patches = mask_patches.contiguous().view(-1, size, size)
return img_patches, mask_patches
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
t_train = A.Compose([A.Resize(size, size, interpolation=cv2.INTER_NEAREST),
A.VerticalFlip(),
A.HorizontalFlip(),
A.GridDistortion(p=0.2),
A.GaussNoise(),
A.RandomBrightnessContrast((0, 0.5), (0, 0.5)),])
t_val = A.Compose([A.Resize(size, size, interpolation=cv2.INTER_NEAREST),
A.HorizontalFlip(),
A.GridDistortion(p=0.2),])
train_dataset = CarDataset(IMAGES, MASKS, x_train, mean, std, t_train, patch=False)
val_dataset = CarDataset(IMAGES, MASKS, x_val, mean, std, t_val, patch=False)
batch_size = 3
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
U-Net架构
class UNet(nn.Module):
def __init__(self, n_class):
super().__init__()
"""
Encoder
Every block in encoder has 2 convolution layer followed by max pooling layer, except last block which do not have max pooling layer
The input to the U-Net is 400*400*channels
"""
self.enc_blk11 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.enc_blk12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu = nn.ReLU()
self.enc_blk21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.enc_blk22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.enc_blk31 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.enc_blk32 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.enc_blk41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
self.enc_blk42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.enc_blk51 = nn.Conv2d(512, 1024, kernel_size=3, padding=1)
self.enc_blk52 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
"""
Decoder
Here Upsampling of layers are done
"""
self.upconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.dec_blk11 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
self.dec_blk12 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.upconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec_blk21 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
self.dec_blk22 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec_blk31 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
self.dec_blk32 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec_blk41 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
self.dec_blk42 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
# Output Layer
self.out_layer = nn.Conv2d(64, n_class, kernel_size=1)
def forward(self, x):
# Encoder
enc11 = self.relu(self.enc_blk11(x))
enc12 = self.relu(self.enc_blk12(enc11))
pool1 = self.pool(enc12)
enc21 = self.relu(self.enc_blk21(pool1))
enc22 = self.relu(self.enc_blk22(enc21))
pool2 = self.pool(enc22)
enc31 = self.relu(self.enc_blk31(pool2))
enc32 = self.relu(self.enc_blk32(enc31))
pool3 = self.pool(enc32)
enc41 = self.relu(self.enc_blk41(pool3))
enc42 = self.relu(self.enc_blk42(enc41))
pool4 = self.pool(enc42)
enc51 = self.relu(self.enc_blk51(pool4))
enc52 = self.relu(self.enc_blk52(enc51))
# Decoder
up1 = self.upconv1(enc52)
up11 = torch.cat([up1, enc42], dim=1)
dec11 = self.relu(self.dec_blk11(up11))
dec12 = self.relu(self.dec_blk12(dec11))
up2 = self.upconv2(dec12)
up22 = torch.cat([up2, enc32], dim=1)
dec21 = self.relu(self.dec_blk21(up22))
dec22 = self.relu(self.dec_blk22(dec21))
up3 = self.upconv3(dec22)
up33 = torch.cat([up3, enc22], dim=1)
dec31 = self.relu(self.dec_blk31(up33))
dec32 = self.relu(self.dec_blk32(dec31))
up4 = self.upconv4(dec32)
up44 = torch.cat([up4, enc12], dim=1)
dec41 = self.relu(self.dec_blk41(up44))
dec42 = self.relu(self.dec_blk42(dec41))
out = self.out_layer(dec42)
return out
训练
def pixel_accuracy(output, mask):
with torch.no_grad():
output = torch.argmax(F.softmax(output, dim=1), dim=1)
correct = torch.eq(output, mask).int()
accuracy = float(correct.sum() / float(correct.numel()))
return accuracy
def mIoU(pred_mask, mask, smooth=1e-10, n_classes=5):
with torch.no_grad():
pred_mask = F.softmax(pred_mask, dim=1)
pred_mask = torch.argmax(pred_mask, dim=1)
pred_mask = pred_mask.contiguous().view(-1)
mask = mask.contiguous().view(-1)
iou_per_class = []
for classes in range(0, n_classes):
true_class = (pred_mask == classes)
true_label = (mask == classes)
if true_label.long().sum().item() == 0:
iou_per_class.append(np.nan)
else:
intersect = torch.logical_and(true_class, true_label).sum().float().item()
union = torch.logical_or(true_class, true_label).sum().float().item()
iou = (intersect + smooth) / (union + smooth)
iou_per_class.append(iou)
return np.nanmean(iou_per_class)
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
def fit(epochs, model, train_loader, val_loader, criterion, optimizer, scheduler, patch=False):
torch.cuda.empty_cache()
train_losses = []
test_losses = []
val_iou = []
val_acc = []
train_iou = []
train_acc = []
lrs = []
min_loss = np.inf
decrease = 1
not_improve = 0
model.to(device)
fit_time = time.time()
for epoch in range(epochs):
since = time.time()
running_loss = 0
iou_score = 0
accuracy = 0
# Training
model.train()
for i, data in enumerate(tqdm(train_loader)):
#Training Phase
image_tiles, mask_tiles = data
bs, c, h, w = image_tiles.size()
if patch:
batch, n_tiles, c, h, w = image_tiles.size()
image_tiles = image_tiles.view(-1, c, h, w)
mask_tiles = mask_tiles.view(-1, h, w)
else:
mask_tiles = mask_tiles.view(bs, c, h, w)
mask_tiles = mask_tiles.float()
mask_tiles = torch.mean(mask_tiles, dim=1)
mask_tiles = mask_tiles.long()
image = image_tiles.to(device)
mask = mask_tiles.to(device)
# Forward Prop
output = model(image)
# Loss
loss = criterion(output, mask)
# Evaluation
iou_score += mIoU(output, mask)
accuracy += pixel_accuracy(output, mask)
# Backpropogation
loss.backward()
# Updating weights
optimizer.step()
# Clearing gradients
optimizer.zero_grad()
lrs.append(get_lr(optimizer))
scheduler.step()
running_loss += loss.item()
else:
# Validation
model.eval()
test_loss = 0
test_accuracy = 0
val_iou_score = 0
# Validation Phase
with torch.no_grad():
for i, data in enumerate(tqdm(val_loader)):
image_tiles, mask_tiles = data
bs, c, h, w = image_tiles.size()
if patch:
bs, n_tiles, c, h, w = image_tiles.size()
image_tiles = image_tiles.view(-1, c, h, w)
mask_tiles = mask_tiles.view(-1, h, w)
else:
print(mask_tiles.shape)
mask_tiles = mask_tiles.view(bs, c, h, w)
mask_tiles = mask_tiles.float()
mask_tiles = torch.mean(mask_tiles, dim=1)
mask_tiles = mask_tiles.long()
image = image_tiles.to(device)
mask = mask_tiles.to(device)
# Forward prop
output = model(image)
# Evaluation
val_iou_score += mIoU(output, mask)
test_accuracy += pixel_accuracy(output, mask)
# Loss
loss = criterion(output, mask)
test_loss += loss.item()
# Calculating Mean for Each Batch
train_losses.append(running_loss / len(train_loader))
test_losses.append(test_loss / len(val_loader))
if min_loss > (test_loss / len(val_loader)):
print('Loss Decreasing.... {:.3f} >> {:.3f} '.format(min_loss, (test_loss / len(val_loader))))
min_loss = (test_loss / len(val_loader))
decrease += 1
if decrease % 5 == 0:
print('Saving Model.......')
torch.save(model, '//kaggle//working//UNet-mIoU-{:.3f}.pt'.format(val_iou_score / len(val_loader)))
if (test_loss / len(val_loader)) > min_loss:
not_improve += 1
min_loss = (test_loss / len(val_loader))
print(f"Loss did not decrease for {not_improve} times")
if not_improve == 7:
print('Loss did not decrease for 7 times, Stopped Training..')
break
# IoU
val_iou.append(val_iou_score / len(val_loader))
train_iou.append(iou_score / len(train_loader))
train_acc.append(accuracy / len(train_loader))
val_acc.append(test_accuracy / len(val_loader))
print("Epoch: {} / {} ".format(epoch + 1, epochs),
"Train Loss: {:.3f} ".format(running_loss / len(train_loader)),
"Val Loss: {:.3f} ".format(test_loss / len(val_loader)),
"Train mIoU: {:.3f} ".format(iou_score / len(train_loader)),
"Val mIoU: {:.3f} ".format(val_iou_score / len(val_loader)),
"Train Accuracy: {:.3f} ".format(accuracy / len(train_loader)),
"Val Accuracy: {:.3f} ".format(test_accuracy / len(val_loader)),
"Time: {:.2f}m".format((time.time() - since) / 60))
history = {
'train_loss': train_losses,
"val_loss": test_losses,
'train_miou' :train_iou,
'val_miou':val_iou,
'train_acc' :train_acc,
'val_acc':val_acc,
'lrs': lrs}
print("Total time: {:.2f} m".format((time.time() - fit_time) /60 ))
return history
model = UNet(l)
model.to(device)
亏损和MIoU
def plot_loss(history):
plt.plot(history['val_loss'], label='val', marker='o')
plt.plot( history['train_loss'], label='train', marker='o')
plt.title('Loss per epoch'); plt.ylabel('loss');
plt.xlabel('epoch')
plt.legend(), plt.grid()
plt.show()
def plot_score(history):
plt.plot(history['train_miou'], label='train_mIoU', marker='*')
plt.plot(history['val_miou'], label='val_mIoU', marker='*')
plt.title('Score per epoch'); plt.ylabel('mean IoU')
plt.xlabel('epoch')
plt.legend(), plt.grid()
plt.show()
def plot_acc(history):
plt.plot(history['train_acc'], label='train_accuracy', marker='*')
plt.plot(history['val_acc'], label='val_accuracy', marker='*')
plt.title('Accuracy per epoch'); plt.ylabel('Accuracy')
plt.xlabel('epoch')
plt.legend(), plt.grid()
plt.show()
plot_loss(history)
plot_score(history)
plot_acc(history)
输出可视化
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 10))
ax1.imshow(image)
ax1.set_title('Original Image')
ax2.imshow(mask)
ax2.set_title('Ground Truth')
ax2.set_axis_off()
ax3.imshow(pred_mask)
ax3.set_title('UNet-MobileNet | mIoU {:.3f}'.format(score))
ax3.set_axis_off()