英文

蝴蝶生成对抗网络

模型描述

基于Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis这篇文章。文章提到:“值得注意的是,该模型只需在一块RTX-2080 GPU上进行几个小时的训练就能从头开始收敛,并且具有一致的性能,即使训练样本少于100张。” 该模型也被称为Light-GAN模型。该模型使用了来自lucidrains的脚本 repo 进行训练,但我使用了官方库中的转换方式,因为我们的训练图像已经被裁剪和对齐。 repo 的官方论文中有更多细节。

用途和限制

用于娱乐和学习~

如何使用

使用

import torch
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN # install the community-events repo above

gan = LightweightGAN.from_pretrained("ceyda/butterfly_cropped_uniq1K_512")
gan.eval()
batch_size = 1
with torch.no_grad():
        ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)*255
        ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8)
        # ims is [BxWxHxC] call Image.fromarray(ims[0])

限制和偏见

  • 在训练过程中,我过滤了数据集,只保留了每个物种的一只蝴蝶。否则,模型生成的蝴蝶变化较少(一些物种的图像较多会占据主导地位)。
  • 数据集还使用了['漂亮的蝴蝶','一只蝴蝶','展翅蝴蝶','多彩的蝴蝶']的CLIP分数进行了过滤。尽管这样做是为了排除不包含蝴蝶的图像(只有科学标签,杂乱的图像)从 full dataset 中。但很容易想象在某些情况下,这种方法可能会导致问题;谁能说哪只蝴蝶“漂亮”并应该包含在数据集中呢?即使CLIP无法识别一只蝴蝶,也可能导致其被排除在数据集之外,引发偏见。

训练数据

使用了1000张图像,虽然有可能增加这个数量,但我们没有时间手动筛选数据集,并且想看看是否可能像论文中提到的那样进行低数据训练。更多详细信息请查看 data card

训练过程

在2个A4000上进行了大约1天的训练。在7-12小时内可以看到良好的结果。重要参数:"--batch_size 64 --gradient_accumulate_every 4 --image_size 512 --mixed_precision fp16"训练日志可以在 here 中查看。

评估结果

对100张图像进行了FID分数计算,不同的检查点得到的结果为 here ,但不能说它太有意义(由于FID分数的缺点)。

生成的图像

demo 中进行操作。

BibTeX条目和引用信息

在huggan sprint期间制作。

模型训练者:Ceyda Cinarel https://twitter.com/ceyda_cinarel

Jonathan Whitaker的额外贡献 https://twitter.com/johnowhitaker