My Project
Statistics.h
Go to the documentation of this file.
1 //===- Statistics.h - Collects statistics over tensors ----------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines adapters for extracting various (per layer and per axis)
10 // statistics over tensors.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_QUANTIZER_SUPPORT_STATISTICS_H
15 #define MLIR_QUANTIZER_SUPPORT_STATISTICS_H
16 
17 #include "mlir/IR/Attributes.h"
18 
19 namespace mlir {
20 namespace quantizer {
21 
24  int64_t sampleSize = 0;
25  double minValue = 0;
26  double maxValue = 0;
27  double mean = 0;
28  double variance = 0;
29 
31  TensorAxisStatistics(int64_t sampleSize, double minValue, double maxValue,
32  double mean, double variance)
33  : sampleSize(sampleSize), minValue(minValue), maxValue(maxValue),
34  mean(mean), variance(variance) {}
35  void clear() { *this = TensorAxisStatistics(); }
36 };
37 
40 public:
41  virtual ~AbstractTensorStatistics() = default;
42 
45  virtual bool get(TensorAxisStatistics &stats) const { return false; }
46 
49  virtual bool supportsPerAxis() const { return false; }
50 
52  virtual unsigned getAxisCount() const { return 0; }
53 
56  virtual bool getForAxis(unsigned axis, TensorAxisStatistics &stats) const {
57  return false;
58  }
59 };
60 
68 public:
69  AttributeTensorStatistics(Attribute attr) : attr(attr) {}
70 
71  bool get(TensorAxisStatistics &stats) const override;
72 
73  // TODO: Implement per-axis.
74 
75 private:
76  Attribute attr;
77 };
78 
79 raw_ostream &operator<<(raw_ostream &os, const TensorAxisStatistics &stats);
80 
81 } // end namespace quantizer
82 } // end namespace mlir
83 
84 #endif // MLIR_QUANTIZER_SUPPORT_STATISTICS_H
Definition: InferTypeOpInterface.cpp:20
double minValue
Definition: Statistics.h:25
virtual bool supportsPerAxis() const
Definition: Statistics.h:49
AttributeTensorStatistics(Attribute attr)
Definition: Statistics.h:69
Statistics about a tensor axis (or the whole tensor).
Definition: Statistics.h:23
virtual unsigned getAxisCount() const
Count of axes supported in a per-axis query.
Definition: Statistics.h:52
double maxValue
Definition: Statistics.h:26
TensorAxisStatistics(int64_t sampleSize, double minValue, double maxValue, double mean, double variance)
Definition: Statistics.h:31
raw_ostream & operator<<(raw_ostream &os, const CAGNode &node)
Definition: ConstraintAnalysisGraph.h:352
Definition: Attributes.h:53
double mean
Definition: Statistics.h:27
double variance
Definition: Statistics.h:28
virtual bool getForAxis(unsigned axis, TensorAxisStatistics &stats) const
Definition: Statistics.h:56
int64_t sampleSize
Definition: Statistics.h:24
TensorAxisStatistics()
Definition: Statistics.h:30
void clear()
Definition: Statistics.h:35
Base class for querying statistics about a tensor.
Definition: Statistics.h:39