29 #include <unordered_map>
30 #include <unordered_set>
33 #include "ngraph/attribute_visitor.hpp"
34 #include "ngraph/check.hpp"
35 #include "ngraph/coordinate.hpp"
36 #include "ngraph/coordinate_diff.hpp"
37 #include "ngraph/deprecated.hpp"
38 #include "ngraph/descriptor/input.hpp"
39 #include "ngraph/descriptor/output.hpp"
40 #include "ngraph/descriptor/tensor.hpp"
41 #include "ngraph/node_input.hpp"
42 #include "ngraph/node_output.hpp"
43 #include "ngraph/op/util/attr_types.hpp"
44 #include "ngraph/op/util/op_annotations.hpp"
45 #include "ngraph/output_vector.hpp"
46 #include "ngraph/strides.hpp"
47 #include "ngraph/type.hpp"
51 template <
typename NodeType>
54 template <
typename NodeType>
57 class AttributeVisitor;
67 using HostTensor = runtime::HostTensor;
68 using HostTensorPtr = std::shared_ptr<HostTensor>;
69 using HostTensorVector = std::vector<HostTensorPtr>;
73 struct AutoBroadcastSpec;
86 using ResultVector = std::vector<std::shared_ptr<op::v0::Result>>;
89 std::string node_validation_failure_loc_string(
const Node* node);
91 const std::shared_ptr<Node>& check_single_output_arg(
const std::shared_ptr<Node>& node,
94 const NodeVector& check_single_output_args(
const NodeVector& args);
96 const std::shared_ptr<Node>& check_single_output_arg(
const std::shared_ptr<Node>& node,
100 OutputVector as_output_vector(
const NodeVector& args);
102 NodeVector as_node_vector(
const OutputVector& values);
108 using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>;
125 #define TYPE_CASE(a) \
126 case element::Type_t::a: rc = evaluate<element::Type_t::a>
131 class NGRAPH_API
Node :
public std::enable_shared_from_this<Node>
137 template <
typename NodeType>
141 template <
typename NodeType>
153 void constructor_validate_and_infer_types();
172 Node(
const OutputVector& arguments,
size_t output_size = 1);
209 virtual bool evaluate(
const HostTensorVector& output_values,
210 const HostTensorVector& input_values)
const;
211 virtual bool evaluate_lower(
const HostTensorVector& output_values)
const;
212 virtual bool evaluate_upper(
const HostTensorVector& output_values)
const;
214 virtual bool constant_fold(OutputVector& output_values,
const OutputVector& inputs_values);
224 const char* get_type_name()
const {
return get_type_info().name; }
232 void set_output_type(
size_t i,
239 void invalidate_values();
240 virtual void revalidate_and_infer_types()
243 validate_and_infer_types();
264 virtual bool is_dynamic()
const;
265 size_t get_instance_id()
const {
return m_instance_id; }
339 "The tensor name was deprecated. Use get_output_tensor(i).get_names() instead.")
340 const std::
string& get_output_tensor_name(
size_t i) const;
342 std::set<
Input<
Node>> get_output_target_inputs(
size_t i) const;
345 size_t get_input_size() const;
349 const element::Type& get_input_element_type(
size_t i) const;
353 const
Shape& get_input_shape(
size_t i) const;
361 "The tensor name was deprecated. Use get_input_tensor(i).get_names() instead.")
362 const std::
string& get_input_tensor_name(
size_t i) const;
364 std::unordered_set<descriptor::Tensor*> liveness_new_list;
365 std::unordered_set<descriptor::Tensor*> liveness_free_list;
367 Node* get_input_node_ptr(
size_t index) const;
368 std::shared_ptr<
Node> get_input_node_shared_ptr(
size_t index) const;
369 Output<
Node> get_input_source_output(
size_t i) const;
372 virtual std::shared_ptr<
Node> clone_with_new_inputs(const OutputVector& inputs) const = 0;
374 std::shared_ptr<
Node> copy_with_new_inputs(const OutputVector& new_args) const;
376 std::shared_ptr<
Node> copy_with_new_inputs(
377 const OutputVector& inputs,
378 const std::vector<std::shared_ptr<
Node>>& control_dependencies) const;
381 bool has_same_type(std::shared_ptr<const
Node> node) const;
383 using RTMap = std::map<std::
string, std::shared_ptr<
Variant>>;
385 RTMap& get_rt_info() {
return m_rt_info; }
386 const RTMap& get_rt_info()
const {
return m_rt_info; }
387 const std::unordered_set<std::string>& get_provenance_tags()
const;
388 void add_provenance_tag(
const std::string& tag);
389 template <
typename T>
390 void add_provenance_tags(T tag_set)
392 for (
auto tag : tag_set)
394 add_provenance_tag(tag);
399 const std::unordered_set<std::string>& tag_set);
400 void remove_provenance_tag(
const std::string& tag);
407 const std::shared_ptr<Node>& replacement_node);
416 void merge_provenance_tags_from(
const std::shared_ptr<const Node>& source);
425 virtual size_t get_version()
const {
return get_type_info().version; }
426 virtual std::shared_ptr<Node> get_default_value()
const {
return nullptr; }
428 bool operator<(
const Node& other)
const {
return m_instance_id < other.m_instance_id; }
434 std::vector<Input<const Node>>
inputs()
const;
444 std::vector<Output<const Node>>
outputs()
const;
464 void set_op_annotations(std::shared_ptr<ngraph::op::util::OpAnnotations> op_annotations)
466 m_op_annotations = op_annotations;
468 std::shared_ptr<ngraph::op::util::OpAnnotations> get_op_annotations()
const
470 return m_op_annotations;
473 virtual bool match_value(pattern::Matcher* matcher,
474 const Output<Node>& pattern_value,
475 const Output<Node>& graph_value);
477 virtual bool match_node(pattern::Matcher* matcher,
const Output<Node>& graph_value);
480 descriptor::Input& get_input_descriptor(
size_t position);
481 descriptor::Output& get_output_descriptor(
size_t position);
483 std::vector<Node*> m_control_dependents;
484 std::vector<std::shared_ptr<Node>> m_control_dependencies;
485 std::string m_node_type;
486 size_t m_instance_id{m_next_instance_id.fetch_add(1)};
487 std::string m_friendly_name;
488 std::string m_unique_name;
489 static std::atomic<size_t> m_next_instance_id;
490 std::unordered_set<std::string> m_provenance_tags;
491 std::set<std::shared_ptr<Node>> m_provenance_group;
492 std::deque<descriptor::Input> m_inputs;
493 std::deque<descriptor::Output> m_outputs;
494 std::shared_ptr<ngraph::op::util::OpAnnotations> m_op_annotations;
495 std::map<std::string, std::shared_ptr<Variant>> m_rt_info;
498 using NodeTypeInfo = Node::type_info_t;
500 NGRAPH_API std::ostream& operator<<(std::ostream&,
const Node&);
501 NGRAPH_API std::ostream& operator<<(std::ostream&,
const Node*);
503 #define _NGRAPH_RTTI_EXPAND(X) X
542 #define NGRAPH_RTTI_DECLARATION \
543 static const ::ngraph::Node::type_info_t type_info; \
544 const ::ngraph::Node::type_info_t& get_type_info() const override; \
545 static const ::ngraph::Node::type_info_t& get_type_info_static()
547 #define _NGRAPH_RTTI_DEFINITION_COMMON(CLASS) \
548 const ::ngraph::Node::type_info_t& CLASS::get_type_info() const \
550 return get_type_info_static(); \
552 const ::ngraph::Node::type_info_t CLASS::type_info = CLASS::get_type_info_static()
553 #define _NGRAPH_RTTI_DEFINITION_WITH_PARENT(CLASS, TYPE_NAME, _VERSION_INDEX, PARENT_CLASS) \
554 const ::ngraph::Node::type_info_t& CLASS::get_type_info_static() \
556 static const ::ngraph::Node::type_info_t type_info_static{ \
557 TYPE_NAME, _VERSION_INDEX, &PARENT_CLASS::get_type_info_static()}; \
558 return type_info_static; \
560 _NGRAPH_RTTI_DEFINITION_COMMON(CLASS)
562 #define _NGRAPH_RTTI_DEFINITION_NO_PARENT(CLASS, TYPE_NAME, _VERSION_INDEX) \
563 const ::ngraph::Node::type_info_t& CLASS::get_type_info_static() \
565 static const ::ngraph::Node::type_info_t type_info_static{TYPE_NAME, _VERSION_INDEX}; \
566 return type_info_static; \
568 _NGRAPH_RTTI_DEFINITION_COMMON(CLASS)
570 #define _NGRAPH_RTTI_DEFINITION_SELECTOR(_1, _2, _3, _4, NAME, ...) NAME
599 #define NGRAPH_RTTI_DEFINITION(...) \
600 _NGRAPH_RTTI_EXPAND(_NGRAPH_RTTI_DEFINITION_SELECTOR( \
601 __VA_ARGS__, _NGRAPH_RTTI_DEFINITION_WITH_PARENT, _NGRAPH_RTTI_DEFINITION_NO_PARENT)( \
627 return node == other.node && index == other.index;
629 bool operator!=(
const RawNodeOutput& other)
const {
return !(*
this == other); }
632 return node < other.node || (node == other.node && index < other.index);
636 return node > other.node || (node == other.node && index > other.index);
638 bool operator<=(
const RawNodeOutput& other)
const {
return !(*
this > other); }
639 bool operator>=(
const RawNodeOutput& other)
const {
return !(*
this < other); }
650 static constexpr
DiscreteTypeInfo type_info{
"AttributeAdapter<std::shared_ptr<Node>>", 0};
653 std::shared_ptr<Node>& m_ref;
664 static constexpr
DiscreteTypeInfo type_info{
"AttributeAdapter<NodeVector>", 0};
670 using RawNodeOutputMap = std::map<RawNodeOutput, Output<Node>>;
677 const std::string& explanation)
678 :
CheckFailure(check_loc_info, node_validation_failure_loc_string(node), explanation)
683 #define NODE_VALIDATION_CHECK(node, ...) \
684 NGRAPH_CHECK_HELPER(::ngraph::NodeValidationFailure, (node), __VA_ARGS__)
688 template <
typename T>
689 void check_new_args_count(
const Node* node, T new_args)
691 NODE_VALIDATION_CHECK(node,
692 new_args.size() == node->input_values().size(),
693 "clone_with_new_inputs() expected ",
694 node->input_values().size(),
696 (node->input_values().size() == 1 ?
"" :
"s"),
const DiscreteTypeInfo & get_type_info() const override
type info enables identification of the value accessor, as well as is_type and as_type.
Definition: node.hpp:665
const DiscreteTypeInfo & get_type_info() const override
type info enables identification of the value accessor, as well as is_type and as_type.
Definition: node.hpp:651
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:171
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
Base class for check failure exceptions.
Definition: check.hpp:43
Node()=default
Construct an unitialized Node.
void clear_control_dependencies()
Remove all dependencies from this node.
const Shape & get_shape() const
Checks that there is exactly one output and returns its shape.
void set_arguments(const OutputVector &arguments)
Sets/replaces the arguments with new arguments.
void add_node_control_dependencies(std::shared_ptr< Node > source_node)
This node absorbs the control dependencies of source_node.
void remove_provenance_group_member(const std::shared_ptr< Node > &node)
Remove node to additional nodes that receive tags.
std::vector< Input< Node > > inputs()
Node(size_t output_size)
Construct an unitialized Node.
std::vector< Output< const Node > > outputs() const
Output< const Node > output(size_t output_index) const
NodeVector get_users(bool check_is_used=false) const
Get all the nodes that uses the current node.
const std::vector< std::shared_ptr< Node > > & get_control_dependencies() const
Get control dependencies registered on the node.
virtual void validate_and_infer_types()
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Input< const Node > input(size_t input_index) const
void add_provenance_tags_above(const OutputVector &base, const std::unordered_set< std::string > &tag_set)
Adds tag_set to this node and all intermediate nodes above base.
void add_provenance_group_member(const std::shared_ptr< Node > &node)
Add node to additional nodes that receive tags.
virtual OutputVector decompose_op() const
Decomposes the FusedOp into a sub-graph consisting of core ngraph ops.
Definition: node.hpp:219
descriptor::Tensor & get_output_tensor(size_t i) const
Returns the tensor for output or input i.
Node(const OutputVector &arguments, size_t output_size=1)
Constructor for Node subclasses that have metaclasses.
std::shared_ptr< Node > add_provenance_group_members_above(const OutputVector &base)
Add all nodes between this node and nodes in base as additional nodes to receive provenance tags.
size_t get_output_size() const
Returns the number of outputs from the node.
void set_arguments(const NodeVector &arguments)
Sets/replaces the arguments with new arguments.
Input< Node > input(size_t input_index)
const std::string & get_friendly_name() const
Gets the friendly name for a node. If no friendly name has been set via set_friendly_name then the no...
virtual bool evaluate(const HostTensorVector &output_values, const HostTensorVector &input_values) const
Evaluates the op on input_values putting results in output_values.
void set_input_is_relevant_to_shape(size_t i, bool relevant=true)
Marks an input as being relevant or irrelevant to the output shapes of this node.
void replace_provenance_group_member(const std::shared_ptr< Node > ¤t_node, const std::shared_ptr< Node > &replacement_node)
Replace current_node with replacement_node and transfer tags.
virtual const op::AutoBroadcastSpec & get_autob() const
virtual size_t get_default_output_index() const
Returns the output of the default output, or throws if there is none.
virtual const type_info_t & get_type_info() const =0
Node(const Node &)
Copying a node.
const element::Type & get_element_type() const
Checks that there is exactly one output and returns its element type.
Output< Node > output(size_t output_index)
void add_control_dependency(std::shared_ptr< Node > node)
This node cannot execute until node executes.
std::vector< Input< const Node > > inputs() const
void clear_control_dependents()
Remove this node as a dependency from all dependent nodes.
Node & operator=(const Node &)
Assignment operator.
std::vector< Output< Node > > outputs()
void remove_control_dependency(std::shared_ptr< Node > node)
Remove the dependency of this node on node.
size_t no_default_index() const
Throws no default.
void add_node_control_dependents(std::shared_ptr< Node > source_node)
This node becomes a dependent of every node dependent on source_node.
const std::vector< Node * > & get_control_dependents() const
Get nodes dependent on this node.
virtual std::ostream & write_description(std::ostream &os, uint32_t depth=0) const
Writes a description of a node to a stream.
void set_friendly_name(const std::string &name)
Sets a friendly name for a node. This does not overwrite the unique name of the node and is retrieved...
void transfer_control_dependents(std::shared_ptr< Node > replacement)
This node's control dependencies are replaced by replacement.
const std::string & get_name() const
Get the unique name of the node.
virtual size_t get_version() const
Definition: node.hpp:425
const std::set< std::shared_ptr< Node > > & get_provenance_group_members() const
const Shape & get_output_shape(size_t i) const
Returns the shape for output i.
void set_output_size(size_t output_size)
Sets the number of outputs.
void set_argument(size_t position, const Output< Node > &argument)
Sets/replaces the arguments with new arguments.
void transfer_provenance_tags(const std::shared_ptr< Node > &replacement)
Transfer provenance tags to replacement.
const PartialShape & get_output_partial_shape(size_t i) const
Returns the partial shape for output i.
bool operator<(const Node &other) const
Use instance ids for comparison instead of memory addresses to improve determinism.
Definition: node.hpp:428
void set_input_is_relevant_to_value(size_t i, bool relevant=true)
Marks an input as being relevant or irrelevant to the output values of this node.
void safe_delete(NodeVector &nodes, bool recurse)
Moves nodes that would be deleted from inputs to nodes to avoid stack overflows on deep networks.
Output< const Node > get_default_output() const
const element::Type & get_output_element_type(size_t i) const
Returns the element type for output i.
std::vector< Output< Node > > input_values() const
virtual std::string description() const
Get the string name for the type of the node, such as Add or Multiply. The class name,...
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Definition: node_output.hpp:118
Definition: node_output.hpp:36
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:46
Shape for a tensor.
Definition: shape.hpp:31
Definition: variant.hpp:30
Adapters will see visitor.
Definition: attribute_adapter.hpp:194
Compile-time descriptor of a first-class value that is a tensor.
Definition: tensor.hpp:40
Definition: element_type.hpp:61
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
NGRAPH_API ResultVector as_result_vector(const OutputVector &values)
Returns a ResultVector referencing values.
std::unordered_map< ngraph::Node *, std::shared_ptr< ngraph::Node > > NodeMap
Alias useful for cloning.
Definition: node.hpp:108
Implicit broadcast specification.
Definition: attr_types.hpp:323