|
235 | 235 | <div class="pytorch-left-menu-search">
|
236 | 236 |
|
237 | 237 | <div class="version">
|
238 |
| - <a href='https://pytorch.org/docs/versions.html'>master (2.0.0a0+gitd52f121 ) ▼</a> |
| 238 | + <a href='https://pytorch.org/docs/versions.html'>master (2.0.0a0+gitd19791e ) ▼</a> |
239 | 239 | </div>
|
240 | 240 |
|
241 | 241 |
|
@@ -1601,6 +1601,20 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
|
1601 | 1601 | <span class="n">lstsq</span><span class="p">,</span>
|
1602 | 1602 | <span class="p">)</span>
|
1603 | 1603 |
|
| 1604 | +<span class="k">class</span> <span class="nc">_TorchCompileInductorWrapper</span><span class="p">:</span> |
| 1605 | + <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">mode</span><span class="p">,</span> <span class="n">passes</span><span class="p">):</span> |
| 1606 | + <span class="kn">from</span> <span class="nn">torch._dynamo.eval_frame</span> <span class="kn">import</span> <span class="n">lookup_backend</span> |
| 1607 | + <span class="kn">from</span> <span class="nn">torch._inductor.config</span> <span class="kn">import</span> <span class="n">InductorConfigContext</span> |
| 1608 | + |
| 1609 | + <span class="bp">self</span><span class="o">.</span><span class="n">compile_fn</span> <span class="o">=</span> <span class="n">lookup_backend</span><span class="p">(</span><span class="s2">"inductor"</span><span class="p">)</span> |
| 1610 | + <span class="bp">self</span><span class="o">.</span><span class="n">cm</span> <span class="o">=</span> <span class="n">InductorConfigContext</span><span class="p">(</span><span class="n">mode</span> <span class="k">if</span> <span class="n">mode</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">passes</span><span class="p">)</span> |
| 1611 | + <span class="bp">self</span><span class="o">.</span><span class="n">_torchdynamo_orig_callable</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compile_fn</span> |
| 1612 | + |
| 1613 | + <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">model_</span><span class="p">,</span> <span class="n">inputs_</span><span class="p">):</span> |
| 1614 | + <span class="k">with</span> <span class="bp">self</span><span class="o">.</span><span class="n">cm</span><span class="p">:</span> |
| 1615 | + <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">compile_fn</span><span class="p">(</span><span class="n">model_</span><span class="p">,</span> <span class="n">inputs_</span><span class="p">)</span> |
| 1616 | + |
| 1617 | + |
1604 | 1618 | <div class="viewcode-block" id="compile"><a class="viewcode-back" href="../generated/torch.compile.html#torch.compile">[docs]</a><span class="k">def</span> <span class="nf">compile</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Callable</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
|
1605 | 1619 | <span class="n">fullgraph</span><span class="p">:</span> <span class="n">builtins</span><span class="o">.</span><span class="n">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
1606 | 1620 | <span class="n">dynamic</span><span class="p">:</span> <span class="n">builtins</span><span class="o">.</span><span class="n">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
|
@@ -1651,22 +1665,12 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
|
1651 | 1665 | <span class="k">return</span> <span class="n">fn</span>
|
1652 | 1666 |
|
1653 | 1667 | <span class="kn">import</span> <span class="nn">torch._dynamo</span>
|
1654 |
| - <span class="kn">from</span> <span class="nn">torch._dynamo.eval_frame</span> <span class="kn">import</span> <span class="n">lookup_backend</span> |
1655 |
| - <span class="kn">from</span> <span class="nn">torch._inductor.config</span> <span class="kn">import</span> <span class="n">InductorConfigContext</span> |
1656 | 1668 | <span class="k">if</span> <span class="n">mode</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">passes</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
1657 | 1669 | <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s2">"Either mode or passes can be specified, but both can't be specified at the same time."</span><span class="p">)</span>
|
1658 | 1670 | <span class="k">if</span> <span class="n">mode</span> <span class="ow">is</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">passes</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
1659 | 1671 | <span class="n">mode</span> <span class="o">=</span> <span class="s2">"default"</span>
|
1660 | 1672 | <span class="k">if</span> <span class="n">backend</span> <span class="o">==</span> <span class="s2">"inductor"</span><span class="p">:</span>
|
1661 |
| - <span class="n">compile_fn</span> <span class="o">=</span> <span class="n">lookup_backend</span><span class="p">(</span><span class="n">backend</span><span class="p">)</span> |
1662 |
| - <span class="n">cm</span> <span class="o">=</span> <span class="n">InductorConfigContext</span><span class="p">(</span><span class="n">mode</span> <span class="k">if</span> <span class="n">mode</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="n">passes</span><span class="p">)</span> |
1663 |
| - |
1664 |
| - <span class="k">def</span> <span class="nf">_compile_fn</span><span class="p">(</span><span class="n">model_</span><span class="p">,</span> <span class="n">inputs_</span><span class="p">):</span> |
1665 |
| - <span class="k">with</span> <span class="n">cm</span><span class="p">:</span> |
1666 |
| - <span class="k">return</span> <span class="n">compile_fn</span><span class="p">(</span><span class="n">model_</span><span class="p">,</span> <span class="n">inputs_</span><span class="p">)</span> |
1667 |
| - |
1668 |
| - <span class="n">_compile_fn</span><span class="o">.</span><span class="n">_torchdynamo_orig_callable</span> <span class="o">=</span> <span class="n">compile_fn</span> <span class="c1"># type: ignore[attr-defined]</span> |
1669 |
| - <span class="n">backend</span> <span class="o">=</span> <span class="n">_compile_fn</span> |
| 1673 | + <span class="n">backend</span> <span class="o">=</span> <span class="n">_TorchCompileInductorWrapper</span><span class="p">(</span><span class="n">mode</span><span class="p">,</span> <span class="n">passes</span><span class="p">)</span> |
1670 | 1674 | <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">_dynamo</span><span class="o">.</span><span class="n">optimize</span><span class="p">(</span><span class="n">backend</span><span class="o">=</span><span class="n">backend</span><span class="p">,</span> <span class="n">nopython</span><span class="o">=</span><span class="n">fullgraph</span><span class="p">,</span> <span class="n">dynamic</span><span class="o">=</span><span class="n">dynamic</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)(</span><span class="n">model</span><span class="p">)</span></div>
|
1671 | 1675 |
|
1672 | 1676 |
|
|
0 commit comments