Skip to main content

Einops & Einsum

NeuroScript borrows two powerful ideas from the tensor programming world: einops-style shape transforms (via the fat arrow => operator) and Einstein summation (via the Einsum primitive). This page explains both concepts, their history, and how to use them.


A Brief History

Einstein Summation (1916)

In 1916, Albert Einstein introduced a notational shorthand in his paper on general relativity: when an index variable appears twice in a term, it implies summation over that index. Instead of writing:

c_ij = Σ_k  a_ik * b_kj

you simply write c_ij = a_ik b_kj — the repeated k index tells you to sum over it. This Einstein summation convention became standard in physics and mathematics.

In 2011, NumPy 1.6 brought this idea to programming with numpy.einsum. PyTorch, TensorFlow, and JAX all followed with their own einsum functions. The key insight: a short string like "bij,bjk->bik" can express matrix multiplication, transposition, trace, outer products, and dozens of other operations in a single call.

Einops (2018)

Alex Rogozhnikov's einops library tackled a different but related pain point. Deep learning code is full of reshape, transpose, permute, unsqueeze, and squeeze calls — operations that rearrange tensor dimensions without doing math on the values. These calls are hard to read and easy to get wrong.

Einops introduced a readable mini-language for shape transforms:

# PyTorch: x.reshape(b, h, w, c // g, g).permute(0, 3, 4, 1, 2)
# Einops: rearrange(x, 'b h w (c g) -> b c g h w', g=num_groups)

The idea: name your dimensions, describe where they go. The library figures out the reshapes and transposes for you.

NeuroScript's Approach

NeuroScript combines both ideas as first-class language features:

  • Fat arrow (=>) — einops-style shape transforms directly in pipelines
  • Einsum primitive — Einstein summation as a callable neuron

No library imports. No string parsing at runtime. The compiler knows your shapes and checks them at compile time.


Fat Arrow: Shape Transforms

The fat arrow => reshapes, transposes, splits, and merges dimensions inline. It reads left-to-right: data flows in from the left, comes out with the shape on the right.

Basic Reshape

The simplest case: rearranging dimensions.

Transpose

Swap height and width dimensions of an image tensor

NeuroScript
Loading editor...
PyTorch Output
Loading editor...

The compiler matches dimension names between the input shape [b, c, h, w] and the target shape [b, c, w, h], then generates the correct permute or reshape call.

Splitting Dimensions

Use binding syntax (name=expr) to decompose a single dimension into multiple:

Multi-Head Split

Split the model dimension into separate heads — the core reshape in multi-head attention

NeuroScript
Loading editor...
PyTorch Output
Loading editor...

Here's what happens step by step:

  1. Input arrives as [batch, seq, dim]
  2. First => splits dim into heads and dh (where dh = dim / heads), producing [batch, seq, heads, dh]
  3. Second => reorders to [batch, heads, seq, dh]

This is the reshape that every multi-head attention implementation needs — expressed in one line.

Binding expressions and shape inference

Dimension bindings like dh=dim/heads are tracked by name, but the arithmetic constraint is not yet propagated through shape inference — the compiler treats dh as a free variable. Full end-to-end checking of binding math is a known TODO.

Merging Dimensions

The reverse: merge multiple dimensions into one.

ViT Flatten

Merge spatial dimensions for Vision Transformer patch processing

NeuroScript
Loading editor...
PyTorch Output
Loading editor...

hw=h*w merges the height and width dimensions into a single spatial dimension, then the second arrow transposes channels to the last position. This is the spatial flattening step in Vision Transformers.

The others Keyword

When you want to flatten everything after a certain point without naming each dimension:

Flatten Tail

Collapse all trailing dimensions into one using the others keyword

NeuroScript
Loading editor...
PyTorch Output
Loading editor...

others acts as a wildcard that absorbs all remaining dimensions and flattens them into one. It compiles to PyTorch's -1 dimension (infer size). Note that others is reshape-only syntax — it can't appear in port declarations. The output port uses flat as a free dimension variable whose size is inferred from the collapsed result.


Annotated Transforms: @reduce and @repeat

Plain => preserves the total number of elements — it only rearranges them. But sometimes you need to remove or add dimensions. That's what annotations are for.

@reduce — Collapsing Dimensions

@reduce removes dimensions by applying a reduction function. The compiler figures out which dimensions to reduce by comparing the source and target shapes.

Global Average Pooling

Reduce spatial dimensions with mean — a single line replaces AdaptiveAvgPool2d

NeuroScript
Loading editor...
PyTorch Output
Loading editor...

The input has [b, c, h, w] and the target has [b, c]. The compiler sees that h and w are missing from the target and generates .mean(dim=(2, 3)).

Available reduction strategies:

StrategyPyTorch equivalentUse case
mean.mean(dim=...)Average pooling, feature aggregation
sum.sum(dim=...)Accumulation, counting
max.amax(dim=...)Max pooling, peak detection
min.amin(dim=...)Minimum finding
prod.prod(dim=...)Product reduction
logsumexp.logsumexp(dim=...)Numerically stable log-sum-exp

Here's a sequence-level reduction — common in classification where you need a single vector per sequence:

Sequence Mean Pooling

Reduce the sequence dimension to get a fixed-size representation

NeuroScript
Loading editor...
PyTorch Output
Loading editor...

@reduce with a Neuron

Instead of a built-in function, you can use a learned reduction — pass a neuron that processes the dimensions being removed:

neuron LearnedPool(dim):
in: [batch, seq, dim]
out: [batch, dim]
graph:
in => @reduce(AttentionPool(dim)) [batch, dim] -> out

The compiler instantiates AttentionPool and calls it to perform the reduction, giving you a learnable alternative to mean pooling.

@repeat — Adding Dimensions

