程序员求职经验分享与学习资料整理平台

网站首页 > 文章精选 正文

Transformer 模型中的掩码(Mask)机制解析

balukai 2025-08-03 04:06:48 文章精选 5 ℃

Mask机制几乎贯穿了Transformer架构的始终,若不能首先将mask机制交代清楚,就难以对Transformer进行连贯的阐述。因此,决定将mask机制的介绍放在最前面,如果一开始难以理解,可以结合后文中的整体架构再回来理解mask机制。

Transformer 模型中的掩码(Mask)机制主要用于控制注意力计算中的信息流动,确保模型在处理序列数据时遵循特定约束(如忽略无效信息或防止未来信息泄露)。以下是几种核心掩码机制及其作用:


1. Padding Mask(填充掩码)

  • 作用:处理变长序列的批量训练问题。当序列长度不一致时,需用特殊符号(如 [PAD])填充较短序列,而 Padding Mask 会标记这些填充位置,使模型在计算注意力时忽略它们
  • 实现方式:生成与输入序列同形状的布尔矩阵,填充位置为 False(或 0),有效位置为 True(或 1)。在计算注意力得分(Q·K)后,将填充位置替换为极小的负值(如 -1e9),使 Softmax 后权重趋近于 0。
  • 应用场景所有注意力层(编码器和解码器均需使用)。

2. Sequence Mask(序列掩码 / Look-Ahead Mask / Causal Mask)

  • 作用:防止解码器在预测当前位置时访问未来信息,确保自回归生成的正确性(如机器翻译或文本生成)。
  • 实现方式:生成上三角矩阵(对角线及以上为 0 或 -∞,对角线及以下为 1),使当前位置仅能关注历史位置。
  • 示例代码:
  • python
  • 复制
  • # 生成上三角掩码矩阵 mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) # 上三角为1 mask = mask.masked_fill(mask == 1, float('-inf')) # 替换为负无穷
  • 应用场景仅解码器的自注意力层(如 Transformer Decoder 的 Masked Multi-Head Attention)。

3. 组合掩码(Padding + Sequence Mask)

  • 作用:解码器需同时处理填充位置和未来信息屏蔽。通过叠加两种掩码,确保模型既忽略填充符,又仅依赖历史信息。
  • 实现方式:将 Padding Mask(形状 [batch, seq_len])与 Sequence Mask(形状 [seq_len, seq_len])按位与(&)或相加。
  • 示例:
  • python
  • 复制
  • # 结合两种掩码 combined_mask = padding_mask & sequence_mask # 布尔矩阵逻辑与 scores = scores.masked_fill(combined_mask == 0, -1e9) # 应用至注意力得分

4. 进阶变体:预训练任务中的掩码

以下机制主要用于预训练模型(如 BERT、ERNIE),不属于原始 Transformer 结构,但基于相似原理:

  • MLM Mask(掩码语言模型)
    在 BERT 中随机遮盖输入 Token(如 15%),要求模型预测被遮盖的词。遮盖方式包括:
    Token-level Mask:随机遮盖单个 Token(原始 BERT)。
  • Whole-Word Mask:遮盖整个词(如 “Superman” 拆分为 “Super” 和 “man” 时同时遮盖)。
  • N-gram Mask:按比例遮盖 Unigram 到 4-gram(MacBERT)。
  • Knowledge Mask(知识驱动掩码)
    ERNIE 针对实体或短语进行整体遮盖(如遮盖 “哈利波特” 而非单个字),迫使模型学习语义知识。

总结对比表

掩码类型

作用位置

主要目的

实现关键

Padding Mask

所有注意力层

忽略填充符 [PAD]

填充位置设为 -∞,Softmax 权重归零

Sequence Mask

解码器自注意力层

防止未来信息泄露

上三角矩阵(值 -∞)

组合掩码

解码器自注意力层

同时处理填充与未来信息

Padding Mask + Sequence Mask

MLM/Knowledge Mask

预训练模型(如 BERT)

学习上下文感知表示

随机/实体级遮盖

提示:原始 Transformer 仅依赖 Padding MaskSequence Mask,组合掩码是其衍生应用;其他变体(如 MLM)属于预训练任务的扩展设计。


Tags:

最近发表
标签列表