gelu.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/node.hpp"
8 #include "ngraph/op/op.hpp"
9 #include "ngraph/op/util/fused_op.hpp"
10 #include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
11 
12 namespace ngraph
13 {
14  namespace op
15  {
16  NGRAPH_SUPPRESS_DEPRECATED_START
17  namespace v0
18  {
19  /// \brief Gaussian Error Linear Unit
20  /// f(x) = 0.5 * x * (1 + erf( x / sqrt(2) )
21  class NGRAPH_API Gelu : public ngraph::op::util::FusedOp
22  {
23  public:
24  static constexpr NodeTypeInfo type_info{"Gelu", 0};
25  const NodeTypeInfo& get_type_info() const override { return type_info; }
26  Gelu();
27  /// \brief Constructs a Gelu operation.
28  ///
29  /// \param data Input tensor
30  Gelu(const Output<Node>& data);
31 
32  bool visit_attributes(AttributeVisitor& visitor) override;
33  virtual OutputVector decompose_op() const override;
34 
35  void pre_validate_and_infer_types() override;
36 
37  virtual std::shared_ptr<Node>
38  clone_with_new_inputs(const OutputVector& new_args) const override;
39  };
40  } // namespace v0
41  using v0::Gelu;
42  NGRAPH_SUPPRESS_DEPRECATED_END
43 
44  /// \brief Specifies the approximation to calculate Gelu
46  {
47  TANH,
48  ERF
49  };
50  NGRAPH_API std::ostream& operator<<(std::ostream& s, const GeluApproximationMode& type);
51 
52  namespace v7
53  {
54  /// \brief Gaussian Error Linear Unit
55  /// f(x) = 0.5 * x * (1 + erf( x / sqrt(2) ) for "approximation" = "erf"
56  /// f(x) = 0.5 * x * (1 + tanh([sqrt(2 / pi)] * [x + 0.044715^3]) for "approximation" =
57  /// "tanh"
58  class NGRAPH_API Gelu : public util::UnaryElementwiseArithmetic
59  {
60  public:
61  NGRAPH_RTTI_DECLARATION;
62 
63  Gelu() = default;
64  /// \brief Constructs a Gelu operation.
65  ///
66  /// \param data Input tensor
67  /// \param mode Approximation mode
68  Gelu(const Output<Node>& data,
69  GeluApproximationMode mode = GeluApproximationMode::ERF);
70 
71  bool visit_attributes(AttributeVisitor& visitor) override;
72 
73  void validate_and_infer_types() override;
74 
75  bool evaluate(const HostTensorVector& outputs,
76  const HostTensorVector& inputs) const override;
77  bool has_evaluate() const override;
78 
79  std::shared_ptr<Node>
80  clone_with_new_inputs(const OutputVector& new_args) const override;
81 
82  GeluApproximationMode get_approximation_mode() const;
83 
84  private:
85  GeluApproximationMode m_approximation_mode = GeluApproximationMode::ERF;
86  };
87  } // namespace v7
88  } // namespace op
89  template <>
90  class NGRAPH_API AttributeAdapter<op::GeluApproximationMode>
91  : public EnumAttributeAdapterBase<op::GeluApproximationMode>
92  {
93  public:
96  {
97  }
98 
99  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::GeluApproximationMode>",
100  0};
101  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
102  };
103 } // namespace ngraph
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:161
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
Access an enum via a string.
Definition: attribute_adapter.hpp:168
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Abstract base class for elementwise unary arithmetic operations, i.e., operations where the same scal...
Definition: unary_elementwise_arithmetic.hpp:37
Gaussian Error Linear Unit f(x) = 0.5 * x * (1 + erf( x / sqrt(2) )
Definition: gelu.hpp:22
Gelu(const Output< Node > &data)
Constructs a Gelu operation.
Gaussian Error Linear Unit f(x) = 0.5 * x * (1 + erf( x / sqrt(2) ) for "approximation" = "erf" f(x) ...
Definition: gelu.hpp:59
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Gelu(const Output< Node > &data, GeluApproximationMode mode=GeluApproximationMode::ERF)
Constructs a Gelu operation.
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
GeluApproximationMode
Specifies the approximation to calculate Gelu.
Definition: gelu.hpp:46
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27