function.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <atomic>
20 #include <initializer_list>
21 #include <list>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "ngraph/ngraph_visibility.hpp"
27 #include "ngraph/node.hpp"
28 #include "ngraph/op/parameter.hpp"
29 #include "ngraph/op/result.hpp"
30 
31 namespace ngraph
32 {
33  /// A user-defined function.
34  class NGRAPH_API Function
35  {
36  public:
37  static constexpr DiscreteTypeInfo type_info{"Function", 0};
38  const DiscreteTypeInfo& get_type_info() const { return type_info; }
39  Function(const NodeVector& results,
40  const ParameterVector& parameters,
41  const std::string& name = "");
42 
43  Function(const OutputVector& results,
44  const ParameterVector& parameters,
45  const std::string& name = "");
46 
47  Function(const std::shared_ptr<Node>& result,
48  const ParameterVector& parameters,
49  const std::string& name = "");
50 
51  Function(const ResultVector& results,
52  const ParameterVector& parameters,
53  const std::string& name = "");
54 
55  virtual ~Function() {}
56  /// Return the number of outputs for this function.
57  size_t get_output_size() const;
58 
59  /// Return the op that generates output i
60  std::shared_ptr<Node> get_output_op(size_t i) const;
61 
62  Output<Node> output(size_t i) const;
63 
64  /// Return the element type of output i
65  const element::Type& get_output_element_type(size_t i) const;
66 
67  /// Return the shape of element i
68  const Shape& get_output_shape(size_t i) const;
69 
70  /// Return the partial shape of element i
71  const PartialShape& get_output_partial_shape(size_t i) const;
72 
73  /// Check that there is a single result and return it.
74  std::shared_ptr<Node> get_result() const;
75 
76  /// \brief Get the unique name of the function.
77  /// \returns A const reference to the function's unique name.
78  const std::string& get_name() const;
79 
80  /// \brief Sets a friendly name for a function. This does not overwrite the unique name
81  /// of the function and is retrieved via get_friendly_name(). Used mainly for
82  /// debugging.
83  /// \param name is the friendly name to set
84  void set_friendly_name(const std::string& name);
85 
86  /// \brief Gets the friendly name for a function. If no friendly name has been set via
87  /// set_friendly_name then the function's unique name is returned.
88  /// \returns A const reference to the function's friendly name.
89  const std::string& get_friendly_name() const;
90 
91  std::vector<std::shared_ptr<Node>> get_ops() const;
92  std::vector<std::shared_ptr<Node>> get_ordered_ops() const;
93  void map_unordered_ops(std::function<void(Node*)> f) const;
94 
95  friend std::ostream& operator<<(std::ostream&, const Function&);
96  // updates graph and m_results list
97  void replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> repl);
98 
99  void validate_nodes_and_infer_types();
100 
101  /// \brief Returns the sum of the size of all nodes in the graph plus the size of
102  /// all constant data. This has little value beyond comparing the relative size of
103  /// graphs and should not be considered the actual memory consumption of a graph.
104  size_t get_graph_size() const;
105 
106  /// \brief Returns true if any of the op's defined in the function contains partial shape
107  bool is_dynamic() const;
108 
109  /// \brief Replace the `parameter_index`th parameter of the function with `parameter`.
110  ///
111  /// All users of the `parameter_index`th parameter are redirected to `parameter`, and the
112  /// `parameter_index`th entry in the function parameter list is replaced with `parameter`.
113  ///
114  /// \param parameter_index The index of the parameter to replace.
115  /// \param parameter The parameter to substitute for the `parameter_index`th parameter.
116  void replace_parameter(size_t parameter_index,
117  const std::shared_ptr<op::Parameter>& parameter);
118 
119  using topological_sort_t = std::function<std::vector<std::shared_ptr<Node>>(
120  const std::vector<std::shared_ptr<Node>>& root_nodes)>;
121  void set_topological_sort(topological_sort_t);
122 
123  virtual bool visit_attributes(AttributeVisitor& visitor);
124 
125  /// Return the function parameters
126  const ParameterVector& get_parameters() const { return m_parameters; };
127  /// Return a list of function's outputs
128  const ResultVector& get_results() const { return m_results; };
129  /// Index for parameter, or -1
130  int64_t get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const;
131 
132  /// Index for value or result referencing it, or -1
133  int64_t get_result_index(const Output<Node>& value) const;
134 
135  /// \brief Evaluate the function on inputs, putting results in outputs.
136  /// \param outputs Tensors for the outputs to compute. One for each result
137  /// \param inputs Tensors for the inputs. One for each inputs.
138  bool evaluate(const HostTensorVector& output_tensors,
139  const HostTensorVector& input_tensors) const;
140 
141  private:
142  Function(const Function&) = delete;
143  Function(const Function&&) = delete;
144  Function& operator=(const Function&) = delete;
145 
146  static std::atomic<size_t> m_next_instance_id;
147  std::string m_name;
148  const std::string m_unique_name;
149  size_t m_placement{0};
150  topological_sort_t m_topological_sorter;
151 
152  ResultVector m_results;
153  ParameterVector m_parameters;
154  };
155 
156  template <>
157  class NGRAPH_API AttributeAdapter<std::shared_ptr<Function>> : public VisitorAdapter
158  {
159  public:
160  AttributeAdapter(std::shared_ptr<Function>& ref);
161 
162  bool visit_attributes(AttributeVisitor& visitor) override;
163 
164  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<shared_ptr<Function>>", 0};
165  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
166  protected:
167  std::shared_ptr<Function>& m_ref;
168  };
169 }
ngraph::Function::get_result_index
int64_t get_result_index(const Output< Node > &value) const
Index for value or result referencing it, or -1.
ngraph::Function::get_friendly_name
const std::string & get_friendly_name() const
Gets the friendly name for a function. If no friendly name has been set via set_friendly_name then th...
ngraph::PartialShape
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:46
ngraph::Function::get_result
std::shared_ptr< Node > get_result() const
Check that there is a single result and return it.
ngraph::Function::get_output_op
std::shared_ptr< Node > get_output_op(size_t i) const
Return the op that generates output i.
ngraph::Function::get_output_shape
const Shape & get_output_shape(size_t i) const
Return the shape of element i.
ngraph::Function::evaluate
bool evaluate(const HostTensorVector &output_tensors, const HostTensorVector &input_tensors) const
Evaluate the function on inputs, putting results in outputs.
ngraph::Function::get_graph_size
size_t get_graph_size() const
Returns the sum of the size of all nodes in the graph plus the size of all constant data....
ngraph::Function::get_output_size
size_t get_output_size() const
Return the number of outputs for this function.
ngraph::Output< Node >
A handle for one of a node's outputs.
Definition: node_output.hpp:41
ngraph::replace_node
NGRAPH_API void replace_node(std::shared_ptr< Node > target, std::shared_ptr< Node > replacement, const std::vector< int64_t > &output_order)
Replace the node target with the node replacement, i.e., redirect all users and control dependencies ...
ngraph::Function::get_output_partial_shape
const PartialShape & get_output_partial_shape(size_t i) const
Return the partial shape of element i.
ngraph::DiscreteTypeInfo
Definition: type.hpp:39
ngraph::Shape
Shape for a tensor.
Definition: shape.hpp:31
ngraph::Function::get_parameters
const ParameterVector & get_parameters() const
Return the function parameters.
Definition: function.hpp:126
ngraph::Function::get_output_element_type
const element::Type & get_output_element_type(size_t i) const
Return the element type of output i.
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::Function::get_parameter_index
int64_t get_parameter_index(const std::shared_ptr< op::Parameter > &parameter) const
Index for parameter, or -1.
ngraph::Function::get_results
const ResultVector & get_results() const
Return a list of function's outputs.
Definition: function.hpp:128
ngraph::AttributeAdapter< std::shared_ptr< Function > >::get_type_info
const DiscreteTypeInfo & get_type_info() const override
type info enables identification of the value accessor, as well as is_type and as_type.
Definition: function.hpp:165
ngraph::Function::replace_parameter
void replace_parameter(size_t parameter_index, const std::shared_ptr< op::Parameter > &parameter)
Replace the parameter_indexth parameter of the function with parameter.
ngraph::Function::is_dynamic
bool is_dynamic() const
Returns true if any of the op's defined in the function contains partial shape.
ngraph::Function::set_friendly_name
void set_friendly_name(const std::string &name)
Sets a friendly name for a function. This does not overwrite the unique name of the function and is r...
ngraph::Function
A user-defined function.
Definition: function.hpp:35
ngraph::Node
Definition: node.hpp:131
ngraph::Function::get_name
const std::string & get_name() const
Get the unique name of the function.