Self-Attention 自注意力计算:Q、K、V 矩阵的数学本质

📂 所属阶段:第三阶段 — Transformer 革命(核心篇)
🔗 前置关联:Seq2Seq 标准注意力入门 · 词嵌入与位置编码基础
🧱 后续模块:多头注意力 Multi-Head Attention · Transformer 编码器-解码器


1. 入门:Self-Attention vs 标准Seq2Seq Attention

注意力机制的核心目标是「选择性聚焦」——从一堆信息里挑出当前任务需要的重点,但Self-Attention和Seq2Seq的注意力,聚焦的来源/范围完全不一样

1.1 差异对比表

维度标准Seq2Seq AttentionSelf-Attention(Transformer核心)
Q/K/V来源Q仅来自解码器当前步骤,K/V仅来自编码器全序列Q/K/V全部来自同一个输入序列
目标帮助解码器「从编码器里找参考词」生成下一个字帮助序列中的每个词都“重新认识自己”——结合序列中所有其他词的信息,更新自身表示
关键场景举例翻译“猫→cat”时,解码器Q关注编码器的“猫”V句子“The animal didn't cross the street because it was too tired”中,代词“it”的Self-Attention会自动把最高权重给“animal”,解决指代问题

2. 核心:Q、K、V 矩阵的物理意义与计算

要实现「每个词都看遍全序列再更新自己」,核心工具就是三个独立训练的投影矩阵:W_q、W_k、W_v。

2.1 Q/K/V的「人话」类比

假设你正在整理一堆杂乱的便签(每个便签是一个输入词的嵌入向量x),要给每个便签重新写更有信息量的版本:

  • W_q投影得到Query(Q):便签上写「我现在需要找什么样的内容?」
  • W_k投影得到Key(K):便签上写「我自己有哪些核心标签?」
  • W_v投影得到Value(V):便签上写「如果别人需要我,我能分享什么具体信息?」

Self-Attention的流程就是:

  1. 给所有便签都写上Q/K/V
  2. 每一张便签A,拿它的Q去和所有便签的K做「相似度匹配」,得到匹配分数
  3. 给分数加个“平滑处理”,再转成0-1之间的权重(权重加起来是1)
  4. 用权重去加权所有便签的V,得到便签A的「新版信息量表示」

2.2 纯PyTorch实现单头Self-Attention

代码里补全了缺失的F导入,调整了注释的对齐和可读性,加入了关键步骤的维度说明(代码的维度是技术博客的灵魂!):

import torch
import torch.nn as nn
import torch.nn.functional as F

class SingleHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        # 三个独立的可学习线性投影层
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        # x的维度:(batch_size, seq_len, embed_dim)
        # 这里可以是词嵌入,或者上一层Transformer的输出

        # ---------------------------
        # 步骤1:投影得到Q/K/V
        # ---------------------------
        Q = self.W_q(x)  # (batch, seq_len, embed_dim)
        K = self.W_k(x)  # (batch, seq_len, embed_dim)
        V = self.W_v(x)  # (batch, seq_len, embed_dim)

        # ---------------------------
        # 步骤2-4:缩放点积注意力(核心计算)
        # ---------------------------
        output, attention_weights = self._scaled_dot_product(Q, K, V, mask)
        return output, attention_weights

    def _scaled_dot_product(self, Q, K, V, mask=None):
        # 获取投影后的维度d_k(embed_dim不变的话就是embed_dim)
        d_k = Q.size(-1)

        # ---------------------------
        # 步骤2:Q和K计算相似度(点积)+ 缩放
        # 缩放的目的:防止点积值太大,导致softmax后梯度消失
        # ---------------------------
        scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, seq_len, seq_len)
        scores = scores / (d_k ** 0.5)  # 用d_k的平方根做缩放

        # ---------------------------
        # 可选步骤:加掩码
        # 用于Transformer解码器(遮盖未来词),或填充位置的忽略
        # ---------------------------
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)  # 把不需要的位置设为负无穷,softmax后趋近于0

        # ---------------------------
        # 步骤3:转成注意力权重(softmax归一化)
        # ---------------------------
        attention_weights = F.softmax(scores, dim=-1)  # 最后一个维度归一化,每个词的权重加起来=1

        # ---------------------------
        # 步骤4:加权V得到输出
        # ---------------------------
        output = torch.matmul(attention_weights, V)  # (batch, seq_len, embed_dim)
        return output, attention_weights

3. 进阶:Multi-Head Attention(多头注意力)

单头Self-Attention虽然能解决指代、关联的问题,但表达能力有限——只能学习一种「关注模式」。

3.1 为什么要做多头?

