My Project
Builders.h
Go to the documentation of this file.
1 //===- Builders.h - MLIR Declarative Linalg Builders ------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Provides intuitive composable interfaces for building structured MLIR
10 // snippets in a declarative fashion.
11 //
12 //===----------------------------------------------------------------------===//
13 #ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
14 #define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
15 
18 #include "mlir/EDSC/Builders.h"
19 #include "mlir/EDSC/Intrinsics.h"
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/Builders.h"
22 
23 namespace mlir {
24 class BlockArgument;
25 
26 namespace edsc {
27 enum class IterType { Parallel, Reduction };
28 
29 inline StringRef toString(IterType t) {
30  switch (t) {
31  case IterType::Parallel:
32  return getParallelIteratorTypeName();
34  return getReductionIteratorTypeName();
35  }
36  llvm_unreachable("Unsupported IterType");
37 }
38 
48  StructuredIndexed(Value v) : value(v) {}
50  return StructuredIndexed(value, indexings);
51  }
52 
53  operator Value() const /* implicit */ { return value; }
54  ArrayRef<AffineExpr> getExprs() { return exprs; }
55 
56 private:
58  : value(v), exprs(indexings.begin(), indexings.end()) {
59  assert(v->getType().isa<MemRefType>() && "MemRefType expected");
60  }
62  : StructuredIndexed(v.getValue(), indexings) {}
63 
64  Value value;
66 };
67 
69 
71  ArrayRef<IterType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
73  function_ref<void(ArrayRef<BlockArgument>)> regionBuilder =
75  ArrayRef<Value> otherValues = {}, ArrayRef<Attribute> otherAttributes = {});
76 
77 namespace ops {
79 using edsc::ValueHandle;
81 
82 //===----------------------------------------------------------------------===//
83 // EDSC builders for linalg generic operations.
84 //===----------------------------------------------------------------------===//
85 
89 
93 
110 
115 
120 
126 
132 
138 
139 // TODO(ntv): Implement more useful pointwise operations on a per-need basis.
140 
149 
150 template <typename Container> Operation *linalg_matmul(Container values) {
151  assert(values.size() == 3 && "Expected exactly 3 values");
152  return linalg_matmul(values[0], values[1], values[2]);
153 }
154 
176 // TODO(ntv) Extend convolution rank with some template magic.
178  ArrayRef<int> strides = {},
179  ArrayRef<int> dilations = {});
180 
181 template <typename Container>
182 Operation *linalg_conv_nhwc(Container values, ArrayRef<int> strides = {},
183  ArrayRef<int> dilations = {}) {
184  assert(values.size() == 3 && "Expected exactly 3 values");
185  return linalg_conv_nhwc(values[0], values[1], values[2], strides, dilations);
186 }
187 
209 // TODO(ntv) Extend convolution rank with some template magic.
211  ValueHandle vO, int depth_multiplier = 1,
212  ArrayRef<int> strides = {},
213  ArrayRef<int> dilations = {});
214 
215 template <typename Container>
216 Operation *linalg_dilated_conv_nhwc(Container values, int depth_multiplier,
217  ArrayRef<int> strides = {},
218  ArrayRef<int> dilations = {}) {
219  assert(values.size() == 3 && "Expected exactly 3 values");
220  return linalg_dilated_conv_nhwc(values[0], values[1], values[2],
221  depth_multiplier, strides, dilations);
222 }
223 
224 } // namespace ops
225 } // namespace edsc
226 } // namespace mlir
227 
228 #endif // MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_
Definition: InferTypeOpInterface.cpp:20
Definition: Operation.h:27
Value getValue() const
Definition: Builders.h:358
Definition: LLVM.h:49
Definition: Builders.h:290
ArrayRef< AffineExpr > getExprs()
Definition: Builders.h:54
Operation * linalg_dilated_conv_nhwc(Container values, int depth_multiplier, ArrayRef< int > strides={}, ArrayRef< int > dilations={})
Definition: Builders.h:216
Operation * linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O)
Definition: Builders.cpp:125
void macRegionBuilder(ArrayRef< BlockArgument > args)
Definition: Builders.cpp:104
Definition: Builders.h:47
IterType
Definition: Builders.h:27
void defaultRegionBuilder(ArrayRef< BlockArgument > args)
Definition: Builders.h:68
Operation * linalg_conv_nhwc(Container values, ArrayRef< int > strides={}, ArrayRef< int > dilations={})
Definition: Builders.h:182
Definition: LLVM.h:37
Type getType() const
Return the type of this value.
Definition: Value.cpp:34
StringRef toString(IterType t)
Definition: Builders.h:29
Definition: StandardTypes.h:390
Operation * makeGenericLinalgOp(ArrayRef< IterType > iteratorTypes, ArrayRef< StructuredIndexed > inputs, ArrayRef< StructuredIndexed > outputs, function_ref< void(ArrayRef< BlockArgument >)> regionBuilder=defaultRegionBuilder, ArrayRef< Value > otherValues={}, ArrayRef< Attribute > otherAttributes={})
Definition: Builders.cpp:35
OperationBuilder< linalg::YieldOp > linalg_yield
Definition: Intrinsics.h:21
Operation * linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, StructuredIndexed O)
Definition: Builders.cpp:157
Operation * linalg_matmul(Container values)
Definition: Builders.h:150
Definition: Value.h:38
StructuredIndexed(Value v)
Definition: Builders.h:48
StructuredIndexed operator()(ArrayRef< AffineExpr > indexings)
Definition: Builders.h:49
Definition: LLVM.h:35
Operation * linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1, StructuredIndexed I2, StructuredIndexed O)
Binary pointwise operation (with broadcast) entry point.
Definition: Builders.cpp:134
bool isa() const
Definition: Types.h:254
Operation * linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2, StructuredIndexed O)
Definition: Builders.cpp:148