-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Latex representation for SymbolicDistributions #5793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #5793 +/- ##
==========================================
+ Coverage 89.36% 89.64% +0.27%
==========================================
Files 74 73 -1
Lines 13757 13257 -500
==========================================
- Hits 12294 11884 -410
+ Misses 1463 1373 -90
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments below. We need some yummy tests!
pymc/printing.py
Outdated
|
||
# below is copy-pasted from str_for_dist | ||
if include_params: | ||
# first 3 args are always (rng, size, dtype), rest is relevant for distribution |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not true for SymbolicDists. We might just want to exclude any inputs that are Shared RandomState/Generator variables.
# functools.partial(str_for_dist, formatting="latex"), rv_out | ||
# ) | ||
|
||
set_print_name(cls, rv_out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we clone a variable (e.g., during graph rewrite) this will be lost right? It's not a big issue for now, but of that's the case we should add a comment stating so
rv_out = Mixture.dist(weights, comp_dists, **kwargs) | ||
|
||
# overriding Mixture _print_name | ||
set_print_name(cls, rv_out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alternative would be to be more clever about mixture print name and on the spot check:
- If it has 2 components and one is a zero constant, call it "ZeroInflatedX", where "X" is the name of the nonzero component, like "ZeroInflatedBinomial"
- If all the components follow the same distribution call it "XMixture" like "NormalMixture" or "GammaMixture"
- If there are two different components, call it "X-YMixture" like "Gamma-ExponentialMixture"
- Otherwise call it simply Mixture
This is where the dispatching idea becomes a bit more powerful. You can have an arbitrarily complex function that specializes on an Op
and can look at its inputs at evaluation time to figure out a nice name.
The same type of logic could be used for Censored (dispatched on at.clip) to call it "CensoredX" like "CensoredNormal", or the future RandomWalk to call them "XRandomWalk" (perhaps with a special case for Normal, where we call it "Gaussian")
The dispatching itself only means we don't need to define the name at creation time (and others can overwrite it more easily). The more fundamental difference is that it uses a per-op function to decide what name to give to the distribution.
The basecase would be the RandomVariable
Op
which does what was already done eagerly before this PR.
This is just an idea. Feel free to investigate something like this or leave it as a separate enhancement/feature request issue!
Notes for myself for tomorrow or Wednesday:
And of course adding aforementioned yummy tests |
Need to check:
|
Mixing weights for mixture components needs to be discussed because it would hard to delineate them from |
*Note: I just marked as ready for review, but it is not ready to be merged. It is missing tests and how to properly address mixture weights. |
@twiecki this needs to be a blocker for V4, otherwise you get errors when creating models with Distributions in an interactive environment like jupyter or ipython |
2d29d7c
to
60c3407
Compare
@larryshamalama @ricardoV94 how much work is left on this one? Can we include it in the 4.0.2 milestone? |
To do it well, I think that there is still quite some work to be done. We wanted to do it the hackish way in preparation for v4 but we temporarily have this even more hackish PR merged. I just saw that you added this to the v4.2.0 milestone and I think that's appropriate since I'd like to work on this slowly. |
Closing this in favor of #6072 |
Closes #5616.
This PR adds Latex presentation to
SymbolicDistribution
s which does not constructrv_op
s like regularDistribution
s, hence why the Latex representation is currently broken.I would like some advice on how to carry this out. Dispatching was recommended, but the initial class which was called does not seem to be encoded in
TensorVariable
s forSymbolicDistribution
s. I currently set it up so that a newstr_for_symbolic_dist
function would be called forSymbolicDistribution
s for_print_name
s which can be automatically initialized for such distributions (except for zero-inflated ones because they are created viapm.Mixture
). Here are the internal changes that I came up with, but I'm happy to hear discussion about these points:cls
is now passed in_zero_inflated_mixture
so they don't come out as "Mixture"str_for_dist
,str_for_symbolic_dist
is created inprinting.py
in which pattern matching can be used to determine how to print the corresponding distribution. I have yet to address this and perhaps this is where dispatching would be useful?_print_name
attributes initialized forSymbolicDistribution
sSome printing references that have been provided: AePPL printing and symbolic-pymc printing. Both of which use dispatching, but only the latter uses
singledispatch
.