#Self-Attention 自注意力计算:Q、K、V 矩阵的数学本质
📂 所属阶段:第三阶段 — Transformer 革命(核心篇)
🔗 相关章节:注意力机制详解 · 多头注意力 (Multi-Head Attention)
#1. Self-Attention vs Attention
#1.1 区别
标准 Attention(Seq2Seq):
Query 来自 Decoder(解码器)
Key, Value 来自 Encoder(编码器)
→ 关注"编码器中的哪些信息"来生成当前解码词
Self-Attention(Transformer):
Query, Key, Value 都来自同一个序列
→ 序列中的每个词关注"序列中的哪些词"来理解自己
举例:"The animal didn't cross the street because it was too tired"
→ "it" 的 Self-Attention:关注 "animal" 的权重最高
→ 模型自动学会代词指向!#2. Q、K、V 的物理意义
#2.1 如何得到 Q、K、V
"""
每个输入词先转为嵌入向量 x(来自词嵌入层)
然后通过三个独立的线性变换得到 Q、K、V:
Q = x · W_q (Query:我在找什么)
K = x · W_k (Key:我有什么特征)
V = x · W_v (Value:如果匹配成功,我贡献什么信息)
W_q, W_k, W_v 是可学习的参数!
"""
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
def forward(self, x, mask=None):
# x: (batch, seq_len, embed_dim)
Q = self.W_q(x) # (batch, seq_len, embed_dim)
K = self.W_k(x)
V = self.W_v(x)
# 调用 scaled dot-product attention
output = self.scaled_dot_product(Q, K, V, mask)
return output
def scaled_dot_product(self, Q, K, V, mask=None):
import math
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, V), weights#3. Multi-Head Attention
#3.1 为什么需要多头?
单头 Attention 的问题:
→ 每个头只能学到一种"关注模式"
多头的优势:
→ 每个头独立学习不同的关注模式
→ 有的头关注语法关系,有的关注语义关系
→ 表达能力大大增强!
示例(翻译任务):
Head 1:关注主语-谓语关系
Head 2:关注修饰关系
Head 3:关注指代关系
...#3.2 Multi-Head Attention 实现
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads=8):
super().__init__()
assert embed_dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.W_o = nn.Linear(embed_dim, embed_dim)
def split_heads(self, x, batch_size):
# 将 embed_dim 分成 num_heads 个 head
x = x.view(batch_size, -1, self.num_heads, self.head_dim)
return x.permute(0, 2, 1, 3) # (batch, heads, seq_len, head_dim)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 线性变换
Q = self.W_q(Q)
K = self.W_k(K)
V = self.W_v(V)
# 分成多头
Q = self.split_heads(Q, batch_size)
K = self.split_heads(K, batch_size)
V = self.split_heads(V, batch_size)
# Scaled Dot-Product Attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(weights, V)
# 合并多头
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
attn_output = attn_output.view(batch_size, -1, self.num_heads * self.head_dim)
# 最终线性变换
output = self.W_o(attn_output)
return output, weights#4. 小结
Self-Attention 流程:
1. 输入 x 通过三个线性变换得到 Q, K, V
2. QK^T / √d_k → 相似度矩阵
3. softmax → 注意力权重
4. 权重 × V → 加权输出
Multi-Head = 多个 Self-Attention 并行 → 更丰富的表示💡 记住:Self-Attention 让序列中任意两个位置直接建立联系(O(1)路径长度),这是它能处理长序列的根本原因。
🔗 扩展阅读

