摘要:本文详细介绍了基于扩散模型构建的文本到视频生成系统,展示了在MSRV-TT和Shutterstock视频标注数据集上训练的模型输出结果。以下是模型在不同提示词下的生成示例。
本文详细介绍了基于扩散模型构建的文本到视频生成系统,展示了在MSRV-TT和Shutterstock视频标注数据集上训练的模型输出结果。以下是模型在不同提示词下的生成示例。
首先展示一些模型生成效果展示
提示词:"A person holding a camera"(训练10K步)
拿相机的人物场景
提示词:"Spaceship crossing the bridge"(训练10K步)
飞船穿过桥梁场景
提示词:"News Reporter speaking"(训练10K步)
新闻记者讲话场景
在Moving Mnist数据集上训练的模型(训练5K步)
使用训练模型生成的合成Moving Mnist数据
扩散模型是当前文本到视频生成领域的主流架构,被广泛应用于OpenAI Sora、Stable Video Diffusion等系统中。本文将从基本原理出发,实现一个基于扩散原理的架构,构建能够根据文本提示生成视频或GIF的模型系统。
3D U-Net结构:专为视频处理设计,能够有效处理随时间变化的帧序列。此U-Net架构融合了多重注意力机制:
时间注意力:负责捕捉帧间的时序关联空间注意力:处理每一帧内的区域关系这些注意力层与特殊功能模块协同工作,从视频数据中提取关键特征扩散过程原理:模型的工作机制可以简述为以下步骤:
向训练视频添加噪声直至其变为纯随机分布模型学习逆向过程,即从噪声中恢复原始信号生成阶段,从随机噪声出发,模型逐步去除噪声文本提示通过BERT转换为嵌入向量,引导UNet的去噪方向通过反复迭代,最终生成与文本语义匹配的视频内容为了更直观地理解系统架构,下面是一个简化的模型框架图:
核心文件说明:
train.py是训练模型的主脚本,负责加载配置、初始化模型组件并启动训练流程generate.py用于利用训练好的模型根据文本提示生成视频内容src目录包含模型架构、数据处理、扩散过程实现、文本处理和训练工具的核心代码成功实现本项目需要对面向对象编程(OOP)和神经网络(NN)有深入理解。熟悉PyTorch框架将有助于理解和修改代码实现。源代码地址在文章最后
环境配置首先需要克隆项目仓库并安装依赖:
git clone https://github.com/FareedKhan-dev/text2video-from-scratch cd text2video-from-scratchpip install -r requirements.txt本项目使用的主要库及其功能:
# 系统操作相关import os # 文件系统操作import yaml # 配置文件解析from pathlib import Path # 跨平台路径处理import subprocess # 执行系统命令import zipfile # ZIP文件处理# 数据处理相关import pandas as pd # 结构化数据处理from tqdm import tqdm # 进度显示# 图像、视频处理from PIL import Image # 图像处理from moviepy.editor import VideoFileClip # 视频编辑from datasets import load_dataset # 数据集加载# PyTorch核心组件import torch # 深度学习框架from torch import nn, einsum # 神经网络构建与张量运算from torch.nn import functional as F # 函数式APIfrom torch.utils import data # 数据加载工具# 张量操作工具from einops import rearrange # 张量重排from einops_exts import rearrange_many, check_shape # 扩展张量操作from rotary_embedding_torch import RotaryEmbedding # 旋转位置编码# 自然语言处理from transformers import BertModel, BertTokenizer # BERT模型与分词器# 其他工具import copy # 对象复制from torch.optim import Adam # 优化器from torch.cuda.amp import autocast, GradScaler # 混合精度训练import math # 数学函数import colorsys # 颜色空间转换训练数据准备为了构建高质量的文本到视频生成模型,我们需要多样化的带标注视频数据集。MSR-VTT(Microsoft Research Video to Text)是理想选择,它包含来自20个类别的10,000个视频剪辑,每个剪辑都有对应的英文描述标注。以下是数据获取与处理的实现方法:
# 使用Kaggle API下载数据集def download_kaggle_dataset(dataset_name: str, download_dir: str) -> None: # 确保目录存在 Path(download_dir).mkdir(parents=True, exist_ok=True) # 调用Kaggle命令行工具下载数据集command = f"kaggle datasets download {dataset_name} -p {download_dir}" subprocess.run(command, shell=True, check=True) # 解压下载的数据文件def unzip_file(zip_path: str, extract_dir: str) -> None: with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(extract_dir) # 可视化数据集中的随机视频样本def visualize_random_videos(videos_dir: str, num_videos: int = 8) -> None: # 获取所有MP4文件video_files = [f for f in os.listdir(videos_dir) if f.endswith('.MP4')] # 随机抽样指定数量的视频random_videos = random.sample(video_files, num_videos) # 创建显示网格fig, axes = plt.subplots(2, 4, figsize=(12, 6)) axes = axes.ravel # 处理并显示每个视频的首帧for i, video_file in enumerate(random_videos): video_path = os.path.join(videos_dir, video_file) # 加载视频并提取前2秒帧clip = VideoFileClip(video_path).subclip(0, 2) # 获取首帧并显示frame = clip.get_frame(0) axes[i].imshow(frame) axes[i].axis('off') axes[i].set_title(f"Video {i+1}") plt.tight_layout plt.show实际执行数据下载与可视化:
# 数据集下载与解压kaggle_dataset_name = 'vishnutheepb/msrvtt' download_dir = './msrvtt_data' unzip_dir = './msrvtt_data/msrvtt' download_kaggle_dataset(kaggle_dataset_name, download_dir) zip_file_path = os.path.join(download_dir, 'msrvtt.zip') unzip_file(zip_file_path, unzip_dir) # 可视化8个随机视频样本videos_dir = os.path.join(unzip_dir, 'TrainValVideo') visualize_random_videos(videos_dir)MSRVTT数据集样本展示
为便于训练,我们需要将视频转换为更高效的格式。我们选择将MP4文件转换为GIF格式并创建对应的文本描述文件:
# 创建训练数据集,转换视频格式并准备标注def create_training_data(videos_dir: str, output_dir: str, size=(64, 64), duration=2) -> None: Path(output_dir).mkdir(parents=True, exist_ok=True) video_files = [f for f in os.listdir(videos_dir) if f.endswith('.mp4')] for video_file in video_files: video_path = os.path.join(videos_dir, video_file) base_name = os.path.splitext(video_file)[0] gif_path = os.path.join(output_dir, f"{base_name}.gif") txt_path = os.path.join(output_dir, f"{base_name}.txt") # 视频转GIF处理clip = VideoFileClip(video_path).subclip(0, duration) clip = clip.resize(size) clip.write_gif(gif_path, program='ffmpeg') # 创建文本描述文件with open(txt_path, "w") as txt_file: txt_file.write(f"{base_name}") print(f"已处理: {video_file} -> {base_name}.gif and {base_name}.txt") # 执行数据转换videos_dir = "./msrvtt_data/msrvtt/TrainValVideo" output_dir = "./training_data" create_training_data(videos_dir, output_dir)经过处理,我们的训练数据将采用以下结构:
training_data/ |── video1.gif |── video1.txt |── video2.gif |── video2.txt ...此格式便于模型训练过程中高效地加载和处理数据。
为了提高代码的可维护性和模块化程度,我们定义了一系列辅助函数,用于处理常见任务如参数检查、数据迭代和掩码生成等:
# 检查变量是否存在(非None)def exists(x: Union[None, object]) -> bool: return x is not None # 空操作函数,接受任意参数但不执行任何操作def noop(*args, **kwargs) -> None: pass # 检查整数是否为奇数def is_odd(n: int) -> bool: return (n % 2) == 1 # 返回值或默认值def default(val: Union[None, object], d: Union[object, Callable[, object]]) -> object: if exists(val): return val return d if callable(d) else d # 数据加载器循环迭代器def cycle(dl: torch.utils.data.DataLoader) -> torch.utils.data.DataLoader: while True: for data in dl: yield data # 将总数划分为指定大小的组def num_to_groups(num: int, divisor: int) -> List[int]: groups = num // divisor remainder = num % divisor arr = [divisor] * groups if remainder > 0: arr.append(remainder) return arr # 生成概率掩码def prob_mask_like(shape: Tuple[int, ...], prob: float, device: torch.device) -> torch.Tensor: if prob == 1: return torch.ones(shape, device=device, dtype=torch.bool) elif prob == 0: return torch.zeros(shape, device=device, dtype=torch.bool) else: return torch.zeros(shape, device=device).float.uniform_(0, 1) bool: if not isinstance(x, (list, tuple)): return False return all([type(el) == str for el in x])注意力机制是扩散视频生成模型的核心组件,使模型能够有选择地关注输入的重要部分,无论是空间区域、时间帧还是文本信息。以下是注意力模块的实现:
class EinopsToAndFrom(nn.Module): def __init__(self, from_einops: str, to_einops: str, fn: Callable[[torch.Tensor], torch.Tensor]) -> None: super.__init__ self.from_einops = from_einops # 输入形状描述self.to_einops = to_einops # 目标形状描述self.fn = fn # 要应用的变换函数def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: shape = x.shape # 创建重建原始形状所需的参数字典reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))# 重排张量到目标形状x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') # 应用变换函数x = self.fn(x, **kwargs)# 恢复原始形状x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) return x class Attention(nn.Module): def __init__( self, dim: int, heads: int = 4, dim_head: int = 32, rotary_emb: Optional[nn.Module] = None ) -> None: super.__init__ self.scale = dim_head ** -0.5 # 注意力缩放因子self.heads = heads hidden_dim = dim_head * heads self.rotary_emb = rotary_emb # 可选的旋转位置编码self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) # QKV投影self.to_out = nn.Linear(hidden_dim, dim, bias=False) # 输出投影def forward( self, x: torch.Tensor, pos_bias: Optional[torch.Tensor] = None, focus_present_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: n, device = x.shape[-2], x.device qkv = self.to_qkv(x).chunk(3, dim=-1) # 分离QKV向量# 如果focus_present_mask全部激活,直接返回值向量if exists(focus_present_mask) and focus_present_mask.all: values = qkv[-1] return self.to_out(values) # 重排QKV用于多头处理q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) q = q * self.scale # 缩放查询向量# 应用旋转位置编码(如提供)if exists(self.rotary_emb): q = self.rotary_emb.rotate_queries_or_keys(q) k = self.rotary_emb.rotate_queries_or_keys(k) # 计算注意力分数sim = einsum('... h i d, ... h j d -> ... h i j', q, k) # 应用位置偏置(如提供)if pos_bias is not None: sim = sim + pos_bias # 处理焦点掩码if focus_present_mask is not None and not (~focus_present_mask).all: attend_all_mask = torch.ones((n, n), device=device, dtype=torch.bool) attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) mask = torch.where( rearrange(focus_present_mask, 'b -> b 1 1 1 1'), rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), ) sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) # 掩码应用# 数值稳定性优化sim = sim - sim.amax(dim=-1, keepdim=True).detach # 计算注意力权重attn = sim.softmax(dim=-1) # 加权汇总值向量out = einsum('... h i j, ... h j d -> ... h i d', attn, v) # 重排输出out = rearrange(out, '... h n d -> ... n (h d)') return self.to_out(out) # 应用输出投影EinopsToAndFrom 类提供了张量形状转换的灵活机制,使模型能够在不同表示之间高效切换。而 Attention 类实现了多头注意力机制的核心算法,支持位置编码、掩码操作和数值稳定性优化,这些都是高质量视频生成的关键要素。
在深度学习视频生成架构中,基础构建模块是整个网络功能的关键组件。我们实现的U-Net架构采用了层次化结构,由多个专用模块组成,每个模块都具有特定的数据转换功能。
视频生成架构的基础构建模块包括三个关键组件:
标准卷积块 (Block):作为神经网络的基础单元,实现了输入特征的基本转换处理。该模块采用3D卷积层、层归一化和SiLU激活函数的组合,支持通过scale-shift操作进行条件处理,使其能够根据不同的条件输入(如时间嵌入)动态调整特征响应。
残差块 (ResnetBlock):通过引入跳跃连接增强了网络训练稳定性,有效缓解了深层网络中的梯度消失问题。此模块包含两个串联的标准卷积块,并支持时间条件嵌入,通过MLP网络将时间信息转换为特征调制信号,从而使模型能够根据时间步信息调整特征表示。
空间线性注意力模块 (SpatialLinearAttention):在视频帧的空间维度上实现高效注意力计算,使模型能够捕获图像内的长距离依赖关系。该模块采用了多头注意力设计,分别对查询和键应用softmax操作,实现了对输入特征的自适应加权处理。
这些基础模块具有不同但相互补充的功能:Block模块处理局部特征转换,ResnetBlock增强训练稳定性并支持时间条件,而SpatialLinearAttention则捕获远距离空间依赖关系。通过组合这些模块,模型能够有效处理视频数据的时空复杂性。
视频生成架构需要多种辅助组件来增强模型功能和训练稳定性。这些组件包括参数平滑机制、归一化层和位置编码等,为模型提供了必要的数据处理能力。
视频生成架构的通用组件提供了多种功能增强机制:
指数移动平均 (EMA):实现了模型参数的平滑更新策略,有效降低训练过程中的参数波动。通过对历史权重和当前权重进行加权平均,提高了模型的泛化能力和推理稳定性。这在视频生成任务中尤为重要,因为它帮助模型产生更连贯的时序输出。
残差连接包装器 (Residual):为任意函数模块添加跳跃连接,促进梯度在深层网络中的高效传播。这种设计不仅简化了残差路径的实现,还提高了整个架构的训练效率。
正弦位置嵌入 (SinusoidalPosEmb):为一维输入(如时间步)生成周期性位置编码,使模型能够区分不同的时间位置。该组件采用正弦和余弦函数的组合,创建了一种具有良好插值特性的嵌入表示。
上采样和下采样函数:分别通过3D转置卷积和标准3D卷积实现特征图的分辨率变换,维持通道数不变的同时改变空间维度。这些操作在U-Net结构的编码器和解码器部分扮演关键角色。
自定义归一化层:实现了多种特化的归一化策略,包括:
LayerNorm:实现通道维度上的特征归一化,带有可学习的缩放参数RMSNorm:提供基于均方根的归一化,计算复杂度更低PreNorm:在应用任意函数前执行归一化预处理,稳定深层网络中的信号传播这些通用组件协同工作,确保了视频生成模型的训练稳定性和生成质量,为处理高维时空数据提供了必要的计算基础。
相对位置编码在处理序列数据时至关重要,它使模型能够理解和利用输入元素之间的相对位置关系,而不依赖于绝对位置信息。这在视频生成等需要捕捉时空依赖关系的任务中尤为关键。
RelativePositionBias实现了一种高效的相对位置编码机制,通过桶化处理将相对位置映射到可学习的嵌入空间,具有以下特点:
自适应位置表示:与固定的位置嵌入不同,相对位置偏置学习不同位置元素之间的关系,而非绝对位置,这使模型在泛化到不同长度序列时更加灵活。
非线性桶化机制:通过_relative_position_bucket方法实现了一种智能的相对位置量化策略,对近距离位置使用线性分桶(提供精确表示),对远距离位置使用对数分桶(提供高效表示)。这种设计兼顾了计算效率和表示能力。
方向感知:通过区分正负相对距离,模型能够感知序列元素的相对方向,对于捕捉视频中的时序关系至关重要。
多头兼容设计:通过为每个注意力头生成独立的位置偏置,允许不同注意力头专注于不同类型的位置模式和依赖关系。
可扩展性:通过num_buckets和max_distance参数,该机制可以灵活调整以适应不同长度和复杂度的序列数据。
此位置编码组件通过为自注意力层提供精确的位置信息,使视频生成模型能够更好地理解时间和空间结构,在保持计算效率的同时提高生成质量。
# 将通道数映射到相应图像模式的字典 CHANNELS_TO_MODE = { 1: 'L', # 1通道对应灰度模式('L'模式) 3: 'RGB', # 3通道对应RGB颜色模式 4: 'RGBA' # 4通道对应RGBA颜色模式(带透明度) } # 生成器函数,用于从多帧图像(如GIF)中提取所有图像 def seek_all_images(img: Image.Image, channels: int = 3): # 确保指定的通道数有效 assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid' # 获取通道数对应的模式 mode = CHANNELS_TO_MODE[channels] i = 0 while True: try: # 寻找图像中的第i帧 img.seek(i) # 转换图像帧为所需模式并返回 yield img.convert(mode) except EOFError: # 帧结束(EOF),跳出循环 break i += 1 # 将视频张量转换为GIF并保存到指定路径的函数 def video_tensor_to_gif( tensor: torch.Tensor, path: str, duration: int = 120, loop: int = 0, optimize: bool = True ): # 将视频张量中的每一帧转换为PIL图像 images = map(T.ToPILImage, tensor.unbind(dim=1)) # 解包第一张图像和其余图像 first_img, *rest_imgs = images # 保存GIF并指定参数 first_img.save( path, save_all=True, # 将所有帧保存为GIF的一部分 append_images=rest_imgs, # 将其他帧附加到GIF中 duration=duration, # 设置每帧的持续时间(毫秒) loop=loop, # 设置GIF的循环次数(0表示无限循环) optimize=optimize # 启用GIF文件优化 ) # 返回图像列表作为结果 return images # 将GIF转换为张量(帧序列)的函数 def gif_to_tensor( path: str, channels: int = 3, transform: T.Compose = T.ToTensor ) -> torch.Tensor: # 从给定路径打开GIF图像 img = Image.open(path) # 将GIF中的所有帧转换为张量,应用转换 tensors = tuple(map(transform, seek_all_images(img, channels=channels))) # 沿帧维度将张量堆叠成单个张量 return torch.stack(tensors, dim=1) # 恒等函数:返回不变的输入张量 def identity(t, *args, **kwargs): return t # 将图像张量归一化到[-1, 1]范围的函数 def normalize_img(t: torch.Tensor) -> torch.Tensor: # 通过将张量值从[0, 1]缩放到[-1, 1]来归一化 return t * 2 - 1 # 将图像张量反归一化回[0, 1]范围的函数 def unnormalize_img(t: torch.Tensor) -> torch.Tensor: # 通过将张量值从[-1, 1]缩放到[0, 1]来反归一化 return (t + 1) * 0.5 # 确保张量具有指定帧数的函数 def cast_num_frames(t: torch.Tensor, *, frames: int) -> torch.Tensor: # 获取张量中当前的帧数 f = t.shape[1] if f == frames: # 如果帧数已经符合要求,则返回不变的张量 return t if f > frames: # 如果帧数超过需要,则截取张量到所需帧数 return t[:, :frames] # 如果帧数不足,则用零填充张量(无新帧) return torch.nn.functional.padding(t, (0, 0, 0, 0, 0, frames - f))上述工具函数集实现了对视频和帧序列数据的多种处理能力。seek_all_images函数提供了从多帧图像(如GIF)中逐帧提取单帧的机制,而video_tensor_to_gif允许将模型生成的视频张量转换为标准GIF格式并保存至存储设备。与此相对应,gif_to_tensor则实现了从GIF文件到PyTorch张量的转换,便于后续的神经网络处理。
此外,辅助函数如normalize_img和unnormalize_img提供了图像值域在标准化区间([-1,1])与图像标准区间([0,1])之间的双向转换,这是深度学习模型处理图像数据的常见需求。cast_num_frames函数则灵活处理视频帧数调整,通过截取或填充操作确保视频序列符合模型的输入要求。
为了支持扩散模型的训练过程,我们实现了专用的数据集类,用于高效处理视频数据与对应的文本描述。该类确保数据以PyTorch可用的批次形式加载,同时实现了必要的预处理转换。
# 用于处理GIF或视频文件的自定义数据集类 class Dataset(data.Dataset): # 使用所需参数初始化数据集 def __init__( self, folder: str, # 存储数据集的文件夹路径 image_size: int, # 每个图像调整到的大小 channels: int = 3, # 颜色通道数(默认为3,RGB) num_frames: int = 16, # 每个视频提取的帧数(默认为16) horizontal_flip: bool = False, # 是否应用水平翻转增强 force_num_frames: bool = True, # 是否强制视频张量具有确切的`num_frames`帧数 exts: List[str] = ['gif'] # 要查找的文件扩展名列表(默认为['gif']) ) -> None: # 调用父构造函数(来自PyTorch的Dataset) super.__init__ # 初始化数据集属性 self.folder = folder self.image_size = image_size self.channels = channels # 获取文件夹(及子文件夹)中与给定扩展名匹配的所有文件路径 self.paths = [ p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}') ] # 定义必要时用于转换帧数的函数 # 如果`force_num_frames`为True,我们应用`cast_num_frames`函数,否则使用恒等函数 self.cast_num_frames_fn = partial(cast_num_frames, frames=num_frames) if force_num_frames else identity # 定义要应用于每个图像的转换(调整大小、随机翻转、裁剪和转换为张量) self.transform = T.Compose([ T.Resize(image_size), # 将图像调整到目标大小 T.RandomHorizontalFlip if horizontal_flip else T.Lambda(identity), # 如果指定则应用随机水平翻转 T.CenterCrop(image_size), # 中心裁剪图像到目标大小 T.ToTensor # 将图像转换为PyTorch张量 ]) # 返回数据集中样本的总数 def __len__(self) -> int: return len(self.paths) # 通过索引获取特定样本(图像及其对应文本,如果有) def __getitem__(self, index: int) -> Tuple[torch.Tensor, Optional[str]]: # 获取给定索引的样本文件路径 path = self.paths[index] # 使用`gif_to_tensor`函数将GIF(或视频)转换为张量 # 应用先前定义的转换 tensor = gif_to_tensor(path, self.channels, transform=self.transform) # 将张量转换为具有正确帧数(如果需要) tensor = self.cast_num_frames_fn(tensor) # 检查此图像是否有对应的文本文件(相同名称,.txt扩展名) text_path = path.with_suffix(".txt") if text_path.exists: # 如果文本文件存在,读取其内容 with open(text_path, 'r') as f: text = f.read # 返回张量和文件中的文本 return tensor, text else: # 如果没有文本文件,返回张量,文本为`None` return tensor, NoneDataset类提供了完整的视频数据加载解决方案,通过初始化时接收的参数灵活配置数据处理行为。它在实例化过程中扫描指定文件夹中符合条件的视频文件,并建立转换管道用于图像处理。该类支持数据集大小查询和基于索引的数据获取,核心功能包括:
通过__init__方法配置数据源位置、图像尺寸、通道数和预处理参数利用__len__方法返回数据集的样本总数通过__getitem__方法检索指定索引的视频数据及其对应文本描述此数据集类支持数据增强(如水平翻转),帧数调整,并处理视频与文本的自动关联匹配,为模型训练提供高质量的数据流。
高斯扩散过程是本系统生成视频的核心机制,负责实现视频的逐步去噪生成。该过程模拟了从随机噪声逐渐恢复有意义视频内容的过程,通过深度神经网络指导每一步去噪。
高斯扩散原理示意图
# 根据时间步从张量中提取值的辅助函数 def extract(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor: b, *_ = t.shape # 获取批量大小 out = a.gather(-1, t) # 根据时间步提取值 return out.reshape(b, *((1,) * (len(x_shape) - 1))) # 重塑以匹配输入形状 # 为beta创建余弦调度的函数 def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> torch.Tensor: steps = timesteps + 1 x = torch.linspace(0, timesteps, steps, dtype=torch.float64) # 创建时间网格 alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 # 余弦函数 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] # 归一化 betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) # 计算beta值 return torch.clip(betas, 0, 0.9999) # 确保beta值保持在范围内 # 高斯扩散模型的主类 class GaussianDiffusion(nn.Module): def __init__(self, denoise_fn: nn.Module, *, image_size: int, num_frames: int, timesteps: int = 1000): super.__init__ self.denoise_fn = denoise_fn self.image_size = image_size self.num_frames = num_frames betas = cosine_beta_schedule(timesteps) # 获取beta调度 # 初始化模型计算的各种张量 alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) timesteps, = betas.shape self.num_timesteps = int(timesteps) # 注册缓冲区(不通过梯度下降更新的张量) register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) register_buffer('betas', betas) register_buffer('alphas_cumprod', alphas_cumprod) register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) # 更多初始化,用于各种系数(用于计算后验和前向过程) register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) # 计算q分布的均值、方差和对数方差的函数 def q_mean_variance(self, x_start: torch.Tensor, t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = extract(1. - self.alphas_cumprod, t, x_start.shape) log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) return mean, variance, log_variance # 从噪声数据预测图像起点的函数 def predict_start_from_noise(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) # 计算后验分布的函数 def q_posterior(self, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: posterior_mean = ( extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped # 使用模型预测进行去噪的函数 def p_mean_variance(self, x: torch.Tensor, t: torch.Tensor, clip_denoised: bool) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn(x)) if clip_denoised: # 如果需要,对去噪后的图像进行裁剪 x_recon = x_recon.clamp(-1., 1.) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance # 单个去噪步骤的函数 @torch.inference_mode def p_sample(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=True) noise = torch.randn_like(x) # 添加随机噪声 return model_mean + noise * (0.5 * model_log_variance).exp # 返回去噪后的图像 # 生成样本的函数(整个循环) @torch.inference_mode def p_sample_loop(self, shape: torch.Size) -> torch.Tensor: img = torch.randn(shape, device=self.device) # 从随机噪声开始 for t in reversed(range(self.num_timesteps)): # 迭代去噪 img = self.p_sample(img, t) return (img + 1) * 0.5 # 返回适当范围内的最终图像 # 生成一批样本的函数 @torch.inference_mode def sample(self, batch_size: int = 16) -> torch.Tensor: return self.p_sample_loop((batch_size, self.channels, self.num_frames, self.image_size, self.image_size)) # 计算噪声图像和去噪图像之间的损失(如L1损失)的函数 def p_losses(self, x_start: torch.Tensor, t: torch.Tensor, noise: torch.Tensor = None) -> torch.Tensor: x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # 向图像添加噪声 x_recon = self.denoise_fn(x_noisy, t) # 使用模型去噪 # 计算噪声输出和去噪输出之间的损失 return F.l1_loss(noise, x_recon) # 模型的前向传递 def forward(self, x: torch.Tensor) -> torch.Tensor: t = torch.randint(0, self.num_timesteps, (x.shape[0],)) # 随机时间步 return self.p_losses(x, t) # 计算损失GaussianDiffusion类实现了扩散过程的核心机制,它基于指定的去噪网络构建完整的视频生成流程。该类在初始化时计算扩散过程所需的各种参数,包括基于余弦调度的噪声水平控制。主要组件包括:
前向扩散过程:通过q_sample方法向原始视频添加噪声反向去噪过程:通过p_sample和p_sample_loop方法从纯噪声逐步恢复视频内容损失计算:p_losses方法衡量预测噪声与实际噪声之间的差异条件生成:sample方法支持基于文本条件的视频生成扩散过程的精妙之处在于通过数学建模将视频生成分解为多步有监督去噪任务,每一步都基于神经网络对噪声分布的准确预测,最终实现从纯噪声到高质量视频的演化。
文本处理模块为了实现文本条件引导的视频生成,我们需要将文本转换为神经网络可处理的向量表示。下面实现了基于BERT的文本编码系统:
文本处理模块提供了从自然语言文本到语义向量的转换能力,是文本条件视频生成的关键组件。该模块采用预训练的BERT模型将文本转换为上下文感知的嵌入向量,实现了以下核心功能:
惰性加载机制:通过全局变量和初始化检查,确保BERT模型和分词器只加载一次文本分词:将原始文本转换为BERT可处理的标记ID序列向量嵌入提取:从BERT模型获取文本的语义表示,支持多种嵌入方式句子级表示:使用[CLS]标记作为整个文本的表示平均池化表示:对所有非填充标记的隐藏状态进行平均这些文本嵌入随后被注入到扩散模型中,引导视频生成过程朝向与文本语义一致的方向发展,实现"文本到视频"的精确转换。
训练器设计训练器模块封装了模型训练的完整流程,包括数据加载、优化、模型保存等关键环节。下面是训练器的核心实现:
class Trainer: def __init__(self, diffusion_model: nn.Module, folder: str, *, ema_decay: float = 0.995, train_batch_size: int = 32, train_lr: float = 1e-4, train_num_steps: int = 100000, gradient_accumulate_every: int = 2, amp: bool = False, save_model_every: int = 1000, results_folder: str = './results'): # 初始化训练器、数据集、优化器和其他配置 self.model = diffusion_model # 扩散模型 self.ema = EMA(ema_decay) # 用于平均权重的EMA模型 self.ema_model = copy.deepcopy(self.model) # EMA的副本 self.batch_size = train_batch_size # 批量大小 self.train_num_steps = train_num_steps # 总训练步数 self.ds = Dataset(folder, image_size=diffusion_model.image_size) # 视频数据集 self.dl = cycle(torch.utils.data.DataLoader(self.ds, batch_size=train_batch_size, shuffle=True)) # 数据加载器 self.opt = Adam(diffusion_model.parameters, lr=train_lr) # 优化器 self.step = 0 # 步数计数器 self.amp = amp # 混合精度标志 self.scaler = GradScaler(enabled=amp) # 用于混合精度的缩放器 self.results_folder = Path(results_folder) # 保存结果的文件夹 self.results_folder.mkdir(exist_ok=True, parents=True) # 如果不存在则创建结果文件夹 def reset_parameters(self): # 重置EMA模型以匹配模型的参数 self.ema_model.load_state_dict(self.model.state_dict) def step_ema(self): # 如果训练步数超过阈值,则更新EMA模型 if self.step >= 2000: # 2000步后开始更新EMA self.ema.update_model_average(self.ema_model, self.model) def save(self, milestone: int): # 在里程碑保存模型、EMA模型和优化器状态 torch.save({'step': self.step, 'model': self.model.state_dict, 'ema': self.ema_model.state_dict, 'scaler': self.scaler.state_dict}, self.results_folder / f'model-{milestone}.pt') def load(self, milestone: int): # 从检查点加载模型 data = torch.load(self.results_folder / f'model-{milestone}.pt') self.step = data['step'] self.model.load_state_dict(data['model']) self.ema_model.load_state_dict(data['ema']) self.scaler.load_state_dict(data['scaler']) def train(self, log_fn: Callable[[dict], None] = noop): # 训练循环 while self.stepTrainer类提供了扩散模型训练的完整框架,集成了现代深度学习训练技术,包括:
模型权重管理:通过指数移动平均(EMA)机制维护模型参数的平滑版本优化控制:支持混合精度训练、梯度累积和梯度裁剪,提高训练效率和稳定性检查点机制:定期保存模型状态,支持训练中断后的恢复数据流管理:高效处理视频数据与文本描述的批次加载训练器类的设计充分考虑了扩散模型训练的特殊需求,通过配置参数如批量大小、学习率和训练步数等,可以灵活适应不同的训练场景和计算资源。整体训练流程由train方法协调,它实现了数据加载、前向传播、损失计算、反向传播和参数更新的完整循环。
配置系统为确保模型训练和推理过程的可复现性和灵活性,我们实现了配置系统来管理所有相关参数:
# 内容:用于使用文本到视频扩散训练TinySora模型的默认配置文件 training_data_dir: "./training_data" # 包含训练数据(文本和视频帧)的目录 model: dim: 64 # 模型的维度(嵌入大小) use_bert_text_cond: True # 为模型输入启用基于BERT的文本条件 dim_mults: [1, 2, 4, 8] # 每个模型块的缩放因子(增加深度) init_dim: null # 初始维度,未指定(默认为None) init_kernel_size: 7 # 初始层的内核大小(通常用于卷积层) use_sparse_linear_attn: True # 为效率启用稀疏注意力机制 block_type: "basic" # 模型块的类型(例如,'basic'、'resnet'等) diffusion: image_size: 32 # 视频帧的高度和宽度 num_frames: 5 # 视频中的帧数(序列长度) timesteps: 10 # 训练期间使用的扩散时间步数 loss_type: "l1" # 优化的损失函数('l1'表示L1损失) use_dynamic_thres: False # 是否在训练期间使用动态阈值 dynamic_thres_percentile: 0.9 # 用于动态阈值的阈值百分比 trainer: ema_decay: 0.995 # 模型权重的指数移动平均衰减率 train_batch_size: 2 # 训练期间每批的样本数 train_lr: 0.0001 # 训练学习率 train_num_steps: 10000 # 总训练步数(轮次) gradient_accumulate_every: 1 # 梯度累积频率(1表示不累积) amp: False # 是否使用自动混合精度进行训练(默认:False) step_start_ema: 2000 # 开始应用EMA平滑的步骤 update_ema_every: 10 # 更新EMA权重的频率(每10步) save_model_every: 10 # 每10步保存一次模型 results_folder: "./saved_models" # 保存结果(模型、样本)的文件夹 num_sample_rows: 4 # 采样期间显示的行数(可视化) max_grad_norm: null # 用于裁剪的最大梯度范数(null表示不裁剪)配置系统采用YAML格式,将参数组织为三个主要部分:
模型配置:定义了架构规模和特性,如维度、注意力机制类型和文本条件使用扩散过程配置:设置了图像尺寸、帧数和扩散步数等关键参数训练器配置:管理训练流程的各方面,包括批量大小、学习率和保存频率这种参数化设计使研究者能够在不修改代码的情况下轻松调整模型行为,便于进行实验和比较不同配置的效果。
模型初始化与训练完成所有组件的定义后,我们可以初始化模型并启动训练过程:
# 使用模型的配置参数初始化3D U-Net模型。 # 此模型被移动到GPU(cuda)。 model = Unet3D(**config['model']).cuda # 以U-Net模型作为去噪函数初始化GaussianDiffusion模型。 # 从`config['diffusion']`加载扩散过程的其他配置参数。 # 模型被移动到GPU(cuda)。 diffusion = GaussianDiffusion( denoise_fn = model, # 该模型将用于在扩散过程中去除噪声图像。 **config['diffusion'] # 其他扩散设置,如时间步、噪声调度等。 ).cuda # 使用扩散模型、训练配置和包含训练数据的文件夹初始化Trainer类。 # 这也被移动到GPU。 trainer = Trainer( diffusion_model = diffusion, # 要训练的扩散模型。 **config['trainer'], # 训练过程的配置设置(例如,学习率、批量大小)。 folder = config['training_data_dir'] # 存储训练数据的目录。 ) # 开始训练 trainer.train训练过程会在每次迭代后显示损失值,输出类似于:
0: 0.9512512 1: 0.5235211 ...训练完成后,模型权重将保存在配置的results_folder目录中,随时可用于视频生成。
视频生成实现要使用训练好的模型生成视频,我们需要实现以下辅助函数:
def generate_video(diffusion: GaussianDiffusion, text: str, batch_size: int, cond_scale: float) -> torch.Tensor: """使用训练好的扩散模型生成视频。""" with torch.no_grad: video = diffusion.sample(cond=[text], batch_size=batch_size, cond_scale=cond_scale) return video def save_video_as_gif_pil(video_tensor: torch.Tensor, output_path: str) -> None: video_np = (video_tensor.squeeze(0).permute(1, 2, 3, 0).cpu.numpy * 255).astype(np.uint8) frames = [Image.fromarray(frame) for frame in video_np] frames[0].save(output_path, save_all=True, append_images=frames[1:], duration=100, loop=0) print(f"已保存GIF:{output_path}")以下是使用训练好的模型生成视频的完整流程:
# 对于推理,我们应该加载预训练模型 DEFAULT_MODEL_PATH = "./saved_models" DEFAULT_OUTPUT_DIR = "./results" # 找到最新的模型检查点 model_path = DEFAULT_MODEL_PATH if os.path.isdir(model_path): checkpoint_files = [f for f in os.listdir(model_path) if f.endswith(".pt")] if not checkpoint_files: raise FileNotFoundError(f"在{model_path}中未找到模型检查点") checkpoint_files.sort model_path = os.path.join(model_path, checkpoint_files[-1]) print('从路径加载模型:', model_path) trainer.load(milestone=-1) # 加载最新模型 # 生成视频 text_prompt = "News Reporter talking" batch_size = 1 cond_scale = 2.0 generated_video = generate_video(diffusion, text_prompt, batch_size, cond_scale) # 保存视频 gif_filename = sanitize_filename(text_prompt) + ".gif" output_path = os.path.join(DEFAULT_OUTPUT_DIR, gif_filename) # 如果输出目录不存在,则创建 Path(DEFAULT_OUTPUT_DIR).mkdir(parents=True, exist_ok=True) save_video_as_gif_pil(generated_video, output_path)模型生成的视频示例(经过10K轮训练):
"新闻记者讲话"提示生成的视频
训练10,000轮后的多种模型输出结果
未来研究方向本项目可通过以下方向进一步拓展和优化:
数据多样化:除MSR-VTT外,可使用COCO或ActivityNet等数据集进行训练,或针对特定应用场景收集定制数据架构改进:增加时间分辨率,生成更长的视频序列提高空间分辨率,生成更高清的视频内容训练策略优化:实现渐进式训练,先学习低分辨率再提升到高分辨率探索不同的噪声调度方案添加姿态引导或布局控制结合图像和文本作为混合条件来源:deephub