VQGAN:一步步从图像重建到新图像生成

2023年12月18日 由 alex 发表 1074 0

本文的目标是VQGAN,作为一个整体系统用于生成新图片。VQVAE的概念是编码器、解码器和码本的同时训练,这对所有可能的图像都是通用的。码本是一组256嵌入向量。任何具有输入分辨率256x256的图像的潜在空间由码本向量的某个子集表示。VQVAE流程图的图示如下图底部:


5

图1


编码器将输入图像(分辨率为256x256像素)转换为潜在空间,该潜在空间有一个16x16的平面,每个条目都是一个包含256个数值的向量。然后,潜在空间中的每个条目都将改变为码本中与其L2度量最近的向量——这个过程称为向量量化。因此,潜在空间由16x16的码本索引平面表示。将这个量化的潜在空间发送到解码器,我们得到了重建的图像。


在VQGAN中,自编码器的部分通过额外的CNN——基于补丁(Patch)的鉴别器(Discriminator)进行扩展。鉴别器具有分类器结构。在图片中显示了VQVAE与鉴别器之间的交互:在图像被重建后,它被发送到鉴别器,鉴别器为图像补丁产生类别值。鉴别器为输入和重建图像的“每个补丁的类别”空间获取类别值,并在每个补丁上验证这些空间之间的类别差异:相同类别(真实)或不同类别(伪造)。鉴别器参与VQGAN的训练,并试图最大化其损失,但是共同的损失=VQVAE的损失+来自鉴别器的损失要最小化。训练模型进行图像重建时不使用鉴别器,它是在训练步骤中用于改善VQVAE质量的。在GAN训练的下一步,为生成新图像,鉴别器发挥了重要作用。


潜在空间的实际实验


在这一部分,我展示了VQGAN在实践中的图像重建,并实验了潜在空间、码本及其在生成新图像中的作用。这里,我在我的谷歌协作平台中使用以下代码。


导入:


import copy
import cv2
import sys
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np


Google Drive 映射:


from google.colab import drive
drive.mount('/content/gdrive')


CUDA 设备设置:


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


VQGAN安装及模型下载:


%pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 einops>=0.3.0
sys.path.append(".")
!git clone https://github.com/CompVis/taming-transformers
%cd taming-transformers
# download a VQGAN with f=16 (16x compression per spatial dimension) and with a codebook with 1024 entries
!mkdir -p logs/vqgan_imagenet_f16_1024/checkpoints
!mkdir -p logs/vqgan_imagenet_f16_1024/configs
!wget 'https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1' -O 'logs/vqgan_imagenet_f16_1024/checkpoints/last.ckpt'
!wget 'https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1' -O 'logs/vqgan_imagenet_f16_1024/configs/model.yaml'


# also disable grad to save memory
torch.set_grad_enabled(False)


在上面的代码中,安装了具有1024个代码簿条目的最小模型。


两个实用函数用于从文件中读取图像并将其转换为torch张量,以及用于显示输入和输出图像:


def get_img_tensor(name,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Resize((256, 256)),
                   ])):
    img = Image.open(name)
    img = transform(img)
    img = img.unsqueeze(0)
    return img
 
def show_results(img, out):
    rec = custom_to_pil(out[0])
    _, ax = plt.subplots(1, 2, figsize=(12, 5))
    if img is not None:
        ax[0].imshow(img[0].permute(1, 2, 0))
    ax[0].axis("off")
    ax[1].imshow(rec)
    ax[1].axis("off")  
    plt.show()


以下的代码包含了使用VQGAN进行图像重建的函数:


from omegaconf import OmegaConf
from taming.models.vqgan import VQModel
def load_config(config_path):
    config = OmegaConf.load(config_path)
    return config
def load_vqgan(config, ckpt_path=None):
    model = VQModel(**config.model.params)
    if ckpt_path is not None:
        sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
        missing, unexpected = model.load_state_dict(sd, strict=False)
    return model.eval()
def preprocess_vqgan(x):
    x = 2.*x - 1.
    return x
