assign.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/op/sink.hpp"
8 #include "ngraph/op/util/variable.hpp"
9 #include "ngraph/op/util/variable_extension.hpp"
10 
11 namespace ngraph
12 {
13  namespace op
14  {
15  class NGRAPH_API AssignBase : public Sink, public VariableExtension
16  {
17  public:
18  NGRAPH_RTTI_DECLARATION;
19  AssignBase() = default;
20  /// \brief Constructs an AssignBase operation.
21  explicit AssignBase(const OutputVector& arguments)
22  : Sink(arguments)
23  {
24  }
25  };
26 
27  namespace v3
28  {
29  /// \brief Assign operation sets an input value to the variable with `variable_id`
30  class NGRAPH_API Assign : public AssignBase
31  {
32  public:
33  NGRAPH_RTTI_DECLARATION;
34  Assign() = default;
35 
36  /// \brief Constructs an Assign operation.
37  ///
38  /// \param new_value Node that produces the input tensor.
39  /// \param variable_id identifier of the variable to be updated.
40  Assign(const Output<Node>& new_value, const std::string& variable_id);
41 
42  void validate_and_infer_types() override;
43  std::string get_variable_id() const override { return m_variable_id; }
44 
45  std::shared_ptr<Node>
46  clone_with_new_inputs(const OutputVector& new_args) const override;
47 
48  bool visit_attributes(AttributeVisitor& visitor) override;
49 
50  private:
51  std::string m_variable_id;
52  };
53  } // namespace v3
54  namespace v6
55  {
56  /// \brief Assign operation sets an input value to the variable with `variable_id`
57  class NGRAPH_API Assign : public AssignBase
58  {
59  public:
60  NGRAPH_RTTI_DECLARATION;
61  Assign() = default;
62 
63  /// \brief Constructs an Assign operation.
64  ///
65  /// \param new_value Node that produces the input tensor.
66  /// \param variable Class for storing and synchronizing element types, shapes and
67  /// identifiers
68  /// between pairs of Assign/ReadValue nodes.
69  Assign(const Output<Node>& new_value, const std::shared_ptr<Variable>& variable);
70 
71  void validate_and_infer_types() override;
72 
73  std::shared_ptr<Node>
74  clone_with_new_inputs(const OutputVector& new_args) const override;
75 
76  bool visit_attributes(AttributeVisitor& visitor) override;
77 
78  std::string get_variable_id() const override
79  {
80  NGRAPH_CHECK(m_variable,
81  "Variable is not initialized. Variable_id is unavailable");
82  return m_variable->get_info().variable_id;
83  }
84  bool evaluate(const HostTensorVector& outputs,
85  const HostTensorVector& inputs,
86  const EvaluationContext& evaluation_context) const override;
87  bool has_evaluate() const override;
88  bool constant_fold(OutputVector& output_values,
89  const OutputVector& inputs_values) override;
90  };
91  } // namespace v6
92  } // namespace op
93 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Definition: variable_extension.hpp:13
Definition: assign.hpp:16
AssignBase(const OutputVector &arguments)
Constructs an AssignBase operation.
Definition: assign.hpp:21
Root of nodes that can be sink nodes.
Definition: sink.hpp:17
Assign operation sets an input value to the variable with variable_id
Definition: assign.hpp:31
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
std::string get_variable_id() const override
Returns the identifier of corresponding variable.
Definition: assign.hpp:43
Assign(const Output< Node > &new_value, const std::string &variable_id)
Constructs an Assign operation.
Assign operation sets an input value to the variable with variable_id
Definition: assign.hpp:58
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs, const EvaluationContext &evaluation_context) const override
Evaluates the op on input_values putting results in output_values.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Assign(const Output< Node > &new_value, const std::shared_ptr< Variable > &variable)
Constructs an Assign operation.
std::string get_variable_id() const override
Returns the identifier of corresponding variable.
Definition: assign.hpp:78
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
std::map< std::string, std::shared_ptr< Variant > > EvaluationContext
Definition: node.hpp:63