Automatic Dim Labelling with Numpyro?

numpyro
tensors
ArviZ
Published

February 23, 2025

Code
from typing import Optional, Callable, List, Dict
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import inspect, Predictive, MCMC, NUTS

import jaxlib
import jax
import jax.numpy as jnp
from jax import random
from jaxtyping import Float, Array, Int


import arviz as az
numpyro.set_host_device_count(2)
SEED: int = 99

print(
    "numpyro version:", numpyro.__version__,
    "\njax version:", jax.__version__,
    "\njaxlib version", jaxlib.__version__,
    "\nArviZ version: ", az.__version__
)
numpyro version: 0.17.0 
jax version: 0.5.0 
jaxlib version 0.5.0 
ArviZ version:  0.20.0

The goal of this post is to figure out how to use numpyro internals to auto-label variable dimensions for ArviZ. PyMC is heavily integrated with ArviZ and their dimension labelling is fantastic - I want it to be just as easy with numpyro.

This isn’t going to be a very fun blog post, it’s meant to follow my thought process as I work through this problem. Maybe that’s helpful for some people? But mostly it will be a good reference for myself in the future.

Part 1: Setting a baseline

What are coords and dims?

I’m assuming readers are familiar with this, but here is a quick refresher. ArviZ takes in a bayesian posterior and organizes it as an xarray dataset - picture pandas but for tensors. ArviZ expects coords and dims to map parameter dimensions to names and categories as show below

coords = {
    "category": ['A', 'B', 'C', 'D', 'E'],
    'features': ['X1', 'X1', 'X3']
}

dims = {
    'beta': ['category', 'features'],
    "gamma": ['category']
}
idata = az.from_numpyro(mcmc_object, coords=coords, dims=dims)

The coords map the category names to a positional index, where A corresponds to index 0, B corresponds to index 1, and so on.

Our dims tell us that our gamma parameter is indexed by category. This would estimate a gamma parameter for each category, i.e. gamma[0] is the gamma estimate for category A. beta in this case has a separate estimate for each category {A, B, C, D, E} and for each feature {X1, X2, X3}.

This is helpful for so many reasons - plots are easier to read with the category names displayed for each parameter, and you can do operations using the dimension names

The code snippet below returns the average estimate for beta when the category=A, which is extremely easy to read

beta = idata.posterior['beta']
avg_beta_groupA = beta.sel(category='A').mean(("chain", "draw"))

A properly labelled InferenceData object from ArviZ is a massive quality of life improvement.

Numpyro’s default behavior

Let’s explore numpyro’s default behavior and see where it falls short. We’ll start with simulating some regression data

rng = np.random.default_rng(SEED)
N, n_feats = 1000, 3
z_cardinality = 10

alpha = 1.3
beta = rng.normal(size=n_feats)
gamma = rng.normal(size=z_cardinality)
sigma = 1

X = rng.normal(size=(N, n_feats))
z = rng.choice(range(z_cardinality), size=N)

mu = alpha + np.dot(X, beta) + gamma[z]
y = rng.normal(mu, sigma)

# we'll make a coords dictionary for later
coords = {
    "Z": ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'],
    "obs_id": np.arange(N),
    "features": ['X1', 'X2', 'X3'],
}

Next, lets fit a linear regression model and try and pass it into ArviZ to see what the dims for our parameters look like

Code
def run_inference(
    model: Callable,
    key: random.PRNGKey,
    num_warmup: int = 50, 
    num_chains: int = 2, 
    num_samples: int = 50,
    **kwargs
) -> MCMC:
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_chains=num_chains, num_samples=num_samples, progress_bar=False)
    mcmc.run(key, **kwargs)
    return mcmc

def print_site_dims(idata, vars = ['alpha', 'beta', 'mu']):
    for var in vars:
        print(f"{var}:", idata.posterior[var].dims )
def model(
    X: Float[Array, "obs features"], 
    Z: Float[Array, " obs"], 
    y: Float[Array, " obs"] = None
) -> Float[Array, " obs"]:

    n_features = X.shape[-1]
    n_groups = len(np.unique(Z))
    
    with numpyro.plate("Z", n_groups):
        alpha = numpyro.sample("alpha", dist.Normal(0, 3))

    sigma = numpyro.sample("sigma", dist.HalfNormal(1))
    with numpyro.plate("features", n_features):
        beta = numpyro.sample("beta", dist.Normal(0, 1))

    with numpyro.plate("obs_id", X.shape[0]):
        mu = numpyro.deterministic("mu", alpha[z] + jnp.dot(X, beta))
        return numpyro.sample("y", dist.Normal(mu, sigma))

