My Project
FunctionSupport.h
Go to the documentation of this file.
1 //===- FunctionSupport.h - Utility types for function-like ops --*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines support types for Operations that represent function-like
10 // constructs to use.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_IR_FUNCTIONSUPPORT_H
15 #define MLIR_IR_FUNCTIONSUPPORT_H
16 
17 #include "mlir/IR/OpDefinition.h"
18 #include "llvm/ADT/SmallString.h"
19 
20 namespace mlir {
21 
22 namespace impl {
23 
25 inline StringRef getTypeAttrName() { return "type"; }
26 
28 inline StringRef getArgAttrName(unsigned arg, SmallVectorImpl<char> &out) {
29  out.clear();
30  return ("arg" + Twine(arg)).toStringRef(out);
31 }
32 
34 inline StringRef getResultAttrName(unsigned arg, SmallVectorImpl<char> &out) {
35  out.clear();
36  return ("result" + Twine(arg)).toStringRef(out);
37 }
38 
42 inline DictionaryAttr getArgAttrDict(Operation *op, unsigned index) {
43  SmallString<8> nameOut;
44  return op->getAttrOfType<DictionaryAttr>(getArgAttrName(index, nameOut));
45 }
46 
50 inline DictionaryAttr getResultAttrDict(Operation *op, unsigned index) {
51  SmallString<8> nameOut;
52  return op->getAttrOfType<DictionaryAttr>(getResultAttrName(index, nameOut));
53 }
54 
56 inline ArrayRef<NamedAttribute> getArgAttrs(Operation *op, unsigned index) {
57  auto argDict = getArgAttrDict(op, index);
58  return argDict ? argDict.getValue() : llvm::None;
59 }
60 
62 inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
63  auto resultDict = getResultAttrDict(op, index);
64  return resultDict ? resultDict.getValue() : llvm::None;
65 }
66 
67 } // namespace impl
68 
69 namespace OpTrait {
70 
99 template <typename ConcreteType>
100 class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
101 public:
103  static LogicalResult verifyTrait(Operation *op);
104 
105  //===--------------------------------------------------------------------===//
106  // Body Handling
107  //===--------------------------------------------------------------------===//
108 
110  bool isExternal() { return empty(); }
111 
112  Region &getBody() { return this->getOperation()->getRegion(0); }
113 
115  void eraseBody() {
116  getBody().dropAllReferences();
117  getBody().getBlocks().clear();
118  }
119 
122  BlockListType &getBlocks() { return getBody().getBlocks(); }
123 
124  // Iteration over the block in the function.
125  using iterator = BlockListType::iterator;
126  using reverse_iterator = BlockListType::reverse_iterator;
127 
128  iterator begin() { return getBody().begin(); }
129  iterator end() { return getBody().end(); }
130  reverse_iterator rbegin() { return getBody().rbegin(); }
131  reverse_iterator rend() { return getBody().rend(); }
132 
133  bool empty() { return getBody().empty(); }
134  void push_back(Block *block) { getBody().push_back(block); }
135  void push_front(Block *block) { getBody().push_front(block); }
136 
137  Block &back() { return getBody().back(); }
138  Block &front() { return getBody().front(); }
139 
143  LogicalResult verifyBody();
144 
145  //===--------------------------------------------------------------------===//
146  // Type Attribute Handling
147  //===--------------------------------------------------------------------===//
148 
151 
153  return this->getOperation()->template getAttrOfType<TypeAttr>(
154  getTypeAttrName());
155  }
156 
158  auto typeAttr = getTypeAttr();
159  if (!typeAttr)
160  return false;
161  return typeAttr.getValue() != Type{};
162  }
163 
164  //===--------------------------------------------------------------------===//
165  // Argument Handling
166  //===--------------------------------------------------------------------===//
167 
168  unsigned getNumArguments() {
169  return static_cast<ConcreteType *>(this)->getNumFuncArguments();
170  }
171 
172  unsigned getNumResults() {
173  return static_cast<ConcreteType *>(this)->getNumFuncResults();
174  }
175 
177  BlockArgument getArgument(unsigned idx) {
178  return getBlocks().front().getArgument(idx);
179  }
180 
181  // Supports non-const operand iteration.
183  args_iterator args_begin() { return front().args_begin(); }
184  args_iterator args_end() { return front().args_end(); }
186  return {args_begin(), args_end()};
187  }
188 
189  //===--------------------------------------------------------------------===//
190  // Argument Attributes
191  //===--------------------------------------------------------------------===//
192 
199 
202  return ::mlir::impl::getArgAttrs(this->getOperation(), index);
203  }
204 
207  for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
208  result.emplace_back(getArgAttrDict(i));
209  }
210 
213  Attribute getArgAttr(unsigned index, Identifier name) {
214  auto argDict = getArgAttrDict(index);
215  return argDict ? argDict.get(name) : nullptr;
216  }
217  Attribute getArgAttr(unsigned index, StringRef name) {
218  auto argDict = getArgAttrDict(index);
219  return argDict ? argDict.get(name) : nullptr;
220  }
221 
222  template <typename AttrClass>
223  AttrClass getArgAttrOfType(unsigned index, Identifier name) {
224  return getArgAttr(index, name).template dyn_cast_or_null<AttrClass>();
225  }
226  template <typename AttrClass>
227  AttrClass getArgAttrOfType(unsigned index, StringRef name) {
228  return getArgAttr(index, name).template dyn_cast_or_null<AttrClass>();
229  }
230 
232  void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes);
233  void setArgAttrs(unsigned index, NamedAttributeList attributes);
235  assert(attributes.size() == getNumArguments());
236  for (unsigned i = 0, e = attributes.size(); i != e; ++i)
237  setArgAttrs(i, attributes[i]);
238  }
239 
242  void setArgAttr(unsigned index, Identifier name, Attribute value);
243  void setArgAttr(unsigned index, StringRef name, Attribute value) {
244  setArgAttr(index, Identifier::get(name, this->getOperation()->getContext()),
245  value);
246  }
247 
249  NamedAttributeList::RemoveResult removeArgAttr(unsigned index,
250  Identifier name);
251 
252  //===--------------------------------------------------------------------===//
253  // Result Attributes
254  //===--------------------------------------------------------------------===//
255 
262 
265  return ::mlir::impl::getResultAttrs(this->getOperation(), index);
266  }
267 
270  for (unsigned i = 0, e = getNumResults(); i != e; ++i)
271  result.emplace_back(getResultAttrDict(i));
272  }
273 
276  Attribute getResultAttr(unsigned index, Identifier name) {
277  auto argDict = getResultAttrDict(index);
278  return argDict ? argDict.get(name) : nullptr;
279  }
280  Attribute getResultAttr(unsigned index, StringRef name) {
281  auto argDict = getResultAttrDict(index);
282  return argDict ? argDict.get(name) : nullptr;
283  }
284 
285  template <typename AttrClass>
286  AttrClass getResultAttrOfType(unsigned index, Identifier name) {
287  return getResultAttr(index, name).template dyn_cast_or_null<AttrClass>();
288  }
289  template <typename AttrClass>
290  AttrClass getResultAttrOfType(unsigned index, StringRef name) {
291  return getResultAttr(index, name).template dyn_cast_or_null<AttrClass>();
292  }
293 
295  void setResultAttrs(unsigned index, ArrayRef<NamedAttribute> attributes);
296  void setResultAttrs(unsigned index, NamedAttributeList attributes);
298  assert(attributes.size() == getNumResults());
299  for (unsigned i = 0, e = attributes.size(); i != e; ++i)
300  setResultAttrs(i, attributes[i]);
301  }
302 
305  void setResultAttr(unsigned index, Identifier name, Attribute value);
306  void setResultAttr(unsigned index, StringRef name, Attribute value) {
307  setResultAttr(index,
308  Identifier::get(name, this->getOperation()->getContext()),
309  value);
310  }
311 
313  NamedAttributeList::RemoveResult removeResultAttr(unsigned index,
314  Identifier name);
315 
316 protected:
319  static StringRef getArgAttrName(unsigned index, SmallVectorImpl<char> &out) {
321  }
322 
326  DictionaryAttr getArgAttrDict(unsigned index) {
327  assert(index < getNumArguments() && "invalid argument number");
328  return ::mlir::impl::getArgAttrDict(this->getOperation(), index);
329  }
330 
333  static StringRef getResultAttrName(unsigned index,
334  SmallVectorImpl<char> &out) {
336  }
337 
342  assert(index < getNumResults() && "invalid result number");
343  return ::mlir::impl::getResultAttrDict(this->getOperation(), index);
344  }
345 
349 };
350 
353 template <typename ConcreteType>
355  auto funcOp = cast<ConcreteType>(this->getOperation());
356 
357  if (funcOp.isExternal())
358  return success();
359 
360  unsigned numArguments = funcOp.getNumArguments();
361  if (funcOp.front().getNumArguments() != numArguments)
362  return funcOp.emitOpError("entry block must have ")
363  << numArguments << " arguments to match function signature";
364 
365  return success();
366 }
367 
368 template <typename ConcreteType>
370  MLIRContext *ctx = op->getContext();
371  auto funcOp = cast<ConcreteType>(op);
372 
373  if (!funcOp.isTypeAttrValid())
374  return funcOp.emitOpError("requires a type attribute '")
375  << getTypeAttrName() << '\'';
376 
377  if (failed(funcOp.verifyType()))
378  return failure();
379 
380  for (unsigned i = 0, e = funcOp.getNumArguments(); i != e; ++i) {
381  // Verify that all of the argument attributes are dialect attributes, i.e.
382  // that they contain a dialect prefix in their name. Call the dialect, if
383  // registered, to verify the attributes themselves.
384  for (auto attr : funcOp.getArgAttrs(i)) {
385  if (!attr.first.strref().contains('.'))
386  return funcOp.emitOpError("arguments may only have dialect attributes");
387  auto dialectNamePair = attr.first.strref().split('.');
388  if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
389  if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
390  /*argIndex=*/i, attr)))
391  return failure();
392  }
393  }
394  }
395 
396  for (unsigned i = 0, e = funcOp.getNumResults(); i != e; ++i) {
397  // Verify that all of the result attributes are dialect attributes, i.e.
398  // that they contain a dialect prefix in their name. Call the dialect, if
399  // registered, to verify the attributes themselves.
400  for (auto attr : funcOp.getResultAttrs(i)) {
401  if (!attr.first.strref().contains('.'))
402  return funcOp.emitOpError("results may only have dialect attributes");
403  auto dialectNamePair = attr.first.strref().split('.');
404  if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
405  if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
406  /*resultIndex=*/i,
407  attr)))
408  return failure();
409  }
410  }
411  }
412 
413  // Check that the op has exactly one region for the body.
414  if (op->getNumRegions() != 1)
415  return funcOp.emitOpError("expects one region");
416 
417  return funcOp.verifyBody();
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // Function Argument Attribute.
422 //===----------------------------------------------------------------------===//
423 
425 template <typename ConcreteType>
427  unsigned index, ArrayRef<NamedAttribute> attributes) {
428  assert(index < getNumArguments() && "invalid argument number");
429  SmallString<8> nameOut;
430  getArgAttrName(index, nameOut);
431 
432  if (attributes.empty())
433  return (void)static_cast<ConcreteType *>(this)->removeAttr(nameOut);
434  Operation *op = this->getOperation();
435  op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
436 }
437 
438 template <typename ConcreteType>
440  NamedAttributeList attributes) {
441  assert(index < getNumArguments() && "invalid argument number");
442  SmallString<8> nameOut;
443  if (auto newAttr = attributes.getDictionary())
444  return this->getOperation()->setAttr(getArgAttrName(index, nameOut),
445  newAttr);
446  static_cast<ConcreteType *>(this)->removeAttr(getArgAttrName(index, nameOut));
447 }
448 
451 template <typename ConcreteType>
453  Attribute value) {
454  auto curAttr = getArgAttrDict(index);
455  NamedAttributeList attrList(curAttr);
456  attrList.set(name, value);
457 
458  // If the attribute changed, then set the new arg attribute list.
459  if (curAttr != attrList.getDictionary())
460  setArgAttrs(index, attrList);
461 }
462 
464 template <typename ConcreteType>
467  // Build an attribute list and remove the attribute at 'name'.
468  NamedAttributeList attrList(getArgAttrDict(index));
469  auto result = attrList.remove(name);
470 
471  // If the attribute was removed, then update the argument dictionary.
473  setArgAttrs(index, attrList);
474  return result;
475 }
476 
477 //===----------------------------------------------------------------------===//
478 // Function Result Attribute.
479 //===----------------------------------------------------------------------===//
480 
482 template <typename ConcreteType>
484  unsigned index, ArrayRef<NamedAttribute> attributes) {
485  assert(index < getNumResults() && "invalid result number");
486  SmallString<8> nameOut;
487  getResultAttrName(index, nameOut);
488 
489  if (attributes.empty())
490  return (void)static_cast<ConcreteType *>(this)->removeAttr(nameOut);
491  Operation *op = this->getOperation();
492  op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
493 }
494 
495 template <typename ConcreteType>
497  NamedAttributeList attributes) {
498  assert(index < getNumResults() && "invalid result number");
499  SmallString<8> nameOut;
500  if (auto newAttr = attributes.getDictionary())
501  return this->getOperation()->setAttr(getResultAttrName(index, nameOut),
502  newAttr);
503  static_cast<ConcreteType *>(this)->removeAttr(
504  getResultAttrName(index, nameOut));
505 }
506 
509 template <typename ConcreteType>
511  Attribute value) {
512  auto curAttr = getResultAttrDict(index);
513  NamedAttributeList attrList(curAttr);
514  attrList.set(name, value);
515 
516  // If the attribute changed, then set the new arg attribute list.
517  if (curAttr != attrList.getDictionary())
518  setResultAttrs(index, attrList);
519 }
520 
522 template <typename ConcreteType>
525  // Build an attribute list and remove the attribute at 'name'.
526  NamedAttributeList attrList(getResultAttrDict(index));
527  auto result = attrList.remove(name);
528 
529  // If the attribute was removed, then update the result dictionary.
531  setResultAttrs(index, attrList);
532  return result;
533 }
534 
535 } // end namespace OpTrait
536 
537 } // end namespace mlir
538 
539 #endif // MLIR_IR_FUNCTIONSUPPORT_H
void setAllResultAttrs(ArrayRef< NamedAttributeList > attributes)
Definition: FunctionSupport.h:297
AttrClass getAttrOfType(Identifier name)
Definition: Operation.h:289
Definition: InferTypeOpInterface.cpp:20
static StringRef getTypeAttrName()
Return the name of the attribute used for function types.
Definition: FunctionSupport.h:150
void setAttr(Identifier name, Attribute value)
Definition: Operation.h:299
StringRef getResultAttrName(unsigned arg, SmallVectorImpl< char > &out)
Return the name of the attribute used for function results.
Definition: FunctionSupport.h:34
Definition: Region.h:23
Attribute getArgAttr(unsigned index, Identifier name)
Definition: FunctionSupport.h:213
reverse_iterator rbegin()
Definition: FunctionSupport.h:130
Definition: Operation.h:27
BlockArgument getArgument(unsigned idx)
Gets argument.
Definition: FunctionSupport.h:177
Operation & back()
Definition: Block.h:119
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:360
Definition: Attributes.h:519
Block represents an ordered list of Operations.
Definition: Block.h:21
LogicalResult verifyType()
Definition: FunctionSupport.h:348
iterator_range< args_iterator > getArguments()
Definition: FunctionSupport.h:185
DictionaryAttr getResultAttrDict(unsigned index)
Definition: FunctionSupport.h:341
void getAllResultAttrs(SmallVectorImpl< NamedAttributeList > &result)
Return all result attributes of this function.
Definition: FunctionSupport.h:269
AttrClass getResultAttrOfType(unsigned index, Identifier name)
Definition: FunctionSupport.h:286
bool failed(LogicalResult result)
Definition: LogicalResult.h:45
static Identifier get(StringRef str, MLIRContext *context)
Return an identifier for the specified string.
Definition: MLIRContext.cpp:426
Definition: Identifier.h:26
static DictionaryAttr get(ArrayRef< NamedAttribute > value, MLIRContext *context)
Definition: Attributes.cpp:88
AttrClass getArgAttrOfType(unsigned index, Identifier name)
Definition: FunctionSupport.h:223
Definition: Attributes.h:269
Operation & front()
Definition: Block.h:120
Dialect * getRegisteredDialect(StringRef name)
Definition: MLIRContext.cpp:315
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.cpp:248
BlockListType & getBlocks()
Definition: FunctionSupport.h:122
args_iterator args_end()
Definition: FunctionSupport.h:184
Definition: LLVM.h:34
BlockArgListType::iterator args_iterator
Definition: Block.h:70
ArrayRef< NamedAttribute > getResultAttrs(unsigned index)
Return all of the attributes for the result at &#39;index&#39;.
Definition: FunctionSupport.h:264
AttrClass getArgAttrOfType(unsigned index, StringRef name)
Definition: FunctionSupport.h:227
LogicalResult success(bool isSuccess=true)
Definition: LogicalResult.h:25
RemoveResult remove(Identifier name)
Definition: Attributes.cpp:1082
Definition: LogicalResult.h:18
LogicalResult failure(bool isFailure=true)
Definition: LogicalResult.h:32
DictionaryAttr getArgAttrDict(unsigned index)
Definition: FunctionSupport.h:326
Definition: LLVM.h:37
Definition: LLVM.h:36
static StringRef getResultAttrName(unsigned index, SmallVectorImpl< char > &out)
Definition: FunctionSupport.h:333
bool isExternal()
Returns true if this function is external, i.e. it has no body.
Definition: FunctionSupport.h:110
DictionaryAttr getDictionary() const
Definition: Attributes.h:1389
TypeAttr getTypeAttr()
Definition: FunctionSupport.h:152
iterator end()
Definition: FunctionSupport.h:129
Definition: Attributes.h:53
Block & front()
Definition: FunctionSupport.h:138
Attribute getResultAttr(unsigned index, Identifier name)
Definition: FunctionSupport.h:276
Block::args_iterator args_iterator
Definition: FunctionSupport.h:182
unsigned getNumArguments()
Definition: FunctionSupport.h:168
void setArgAttr(unsigned index, StringRef name, Attribute value)
Definition: FunctionSupport.h:243
ArrayRef< NamedAttribute > getArgAttrs(unsigned index)
Return all of the attributes for the argument at &#39;index&#39;.
Definition: FunctionSupport.h:201
Definition: Attributes.h:1374
bool empty()
Definition: FunctionSupport.h:133
Block arguments are values.
Definition: Value.h:235
args_iterator args_begin()
Definition: FunctionSupport.h:183
Definition: Types.h:84
ArrayRef< NamedAttribute > getResultAttrs(Operation *op, unsigned index)
Return all of the attributes for the result at &#39;index&#39;.
Definition: FunctionSupport.h:62
DictionaryAttr getArgAttrDict(Operation *op, unsigned index)
Definition: FunctionSupport.h:42
Region & getBody()
Definition: FunctionSupport.h:112
reverse_iterator rend()
Definition: FunctionSupport.h:131
void eraseBody()
Delete all blocks from this function.
Definition: FunctionSupport.h:115
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
Definition: FunctionSupport.h:25
Definition: OpDefinition.h:386
void setResultAttr(unsigned index, StringRef name, Attribute value)
Definition: FunctionSupport.h:306
ArrayRef< NamedAttribute > getArgAttrs(Operation *op, unsigned index)
Return all of the attributes for the argument at &#39;index&#39;.
Definition: FunctionSupport.h:56
Attribute getResultAttr(unsigned index, StringRef name)
Definition: FunctionSupport.h:280
Definition: LLVM.h:50
BlockListType::iterator iterator
Definition: FunctionSupport.h:125
StringRef getArgAttrName(unsigned arg, SmallVectorImpl< char > &out)
Return the name of the attribute used for function arguments.
Definition: FunctionSupport.h:28
AttrClass getResultAttrOfType(unsigned index, StringRef name)
Definition: FunctionSupport.h:290
Definition: MLIRContext.h:34
static StringRef getArgAttrName(unsigned index, SmallVectorImpl< char > &out)
Definition: FunctionSupport.h:319
void push_back(Block *block)
Definition: FunctionSupport.h:134
Block & back()
Definition: FunctionSupport.h:137
void setAllArgAttrs(ArrayRef< NamedAttributeList > attributes)
Definition: FunctionSupport.h:234
bool isTypeAttrValid()
Definition: FunctionSupport.h:157
Definition: StandardTypes.h:63
Region::BlockListType BlockListType
This is the list of blocks in the function.
Definition: FunctionSupport.h:121
RemoveResult
Definition: Attributes.h:1405
Attribute getArgAttr(unsigned index, StringRef name)
Definition: FunctionSupport.h:217
unsigned getNumResults()
Definition: FunctionSupport.h:172
DictionaryAttr getResultAttrDict(Operation *op, unsigned index)
Definition: FunctionSupport.h:50
iterator begin()
Definition: FunctionSupport.h:128
void set(Identifier name, Attribute value)
Definition: Attributes.cpp:1062
Definition: FunctionSupport.h:100
BlockListType::reverse_iterator reverse_iterator
Definition: FunctionSupport.h:126
void push_front(Block *block)
Definition: FunctionSupport.h:135
llvm::iplist< Block > BlockListType
Definition: Region.h:37
void getAllArgAttrs(SmallVectorImpl< NamedAttributeList > &result)
Return all argument attributes of this function.
Definition: FunctionSupport.h:206