Traceback (most recent call last):
File "/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_25763/1941668504.py", line 23, in <module>
).fit(X, A, y)
File "/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_25763/3488910688.py", line 60, in fit
mcmc.run(rng_key, *args, **kwargs)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/infer/mcmc.py", line 640, in run
states, last_state = pmap(partial_map_fn)(map_args)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/api.py", line 1804, in cache_miss
execute = pxla.xla_pmap_impl_lazy(fun_, *tracers, **params)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 285, in xla_pmap_impl_lazy
compiled_fun, fingerprint = parallel_callable(
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun
ans = call(fun, *args)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 570, in parallel_callable
pmap_computation = lower_parallel_callable(
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 727, in lower_parallel_callable
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 659, in stage_parallel_callable
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2392, in trace_to_jaxpr_final
jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2336, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/infer/mcmc.py", line 416, in _single_chain_mcmc
new_init_state = self.sampler.init(
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/infer/hmc.py", line 713, in init
init_params = self._init_state(
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/infer/hmc.py", line 657, in _init_state
) = initialize_model(
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/infer/util.py", line 656, in initialize_model
) = _get_model_transforms(substituted_model, model_args, model_kwargs)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/infer/util.py", line 450, in _get_model_transforms
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/handlers.py", line 171, in get_trace
self(*args, **kwargs)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File "/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_25763/1207791297.py", line 17, in __call__
mu_y, y = self.denoise_model(X=X, y=y) # denoise outcome E[y|X]
File "/var/folders/1p/v01fvg3j1cz1hzv988ygj8m00000gn/T/ipykernel_25763/3665807378.py", line 4, in __call__
alpha = numpyro.sample("alpha", dist.Normal(0, 1))
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/primitives.py", line 222, in sample
msg = apply_stack(initial_msg)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/primitives.py", line 59, in apply_stack
handler.postprocess_message(msg)
File "/Users/kylecaron/Desktop/kylejcaron.github.io/.venv/lib/python3.10/site-packages/numpyro/handlers.py", line 156, in postprocess_message
assert not (
AssertionError: all sites must have unique names but got `alpha` duplicated