mcmc = run_inference(model, random.PRNGKey(0), X=X, Z=z, y=y)
idata = az.from_numpyro(mcmc)
print_site_dims(idata)
alpha: ('chain', 'draw', 'alpha_dim_0')
beta: ('chain', 'draw', 'beta_dim_0')
mu: ('chain', 'draw', 'mu_dim_0')

What we really want is for those dimensions labeled “{site}_dim_0” to be labelled by their plate names, so we need to see if there’s a way to automatically grab the plate metadata. Luckily there’s an inspect module in numpyro that seems like it might cover some of this.

Pulling model relations

model_relations = inspect.get_model_relations(model, model_kwargs=dict(X=X, Z=z, y=y))
model_relations
{'sample_sample': {'alpha': [],
  'sigma': [],
  'beta': [],
  'mu': ['alpha', 'beta'],
  'y': ['sigma', 'mu']},
 'sample_param': {'alpha': [], 'sigma': [], 'beta': [], 'mu': [], 'y': []},
 'sample_dist': {'alpha': 'Normal',
  'sigma': 'HalfNormal',
  'beta': 'Normal',
  'mu': 'Deterministic',
  'y': 'Normal'},
 'param_constraint': {},
 'plate_sample': {'Z': ['alpha'], 'features': ['beta'], 'obs_id': ['mu', 'y']},
 'observed': []}

It looks like plate_sample has exactly what we need, its just in the wrong order. It should really be formatted as

{
    "alpha": ["Z"]
    "beta": ["features"],
    "mu": ["obs_id"],
    "y": ["obs_id"]
}

All we need to do is make a quick dictionary reversing function

from collections import defaultdict
def reverse_plate_mapping(plate_mapping: Dict[str, List[str]]) -> Dict[str, List[str]]:
    reversed_map = defaultdict(list)
    for key, values in plate_mapping.items():
        for value in values:
            reversed_map[value].append(key)
    return reversed_map

dims = reverse_plate_mapping(model_relations['plate_sample'])

idata = az.from_numpyro(mcmc, coords=coords, dims=dims)
print_site_dims(idata)
alpha: ('chain', 'draw', 'Z')
beta: ('chain', 'draw', 'features')
mu: ('chain', 'draw', 'obs_id')

Perfect! Next lets make sure this works when a site is within multiple plates - I’m concerned that my approach for reversing the plate mapping wont get the order right when a site is within multiple plates

Part 2: Adding an edge case with multiple plates per site

Does our approach preserve the dimension order?

We’ll have to simulate some new data where \(\beta\) has group-specific effects. I’m going to make each feature have a group specific effect so that I can test if my solutions will work with nested plates.

# Simulate new data
beta = rng.normal(size=(z_cardinality, n_feats))
mu = alpha + (X * beta[z]).sum(-1) + gamma[z]
y = rng.normal(mu, sigma)

def model2(
    X: Float[Array, "obs features"], 
    Z: Float[Array, " obs"], 
    y: Float[Array, " obs"] = None
) -> Float[Array, " obs"]:

    feature_plate = numpyro.plate("features", X.shape[-1], dim=-2)
    group_plate = numpyro.plate("Z", len(np.unique(Z)), dim=-1)
    
    with group_plate:
        alpha = numpyro.sample("alpha", dist.Normal(0, 3))

    sigma = numpyro.sample("sigma", dist.HalfNormal(1))
    with group_plate, feature_plate:
        beta = numpyro.sample("beta", dist.Normal(0, 1))

    with numpyro.plate("obs_id", X.shape[0]):
        mu = numpyro.deterministic("mu", alpha[z] + (X*beta[:,z].T).sum(-1))
        return numpyro.sample("y", dist.Normal(mu, sigma))

mcmc = run_inference(model2, random.PRNGKey(0), X=X, Z=z, y=y)
idata = az.from_numpyro(mcmc)
print_site_dims(idata)
alpha: ('chain', 'draw', 'alpha_dim_0')
beta: ('chain', 'draw', 'beta_dim_0', 'beta_dim_1')
mu: ('chain', 'draw', 'mu_dim_0')
model2_relations = inspect.get_model_relations(model2, model_kwargs=dict(X=X, Z=z, y=y))
model2_relations
{'sample_sample': {'alpha': [],
  'sigma': [],
  'beta': [],
  'mu': ['alpha', 'beta'],
  'y': ['sigma', 'mu']},
 'sample_param': {'alpha': [], 'sigma': [], 'beta': [], 'mu': [], 'y': []},
 'sample_dist': {'alpha': 'Normal',
  'sigma': 'HalfNormal',
  'beta': 'Normal',
  'mu': 'Deterministic',
  'y': 'Normal'},
 'param_constraint': {},
 'plate_sample': {'features': ['beta'],
  'Z': ['alpha', 'beta'],
  'obs_id': ['mu', 'y']},
 'observed': []}

