node_output.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <cstring>
20 #include <unordered_set>
21 
22 #include "ngraph/descriptor/tensor.hpp"
23 #include "ngraph/partial_shape.hpp"
24 #include "ngraph/shape.hpp"
25 #include "ngraph/type/element_type.hpp"
26 
27 namespace ngraph
28 {
29  class Node;
30 
31  template <typename NodeType>
32  class Input;
33 
34  template <typename NodeType>
35  class Output
36  {
37  };
38 
39  /// \brief A handle for one of a node's outputs.
40  template <>
41  class NGRAPH_API Output<Node>
42  {
43  public:
44  /// \brief Constructs a Output.
45  /// \param node A pointer to the node for the output handle.
46  /// \param index The index of the output.
47  Output(Node* node, size_t index);
48 
49  /// \brief Constructs a Output.
50  /// \param node A `shared_ptr` to the node for the output handle.
51  /// \param index The index of the output.
52  ///
53  /// TODO: Make a plan to deprecate this.
54  Output(const std::shared_ptr<Node>& node, size_t index);
55 
56  /// \brief Constructs a Output, referencing the zeroth output of the node.
57  /// \param node A `shared_ptr` to the node for the output handle.
58  template <typename T>
59  Output(const std::shared_ptr<T>& node)
60  : Output(node ? node->get_default_output() : Output<Node>())
61  {
62  }
63 
64  /// A null output
65  Output() = default;
66 
67  void reset();
68 
69  /// This output position for a different node
70  Output<Node> for_node(const std::shared_ptr<Node>& node);
71  /// \return A pointer to the node referred to by this output handle.
72  Node* get_node() const;
73  /// \return A `shared_ptr` to the node referred to by this output handle.
74  ///
75  /// TODO: Make a plan to deprecate this.
76  std::shared_ptr<Node> get_node_shared_ptr() const;
77 
78  /// \return The index of the output referred to by this output handle.
79  size_t get_index() const;
80  /// \return A reference to the tensor descriptor for this output.
82  /// \return A shared point to the tensor ptr for this output.
83  std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
84  /// \return The element type of the output referred to by this output handle.
86  /// \return The shape of the output referred to by this output handle.
87  const Shape& get_shape() const;
88  /// \return The partial shape of the output referred to by this output handle.
90 
91  /// \return A set containing handles for all inputs targeted by the output referenced by
92  /// this output handle.
93  std::set<Input<Node>> get_target_inputs() const;
94 
95  /// \brief Removes a target input from the output referenced by this output handle.
96  /// \param target_input The target input to remove.
97  ///
98  // TODO(amprocte): Investigate whether this really ought to be public.
99  void remove_target_input(const Input<Node>& target_input) const;
100 
101  /// \brief Replace all users of this value with replacement
102  void replace(const Output<Node>& replacement);
103 
104  bool operator==(const Output& other) const;
105  bool operator!=(const Output& other) const;
106  bool operator<(const Output& other) const;
107  bool operator>(const Output& other) const;
108  bool operator<=(const Output& other) const;
109  bool operator>=(const Output& other) const;
110 
111  private:
112  std::shared_ptr<Node> m_node;
113  size_t m_index{0};
114  };
115 
116  template <>
117  class NGRAPH_API Output<const Node>
118  {
119  public:
120  /// \brief Constructs a Output.
121  /// \param node A pointer to the node for the output handle.
122  /// \param index The index of the output.
123  Output(const Node* node, size_t index);
124 
125  /// \brief Constructs a Output.
126  /// \param node A `shared_ptr` to the node for the output handle.
127  /// \param index The index of the output.
128  ///
129  /// TODO: Make a plan to deprecate this.
130  Output(const std::shared_ptr<const Node>& node, size_t index);
131 
132  /// \brief Constructs a Output, referencing the zeroth output of the node.
133  /// \param node A `shared_ptr` to the node for the output handle.
134  template <typename T>
135  Output(const std::shared_ptr<T>& node)
136  : Output(node ? node->get_default_output() : Output<const Node>())
137  {
138  }
139 
140  /// A null output
141  Output() = default;
142 
143  void reset();
144 
145  /// This output position for a different node
146  Output<const Node> for_node(const std::shared_ptr<const Node>& node);
147 
148  /// \return A pointer to the node referred to by this output handle.
149  const Node* get_node() const;
150  /// \return A `shared_ptr` to the node referred to by this output handle.
151  ///
152  /// TODO: Make a plan to deprecate this.
153  std::shared_ptr<const Node> get_node_shared_ptr() const;
154  /// \return The index of the output referred to by this output handle.
155  size_t get_index() const;
156  /// \return A reference to the tensor descriptor for this output.
158  /// \return A shared point to the tensor ptr for this output.
159  std::shared_ptr<descriptor::Tensor> get_tensor_ptr() const;
160  /// \return The element type of the output referred to by this output handle.
162  /// \return The shape of the output referred to by this output handle.
163  const Shape& get_shape() const;
164  /// \return The partial shape of the output referred to by this output handle.
166 
167  /// \return A set containing handles for all inputs targeted by the output referenced by
168  /// this output handle.
169  std::set<Input<Node>> get_target_inputs() const;
170 
171  bool operator==(const Output& other) const;
172  bool operator!=(const Output& other) const;
173  bool operator<(const Output& other) const;
174  bool operator>(const Output& other) const;
175  bool operator<=(const Output& other) const;
176  bool operator>=(const Output& other) const;
177 
178  private:
179  std::shared_ptr<const Node> m_node;
180  size_t m_index{0};
181  };
182 
183  NGRAPH_API std::ostream& operator<<(std::ostream& out, const Output<Node>& output);
184  NGRAPH_API std::ostream& operator<<(std::ostream& out, const Output<const Node>& output);
185 }
A handle for one of a node's inputs.
Definition: node_input.hpp:41
Definition: node.hpp:132
A handle for one of a node's outputs.
Definition: node_output.hpp:42
void remove_target_input(const Input< Node > &target_input) const
Removes a target input from the output referenced by this output handle.
Output(const std::shared_ptr< Node > &node, size_t index)
Constructs a Output.
std::set< Input< Node > > get_target_inputs() const
const element::Type & get_element_type() const
Node * get_node() const
const PartialShape & get_partial_shape() const
std::shared_ptr< Node > get_node_shared_ptr() const
void replace(const Output< Node > &replacement)
Replace all users of this value with replacement.
Output< Node > for_node(const std::shared_ptr< Node > &node)
This output position for a different node.
descriptor::Tensor & get_tensor() const
std::shared_ptr< descriptor::Tensor > get_tensor_ptr() const
Output(const std::shared_ptr< T > &node)
Constructs a Output, referencing the zeroth output of the node.
Definition: node_output.hpp:59
size_t get_index() const
Output()=default
A null output.
Output(Node *node, size_t index)
Constructs a Output.
const Shape & get_shape() const
Definition: node_output.hpp:118
Output(const Node *node, size_t index)
Constructs a Output.
Output< const Node > for_node(const std::shared_ptr< const Node > &node)
This output position for a different node.
Output()=default
A null output.
const element::Type & get_element_type() const
descriptor::Tensor & get_tensor() const
std::set< Input< Node > > get_target_inputs() const
std::shared_ptr< const Node > get_node_shared_ptr() const
Output(const std::shared_ptr< T > &node)
Constructs a Output, referencing the zeroth output of the node.
Definition: node_output.hpp:135
Output(const std::shared_ptr< const Node > &node, size_t index)
Constructs a Output.
const Node * get_node() const
const Shape & get_shape() const
const PartialShape & get_partial_shape() const
std::shared_ptr< descriptor::Tensor > get_tensor_ptr() const
Definition: node_output.hpp:36
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:46
Shape for a tensor.
Definition: shape.hpp:31
Compile-time descriptor of a first-class value that is a tensor.
Definition: tensor.hpp:40
Definition: element_type.hpp:61
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28