使用Python从头构建AI文本转视频模型

2024年06月14日 由 alex 发表 137 0

OpenAI 的 Sora、Stability AI 的 Stable Video Diffusion 以及许多其他已经问世或未来将出现的文本转视频模型,是继大型语言模型 (LLM) 之后 2024 年最流行的 AI 趋势之一。在本文中,我们将从头开始构建一个小规模的文本转视频模型。我们将输入一个文本提示,我们训练过的模型将根据该提示生成视频。本文将涵盖从理解理论概念到编码整个架构并生成最终结果的所有内容。


在 CPU 上运行显然需要更长的时间来训练模型。如果你需要快速测试代码中的更改并查看结果,CPU 不是最佳选择。我建议使用Colab或Kaggle的 T4 GPU进行更高效、更快速的训练。


我们正在构建什么

我们将采用与传统机器学习或深度学习模型类似的方法,在数据集上进行训练,然后在未见过的数据上进行测试。在文本到视频的背景下,假设我们有一个包含 10 万个狗捡球和猫追老鼠视频的训练数据集。我们将对模型进行训练,以生成猫抓球或狗追老鼠的视频。


2


虽然此类训练数据集在互联网上很容易获得,但所需的计算能力却非常高。因此,我们将使用由 Python 代码生成的移动物体视频数据集。


我们将使用 GAN(生成对抗网络)架构来创建模型,而不是 OpenAI Sora 使用的扩散模型。我曾尝试使用扩散模型,但它因内存需求而崩溃,这超出了我的能力范围。另一方面,GAN 在训练和测试方面更加简单快捷。


先决条件

我们将使用 OOP(面向对象编程),因此你必须对它和神经网络有基本的了解。对 GAN(生成对抗网络)的了解不是必须的,因为我们将在这里介绍其架构。


Topic                          Link
OOP Video Link
Neural Networks Theory Video Link
GAN Architecture Video Link
Python basics Video Link


了解 GAN 架构

了解 GAN 非常重要,因为我们的大部分架构都依赖于它。让我们来探讨一下什么是 GAN、GAN 的组成部分等。


什么是 GAN?

生成对抗网络(GAN)是一种深度学习模型,其中有两个神经网络在竞争:一个神经网络从给定的数据集中创建新数据(如图像或音乐),另一个神经网络则尝试辨别数据的真伪。这一过程一直持续到生成的数据与原始数据无法区分为止。


实际应用

  1. 生成图像: GAN 可根据文本提示生成逼真的图像或修改现有图像,例如增强分辨率或为黑白照片添加颜色。
  2. 数据增强: 它们生成合成数据来训练其他机器学习模型,例如为欺诈检测系统创建欺诈交易数据。
  3. 完善缺失信息: GANs 可以填补缺失数据,例如为能源应用从地形图中生成地下图像。
  4. 生成三维模型: 它们能将二维图像转换成三维模型,在医疗保健等领域非常有用,能为手术规划创建逼真的器官图像。


GAN 如何工作?

它由两个深度神经网络组成:生成器和判别器。这些网络在对抗设置中一起训练,其中一个生成新数据,另一个评估数据是真是假。


以下是 GAN 工作原理的简要概述:

  1. 训练集分析: 生成器分析训练集以识别数据属性,而判别器则独立分析相同数据以学习其属性。
  2. 数据修改: 生成器在数据的某些属性上添加噪音(随机变化)。
  3. 数据传递: 然后将修改后的数据传递给判别器。
  4. 概率计算: 判别器计算生成的数据来自原始数据集的概率。
  5. 反馈回路: 鉴别器向生成器提供反馈,指导生成器在下一个循环中减少随机噪音。
  6. 对抗训练: 生成器试图将判别器的错误最大化,而判别器则试图将自己的错误最小化。通过多次训练迭代,两个网络都会得到改进和发展。
  7. 平衡状态: 训练一直持续到鉴别器无法再区分真实数据和合成数据,表明生成器已成功学会生成真实数据。至此,训练过程结束。


3


