注意力机制显著改变了自然语言处理领域,但对其工作原理的直观理解仍有一定门槛。本指南提供了可视化 Transformer 模型注意力模式的实践方法。
注意力机制概述
从本质上看,注意力机制允许模型在生成输出的每个部分时,聚焦于输入序列的不同位置。Transformer 的注意力计算遵循以下公式:
Transformer 的注意力计算遵循以下公式:Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V
实现示例
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(query, key, value, mask=None):
"""
计算缩放点积注意力。
参数:
query: (batch, heads, seq_len, d_k)
key: (batch, heads, seq_len, d_k)
value: (batch, heads, seq_len, d_v)
mask: 可选的注意力掩码
"""
d_k = query.size(-1)
# 计算注意力分数
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 如有掩码则应用(用于填充或因果注意力)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax 获取注意力权重
attention_weights = F.softmax(scores, dim=-1)
# 将注意力应用于 value
output = torch.matmul(attention_weights, value)
return output, attention_weights可视化的意义
可视化有助于:
- 理解模型行为:分析特定预测背后的依据
- 发现模式:观察模型习得的注意力分布特征
- 建立信任:直观呈现决策过程
- 调试模型:发现注意力漂移或不期望的偏置
- 提升可解释性:让黑盒模型更加透明
构建可视化工具
1. 使用 HuggingFace 提取特征
from transformers import AutoModel, AutoTokenizer
import torch
class AttentionExtractor:
def __init__(self, model_name="bert-base-chinese"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(
model_name,
output_attentions=True
)
self.model.eval()
def extract(self, text):
"""从文本中提取注意力模式。"""
inputs = self.tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
# outputs.attentions: (batch, heads, seq_len, seq_len) 的元组
attentions = outputs.attentions
return {
"tokens": self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
"attentions": [a.squeeze(0) for a in attentions],
"num_layers": len(attentions),
"num_heads": attentions[0].shape[1]
}2. 使用 Matplotlib 可视化
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
def plot_attention_heatmap(tokens, attention_matrix, head=0, layer=0):
"""将单个注意力头绘制为热力图。"""
fig, ax = plt.subplots(figsize=(10, 8))
# attention_matrix: (heads, seq_len, seq_len)
attn = attention_matrix[head].numpy()
sns.heatmap(
attn,
xticklabels=tokens,
yticklabels=tokens,
cmap="viridis",
ax=ax,
square=True,
cbar_kws={"label": "注意力权重"}
)
ax.set_xlabel("Key 位置")
ax.set_ylabel("Query 位置")
ax.set_title(f"注意力头 {head}, 层 {layer}")
plt.tight_layout()
return fig3. 使用 Plotly 创建交互式可视化
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def interactive_attention_view(tokens, attention_matrix):
"""创建交互式多头注意力视图。"""
num_heads = attention_matrix.shape[0]
cols = min(4, num_heads)
rows = (num_heads + cols - 1) // cols
fig = make_subplots(rows=rows, cols=cols)
for head in range(num_heads):
row, col = head // cols + 1, head % cols + 1
attn = attention_matrix[head].numpy()
fig.add_trace(
go.Heatmap(
z=attn,
x=tokens,
y=tokens,
colorscale="Viridis",
showscale=(head == 0),
name=f"头 {head}"
),
row=row, col=col
)
fig.update_layout(
title="多头注意力模式",
height=200 * rows,
width=200 * cols
)
return fig可观察的模式
局部注意力
某些头学习关注相邻 token,适用于:
- 句法依赖:主谓一致
- 局部短语结构:名词短语、介词短语
def detect_local_attention(attention_matrix, window=3):
"""测量局部注意力的集中程度。"""
seq_len = attention_matrix.shape[-1]
local_mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
for j in range(max(0, i - window), min(seq_len, i + window + 1)):
local_mask[i, j] = 1
local_attention = (attention_matrix * local_mask).sum()
total_attention = attention_matrix.sum()
return local_attention / total_attention全局注意力
其他头在整个序列上广泛关注:
- 句子级语义:主题检测
- 长距离依赖:指代消解
注意力头特化
研究表明不同的头专门负责:
| 头类型 | 功能 | 模式示例 |
|---|---|---|
| 句法头 | 语法结构 | 主语 → 谓语 |
| 位置头 | 序列顺序 | 下一个/上一个 token |
| 语义头 | 语义关系 | 同义词、反义词 |
| 实体头 | 命名实体 | 人物 → 组织 |
实践应用:分析 BERT
# 示例:分析 BERT 关注什么
extractor = AttentionExtractor("bert-base-chinese")
text = "猫坐在垫子上因为它累了。"
result = extractor.extract(text)
# 找出"它"最关注哪个 token
it_index = result["tokens"].index("它")
print(f"分析位置 {it_index} 的 '它' 的注意力")
for layer in range(result["num_layers"]):
# 在头之间取平均
avg_attention = result["attentions"][layer].mean(dim=0)
# "它"关注什么?
it_attention = avg_attention[it_index]
top_indices = it_attention.topk(3).indices.tolist()
top_tokens = [result["tokens"][i] for i in top_indices]
print(f"层 {layer}: '它' → {top_tokens}")这揭示了 BERT 如何通过不同层解析代词引用。
工具与资源
- BertViz(opens in a new tab) — BERT 的交互式注意力可视化
- Transformer-Explainability(opens in a new tab) — CVPR 2021 论文实现,可视化 Transformer 分类依据
最佳实践
- 层选择:中间层通常捕获最有趣的模式
- 头平均:某些分析需要平均各头,其他则需要单独观察
- 归一化:应检查注意力权重是否正确归一化
- 上下文长度:更长的序列可能稀释注意力模式