Self-Attention 自注意力计算:Q、K、V 矩阵的数学本质

📂 所属阶段:第三阶段 — Transformer 革命(核心篇)
🔗 前置关联:Seq2Seq 标准注意力入门 · 词嵌入与位置编码基础
🧱 后续模块:多头注意力 Multi-Head Attention · Transformer 编码器-解码器


1. 入门:Self-Attention 和标准 Seq2Seq 注意力,究竟差在哪?

注意力机制的核心思想就三个字——「挑重点」。无论是读书、看图片还是处理序列,我们都希望模型能把“目光”放在最关键的地方。但同样是“挑重点”,Self-Attention 和经典 Seq2Seq 中的注意力,选材范围完全不同

1.1 一张表看懂两种注意力

维度标准 Seq2Seq AttentionSelf-Attention(Transformer 的心脏)
Q/K/V 的来源Q 只来自解码器当前步,K、V 全部来自编码器Q、K、V 都来自同一个输入序列
要解决什么问题让解码器「参考编码器的信息」来生成下一个词让序列里的每个词都重新认识自己——融合整个序列中其他词的信息,更新自己的表示
典型例子翻译 “猫 → cat” 时,解码器的 Q 关注编码器的 “猫”句子 “The animal didn't cross the street because it was too tired”,it 的 Self-Attention 会自动把最高权重放在 animal

一句话总结:标准注意力是「解码器看编码器」,Self-Attention 是「序列自己看自己」。

🧠 小思考:正是因为 Self-Attention 完全工作在同一个序列上,所以 Transformer 的编码器可以并行计算所有词的更新,而 RNN 必须一步接一步,这是性能革命的关键。


2. 核心:Q、K、V 矩阵的物理意义与完整计算

Self-Attention 的核心工具是三个可以学习的投影矩阵W_qW_kW_v。它们就像三副不同的“眼镜”,让同一个词向量去扮演三种不同的角色。

2.1 给 Q/K/V 一个“人话”版本的比喻

假设你有一堆待整理的便签,每张便签代表一个词(向量 x),你的任务是给每张便签写一个更丰富、更有上下文信息的新版本

  • 🕵️ Query(Q): 我在便签上写:「我现在需要找什么样的信息?」
  • 🏷️ Key(K): 我在便签上写:「我自己身上有哪些核心标签?」
  • 📦 Value(V): 我在便签上写:「如果有人选中我,我能分享给他哪些具体内容?」

整个 Self-Attention 的运行流程就像一场“全员匹配大会”:

  1. 给所有便签都写上 Q、K、V(通过 W_qW_kW_v 投影得到)。
  2. 对于便签 A,拿着它的 Q 去和所有便签的 K 做「相似度匹配」(点积),得到匹配分数。
  3. 把分数平滑一下,再转换成 0~1 之间的权重(加起来等于 1)。
  4. 用这些权重去加权所有便签的 V,得到便签 A 的全新表示。

于是,每个词都带着整个序列的“集体智慧”重新出炉了。

2.2 纯 PyTorch 实现单头 Self-Attention

下面的代码完整实现了单头 Self-Attention,每一处关键计算都标注了张量维度的变化——对于理解 Transformer,维度就是生命线

import torch
import torch.nn as nn
import torch.nn.functional as F

class SingleHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        # 三个独立的可学习线性投影层,把输入映射到 Q/K/V 空间
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        # x 的形状:(batch_size, seq_len, embed_dim)
        # 既可以是词嵌入,也可以是上一层 Transformer 的输出

        # ── 步骤 1:投影得到 Q / K / V ──
        Q = self.W_q(x)  # (batch, seq_len, embed_dim)
        K = self.W_k(x)  # (batch, seq_len, embed_dim)
        V = self.W_v(x)  # (batch, seq_len, embed_dim)

        # ── 步骤 2-4:缩放点积注意力 ──
        output, attention_weights = self._scaled_dot_product(Q, K, V, mask)
        return output, attention_weights

    def _scaled_dot_product(self, Q, K, V, mask=None):
        d_k = Q.size(-1)          # 投影后的维度(默认就是 embed_dim)

        # ── 步骤 2:计算相似度(点积)并缩放 ──
        # 为什么缩放? 点积值如果太大,softmax 后会变得极端,梯度消失。
        scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, seq_len, seq_len)
        scores = scores / (d_k ** 0.5)                 # 除以 √d_k

        # ── 可选:遮盖掉不应关注的位置 ──
        # 解码器中用来屏蔽未来词,或忽略填充符 <pad>
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)  # 负无穷,softmax 后趋近于 0

        # ── 步骤 3:softmax 得到注意力权重 ──
        attention_weights = F.softmax(scores, dim=-1)  # 每行权重之和为 1

        # ── 步骤 4:用权重加权 V,得到输出 ──
        output = torch.matmul(attention_weights, V)     # (batch, seq_len, embed_dim)
        return output, attention_weights

💡 看不懂维度?没关系,记住一条法则:

  • QK 点积产生 (seq_len × seq_len) 的“关系矩阵”,每一行代表该词对所有词的关注分数。
  • 归一化后变成权重,再乘上 V,得到融合了全局信息的新表示。

3. 进阶:多头注意力(Multi-Head Attention)

