Attention mechanisms have significantly transformed natural language processing, but developing an intuitive understanding of their operation can be challenging. This guide provides a practical approach to visualizing attention patterns in transformer models.
What is Attention?
At its core, attention allows a model to focus on different parts of the input sequence when producing each part of the output. The transformer attention computation follows the formula:
The attention formula: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V
Implementation
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(query, key, value, mask=None):
"""
Compute scaled dot-product attention.
Args:
query: (batch, heads, seq_len, d_k)
key: (batch, heads, seq_len, d_k)
value: (batch, heads, seq_len, d_v)
mask: Optional attention mask
"""
d_k = query.size(-1)
# Compute attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Apply mask if provided (for padding or causal attention)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1)
# Apply attention to values
output = torch.matmul(attention_weights, value)
return output, attention_weightsWhy Visualize?
Visualization helps us:
- Understand Model Behavior: Analyze reasoning behind specific predictions
- Discover Patterns: Observe learned attention distribution characteristics
- Build Trust: Intuitively present the decision process
- Debug Models: Detect attention drift or unwanted biases
- Improve Interpretability: Make black-box models more transparent
Building a Visualizer
1. Feature Extraction with HuggingFace
from transformers import AutoModel, AutoTokenizer
import torch
class AttentionExtractor:
def __init__(self, model_name="bert-base-uncased"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(
model_name,
output_attentions=True
)
self.model.eval()
def extract(self, text):
"""Extract attention patterns from text."""
inputs = self.tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
# outputs.attentions: tuple of (batch, heads, seq_len, seq_len)
attentions = outputs.attentions
return {
"tokens": self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
"attentions": [a.squeeze(0) for a in attentions], # Remove batch dim
"num_layers": len(attentions),
"num_heads": attentions[0].shape[1]
}2. Visualization with Matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
def plot_attention_heatmap(tokens, attention_matrix, head=0, layer=0):
"""Plot a single attention head as a heatmap."""
fig, ax = plt.subplots(figsize=(10, 8))
# attention_matrix: (heads, seq_len, seq_len)
attn = attention_matrix[head].numpy()
sns.heatmap(
attn,
xticklabels=tokens,
yticklabels=tokens,
cmap="viridis",
ax=ax,
square=True,
cbar_kws={"label": "Attention Weight"}
)
ax.set_xlabel("Key Position")
ax.set_ylabel("Query Position")
ax.set_title(f"Attention Head {head}, Layer {layer}")
plt.tight_layout()
return fig3. Interactive Visualization with Plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
def interactive_attention_view(tokens, attention_matrix):
"""Create an interactive multi-head attention view."""
num_heads = attention_matrix.shape[0]
cols = min(4, num_heads)
rows = (num_heads + cols - 1) // cols
fig = make_subplots(rows=rows, cols=cols)
for head in range(num_heads):
row, col = head // cols + 1, head % cols + 1
attn = attention_matrix[head].numpy()
fig.add_trace(
go.Heatmap(
z=attn,
x=tokens,
y=tokens,
colorscale="Viridis",
showscale=(head == 0),
name=f"Head {head}"
),
row=row, col=col
)
fig.update_layout(
title="Multi-Head Attention Patterns",
height=200 * rows,
width=200 * cols
)
return figObservable Patterns
Local Attention
Some heads learn to focus on adjacent tokens, useful for:
- Syntactic dependencies: Subject-verb agreement
- Local phrase structure: Noun phrases, prepositional phrases
def detect_local_attention(attention_matrix, window=3):
"""Measure how much attention is focused locally."""
seq_len = attention_matrix.shape[-1]
local_mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
for j in range(max(0, i - window), min(seq_len, i + window + 1)):
local_mask[i, j] = 1
local_attention = (attention_matrix * local_mask).sum()
total_attention = attention_matrix.sum()
return local_attention / total_attentionGlobal Attention
Other heads attend broadly across the sequence:
- Sentence-level semantics: Topic detection
- Long-range dependencies: Coreference resolution
Head Specialization
Research has shown different heads specialize in:
| Head Type | Function | Example Pattern |
|---|---|---|
| Syntax | Grammatical structure | Subject → Verb |
| Position | Sequential order | Next/previous token |
| Semantic | Meaning relationships | Synonyms, antonyms |
| Entity | Named entities | Person → Organization |
Practical Application: Analyzing BERT
# Example: Analyze what BERT focuses on
extractor = AttentionExtractor("bert-base-uncased")
text = "The cat sat on the mat because it was tired."
result = extractor.extract(text)
# Find which token "it" attends to most
it_index = result["tokens"].index("it")
print(f"Analyzing attention for 'it' at position {it_index}")
for layer in range(result["num_layers"]):
# Average across heads
avg_attention = result["attentions"][layer].mean(dim=0)
# What does "it" attend to?
it_attention = avg_attention[it_index]
top_indices = it_attention.topk(3).indices.tolist()
top_tokens = [result["tokens"][i] for i in top_indices]
print(f"Layer {layer}: 'it' → {top_tokens}")This reveals how BERT resolves the pronoun reference through different layers.
Tools and Resources
- BertViz(opens in a new tab) — Interactive attention visualization for BERT
- Transformer-Explainability(opens in a new tab) — CVPR 2021 paper implementation for visualizing Transformer classifications
Best Practices
- Layer Selection: Middle layers often capture the most interesting patterns
- Head Averaging: Some analyses benefit from averaging heads; others require individual examination
- Normalization: Verify that attention weights are properly normalized
- Context Length: Longer sequences may dilute attention patterns