arithmetic_reduction.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/op/op.hpp"
8 
9 namespace ngraph
10 {
11  namespace op
12  {
13  namespace util
14  {
15  /// \brief Abstract base class for arithmetic reduction operations, i.e., operations
16  /// where chosen axes of the input tensors are eliminated (reduced out) by
17  /// repeated application of a particular binary arithmetic operation.
18  class NGRAPH_API ArithmeticReduction : public Op
19  {
20  protected:
21  /// \brief Constructs an arithmetic reduction operation.
23 
24  /// \brief Constructs an arithmetic reduction operation.
25  ///
26  /// \param arg Output that produces the first input tensor.
27  /// \param reduction_axes The axis positions (0-based) to be eliminated.
28  ArithmeticReduction(const Output<Node>& arg, const Output<Node>& reduction_axes);
29 
30  public:
31  NGRAPH_RTTI_DECLARATION;
32  void validate_and_infer_types() override;
33 
34  /// \return true if reduction axes are constant else false.
36 
37  /// \return The axis positions (0-based) to be eliminated through reduction.
38  /// \throws CheckFailure if the reduction axes are not constant. (Use
39  /// reduction_axes_constant to check.)
40  const AxisSet get_reduction_axes() const;
41 
42  /// \brief Change the reduction axes
43  void set_reduction_axes(const AxisSet& reduction_axes);
44  };
45  } // namespace util
46  } // namespace op
47 } // namespace ngraph
A set of axes.
Definition: axis_set.hpp:19
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Root of all actual ops.
Definition: op.hpp:17
Abstract base class for arithmetic reduction operations, i.e., operations where chosen axes of the in...
Definition: arithmetic_reduction.hpp:19
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
void set_reduction_axes(const AxisSet &reduction_axes)
Change the reduction axes.
const AxisSet get_reduction_axes() const
ArithmeticReduction()
Constructs an arithmetic reduction operation.
ArithmeticReduction(const Output< Node > &arg, const Output< Node > &reduction_axes)
Constructs an arithmetic reduction operation.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16