node.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <atomic>
20 #include <cstring>
21 #include <deque>
22 #include <iostream>
23 #include <map>
24 #include <memory>
25 #include <set>
26 #include <string>
27 #include <tuple>
28 #include <typeindex>
29 #include <unordered_map>
30 #include <unordered_set>
31 #include <vector>
32 
33 #include "ngraph/attribute_visitor.hpp"
34 #include "ngraph/check.hpp"
35 #include "ngraph/coordinate.hpp"
36 #include "ngraph/coordinate_diff.hpp"
37 #include "ngraph/deprecated.hpp"
38 #include "ngraph/descriptor/input.hpp"
39 #include "ngraph/descriptor/output.hpp"
40 #include "ngraph/descriptor/tensor.hpp"
41 #include "ngraph/node_input.hpp"
42 #include "ngraph/node_output.hpp"
43 #include "ngraph/op/util/attr_types.hpp"
44 #include "ngraph/op/util/op_annotations.hpp"
45 #include "ngraph/output_vector.hpp"
46 #include "ngraph/strides.hpp"
47 #include "ngraph/type.hpp"
48 
49 namespace ngraph
50 {
51  template <typename NodeType>
52  class Input;
53 
54  template <typename NodeType>
55  class Output;
56 
57  class AttributeVisitor;
58  class Variant;
59  class Node;
60 
61  class Function;
62 
63  namespace runtime
64  {
65  class HostTensor;
66  }
67  using HostTensor = runtime::HostTensor;
68  using HostTensorPtr = std::shared_ptr<HostTensor>;
69  using HostTensorVector = std::vector<HostTensorPtr>;
70 
71  namespace op
72  {
73  struct AutoBroadcastSpec;
74 
75  namespace v0
76  {
77  class Result;
78  }
79  } // namespace op
80 
81  namespace pattern
82  {
83  class Matcher;
84  }
85 
86  using ResultVector = std::vector<std::shared_ptr<op::v0::Result>>;
87 
88  NGRAPH_API
89  std::string node_validation_failure_loc_string(const Node* node);
90 
91  const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node,
92  size_t i);
93  NGRAPH_API
94  const NodeVector& check_single_output_args(const NodeVector& args);
95 
96  const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node,
97  size_t i);
98 
99  NGRAPH_API
100  OutputVector as_output_vector(const NodeVector& args);
101  NGRAPH_API
102  NodeVector as_node_vector(const OutputVector& values);
103  /// Returns a ResultVector referencing values.
104  NGRAPH_API
105  ResultVector as_result_vector(const OutputVector& values);
106 
107  /// Alias useful for cloning
108  using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>;
109 
110 /// \brief Used in evaluator switch statement so that the case type and evaluate call
111 /// are guaranteed to have the types match.
112 ///
113 /// Use this in an evaluate_*() function like this
114 /// switch (arg0->get_element_type())
115 /// {
116 /// TYPE_CASE(i8)(arg0, arg1, out, broadcast_spec); break;
117 /// TYPE_CASE(i16)(arg0, arg1, out, broadcast_spec); break;
118 ///
119 /// Each TYPE_CASE statement expands like this:
120 /// case element::Type_t::a: rc = evaluate<element::Type_t::a>(arg0, arg1, out, broadcast_spec)
121 ///
122 /// \note Don't forget to put a break after each statement or it will fall through and generate
123 /// a runtime error.
124 
125 #define TYPE_CASE(a) \
126  case element::Type_t::a: rc = evaluate<element::Type_t::a>
127 
128  /// Nodes are the backbone of the graph of Value dataflow. Every node has
129  /// zero or more nodes as arguments and one value, which is either a tensor
130  /// or a (possibly empty) tuple of values.
131  class NGRAPH_API Node : public std::enable_shared_from_this<Node>
132  {
133  // For access to m_outputs.
134  friend class descriptor::Input;
135 
136  // For access to m_inputs and m_outputs.
137  template <typename NodeType>
138  friend class Input;
139 
140  // For access to m_outputs.
141  template <typename NodeType>
142  friend class Output;
143 
144  public:
145  /// \brief Verifies that attributes and inputs are consistent and computes output shapes
146  /// and element types. Must be implemented by concrete child classes so that it
147  /// can be run any number of times.
148  ///
149  /// Throws if the node is invalid.
150  virtual void validate_and_infer_types();
151 
152  // Called in constructors during transition
153  void constructor_validate_and_infer_types();
154 
156 
157  protected:
158  /// \brief Construct an unitialized Node
159  Node() = default;
160  /// \brief Copying a node
161  Node(const Node&);
162  /// \brief Assignment operator
163  Node& operator=(const Node&);
164 
165  /// \brief Construct an unitialized Node
166  /// \param output_size Number of outputs for this node
167  Node(size_t output_size);
168 
169  /// \brief Constructor for Node subclasses that have metaclasses.
170  /// \param arguments Output i will connect to input i
171  /// \param output_size Number of outputs for this node
172  Node(const OutputVector& arguments, size_t output_size = 1);
173  /// \brief Moves nodes that would be deleted from inputs to nodes to avoid stack overflows
174  /// on deep networks.
175  void safe_delete(NodeVector& nodes, bool recurse);
176 
177  /// \brief Marks an input as being relevant or irrelevant to the output shapes of this
178  /// node.
179  /// \param i The index of the input to mark as relevant or irrelevant.
180  /// \param relevant true if the input is relevant to output shapes, false otherwise.
181  ///
182  /// This is used by the shape specialization pass to know which nodes must be statically
183  /// evaluated in order to complete shape specialization. (For example, the shape input of
184  /// DynReshape must be evaluated statically in order for the output shape to be
185  /// determined.) By default, all inputs are marked as shape-irrelevant. Overrides of
186  /// validate_and_infer_types should call this function to mark shape-relevant inputs.
187  void set_input_is_relevant_to_shape(size_t i, bool relevant = true);
188 
189  /// \brief Marks an input as being relevant or irrelevant to the output values of this
190  /// node.
191  /// \param i The index of the input to mark as relevant or irrelevant.
192  /// \param relevant true if the input is relevant to output values, false otherwise.
193  ///
194  /// This is used by the shape specialization pass to cut short evaluation in cases where
195  /// an input value does not actually have any effect on the output value of the node. (As
196  /// of this writing, the only example of this is ShapeOf.) By default, all inputs are
197  /// marked as value-relevant. Overrides of validate_and_infer_types should call this
198  /// function to mark value-irrelevant inputs.
199  void set_input_is_relevant_to_value(size_t i, bool relevant = true);
200 
201  public:
202  virtual ~Node();
203 
204  virtual bool visit_attributes(AttributeVisitor&) { return false; }
205  /// \returns the autobroadcasr spec
206  virtual const op::AutoBroadcastSpec& get_autob() const;
207  /// \brief Evaluates the op on input_values putting results in output_values
208  /// \returns true if successful
209  virtual bool evaluate(const HostTensorVector& output_values,
210  const HostTensorVector& input_values) const;
211  virtual bool evaluate_lower(const HostTensorVector& output_values) const;
212  virtual bool evaluate_upper(const HostTensorVector& output_values) const;
213 
214  virtual bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values);
215  /// \brief Decomposes the FusedOp into a sub-graph consisting of core ngraph ops
216  ///
217  /// \return A vector of nodes comprising the sub-graph. The order of output
218  /// tensors must match the match output tensors of the FusedOp
219  virtual OutputVector decompose_op() const { return OutputVector(); }
220  /// Returns the NodeTypeInfo for the node's class.
221  /// During transition to type_info, returns a dummy type_info for Node if the class
222  /// has not been updated yet.
223  virtual const type_info_t& get_type_info() const = 0;
224  const char* get_type_name() const { return get_type_info().name; }
225  /// Sets/replaces the arguments with new arguments.
226  void set_arguments(const NodeVector& arguments);
227  /// Sets/replaces the arguments with new arguments.
228  void set_arguments(const OutputVector& arguments);
229  /// Sets/replaces the arguments with new arguments.
230  void set_argument(size_t position, const Output<Node>& argument);
231 
232  void set_output_type(size_t i,
233  const element::Type& element_type,
234  const PartialShape& pshape);
235 
236  /// Sets the number of outputs
237  void set_output_size(size_t output_size);
238 
239  void invalidate_values();
240  virtual void revalidate_and_infer_types()
241  {
242  invalidate_values();
243  validate_and_infer_types();
244  }
245  /// \brief Get the string name for the type of the node, such as `Add` or `Multiply`.
246  /// The class name, must not contain spaces as it is used for codegen.
247  /// \returns A const reference to the node's type name
248  virtual std::string description() const;
249  /// \brief Get the unique name of the node.
250  /// \returns A const reference to the node's unique name.
251  const std::string& get_name() const;
252 
253  /// \brief Sets a friendly name for a node. This does not overwrite the unique name
254  /// of the node and is retrieved via get_friendly_name(). Used mainly for debugging.
255  /// The friendly name may be set exactly once.
256  /// \param name is the friendly name to set
257  void set_friendly_name(const std::string& name);
258 
259  /// \brief Gets the friendly name for a node. If no friendly name has been set via
260  /// set_friendly_name then the node's unique name is returned.
261  /// \returns A const reference to the node's friendly name.
262  const std::string& get_friendly_name() const;
263 
264  virtual bool is_dynamic() const;
265  size_t get_instance_id() const { return m_instance_id; }
266  /// \brief Writes a description of a node to a stream
267  /// \param os The stream; should be returned
268  /// \param depth How many levels of inputs to describe
269  /// \returns The stream os
270  virtual std::ostream& write_description(std::ostream& os, uint32_t depth = 0) const;
271 
272  /// Get control dependencies registered on the node
273  const std::vector<std::shared_ptr<Node>>& get_control_dependencies() const;
274 
275  /// Get nodes dependent on this node
276  const std::vector<Node*>& get_control_dependents() const;
277 
278  /// This node cannot execute until node executes
279  void add_control_dependency(std::shared_ptr<Node> node);
280 
281  /// Remove the dependency of this node on node
282  void remove_control_dependency(std::shared_ptr<Node> node);
283 
284  /// Remove all dependencies from this node
286 
287  /// Remove this node as a dependency from all dependent nodes
289 
290  /// This node absorbs the control dependencies of source_node
291  void add_node_control_dependencies(std::shared_ptr<Node> source_node);
292 
293  /// This node becomes a dependent of every node dependent on source_node
294  void add_node_control_dependents(std::shared_ptr<Node> source_node);
295 
296  /// This node's control dependencies are replaced by replacement
297  void transfer_control_dependents(std::shared_ptr<Node> replacement);
298 
299  /// Returns the number of outputs from the node.
300  size_t get_output_size() const;
301 
302  /// Returns the element type for output i
303  const element::Type& get_output_element_type(size_t i) const;
304 
305  /// Checks that there is exactly one output and returns its element type
306  // TODO: deprecate in favor of node->get_output_element_type(0) with a suitable check in
307  // the calling code, or updates to the calling code if it is making an invalid assumption
308  // of only one output.
310 
311  /// Returns the shape for output i
312  const Shape& get_output_shape(size_t i) const;
313 
314  /// Returns the partial shape for output i
315  const PartialShape& get_output_partial_shape(size_t i) const;
316 
317  /// Return the output to use when converting to an Output<Node> with no index specified.
318  /// Throws when not supported.
320  Output<Node> get_default_output();
321 
322  /// Returns the output of the default output, or throws if there is none
323  virtual size_t get_default_output_index() const;
324  /// Throws no default
325  size_t no_default_index() const;
326 
327  /// Checks that there is exactly one output and returns its shape
328  // TODO: deprecate in favor of node->get_output_shape(0) with a suitable check in the
329  // calling code, or updates to the calling code if it is making an invalid assumption of
330  // only one output.
331  const Shape& get_shape() const;
332 
333  /// Returns the tensor for output or input i
335  descriptor::Tensor& get_input_tensor(size_t i) const;
336 
337  /// Returns the tensor name for output i
338  NGRAPH_DEPRECATED(
339  "The tensor name was deprecated. Use get_output_tensor(i).get_names() instead.")
340  const std::string& get_output_tensor_name(size_t i) const;
341 
342  std::set<Input<Node>> get_output_target_inputs(size_t i) const;
343 
344  /// Returns the number of inputs for the op
345  size_t get_input_size() const;
346 
347  /// Returns the element type of input i
348  // TODO: deprecate in favor of node->get_input_element_type(i)
349  const element::Type& get_input_element_type(size_t i) const;
350 
351  /// Returns the shape of input i
352  // TODO: deprecate in favor of node->get_input_shape(i)
353  const Shape& get_input_shape(size_t i) const;
354 
355  /// Returns the partial shape of input i
356  // TODO: deprecate in favor of node->get_input_partial_shape(i)
357  const PartialShape& get_input_partial_shape(size_t i) const;
358 
359  /// Returns the tensor name for input i
360  NGRAPH_DEPRECATED(
361  "The tensor name was deprecated. Use get_input_tensor(i).get_names() instead.")
362  const std::string& get_input_tensor_name(size_t i) const;
363 
364  std::unordered_set<descriptor::Tensor*> liveness_new_list;
365  std::unordered_set<descriptor::Tensor*> liveness_free_list;
366 
367  Node* get_input_node_ptr(size_t index) const;
368  std::shared_ptr<Node> get_input_node_shared_ptr(size_t index) const;
369  Output<Node> get_input_source_output(size_t i) const;
370 
371  public:
372  virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const = 0;
373 
374  std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const;
375 
376  std::shared_ptr<Node> copy_with_new_inputs(
377  const OutputVector& inputs,
378  const std::vector<std::shared_ptr<Node>>& control_dependencies) const;
379 
380  /// True if this and node have one output with same element type and shape
381  bool has_same_type(std::shared_ptr<const Node> node) const;
382 
383  using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
384 
385  RTMap& get_rt_info() { return m_rt_info; }
386  const RTMap& get_rt_info() const { return m_rt_info; }
387  const std::unordered_set<std::string>& get_provenance_tags() const;
388  void add_provenance_tag(const std::string& tag);
389  template <typename T>
390  void add_provenance_tags(T tag_set)
391  {
392  for (auto tag : tag_set)
393  {
394  add_provenance_tag(tag);
395  }
396  }
397  /// \brief Adds tag_set to this node and all intermediate nodes above base
398  void add_provenance_tags_above(const OutputVector& base,
399  const std::unordered_set<std::string>& tag_set);
400  void remove_provenance_tag(const std::string& tag);
401  /// \brief Add node to additional nodes that receive tags
402  void add_provenance_group_member(const std::shared_ptr<Node>& node);
403  /// \brief Remove node to additional nodes that receive tags
404  void remove_provenance_group_member(const std::shared_ptr<Node>& node);
405  /// \brief Replace current_node with replacement_node and transfer tags
406  void replace_provenance_group_member(const std::shared_ptr<Node>& current_node,
407  const std::shared_ptr<Node>& replacement_node);
408  /// \return Provenance group nodes
409  const std::set<std::shared_ptr<Node>>& get_provenance_group_members() const;
410 
411  /// \brief Add all nodes between this node and nodes in base as additional nodes to receive
412  /// provenance tags.
413  std::shared_ptr<Node> add_provenance_group_members_above(const OutputVector& base);
414 
415  // to be used when nodes are replaced
416  void merge_provenance_tags_from(const std::shared_ptr<const Node>& source);
417 
418  /// Transfer provenance tags to replacement
419  void transfer_provenance_tags(const std::shared_ptr<Node>& replacement);
420 
421  /// Get all the nodes that uses the current node
422  NodeVector get_users(bool check_is_used = false) const;
423 
424  /// \return Version of this node
425  virtual size_t get_version() const { return get_type_info().version; }
426  virtual std::shared_ptr<Node> get_default_value() const { return nullptr; }
427  /// Use instance ids for comparison instead of memory addresses to improve determinism
428  bool operator<(const Node& other) const { return m_instance_id < other.m_instance_id; }
429  /// \return A vector containing a handle for each of this node's inputs, in order.
430  // TODO: Rename to get_inputs()?
431  std::vector<Input<Node>> inputs();
432 
433  /// \return A vector containing a handle for each of this node's inputs, in order.
434  std::vector<Input<const Node>> inputs() const;
435 
436  /// \return A vector containing the values for each input
437  std::vector<Output<Node>> input_values() const;
438 
439  /// \return A vector containing a handle for each of this node's outputs, in order.
440  // TODO: Rename to get_outputs()?
441  std::vector<Output<Node>> outputs();
442 
443  /// \return A vector containing a handle for each of this node's outputs, in order.
444  std::vector<Output<const Node>> outputs() const;
445 
446  /// \return A handle to the `input_index`th input of this node.
447  /// \throw std::out_of_range if the node does not have at least `input_index+1` inputs.
448  Input<Node> input(size_t input_index);
449 
450  /// \return A handle to the `input_index`th input of this node.
451  /// \throw std::out_of_range if the node does not have at least `input_index+1` inputs.
452  Input<const Node> input(size_t input_index) const;
453 
454  Output<Node> input_value(size_t input_index) const;
455 
456  /// \return A handle to the `output_index`th output of this node.
457  /// \throw std::out_of_range if the node does not have at least `output_index+1` outputs.
458  Output<Node> output(size_t output_index);
459 
460  /// \return A handle to the `output_index`th output of this node.
461  /// \throw std::out_of_range if the node does not have at least `output_index+1` outputs.
462  Output<const Node> output(size_t output_index) const;
463 
464  void set_op_annotations(std::shared_ptr<ngraph::op::util::OpAnnotations> op_annotations)
465  {
466  m_op_annotations = op_annotations;
467  }
468  std::shared_ptr<ngraph::op::util::OpAnnotations> get_op_annotations() const
469  {
470  return m_op_annotations;
471  }
472 
473  virtual bool match_value(pattern::Matcher* matcher,
474  const Output<Node>& pattern_value,
475  const Output<Node>& graph_value);
476 
477  virtual bool match_node(pattern::Matcher* matcher, const Output<Node>& graph_value);
478 
479  private:
480  descriptor::Input& get_input_descriptor(size_t position);
481  descriptor::Output& get_output_descriptor(size_t position);
482 
483  std::vector<Node*> m_control_dependents;
484  std::vector<std::shared_ptr<Node>> m_control_dependencies;
485  std::string m_node_type;
486  size_t m_instance_id{m_next_instance_id.fetch_add(1)};
487  std::string m_friendly_name;
488  std::string m_unique_name;
489  static std::atomic<size_t> m_next_instance_id;
490  std::unordered_set<std::string> m_provenance_tags;
491  std::set<std::shared_ptr<Node>> m_provenance_group;
492  std::deque<descriptor::Input> m_inputs;
493  std::deque<descriptor::Output> m_outputs;
494  std::shared_ptr<ngraph::op::util::OpAnnotations> m_op_annotations;
495  std::map<std::string, std::shared_ptr<Variant>> m_rt_info;
496  };
497 
498  using NodeTypeInfo = Node::type_info_t;
499 
500  NGRAPH_API std::ostream& operator<<(std::ostream&, const Node&);
501  NGRAPH_API std::ostream& operator<<(std::ostream&, const Node*);
502 
503 #define _NGRAPH_RTTI_EXPAND(X) X
504 
505 /// Helper macro that puts necessary declarations of RTTI block inside a class definition.
506 /// Should be used in the scope of class that requires type identification besides one provided by
507 /// C++ RTTI.
508 /// Recommended to be used for all classes that are inherited from class ngraph::Node to enable
509 /// pattern
510 /// matching for them. Accepts necessary type identification details like type of the operation,
511 /// version and optional parent class.
512 ///
513 /// Applying this macro within a class definition provides declaration of type_info static
514 /// constant for backward compatibility with old RTTI definition for Node,
515 /// static function get_type_info_static which returns a reference to an object that is equal to
516 /// type_info but not necessary to the same object, and get_type_info virtual function that
517 /// overrides Node::get_type_info and returns a reference to the same object that
518 /// get_type_info_static gives.
519 ///
520 /// Use this macro as a public part of the class definition:
521 ///
522 /// class MyOp : public Node
523 /// {
524 /// public:
525 /// // Don't use Node as a parent for type_info, it doesn't have any value and
526 /// prohibited
527 /// NGRAPH_RTTI_DECLARATION;
528 ///
529 /// ...
530 /// };
531 ///
532 /// class MyInheritedOp : public MyOp
533 /// {
534 /// public:
535 /// NGRAPH_RTTI_DECLARATION;
536 ///
537 /// ...
538 /// };
539 ///
540 /// To complete type identification for a class, use NGRAPH_RTTI_DEFINITION.
541 ///
542 #define NGRAPH_RTTI_DECLARATION \
543  static const ::ngraph::Node::type_info_t type_info; \
544  const ::ngraph::Node::type_info_t& get_type_info() const override; \
545  static const ::ngraph::Node::type_info_t& get_type_info_static()
546 
547 #define _NGRAPH_RTTI_DEFINITION_COMMON(CLASS) \
548  const ::ngraph::Node::type_info_t& CLASS::get_type_info() const \
549  { \
550  return get_type_info_static(); \
551  } \
552  const ::ngraph::Node::type_info_t CLASS::type_info = CLASS::get_type_info_static()
553 #define _NGRAPH_RTTI_DEFINITION_WITH_PARENT(CLASS, TYPE_NAME, _VERSION_INDEX, PARENT_CLASS) \
554  const ::ngraph::Node::type_info_t& CLASS::get_type_info_static() \
555  { \
556  static const ::ngraph::Node::type_info_t type_info_static{ \
557  TYPE_NAME, _VERSION_INDEX, &PARENT_CLASS::get_type_info_static()}; \
558  return type_info_static; \
559  } \
560  _NGRAPH_RTTI_DEFINITION_COMMON(CLASS)
561 
562 #define _NGRAPH_RTTI_DEFINITION_NO_PARENT(CLASS, TYPE_NAME, _VERSION_INDEX) \
563  const ::ngraph::Node::type_info_t& CLASS::get_type_info_static() \
564  { \
565  static const ::ngraph::Node::type_info_t type_info_static{TYPE_NAME, _VERSION_INDEX}; \
566  return type_info_static; \
567  } \
568  _NGRAPH_RTTI_DEFINITION_COMMON(CLASS)
569 
570 #define _NGRAPH_RTTI_DEFINITION_SELECTOR(_1, _2, _3, _4, NAME, ...) NAME
571 
572 /// Complementary to NGRAPH_RTTI_DECLARATION, this helper macro _defines_ items _declared_ by
573 /// NGRAPH_RTTI_DECLARATION.
574 /// Should be used outside the class definition scope in place where ODR is ensured.
575 ///
576 /// \param CLASS is a C++ name of the class where corresponding NGRAPH_RTTI_DECLARATION was applied.
577 /// \param TYPE_NAME a string literal of type const char* that names your class in type
578 /// identification namespace;
579 /// It is your choice how to name it, but it should be unique among all
580 /// NGRAPH_RTTI_DECLARATION-enabled classes that can be
581 /// used in conjunction with each other in one transformation flow.
582 /// \param _VERSION_INDEX is an unsigned integer index to distinguish different versions of
583 /// operations that shares the same TYPE_NAME
584 /// \param PARENT_CLASS is an optional direct or indirect parent class for this class; define
585 /// it only in case if there is a need to capture any operation from some group of operations
586 /// that all derived from some common base class. Don't use Node as a parent, it is a base
587 /// class
588 /// for all operations and doesn't provide ability to define some perfect subset of
589 /// operations. PARENT_CLASS should define RTTI with NGRAPH_RTTI_{DECLARATION/DEFINITION}
590 /// macros.
591 ///
592 /// Examples (see corresponding declarations in NGRAPH_RTTI_DECLARATION description):
593 ///
594 /// NGRAPH_RTTI_DEFINITION(MyOp,"MyOp", 1);
595 /// NGRAPH_RTTI_DEFINITION(MyInheritedOp, "MyInheritedOp", 1, MyOp)
596 ///
597 /// For convenience, TYPE_NAME and CLASS name are recommended to be the same.
598 ///
599 #define NGRAPH_RTTI_DEFINITION(...) \
600  _NGRAPH_RTTI_EXPAND(_NGRAPH_RTTI_DEFINITION_SELECTOR( \
601  __VA_ARGS__, _NGRAPH_RTTI_DEFINITION_WITH_PARENT, _NGRAPH_RTTI_DEFINITION_NO_PARENT)( \
602  __VA_ARGS__))
603 
604  // Like an Output but with a Node* instead of a shared_ptr<Node>
606  {
607  RawNodeOutput(const Output<Node>& value)
608  : node(value.get_node())
609  , index(value.get_index())
610  {
611  }
612  RawNodeOutput(Node* node, size_t index)
613  : node(node)
614  , index(index)
615  {
616  }
617  RawNodeOutput(const RawNodeOutput&) = default;
618  RawNodeOutput() = default;
619  RawNodeOutput& operator=(const RawNodeOutput&) = default;
620 
621  Node* node;
622  size_t index{0};
623 
624  operator Output<Node>() { return Output<Node>(node->shared_from_this(), index); }
625  bool operator==(const RawNodeOutput& other) const
626  {
627  return node == other.node && index == other.index;
628  }
629  bool operator!=(const RawNodeOutput& other) const { return !(*this == other); }
630  bool operator<(const RawNodeOutput& other) const
631  {
632  return node < other.node || (node == other.node && index < other.index);
633  }
634  bool operator>(const RawNodeOutput& other) const
635  {
636  return node > other.node || (node == other.node && index > other.index);
637  }
638  bool operator<=(const RawNodeOutput& other) const { return !(*this > other); }
639  bool operator>=(const RawNodeOutput& other) const { return !(*this < other); }
640  };
641 
642  /// \brief Visits a reference to a node that has been registered with the visitor.
643  template <>
644  class NGRAPH_API AttributeAdapter<std::shared_ptr<Node>> : public VisitorAdapter
645  {
646  public:
647  AttributeAdapter(std::shared_ptr<Node>& value);
648 
649  bool visit_attributes(AttributeVisitor& visitor) override;
650  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Node>>", 0};
651  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
652  protected:
653  std::shared_ptr<Node>& m_ref;
654  };
655 
656  template <>
657  class NGRAPH_API AttributeAdapter<NodeVector> : public VisitorAdapter
658  {
659  public:
660  AttributeAdapter(NodeVector& ref);
661 
662  bool visit_attributes(AttributeVisitor& visitor) override;
663 
664  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<NodeVector>", 0};
665  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
666  protected:
667  NodeVector& m_ref;
668  };
669 
670  using RawNodeOutputMap = std::map<RawNodeOutput, Output<Node>>;
671 
672  class NGRAPH_API NodeValidationFailure : public CheckFailure
673  {
674  public:
675  NodeValidationFailure(const CheckLocInfo& check_loc_info,
676  const Node* node,
677  const std::string& explanation)
678  : CheckFailure(check_loc_info, node_validation_failure_loc_string(node), explanation)
679  {
680  }
681  };
682 }
683 #define NODE_VALIDATION_CHECK(node, ...) \
684  NGRAPH_CHECK_HELPER(::ngraph::NodeValidationFailure, (node), __VA_ARGS__)
685 
686 namespace ngraph
687 {
688  template <typename T>
689  void check_new_args_count(const Node* node, T new_args)
690  {
691  NODE_VALIDATION_CHECK(node,
692  new_args.size() == node->input_values().size(),
693  "clone_with_new_inputs() expected ",
694  node->input_values().size(),
695  " argument",
696  (node->input_values().size() == 1 ? "" : "s"),
697  " but got ",
698  new_args.size());
699  }
700 
701 } // namespace ngraph
const DiscreteTypeInfo & get_type_info() const override
type info enables identification of the value accessor, as well as is_type and as_type.
Definition: node.hpp:665
const DiscreteTypeInfo & get_type_info() const override
type info enables identification of the value accessor, as well as is_type and as_type.
Definition: node.hpp:651
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:171
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
Base class for check failure exceptions.
Definition: check.hpp:43
A handle for one of a node's inputs.
Definition: node_input.hpp:41
A handle for one of a node's inputs.
Definition: node_input.hpp:88
Definition: node_input.hpp:35
Definition: node.hpp:673
Definition: node.hpp:132
Node()=default
Construct an unitialized Node.
void clear_control_dependencies()
Remove all dependencies from this node.
const Shape & get_shape() const
Checks that there is exactly one output and returns its shape.
void set_arguments(const OutputVector &arguments)
Sets/replaces the arguments with new arguments.
void add_node_control_dependencies(std::shared_ptr< Node > source_node)
This node absorbs the control dependencies of source_node.
void remove_provenance_group_member(const std::shared_ptr< Node > &node)
Remove node to additional nodes that receive tags.
std::vector< Input< Node > > inputs()
Node(size_t output_size)
Construct an unitialized Node.
std::vector< Output< const Node > > outputs() const
Output< const Node > output(size_t output_index) const
NodeVector get_users(bool check_is_used=false) const
Get all the nodes that uses the current node.
const std::vector< std::shared_ptr< Node > > & get_control_dependencies() const
Get control dependencies registered on the node.
virtual void validate_and_infer_types()
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Input< const Node > input(size_t input_index) const
void add_provenance_tags_above(const OutputVector &base, const std::unordered_set< std::string > &tag_set)
Adds tag_set to this node and all intermediate nodes above base.
void add_provenance_group_member(const std::shared_ptr< Node > &node)
Add node to additional nodes that receive tags.
virtual OutputVector decompose_op() const
Decomposes the FusedOp into a sub-graph consisting of core ngraph ops.
Definition: node.hpp:219
descriptor::Tensor & get_output_tensor(size_t i) const
Returns the tensor for output or input i.
Node(const OutputVector &arguments, size_t output_size=1)
Constructor for Node subclasses that have metaclasses.
std::shared_ptr< Node > add_provenance_group_members_above(const OutputVector &base)
Add all nodes between this node and nodes in base as additional nodes to receive provenance tags.
size_t get_output_size() const
Returns the number of outputs from the node.
void set_arguments(const NodeVector &arguments)
Sets/replaces the arguments with new arguments.
Input< Node > input(size_t input_index)
const std::string & get_friendly_name() const
Gets the friendly name for a node. If no friendly name has been set via set_friendly_name then the no...
virtual bool evaluate(const HostTensorVector &output_values, const HostTensorVector &input_values) const
Evaluates the op on input_values putting results in output_values.
void set_input_is_relevant_to_shape(size_t i, bool relevant=true)
Marks an input as being relevant or irrelevant to the output shapes of this node.
void replace_provenance_group_member(const std::shared_ptr< Node > &current_node, const std::shared_ptr< Node > &replacement_node)
Replace current_node with replacement_node and transfer tags.
virtual const op::AutoBroadcastSpec & get_autob() const
virtual size_t get_default_output_index() const
Returns the output of the default output, or throws if there is none.
virtual const type_info_t & get_type_info() const =0
Node(const Node &)
Copying a node.
const element::Type & get_element_type() const
Checks that there is exactly one output and returns its element type.
Output< Node > output(size_t output_index)
void add_control_dependency(std::shared_ptr< Node > node)
This node cannot execute until node executes.
std::vector< Input< const Node > > inputs() const
void clear_control_dependents()
Remove this node as a dependency from all dependent nodes.
Node & operator=(const Node &)
Assignment operator.
std::vector< Output< Node > > outputs()
void remove_control_dependency(std::shared_ptr< Node > node)
Remove the dependency of this node on node.
size_t no_default_index() const
Throws no default.
void add_node_control_dependents(std::shared_ptr< Node > source_node)
This node becomes a dependent of every node dependent on source_node.
const std::vector< Node * > & get_control_dependents() const
Get nodes dependent on this node.
virtual std::ostream & write_description(std::ostream &os, uint32_t depth=0) const
Writes a description of a node to a stream.
void set_friendly_name(const std::string &name)
Sets a friendly name for a node. This does not overwrite the unique name of the node and is retrieved...
void transfer_control_dependents(std::shared_ptr< Node > replacement)
This node's control dependencies are replaced by replacement.
const std::string & get_name() const
Get the unique name of the node.
virtual size_t get_version() const
Definition: node.hpp:425
const std::set< std::shared_ptr< Node > > & get_provenance_group_members() const
const Shape & get_output_shape(size_t i) const
Returns the shape for output i.
void set_output_size(size_t output_size)
Sets the number of outputs.
void set_argument(size_t position, const Output< Node > &argument)
Sets/replaces the arguments with new arguments.
void transfer_provenance_tags(const std::shared_ptr< Node > &replacement)
Transfer provenance tags to replacement.
const PartialShape & get_output_partial_shape(size_t i) const
Returns the partial shape for output i.
bool operator<(const Node &other) const
Use instance ids for comparison instead of memory addresses to improve determinism.
Definition: node.hpp:428
void set_input_is_relevant_to_value(size_t i, bool relevant=true)
Marks an input as being relevant or irrelevant to the output values of this node.
void safe_delete(NodeVector &nodes, bool recurse)
Moves nodes that would be deleted from inputs to nodes to avoid stack overflows on deep networks.
Output< const Node > get_default_output() const
const element::Type & get_output_element_type(size_t i) const
Returns the element type for output i.
std::vector< Output< Node > > input_values() const
virtual std::string description() const
Get the string name for the type of the node, such as Add or Multiply. The class name,...
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Node * get_node() const
size_t get_index() const
Definition: node_output.hpp:118
Definition: node_output.hpp:36
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:46
Shape for a tensor.
Definition: shape.hpp:31
Definition: variant.hpp:30
Adapters will see visitor.
Definition: attribute_adapter.hpp:194
Definition: input.hpp:33
Compile-time descriptor of a first-class value that is a tensor.
Definition: tensor.hpp:40
Definition: element_type.hpp:61
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
NGRAPH_API ResultVector as_result_vector(const OutputVector &values)
Returns a ResultVector referencing values.
std::unordered_map< ngraph::Node *, std::shared_ptr< ngraph::Node > > NodeMap
Alias useful for cloning.
Definition: node.hpp:108
Definition: check.hpp:35
Definition: type.hpp:39
Definition: node.hpp:606
Implicit broadcast specification.
Definition: attr_types.hpp:323