pytorch实现seq2seq时对loss进行mask的方式
在Pytorch实现seq2seq模型中,对于一个batch中的每个序列,其长度可能不一致。对于长度不一致的序列,需要进行pad操作,使其长度一致。但是,在计算loss的时候,pad部分的贡献必须要被剔除,否则会带来噪声。
为了解决这一问题,可以使用mask技术,即使用一个mask张量对loss进行掩码,将pad部分设置为0,只计算有效部分的loss。
下面是实现seq2seq时对loss进行mask的方式的完整攻略:
1.创建mask张量
通过给定的输入序列长度,创建一个bool掩码,其中有效部分为True,pad部分为False。
其中,seq_len为每个序列的长度,pad_idx为pad的token索引,此处默认使用0进行pad。
2.计算loss时掩码
在计算loss时,将mask张量与计算得到的loss张量相乘即可实现mask。
3.示例说明
下面给出两个示例,更好地理解如何使用mask对seq2seq模型的loss进行掩码。
假设我们有如下两个序列:
- 输入序列:['I', 'love', 'you']
- 目标序列:['Ich', 'liebe', 'dich']
其中,我们使用3个token来表示输入和输出序列,对应的pad_idx为0。那么,我们需要将输入和输出序列转换为相同的长度,这里设定为5。那么,经过pad之后,就可以得到如下矩阵:
其中,1/3/2对应的是输入序列中的'I'/'love'/'you',4/5/6对应的是目标序列中的'Ich'/'liebe'/'dich'。
接下来,我们需要创建掩码张量,对于pad部分置为False,其他部分置为True。
最后,计算loss时,使用mask张量掩码:
这里,我们首先使用model计算模型输出,然后计算loss,最后使用target_mask掩码。需要注意的是,这里的target_seqs需要去掉最后的一个token,也就是'pad',以保证input_seqs和target_seqs的长度相同。