GAN 训练示例

让我们以图像到图像的翻译为例,解释 GAN 模型,重点是修改人脸。


  1. 输入图像: 输入是一张真实的人脸图像。
  2. 属性修改: 生成器会修改人脸的属性,例如在眼睛上添加太阳镜。
  3. 生成图像: 生成器创建一组添加了太阳镜的图像。
  4. 判别器的任务: 判别器将收到真实图像(戴太阳镜的人)和生成图像(添加了太阳镜的人脸)的混合图像。
  5. 评估: 鉴别器尝试区分真实图像和生成图像。
  6. 反馈回路: 如果鉴别器能正确识别假图像,生成器就会调整参数,生成更有说服力的图像。如果生成器成功骗过了鉴别器,鉴别器就会更新参数,以提高检测能力。


通过这种对抗过程,两个网络都在不断改进。生成器越来越善于生成逼真的图像,而鉴别器则越来越善于识别假图像,直到达到平衡,鉴别器再也无法区分真实图像和生成的图像。至此,GAN 已成功学会生成逼真的修改图像。


搭建舞台

我们将使用一系列 Python 库,让我们导入它们。


# Operating System module for interacting with the operating system
import os
# Module for generating random numbers
import random
# Module for numerical operations
import numpy as np
# OpenCV library for image processing
import cv2
# Python Imaging Library for image processing
from PIL import Image, ImageDraw, ImageFont
# PyTorch library for deep learning
import torch
# Dataset class for creating custom datasets in PyTorch
from torch.utils.data import Dataset
# Module for image transformations
import torchvision.transforms as transforms
# Neural network module in PyTorch
import torch.nn as nn
# Optimization algorithms in PyTorch
import torch.optim as optim
# Function for padding sequences in PyTorch
from torch.nn.utils.rnn import pad_sequence
# Function for saving images in PyTorch
from torchvision.utils import save_image
# Module for plotting graphs and images
import matplotlib.pyplot as plt
# Module for displaying rich content in IPython environments
from IPython.display import clear_output, display, HTML
# Module for encoding and decoding binary data to text
import base64


现在我们已经导入了所有库,下一步就是定义训练数据,我们将使用这些数据来训练我们的 GAN 架构。


编码训练数据

我们需要至少 10,000 个视频作为训练数据。为什么呢?因为我用更少的数据进行过测试,结果非常糟糕,几乎什么都看不到。下一个大问题是:这些视频是关于什么的?我们的训练视频数据集由一个向不同方向运动的圆圈和不同的动作组成。那么,让我们编码生成 10,000 个视频,看看它是什么样子的。


# Create a directory named 'training_dataset'
os.makedirs('training_dataset', exist_ok=True)
# Define the number of videos to generate for the dataset
num_videos = 10000
# Define the number of frames per video (1 Second Video)
frames_per_video = 10
# Define the size of each image in the dataset
img_size = (64, 64)
# Define the size of the shapes (Circle)
shape_size = 10


设置完一些基本参数后,接下来我们需要定义训练数据集的文本提示,并根据这些提示生成训练视频。


# Define text prompts and corresponding movements for circles
prompts_and_movements = [
    ("circle moving down", "circle", "down"),  # Move circle downward
    ("circle moving left", "circle", "left"),  # Move circle leftward
    ("circle moving right", "circle", "right"),  # Move circle rightward
    ("circle moving diagonally up-right", "circle", "diagonal_up_right"),  # Move circle diagonally up-right
    ("circle moving diagonally down-left", "circle", "diagonal_down_left"),  # Move circle diagonally down-left
    ("circle moving diagonally up-left", "circle", "diagonal_up_left"),  # Move circle diagonally up-left
    ("circle moving diagonally down-right", "circle", "diagonal_down_right"),  # Move circle diagonally down-right
    ("circle rotating clockwise", "circle", "rotate_clockwise"),  # Rotate circle clockwise
    ("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"),  # Rotate circle counter-clockwise
    ("circle shrinking", "circle", "shrink"),  # Shrink circle
    ("circle expanding", "circle", "expand"),  # Expand circle
    ("circle bouncing vertically", "circle", "bounce_vertical"),  # Bounce circle vertically
    ("circle bouncing horizontally", "circle", "bounce_horizontal"),  # Bounce circle horizontally
    ("circle zigzagging vertically", "circle", "zigzag_vertical"),  # Zigzag circle vertically
    ("circle zigzagging horizontally", "circle", "zigzag_horizontal"),  # Zigzag circle horizontally
    ("circle moving up-left", "circle", "up_left"),  # Move circle up-left
    ("circle moving down-right", "circle", "down_right"),  # Move circle down-right
    ("circle moving down-left", "circle", "down_left"),  # Move circle down-left
]


