Self-Attention 自注意力计算:Q、K、V 矩阵的数学本质

📂 所属阶段:第三阶段 — Transformer 革命(核心篇)
🔗 相关章节:注意力机制详解 · 多头注意力 (Multi-Head Attention)


1. Self-Attention vs Attention

1.1 区别

标准 Attention(Seq2Seq):
  Query 来自 Decoder(解码器)
  Key, Value 来自 Encoder(编码器)
  → 关注"编码器中的哪些信息"来生成当前解码词

Self-Attention(Transformer):
  Query, Key, Value 都来自同一个序列
  → 序列中的每个词关注"序列中的哪些词"来理解自己

举例:"The animal didn't cross the street because it was too tired"
→ "it" 的 Self-Attention:关注 "animal" 的权重最高
→ 模型自动学会代词指向!

2. Q、K、V 的物理意义

2.1 如何得到 Q、K、V

"""
每个输入词先转为嵌入向量 x(来自词嵌入层)
然后通过三个独立的线性变换得到 Q、K、V:

Q = x · W_q  (Query:我在找什么)
K = x · W_k  (Key:我有什么特征)
V = x · W_v  (Value:如果匹配成功,我贡献什么信息)

W_q, W_k, W_v 是可学习的参数!
"""
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        # x: (batch, seq_len, embed_dim)

        Q = self.W_q(x)  # (batch, seq_len, embed_dim)
        K = self.W_k(x)
        V = self.W_v(x)

        # 调用 scaled dot-product attention
        output = self.scaled_dot_product(Q, K, V, mask)
        return output

    def scaled_dot_product(self, Q, K, V, mask=None):
        import math
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        weights = F.softmax(scores, dim=-1)
        return torch.matmul(weights, V), weights

3. Multi-Head Attention

3.1 为什么需要多头?

单头 Attention 的问题:
→ 每个头只能学到一种"关注模式"

多头的优势:
→ 每个头独立学习不同的关注模式
→ 有的头关注语法关系,有的关注语义关系
→ 表达能力大大增强!

示例(翻译任务):
  Head 1:关注主语-谓语关系
  Head 2:关注修饰关系
  Head 3:关注指代关系
  ...

3.2 Multi-Head Attention 实现

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8):
        super().__init__()
        assert embed_dim % num_heads == 0

        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim)

    def split_heads(self, x, batch_size):
        # 将 embed_dim 分成 num_heads 个 head
        x = x.view(batch_size, -1, self.num_heads, self.head_dim)
        return x.permute(0, 2, 1, 3)  # (batch, heads, seq_len, head_dim)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # 线性变换
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)

        # 分成多头
        Q = self.split_heads(Q, batch_size)
        K = self.split_heads(K, batch_size)
        V = self.split_heads(V, batch_size)

        # Scaled Dot-Product Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(weights, V)

        # 合并多头
        attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.num_heads * self.head_dim)

        # 最终线性变换
        output = self.W_o(attn_output)
        return output, weights

4. 小结

Self-Attention 流程:

1. 输入 x 通过三个线性变换得到 Q, K, V
2. QK^T / √d_k → 相似度矩阵
3. softmax → 注意力权重
4. 权重 × V → 加权输出

Multi-Head = 多个 Self-Attention 并行 → 更丰富的表示

💡 记住:Self-Attention 让序列中任意两个位置直接建立联系(O(1)路径长度),这是它能处理长序列的根本原因。


🔗 扩展阅读