Skip to main content

Shape Inference

NeuroScript automatically infers tensor dimensions as data flows through your network. This powerful feature lets you write flexible, reusable neurons without specifying every dimension explicitly.

Dimension Variables

In NeuroScript, shapes use dimension variables like dim, batch, seq that get resolved based on how neurons are connected.

neuron MyNeuron(dim):
in: [*, dim] # Input: any batch dimensions + dim
out: [*, dim * 4] # Output: same batch dimensions + dim * 4

When you instantiate this neuron, the compiler infers the actual value of dim from the context.

Basic Example: Linear Projection

This neuron takes input of shape [*, dim] and produces output of shape [*, dim * 4]:

Linear Projection

The dimension variable 'dim' is used in both input shape and Linear arguments

NeuroScript
Loading editor...
PyTorch Output
Loading editor...

Click "Show Analysis" to see the shape contracts. Notice how dim appears in:

  • The input shape [*, dim]
  • The output shape [*, dim * 4]
  • The Linear layer arguments Linear(dim, dim * 4)

Expand and Contract Pattern

A common pattern is to expand dimensions, apply transformations, then contract back:

Expand-Contract (FFN Pattern)

Dimensions expand (dim -> dim*4) then contract back (dim*4 -> dim)

NeuroScript
Loading editor...
PyTorch Output
Loading editor...

This is the feed-forward network (FFN) pattern used in Transformers. The intermediate dimension is 4x the input, allowing for richer representations.

Multiple Dimension Variables

You can use multiple dimension variables for more complex shapes:

Multi-Head Attention Shape

Three dimension variables: batch (*), sequence (seq), and model dimension (dim)

NeuroScript
Loading editor...
PyTorch Output
Loading editor...

Here:

  • * captures any batch dimensions
  • seq is the sequence length
  • dim is the model dimension
  • heads is a parameter (number of attention heads)

How Inference Works

When neurons are connected, the compiler:

  1. Matches output shape to input shape - Dimensions align position-by-position
  2. Unifies dimension variables - If output has dim and input expects dim, they must match
  3. Evaluates expressions - Expressions like dim * 4 are computed when dim is known
  4. Propagates constraints - Inferred values flow through the entire graph

Try It Yourself

Experiment with the examples above:

  • Change dimension variable names
  • Modify the expansion factor (try dim * 2 or dim * 8)
  • Add more layers to the pipeline
  • Click "Show Analysis" to see how shapes flow through connections