My Project
PatternMatch.h
Go to the documentation of this file.
1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_PATTERNMATCHER_H
10 #define MLIR_PATTERNMATCHER_H
11 
12 #include "mlir/IR/Builders.h"
13 
14 namespace mlir {
15 
16 class PatternRewriter;
17 
18 //===----------------------------------------------------------------------===//
19 // PatternBenefit class
20 //===----------------------------------------------------------------------===//
21 
30  enum { ImpossibleToMatchSentinel = 65535 };
31 
32 public:
33  /*implicit*/ PatternBenefit(unsigned benefit);
34  PatternBenefit(const PatternBenefit &) = default;
35  PatternBenefit &operator=(const PatternBenefit &) = default;
36 
38  bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
39 
41  // corresponding pattern isImpossibleToMatch() then this aborts.
42  unsigned short getBenefit() const;
43 
44  bool operator==(const PatternBenefit &rhs) const {
45  return representation == rhs.representation;
46  }
47  bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
48  bool operator<(const PatternBenefit &rhs) const {
49  return representation < rhs.representation;
50  }
51 
52 private:
53  PatternBenefit() : representation(ImpossibleToMatchSentinel) {}
54  unsigned short representation;
55 };
56 
60 class PatternState {
61 public:
62  virtual ~PatternState() {}
63 
64 protected:
65  // Must be subclassed.
67 };
68 
73 
74 //===----------------------------------------------------------------------===//
75 // Pattern class
76 //===----------------------------------------------------------------------===//
77 
82 class Pattern {
83 public:
89  PatternBenefit getBenefit() const { return benefit; }
90 
93  OperationName getRootKind() const { return rootKind; }
94 
95  //===--------------------------------------------------------------------===//
96  // Implementation hooks for patterns to implement.
97  //===--------------------------------------------------------------------===//
98 
103  virtual PatternMatchResult match(Operation *op) const = 0;
104 
105  virtual ~Pattern() {}
106 
107  //===--------------------------------------------------------------------===//
108  // Helper methods to simplify pattern implementations
109  //===--------------------------------------------------------------------===//
110 
112  static PatternMatchResult matchFailure() { return None; }
113 
116  matchSuccess(std::unique_ptr<PatternState> state = {}) const {
117  return PatternMatchResult(std::move(state));
118  }
119 
120 protected:
123  Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
124 
125 private:
126  const OperationName rootKind;
127  const PatternBenefit benefit;
128 
129  virtual void anchor();
130 };
131 
142 class RewritePattern : public Pattern {
143 public:
149  virtual void rewrite(Operation *op, std::unique_ptr<PatternState> state,
150  PatternRewriter &rewriter) const;
151 
157  virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
158 
164  PatternMatchResult match(Operation *op) const override;
165 
170  PatternRewriter &rewriter) const {
171  if (auto matchResult = match(op)) {
172  rewrite(op, std::move(*matchResult), rewriter);
173  return matchSuccess();
174  }
175  return matchFailure();
176  }
177 
180  ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
181 
182 protected:
185  RewritePattern(StringRef rootName, PatternBenefit benefit,
186  MLIRContext *context)
187  : Pattern(rootName, benefit, context) {}
191  RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
192  PatternBenefit benefit, MLIRContext *context);
193 
197 };
198 
202 template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
206  : RewritePattern(SourceOp::getOperationName(), benefit, context) {}
207 
209  void rewrite(Operation *op, std::unique_ptr<PatternState> state,
210  PatternRewriter &rewriter) const final {
211  rewrite(cast<SourceOp>(op), std::move(state), rewriter);
212  }
213  void rewrite(Operation *op, PatternRewriter &rewriter) const final {
214  rewrite(cast<SourceOp>(op), rewriter);
215  }
216  PatternMatchResult match(Operation *op) const final {
217  return match(cast<SourceOp>(op));
218  }
220  PatternRewriter &rewriter) const final {
221  return matchAndRewrite(cast<SourceOp>(op), rewriter);
222  }
223 
226  virtual void rewrite(SourceOp op, std::unique_ptr<PatternState> state,
227  PatternRewriter &rewriter) const {
228  rewrite(op, rewriter);
229  }
230  virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
231  llvm_unreachable("must override matchAndRewrite or a rewrite method");
232  }
233  virtual PatternMatchResult match(SourceOp op) const {
234  llvm_unreachable("must override match or matchAndRewrite");
235  }
236  virtual PatternMatchResult matchAndRewrite(SourceOp op,
237  PatternRewriter &rewriter) const {
238  if (auto matchResult = match(op)) {
239  rewrite(op, std::move(*matchResult), rewriter);
240  return matchSuccess();
241  }
242  return matchFailure();
243  }
244 };
245 
246 //===----------------------------------------------------------------------===//
247 // PatternRewriter class
248 //===----------------------------------------------------------------------===//
249 
260 class PatternRewriter : public OpBuilder {
261 public:
264  template <typename OpTy, typename... Args>
265  OpTy create(Location location, Args... args) {
266  OperationState state(location, OpTy::getOperationName());
267  OpTy::build(this, state, args...);
268  auto *op = createOperation(state);
269  auto result = dyn_cast<OpTy>(op);
270  assert(result && "Builder didn't return the right type");
271  return result;
272  }
273 
277  template <typename OpTy, typename... Args>
278  OpTy createChecked(Location location, Args... args) {
279  OperationState state(location, OpTy::getOperationName());
280  OpTy::build(this, state, args...);
281  auto *op = createOperation(state);
282 
283  // If the Operation we produce is valid, return it.
284  if (!OpTy::verifyInvariants(op)) {
285  auto result = dyn_cast<OpTy>(op);
286  assert(result && "Builder didn't return the right type");
287  return result;
288  }
289 
290  // Otherwise, the error message got emitted. Just remove the operation
291  // we made.
292  op->erase();
293  return OpTy();
294  }
295 
298  virtual Operation *insert(Operation *op) = 0;
299 
304  virtual void inlineRegionBefore(Region &region, Region &parent,
305  Region::iterator before);
306  void inlineRegionBefore(Region &region, Block *before);
307 
312  virtual void cloneRegionBefore(Region &region, Region &parent,
313  Region::iterator before,
314  BlockAndValueMapping &mapping);
315  void cloneRegionBefore(Region &region, Region &parent,
316  Region::iterator before);
317  void cloneRegionBefore(Region &region, Block *before);
318 
325  virtual void replaceOp(Operation *op, ValueRange newValues,
326  ValueRange valuesToRemoveIfDead);
327  void replaceOp(Operation *op, ValueRange newValues) {
328  replaceOp(op, newValues, llvm::None);
329  }
330 
333  template <typename OpTy, typename... Args>
334  void replaceOpWithNewOp(Operation *op, Args &&... args) {
335  auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
336  replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), {});
337  }
338 
342  template <typename OpTy, typename... Args>
343  void replaceOpWithNewOp(ValueRange valuesToRemoveIfDead, Operation *op,
344  Args &&... args) {
345  auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
346  replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(),
347  valuesToRemoveIfDead);
348  }
349 
351  virtual void eraseOp(Operation *op);
352 
357  virtual void mergeBlocks(Block *source, Block *dest,
358  ValueRange argValues = llvm::None);
359 
362  virtual Block *splitBlock(Block *block, Block::iterator before);
363 
369  virtual void startRootUpdate(Operation *op) {}
370 
374  virtual void finalizeRootUpdate(Operation *op) {}
375 
378  virtual void cancelRootUpdate(Operation *op) {}
379 
383  template <typename CallableT>
384  void updateRootInPlace(Operation *root, CallableT &&callable) {
385  startRootUpdate(root);
386  callable();
387  finalizeRootUpdate(root);
388  }
389 
390 protected:
391  explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
392  virtual ~PatternRewriter();
393 
394  // These are the callback methods that subclasses can choose to implement if
395  // they would like to be notified about certain types of mutations.
396 
400  virtual void notifyRootReplaced(Operation *op) {}
401 
405  virtual void notifyOperationRemoved(Operation *op) {}
406 
407 private:
410  void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp,
411  ValueRange valuesToRemoveIfDead);
412 };
413 
414 //===----------------------------------------------------------------------===//
415 // Pattern-driven rewriters
416 //===----------------------------------------------------------------------===//
417 
419  using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
420 
421 public:
422  PatternListT::iterator begin() { return patterns.begin(); }
423  PatternListT::iterator end() { return patterns.end(); }
424  PatternListT::const_iterator begin() const { return patterns.begin(); }
425  PatternListT::const_iterator end() const { return patterns.end(); }
426  void clear() { patterns.clear(); }
427 
428  //===--------------------------------------------------------------------===//
429  // Pattern Insertion
430  //===--------------------------------------------------------------------===//
431 
435  template <typename... Ts, typename ConstructorArg,
436  typename... ConstructorArgs,
437  typename = std::enable_if_t<sizeof...(Ts) != 0>>
438  void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
439  // The following expands a call to emplace_back for each of the pattern
440  // types 'Ts'. This magic is necessary due to a limitation in the places
441  // that a parameter pack can be expanded in c++11.
442  // FIXME: In c++17 this can be simplified by using 'fold expressions'.
443  using dummy = int[];
444  (void)dummy{
445  0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...};
446  }
447 
448 private:
449  PatternListT patterns;
450 };
451 
457 public:
459  explicit RewritePatternMatcher(const OwningRewritePatternList &patterns);
460 
463  bool matchAndRewrite(Operation *op, PatternRewriter &rewriter);
464 
465 private:
467  void operator=(const RewritePatternMatcher &) = delete;
468 
471  std::vector<RewritePattern *> patterns;
472 };
473 
483  const OwningRewritePatternList &patterns);
486  const OwningRewritePatternList &patterns);
487 } // end namespace mlir
488 
489 #endif // MLIR_PATTERN_MATCH_H
bool operator<(const PatternBenefit &rhs) const
Definition: PatternMatch.h:48
Definition: InferTypeOpInterface.cpp:20
virtual void notifyOperationRemoved(Operation *op)
Definition: PatternMatch.h:405
Definition: Region.h:23
virtual PatternMatchResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const
Definition: PatternMatch.h:236
Definition: PatternMatch.h:260
PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Definition: PatternMatch.h:219
Definition: Operation.h:27
virtual void notifyRootReplaced(Operation *op)
Definition: PatternMatch.h:400
virtual void finalizeRootUpdate(Operation *op)
Definition: PatternMatch.h:374
OperationName getRootKind() const
Definition: PatternMatch.h:93
virtual void cancelRootUpdate(Operation *op)
Definition: PatternMatch.h:378
Block represents an ordered list of Operations.
Definition: Block.h:21
virtual ~Pattern()
Definition: PatternMatch.h:105
PatternListT::iterator begin()
Definition: PatternMatch.h:422
Definition: PatternMatch.h:456
BlockListType::iterator iterator
Definition: Region.h:41
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const
Definition: PatternMatch.h:230
OpTy create(Location location, Args... args)
Definition: PatternMatch.h:265
PatternListT::const_iterator end() const
Definition: PatternMatch.h:425
Optional< std::unique_ptr< PatternState > > PatternMatchResult
Definition: PatternMatch.h:72
Definition: LLVM.h:40
PatternBenefit & operator=(const PatternBenefit &)=default
virtual PatternMatchResult match(SourceOp op) const
Definition: PatternMatch.h:233
ArrayRef< OperationName > getGeneratedOps() const
Definition: PatternMatch.h:180
void insert(ConstructorArg &&arg, ConstructorArgs &&... args)
Definition: PatternMatch.h:438
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:501
Definition: Location.h:52
Definition: PatternMatch.h:82
Definition: PatternMatch.h:142
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
Definition: PatternMatch.cpp:20
static PatternBenefit impossibleToMatch()
Definition: PatternMatch.h:37
Definition: LLVM.h:37
void replaceOp(Operation *op, ValueRange newValues)
Definition: PatternMatch.h:327
OpListType::iterator iterator
Definition: Block.h:107
void replaceOpWithNewOp(ValueRange valuesToRemoveIfDead, Operation *op, Args &&... args)
Definition: PatternMatch.h:343
Definition: PatternMatch.h:29
bool applyPatternsGreedily(Operation *op, const OwningRewritePatternList &patterns)
Definition: GreedyPatternRewriteDriver.cpp:218
Definition: LLVM.h:38
void updateRootInPlace(Operation *root, CallableT &&callable)
Definition: PatternMatch.h:384
void clear()
Definition: PatternMatch.h:426
PatternBenefit getBenefit() const
Definition: PatternMatch.h:89
void rewrite(Operation *op, PatternRewriter &rewriter) const final
Definition: PatternMatch.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:107
Definition: OperationSupport.h:261
PatternRewriter(MLIRContext *ctx)
Definition: PatternMatch.h:391
Definition: PatternMatch.h:60
SmallVector< OperationName, 2 > generatedOps
Definition: PatternMatch.h:196
Definition: BlockAndValueMapping.h:26
PatternMatchResult matchSuccess(std::unique_ptr< PatternState > state={}) const
This method indicates that a match was found and has the specified cost.
Definition: PatternMatch.h:116
virtual PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const
Definition: PatternMatch.h:169
PatternState()
Definition: PatternMatch.h:66
Definition: LLVM.h:35
void replaceOpWithNewOp(Operation *op, Args &&... args)
Definition: PatternMatch.h:334
Definition: PatternMatch.h:202
Definition: PatternMatch.h:418
OpTy createChecked(Location location, Args... args)
Definition: PatternMatch.h:278
bool isImpossibleToMatch() const
Definition: PatternMatch.h:38
PatternMatchResult match(Operation *op) const final
Definition: PatternMatch.h:216
Definition: MLIRContext.h:34
bool operator==(const PatternBenefit &rhs) const
Definition: PatternMatch.h:44
static PatternMatchResult matchFailure()
This method indicates that no match was found.
Definition: PatternMatch.h:112
PatternListT::iterator end()
Definition: PatternMatch.h:423
Definition: StandardTypes.h:63
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:205
virtual void startRootUpdate(Operation *op)
Definition: PatternMatch.h:369
RewritePattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context)
Definition: PatternMatch.h:185
void rewrite(Operation *op, std::unique_ptr< PatternState > state, PatternRewriter &rewriter) const final
Wrappers around the RewritePattern methods that pass the derived op type.
Definition: PatternMatch.h:209
virtual void rewrite(SourceOp op, std::unique_ptr< PatternState > state, PatternRewriter &rewriter) const
Definition: PatternMatch.h:226
Definition: OperationSupport.h:203
virtual ~PatternState()
Definition: PatternMatch.h:62
Definition: Builders.h:158
Definition: OperationSupport.h:640
bool operator!=(const PatternBenefit &rhs) const
Definition: PatternMatch.h:47
PatternListT::const_iterator begin() const
Definition: PatternMatch.h:424