Skip to content

Implement scalar loop for iterative gradients #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 14, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Apr 25, 2023

Used in the derivatives of betainc, gammainc, gammaincc and hyp2f1.

It should be considerably faster than the old numpy loops, and easier to dispatch to numba and jax as well.

Alternative to #174

Closes #83

This is the kind of C code that is generated for the gammaincc gradient inner loop:

Toggle C code
{
    bool until = 1;

    // Set carried inputs to initial state variables
    // Expressions in the loop will depend on these ones
    npy_float64 V41_i_copy0 = V41_i;
    npy_float64 V41_i_copy1 = V41_i;
    npy_float64 V41_i_copy2 = V41_i;
    npy_float64 V43_i_copy3 = V43_i;
    npy_float64 V43_i_copy4 = V43_i;
    npy_float64 V43_i_copy5 = V43_i;
    npy_int8 V39_i_copy6 = V39_i;
    npy_int8 V39_i_copy7 = V39_i;
    npy_int8 V39_i_copy8 = V39_i;
    npy_float64 V41_i_copy9 = V41_i;
    npy_int8 V39_i_copy10 = V39_i;
    npy_int8 V27_i_copy11 = V27_i;
    npy_int32 V37_i_copy12 = V37_i;

    // Loop away
    for (npy_int32 i = 0; i < V45_i; i++)
    {
        npy_float64 V47_tmp1;
        V47_tmp1 = V35_i + V37_i_copy12;

        npy_float64 V47_tmp2;
        V47_tmp2 = 1.0 / V47_tmp1;

        npy_float64 V47_tmp3;
        V47_tmp3 = V43_i_copy3 - V41_i_copy9;

        npy_float64 V47_tmp4;
        V47_tmp4 = exp((npy_float64)V47_tmp3);

        npy_int8 V47_tmp5;
        V47_tmp5 = V39_i_copy6 * V39_i_copy10;

        npy_float64 V47_tmp6;
        V47_tmp6 = V47_tmp5 * V47_tmp4;

        npy_float64 V47_tmp7;
        V47_tmp7 = V47_tmp6 + V47_tmp2;

        npy_float64 V47_tmp8;
        V47_tmp8 = fabs(V47_tmp7);

        npy_float64 V47_tmp9;
        V47_tmp9 = log((npy_float64)V47_tmp8);

        V25_i = V37_i_copy12 + 1;

        npy_float64 V47_tmp10;
        V47_tmp10 = V31_i + V37_i_copy12;

        npy_float64 V47_tmp11;
        V47_tmp11 = V47_tmp10 * V25_i;

        npy_float64 V47_tmp12;
        V47_tmp12 = V33_i + V37_i_copy12;

        npy_float64 V47_tmp13;
        V47_tmp13 = V47_tmp1 * V47_tmp12;

        npy_float64 V47_tmp14;
        V47_tmp14 = V47_tmp13 / V47_tmp11;

        npy_float64 V47_tmp15;
        V47_tmp15 = fabs(V47_tmp14);

        npy_float64 V47_tmp16;
        V47_tmp16 = log((npy_float64)V47_tmp15);

        npy_float64 V47_tmp17;
        V47_tmp17 = V41_i_copy9 + V47_tmp16;

        V19_i = V47_tmp17 + V29_i;

        V7_i = V19_i + V47_tmp9;

        npy_float64 V47_tmp18;
        V47_tmp18 = exp((npy_float64)V7_i);

        npy_bool V47_tmp19;
        V47_tmp19 = (V47_tmp14 > 0);

        npy_int8 V47_tmp20;
        V47_tmp20 = V47_tmp19 ? 1 : -1;

        V21_i = V47_tmp20 * V39_i_copy10;

        npy_bool V47_tmp21;
        V47_tmp21 = (V47_tmp7 > 0);

        npy_int8 V47_tmp22;
        V47_tmp22 = V47_tmp21 ? 1 : -1;

        V13_i = V47_tmp22 * V21_i;

        npy_float64 V47_tmp23;
        V47_tmp23 = V13_i * V47_tmp18;

        npy_float64 V47_tmp24;
        V47_tmp24 = V47_tmp23 * V27_i_copy11;

        npy_float64 V47_tmp25;
        V47_tmp25 = V41_i_copy0 + V47_tmp24;

        npy_bool V47_tmp26;
        V47_tmp26 = (V47_tmp14 == 0);

        V1_i = V47_tmp26 ? V41_i_copy0 : V47_tmp25;

        npy_float64 V47_tmp27;
        V47_tmp27 = 1.0 / V47_tmp12;

        npy_float64 V47_tmp28;
        V47_tmp28 = V43_i_copy4 - V41_i_copy9;

        npy_float64 V47_tmp29;
        V47_tmp29 = exp((npy_float64)V47_tmp28);

        npy_int8 V47_tmp30;
        V47_tmp30 = V39_i_copy7 * V39_i_copy10;

        npy_float64 V47_tmp31;
        V47_tmp31 = V47_tmp30 * V47_tmp29;

        npy_float64 V47_tmp32;
        V47_tmp32 = V47_tmp31 + V47_tmp27;

        npy_float64 V47_tmp33;
        V47_tmp33 = fabs(V47_tmp32);

        npy_float64 V47_tmp34;
        V47_tmp34 = log((npy_float64)V47_tmp33);

        V9_i = V19_i + V47_tmp34;

        npy_float64 V47_tmp35;
        V47_tmp35 = exp((npy_float64)V9_i);

        npy_bool V47_tmp36;
        V47_tmp36 = (V47_tmp32 > 0);

        npy_int8 V47_tmp37;
        V47_tmp37 = V47_tmp36 ? 1 : -1;

        V15_i = V47_tmp37 * V21_i;

        npy_float64 V47_tmp38;
        V47_tmp38 = V15_i * V47_tmp35;

        npy_float64 V47_tmp39;
        V47_tmp39 = V47_tmp38 * V27_i_copy11;

        npy_float64 V47_tmp40;
        V47_tmp40 = V41_i_copy1 + V47_tmp39;

        V3_i = V47_tmp26 ? V41_i_copy1 : V47_tmp40;

        npy_float64 V47_tmp41;
        V47_tmp41 = 1.0 / V47_tmp10;

        npy_float64 V47_tmp42;
        V47_tmp42 = V43_i_copy5 - V41_i_copy9;

        npy_float64 V47_tmp43;
        V47_tmp43 = exp((npy_float64)V47_tmp42);

        npy_int8 V47_tmp44;
        V47_tmp44 = V39_i_copy8 * V39_i_copy10;

        npy_float64 V47_tmp45;
        V47_tmp45 = V47_tmp44 * V47_tmp43;

        npy_float64 V47_tmp46;
        V47_tmp46 = V47_tmp45 - V47_tmp41;

        npy_float64 V47_tmp47;
        V47_tmp47 = fabs(V47_tmp46);

        npy_float64 V47_tmp48;
        V47_tmp48 = log((npy_float64)V47_tmp47);

        V11_i = V19_i + V47_tmp48;

        npy_float64 V47_tmp49;
        V47_tmp49 = exp((npy_float64)V11_i);

        npy_bool V47_tmp50;
        V47_tmp50 = (V47_tmp46 > 0);

        npy_int8 V47_tmp51;
        V47_tmp51 = V47_tmp50 ? 1 : -1;

        V17_i = V47_tmp51 * V21_i;

        npy_float64 V47_tmp52;
        V47_tmp52 = V17_i * V47_tmp49;

        npy_float64 V47_tmp53;
        V47_tmp53 = V47_tmp52 * V27_i_copy11;

        npy_float64 V47_tmp54;
        V47_tmp54 = V41_i_copy2 + V47_tmp53;

        V5_i = V47_tmp26 ? V41_i_copy2 : V47_tmp54;

        V23_i = V27_i_copy11 * V27_i;

        npy_float64 V47_tmp55;
        V47_tmp55 = fabs(V47_tmp53);

        npy_float64 V47_tmp56;
        V47_tmp56 = fabs(V47_tmp39);

        npy_float64 V47_tmp57;
        V47_tmp57 = fabs(V47_tmp24);

        npy_float64 V47_tmp58;
        V47_tmp58 = ((V47_tmp56) > (V47_tmp57) ? (V47_tmp56) : ((V47_tmp57) >= (V47_tmp56) ? (V47_tmp57) : nan("")));

        npy_float64 V47_tmp59;
        V47_tmp59 = ((V47_tmp55) > (V47_tmp58) ? (V47_tmp55) : ((V47_tmp58) >= (V47_tmp55) ? (V47_tmp58) : nan("")));

        npy_bool V47_tmp60;
        V47_tmp60 = (V47_tmp59 <= 1e-14);

        npy_bool V47_tmp61;
        V47_tmp61 = (V25_i > 10);

        npy_bool V47_tmp62;
        V47_tmp62 = (V47_tmp61 & V47_tmp60);

        until = (V47_tmp26 | V47_tmp62);

        // Set carried inputs to output variables
        V41_i_copy0 = V1_i;
        V41_i_copy1 = V3_i;
        V41_i_copy2 = V5_i;
        V43_i_copy3 = V7_i;
        V43_i_copy4 = V9_i;
        V43_i_copy5 = V11_i;
        V39_i_copy6 = V13_i;
        V39_i_copy7 = V15_i;
        V39_i_copy8 = V17_i;
        V41_i_copy9 = V19_i;
        V39_i_copy10 = V21_i;
        V27_i_copy11 = V23_i;
        V37_i_copy12 = V25_i;

        if (until)
        {
            break;
        }
    }

    if (!until)
    {
        PyErr_WarnEx(PyExc_RuntimeWarning, "Until condition in ScalarLoop hyp2f1_grad not reached!", 1);
    }
}