我们已经利用这些提示定义了圆的几个运动轨迹。现在,我们需要编写一些数学公式,以便根据提示移动圆。


# Define function with parameters
def create_image_with_moving_shape(size, frame_num, shape, direction):
  
    # Create a new RGB image with specified size and white background
    img = Image.new('RGB', size, color=(255, 255, 255))  
    # Create a drawing context for the image
    draw = ImageDraw.Draw(img)  
    # Calculate the center coordinates of the image
    center_x, center_y = size[0] // 2, size[1] // 2  
    # Initialize position with center for all movements
    position = (center_x, center_y)  
    # Define a dictionary mapping directions to their respective position adjustments or image transformations
    direction_map = {  
        # Adjust position downwards based on frame number
        "down": (0, frame_num * 5 % size[1]),  
        # Adjust position to the left based on frame number
        "left": (-frame_num * 5 % size[0], 0),  
        # Adjust position to the right based on frame number
        "right": (frame_num * 5 % size[0], 0),  
        # Adjust position diagonally up and to the right
        "diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        # Adjust position diagonally down and to the left
        "diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]),  
        # Adjust position diagonally up and to the left
        "diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        # Adjust position diagonally down and to the right
        "diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),  
        # Rotate the image clockwise based on frame number
        "rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),  
        # Rotate the image counter-clockwise based on frame number
        "rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),  
        # Adjust position for a bouncing effect vertically
        "bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)),  
        # Adjust position for a bouncing effect horizontally
        "bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0),  
        # Adjust position for a zigzag effect vertically
        "zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]),  
        # Adjust position for a zigzag effect horizontally
        "zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y),  
        # Adjust position upwards and to the right based on frame number
        "up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        # Adjust position upwards and to the left based on frame number
        "up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),  
        # Adjust position downwards and to the right based on frame number
        "down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),  
        # Adjust position downwards and to the left based on frame number
        "down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1])  
    }
    # Check if direction is in the direction map
    if direction in direction_map:  
        # Check if the direction maps to a position adjustment
        if isinstance(direction_map[direction], tuple):  
            # Update position based on the adjustment
            position = tuple(np.add(position, direction_map[direction]))  
        else:  # If the direction maps to an image transformation
            # Update the image based on the transformation
            img = direction_map[direction]  
    # Return the image as a numpy array
    return np.array(img)


上述函数用于根据所选方向在每一帧中移动我们的圆。我们只需在其上运行一个循环,直至生成所有视频的次数。


# Iterate over the number of videos to generate
for i in range(num_videos):
    # Randomly choose a prompt and movement from the predefined list
    prompt, shape, direction = random.choice(prompts_and_movements)
    
    # Create a directory for the current video
    video_dir = f'training_dataset/video_{i}'
    os.makedirs(video_dir, exist_ok=True)
    
    # Write the chosen prompt to a text file in the video directory
    with open(f'{video_dir}/prompt.txt', 'w') as f:
        f.write(prompt)
    
    # Generate frames for the current video
    for frame_num in range(frames_per_video):
        # Create an image with a moving shape based on the current frame number, shape, and direction
        img = create_image_with_moving_shape(img_size, frame_num, shape, direction)
        
        # Save the generated image as a PNG file in the video directory
        cv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)


