9 #include "ngraph/axis_set.hpp"
10 #include "ngraph/op/constant.hpp"
11 #include "ngraph/op/op.hpp"
24 using SortType = TopKSortType;
25 using Mode = TopKMode;
46 const std::string& mode,
47 const std::string& sort,
60 virtual std::shared_ptr<Node>
61 clone_with_new_inputs(
const OutputVector& new_args)
const override;
70 void set_axis(
const int64_t axis);
71 Mode get_mode()
const {
return m_mode; }
72 void set_mode(
const Mode mode) { m_mode = mode; }
73 SortType get_sort_type()
const {
return m_sort; }
74 void set_sort_type(
const SortType sort) { m_sort = sort; }
75 element::Type get_index_element_type()
const {
return m_index_element_type; }
76 void set_index_element_type(
const element::Type& index_element_type)
78 m_index_element_type = index_element_type;
88 const HostTensorVector& inputs)
const override;
93 uint64_t m_normalized_axis;
98 virtual size_t read_k_from_constant_node(
const std::shared_ptr<Node>& node,
101 template <
typename T>
102 size_t validate_and_get_k(
const std::shared_ptr<op::Constant>& k_constant)
const;
103 Shape compute_output_shape(
const std::string& node_description,
105 const int64_t k)
const;
106 void set_axis(
const Rank input_rank,
const int64_t axis);
136 const std::string& mode,
137 const std::string& sort,
148 virtual std::shared_ptr<Node>
149 clone_with_new_inputs(
const OutputVector& new_args)
const override;
152 const HostTensorVector& inputs)
const override;
157 read_k_from_constant_node(
const std::shared_ptr<Node>& node,
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
Class representing a dimension, which may be dynamic (undetermined until runtime),...
Definition: dimension.hpp:23
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:34
Shape for a tensor.
Definition: shape.hpp:19
Definition: element_type.hpp:51
Root of all actual ops.
Definition: op.hpp:17
Computes indices and values of the k maximum/minimum values for each slice along specified axis.
Definition: topk.hpp:22
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
virtual size_t get_version() const override
Definition: topk.hpp:63
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
TopK()=default
Constructs a TopK operation.
TopK(const Output< Node > &data, const Output< Node > &k, const int64_t axis, const std::string &mode, const std::string &sort, const element::Type &index_element_type=element::i32)
Constructs a TopK operation with two outputs: values and indices. By default the indices output is de...
const NodeTypeInfo & get_type_info() const override
Definition: topk.hpp:28
size_t get_k() const
Returns the value of K, if available.
size_t get_default_output_index() const override
Returns the output of the default output, or throws if there is none.
Definition: topk.hpp:86
int64_t get_provided_axis() const
Returns axis value before normalization.
Definition: topk.hpp:69
uint64_t get_axis() const
Returns axis value after normalization.
Computes indices and values of the k maximum/minimum values for each slice along specified axis.
Definition: topk.hpp:115
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
const NodeTypeInfo & get_type_info() const override
Definition: topk.hpp:118
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
TopK(const Output< Node > &data, const Output< Node > &k, const int64_t axis, const std::string &mode, const std::string &sort, const element::Type &index_element_type=element::i32)
Constructs a TopK operation with two outputs: values and indices. By default the indices output is de...
TopK()=default
Constructs a TopK operation.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16