node.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <cstddef>
20 #include <string>
21 
22 #include "ngraph/except.hpp"
23 #include "ngraph/node.hpp"
24 #include "onnx_import/utils/onnx_importer_visibility.hpp"
25 
26 namespace ONNX_NAMESPACE
27 {
28  // forward declaration
29  class NodeProto;
30 }
31 
32 namespace ngraph
33 {
34  namespace onnx_import
35  {
36  namespace error
37  {
38  namespace node
39  {
41  {
42  explicit UnknownAttribute(const std::string& node, const std::string& name)
43  : ngraph_error{"Node (" + node + "): unknown attribute \'" + name + "\'"}
44  {
45  }
46  };
47 
48  } // namespace node
49 
50  } // namespace error
51 
52  // forward declaration
53  class Graph;
54  class Subgraph;
55  class Tensor;
56 
57  class ONNX_IMPORTER_API Node
58  {
59  public:
60  Node() = delete;
61  Node(const ONNX_NAMESPACE::NodeProto& node_proto, const Graph& graph);
62 
63  Node(Node&&) noexcept;
64  Node(const Node&);
65 
66  Node& operator=(Node&&) noexcept = delete;
67  Node& operator=(const Node&) = delete;
68 
69  OutputVector get_ng_inputs() const;
70  OutputVector get_ng_nodes() const;
71  const std::string& domain() const;
72  const std::string& op_type() const;
73  const std::string& get_name() const;
74 
75  /// \brief Describe the ONNX Node to make debugging graphs easier
76  /// Function will return the Node's name if it has one, or the names of its outputs.
77  /// \return Description of Node
78  const std::string& get_description() const;
79 
80  const std::vector<std::reference_wrapper<const std::string>>& get_output_names() const;
81  const std::string& output(int index) const;
82  std::size_t get_outputs_size() const;
83 
84  bool has_attribute(const std::string& name) const;
85 
86  template <typename T>
87  T get_attribute_value(const std::string& name, T default_value) const;
88 
89  template <typename T>
90  T get_attribute_value(const std::string& name) const;
91 
92  private:
93  class Impl;
94  // In this case we need custom deleter, because Impl is an incomplete
95  // type. Node's are elements of std::vector. Without custom deleter
96  // compilation fails; the compiler is unable to parameterize an allocator's
97  // default deleter due to incomple type.
98  std::unique_ptr<Impl, void (*)(Impl*)> m_pimpl;
99  };
100 
101  template <>
102  ONNX_IMPORTER_API float Node::get_attribute_value(const std::string& name,
103  float default_value) const;
104 
105  template <>
106  ONNX_IMPORTER_API double Node::get_attribute_value(const std::string& name,
107  double default_value) const;
108 
109  template <>
110  ONNX_IMPORTER_API std::int64_t Node::get_attribute_value(const std::string& name,
111  std::int64_t default_value) const;
112 
113  template <>
114  ONNX_IMPORTER_API std::string Node::get_attribute_value(const std::string& name,
115  std::string default_value) const;
116 
117  template <>
118  ONNX_IMPORTER_API Tensor Node::get_attribute_value(const std::string& name,
119  Tensor default_value) const;
120 
121  template <>
122  ONNX_IMPORTER_API Graph Node::get_attribute_value(const std::string& name,
123  Graph default_value) const;
124 
125  template <>
126  ONNX_IMPORTER_API std::vector<float>
127  Node::get_attribute_value(const std::string& name,
128  std::vector<float> default_value) const;
129 
130  template <>
131  ONNX_IMPORTER_API std::vector<double>
132  Node::get_attribute_value(const std::string& name,
133  std::vector<double> default_value) const;
134 
135  template <>
136  ONNX_IMPORTER_API std::vector<std::int64_t>
137  Node::get_attribute_value(const std::string& name,
138  std::vector<std::int64_t> default_value) const;
139 
140  template <>
141  ONNX_IMPORTER_API std::vector<std::size_t>
142  Node::get_attribute_value(const std::string& name,
143  std::vector<std::size_t> default_value) const;
144 
145  template <>
146  ONNX_IMPORTER_API std::vector<std::string>
147  Node::get_attribute_value(const std::string& name,
148  std::vector<std::string> default_value) const;
149 
150  template <>
151  ONNX_IMPORTER_API std::vector<Tensor>
152  Node::get_attribute_value(const std::string& name,
153  std::vector<Tensor> default_value) const;
154 
155  template <>
156  ONNX_IMPORTER_API std::vector<Graph>
157  Node::get_attribute_value(const std::string& name,
158  std::vector<Graph> default_value) const;
159 
160  template <>
161  ONNX_IMPORTER_API float Node::get_attribute_value(const std::string& name) const;
162 
163  template <>
164  ONNX_IMPORTER_API double Node::get_attribute_value(const std::string& name) const;
165 
166  template <>
167  ONNX_IMPORTER_API std::int64_t Node::get_attribute_value(const std::string& name) const;
168 
169  template <>
170  ONNX_IMPORTER_API std::size_t Node::get_attribute_value(const std::string& name) const;
171 
172  template <>
173  ONNX_IMPORTER_API std::string Node::get_attribute_value(const std::string& name) const;
174 
175  template <>
176  ONNX_IMPORTER_API Tensor Node::get_attribute_value(const std::string& name) const;
177 
178  template <>
179  ONNX_IMPORTER_API Subgraph Node::get_attribute_value(const std::string& name) const;
180 
181  template <>
182  ONNX_IMPORTER_API std::vector<float>
183  Node::get_attribute_value(const std::string& name) const;
184 
185  template <>
186  ONNX_IMPORTER_API std::vector<double>
187  Node::get_attribute_value(const std::string& name) const;
188 
189  template <>
190  ONNX_IMPORTER_API std::vector<std::int64_t>
191  Node::get_attribute_value(const std::string& name) const;
192 
193  template <>
194  ONNX_IMPORTER_API std::vector<std::size_t>
195  Node::get_attribute_value(const std::string& name) const;
196 
197  template <>
198  ONNX_IMPORTER_API std::vector<std::string>
199  Node::get_attribute_value(const std::string& name) const;
200 
201  template <>
202  ONNX_IMPORTER_API std::vector<Tensor>
203  Node::get_attribute_value(const std::string& name) const;
204 
205  template <>
206  ONNX_IMPORTER_API std::vector<Graph>
207  Node::get_attribute_value(const std::string& name) const;
208 
209  inline std::ostream& operator<<(std::ostream& outs, const Node& node)
210  {
211  return (outs << "<Node(" << node.op_type() << "): " << node.get_description() << ">");
212  }
213 
214  } // namespace onnx_import
215 
216 } // namespace ngraph
Base error for ngraph runtime errors.
Definition: except.hpp:28
Definition: node.hpp:58
const std::string & get_description() const
Describe the ONNX Node to make debugging graphs easier Function will return the Node's name if it has...
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28