My Project
Pattern.h
Go to the documentation of this file.
1 //===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_TABLEGEN_PATTERN_H_
15 #define MLIR_TABLEGEN_PATTERN_H_
16 
17 #include "mlir/Support/LLVM.h"
18 #include "mlir/TableGen/Argument.h"
19 #include "mlir/TableGen/Operator.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/StringMap.h"
22 #include "llvm/ADT/StringSet.h"
23 
24 namespace llvm {
25 class DagInit;
26 class Init;
27 class Record;
28 } // end namespace llvm
29 
30 namespace mlir {
31 namespace tblgen {
32 
33 // Mapping from TableGen Record to Operator wrapper object.
34 //
35 // We allocate each wrapper object in heap to make sure the pointer to it is
36 // valid throughout the lifetime of this map. This is important because this map
37 // is shared among multiple patterns to avoid creating the wrapper object for
38 // the same op again and again. But this map will continuously grow.
39 using RecordOperatorMap =
41 
42 class Pattern;
43 
44 // Wrapper class providing helper methods for accessing TableGen DAG leaves
45 // used inside Patterns. This class is lightweight and designed to be used like
46 // values.
47 //
48 // A TableGen DAG construct is of the syntax
49 // `(operator, arg0, arg1, ...)`.
50 //
51 // This class provides getters to retrieve `arg*` as tblgen:: wrapper objects
52 // for handy helper methods. It only works on `arg*`s that are not nested DAG
53 // constructs.
54 class DagLeaf {
55 public:
56  explicit DagLeaf(const llvm::Init *def) : def(def) {}
57 
58  // Returns true if this DAG leaf is not specified in the pattern. That is, it
59  // places no further constraints/transforms and just carries over the original
60  // value.
61  bool isUnspecified() const;
62 
63  // Returns true if this DAG leaf is matching an operand. That is, it specifies
64  // a type constraint.
65  bool isOperandMatcher() const;
66 
67  // Returns true if this DAG leaf is matching an attribute. That is, it
68  // specifies an attribute constraint.
69  bool isAttrMatcher() const;
70 
71  // Returns true if this DAG leaf is wrapping native code call.
72  bool isNativeCodeCall() const;
73 
74  // Returns true if this DAG leaf is specifying a constant attribute.
75  bool isConstantAttr() const;
76 
77  // Returns true if this DAG leaf is specifying an enum attribute case.
78  bool isEnumAttrCase() const;
79 
80  // Returns this DAG leaf as a constraint. Asserts if fails.
81  Constraint getAsConstraint() const;
82 
83  // Returns this DAG leaf as an constant attribute. Asserts if fails.
84  ConstantAttr getAsConstantAttr() const;
85 
86  // Returns this DAG leaf as an enum attribute case.
87  // Precondition: isEnumAttrCase()
88  EnumAttrCase getAsEnumAttrCase() const;
89 
90  // Returns the matching condition template inside this DAG leaf. Assumes the
91  // leaf is an operand/attribute matcher and asserts otherwise.
92  std::string getConditionTemplate() const;
93 
94  // Returns the native code call template inside this DAG leaf.
95  // Precondition: isNativeCodeCall()
96  StringRef getNativeCodeTemplate() const;
97 
98  void print(raw_ostream &os) const;
99 
100 private:
101  // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
102  // also a subclass of the given `superclass`.
103  bool isSubClassOf(StringRef superclass) const;
104 
105  const llvm::Init *def;
106 };
107 
108 // Wrapper class providing helper methods for accessing TableGen DAG constructs
109 // used inside Patterns. This class is lightweight and designed to be used like
110 // values.
111 //
112 // A TableGen DAG construct is of the syntax
113 // `(operator, arg0, arg1, ...)`.
114 //
115 // When used inside Patterns, `operator` corresponds to some dialect op, or
116 // a known list of verbs that defines special transformation actions. This
117 // `arg*` can be a nested DAG construct. This class provides getters to
118 // retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper
119 // methods.
120 //
121 // A null DagNode contains a nullptr and converts to false implicitly.
122 class DagNode {
123 public:
124  explicit DagNode(const llvm::DagInit *node) : node(node) {}
125 
126  // Implicit bool converter that returns true if this DagNode is not a null
127  // DagNode.
128  operator bool() const { return node != nullptr; }
129 
130  // Returns the symbol bound to this DAG node.
131  StringRef getSymbol() const;
132 
133  // Returns the operator wrapper object corresponding to the dialect op matched
134  // by this DAG. The operator wrapper will be queried from the given `mapper`
135  // and created in it if not existing.
136  Operator &getDialectOp(RecordOperatorMap *mapper) const;
137 
138  // Returns the number of operations recursively involved in the DAG tree
139  // rooted from this node.
140  int getNumOps() const;
141 
142  // Returns the number of immediate arguments to this DAG node.
143  int getNumArgs() const;
144 
145  // Returns true if the `index`-th argument is a nested DAG construct.
146  bool isNestedDagArg(unsigned index) const;
147 
148  // Gets the `index`-th argument as a nested DAG construct if possible. Returns
149  // null DagNode otherwise.
150  DagNode getArgAsNestedDag(unsigned index) const;
151 
152  // Gets the `index`-th argument as a DAG leaf.
153  DagLeaf getArgAsLeaf(unsigned index) const;
154 
155  // Returns the specified name of the `index`-th argument.
156  StringRef getArgName(unsigned index) const;
157 
158  // Returns true if this DAG construct means to replace with an existing SSA
159  // value.
160  bool isReplaceWithValue() const;
161 
162  // Returns true if this DAG node is wrapping native code call.
163  bool isNativeCodeCall() const;
164 
165  // Returns true if this DAG node is an operation.
166  bool isOperation() const;
167 
168  // Returns the native code call template inside this DAG node.
169  // Precondition: isNativeCodeCall()
170  StringRef getNativeCodeTemplate() const;
171 
172  void print(raw_ostream &os) const;
173 
174 private:
175  const llvm::DagInit *node; // nullptr means null DagNode
176 };
177 
178 // A class for maintaining information for symbols bound in patterns and
179 // provides methods for resolving them according to specific use cases.
180 //
181 // Symbols can be bound to
182 //
183 // * Op arguments and op results in the source pattern and
184 // * Op results in result patterns.
185 //
186 // Symbols can be referenced in result patterns and additional constraints to
187 // the pattern.
188 //
189 // For example, in
190 //
191 // ```
192 // def : Pattern<
193 // (SrcOp:$results1 $arg0, %arg1),
194 // [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>;
195 // ```
196 //
197 // `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to
198 // `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build
199 // `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`.
200 //
201 // If a symbol binds to a multi-result op and it does not have the `__N`
202 // suffix, the symbol is expanded to represent all results generated by the
203 // multi-result op. If the symbol has a `__N` suffix, then it will expand to
204 // only the N-th *static* result as declared in ODS, and that can still
205 // corresponds to multiple *dynamic* values if the N-th *static* result is
206 // variadic.
207 //
208 // This class keeps track of such symbols and resolves them into their bound
209 // values in a suitable way.
211 public:
212  explicit SymbolInfoMap(ArrayRef<llvm::SMLoc> loc) : loc(loc) {}
213 
214  // Class for information regarding a symbol.
215  class SymbolInfo {
216  public:
217  // Returns a string for defining a variable named as `name` to store the
218  // value bound by this symbol.
219  std::string getVarDecl(StringRef name) const;
220 
221  private:
222  // Allow SymbolInfoMap to access private methods.
223  friend class SymbolInfoMap;
224 
225  // What kind of entity this symbol represents:
226  // * Attr: op attribute
227  // * Operand: op operand
228  // * Result: op result
229  // * Value: a value not attached to an op (e.g., from NativeCodeCall)
230  enum class Kind : uint8_t { Attr, Operand, Result, Value };
231 
232  // Creates a SymbolInfo instance. `index` is only used for `Attr` and
233  // `Operand` so should be negative for `Result` and `Value` kind.
234  SymbolInfo(const Operator *op, Kind kind, Optional<int> index);
235 
236  // Static methods for creating SymbolInfo.
237  static SymbolInfo getAttr(const Operator *op, int index) {
238  return SymbolInfo(op, Kind::Attr, index);
239  }
240  static SymbolInfo getOperand(const Operator *op, int index) {
241  return SymbolInfo(op, Kind::Operand, index);
242  }
243  static SymbolInfo getResult(const Operator *op) {
244  return SymbolInfo(op, Kind::Result, llvm::None);
245  }
246  static SymbolInfo getValue() {
247  return SymbolInfo(nullptr, Kind::Value, llvm::None);
248  }
249 
250  // Returns the number of static values this symbol corresponds to.
251  // A static value is an operand/result declared in ODS. Normally a symbol
252  // only represents one static value, but symbols bound to op results can
253  // represent more than one if the op is a multi-result op.
254  int getStaticValueCount() const;
255 
256  // Returns a string containing the C++ expression for referencing this
257  // symbol as a value (if this symbol represents one static value) or a value
258  // range (if this symbol represents multiple static values). `name` is the
259  // name of the C++ variable that this symbol bounds to. `index` should only
260  // be used for indexing results. `fmt` is used to format each value.
261  // `separator` is used to separate values if this is a value range.
262  std::string getValueAndRangeUse(StringRef name, int index, const char *fmt,
263  const char *separator) const;
264 
265  // Returns a string containing the C++ expression for referencing this
266  // symbol as a value range regardless of how many static values this symbol
267  // represents. `name` is the name of the C++ variable that this symbol
268  // bounds to. `index` should only be used for indexing results. `fmt` is
269  // used to format each value. `separator` is used to separate values in the
270  // range.
271  std::string getAllRangeUse(StringRef name, int index, const char *fmt,
272  const char *separator) const;
273 
274  const Operator *op; // The op where the bound entity belongs
275  Kind kind; // The kind of the bound entity
276  // The argument index (for `Attr` and `Operand` only)
277  Optional<int> argIndex;
278  };
279 
280  using BaseT = llvm::StringMap<SymbolInfo>;
281 
282  // Iterators for accessing all symbols.
283  using iterator = BaseT::iterator;
284  iterator begin() { return symbolInfoMap.begin(); }
285  iterator end() { return symbolInfoMap.end(); }
286 
287  // Const iterators for accessing all symbols.
288  using const_iterator = BaseT::const_iterator;
289  const_iterator begin() const { return symbolInfoMap.begin(); }
290  const_iterator end() const { return symbolInfoMap.end(); }
291 
292  // Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
293  // Returns false if `symbol` is already bound.
294  bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
295 
296  // Binds the given `symbol` to the results the given `op`. Returns false if
297  // `symbol` is already bound.
298  bool bindOpResult(StringRef symbol, const Operator &op);
299 
300  // Registers the given `symbol` as bound to a value. Returns false if `symbol`
301  // is already bound.
302  bool bindValue(StringRef symbol);
303 
304  // Returns true if the given `symbol` is bound.
305  bool contains(StringRef symbol) const;
306 
307  // Returns an iterator to the information of the given symbol named as `key`.
308  const_iterator find(StringRef key) const;
309 
310  // Returns the number of static values of the given `symbol` corresponds to.
311  // A static value is a operand/result declared in ODS. Normally a symbol only
312  // represents one static value, but symbols bound to op results can represent
313  // more than one if the op is a multi-result op.
314  int getStaticValueCount(StringRef symbol) const;
315 
316  // Returns a string containing the C++ expression for referencing this
317  // symbol as a value (if this symbol represents one static value) or a value
318  // range (if this symbol represents multiple static values). `fmt` is used to
319  // format each value. `separator` is used to separate values if `symbol`
320  // represents a value range.
321  std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}",
322  const char *separator = ", ") const;
323 
324  // Returns a string containing the C++ expression for referencing this
325  // symbol as a value range regardless of how many static values this symbol
326  // represents. `fmt` is used to format each value. `separator` is used to
327  // separate values in the range.
328  std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
329  const char *separator = ", ") const;
330 
331  // Splits the given `symbol` into a value pack name and an index. Returns the
332  // value pack name and writes the index to `index` on success. Returns
333  // `symbol` itself if it does not contain an index.
334  //
335  // We can use `name__N` to access the `N`-th value in the value pack bound to
336  // `name`. `name` is typically the results of an multi-result op.
337  static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
338 
339 private:
340  llvm::StringMap<SymbolInfo> symbolInfoMap;
341 
342  // Pattern instantiation location. This is intended to be used as parameter
343  // to PrintFatalError() to report errors.
345 };
346 
347 // Wrapper class providing helper methods for accessing MLIR Pattern defined
348 // in TableGen. This class should closely reflect what is defined as class
349 // `Pattern` in TableGen. This class contains maps so it is not intended to be
350 // used as values.
351 class Pattern {
352 public:
353  explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper);
354 
355  // Returns the source pattern to match.
356  DagNode getSourcePattern() const;
357 
358  // Returns the number of result patterns generated by applying this rewrite
359  // rule.
360  int getNumResultPatterns() const;
361 
362  // Returns the DAG tree root node of the `index`-th result pattern.
363  DagNode getResultPattern(unsigned index) const;
364 
365  // Collects all symbols bound in the source pattern into `infoMap`.
366  void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap);
367 
368  // Collects all symbols bound in result patterns into `infoMap`.
369  void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap);
370 
371  // Returns the op that the root node of the source pattern matches.
372  const Operator &getSourceRootOp();
373 
374  // Returns the operator wrapper object corresponding to the given `node`'s DAG
375  // operator.
376  Operator &getDialectOp(DagNode node);
377 
378  // Returns the constraints.
379  std::vector<AppliedConstraint> getConstraints() const;
380 
381  // Returns the benefit score of the pattern.
382  int getBenefit() const;
383 
384  using IdentifierLine = std::pair<StringRef, unsigned>;
385 
386  // Returns the file location of the pattern (buffer identifier + line number
387  // pair).
388  std::vector<IdentifierLine> getLocation() const;
389 
390 private:
391  // Recursively collects all bound symbols inside the DAG tree rooted
392  // at `tree` and updates the given `infoMap`.
393  void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
394  bool isSrcPattern);
395 
396  // The TableGen definition of this pattern.
397  const llvm::Record &def;
398 
399  // All operators.
400  // TODO(antiagainst): we need a proper context manager, like MLIRContext,
401  // for managing the lifetime of shared entities.
402  RecordOperatorMap *recordOpMap;
403 };
404 
405 } // end namespace tblgen
406 } // end namespace mlir
407 
408 #endif // MLIR_TABLEGEN_PATTERN_H_
Definition: InferTypeOpInterface.cpp:20
Definition: PassRegistry.cpp:413
Definition: LLVM.h:48
Definition: Attribute.h:126
Definition: LLVM.h:40
iterator begin()
Definition: Pattern.h:284
Definition: Operator.h:41
Definition: Pattern.h:122
Definition: LLVM.h:37
Definition: Constraint.h:30
llvm::StringMap< SymbolInfo > BaseT
Definition: Pattern.h:280
BaseT::const_iterator const_iterator
Definition: Pattern.h:288
DenseMap< const llvm::Record *, std::unique_ptr< Operator > > RecordOperatorMap
Definition: Pattern.h:40
Definition: Pattern.h:210
Definition: Value.h:38
DagLeaf(const llvm::Init *def)
Definition: Pattern.h:56
std::pair< StringRef, unsigned > IdentifierLine
Definition: Pattern.h:384
Definition: Pattern.h:351
const_iterator begin() const
Definition: Pattern.h:289
BaseT::iterator iterator
Definition: Pattern.h:283
void print(OpAsmPrinter &p, AffineIfOp op)
Definition: AffineOps.cpp:1671
Definition: StandardTypes.h:63
Definition: Pattern.h:54
iterator end()
Definition: Pattern.h:285
Definition: Attribute.h:108
SymbolInfoMap(ArrayRef< llvm::SMLoc > loc)
Definition: Pattern.h:212
DagNode(const llvm::DagInit *node)
Definition: Pattern.h:124
const_iterator end() const
Definition: Pattern.h:290