read_value.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include "ngraph/op/op.hpp"
20 #include "ngraph/op/util/variable.hpp"
21 
22 namespace ngraph
23 {
24  namespace op
25  {
26  class NGRAPH_API ReadValueBase : public Op
27  {
28  public:
29  NGRAPH_RTTI_DECLARATION;
30 
31  ReadValueBase() = default;
32 
33  /// \brief Constructs an AssignBase operation.
34  explicit ReadValueBase(const OutputVector& arguments)
35  : Op(arguments)
36  {
37  }
38 
39  /// \brief Sets the identifier of corresponding variable
40  ///
41  /// \param variable_id New identifier of the variable.
42  virtual void set_variable_id(const std::string& variable_id){};
43 
44  /// \brief Returns the identifier of corresponding variable.
45  virtual std::string get_variable_id() const = 0;
46 
47  /// \brief Returns variable connected to this node.
48  virtual std::shared_ptr<ngraph::Variable> get_variable() const { return m_variable; }
49  /// \brief Sets a new variable to be connected to this node.
50  ///
51  /// \param variable New variable to be connected to this node.
52  virtual void set_variable(const std::shared_ptr<ngraph::Variable>& variable)
53  {
54  m_variable = variable;
55  }
56 
57  protected:
58  std::shared_ptr<ngraph::Variable> m_variable;
59  };
60  namespace v3
61  {
62  /// \brief ReadValue operation creates the variable with `variable_id` and returns value
63  /// of this variable.
64  class NGRAPH_API ReadValue : public ReadValueBase
65  {
66  public:
67  NGRAPH_RTTI_DECLARATION;
68  ReadValue() = default;
69 
70  /// \brief Constructs a ReadValue operation.
71  ///
72  /// \param init_value Node that produces the input tensor.
73  /// \param variable_id identificator of the variable to create.
74  ReadValue(const Output<Node>& init_value, const std::string& variable_id);
75 
76  void validate_and_infer_types() override;
77 
78  std::shared_ptr<Node>
79  clone_with_new_inputs(const OutputVector& new_args) const override;
80 
81  bool visit_attributes(AttributeVisitor& visitor) override;
82 
83  std::string get_variable_id() const override { return m_variable_id; }
84  void set_variable_id(const std::string& variable_id) override
85  {
86  m_variable_id = variable_id;
87  }
88 
89  private:
90  std::string m_variable_id;
91  };
92  }
93 
94  namespace v6
95  {
96  /// \brief ReadValue operation gets an input value from the variable with `variable_id`
97  /// and returns it as an output.
98  class NGRAPH_API ReadValue : public ReadValueBase
99  {
100  public:
101  NGRAPH_RTTI_DECLARATION;
102  ReadValue() = default;
103 
104  /// \brief Constructs a ReadValue operation.
105  ///
106  /// \param init_value Node that produces the input tensor.
107  /// \param variable Class for storing and synchronizing element types, shapes and
108  /// identifiers
109  /// between pairs of Assign/ReadValue nodes.
110  ReadValue(const Output<Node>& init_value,
111  const std::shared_ptr<Variable>& variable);
112 
113  void validate_and_infer_types() override;
114 
115  void revalidate_and_infer_types() override;
116 
117  std::shared_ptr<Node>
118  clone_with_new_inputs(const OutputVector& new_args) const override;
119 
120  bool visit_attributes(AttributeVisitor& visitor) override;
121 
122  std::string get_variable_id() const override
123  {
124  NGRAPH_CHECK(m_variable,
125  "Variable is not initialized. Variable_id is unavailable");
126  return m_variable->get_info().variable_id;
127  }
128  };
129  }
130  }
131 }
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Root of all actual ops.
Definition: op.hpp:29
Definition: read_value.hpp:27
virtual void set_variable(const std::shared_ptr< ngraph::Variable > &variable)
Sets a new variable to be connected to this node.
Definition: read_value.hpp:52
virtual std::string get_variable_id() const =0
Returns the identifier of corresponding variable.
ReadValueBase(const OutputVector &arguments)
Constructs an AssignBase operation.
Definition: read_value.hpp:34
virtual std::shared_ptr< ngraph::Variable > get_variable() const
Returns variable connected to this node.
Definition: read_value.hpp:48
virtual void set_variable_id(const std::string &variable_id)
Sets the identifier of corresponding variable.
Definition: read_value.hpp:42
ReadValue operation creates the variable with variable_id and returns value of this variable.
Definition: read_value.hpp:65
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
std::string get_variable_id() const override
Returns the identifier of corresponding variable.
Definition: read_value.hpp:83
void set_variable_id(const std::string &variable_id) override
Sets the identifier of corresponding variable.
Definition: read_value.hpp:84
ReadValue(const Output< Node > &init_value, const std::string &variable_id)
Constructs a ReadValue operation.
ReadValue operation gets an input value from the variable with variable_id and returns it as an outpu...
Definition: read_value.hpp:99
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
std::string get_variable_id() const override
Returns the identifier of corresponding variable.
Definition: read_value.hpp:122
ReadValue(const Output< Node > &init_value, const std::shared_ptr< Variable > &variable)
Constructs a ReadValue operation.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28