运行上述代码后,就会生成整个训练数据集。以下是训练数据集文件的结构。


4


每个训练视频文件夹都包含帧和文本提示。让我们来看看训练数据集的样本。


5


在我们的训练数据集中,我们没有包含圆圈向上移动然后向右移动的动作。我们将以此作为测试提示,在未见过的数据上对训练模型进行评估。


还有一点需要注意的是,我们的训练数据中确实包含了许多物体远离场景或部分出现在摄像头前的样本,这与我们在 OpenAI Sora 演示视频中观察到的情况类似。


6


之所以在训练数据中加入此类样本,是为了测试当圆形从最角落进入场景时,我们的模型能否在不破坏其形状的情况下保持一致性。


既然训练数据已经生成,我们就需要将训练视频转换为张量,这是 PyTorch 等深度学习框架中使用的主要数据类型。此外,进行归一化等转换有助于将数据缩放到更小的范围内,从而提高训练架构的收敛性和稳定性。


预处理我们的训练数据

我们必须编写一个用于文本到视频任务的数据集类,它可以从训练数据集目录中读取视频帧及其对应的文本提示,使其可以在 PyTorch 中使用。


# Define a dataset class inheriting from torch.utils.data.Dataset
class TextToVideoDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        # Initialize the dataset with root directory and optional transform
        self.root_dir = root_dir
        self.transform = transform
        # List all subdirectories in the root directory
        self.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        # Initialize lists to store frame paths and corresponding prompts
        self.frame_paths = []
        self.prompts = []
        # Loop through each video directory
        for video_dir in self.video_dirs:
            # List all PNG files in the video directory and store their paths
            frames = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')]
            self.frame_paths.extend(frames)
            # Read the prompt text file in the video directory and store its content
            with open(os.path.join(video_dir, 'prompt.txt'), 'r') as f:
                prompt = f.read().strip()
            # Repeat the prompt for each frame in the video and store in prompts list
            self.prompts.extend([prompt] * len(frames))
    # Return the total number of samples in the dataset
    def __len__(self):
        return len(self.frame_paths)
    # Retrieve a sample from the dataset given an index
    def __getitem__(self, idx):
        # Get the path of the frame corresponding to the given index
        frame_path = self.frame_paths[idx]
        # Open the image using PIL (Python Imaging Library)
        image = Image.open(frame_path)
        # Get the prompt corresponding to the given index
        prompt = self.prompts[idx]
        # Apply transformation if specified
        if self.transform:
            image = self.transform(image)
        # Return the transformed image and the prompt
        return image, prompt


在对架构进行编码之前,我们需要对训练数据进行归一化处理。我们将使用 16 个批次的数据,并对数据进行洗牌,以引入更多随机性。


# Define a set of transformations to be applied to the data
transform = transforms.Compose([
    transforms.ToTensor(), # Convert PIL Image or numpy.ndarray to tensor
    transforms.Normalize((0.5,), (0.5,)) # Normalize image with mean and standard deviation
])
# Load the dataset using the defined transform
dataset = TextToVideoDataset(root_dir='training_dataset', transform=transform)
# Create a dataloader to iterate over the dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)


实现文本嵌入层

你可能在转换器架构中见过,其起点是将文本输入转换为嵌入,以便在多头注意力中进行进一步处理,与此类似,我们必须编码一个文本嵌入层,在此基础上对嵌入数据和图像张量进行 GAN 架构训练。


# Define a class for text embedding
class TextEmbedding(nn.Module):
    # Constructor method with vocab_size and embed_size parameters
    def __init__(self, vocab_size, embed_size):
        # Call the superclass constructor
        super(TextEmbedding, self).__init__()
        # Initialize embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_size)
    # Define the forward pass method
    def forward(self, x):
        # Return embedded representation of input
        return self.embedding(x)


词汇量大小将基于我们的训练数据,稍后我们将计算出来。嵌入大小为 10。如果要处理更大的数据集,也可以使用 Hugging Face 上提供的嵌入模型。


实现生成器层

