node.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <atomic>
8 #include <cstring>
9 #include <deque>
10 #include <iostream>
11 #include <map>
12 #include <memory>
13 #include <set>
14 #include <string>
15 #include <tuple>
16 #include <typeindex>
17 #include <unordered_map>
18 #include <unordered_set>
19 #include <vector>
20 
21 #include "ngraph/attribute_visitor.hpp"
22 #include "ngraph/check.hpp"
23 #include "ngraph/coordinate.hpp"
24 #include "ngraph/coordinate_diff.hpp"
25 #include "ngraph/deprecated.hpp"
26 #include "ngraph/descriptor/input.hpp"
27 #include "ngraph/descriptor/output.hpp"
28 #include "ngraph/descriptor/tensor.hpp"
29 #include "ngraph/node_input.hpp"
30 #include "ngraph/node_output.hpp"
31 #include "ngraph/op/util/attr_types.hpp"
32 #include "ngraph/op/util/op_annotations.hpp"
33 #include "ngraph/op/util/variable.hpp"
34 #include "ngraph/op/util/variable_value.hpp"
35 #include "ngraph/output_vector.hpp"
36 #include "ngraph/strides.hpp"
37 #include "ngraph/type.hpp"
38 
39 namespace ngraph
40 {
41  template <typename NodeType>
42  class Input;
43 
44  template <typename NodeType>
45  class Output;
46 
47  class AttributeVisitor;
48  class Variant;
49  class Node;
50 
51  class Function;
52 
53  namespace runtime
54  {
55  class HostTensor;
56  }
57  using HostTensor = runtime::HostTensor;
58  using HostTensorPtr = std::shared_ptr<HostTensor>;
59  using HostTensorVector = std::vector<HostTensorPtr>;
60 
61  /// EvaluationContext stores and manages a context (additional parameters, values and
62  /// environment) for evaluating ngraph::function.
63  using EvaluationContext = std::map<std::string, std::shared_ptr<Variant>>;
64 
65  namespace op
66  {
67  struct AutoBroadcastSpec;
68 
69  namespace v0
70  {
71  class Result;
72  }
73  } // namespace op
74 
75  namespace pattern
76  {
77  class Matcher;
78  }
79 
80  using ResultVector = std::vector<std::shared_ptr<op::v0::Result>>;
81 
82  NGRAPH_API
83  std::string node_validation_failure_loc_string(const Node* node);
84 
85  const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node,
86  size_t i);
87  NGRAPH_API
88  const NodeVector& check_single_output_args(const NodeVector& args);
89 
90  const std::shared_ptr<Node>& check_single_output_arg(const std::shared_ptr<Node>& node,
91  size_t i);
92 
93  NGRAPH_API
94  OutputVector as_output_vector(const NodeVector& args);
95  NGRAPH_API
96  NodeVector as_node_vector(const OutputVector& values);
97  /// Returns a ResultVector referencing values.
98  NGRAPH_API
99  ResultVector as_result_vector(const OutputVector& values);
100 
101  /// Alias useful for cloning
102  using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>;
103 
104  /// \brief Used in evaluator switch statement so that the case type and evaluate call
105  /// are guaranteed to have the types match.
106  ///
107  /// Use this in an evaluate_*() function like this
108  /// switch (arg0->get_element_type())
109  /// {
110  /// TYPE_CASE(i8)(arg0, arg1, out, broadcast_spec); break;
111  /// TYPE_CASE(i16)(arg0, arg1, out, broadcast_spec); break;
112  ///
113  /// Each TYPE_CASE statement expands like this:
114  /// case element::Type_t::a: rc = evaluate<element::Type_t::a>(arg0, arg1, out,
115  /// broadcast_spec)
116  ///
117  /// \note Don't forget to put a break after each statement or it will fall through and generate
118  /// a runtime error.
119 
120 #define TYPE_CASE(a) \
121  case element::Type_t::a: rc = evaluate<element::Type_t::a>
122 
123  /// Nodes are the backbone of the graph of Value dataflow. Every node has
124  /// zero or more nodes as arguments and one value, which is either a tensor
125  /// or a (possibly empty) tuple of values.
126  class NGRAPH_API Node : public std::enable_shared_from_this<Node>
127  {
128  // For access to m_outputs.
129  friend class descriptor::Input;
130 
131  // For access to m_inputs and m_outputs.
132  template <typename NodeType>
133  friend class Input;
134 
135  // For access to m_outputs.
136  template <typename NodeType>
137  friend class Output;
138 
139  public:
140  /// \brief Verifies that attributes and inputs are consistent and computes output shapes
141  /// and element types. Must be implemented by concrete child classes so that it
142  /// can be run any number of times.
143  ///
144  /// Throws if the node is invalid.
145  virtual void validate_and_infer_types();
146 
147  // Called in constructors during transition
148  void constructor_validate_and_infer_types();
149 
151 
152  protected:
153  /// \brief Construct an unitialized Node
154  Node() = default;
155  /// \brief Copying a node
156  Node(const Node&);
157  /// \brief Assignment operator
158  Node& operator=(const Node&);
159 
160  /// \brief Construct an unitialized Node
161  /// \param output_size Number of outputs for this node
162  Node(size_t output_size);
163 
164  /// \brief Constructor for Node subclasses that have metaclasses.
165  /// \param arguments Output i will connect to input i
166  /// \param output_size Number of outputs for this node
167  Node(const OutputVector& arguments, size_t output_size = 1);
168  /// \brief Moves nodes that would be deleted from inputs to nodes to avoid stack overflows
169  /// on deep networks.
170  void safe_delete(NodeVector& nodes, bool recurse);
171 
172  /// \brief Marks an input as being relevant or irrelevant to the output shapes of this
173  /// node.
174  /// \param i The index of the input to mark as relevant or irrelevant.
175  /// \param relevant true if the input is relevant to output shapes, false otherwise.
176  ///
177  /// This is used by the shape specialization pass to know which nodes must be statically
178  /// evaluated in order to complete shape specialization. (For example, the shape input of
179  /// DynReshape must be evaluated statically in order for the output shape to be
180  /// determined.) By default, all inputs are marked as shape-irrelevant. Overrides of
181  /// validate_and_infer_types should call this function to mark shape-relevant inputs.
182  void set_input_is_relevant_to_shape(size_t i, bool relevant = true);
183 
184  /// \brief Marks an input as being relevant or irrelevant to the output values of this
185  /// node.
186  /// \param i The index of the input to mark as relevant or irrelevant.
187  /// \param relevant true if the input is relevant to output values, false otherwise.
188  ///
189  /// This is used by the shape specialization pass to cut short evaluation in cases where
190  /// an input value does not actually have any effect on the output value of the node. (As
191  /// of this writing, the only example of this is ShapeOf.) By default, all inputs are
192  /// marked as value-relevant. Overrides of validate_and_infer_types should call this
193  /// function to mark value-irrelevant inputs.
194  void set_input_is_relevant_to_value(size_t i, bool relevant = true);
195 
196  public:
197  virtual ~Node();
198 
199  virtual bool visit_attributes(AttributeVisitor&) { return false; }
200  /// \returns the autobroadcasr spec
201  virtual const op::AutoBroadcastSpec& get_autob() const;
202 
203  /// \brief Allows to get information about availability of evaluate method for the current
204  /// operation
205  // \returns true if evaluate is available
206  virtual bool has_evaluate() const;
207  /// \brief Evaluates the op on input_values putting results in output_values
208  /// \param output_values Tensors for the outputs to compute. One for each result
209  /// \param input_values Tensors for the inputs. One for each inputs.
210  /// \returns true if successful
211  virtual bool evaluate(const HostTensorVector& output_values,
212  const HostTensorVector& input_values) const;
213  /// \brief Evaluates the op on input_values putting results in output_values
214  /// \param output_values Tensors for the outputs to compute. One for each result
215  /// \param input_values Tensors for the inputs. One for each inputs.
216  /// \param evaluation_context Storage of additional settings and attributes that can be used
217  /// when evaluating the op.
218  /// \returns true if successful
219  virtual bool evaluate(const HostTensorVector& output_values,
220  const HostTensorVector& input_values,
221  const EvaluationContext& evaluationContext) const;
222  virtual bool evaluate_lower(const HostTensorVector& output_values) const;
223  virtual bool evaluate_upper(const HostTensorVector& output_values) const;
224 
225  virtual bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values);
226  /// \brief Decomposes the FusedOp into a sub-graph consisting of core ngraph ops
227  ///
228  /// \return A vector of nodes comprising the sub-graph. The order of output
229  /// tensors must match the match output tensors of the FusedOp
230  virtual OutputVector decompose_op() const { return OutputVector(); }
231  /// Returns the NodeTypeInfo for the node's class.
232  /// During transition to type_info, returns a dummy type_info for Node if the class
233  /// has not been updated yet.
234  virtual const type_info_t& get_type_info() const = 0;
235  const char* get_type_name() const { return get_type_info().name; }
236  /// Sets/replaces the arguments with new arguments.
237  void set_arguments(const NodeVector& arguments);
238  /// Sets/replaces the arguments with new arguments.
239  void set_arguments(const OutputVector& arguments);
240  /// Sets/replaces the arguments with new arguments.
241  void set_argument(size_t position, const Output<Node>& argument);
242 
243  void set_output_type(size_t i,
244  const element::Type& element_type,
245  const PartialShape& pshape);
246 
247  /// Sets the number of outputs
248  void set_output_size(size_t output_size);
249 
250  void invalidate_values();
251  virtual void revalidate_and_infer_types()
252  {
253  invalidate_values();
254  validate_and_infer_types();
255  }
256  /// \brief Get the string name for the type of the node, such as `Add` or `Multiply`.
257  /// The class name, must not contain spaces as it is used for codegen.
258  /// \returns A const reference to the node's type name
259  virtual std::string description() const;
260  /// \brief Get the unique name of the node.
261  /// \returns A const reference to the node's unique name.
262  const std::string& get_name() const;
263 
264  /// \brief Sets a friendly name for a node. This does not overwrite the unique name
265  /// of the node and is retrieved via get_friendly_name(). Used mainly for debugging.
266  /// The friendly name may be set exactly once.
267  /// \param name is the friendly name to set
268  void set_friendly_name(const std::string& name);
269 
270  /// \brief Gets the friendly name for a node. If no friendly name has been set via
271  /// set_friendly_name then the node's unique name is returned.
272  /// \returns A const reference to the node's friendly name.
273  const std::string& get_friendly_name() const;
274 
275  virtual bool is_dynamic() const;
276  size_t get_instance_id() const { return m_instance_id; }
277  /// \brief Writes a description of a node to a stream
278  /// \param os The stream; should be returned
279  /// \param depth How many levels of inputs to describe
280  /// \returns The stream os
281  virtual std::ostream& write_description(std::ostream& os, uint32_t depth = 0) const;
282 
283  /// Get control dependencies registered on the node
284  const std::vector<std::shared_ptr<Node>>& get_control_dependencies() const;
285 
286  /// Get nodes dependent on this node
287  const std::vector<Node*>& get_control_dependents() const;
288 
289  /// This node cannot execute until node executes
290  void add_control_dependency(std::shared_ptr<Node> node);
291 
292  /// Remove the dependency of this node on node
293  void remove_control_dependency(std::shared_ptr<Node> node);
294 
295  /// Remove all dependencies from this node
297 
298  /// Remove this node as a dependency from all dependent nodes
300 
301  /// This node absorbs the control dependencies of source_node
302  void add_node_control_dependencies(std::shared_ptr<Node> source_node);
303 
304  /// This node becomes a dependent of every node dependent on source_node
305  void add_node_control_dependents(std::shared_ptr<Node> source_node);
306 
307  /// This node's control dependencies are replaced by replacement
308  void transfer_control_dependents(std::shared_ptr<Node> replacement);
309 
310  /// Returns the number of outputs from the node.
311  size_t get_output_size() const;
312 
313  /// Returns the element type for output i
314  const element::Type& get_output_element_type(size_t i) const;
315 
316  /// Checks that there is exactly one output and returns its element type
317  // TODO: deprecate in favor of node->get_output_element_type(0) with a suitable check in
318  // the calling code, or updates to the calling code if it is making an invalid assumption
319  // of only one output.
321 
322  /// Returns the shape for output i
323  const Shape& get_output_shape(size_t i) const;
324 
325  /// Returns the partial shape for output i
326  const PartialShape& get_output_partial_shape(size_t i) const;
327 
328  /// Return the output to use when converting to an Output<Node> with no index specified.
329  /// Throws when not supported.
331  Output<Node> get_default_output();
332 
333  /// Returns the output of the default output, or throws if there is none
334  virtual size_t get_default_output_index() const;
335  /// Throws no default
336  size_t no_default_index() const;
337 
338  /// Checks that there is exactly one output and returns its shape
339  // TODO: deprecate in favor of node->get_output_shape(0) with a suitable check in the
340  // calling code, or updates to the calling code if it is making an invalid assumption of
341  // only one output.
342  const Shape& get_shape() const;
343 
344  /// Returns the tensor for output or input i
346  descriptor::Tensor& get_input_tensor(size_t i) const;
347 
348  /// Returns the tensor name for output i
349  NGRAPH_DEPRECATED(
350  "The tensor name was deprecated. Use get_output_tensor(i).get_names() instead.")
351  const std::string& get_output_tensor_name(size_t i) const;
352 
353  std::set<Input<Node>> get_output_target_inputs(size_t i) const;
354 
355  /// Returns the number of inputs for the op
356  size_t get_input_size() const;
357 
358  /// Returns the element type of input i
359  // TODO: deprecate in favor of node->get_input_element_type(i)
360  const element::Type& get_input_element_type(size_t i) const;
361 
362  /// Returns the shape of input i
363  // TODO: deprecate in favor of node->get_input_shape(i)
364  const Shape& get_input_shape(size_t i) const;
365 
366  /// Returns the partial shape of input i
367  // TODO: deprecate in favor of node->get_input_partial_shape(i)
368  const PartialShape& get_input_partial_shape(size_t i) const;
369 
370  /// Returns the tensor name for input i
371  NGRAPH_DEPRECATED(
372  "The tensor name was deprecated. Use get_input_tensor(i).get_names() instead.")
373  const std::string& get_input_tensor_name(size_t i) const;
374 
375  std::unordered_set<descriptor::Tensor*> liveness_new_list;
376  std::unordered_set<descriptor::Tensor*> liveness_free_list;
377 
378  Node* get_input_node_ptr(size_t index) const;
379  std::shared_ptr<Node> get_input_node_shared_ptr(size_t index) const;
380  Output<Node> get_input_source_output(size_t i) const;
381 
382  public:
383  virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const = 0;
384 
385  std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const;
386 
387  std::shared_ptr<Node> copy_with_new_inputs(
388  const OutputVector& inputs,
389  const std::vector<std::shared_ptr<Node>>& control_dependencies) const;
390 
391  /// True if this and node have one output with same element type and shape
392  bool has_same_type(std::shared_ptr<const Node> node) const;
393 
394  using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
395 
396  RTMap& get_rt_info() { return m_rt_info; }
397  const RTMap& get_rt_info() const { return m_rt_info; }
398  const std::unordered_set<std::string>& get_provenance_tags() const;
399  void add_provenance_tag(const std::string& tag);
400  template <typename T>
401  void add_provenance_tags(T tag_set)
402  {
403  for (auto tag : tag_set)
404  {
405  add_provenance_tag(tag);
406  }
407  }
408  /// \brief Adds tag_set to this node and all intermediate nodes above base
409  void add_provenance_tags_above(const OutputVector& base,
410  const std::unordered_set<std::string>& tag_set);
411  void remove_provenance_tag(const std::string& tag);
412  /// \brief Add node to additional nodes that receive tags
413  void add_provenance_group_member(const std::shared_ptr<Node>& node);
414  /// \brief Remove node to additional nodes that receive tags
415  void remove_provenance_group_member(const std::shared_ptr<Node>& node);
416  /// \brief Replace current_node with replacement_node and transfer tags
417  void replace_provenance_group_member(const std::shared_ptr<Node>& current_node,
418  const std::shared_ptr<Node>& replacement_node);
419  /// \return Provenance group nodes
420  const std::set<std::shared_ptr<Node>>& get_provenance_group_members() const;
421 
422  /// \brief Add all nodes between this node and nodes in base as additional nodes to receive
423  /// provenance tags.
424  std::shared_ptr<Node> add_provenance_group_members_above(const OutputVector& base);
425 
426  // to be used when nodes are replaced
427  void merge_provenance_tags_from(const std::shared_ptr<const Node>& source);
428 
429  /// Transfer provenance tags to replacement
430  void transfer_provenance_tags(const std::shared_ptr<Node>& replacement);
431 
432  /// Get all the nodes that uses the current node
433  NodeVector get_users(bool check_is_used = false) const;
434 
435  /// \return Version of this node
436  virtual size_t get_version() const { return get_type_info().version; }
437  virtual std::shared_ptr<Node> get_default_value() const { return nullptr; }
438  /// Use instance ids for comparison instead of memory addresses to improve determinism
439  bool operator<(const Node& other) const { return m_instance_id < other.m_instance_id; }
440  /// \return A vector containing a handle for each of this node's inputs, in order.
441  // TODO: Rename to get_inputs()?
442  std::vector<Input<Node>> inputs();
443 
444  /// \return A vector containing a handle for each of this node's inputs, in order.
445  std::vector<Input<const Node>> inputs() const;
446 
447  /// \return A vector containing the values for each input
448  std::vector<Output<Node>> input_values() const;
449 
450  /// \return A vector containing a handle for each of this node's outputs, in order.
451  // TODO: Rename to get_outputs()?
452  std::vector<Output<Node>> outputs();
453 
454  /// \return A vector containing a handle for each of this node's outputs, in order.
455  std::vector<Output<const Node>> outputs() const;
456 
457  /// \return A handle to the `input_index`th input of this node.
458  /// \throw std::out_of_range if the node does not have at least `input_index+1` inputs.
459  Input<Node> input(size_t input_index);
460 
461  /// \return A handle to the `input_index`th input of this node.
462  /// \throw std::out_of_range if the node does not have at least `input_index+1` inputs.
463  Input<const Node> input(size_t input_index) const;
464 
465  Output<Node> input_value(size_t input_index) const;
466 
467  /// \return A handle to the `output_index`th output of this node.
468  /// \throw std::out_of_range if the node does not have at least `output_index+1` outputs.
469  Output<Node> output(size_t output_index);
470 
471  /// \return A handle to the `output_index`th output of this node.
472  /// \throw std::out_of_range if the node does not have at least `output_index+1` outputs.
473  Output<const Node> output(size_t output_index) const;
474 
475  void set_op_annotations(std::shared_ptr<ngraph::op::util::OpAnnotations> op_annotations)
476  {
477  m_op_annotations = op_annotations;
478  }
479  std::shared_ptr<ngraph::op::util::OpAnnotations> get_op_annotations() const
480  {
481  return m_op_annotations;
482  }
483 
484  virtual bool match_value(pattern::Matcher* matcher,
485  const Output<Node>& pattern_value,
486  const Output<Node>& graph_value);
487 
488  virtual bool match_node(pattern::Matcher* matcher, const Output<Node>& graph_value);
489 
490  private:
491  descriptor::Input& get_input_descriptor(size_t position);
492  descriptor::Output& get_output_descriptor(size_t position);
493 
494  std::vector<Node*> m_control_dependents;
495  std::vector<std::shared_ptr<Node>> m_control_dependencies;
496  std::string m_node_type;
497  size_t m_instance_id{m_next_instance_id.fetch_add(1)};
498  std::string m_friendly_name;
499  std::string m_unique_name;
500  static std::atomic<size_t> m_next_instance_id;
501  std::unordered_set<std::string> m_provenance_tags;
502  std::set<std::shared_ptr<Node>> m_provenance_group;
503  std::deque<descriptor::Input> m_inputs;
504  std::deque<descriptor::Output> m_outputs;
505  std::shared_ptr<ngraph::op::util::OpAnnotations> m_op_annotations;
506  std::map<std::string, std::shared_ptr<Variant>> m_rt_info;
507  };
508 
509  using NodeTypeInfo = Node::type_info_t;
510 
511  NGRAPH_API std::ostream& operator<<(std::ostream&, const Node&);
512  NGRAPH_API std::ostream& operator<<(std::ostream&, const Node*);
513 
514 #define _NGRAPH_RTTI_EXPAND(X) X
515 
516 /// Helper macro that puts necessary declarations of RTTI block inside a class definition.
517 /// Should be used in the scope of class that requires type identification besides one provided by
518 /// C++ RTTI.
519 /// Recommended to be used for all classes that are inherited from class ngraph::Node to enable
520 /// pattern
521 /// matching for them. Accepts necessary type identification details like type of the operation,
522 /// version and optional parent class.
523 ///
524 /// Applying this macro within a class definition provides declaration of type_info static
525 /// constant for backward compatibility with old RTTI definition for Node,
526 /// static function get_type_info_static which returns a reference to an object that is equal to
527 /// type_info but not necessary to the same object, and get_type_info virtual function that
528 /// overrides Node::get_type_info and returns a reference to the same object that
529 /// get_type_info_static gives.
530 ///
531 /// Use this macro as a public part of the class definition:
532 ///
533 /// class MyOp : public Node
534 /// {
535 /// public:
536 /// // Don't use Node as a parent for type_info, it doesn't have any value and
537 /// prohibited
538 /// NGRAPH_RTTI_DECLARATION;
539 ///
540 /// ...
541 /// };
542 ///
543 /// class MyInheritedOp : public MyOp
544 /// {
545 /// public:
546 /// NGRAPH_RTTI_DECLARATION;
547 ///
548 /// ...
549 /// };
550 ///
551 /// To complete type identification for a class, use NGRAPH_RTTI_DEFINITION.
552 ///
553 #define NGRAPH_RTTI_DECLARATION \
554  static const ::ngraph::Node::type_info_t type_info; \
555  const ::ngraph::Node::type_info_t& get_type_info() const override; \
556  static const ::ngraph::Node::type_info_t& get_type_info_static()
557 
558 #define _NGRAPH_RTTI_DEFINITION_COMMON(CLASS) \
559  const ::ngraph::Node::type_info_t& CLASS::get_type_info() const \
560  { \
561  return get_type_info_static(); \
562  } \
563  const ::ngraph::Node::type_info_t CLASS::type_info = CLASS::get_type_info_static()
564 #define _NGRAPH_RTTI_DEFINITION_WITH_PARENT(CLASS, TYPE_NAME, _VERSION_INDEX, PARENT_CLASS) \
565  const ::ngraph::Node::type_info_t& CLASS::get_type_info_static() \
566  { \
567  static const ::ngraph::Node::type_info_t type_info_static{ \
568  TYPE_NAME, _VERSION_INDEX, &PARENT_CLASS::get_type_info_static()}; \
569  return type_info_static; \
570  } \
571  _NGRAPH_RTTI_DEFINITION_COMMON(CLASS)
572 
573 #define _NGRAPH_RTTI_DEFINITION_NO_PARENT(CLASS, TYPE_NAME, _VERSION_INDEX) \
574  const ::ngraph::Node::type_info_t& CLASS::get_type_info_static() \
575  { \
576  static const ::ngraph::Node::type_info_t type_info_static{TYPE_NAME, _VERSION_INDEX}; \
577  return type_info_static; \
578  } \
579  _NGRAPH_RTTI_DEFINITION_COMMON(CLASS)
580 
581 #define _NGRAPH_RTTI_DEFINITION_SELECTOR(_1, _2, _3, _4, NAME, ...) NAME
582 
583 /// Complementary to NGRAPH_RTTI_DECLARATION, this helper macro _defines_ items _declared_ by
584 /// NGRAPH_RTTI_DECLARATION.
585 /// Should be used outside the class definition scope in place where ODR is ensured.
586 ///
587 /// \param CLASS is a C++ name of the class where corresponding NGRAPH_RTTI_DECLARATION was applied.
588 /// \param TYPE_NAME a string literal of type const char* that names your class in type
589 /// identification namespace;
590 /// It is your choice how to name it, but it should be unique among all
591 /// NGRAPH_RTTI_DECLARATION-enabled classes that can be
592 /// used in conjunction with each other in one transformation flow.
593 /// \param _VERSION_INDEX is an unsigned integer index to distinguish different versions of
594 /// operations that shares the same TYPE_NAME
595 /// \param PARENT_CLASS is an optional direct or indirect parent class for this class; define
596 /// it only in case if there is a need to capture any operation from some group of operations
597 /// that all derived from some common base class. Don't use Node as a parent, it is a base
598 /// class
599 /// for all operations and doesn't provide ability to define some perfect subset of
600 /// operations. PARENT_CLASS should define RTTI with NGRAPH_RTTI_{DECLARATION/DEFINITION}
601 /// macros.
602 ///
603 /// Examples (see corresponding declarations in NGRAPH_RTTI_DECLARATION description):
604 ///
605 /// NGRAPH_RTTI_DEFINITION(MyOp,"MyOp", 1);
606 /// NGRAPH_RTTI_DEFINITION(MyInheritedOp, "MyInheritedOp", 1, MyOp)
607 ///
608 /// For convenience, TYPE_NAME and CLASS name are recommended to be the same.
609 ///
610 #define NGRAPH_RTTI_DEFINITION(...) \
611  _NGRAPH_RTTI_EXPAND(_NGRAPH_RTTI_DEFINITION_SELECTOR( \
612  __VA_ARGS__, _NGRAPH_RTTI_DEFINITION_WITH_PARENT, _NGRAPH_RTTI_DEFINITION_NO_PARENT)( \
613  __VA_ARGS__))
614 
615  // Like an Output but with a Node* instead of a shared_ptr<Node>
617  {
618  RawNodeOutput(const Output<Node>& value)
619  : node(value.get_node())
620  , index(value.get_index())
621  {
622  }
623  RawNodeOutput(Node* node, size_t index)
624  : node(node)
625  , index(index)
626  {
627  }
628  RawNodeOutput(const RawNodeOutput&) = default;
629  RawNodeOutput() = default;
630  RawNodeOutput& operator=(const RawNodeOutput&) = default;
631 
632  Node* node;
633  size_t index{0};
634 
635  operator Output<Node>() { return Output<Node>(node->shared_from_this(), index); }
636  bool operator==(const RawNodeOutput& other) const
637  {
638  return node == other.node && index == other.index;
639  }
640  bool operator!=(const RawNodeOutput& other) const { return !(*this == other); }
641  bool operator<(const RawNodeOutput& other) const
642  {
643  return node < other.node || (node == other.node && index < other.index);
644  }
645  bool operator>(const RawNodeOutput& other) const
646  {
647  return node > other.node || (node == other.node && index > other.index);
648  }
649  bool operator<=(const RawNodeOutput& other) const { return !(*this > other); }
650  bool operator>=(const RawNodeOutput& other) const { return !(*this < other); }
651  };
652 
653  /// \brief Visits a reference to a node that has been registered with the visitor.
654  template <>
655  class NGRAPH_API AttributeAdapter<std::shared_ptr<Node>> : public VisitorAdapter
656  {
657  public:
658  AttributeAdapter(std::shared_ptr<Node>& value);
659 
660  bool visit_attributes(AttributeVisitor& visitor) override;
661  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Node>>", 0};
662  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
663 
664  protected:
665  std::shared_ptr<Node>& m_ref;
666  };
667 
668  template <>
669  class NGRAPH_API AttributeAdapter<NodeVector> : public VisitorAdapter
670  {
671  public:
672  AttributeAdapter(NodeVector& ref);
673 
674  bool visit_attributes(AttributeVisitor& visitor) override;
675 
676  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<NodeVector>", 0};
677  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
678 
679  protected:
680  NodeVector& m_ref;
681  };
682 
683  using RawNodeOutputMap = std::map<RawNodeOutput, Output<Node>>;
684 
685  class NGRAPH_API NodeValidationFailure : public CheckFailure
686  {
687  public:
688  NodeValidationFailure(const CheckLocInfo& check_loc_info,
689  const Node* node,
690  const std::string& explanation)
691  : CheckFailure(check_loc_info, node_validation_failure_loc_string(node), explanation)
692  {
693  }
694  };
695 } // namespace ngraph
696 #define NODE_VALIDATION_CHECK(node, ...) \
697  NGRAPH_CHECK_HELPER(::ngraph::NodeValidationFailure, (node), __VA_ARGS__)
698 
699 namespace ngraph
700 {
701  template <typename T>
702  void check_new_args_count(const Node* node, T new_args)
703  {
704  NODE_VALIDATION_CHECK(node,
705  new_args.size() == node->input_values().size(),
706  "clone_with_new_inputs() expected ",
707  node->input_values().size(),
708  " argument",
709  (node->input_values().size() == 1 ? "" : "s"),
710  " but got ",
711  new_args.size());
712  }
713 
714 } // namespace ngraph
const DiscreteTypeInfo & get_type_info() const override
type info enables identification of the value accessor, as well as is_type and as_type.
Definition: node.hpp:677
const DiscreteTypeInfo & get_type_info() const override
type info enables identification of the value accessor, as well as is_type and as_type.
Definition: node.hpp:662
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:161
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
Base class for check failure exceptions.
Definition: check.hpp:31
A handle for one of a node's inputs.
Definition: node_input.hpp:32
A handle for one of a node's inputs.
Definition: node_input.hpp:85
Definition: node_input.hpp:24
Definition: node.hpp:686
Definition: node.hpp:127
Node()=default
Construct an unitialized Node.
void clear_control_dependencies()
Remove all dependencies from this node.
const Shape & get_shape() const
Checks that there is exactly one output and returns its shape.
void set_arguments(const OutputVector &arguments)
Sets/replaces the arguments with new arguments.
void add_node_control_dependencies(std::shared_ptr< Node > source_node)
This node absorbs the control dependencies of source_node.
void remove_provenance_group_member(const std::shared_ptr< Node > &node)
Remove node to additional nodes that receive tags.
std::vector< Input< Node > > inputs()
Node(size_t output_size)
Construct an unitialized Node.
std::vector< Output< const Node > > outputs() const
Output< const Node > output(size_t output_index) const
NodeVector get_users(bool check_is_used=false) const
Get all the nodes that uses the current node.
const std::vector< std::shared_ptr< Node > > & get_control_dependencies() const
Get control dependencies registered on the node.
virtual void validate_and_infer_types()
Verifies that attributes and inputs are consistent and computes output shapes and element types....
virtual bool evaluate(const HostTensorVector &output_values, const HostTensorVector &input_values, const EvaluationContext &evaluationContext) const
Evaluates the op on input_values putting results in output_values.
Input< const Node > input(size_t input_index) const
void add_provenance_tags_above(const OutputVector &base, const std::unordered_set< std::string > &tag_set)
Adds tag_set to this node and all intermediate nodes above base.
void add_provenance_group_member(const std::shared_ptr< Node > &node)
Add node to additional nodes that receive tags.
virtual OutputVector decompose_op() const
Decomposes the FusedOp into a sub-graph consisting of core ngraph ops.
Definition: node.hpp:230
descriptor::Tensor & get_output_tensor(size_t i) const
Returns the tensor for output or input i.
Node(const OutputVector &arguments, size_t output_size=1)
Constructor for Node subclasses that have metaclasses.
std::shared_ptr< Node > add_provenance_group_members_above(const OutputVector &base)
Add all nodes between this node and nodes in base as additional nodes to receive provenance tags.
size_t get_output_size() const
Returns the number of outputs from the node.
void set_arguments(const NodeVector &arguments)
Sets/replaces the arguments with new arguments.
Input< Node > input(size_t input_index)
const std::string & get_friendly_name() const
Gets the friendly name for a node. If no friendly name has been set via set_friendly_name then the no...
virtual bool evaluate(const HostTensorVector &output_values, const HostTensorVector &input_values) const
Evaluates the op on input_values putting results in output_values.
void set_input_is_relevant_to_shape(size_t i, bool relevant=true)
Marks an input as being relevant or irrelevant to the output shapes of this node.
void replace_provenance_group_member(const std::shared_ptr< Node > &current_node, const std::shared_ptr< Node > &replacement_node)
Replace current_node with replacement_node and transfer tags.
virtual const op::AutoBroadcastSpec & get_autob() const
virtual size_t get_default_output_index() const
Returns the output of the default output, or throws if there is none.
virtual const type_info_t & get_type_info() const =0
Node(const Node &)
Copying a node.
const element::Type & get_element_type() const
Checks that there is exactly one output and returns its element type.
Output< Node > output(size_t output_index)
void add_control_dependency(std::shared_ptr< Node > node)
This node cannot execute until node executes.
virtual bool has_evaluate() const
Allows to get information about availability of evaluate method for the current operation.
std::vector< Input< const Node > > inputs() const
void clear_control_dependents()
Remove this node as a dependency from all dependent nodes.
Node & operator=(const Node &)
Assignment operator.
std::vector< Output< Node > > outputs()
void remove_control_dependency(std::shared_ptr< Node > node)
Remove the dependency of this node on node.
size_t no_default_index() const
Throws no default.
void add_node_control_dependents(std::shared_ptr< Node > source_node)
This node becomes a dependent of every node dependent on source_node.
const std::vector< Node * > & get_control_dependents() const
Get nodes dependent on this node.
virtual std::ostream & write_description(std::ostream &os, uint32_t depth=0) const
Writes a description of a node to a stream.
void set_friendly_name(const std::string &name)
Sets a friendly name for a node. This does not overwrite the unique name of the node and is retrieved...
void transfer_control_dependents(std::shared_ptr< Node > replacement)
This node's control dependencies are replaced by replacement.
const std::string & get_name() const
Get the unique name of the node.
virtual size_t get_version() const
Definition: node.hpp:436
const std::set< std::shared_ptr< Node > > & get_provenance_group_members() const
const Shape & get_output_shape(size_t i) const
Returns the shape for output i.
void set_output_size(size_t output_size)
Sets the number of outputs.
void set_argument(size_t position, const Output< Node > &argument)
Sets/replaces the arguments with new arguments.
void transfer_provenance_tags(const std::shared_ptr< Node > &replacement)
Transfer provenance tags to replacement.
const PartialShape & get_output_partial_shape(size_t i) const
Returns the partial shape for output i.
bool operator<(const Node &other) const
Use instance ids for comparison instead of memory addresses to improve determinism.
Definition: node.hpp:439
void set_input_is_relevant_to_value(size_t i, bool relevant=true)
Marks an input as being relevant or irrelevant to the output values of this node.
void safe_delete(NodeVector &nodes, bool recurse)
Moves nodes that would be deleted from inputs to nodes to avoid stack overflows on deep networks.
Output< const Node > get_default_output() const
const element::Type & get_output_element_type(size_t i) const
Returns the element type for output i.
std::vector< Output< Node > > input_values() const
virtual std::string description() const
Get the string name for the type of the node, such as Add or Multiply. The class name,...
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Node * get_node() const
size_t get_index() const
Definition: node_output.hpp:115
Definition: node_output.hpp:25
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:34
Shape for a tensor.
Definition: shape.hpp:19
Definition: variant.hpp:18
Adapters will see visitor.
Definition: attribute_adapter.hpp:185
Definition: input.hpp:21
Compile-time descriptor of a first-class value that is a tensor.
Definition: tensor.hpp:28
Definition: element_type.hpp:51
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
NGRAPH_API ResultVector as_result_vector(const OutputVector &values)
Returns a ResultVector referencing values.
std::map< std::string, std::shared_ptr< Variant > > EvaluationContext
Definition: node.hpp:63
std::unordered_map< ngraph::Node *, std::shared_ptr< ngraph::Node > > NodeMap
Alias useful for cloning.
Definition: node.hpp:102
Definition: check.hpp:23
Definition: type.hpp:27
Definition: node.hpp:617
Implicit broadcast specification.
Definition: attr_types.hpp:311