近期,MistralAI宣布推出了Pixtral 12B,这是一种新型开源多模态模型。根据论文所述,与其他开源多模态模型相比,该模型取得了最优结果。那么,Pixtral 12B与其他模型有何不同?在本篇文章中,我将首先介绍Pixtral 12B的架构,然后展示如何在Python中实现它。
Pixtral 12B架构解析
概述
Pixtral 12B是一个开源多模态语言模型,经过训练能够理解图像和文本。它能够执行多模态任务,如视觉问答(VQA)、视觉推理、光学字符识别(OCR)和图像描述。如下所示,它的性能优于其他类似规模甚至某些规模更大的开源模型。
评价指标为MM-MT-Bench和LMSys-Vision ELO。MM-MT-Bench由来自各种场景、涵盖32个核心元任务的多选视觉问题组成。LMSys-Vision ELO是一个基于ELO评分的开源评估系统,用于评估人类对多模态模型的偏好。
如你所见,与Llama-3.2 11B、LLaVa-OneVision 7B和Qwen-2-VL 7B等规模相似的模型相比,Pixtral 12B的性能表现出色。此外,它的性能甚至优于规模更大的模型,如Llama-3.2 90B和LLaVa-OneVision 72B。为何Pixtral 12B在较小规模下能取得如此出色的成果?让我们一起探究其内部机制。
Pixtral 12B架构详解
Pixtral 12B基于Transformer架构,由多模态解码器和视觉编码器组成。对于文本信息,它使用Tekken分词器,该分词器采用带有Tiktoken的字节对编码(你可以从此处查看更多信息)。另一方面,视觉编码器对图像信息进行编码,而多模态解码器则接收图像和文本标记作为输入,并生成文本。整体架构如下。
这种架构对于多模态语言模型来说是典型的。然而,由于采用了原创的视觉编码器,Pixtral 12B能够接受多种分辨率的多张图像输入。该视觉编码器对视觉Transformer进行了从头开始的训练,并做出了以下四项关键改变。
中断标记(Break Tokens)
Pixtral 12B引入了中断标记,以帮助模型理解多张图像和不同的分辨率。遵循常规的ViT(Vision Transformer)做法,它将图像分割成多个图像块。在上面的图示中,[b]代表[IMAGE BREAK](图像中断)标记,[e]代表[IMAGE END](图像结束)标记。[IMAGE BREAK]标记将插入到图像块行变化的位置,而[IMAGE END]标记将插入到图像块的最后一个位置。
前馈层中的门控(Gating in Feedforward Layers)
Pixtral 12B在注意力块中用门控前馈层[2]替换了标准的前馈层。门控是一种门控线性单元(GLU),它是一种神经网络层,定义为两个线性层的逐元素乘积,其中一个线性层经过sigmoid激活。与Transformer中通常使用的ReLU或GELU激活函数相比,它带来了质量上的提升。GLU的数学公式如下。
除了GLU之外,Pixtral 12B还在一层中使用了混合专家(Mixture of Experts,MoE)层,该层也是通过sigmoid函数进行激活的。官方实现如下所示。
class MoeLayer(nn.Module):
def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):
super().__init__()
assert len(experts) > 0
self.experts = nn.ModuleList(experts)
self.gate = gate
self.args = moe_args
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# Mixture of Experts part : choose the top candidate
# We can get xW1 in the mathematical formula
gate_logits = self.gate(inputs)
weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)
# GLU part
# we can get sigma(xW1)
weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)
results = torch.zeros_like(inputs)
for i, expert in enumerate(self.experts):
batch_idx, nth_expert = torch.where(selected_experts == i)
# GLU part
# we can get component-wise (sigma(xW1) x xV)
results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx])
return results
序列打包(Sequence Packing)
Pixtral 12B将图像沿序列维度展平,以便在单个批次内高效处理图像,并将它们连接起来。在官方实现中,它看起来如下所示。
# pass images through initial convolution independently
patch_embeds_list = [self.patch_conv(img.unsqueeze(0)).squeeze(0) for img in images]
# flatten to a single sequence
patch_embeds = torch.cat([p.flatten(1).permute(1, 0) for p in patch_embeds_list], dim=0)
你可能会注意到,与其他一起输入的图像之间存在一些信息“泄漏”。Pixtral 12B构建了一个块对角掩码,以避免来自不同输入图像的信息泄漏。得益于这个注意力掩码,我们能够消除其他图像的影响。
RoPE-2D:位置编码的替代函数
Pixtral 12B用相对旋转位置编码(RoPE)替代了传统的固定或学习得到的图像块位置编码。RoPE使用旋转矩阵对绝对位置进行编码,并在自注意力公式中纳入了明确的相对位置依赖关系。其优点之一是输入长度的灵活性。我们如何实现这一功能呢?设x为d维的图像块向量(键或查询特征)。图像块向量x(i, j)表示图像中位置为(i, j)的图像块。x(i, j)的RoPE-2D变换表示为:
实际上,这个公式可以改变为下面的公式,这样我们就可以高效地计算RoPE-2D。官方实现使用的就是这个版本。公式和代码如下所示。
# calculate the rotation matrix
def precompute_freqs_cis_2d(
dim: int,
height: int,
width: int,
theta: float,
) -> torch.Tensor:
"""
freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by
(height, width) position tuples
"""
# (dim / 2) frequency bases
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
h = torch.arange(height, device=freqs.device)
w = torch.arange(width, device=freqs.device)
freqs_h = torch.outer(h, freqs[::2]).float()
freqs_w = torch.outer(w, freqs[1::2]).float()
freqs_2d = torch.cat(
[
freqs_h[:, None, :].repeat(1, width, 1),
freqs_w[None, :, :].repeat(height, 1, 1),
],
dim=-1,
)
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
# RoPE-2D operation
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = freqs_cis[:, None, :]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)
除了视觉编码器之外,Pixtral 12B还具备一个视觉-语言适配器,用于将图像特征与文本特征对齐,以及一个多模态解码器。多模态解码器是基于Mistral Nemo 12B构建的,这是一个仅解码器的语言模型,在各种任务中表现出色。解码器将图像和文本视为相同,并使用因果自注意力机制来顺畅地促进多图像对话。
到目前为止,我们已经了解了Pixtral 12B的架构。接下来,我们将用Python实现它,并尝试应用于图像标题任务。
Python实现——使用Pixtral 12B的几个用例
我将向你展示如何实现Pixtral 12B。首先,我们设置环境。接下来,我们将Pixtral 12B应用于几个任务,如复杂图形推理和OCR。我将使用Pixtral 12B论文[1]中引入的MM-MT-Bench数据集进行评估。
环境设置
我使用了一个带有Python 3.11的conda环境。我在Ubuntu 20.04上进行了实验,配置了cuda 12.4和40 GB的显存。(目前还没有量化选项。)根据我的实验,运行Pixtral 12B推理代码至少需要30 GB的显存。
conda create -n mistral python=3.11 -ycreate -n mistral python=3.11 -y
conda activate mistral
接下来,我们需要通过pip安装以下几个库。
pip install transformers accelerate datasets
在这篇文章中,我将mistral-inference仓库安装为可编辑模式。你也可以选择不使用-e选项来安装它。
git clone https://github.com/mistralai/mistral-inferenceclone https://github.com/mistralai/mistral-inference
pip install -e .
你需要从hugging face-hub安装模型权重。请注意,你需要进行身份验证才能从官方仓库安装它们。你还必须按照本说明中的参考设置一个访问令牌来进行安装。设置好令牌后,请执行以下命令。
huggingface-cli login
你将会看到一个设置令牌的界面,你可以在其中输入你的令牌。
所有准备工作都已完成!现在,让我们来实现Pixtral 12B吧。
Python实现——使用Pixtral 12B的几个用例
首先,我们需要下载MM-MT-Bench数据集。我们可以通过Huggingface数据集API快速下载它。
from datasets import load_datasetimport load_dataset
ds = load_dataset("mistralai/MM-MT-Bench")
这个数据集包含92个带有问题陈述的示例,其中包括文本提示、图像和类别,以及相应的答案。文本提示部分是以JSON格式序列化的。你可以使用下面的代码提取其中的每一个示例。
# i is an index number. You can set this number freely until 91.
i = 1
# extract an image
image = ds['eval']['image'][i]
# extract a prompt
data = json.loads(ds['eval']['conversation'][i])
prompt = data[0]['content'][1]['text']
# extract a category
category = ds['eval']['category'][i]
为了准备输入,我们使用了mistral-common库为方便使用而提供的ChatCompletionRequest。请注意,如果你传递的是PIL图像作为输入,则需要使用ImageChunk类;如果你传递的是在线图像的URL,则需要使用ImageURLChunk类。
# prepare an input for PIL Image object
completion_request = ChatCompletionRequest(messages=[UserMessage(content=[ImageChunk(image=image), TextChunk(text=prompt)])])
# prepare an input for an online image
# completion_request = ChatCompletionRequest(messages=[UserMessage(content=[ImageURLChunk(image_url=url), TextChunk(text=prompt)])])
输入准备完成后,你需要将它们传递给分词器(Tokenizer)。
# encode an input using tokenizer
encoded = tokenizer.encode_chat_completion(completion_request)
images = encoded.images
tokens = encoded.tokens
你需要将分词后的图像和文本与模型和某些参数一起传递给生成函数。
# inference using Pixtral 12B
out_tokens, _ = generate([tokens], model, images=[images], max_tokens=256, temperature=0.35, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.decode(out_tokens[0])
这不是很简单吗?我们只需要几行代码就可以实现Pixtral 12B的推理。
现在,让我们来检验一下它的能力。我从数据集中挑选了三个多模态任务。请注意,输入图像的分辨率越高越好,最大分辨率为1024像素。第一个问题是关于高级线性代数定理的PDF。我们要求Pixtral 12B给出PDF中特定部分的定理。
由于我将最大令牌数设置为256,所以答案在完整答案的中间就停止了,但我们可以看到Pixtral 12B能够根据提示正确地回答问题。
第二个例子是关于图表理解的。我们询问了这幅插图的内容。
如你所见,Pixtral 12B能够准确地描述图表。我认为这个例子很有挑战性,因为文本框中有一些颜色模式,但结果却相当令人印象深刻。它可以应用于光学字符识别(OCR)任务。
最后一个例子是关于Python Django文档的。我们询问了这份文档的主题。
Pixtral 12B能够详细阅读和理解信息。它不仅能够捕捉文档内容,还能捕捉到页眉和页脚的信息。
总结
即使参数较少,Pixtral 12B也能执行各种多模态任务。尽管其架构与传统的多模态大型语言模型相似,但作者从头开始为多模态任务训练了一个特定的视觉编码器,与其他开源模型相比,它可以表现出更优越的性能。它可以应用于各种多模态任务,如视觉问题回答(VQA)和光学字符识别(OCR)。