返回博客列表

神经注意力模式可视化实现

使用代码示例深入理解 Transformer 模型中的注意力机制及其可视化方法。

#Machine Learning#Visualization#Python#Transformers

注意力机制显著改变了自然语言处理领域,但对其工作原理的直观理解仍有一定门槛。本指南提供了可视化 Transformer 模型注意力模式的实践方法。

注意力机制概述

从本质上看,注意力机制允许模型在生成输出的每个部分时,聚焦于输入序列的不同位置。Transformer 的注意力计算遵循以下公式:

Transformer 的注意力计算遵循以下公式:Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V

实现示例

import torch
import torch.nn.functional as F
import math
 
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    计算缩放点积注意力。
 
    参数:
        query: (batch, heads, seq_len, d_k)
        key: (batch, heads, seq_len, d_k)
        value: (batch, heads, seq_len, d_v)
        mask: 可选的注意力掩码
    """
    d_k = query.size(-1)
 
    # 计算注意力分数
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
 
    # 如有掩码则应用(用于填充或因果注意力)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
 
    # Softmax 获取注意力权重
    attention_weights = F.softmax(scores, dim=-1)
 
    # 将注意力应用于 value
    output = torch.matmul(attention_weights, value)
 
    return output, attention_weights

可视化的意义

可视化有助于:

  1. 理解模型行为:分析特定预测背后的依据
  2. 发现模式:观察模型习得的注意力分布特征
  3. 建立信任:直观呈现决策过程
  4. 调试模型:发现注意力漂移或不期望的偏置
  5. 提升可解释性:让黑盒模型更加透明

构建可视化工具

1. 使用 HuggingFace 提取特征

from transformers import AutoModel, AutoTokenizer
import torch
 
class AttentionExtractor:
    def __init__(self, model_name="bert-base-chinese"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(
            model_name,
            output_attentions=True
        )
        self.model.eval()
 
    def extract(self, text):
        """从文本中提取注意力模式。"""
        inputs = self.tokenizer(text, return_tensors="pt")
 
        with torch.no_grad():
            outputs = self.model(**inputs)
 
        # outputs.attentions: (batch, heads, seq_len, seq_len) 的元组
        attentions = outputs.attentions
 
        return {
            "tokens": self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
            "attentions": [a.squeeze(0) for a in attentions],
            "num_layers": len(attentions),
            "num_heads": attentions[0].shape[1]
        }

2. 使用 Matplotlib 可视化

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
 
def plot_attention_heatmap(tokens, attention_matrix, head=0, layer=0):
    """将单个注意力头绘制为热力图。"""
    fig, ax = plt.subplots(figsize=(10, 8))
 
    # attention_matrix: (heads, seq_len, seq_len)
    attn = attention_matrix[head].numpy()
 
    sns.heatmap(
        attn,
        xticklabels=tokens,
        yticklabels=tokens,
        cmap="viridis",
        ax=ax,
        square=True,
        cbar_kws={"label": "注意力权重"}
    )
 
    ax.set_xlabel("Key 位置")
    ax.set_ylabel("Query 位置")
    ax.set_title(f"注意力头 {head}, 层 {layer}")
 
    plt.tight_layout()
    return fig

3. 使用 Plotly 创建交互式可视化

import plotly.graph_objects as go
from plotly.subplots import make_subplots
 
def interactive_attention_view(tokens, attention_matrix):
    """创建交互式多头注意力视图。"""
    num_heads = attention_matrix.shape[0]
    cols = min(4, num_heads)
    rows = (num_heads + cols - 1) // cols
 
    fig = make_subplots(rows=rows, cols=cols)
 
    for head in range(num_heads):
        row, col = head // cols + 1, head % cols + 1
        attn = attention_matrix[head].numpy()
 
        fig.add_trace(
            go.Heatmap(
                z=attn,
                x=tokens,
                y=tokens,
                colorscale="Viridis",
                showscale=(head == 0),
                name=f"头 {head}"
            ),
            row=row, col=col
        )
 
    fig.update_layout(
        title="多头注意力模式",
        height=200 * rows,
        width=200 * cols
    )
 
    return fig

可观察的模式

局部注意力

某些头学习关注相邻 token,适用于:

  • 句法依赖:主谓一致
  • 局部短语结构:名词短语、介词短语
def detect_local_attention(attention_matrix, window=3):
    """测量局部注意力的集中程度。"""
    seq_len = attention_matrix.shape[-1]
    local_mask = torch.zeros(seq_len, seq_len)
 
    for i in range(seq_len):
        for j in range(max(0, i - window), min(seq_len, i + window + 1)):
            local_mask[i, j] = 1
 
    local_attention = (attention_matrix * local_mask).sum()
    total_attention = attention_matrix.sum()
 
    return local_attention / total_attention

全局注意力

其他头在整个序列上广泛关注:

  • 句子级语义:主题检测
  • 长距离依赖:指代消解

注意力头特化

研究表明不同的头专门负责:

头类型功能模式示例
句法头语法结构主语 → 谓语
位置头序列顺序下一个/上一个 token
语义头语义关系同义词、反义词
实体头命名实体人物 → 组织

实践应用:分析 BERT

# 示例:分析 BERT 关注什么
extractor = AttentionExtractor("bert-base-chinese")
 
text = "猫坐在垫子上因为它累了。"
result = extractor.extract(text)
 
# 找出"它"最关注哪个 token
it_index = result["tokens"].index("它")
print(f"分析位置 {it_index} 的 '它' 的注意力")
 
for layer in range(result["num_layers"]):
    # 在头之间取平均
    avg_attention = result["attentions"][layer].mean(dim=0)
 
    # "它"关注什么?
    it_attention = avg_attention[it_index]
    top_indices = it_attention.topk(3).indices.tolist()
 
    top_tokens = [result["tokens"][i] for i in top_indices]
    print(f"层 {layer}: '它' → {top_tokens}")

这揭示了 BERT 如何通过不同层解析代词引用。

工具与资源

最佳实践

  1. 层选择:中间层通常捕获最有趣的模式
  2. 头平均:某些分析需要平均各头,其他则需要单独观察
  3. 归一化:应检查注意力权重是否正确归一化
  4. 上下文长度:更长的序列可能稀释注意力模式