My Project
StandardTypes.h
Go to the documentation of this file.
1 //===- StandardTypes.h - MLIR Standard Type Classes -------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_STANDARDTYPES_H
10 #define MLIR_IR_STANDARDTYPES_H
11 
12 #include "mlir/IR/Types.h"
13 
14 namespace llvm {
15 struct fltSemantics;
16 } // namespace llvm
17 
18 namespace mlir {
19 class AffineExpr;
20 class AffineMap;
21 class FloatType;
22 class IndexType;
23 class IntegerType;
24 class Location;
25 class MLIRContext;
26 
27 namespace detail {
28 
29 struct IntegerTypeStorage;
30 struct ShapedTypeStorage;
31 struct VectorTypeStorage;
32 struct RankedTensorTypeStorage;
33 struct UnrankedTensorTypeStorage;
34 struct MemRefTypeStorage;
35 struct UnrankedMemRefTypeStorage;
36 struct ComplexTypeStorage;
37 struct TupleTypeStorage;
38 
39 } // namespace detail
40 
41 namespace StandardTypes {
42 enum Kind {
43  // Floating point.
44  BF16 = Type::Kind::FIRST_STANDARD_TYPE,
45  F16,
46  F32,
47  F64,
50 
51  // Target pointer sized integer, used (e.g.) in affine mappings.
53 
54  // Derived types.
64 };
65 
66 } // namespace StandardTypes
67 
70 class IndexType : public Type::TypeBase<IndexType, Type> {
71 public:
72  using Base::Base;
73 
75  static IndexType get(MLIRContext *context);
76 
78  static bool kindof(unsigned kind) { return kind == StandardTypes::Index; }
79 };
80 
83  : public Type::TypeBase<IntegerType, Type, detail::IntegerTypeStorage> {
84 public:
85  using Base::Base;
86 
90  static IntegerType get(unsigned width, MLIRContext *context);
91 
95  static IntegerType getChecked(unsigned width, MLIRContext *context,
96  Location location);
97 
99  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
100  MLIRContext *context,
101  unsigned width);
102 
104  unsigned getWidth() const;
105 
107  static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; }
108 
110  static constexpr unsigned kMaxWidth = 4096;
111 };
112 
113 class FloatType : public Type::TypeBase<FloatType, Type> {
114 public:
115  using Base::Base;
116 
117  static FloatType get(StandardTypes::Kind kind, MLIRContext *context);
118 
119  // Convenience factories.
121  return get(StandardTypes::BF16, ctx);
122  }
123  static FloatType getF16(MLIRContext *ctx) {
124  return get(StandardTypes::F16, ctx);
125  }
126  static FloatType getF32(MLIRContext *ctx) {
127  return get(StandardTypes::F32, ctx);
128  }
129  static FloatType getF64(MLIRContext *ctx) {
130  return get(StandardTypes::F64, ctx);
131  }
132 
134  static bool kindof(unsigned kind) {
137  }
138 
140  unsigned getWidth();
141 
143  const llvm::fltSemantics &getFloatSemantics();
144 };
145 
152  : public Type::TypeBase<ComplexType, Type, detail::ComplexTypeStorage> {
153 public:
154  using Base::Base;
155 
157  static ComplexType get(Type elementType);
158 
162  static ComplexType getChecked(Type elementType, Location location);
163 
165  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
166  MLIRContext *context,
167  Type elementType);
168 
169  Type getElementType();
170 
171  static bool kindof(unsigned kind) { return kind == StandardTypes::Complex; }
172 };
173 
178 class ShapedType : public Type {
179 public:
181  using Type::Type;
182 
183  // TODO(ntv): merge these two special values in a single one used everywhere.
184  // Unfortunately, uses of `-1` have crept deep into the codebase now and are
185  // hard to track.
186  static constexpr int64_t kDynamicSize = -1;
187  static constexpr int64_t kDynamicStrideOrOffset =
188  std::numeric_limits<int64_t>::min();
189 
191  Type getElementType() const;
192 
195  unsigned getElementTypeBitWidth() const;
196 
198  int64_t getNumElements() const;
199 
201  int64_t getRank() const;
202 
205  bool hasRank() const;
206 
208  ArrayRef<int64_t> getShape() const;
209 
213  bool hasStaticShape() const;
214 
216  bool hasStaticShape(ArrayRef<int64_t> shape) const;
217 
220  int64_t getNumDynamicDims() const;
221 
224  int64_t getDimSize(int64_t i) const;
225 
228  unsigned getDynamicDimIndex(unsigned index) const;
229 
236  int64_t getSizeInBits() const;
237 
239  static bool classof(Type type) {
240  return type.getKind() == StandardTypes::Vector ||
244  type.getKind() == StandardTypes::MemRef;
245  }
246 
248  static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; }
249  static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
250  return dStrideOrOffset == kDynamicStrideOrOffset;
251  }
252 };
253 
257  : public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> {
258 public:
259  using Base::Base;
260 
263  static VectorType get(ArrayRef<int64_t> shape, Type elementType);
264 
269  static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType,
270  Location location);
271 
273  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
274  MLIRContext *context,
275  ArrayRef<int64_t> shape,
276  Type elementType);
277 
280  static bool isValidElementType(Type t) { return t.isIntOrFloat(); }
281 
282  ArrayRef<int64_t> getShape() const;
283 
285  static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; }
286 };
287 
290 class TensorType : public ShapedType {
291 public:
292  using ShapedType::ShapedType;
293 
295  static bool isValidElementType(Type type) {
296  // Note: Non standard/builtin types are allowed to exist within tensor
297  // types. Dialects are expected to verify that tensor types have a valid
298  // element type within that dialect.
299  return type.isIntOrFloat() || type.isa<ComplexType>() ||
300  type.isa<VectorType>() || type.isa<OpaqueType>() ||
301  (type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
302  }
303 
305  static bool classof(Type type) {
306  return type.getKind() == StandardTypes::RankedTensor ||
308  }
309 };
310 
315  : public Type::TypeBase<RankedTensorType, TensorType,
316  detail::RankedTensorTypeStorage> {
317 public:
318  using Base::Base;
319 
322  static RankedTensorType get(ArrayRef<int64_t> shape, Type elementType);
323 
328  static RankedTensorType getChecked(ArrayRef<int64_t> shape, Type elementType,
329  Location location);
330 
332  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
333  MLIRContext *context,
334  ArrayRef<int64_t> shape,
335  Type elementType);
336 
337  ArrayRef<int64_t> getShape() const;
338 
339  static bool kindof(unsigned kind) {
340  return kind == StandardTypes::RankedTensor;
341  }
342 };
343 
347  : public Type::TypeBase<UnrankedTensorType, TensorType,
348  detail::UnrankedTensorTypeStorage> {
349 public:
350  using Base::Base;
351 
354  static UnrankedTensorType get(Type elementType);
355 
360  static UnrankedTensorType getChecked(Type elementType, Location location);
361 
363  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
364  MLIRContext *context,
365  Type elementType);
366 
368 
369  static bool kindof(unsigned kind) {
370  return kind == StandardTypes::UnrankedTensor;
371  }
372 };
373 
375 class BaseMemRefType : public ShapedType {
376 public:
377  using ShapedType::ShapedType;
378 
380  static bool classof(Type type) {
381  return type.getKind() == StandardTypes::MemRef ||
383  }
384 };
385 
390 class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
391  detail::MemRefTypeStorage> {
392 public:
393  using Base::Base;
394 
399  static MemRefType get(ArrayRef<int64_t> shape, Type elementType,
400  ArrayRef<AffineMap> affineMapComposition = {},
401  unsigned memorySpace = 0);
402 
409  static MemRefType getChecked(ArrayRef<int64_t> shape, Type elementType,
410  ArrayRef<AffineMap> affineMapComposition,
411  unsigned memorySpace, Location location);
412 
413  ArrayRef<int64_t> getShape() const;
414 
417  ArrayRef<AffineMap> getAffineMaps() const;
418 
420  unsigned getMemorySpace() const;
421 
422  // TODO(ntv): merge these two special values in a single one used everywhere.
423  // Unfortunately, uses of `-1` have crept deep into the codebase now and are
424  // hard to track.
425  static constexpr int64_t kDynamicSize = -1;
426  static int64_t getDynamicStrideOrOffset() {
427  return ShapedType::kDynamicStrideOrOffset;
428  }
429 
430  static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }
431 
432 private:
436  static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType,
437  ArrayRef<AffineMap> affineMapComposition,
438  unsigned memorySpace, Optional<Location> location);
439  using Base::getImpl;
440 };
441 
445  : public Type::TypeBase<UnrankedMemRefType, BaseMemRefType,
446  detail::UnrankedMemRefTypeStorage> {
447 public:
448  using Base::Base;
449 
452  static UnrankedMemRefType get(Type elementType, unsigned memorySpace);
453 
458  static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace,
459  Location location);
460 
462  static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
463  MLIRContext *context,
464  Type elementType,
465  unsigned memorySpace);
466 
468 
470  unsigned getMemorySpace() const;
471  static bool kindof(unsigned kind) {
472  return kind == StandardTypes::UnrankedMemRef;
473  }
474 };
475 
482  : public Type::TypeBase<TupleType, Type, detail::TupleTypeStorage> {
483 public:
484  using Base::Base;
485 
488  static TupleType get(ArrayRef<Type> elementTypes, MLIRContext *context);
489 
491  static TupleType get(MLIRContext *context) { return get({}, context); }
492 
494  ArrayRef<Type> getTypes() const;
495 
501 
503  size_t size() const;
504 
507  iterator begin() const { return getTypes().begin(); }
508  iterator end() const { return getTypes().end(); }
509 
511  Type getType(size_t index) const {
512  assert(index < size() && "invalid index for tuple type");
513  return getTypes()[index];
514  }
515 
516  static bool kindof(unsigned kind) { return kind == StandardTypes::Tuple; }
517 };
518 
521 class NoneType : public Type::TypeBase<NoneType, Type> {
522 public:
523  using Base::Base;
524 
526  static NoneType get(MLIRContext *context);
527 
528  static bool kindof(unsigned kind) { return kind == StandardTypes::None; }
529 };
530 
553  SmallVectorImpl<int64_t> &strides,
554  int64_t &offset);
557  AffineExpr &offset);
558 
577  MLIRContext *context);
578 
584 
586 bool isStrided(MemRefType t);
587 
588 } // end namespace mlir
589 
590 #endif // MLIR_IR_STANDARDTYPES_H
Definition: InferTypeOpInterface.cpp:20
Definition: StandardTypes.h:61
Definition: PassRegistry.cpp:413
static bool kindof(unsigned kind)
Definition: StandardTypes.h:369
Integer types can have arbitrary bitwidth up to a large fixed limit.
Definition: StandardTypes.h:82
static bool kindof(unsigned kind)
Definition: StandardTypes.h:516
Definition: Attributes.h:139
Definition: StandardTypes.h:59
static bool kindof(unsigned kind)
Definition: StandardTypes.h:528
Definition: Attributes.h:129
Shaped Type Storage.
Definition: TypeDetail.h:112
static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset)
Definition: StandardTypes.h:249
static bool kindof(unsigned kind)
Methods for support type inquiry through isa, cast, and dyn_cast.
Definition: StandardTypes.h:285
static FloatType getF64(MLIRContext *ctx)
Definition: StandardTypes.h:129
SmallVector< Type, 10 > getFlattenedTypes(TupleType t)
Definition: TypeUtilities.cpp:35
Definition: StandardTypes.h:113
Definition: StandardTypes.h:52
bool isIntOrFloat()
Return true of this is an integer or a float type.
Definition: StandardTypes.cpp:45
Definition: LLVM.h:40
Definition: StandardTypes.h:314
Definition: StandardTypes.h:151
static FloatType getF32(MLIRContext *ctx)
Definition: StandardTypes.h:126
Definition: StandardTypes.h:46
Kind
Definition: StandardTypes.h:42
Definition: StandardTypes.h:47
Definition: LLVM.h:34
static bool kindof(unsigned kind)
Methods for support type inquiry through isa, cast, and dyn_cast.
Definition: StandardTypes.h:134
Definition: StandardTypes.h:178
static bool isValidElementType(Type type)
Return true if the specified element type is ok in a tensor.
Definition: StandardTypes.h:295
Definition: Location.h:52
static bool kindof(unsigned kind)
Support method to enable LLVM-style type casting.
Definition: StandardTypes.h:78
static FloatType getBF16(MLIRContext *ctx)
Definition: StandardTypes.h:120
Definition: StandardTypes.h:49
static bool kindof(unsigned kind)
Definition: StandardTypes.h:171
Definition: LogicalResult.h:18
Definition: StandardTypes.h:56
Definition: LLVM.h:37
unsigned getKind() const
Return the classification for this type.
Definition: Types.cpp:22
bool isStrided(MemRefType t)
Return true if the layout for t is compatible with strided semantics.
Definition: StandardTypes.cpp:734
Definition: StandardTypes.h:58
Type getType(size_t index) const
Return the element type at index &#39;index&#39;.
Definition: StandardTypes.h:511
Definition: StandardTypes.h:62
Definition: AffineExpr.h:66
Definition: StandardTypes.h:390
Definition: StandardTypes.h:45
Definition: StandardTypes.h:481
Definition: StandardTypes.h:44
static bool isValidElementType(Type t)
Definition: StandardTypes.h:280
Definition: AffineMap.h:37
ArrayRef< int64_t > getShape() const
Definition: StandardTypes.h:367
Definition: StandardTypes.h:60
static bool classof(Type type)
Methods for support type inquiry through isa, cast, and dyn_cast.
Definition: StandardTypes.h:380
Definition: StandardTypes.h:290
static bool kindof(unsigned kind)
Definition: StandardTypes.h:430
ArrayRef< int64_t > getShape() const
Definition: StandardTypes.h:467
static FloatType getF16(MLIRContext *ctx)
Definition: StandardTypes.h:123
Definition: Types.h:84
Definition: StandardTypes.h:55
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< AffineExpr > &strides, AffineExpr &offset)
Definition: StandardTypes.cpp:523
Definition: StandardTypes.h:256
Definition: StandardTypes.h:57
Base MemRef for Ranked and Unranked variants.
Definition: StandardTypes.h:375
Definition: StandardTypes.h:444
Definition: Types.h:219
Definition: StorageUniquerSupport.h:30
static bool kindof(unsigned kind)
Definition: StandardTypes.h:471
Definition: MLIRContext.h:34
AffineMap makeStridedLinearLayoutMap(ArrayRef< int64_t > strides, int64_t offset, MLIRContext *context)
Definition: StandardTypes.cpp:666
static constexpr bool isDynamic(int64_t dSize)
Whether the given dimension size indicates a dynamic dimension.
Definition: StandardTypes.h:248
MemRefType canonicalizeStridedLayout(MemRefType t)
Definition: StandardTypes.cpp:707
static bool kindof(unsigned kind)
Definition: StandardTypes.h:339
Definition: StandardTypes.h:63
Definition: StandardTypes.h:346
iterator end() const
Definition: StandardTypes.h:508
static bool classof(Type type)
Methods for support type inquiry through isa, cast, and dyn_cast.
Definition: StandardTypes.h:239
bool isa() const
Definition: Types.h:254
Definition: StandardTypes.h:70
iterator begin() const
Definition: StandardTypes.h:507
static bool classof(Type type)
Methods for support type inquiry through isa, cast, and dyn_cast.
Definition: StandardTypes.h:305
static bool kindof(unsigned kind)
Methods for support type inquiry through isa, cast, and dyn_cast.
Definition: StandardTypes.h:107
ArrayRef< Type >::iterator iterator
Iterate over the held elements.
Definition: StandardTypes.h:506
static int64_t getDynamicStrideOrOffset()
Definition: StandardTypes.h:426
Definition: StandardTypes.h:521