21 #include "ngraph/axis_set.hpp"
22 #include "ngraph/op/constant.hpp"
23 #include "ngraph/op/op.hpp"
36 using SortType = TopKSortType;
37 using Mode = TopKMode;
58 const std::string& mode,
59 const std::string& sort,
72 virtual std::shared_ptr<Node>
73 clone_with_new_inputs(
const OutputVector& new_args)
const override;
82 void set_axis(
const int64_t axis);
83 Mode get_mode()
const {
return m_mode; }
84 void set_mode(
const Mode mode) { m_mode = mode; }
85 SortType get_sort_type()
const {
return m_sort; }
86 void set_sort_type(
const SortType sort) { m_sort = sort; }
87 element::Type get_index_element_type()
const {
return m_index_element_type; }
88 void set_index_element_type(
const element::Type& index_element_type)
90 m_index_element_type = index_element_type;
100 const HostTensorVector& inputs)
const override;
104 uint64_t m_normalized_axis;
109 virtual size_t read_k_from_constant_node(
const std::shared_ptr<Node>& node,
112 template <
typename T>
113 size_t validate_and_get_k(
const std::shared_ptr<op::Constant>& k_constant)
const;
114 Shape compute_output_shape(
const std::string& node_description,
116 const int64_t k)
const;
117 void set_axis(
const Rank input_rank,
const int64_t axis);
147 const std::string& mode,
148 const std::string& sort,
159 virtual std::shared_ptr<Node>
160 clone_with_new_inputs(
const OutputVector& new_args)
const override;
163 const HostTensorVector& inputs)
const override;
167 read_k_from_constant_node(
const std::shared_ptr<Node>& node,
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
Class representing a dimension, which may be dynamic (undetermined until runtime),...
Definition: dimension.hpp:35
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:46
Shape for a tensor.
Definition: shape.hpp:31
Definition: element_type.hpp:61
Root of all actual ops.
Definition: op.hpp:29
Computes indices and values of the k maximum/minimum values for each slice along specified axis.
Definition: topk.hpp:34
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
virtual size_t get_version() const override
Definition: topk.hpp:75
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
TopK()=default
Constructs a TopK operation.
TopK(const Output< Node > &data, const Output< Node > &k, const int64_t axis, const std::string &mode, const std::string &sort, const element::Type &index_element_type=element::i32)
Constructs a TopK operation with two outputs: values and indices. By default the indices output is de...
const NodeTypeInfo & get_type_info() const override
Definition: topk.hpp:40
size_t get_k() const
Returns the value of K, if available.
size_t get_default_output_index() const override
Returns the output of the default output, or throws if there is none.
Definition: topk.hpp:98
int64_t get_provided_axis() const
Returns axis value before normalization.
Definition: topk.hpp:81
uint64_t get_axis() const
Returns axis value after normalization.
Computes indices and values of the k maximum/minimum values for each slice along specified axis.
Definition: topk.hpp:126
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
const NodeTypeInfo & get_type_info() const override
Definition: topk.hpp:129
TopK(const Output< Node > &data, const Output< Node > &k, const int64_t axis, const std::string &mode, const std::string &sort, const element::Type &index_element_type=element::i32)
Constructs a TopK operation with two outputs: values and indices. By default the indices output is de...
TopK()=default
Constructs a TopK operation.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28