Seq2seq模型的一个变种网络:Pointer Network的简单介绍

2017年09月23日 由 yuxiangyu 发表 440290 0
Pointer Network(为方便起见以下称为指针网络)是seq2seq模型的一个变种。他们不是把一个序列转换成另一个序列, 而是产生一系列指向输入序列元素的指针。最基础的用法是对可变长度序列或集合的元素进行排序。

seq2seq的基础是一个LSTM编码器加上一个LSTM解码器。在机器翻译的语境中, 最常听到的是: 用一种语言造句, 编码器把它变成一个固定大小的陈述。解码器将他转换成一个句子, 可能和之前的句子长度不同。例如, "como estas?"-两个单词-将被翻译成 "how are you?"-三个单词。

“注意力”增强时模型效果会更好。这意味着解码器在输入的前后都可以访问。就是说, 它可以从每个步骤访问编码器状态, 而不仅仅是最后一个。思考一下它怎样帮助西班牙语让形容词在名词之前: “neural network”变成 “red neuronal”

在专业术语中,“注意力”(至少是这种特定的 基于内容的注意力) 归结为加权平均值均数。简而言之,编码器状态的加权平均值转换为解码器状态。注意力只是权重的分配。

想知道更多可以访问:https://medium.com/datalogue/attention-in-keras-1892773a4f22

在指针网络中, 注意力更简单:它不考虑输入元素,而是在概率上指向它们。实际上,你得到了输入的排列。有关更多细节和公式, 请参阅论文:

https://arxiv.org/abs/1506.03134

注意, 不需要使用所有的指针。例如, 给定一段文本, 网络可以通过指向两个元素来标记摘录: 它的起始位置和结束位置。

实验


我们从顺序数字开始?换句话说,一个深入的argsort:
In [3]: np.argsort([ 10, 30, 20 ])
Out[3]: array([0, 2, 1], dtype=int64)

In [4]: np.argsort([ 40, 10, 30, 20 ])
Out[4]: array([1, 3, 2, 0], dtype=int64)

令人惊讶的是,作者在论文中没有继续进行完成任务。相反的,他们使用两个奇特的问题:旅行推销员凸包(参考README), 虽然结果是好的。但为什么不按照数字顺序呢?

Pointer Network介绍

原来,数字排序很难做到。他们在后续文件中提到了这个问题(Order Matters: Sequence to sequence for sets)。重点是顺序不能错。也就是说,我们讨论的是输入元素的顺序。作者发现,它对结果影响很大, 这不是我们想要的。因为本质上我们处理的是集合作为输入, 而不是序列。集合没有固定的顺序,所以元素是如何排列在理论上不应该影响结果。

因此, 本文介绍了一种改进的架构, 它们通过连接到另一个LSTM的前馈网络来替换LSTM编码器。这就是说,LSTM重复运行,以产生一个置换不变的嵌入给输入。解码器同样是一个指针网络。

让我们回到数字排列。较长的集合更难去排列。对于5个数字,他们报告的准确度范围是81%-94%, 具体取决于模型 (这里提到的准确度是指正确排序序列的百分比)。当处理15数字时, 这个范围变成了0%-10%。

在我们的研究中,对于五个数字,我们几乎达到了100%的准确度。请注意, 这是Keras所报告的 "分类精度", 意思是在正确位置上元素的百分比。例如, 这个例子是50%准确度,即前两个元素不动, 但最后两个被调换:
4 3 2 1 -> 3 2 0 1

对于有八元素的序列, 分类精度下降到大约33%。我们还尝试了一个更具挑战性的任务, 按它们的和对一个集合进行排序:
[1 2] [3 4] [2 3] -> 0 2 1

网络处理它就像处理简单的(un)标量数字。

我们注意到的一个意想不到的事情是, 网络倾向于重复指针, 尤其是在训练的早期。这是令人失望的:显然它不记得它不久之前的预测。
y_test: [2 0 1 4 3]
p: [2 2 2 2 2]

Pointer Network介绍

在训练的早期, 人们聚集在一起, 构想指针网络的输出。
y_test: [2 0 1 4 3]
p: [2 0 2 4 3]

同时, 训练有时会被某种准确度所困。而一个对少量数字进行训练的网络并不能概括更大的, 比如:
981,66,673
856,10,438
884,808,241

为了帮助网络使用数字, 我们添加一个 ID (1,2, 3...) 到序列的每个元素。这个假设是因为注意力是基于内容的, 也许它可以使用内容中明确编码的位置。此ID是一个数字 (train_with_positions) 或独热向量 (train_with_positions_categorical)。这看起来有点效果,但没有解决根本问题。

实验代码在GitHub可以使用。与original repo相比, 我们添加了一个数据生成脚本, 并更改了训练脚本以从生成的文件中加载数据。我们还将优化算法改成RMSPro, 因为它在处理学习率的过程中似乎收敛得很好。

数据结构


3D数组中的数据。第一个维度 (行) 是像往常一样的例子。第二个维度“列”通常是特征(属性), 但带序列的特征进入第三个维度。第二个维度由给定序列的元素组成。下面是三个序列示例, 每个都有三个元素 (步骤), 每个元素都有两个特征:
array([[[ 8,  2],
[ 3, 3],
[10, 3]],

[[ 1, 4],
[19, 12],
[ 4, 10]],

[[19, 0],
[15, 12],
[ 8, 6]],

目标是按特征的和对元素进行排序, 因此相应的目标将是:
array([[1, 0, 2],
[0, 2, 1],
[2, 0, 1],

并且,它们将被明确编码:
array([[[ 0.,  1.,  0.],
[ 1., 0., 0.],
[ 0., 0., 1.]],

[[ 1., 0., 0.],
[ 0., 0., 1.],
[ 0., 1., 0.]],

[[ 0., 0., 1.],
[ 1., 0., 0.],
[ 0., 1., 0.]],

这里有一个问题,我们一直在讨论循环网络如何处理可变长度的序列,但实际上数据是3D数组,如上所示。换句话说,序列长度是固定的。

Pointer Network介绍

处理这一问题的方法是在最大可能的序列长度上固定维度, 并用零填充未使用的位置。

但它有可能搞乱代价函数,因此我们更好地掩盖那些零, 确保他们在计算损失时被省略。Keras官方的做法似乎是embdedding layer。相关参数为mask_zero:

mask_zero: 无论输入值0是否是一个特殊的 "padding" 值, 都应该被屏蔽掉。当使用可变长度输入的循环层时这很有用。如果它为“True”,那么模型中的所有后续层都需要支持掩蔽, 否则将引发异常。如果 mask_zero设置为True, 那么作为一个序列,词汇表中不能使用索引0(input_dim应等于词汇量“+1”)。

关于实现


我们使用了一个Keras执行的指针网络。GitHub上还有一些其他的, 大部分用Tensorflow。

附录A:指针网络的实现

附录B:seq2seq的一些注意力的实现
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消