模型:
ceyda/butterfly_cropped_uniq1K_512
基于Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis这篇文章。文章提到:“值得注意的是,该模型只需在一块RTX-2080 GPU上进行几个小时的训练就能从头开始收敛,并且具有一致的性能,即使训练样本少于100张。” 该模型也被称为Light-GAN模型。该模型使用了来自lucidrains的脚本 repo 进行训练,但我使用了官方库中的转换方式,因为我们的训练图像已经被裁剪和对齐。 repo 的官方论文中有更多细节。
用于娱乐和学习~
使用
import torch from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN # install the community-events repo above gan = LightweightGAN.from_pretrained("ceyda/butterfly_cropped_uniq1K_512") gan.eval() batch_size = 1 with torch.no_grad(): ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)*255 ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8) # ims is [BxWxHxC] call Image.fromarray(ims[0])
使用了1000张图像,虽然有可能增加这个数量,但我们没有时间手动筛选数据集,并且想看看是否可能像论文中提到的那样进行低数据训练。更多详细信息请查看 data card 。
在2个A4000上进行了大约1天的训练。在7-12小时内可以看到良好的结果。重要参数:"--batch_size 64 --gradient_accumulate_every 4 --image_size 512 --mixed_precision fp16"训练日志可以在 here 中查看。
对100张图像进行了FID分数计算,不同的检查点得到的结果为 here ,但不能说它太有意义(由于FID分数的缺点)。
在 demo 中进行操作。
在huggan sprint期间制作。
模型训练者:Ceyda Cinarel https://twitter.com/ceyda_cinarel
Jonathan Whitaker的额外贡献 https://twitter.com/johnowhitaker