def custom_to_pil(x):
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.)/2.
    x = x.permute(1, 2, 0).numpy()
    x = (255*x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x
def reconstruct_with_vqgan(x, model):
    # could also use model(x) for reconstruction but use explicit encoding and decoding here
    z, _, [_, _, indices] = model.encode(x)
    print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
    xrec = model.decode(z)
    return xrec


load_config(),load_vqgan() — 用于加载预训练模型的函数。


preprocess_vqgan() — 在将输入的图像张量发送给VQGAN编码器前对其进行预处理的函数。


custom_to_pil() — 在VQGAN解码器之后对重建的图像张量进行后处理的函数。


reconstruct_with_vqgan() — 进行图像重建的函数:该函数调用编码器以获得图像的潜在空间,然后调用解码器获得重建的图像。


注意:一个已知问题“from taming.models.vqgan import VQModel”行中可能会发生错误。


现在一切准备就绪,可以进行重建。加载预训练模型:


config1024 = load_config("logs/vqgan_imagenet_f16_1024/configs/model.yaml")
model1024 = load_vqgan(config1024, ckpt_path="logs/vqgan_imagenet_f16_1024/checkpoints/last.ckpt").to(device)


并使用模型进行图像重建:


img = get_img_tensor("image path")
out = reconstruct_with_vqgan(preprocess_vqgan(img.to(device)), model1024)
show_results(img, out)


图像重建结果示例:


6

图2


reconstruct_with_vqgan()调用了Encoder和Decoder这两个函数。让我们来看看这些函数:


z, _, [_, _, indices] = model1024.encode(preprocess_vqgan(img.to(device)))
indices = indices.detach().cpu().numpy()


预训练的编码器模型返回具有形状 (1, 256, 16, 16) 的潜在空间 z 和形状为 (256) 的码本向量索引。拥有这些索引的码本向量如果按光栅顺序放置在 16x16 的平面上,就可以组成潜在空间。换句话说,如果我有一组正确顺序的 256 个索引,我就可以从码本创建一个潜在空间,并通过调用解码器重建图片。在下面的代码中,我尝试了这个操作。


首先,获取码本向量:


ind = torch.arange(1024).to(device)
cb = model1024.quantize.get_codebook_entry(ind, None)
print(cb.shape)


在上面的代码中,我从码本中得到了索引为0,… 1023的向量,即整个码本(我使用的是小型的VQGAN模型)。码本的形状为(1024, 256)。


下面的函数展示了如何从码本+ 256个索引的numpy数组创建潜在空间,并使用Decoder得到输出图片:


def cb_construct(cb, indices, img):
    emb = [cb[i] for i in indices]
    zn = torch.stack(emb)
    zn = torch.reshape(zn, (16, 16, 256))
    zn = torch.unsqueeze(zn, 0)
    zn = zn.permute(0, 3, 1, 2)
    xrec = model1024.decode(zn.to(device))
    show_results(img, xrec)


如果我们使用前面代码块中获取的 cb 和 indices 调用这个函数:


cb_construct(cb, indices, img)


我们获得了与图2上完全相同的重构结果。

如果我们尝试打乱索引并用打乱的向量解码潜在空间:


indices1 = copy.deepcopy(indices)
np.random.shuffle(indices1)
cb_construct(cb, indices1, img)


我们获得了一个具有一定抽象性的新图像。


7


其他一些基于其他图片潜在空间的“抽象艺术”例子:


8


9


10


因此,我们已经通过实验验证,要创造任何图像,我们需要一个码本(codebook),一个已定义顺序的码本向量索引集合和一个解码器。直观上我们明白,我们需要某种系统来定义码本向量的子集和它们索引的顺序,以生成某些类型的逼真图像。


驯服变换器(Taming Transformer)


Taming Transformer是图像生成器的第二阶段。它被训练用于生成新图像潜在空间的索引序列。生成过程从初始条件处理开始。以下是该模型处理的条件图像类型:


11


类型为输入代码的<输入类型+输入代码>作为初始参数发送给模型。输入代码指的是条件图像潜在空间的码本向量索引集合。模型被训练以使用之前预测的索引来预测当前索引。第一个索引是基于输入代码来预测的。变换器预测可能下一个索引的分布(见图1)。如果输入图像分辨率为256x256,所有之前预测的索引都被用来预测当前索引。对于高分辨率图像,每个补丁只使用之前从邻近补丁中预测的索引来进行预测,在下面的图片中以滑动窗口的形式展示:


12


Taming Transformer使用VQGAN和第一阶段训练过的鉴别器模型作为骨干。训练步骤如下:预测整个补丁中码本索引的分布,将预测的索引送到解码器并获取输出补丁,将输出补丁送到鉴别器并获得补丁特征,然后计算输入特征(从输入编码获得)与输出特征之间的交叉熵损失。


我尝试使用这个谷歌协作平台(Google Colab)从分割掩码生成新图像。我使用了协作平台所提议的输入数据。这里是用预训练的驯服变压器模型对同一个分割掩码进行三次不同运行的结果:


13


分割掩膜可能包含多达182类对象(掩膜值从1到182)。


关于Taming Transformer + VQGAN系统的结论:


  1. 该系统能够使用条件图像输入配置生成高质量的逼真图像。
  2. 该系统可用于图像扩展:例如,可以将一个条件图像作为上半部分的图像发送到系统中,输出的图像将包含这个上部分 + 一个生成的下部分。
  3. 需要对条件图像进行特殊的预处理,这取决于输入类型,例如适当配置分割掩膜。
  4. 该系统用于生成类似于条件图像的新图像,但不能用于改变条件图像的风格。


CLIP + VQGAN系统用于生成新图像


首先,我从第2节的代码继续。在下面的代码中,我在编码后改变了图像的潜在空间 - 将其乘以0.7:


img = get_img_tensor("image path")
z, _, [_, _, indices] = model1024.encode(preprocess_vqgan(img.to(device)))
out = model1024.decode(0.7 * z)
show_results(img, out)


因此,我得到了另一种风格的冬季风景画:


14


我可以用另一种方式改变潜在空间,例如,将潜在空间中每个向量的第70个元素乘以50:


ind = 70
z, _, [_, _, indices] = model1024.encode(img.to(device))
z = z.permute(1, 0, 2, 3)
z = [z[i] for i in range(256)]
z[ind] = z[ind] * 50
z = torch.stack(z)
z = z.permute(1, 0, 2, 3)
out = model1024.decode(z)
show_results(img, out) 


景观风格以另一种方式改变:


15


CLIP + VQGAN 系统的理念类似:通过改变整个潜在空间来以期望的方式改变图像。CLIP 扮演了鉴别器的角色,它理解文本描述中期望的图像,并生成一个损失值。CLIP 是一个被训练来找出图像与文本相似性的系统。生成过程看起来像是一个带有潜在空间权重变化的训练过程:CLIP 生成文本描述的嵌入向量,VQGAN 解码潜在空间并获取图像,然后 CLIP 生成图像的嵌入向量,CLIP + VQGAN 系统计算其余输入文本描述的嵌入向量的余弦相似度。这个系统的目标是最大化相似度(相似度的区间是(0, 1))。为了实现这个目标,系统在反向传播步骤中改变潜在空间中的权重。主要的挑战是在反向传播期间找到一种梯度传输的技巧,因为 VQGAN 和 CLIP 不是系统的骨干,只是加载的预训练模型。我尝试了来自这个谷歌 colab 的 CLIP + VQGAN 实现。我根据输入文本描述改变了输入图像以在图像中获得视频效果。下面的图片展示了一些结果:


文本提示:“被雪覆盖的云杉”。


图像变换:


16


雪覆盖的冷杉枝条和松果。


17


文字提示1: 白色和红色的点点散布在一个大象的黑色轮廓上。


文字提示2: 白色和红色的花朵位于一个大象的黑色轮廓上。


图像变换:


18


文本提示:“新年云杉水彩画”。


19


CLIP + VQGAN系统的结论:


与Taming Transformer+ VQGAN系统相反,CLIP + VQGAN系统是为艺术创作而非为真实感图像生成提供的解决方案。它能够接受用户以最便捷的形式输入,即文本描述,来生成不同风格的图像。


结论


用于图像压缩的自动编码器 -> 相对较小的图像潜在空间,特定数据集的图像重建。


矢量量化自动编码器 -> 基于码本和矢量量化技术的图像潜在空间,任意图像的高质量图像重建。


驯服变压器 + VQGAN -> 基于向量索引预测的码本向量生成新图像。


CLIP + VQGAN -> 基于根据文本描述改变的图像潜在空间生成新图像。

文章来源:https://medium.com/@olga.mindlina/vqgan-from-image-reconstruction-to-new-image-generation-step-by-step-02af2cbe1a52
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消