node_input.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <cstring>
8 #include <map>
9 
10 #include "ngraph/descriptor/tensor.hpp"
11 #include "ngraph/partial_shape.hpp"
12 #include "ngraph/shape.hpp"
13 #include "ngraph/type/element_type.hpp"
14 
15 namespace ngraph
16 {
17  class Node;
18 
19  template <typename NodeType>
20  class Output;
21 
22  template <typename NodeType>
23  class Input
24  {
25  };
26 
27  class Variant;
28 
29  /// \brief A handle for one of a node's inputs.
30  template <>
31  class NGRAPH_API Input<Node>
32  {
33  public:
34  /// \brief Constructs a Input.
35  /// \param node Pointer to the node for the input handle.
36  /// \param index The index of the input.
37  Input(Node* node, size_t index);
38 
39  /// \return A pointer to the node referenced by this input handle.
40  Node* get_node() const;
41  /// \return The index of the input referred to by this input handle.
42  size_t get_index() const;
43  /// \return The element type of the input referred to by this input handle.
45  /// \return The shape of the input referred to by this input handle.
46  const Shape& get_shape() const;
47  /// \return The partial shape of the input referred to by this input handle.
49  /// \return A handle to the output that is connected to this input.
51  /// \return A reference to the tensor descriptor for this input.
53  /// \return A shared pointer to the tensor descriptor for this input.
54  std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
55  /// \return true if this input is relevant to its node's output shapes; else false.
57  /// \return true if this input is relevant to its node's output values; else false.
59 
60  /// \brief Replaces the source output of this input.
61  /// \param new_source_output A handle for the output that will replace this input's source.
62  void replace_source_output(const Output<Node>& new_source_output) const;
63 
64  using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
65  /// \return The reference to runtime info map
66  RTMap& get_rt_info();
67  /// \return The constant reference to runtime info map
68  const RTMap& get_rt_info() const;
69 
70  bool operator==(const Input& other) const;
71  bool operator!=(const Input& other) const;
72  bool operator<(const Input& other) const;
73  bool operator>(const Input& other) const;
74  bool operator<=(const Input& other) const;
75  bool operator>=(const Input& other) const;
76 
77  private:
78  Node* const m_node;
79  const size_t m_index;
80  };
81 
82  /// \brief A handle for one of a node's inputs.
83  template <>
84  class NGRAPH_API Input<const Node>
85  {
86  public:
87  /// \brief Constructs a Input.
88  /// \param node Pointer to the node for the input handle.
89  /// \param index The index of the input.
90  Input(const Node* node, size_t index);
91 
92  /// \return A pointer to the node referenced by this input handle.
93  const Node* get_node() const;
94  /// \return The index of the input referred to by this input handle.
95  size_t get_index() const;
96  /// \return The element type of the input referred to by this input handle.
98  /// \return The shape of the input referred to by this input handle.
99  const Shape& get_shape() const;
100  /// \return The partial shape of the input referred to by this input handle.
102  /// \return A handle to the output that is connected to this input.
104  /// \return A reference to the tensor descriptor for this input.
106  /// \return A shared pointer to the tensor descriptor for this input.
107  std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
108  /// \return true if this input is relevant to its node's output shapes; else false.
110  /// \return true if this input is relevant to its node's output values; else false.
112 
113  using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
114  /// \return The constant reference to runtime info map
115  const RTMap& get_rt_info() const;
116 
117  bool operator==(const Input& other) const;
118  bool operator!=(const Input& other) const;
119  bool operator<(const Input& other) const;
120  bool operator>(const Input& other) const;
121  bool operator<=(const Input& other) const;
122  bool operator>=(const Input& other) const;
123 
124  private:
125  const Node* const m_node;
126  const size_t m_index;
127  };
128 
129  NGRAPH_API std::ostream& operator<<(std::ostream& out, const Input<Node>& input);
130  NGRAPH_API std::ostream& operator<<(std::ostream& out, const Input<const Node>& input);
131 } // namespace ngraph
A handle for one of a node's inputs.
Definition: node_input.hpp:32
const Shape & get_shape() const
const RTMap & get_rt_info() const
void replace_source_output(const Output< Node > &new_source_output) const
Replaces the source output of this input.
bool get_is_relevant_to_shapes() const
descriptor::Tensor & get_tensor() const
const element::Type & get_element_type() const
Input(Node *node, size_t index)
Constructs a Input.
std::shared_ptr< descriptor::Tensor > get_tensor_ptr() const
Node * get_node() const
Output< Node > get_source_output() const
const PartialShape & get_partial_shape() const
size_t get_index() const
bool get_is_relevant_to_values() const
A handle for one of a node's inputs.
Definition: node_input.hpp:85
const Node * get_node() const
descriptor::Tensor & get_tensor() const
bool get_is_relevant_to_shapes() const
const element::Type & get_element_type() const
std::shared_ptr< descriptor::Tensor > get_tensor_ptr() const
Output< Node > get_source_output() const
Input(const Node *node, size_t index)
Constructs a Input.
const PartialShape & get_partial_shape() const
bool get_is_relevant_to_values() const
const Shape & get_shape() const
const RTMap & get_rt_info() const
Definition: node_input.hpp:24
Definition: node.hpp:127
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:34
Shape for a tensor.
Definition: shape.hpp:19
Definition: variant.hpp:18
Compile-time descriptor of a first-class value that is a tensor.
Definition: tensor.hpp:28
Definition: element_type.hpp:51
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16