Skip to content

Commit bdc3e6c

Browse files
authored
[MLIR][python bindings] invalidate ops after PassManager run (#69746)
Fixes #69730 (also see https://reviews.llvm.org/D155543). There are two things outstanding (why I didn't land before): 1. add some C API tests for `mlirOperationWalk`; 2. potentially refactor how the invalidation in `run` works; the first version of the code looked like this: ```cpp if (invalidateOps) { auto *context = op.getOperation().getContext().get(); MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, void *userData) { PyMlirContext *context = static_cast<PyMlirContext *>(userData); context->setOperationInvalid(op); }; auto numRegions = mlirOperationGetNumRegions(op.getOperation().get()); for (int i = 0; i < numRegions; ++i) { MlirRegion region = mlirOperationGetRegion(op.getOperation().get(), i); for (MlirBlock block = mlirRegionGetFirstBlock(region); !mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) for (MlirOperation childOp = mlirBlockGetFirstOperation(block); !mlirOperationIsNull(childOp); childOp = mlirOperationGetNextInBlock(childOp)) mlirOperationWalk(childOp, invalidatingCallback, context, MlirWalkPostOrder); } } ``` This is verbose and ugly but it has the important benefit of not executing `mlirOperationEqual(rootOp->get(), op)` for every op underneath the root op. Supposing there's no desire for the slightly more efficient but highly convoluted approach, I can land this "posthaste". But, since we have eyes on this now, any suggestions or approaches (or needs/concerns) are welcome.
1 parent 7e3d110 commit bdc3e6c

File tree

7 files changed

+218
-9
lines changed

7 files changed

+218
-9
lines changed

mlir/include/mlir-c/IR.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ DEFINE_C_API_STRUCT(MlirValue, const void);
7373
///
7474
/// A named attribute is essentially a (name, attribute) pair where the name is
7575
/// a string.
76-
7776
struct MlirNamedAttribute {
7877
MlirIdentifier name;
7978
MlirAttribute attribute;
@@ -698,6 +697,24 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op,
698697
/// ownership is transferred to the block of the other operation.
699698
MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op,
700699
MlirOperation other);
700+
701+
/// Traversal order for operation walk.
702+
typedef enum MlirWalkOrder {
703+
MlirWalkPreOrder,
704+
MlirWalkPostOrder
705+
} MlirWalkOrder;
706+
707+
/// Operation walker type. The handler is passed an (opaque) reference to an
708+
/// operation a pointer to a `userData`.
709+
typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);
710+
711+
/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
712+
/// `*userData` is passed to the callback as well and can be used to tunnel some
713+
/// some context or other data into the callback.
714+
MLIR_CAPI_EXPORTED
715+
void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
716+
void *userData, MlirWalkOrder walkOrder);
717+
701718
//===----------------------------------------------------------------------===//
702719
// Region API.
703720
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,11 @@ size_t PyMlirContext::clearLiveOperations() {
635635
return numInvalidated;
636636
}
637637

638+
void PyMlirContext::setOperationInvalid(MlirOperation op) {
639+
if (liveOperations.contains(op.ptr))
640+
liveOperations[op.ptr].second->setInvalid();
641+
}
642+
638643
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
639644

