My Project
FoldUtils.h
Go to the documentation of this file.
1 //===- FoldUtils.h - Operation Fold Utilities -------------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file declares various operation folding utilities. These
10 // utilities are intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TRANSFORMS_FOLDUTILS_H
15 #define MLIR_TRANSFORMS_FOLDUTILS_H
16 
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Dialect.h"
20 
21 namespace mlir {
22 class Operation;
23 class Value;
24 
25 //===--------------------------------------------------------------------===//
26 // Operation Folding Interface
27 //===--------------------------------------------------------------------===//
28 
32  : public DialectInterface::Base<OpFolderDialectInterface> {
33 public:
34  OpFolderDialectInterface(Dialect *dialect) : Base(dialect) {}
35 
41  virtual bool shouldMaterializeInto(Region *region) const { return false; }
42 };
43 
44 //===--------------------------------------------------------------------===//
45 // OperationFolder
46 //===--------------------------------------------------------------------===//
47 
51 public:
52  OperationFolder(MLIRContext *ctx) : interfaces(ctx) {}
53 
61  tryToFold(Operation *op,
62  function_ref<void(Operation *)> processGeneratedConstants = nullptr,
63  function_ref<void(Operation *)> preReplaceAction = nullptr);
64 
70  void notifyRemoval(Operation *op);
71 
75  template <typename OpTy, typename... Args>
76  void create(OpBuilder &builder, SmallVectorImpl<Value> &results,
77  Location location, Args &&... args) {
78  Operation *op = builder.create<OpTy>(location, std::forward<Args>(args)...);
79  if (failed(tryToFold(op, results)))
80  results.assign(op->result_begin(), op->result_end());
81  else if (op->getNumResults() != 0)
82  op->erase();
83  }
84 
86  template <typename OpTy, typename... Args>
87  typename std::enable_if<OpTy::template hasTrait<OpTrait::OneResult>(),
88  Value>::type
89  create(OpBuilder &builder, Location location, Args &&... args) {
90  SmallVector<Value, 1> results;
91  create<OpTy>(builder, results, location, std::forward<Args>(args)...);
92  return results.front();
93  }
94 
96  template <typename OpTy, typename... Args>
97  typename std::enable_if<OpTy::template hasTrait<OpTrait::ZeroResult>(),
98  OpTy>::type
99  create(OpBuilder &builder, Location location, Args &&... args) {
100  auto op = builder.create<OpTy>(location, std::forward<Args>(args)...);
101  SmallVector<Value, 0> unused;
102  (void)tryToFold(op.getOperation(), unused);
103 
104  // Folding cannot remove a zero-result operation, so for convenience we
105  // continue to return it.
106  return op;
107  }
108 
109 private:
114  using ConstantMap =
116 
119  LogicalResult tryToFold(
120  Operation *op, SmallVectorImpl<Value> &results,
121  function_ref<void(Operation *)> processGeneratedConstants = nullptr);
122 
125  Operation *tryGetOrCreateConstant(ConstantMap &uniquedConstants,
126  Dialect *dialect, OpBuilder &builder,
127  Attribute value, Type type, Location loc);
128 
132 
136 
139 };
140 
141 } // end namespace mlir
142 
143 #endif // MLIR_TRANSFORMS_FOLDUTILS_H
Definition: InferTypeOpInterface.cpp:20
OperationFolder(MLIRContext *ctx)
Definition: FoldUtils.h:52
Definition: Region.h:23
virtual bool shouldMaterializeInto(Region *region) const
Definition: FoldUtils.h:41
Definition: Operation.h:27
void create(OpBuilder &builder, SmallVectorImpl< Value > &results, Location location, Args &&... args)
Definition: FoldUtils.h:76
Definition: DialectInterface.h:27
Definition: LLVM.h:48
OpFolderDialectInterface(Dialect *dialect)
Definition: FoldUtils.h:34
Definition: LLVM.h:49
bool failed(LogicalResult result)
Definition: LogicalResult.h:45
Definition: DialectInterface.h:150
std::enable_if< OpTy::template hasTrait< OpTrait::ZeroResult >), OpTy >::type create(OpBuilder &builder, Location location, Args &&... args)
Overload to create or fold a zero result operation.
Definition: FoldUtils.h:99
Definition: LLVM.h:34
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:501
Definition: Location.h:52
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.cpp:548
Definition: LogicalResult.h:18
OpTy create(Location location, Args &&... args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:294
Definition: Attributes.h:53
Definition: Dialect.h:39
Definition: Types.h:84
Definition: FoldUtils.h:31
Definition: Value.h:38
Definition: FoldUtils.h:50
Definition: LLVM.h:35
result_iterator result_end()
Definition: Operation.h:253
std::enable_if< OpTy::template hasTrait< OpTrait::OneResult >), Value >::type create(OpBuilder &builder, Location location, Args &&... args)
Overload to create or fold a single result operation.
Definition: FoldUtils.h:89
Definition: MLIRContext.h:34
Definition: Builders.h:158
result_iterator result_begin()
Definition: Operation.h:252