arithmetic_reductions_keep_dims.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include "ngraph/op/op.hpp"
20 #include "ngraph/op/util/arithmetic_reduction.hpp"
21 
22 namespace ngraph
23 {
24  namespace op
25  {
26  namespace util
27  {
29  {
30  protected:
31  ArithmeticReductionKeepDims() = default;
32 
33  /// \param arg The tensor to be summed.
34  /// \param reduction_axes The axis positions (0-based) to be eliminated.
35  /// \param keep_dims If set to 1 it holds axes that are used for reduction.
37  const Output<Node>& reduction_axes,
38  bool keep_dims = false);
39 
40  bool visit_attributes(AttributeVisitor& visitor) override;
41 
42  public:
43  void validate_and_infer_types() override;
44 
45  /// \return If set to 1 it holds axes that are used for reduction.
46  /// For each such axis, output dimension is equal to 1.
47  bool get_keep_dims() const { return m_keep_dims; }
48  void set_keep_dims(bool keep_dims) { m_keep_dims = keep_dims; }
49  private:
50  bool m_keep_dims = false;
51  };
52  }
53  }
54 }
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Definition: arithmetic_reductions_keep_dims.hpp:29
bool get_keep_dims() const
Definition: arithmetic_reductions_keep_dims.hpp:47
ArithmeticReductionKeepDims(const Output< Node > &arg, const Output< Node > &reduction_axes, bool keep_dims=false)
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Abstract base class for arithmetic reduction operations, i.e., operations where chosen axes of the in...
Definition: arithmetic_reduction.hpp:31
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28