使用合成数据进行语义分割:一种生成方法

2023年09月21日 由 alex 发表 313 0

计算机视觉应用现在已经在每个技术领域中无处不在。这些模型高效且有效,研究人员每年都在尝试新的想法。新的前沿是试图消除深度学习的最大负担:需要大量标记数据的需求。正如本文所解释的那样,解决这个问题的方法是使用合成数据。


从这些研究中收益最多的计算机视觉领域无疑是语义分割领域,即预测图像中每个像素的标签,以便从图像中检索出感兴趣的对象。正如人们所预期的那样,手动标记训练集是一项费时、费力且容易出错的过程,因此有各种新方法利用合成数据。


在本文中,我们将看到其中一种方法,它利用生成对抗网络来解决使用合成数据时的一个问题:域适应问题。


数据生成


为了生成语义分割任务的数据,最常见的解决方案是使用与渲染引擎关联的模拟器。通过这种方式,可以根据需要产生图像,改变光照条件、对象的数量和姿势以及它们之间的交互,始终与像素完美的语义标签相关联。例如,一个非常流行的数据集是GTAV ,几乎所有研究都将其用作基准。这个数据集使用的模拟引擎是同名的视频游戏。该数据集包含从汽车驾驶员的视角拍摄的图像,非常适合像自动驾驶这样的应用。另一个著名的数据集是SINTHIA,它也包含了城市环境的图像。


1


生成方法用于领域自适应


直接使用合成数据进行模型训练是不够的。即使渲染引擎非常逼真,神经网络仍然会学习到模拟环境中存在的一些不真实的模式,无法很好地推广到真实世界的数据上。这被称为域自适应问题。


为了克服这个问题,模型必须在训练过程中学习到将源域S(合成域)和目标域T(真实域)的特征分布重新对齐的最优方式。这可以通过各种方法实现,例如对抗训练、知识蒸馏和自监督学习。


特别是,对抗训练是一种生成方法,它将源域数据转换为更类似于目标域的分布。它可以表示为:


给定一个源域数据集Dₛ = {(xᵢˢ, yᵢˢ), i=1…nₛ}和一个目标域数据集Dₜ = {xᵢᵗ, i=1…nₜ},其中xᵢˢ和xᵢᵗ是输入样本,yᵢˢ是xᵢˢ对应的标签,目标是学习一个映射函数?ᵢˢ = G(xᵢˢ),称为生成器,它将源域特征映射到目标域特征,使得在转换后的源域图像上训练的深度学习模型在目标域上表现良好。这通过一个判别器实现,判别器是一个神经网络,同时接收真实和转换后的合成图像作为输入,并尝试预测输入是否来自真实分布。


这些网络在对抗的环境中进行训练,只有当鉴别器失败时,生成器才会获胜。当转换后的图像与真实图像非常相似以至于鉴别器不能区分它们时,过程收敛,使得预测准确率不比随机猜测(50%准确率)更好。


几何引导的输入-输出自适应


各种算法使用了生成方法。其中之一被称为GIO-Ada,它代表着几何引导的输入-输出自适应。该算法对原始方法进行了两个改进。


1. 它使用从模拟引擎中很容易获取的另一个信息:深度图。直觉是对象的几何信息在其深度信息中编码得比像素的语义标签更好。因此,模型还会被训练来估计输入图像的深度图,并且这个额外的信息仅在训练过程中用作辅助损失。


2. 它在输出级别使用了第二个对抗阶段,有一个第二个判别器操作于任务网络的输出(包括语义标签图和几何深度图),训练判断预测的输出是来自真实还是合成图像。


2


全体架构由4个神经网络组成:生成器,用于转换合成图像;任务网络,用于预测真实图像和转换图像的标签和深度图;以及两个判别器。所有网络都经过端到端的训练,使用遵循对抗训练规则的通用优化步骤。


PyTorch Lightning的实现


为了方便实现和训练这个复杂的算法,可以使用一个叫做 pytorch_lightning 的库。这是一个对 PyTorch 的封装,可以帮助避免重新实现一些必要的与 torch 一起使用的样板代码,例如实现训练循环、处理日志记录和保存超参数和权重、管理 GPU(或多个 GPU)以及执行优化器步骤。在我们的情况下,最后一个特性是不必要的,因为对抗性训练的特殊之处在于生成器和判别器之间的优化步骤的交替,需要自定义。


让我们开始导入库并定义一个实用函数,用于为判别器创建标签。


import itertools
from typing import Iterator
import pytorch_lightning as pl
import torch
from torch import nn
from torchmetrics.classification.jaccard import MulticlassJaccardIndex

def _labels(inputs: torch.Tensor, fill_value: int) -> torch.Tensor:
    return torch.full((inputs.size(0), 1), fill_value).to(inputs)


神经网络以torch模块的形式实现。设B为批量大小,C为图像通道数,K为类别数,W、H为图像的宽度和高度:


1. 任务网络必须处理形状为B × C × W × H的图像批量,并返回形状为B × K × W × H的标签预测和形状为B × 1 × W × H的深度预测。一种可能的架构选择是使用DeepLabV3+ [4]作为任务网络,其中包括两个不同的头部,一个用于类别预测,一个用于深度预测。


2. 图像转换网络必须接受所有合成数据作为输入,即形状为B × C × W × H的图像,形状为B × K × W × H的标签,以及形状为B × 1 × W × H的深度图,将它们连接起来,并输出形状为B × C × W × H的转换后的图像。


3. 鉴别器必须接受形状为B × (C或C + K + 1) × W × H的输入,并输出形状为B × 1的输出,表示样本是真实的概率。


