detection_output.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/op/op.hpp"
8 
9 namespace ngraph
10 {
11  namespace op
12  {
14  {
15  int num_classes;
16  int background_label_id = 0;
17  int top_k = -1;
18  bool variance_encoded_in_target = false;
19  std::vector<int> keep_top_k;
20  std::string code_type = std::string{"caffe.PriorBoxParameter.CORNER"};
21  bool share_location = true;
22  float nms_threshold;
23  float confidence_threshold = 0;
24  bool clip_after_nms = false;
25  bool clip_before_nms = false;
26  bool decrease_label_id = false;
27  bool normalized = false;
28  size_t input_height = 1;
29  size_t input_width = 1;
30  float objectness_score = 0;
31  };
32 
33  namespace v0
34  {
35  /// \brief Layer which performs non-max suppression to
36  /// generate detection output using location and confidence predictions
37  class NGRAPH_API DetectionOutput : public Op
38  {
39  public:
40  static constexpr NodeTypeInfo type_info{"DetectionOutput", 0};
41  const NodeTypeInfo& get_type_info() const override { return type_info; }
42  DetectionOutput() = default;
43  /// \brief Constructs a DetectionOutput operation
44  ///
45  /// \param box_logits Box logits
46  /// \param class_preds Class predictions
47  /// \param proposals Proposals
48  /// \param aux_class_preds Auxilary class predictions
49  /// \param aux_box_preds Auxilary box predictions
50  /// \param attrs Detection Output attributes
51  DetectionOutput(const Output<Node>& box_logits,
52  const Output<Node>& class_preds,
53  const Output<Node>& proposals,
54  const Output<Node>& aux_class_preds,
55  const Output<Node>& aux_box_preds,
56  const DetectionOutputAttrs& attrs);
57 
58  /// \brief Constructs a DetectionOutput operation
59  ///
60  /// \param box_logits Box logits
61  /// \param class_preds Class predictions
62  /// \param proposals Proposals
63  /// \param attrs Detection Output attributes
64  DetectionOutput(const Output<Node>& box_logits,
65  const Output<Node>& class_preds,
66  const Output<Node>& proposals,
67  const DetectionOutputAttrs& attrs);
68 
69  void validate_and_infer_types() override;
70 
71  virtual std::shared_ptr<Node>
72  clone_with_new_inputs(const OutputVector& new_args) const override;
73 
74  const DetectionOutputAttrs& get_attrs() const { return m_attrs; }
75  virtual bool visit_attributes(AttributeVisitor& visitor) override;
76 
77  private:
78  DetectionOutputAttrs m_attrs;
79  };
80  } // namespace v0
81  using v0::DetectionOutput;
82  } // namespace op
83 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Root of all actual ops.
Definition: op.hpp:17
Layer which performs non-max suppression to generate detection output using location and confidence p...
Definition: detection_output.hpp:38
DetectionOutput(const Output< Node > &box_logits, const Output< Node > &class_preds, const Output< Node > &proposals, const Output< Node > &aux_class_preds, const Output< Node > &aux_box_preds, const DetectionOutputAttrs &attrs)
Constructs a DetectionOutput operation.
DetectionOutput(const Output< Node > &box_logits, const Output< Node > &class_preds, const Output< Node > &proposals, const DetectionOutputAttrs &attrs)
Constructs a DetectionOutput operation.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
const NodeTypeInfo & get_type_info() const override
Definition: detection_output.hpp:41
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27
Definition: detection_output.hpp:14