注意力机制 (Attention) 详解:为什么"注意力"就是你所需要的一切

📂 所属阶段:第三阶段 — Transformer 革命(核心篇)
🔗 相关章节:序列到序列模型 (Seq2Seq) · Self-Attention 自注意力计算


1. Attention 的起源

1.1 不用 Attention 的问题

Seq2Seq 的核心缺陷:用一个固定向量压缩整个输入

Encoder:
  输入:"我 爱 NLP [PAD] [PAD]"(假设长度=5)
  → 压缩成 1 个 Context 向量

问题:
  当输入很长时,一个向量无法存储所有信息!
  "我出生于北京……(1000字后)……我会说___"
  → Context 容量有限,无法记住"北京"这个关键信息

这就是"信息瓶颈"问题!

1.2 Attention 的灵感

人脑的注意力:你读这句话时,会"重点关注"某些词

"The animal didn't cross the street because it was too tired."
→ 你读到 "it" 时,注意力集中在 "animal",而不是 "street"

Attention = 让模型自动学会"关注"输入的相关部分

2. Attention 数学原理

2.1 Query-Key-Value 抽象

Attention 机制的三个角色:

Query(查询):我当前在关注什么?
Key(键):我这里有什么信息?
Value(值):这些信息的实际内容是什么?

类比:搜索引擎
  Query = 你的搜索词
  Key = 文章标题
  Value = 文章内容

Attention = 用 Query 在 Key-Value 池中找到最相关的信息

2.2 Scaled Dot-Product Attention

"""
Attention(Q, K, V) = softmax(QK^T / √d_k) × V

步骤:
1. QK^T:计算 Query 和每个 Key 的相似度(点积)
2. / √d_k:除以 √d_k 做缩放(防止梯度消失)
3. softmax:将相似度归一化为概率分布
4. × V:用概率对 Value 加权求和
"""

import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: (batch, heads, seq_len, d_k)
    K: (batch, heads, seq_len, d_k)
    V: (batch, heads, seq_len, d_v)
    """
    d_k = Q.size(-1)

    # 1. 计算相似度矩阵
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    # QK^T: (batch, heads, seq_len, seq_len)

    # 2. 掩码(可选,用于 padding 或解码时遮挡未来信息)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # 3. Softmax 得到注意力权重
    attention_weights = F.softmax(scores, dim=-1)
    # shape: (batch, heads, seq_len, seq_len)

    # 4. 对 V 加权求和
    output = torch.matmul(attention_weights, V)

    return output, attention_weights

3. Attention 可视化

3.1 机器翻译中的注意力

# 示例:翻译时关注正确的源语言词
# "I love NLP" → "我 爱 自然语言处理"

attention_matrix = [
    # to:  我      爱     NLP
    # from
    "I":    [0.9,  0.05,  0.05],  # "我" 主要关注 "I"
    "love": [0.05, 0.9,   0.05],  # "爱" 主要关注 "love"
    "NLP":  [0.03, 0.03,  0.94],  # "NLP" 主要关注 "NLP"
]

# 注意力让模型学习词语对齐!

3.2 可视化代码

import matplotlib.pyplot as plt
import seaborn as sns

def plot_attention(attention_weights, source_tokens, target_tokens):
    """绘制注意力热力图"""
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(
        attention_weights,
        xticklabels=source_tokens,
        yticklabels=target_tokens,
        cmap="YlOrRd",
        ax=ax,
    )
    ax.set_xlabel("Source Tokens")
    ax.set_ylabel("Target Tokens")
    ax.set_title("Attention Weights Heatmap")
    plt.tight_layout()
    plt.savefig("attention.png")

4. Attention 的优势

# Attention vs RNN/LSTM:

# RNN/LSTM 的问题:
# 1. 信息需要经过多个时间步才能传递(路径长 → 梯度消失)
# 2. 难以并行化(必须按顺序计算)
# 3. 对长序列效果差

# Attention 的优势:
# 1. 直接连接:每个位置可以直接访问所有其他位置(路径长度=O(1))
# 2. 完全并行:所有注意力计算可以并行
# 3. 可解释性:注意力权重可以直接可视化
# 4. 长距离依赖:不再受限于序列长度

5. 小结

Attention 三步曲:

1. QK^T  → 计算 Query 和 Key 的相似度
2. softmax → 归一化为概率分布
3. ×V    → 用概率对 Value 加权求和

这就是 Transformer 的核心!

💡 记住:Attention 是 2017 年 Google 在《Attention is All You Need》中提出的,它是 Transformer 论文的核心,也是现代所有大模型的基石。


🔗 扩展阅读