在本文中,我将尝试从头开始创建一个小规模的稳定扩散类模型。这里的小规模意味着我们将使用你可能听说过的小型数据集MNIST 。选择这个数据集的原因是训练过程不应该花费太多时间,让我们能够快速检查我们的结果是否变得更好。在整个博客中,你看到的所有代码都可以在我的 GitHub 存储库中找到。
Topic Video Link
Neural Network Neural Network Video
Pytorch Pytorch Video
Linear Algebra Linear Algebra Video
这种选择是由于分辨率为512x512的彩色图像有巨大数量的潜在取值。相比之下,稳定扩散使用的压缩图像小了48倍,包含的取值更少。这种显著降低的处理需求使得在配备有8GB RAM的NVIDIA GPU的台式电脑上使用稳定扩散成为可能。小型潜在空间的有效性基于自然图像遵循模式而非随机性的观点。稳定扩散在解码器中使用变分自编码器(VAE)文件来捕捉精细细节,如眼睛。
Stable Diffusion V1的训练使用了三个由LAION从Common Crawl编译的数据集。这包括有着6分或更高美学评分的图片的LAION-Aesthetics v2.6数据集。
我们将使用 torchvision 模块中的 MNIST 数据集,它包含的是小型的 28x28 手写数字图像(0-9)。正如之前提到的,我们需要一个小型数据集以便训练不会耗时太长。让我们来看一下我们的数据集长什么样子。
# Import the required libraries
import torch
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# Download and load the training dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
# Extract a batch of unique images
unique_images, unique_labels = next(iter(train_loader))
unique_images = unique_images.numpy()
# Display a grid of unique images
fig, axes = plt.subplots(4, 16, figsize=(16, 4), sharex=True, sharey=True) # Create a 4x16 grid of subplots with a wider figure
for i in range(4): # Loop over rows
for j in range(16): # Loop over columns
index = i * 16 + j # Calculate the index in the batch
axes[i, j].imshow(unique_images[index].squeeze(), cmap='gray') # Show the image using a grayscale colormap
axes[i, j].axis('off') # Turn off axis labels and ticks
plt.show() # Display the plot
# Import the PyTorch library for tensor operations.
import torch
# Import the neural network module from PyTorch.
import torch.nn as nn
# Import functional operations from PyTorch.
import torch.nn.functional as F
# Import the 'numpy' library for numerical operations.
import numpy as np
# Import the 'functools' module for higher-order functions.
import functools
# Import the Adam optimizer from PyTorch.
from torch.optim import Adam
# Import the DataLoader class from PyTorch for handling datasets.
from torch.utils.data import DataLoader
# Import data transformation functions from torchvision.
import torchvision.transforms as transforms
# Import the MNIST dataset from torchvision.
from torchvision.datasets import MNIST
# Import 'tqdm' for creating progress bars during training.
import tqdm
# Import 'trange' and 'tqdm' specifically for notebook compatibility.
from tqdm.notebook import trange, tqdm
# Import the learning rate scheduler from PyTorch.
from torch.optim.lr_scheduler import MultiplicativeLR, LambdaLR
# Import the 'matplotlib.pyplot' library for plotting graphs.
import matplotlib.pyplot as plt
# Import the 'make_grid' function from torchvision.utils for visualizing image grids.
from torchvision.utils import make_grid
# Importing the `rearrange` function from the `einops` library
from einops import rearrange
# Importing the `math` module for mathematical operations
import math
# Install the 'einops' library for easy manipulation of tensors
pip install einops
# Install the 'lpips' library for computing perceptual similarity between images
pip install lpips
# Forward diffusion for N steps in 1D.
def forward_diffusion_1D(x0, noise_strength_fn, t0, nsteps, dt):
- x0: Initial sample value (scalar)
- noise_strength_fn: Function of time, outputs scalar noise strength
- t0: Initial time
- nsteps: Number of diffusion steps
- dt: Time step size
- x: Trajectory of sample values over time
- t: Corresponding time points for the trajectory
# Initialize the trajectory array
x = np.zeros(nsteps + 1)
# Set the initial sample value
x[0] = x0
# Generate time points for the trajectory
t = t0 + np.arange(nsteps + 1) * dt
# Perform Euler-Maruyama time steps for diffusion simulation
for i in range(nsteps):
# Get the noise strength at the current time
noise_strength = noise_strength_fn(t[i])
# Generate a random normal variable
random_normal = np.random.randn()
# Update the trajectory using Euler-Maruyama method
x[i + 1] = x[i] + random_normal * noise_strength
# Return the trajectory and corresponding time points
return x, t
# Example noise strength function: always equal to 1
def noise_strength_constant(t):
Example noise strength function that returns a constant value (1).
- t: Time parameter (unused in this example)
- Constant noise strength (1)
return 1
# Number of diffusion steps
nsteps = 100
# Initial time
t0 = 0
# Time step size
dt = 0.1
# Noise strength function
noise_strength_fn = noise_strength_constant
# Initial sample value
x0 = 0
# Number of tries for visualization
num_tries = 5
# Setting larger width and smaller height for the plot
plt.figure(figsize=(15, 5))
# Loop for multiple trials
for i in range(num_tries):
# Simulate forward diffusion
x, t = forward_diffusion_1D(x0, noise_strength_fn, t0, nsteps, dt)
# Plot the trajectory
plt.plot(t, x, label=f'Trial {i+1}') # Adding a label for each trial
# Labeling the plot
plt.xlabel('Time', fontsize=20)
plt.ylabel('Sample Value ($x$)', fontsize=20)
# Title of the plot
plt.title('Forward Diffusion Visualization', fontsize=20)
# Adding a legend to identify each trial
# Show the plot
s(x,t) 被称为得分函数。知道这个函数使我们能够逆转前向扩散,并将噪声还原为我们的初始状态。
如果我们的起点始终是 x0=0 的单一点,并且噪声强度是恒定的,那么得分函数恰好等于
# Reverse diffusion for N steps in 1D.
def reverse_diffusion_1D(x0, noise_strength_fn, score_fn, T, nsteps, dt):
- x0: Initial sample value (scalar)
- noise_strength_fn: Function of time, outputs scalar noise strength
- score_fn: Score function
- T: Final time
- nsteps: Number of diffusion steps
- dt: Time step size
- x: Trajectory of sample values over time
- t: Corresponding time points for the trajectory
# Initialize the trajectory array
x = np.zeros(nsteps + 1)
# Set the initial sample value
x[0] = x0
# Generate time points for the trajectory
t = np.arange(nsteps + 1) * dt
# Perform Euler-Maruyama time steps for reverse diffusion simulation
for i in range(nsteps):
# Calculate noise strength at the current time
noise_strength = noise_strength_fn(T - t[i])
# Calculate the score using the score function
score = score_fn(x[i], 0, noise_strength, T - t[i])
# Generate a random normal variable
random_normal = np.random.randn()
# Update the trajectory using the reverse Euler-Maruyama method
x[i + 1] = x[i] + score * noise_strength**2 * dt + noise_strength * random_normal * np.sqrt(dt)
# Return the trajectory and corresponding time points
return x, t
# Example score function: always equal to 1
def score_simple(x, x0, noise_strength, t):
- x: Current sample value (scalar)
- x0: Initial sample value (scalar)
- noise_strength: Scalar noise strength at the current time
- t: Current time
- score: Score calculated based on the provided formula
# Calculate the score using the provided formula
score = - (x - x0) / ((noise_strength**2) * t)
# Return the calculated score
return score
# Number of reverse diffusion steps
nsteps = 100
# Initial time for reverse diffusion
t0 = 0
# Time step size for reverse diffusion
dt = 0.1
# Function defining constant noise strength for reverse diffusion
noise_strength_fn = noise_strength_constant
# Example score function for reverse diffusion
score_fn = score_simple
# Initial sample value for reverse diffusion
x0 = 0
# Final time for reverse diffusion
T = 11
# Number of tries for visualization
num_tries = 5
# Setting larger width and smaller height for the plot
plt.figure(figsize=(15, 5))
# Loop for multiple trials
for i in range(num_tries):
# Draw from the noise distribution, which is diffusion for time T with noise strength 1
x0 = np.random.normal(loc=0, scale=T)
# Simulate reverse diffusion
x, t = reverse_diffusion_1D(x0, noise_strength_fn, score_fn, T, nsteps, dt)
# Plot the trajectory
plt.plot(t, x, label=f'Trial {i+1}') # Adding a label for each trial
# Labeling the plot
plt.xlabel('Time', fontsize=20)
plt.ylabel('Sample Value ($x$)', fontsize=20)
# Title of the plot
plt.title('Reverse Diffusion Visualized', fontsize=20)
# Adding a legend to identify each trial
# Show the plot
在这里,p0(x0)代表我们的目标分布(例如,汽车和猫的图像),而x(noised)表示经过单次向前扩散步骤后来自目标分布x0的样本。换句更简单的话说,[ x(noised) − x0 ]本质上是一个呈正态分布的随机变量。
# Define a module for Gaussian random features used to encode time steps.
class GaussianFourierProjection(nn.Module):
def __init__(self, embed_dim, scale=30.):
- embed_dim: Dimensionality of the embedding (output dimension)
- scale: Scaling factor for random weights (frequencies)
# Randomly sample weights (frequencies) during initialization.
# These weights (frequencies) are fixed during optimization and are not trainable.
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
def forward(self, x):
- x: Input tensor representing time steps
# Calculate the cosine and sine projections: Cosine(2 pi freq x), Sine(2 pi freq x)
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
# Concatenate the sine and cosine projections along the last dimension
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
# Define a module for a fully connected layer that reshapes outputs to feature maps.
class Dense(nn.Module):
def __init__(self, input_dim, output_dim):
- input_dim: Dimensionality of the input features
- output_dim: Dimensionality of the output features
# Define a fully connected layer
self.dense = nn.Linear(input_dim, output_dim)
def forward(self, x):
- x: Input tensor
- Output tensor after passing through the fully connected layer
and reshaping to a 4D tensor (feature map)
# Apply the fully connected layer and reshape the output to a 4D tensor
return self.dense(x)[..., None, None]
# This broadcasts the 2D tensor to a 4D tensor, adding the same value across space.
Dense模块旨在将全连接层的输出重塑成4D张量,有效地将其转换成一个特征图。该模块接受输入特征的维数(input_dim)和希望得到的输出特征的维数(output_dim)作为输入。在前向传播过程中,输入张量x通过全连接层(self.dense(x))处理,输出被重塑成4D张量,方法是在末尾添加两个单例维度([..., None, None])。这个重塑操作有效地将输出转换成适合在卷积层中进一步处理的特征图。该操作通过在空间维度上添加相同的值,将2D张量广播到4D张量。
# Define a time-dependent score-based model built upon the U-Net architecture.
class UNet(nn.Module):
def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
Initialize a time-dependent score-based network.
- marginal_prob_std: A function that takes time t and gives the standard deviation
of the perturbation kernel p_{0t}(x(t) | x(0)).
- channels: The number of channels for feature maps of each resolution.
- embed_dim: The dimensionality of Gaussian random feature embeddings.
# Gaussian random feature embedding layer for time
self.time_embed = nn.Sequential(
nn.Linear(embed_dim, embed_dim)
# Encoding layers where the resolution decreases
self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
self.dense1 = Dense(embed_dim, channels[0])
self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
self.dense2 = Dense(embed_dim, channels[1])
self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
# Additional encoding layers (copied from the original code)
self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
self.dense3 = Dense(embed_dim, channels[2])
self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
self.dense4 = Dense(embed_dim, channels[3])
self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
# Decoding layers where the resolution increases
self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
self.dense5 = Dense(embed_dim, channels[2])
self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
self.dense6 = Dense(embed_dim, channels[1])
self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
self.dense7 = Dense(embed_dim, channels[0])
self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)
# The swish activation function
self.act = lambda x: x * torch.sigmoid(x)
self.marginal_prob_std = marginal_prob_std
def forward(self, x, t, y=None):
- x: Input tensor
- t: Time tensor
- y: Target tensor (not used in this forward pass)
- h: Output tensor after passing through the U-Net architecture
# Obtain the Gaussian random feature embedding for t
embed = self.act(self.time_embed(t))
# Encoding path
h1 = self.conv1(x) + self.dense1(embed)
h1 = self.act(self.gnorm1(h1))
h2 = self.conv2(h1) + self.dense2(embed)
h2 = self.act(self.gnorm2(h2))
# Additional encoding path layers (copied from the original code)
h3 = self.conv3(h2) + self.dense3(embed)
h3 = self.act(self.gnorm3(h3))
h4 = self.conv4(h3) + self.dense4(embed)
h4 = self.act(self.gnorm4(h4))
# Decoding path
h = self.tconv4(h4)
h += self.dense5(embed)
h = self.act(self.tgnorm4(h))
h = self.tconv3(torch.cat([h, h3], dim=1))
h += self.dense6(embed)
h = self.act(self.tgnorm3(h))
h = self.tconv2(torch.cat([h, h2], dim=1))
h += self.dense7(embed)
h = self.act(self.tgnorm2(h))
h = self.tconv1(torch.cat([h, h1], dim=1))
# Normalize output
h = h / self.marginal_prob_std(t)[:, None, None, None]
return h
在U型网络模型的架构中,张量的形状随着信息通过编码和解码路径而变化。在编码路径中,涉及到下采样,张量在每一个卷积层上经历形状缩减 —— 依次是h1、h2、h3和h4。在解码路径中,转置卷积层开始恢复空间信息。张量h开始恢复原始的空间尺寸,在每一步 (h4到h1) 中,从早期层加回的特征有助于上采样。最后,由h代表的最后一层产出输出,并且标准化步骤确保生成图像的适当缩放。张量形状的具体情况取决于卷积层中使用的滤波器大小、步幅和填充等因素,这些因素塑造了模型捕获和重建细节的能力。
# Define a time-dependent score-based model built upon the U-Net architecture.
class UNet_res(nn.Module):
def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
- marginal_prob_std: A function that takes time t and gives the standard deviation
of the perturbation kernel p_{0t}(x(t) | x(0)).
- channels: The number of channels for feature maps of each resolution.
- embed_dim: The dimensionality of Gaussian random feature embeddings.
# Gaussian random feature embedding layer for time
self.time_embed = nn.Sequential(
nn.Linear(embed_dim, embed_dim)
# Encoding layers where the resolution decreases
self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
self.dense1 = Dense(embed_dim, channels[0])
self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
self.dense2 = Dense(embed_dim, channels[1])
self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
self.dense3 = Dense(embed_dim, channels[2])
self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
self.dense4 = Dense(embed_dim, channels[3])
self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
# Decoding layers where the resolution increases
self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
self.dense5 = Dense(embed_dim, channels[2])
self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
self.dense6 = Dense(embed_dim, channels[1])
self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
self.dense7 = Dense(embed_dim, channels[0])
self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
self.tconv1 = nn.ConvTranspose2d(channels[0], 1, 3, stride=1)
# The swish activation function
self.act = lambda x: x * torch.sigmoid(x)
self.marginal_prob_std = marginal_prob_std
def forward(self, x, t, y=None):
- x: Input tensor
- t: Time tensor
- y: Target tensor (not used in this forward pass)
- h: Output tensor after passing through the U-Net architecture
# Obtain the Gaussian random feature embedding for t
embed = self.act(self.time_embed(t))
# Encoding path
h1 = self.conv1(x) + self.dense1(embed)
h1 = self.act(self.gnorm1(h1))
h2 = self.conv2(h1) + self.dense2(embed)
h2 = self.act(self.gnorm2(h2))
h3 = self.conv3(h2) + self.dense3(embed)
h3 = self.act(self.gnorm3(h3))
h4 = self.conv4(h3) + self.dense4(embed)
h4 = self.act(self.gnorm4(h4))
# Decoding path
h = self.tconv4(h4)
h += self.dense5(embed)
h = self.act(self.tgnorm4(h))
h = self.tconv3(h + h3)
h += self.dense6(embed)
h = self.act(self.tgnorm3(h))
h = self.tconv2(h + h2)
h += self.dense7(embed)
h = self.act(self.tgnorm2(h))
h = self.tconv1(h + h1)
# Normalize output
h = h / self.marginal_prob_std(t)[:, None, None, None]
return h
在这个上下文中,σ(t) 被称为边际标准差。本质上,它代表了在给定初始值 x(0) 的情况下 x(t) 分布的变异性。
# Using GPU
device = "cuda"
# Marginal Probability Standard Deviation Function
def marginal_prob_std(t, sigma):
Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.
- t: A vector of time steps.
- sigma: The $\sigma$ in our SDE.
- The standard deviation.
# Convert time steps to a PyTorch tensor
t = torch.tensor(t, device=device)
# Calculate and return the standard deviation based on the given formula
return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))
# Using GPU
device = "cuda"
def diffusion_coeff(t, sigma):
Compute the diffusion coefficient of our SDE.
- t: A vector of time steps.
- sigma: The $\sigma$ in our SDE.
- The vector of diffusion coefficients.
# Calculate and return the diffusion coefficients based on the given formula
return torch.tensor(sigma**t, device=device)
现在我们将边际概率标准差和扩散系数初始化为 sigma 25。
# Sigma Value
sigma = 25.0
# marginal probability standard
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
# diffusion coefficient
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)
def loss_fn(model, x, marginal_prob_std, eps=1e-5):
The loss function for training score-based generative models.
- model: A PyTorch model instance that represents a time-dependent score-based model.
- x: A mini-batch of training data.
- marginal_prob_std: A function that gives the standard deviation of the perturbation kernel.
- eps: A tolerance value for numerical stability.
# Sample time uniformly in the range (eps, 1-eps)
random_t = torch.rand(x.shape[0], device=x.device) * (1. - 2 * eps) + eps
# Find the noise std at the sampled time `t`
std = marginal_prob_std(random_t)
# Generate normally distributed noise
z = torch.randn_like(x)
# Perturb the input data with the generated noise
perturbed_x = x + z * std[:, None, None, None]
# Get the score from the model using the perturbed data and time
score = model(perturbed_x, random_t)
# Calculate the loss based on the score and noise
loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1, 2, 3)))
return loss
# number of steps
num_steps = 500
def Euler_Maruyama_sampler(score_model,
x_shape=(1, 28, 28),
eps=1e-3, y=None):
Generate samples from score-based models with the Euler-Maruyama solver.
- score_model: A PyTorch model that represents the time-dependent score-based model.
- marginal_prob_std: A function that gives the standard deviation of the perturbation kernel.
- diffusion_coeff: A function that gives the diffusion coefficient of the SDE.
- batch_size: The number of samplers to generate by calling this function once.
- x_shape: The shape of the samples.
- num_steps: The number of sampling steps, equivalent to the number of discretized time steps.
- device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
- eps: The smallest time step for numerical stability.
- y: Target tensor (not used in this function).
- Samples.
# Initialize time and the initial sample
t = torch.ones(batch_size, device=device)
init_x = torch.randn(batch_size, *x_shape, device=device) * marginal_prob_std(t)[:, None, None, None]
# Generate time steps
time_steps = torch.linspace(1., eps, num_steps, device=device)
step_size = time_steps[0] - time_steps[1]
x = init_x
# Sample using Euler-Maruyama method
with torch.no_grad():
for time_step in tqdm(time_steps):
batch_time_step = torch.ones(batch_size, device=device) * time_step
g = diffusion_coeff(batch_time_step)
mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step, y=y) * step_size
x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
# Do not include any noise in the last sampling step.
return mean_x
# Define the score-based model and move it to the specified device
score_model = torch.nn.DataParallel(UNet(marginal_prob_std=marginal_prob_std_fn))
score_model = score_model.to(device)
# Number of training epochs
n_epochs = 50
# Size of a mini-batch
batch_size = 2048
# Learning rate
lr = 5e-4
# Load the MNIST dataset and create a data loader
dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
# Define the Adam optimizer for training the model
optimizer = Adam(score_model.parameters(), lr=lr)
# Progress bar for epochs
tqdm_epoch = trange(n_epochs)
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
# Iterate through mini-batches in the data loader
for x, y in tqdm(data_loader):
x = x.to(device)
# Calculate the loss and perform backpropagation
loss = loss_fn(score_model, x, marginal_prob_std_fn)
avg_loss += loss.item() * x.shape[0]
num_items += x.shape[0]
# Print the averaged training loss for the current epoch
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
# Save the model checkpoint after each epoch of training
torch.save(score_model.state_dict(), 'ckpt.pth')
# Load the pre-trained checkpoint from disk.
device = 'cuda'
# Load the pre-trained model checkpoint
ckpt = torch.load('ckpt.pth', map_location=device)
# Set sample batch size and number of steps
sample_batch_size = 64
num_steps = 500
# Choose the Euler-Maruyama sampler
sampler = Euler_Maruyama_sampler
# Generate samples using the specified sampler
samples = sampler(score_model,
# Clip samples to be in the range [0, 1]
samples = samples.clamp(0.0, 1.0)
# Visualize the generated samples
%matplotlib inline
import matplotlib.pyplot as plt
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))
# Plot the sample grid
plt.figure(figsize=(6, 6))
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
# Initialize the alternate U-Net model for training.
score_model = torch.nn.DataParallel(UNet_res(marginal_prob_std=marginal_prob_std_fn))
score_model = score_model.to(device)
# Set the number of training epochs, mini-batch size, and learning rate.
n_epochs = 75
batch_size = 1024
lr = 1e-3
# Load the MNIST dataset for training.
dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
# Initialize the Adam optimizer with the specified learning rate.
optimizer = Adam(score_model.parameters(), lr=lr)
# Learning rate scheduler to adjust the learning rate during training.
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.2, 0.98 ** epoch))
# Training loop over epochs.
tqdm_epoch = trange(n_epochs)
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
# Iterate over mini-batches in the training data loader.
for x, y in data_loader:
x = x.to(device)
# Compute the loss for the current mini-batch.
loss = loss_fn(score_model, x, marginal_prob_std_fn)
# Zero the gradients, backpropagate, and update the model parameters.
# Accumulate the total loss and the number of processed items.
avg_loss += loss.item() * x.shape[0]
num_items += x.shape[0]
# Adjust the learning rate using the scheduler.
lr_current = scheduler.get_last_lr()[0]
# Print the average loss and learning rate for the current epoch.
print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current))
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
# Save the model checkpoint after each epoch of training.
torch.save(score_model.state_dict(), 'ckpt_res.pth')
# Load the pre-trained checkpoint from disk.
device = 'cuda'
# Load the pre-trained model checkpoint
ckpt = torch.load('ckpt_res.pth', map_location=device)
# Set sample batch size and number of steps
sample_batch_size = 64
num_steps = 500
# Choose the Euler-Maruyama sampler
sampler = Euler_Maruyama_sampler
# Generate samples using the specified sampler
samples = sampler(score_model,
# Clip samples to be in the range [0, 1]
samples = samples.clamp(0.0, 1.0)
# Visualize the generated samples
%matplotlib inline
import matplotlib.pyplot as plt
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))
# Plot the sample grid
plt.figure(figsize=(6, 6))
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
这种注意力分配有助于挑选出相关的特征组合。例如,当将英语短语“This is cool”翻译成法语时,正确的答案(“c’est cool”)涉及到同时关注两个单词,而不是分别逐个单词进行翻译。在数学上,我们使用注意力分布来加权这些值(vj):
class CrossAttention(nn.Module):
def __init__(self, embed_dim, hidden_dim, context_dim=None, num_heads=1):
Initialize the CrossAttention module.
- embed_dim: The dimensionality of the output embeddings.
- hidden_dim: The dimensionality of the hidden representations.
- context_dim: The dimensionality of the context representations (if not self attention).
- num_heads: Number of attention heads (currently supports 1 head).
Note: For simplicity reasons, the implementation assumes 1-head attention.
Feel free to implement multi-head attention using fancy tensor manipulations.
super(CrossAttention, self).__init__()
self.hidden_dim = hidden_dim
self.context_dim = context_dim
self.embed_dim = embed_dim
# Linear layer for query projection
self.query = nn.Linear(hidden_dim, embed_dim, bias=False)
# Check if self-attention or cross-attention
if context_dim is None:
self.self_attn = True
self.key = nn.Linear(hidden_dim, embed_dim, bias=False)
self.value = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.self_attn = False
self.key = nn.Linear(context_dim, embed_dim, bias=False)
self.value = nn.Linear(context_dim, hidden_dim, bias=False)
def forward(self, tokens, context=None):
Forward pass of the CrossAttention module.
- tokens: Input tokens with shape [batch, sequence_len, hidden_dim].
- context: Context information with shape [batch, context_seq_len, context_dim].
If self_attn is True, context is ignored.
- ctx_vecs: Context vectors after attention with shape [batch, sequence_len, embed_dim].
if self.self_attn:
# Self-attention case
Q = self.query(tokens)
K = self.key(tokens)
V = self.value(tokens)
# Cross-attention case
Q = self.query(tokens)
K = self.key(context)
V = self.value(context)
# Compute score matrices, attention matrices, and context vectors
scoremats = torch.einsum("BTH,BSH->BTS", Q, K) # Inner product of Q and K, a tensor
attnmats = F.softmax(scoremats / math.sqrt(self.embed_dim), dim=-1) # Softmax of scoremats
ctx_vecs = torch.einsum("BTS,BSH->BTH", attnmats, V) # Weighted average value vectors by attnmats
return ctx_vecs
CrossAttention 类是一个为处理神经网络中的注意力机制而设计的模块。它接收输入令牌,并且可选地接收上下文信息。如果用于自注意力,它专注于输入令牌之内的关系。在交叉注意力的情况下,它考虑输入令牌和上下文信息之间的交互作用。该模块采用线性投影来进行查询(query)、键(key)和值(value)的转换。它计算得分矩阵,应用 softmax 函数获得注意力权重,并通过基于注意力权重合并加权后的值来计算上下文向量。这一机制允许网络有选择性地关注输入或上下文的不同部分,帮助在学习过程中捕获相关信息。forward 方法实现了这些操作,返回经过注意力处理后的上下文向量。
让我们接着讨论第二个注意力层,称为 TransformerBlock。
class TransformerBlock(nn.Module):
"""The transformer block that combines self-attn, cross-attn, and feed forward neural net"""
def __init__(self, hidden_dim, context_dim):
Initialize the TransformerBlock.
- hidden_dim: The dimensionality of the hidden state.
- context_dim: The dimensionality of the context tensor.
Note: For simplicity, the self-attn and cross-attn use the same hidden_dim.
super(TransformerBlock, self).__init__()
# Self-attention module
self.attn_self = CrossAttention(hidden_dim, hidden_dim)
# Cross-attention module
self.attn_cross = CrossAttention(hidden_dim, hidden_dim, context_dim)
# Layer normalization modules
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
self.norm3 = nn.LayerNorm(hidden_dim)
# Implement a 2-layer MLP with K * hidden_dim hidden units, and nn.GELU nonlinearity
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, 3 * hidden_dim),
nn.Linear(3 * hidden_dim, hidden_dim)
def forward(self, x, context=None):
Forward pass of the TransformerBlock.
- x: Input tensor with shape [batch, sequence_len, hidden_dim].
- context: Context tensor with shape [batch, context_seq_len, context_dim].
- x: Output tensor after passing through the TransformerBlock.
# Apply self-attention with layer normalization and residual connection
x = self.attn_self(self.norm1(x)) + x
# Apply cross-attention with layer normalization and residual connection
x = self.attn_cross(self.norm2(x), context=context) + x
# Apply feed forward neural network with layer normalization and residual connection
x = self.ffn(self.norm3(x)) + x
return x
TransformerBlock类代表了在Transformer模型中的一个构建块,它集成了自注意力(self-attention)、交叉注意力(cross-attention)和一个前馈神经网络。它接受形状为[batch, sequence_len, hidden_dim]的输入张量,并且可选地接受一个形状为[batch, context_seq_len, context_dim]的上下文张量。自注意力和交叉注意力模块之后是层归一化(layer normalization)和残差连接。另外,该模块包含了一个带有GELU非线性激活函数的两层MLP,用于进一步的非线性变换。在通过TransformerBlock处理后得到输出。
class SpatialTransformer(nn.Module):
def __init__(self, hidden_dim, context_dim):
Initialize the SpatialTransformer.
- hidden_dim: The dimensionality of the hidden state.
- context_dim: The dimensionality of the context tensor.
super(SpatialTransformer, self).__init__()
# TransformerBlock for spatial transformation
self.transformer = TransformerBlock(hidden_dim, context_dim)
def forward(self, x, context=None):
Forward pass of the SpatialTransformer.
- x: Input tensor with shape [batch, channels, height, width].
- context: Context tensor with shape [batch, context_seq_len, context_dim].
- x: Output tensor after applying spatial transformation.
b, c, h, w = x.shape
x_in = x
# Combine the spatial dimensions and move the channel dimension to the end
x = rearrange(x, "b c h w -> b (h w) c")
# Apply the sequence transformer
x = self.transformer(x, context)
# Reverse the process
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
# Residue connection
return x + x_in
现在,你可以将空间变换器层(SpatialTransformer layers)集成到我们的U-Net架构中。
class UNet_Tranformer(nn.Module):
"""A time-dependent score-based model built upon U-Net architecture."""
def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256,
text_dim=256, nClass=10):
Initialize a time-dependent score-based network.
- marginal_prob_std: A function that takes time t and gives the standard deviation
of the perturbation kernel p_{0t}(x(t) | x(0)).
- channels: The number of channels for feature maps of each resolution.
- embed_dim: The dimensionality of Gaussian random feature embeddings of time.
- text_dim: The embedding dimension of text/digits.
- nClass: Number of classes to model.
# Gaussian random feature embedding layer for time
self.time_embed = nn.Sequential(
nn.Linear(embed_dim, embed_dim)
# Encoding layers where the resolution decreases
self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
self.dense1 = Dense(embed_dim, channels[0])
self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
self.dense2 = Dense(embed_dim, channels[1])
self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
self.dense3 = Dense(embed_dim, channels[2])
self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
self.attn3 = SpatialTransformer(channels[2], text_dim)
self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
self.dense4 = Dense(embed_dim, channels[3])
self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
self.attn4 = SpatialTransformer(channels[3], text_dim)
# Decoding layers where the resolution increases
self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
self.dense5 = Dense(embed_dim, channels[2])
self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
self.dense6 = Dense(embed_dim, channels[1])
self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
self.dense7 = Dense(embed_dim, channels[0])
self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
self.tconv1 = nn.ConvTranspose2d(channels[0], 1, 3, stride=1)
# The swish activation function
self.act = nn.SiLU()
self.marginal_prob_std = marginal_prob_std
self.cond_embed = nn.Embedding(nClass, text_dim)
def forward(self, x, t, y=None):
Forward pass of the UNet_Transformer model.
- x: Input tensor.
- t: Time tensor.
- y: Target tensor.
- h: Output tensor after passing through the UNet_Transformer architecture.
# Obtain the Gaussian random feature embedding for t
embed = self.act(self.time_embed(t))
y_embed = self.cond_embed(y).unsqueeze(1)
# Encoding path
h1 = self.conv1(x) + self.dense1(embed)
h1 = self.act(self.gnorm1(h1))
h2 = self.conv2(h1) + self.dense2(embed)
h2 = self.act(self.gnorm2(h2))
h3 = self.conv3(h2) + self.dense3(embed)
h3 = self.act(self.gnorm3(h3))
h3 = self.attn3(h3, y_embed)
h4 = self.conv4(h3) + self.dense4(embed)
h4 = self.act(self.gnorm4(h4))
h4 = self.attn4(h4, y_embed)
# Decoding path
h = self.tconv4(h4) + self.dense5(embed)
h = self.act(self.tgnorm4(h))
h = self.tconv3(h + h3) + self.dense6(embed)
h = self.act(self.tgnorm3(h))
h = self.tconv2(h + h2) + self.dense7(embed)
h = self.act(self.tgnorm2(h))
h = self.tconv1(h + h1)
# Normalize output
h = h / self.marginal_prob_std(t)[:, None, None, None]
return h
def loss_fn_cond(model, x, y, marginal_prob_std, eps=1e-5):
"""The loss function for training score-based generative models with conditional information.
- model: A PyTorch model instance that represents a time-dependent score-based model.
- x: A mini-batch of training data.
- y: Conditional information (target tensor).
- marginal_prob_std: A function that gives the standard deviation of the perturbation kernel.
- eps: A tolerance value for numerical stability.
- loss: The calculated loss.
# Sample time uniformly in the range [eps, 1-eps]
random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
# Generate random noise with the same shape as the input
z = torch.randn_like(x)
# Compute the standard deviation of the perturbation kernel at the sampled time
std = marginal_prob_std(random_t)
# Perturb the input data with the generated noise and scaled by the standard deviation
perturbed_x = x + z * std[:, None, None, None]
# Get the model's score for the perturbed input, considering conditional information
score = model(perturbed_x, random_t, y=y)
# Calculate the loss using the score and perturbation
loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1, 2, 3)))
return loss
# Specify whether to continue training or initialize a new model
continue_training = False # Either True or False
if not continue_training:
# Initialize a new UNet with Transformer model
score_model = torch.nn.DataParallel(UNet_Tranformer(marginal_prob_std=marginal_prob_std_fn))
score_model = score_model.to(device)
# Set training hyperparameters
n_epochs = 100 #{'type':'integer'}
batch_size = 1024 #{'type':'integer'}
lr = 10e-4 #{'type':'number'}
# Load the MNIST dataset and create a data loader
dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
# Define the optimizer and learning rate scheduler
optimizer = Adam(score_model.parameters(), lr=lr)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.2, 0.98 ** epoch))
# Use tqdm to display a progress bar over epochs
tqdm_epoch = trange(n_epochs)
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
# Iterate over batches in the data loader
for x, y in tqdm(data_loader):
x = x.to(device)
# Compute the loss using the conditional score-based model
loss = loss_fn_cond(score_model, x, y, marginal_prob_std_fn)
avg_loss += loss.item() * x.shape[0]
num_items += x.shape[0]
# Adjust learning rate using the scheduler
lr_current = scheduler.get_last_lr()[0]
# Print epoch information including average loss and current learning rate
print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current))
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
# Save the model checkpoint after each epoch of training
torch.save(score_model.state_dict(), 'ckpt_transformer.pth')
## Load the pre-trained checkpoint from disk.
# device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}
ckpt = torch.load('ckpt_transformer.pth', map_location=device)
########### Specify the digit for which to generate samples
digit = 9 #@param {'type':'integer'}
# Set the batch size for generating samples
sample_batch_size = 64 #@param {'type':'integer'}
# Set the number of steps for the Euler-Maruyama sampler
num_steps = 250 #@param {'type':'integer'}
# Choose the sampler type (Euler-Maruyama, pc_sampler, ode_sampler)
sampler = Euler_Maruyama_sampler #@param ['Euler_Maruyama_sampler', 'pc_sampler', 'ode_sampler'] {'type': 'raw'}
# score_model.eval()
## Generate samples using the specified sampler.
samples = sampler(score_model,
y=digit*torch.ones(sample_batch_size, dtype=torch.long))
## Sample visualization.
samples = samples.clamp(0.0, 1.0)
%matplotlib inline
import matplotlib.pyplot as plt
# Create a grid of samples for visualization
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))
# Plot the generated samples
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)