My Project
Visitors.h
Go to the documentation of this file.
1 //===- Visitors.h - Utilities for visiting operations -----------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines utilities for walking and visiting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_VISITORS_H
14 #define MLIR_IR_VISITORS_H
15 
16 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/STLExtras.h"
19 
20 namespace mlir {
21 class Diagnostic;
22 class InFlightDiagnostic;
23 class Operation;
24 
27 class WalkResult {
28  enum ResultEnum { Interrupt, Advance } result;
29 
30 public:
31  WalkResult(ResultEnum result) : result(result) {}
32 
35  : result(failed(result) ? Interrupt : Advance) {}
36 
38  WalkResult(Diagnostic &&) : result(Interrupt) {}
39  WalkResult(InFlightDiagnostic &&) : result(Interrupt) {}
40 
41  bool operator==(const WalkResult &rhs) const { return result == rhs.result; }
42 
43  static WalkResult interrupt() { return {Interrupt}; }
44  static WalkResult advance() { return {Advance}; }
45 
47  bool wasInterrupted() const { return result == Interrupt; }
48 };
49 
50 namespace detail {
52 template <typename Ret, typename Arg> Arg first_argument_type(Ret (*)(Arg));
53 template <typename Ret, typename F, typename Arg>
54 Arg first_argument_type(Ret (F::*)(Arg));
55 template <typename Ret, typename F, typename Arg>
56 Arg first_argument_type(Ret (F::*)(Arg) const);
57 template <typename F>
58 decltype(first_argument_type(&F::operator())) first_argument_type(F);
59 
61 template <typename T>
62 using first_argument = decltype(first_argument_type(std::declval<T>()));
63 
65 void walkOperations(Operation *op, function_ref<void(Operation *op)> callback);
66 
71  function_ref<WalkResult(Operation *op)> callback);
72 
73 // Below are a set of functions to walk nested operations. Users should favor
74 // the direct `walk` methods on the IR classes(Operation/Block/etc) over these
75 // methods. They are also templated to allow for statically dispatching based
76 // upon the type of the callback function.
77 
83 template <
84  typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
85  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
86 typename std::enable_if<std::is_same<ArgT, Operation *>::value, RetT>::type
87 walkOperations(Operation *op, FuncTy &&callback) {
88  return detail::walkOperations(op, function_ref<RetT(ArgT)>(callback));
89 }
90 
97 template <
98  typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
99  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
100 typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
101  std::is_same<RetT, void>::value,
102  RetT>::type
103 walkOperations(Operation *op, FuncTy &&callback) {
104  auto wrapperFn = [&](Operation *op) {
105  if (auto derivedOp = dyn_cast<ArgT>(op))
106  callback(derivedOp);
107  };
108  return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
109 }
110 
121 template <
122  typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
123  typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
124 typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
125  std::is_same<RetT, WalkResult>::value,
126  RetT>::type
127 walkOperations(Operation *op, FuncTy &&callback) {
128  auto wrapperFn = [&](Operation *op) {
129  if (auto derivedOp = dyn_cast<ArgT>(op))
130  return callback(derivedOp);
131  return WalkResult::advance();
132  };
133  return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
134 }
135 
137 template <typename FnT>
138 using walkResultType = decltype(walkOperations(nullptr, std::declval<FnT>()));
139 } // end namespace detail
140 
141 } // namespace mlir
142 
143 #endif
Definition: InferTypeOpInterface.cpp:20
WalkResult(InFlightDiagnostic &&)
Definition: Visitors.h:39
Definition: Operation.h:27
Definition: Diagnostics.h:320
bool wasInterrupted() const
Returns if the walk was interrupted.
Definition: Visitors.h:47
Definition: LLVM.h:49
bool failed(LogicalResult result)
Definition: LogicalResult.h:45
Definition: Diagnostics.h:173
Definition: LogicalResult.h:18
WalkResult(LogicalResult result)
Allow LogicalResult to interrupt the walk on failure.
Definition: Visitors.h:34
std::enable_if<!std::is_same< ArgT, Operation * >::value &&std::is_same< RetT, WalkResult >::value, RetT >::type walkOperations(Operation *op, FuncTy &&callback)
Definition: Visitors.h:127
static WalkResult advance()
Definition: Visitors.h:44
void walkOperations(Operation *op, function_ref< void(Operation *op)> callback)
Walk all of the operations nested under and including the given operation.
Definition: Visitors.cpp:15
static WalkResult interrupt()
Definition: Visitors.h:43
Definition: Visitors.h:27
decltype(first_argument_type(std::declval< T >())) first_argument
Type definition of the first argument to the given callable &#39;T&#39;.
Definition: Visitors.h:62
bool operator==(const WalkResult &rhs) const
Definition: Visitors.h:41
Arg first_argument_type(Ret(F::*)(Arg))
WalkResult(ResultEnum result)
Definition: Visitors.h:31
decltype(walkOperations(nullptr, std::declval< FnT >())) walkResultType
Utility to provide the return type of a templated walk method.
Definition: Visitors.h:138
WalkResult(Diagnostic &&)
Allow diagnostics to interrupt the walk.
Definition: Visitors.h:38