batch_norm.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <memory>
8 
9 #include "ngraph/deprecated.hpp"
10 #include "ngraph/node.hpp"
11 #include "ngraph/op/op.hpp"
12 
13 namespace ngraph
14 {
15  namespace op
16  {
17  namespace v0
18  {
19  class NGRAPH_API BatchNormInference : public Op
20  {
21  public:
22  NGRAPH_RTTI_DECLARATION;
23  BatchNormInference() = default;
24  /// \param input [., C, ...]
25  /// \param gamma gamma scaling for normalized value. [C]
26  /// \param beta bias added to the scaled normalized value [C]
27  /// \param mean value for mean normalization [C]
28  /// \param variance value for variance normalization [C]
29  /// \param epsilon Avoids divsion by 0 if input has 0 variance
31  const Output<Node>& gamma,
32  const Output<Node>& beta,
33  const Output<Node>& mean,
34  const Output<Node>& variance,
35  double epsilon);
36 
37  bool visit_attributes(AttributeVisitor& visitor) override;
38 
39  void validate_and_infer_types() override;
40 
41  double get_eps_value() const { return m_epsilon; }
42  void set_eps_value(double epsilon) { m_epsilon = epsilon; }
43  std::shared_ptr<Node>
44  clone_with_new_inputs(const OutputVector& new_args) const override;
45 
46  private:
47  static constexpr size_t INPUT_GAMMA = 0;
48  static constexpr size_t INPUT_BETA = 1;
49  static constexpr size_t INPUT_DATA = 2;
50  static constexpr size_t INPUT_MEAN = 3;
51  static constexpr size_t INPUT_VARIANCE = 4;
52 
53  double m_epsilon;
54  };
55  } // namespace v0
56  namespace v5
57  {
58  class NGRAPH_API BatchNormInference : public Op
59  {
60  public:
61  NGRAPH_RTTI_DECLARATION;
62  BatchNormInference() = default;
63  /// \param input [., C, ...]
64  /// \param gamma gamma scaling for normalized value. [C]
65  /// \param beta bias added to the scaled normalized value [C]
66  /// \param mean value for mean normalization [C]
67  /// \param variance value for variance normalization [C]
68  /// \param epsilon Avoids divsion by 0 if input has 0 variance
70  const Output<Node>& gamma,
71  const Output<Node>& beta,
72  const Output<Node>& mean,
73  const Output<Node>& variance,
74  double epsilon);
75 
76  bool visit_attributes(AttributeVisitor& visitor) override;
77 
78  void validate_and_infer_types() override;
79 
80  double get_eps_value() const { return m_epsilon; }
81  void set_eps_value(double epsilon) { m_epsilon = epsilon; }
82  std::shared_ptr<Node>
83  clone_with_new_inputs(const OutputVector& new_args) const override;
84 
85  private:
86  static constexpr size_t INPUT_DATA = 0;
87  static constexpr size_t INPUT_GAMMA = 1;
88  static constexpr size_t INPUT_BETA = 2;
89  static constexpr size_t INPUT_MEAN = 3;
90  static constexpr size_t INPUT_VARIANCE = 4;
91 
92  double m_epsilon;
93  };
94  } // namespace v5
95  } // namespace op
96 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Root of all actual ops.
Definition: op.hpp:17
Definition: batch_norm.hpp:20
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
BatchNormInference(const Output< Node > &input, const Output< Node > &gamma, const Output< Node > &beta, const Output< Node > &mean, const Output< Node > &variance, double epsilon)
Definition: batch_norm.hpp:59
BatchNormInference(const Output< Node > &input, const Output< Node > &gamma, const Output< Node > &beta, const Output< Node > &mean, const Output< Node > &variance, double epsilon)
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16