既然我们已经知道了生成器在 GAN 中的作用,那么让我们对这一层进行编码,然后了解其内容。


class Generator(nn.Module):
    def __init__(self, text_embed_size):
        super(Generator, self).__init__()
        
        # Fully connected layer that takes noise and text embedding as input
        self.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8)
        
        # Transposed convolutional layers to upsample the input
        self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1)  # Output has 3 channels for RGB images
        
        # Activation functions
        self.relu = nn.ReLU(True)  # ReLU activation function
        self.tanh = nn.Tanh()       # Tanh activation function for final output
    def forward(self, noise, text_embed):
        # Concatenate noise and text embedding along the channel dimension
        x = torch.cat((noise, text_embed), dim=1)
        
        # Fully connected layer followed by reshaping to 4D tensor
        x = self.fc1(x).view(-1, 256, 8, 8)
        
        # Upsampling through transposed convolution layers with ReLU activation
        x = self.relu(self.deconv1(x))
        x = self.relu(self.deconv2(x))
        
        # Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)
        x = self.tanh(self.deconv3(x))
        
        return x


该生成器类负责从随机噪音和文本嵌入的组合中创建视频帧。其目的是根据给定的文本描述生成逼真的视频帧。网络从一个全连接层(nn.Linear)开始,该层将噪声矢量和文本嵌入合并为一个特征向量。然后对该向量进行重塑,并通过一系列转置卷积层(nn.ConvTranspose2d),将特征映射逐步上样到所需的视频帧大小。


各层使用 ReLU 激活(nn.ReLU)来实现非线性,最后一层使用 Tanh 激活(nn.Tanh)将输出缩放到 [-1, 1] 范围内。这样,生成器就能将抽象的高维输入转化为可视化表示输入文本的连贯视频帧。


实现鉴别器层

对生成器层进行编码后,我们需要实现另一半,即判别器部分。


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # Convolutional layers to process input images
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)   # 3 input channels (RGB), 64 output channels, kernel size 4x4, stride 2, padding 1
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1) # 64 input channels, 128 output channels, kernel size 4x4, stride 2, padding 1
        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1) # 128 input channels, 256 output channels, kernel size 4x4, stride 2, padding 1
        
        # Fully connected layer for classification
        self.fc1 = nn.Linear(256 * 8 * 8, 1)  # Input size 256x8x8 (output size of last convolution), output size 1 (binary classification)
        
        # Activation functions
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)  # Leaky ReLU activation with negative slope 0.2
        self.sigmoid = nn.Sigmoid()  # Sigmoid activation for final output (probability)
    def forward(self, input):
        # Pass input through convolutional layers with LeakyReLU activation
        x = self.leaky_relu(self.conv1(input))
        x = self.leaky_relu(self.conv2(x))
        x = self.leaky_relu(self.conv3(x))
        
        # Flatten the output of convolutional layers
        x = x.view(-1, 256 * 8 * 8)
        
        # Pass through fully connected layer with Sigmoid activation for binary classification
        x = self.sigmoid(self.fc1(x))
        
        return x


鉴别器类是一种二进制分类器,用于区分真实和生成的视频帧。其目的是评估视频帧的真实性,从而引导生成器产生更真实的输出。该网络由卷积层(nn.Conv2d)和 Leaky ReLU 激活层(nn.LeakyReLU)组成,卷积层用于从输入视频帧中提取分层特征,Leaky ReLU 激活层增加了非线性,同时允许负值具有较小的梯度。然后,对特征图进行扁平化处理,并通过一个全连接层(nn.Linear),最后通过一个 sigmoid 激活层(nn.Sigmoid),输出一个概率分数,显示帧的真假。


通过训练鉴别器对帧进行准确分类,同时训练生成器创建更有说服力的视频帧,因为它的目的是骗过鉴别器。


编码训练参数

我们必须设置训练 GAN 的基本组件,如损失函数、优化器等。


# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create a simple vocabulary for text prompts
all_prompts = [prompt for prompt, _, _ in prompts_and_movements]  # Extract all prompts from prompts_and_movements list
vocab = {word: idx for idx, word in enumerate(set(" ".join(all_prompts).split()))}  # Create a vocabulary dictionary where each unique word is assigned an index
vocab_size = len(vocab)  # Size of the vocabulary
embed_size = 10  # Size of the text embedding vector
def encode_text(prompt):
    # Encode a given prompt into a tensor of indices using the vocabulary
    return torch.tensor([vocab[word] for word in prompt.split()])
# Initialize models, loss function, and optimizers
text_embedding = TextEmbedding(vocab_size, embed_size).to(device)  # Initialize TextEmbedding model with vocab_size and embed_size
netG = Generator(embed_size).to(device)  # Initialize Generator model with embed_size
netD = Discriminator().to(device)  # Initialize Discriminator model
criterion = nn.BCELoss().to(device)  # Binary Cross Entropy loss function
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Adam optimizer for Discriminator
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Adam optimizer for Generator


这部分我们必须转换代码,以便在 GPU 上运行(如果有的话)。我们已经编写了查找 vocab_size 的代码,并在生成器和判别器中使用了 ADAM 优化器。如果你愿意,也可以选择自己的优化器。在这里,我们将学习率设置为 0.0002 的较小值,嵌入大小为 10,这与其他可公开使用的拥抱脸部模型相比要小得多。


编码训练循环

就像其他神经网络一样,我们将以类似的方式对 GAN 架构训练进行编码。


# Number of epochs
num_epochs = 13
# Iterate over each epoch
for epoch in range(num_epochs):
    # Iterate over each batch of data
    for i, (data, prompts) in enumerate(dataloader):
        # Move real data to device
        real_data = data.to(device)
        
        # Convert prompts to list
        prompts = [prompt for prompt in prompts]
        # Update Discriminator
        netD.zero_grad()  # Zero the gradients of the Discriminator
        batch_size = real_data.size(0)  # Get the batch size
        labels = torch.ones(batch_size, 1).to(device)  # Create labels for real data (ones)
        output = netD(real_data)  # Forward pass real data through Discriminator
        lossD_real = criterion(output, labels)  # Calculate loss on real data
        lossD_real.backward()  # Backward pass to calculate gradients
       
        # Generate fake data
        noise = torch.randn(batch_size, 100).to(device)  # Generate random noise
        text_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts])  # Encode prompts into text embeddings
        fake_data = netG(noise, text_embeds)  # Generate fake data from noise and text embeddings
        labels = torch.zeros(batch_size, 1).to(device)  # Create labels for fake data (zeros)
        output = netD(fake_data.detach())  # Forward pass fake data through Discriminator (detach to avoid gradients flowing back to Generator)
        lossD_fake = criterion(output, labels)  # Calculate loss on fake data
        lossD_fake.backward()  # Backward pass to calculate gradients
        optimizerD.step()  # Update Discriminator parameters
        # Update Generator
        netG.zero_grad()  # Zero the gradients of the Generator
        labels = torch.ones(batch_size, 1).to(device)  # Create labels for fake data (ones) to fool Discriminator
        output = netD(fake_data)  # Forward pass fake data (now updated) through Discriminator
        lossG = criterion(output, labels)  # Calculate loss for Generator based on Discriminator's response
        lossG.backward()  # Backward pass to calculate gradients
        optimizerG.step()  # Update Generator parameters
    
    # Print epoch information
    print(f"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}")


通过反向传播,我们将调整发生器和鉴别器的损耗。我们在训练循环中使用了 13 个历元。我测试过不同的值,但如果 epoch 超过这个值,结果并不会有太大差别。此外,过度拟合的风险也很高。如果我们有一个更多样化的数据集,包含更多的动作和形状,我们可以考虑使用更高的epochs,但在这种情况下不行。


当我们运行这段代码时,它会开始训练,并在每个epoch后打印生成器和判别器的损失。


