deformable_convolution.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/coordinate_diff.hpp"
8 #include "ngraph/op/op.hpp"
9 #include "ngraph/op/util/attr_types.hpp"
10 
11 namespace ngraph
12 {
13  namespace op
14  {
15  namespace v1
16  {
17  /// \brief DeformableConvolution operation.
18  class NGRAPH_API DeformableConvolution : public Op
19  {
20  public:
21  static constexpr NodeTypeInfo type_info{"DeformableConvolution", 1};
22  const NodeTypeInfo& get_type_info() const override { return type_info; }
23  /// \brief Constructs a conversion operation.
24  DeformableConvolution() = default;
25  /// \brief Constructs a conversion operation.
26  ///
27  /// \param arg Node that produces the input tensor.
28  /// \param offsets Node producing the deformable values tensor.
29  /// \param filters Node producing the filters(kernels) tensor with OIZYX
30  /// layout.
31  /// \param strides Convolution strides.
32  /// \param pads_begin Amount of padding to be added to the beginning along
33  /// each axis. For example in case of a 2D input the value
34  /// of (1, 2) means that 1 element will be added to the
35  /// top and 2 elements to the left.
36  /// \param pads_end Amount of padding to be added to the end along each
37  /// axis.
38  /// \param dilations The distance in width and height between the weights
39  /// in the filters tensor.
40  /// \param auto_pad Specifies how the automatic calculation of padding
41  /// should be done.
42  /// \param group The number of groups which both output and input
43  /// should be split into.
44  /// \param deformable_group The number of groups which deformable values and
45  /// output should be split into along the channel axis.
47  const Output<Node>& offsets,
48  const Output<Node>& filters,
49  const Strides& strides,
50  const CoordinateDiff& pads_begin,
51  const CoordinateDiff& pads_end,
52  const Strides& dilations,
53  const PadType& auto_pad = PadType::EXPLICIT,
54  const int64_t group = 1,
55  const int64_t deformable_group = 1);
56  bool visit_attributes(AttributeVisitor& visitor) override;
57 
58  void validate_and_infer_types() override;
59 
60  const Strides& get_strides() const { return m_strides; }
61  void set_strides(const Strides& strides) { m_strides = strides; }
62  const Strides& get_dilations() const { return m_dilations; }
63  void set_dilations(const Strides& dilations) { m_dilations = dilations; }
64  const CoordinateDiff& get_pads_begin() const { return m_pads_begin; }
65  void set_pads_begin(const CoordinateDiff& pads_begin) { m_pads_begin = pads_begin; }
66  const CoordinateDiff& get_pads_end() const { return m_pads_end; }
67  void set_pads_end(const CoordinateDiff& pads_end) { m_pads_end = pads_end; }
68  const PadType& get_auto_pad() const { return m_auto_pad; }
69  void set_auto_pad(const PadType& auto_pad) { m_auto_pad = auto_pad; }
70  int64_t get_group() const { return m_group; }
71  void set_group(const int64_t group) { m_group = group; }
72  int64_t get_deformable_group() const { return m_deformable_group; }
73  void set_deformable_group(const int64_t deformable_group)
74  {
75  m_deformable_group = deformable_group;
76  }
77  virtual std::shared_ptr<Node>
78  clone_with_new_inputs(const OutputVector& new_args) const override;
79 
80  protected:
81  Strides m_strides;
82  Strides m_dilations;
83  CoordinateDiff m_pads_begin;
84  CoordinateDiff m_pads_end;
85  PadType m_auto_pad;
86  int64_t m_group;
87  int64_t m_deformable_group;
88  };
89  } // namespace v1
90  } // namespace op
91 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A difference (signed) of tensor element coordinates.
Definition: coordinate_diff.hpp:18
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Strides for a tensor.
Definition: strides.hpp:18
Root of all actual ops.
Definition: op.hpp:17
DeformableConvolution operation.
Definition: deformable_convolution.hpp:19
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
DeformableConvolution()=default
Constructs a conversion operation.
const NodeTypeInfo & get_type_info() const override
Definition: deformable_convolution.hpp:22
DeformableConvolution(const Output< Node > &arg, const Output< Node > &offsets, const Output< Node > &filters, const Strides &strides, const CoordinateDiff &pads_begin, const CoordinateDiff &pads_end, const Strides &dilations, const PadType &auto_pad=PadType::EXPLICIT, const int64_t group=1, const int64_t deformable_group=1)
Constructs a conversion operation.
PadType
Padding Type used for Convolution and Pooling
Definition: attr_types.hpp:61
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27