640645
pybind11::object PyMlirContext::contextEnter() {

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,11 @@ class PyMlirContext {
209209
/// place.
210210
size_t clearLiveOperations();
211211

212+
/// Sets an operation invalid. This is useful for when some non-bindings
213+
/// code destroys the operation and the bindings need to made aware. For
214+
/// example, in the case when pass manager is run.
215+
void setOperationInvalid(MlirOperation op);
216+
212217
/// Gets the count of live modules associated with this context.
213218
/// Used for testing.
214219
size_t getLiveModuleCount();

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir-c/Pass.h"
1414

1515
namespace py = pybind11;
16+
using namespace py::literals;
1617
using namespace mlir;
1718
using namespace mlir::python;
1819

@@ -63,8 +64,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
6364
mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
6465
return new PyPassManager(passManager);
6566
}),
66-
py::arg("anchor_op") = py::str("any"),
67-
py::arg("context") = py::none(),
67+
"anchor_op"_a = py::str("any"), "context"_a = py::none(),
6868
"Create a new PassManager for the current (or provided) Context.")
6969
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
7070
&PyPassManager::getCapsule)
@@ -82,7 +82,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
8282
[](PyPassManager &passManager, bool enable) {
8383
mlirPassManagerEnableVerifier(passManager.get(), enable);
8484
},
85-
py::arg("enable"), "Enable / disable verify-each.")
85+
"enable"_a, "Enable / disable verify-each.")
8686
.def_static(
8787
"parse",
8888
[](const std::string &pipeline, DefaultingPyMlirContext context) {
@@ -96,7 +96,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
9696
throw py::value_error(std::string(errorMsg.join()));
9797
return new PyPassManager(passManager);
9898
},
99-
py::arg("pipeline"), py::arg("context") = py::none(),
99+
"pipeline"_a, "context"_a = py::none(),
100100
"Parse a textual pass-pipeline and return a top-level PassManager "
101101
"that can be applied on a Module. Throw a ValueError if the pipeline "
102102
"can't be parsed")
@@ -111,20 +111,43 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
111111
if (mlirLogicalResultIsFailure(status))
112112
throw py::value_error(std::string(errorMsg.join()));
113113
},
114-
py::arg("pipeline"),
114+
"pipeline"_a,
115115
"Add textual pipeline elements to the pass manager. Throws a "
116116
"ValueError if the pipeline can't be parsed.")
117117
.def(
118118
"run",
119-
[](PyPassManager &passManager, PyOperationBase &op) {
119+
[](PyPassManager &passManager, PyOperationBase &op,
120+
bool invalidateOps) {
121+
if (invalidateOps) {
122+
typedef struct {
123+
PyOperation &rootOp;
124+
bool rootSeen;
125+
} callBackData;
126+
callBackData data{op.getOperation(), false};
127+
// Mark all ops below the op that the passmanager will be rooted
128+
// at (but not op itself - note the preorder) as invalid.
129+
MlirOperationWalkCallback invalidatingCallback =
130+
[](MlirOperation op, void *userData) {
131+
callBackData *data = static_cast<callBackData *>(userData);
132+
if (LLVM_LIKELY(data->rootSeen))
133+
data->rootOp.getOperation()
134+
.getContext()
135+
->setOperationInvalid(op);
136+
else
137+
data->rootSeen = true;
138+
};
139+
mlirOperationWalk(op.getOperation(), invalidatingCallback,
140+
static_cast<void *>(&data), MlirWalkPreOrder);
141+
}
142+
// Actually run the pass manager.
120143
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
121144
MlirLogicalResult status = mlirPassManagerRunOnOp(
122145
passManager.get(), op.getOperation().get());
123146
if (mlirLogicalResultIsFailure(status))
124147
throw MLIRError("Failure while executing pass pipeline",
125148
errors.take());
126149
},
127-
py::arg("operation"),
150+
"operation"_a, "invalidate_ops"_a = true,
128151
"Run the pass manager on the provided operation, raising an "
129152
"MLIRError on failure.")
130153
.def(

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/IR/Types.h"
2626
#include "mlir/IR/Value.h"
2727
#include "mlir/IR/Verifier.h"
28+
#include "mlir/IR/Visitors.h"
2829
#include "mlir/Interfaces/InferTypeOpInterface.h"
2930
#include "mlir/Parser/Parser.h"
3031

@@ -705,6 +706,20 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
705706
return unwrap(op)->moveBefore(unwrap(other));
706707
}
707708

709+
void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
710+
void *userData, MlirWalkOrder walkOrder) {
711+
switch (walkOrder) {
712+
713+
case MlirWalkPreOrder:
714+
unwrap(op)->walk<mlir::WalkOrder::PreOrder>(
715+
[callback, userData](Operation *op) { callback(wrap(op), userData); });
716+
break;
717+
case MlirWalkPostOrder:
718+
unwrap(op)->walk<mlir::WalkOrder::PostOrder>(
719+
[callback, userData](Operation *op) { callback(wrap(op), userData); });
720+
}
721+
}
722+
708723
//===----------------------------------------------------------------------===//
709724
// Region API.
710725
//===----------------------------------------------------------------------===//

mlir/test/CAPI/ir.c

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2210,6 +2210,51 @@ int testSymbolTable(MlirContext ctx) {
22102210
return 0;
22112211
}
22122212