@ricardoV94 ricardoV94 force-pushed the scalar_scan_pure branch 6 times, most recently from 1a36f5c to 2d36eb1 Compare April 27, 2023 14:59
@ricardoV94 ricardoV94 changed the title Implement scalar scan (as pure Op) Implement scalar loop for iterative gradients Apr 27, 2023
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've never tried something like this before so I hope it can be a bit useful. The test coverage for some of the errors in loop.py is spotty, but I don't know if this matters. There's also no test for half precision computation.

@ricardoV94
Copy link
Member Author

ricardoV94 commented May 9, 2023

@aseyboldt mentioned we should be able to implement the gradient as another ScalarLoop, via forward auto-diff, which in this case should be equivalent to backward-autodiff.

It should also be possible to fuse the loops, but hopefully that's one thing the compilers can easily figure out, once they get merged inside the same Composite?

@ricardoV94 ricardoV94 force-pushed the scalar_scan_pure branch 5 times, most recently from 660ab98 to 6ce2143 Compare May 9, 2023 13:06
@ricardoV94 ricardoV94 requested review from jessegrabowski and michaelosthege and removed request for jessegrabowski May 9, 2023 13:08
@ricardoV94
Copy link
Member Author

There's also no test for half precision computation.

We have a job that runs on float32 by default if that's what you meant?

Comment on lines +103 to +105
# TODO: We could convert to TensorVariable, optimize graph,
# and then convert back to ScalarVariable.
# This would introduce rewrites like `log(1 + x) -> log1p`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be more useful to make the rewrites also apply to scalar variables?

Copy link
Member Author

@ricardoV94 ricardoV94 May 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be nice, but there is no automatic way of doing that. Almost all of our rewrites are written with TensorVariables in mind and they would probably error out quickly if you passed ScalarVariables as inputs

Comment on lines +284 to +287
# _c_code += 'printf("inputs=[");'
# for i in range(1, len(fgraph.inputs)):
# _c_code += f'printf("%%.16g, ", %(i{i})s);'
# _c_code += 'printf("]\\n");\n'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this about?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was code I used for debug that I imagine would be very helpful for someone trying to debug the c implementation in the future. Same for the other commented code I left

Comment on lines +320 to +322
# _c_code += 'printf("%%ld\\n", i);\n'
# for carry in range(1, 10):
# _c_code += f'printf("\\t %%.g\\n", i, %(i{carry})s_copy{carry-1});\n'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Consider implementing a scalar Scan Op
3 participants