binarize_weights.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <transformations_visibility.hpp>
8 #include <ngraph/pass/graph_rewrite.hpp>
9 
10 namespace ngraph {
11 namespace pass {
12 
13 class TRANSFORMATIONS_API BinarizeWeights;
14 
15 } // namespace pass
16 } // namespace ngraph
17 
18 /**
19  * @ingroup ie_transformation_common_api
20  * @brief This transformation converts weights to -1/+1 form
21  * and applies normalization factors to output low/high and after Convolution.
22  * For example, following graph
23  *
24  * .... .... out_low out_high weights .. .. out_low out_high
25  * | | | | | | | | |
26  * +--------------------------+ +--------------------------+
27  * | FakeQuantize (levels==2) | | FakeQuantize (levels==2) |
28  * | (on activations) | | (on weights) |
29  * +--------------------------+ +--------------------------+
30  * | |
31  * | |
32  * ----------------- -------------------
33  * | |
34  * v v
35  * +-------------+
36  * | Convolution |
37  * +-------------+
38  * |
39  * v
40  *
41  * is transformed to:
42  *
43  * normalized normalized
44  * .... .... out_low out_high
45  * | | | |
46  * +--------------------------+ +--------------------------+
47  * | FakeQuantize (levels==2) | | Constant |
48  * | (on activations) | | (with converted weights) |
49  * +--------------------------+ +--------------------------+
50  * | |
51  * | |
52  * ----------------- -------------------
53  * | |
54  * v v
55  * +-------------+
56  * | Convolution |
57  * +-------------+
58  * |
59  * v
60  * +------------+ +---------------------------------------------------------------+
61  * | Multiply | <---| Constant (normalization factor coming from FQ on activations) |
62  * +------------+ +---------------------------------------------------------------+
63  * |
64  * v
65  * +------------+ +-----------------------------------------------------------+
66  * | Multiply | <---| Constant (normalization factor coming from FQ on weights) |
67  * +------------+ +------------------------------------------------------------
68  * |
69  * v
70  *
71  * Normalization factors are chosen based output_high value.
72  * If it's zero - norm factor is equal to output_low and output_high otherwise
73  */
74 
75 class ngraph::pass::BinarizeWeights : public ngraph::pass::MatcherPass {
76 public:
77  NGRAPH_RTTI_DECLARATION;
79 };
This transformation converts weights to -1/+1 form and applies normalization factors to output low/hi...
Definition: binarize_weights.hpp:75
ngraph namespace
Definition: add_fake_quantize_fusion.hpp:14