My Project
LinalgTraits.h
Go to the documentation of this file.
1 //===- LinalgTraits.h - Linalg Traits ---------------------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_LINALG_LINALGTRAITS_H_
10 #define MLIR_DIALECT_LINALG_LINALGTRAITS_H_
11 
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/IR/StandardTypes.h"
16 #include "mlir/Support/LLVM.h"
17 
18 namespace mlir {
19 namespace OpTrait {
20 namespace linalg {
21 
27 template <unsigned N> class NInputs {
28 public:
29  template <typename ConcreteType>
30  class Impl : public OpTrait::TraitBase<ConcreteType, NInputs<N>::Impl> {
31  public:
32  static unsigned getNumInputs() { return N; }
33  };
34 };
35 
41 template <unsigned N> class NOutputs {
42 public:
43  template <typename ConcreteType>
44  class Impl : public OpTrait::TraitBase<ConcreteType, NOutputs<N>::Impl> {
45  public:
46  static unsigned getNumOutputs() { return N; }
47  };
48 };
49 
57 template <typename ConcreteType>
59  : public OpTrait::TraitBase<ConcreteType, StructuredOpTraits> {
60 private:
62  unsigned nInputs() {
63  return cast<ConcreteType>(this->getOperation()).getNumInputs();
64  }
66  unsigned nOutputs() {
67  return cast<ConcreteType>(this->getOperation()).getNumOutputs();
68  }
69 
70 public:
72  Value getInput(unsigned i) {
73  assert(i < nInputs());
74  return this->getOperation()->getOperand(i);
75  }
79  auto it = llvm::find(getInputs(), value);
80  if (it != getInputs().end())
81  return it - getInputs().begin();
82  return llvm::None;
83  }
86  return getInput(i)->getType().template cast<ShapedType>();
87  }
90  auto range = this->getOperation()->getOperands();
91  return {range.begin(), range.begin() + nInputs()};
92  }
94  Value getOutput(unsigned i) {
95  return this->getOperation()->getOperand(nInputs() + i);
96  }
100  auto it = llvm::find(getOutputs(), value);
101  if (it != getOutputs().end())
102  return it - getOutputs().begin();
103  return llvm::None;
104  }
107  return getOutput(i)->getType().template cast<ShapedType>();
108  }
111  return this->getOperation()->getNumResults() == 0 &&
112  llvm::all_of(getInputsAndOutputs(),
113  [](Value v) { return v.getType().isa<MemRefType>(); });
114  }
118  for (Type type : getInputs().getTypes())
119  if (auto t = type.template dyn_cast<RankedTensorType>())
120  res.push_back(t);
121  return res;
122  }
126  for (Type type : getOutputs().getTypes())
127  if (auto t = type.template dyn_cast<RankedTensorType>())
128  res.push_back(t);
129  return res;
130  }
133  auto range = this->getOperation()->getOperands();
134  return {range.begin() + nInputs(),
135  range.begin() + getNumInputsAndOutputs()};
136  }
138  unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }
140  ShapedType getShapedType(unsigned i) {
141  return (i < nInputs()) ? getInputShapedType(i)
142  : getOutputShapedType(i - nInputs());
143  }
146  auto range = this->getOperation()->getOperands();
147  return {range.begin(), range.begin() + getNumInputsAndOutputs()};
148  }
149  unsigned getNumParallelLoops() {
150  return getNumIterators(
151  getParallelIteratorTypeName(),
152  cast<ConcreteType>(this->getOperation()).iterator_types());
153  }
154  unsigned getNumReductionLoops() {
155  return getNumIterators(
156  getReductionIteratorTypeName(),
157  cast<ConcreteType>(this->getOperation()).iterator_types());
158  }
159  unsigned getNumWindowLoops() {
160  return getNumIterators(
161  getWindowIteratorTypeName(),
162  cast<ConcreteType>(this->getOperation()).iterator_types());
163  }
164  unsigned getNumLoops() {
165  return getNumIterators(
166  cast<ConcreteType>(this->getOperation()).iterator_types());
167  }
169  auto nOperands = cast<ConcreteType>(op).getNumInputsAndOutputs();
170  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands)))
171  return failure();
172  return success();
173  }
174 };
175 
176 } // namespace linalg
177 } // namespace OpTrait
178 } // namespace mlir
179 
180 #endif // MLIR_DIALECT_LINALG_LINALGTRAITS_H_
Definition: InferTypeOpInterface.cpp:20
Operation::operand_range getInputs()
Return the range over inputs.
Definition: LinalgTraits.h:89
ShapedType getInputShapedType(unsigned i)
Return the i-th input buffer type.
Definition: LinalgTraits.h:85
Definition: Operation.h:27
bool hasBufferSemantics()
Query whether the op has only MemRef input and outputs.
Definition: LinalgTraits.h:110
ShapedType getShapedType(unsigned i)
Return the i-th buffer type.
Definition: LinalgTraits.h:140
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s (Value ).
Definition: Operation.h:220
Value getOperand(unsigned idx)
Definition: Operation.h:207
bool failed(LogicalResult result)
Definition: LogicalResult.h:45
Definition: LinalgTraits.h:58
Definition: LLVM.h:40
static LogicalResult verifyTrait(Operation *op)
Definition: LinalgTraits.h:168
unsigned getNumReductionLoops()
Definition: LinalgTraits.h:154
Definition: LinalgTraits.h:41
Definition: StandardTypes.h:178
Value getOutput(unsigned i)
Return the i-th output.
Definition: LinalgTraits.h:94
Definition: LinalgTraits.h:30
Operation::operand_range getOutputs()
Return the range over outputs.
Definition: LinalgTraits.h:132
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.cpp:548
LogicalResult success(bool isSuccess=true)
Definition: LogicalResult.h:25
unsigned getNumLoops()
Definition: LinalgTraits.h:164
Definition: LogicalResult.h:18
LogicalResult failure(bool isFailure=true)
Definition: LogicalResult.h:32
unsigned getNumIterators(StringRef name, ArrayAttr iteratorTypes)
Returns the iterator of a certain type.
Definition: StructuredOpsUtils.h:87
Type getType() const
Return the type of this value.
Definition: Value.cpp:34
static unsigned getNumOutputs()
Definition: LinalgTraits.h:46
static unsigned getNumInputs()
Definition: LinalgTraits.h:32
Definition: LinalgTraits.h:27
Operation::operand_range getInputsAndOutputs()
Return the range over inputs and outputs.
Definition: LinalgTraits.h:145
Definition: StandardTypes.h:390
SmallVector< RankedTensorType, 4 > getInputTensorTypes()
Query the subset of input operands that are of ranked tensor type.
Definition: LinalgTraits.h:116
Value getInput(unsigned i)
Return the i-th input value.
Definition: LinalgTraits.h:72
ShapedType getOutputShapedType(unsigned i)
Return the i-th output buffer type.
Definition: LinalgTraits.h:106
Definition: Types.h:84
Definition: Value.h:38
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
Definition: Operation.cpp:768
Definition: LLVM.h:35
Definition: OpDefinition.h:386
Definition: LinalgTraits.h:44
Optional< unsigned > getIndexOfOutput(Value value)
Definition: LinalgTraits.h:99
Optional< unsigned > getIndexOfInput(Value value)
Definition: LinalgTraits.h:78
This class implements the operand iterators for the Operation class.
Definition: OperationSupport.h:559
unsigned getNumParallelLoops()
Definition: LinalgTraits.h:149
Definition: StandardTypes.h:63
mlir::edsc::intrinsics::ValueBuilder< RangeOp > range
Definition: Intrinsics.h:23
unsigned getNumInputsAndOutputs()
Return the number of inputs and outputs.
Definition: LinalgTraits.h:138
SmallVector< RankedTensorType, 4 > getOutputTensorTypes()
Query the subset of output operands that are of ranked tensor type.
Definition: LinalgTraits.h:124
bool isa() const
Definition: Types.h:254
unsigned getNumWindowLoops()
Definition: LinalgTraits.h:159
Operation * getOperation()
Return the ultimate Operation being worked on.
Definition: OpDefinition.h:389