Skip to content

Optimize Sums of MakeVectors and Joins #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ricardoV94 opened this issue Nov 29, 2022 · 3 comments
Open

Optimize Sums of MakeVectors and Joins #59

ricardoV94 opened this issue Nov 29, 2022 · 3 comments

Comments

@ricardoV94
Copy link
Member

Please describe the purpose of filing this issue

Not a drastic improvement by any means, but something we can keep in mind:

reduce(at.concatenate(*tensors)) -> reduce(reduce(tensor) for tensor in tensors)

Ignoring any axis complexities

import pytensor
import pytensor.tensor as pt
import numpy as np

x = pt.vector("x")
y = pt.vector("y")

f1 = pytensor.function([x, y], pt.sum(pt.concatenate((x, y))))
f2 = pytensor.function([x, y], pt.sum((pt.sum(x), pt.sum(y))))
f3 = pytensor.function([x, y], pt.add(pt.sum(x), pt.sum(y)))

pytensor.dprint(f1)
print()
pytensor.dprint(f2)
print()
pytensor.dprint(f3)

x_val = np.random.rand(100_000)
y_val = np.random.rand(200_000)

%timeit f1(x_val, y_val)
%timeit f2(x_val, y_val)
%timeit f3(x_val, y_val)
Sum{acc_dtype=float64} [id A] ''   1
 |Join [id B] ''   0
   |TensorConstant{0} [id C]
   |x [id D]
   |y [id E]

Sum{acc_dtype=float64} [id A] ''   3
 |MakeVector{dtype='float64'} [id B] ''   2
   |Sum{acc_dtype=float64} [id C] ''   1
   | |x [id D]
   |Sum{acc_dtype=float64} [id E] ''   0
     |y [id F]

Elemwise{Add}[(0, 0)] [id A] ''   2
 |Sum{acc_dtype=float64} [id B] ''   1
 | |x [id C]
 |Sum{acc_dtype=float64} [id D] ''   0
   |y [id E]
544 µs ± 27.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
270 µs ± 5.11 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
270 µs ± 8.86 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
@manulpatel
Copy link

Hi! @ricardoV94. I am new to PyMC. I want to work on this. Could you please guide me further?

@HangenYuu
Copy link
Contributor

HangenYuu commented Jul 22, 2024

Updated output:

Sum{axes=None} [id A] 1
 └─ Join [id B] 0
    ├─ 0 [id C]
    ├─ x [id D]
    └─ y [id E]

Add [id A] 2
 ├─ Sum{axes=None} [id B] 1
 │  └─ x [id C]
 └─ Sum{axes=None} [id D] 0
    └─ y [id E]

Add [id A] 2
 ├─ Sum{axes=None} [id B] 1
 │  └─ x [id C]
 └─ Sum{axes=None} [id D] 0
    └─ y [id E]

360 μs ± 8.84 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
293 μs ± 4.04 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
290 μs ± 466 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

f2 and f3 are already optimized (especially get rid of MakeVector for f2). Only f1 is still stuck with the Join Op.

@ricardoV94
Copy link
Member Author

Things still missing. There is an optimization for sum of make_vector introduced in #346, but not for other CAReduce. We should extend it.

There's also rewrite for Sum/Prod of Join along axis0 for a join along axis 0. #888 extends it no any join axis if the reduction is on that same axis.

Either still leaves out optimizations where:

  1. We concatenate and then sum all the axes (easy extension of Merge consecutive reduces #888)
  2. We concatenate and sum some but not all axis (may not even sum the axes of concatenation).

For instance:

import numpy as np
import pytensor
import pytensor.tensor as pt

x = pt.tensor3("x", shape=(128, 128, 128))
y = pt.tensor3("y", shape=(128, 128, 128))
joined = pt.join(0, x, y)
out = pt.sum(joined, axis=(1, 2))
fn = pytensor.function([x, y], out)

alt_out = pt.join(
    0,
    pt.sum(x, axis=(1, 2)),
    pt.sum(y, axis=(1, 2)),
)
alt_fn = pytensor.function([x, y], alt_out)

x_test = np.random.normal(size=x.type.shape)
y_test = np.random.normal(size=y.type.shape)

fn.trust_input=True
alt_fn.trust_input=True
np.testing.assert_allclose(fn(x_test, y_test), alt_fn(x_test, y_test))
%timeit fn(x_test, y_test)  # 19.9 ms ± 1.93 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit alt_fn(x_test, y_test)  # 7.73 ms ± 618 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants