模型:
CompVis/ldm-celebahq-256
论文: High-Resolution Image Synthesis with Latent Diffusion Models
摘要:
通过将图像形成过程分解为去噪自动编码器的连续应用,扩散模型(DMs)在图像数据及其他领域上实现了最先进的合成结果。此外,它们的表达允许通过引导机制来控制图像生成过程,而无需重新训练。然而,由于这些模型通常直接在像素空间中操作,优化强大的DMs通常需要消耗数百个GPU天,并且由于顺序评估,推断是昂贵的。为了在有限的计算资源上进行DM培训,同时保持其质量和灵活性,我们将它们应用于强大的预训练自动编码器的潜在空间中。与以前的工作不同,对这种表示进行扩散模型训练首次允许在复杂性减少和细节保留之间达到近乎最优的点,从而大大提高了视觉保真度。通过将交叉注意力层引入模型架构,我们将扩散模型转变为适用于文本或边界框等一般条件输入的强大而灵活的生成器,并且高分辨率合成以卷积方式成为可能。我们的潜在扩散模型(LDMs)在图像修复方面取得了新的最先进水平,并在各种任务上表现出极具竞争力的性能,包括无条件图像生成、语义场景合成和超分辨率,同时与基于像素的DM相比,大大减少了计算要求。
作者
Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer
!pip install diffusers from diffusers import DiffusionPipeline model_id = "CompVis/ldm-celebahq-256" # load model and scheduler pipeline = DiffusionPipeline.from_pretrained(model_id) # run pipeline in inference (sample random noise and denoise) image = pipeline(num_inference_steps=200)["sample"] # save image image[0].save("ldm_generated_image.png")
!pip install diffusers from diffusers import UNet2DModel, DDIMScheduler, VQModel import torch import PIL.Image import numpy as np import tqdm seed = 3 # load all models unet = UNet2DModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="unet") vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae") scheduler = DDIMScheduler.from_config("CompVis/ldm-celebahq-256", subfolder="scheduler") # set to cuda torch_device = "cuda" if torch.cuda.is_available() else "cpu" unet.to(torch_device) vqvae.to(torch_device) # generate gaussian noise to be decoded generator = torch.manual_seed(seed) noise = torch.randn( (1, unet.in_channels, unet.sample_size, unet.sample_size), generator=generator, ).to(torch_device) # set inference steps for DDIM scheduler.set_timesteps(num_inference_steps=200) image = noise for t in tqdm.tqdm(scheduler.timesteps): # predict noise residual of previous image with torch.no_grad(): residual = unet(image, t)["sample"] # compute previous image x_t according to DDIM formula prev_image = scheduler.step(residual, t, image, eta=0.0)["prev_sample"] # x_t-1 -> x_t image = prev_image # decode image with vae with torch.no_grad(): image = vqvae.decode(image) # process image image_processed = image.cpu().permute(0, 2, 3, 1) image_processed = (image_processed + 1.0) * 127.5 image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8) image_pil = PIL.Image.fromarray(image_processed[0]) image_pil.save(f"generated_image_{seed}.png")