This is tricky. Beta should have dims (chain, draw, features, Z), but how do we know the plate_sample dictionary will return the correct order? What happens if we reverse the plate order?

Code
def model2_reversed(
    X: Float[Array, "obs features"], 
    Z: Float[Array, " obs"], 
    y: Float[Array, " obs"] = None
) -> Float[Array, " obs"]:

    # reversed the plates
    feature_plate = numpyro.plate("features", X.shape[-1], dim=-1)
    group_plate = numpyro.plate("Z", len(np.unique(Z)), dim=-2)
    
    with group_plate:
        alpha = numpyro.sample("alpha", dist.Normal(0, 3))

    sigma = numpyro.sample("sigma", dist.HalfNormal(1))
    with feature_plate, group_plate:
        beta = numpyro.sample("beta", dist.Normal(0, 1))

    with numpyro.plate("obs_id", X.shape[0]):
        mu = numpyro.deterministic("mu", alpha[z].squeeze(-1) + (X*beta[z]).sum(-1))
        return numpyro.sample("y", dist.Normal(mu, sigma))

mcmc_reversed = run_inference(model2_reversed, random.PRNGKey(0), X=X, Z=z, y=y)
model2_reversed_relations = inspect.get_model_relations(model2_reversed, model_kwargs=dict(X=X, Z=z, y=y))

print(model2_relations['plate_sample'])
print(model2_reversed_relations['plate_sample'])
{'features': ['beta'], 'Z': ['alpha', 'beta'], 'obs_id': ['mu', 'y']}
{'features': ['beta'], 'Z': ['alpha', 'beta'], 'obs_id': ['mu', 'y']}

This is bad news - looking at the plate_samples’s above, we see the same result despite the plates having different dims. This approach isn’t going to maintain the plate order, so we’ll end up mislabelling dims in our previous approach.

Finding better plate metadata

I took a closer look at the get_model_relations function and it looks like this is where they’re pulling plate information

def get_model_relations(model, model_args=None, model_kwargs=None):
    ...

    sample_plates = {
        name: [frame.name for frame in site["cond_indep_stack"]]
        for name, site in trace.items()
        if site["type"] in ["sample", "deterministic"]
    }

    ...

I’m going to try and see what other information is in there by copying their approach of pulling a model trace. Some helper functions are hidden in the code-fold below

Code
from functools import partial
from numpyro import handlers
from numpyro.infer.initialization import init_to_sample
from numpyro.ops.pytree import PytreeTrace

model_kwargs = dict(X=X, Z=z, y=y)

def _get_dist_name(fn):
    if isinstance(
        fn, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution)
    ):
        return _get_dist_name(fn.base_dist)
    return type(fn).__name__

def get_trace(model, model_kwargs):
    # We use `init_to_sample` to get around ImproperUniform distribution,
    # which does not have `sample` method.
    subs_model = handlers.substitute(
        handlers.seed(model, 0),
        substitute_fn=init_to_sample,
    )
    trace = handlers.trace(subs_model).get_trace(**model_kwargs)
    # Work around an issue where jax.eval_shape does not work
    # for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
    # Here we will remove `fn` and store its name in the trace.
    for name, site in trace.items():
        if site["type"] == "sample":
            site["fn_name"] = _get_dist_name(site.pop("fn"))
        elif site["type"] == "deterministic":
            site["fn_name"] = "Deterministic"
    return PytreeTrace(trace)
trace = jax.eval_shape(partial(get_trace, model2, model_kwargs)).trace
trace['beta']
{'args': (),
 'intermediates': [],
 'value': ShapeDtypeStruct(shape=(3, 10), dtype=float32),
 '_control_flow_done': True,
 'type': 'sample',
 'name': 'beta',
 'kwargs': {'rng_key': None, 'sample_shape': ()},
 'scale': None,
 'is_observed': False,
 'cond_indep_stack': [CondIndepStackFrame(name='features', dim=-2, size=3),
  CondIndepStackFrame(name='Z', dim=-1, size=10)],
 'infer': {},
 'fn_name': 'Normal'}

This is perfect! For each site, we can see the plates they are nested within and the dimensions of those plates.

We also see for the reversed model, the order of the plates and the dim values follow a consistent pattern.

