网站首页 > 文章精选 正文
Mask机制几乎贯穿了Transformer架构的始终,若不能首先将mask机制交代清楚,就难以对Transformer进行连贯的阐述。因此,决定将mask机制的介绍放在最前面,如果一开始难以理解,可以结合后文中的整体架构再回来理解mask机制。
Transformer 模型中的掩码(Mask)机制主要用于控制注意力计算中的信息流动,确保模型在处理序列数据时遵循特定约束(如忽略无效信息或防止未来信息泄露)。以下是几种核心掩码机制及其作用:
1. Padding Mask(填充掩码)
- 作用:处理变长序列的批量训练问题。当序列长度不一致时,需用特殊符号(如 [PAD])填充较短序列,而 Padding Mask 会标记这些填充位置,使模型在计算注意力时忽略它们
- 实现方式:生成与输入序列同形状的布尔矩阵,填充位置为 False(或 0),有效位置为 True(或 1)。在计算注意力得分(Q·K)后,将填充位置替换为极小的负值(如 -1e9),使 Softmax 后权重趋近于 0。
- 应用场景:所有注意力层(编码器和解码器均需使用)。
2. Sequence Mask(序列掩码 / Look-Ahead Mask / Causal Mask)
- 作用:防止解码器在预测当前位置时访问未来信息,确保自回归生成的正确性(如机器翻译或文本生成)。
- 实现方式:生成上三角矩阵(对角线及以上为 0 或 -∞,对角线及以下为 1),使当前位置仅能关注历史位置。
- 示例代码:
- python
- 复制
- # 生成上三角掩码矩阵 mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) # 上三角为1 mask = mask.masked_fill(mask == 1, float('-inf')) # 替换为负无穷
- 应用场景:仅解码器的自注意力层(如 Transformer Decoder 的 Masked Multi-Head Attention)。
3. 组合掩码(Padding + Sequence Mask)
- 作用:解码器需同时处理填充位置和未来信息屏蔽。通过叠加两种掩码,确保模型既忽略填充符,又仅依赖历史信息。
- 实现方式:将 Padding Mask(形状 [batch, seq_len])与 Sequence Mask(形状 [seq_len, seq_len])按位与(&)或相加。
- 示例:
- python
- 复制
- # 结合两种掩码 combined_mask = padding_mask & sequence_mask # 布尔矩阵逻辑与 scores = scores.masked_fill(combined_mask == 0, -1e9) # 应用至注意力得分
4. 进阶变体:预训练任务中的掩码
以下机制主要用于预训练模型(如 BERT、ERNIE),不属于原始 Transformer 结构,但基于相似原理:
- MLM Mask(掩码语言模型):
在 BERT 中随机遮盖输入 Token(如 15%),要求模型预测被遮盖的词。遮盖方式包括:Token-level Mask:随机遮盖单个 Token(原始 BERT)。 - Whole-Word Mask:遮盖整个词(如 “Superman” 拆分为 “Super” 和 “man” 时同时遮盖)。
- N-gram Mask:按比例遮盖 Unigram 到 4-gram(MacBERT)。
- Knowledge Mask(知识驱动掩码):
ERNIE 针对实体或短语进行整体遮盖(如遮盖 “哈利波特” 而非单个字),迫使模型学习语义知识。
总结对比表
掩码类型 | 作用位置 | 主要目的 | 实现关键 |
Padding Mask | 所有注意力层 | 忽略填充符 [PAD] | 填充位置设为 -∞,Softmax 权重归零 |
Sequence Mask | 解码器自注意力层 | 防止未来信息泄露 | 上三角矩阵(值 -∞) |
组合掩码 | 解码器自注意力层 | 同时处理填充与未来信息 | Padding Mask + Sequence Mask |
MLM/Knowledge Mask | 预训练模型(如 BERT) | 学习上下文感知表示 | 随机/实体级遮盖 |
提示:原始 Transformer 仅依赖 Padding Mask 和 Sequence Mask,组合掩码是其衍生应用;其他变体(如 MLM)属于预训练任务的扩展设计。
猜你喜欢
- 2025-08-03 学习回顾—BGP(0x03C)-配置向对等体发送缺省路由
- 2025-08-03 「网络」五大网络概念:IP、子网掩码、网关、DHCP和PPPoE
- 2025-08-03 记住3个部分、2个地址,1个公式,你也能轻松划分子网
- 2025-08-03 OSPF基础配置命令及案例
- 2025-08-03 学习回顾——OSPF路由协议(0x2A)-配置OSPF路由聚合
- 2025-08-03 静态路由深入讲解
- 2025-08-03 CTF竞赛密码学 之 LFSR
- 2025-08-03 学习回顾—BGP(0x045)-配置BGP路由振荡抑制
- 2025-08-03 HUAWEI FW全局选路策略详解
- 2025-08-03 WIFI管理员密码忘记了怎么办?路由器默认WIFI管理密码是多少?
- 最近发表
- 标签列表
-
- newcoder (56)
- 字符串的长度是指 (45)
- drawcontours()参数说明 (60)
- unsignedshortint (59)
- postman并发请求 (47)
- python列表删除 (50)
- 左程云什么水平 (56)
- 编程题 (64)
- postgresql默认端口 (66)
- 数据库的概念模型独立于 (48)
- 产生系统死锁的原因可能是由于 (51)
- 数据库中只存放视图的 (62)
- 在vi中退出不保存的命令是 (53)
- 哪个命令可以将普通用户转换成超级用户 (49)
- noscript标签的作用 (48)
- 联合利华网申 (49)
- swagger和postman (46)
- 结构化程序设计主要强调 (53)
- 172.1 (57)
- apipostwebsocket (47)
- 唯品会后台 (61)
- 简历助手 (56)
- offshow (61)
- mysql数据库面试题 (57)
- fmt.println (52)