My Project
Ops.h
Go to the documentation of this file.
1 //===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines convenience types for working with standard operations
10 // in the MLIR operation set.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_DIALECT_STANDARDOPS_OPS_H
15 #define MLIR_DIALECT_STANDARDOPS_OPS_H
16 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/StandardTypes.h"
22 
23 // Pull in all enum type definitions and utility function declarations.
24 #include "mlir/Dialect/StandardOps/OpsEnums.h.inc"
25 
26 namespace mlir {
27 class AffineMap;
28 class Builder;
29 class FuncOp;
30 class OpBuilder;
31 
32 class StandardOpsDialect : public Dialect {
33 public:
35  static StringRef getDialectNamespace() { return "std"; }
36 
39  Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
40  Location loc) override;
41 };
42 
46 enum class CmpFPredicate {
48  // Always false
50  // Ordered comparisons
51  OEQ,
52  OGT,
53  OGE,
54  OLT,
55  OLE,
56  ONE,
57  // Both ordered
58  ORD,
59  // Unordered comparisons
60  UEQ,
61  UGT,
62  UGE,
63  ULT,
64  ULE,
65  UNE,
66  // Any unordered
67  UNO,
68  // Always true
69  AlwaysTrue,
70  // Number of predicates.
72 };
73 
74 #define GET_OP_CLASSES
75 #include "mlir/Dialect/StandardOps/Ops.h.inc"
76 
82 class ConstantFloatOp : public ConstantOp {
83 public:
84  using ConstantOp::ConstantOp;
85 
87  static void build(Builder *builder, OperationState &result,
88  const APFloat &value, FloatType type);
89 
90  APFloat getValue() { return getAttrOfType<FloatAttr>("value").getValue(); }
91 
92  static bool classof(Operation *op);
93 };
94 
100 class ConstantIntOp : public ConstantOp {
101 public:
102  using ConstantOp::ConstantOp;
104  static void build(Builder *builder, OperationState &result, int64_t value,
105  unsigned width);
106 
109  static void build(Builder *builder, OperationState &result, int64_t value,
110  Type type);
111 
112  int64_t getValue() { return getAttrOfType<IntegerAttr>("value").getInt(); }
113 
114  static bool classof(Operation *op);
115 };
116 
122 class ConstantIndexOp : public ConstantOp {
123 public:
124  using ConstantOp::ConstantOp;
125 
127  static void build(Builder *builder, OperationState &result, int64_t value);
128 
129  int64_t getValue() { return getAttrOfType<IntegerAttr>("value").getInt(); }
130 
131  static bool classof(Operation *op);
132 };
133 
134 // DmaStartOp starts a non-blocking DMA operation that transfers data from a
135 // source memref to a destination memref. The source and destination memref need
136 // not be of the same dimensionality, but need to have the same elemental type.
137 // The operands include the source and destination memref's each followed by its
138 // indices, size of the data transfer in terms of the number of elements (of the
139 // elemental type of the memref), a tag memref with its indices, and optionally
140 // at the end, a stride and a number_of_elements_per_stride arguments. The tag
141 // location is used by a DmaWaitOp to check for completion. The indices of the
142 // source memref, destination memref, and the tag memref have the same
143 // restrictions as any load/store. The optional stride arguments should be of
144 // 'index' type, and specify a stride for the slower memory space (memory space
145 // with a lower memory space id), transferring chunks of
146 // number_of_elements_per_stride every stride until %num_elements are
147 // transferred. Either both or no stride arguments should be specified.
148 //
149 // For example, a DmaStartOp operation that transfers 256 elements of a memref
150 // '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space
151 // 1 at indices [%k, %l], would be specified as follows:
152 //
153 // %num_elements = constant 256
154 // %idx = constant 0 : index
155 // %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4>
156 // dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
157 // memref<40 x 128 x f32>, (d0) -> (d0), 0>,
158 // memref<2 x 1024 x f32>, (d0) -> (d0), 1>,
159 // memref<1 x i32>, (d0) -> (d0), 2>
160 //
161 // If %stride and %num_elt_per_stride are specified, the DMA is expected to
162 // transfer %num_elt_per_stride elements every %stride elements apart from
163 // memory space 0 until %num_elements are transferred.
164 //
165 // dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
166 // %num_elt_per_stride :
167 //
168 // TODO(mlir-team): add additional operands to allow source and destination
169 // striding, and multiple stride levels.
170 // TODO(andydavis) Consider replacing src/dst memref indices with view memrefs.
172  : public Op<DmaStartOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
173 public:
174  using Op::Op;
175 
176  static void build(Builder *builder, OperationState &result, Value srcMemRef,
177  ValueRange srcIndices, Value destMemRef,
178  ValueRange destIndices, Value numElements, Value tagMemRef,
179  ValueRange tagIndices, Value stride = nullptr,
180  Value elementsPerStride = nullptr);
181 
182  // Returns the source MemRefType for this DMA operation.
183  Value getSrcMemRef() { return getOperand(0); }
184  // Returns the rank (number of indices) of the source MemRefType.
185  unsigned getSrcMemRefRank() {
186  return getSrcMemRef()->getType().cast<MemRefType>().getRank();
187  }
188  // Returns the source memref indices for this DMA operation.
190  return {getOperation()->operand_begin() + 1,
191  getOperation()->operand_begin() + 1 + getSrcMemRefRank()};
192  }
193 
194  // Returns the destination MemRefType for this DMA operations.
195  Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
196  // Returns the rank (number of indices) of the destination MemRefType.
197  unsigned getDstMemRefRank() {
198  return getDstMemRef()->getType().cast<MemRefType>().getRank();
199  }
200  unsigned getSrcMemorySpace() {
201  return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
202  }
203  unsigned getDstMemorySpace() {
204  return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
205  }
206 
207  // Returns the destination memref indices for this DMA operation.
209  return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1,
210  getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 +
211  getDstMemRefRank()};
212  }
213 
214  // Returns the number of elements being transferred by this DMA operation.
216  return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank());
217  }
218 
219  // Returns the Tag MemRef for this DMA operation.
221  return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
222  }
223  // Returns the rank (number of indices) of the tag MemRefType.
224  unsigned getTagMemRefRank() {
225  return getTagMemRef()->getType().cast<MemRefType>().getRank();
226  }
227 
228  // Returns the tag memref index for this DMA operation.
230  unsigned tagIndexStartPos =
231  1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1;
232  return {getOperation()->operand_begin() + tagIndexStartPos,
233  getOperation()->operand_begin() + tagIndexStartPos +
234  getTagMemRefRank()};
235  }
236 
239  return (getSrcMemorySpace() < getDstMemorySpace());
240  }
241 
244  // Assumes that a lower number is for a slower memory space.
245  return (getDstMemorySpace() < getSrcMemorySpace());
246  }
247 
251  unsigned getFasterMemPos() {
252  assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
253  return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
254  }
255 
256  static StringRef getOperationName() { return "std.dma_start"; }
257  static ParseResult parse(OpAsmParser &parser, OperationState &result);
258  void print(OpAsmPrinter &p);
260 
261  LogicalResult fold(ArrayRef<Attribute> cstOperands,
263 
264  bool isStrided() {
265  return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() +
266  1 + 1 + getTagMemRefRank();
267  }
268 
270  if (!isStrided())
271  return nullptr;
272  return getOperand(getNumOperands() - 1 - 1);
273  }
274 
276  if (!isStrided())
277  return nullptr;
278  return getOperand(getNumOperands() - 1);
279  }
280 };
281 
282 // DmaWaitOp blocks until the completion of a DMA operation associated with the
283 // tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
284 // with the same restrictions as any load/store index. %num_elements is the
285 // number of elements associated with the DMA operation. For example:
286 //
287 // dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
288 // memref<2048 x f32>, (d0) -> (d0), 0>,
289 // memref<256 x f32>, (d0) -> (d0), 1>
290 // memref<1 x i32>, (d0) -> (d0), 2>
291 // ...
292 // ...
293 // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2>
294 //
296  : public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
297 public:
298  using Op::Op;
299 
300  static void build(Builder *builder, OperationState &result, Value tagMemRef,
301  ValueRange tagIndices, Value numElements);
302 
303  static StringRef getOperationName() { return "std.dma_wait"; }
304 
305  // Returns the Tag MemRef associated with the DMA operation being waited on.
306  Value getTagMemRef() { return getOperand(0); }
307 
308  // Returns the tag memref index for this DMA operation.
310  return {getOperation()->operand_begin() + 1,
311  getOperation()->operand_begin() + 1 + getTagMemRefRank()};
312  }
313 
314  // Returns the rank (number of indices) of the tag memref.
315  unsigned getTagMemRefRank() {
316  return getTagMemRef()->getType().cast<MemRefType>().getRank();
317  }
318 
319  // Returns the number of elements transferred in the associated DMA operation.
320  Value getNumElements() { return getOperand(1 + getTagMemRefRank()); }
321 
322  static ParseResult parse(OpAsmParser &parser, OperationState &result);
323  void print(OpAsmPrinter &p);
324  LogicalResult fold(ArrayRef<Attribute> cstOperands,
326 };
327 
330  Operation::operand_iterator end, unsigned numDims,
331  OpAsmPrinter &p);
332 
335  SmallVectorImpl<Value> &operands,
336  unsigned &numDims);
337 
338 raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range);
339 
340 } // end namespace mlir
341 
342 #endif // MLIR_DIALECT_STANDARDOPS_OPS_H
operand_range::iterator operand_iterator
Definition: Operation.h:214
Definition: InferTypeOpInterface.cpp:20
unsigned getSrcMemRefRank()
Definition: Ops.h:185
unsigned getFasterMemPos()
Definition: Ops.h:251
Definition: Operation.h:27
Definition: Ops.h:100
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list and returns true if parsing failed.
Definition: Ops.cpp:177
Definition: Attributes.h:129
operand_range getDstIndices()
Definition: Ops.h:208
static StringRef getDialectNamespace()
Definition: Ops.h:35
LogicalResult verify(Operation *op)
Definition: Verifier.cpp:264
operand_range getTagIndices()
Definition: Ops.h:309
Definition: StandardTypes.h:113
Value getDstMemRef()
Definition: Ops.h:195
void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &p)
Prints dimension and symbol list.
Definition: Ops.cpp:165
Definition: OpImplementation.h:214
StandardOpsDialect(MLIRContext *context)
Definition: Ops.cpp:148
Definition: LLVM.h:34
Definition: Location.h:52
Definition: Ops.h:82
Definition: LogicalResult.h:18
CmpFPredicate
Definition: Ops.h:46
Definition: LLVM.h:37
bool isStrided(MemRefType t)
Return true if the layout for t is compatible with strided semantics.
Definition: StandardTypes.cpp:734
unsigned getSrcMemorySpace()
Definition: Ops.h:200
int64_t getValue()
Definition: Ops.h:112
unsigned getDstMemorySpace()
Definition: Ops.h:203
int64_t getValue()
Definition: Ops.h:129
Value getNumElements()
Definition: Ops.h:215
Value getStride()
Definition: Ops.h:269
Definition: Attributes.h:53
Definition: Dialect.h:39
Value getTagMemRef()
Definition: Ops.h:306
Definition: Ops.h:32
Definition: StandardTypes.h:390
Definition: OpImplementation.h:32
Definition: OperationSupport.h:261
operand_range getTagIndices()
Definition: Ops.h:229
Definition: Ops.h:171
bool isDestMemorySpaceFaster()
Returns true if this is a DMA from a faster memory space to a slower one.
Definition: Ops.h:238
Definition: Ops.h:295
Op()
This is a public constructor. Any op can be initialized to null.
Definition: OpDefinition.h:1029
Value getTagMemRef()
Definition: Ops.h:220
Definition: Types.h:84
operand_range getSrcIndices()
Definition: Ops.h:189
unsigned getDstMemRefRank()
Definition: Ops.h:197
Definition: Value.h:38
bool isSrcMemorySpaceFaster()
Returns true if this is a DMA from a slower memory space to a faster one.
Definition: Ops.h:243
Definition: Builders.h:47
static StringRef getOperationName()
Definition: Ops.h:303
Value getNumElementsPerStride()
Definition: Ops.h:275
APFloat getValue()
Definition: Ops.h:90
Definition: MLIRContext.h:34
void print(OpAsmPrinter &p, AffineIfOp op)
Definition: AffineOps.cpp:1671
This class implements the operand iterators for the Operation class.
Definition: OperationSupport.h:559
Definition: OpDefinition.h:949
static StringRef getOperationName()
Definition: Ops.h:256
unsigned getTagMemRefRank()
Definition: Ops.h:315
unsigned getTagMemRefRank()
Definition: Ops.h:224
mlir::edsc::intrinsics::ValueBuilder< RangeOp > range
Definition: Intrinsics.h:23
raw_ostream & operator<<(raw_ostream &os, SubViewOp::Range &range)
Definition: Ops.cpp:2759
Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) override
Definition: Ops.cpp:159
Definition: OpDefinition.h:36
bool isStrided()
Definition: Ops.h:264
Definition: LinalgTypes.h:20
Definition: Builders.h:158
Definition: OperationSupport.h:640
Value getSrcMemRef()
Definition: Ops.h:183
Definition: Ops.h:122
Value getNumElements()
Definition: Ops.h:320