11 #include <transformations_visibility.hpp>
13 #include <ngraph/op/util/op_types.hpp>
14 #include <ngraph/pass/graph_rewrite.hpp>
15 #include <ngraph/opsets/opset1.hpp>
16 #include <ngraph/validation_util.hpp>
17 #include <ngraph/rt_info.hpp>
18 #include <ngraph/pattern/op/wrap_type.hpp>
24 class TRANSFORMATIONS_API ConvertReduceToPooling;
25 class TRANSFORMATIONS_API ConvertReduceMeanToPooling;
26 class TRANSFORMATIONS_API ConvertReduceMaxToPooling;
27 class TRANSFORMATIONS_API ConvertReduceSumToPooling;
35 ngraph::matcher_pass_callback convert_reduce_to_pooling();
40 NGRAPH_RTTI_DECLARATION;
46 NGRAPH_RTTI_DECLARATION;
52 NGRAPH_RTTI_DECLARATION;
58 NGRAPH_RTTI_DECLARATION;
60 add_matcher<ConvertReduceMeanToPooling>();
61 add_matcher<ConvertReduceMaxToPooling>();
62 add_matcher<ConvertReduceSumToPooling>();
67 ngraph::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() {
68 return [&](ngraph::pattern::Matcher& m) {
69 auto reduce = std::dynamic_pointer_cast<T>(m.get_match_root());
71 if (!reduce || transformation_callback(reduce)) {
75 auto input = reduce->input_value(0);
77 auto axes_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(reduce->input_value(1).get_node_shared_ptr());
82 auto axes_vector = axes_node->template cast_vector<int64_t>();
83 const auto input_rank = input.get_partial_shape().rank().get_length();
85 for (
size_t i = 0; i < axes_vector.size(); ++i) {
86 if (axes_vector[i] < 0) {
87 axes_vector[i] += input_rank;
90 std::sort(axes_vector.begin(), axes_vector.end());
93 if (axes_vector.empty()) {
94 return replace_output_update_name(reduce->output(0), input);
97 auto input_shape = input.get_shape();
100 if (std::all_of(axes_vector.begin(), axes_vector.end(),
101 [&input_shape](
const int64_t & axis) { return input_shape[axis] == 1; })) {
102 const auto reshape_shape = reduce->output(0).get_shape();
103 auto reshape = std::make_shared<ngraph::opset1::Reshape>(input,
104 ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape),
true);
106 reshape->set_friendly_name(reduce->get_friendly_name());
107 copy_runtime_info(reduce, reshape);
108 replace_node(reduce, reshape);
113 for (
size_t i = 1; i < axes_vector.size(); ++i) {
114 if (axes_vector[i] - axes_vector[i-1] != 1) {
120 bool spatial_dims_reduction(
true);
121 size_t reduction_dims_count = 1;
122 for (
auto& axis : axes_vector) {
123 reduction_dims_count *= input_shape[axis];
125 spatial_dims_reduction =
false;
140 ngraph::Strides strides;
141 ngraph::Shape pads_begin, pads_end, kernel, shape_begin, shape_end;
143 if (!spatial_dims_reduction || input_shape.size() != 4) {
146 size_t dims_prod = 1, dims_begin = 1, dims_end = 1;
147 for (int64_t i = 0;
static_cast<size_t>(i) < input_shape.size(); ++i) {
148 if (i < *axes_vector.begin()) {
149 dims_begin *= input_shape[i];
150 }
else if (i >= axes_vector.front() && i <= axes_vector.back()) {
151 dims_prod *= input_shape[i];
153 dims_end *= input_shape[i];
158 shape_begin.assign({dims_begin, 1, dims_prod, dims_end});
159 shape_end = reduce->output(0).get_shape();
160 strides.assign({1, 1});
161 pads_begin.assign({0, 0});
162 pads_end.assign({0, 0});
163 kernel.assign({dims_prod, 1});
165 for (
size_t i = 0; i < input_shape.size() - 2; ++i) {
166 strides.push_back(1);
167 pads_begin.push_back(0);
168 pads_end.push_back(0);
171 for (
auto& axis : axes_vector) {
172 kernel[axis-2] = input_shape[axis];
174 if (!reduce->get_keep_dims()) {
175 shape_end = reduce->output(0).get_shape();
194 ngraph::NodeVector new_ops;
196 if (!shape_begin.empty() && shape_begin != input.get_shape()) {
197 input = std::make_shared<ngraph::opset1::Reshape>(input,
198 ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{shape_begin.size()}, shape_begin),
true);
199 input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() +
"/reshape_begin");
200 new_ops.push_back(input.get_node_shared_ptr());
203 if (std::is_same<T, ngraph::opset1::ReduceMean>()) {
204 input = std::make_shared<ngraph::opset1::AvgPool>(input,
210 ngraph::op::RoundingType::FLOOR);
212 input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() +
"/pool");
213 new_ops.push_back(input.get_node_shared_ptr());
214 }
else if (std::is_same<T, ngraph::opset1::ReduceMax>()) {
215 input = std::make_shared<ngraph::opset1::MaxPool>(input,
220 ngraph::op::RoundingType::FLOOR);
222 input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() +
"/pool");
223 new_ops.push_back(input.get_node_shared_ptr());
224 }
else if (std::is_same<T, ngraph::opset1::ReduceSum>()) {
226 bool fallback_to_real = input.get_element_type().is_integral();
228 if (fallback_to_real) {
229 input = std::make_shared<ngraph::opset1::Convert>(input, ngraph::element::f32);
230 new_ops.push_back(input.get_node_shared_ptr());
233 input = std::make_shared<ngraph::opset1::AvgPool>(input,
239 ngraph::op::RoundingType::FLOOR);
241 input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() +
"/pool");
242 new_ops.push_back(input.get_node_shared_ptr());
244 input = std::make_shared<ngraph::opset1::Multiply>(input,
245 ngraph::opset1::Constant::create(input.get_element_type(), ngraph::Shape{1}, {reduction_dims_count}));
246 input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() +
"/mul");
247 new_ops.push_back(input.get_node_shared_ptr());
249 if (fallback_to_real) {
250 input = std::make_shared<ngraph::opset1::Convert>(input, reduce->output(0).get_element_type());
251 new_ops.push_back(input.get_node_shared_ptr());
257 if (!shape_end.empty() && shape_end != input.get_shape()) {
258 input = std::make_shared<ngraph::opset1::Reshape>(input,
259 ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{shape_end.size()}, shape_end),
true);
260 new_ops.push_back(input.get_node_shared_ptr());
262 input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name());
263 copy_runtime_info(reduce, new_ops);
264 reduce->output(0).replace(input);
Definition: convert_reduce_to_pooling.hpp:32
Definition: convert_reduce_to_pooling.hpp:44
Definition: convert_reduce_to_pooling.hpp:38
Definition: convert_reduce_to_pooling.hpp:50
Definition: convert_reduce_to_pooling.hpp:56
ngraph namespace
Definition: add_fake_quantize_fusion.hpp:14