My Project
VectorTransforms.h
Go to the documentation of this file.
1 //===- VectorTransforms.h - Vector transformations as patterns --*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef DIALECT_VECTOROPS_VECTORTRANSFORMS_H_
10 #define DIALECT_VECTOROPS_VECTORTRANSFORMS_H_
11 
12 #include "mlir/IR/PatternMatch.h"
13 
14 namespace mlir {
15 class MLIRContext;
16 class OwningRewritePatternList;
17 
21  MLIRContext *context, OwningRewritePatternList &patterns,
22  ArrayRef<int64_t> coarseVectorShape = {},
23  ArrayRef<int64_t> fineVectorShape = {});
24 
26 // The following Declarative Rewrite Rule (DRR) helpers are used in rewrite
27 // patterns. As such, they must not call into `rewriter.erase/replace` APIs and
28 // it is the responsibility of the enclosing PatternRewriter to erase on
29 // success.
31 
32 namespace vector {
33 
34 // Entry point for unrolling declarative pattern rewrites.
35 // `op` is unrolled to the `targetShape` as follows, for each of its operands:
36 // 1. the unrolled type `unrolledVectorType` and number of unrolled instances
37 // `numUnrolledInstances` are computed from the `targetShape`. For now it is
38 // assumed the unrolling factors divide the vector sizes.
39 // 2. a fakeFork cast op is inserted that takes the operand and returns
40 // `numUnrolledInstances` results of type `unrolledVectorType`.
41 // 3. the original op is cloned `numUnrolledInstances` times, once for each
42 // result of the fakeFork cast op.
43 // 4. a fakeJoin cast op takes all these results and merges them into a single
44 // aggregate vector result whose size matches the original non-unrolled op
45 // operand types.
46 //
47 // Example:
48 //
49 // opA(operand0, operand1) // numUnrolledInstances = 3
50 //
51 // operand0 operand1
52 // | |
53 // fork fork
54 // <----------gather all fork ops --------->
55 // /|\ /|\
56 // f00 f01 f02 f10 f11 f12
57 // <---------- clone op 3 times --------->
58 // opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
59 // \ | /
60 // <-------------------- join ------------------------->
61 //
62 // Other local patterns then kick in iteratively (including DCE) and compose
63 // until all the fakeFork and fakeJoin ops are removed.
64 //
65 // This will be extended in the future to support more advanced use cases than
66 // simple pointwise ops.
67 SmallVector<Value, 1>
68 unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op,
69  ArrayRef<int64_t> targetShape);
70 
71 } // namespace vector
72 } // namespace mlir
73 
74 #endif // DIALECT_VECTOROPS_VECTORTRANSFORMS_H_
Definition: InferTypeOpInterface.cpp:20
SmallVector< Value, 1 > unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op, ArrayRef< int64_t > targetShape)
Definition: VectorTransforms.cpp:465
void populateVectorToVectorConversionPatterns(MLIRContext *context, OwningRewritePatternList &patterns, ArrayRef< int64_t > coarseVectorShape={}, ArrayRef< int64_t > fineVectorShape={})