read_value.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/op/op.hpp"
8 #include "ngraph/op/util/variable.hpp"
9 #include "ngraph/op/util/variable_extension.hpp"
10 
11 namespace ngraph
12 {
13  namespace op
14  {
15  class NGRAPH_API ReadValueBase : public Op, public VariableExtension
16  {
17  public:
18  NGRAPH_RTTI_DECLARATION;
19 
20  ReadValueBase() = default;
21 
22  /// \brief Constructs an AssignBase operation.
23  explicit ReadValueBase(const OutputVector& arguments)
24  : Op(arguments)
25  {
26  }
27  };
28 
29  namespace v3
30  {
31  /// \brief ReadValue operation creates the variable with `variable_id` and returns value
32  /// of this variable.
33  class NGRAPH_API ReadValue : public ReadValueBase
34  {
35  public:
36  NGRAPH_RTTI_DECLARATION;
37  ReadValue() = default;
38 
39  /// \brief Constructs a ReadValue operation.
40  ///
41  /// \param init_value Node that produces the input tensor.
42  /// \param variable_id identificator of the variable to create.
43  ReadValue(const Output<Node>& init_value, const std::string& variable_id);
44 
45  void validate_and_infer_types() override;
46 
47  std::shared_ptr<Node>
48  clone_with_new_inputs(const OutputVector& new_args) const override;
49 
50  bool visit_attributes(AttributeVisitor& visitor) override;
51 
52  std::string get_variable_id() const override { return m_variable_id; }
53 
54  private:
55  std::string m_variable_id;
56  };
57  } // namespace v3
58 
59  namespace v6
60  {
61  /// \brief ReadValue operation gets an input value from the variable with `variable_id`
62  /// and returns it as an output.
63  class NGRAPH_API ReadValue : public ReadValueBase
64  {
65  public:
66  NGRAPH_RTTI_DECLARATION;
67  ReadValue() = default;
68 
69  /// \brief Constructs a ReadValue operation.
70  ///
71  /// \param init_value Node that produces the input tensor.
72  /// \param variable Class for storing and synchronizing element types, shapes and
73  /// identifiers
74  /// between pairs of Assign/ReadValue nodes.
75  ReadValue(const Output<Node>& init_value,
76  const std::shared_ptr<Variable>& variable);
77 
78  void validate_and_infer_types() override;
79 
80  void revalidate_and_infer_types() override;
81 
82  std::shared_ptr<Node>
83  clone_with_new_inputs(const OutputVector& new_args) const override;
84 
85  bool visit_attributes(AttributeVisitor& visitor) override;
86 
87  std::string get_variable_id() const override
88  {
89  NGRAPH_CHECK(m_variable,
90  "Variable is not initialized. Variable_id is unavailable");
91  return m_variable->get_info().variable_id;
92  }
93 
94  bool evaluate(const HostTensorVector& outputs,
95  const HostTensorVector& inputs,
96  const EvaluationContext& evaluation_context) const override;
97  bool has_evaluate() const override;
98 
99  bool constant_fold(OutputVector& output_values,
100  const OutputVector& inputs_values) override;
101  };
102  } // namespace v6
103  } // namespace op
104 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Definition: variable_extension.hpp:13
Root of all actual ops.
Definition: op.hpp:17
Definition: read_value.hpp:16
ReadValueBase(const OutputVector &arguments)
Constructs an AssignBase operation.
Definition: read_value.hpp:23
ReadValue operation creates the variable with variable_id and returns value of this variable.
Definition: read_value.hpp:34
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
std::string get_variable_id() const override
Returns the identifier of corresponding variable.
Definition: read_value.hpp:52
ReadValue(const Output< Node > &init_value, const std::string &variable_id)
Constructs a ReadValue operation.
ReadValue operation gets an input value from the variable with variable_id and returns it as an outpu...
Definition: read_value.hpp:64
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs, const EvaluationContext &evaluation_context) const override
Evaluates the op on input_values putting results in output_values.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
std::string get_variable_id() const override
Returns the identifier of corresponding variable.
Definition: read_value.hpp:87
ReadValue(const Output< Node > &init_value, const std::shared_ptr< Variable > &variable)
Constructs a ReadValue operation.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
std::map< std::string, std::shared_ptr< Variant > > EvaluationContext
Definition: node.hpp:63