convert_reduce_to_pooling.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <vector>
8 #include <memory>
9 #include <algorithm>
10 
11 #include <transformations_visibility.hpp>
12 
13 #include <ngraph/op/util/op_types.hpp>
14 #include <ngraph/pass/graph_rewrite.hpp>
15 #include <ngraph/opsets/opset1.hpp>
16 #include <ngraph/validation_util.hpp>
17 #include <ngraph/rt_info.hpp>
18 #include <ngraph/pattern/op/wrap_type.hpp>
19 
20 
21 namespace ngraph {
22 namespace pass {
23 
24 class TRANSFORMATIONS_API ConvertReduceToPooling;
25 class TRANSFORMATIONS_API ConvertReduceMeanToPooling;
26 class TRANSFORMATIONS_API ConvertReduceMaxToPooling;
27 class TRANSFORMATIONS_API ConvertReduceSumToPooling;
28 
29 } // namespace pass
30 } // namespace ngraph
31 
32 class ConvertReduceBase : public ngraph::pass::MatcherPass {
33 public:
34  template <class T>
35  ngraph::matcher_pass_callback convert_reduce_to_pooling();
36 };
37 
39 public:
40  NGRAPH_RTTI_DECLARATION;
42 };
43 
45 public:
46  NGRAPH_RTTI_DECLARATION;
48 };
49 
51 public:
52  NGRAPH_RTTI_DECLARATION;
54 };
55 
56 class ngraph::pass::ConvertReduceToPooling: public ngraph::pass::GraphRewrite {
57 public:
58  NGRAPH_RTTI_DECLARATION;
60  add_matcher<ConvertReduceMeanToPooling>();
61  add_matcher<ConvertReduceMaxToPooling>();
62  add_matcher<ConvertReduceSumToPooling>();
63  }
64 };
65 
66 template <class T>
67 ngraph::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() {
68  return [&](ngraph::pattern::Matcher& m) {
69  auto reduce = std::dynamic_pointer_cast<T>(m.get_match_root());
70 
71  if (!reduce || transformation_callback(reduce)) {
72  return false;
73  }
74 
75  auto input = reduce->input_value(0);
76 
77  auto axes_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(reduce->input_value(1).get_node_shared_ptr());
78  if (!axes_node) {
79  return false;
80  }
81 
82  auto axes_vector = axes_node->template cast_vector<int64_t>();
83  const auto input_rank = input.get_partial_shape().rank().get_length();
84  // Transform negative axes into non-negative ones
85  for (size_t i = 0; i < axes_vector.size(); ++i) {
86  if (axes_vector[i] < 0) {
87  axes_vector[i] += input_rank;
88  }
89  }
90  std::sort(axes_vector.begin(), axes_vector.end());
91 
92  // If axes are empty we just remove Reduction operation
93  if (axes_vector.empty()) {
94  return replace_output_update_name(reduce->output(0), input);
95  }
96 
97  auto input_shape = input.get_shape();
98 
99  // If Reduce op reduces only 1 dims we replace it with Reshape
100  if (std::all_of(axes_vector.begin(), axes_vector.end(),
101  [&input_shape](const int64_t & axis) { return input_shape[axis] == 1; })) {
102  const auto reshape_shape = reduce->output(0).get_shape();
103  auto reshape = std::make_shared<ngraph::opset1::Reshape>(input,
104  ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape), true);
105 
106  reshape->set_friendly_name(reduce->get_friendly_name());
107  copy_runtime_info(reduce, reshape);
108  replace_node(reduce, reshape);
109  return true;
110  }
111 
112  // Check that axes are consecutive otherwise this transformation is not applicable
113  for (size_t i = 1; i < axes_vector.size(); ++i) {
114  if (axes_vector[i] - axes_vector[i-1] != 1) {
115  return false;
116  }
117  }
118 
119  // Check either reduction applies to spatial dimensions or not
120  bool spatial_dims_reduction(true);
121  size_t reduction_dims_count = 1;
122  for (auto& axis : axes_vector) {
123  reduction_dims_count *= input_shape[axis];
124  if (axis <= 1) {
125  spatial_dims_reduction = false;
126  }
127  }
128 
129  /*
130  * Prepare default attributes for Pooling operation
131  * pads_begin/pads_end - should be zeros as we don't need any padding
132  * stride - should be filled with ones
133  * kernel - depends on Reduction operation axes
134  *
135  * Also here we decide should we use Reshapes before and after Pooling
136  * shape_begin - if not empty indicates that we need a Reshape before Pooling
137  * shape_end - if not empty indicates that we need a Reshape after Pooling
138  */
139 
140  ngraph::Strides strides;
141  ngraph::Shape pads_begin, pads_end, kernel, shape_begin, shape_end;
142 
143  if (!spatial_dims_reduction || input_shape.size() != 4) {
144  // In case if reduction applies not to spatial dimensions
145  // we have to fit it into 4D Pooling
146  size_t dims_prod = 1, dims_begin = 1, dims_end = 1;
147  for (int64_t i = 0; static_cast<size_t>(i) < input_shape.size(); ++i) {
148  if (i < *axes_vector.begin()) {
149  dims_begin *= input_shape[i];
150  } else if (i >= axes_vector.front() && i <= axes_vector.back()) {
151  dims_prod *= input_shape[i];
152  } else {
153  dims_end *= input_shape[i];
154  }
155  }
156  // The batch dimenstion is repositioned in the shape
157  // only in case of batch dimension reduction
158  shape_begin.assign({dims_begin, 1, dims_prod, dims_end});
159  shape_end = reduce->output(0).get_shape();
160  strides.assign({1, 1});
161  pads_begin.assign({0, 0});
162  pads_end.assign({0, 0});
163  kernel.assign({dims_prod, 1});
164  } else {
165  for (size_t i = 0; i < input_shape.size() - 2; ++i) {
166  strides.push_back(1);
167  pads_begin.push_back(0);
168  pads_end.push_back(0);
169  kernel.push_back(1);
170  }
171  for (auto& axis : axes_vector) {
172  kernel[axis-2] = input_shape[axis];
173  }
174  if (!reduce->get_keep_dims()) {
175  shape_end = reduce->output(0).get_shape();
176  }
177  }
178 
179  /*
180  * ReduceMean => AvgPool
181  * AvgPool->Reshape (in case if keep_dims=False)
182  * Reshape->AvgPool->Reshape (in case if axes doesn't match spatial dims)
183 
184  * ReduceMax => MaxPool
185  * MaxPool->Reshape (in case if keep_dims=False)
186  * Reshape->MaxPool->Reshape (in case if axes doesn't match spatial dims)
187  *
188  * ReduceSum => AvgPool->Multiply
189  * AvgPool->Multiply->Reshape (in case if keep_dims=False)
190  * Reshape->AvgPool->Multiply->Reshape (in case if axes doesn't match spatial dims)
191  *
192  * Note: some of reshape nodes can be optimized if they do nothing.
193  */
194  ngraph::NodeVector new_ops;
195 
196  if (!shape_begin.empty() && shape_begin != input.get_shape()) {
197  input = std::make_shared<ngraph::opset1::Reshape>(input,
198  ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{shape_begin.size()}, shape_begin), true);
199  input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/reshape_begin");
200  new_ops.push_back(input.get_node_shared_ptr());
201  }
202 
203  if (std::is_same<T, ngraph::opset1::ReduceMean>()) {
204  input = std::make_shared<ngraph::opset1::AvgPool>(input,
205  strides,
206  pads_begin,
207  pads_end,
208  kernel,
209  true,
210  ngraph::op::RoundingType::FLOOR);
211 
212  input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
213  new_ops.push_back(input.get_node_shared_ptr());
214  } else if (std::is_same<T, ngraph::opset1::ReduceMax>()) {
215  input = std::make_shared<ngraph::opset1::MaxPool>(input,
216  strides,
217  pads_begin,
218  pads_end,
219  kernel,
220  ngraph::op::RoundingType::FLOOR);
221 
222  input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
223  new_ops.push_back(input.get_node_shared_ptr());
224  } else if (std::is_same<T, ngraph::opset1::ReduceSum>()) {
225  // Fallback to real type because of potential data loss in case of integer AVG Pool
226  bool fallback_to_real = input.get_element_type().is_integral();
227 
228  if (fallback_to_real) {
229  input = std::make_shared<ngraph::opset1::Convert>(input, ngraph::element::f32);
230  new_ops.push_back(input.get_node_shared_ptr());
231  }
232 
233  input = std::make_shared<ngraph::opset1::AvgPool>(input,
234  strides,
235  pads_begin,
236  pads_end,
237  kernel,
238  true,
239  ngraph::op::RoundingType::FLOOR);
240 
241  input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
242  new_ops.push_back(input.get_node_shared_ptr());
243 
244  input = std::make_shared<ngraph::opset1::Multiply>(input,
245  ngraph::opset1::Constant::create(input.get_element_type(), ngraph::Shape{1}, {reduction_dims_count}));
246  input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/mul");
247  new_ops.push_back(input.get_node_shared_ptr());
248 
249  if (fallback_to_real) {
250  input = std::make_shared<ngraph::opset1::Convert>(input, reduce->output(0).get_element_type());
251  new_ops.push_back(input.get_node_shared_ptr());
252  }
253  } else {
254  return false;
255  }
256 
257  if (!shape_end.empty() && shape_end != input.get_shape()) {
258  input = std::make_shared<ngraph::opset1::Reshape>(input,
259  ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{shape_end.size()}, shape_end), true);
260  new_ops.push_back(input.get_node_shared_ptr());
261  }
262  input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name());
263  copy_runtime_info(reduce, new_ops);
264  reduce->output(0).replace(input);
265  return true;
266  };
267 }
Definition: convert_reduce_to_pooling.hpp:32
Definition: convert_reduce_to_pooling.hpp:44
Definition: convert_reduce_to_pooling.hpp:38
Definition: convert_reduce_to_pooling.hpp:50
Definition: convert_reduce_to_pooling.hpp:56
ngraph namespace
Definition: add_fake_quantize_fusion.hpp:14