Skip to content

Latex representation for SymbolicDistributions #5793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

larryshamalama
Copy link
Member

Closes #5616.

This PR adds Latex presentation to SymbolicDistributions which does not construct rv_ops like regular Distributions, hence why the Latex representation is currently broken.

I would like some advice on how to carry this out. Dispatching was recommended, but the initial class which was called does not seem to be encoded in TensorVariables for SymbolicDistributions. I currently set it up so that a new str_for_symbolic_dist function would be called for SymbolicDistributions for _print_names which can be automatically initialized for such distributions (except for zero-inflated ones because they are created via pm.Mixture). Here are the internal changes that I came up with, but I'm happy to hear discussion about these points:

  • cls is now passed in _zero_inflated_mixture so they don't come out as "Mixture"
  • Akin to str_for_dist, str_for_symbolic_dist is created in printing.py in which pattern matching can be used to determine how to print the corresponding distribution. I have yet to address this and perhaps this is where dispatching would be useful?
  • _print_name attributes initialized for SymbolicDistributions

Some printing references that have been provided: AePPL printing and symbolic-pymc printing. Both of which use dispatching, but only the latter uses singledispatch.

@codecov
Copy link

codecov bot commented May 23, 2022

Codecov Report

Merging #5793 (60c3407) into main (7b239bb) will increase coverage by 0.27%.
The diff coverage is 74.51%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5793      +/-   ##
==========================================
+ Coverage   89.36%   89.64%   +0.27%     
==========================================
  Files          74       73       -1     
  Lines       13757    13257     -500     
==========================================
- Hits        12294    11884     -410     
+ Misses       1463     1373      -90     
Impacted Files Coverage Δ
pymc/variational/__init__.py 100.00% <ø> (ø)
pymc/printing.py 60.00% <9.80%> (-26.14%) ⬇️
pymc/variational/approximations.py 86.51% <62.50%> (+20.44%) ⬆️
pymc/distributions/timeseries.py 78.64% <90.47%> (+1.17%) ⬆️
pymc/aesaraf.py 91.95% <100.00%> (+0.06%) ⬆️
pymc/distributions/discrete.py 99.73% <100.00%> (+<0.01%) ⬆️
pymc/distributions/distribution.py 90.94% <100.00%> (+0.11%) ⬆️
pymc/distributions/logprob.py 97.65% <100.00%> (+0.19%) ⬆️
pymc/distributions/mixture.py 95.72% <100.00%> (-1.04%) ⬇️
pymc/sampling.py 82.52% <100.00%> (-6.19%) ⬇️
... and 15 more

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments below. We need some yummy tests!

pymc/printing.py Outdated

# below is copy-pasted from str_for_dist
if include_params:
# first 3 args are always (rng, size, dtype), rest is relevant for distribution
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not true for SymbolicDists. We might just want to exclude any inputs that are Shared RandomState/Generator variables.

# functools.partial(str_for_dist, formatting="latex"), rv_out
# )

set_print_name(cls, rv_out)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we clone a variable (e.g., during graph rewrite) this will be lost right? It's not a big issue for now, but of that's the case we should add a comment stating so

rv_out = Mixture.dist(weights, comp_dists, **kwargs)

# overriding Mixture _print_name
set_print_name(cls, rv_out)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative would be to be more clever about mixture print name and on the spot check:

  1. If it has 2 components and one is a zero constant, call it "ZeroInflatedX", where "X" is the name of the nonzero component, like "ZeroInflatedBinomial"
  2. If all the components follow the same distribution call it "XMixture" like "NormalMixture" or "GammaMixture"
  3. If there are two different components, call it "X-YMixture" like "Gamma-ExponentialMixture"
  4. Otherwise call it simply Mixture

This is where the dispatching idea becomes a bit more powerful. You can have an arbitrarily complex function that specializes on an Op and can look at its inputs at evaluation time to figure out a nice name.

The same type of logic could be used for Censored (dispatched on at.clip) to call it "CensoredX" like "CensoredNormal", or the future RandomWalk to call them "XRandomWalk" (perhaps with a special case for Normal, where we call it "Gaussian")

The dispatching itself only means we don't need to define the name at creation time (and others can overwrite it more easily). The more fundamental difference is that it uses a per-op function to decide what name to give to the distribution.

The basecase would be the RandomVariable Op which does what was already done eagerly before this PR.

This is just an idea. Feel free to investigate something like this or leave it as a separate enhancement/feature request issue!

@ricardoV94 ricardoV94 added this to the v4.0.0 milestone May 24, 2022
@larryshamalama
Copy link
Member Author

larryshamalama commented May 30, 2022

Notes for myself for tomorrow or Wednesday:

  • How to deal with TensorConstants or random variable constants (like in the ZeroInflatedPoisson)
  • Single component mixture representation

And of course adding aforementioned yummy tests

@larryshamalama
Copy link
Member Author

Need to check:

  • Censored distributions
  • Mixtures with non-constant mixing probabilities
  • Where UnmeasurableConstantRV comes from

@larryshamalama
Copy link
Member Author

Mixing weights for mixture components needs to be discussed because it would hard to delineate them from nonzero_p in zero-inflated distributions. Also, in [p, 1 - p], 1 - p appears as an Elemwise.

@larryshamalama larryshamalama marked this pull request as ready for review June 1, 2022 21:32
@larryshamalama
Copy link
Member Author

*Note: I just marked as ready for review, but it is not ready to be merged. It is missing tests and how to properly address mixture weights.

@twiecki twiecki modified the milestones: v4.0.0, v4.1.0 Jun 3, 2022
@ricardoV94
Copy link
Member

ricardoV94 commented Jun 3, 2022

@twiecki this needs to be a blocker for V4, otherwise you get errors when creating models with Distributions in an interactive environment like jupyter or ipython

@ricardoV94 ricardoV94 modified the milestones: v4.1.0, v4.0.0 Jun 3, 2022
@ricardoV94 ricardoV94 modified the milestones: v4.0.0, v4.1.0 Jun 3, 2022
@michaelosthege
Copy link
Member

@larryshamalama @ricardoV94 how much work is left on this one? Can we include it in the 4.0.2 milestone?

@michaelosthege michaelosthege modified the milestones: v4.1.0, v4.2.0 Jul 2, 2022
@larryshamalama
Copy link
Member Author

To do it well, I think that there is still quite some work to be done. We wanted to do it the hackish way in preparation for v4 but we temporarily have this even more hackish PR merged. I just saw that you added this to the v4.2.0 milestone and I think that's appropriate since I'd like to work on this slowly.

@larryshamalama larryshamalama marked this pull request as draft August 2, 2022 21:31
@larryshamalama larryshamalama mentioned this pull request Aug 28, 2022
5 tasks
@ricardoV94
Copy link
Member

Closing this in favor of #6072

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Latex representation broken for models with SymbolicDistributions
4 participants