trace_reversed = jax.eval_shape(partial(get_trace, model2_reversed, model_kwargs)).trace
trace_reversed['beta']
{'args': (),
 'intermediates': [],
 'value': ShapeDtypeStruct(shape=(10, 3), dtype=float32),
 '_control_flow_done': True,
 'type': 'sample',
 'name': 'beta',
 'kwargs': {'rng_key': None, 'sample_shape': ()},
 'scale': None,
 'is_observed': False,
 'cond_indep_stack': [CondIndepStackFrame(name='Z', dim=-2, size=10),
  CondIndepStackFrame(name='features', dim=-1, size=3)],
 'infer': {},
 'fn_name': 'Normal'}

We should be able to make a working dims mapping from this for each site in our model.

It turns out that they already do this for us on L337

sample_plates = {
    name: [frame.name for frame in site["cond_indep_stack"]]
    for name, site in trace.items()
    if site["type"] in ["sample", "deterministic"]
}

The working implementation is below (but could probably be cleaned up)

def get_site_dims(model: Callable, **kwargs):
    def _get_dist_name(fn):
        if isinstance(
            fn, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution)
        ):
            return _get_dist_name(fn.base_dist)
        return type(fn).__name__

    def get_trace():
        # We use `init_to_sample` to get around ImproperUniform distribution,
        # which does not have `sample` method.
        subs_model = handlers.substitute(
            handlers.seed(model, 0),
            substitute_fn=init_to_sample,
        )
        trace = handlers.trace(subs_model).get_trace(**kwargs)
        # Work around an issue where jax.eval_shape does not work
        # for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
        # Here we will remove `fn` and store its name in the trace.
        for name, site in trace.items():
            if site["type"] == "sample":
                site["fn_name"] = _get_dist_name(site.pop("fn"))
            elif site["type"] == "deterministic":
                site["fn_name"] = "Deterministic"
        return PytreeTrace(trace)

    # We use eval_shape to avoid any array computation.
    trace = jax.eval_shape(get_trace).trace
    sample_plates = {
        name: [frame.name for frame in site["cond_indep_stack"]]
        for name, site in trace.items()
        if site["type"] in ["sample", "deterministic"]
    }
    return {k:v for k,v in sample_plates.items() if len(v) > 0}


idata = az.from_numpyro(
    mcmc,
    dims=get_site_dims(model2, **model_kwargs),
    coords=coords
)

print_site_dims(idata)
alpha: ('chain', 'draw', 'Z')
beta: ('chain', 'draw', 'features', 'Z')
mu: ('chain', 'draw', 'obs_id')

Part 3: What happens with the ZeroSumNormal distribution

A quick tangent on tensor shapes and why the ZeroSumNormal is different

I recently ported the ZeroSumNormal distribution from PyMC to numpyro and it follows some different conventions than the typical distribution.

To understand it, you probably need to have some background on tensor shapes - theres a great blog from Eric Ma here.

Typically in numpyro, when we put a sample site within a plate, the batch shape is determined by the plates. The batch shape is independent, but not identically distributed across dimensions. For instance in our previous alpha parameter we had 10 groups we stratified it by - we’re getting independent draws for alpha for each group.

The ZeroSumNormal instead works by using an event shape instead of a batch shape. Event shapes often have some dependency across dimensions, and in this case the zero sum constraint creates some dependency across the dimensions.

While we typically might define a categorical parameter like this in numpyro:

with numpyro.plate("groups", n_groups):
    gamma = numpyro.sample("gamma", dist.Normal(0,1))

with the ZeroSumNormal we define categorical parameters like this:

gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(n_groups,)))

We need to figure out how to have dims mapped for the ZeroSumNormal despite it not using a plate. The first test will be seeing what happens when we do nest it under a plate

def test_model():
    n_groups = 10
    with numpyro.plate("groups", n_groups):
        gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(n_groups,)))

mcmc = run_inference(test_model, random.PRNGKey(0), num_chains=1)
mcmc.get_samples()['gamma'].shape
(50, 10, 10)

It looks like this unfortunately doesnt work, it creates a batch_shape=10, event_shape=10 when we only want the event_shape in this case.

A custom primitive for labelling event dims

This is going to be a problem that I’m not sure is solve-able with the current tools. But what if we could create a primitive that could save dim names for us in the trace without actually doing anything? ie

# option 1
with pseudo_plate("groups", n_groups):
    gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(n_groups,)))

# option 2
with site_labeller("groups"):
    gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(n_groups,)))

The main idea is that we could store information thats retrieavable in the trace, mimicking plates but without actually expanding the parameter shape like a plate would

Below is an implementation and a quick test model to make sure it has the correct shape saved

from numpyro.primitives import Messenger, Message, CondIndepStackFrame

