MCX interface

This part of the documentation covers the public interface of mcx.

Interactions with the Model

mcx.seed(model: mcx.model.model, rng_key: jax._src.numpy.lax_numpy.ndarray)

Wrap the model’s calling function to do the rng splitting automatically.

class mcx.model(model_fn: function)

MCX representation of a probabilistic model (or program).

Models are expressed as generative functions. The expression of the model within the function should be as close to the mathematical expression as possible. The only difference with standard python code is the use of the “<~” operator for random variable assignments. Model definitions are python functions decorated with the @mcx.model decorator. Calling the model returns samples from the prior predictive distribution.

A model is a representation of a probabilistic graphical model, and as such implicitly defines a multivariate probability distribution. The class model thus inherits from the Distribution class and implements the logpdf and sample method. The sample method returns samples from the joint probability distribution.

Since it represents a probability graphical model, the model instance is a (multivariate) probability distribution, and as such inherits from the Distribution class. It implements the sample and logpdf methods.

Model expressions are parsed into an internal graph representation that can be conditioned on data, compiled into a logpdf or a forward sampler. The result is pure functions that can be further JIT-compiled with JAX, differentiated and dispatched on GPUs and TPUs. The graph can be inspected and modified at runtime.

model_fn

The function that contains mcx model definition.

graph

The internal representation of the model as a graphical model.

namespace

The namespace in which the function is called.

__call__:

Return a sample from the prior predictive distribution.

sample:

Return a sampler from the joint probability distribution.

logpdf:

Return the value of the log-probability density function of the implied multivariate probability distribution.

seed:

Seed the model with an auto-updating PRNGKey so the sampling methods do not need to be called with a new key each time.

References

1

van de Meent, Jan-Willem, Brooks Paige, Hongseok Yang, and Frank Wood. “An introduction to probabilistic programming.” arXiv preprint arXiv:1809.10756 (2018).

2

Kochurov, Max, Colin Carroll, Thomas Wiecki, and Junpeng Lao. “PyMC4: Exploiting Coroutines for Implementing a Probabilistic Programming Framework.” (2019).

Forward sampling