concat.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <memory>
8 
9 #include "ngraph/op/op.hpp"
10 
11 namespace ngraph
12 {
13  namespace op
14  {
15  namespace v0
16  {
17  /// \brief Concatenation operation.
18  class NGRAPH_API Concat : public Op
19  {
20  public:
21  NGRAPH_RTTI_DECLARATION;
22 
23  /// \brief Constructs a concatenation operation.
24  Concat() = default;
25  /// \brief Constructs a concatenation operation.
26  ///
27  /// \param args The outputs producing the input tensors.
28  /// \param axis The axis along which to concatenate the input tensors.
29  Concat(const OutputVector& args, int64_t axis);
30 
31  /// \brief Constructs a concatenation operation.
32  ///
33  /// \param args The nodes producing the input tensors.
34  /// \param axis The axis along which to concatenate the input tensors.
35  Concat(const NodeVector& args, int64_t axis);
36 
37  bool visit_attributes(AttributeVisitor& visitor) override;
38  void validate_and_infer_types() override;
39 
40  virtual std::shared_ptr<Node>
41  clone_with_new_inputs(const OutputVector& new_args) const override;
42 
43  /// \return The concatenation axis.
44  int64_t get_concatenation_axis() const { return m_concat_axis; }
45  void set_concatenation_axis(int64_t concatenation_axis)
46  {
47  m_concat_axis = concatenation_axis;
48  }
49  /// \return The concatenation axis.
50  int64_t get_axis() const { return m_axis; }
51  void set_axis(int64_t axis) { m_axis = axis; }
52  bool evaluate(const HostTensorVector& outputs,
53  const HostTensorVector& inputs) const override;
54  bool has_evaluate() const override;
55  bool evaluate_lower(const HostTensorVector& output_values) const override;
56  bool evaluate_upper(const HostTensorVector& output_values) const override;
57 
58  protected:
59  /// \ brief m_axis stores default value for all iterations
60  int64_t m_axis;
61  /// \brief m_concat_axis stores m_axis plus the number of rank for each iteration
62  int64_t m_concat_axis = -1;
63  };
64  } // namespace v0
65  using v0::Concat;
66  } // namespace op
67 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
Root of all actual ops.
Definition: op.hpp:17
Concatenation operation.
Definition: concat.hpp:19
int64_t get_axis() const
Definition: concat.hpp:50
int64_t get_concatenation_axis() const
Definition: concat.hpp:44
Concat(const OutputVector &args, int64_t axis)
Constructs a concatenation operation.
int64_t m_axis
\ brief m_axis stores default value for all iterations
Definition: concat.hpp:60
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
Concat()=default
Constructs a concatenation operation.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Concat(const NodeVector &args, int64_t axis)
Constructs a concatenation operation.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16