@repeat is the inverse of @reduce: it adds new dimensions by copying or broadcasting data.

Expand Dimension

Add a new dimension of size 1 by repeating (broadcasting) the tensor

NeuroScript
Loading editor...
PyTorch Output
Loading editor...

The compiler sees a new dimension (size 1) that doesn't exist in the source and generates unsqueeze + expand.

@repeat(copy) creates a view, not a copy

Despite the name, PyTorch's expand creates a view — no data is copied in memory. In-place operations on the expanded tensor will raise an error; call .contiguous() first if you need a true copy.


Chaining Transforms in Pipelines

Fat arrows compose naturally with neuron calls in a pipeline. You can reshape, process, and reshape again:

Attention Head Pipeline

Split into heads, process, merge back — a common attention pattern

NeuroScript
Loading editor...
PyTorch Output
Loading editor...

This is a full reshape round-trip:

  1. Split dim into heads x dh
  2. Move heads before seq
  3. Apply a per-head linear projection
  4. Move heads back after seq
  5. Merge heads and dh back into dim

All shape transitions are checked at compile time (see the note above about binding expression limitations).


Einsum: Einstein Summation

For operations that combine values across dimensions (not just rearranging them), NeuroScript provides the Einsum primitive. It takes an equation string using Einstein notation.

Reading Einsum Equations

An einsum equation has the form inputs -> output, where each tensor is described by its index letters:

"bij,bjk->bik"
^^^ ^^^ ^^^
| | └── output: batch × i × k
| └─────── second input: batch × j × k
└──────────── first input: batch × i × j

The rule: indices that appear on the left but not on the right are summed over. Here j appears in both inputs but not the output, so it gets summed — that's matrix multiplication.

Common Einsum Patterns

Here are the most useful patterns, each expressed as an equation:

EquationOperationDescription
ij->jiTransposeSwap rows and columns
ii->iDiagonalExtract the diagonal
ii->TraceSum of diagonal elements
ij,jk->ikMatrix multiplyStandard matmul
bij,bjk->bikBatched matmulMatmul with batch dimension
bhqd,bhkd->bhqkAttention scoresQuery-key dot products
i,j->ijOuter productEvery pair multiplied
ij->iRow sumSum over columns

Using Einsum in NeuroScript

Einsum is a neuron with an equation parameter. Multi-input operations receive their inputs as a tuple — tuple elements map positionally to Einsum's named input ports (a, b) in declaration order:

Batched Matrix Multiply

Einsum for batched matrix multiplication using Einstein notation

NeuroScript
Loading editor...
PyTorch Output
Loading editor...

Batched matmul needs two distinct inputs — left with shape [batch, i, j] and right with shape [batch, j, k]. The equation `bij,bjk->bik` contracts over the j index, producing [batch, i, k].

Attention Scores

The core computation in attention — computing query-key similarity. Here the input is forked to serve as both queries and keys (self-attention):

Self-Attention Scores

Query-key dot product via einsum — the heart of the attention mechanism

NeuroScript
Loading editor...
PyTorch Output
Loading editor...

d_head appears in both inputs but not the output, so it gets summed — producing the dot product between each query and key vector.


Fat Arrow vs Einsum: When to Use Which

The two features serve different purposes:

Fat Arrow (=>)Einsum
What it doesRearranges tensor layoutComputes over tensor values
Element countPreserved (unless annotated)Can change (contractions)
Learned paramsNone (pure shape op)None (pure math op)
Compile-time checkedYes — full shape validationPartial (variadic shapes)
Typical useReshape, transpose, split, mergeMatmul, dot product, contraction

Rule of thumb: if you're just moving data around (reshape, transpose, split heads), use =>. If you're multiplying and summing across indices, use Einsum.


Putting It Together

A real-world example combining both: self-attention scores with multi-head reshaping.

Multi-Head Self-Attention Scores

Fork input, reshape each branch into heads with fat arrow, compute scores with einsum

NeuroScript
Loading editor...
PyTorch Output
Loading editor...

This shows the two features working together:

  • The input is forked into q and k branches (self-attention)
  • Fat arrows reshape each branch from [batch, seq, dim] into [batch, heads, seq, dh]
  • The reshaped tensors are joined as a tuple and passed to Einsum, which contracts the dh dimension

Quick Reference

Fat Arrow Syntax

# Basic reshape/transpose
in => [b, c, w, h] -> out

# Split a dimension
in => [batch, seq, heads, dh=dim/heads] -> out

# Merge dimensions
in => [b, c, hw=h*w] -> out

# Flatten remaining dims
in => [b, others] -> out

# Reduce dimensions
in => @reduce(mean) [b, c] -> out
in => @reduce(sum) [batch, dim] -> out

# Reduce with a neuron
in => @reduce(AttentionPool(dim)) [batch, dim] -> out

# Add dimensions
in => @repeat(copy) [b, c, 1, h, w] -> out

# Chain transforms with neuron calls (-> before the neuron call)
in ->
=> [batch, seq, heads, dh=dim/heads]
=> [batch, heads, seq, dh]
SomeNeuron() ->
=> [batch, seq, dim]
out

Einsum Syntax

# Two-input operations (fork input, pass as tuple)
in -> (a, b)
(a, b) -> Einsum(`bij,bjk->bik`) -> out # batched matmul

in -> (q, k)
(q, k) -> Einsum(`bhqd,bhkd->bhqk`) -> out # attention scores

Try It Yourself

Experiment with these concepts:

  • Change @reduce(mean) to @reduce(sum) or @reduce(max) and see how the generated code changes
  • Try chaining multiple => steps to build complex reshapes
  • Combine => with neuron calls in a pipeline
  • Write an einsum equation for an outer product: "i,j->ij"
  • Click "Show Analysis" to see the compiled output