My Project
OpImplementation.h
Go to the documentation of this file.
1 //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This classes used by the implementation details of Op types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_OPIMPLEMENTATION_H
14 #define MLIR_IR_OPIMPLEMENTATION_H
15 
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/SMLoc.h"
20 #include "llvm/Support/raw_ostream.h"
21 
22 namespace mlir {
23 
24 class Builder;
25 
26 //===----------------------------------------------------------------------===//
27 // OpAsmPrinter
28 //===----------------------------------------------------------------------===//
29 
32 class OpAsmPrinter {
33 public:
35  virtual ~OpAsmPrinter();
36  virtual raw_ostream &getStream() const = 0;
37 
39  virtual void printOperand(Value value) = 0;
40 
42  template <typename ContainerType>
43  void printOperands(const ContainerType &container) {
44  printOperands(container.begin(), container.end());
45  }
46 
48  template <typename IteratorType>
49  void printOperands(IteratorType it, IteratorType end) {
50  if (it == end)
51  return;
52  printOperand(*it);
53  for (++it; it != end; ++it) {
54  getStream() << ", ";
55  printOperand(*it);
56  }
57  }
58  virtual void printType(Type type) = 0;
59  virtual void printAttribute(Attribute attr) = 0;
60 
63  virtual void printSuccessorAndUseList(Operation *term, unsigned index) = 0;
64 
70  ArrayRef<StringRef> elidedAttrs = {}) = 0;
71 
74  virtual void
76  ArrayRef<StringRef> elidedAttrs = {}) = 0;
77 
79  virtual void printGenericOp(Operation *op) = 0;
80 
82  virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
83  bool printBlockTerminators = true) = 0;
84 
89  virtual void shadowRegionArgs(Region &region, ValueRange namesToUse) = 0;
90 
95  virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
96  ValueRange operands) = 0;
97 
100  if (types.empty())
101  return;
102  auto &os = getStream() << " -> ";
103  bool wrapped = types.size() != 1 || types[0].isa<FunctionType>();
104  if (wrapped)
105  os << '(';
106  interleaveComma(types, *this);
107  if (wrapped)
108  os << ')';
109  }
110 
113  auto &os = getStream();
114  os << "(";
115  interleaveComma(op->getNonSuccessorOperands(), os, [&](Value operand) {
116  if (operand)
117  printType(operand->getType());
118  else
119  os << "<<NULL>";
120  });
121  os << ") -> ";
122  if (op->getNumResults() == 1 &&
123  !op->getResult(0)->getType().isa<FunctionType>()) {
124  printType(op->getResult(0)->getType());
125  } else {
126  os << '(';
127  interleaveComma(op->getResultTypes(), os);
128  os << ')';
129  }
130  }
131 
136  virtual void printSymbolName(StringRef symbolRef) = 0;
137 
138 private:
139  OpAsmPrinter(const OpAsmPrinter &) = delete;
140  void operator=(const OpAsmPrinter &) = delete;
141 };
142 
143 // Make the implementations convenient to use.
145  p.printOperand(value);
146  return p;
147 }
148 
149 template <typename T,
150  typename std::enable_if<std::is_convertible<T &, ValueRange>::value &&
151  !std::is_convertible<T &, Value &>::value,
152  T>::type * = nullptr>
153 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
154  p.printOperands(values);
155  return p;
156 }
157 
159  p.printType(type);
160  return p;
161 }
162 
164  p.printAttribute(attr);
165  return p;
166 }
167 
168 // Support printing anything that isn't convertible to one of the above types,
169 // even if it isn't exactly one of them. For example, we want to print
170 // FunctionType with the Type version above, not have it match this.
171 template <typename T, typename std::enable_if<
172  !std::is_convertible<T &, Value &>::value &&
173  !std::is_convertible<T &, Type &>::value &&
174  !std::is_convertible<T &, Attribute &>::value &&
175  !std::is_convertible<T &, ValueRange>::value &&
176  !llvm::is_one_of<T, bool>::value,
177  T>::type * = nullptr>
178 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) {
179  p.getStream() << other;
180  return p;
181 }
182 
183 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, bool value) {
184  return p << (value ? StringRef("true") : "false");
185 }
186 
187 template <typename IteratorT>
188 inline OpAsmPrinter &
191  interleaveComma(types, p);
192  return p;
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // OpAsmParser
197 //===----------------------------------------------------------------------===//
198 
214 class OpAsmParser {
215 public:
216  virtual ~OpAsmParser();
217 
219  virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
220  const Twine &message = {}) = 0;
221 
224  virtual Builder &getBuilder() const = 0;
225 
228  virtual llvm::SMLoc getCurrentLocation() = 0;
229  ParseResult getCurrentLocation(llvm::SMLoc *loc) {
230  *loc = getCurrentLocation();
231  return success();
232  }
233 
235  virtual llvm::SMLoc getNameLoc() const = 0;
236 
237  // These methods emit an error and return failure or success. This allows
238  // these to be chained together into a linear sequence of || expressions in
239  // many cases.
240 
246  virtual Operation *parseGenericOperation(Block *insertBlock,
247  Block::iterator insertPt) = 0;
248 
249  //===--------------------------------------------------------------------===//
250  // Token Parsing
251  //===--------------------------------------------------------------------===//
252 
254  virtual ParseResult parseArrow() = 0;
255 
257  virtual ParseResult parseOptionalArrow() = 0;
258 
260  virtual ParseResult parseColon() = 0;
261 
263  virtual ParseResult parseOptionalColon() = 0;
264 
266  virtual ParseResult parseComma() = 0;
267 
269  virtual ParseResult parseOptionalComma() = 0;
270 
272  virtual ParseResult parseEqual() = 0;
273 
275  virtual ParseResult parseLess() = 0;
276 
278  virtual ParseResult parseGreater() = 0;
279 
281  ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
282  auto loc = getCurrentLocation();
283  if (parseOptionalKeyword(keyword))
284  return emitError(loc, "expected '") << keyword << "'" << msg;
285  return success();
286  }
287 
289  ParseResult parseKeyword(StringRef *keyword) {
290  auto loc = getCurrentLocation();
291  if (parseOptionalKeyword(keyword))
292  return emitError(loc, "expected valid keyword");
293  return success();
294  }
295 
297  virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
298 
300  virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
301 
303  virtual ParseResult parseLParen() = 0;
304 
306  virtual ParseResult parseOptionalLParen() = 0;
307 
309  virtual ParseResult parseRParen() = 0;
310 
312  virtual ParseResult parseOptionalRParen() = 0;
313 
315  virtual ParseResult parseLSquare() = 0;
316 
318  virtual ParseResult parseOptionalLSquare() = 0;
319 
321  virtual ParseResult parseRSquare() = 0;
322 
324  virtual ParseResult parseOptionalRSquare() = 0;
325 
327  virtual ParseResult parseOptionalEllipsis() = 0;
328 
329  //===--------------------------------------------------------------------===//
330  // Attribute Parsing
331  //===--------------------------------------------------------------------===//
332 
335  ParseResult parseAttribute(Attribute &result, StringRef attrName,
337  return parseAttribute(result, Type(), attrName, attrs);
338  }
339 
341  template <typename AttrType>
342  ParseResult parseAttribute(AttrType &result, StringRef attrName,
344  return parseAttribute(result, Type(), attrName, attrs);
345  }
346 
350  virtual ParseResult
351  parseAttribute(Attribute &result, Type type, StringRef attrName,
353 
355  template <typename AttrType>
356  ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
358  llvm::SMLoc loc = getCurrentLocation();
359 
360  // Parse any kind of attribute.
361  Attribute attr;
362  if (parseAttribute(attr, type, attrName, attrs))
363  return failure();
364 
365  // Check for the right kind of attribute.
366  result = attr.dyn_cast<AttrType>();
367  if (!result)
368  return emitError(loc, "invalid kind of attribute specified");
369 
370  return success();
371  }
372 
374  virtual ParseResult
375  parseOptionalAttrDict(SmallVectorImpl<NamedAttribute> &result) = 0;
376 
379  virtual ParseResult
380  parseOptionalAttrDictWithKeyword(SmallVectorImpl<NamedAttribute> &result) = 0;
381 
382  //===--------------------------------------------------------------------===//
383  // Identifier Parsing
384  //===--------------------------------------------------------------------===//
385 
388  ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
390  if (failed(parseOptionalSymbolName(result, attrName, attrs)))
391  return emitError(getCurrentLocation())
392  << "expected valid '@'-identifier for symbol name";
393  return success();
394  }
395 
398  virtual ParseResult
399  parseOptionalSymbolName(StringAttr &result, StringRef attrName,
401 
402  //===--------------------------------------------------------------------===//
403  // Operand Parsing
404  //===--------------------------------------------------------------------===//
405 
407  struct OperandType {
408  llvm::SMLoc location; // Location of the token.
409  StringRef name; // Value name, e.g. %42 or %abc
410  unsigned number; // Number, e.g. 12 for an operand like %xyz#12
411  };
412 
414  virtual ParseResult parseOperand(OperandType &result) = 0;
415 
418  enum class Delimiter {
420  None,
422  Paren,
424  Square,
426  OptionalParen,
428  OptionalSquare,
429  };
430 
433  virtual ParseResult
434  parseOperandList(SmallVectorImpl<OperandType> &result,
435  int requiredOperandCount = -1,
436  Delimiter delimiter = Delimiter::None) = 0;
438  Delimiter delimiter) {
439  return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter);
440  }
441 
445  virtual ParseResult
446  parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
447  int requiredOperandCount = -1,
448  Delimiter delimiter = Delimiter::None) = 0;
450  Delimiter delimiter) {
451  return parseTrailingOperandList(result, /*requiredOperandCount=*/-1,
452  delimiter);
453  }
454 
456  virtual ParseResult resolveOperand(const OperandType &operand, Type type,
457  SmallVectorImpl<Value> &result) = 0;
458 
463  SmallVectorImpl<Value> &result) {
464  for (auto elt : operands)
465  if (resolveOperand(elt, type, result))
466  return failure();
467  return success();
468  }
469 
474  ArrayRef<Type> types, llvm::SMLoc loc,
475  SmallVectorImpl<Value> &result) {
476  if (operands.size() != types.size())
477  return emitError(loc)
478  << operands.size() << " operands present, but expected "
479  << types.size();
480 
481  for (unsigned i = 0, e = operands.size(); i != e; ++i)
482  if (resolveOperand(operands[i], types[i], result))
483  return failure();
484  return success();
485  }
486 
490  virtual ParseResult
491  parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map,
492  StringRef attrName,
494 
495  //===--------------------------------------------------------------------===//
496  // Region Parsing
497  //===--------------------------------------------------------------------===//
498 
506  virtual ParseResult parseRegion(Region &region,
507  ArrayRef<OperandType> arguments,
508  ArrayRef<Type> argTypes,
509  bool enableNameShadowing = false) = 0;
510 
512  virtual ParseResult parseOptionalRegion(Region &region,
513  ArrayRef<OperandType> arguments,
514  ArrayRef<Type> argTypes,
515  bool enableNameShadowing = false) = 0;
516 
519  virtual ParseResult parseRegionArgument(OperandType &argument) = 0;
520 
525  virtual ParseResult
526  parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
527  int requiredOperandCount = -1,
528  Delimiter delimiter = Delimiter::None) = 0;
529  virtual ParseResult
531  Delimiter delimiter) {
532  return parseRegionArgumentList(result, /*requiredOperandCount=*/-1,
533  delimiter);
534  }
535 
537  virtual ParseResult parseOptionalRegionArgument(OperandType &argument) = 0;
538 
539  //===--------------------------------------------------------------------===//
540  // Successor Parsing
541  //===--------------------------------------------------------------------===//
542 
544  virtual ParseResult
545  parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
546 
547  //===--------------------------------------------------------------------===//
548  // Type Parsing
549  //===--------------------------------------------------------------------===//
550 
552  virtual ParseResult parseType(Type &result) = 0;
553 
555  virtual ParseResult
556  parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
557 
559  virtual ParseResult parseColonType(Type &result) = 0;
560 
562  template <typename TypeType> ParseResult parseColonType(TypeType &result) {
563  llvm::SMLoc loc = getCurrentLocation();
564 
565  // Parse any kind of type.
566  Type type;
567  if (parseColonType(type))
568  return failure();
569 
570  // Check for the right kind of attribute.
571  result = type.dyn_cast<TypeType>();
572  if (!result)
573  return emitError(loc, "invalid kind of type specified");
574 
575  return success();
576  }
577 
579  virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
580 
583  virtual ParseResult
584  parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
585 
587  ParseResult parseKeywordType(const char *keyword, Type &result) {
588  return failure(parseKeyword(keyword) || parseType(result));
589  }
590 
595  result.push_back(type);
596  return success();
597  }
598 
603  SmallVectorImpl<Type> &result) {
604  result.append(types.begin(), types.end());
605  return success();
606  }
607 
608 private:
611  ParseResult parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result,
612  bool isOperandList,
613  int requiredOperandCount,
614  Delimiter delimiter);
615 };
616 
617 //===--------------------------------------------------------------------===//
618 // Dialect OpAsm interface.
619 //===--------------------------------------------------------------------===//
620 
624 
626  : public DialectInterface::Base<OpAsmDialectInterface> {
627 public:
628  OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
629 
637  SmallVectorImpl<std::pair<unsigned, StringRef>> &aliases) const {}
640  virtual void getAttributeAliases(
641  SmallVectorImpl<std::pair<Attribute, StringRef>> &aliases) const {}
643  virtual void
644  getTypeAliases(SmallVectorImpl<std::pair<Type, StringRef>> &aliases) const {}
645 
648  virtual void getAsmResultNames(Operation *op,
649  OpAsmSetValueNameFn setNameFn) const {}
650 
653  virtual void getAsmBlockArgumentNames(Block *block,
654  OpAsmSetValueNameFn setNameFn) const {}
655 };
656 
657 //===--------------------------------------------------------------------===//
658 // Operation OpAsm interface.
659 //===--------------------------------------------------------------------===//
660 
662 #include "mlir/IR/OpAsmInterface.h.inc"
663 
664 } // end namespace mlir
665 
666 #endif
This is the representation of an operand reference.
Definition: OpImplementation.h:407
virtual ParseResult parseRegionArgumentList(SmallVectorImpl< OperandType > &result, Delimiter delimiter)
Definition: OpImplementation.h:530
Definition: InferTypeOpInterface.cpp:20
ParseResult resolveOperands(ArrayRef< OperandType > operands, ArrayRef< Type > types, llvm::SMLoc loc, SmallVectorImpl< Value > &result)
Definition: OpImplementation.h:473
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context)
Definition: Region.h:23
ParseResult resolveOperands(ArrayRef< OperandType > operands, Type type, SmallVectorImpl< Value > &result)
Definition: OpImplementation.h:462
ParseResult parseAttribute(AttrType &result, StringRef attrName, SmallVectorImpl< NamedAttribute > &attrs)
Parse an attribute of a specific kind and type.
Definition: OpImplementation.h:342
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Definition: OpImplementation.h:602
ParseResult parseTrailingOperandList(SmallVectorImpl< OperandType > &result, Delimiter delimiter)
Definition: OpImplementation.h:449
ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, SmallVectorImpl< NamedAttribute > &attrs)
Parse an attribute of a specific kind and type.
Definition: OpImplementation.h:356
Definition: Operation.h:27
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: OpImplementation.h:112
Definition: DialectInterface.h:27
Definition: Attributes.h:139
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
Definition: OpImplementation.h:281
Definition: Diagnostics.h:320
OpAsmPrinter()
Definition: OpImplementation.h:34
Block represents an ordered list of Operations.
Definition: Block.h:21
Definition: OpImplementation.h:625
Definition: LLVM.h:49
virtual void shadowRegionArgs(Region &region, ValueRange namesToUse)=0
virtual void getAttributeAliases(SmallVectorImpl< std::pair< Attribute, StringRef >> &aliases) const
Definition: OpImplementation.h:640
bool failed(LogicalResult result)
Definition: LogicalResult.h:45
Function types map from a list of inputs to a list of results.
Definition: Types.h:190
Definition: OpImplementation.h:214
void printOperands(IteratorType it, IteratorType end)
Print a comma separated list of operands.
Definition: OpImplementation.h:49
virtual ~OpAsmPrinter()
Definition: AsmPrinter.cpp:50
ParseResult parseSymbolName(StringAttr &result, StringRef attrName, SmallVectorImpl< NamedAttribute > &attrs)
Definition: OpImplementation.h:388
unsigned number
Definition: OpImplementation.h:410
virtual void printSymbolName(StringRef symbolRef)=0
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
Definition: OpImplementation.h:43
Definition: LLVM.h:34
operand_range getNonSuccessorOperands()
Return the operands of this operation that are not successor arguments.
Definition: Operation.cpp:559
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.cpp:548
LogicalResult success(bool isSuccess=true)
Definition: LogicalResult.h:25
This class implements iteration on the types of a given range of values.
Definition: OperationSupport.h:540
LogicalResult failure(bool isFailure=true)
Definition: LogicalResult.h:32
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true)=0
Prints a region.
Definition: LLVM.h:37
OpListType::iterator iterator
Definition: Block.h:107
Type getType() const
Return the type of this value.
Definition: Value.cpp:34
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
U dyn_cast() const
Definition: Types.h:258
Definition: Attributes.h:53
auto map(Fn fun, IterType begin, IterType end) -> SmallVector< typename std::result_of< Fn(decltype(*begin))>::type, 8 >
Map with iterators.
Definition: Functional.h:28
Definition: Dialect.h:39
void interleaveComma(const Container &c, raw_ostream &os, UnaryFunctor each_fn)
Definition: STLExtras.h:81
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Definition: OpImplementation.h:594
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:246
Definition: OpImplementation.h:32
Definition: Attributes.h:428
virtual raw_ostream & getStream() const =0
Type parseType(llvm::StringRef typeStr, MLIRContext *context)
Definition: Types.h:84
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
Definition: OpImplementation.h:587
ParseResult parseOperandList(SmallVectorImpl< OperandType > &result, Delimiter delimiter)
Definition: OpImplementation.h:437
Definition: Value.h:38
Delimiter
Definition: OpImplementation.h:418
Definition: Attributes.h:175
OpAsmDialectInterface(Dialect *dialect)
Definition: OpImplementation.h:628
ParseResult parseAttribute(Attribute &result, StringRef attrName, SmallVectorImpl< NamedAttribute > &attrs)
Definition: OpImplementation.h:335
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Definition: Diagnostics.cpp:301
Definition: Builders.h:47
Definition: LLVM.h:50
virtual void printAttribute(Attribute attr)=0
virtual void getAsmBlockArgumentNames(Block *block, OpAsmSetValueNameFn setNameFn) const
Definition: OpImplementation.h:653
U dyn_cast() const
Definition: Attributes.h:1347
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
ParseResult parseColonType(TypeType &result)
Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
Definition: OpImplementation.h:562
llvm::SMLoc location
Definition: OpImplementation.h:408
virtual void getAsmResultNames(Operation *op, OpAsmSetValueNameFn setNameFn) const
Definition: OpImplementation.h:648
Definition: StandardTypes.h:63
void printOptionalArrowTypeList(ArrayRef< Type > types)
Print an optional arrow followed by a type list.
Definition: OpImplementation.h:99
virtual void printType(Type type)=0
raw_ostream & operator<<(raw_ostream &os, SubViewOp::Range &range)
Definition: Ops.cpp:2759
virtual void printGenericOp(Operation *op)=0
Print the entire operation with the default generic assembly form.
ParseResult parseKeyword(StringRef *keyword)
Parse a keyword into &#39;keyword&#39;.
Definition: OpImplementation.h:289
bool isa() const
Definition: Types.h:254
virtual void getAttributeKindAliases(SmallVectorImpl< std::pair< unsigned, StringRef >> &aliases) const
Definition: OpImplementation.h:636
Definition: OpDefinition.h:36
virtual void printSuccessorAndUseList(Operation *term, unsigned index)=0
Definition: OperationSupport.h:640
StringRef name
Definition: OpImplementation.h:409
ParseResult getCurrentLocation(llvm::SMLoc *loc)
Definition: OpImplementation.h:229
result_type_range getResultTypes()
Definition: Operation.h:264
virtual void getTypeAliases(SmallVectorImpl< std::pair< Type, StringRef >> &aliases) const
Hook for defining Type aliases.
Definition: OpImplementation.h:644