class TaskNetwork(nn.Module):
    def __init__(
        self,
        input_channels: int,
        num_classes: int,
        pretrained_backbone: bool = False,
    ) -> None:
        ...
    def forward(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        ...

class ImageTransformNetwork(nn.Module):
    def __init__(
        self,
        input_channels: int,
        output_channels: int,
    ) -> None:
        ...
    def forward(
        self,
        fake_images: torch.Tensor,
        labels: torch.Tensor,
        depths: torch.Tensor,
    ) -> torch.Tensor:
        ...

class Discriminator(nn.Module):
    def __init__(self, input_channels: int) -> None:
        ...
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        ...


其余的代码将在pytorch_lightning LightningModule中实现。在__init__方法中,我们会传递所有的超参数,同时实例化4个神经网络、损失函数和指标。卷积层的权重将从正态分布中初始化,除了任务网络的权重可以进行预训练,例如使用ImageNet数据集。


class GIOAda(pl.LightningModule):
    REAL_LABEL = 1
    FAKE_LABEL = 0
    def __init__(
        self,
        num_classes: int,
        pretrained_backbone: bool,
        init_lr: float,
        betas: tuple[float, float],
        num_epochs: int,
        num_steps_per_epoch: int,
        lam_input: float,
        lam_output: float,
        lam_depth: float,
    ) -> None:
        super().__init__()
        self.save_hyperparameters() # saved in the dictionary self.hparams
        # disabling automatic optimization, as it willl be made manually
        self.automatic_optimization = False
        self.task_network = TaskNetwork(
            input_channels=3,  # RGB Channels
            num_classes=num_classes,  # Classes
            pretrained_backbone=pretrained_backbone,
        )
        self.fake_transformation = ImageTransformNetwork(
            input_channels=num_classes + 4,  # RGB Channels + Classes + Depth
            output_channels=3,  # RGB Channels
        )
        self.input_discriminator = Discriminator(
            input_channels=3,  # RGB Channels
        )
        self.output_discriminator = Discriminator(
            input_channels=num_classes + 1,  # Classes + Depth
        )
        self.depths_loss = nn.L1Loss()
        self.labels_loss = nn.CrossEntropyLoss()
        self.discriminator_loss = nn.BCELoss()
        self.miou_index = MulticlassJaccardIndex(num_classes)
        self.weight_init(pretrained_backbone=pretrained_backbone)
    def weight_init(self, pretrained_backbone: bool = False):
        for name, module in self.named_modules():
            if "task" in name and pretrained_backbone:
                continue
            if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
                module.weight.data.normal_(0, 0.001)
                if module.bias is not None:
                    module.bias.data.zero_()


然后,我们定义优化器和学习率调度器。我们需要一个优化器来处理“生成器”的权重,也就是生成器和任务网络的权重,另一个优化器来处理鉴别器的权重。作为学习率调度器,我们将使用OneCycle策略,在训练的初期通过增加学习率和减小动量来“热身”网络,允许更早地探索权重空间并找到更好的起点。然后,在最后阶段,学习率会通过余弦退火策略进行降低。


def configure_optimizers(def configure_optimizers(
        self,
    ) -> tuple[
        list[torch.optim.Adam], list[torch.optim.lr_scheduler.OneCycleLR]
    ]:
        params_g = itertools.chain(
            self.fake_transformation.parameters(),
            self.task_network.parameters(),
        )
        params_d = itertools.chain(
            self.input_discriminator.parameters(),
            self.output_discriminator.parameters(),
        )
        optimizer_g, lr_sched_g = self._optimizer_lr_scheduler(params_g)
        optimizer_d, lr_sched_d = self._optimizer_lr_scheduler(params_d)
        return [optimizer_g, optimizer_d], [lr_sched_g, lr_sched_d]
    def _optimizer_lr_scheduler(
        self,
        parameters: Iterator[torch.nn.Parameter],
    ) -> tuple[torch.optim.Adam, torch.optim.lr_scheduler.OneCycleLR]:
        optimizer = torch.optim.Adam(
            parameters,
            lr=self.hparams["init_lr"],
            betas=self.hparams["betas"],
        )
        lr_sched = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.hparams["init_lr"],
            epochs=self.hparams["num_epochs"],
            steps_per_epoch=self.hparams["num_steps_per_epoch"],
            base_momentum=self.hparams["betas"][0],
        )
        return optimizer, lr_sched


训练步骤接收的输入有:


1. 从真实数据集中采样的一批真实图像


2. 从合成数据集中采样的一批合成图像,以及对应的标签和深度图


然后进行以下两个操作:


1. 对于辨别器进行优化步骤,其中需要所有输入数据


2. 对于生成器进行优化步骤,只需要合成数据


步骤的顺序对于确保模型的收敛至关重要。由于生成器更容易崩溃,我们应该让辨别器在训练路径上“引领”。这样,在生成器步骤中,辨别器在其任务上会更好一些,为生成器提供“更好”的梯度。


def training_step(self, batch: tuple[torch.Tensor, ...]) -> None:def training_step(self, batch: tuple[torch.Tensor, ...]) -> None:
        optimizer_g, optimizer_d = self.optimizers()  
        real_images, fake_images, labels, depths = batch
        # Update D network: minimize log(D(x)) + log(1 - D(G(z)))
        self.toggle_optimizer(optimizer_d)
        optimizer_d.zero_grad()
        self._discriminator_step(real_images, fake_images, labels, depths)
        optimizer_d.step()
        self.untoggle_optimizer(optimizer_d)
        # Update G network: maximize log(D(G(z))) and minimize task loss
        self.toggle_optimizer(optimizer_g)
        optimizer_g.zero_grad()
        self._generator_step(fake_images, labels, depths)
        optimizer_g.step()
        self.untoggle_optimizer(optimizer_g)


鉴别器步骤只是最小化鉴别器输出的二元交叉熵损失。首先对真实批次进行操作,其中期望的标签全部为1,然后对合成批次进行操作,其中期望的标签全部为0。


def _discriminator_step(def _discriminator_step(
        self,
        real_images: torch.Tensor,
        fake_images: torch.Tensor,
        labels: torch.Tensor,
        depths: torch.Tensor,
    ) -> None:
        disc_lab = _labels(real_images, self.REAL_LABEL)
        disc_input = self.input_discriminator(real_images)
        disc_output = self.output_discriminator(
            torch.concat(self.task_network(real_images), dim=1)
        )
        loss_input = (
            self.discriminator_loss(disc_input, disc_lab)
            * self.hparams["lam_input"]
        )
        loss_output = (
            self.discriminator_loss(disc_output, disc_lab)
            * self.hparams["lam_output"]
        )
        self.manual_backward(loss_input + loss_output)
        transformed = self.fake_transformation(fake_images, labels, depths)
        disc_lab = _labels(transformed, self.FAKE_LABEL)
        disc_input = self.input_discriminator(transformed)
        disc_output = self.output_discriminator(
            torch.concat(self.task_network(transformed), dim=1)
        )
        loss_input = (
            self.discriminator_loss(disc_input, disc_lab)
            * self.hparams["lam_input"]
        )
        loss_output = (
            self.discriminator_loss(disc_output, disc_lab)
            * self.hparams["lam_output"]
        )
        self.manual_backward(loss_input + loss_output)
        
        # Log losses and metrics
        # ...


生成器步骤相反地最小化了标签的交叉熵损失和深度估计的L1损失,并最大化了鉴别器的二元交叉熵损失。这是通过使用相反的标签计算损失来实现的,因此对于合成输入来说,全部为1。不需要为真实输入计算此损失,因为生成器的权重对这些输出没有影响。


 def _generator_step(def _generator_step(
        self,
        fake_images: torch.Tensor,
        labels: torch.Tensor,
        depths: torch.Tensor,
    ) -> None:
        # Set disc_lab = REAL in order to maximize the loss for the 
        # discriminator when inputs are all fakes
        disc_lab = _labels(fake_images, self.REAL_LABEL)
        # Forward pass on all the networks to collect gradients for G
        transformed = self.fake_transformation(fake_images, labels, depths)
        fake_mask, fake_depth = self.task_network(transformed)
        disc_input = self.input_discriminator(transformed)
        disc_output = self.output_discriminator(
            torch.concat((fake_mask, fake_depth), dim=1)
        )
        # Calculate losses
        loss_input = (
            self.discriminator_loss(disc_input, disc_lab)
            * self.hparams["lam_input"]
        )
        loss_output = (
            self.discriminator_loss(disc_output, disc_lab)
            * self.hparams["lam_output"]
        )
        loss_depths = (
            self.depths_loss(fake_depth, depths) * self.hparams["lam_depth"]
        )
        loss_labels = self.labels_loss(fake_mask, labels)
        # Calculate Gradients
        self.manual_backward(
            loss_input + loss_output + loss_depths + loss_labels
        )
        # Log losses and metrics
        # ...


结论


在各种数据集上,这里解释的方法被证明非常有效。在下面的图中,我们可以看到,在利用合成数据进行训练的模型超过只使用小的KITTI数据集进行训练的模型。从大量合成数据中获得的知识使模型能够从真实图像中提取更细致的细节。


3


这个算法也有一些缺点。首先,对抗性训练可能非常不稳定,正如之前看到的不寻常的训练步骤所暗示的那样。因此,要获得良好的结果,进行详尽的超参数搜索是至关重要的。另一个主要问题是,训练生成网络是一项非常耗费内存的工作,尤其是对于高分辨率图像。


最新的研究关注其他方法(例如自学习),利用变压器层中注意机制的强大泛化性质以及领域特定的数据增强技术。



文章来源:https://medium.com/data-reply-it-datatech/semantic-segmentation-with-synthetic-data-a-generative-approach-8ad72d14dec1
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
写评论取消
回复取消