Skip to main content

PPL tools for Aesara

Project description

Tests Status Coverage Join the chat at https://gitter.im/aesara-devs/aeppl

aeppl provides tools for a[e]PPL written in Aesara.

Features

  • Convert graphs containing Aesara RandomVariables into joint log-probability graphs

  • Transforms for RandomVariables that map constrained support spaces to unconstrained spaces (e.g. the extended real numbers), and a rewrite that automatically applies these transformations throughout a graph

  • Tools for traversing and transforming graphs containing RandomVariables

  • RandomVariable-aware pretty printing and LaTeX output

Examples

Using aeppl, one can create a joint log-probability graph from a graph containing Aesara RandomVariables:

import aesara
from aesara import tensor as at

from aeppl import joint_logprob, pprint

srng = at.random.RandomStream()

# A simple scale mixture model
S_rv = srng.invgamma(0.5, 0.5)
Y_rv = srng.normal(0.0, at.sqrt(S_rv))

# Compute the joint log-probability
logprob, (y, s) = joint_logprob(Y_rv, S_rv)

Log-probability graphs are standard Aesara graphs, so we can compute values with them:

logprob_fn = aesara.function([y, s], logprob)

logprob_fn(-0.5, 1.0)
# array(-2.46287705)

Graphs can also be pretty printed:

from aeppl import pprint, latex_pprint


# Print the original graph
print(pprint(Y_rv))
# b ~ invgamma(0.5, 0.5) in R, a ~ N(0.0, sqrt(b)**2) in R
# a

print(latex_pprint(Y_rv))
# \begin{equation}
#   \begin{gathered}
#     b \sim \operatorname{invgamma}\left(0.5, 0.5\right)\,  \in \mathbb{R}
#     \\
#     a \sim \operatorname{N}\left(0.0, {\sqrt{b}}^{2}\right)\,  \in \mathbb{R}
#   \end{gathered}
#   \\
#   a
# \end{equation}

# Simplify the graph so that it's easier to read
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.tensor.rewriting.basic import topo_constant_folding


logprob = rewrite_graph(logprob, custom_rewrite=topo_constant_folding)


print(pprint(logprob))
# s in R, y in R
# (switch(s >= 0.0,
#         ((-0.9189385175704956 +
#           switch(s == 0, -inf, (-1.5 * log(s)))) - (0.5 / s)),
#         -inf) +
#  ((-0.9189385332046727 + (-0.5 * ((y / sqrt(s)) ** 2))) - log(sqrt(s))))

Joint log-probabilities can be computed for some terms that are derived from RandomVariables, as well:

# Create a switching model from a Bernoulli distributed index
Z_rv = srng.normal([-100, 100], 1.0, name="Z")
I_rv = srng.bernoulli(0.5, name="I")

M_rv = Z_rv[I_rv]
M_rv.name = "M"

# Compute the joint log-probability for the mixture
logprob, (m, z, i) = joint_logprob(M_rv, Z_rv, I_rv)


logprob = rewrite_graph(logprob, custom_rewrite=topo_constant_folding)

print(pprint(logprob))
# i in Z, m in R, a in Z
# (switch((0 <= i and i <= 1), -0.6931472, -inf) +
#  ((-0.9189385332046727 + (-0.5 * (((m - [-100  100][a]) / [1. 1.][a]) ** 2))) -
#   log([1. 1.][a])))

Installation

The latest release of aeppl can be installed from PyPI using pip:

pip install aeppl

The current development branch of aeppl can be installed from GitHub, also using pip:

pip install git+https://github.com/aesara-devs/aeppl

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

aeppl-nightly-0.1.0.dev20230125.tar.gz (66.0 kB view details)

Uploaded Source

File details

Details for the file aeppl-nightly-0.1.0.dev20230125.tar.gz.

File metadata

File hashes

Hashes for aeppl-nightly-0.1.0.dev20230125.tar.gz
Algorithm Hash digest
SHA256 2cde7eb2b29e7a05a933f74a87f469ecf6ed48b1996514f0f86b6b7e3c1b0bf0
MD5 5409424975809c078c89c9cf616215b1
BLAKE2b-256 716ee85e10911ff138e05404899356eb39350761fea3b40de89b0411098004c9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page