convert_precision.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <vector>
8 #include <memory>
9 #include <algorithm>
10 #include <unordered_map>
11 
12 #include <transformations_visibility.hpp>
13 
14 #include <ngraph/pass/pass.hpp>
15 #include <ngraph/opsets/opset3.hpp>
16 #include <ngraph/validation_util.hpp>
17 #include <ngraph/rt_info.hpp>
18 #include <ngraph/pass/graph_rewrite.hpp>
19 
20 
21 namespace ngraph {
22 namespace pass {
23 
24 class NGRAPH_API ConvertPrecision;
25 
26 } // namespace pass
27 } // namespace ngraph
28 
29 /**
30  * @ingroup ie_transformation_common_api
31  * @brief ConvertPrecision transformation convert precision for entire ngraph::Function
32  * List of supported precision conversion:
33  * FROM -> TO
34  * u8 -> i32
35  * u16 -> i32
36  * u32 -> i32
37  * u64 -> i32
38  * i64 -> i32
39  * f16 -> f32
40  * bool -> u8
41  * bool -> i32
42  *
43  * For all operations from opset1-opset4 this conversions can be applied without adding Conversion operations.
44  * That is possible because all operations that produces "FROM" type can produce "TO" type. And for this operations
45  * we have created special fuse_type_into_<type> functoin (can be found in cpp file) that performs type fusion
46  * into operation.
47  *
48  * List of operations that are supported by this transformations for i64 -> i32 conversion:
49  * opset4::Parameter
50  * opset4::Convert
51  * opset4::ShapeOf
52  * opset3::NonMaxSuppression
53  * opset4::NonMaxSuppression
54  * opset4::TopK
55  * opset4::NonZero
56  * opset4::Bucketize
57  *
58  * List of operations that are supported by this transformations for bool -> u8 conversion:
59  * LogicalAnd
60  * LogicalNot
61  * LogicalOr
62  * LogicalXor
63  * ReduceLogicalAnd
64  * ReduceLogicalOr
65  * Equal
66  * NotEqual
67  * Greater
68  * GreaterEqual
69  * Less
70  * LessEqual
71  */
72 
73 using type_to_fuse_map = std::unordered_map<ngraph::NodeTypeInfo, std::function<bool(const std::shared_ptr<ngraph::Node>&, ngraph::element::Type, size_t idx)>>;
74 using precisions_array = std::vector<std::pair<ngraph::element::Type, ngraph::element::Type>>;
75 
76 class ngraph::pass::ConvertPrecision : public ngraph::pass::FunctionPass {
77 public:
78  NGRAPH_RTTI_DECLARATION;
79  ConvertPrecision(ngraph::element::Type_t from, ngraph::element::Type_t to, type_to_fuse_map additional_type_to_fuse_map = {})
80  : FunctionPass(),
81  m_precisions(precisions_array {{ from, to }}),
82  m_additional_type_to_fuse_map(additional_type_to_fuse_map) {}
83 
84  ConvertPrecision(const precisions_array& precisions, const type_to_fuse_map & additional_type_to_fuse_map = {})
85  : FunctionPass(),
86  m_precisions(precisions),
87  m_additional_type_to_fuse_map(additional_type_to_fuse_map) {}
88 
89  bool run_on_function(std::shared_ptr<Function> f) override;
90 private:
91  precisions_array m_precisions;
92  type_to_fuse_map m_additional_type_to_fuse_map;
93 };
Definition: convert_precision.hpp:76
ngraph namespace
Definition: add_fake_quantize_fusion.hpp:14