class pseudo_plate(numpyro.plate):
    
    def __init__(
        self,
        name: str,
        size: int,
        subsample_size: Optional[int] = None,
        dim: Optional[int] = None,
    ) -> None:
        self.name = name
        self.size = size
        if dim is not None and dim >= 0:
            raise ValueError("dim arg must be negative.")
        self.dim, self._indices = self._subsample(
            self.name, self.size, subsample_size, dim
        )
        self.subsample_size = self._indices.shape[0]

    # We'll try only adding our pseudoplate to the CondIndepStack without doing anything else
    def process_message(self, msg: Message) -> None:
        cond_indep_stack: list[CondIndepStackFrame] = msg["cond_indep_stack"]
        frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size)
        cond_indep_stack.append(frame)


def test_model():
    n_groups = 10
    with pseudo_plate("groups", 0):
        gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(n_groups,)))

mcmc = run_inference(test_model, random.PRNGKey(0), num_chains=1)
mcmc.get_samples()['gamma'].shape
(50, 10)

As we can see above, we got the correct shape for gamma. Lets see if this works in a model

def model3(
    X: Float[Array, "obs features"], 
    Z: Float[Array, " obs"], 
    y: Float[Array, " obs"] = None,
    try_pseudo_plate = True,
) -> Float[Array, " obs"]:

    feature_plate = numpyro.plate("features", X.shape[-1], dim=-2)
    group_plate = numpyro.plate("Z", len(np.unique(Z)), dim=-1)
    
    alpha = numpyro.sample("alpha", dist.Normal(0, 3))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1))
    with group_plate, feature_plate:
        beta = numpyro.sample("beta", dist.Normal(0, 1))
    if try_pseudo_plate:
        with pseudo_plate("groups", 0):
            gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(group_plate.size,)))
    else:
        gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(group_plate.size,)))

    with numpyro.plate("obs_id", X.shape[0]):
        mu = numpyro.deterministic("mu", alpha + gamma[z] + (X*beta[:,z].T).sum(-1))
        return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)

mcmc_zsn_pseudoplate = run_inference(model3, random.PRNGKey(0), num_chains=2, num_warmup=1000, num_samples=1000, **model_kwargs)
idata_zsn_pseudoplate = az.from_numpyro(
    mcmc_zsn_pseudoplate,
    dims=get_site_dims(model3, **model_kwargs),
    coords=coords
)

print_site_dims(idata_zsn_pseudoplate, vars=['alpha', 'beta', 'gamma', 'mu'])
alpha: ('chain', 'draw')
beta: ('chain', 'draw', 'features', 'Z')
gamma: ('chain', 'draw', 'groups')
mu: ('chain', 'draw', 'obs_id')

This seemed to work! But are there any unintended consequences of having another “plate” in the CondIndepStack? Lets fit a second version of the model without the pseudo_plate and see if they return similar results

mcmc_zsn = run_inference(model3, random.PRNGKey(0), num_chains=2, num_warmup=1000, num_samples=1000, **model_kwargs, try_pseudo_plate=False)

We’ll check that parameter estimation for the site is the same

import matplotlib.pyplot as plt
az.plot_forest(
    [mcmc_zsn_pseudoplate, mcmc_zsn],
    model_names=['With Pseudo Plate', 'Without Pseudo Plate'],
     var_names=['gamma']
)
plt.title("Gamma Estimates Across both Models")
plt.show()

And more importantly, we’ll make sure this doesnt impact the final likelihood estimate

idata_zsn_pseudoplate = az.from_numpyro(mcmc_zsn_pseudoplate)
idata_zsn = az.from_numpyro(mcmc_zsn)

delta = idata_zsn.log_likelihood['y'] - idata_zsn_pseudoplate.log_likelihood['y']

print("Difference in Log Likelihood with and without the pseudoplate:", float(delta.mean()))
Difference in Log Likelihood with and without the pseudoplate: 0.0

I’m a bit shocked to say that this works so far.

Another edge case: both batch and event dims

Next we need to make sure it works if we also have a batch dimension

def test_model_batches():
    n_groups = 10
    with pseudo_plate("groups", n_groups), numpyro.plate("batches", 3):
        gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(n_groups,)))

trace = jax.eval_shape(partial(get_trace, test_model_batches, model_kwargs={})).trace
trace['gamma']['value']
ShapeDtypeStruct(shape=(3, 1, 10), dtype=float32)

Sadly this didnt quite work - it seems like the dim-handling isn’t working as we want it to because the batch shape is (3,1) instead of just (3,) - this is because of the way we aranged the plates, we put the batches in the position of dim=-2. Lets look at the trace

