node_output.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <cstring>
8 #include <map>
9 #include <unordered_set>
10 
11 #include "ngraph/descriptor/tensor.hpp"
12 #include "ngraph/partial_shape.hpp"
13 #include "ngraph/shape.hpp"
14 #include "ngraph/type/element_type.hpp"
15 
16 namespace ngraph
17 {
18  class Node;
19 
20  template <typename NodeType>
21  class Input;
22 
23  template <typename NodeType>
24  class Output
25  {
26  };
27 
28  class Variant;
29 
30  /// \brief A handle for one of a node's outputs.
31  template <>
32  class NGRAPH_API Output<Node>
33  {
34  public:
35  /// \brief Constructs a Output.
36  /// \param node A pointer to the node for the output handle.
37  /// \param index The index of the output.
38  Output(Node* node, size_t index);
39 
40  /// \brief Constructs a Output.
41  /// \param node A `shared_ptr` to the node for the output handle.
42  /// \param index The index of the output.
43  ///
44  /// TODO: Make a plan to deprecate this.
45  Output(const std::shared_ptr<Node>& node, size_t index);
46 
47  /// \brief Constructs a Output, referencing the zeroth output of the node.
48  /// \param node A `shared_ptr` to the node for the output handle.
49  template <typename T>
50  Output(const std::shared_ptr<T>& node)
51  : Output(node ? node->get_default_output() : Output<Node>())
52  {
53  }
54 
55  /// A null output
56  Output() = default;
57 
58  void reset();
59 
60  /// This output position for a different node
61  Output<Node> for_node(const std::shared_ptr<Node>& node);
62  /// \return A pointer to the node referred to by this output handle.
63  Node* get_node() const;
64  /// \return A `shared_ptr` to the node referred to by this output handle.
65  ///
66  /// TODO: Make a plan to deprecate this.
67  std::shared_ptr<Node> get_node_shared_ptr() const;
68 
69  /// \return The index of the output referred to by this output handle.
70  size_t get_index() const;
71  /// \return A reference to the tensor descriptor for this output.
73  /// \return A shared point to the tensor ptr for this output.
74  std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
75  /// \return The element type of the output referred to by this output handle.
77  /// \return The shape of the output referred to by this output handle.
78  const Shape& get_shape() const;
79  /// \return The partial shape of the output referred to by this output handle.
81 
82  using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
83  /// \return The reference to runtime info map
84  RTMap& get_rt_info();
85  /// \return The constant reference to runtime info map
86  const RTMap& get_rt_info() const;
87 
88  /// \return A set containing handles for all inputs targeted by the output referenced by
89  /// this output handle.
90  std::set<Input<Node>> get_target_inputs() const;
91 
92  /// \brief Removes a target input from the output referenced by this output handle.
93  /// \param target_input The target input to remove.
94  ///
95  // TODO(amprocte): Investigate whether this really ought to be public.
96  void remove_target_input(const Input<Node>& target_input) const;
97 
98  /// \brief Replace all users of this value with replacement
99  void replace(const Output<Node>& replacement);
100 
101  bool operator==(const Output& other) const;
102  bool operator!=(const Output& other) const;
103  bool operator<(const Output& other) const;
104  bool operator>(const Output& other) const;
105  bool operator<=(const Output& other) const;
106  bool operator>=(const Output& other) const;
107 
108  private:
109  std::shared_ptr<Node> m_node;
110  size_t m_index{0};
111  };
112 
113  template <>
114  class NGRAPH_API Output<const Node>
115  {
116  public:
117  /// \brief Constructs a Output.
118  /// \param node A pointer to the node for the output handle.
119  /// \param index The index of the output.
120  Output(const Node* node, size_t index);
121 
122  /// \brief Constructs a Output.
123  /// \param node A `shared_ptr` to the node for the output handle.
124  /// \param index The index of the output.
125  ///
126  /// TODO: Make a plan to deprecate this.
127  Output(const std::shared_ptr<const Node>& node, size_t index);
128 
129  /// \brief Constructs a Output, referencing the zeroth output of the node.
130  /// \param node A `shared_ptr` to the node for the output handle.
131  template <typename T>
132  Output(const std::shared_ptr<T>& node)
133  : Output(node ? node->get_default_output() : Output<const Node>())
134  {
135  }
136 
137  /// A null output
138  Output() = default;
139 
140  void reset();
141 
142  /// This output position for a different node
143  Output<const Node> for_node(const std::shared_ptr<const Node>& node);
144 
145  /// \return A pointer to the node referred to by this output handle.
146  const Node* get_node() const;
147  /// \return A `shared_ptr` to the node referred to by this output handle.
148  ///
149  /// TODO: Make a plan to deprecate this.
150  std::shared_ptr<const Node> get_node_shared_ptr() const;
151  /// \return The index of the output referred to by this output handle.
152  size_t get_index() const;
153  /// \return A reference to the tensor descriptor for this output.
155  /// \return A shared point to the tensor ptr for this output.
156  std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
157  /// \return The element type of the output referred to by this output handle.
159  /// \return The shape of the output referred to by this output handle.
160  const Shape& get_shape() const;
161  /// \return The partial shape of the output referred to by this output handle.
163 
164  using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
165  /// \return The constant reference to runtime info map
166  const RTMap& get_rt_info() const;
167  /// \return A set containing handles for all inputs targeted by the output referenced by
168  /// this output handle.
169  std::set<Input<Node>> get_target_inputs() const;
170 
171  bool operator==(const Output& other) const;
172  bool operator!=(const Output& other) const;
173  bool operator<(const Output& other) const;
174  bool operator>(const Output& other) const;
175  bool operator<=(const Output& other) const;
176  bool operator>=(const Output& other) const;
177 
178  private:
179  std::shared_ptr<const Node> m_node;
180  size_t m_index{0};
181  };
182 
183  NGRAPH_API std::ostream& operator<<(std::ostream& out, const Output<Node>& output);
184  NGRAPH_API std::ostream& operator<<(std::ostream& out, const Output<const Node>& output);
185 } // namespace ngraph
A handle for one of a node's inputs.
Definition: node_input.hpp:32
Definition: node.hpp:127
A handle for one of a node's outputs.
Definition: node_output.hpp:33
void remove_target_input(const Input< Node > &target_input) const
Removes a target input from the output referenced by this output handle.
Output(const std::shared_ptr< Node > &node, size_t index)
Constructs a Output.
std::set< Input< Node > > get_target_inputs() const
const element::Type & get_element_type() const
Node * get_node() const
const PartialShape & get_partial_shape() const
std::shared_ptr< Node > get_node_shared_ptr() const
const RTMap & get_rt_info() const
void replace(const Output< Node > &replacement)
Replace all users of this value with replacement.
Output< Node > for_node(const std::shared_ptr< Node > &node)
This output position for a different node.
descriptor::Tensor & get_tensor() const
std::shared_ptr< descriptor::Tensor > get_tensor_ptr() const
Output(const std::shared_ptr< T > &node)
Constructs a Output, referencing the zeroth output of the node.
Definition: node_output.hpp:50
size_t get_index() const
Output()=default
A null output.
Output(Node *node, size_t index)
Constructs a Output.
const Shape & get_shape() const
Definition: node_output.hpp:115
Output(const Node *node, size_t index)
Constructs a Output.
Output< const Node > for_node(const std::shared_ptr< const Node > &node)
This output position for a different node.
Output()=default
A null output.
const element::Type & get_element_type() const
const RTMap & get_rt_info() const
descriptor::Tensor & get_tensor() const
std::set< Input< Node > > get_target_inputs() const
std::shared_ptr< const Node > get_node_shared_ptr() const
Output(const std::shared_ptr< T > &node)
Constructs a Output, referencing the zeroth output of the node.
Definition: node_output.hpp:132
Output(const std::shared_ptr< const Node > &node, size_t index)
Constructs a Output.
const Node * get_node() const
const Shape & get_shape() const
const PartialShape & get_partial_shape() const
std::shared_ptr< descriptor::Tensor > get_tensor_ptr() const
Definition: node_output.hpp:25
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:34
Shape for a tensor.
Definition: shape.hpp:19
Definition: variant.hpp:18
Compile-time descriptor of a first-class value that is a tensor.
Definition: tensor.hpp:28
Definition: element_type.hpp:51
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16