Skip to content

Implement Symbolic RVs #6072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/api/distributions/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ Distribution utilities
:toctree: generated/

Distribution
SymbolicDistribution
Discrete
Continuous
NoDistribution
DensityDist
SymbolicRandomVariable
1 change: 1 addition & 0 deletions docs/source/api/shape_utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ This module introduces functions that are made aware of the requested `size_tupl
broadcast_distribution_samples
broadcast_dist_samples_to
rv_size_is_none
change_dist_size
115 changes: 44 additions & 71 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@
import pandas as pd
import scipy.sparse as sps

from aeppl.abstract import MeasurableVariable
from aeppl.logprob import CheckParameterValue
from aesara import config, scalar
from aesara import scalar
from aesara.compile.mode import Mode, get_mode
from aesara.gradient import grad
from aesara.graph import node_rewriter
Expand All @@ -48,7 +47,7 @@
walk,
)
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op, compute_test_value
from aesara.graph.op import Op
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
from aesara.scalar.basic import Cast
from aesara.tensor.basic import _as_tensor_variable
Expand All @@ -58,12 +57,10 @@
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
)
from aesara.tensor.shape import SpecifyShape
from aesara.tensor.sharedvar import SharedVariable
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
from aesara.tensor.var import TensorConstant, TensorVariable

from pymc.exceptions import ShapeError
from pymc.vartypes import continuous_types, isgenerator, typefilter

PotentialShapeType = Union[int, np.ndarray, Sequence[Union[int, Variable]], TensorVariable]
Expand Down Expand Up @@ -150,65 +147,6 @@ def dataframe_to_tensor_variable(df: pd.DataFrame, *args, **kwargs) -> TensorVar
return at.as_tensor_variable(df.to_numpy(), *args, **kwargs)


def change_rv_size(
rv: TensorVariable,
new_size: PotentialShapeType,
expand: Optional[bool] = False,
) -> TensorVariable:
"""Change or expand the size of a `RandomVariable`.

Parameters
==========
rv
The old `RandomVariable` output.
new_size
The new size.
expand:
Expand the existing size by `new_size`.

"""
# Check the dimensionality of the `new_size` kwarg
new_size_ndim = np.ndim(new_size)
if new_size_ndim > 1:
raise ShapeError("The `new_size` must be ≤1-dimensional.", actual=new_size_ndim)
elif new_size_ndim == 0:
new_size = (new_size,)

# Extract the RV node that is to be resized, together with its inputs, name and tag
assert rv.owner.op is not None
if isinstance(rv.owner.op, SpecifyShape):
rv = rv.owner.inputs[0]
rv_node = rv.owner
rng, size, dtype, *dist_params = rv_node.inputs
name = rv.name
tag = rv.tag

if expand:
shape = tuple(rv_node.op._infer_shape(size, dist_params))
size = shape[: len(shape) - rv_node.op.ndim_supp]
new_size = tuple(new_size) + tuple(size)

# Make sure the new size is a tensor. This dtype-aware conversion helps
# to not unnecessarily pick up a `Cast` in some cases (see #4652).
new_size = at.as_tensor(new_size, ndim=1, dtype="int64")

new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)
new_rv = new_rv_node.outputs[-1]
new_rv.name = name
for k, v in tag.__dict__.items():
new_rv.tag.__dict__.setdefault(k, v)

# Update "traditional" rng default_update, if that was set for old RV
default_update = getattr(rng, "default_update", None)
if default_update is not None and default_update is rv_node.outputs[0]:
rng.default_update = new_rv_node.outputs[0]

if config.compute_test_value != "off":
compute_test_value(new_rv_node)

return new_rv


def extract_rv_and_value_vars(
var: TensorVariable,
) -> Tuple[TensorVariable, TensorVariable]:
Expand Down Expand Up @@ -926,6 +864,31 @@ def find_rng_nodes(
]


def replace_rng_nodes(outputs: Sequence[TensorVariable]) -> Sequence[TensorVariable]:
"""Replace any RNG nodes upsteram of outputs by new RNGs of the same type

This can be used when combining a pre-existing graph with a cloned one, to ensure
RNGs are unique across the two graphs.
"""
rng_nodes = find_rng_nodes(outputs)

# Nothing to do here
if not rng_nodes:
return outputs

graph = FunctionGraph(outputs=outputs, clone=False)
new_rng_nodes: List[Union[np.random.RandomState, np.random.Generator]] = []
for rng_node in rng_nodes:
rng_cls: type
if isinstance(rng_node, at.random.var.RandomStateSharedVariable):
rng_cls = np.random.RandomState
else:
rng_cls = np.random.Generator
new_rng_nodes.append(aesara.shared(rng_cls(np.random.PCG64())))
graph.replace_all(zip(rng_nodes, new_rng_nodes), import_missing=True)
return graph.outputs


SeedSequenceSeed = Optional[Union[int, Sequence[int], np.ndarray, np.random.SeedSequence]]


Expand Down Expand Up @@ -987,6 +950,9 @@ def compile_pymc(
this function is called within a model context and the model `check_bounds` flag
is set to False.
"""
# Avoid circular import
from pymc.distributions.distribution import SymbolicRandomVariable

# Create an update mapping of RandomVariable's RNG so that it is automatically
# updated after every function call
rng_updates = {}
Expand All @@ -995,22 +961,29 @@ def compile_pymc(
var
for var in vars_between(inputs, output_to_list)
if var.owner
and isinstance(var.owner.op, (RandomVariable, MeasurableVariable))
and isinstance(var.owner.op, (RandomVariable, SymbolicRandomVariable))
and var not in inputs
):
# All nodes in `vars_between(inputs, outputs)` have owners.
# But mypy doesn't know, so we just assert it:
assert random_var.owner.op is not None
if isinstance(random_var.owner.op, RandomVariable):
rng = random_var.owner.inputs[0]
if not hasattr(rng, "default_update"):
rng_updates[rng] = random_var.owner.outputs[0]
if hasattr(rng, "default_update"):
update_map = {rng: rng.default_update}
else:
rng_updates[rng] = rng.default_update
update_map = {rng: random_var.owner.outputs[0]}
else:
update_fn = getattr(random_var.owner.op, "update", None)
if update_fn is not None:
rng_updates.update(update_fn(random_var.owner))
update_map = random_var.owner.op.update(random_var.owner)
# Check that we are not setting different update expressions for the same variables
for rng, update in update_map.items():
if rng not in rng_updates:
rng_updates[rng] = update
# When a variable has multiple outputs, it will be called twice with the same
# update expression. We don't want to raise in that case, only if the update
# expression in different from the one already registered
elif rng_updates[rng] is not update:
raise ValueError(f"Multiple update expressions found for the variable {rng}")

# We always reseed random variables as this provides RNGs with no chances of collision
if rng_updates:
Expand Down
10 changes: 7 additions & 3 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
Discrete,
Distribution,
NoDistribution,
SymbolicDistribution,
SymbolicRandomVariable,
)
from pymc.distributions.mixture import Mixture, NormalMixture
from pymc.distributions.multivariate import (
Expand All @@ -105,9 +105,11 @@
from pymc.distributions.timeseries import (
AR,
GARCH11,
EulerMaruyama,
GaussianRandomWalk,
MvGaussianRandomWalk,
MvStudentTRandomWalk,
RandomWalk,
)

__all__ = [
Expand Down Expand Up @@ -154,7 +156,7 @@
"OrderedProbit",
"DensityDist",
"Distribution",
"SymbolicDistribution",
"SymbolicRandomVariable",
"Continuous",
"Discrete",
"NoDistribution",
Expand All @@ -171,11 +173,13 @@
"WishartBartlett",
"LKJCholeskyCov",
"LKJCorr",
"AR",
"AsymmetricLaplace",
"RandomWalk",
"GaussianRandomWalk",
"MvGaussianRandomWalk",
"MvStudentTRandomWalk",
"AR",
"EulerMaruyama",
"GARCH11",
"SkewNormal",
"Mixture",
Expand Down
59 changes: 34 additions & 25 deletions pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,26 @@
import aesara.tensor as at
import numpy as np

from aesara.scalar import Clip
from aesara.tensor import TensorVariable
from aesara.tensor.random.op import RandomVariable

from pymc.aesaraf import change_rv_size
from pymc.distributions.distribution import SymbolicDistribution, _moment
from pymc.distributions.distribution import (
Distribution,
SymbolicRandomVariable,
_moment,
)
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size
from pymc.util import check_dist_not_registered


class Censored(SymbolicDistribution):
class CensoredRV(SymbolicRandomVariable):
"""Censored random variable"""

inline_aeppl = True
_print_name = ("Censored", "\\operatorname{Censored}")


class Censored(Distribution):
r"""
Censored distribution

Expand Down Expand Up @@ -72,9 +82,13 @@ class Censored(SymbolicDistribution):
censored_normal = pm.Censored("censored_normal", normal_dist, lower=-1, upper=1)
"""

rv_type = CensoredRV

@classmethod
def dist(cls, dist, lower, upper, **kwargs):
if not isinstance(dist, TensorVariable) or not isinstance(dist.owner.op, RandomVariable):
if not isinstance(dist, TensorVariable) or not isinstance(
dist.owner.op, (RandomVariable, SymbolicRandomVariable)
):
raise ValueError(
f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}"
)
Expand All @@ -85,10 +99,6 @@ def dist(cls, dist, lower, upper, **kwargs):
check_dist_not_registered(dist)
return super().dist([dist, lower, upper], **kwargs)

@classmethod
def ndim_supp(cls, *dist_params):
return 0

@classmethod
def rv_op(cls, dist, lower=None, upper=None, size=None):

Expand All @@ -97,29 +107,28 @@ def rv_op(cls, dist, lower=None, upper=None, size=None):

# When size is not specified, dist may have to be broadcasted according to lower/upper
dist_shape = size if size is not None else at.broadcast_shape(dist, lower, upper)
dist = change_rv_size(dist, dist_shape)
dist = change_dist_size(dist, dist_shape)

# Censoring is achieved by clipping the base distribution between lower and upper
rv_out = at.clip(dist, lower, upper)
dist_, lower_, upper_ = dist.type(), lower.type(), upper.type()
censored_rv_ = at.clip(dist_, lower_, upper_)

# Reference nodes to facilitate identification in other classmethods, without
# worring about possible dimshuffles
rv_out.tag.dist = dist
rv_out.tag.lower = lower
rv_out.tag.upper = upper
return CensoredRV(
inputs=[dist_, lower_, upper_],
outputs=[censored_rv_],
ndim_supp=0,
)(dist, lower, upper)

return rv_out

@classmethod
def change_size(cls, rv, new_size, expand=False):
dist = rv.tag.dist
lower = rv.tag.lower
upper = rv.tag.upper
new_dist = change_rv_size(dist, new_size, expand=expand)
return cls.rv_op(new_dist, lower, upper)
@_change_dist_size.register(CensoredRV)
def change_censored_size(cls, dist, new_size, expand=False):
uncensored_dist, lower, upper = dist.owner.inputs
if expand:
new_size = tuple(new_size) + tuple(uncensored_dist.shape)
return Censored.rv_op(uncensored_dist, lower, upper, size=new_size)


@_moment.register(Clip)
@_moment.register(CensoredRV)
def moment_censored(op, rv, dist, lower, upper):
moment = at.switch(
at.eq(lower, -np.inf),
Expand Down
Loading