import torch
import torch.nn as nn
import torch.nn.functional as F
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model=512, num_heads=8, d_ff=2048, dropout=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=num_heads,
dropout=dropout,
batch_first=True, # PyTorch 新版支持!
)
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-Attention + 残差连接 + LayerNorm
attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
x = self.norm1(x + self.dropout(attn_output))
# Feed Forward + 残差连接 + LayerNorm
ff_output = self.linear2(F.gelu(self.linear1(x)))
x = self.norm2(x + self.dropout(ff_output))
return x
# 使用
encoder_layer = TransformerEncoderLayer(d_model=512, num_heads=8)
x = torch.randn(32, 100, 512) # (batch, seq_len, d_model)
output = encoder_layer(x)
print(output.shape) # torch.Size([32, 100, 512])