My Project
StructuredOpsUtils.h
Go to the documentation of this file.
1 //===- StructuredOpsUtils.h - Utilities used by structured ops --*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file define utilities that operate on standard types and are
10 // useful across multiple dialects that use structured ops abstractions. These
11 // abstractions consist of define custom operations that encode and transport
12 // information about their semantics (e.g. type of iterators like parallel,
13 // reduction, etc..) as attributes.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #ifndef MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
18 #define MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H
19 
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/Support/LLVM.h"
22 #include "llvm/ADT/StringRef.h"
23 
24 namespace mlir {
27 static constexpr StringLiteral getIndexingMapsAttrName() {
28  return StringLiteral("indexing_maps");
29 }
30 
33 static constexpr StringLiteral getIteratorTypesAttrName() {
34  return StringLiteral("iterator_types");
35 }
36 
39 static constexpr StringLiteral getArgsInAttrName() {
40  return StringLiteral("args_in");
41 }
42 
45 static constexpr StringLiteral getArgsOutAttrName() {
46  return StringLiteral("args_out");
47 }
48 
51 static constexpr StringLiteral getDocAttrName() { return StringLiteral("doc"); }
52 
55 static constexpr StringLiteral getFunAttrName() { return StringLiteral("fun"); }
56 
59 static constexpr StringLiteral getLibraryCallAttrName() {
60  return StringLiteral("library_call");
61 }
62 
64 inline static constexpr StringLiteral getParallelIteratorTypeName() {
65  return StringLiteral("parallel");
66 }
67 
69 inline static constexpr StringLiteral getReductionIteratorTypeName() {
70  return StringLiteral("reduction");
71 }
72 
74 inline static constexpr StringLiteral getWindowIteratorTypeName() {
75  return StringLiteral("window");
76 }
77 
79 inline static ArrayRef<StringRef> getAllIteratorTypeNames() {
80  static StringRef names[3] = {getParallelIteratorTypeName(),
81  getReductionIteratorTypeName(),
82  getWindowIteratorTypeName()};
83  return llvm::makeArrayRef(names);
84 }
85 
87 inline unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes) {
88  auto names = getAllIteratorTypeNames();
89  (void)names;
90  assert(llvm::is_contained(names, name));
91  return llvm::count_if(iteratorTypes, [name](Attribute a) {
92  return a.cast<StringAttr>().getValue() == name;
93  });
94 }
95 
96 inline unsigned getNumIterators(ArrayAttr iteratorTypes) {
97  unsigned res = 0;
98  for (auto n : getAllIteratorTypeNames())
99  res += getNumIterators(n, iteratorTypes);
100  return res;
101 }
102 
103 } // end namespace mlir
104 
105 #endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H
Definition: InferTypeOpInterface.cpp:20
U cast() const
Definition: Attributes.h:1353
Definition: Attributes.h:198
unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes)
Returns the iterator of a certain type.
Definition: StructuredOpsUtils.h:87
Definition: Attributes.h:53
Definition: Attributes.h:428