function.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <atomic>
8 #include <initializer_list>
9 #include <list>
10 #include <memory>
11 #include <string>
12 #include <vector>
13 
14 #include "ngraph/ngraph_visibility.hpp"
15 #include "ngraph/node.hpp"
16 #include "ngraph/op/assign.hpp"
17 #include "ngraph/op/parameter.hpp"
18 #include "ngraph/op/read_value.hpp"
19 #include "ngraph/op/result.hpp"
20 #include "ngraph/op/sink.hpp"
21 #include "ngraph/op/util/variable.hpp"
22 
23 namespace ngraph
24 {
25  /// A user-defined function.
26  class NGRAPH_API Function
27  {
28  public:
29  static constexpr DiscreteTypeInfo type_info{"Function", 0};
30  const DiscreteTypeInfo& get_type_info() const { return type_info; }
31  Function(const NodeVector& results,
32  const ParameterVector& parameters,
33  const std::string& name = "");
34 
35  Function(const OutputVector& results,
36  const ParameterVector& parameters,
37  const std::string& name = "");
38 
39  Function(const std::shared_ptr<Node>& result,
40  const ParameterVector& parameters,
41  const std::string& name = "");
42 
43  Function(const ResultVector& results,
44  const ParameterVector& parameters,
45  const std::string& name = "");
46 
47  Function(const ResultVector& results,
48  const SinkVector& sinks,
49  const ParameterVector& parameters,
50  const std::string& name = "");
51 
52  Function(const OutputVector& results,
53  const SinkVector& sinks,
54  const ParameterVector& parameters,
55  const std::string& name = "");
56 
57  Function(const ResultVector& results,
58  const SinkVector& sinks,
59  const ParameterVector& parameters,
60  const VariableVector& variables,
61  const std::string& name = "");
62 
63  Function(const OutputVector& results,
64  const SinkVector& sinks,
65  const ParameterVector& parameters,
66  const VariableVector& variables,
67  const std::string& name = "");
68 
69  Function(const ResultVector& results,
70  const ParameterVector& parameters,
71  const VariableVector& variables,
72  const std::string& name = "");
73 
74  Function(const OutputVector& results,
75  const ParameterVector& parameters,
76  const VariableVector& variables,
77  const std::string& name = "");
78 
79  /// Constructs a Function. Lists of parameters and variables will be generated automatically
80  /// based on traversing the graph from the results.
81  explicit Function(const OutputVector& results, const std::string& name = "");
82 
83  /// Constructs a Function. Lists of parameters and variables will be generated automatically
84  /// based on traversing the graph from the results and the sinks.
85  Function(const OutputVector& results,
86  const SinkVector& sinks,
87  const std::string& name = "");
88 
89  virtual ~Function() = default;
90  /// Return the number of outputs for this function.
91  size_t get_output_size() const;
92 
93  /// Return the op that generates output i
94  std::shared_ptr<Node> get_output_op(size_t i) const;
95 
96  Output<Node> output(size_t i) const;
97 
98  /// Return the element type of output i
99  const element::Type& get_output_element_type(size_t i) const;
100 
101  /// Return the shape of element i
102  const Shape& get_output_shape(size_t i) const;
103 
104  /// Return the partial shape of element i
105  const PartialShape& get_output_partial_shape(size_t i) const;
106 
107  /// Check that there is a single result and return it.
108  std::shared_ptr<Node> get_result() const;
109 
110  /// \brief Get the unique name of the function.
111  /// \returns A const reference to the function's unique name.
112  const std::string& get_name() const;
113 
114  /// \brief Sets a friendly name for a function. This does not overwrite the unique name
115  /// of the function and is retrieved via get_friendly_name(). Used mainly for
116  /// debugging.
117  /// \param name is the friendly name to set
118  void set_friendly_name(const std::string& name);
119 
120  /// \brief Gets the friendly name for a function. If no friendly name has been set via
121  /// set_friendly_name then the function's unique name is returned.
122  /// \returns A const reference to the function's friendly name.
123  const std::string& get_friendly_name() const;
124 
125  std::vector<std::shared_ptr<Node>> get_ops() const;
126  std::vector<std::shared_ptr<Node>> get_ordered_ops() const;
127  void map_unordered_ops(std::function<void(Node*)> f) const;
128 
129  friend std::ostream& operator<<(std::ostream&, const Function&);
130  // updates graph and m_results list
131  void replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> repl);
132 
133  void validate_nodes_and_infer_types() const;
134 
135  /// \brief Returns the sum of the size of all nodes in the graph plus the size of
136  /// all constant data. This has little value beyond comparing the relative size of
137  /// graphs and should not be considered the actual memory consumption of a graph.
138  size_t get_graph_size() const;
139 
140  /// \brief Returns true if any of the op's defined in the function contains partial shape
141  bool is_dynamic() const;
142 
143  /// \brief Replace the `parameter_index`th parameter of the function with `parameter`.
144  ///
145  /// All users of the `parameter_index`th parameter are redirected to `parameter`, and the
146  /// `parameter_index`th entry in the function parameter list is replaced with `parameter`.
147  ///
148  /// \param parameter_index The index of the parameter to replace.
149  /// \param parameter The parameter to substitute for the `parameter_index`th parameter.
150  void replace_parameter(size_t parameter_index,
151  const std::shared_ptr<op::Parameter>& parameter);
152 
153  using topological_sort_t = std::function<std::vector<std::shared_ptr<Node>>(
154  const std::vector<std::shared_ptr<Node>>& root_nodes)>;
155  void set_topological_sort(topological_sort_t);
156 
157  virtual bool visit_attributes(AttributeVisitor& visitor);
158 
159  /// Return the function parameters
160  const ParameterVector& get_parameters() const { return m_parameters; };
161  /// Return a list of function's outputs
162  const ResultVector& get_results() const { return m_results; };
163  /// Index for parameter, or -1
164  int64_t get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const;
165 
166  /// Index for value or result referencing it, or -1
167  int64_t get_result_index(const Output<Node>& value) const;
168 
169  /// \brief Evaluate the function on inputs, putting results in outputs.
170  /// \param output_tensors Tensors for the outputs to compute. One for each result
171  /// \param input_tensors Tensors for the inputs. One for each inputs.
172  /// \param evaluation_context Storage of additional settings and attributes that can be used
173  /// when evaluating the function. This additional information can be shared across nodes.
174  bool evaluate(const HostTensorVector& output_tensors,
175  const HostTensorVector& input_tensors,
176  EvaluationContext evaluation_context = EvaluationContext()) const;
177 
178  /// \brief Return a list of function's sinks.
179  const SinkVector& get_sinks() const { return m_sinks; }
180  /// \brief Add new sink nodes to the list. Method doesn't validate graph, it should be done
181  /// manually after all changes.
182  /// \param sinks new sink nodes
183  void add_sinks(const SinkVector& sinks);
184 
185  /// \brief Delete sink node from the list of sinks. Method doesn't delete node from graph.
186  /// \param sink Sink to delete
187  void remove_sink(const std::shared_ptr<op::Sink>& sink);
188 
189  /// \brief Add new Result nodes to the list. Method doesn't validate graph, it should be
190  /// done manually after all changes.
191  /// \param results new Result nodes
192  void add_results(const ResultVector& results);
193 
194  /// \brief Delete Result node from the list of results. Method will not delete node from
195  /// graph.
196  /// \param result Result node to delete
197  void remove_result(const std::shared_ptr<op::Result>& result);
198 
199  /// \brief Add new Parameter nodes to the list.
200  ///
201  /// Method doesn't change or validate graph, it should be done manually.
202  /// For example, if you want to replace `ReadValue` node by `Parameter`, you should do the
203  /// following steps:
204  /// * replace node `ReadValue` by `Parameter` in graph
205  /// * call add_parameter() to add new input to the list
206  /// * call graph validation to check correctness of changes
207  ///
208  /// \param params new Parameter nodes
209  void add_parameters(const ParameterVector& params);
210 
211  /// \brief Delete Parameter node from the list of parameters. Method will not delete node
212  /// from graph. You need to replace Parameter with other operation manually.
213  /// Attention: Indexing of parameters can be changed.
214  ///
215  /// Possible use of method is to replace input by variable. For it the following steps
216  /// should be done:
217  /// * `Parameter` node should be replaced by `ReadValue`
218  /// * call remove_parameter(param) to remove input from the list
219  /// * check if any parameter indexes are saved/used somewhere, update it for all inputs
220  /// because indexes can be changed
221  /// * call graph validation to check all changes
222  ///
223  /// \param param Parameter node to delete
224  void remove_parameter(const std::shared_ptr<op::Parameter>& param);
225 
226  /// \brief Add new variables to the list. Method doesn't validate graph, it should be done
227  /// manually after all changes.
228  /// \param variables new variables to add
229  void add_variables(const VariableVector& variables);
230 
231  /// \brief Delete variable from the list of variables.
232  /// Method doesn't delete nodes that used this variable from the graph.
233  /// \param variable Variable to delete
234  void remove_variable(const VariablePtr& variable);
235 
236  /// \brief Return a list of function's variables.
237  const VariableVector& get_variables() const { return m_variables; }
238 
239  /// \brief Return a variable by specified variable_id.
240  VariablePtr get_variable_by_id(const std::string& variable_id) const;
241 
242  private:
243  Function(const Function&) = delete;
244  Function(const Function&&) = delete;
245  Function& operator=(const Function&) = delete;
246 
247  /// \brief Depending on the options selected,
248  /// checks all the Parameter/Variables are registered in the list of Function
249  /// parameters/variables or finds all Parameters/Variables in a function and registers them.
250  /// \param detect_variables If this flag is true, then it finds all Variables in a function
251  /// and registers them, otherwise checks all the Variables are registered.
252  /// \param detect_parameters If this flag is true, then it finds all Parameters in a
253  /// function and registers them, otherwise checks all the Parameters are registered.
254  void prerequirements(bool detect_variables, bool detect_parameters);
255 
256  static std::atomic<size_t> m_next_instance_id;
257  std::string m_name;
258  const std::string m_unique_name;
259  size_t m_placement{0};
260  topological_sort_t m_topological_sorter;
261 
262  ResultVector m_results;
263  // List of the nodes with side effect in graph.
264  // These nodes are not outputs of graph but should not be removed even if have no children.
265  SinkVector m_sinks;
266  ParameterVector m_parameters;
267  VariableVector m_variables;
268  };
269 
270  template <>
271  class NGRAPH_API AttributeAdapter<std::shared_ptr<Function>>
272  : public DirectValueAccessor<std::shared_ptr<Function>>
273  {
274  public:
275  AttributeAdapter(std::shared_ptr<Function>& value)
277  {
278  }
279 
280  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Function>>",
281  0};
282  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
283  };
284 } // namespace ngraph
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:161
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
Definition: attribute_adapter.hpp:67
A user-defined function.
Definition: function.hpp:27
void remove_sink(const std::shared_ptr< op::Sink > &sink)
Delete sink node from the list of sinks. Method doesn't delete node from graph.
void set_friendly_name(const std::string &name)
Sets a friendly name for a function. This does not overwrite the unique name of the function and is r...
const SinkVector & get_sinks() const
Return a list of function's sinks.
Definition: function.hpp:179
void add_variables(const VariableVector &variables)
Add new variables to the list. Method doesn't validate graph, it should be done manually after all ch...
bool is_dynamic() const
Returns true if any of the op's defined in the function contains partial shape.
Function(const OutputVector &results, const SinkVector &sinks, const std::string &name="")
size_t get_graph_size() const
Returns the sum of the size of all nodes in the graph plus the size of all constant data....
void add_sinks(const SinkVector &sinks)
Add new sink nodes to the list. Method doesn't validate graph, it should be done manually after all c...
VariablePtr get_variable_by_id(const std::string &variable_id) const
Return a variable by specified variable_id.
size_t get_output_size() const
Return the number of outputs for this function.
const std::string & get_name() const
Get the unique name of the function.
void add_results(const ResultVector &results)
Add new Result nodes to the list. Method doesn't validate graph, it should be done manually after all...
void remove_variable(const VariablePtr &variable)
Delete variable from the list of variables. Method doesn't delete nodes that used this variable from ...
const VariableVector & get_variables() const
Return a list of function's variables.
Definition: function.hpp:237
const ParameterVector & get_parameters() const
Return the function parameters.
Definition: function.hpp:160
int64_t get_parameter_index(const std::shared_ptr< op::Parameter > &parameter) const
Index for parameter, or -1.
std::shared_ptr< Node > get_output_op(size_t i) const
Return the op that generates output i.
const element::Type & get_output_element_type(size_t i) const
Return the element type of output i.
std::shared_ptr< Node > get_result() const
Check that there is a single result and return it.
void add_parameters(const ParameterVector &params)
Add new Parameter nodes to the list.
const std::string & get_friendly_name() const
Gets the friendly name for a function. If no friendly name has been set via set_friendly_name then th...
const PartialShape & get_output_partial_shape(size_t i) const
Return the partial shape of element i.
const Shape & get_output_shape(size_t i) const
Return the shape of element i.
bool evaluate(const HostTensorVector &output_tensors, const HostTensorVector &input_tensors, EvaluationContext evaluation_context=EvaluationContext()) const
Evaluate the function on inputs, putting results in outputs.
int64_t get_result_index(const Output< Node > &value) const
Index for value or result referencing it, or -1.
void remove_parameter(const std::shared_ptr< op::Parameter > &param)
Delete Parameter node from the list of parameters. Method will not delete node from graph....
void replace_parameter(size_t parameter_index, const std::shared_ptr< op::Parameter > &parameter)
Replace the parameter_indexth parameter of the function with parameter.
Function(const OutputVector &results, const std::string &name="")
void remove_result(const std::shared_ptr< op::Result > &result)
Delete Result node from the list of results. Method will not delete node from graph.
const ResultVector & get_results() const
Return a list of function's outputs.
Definition: function.hpp:162
Definition: node.hpp:127
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:34
Shape for a tensor.
Definition: shape.hpp:19
Definition: element_type.hpp:51
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
NGRAPH_API void replace_node(std::shared_ptr< Node > target, std::shared_ptr< Node > replacement, const std::vector< int64_t > &output_order)
Replace the node target with the node replacement, i.e., redirect all users and control dependencies ...
std::map< std::string, std::shared_ptr< Variant > > EvaluationContext
Definition: node.hpp:63
Definition: type.hpp:27