单头 Self-Attention 已经能解决很多问题,但它一次只能学会一种“关注模式”。就像你闭上一只眼睛看世界,能感知距离,但看不清立体深度。

3.1 为什么要多个头?

多头注意力相当于同时戴上好几副不同的眼镜,每副眼镜关注不同的语言特征:

  • 🧐 头1:负责捕捉语法关系(主语-谓语搭配)
  • 🧐 头2:负责捕捉语义关系(猫-喵喵叫)
  • 🧐 头3:负责捕捉指代关系(it → animal)
  • 🧐 头4:负责捕捉长程依赖(因为……所以……)

每个注意力头都有自己独立的一套投影矩阵W_qW_kW_v),因此可以通过训练学到完全不同的匹配规则。最后,把所有头的输出拼接起来,再做一次线性变换,就得到了语义更丰富的词表示。

3.2 纯 PyTorch 实现多头注意力

下面是一个可以直接使用的多注意力模块。为了通用性,我们支持 Q、K、V 来自不同输入(编码器-解码器注意力),也支持它们全部来自同一个输入(Self-Attention)。

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads   # 每个头分配到的维度
        self.scale = self.head_dim ** -0.5       # 缩放因子,等同于 1/√d_k

        # Q/K/V 的联合投影(也可以拆开,效果一样)
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        # 多头输出拼接后的最终线性融合
        self.W_o = nn.Linear(embed_dim, embed_dim)

    def _split_into_heads(self, x, batch_size):
        # 将最后维度切分为 (num_heads × head_dim)
        # (batch, seq_len, embed_dim) -> (batch, seq_len, num_heads, head_dim)
        x = x.view(batch_size, -1, self.num_heads, self.head_dim)
        # 转换为 (batch, num_heads, seq_len, head_dim),方便独立计算每个头
        return x.permute(0, 2, 1, 3)

    def forward(self, Q_input, K_input, V_input, mask=None):
        """
        灵活模式:
          - 编码器 Self-Attention: Q_input = K_input = V_input = x
          - 解码器 Self-Attention: 同上
          - 交叉注意力: Q来自解码器,K/V 来自编码器
        """
        batch_size = Q_input.size(0)

        # ── 1. 线性投影 ──
        Q = self.W_q(Q_input)   # (batch, seq_len_Q, embed_dim)
        K = self.W_k(K_input)   # (batch, seq_len_KV, embed_dim)
        V = self.W_v(V_input)   # (batch, seq_len_KV, embed_dim)

        # ── 2. 切分为多头 ──
        Q = self._split_into_heads(Q, batch_size)  # (batch, heads, seq_len_Q, head_dim)
        K = self._split_into_heads(K, batch_size)  # (batch, heads, seq_len_KV, head_dim)
        V = self._split_into_heads(V, batch_size)  # (batch, heads, seq_len_KV, head_dim)

        # ── 3. 每个头独立执行缩放点积注意力 ──
        # 计算注意力分数 (batch, heads, seq_len_Q, seq_len_KV)
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = F.softmax(scores, dim=-1)        # 权重
        head_outputs = torch.matmul(attention_weights, V)    # (batch, heads, seq_len_Q, head_dim)

        # ── 4. 合并多头并做最终的线性融合 ──
        # (batch, heads, seq_len_Q, head_dim) -> (batch, seq_len_Q, heads, head_dim)
        head_outputs = head_outputs.permute(0, 2, 1, 3).contiguous()
        # 拼接所有头 -> (batch, seq_len_Q, embed_dim)
        concatenated = head_outputs.view(batch_size, -1, self.num_heads * self.head_dim)
        # 最后再线性变换一次,让不同头学到的特征「相互配合」
        final_output = self.W_o(concatenated)

        return final_output, attention_weights

使用方法建议:

  • 在构建 Transformer 层时,MultiHeadAttention 就是核心积木。
  • PyTorch 官方也有现成的 torch.nn.MultiheadAttention,但其输入维度顺序是 (seq_len, batch, embed_dim),与我们的习惯 (batch, seq_len, embed_dim) 不同,使用时记得加 batch_first=True

4. 小结:一图流 + 两大优势

4.1 单头 Self-Attention 极简流程

flowchart LR
    A[输入序列<br/>词嵌入/上一层输出] --> B["投影 W_q / W_k / W_v<br/>得到 Q / K / V"]
    B --> C["Q·K^T + 缩放<br/>相似度矩阵"]
    C --> D["softmax 归一化<br/>注意力权重"]
    D --> E["权重 · V<br/>得到全局融合的输出"]

这个过程每计算一次,序列中的每一个词就和所有词“交流”了一遍

4.2 Self-Attention 为什么这么强?

  1. 长距离依赖一步到位
    序列中任意两个词,无论相距多远,它们的交互路径长度都是 1。相比之下,RNN 需要 O(n) 步,CNN 需要 O(log n) 步。对于“开头主语、结尾谓语”这种依赖,Self-Attention 直接捕获。

  2. 完全并行,GPU 友好
    所有词的 Q/K/V 投影、相似度计算、权重归一化,都可以一大块张量直接扔进 GPU 并行算完。这是 Transformer 训练速度快、能轻松扩展到大模型的关键。


🔗 优质扩展阅读

📘 掌握了 Q/K/V 和多头注意力,下一步就是搭建一个完整的 Transformer 编码器,敬请期待!