mvn.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/node.hpp"
8 #include "ngraph/op/op.hpp"
9 #include "ngraph/op/util/fused_op.hpp"
10 
11 namespace ngraph
12 {
13  namespace op
14  {
15  NGRAPH_SUPPRESS_DEPRECATED_START
16 
17  namespace v0
18  {
19  /// \brief Operator performing Mean Variance Normalization
20  ///
21  class NGRAPH_API MVN : public ngraph::op::util::FusedOp
22  {
23  public:
24  NGRAPH_RTTI_DECLARATION;
25 
26  MVN();
27  /// \brief Constructs an MVN operation.
28  ///
29  /// \param data Input tensor with data
30  /// \param normalize_variance flag that denotes whether to perform variance
31  /// normalization.
32  /// \param across_channels flag that denotes if mean values are shared across
33  /// channels.
34  /// \param eps the number to be added to the variance to avoid division by zero when
35  /// normalizing the value
36  ///
37  MVN(const Output<Node>& data,
38  bool across_channels = true,
39  bool normalize_variance = true,
40  double eps = 1e-9);
41 
42  /// \brief Constructs an MVN operation.
43  ///
44  /// \param data Input tensor with data
45  /// \param reduction_axes A list of axes, along which to reduce.
46  /// \param normalize_variance flag that denotes whether to perform variance
47  /// normalization.
48  /// \param eps the number to be added to the variance to avoid division by zero when
49  /// normalizing the value
50  ///
51  MVN(const Output<Node>& data,
52  AxisSet reduction_axes,
53  bool normalize_variance = true,
54  double eps = 1e-9);
55 
56  virtual OutputVector decompose_op() const override;
57 
58  virtual void validate_and_infer_types() override;
59 
60  virtual bool visit_attributes(AttributeVisitor& visitor) override;
61 
62  virtual std::shared_ptr<Node>
63  clone_with_new_inputs(const OutputVector& new_args) const override;
64 
65  double get_eps() const { return m_eps; }
66  bool get_across_channels() const { return m_across_channels; }
67  bool get_normalize_variance() const { return m_normalize_variance; }
68  AxisSet get_reduction_axes() const { return m_reduction_axes; }
69  void set_reduction_axes(AxisSet axes) { m_reduction_axes = axes; }
70 
71  private:
72  double m_eps = 1e-9;
73  bool m_across_channels;
74  bool m_normalize_variance;
75  AxisSet m_reduction_axes;
76  };
77  } // namespace v0
78  using v0::MVN;
79 
80  NGRAPH_SUPPRESS_DEPRECATED_END
81 
82  /// \brief Specifies how eps is applied in MVN
83  enum class MVNEpsMode
84  {
85  // Apply eps inside sqrt
86  INSIDE_SQRT,
87  // Apply eps outside sqrt
88  OUTSIDE_SQRT
89  };
90 
91  NGRAPH_API
92  std::ostream& operator<<(std::ostream& s, const MVNEpsMode& type);
93 
94  namespace v6
95  {
96  /// \brief Operator performing Mean Variance Normalization
97  ///
98  class NGRAPH_API MVN : public ngraph::op::Op
99  {
100  public:
101  NGRAPH_RTTI_DECLARATION;
102 
103  MVN() = default;
104  /// \brief Constructs an MVN operation.
105  ///
106  /// \param data Input tensor with data
107  /// \param reduction_axes A list of axes, along which to reduce.
108  /// \param normalize_variance flag that denotes whether to perform variance
109  /// normalization.
110  /// \param eps the number to be added to the variance to avoid division by zero when
111  /// normalizing the value
112  /// \param eps_mode the mode of applying epsilon
113  ///
114  MVN(const Output<Node>& data,
115  const Output<Node>& reduction_axes,
116  bool normalize_variance,
117  float eps,
118  MVNEpsMode eps_mode);
119 
120  bool visit_attributes(AttributeVisitor& visitor) override;
121  void validate_and_infer_types() override;
122 
123  std::shared_ptr<Node>
124  clone_with_new_inputs(const OutputVector& new_args) const override;
125 
126  float get_eps() const { return m_eps; }
127  bool get_normalize_variance() const { return m_normalize_variance; }
128  MVNEpsMode get_eps_mode() const { return m_eps_mode; }
129 
130  private:
131  bool m_normalize_variance = true;
132  float m_eps = (float)1e-6;
133  MVNEpsMode m_eps_mode = MVNEpsMode::INSIDE_SQRT;
134  };
135  } // namespace v6
136  } // namespace op
137 
138  template <>
139  class NGRAPH_API AttributeAdapter<op::MVNEpsMode>
140  : public EnumAttributeAdapterBase<op::MVNEpsMode>
141  {
142  public:
145  {
146  }
147 
148  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::MVNEpsMode>", 0};
149  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
150  };
151 } // namespace ngraph
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:161
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A set of axes.
Definition: axis_set.hpp:19
Access an enum via a string.
Definition: attribute_adapter.hpp:168
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Root of all actual ops.
Definition: op.hpp:17
Operator performing Mean Variance Normalization.
Definition: mvn.hpp:22
MVN(const Output< Node > &data, AxisSet reduction_axes, bool normalize_variance=true, double eps=1e-9)
Constructs an MVN operation.
MVN(const Output< Node > &data, bool across_channels=true, bool normalize_variance=true, double eps=1e-9)
Constructs an MVN operation.
Operator performing Mean Variance Normalization.
Definition: mvn.hpp:99
MVN(const Output< Node > &data, const Output< Node > &reduction_axes, bool normalize_variance, float eps, MVNEpsMode eps_mode)
Constructs an MVN operation.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
MVNEpsMode
Specifies how eps is applied in MVN.
Definition: mvn.hpp:84
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27