Back to blog

Neural Attention Pattern Visualization Implementation

A hands-on guide to understanding and visualizing attention mechanisms in transformer models with code examples.

#Machine Learning#Visualization#Python#Transformers

Attention mechanisms have significantly transformed natural language processing, but developing an intuitive understanding of their operation can be challenging. This guide provides a practical approach to visualizing attention patterns in transformer models.

What is Attention?

At its core, attention allows a model to focus on different parts of the input sequence when producing each part of the output. The transformer attention computation follows the formula:

The attention formula: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V

Implementation

import torch
import torch.nn.functional as F
import math
 
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Compute scaled dot-product attention.
 
    Args:
        query: (batch, heads, seq_len, d_k)
        key: (batch, heads, seq_len, d_k)
        value: (batch, heads, seq_len, d_v)
        mask: Optional attention mask
    """
    d_k = query.size(-1)
 
    # Compute attention scores
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
 
    # Apply mask if provided (for padding or causal attention)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
 
    # Softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
 
    # Apply attention to values
    output = torch.matmul(attention_weights, value)
 
    return output, attention_weights

Why Visualize?

Visualization helps us:

  1. Understand Model Behavior: Analyze reasoning behind specific predictions
  2. Discover Patterns: Observe learned attention distribution characteristics
  3. Build Trust: Intuitively present the decision process
  4. Debug Models: Detect attention drift or unwanted biases
  5. Improve Interpretability: Make black-box models more transparent

Building a Visualizer

1. Feature Extraction with HuggingFace

from transformers import AutoModel, AutoTokenizer
import torch
 
class AttentionExtractor:
    def __init__(self, model_name="bert-base-uncased"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(
            model_name,
            output_attentions=True
        )
        self.model.eval()
 
    def extract(self, text):
        """Extract attention patterns from text."""
        inputs = self.tokenizer(text, return_tensors="pt")
 
        with torch.no_grad():
            outputs = self.model(**inputs)
 
        # outputs.attentions: tuple of (batch, heads, seq_len, seq_len)
        attentions = outputs.attentions
 
        return {
            "tokens": self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]),
            "attentions": [a.squeeze(0) for a in attentions],  # Remove batch dim
            "num_layers": len(attentions),
            "num_heads": attentions[0].shape[1]
        }

2. Visualization with Matplotlib

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
 
def plot_attention_heatmap(tokens, attention_matrix, head=0, layer=0):
    """Plot a single attention head as a heatmap."""
    fig, ax = plt.subplots(figsize=(10, 8))
 
    # attention_matrix: (heads, seq_len, seq_len)
    attn = attention_matrix[head].numpy()
 
    sns.heatmap(
        attn,
        xticklabels=tokens,
        yticklabels=tokens,
        cmap="viridis",
        ax=ax,
        square=True,
        cbar_kws={"label": "Attention Weight"}
    )
 
    ax.set_xlabel("Key Position")
    ax.set_ylabel("Query Position")
    ax.set_title(f"Attention Head {head}, Layer {layer}")
 
    plt.tight_layout()
    return fig

3. Interactive Visualization with Plotly

import plotly.graph_objects as go
from plotly.subplots import make_subplots
 
def interactive_attention_view(tokens, attention_matrix):
    """Create an interactive multi-head attention view."""
    num_heads = attention_matrix.shape[0]
    cols = min(4, num_heads)
    rows = (num_heads + cols - 1) // cols
 
    fig = make_subplots(rows=rows, cols=cols)
 
    for head in range(num_heads):
        row, col = head // cols + 1, head % cols + 1
        attn = attention_matrix[head].numpy()
 
        fig.add_trace(
            go.Heatmap(
                z=attn,
                x=tokens,
                y=tokens,
                colorscale="Viridis",
                showscale=(head == 0),
                name=f"Head {head}"
            ),
            row=row, col=col
        )
 
    fig.update_layout(
        title="Multi-Head Attention Patterns",
        height=200 * rows,
        width=200 * cols
    )
 
    return fig

Observable Patterns

Local Attention

Some heads learn to focus on adjacent tokens, useful for:

  • Syntactic dependencies: Subject-verb agreement
  • Local phrase structure: Noun phrases, prepositional phrases
def detect_local_attention(attention_matrix, window=3):
    """Measure how much attention is focused locally."""
    seq_len = attention_matrix.shape[-1]
    local_mask = torch.zeros(seq_len, seq_len)
 
    for i in range(seq_len):
        for j in range(max(0, i - window), min(seq_len, i + window + 1)):
            local_mask[i, j] = 1
 
    local_attention = (attention_matrix * local_mask).sum()
    total_attention = attention_matrix.sum()
 
    return local_attention / total_attention

Global Attention

Other heads attend broadly across the sequence:

  • Sentence-level semantics: Topic detection
  • Long-range dependencies: Coreference resolution

Head Specialization

Research has shown different heads specialize in:

Head TypeFunctionExample Pattern
SyntaxGrammatical structureSubject → Verb
PositionSequential orderNext/previous token
SemanticMeaning relationshipsSynonyms, antonyms
EntityNamed entitiesPerson → Organization

Practical Application: Analyzing BERT

# Example: Analyze what BERT focuses on
extractor = AttentionExtractor("bert-base-uncased")
 
text = "The cat sat on the mat because it was tired."
result = extractor.extract(text)
 
# Find which token "it" attends to most
it_index = result["tokens"].index("it")
print(f"Analyzing attention for 'it' at position {it_index}")
 
for layer in range(result["num_layers"]):
    # Average across heads
    avg_attention = result["attentions"][layer].mean(dim=0)
 
    # What does "it" attend to?
    it_attention = avg_attention[it_index]
    top_indices = it_attention.topk(3).indices.tolist()
 
    top_tokens = [result["tokens"][i] for i in top_indices]
    print(f"Layer {layer}: 'it' → {top_tokens}")

This reveals how BERT resolves the pronoun reference through different layers.

Tools and Resources

Best Practices

  1. Layer Selection: Middle layers often capture the most interesting patterns
  2. Head Averaging: Some analyses benefit from averaging heads; others require individual examination
  3. Normalization: Verify that attention weights are properly normalized
  4. Context Length: Longer sequences may dilute attention patterns