Skip to content

Rewrite specifically for Sum and Prod to remove Join #951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

HangenYuu
Copy link
Contributor

Description

Rewrite specifically for Sum and Prod to remove Join. So graph that starts as

Sum{axes=None} [id A] 1
 └─ Join [id B] 0
    ├─ 0 [id C]
    ├─ x [id D]
    └─ y [id E]

will become

Add [id A] 2
 ├─ Sum{axes=None} [id B] 1
 │  └─ x [id C]
 └─ Sum{axes=None} [id D] 0
    └─ y [id E]

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ricardoV94
Copy link
Member

This may overlap/be redundant with #888

Copy link
Contributor Author

@HangenYuu HangenYuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My rewrite depends on #346 to get to the final value for Sum as the implementation introduces a MakeVector. For it to work for Prod, I need to add a similar one for Prod.

@register_canonicalize
@register_uncanonicalize
@register_specialize
@node_rewriter([Sum, Prod])
Copy link
Member

@ricardoV94 ricardoV94 Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When there's nothing special about Sum and Prod we should apply the rewrites to all CAReduce operations, which also include stuff like Max, All, Any, ..., of which Sum/Prod are just two more instances

See the related PR I linked to in the comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, 2nd time. I did not see this one show up in #59 😄.

Copy link
Member

@ricardoV94 ricardoV94 Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No they are not the same exactly.

There's a rewrite for reduction along axis0 for a join along axis 0.

My PR extends this to any axis.

There's then the question of multiple axis, of which axis=None is the most extreme (all axes). This PR can cover that case.

We may also want to think about multiple but not all axis. In what cases can we reduce first and join later?

@ricardoV94
Copy link
Member

My rewrite depends on #346 to get to the final value for Sum as the implementation introduces a MakeVector. For it to work for Prod, I need to add a similar one for Prod.

#346 was also myopic in treating sum of make vectors more special than any CAReduce of makevector

@HangenYuu HangenYuu closed this Jul 22, 2024
@HangenYuu HangenYuu reopened this Jul 22, 2024
@HangenYuu
Copy link
Contributor Author

I think you should link #888 to close #59 then.

@HangenYuu HangenYuu closed this Jul 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Optimize Sums of MakeVectors and Joins
2 participants