Double ML in Numpyro using scope

causal inference
numpyro
Published

April 1, 2025

Background

This is a more of a tutorial for using numpyro’s scope handler. It’s fairly straightforward and allows one to use a composable model framework in numpyro - ie calling multiple models within a model.

The code fold below just has some imports and helper functions

Code
from abc import abstractmethod, ABC
from typing_extensions import Self
from typing import Any, Optional, Dict

import numpyro
numpyro.set_host_device_count(4)

import numpy as np
import scipy.special as sp
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import jax.numpy as jnp
import jax.scipy.special as jsp
from jax import random
from numpyro import distributions as dist
from numpyro.handlers import scope
from numpyro.infer import MCMC, NUTS

SEED = 99

class NumpyroModel(ABC):

    def __init__(self):
        self._fitted = None
        self.samples = None
        self.inference_obj = None


    @abstractmethod
    def __call__(self, *args, **kwargs):
        raise NotImplemented

    @property
    def params(self):
        assert self._fitted
        return self.samples

    def fit(
        self,
        *args,
        num_chains: int = 2,
        num_samples: int = 2000,
        num_warmup: int = 1000,
        seed: int = None,
        inference_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ) -> Self:

        inference_kwargs = {} if inference_kwargs is None else inference_kwargs
        rng_key = random.PRNGKey(seed or np.random.choice(10000))
        kernel = NUTS(self, **inference_kwargs)
        mcmc = MCMC(
            sampler=kernel,
            num_chains=num_chains,
            num_warmup=num_warmup,
            num_samples=num_samples,
        )
        mcmc.run(rng_key, *args, **kwargs)

        # extract posterior and save results internally
        self.samples = mcmc.get_samples()
        self.inference_obj = mcmc
        self._fitted = True
        return self

How does scope work?

We’ll simulate a dataset and fit it with scope to show how it works

rng = np.random.default_rng(SEED)
N = 1000
X = rng.normal(0, 3, size=(N,1))
ALPHA, BETA = np.array([1.2]), np.array([0.4])
y = ALPHA + np.dot(X, BETA) + rng.normal(size=N)
plt.scatter(X[:,0], y)
plt.xlabel("X")
plt.ylabel("Y")
plt.show()

Now lets define a simple linear model

class Linear(NumpyroModel):

    def __call__(self, X, y=None):
        alpha = numpyro.sample("alpha", dist.Normal(0, 1))
        sigma = numpyro.sample("sigma", dist.HalfNormal(1))
        with numpyro.plate("features", X.shape[-1]):
            beta = numpyro.sample("beta", dist.Normal(0, 1))

        with numpyro.plate("obs_id", X.shape[0]):
            mu = numpyro.deterministic("mu", alpha + jnp.dot(X, beta))
            y = numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
            return mu, y

Fitting a model using scope

We could just fit a model with the linear model above, but to illustrate how scope works, we’ll use the linear model as a sub-module in another model

Note, scope has no use in this example, this is just to illustrate how it works

class ScopeExampleModel(NumpyroModel):

    def __call__(self, X, y=None):
        model = scope(Linear(), prefix="M", divider='.')
        return model(X, y=y)


m1 = Linear().fit(X=X, y=y, seed=SEED)
m2 = ScopeExampleModel().fit(X=X, y=y, seed=SEED)

Now lets look at the fitted parameters (theyre the same but scope has a prefix)

If we peek at both posteriors below, we see that the results are identical, but the site names are different - Model 2 (that used scope), has a prefix M.* for each site. Thats all scope does

This is an important feature that lets you call multiple models within your numpyro model. Remember, that models with duplicate site names will fail. This allows us to call the same model multiple times in a different model

# looks at first 3 samples for each parameter
def peek(dct, n=3):
    return {k:v[:n].ravel() for k,v in dct.items() if k[-2:]!='mu'}

peek( m1.params )
{'alpha': Array([1.2196213, 1.1488996, 1.1982138], dtype=float32),
 'beta': Array([0.41865802, 0.4032275 , 0.40824327], dtype=float32),
 'sigma': Array([0.9782854, 0.9691421, 1.0134417], dtype=float32)}
peek( m2.params )
{'M.alpha': Array([1.2196213, 1.1488996, 1.1982138], dtype=float32),
 'M.beta': Array([0.41865802, 0.4032275 , 0.40824327], dtype=float32),
 'M.sigma': Array([0.9782854, 0.9691421, 1.0134417], dtype=float32)}

Double ML example

This is the fun part. We’ll use Double Machine Learning as an example of why scope is so useful.

Double ML background

For some quick background, Double ML is a procedure for estimating unbiased Average Treatment Effects (ATE) where \(A\) is the treatment, \(y\) is the outcome, \(X\) are covariates, and * represents residuals, ie \(A^* = (A-\hat{A})\). It is comprised of 3 models:

  1. Model 1: Propensity Model, \(E[A|X]\): Debiases the treatment with propensity scores
  2. Model 2: Outcome Model \(E[y|X]\): Denoises the Outcome
  3. Model 3: Final Model \(E[y^*|A^*]\): Estimates the ATE as \(\beta\) in a Linear Regression on the residuals of the outcome and treatment

Simulate some data with a biased treatment

Code
# simulate a biased treatment
# X -> A -> y, X -> y

# treatment effect
tau = 0.4 

X = rng.normal(0,1.5, size=(N,3))
b = rng.normal(0, 0.5, size=3)
pA = sp.expit(0.25 + np.dot(X, b))
A = rng.binomial(1, pA)

b_y = rng.normal(0, 0.5,size=3)

mu_y = -1.2 + tau*A + np.dot(X, b_y)
y = rng.normal(mu_y, 0.2)


# Show the data as a dataframe
df = (
    pd.DataFrame(X, columns=[f"X{i}" for i in range(X.shape[1])])
    .assign(A=A)
    .assign(y=y)
)
df.head()
X0 X1 X2 A y
0 0.301738 0.587291 0.837701 1 -0.381980
1 2.883451 -0.783693 1.644344 0 -0.544268
2 -2.214088 1.222192 0.332317 1 -0.983728
3 0.909440 1.710911 3.296030 0 0.685601
4 1.826930 -1.149363 -1.683725 1 -0.996186

Next lets make a Logit model

class Logit(NumpyroModel):

    def __call__(self, X, y=None):
        alpha = numpyro.sample("alpha", dist.Normal(0, 2.5))
        with numpyro.plate("features", X.shape[-1]):
            beta = numpyro.sample("beta", dist.Normal(0, 1))

        with numpyro.plate("obs_id", X.shape[0]):
            eta =  numpyro.deterministic("eta", alpha + jnp.dot(X, beta))
            mu = numpyro.deterministic("mu", jsp.expit(eta))
            y = numpyro.sample("obs", dist.Bernoulli(mu), obs=y)
            return mu, y

Finally, use scope to fit multiple models within 1 numpyro model

Notice how simple Scope allows the __call__ function to be. We can also plug and play any models we want including non-parametric models for the debias/denoise models of Double ML

class DoubleML(NumpyroModel):
    def __init__(
        self,
        propensity_model: NumpyroModel = None,
        outcome_model: NumpyroModel = None,
        **kwargs,
    ):
        self.debias_model = scope(propensity_model, prefix='Mt', divider='.')
        self.denoise_model = scope(outcome_model, prefix='My', divider='.')
        self.ate_model = scope(Linear(), prefix="M", divider='.')
        super().__init__(**kwargs)

    # Look how simple our model is
    def __call__(self, X, A, y=None):

        mu_A, A = self.debias_model(X=X, y=A) # estimate treatment E[A|X]
        mu_y, y = self.denoise_model(X=X, y=y) # denoise outcome E[y|X]

        # calculate residuals to debias and denoise
        A_star = (A - mu_A)
        y_star = (y - mu_y)

        # run linear regression on residuals to estimate ATE
        return self.ate_model(X=A_star[:,None], y=y_star)
    
    @property
    def ate(self):
        return self.params['M.beta'].ravel()


dml = DoubleML(
    propensity_model=Logit(), 
    outcome_model=Linear()
).fit(X, A, y, seed=SEED)

Pulling the average treatment effect

From the Double ML literature, its clear that the \(\beta\) estimate of the final model is the unbiased average treatment effect, so all we have to do is pull that from the model posterior. We just have to make sure to reference the prefix, M. first

# Plot the parameter recovery
sns.histplot(dml.params['M.beta'], element='step', label='Estimate')
plt.axvline(tau, color='r', ls='--', label='Ground Truth')
plt.legend()
plt.xlabel("ATE")
plt.title("Estimated ATE vs. Actual")
Text(0.5, 1.0, 'Estimated ATE vs. Actual')

What happens without scope?

Note, without scope this would have failed because there would be multipled sites with the same param names, since both Linear() and Logit() models have sites named alpha, beta, etc

import traceback

# We'll just change the __init__ to not use scope
# but we'll keep everything else the same
class IncorrectDoubleML(DoubleML):
    
    def __init__(
        self,
        propensity_model: NumpyroModel = None,
        outcome_model: NumpyroModel = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.debias_model = propensity_model
        self.denoise_model = outcome_model
        self.ate_model = Linear()


try:
    dml = IncorrectDoubleML(
        propensity_model=Logit(), 
        outcome_model=Linear()
    ).fit(X, A, y)
except Exception as e:
    traceback.print_exc()
Traceback (most recent call last):
  File "/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_25763/1941668504.py", line 23, in <module>
    ).fit(X, A, y)
  File "/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_25763/3488910688.py", line 60, in fit
    mcmc.run(rng_key, *args, **kwargs)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/infer/mcmc.py", line 640, in run
    states, last_state = pmap(partial_map_fn)(map_args)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/api.py", line 1804, in cache_miss
    execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 285, in xla_pmap_impl_lazy
    compiled_fun, fingerprint = parallel_callable(
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun
    ans = call(fun, *args)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 570, in parallel_callable
    pmap_computation = lower_parallel_callable(
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 727, in lower_parallel_callable
    jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 659, in stage_parallel_callable
    jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2392, in trace_to_jaxpr_final
    jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2336, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/infer/mcmc.py", line 416, in _single_chain_mcmc
    new_init_state = self.sampler.init(
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/infer/hmc.py", line 713, in init
    init_params = self._init_state(
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/infer/hmc.py", line 657, in _init_state
    ) = initialize_model(
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/infer/util.py", line 656, in initialize_model
    ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/infer/util.py", line 450, in _get_model_transforms
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/handlers.py", line 171, in get_trace
    self(*args, **kwargs)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/primitives.py", line 105, in __call__
    return self.fn(*args, **kwargs)
  File "/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_25763/1207791297.py", line 17, in __call__
    mu_y, y = self.denoise_model(X=X, y=y) # denoise outcome E[y|X]
  File "/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_25763/3665807378.py", line 4, in __call__
    alpha = numpyro.sample("alpha", dist.Normal(0, 1))
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/primitives.py", line 222, in sample
    msg = apply_stack(initial_msg)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/primitives.py", line 59, in apply_stack
    handler.postprocess_message(msg)
  File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/handlers.py", line 156, in postprocess_message
    assert not (
AssertionError: all sites must have unique names but got `alpha` duplicated

Conclusion

Thats about it. Scope is a great tool for extending numpyro and lets us start to stack and compose models modularly. This makes Double ML and other causal model implementations incredible easy to read and extend.