-
Notifications
You must be signed in to change notification settings - Fork 130
Optimize Sum
s of MakeVector
s and Join
s
#59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi! @ricardoV94. I am new to PyMC. I want to work on this. Could you please guide me further? |
Updated output: Sum{axes=None} [id A] 1
└─ Join [id B] 0
├─ 0 [id C]
├─ x [id D]
└─ y [id E]
Add [id A] 2
├─ Sum{axes=None} [id B] 1
│ └─ x [id C]
└─ Sum{axes=None} [id D] 0
└─ y [id E]
Add [id A] 2
├─ Sum{axes=None} [id B] 1
│ └─ x [id C]
└─ Sum{axes=None} [id D] 0
└─ y [id E]
360 μs ± 8.84 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
293 μs ± 4.04 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
290 μs ± 466 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
|
Things still missing. There is an optimization for sum of make_vector introduced in #346, but not for other CAReduce. We should extend it. There's also rewrite for Sum/Prod of Join along axis0 for a join along axis 0. #888 extends it no any join axis if the reduction is on that same axis. Either still leaves out optimizations where:
For instance: import numpy as np
import pytensor
import pytensor.tensor as pt
x = pt.tensor3("x", shape=(128, 128, 128))
y = pt.tensor3("y", shape=(128, 128, 128))
joined = pt.join(0, x, y)
out = pt.sum(joined, axis=(1, 2))
fn = pytensor.function([x, y], out)
alt_out = pt.join(
0,
pt.sum(x, axis=(1, 2)),
pt.sum(y, axis=(1, 2)),
)
alt_fn = pytensor.function([x, y], alt_out)
x_test = np.random.normal(size=x.type.shape)
y_test = np.random.normal(size=y.type.shape)
fn.trust_input=True
alt_fn.trust_input=True
np.testing.assert_allclose(fn(x_test, y_test), alt_fn(x_test, y_test))
%timeit fn(x_test, y_test) # 19.9 ms ± 1.93 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit alt_fn(x_test, y_test) # 7.73 ms ± 618 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) |
Please describe the purpose of filing this issue
Not a drastic improvement by any means, but something we can keep in mind:
Ignoring any axis complexities
The text was updated successfully, but these errors were encountered: