cum_sum.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/axis_set.hpp"
8 #include "ngraph/op/op.hpp"
9 
10 namespace ngraph
11 {
12  namespace op
13  {
14  namespace v0
15  {
16  /// \brief Tensor cumulative sum operation.
17  ///
18  /// Compute the cumulative sum of the input tensor along the axis specified.
19  ///
20  /// ## Parameters
21  ///
22  /// | | Description |
23  /// | -------------------- |
24  /// --------------------------------------------------------------------------------------------------|
25  /// | `exclusive` | If set to 1 will return exclusive sum in which the top
26  /// element
27  /// is not included. |
28  /// | | In other terms, if set to 1, the j-th output element
29  /// would be
30  /// the
31  /// sum of the first (j-1) elements.|
32  /// | | Otherwise, it would be the sum of the first j elements.
33  /// |
34  ///
35  /// | | Description |
36  /// | -------------------- | -------------------------------------------------- |
37  /// | `reverse` | if set to 1, performs the sum in reverse direction |
38  ///
39  /// ## Inputs
40  ///
41  /// | | Description |
42  /// | ----- | ------------------------------------------------------ |
43  /// | `arg` | An input tensor of any shape and numeric element type. |
44  ///
45  /// | | Description |
46  /// | ----- |
47  /// ------------------------------------------------------------------------------------------------|
48  /// | `axis`| zero dimension tensor specifying axis position along which cumulative sum
49  /// must
50  /// be performed. |
51  ///
52  /// ## Output
53  ///
54  /// | Description |
55  /// |
56  /// ------------------------------------------------------------------------------------|
57  /// | Output tensor of the same type as `arg` with cumulative sums of the arg's elements
58  /// |
59 
60  class NGRAPH_API CumSum : public Op
61  {
62  public:
63  static constexpr NodeTypeInfo type_info{"CumSum", 0};
64  const NodeTypeInfo& get_type_info() const override { return type_info; }
65  /// \brief Constructs a cumulative summation operation.
66  CumSum() = default;
67 
68  /// \brief Constructs a cumulative summation operation.
69  ///
70  /// \param arg The tensor to be summed.
71  /// \param axis zero dimension tensor specifying axis position along which
72  /// cumulative sum must be performed
73  /// \param exclusive if set to true, the top element is not included
74  /// \param reverse if set to true, will perform the sums in reverse direction
75  CumSum(const Output<Node>& arg,
76  const Output<Node>& axis,
77  const bool exclusive = false,
78  const bool reverse = false);
79 
80  /// \brief Constructs a cumulative summation operation with axis = 0
81  ///
82  /// \param arg The tensor to be summed
83  CumSum(const Output<Node>& arg,
84  const bool exclusive = false,
85  const bool reverse = false);
86 
87  virtual std::shared_ptr<Node>
88  clone_with_new_inputs(const OutputVector& new_args) const override;
89 
90  bool visit_attributes(AttributeVisitor& visitor) override;
91  void validate_and_infer_types() override;
92 
93  /// \return The default value for CumSum.
94  virtual std::shared_ptr<Node> get_default_value() const override;
95  bool is_exclusive() const { return m_exclusive; }
96  bool is_reverse() const { return m_reverse; }
97 
98  private:
99  bool m_exclusive;
100  bool m_reverse;
101  };
102  } // namespace v0
103  using v0::CumSum;
104  } // namespace op
105 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Root of all actual ops.
Definition: op.hpp:17
Tensor cumulative sum operation.
Definition: cum_sum.hpp:61
CumSum(const Output< Node > &arg, const bool exclusive=false, const bool reverse=false)
Constructs a cumulative summation operation with axis = 0.
CumSum(const Output< Node > &arg, const Output< Node > &axis, const bool exclusive=false, const bool reverse=false)
Constructs a cumulative summation operation.
CumSum()=default
Constructs a cumulative summation operation.
virtual std::shared_ptr< Node > get_default_value() const override
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
const NodeTypeInfo & get_type_info() const override
Definition: cum_sum.hpp:64
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27