My Project
IndexIntrinsicsOpLowering.h
Go to the documentation of this file.
1 //===- IndexIntrinsicsOpLowering.h - GPU IndexOps Lowering class *- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
10 
14 
15 #include "llvm/ADT/StringSwitch.h"
16 
17 namespace mlir {
18 
19 // Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension
20 // that Op operates on. Op is assumed to return an `std.index` value and
21 // XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on
22 // `indexBitwidth`, sign-extend or truncate the resulting value to match the
23 // bitwidth expected by the consumers of the value.
24 template <typename Op, typename XOp, typename YOp, typename ZOp>
26 private:
27  enum dimension { X = 0, Y = 1, Z = 2, invalid };
28  unsigned indexBitwidth;
29 
30  static dimension dimensionToIndex(Op op) {
31  return llvm::StringSwitch<dimension>(op.dimension())
32  .Case("x", X)
33  .Case("y", Y)
34  .Case("z", Z)
35  .Default(invalid);
36  }
37 
38  static unsigned getIndexBitWidth(LLVMTypeConverter &type_converter) {
39  auto dialect = type_converter.getDialect();
40  return dialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
41  }
42 
43 public:
45  : LLVMOpLowering(Op::getOperationName(),
46  lowering_.getDialect()->getContext(), lowering_),
47  indexBitwidth(getIndexBitWidth(lowering_)) {}
48 
49  // Convert the kernel arguments to an LLVM type, preserve the rest.
52  ConversionPatternRewriter &rewriter) const override {
53  auto loc = op->getLoc();
54  auto dialect = lowering.getDialect();
55  Value newOp;
56  switch (dimensionToIndex(cast<Op>(op))) {
57  case X:
58  newOp = rewriter.create<XOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
59  break;
60  case Y:
61  newOp = rewriter.create<YOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
62  break;
63  case Z:
64  newOp = rewriter.create<ZOp>(loc, LLVM::LLVMType::getInt32Ty(dialect));
65  break;
66  default:
67  return matchFailure();
68  }
69 
70  if (indexBitwidth > 32) {
71  newOp = rewriter.create<LLVM::SExtOp>(
72  loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
73  } else if (indexBitwidth < 32) {
74  newOp = rewriter.create<LLVM::TruncOp>(
75  loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp);
76  }
77 
78  rewriter.replaceOp(op, {newOp});
79  return matchSuccess();
80  }
81 };
82 
83 } // namespace mlir
84 
85 #endif // MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_
Definition: InferTypeOpInterface.cpp:20
Definition: IndexIntrinsicsOpLowering.h:25
Definition: Operation.h:27
static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits)
Utilities used to generate integer types.
Definition: LLVMDialect.cpp:1589
LLVMTypeConverter & lowering
Definition: ConvertStandardToLLVM.h:239
OpTy create(Location location, Args... args)
Definition: PatternMatch.h:265
Definition: LLVM.h:40
void replaceOp(Operation *op, ValueRange newValues, ValueRange valuesToRemoveIfDead) override
PatternRewriter hook for replacing the results of an operation.
Definition: DialectConversion.cpp:841
Definition: LLVM.h:37
static LLVMType getInt32Ty(LLVMDialect *dialect)
Definition: LLVMDialect.h:112
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:107
Definition: Value.h:38
PatternMatchResult matchSuccess(std::unique_ptr< PatternState > state={}) const
This method indicates that a match was found and has the specified cost.
Definition: PatternMatch.h:116
GPUIndexIntrinsicOpLowering(LLVMTypeConverter &lowering_)
Definition: IndexIntrinsicsOpLowering.h:44
Conversion from types in the Standard dialect to the LLVM IR dialect.
Definition: ConvertStandardToLLVM.h:37
Definition: DialectConversion.h:311
llvm::Module & getLLVMModule()
Definition: LLVMDialect.cpp:1436
Definition: OpDefinition.h:949
static PatternMatchResult matchFailure()
This method indicates that no match was found.
Definition: PatternMatch.h:112
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const override
Hook for derived classes to implement combined matching and rewriting.
Definition: IndexIntrinsicsOpLowering.h:51
Definition: ConvertStandardToLLVM.h:231
LLVM::LLVMDialect * getDialect()
Returns the LLVM dialect.
Definition: ConvertStandardToLLVM.h:63