My Project
OpToFuncCallLowering.h
Go to the documentation of this file.
1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
10 
15 #include "mlir/IR/Builders.h"
16 
17 namespace mlir {
18 
28 template <typename SourceOp>
30 public:
31  explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func,
32  StringRef f64Func)
33  : LLVMOpLowering(SourceOp::getOperationName(),
34  lowering_.getDialect()->getContext(), lowering_),
35  f32Func(f32Func), f64Func(f64Func) {}
36 
39  ConversionPatternRewriter &rewriter) const override {
40  using LLVM::LLVMFuncOp;
41  using LLVM::LLVMType;
42 
43  static_assert(
44  std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
45  "expected single result op");
46 
47  LLVMType resultType = lowering.convertType(op->getResult(0)->getType())
48  .template cast<LLVM::LLVMType>();
49  LLVMType funcType = getFunctionType(resultType, operands);
50  StringRef funcName = getFunctionName(resultType);
51  if (funcName.empty())
52  return matchFailure();
53 
54  LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
55  auto callOp = rewriter.create<LLVM::CallOp>(
56  op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands);
57  rewriter.replaceOp(op, {callOp.getResult(0)});
58  return matchSuccess();
59  }
60 
61 private:
62  LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType,
63  ArrayRef<Value> operands) const {
64  using LLVM::LLVMType;
65  SmallVector<LLVMType, 1> operandTypes;
66  for (Value operand : operands) {
67  operandTypes.push_back(operand->getType().cast<LLVMType>());
68  }
69  return LLVMType::getFunctionTy(resultType, operandTypes,
70  /*isVarArg=*/false);
71  }
72 
73  StringRef getFunctionName(LLVM::LLVMType type) const {
74  if (type.isFloatTy())
75  return f32Func;
76  if (type.isDoubleTy())
77  return f64Func;
78  return "";
79  }
80 
81  LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName,
82  LLVM::LLVMType funcType,
83  Operation *op) const {
84  using LLVM::LLVMFuncOp;
85 
86  Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcName);
87  if (funcOp)
88  return cast<LLVMFuncOp>(*funcOp);
89 
90  mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
91  return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
92  }
93 
94  const std::string f32Func;
95  const std::string f64Func;
96 };
97 
98 } // namespace mlir
99 
100 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
Definition: InferTypeOpInterface.cpp:20
Definition: Operation.h:27
Definition: OpDefinition.h:29
LLVMTypeConverter & lowering
Definition: ConvertStandardToLLVM.h:239
FlatSymbolRefAttr getSymbolRefAttr(Operation *value)
Definition: Builders.cpp:153
OpTy create(Location location, Args... args)
Definition: PatternMatch.h:265
bool isDoubleTy()
Definition: LLVMDialect.h:60
Definition: LLVM.h:40
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type &#39;OpTy&#39;.
Definition: Operation.h:121
void replaceOp(Operation *op, ValueRange newValues, ValueRange valuesToRemoveIfDead) override
PatternRewriter hook for replacing the results of an operation.
Definition: DialectConversion.cpp:841
Definition: LLVM.h:37
OpTy create(Location location, Args &&... args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:294
bool isFloatTy()
Utilities to identify types.
Definition: LLVMDialect.h:59
Type getType() const
Return the type of this value.
Definition: Value.cpp:34
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:246
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:107
static Operation * lookupNearestSymbolFrom(Operation *from, StringRef symbol)
Definition: SymbolTable.cpp:125
Definition: Value.h:38
PatternMatchResult matchSuccess(std::unique_ptr< PatternState > state={}) const
This method indicates that a match was found and has the specified cost.
Definition: PatternMatch.h:116
Definition: LLVM.h:35
Conversion from types in the Standard dialect to the LLVM IR dialect.
Definition: ConvertStandardToLLVM.h:37
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef< Value > operands, ConversionPatternRewriter &rewriter) const override
Hook for derived classes to implement combined matching and rewriting.
Definition: OpToFuncCallLowering.h:38
Definition: DialectConversion.h:311
static PatternMatchResult matchFailure()
This method indicates that no match was found.
Definition: PatternMatch.h:112
Definition: ConvertStandardToLLVM.h:231
Definition: Builders.h:158
Definition: OpToFuncCallLowering.h:29
OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func, StringRef f64Func)
Definition: OpToFuncCallLowering.h:31
Type convertType(Type t) override
Definition: ConvertStandardToLLVM.cpp:2149
Definition: LLVMDialect.h:44