trace['gamma']
{'args': (),
 'intermediates': [],
 'value': ShapeDtypeStruct(shape=(3, 1, 10), dtype=float32),
 '_control_flow_done': True,
 'type': 'sample',
 'name': 'gamma',
 'kwargs': {'rng_key': None, 'sample_shape': ()},
 'scale': None,
 'is_observed': False,
 'cond_indep_stack': [CondIndepStackFrame(name='batches', dim=-2, size=3),
  CondIndepStackFrame(name='groups', dim=-1, size=10)],
 'infer': {},
 'fn_name': 'ZeroSumNormal'}

Improving the primitive to store new metadata

Taking a step back, it’s clear I thought about this wrong - the Conditional Indepdence Stack is for batch dimensions that are conditionally independent, and by putting our pseudoplate in there, we’re messing with how the batch dimensions are arranged.

It’s also hard to follow exactly how the conditional indepence stack is used throughout the code, and if that will have any unexpected consequences.

I’m going to try something new - lets make a new primitive that creates a dep_stack and we’ll put the pseudo plates in there and see if it fixes things. This is going to get even more ugly so I’m hiding the implementation below in a code fold

Code
from typing import Tuple
from collections import namedtuple
import numpyro

from numpyro.primitives import Messenger, Message, apply_stack, _subsample_fn
from numpyro.util import find_stack_level

DepStackFrame = namedtuple("DepStackFrame", ["name", "dim", "size"])


class pseudo_plate(numpyro.plate):
    
    def __init__(
        self,
        name: str,
        size: int,
        subsample_size: Optional[int] = None,
        dim: Optional[int] = None,
    ) -> None:
        self.name = name
        self.size = size
        if dim is not None and dim >= 0:
            raise ValueError("dim arg must be negative.")
        self.dim, self._indices = self._subsample(
            self.name, self.size, subsample_size, dim
        )
        self.subsample_size = self._indices.shape[0]

    # We'll try only adding our pseudoplate to the CondIndepStack without doing anything else
    def process_message(self, msg: Message) -> None:
        if msg["type"] not in ("param", "sample", "plate", "deterministic"):
            if msg["type"] == "control_flow":
                raise NotImplementedError(
                    "Cannot use control flow primitive under a `plate` primitive."
                    " Please move those `plate` statements into the control flow"
                    " body function. See `scan` documentation for more information."
                )
            return

        if (
            "block_plates" in msg.get("infer", {})
            and self.name in msg["infer"]["block_plates"]
        ):
            return
            
        frame = DepStackFrame(self.name, self.dim, self.subsample_size)
        msg['dep_stack'] = msg.get('dep_stack', []) + [frame]

    def _get_event_shape(self, dep_stack: List[DepStackFrame]) -> Tuple[int]:
        n_dims = max(-f.dim for f in dep_stack)
        event_shape = [1] * n_dims
        for f in dep_stack:
            event_shape[f.dim] = f.size
        return tuple(event_shape)

    # We need to make sure dims get aranged properly when there are multiple plates
    @staticmethod
    def _subsample(name, size, subsample_size, dim):
        msg = {
            "type": "plate",
            "fn": _subsample_fn,
            "name": name,
            "args": (size, subsample_size),
            "kwargs": {"rng_key": None},
            "value": (
                None
                if (subsample_size is not None and size != subsample_size)
                else jnp.arange(size)
            ),
            "scale": 1.0,
            "cond_indep_stack": [],
            "dep_stack": [],
        }
        apply_stack(msg)
        subsample = msg["value"]
        subsample_size = msg["args"][1]
        if subsample_size is not None and subsample_size != subsample.shape[0]:
            warnings.warn(
                "subsample_size does not match len(subsample), {} vs {}.".format(
                    subsample_size, len(subsample)
                )
                + " Did you accidentally use different subsample_size in the model and guide?",
                stacklevel=find_stack_level(),
            )
        dep_stack = msg["dep_stack"]
        occupied_dims = {f.dim for f in dep_stack}
        if dim is None:
            new_dim = -1
            while new_dim in occupied_dims:
                new_dim -= 1
            dim = new_dim
        else:
            assert dim not in occupied_dims
        return dim, subsample

Note that I’m copying alot of the code above directly from the plate primitive.

def test_model_batches():
    n_groups = 10
    with pseudo_plate("groups", n_groups), numpyro.plate("batches", 3):
        gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(n_groups,)))


trace = jax.eval_shape(partial(get_trace, test_model_batches, model_kwargs={})).trace
trace['gamma']['value']
ShapeDtypeStruct(shape=(3, 10), dtype=float32)

Perfect this now works, we just need to combine our two plate stacks!

