Skip to main content

PPL tools for Aesara

Project description

Tests Status Coverage Join the chat at https://gitter.im/aesara-devs/aeppl

aeppl provides tools for a[e]PPL written in Aesara.

Features

  • Convert graphs containing Aesara RandomVariables into joint log-probability graphs

  • Transforms for RandomVariables that map constrained support spaces to unconstrained spaces (e.g. the extended real numbers), and a rewrite that automatically applies these transformations throughout a graph

  • Tools for traversing and transforming graphs containing RandomVariables

  • RandomVariable-aware pretty printing and LaTeX output

Examples

Using aeppl, one can create a joint log-probability graph from a graph containing Aesara RandomVariables:

import aesara
from aesara import tensor as at

from aeppl import joint_logprob, pprint

srng = at.random.RandomStream()

# A simple scale mixture model
S_rv = srng.invgamma(0.5, 0.5)
Y_rv = srng.normal(0.0, at.sqrt(S_rv))

# Compute the joint log-probability
logprob, (y, s) = joint_logprob(Y_rv, S_rv)

Log-probability graphs are standard Aesara graphs, so we can compute values with them:

logprob_fn = aesara.function([y, s], logprob)

logprob_fn(-0.5, 1.0)
# array(-2.46287705)

Graphs can also be pretty printed:

from aeppl import pprint, latex_pprint


# Print the original graph
print(pprint(Y_rv))
# b ~ invgamma(0.5, 0.5) in R, a ~ N(0.0, sqrt(b)**2) in R
# a

print(latex_pprint(Y_rv))
# \begin{equation}
#   \begin{gathered}
#     b \sim \operatorname{invgamma}\left(0.5, 0.5\right)\,  \in \mathbb{R}
#     \\
#     a \sim \operatorname{N}\left(0.0, {\sqrt{b}}^{2}\right)\,  \in \mathbb{R}
#   \end{gathered}
#   \\
#   a
# \end{equation}

# Simplify the graph so that it's easier to read
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.tensor.rewriting.basic import topo_constant_folding


logprob = rewrite_graph(logprob, custom_rewrite=topo_constant_folding)


print(pprint(logprob))
# s in R, y in R
# (switch(s >= 0.0,
#         ((-0.9189385175704956 +
#           switch(s == 0, -inf, (-1.5 * log(s)))) - (0.5 / s)),
#         -inf) +
#  ((-0.9189385332046727 + (-0.5 * ((y / sqrt(s)) ** 2))) - log(sqrt(s))))

Joint log-probabilities can be computed for some terms that are derived from RandomVariables, as well:

# Create a switching model from a Bernoulli distributed index
Z_rv = srng.normal([-100, 100], 1.0, name="Z")
I_rv = srng.bernoulli(0.5, name="I")

M_rv = Z_rv[I_rv]
M_rv.name = "M"

# Compute the joint log-probability for the mixture
logprob, (m, z, i) = joint_logprob(M_rv, Z_rv, I_rv)


logprob = rewrite_graph(logprob, custom_rewrite=topo_constant_folding)

print(pprint(logprob))
# i in Z, m in R, a in Z
# (switch((0 <= i and i <= 1), -0.6931472, -inf) +
#  ((-0.9189385332046727 + (-0.5 * (((m - [-100  100][a]) / [1. 1.][a]) ** 2))) -
#   log([1. 1.][a])))

Installation

The latest release of aeppl can be installed from PyPI using pip:

pip install aeppl

The current development branch of aeppl can be installed from GitHub, also using pip:

pip install git+https://github.com/aesara-devs/aeppl

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

aeppl-nightly-0.1.1.dev20230201.tar.gz (66.3 kB view details)

Uploaded Source

File details

Details for the file aeppl-nightly-0.1.1.dev20230201.tar.gz.

File metadata

File hashes

Hashes for aeppl-nightly-0.1.1.dev20230201.tar.gz
Algorithm Hash digest
SHA256 7c985069b23af005daa5d7ecdcec17e4d3f73933d1d4d50e00617f34a41c37df
MD5 aab3654cddb0834ac8cae8e0291bdfc2
BLAKE2b-256 2098862f668635143ea80013b6647955d584051a69173988e5afc52b9c172ca7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page