input.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <memory>
20 
21 #include "ngraph/descriptor/tensor.hpp"
22 
23 namespace ngraph
24 {
25  class Node;
26 
27  namespace descriptor
28  {
29  class Output;
30 
31  // Describes a tensor that is an input to an op, directly or indirectly via a tuple
32  class NGRAPH_API Input
33  {
34  friend class ngraph::Node;
35 
36  public:
37  /// \param node The node that owns this input
38  /// \param index The position of this this tensor in all input tensors
39  /// \param output The output that supplies a value for this input
40  Input(Node* node, size_t index, Output& output);
41  /// \brief Create an Input that is not connected to an output
42  /// \param node The node that owns this input
43  /// \param index The position of this this tensor in all input tensors
44  Input(Node* node, size_t index);
45  ~Input();
46 
47  /// \return the node that this is an input of
48  std::shared_ptr<Node> get_node() const;
49 
50  /// \return the raw pointer to the node that this is an input of
51  Node* get_raw_pointer_node() const { return m_node; }
52  /// \return the position within all supplied tensors of this input
53  size_t get_index() const { return m_index; }
54  /// \return the connected output
55  const Output& get_output() const { return *m_output; }
56  /// \return the connected output
57  Output& get_output() { return *m_output; }
58  /// \return true if an output is connected to the input.
59  bool has_output() const { return m_output != nullptr; }
60  /// \return the tensor of the connected output
61  const Tensor& get_tensor() const;
62 
63  /// \return the tensor of the connected output
65 
66  /// \brief Replace the current output that supplies a value for this input with output i
67  /// of node
68  void replace_output(std::shared_ptr<Node> node, size_t i);
69  /// \brief Replace the current output that supplies a value for this input with output
71  /// \brief Remove the output from this input. The node will not be valid until another
72  /// output is supplied.
73  void remove_output();
74 
75  /// \return true if the value of this input is relevant to the output shapes of the
76  /// corresponding node. (Usually this is false.)
77  ///
78  /// See Node::set_input_is_relevant_to_shape for more details.
79  bool get_is_relevant_to_shape() const { return m_is_relevant_to_shape; }
80  /// \return true if the value of this input is relevant to the output value of the
81  /// corresponding node. (Usually this is true.)
82  ///
83  /// See Node::set_input_is_relevant_to_value for more details.
84  bool get_is_relevant_to_value() const { return m_is_relevant_to_value; }
85  protected:
86  /// \return the tensor for the connected output
87  std::shared_ptr<const Tensor> get_tensor_ptr() const;
88 
89  /// \return the tensor for the connected output
90  std::shared_ptr<Tensor> get_tensor_ptr();
91 
92  public:
93  /// \return the shape of the connected output
94  const Shape& get_shape() const;
95 
96  /// \return the partial shape of the connected output
97  const PartialShape& get_partial_shape() const;
98 
99  /// \return the element type of the connected output
101 
102  Input(const Input&) = default;
103  Input(Input&&) = default;
104  Input& operator=(const Input&) = default;
105 
106  protected:
107  // owner of an argument node (in lieu of m_arguments)
108  std::shared_ptr<Node> m_src_node;
109  Node* m_node; // The node we are an input for
110  size_t m_index; // Index into all input tensors
111  Output* m_output;
112 
113  private:
114  bool m_is_relevant_to_shape;
115  bool m_is_relevant_to_value;
116  };
117  }
118 }
ngraph::descriptor::Input::replace_output
void replace_output(Output &output)
Replace the current output that supplies a value for this input with output.
ngraph::descriptor::Input::Input
Input(Node *node, size_t index)
Create an Input that is not connected to an output.
ngraph::descriptor::Input::get_tensor
Tensor & get_tensor()
ngraph::Node::operator=
Node & operator=(const Node &)
Assignment operator.
ngraph::descriptor::Input::get_tensor
const Tensor & get_tensor() const
ngraph::descriptor::Input::get_output
Output & get_output()
Definition: input.hpp:57
ngraph::descriptor::Tensor
Compile-time descriptor of a first-class value that is a tensor.
Definition: tensor.hpp:34
ngraph::descriptor::Input::remove_output
void remove_output()
Remove the output from this input. The node will not be valid until another output is supplied.
ngraph::descriptor::Input
Definition: input.hpp:33
ngraph::descriptor::Input::get_tensor_ptr
std::shared_ptr< Tensor > get_tensor_ptr()
ngraph::descriptor::Input::replace_output
void replace_output(std::shared_ptr< Node > node, size_t i)
Replace the current output that supplies a value for this input with output i of node.
ngraph::element::Type
Definition: element_type.hpp:61
ngraph::descriptor::Input::has_output
bool has_output() const
Definition: input.hpp:59
ngraph::descriptor::Input::get_raw_pointer_node
Node * get_raw_pointer_node() const
Definition: input.hpp:51
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::descriptor::Input::get_node
std::shared_ptr< Node > get_node() const
ngraph::descriptor::Input::get_element_type
const element::Type & get_element_type() const
ngraph::descriptor::Input::get_is_relevant_to_shape
bool get_is_relevant_to_shape() const
Definition: input.hpp:79
ngraph::descriptor::Input::get_tensor_ptr
std::shared_ptr< const Tensor > get_tensor_ptr() const
ngraph::descriptor::Input::get_shape
const Shape & get_shape() const
ngraph::descriptor::Output
Definition: output.hpp:39
ngraph::Node
Definition: node.hpp:131
ngraph::descriptor::Input::get_partial_shape
const PartialShape & get_partial_shape() const
ngraph::Node::Node
Node()=default
Construct an unitialized Node.
ngraph::descriptor::Input::get_output
const Output & get_output() const
Definition: input.hpp:55
ngraph::descriptor::Input::Input
Input(Node *node, size_t index, Output &output)
ngraph::descriptor::Input::get_is_relevant_to_value
bool get_is_relevant_to_value() const
Definition: input.hpp:84
ngraph::descriptor::Input::get_index
size_t get_index() const
Definition: input.hpp:53
ngraph::Node::output
Output< Node > output(size_t output_index)