variable_extension.hpp
1 // Copyright (C) 2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <ngraph/runtime/host_tensor.hpp>
8 #include <utility>
9 
10 namespace ngraph
11 {
12  class NGRAPH_API VariableExtension
13  {
14  public:
15  VariableExtension() = default;
16 
17  /// \brief Returns variable connected to this node.
18  virtual std::shared_ptr<ngraph::Variable> get_variable() const { return m_variable; }
19 
20  /// \brief Sets a new variable to be connected to this node.
21  ///
22  /// \param variable New variable to be connected to this node.
23  virtual void set_variable(const std::shared_ptr<ngraph::Variable>& variable)
24  {
25  m_variable = variable;
26  }
27 
28  /// \brief Sets the identifier to a variable
29  ///
30  /// \param variable_id New identifier of the variable.
31  virtual void set_variable_id(const std::string& variable_id)
32  {
33  m_variable->get_info().variable_id = variable_id;
34  };
35 
36  /// \brief Returns the identifier of corresponding variable.
37  virtual std::string get_variable_id() const = 0;
38 
39  protected:
40  std::shared_ptr<ngraph::Variable> m_variable;
41  };
42 } // namespace ngraph
Definition: variable_extension.hpp:13
virtual std::shared_ptr< ngraph::Variable > get_variable() const
Returns variable connected to this node.
Definition: variable_extension.hpp:18
virtual void set_variable_id(const std::string &variable_id)
Sets the identifier to a variable.
Definition: variable_extension.hpp:31
virtual std::string get_variable_id() const =0
Returns the identifier of corresponding variable.
virtual void set_variable(const std::shared_ptr< ngraph::Variable > &variable)
Sets a new variable to be connected to this node.
Definition: variable_extension.hpp:23
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16