可解释性是人工智能模型的关键主题之一。近来复杂的人工智能往往是黑箱算法,人类很难理解人工智能为什么会产生这些结果。最近,有一篇论文《CLIP Surgery for Better Explainability with Enhancement in Open-Vocabulary Tasks》,主要介绍了 CLIP 的可解释性技术。这篇论文展示了 CLIP 强大的可解释性。因此,我将在本博客中介绍 CLIP_Surgery 的架构及其应用。
快速回顾 CLIP
CLIP 是 OpenAI 开发的改变游戏规则的人工智能之一。得益于其独特的架构,它能够实现零镜头图像分类。其架构如下所示。
CLIP 具有图像和文本编码器,用于创建图像和文本嵌入。训练数据是图像和文本对,例如一张狗的图像和文本 “狗的照片”。如果图像和文本是一对,它将利用对比预训练来对齐图像和文本嵌入,如果不是一对,则不对齐。为了直观地理解,我们来看看下面的例子。在这个例子中,我们使用了三对图像和文本(上图中的 N = 3)。
图像和文本编码器的输出嵌入维度始终为(1,512)。在本例中,图像和文本嵌入的维数分别为(3, 512)。利用嵌入式的余弦相似度,我们可以计算出相似度矩阵,如上图所示。在对比预训练中,CLIP 利用该相似性矩阵对匹配的配对(= 对角线元素)进行对齐,以获得相似性,而其他配对(= 其他元素)则获得不相似性。具体来说,论文中的伪代码过程如下:
# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter
# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]
# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2
在计算出图像和文本嵌入的余弦相似度后,他们应用交叉熵损失,使相似度矩阵中的对角线元素为 1,其他元素为 0。作者将这种计算方法称为对比损失。CLIP 仅通过这种对比损失进行训练。
零镜头分类的过程如下。首先,我们输入 n 个候选文本,得到维数为 (n, 512) 的嵌入。然后,我们计算目标图像嵌入和候选文本嵌入之间的相似度。最后,我们可以选择相似度最高的候选文本作为一类。是不是很简单?
程序简单直观,但我们需要用数百万张图像和文本对以及数百个 GPU 来训练 CLIP。从最初的论文来看,他们使用了 32,768 个非常大的迷你批量,在 592 个 V100 GPU 上花了 18 天时间进行训练。因此,许多公司将该模型作为基础模型,而不是从头开始训练。
CLIP Surgery 算法解析
开发 CLIP Surgery 算法主要是为了增强 CLIP 结果的可解释性。令人惊讶的是,CLIP Surgery 可以在不进行任何额外训练的情况下将标签对应的激活图可视化。由于其良好的激活图可视化能力,这项技术可以应用于分割任务的基础模型--“分割任何事物”(Segmentation Anything)。
作者对注意力层进行了彻底检查,以实现无需训练的良好可解释性。请参见下图。
左侧显示的是原始 CLIP 的注意层,右侧显示的是 CLIP 手术的注意层。他们明确指出,查询键自我注意会激活与标签相对应的语义区域。另一方面,价值-价值自我注意只能关注语义区域。这意味着什么?下图显示了查询键自我注意和价值-价值自我注意的激活图可视化。
如图所示,查询关键字自我关注除了可视化目标标签区域外,还可视化无关区域。反之,值-值自我关注则可以关注相应的目标标签区域。根据实验,查询键自关注可能会导致特征图混淆。需要注意的是,这一事实只是启发式的,并没有通过数学定理推导出来。
此外,他们还发现激活图在所有标签中都有冗余特征。请看下图。
可以看到,冗余区域出现在所有标签的相同位置。因此,他们想到了通过去除所有标签中的共同激活区域来去除冗余特征。
他们是如何做到这一点的呢?具体来说,正式的实现方法如下。
# weights to restrain influence of obvious classes on others
# (batch_size, 1, 512) @ (the number of labels, 512).T = (batch_size, 1, the number of labels)
prob = image_features[:, :1, :] @ text_features.t()
# prob has (batch_size, 1, the number of labels)
prob = (prob * 2).softmax(-1)
# w has (batch_size, 1, the number of labels)
w = prob / prob.mean(-1, keepdim=True)
# element-wise multiplied features
# b is batch_size
# n_t is the number of labels
# n_i is the number of tokens (=197)
# c is the feature dimension (=512)
b, n_t, n_i, c = image_features.shape[0], text_features.shape[0], image_features.shape[1], image_features.shape[2]
# feats has (batch_size, n_i, n_t, c)
feats = image_features.reshape(b, n_i, 1, c) * text_features.reshape(1, 1, n_t, c)
feats *= w.reshape(1, 1, n_t, 1)
# redundant_feats has (batch_size, n_i, n_t, c)
redundant_feats = feats.mean(2, keepdim=True) # along cls dim
feats = feats - redundant_feats
# sum the element-wise multiplied features as cosine similarity
# similarity has (batch_size, n_i, n_t)
similarity = feats.sum(-1)
为了更好地说明,我在代码中添加了每次计算的尺寸大小转换。现在,让我们一步一步来理解它。
第一个步骤是计算权重向量,以保持每个类别的影响力相等。首先,我们从图像嵌入中提取类标记。在转换器架构中,类别标记是标记维度中的第一个。请注意,类别标记应该包含所有其他标记的信息(如果你对 Vision Transformer 不熟悉,可以参考本博客 [5])。然后,我们计算余弦相似度,得到相似度矩阵。接着,我们将相似性矩阵的值转换为沿标签维度的概率,得到权重矩阵。
在第二部分,我们计算除冗余特征外的特征矩阵。首先,我们计算图像和文本嵌入的元素特征矩阵。直观地说,如上图所示,跨标签的激活区域在该图中将具有更高的值。因此,我们可以通过计算各标签的平均值,从特征矩阵中得到冗余特征。从原始特征矩阵中减去冗余特征后,我们就可以得到纯特征矩阵。
最后,我们通过沿特征维度对特征矩阵求和,得到相似性矩阵。
为了实现特征图的可视化,我们需要对相似性矩阵进行归一化、重塑和插值处理,使其与输入图像的大小相匹配。
可以看到,它可以捕捉到与标签相对应的语义区域。你可以感受到这种可视化是多么强大。
至此,我们已经了解了 CLIP Surgery 的详细算法。在最后一节,我们将检验它在真实世界数据中的能力及其应用。
应用: 检查真实世界数据的能力和 “什么都有可能 ”分段的点提供者
在最后一部分,我将引导大家了解 CLIP Surgery 在真实世界数据和 Segment Anything (SAM) 中的应用。让我们深入了解它们!
环境设置
第一步,你需要设置一个环境。我使用的是 ubuntu20.04、cuda11.7 和 Python3.10 环境。首先,我使用 conda 创建了虚拟环境。
conda create --name sam python==3.10 -y
conda activate sam
conda install pip
## optional: To avoid install libraries on the local environment,
## check the which pip will be used to store libraries
which pip
# I use /opt/conda/envs/sam/bin/pip in my enviornment.
接下来,你需要按照官方说明安装 Pytorch 和 torchvision。你可以根据自己的环境安装相应的版本。例如,下面的命令就是我的情况。
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
然后,你需要使用以下命令安装 SAM 资源库和模型权重。
pip install git+https://github.com/facebookresearch/segment-anything.git
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
你还需要安装 CLIP Surgery 软件仓库。
git clone https://github.com/xmed-lab/CLIP_Surgery.git
最后,你需要安装几个软件包。你可以通过 pip 安装它们,格式为 “pip install <library>”。
tqdm==4.66.5
ftfy==6.2.3
matplotlib
opencv-python
regex
现在,你已经完成了环境设置。
Flickr30k 数据集的 CLIP Surgery能力
首先,我想使用 Flickr30k 数据集[4]来检验 CLIP 手术在真实世界数据中的能力。因此,我将比较 CLIP 和 CLIP Surgery 激活图。稍后我会附上使用的代码。下图是比较结果。
正如你所看到的,普通 CLIP 无法精确检测到对象,但 CLIP Surgery 可以在对象存在时检测到与标签相对应的对象。但是,当对象不存在时,如猫和植物,CLIP Surgery 仍然会遇到问题。造成这一问题的原因之一是后处理中的最小-最大归一化。当激活图中只有不相关的区域时,最小-最大归一化可能会增强它们的值差异。为了解决这个问题,我们可以在最小归一化之前添加一个阈值。在 Flickr 数据集中,相关区域值阈值大于 0.1,这一点可以从相似性图的直方图中得到验证。结果如下所示。
有了阈值,我们就可以去除不相关的区域。阈值可能会根据数据集而改变;因此,我们应该使用直方图检查并找到该值。
Segment Anything 的点提供者
由于激活图可视化的精确性,CLIP Surgery 可以应用于 Segment Anything 的点提供器。请注意,SAM 是 Meta 于 2023 年开发的分割基础模型之一。下图显示了其架构。
SAM 的分割能力令人难以置信。但是,它并不是由带标签的分割数据集训练出来的,因此当我们要指定对象时,需要输入一些点、边界框或遮罩。正如你所猜测的那样,这类注释非常耗时。在这里,CLIP Surgery 可以帮助我们自动找到这些点。让我们看看如何在实际应用中将 CLIP Surgery 和 SAM 结合起来。
要为 SAM 生成点,我们要对激活图进行下采样,并对值进行排序,以选择相关区域。在官方实现中,他们使用维度为(7 x 7)的激活图来找到最相关的区域。当目标对象不存在时也会出现问题,因此我对原始实现稍作修改,添加了一个阈值。结果如下所示。
橙色的点代表与标签相关的点,蓝色的点代表标签的负点。正如你所看到的,它能以相当高的精度检测目标标签坐标。需要注意的是,该点的精度来自 CLIP 功能。因此,如果 CLIP 无法理解目标,就无法准确提供目标点。