trace['gamma']['cond_indep_stack'] + trace['gamma']['dep_stack']
[CondIndepStackFrame(name='batches', dim=-1, size=3),
 DepStackFrame(name='groups', dim=-1, size=10)]

Part 4: The resulting implementation

Finally, lets bring this altogether. In the code-fold below is all of the needed functions in one place, but they really need to be cleaned up.

Code
from typing import Tuple, Any, Optional, List, Dict, Callable
from collections import namedtuple
from numpy.typing import ArrayLike
from functools import partial

import numpy as np
import numpyro
from numpyro import distributions as dist
from numpyro import handlers
from numpyro.primitives import Messenger, apply_stack, _subsample_fn
from numpyro.util import find_stack_level
from numpyro.infer import NUTS, MCMC
from numpyro.infer.initialization import init_to_sample
from numpyro.ops.pytree import PytreeTrace

import jax
from jax import random
import jax.numpy as jnp
from jaxtyping import Float, Array, Int
import arviz as az

# #######################
# Code for pseudo_plate
# #######################
Message = dict[str, Any]

DepStackFrame = namedtuple("DepStackFrame", ["name", "dim", "size"])

class pseudo_plate(numpyro.plate):
    
    def __init__(
        self,
        name: str,
        size: int,
        subsample_size: Optional[int] = None,
        dim: Optional[int] = None,
    ) -> None:
        self.name = name
        self.size = size
        if dim is not None and dim >= 0:
            raise ValueError("dim arg must be negative.")
        self.dim, self._indices = self._subsample(
            self.name, self.size, subsample_size, dim
        )
        self.subsample_size = self._indices.shape[0]

    # We'll try only adding our pseudoplate to the CondIndepStack without doing anything else
    def process_message(self, msg: Message) -> None:
        if msg["type"] not in ("param", "sample", "plate", "deterministic"):
            if msg["type"] == "control_flow":
                raise NotImplementedError(
                    "Cannot use control flow primitive under a `plate` primitive."
                    " Please move those `plate` statements into the control flow"
                    " body function. See `scan` documentation for more information."
                )
            return

        if (
            "block_plates" in msg.get("infer", {})
            and self.name in msg["infer"]["block_plates"]
        ):
            return
            
        frame = DepStackFrame(self.name, self.dim, self.subsample_size)
        msg['dep_stack'] = msg.get('dep_stack', []) + [frame]

        if msg["type"] == "deterministic":
            return

    def _get_event_shape(self, dep_stack: List[DepStackFrame]) -> Tuple[int]:
        n_dims = max(-f.dim for f in dep_stack)
        event_shape = [1] * n_dims
        for f in dep_stack:
            event_shape[f.dim] = f.size
        return tuple(event_shape)

    # We need to make sure dims get aranged properly when there are multiple plates
    @staticmethod
    def _subsample(name, size, subsample_size, dim):
        msg = {
            "type": "plate",
            "fn": _subsample_fn,
            "name": name,
            "args": (size, subsample_size),
            "kwargs": {"rng_key": None},
            "value": (
                None
                if (subsample_size is not None and size != subsample_size)
                else jnp.arange(size)
            ),
            "scale": 1.0,
            "cond_indep_stack": [],
            "dep_stack": [],
        }
        apply_stack(msg)
        subsample = msg["value"]
        subsample_size = msg["args"][1]
        if subsample_size is not None and subsample_size != subsample.shape[0]:
            warnings.warn(
                "subsample_size does not match len(subsample), {} vs {}.".format(
                    subsample_size, len(subsample)
                )
                + " Did you accidentally use different subsample_size in the model and guide?",
                stacklevel=find_stack_level(),
            )
        dep_stack = msg["dep_stack"]
        occupied_dims = {f.dim for f in dep_stack}
        if dim is None:
            new_dim = -1
            while new_dim in occupied_dims:
                new_dim -= 1
            dim = new_dim
        else:
            assert dim not in occupied_dims
        return dim, subsample


# #######################
# Code for pulling dims
# #######################

def get_site_dims(model: Callable, **kwargs):
    def _get_dist_name(fn):
        if isinstance(
            fn, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution)
        ):
            return _get_dist_name(fn.base_dist)
        return type(fn).__name__

    def get_trace():
        # We use `init_to_sample` to get around ImproperUniform distribution,
        # which does not have `sample` method.
        subs_model = handlers.substitute(
            handlers.seed(model, 0),
            substitute_fn=init_to_sample,
        )
        trace = handlers.trace(subs_model).get_trace(**kwargs)
        # Work around an issue where jax.eval_shape does not work
        # for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
        # Here we will remove `fn` and store its name in the trace.
        for name, site in trace.items():
            if site["type"] == "sample":
                site["fn_name"] = _get_dist_name(site.pop("fn"))
            elif site["type"] == "deterministic":
                site["fn_name"] = "Deterministic"
        return PytreeTrace(trace)

    # We use eval_shape to avoid any array computation.
    trace = jax.eval_shape(get_trace).trace

    sample_plates = {
        name: [frame.name for frame in site["cond_indep_stack"] + site.get("dep_stack", [])]
        for name, site in trace.items()
        if site["type"] in ["sample", "deterministic"]
    }
    return {k:v for k,v in sample_plates.items() if len(v) > 0}

