bucketize.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include "ngraph/op/op.hpp"
20 
21 namespace ngraph
22 {
23  namespace op
24  {
25  namespace v3
26  {
27  /// \brief Operation that bucketizes the input based on boundaries
28  class NGRAPH_API Bucketize : public Op
29  {
30  public:
31  static constexpr NodeTypeInfo type_info{"Bucketize", 3};
32  const NodeTypeInfo& get_type_info() const override { return type_info; }
33  Bucketize() = default;
34  /// \brief Constructs a Bucketize node
35 
36  /// \param data Input data to bucketize
37  /// \param buckets 1-D of sorted unique boundaries for buckets
38  /// \param output_type Output tensor type, "i64" or "i32", defaults to i64
39  /// \param with_right_bound indicates whether bucket includes the right or left
40  /// edge of interval. default true = includes right edge
41  Bucketize(const Output<Node>& data,
42  const Output<Node>& buckets,
43  const element::Type output_type = element::i64,
44  const bool with_right_bound = true);
45 
46  virtual void validate_and_infer_types() override;
47  virtual bool visit_attributes(AttributeVisitor& visitor) override;
48 
49  virtual std::shared_ptr<Node>
50  clone_with_new_inputs(const OutputVector& inputs) const override;
51 
52  element::Type get_output_type() const { return m_output_type; }
53  void set_output_type(element::Type output_type) { m_output_type = output_type; }
54  // Overload collision with method on Node
55  using Node::set_output_type;
56 
57  bool get_with_right_bound() const { return m_with_right_bound; }
58  void set_with_right_bound(bool with_right_bound)
59  {
60  m_with_right_bound = with_right_bound;
61  }
62 
63  private:
64  element::Type m_output_type;
65  bool m_with_right_bound;
66  };
67  }
68  using v3::Bucketize;
69  }
70 }
ngraph::op::v3::Bucketize::validate_and_infer_types
virtual void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
ngraph::op::v3::Bucketize::Bucketize
Bucketize(const Output< Node > &data, const Output< Node > &buckets, const element::Type output_type=element::i64, const bool with_right_bound=true)
Constructs a Bucketize node.
ngraph::element::Type
Definition: element_type.hpp:61
ngraph::op::v3::Bucketize::get_type_info
const NodeTypeInfo & get_type_info() const override
Definition: bucketize.hpp:32
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::AttributeVisitor
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:70
ngraph::op::v3::Bucketize
Operation that bucketizes the input based on boundaries.
Definition: bucketize.hpp:29
ngraph::op::Op
Root of all actual ops.
Definition: op.hpp:29