2213+
typedef struct {
2214+
const char *x;
2215+
} callBackData;
2216+
2217+
void walkCallBack(MlirOperation op, void *rootOpVoid) {
2218+
fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x,
2219+
mlirIdentifierStr(mlirOperationGetName(op)).data);
2220+
}
2221+
2222+
int testOperationWalk(MlirContext ctx) {
2223+
// CHECK-LABEL: @testOperationWalk
2224+
fprintf(stderr, "@testOperationWalk\n");
2225+
2226+
const char *moduleString = "module {\n"
2227+
" func.func @foo() {\n"
2228+
" %1 = arith.constant 10: i32\n"
2229+
" arith.addi %1, %1: i32\n"
2230+
" return\n"
2231+
" }\n"
2232+
"}";
2233+
MlirModule module =
2234+
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
2235+
2236+
callBackData data;
2237+
data.x = "i love you";
2238+
2239+
// CHECK: i love you: arith.constant
2240+
// CHECK: i love you: arith.addi
2241+
// CHECK: i love you: func.return
2242+
// CHECK: i love you: func.func
2243+
// CHECK: i love you: builtin.module
2244+
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
2245+
(void *)(&data), MlirWalkPostOrder);
2246+
2247+
data.x = "i don't love you";
2248+
// CHECK: i don't love you: builtin.module
2249+
// CHECK: i don't love you: func.func
2250+
// CHECK: i don't love you: arith.constant
2251+
// CHECK: i don't love you: arith.addi
2252+
// CHECK: i don't love you: func.return
2253+
mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack,
2254+
(void *)(&data), MlirWalkPreOrder);
2255+
return 0;
2256+
}
2257+
22132258
int testDialectRegistry(void) {
22142259
fprintf(stderr, "@testDialectRegistry\n");
22152260

@@ -2349,6 +2394,8 @@ int main(void) {
23492394
return 14;
23502395
if (testDialectRegistry())
23512396
return 15;
2397+
if (testOperationWalk(ctx))
2398+
return 16;
23522399

23532400
testExplicitThreadPools();
23542401
testDiagnostics();

mlir/test/python/pass_manager.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from mlir.ir import *
55
from mlir.passmanager import *
66
from mlir.dialects.func import FuncOp
7+
from mlir.dialects.builtin import ModuleOp
8+
79

810
# Log everything to stderr and flush so that we have a unified stream to match
911
# errors/info emitted by MLIR to stderr.
@@ -33,6 +35,7 @@ def testCapsule():
3335

3436
run(testCapsule)
3537

38+
3639
# CHECK-LABEL: TEST: testConstruct
3740
@run
3841
def testConstruct():
@@ -68,6 +71,7 @@ def testParseSuccess():
6871

6972
run(testParseSuccess)
7073

74+
7175
# Verify successful round-trip.
7276
# CHECK-LABEL: TEST: testParseSpacedPipeline
7377
def testParseSpacedPipeline():
@@ -84,6 +88,7 @@ def testParseSpacedPipeline():
8488

8589
run(testParseSpacedPipeline)
8690

91+
8792
# Verify failure on unregistered pass.
8893
# CHECK-LABEL: TEST: testParseFail
8994
def testParseFail():
@@ -102,6 +107,7 @@ def testParseFail():
102107

103108
run(testParseFail)
104109

110+
105111
# Check that adding to a pass manager works
106112
# CHECK-LABEL: TEST: testAdd
107113
@run
@@ -147,6 +153,7 @@ def testRunPipeline():
147153
# CHECK: func.return , 1
148154
run(testRunPipeline)
149155

156+
150157
# CHECK-LABEL: TEST: testRunPipelineError
151158
@run
152159
def testRunPipelineError():
@@ -162,4 +169,94 @@ def testRunPipelineError():
162169
# CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
163170
# CHECK: note: "-":1:1: see current operation: "test.op"() : () -> ()
164171
# CHECK: >
165-
print(f"Exception: <{e}>")
172+
log(f"Exception: <{e}>")
173+
174+
175+
# CHECK-LABEL: TEST: testPostPassOpInvalidation
176+
@run
177+
def testPostPassOpInvalidation():
178+
with Context() as ctx:
179+
module = ModuleOp.parse(
180+
"""
181+
module {
182+
arith.constant 10
183+
func.func @foo() {
184+
arith.constant 10
185+
return
186+
}
187+
}
188+
"""
189+
)
190+
191+
# CHECK: invalidate_ops=False
192+
log("invalidate_ops=False")
193+
194+
outer_const_op = module.body.operations[0]
195+
# CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
196+
log(outer_const_op)
197+
198+
func_op = module.body.operations[1]
199+
# CHECK: func.func @[[FOO:.*]]() {
200+
# CHECK: %[[VAL1:.*]] = arith.constant 10 : i64
201+
# CHECK: return
202+
# CHECK: }
203+
log(func_op)
204+
205+
inner_const_op = func_op.body.blocks[0].operations[0]
206+
# CHECK: %[[VAL1]] = arith.constant 10 : i64
207+
log(inner_const_op)
208+
209+
PassManager.parse("builtin.module(canonicalize)").run(
210+
module, invalidate_ops=False
211+
)
212+
# CHECK: func.func @foo() {
213+
# CHECK: return
214+
# CHECK: }
215+
log(func_op)
216+
217+
# CHECK: func.func @foo() {
218+
# CHECK: return
219+
# CHECK: }
220+
log(module)
221+
222+
# CHECK: invalidate_ops=True
223+
log("invalidate_ops=True")
224+
225+
module = ModuleOp.parse(
226+
"""
227+
module {
228+
arith.constant 10
229+
func.func @foo() {
230+
arith.constant 10
231+
return
232+
}
233+
}
234+
"""
235+
)
236+
outer_const_op = module.body.operations[0]
237+
func_op = module.body.operations[1]
238+
inner_const_op = func_op.body.blocks[0].operations[0]
239+
240+
PassManager.parse("builtin.module(canonicalize)").run(module)
241+
try:
242+
log(func_op)
243+
except RuntimeError as e:
244+
# CHECK: the operation has been invalidated
245+
log(e)
246+
247+
try:
248+
log(outer_const_op)
249+
except RuntimeError as e:
250+
# CHECK: the operation has been invalidated
251+
log(e)
252+
253+
try:
254+
log(inner_const_op)
255+
except RuntimeError as e:
256+
# CHECK: the operation has been invalidated
257+
log(e)
258+
259+
# CHECK: func.func @foo() {
260+
# CHECK: return
261+
# CHECK: }
262+
log(module)

0 commit comments

Comments
 (0)