topk.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <memory>
8 
9 #include "ngraph/axis_set.hpp"
10 #include "ngraph/op/constant.hpp"
11 #include "ngraph/op/op.hpp"
12 
13 namespace ngraph
14 {
15  namespace op
16  {
17  namespace v1
18  {
19  /// \brief Computes indices and values of the k maximum/minimum values
20  /// for each slice along specified axis.
21  class NGRAPH_API TopK : public Op
22  {
23  public:
24  using SortType = TopKSortType;
25  using Mode = TopKMode;
26 
27  static constexpr NodeTypeInfo type_info{"TopK", 1};
28  const NodeTypeInfo& get_type_info() const override { return type_info; }
29  /// \brief Constructs a TopK operation
30  TopK() = default;
31  /// \brief Constructs a TopK operation with two outputs: values and indices.
32  /// By default the indices output is described by i32 data type.
33  ///
34  /// \param data The input tensor
35  /// \param k Specifies how many maximum/minimum elements should be computed
36  /// (note: scalar input tensor)
37  /// \param axis The axis along which to compute top k indices
38  /// \param mode Specifies which operation (min or max) is used to select
39  /// the biggest element of two.
40  /// \param sort Specifies order of output elements and/or indices
41  /// Accepted values: none, index, value
42  /// \param index_element_type Specyfies type of produced indices
43  TopK(const Output<Node>& data,
44  const Output<Node>& k,
45  const int64_t axis,
46  const std::string& mode,
47  const std::string& sort,
48  const element::Type& index_element_type = element::i32);
49 
50  TopK(const Output<Node>& data,
51  const Output<Node>& k,
52  const int64_t axis,
53  const Mode mode,
54  const SortType sort,
55  const element::Type& index_element_type = element::i32);
56 
57  bool visit_attributes(AttributeVisitor& visitor) override;
58  void validate_and_infer_types() override;
59 
60  virtual std::shared_ptr<Node>
61  clone_with_new_inputs(const OutputVector& new_args) const override;
62 
63  virtual size_t get_version() const override { return 1; }
64  /// \brief Returns axis value after normalization
65  /// \note If input rank required to normalization is dynamic, the exception is
66  /// thrown
67  uint64_t get_axis() const;
68  /// \brief Returns axis value before normalization
69  int64_t get_provided_axis() const { return m_axis; }
70  void set_axis(const int64_t axis);
71  Mode get_mode() const { return m_mode; }
72  void set_mode(const Mode mode) { m_mode = mode; }
73  SortType get_sort_type() const { return m_sort; }
74  void set_sort_type(const SortType sort) { m_sort = sort; }
75  element::Type get_index_element_type() const { return m_index_element_type; }
76  void set_index_element_type(const element::Type& index_element_type)
77  {
78  m_index_element_type = index_element_type;
79  }
80  /// \brief Returns the value of K, if available
81  ///
82  /// \note If the second input to this op is a constant, the value is retrieved
83  /// and returned. If the input is not constant(dynamic) this method returns 0
84  size_t get_k() const;
85  void set_k(size_t k);
86  size_t get_default_output_index() const override { return no_default_index(); }
87  bool evaluate(const HostTensorVector& outputs,
88  const HostTensorVector& inputs) const override;
89  bool has_evaluate() const override;
90 
91  protected:
92  int64_t m_axis;
93  uint64_t m_normalized_axis;
94  Mode m_mode;
95  SortType m_sort;
96  element::Type m_index_element_type{element::i32};
97 
98  virtual size_t read_k_from_constant_node(const std::shared_ptr<Node>& node,
99  const element::Type& k_element_type) const;
100 
101  template <typename T>
102  size_t validate_and_get_k(const std::shared_ptr<op::Constant>& k_constant) const;
103  Shape compute_output_shape(const std::string& node_description,
104  const PartialShape input_partial_shape,
105  const int64_t k) const;
106  void set_axis(const Rank input_rank, const int64_t axis);
107  };
108  } // namespace v1
109 
110  namespace v3
111  {
112  /// \brief Computes indices and values of the k maximum/minimum values
113  /// for each slice along specified axis.
114  class NGRAPH_API TopK : public v1::TopK
115  {
116  public:
117  static constexpr NodeTypeInfo type_info{"TopK", 3};
118  const NodeTypeInfo& get_type_info() const override { return type_info; }
119  /// \brief Constructs a TopK operation
120  TopK() = default;
121  /// \brief Constructs a TopK operation with two outputs: values and indices.
122  /// By default the indices output is described by i32 data type.
123  ///
124  /// \param data The input tensor
125  /// \param k Specifies how many maximum/minimum elements should be computed
126  /// (note: scalar input tensor)
127  /// \param axis The axis along which to compute top k indices
128  /// \param mode Specifies which operation (min or max) is used to select
129  /// the biggest element of two.
130  /// \param sort Specifies order of output elements and/or indices
131  /// Accepted values: none, index, value
132  /// \param index_element_type Specyfies type of produced indices
133  TopK(const Output<Node>& data,
134  const Output<Node>& k,
135  const int64_t axis,
136  const std::string& mode,
137  const std::string& sort,
138  const element::Type& index_element_type = element::i32);
139 
140  TopK(const Output<Node>& data,
141  const Output<Node>& k,
142  const int64_t axis,
143  const Mode mode,
144  const SortType sort,
145  const element::Type& index_element_type = element::i32);
146  bool visit_attributes(AttributeVisitor& visitor) override;
147  void validate_and_infer_types() override;
148  virtual std::shared_ptr<Node>
149  clone_with_new_inputs(const OutputVector& new_args) const override;
150 
151  bool evaluate(const HostTensorVector& outputs,
152  const HostTensorVector& inputs) const override;
153  bool has_evaluate() const override;
154 
155  protected:
156  virtual size_t
157  read_k_from_constant_node(const std::shared_ptr<Node>& node,
158  const element::Type& k_element_type) const override;
159  };
160  } // namespace v3
161  } // namespace op
162 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
Class representing a dimension, which may be dynamic (undetermined until runtime),...
Definition: dimension.hpp:23
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:34
Shape for a tensor.
Definition: shape.hpp:19
Definition: element_type.hpp:51
Root of all actual ops.
Definition: op.hpp:17
Computes indices and values of the k maximum/minimum values for each slice along specified axis.
Definition: topk.hpp:22
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
virtual size_t get_version() const override
Definition: topk.hpp:63
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
TopK()=default
Constructs a TopK operation.
TopK(const Output< Node > &data, const Output< Node > &k, const int64_t axis, const std::string &mode, const std::string &sort, const element::Type &index_element_type=element::i32)
Constructs a TopK operation with two outputs: values and indices. By default the indices output is de...
const NodeTypeInfo & get_type_info() const override
Definition: topk.hpp:28
size_t get_k() const
Returns the value of K, if available.
size_t get_default_output_index() const override
Returns the output of the default output, or throws if there is none.
Definition: topk.hpp:86
int64_t get_provided_axis() const
Returns axis value before normalization.
Definition: topk.hpp:69
uint64_t get_axis() const
Returns axis value after normalization.
Computes indices and values of the k maximum/minimum values for each slice along specified axis.
Definition: topk.hpp:115
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
const NodeTypeInfo & get_type_info() const override
Definition: topk.hpp:118
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
TopK(const Output< Node > &data, const Output< Node > &k, const int64_t axis, const std::string &mode, const std::string &sort, const element::Type &index_element_type=element::i32)
Constructs a TopK operation with two outputs: values and indices. By default the indices output is de...
TopK()=default
Constructs a TopK operation.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27