想象你只靠一只眼睛看世界,只能捕捉到「平面的距离信息」;如果是两只眼睛(两个头),就能捕捉「立体的深度」;如果是八只眼睛(NLP常用的8头),甚至能同时关注:

  • 语法关系(比如“主语-谓语”的搭配)
  • 语义关系(比如“猫-喵喵叫”的常识)
  • 指代关系(比如前面的“it”指“animal”)
  • 上下文窗口(比如“因为-所以”的因果对)

每个头的W_q、W_k、W_v都是独立初始化、独立训练的,所以能学到完全不同的「匹配规则」,最后再把所有头的输出「拼起来」,就能得到更丰富的词表示。

3.2 PyTorch实现多头注意力

同样补全了维度注释,优化了变量命名的一致性,还调整了多头合并的顺序逻辑:

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8):
        super().__init__()
        # 检查embed_dim能不能被num_heads整除,保证每个头的维度是整数
        assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"

        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads  # 每个头分配的维度
        self.scale = self.head_dim ** -0.5  # 缩放因子,和单头一样

        # 四个可学习的线性层:Q/K/V投影 + 多头输出的最终融合
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim)

    def _split_into_heads(self, x, batch_size):
        # 将embed_dim的投影结果,切分成num_heads个head_dim的子向量
        # 维度变换:(batch, seq_len, embed_dim) → (batch, seq_len, num_heads, head_dim)
        x = x.view(batch_size, -1, self.num_heads, self.head_dim)
        # 调整顺序为:(batch, num_heads, seq_len, head_dim)
        # 方便后续对每个头单独做scaled dot-product
        return x.permute(0, 2, 1, 3)

    def forward(self, Q_input, K_input, V_input, mask=None):
        # 这里的Q/K/V_input不一定是同一个x!
        # 如果是Encoder的Self-Attention:三者都是x
        # 如果是Decoder的Self-Attention:三者都是x
        # 如果是Decoder的Encoder-Decoder Attention:Q是Decoder当前层的x,K/V是Encoder的最终输出
        batch_size = Q_input.size(0)

        # ---------------------------
        # 步骤1:对Q/K/V_input分别做线性投影
        # ---------------------------
        Q = self.W_q(Q_input)  # (batch, seq_len_Q, embed_dim)
        K = self.W_k(K_input)  # (batch, seq_len_KV, embed_dim)
        V = self.W_v(V_input)  # (batch, seq_len_KV, embed_dim)

        # ---------------------------
        # 步骤2:切分成多头
        # ---------------------------
        Q = self._split_into_heads(Q, batch_size)  # (batch, heads, seq_len_Q, head_dim)
        K = self._split_into_heads(K, batch_size)  # (batch, heads, seq_len_KV, head_dim)
        V = self._split_into_heads(V, batch_size)  # (batch, heads, seq_len_KV, head_dim)

        # ---------------------------
        # 步骤3:对每个头单独做scaled dot-product
        # ---------------------------
        # 计算相似度+缩放:(batch, heads, seq_len_Q, seq_len_KV)
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale

        # 可选步骤:加掩码
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # 归一化得到权重:(batch, heads, seq_len_Q, seq_len_KV)
        attention_weights = F.softmax(scores, dim=-1)

        # 加权V得到每个头的输出:(batch, heads, seq_len_Q, head_dim)
        head_outputs = torch.matmul(attention_weights, V)

        # ---------------------------
        # 步骤4:合并多头+最终线性融合
        # ---------------------------
        # 调整顺序回:(batch, seq_len_Q, heads, head_dim)
        head_outputs = head_outputs.permute(0, 2, 1, 3).contiguous()
        # 合并最后两个维度:(batch, seq_len_Q, embed_dim)
        concatenated = head_outputs.view(batch_size, -1, self.num_heads * self.head_dim)
        # 最终线性融合,让多头的输出“相互配合”
        final_output = self.W_o(concatenated)

        return final_output, attention_weights

4. 小结

4.1 单头Self-Attention的极简流程

flowchart LR
    A[输入序列<br/>词嵌入/上一层输出] --> B["投影W_q/W_k/W_v<br/>得到Q/K/V"]
    B --> C["Q·K^T + 缩放<br/>得到相似度矩阵"]
    C --> D["softmax归一化<br/>得到注意力权重"]
    D --> E["权重·V<br/>得到加权表示输出"]

4.2 Self-Attention的两个核心优势

  1. 全局关联直接建立:序列中任意两个位置的词,路径长度都是1(RNN是n,CNN是log n),完美解决长距离依赖问题
  2. 完全并行化计算:所有词的Q/K/V投影、相似度计算、权重归一化,可以同时在GPU上完成,训练速度远快于RNN

🔗 优质扩展阅读