## OUTPUT ##
Epoch [1/13] Loss D: 0.8798642754554749, Loss G: 1.300612449645996
Epoch [2/13] Loss D: 0.8235711455345154, Loss G: 1.3729925155639648
Epoch [3/13] Loss D: 0.6098687052726746, Loss G: 1.3266581296920776
...


保存训练好的模型

训练完成后,我们需要保存训练好的 GAN 架构的判别器和生成器,这只需两行代码即可实现。


# Save the Generator model's state dictionary to a file named 'generator.pth'
torch.save(netG.state_dict(), 'generator.pth')
# Save the Discriminator model's state dictionary to a file named 'discriminator.pth'
torch.save(netD.state_dict(), 'discriminator.pth')


生成人工智能视频

正如我们所讨论的,我们在未见过的数据上测试模型的方法与我们的训练数据涉及狗取球和猫追老鼠的例子类似。因此,我们的测试提示可能涉及猫取球或狗追老鼠等场景。


在我们的具体案例中,我们的训练数据中不存在圆圈向上然后向右移动的动作,因此模型对这种特定动作并不熟悉。不过,模型已经接受过其他运动的训练。我们可以用这个动作作为提示来测试我们训练好的模型,并观察它的表现。


# Inference function to generate a video based on a given text prompt
def generate_video(text_prompt, num_frames=10):
    # Create a directory for the generated video frames based on the text prompt
    os.makedirs(f'generated_video_{text_prompt.replace(" ", "_")}', exist_ok=True)
    
    # Encode the text prompt into a text embedding tensor
    text_embed = text_embedding(encode_text(text_prompt).to(device)).mean(dim=0).unsqueeze(0)
    
    # Generate frames for the video
    for frame_num in range(num_frames):
        # Generate random noise
        noise = torch.randn(1, 100).to(device)
        
        # Generate a fake frame using the Generator network
        with torch.no_grad():
            fake_frame = netG(noise, text_embed)
        
        # Save the generated fake frame as an image file
        save_image(fake_frame, f'generated_video_{text_prompt.replace(" ", "_")}/frame_{frame_num}.png')
# usage of the generate_video function with a specific text prompt
generate_video('circle moving up-right')


当我们运行上述代码时,它会生成一个目录,其中包含生成视频的所有帧。我们需要使用一些代码将所有这些帧合并成一个视频短片。


# Define the path to your folder containing the PNG frames
folder_path = 'generated_video_circle_moving_up-right'

# Get the list of all PNG files in the folder
image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]
# Sort the images by name (assuming they are numbered sequentially)
image_files.sort()
# Create a list to store the frames
frames = []
# Read each image and append it to the frames list
for image_file in image_files:
  image_path = os.path.join(folder_path, image_file)
  frame = cv2.imread(image_path)
  frames.append(frame)
# Convert the frames list to a numpy array for easier processing
frames = np.array(frames)
# Define the frame rate (frames per second)
fps = 10
# Create a video writer object
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('generated_video.avi', fourcc, fps, (frames[0].shape[1], frames[0].shape[0]))
# Write each frame to the video
for frame in frames:
  out.write(frame)
# Release the video writer
out.release()


确保文件夹路径指向新生成视频的位置。运行此代码后,你的 AI 视频就成功创建了。让我们看看它看起来像什么。


7


我用相同数量的历时进行了多次训练。在这两种情况下,圆都是从底部出现的一半开始。好在我们的模型在两种情况下都尝试了向上向右的运动。例如,在 "尝试 1 "中,圆圈斜向上移动,然后做向上运动;而在 "尝试 2 "中,圆圈斜向上移动,同时缩小尺寸。在这两种情况下,圆都没有向左移动或完全消失,这是个好现象。


此外,本文中讨论的 GAN 架构相对简单。你可以通过整合高级技术或使用语言模型嵌入 (LLM) 而不是基本的神经网络嵌入,使其变得更加复杂。此外,调整嵌入大小等参数也会对模型的有效性产生重大影响。

文章来源:https://medium.com/gitconnected/building-an-ai-text-to-video-model-from-scratch-using-python-35b4eb4002de
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消