# helper function
def print_site_dims(idata, vars = ['alpha', 'beta', 'mu']):
    for var in vars:
        print(f"{var}:", idata.posterior[var].dims )

And the simulation code is underneath the following code-fold

Code
# #################
# Simulation Code 
# #################

SEED: int = 99
rng = np.random.default_rng(SEED)
N, n_feats = 1000, 3
z_cardinality = 10

alpha = 1.3
beta = rng.normal(size=n_feats)
gamma = rng.normal(size=z_cardinality)
sigma = 1

X = rng.normal(size=(N, n_feats))
z = rng.choice(range(z_cardinality), size=N)

mu = alpha + np.dot(X, beta) + gamma[z]
y = rng.normal(mu, sigma)

# beta = rng.normal(size=(z_cardinality, n_feats))
# mu = alpha + (X * beta[z]).sum(-1) + gamma[z]
# y = rng.normal(mu, sigma)

And finally here’s how we’d actually use all of these tools

def run_inference(
    model: Callable,
    key: random.PRNGKey,
    num_warmup: int = 100, 
    num_chains: int = 2, 
    num_samples: int = 100,
    coords: Dict[str, ArrayLike] = None,
    **kwargs
) -> Tuple[MCMC, az.InferenceData]:

    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_chains=num_chains, num_samples=num_samples, progress_bar=False)
    mcmc.run(key, **kwargs)

    dims = get_site_dims(model, **kwargs)
    idata = az.from_numpyro(mcmc, coords=coords, dims=dims)
    return mcmc, idata


def model(
    X: Float[Array, "obs features"], 
    Z: Float[Array, " obs"], 
    y: Float[Array, " obs"] = None,
    try_pseudo_plate = True,
) -> Float[Array, " obs"]:

    feature_plate = numpyro.plate("features", X.shape[-1], dim=-2)
    group_plate = numpyro.plate("Z", len(np.unique(Z)), dim=-1)
    
    alpha = numpyro.sample("alpha", dist.Normal(0, 3))
    sigma = numpyro.sample("sigma", dist.HalfNormal(1))
    with group_plate, feature_plate:
        beta = numpyro.sample("beta", dist.Normal(0, 1))
    with pseudo_plate("Z", group_plate.size):
        gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(group_plate.size,)))

    with numpyro.plate("obs_id", X.shape[0]):
        mu = numpyro.deterministic("mu", alpha + gamma[z] + (X*beta[:,z].T).sum(-1))
        return numpyro.sample("y", dist.Normal(mu, sigma), obs=y)


coords = {
    "Z": ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'],
    "obs_id": np.arange(N),
    "features": ['X1', 'X2', 'X3'],
}
model_kwargs = dict(X=X, Z=z, y=y)
mcmc, idata = run_inference(model, random.PRNGKey(0), coords=coords, **model_kwargs)
print_site_dims(idata, vars=['beta', 'gamma', 'mu'])
beta: ('chain', 'draw', 'features', 'Z')
gamma: ('chain', 'draw', 'Z')
mu: ('chain', 'draw', 'obs_id')

We can see above that the dims are all properly labelled, and below we can see them properly mapped to the coords provided

idata.posterior['gamma'].coords['Z']
<xarray.DataArray 'Z' (Z: 10)> Size: 40B
array(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'], dtype='<U1')
Coordinates:
  * Z        (Z) <U1 40B 'A' 'B' 'C' 'D' 'E' 'F' 'G' 'H' 'I' 'J'

Conclusion

This was alot more difficult than I thought, and there are no guarantees this is a stable implementation since I’m working with very low level code, but I’m hoping this is a solid implementation to get cleaner model output from numpyro models. I’d likely want to reorganize all of this code and run more tests against it when I end up using it in my future projects, and I’d be curious if the numpyro devs had any easier ideas for how to pull this off.

Some things that still need to be explored are:

  1. Will this still work when there’s scoping with sub-models? Or integration with other libraries like flax?
  2. What changes need to be made to get this approach extended to SVI?