batch_norm.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <memory>
20 
21 #include "ngraph/deprecated.hpp"
22 #include "ngraph/node.hpp"
23 #include "ngraph/op/op.hpp"
24 
25 namespace ngraph
26 {
27  namespace op
28  {
29  namespace v0
30  {
31  class NGRAPH_API BatchNormInference : public Op
32  {
33  public:
34  static constexpr NodeTypeInfo type_info{"BatchNormInference", 0};
35  const NodeTypeInfo& get_type_info() const override { return type_info; }
36  BatchNormInference() = default;
37  /// \param input [., C, ...]
38  /// \param gamma gamma scaling for normalized value. [C]
39  /// \param beta bias added to the scaled normalized value [C]
40  /// \param mean value for mean normalization [C]
41  /// \param variance value for variance normalization [C]
42  /// \param epsilon Avoids divsion by 0 if input has 0 variance
43  BatchNormInference(const Output<Node>& input,
44  const Output<Node>& gamma,
45  const Output<Node>& beta,
46  const Output<Node>& mean,
47  const Output<Node>& variance,
48  double epsilon);
49 
50  bool visit_attributes(AttributeVisitor& visitor) override;
51 
52  void validate_and_infer_types() override;
53 
54  double get_eps_value() const { return m_epsilon; }
55  void set_eps_value(double epsilon) { m_epsilon = epsilon; }
56  std::shared_ptr<Node>
57  clone_with_new_inputs(const OutputVector& new_args) const override;
58 
59  private:
60  static constexpr size_t INPUT_GAMMA = 0;
61  static constexpr size_t INPUT_BETA = 1;
62  static constexpr size_t INPUT_DATA = 2;
63  static constexpr size_t INPUT_MEAN = 3;
64  static constexpr size_t INPUT_VARIANCE = 4;
65 
66  double m_epsilon;
67  };
68  } // namespace v0
69  using v0::BatchNormInference;
70  }
71 }
ngraph::op::v0::BatchNormInference
Definition: batch_norm.hpp:32
ngraph::op::v0::BatchNormInference::get_type_info
const NodeTypeInfo & get_type_info() const override
Definition: batch_norm.hpp:35
ngraph::op::v0::BatchNormInference::validate_and_infer_types
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::AttributeVisitor
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:70
ngraph::op::v0::BatchNormInference::BatchNormInference
BatchNormInference(const Output< Node > &input, const Output< Node > &gamma, const Output< Node > &beta, const Output< Node > &mean, const Output< Node > &variance, double epsilon)
ngraph::op::Op
Root of all actual ops.
Definition: op.hpp:29