My Project
Matchers.h
Go to the documentation of this file.
1 //===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides a simple and efficient mechanism for performing general
10 // tree-based pattern matching over MLIR. This mechanism is inspired by LLVM's
11 // include/llvm/IR/PatternMatch.h.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_MATCHERS_H
16 #define MLIR_MATCHERS_H
17 
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/IR/StandardTypes.h"
20 
21 namespace mlir {
22 
23 namespace detail {
24 
27 template <
28  typename AttrClass,
29  // Require AttrClass to be a derived class from Attribute and get its
30  // value type
31  typename ValueType =
32  typename std::enable_if<std::is_base_of<Attribute, AttrClass>::value,
33  AttrClass>::type::ValueType,
34  // Require the ValueType is not void
35  typename = typename std::enable_if<!std::is_void<ValueType>::value>::type>
37  ValueType *bind_value;
38 
40  attr_value_binder(ValueType *bv) : bind_value(bv) {}
41 
42  bool match(const Attribute &attr) {
43  if (auto intAttr = attr.dyn_cast<AttrClass>()) {
44  *bind_value = intAttr.getValue();
45  return true;
46  }
47  return false;
48  }
49 };
50 
53 template <typename AttrT> struct constant_op_binder {
54  AttrT *bind_value;
55 
58  constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
59 
60  bool match(Operation *op) {
61  if (op->getNumOperands() > 0 || op->getNumResults() != 1)
62  return false;
63  if (!op->hasNoSideEffect())
64  return false;
65 
67  if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
68  if (auto attr = foldedOp.front().dyn_cast<Attribute>()) {
69  if ((*bind_value = attr.dyn_cast<AttrT>()))
70  return true;
71  }
72  }
73  return false;
74  }
75 };
76 
81 
84 
85  bool match(Operation *op) {
86  Attribute attr;
87  if (!constant_op_binder<Attribute>(&attr).match(op))
88  return false;
89  auto type = op->getResult(0)->getType();
90 
91  if (type.isIntOrIndex()) {
93  }
94  if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
95  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
97  .match(splatAttr.getSplatValue());
98  }
99  }
100  return false;
101  }
102 };
103 
106 template <int64_t TargetValue> struct constant_int_value_matcher {
107  bool match(Operation *op) {
108  APInt value;
109  return constant_int_op_binder(&value).match(op) && TargetValue == value;
110  }
111 };
112 
115 template <int64_t TargetNotValue> struct constant_int_not_value_matcher {
116  bool match(Operation *op) {
117  APInt value;
118  return constant_int_op_binder(&value).match(op) && TargetNotValue != value;
119  }
120 };
121 
123 template <typename OpClass> struct op_matcher {
124  bool match(Operation *op) { return isa<OpClass>(op); }
125 };
126 
129 template <typename T, typename OperationOrValue>
131  decltype(std::declval<T>().match(std::declval<OperationOrValue>()));
132 
134 template <typename MatcherClass>
135 typename std::enable_if_t<is_detected<detail::has_operation_or_value_matcher_t,
136  MatcherClass, Value>::value,
137  bool>
138 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
139  return matcher.match(op->getOperand(idx));
140 }
141 
143 template <typename MatcherClass>
144 typename std::enable_if_t<is_detected<detail::has_operation_or_value_matcher_t,
145  MatcherClass, Operation *>::value,
146  bool>
147 matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
148  if (auto defOp = op->getOperand(idx)->getDefiningOp())
149  return matcher.match(defOp);
150  return false;
151 }
152 
155  bool match(Value op) const { return true; }
156 };
157 
160  PatternMatcherValue(Value val) : value(val) {}
161  bool match(Value val) const { return val == value; }
163 };
164 
165 template <typename TupleT, class CallbackT, std::size_t... Is>
166 constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,
167  std::index_sequence<Is...>) {
168  (void)std::initializer_list<int>{
169  0,
170  (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)),
171  0)...};
172 }
173 
174 template <typename... Tys, typename CallbackT>
175 constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {
176  detail::enumerateImpl(tuple, std::forward<CallbackT>(callback),
177  std::make_index_sequence<sizeof...(Tys)>{});
178 }
179 
181 template <typename OpType, typename... OperandMatchers>
183  RecursivePatternMatcher(OperandMatchers... matchers)
184  : operandMatchers(matchers...) {}
185  bool match(Operation *op) {
186  if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers))
187  return false;
188  bool res = true;
189  enumerate(operandMatchers, [&](size_t index, auto &matcher) {
190  res &= matchOperandOrValueAtIndex(op, index, matcher);
191  });
192  return res;
193  }
194  std::tuple<OperandMatchers...> operandMatchers;
195 };
196 
197 } // end namespace detail
198 
201 template <typename AttrT>
204 }
205 
209 }
210 
212 template <typename OpClass> inline detail::op_matcher<OpClass> m_Op() {
214 }
215 
219 }
220 
225 }
226 
228 template <typename Pattern>
229 inline bool matchPattern(Value value, const Pattern &pattern) {
230  // TODO: handle other cases
231  if (auto *op = value->getDefiningOp())
232  return const_cast<Pattern &>(pattern).match(op);
233  return false;
234 }
235 
237 template <typename Pattern>
238 inline bool matchPattern(Operation *op, const Pattern &pattern) {
239  return const_cast<Pattern &>(pattern).match(op);
240 }
241 
246  return detail::constant_int_op_binder(bind_value);
247 }
248 
249 template <typename OpType, typename... Matchers>
250 auto m_Op(Matchers... matchers) {
251  return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...);
252 }
253 
254 namespace matchers {
255 inline auto m_Any() { return detail::AnyValueMatcher(); }
256 inline auto m_Val(Value v) { return detail::PatternMatcherValue(v); }
257 } // namespace matchers
258 
259 } // end namespace mlir
260 
261 #endif // MLIR_MATCHERS_H
Definition: InferTypeOpInterface.cpp:20
constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback, std::index_sequence< Is... >)
Definition: Matchers.h:166
Definition: Matchers.h:53
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Definition: Matchers.h:245
Definition: Operation.h:27
Binds to a specific value and matches it.
Definition: Matchers.h:159
detail::constant_int_not_value_matcher< 0 > m_NonZero()
Definition: Matchers.h:223
Value getOperand(unsigned idx)
Definition: Operation.h:207
AttrT * bind_value
Definition: Matchers.h:54
unsigned getNumOperands()
Definition: Operation.h:205
attr_value_binder(ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:40
ValueType * bind_value
Definition: Matchers.h:37
Value value
Definition: Matchers.h:162
constant_int_op_binder(IntegerAttr::ValueType *bv)
Creates a matcher instance that binds the value to bv if match succeeds.
Definition: Matchers.h:83
bool succeeded(LogicalResult result)
Definition: LogicalResult.h:39
detail::constant_int_value_matcher< 1 > m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:207
Definition: StandardTypes.h:314
PatternMatcherValue(Value val)
Definition: Matchers.h:160
Definition: PatternMatch.h:82
auto m_Any()
Definition: Matchers.h:255
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.cpp:548
RecursivePatternMatcher that composes.
Definition: Matchers.h:182
bool match(Operation *op)
Definition: Matchers.h:85
bool hasNoSideEffect()
Returns whether the operation has side-effects.
Definition: Operation.h:456
std::tuple< OperandMatchers... > operandMatchers
Definition: Matchers.h:194
bool match(Value op) const
Definition: Matchers.h:155
bool match(Operation *op)
Definition: Matchers.h:60
detail::constant_op_binder< AttrT > m_Constant(AttrT *bind_value)
Definition: Matchers.h:202
Type getType() const
Return the type of this value.
Definition: Value.cpp:34
std::enable_if_t< is_detected< detail::has_operation_or_value_matcher_t, MatcherClass, Value >::value, bool > matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher)
Statically switch to a Value matcher.
Definition: Matchers.h:138
typename detail::detector< void, Op, Args... >::value_t is_detected
Definition: STLExtras.h:125
Definition: Attributes.h:53
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:175
bool match(Operation *op)
Definition: Matchers.h:185
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:246
bool match(Operation *op)
Definition: Matchers.h:107
IntegerAttr::ValueType * bind_value
Definition: Matchers.h:80
decltype(std::declval< T >().match(std::declval< OperationOrValue >())) has_operation_or_value_matcher_t
Definition: Matchers.h:131
Definition: Value.h:38
Definition: Attributes.h:1193
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition: Matchers.h:212
Definition: LLVM.h:35
The matcher that matches a certain kind of op.
Definition: Matchers.h:123
Definition: StandardTypes.h:256
Terminal matcher, always returns true.
Definition: Matchers.h:154
RecursivePatternMatcher(OperandMatchers... matchers)
Definition: Matchers.h:183
Definition: Matchers.h:79
U dyn_cast() const
Definition: Attributes.h:1347
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:229
bool match(Operation *op)
Definition: Matchers.h:124
constant_op_binder(AttrT *bind_value)
Definition: Matchers.h:58
bool match(Value val) const
Definition: Matchers.h:161
Operation * getDefiningOp() const
Definition: Value.cpp:71
auto m_Val(Value v)
Definition: Matchers.h:256
APInt ValueType
Definition: Attributes.h:348
detail::constant_int_value_matcher< 0 > m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:217
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation using the Op&#39;s registered foldHook.
Definition: Operation.cpp:603
Definition: StandardTypes.h:63
bool match(Operation *op)
Definition: Matchers.h:116
bool match(const Attribute &attr)
Definition: Matchers.h:42
Definition: Matchers.h:36