Skip to main content

PPL tools for Aesara

Project description

Tests Status Coverage Join the chat at https://gitter.im/aesara-devs/aeppl

aeppl provides tools for a[e]PPL written in Aesara.

Features

  • Convert graphs containing Aesara RandomVariables into joint log-probability graphs

  • Transforms for RandomVariables that map constrained support spaces to unconstrained spaces (e.g. the extended real numbers), and a rewrite that automatically applies these transformations throughout a graph

  • Tools for traversing and transforming graphs containing RandomVariables

  • RandomVariable-aware pretty printing and LaTeX output

Examples

Using aeppl, one can create a joint log-probability graph from a graph containing Aesara RandomVariables:

import aesara
from aesara import tensor as at

from aeppl import joint_logprob, pprint

srng = at.random.RandomStream()

# A simple scale mixture model
S_rv = srng.invgamma(0.5, 0.5)
Y_rv = srng.normal(0.0, at.sqrt(S_rv))

# Compute the joint log-probability
logprob, (y, s) = joint_logprob(Y_rv, S_rv)

Log-probability graphs are standard Aesara graphs, so we can compute values with them:

logprob_fn = aesara.function([y, s], logprob)

logprob_fn(-0.5, 1.0)
# array(-2.46287705)

Graphs can also be pretty printed:

from aeppl import pprint, latex_pprint


# Print the original graph
print(pprint(Y_rv))
# b ~ invgamma(0.5, 0.5) in R, a ~ N(0.0, sqrt(b)**2) in R
# a

print(latex_pprint(Y_rv))
# \begin{equation}
#   \begin{gathered}
#     b \sim \operatorname{invgamma}\left(0.5, 0.5\right)\,  \in \mathbb{R}
#     \\
#     a \sim \operatorname{N}\left(0.0, {\sqrt{b}}^{2}\right)\,  \in \mathbb{R}
#   \end{gathered}
#   \\
#   a
# \end{equation}

# Simplify the graph so that it's easier to read
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.tensor.rewriting.basic import topo_constant_folding


logprob = rewrite_graph(logprob, custom_rewrite=topo_constant_folding)


print(pprint(logprob))
# s in R, y in R
# (switch(s >= 0.0,
#         ((-0.9189385175704956 +
#           switch(s == 0, -inf, (-1.5 * log(s)))) - (0.5 / s)),
#         -inf) +
#  ((-0.9189385332046727 + (-0.5 * ((y / sqrt(s)) ** 2))) - log(sqrt(s))))

Joint log-probabilities can be computed for some terms that are derived from RandomVariables, as well:

# Create a switching model from a Bernoulli distributed index
Z_rv = srng.normal([-100, 100], 1.0, name="Z")
I_rv = srng.bernoulli(0.5, name="I")

M_rv = Z_rv[I_rv]
M_rv.name = "M"

# Compute the joint log-probability for the mixture
logprob, (m, z, i) = joint_logprob(M_rv, Z_rv, I_rv)


logprob = rewrite_graph(logprob, custom_rewrite=topo_constant_folding)

print(pprint(logprob))
# i in Z, m in R, a in Z
# (switch((0 <= i and i <= 1), -0.6931472, -inf) +
#  ((-0.9189385332046727 + (-0.5 * (((m - [-100  100][a]) / [1. 1.][a]) ** 2))) -
#   log([1. 1.][a])))

Installation

The latest release of aeppl can be installed from PyPI using pip:

pip install aeppl

The current development branch of aeppl can be installed from GitHub, also using pip:

pip install git+https://github.com/aesara-devs/aeppl

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

aeppl-nightly-0.1.1.dev20230131.tar.gz (66.3 kB view details)

Uploaded Source

File details

Details for the file aeppl-nightly-0.1.1.dev20230131.tar.gz.

File metadata

File hashes

Hashes for aeppl-nightly-0.1.1.dev20230131.tar.gz
Algorithm Hash digest
SHA256 25b99d8f8df5ca7dc60a90f2029d3f23db98c0a672ac437e87ab455728df65ac
MD5 eeb2009e026d016e4b220933de991871
BLAKE2b-256 608e68a9132586e2b64a96f3cb4d4e4beebe0d96ad4b4c36efaca43c5295afc0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page