activation_functions.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <memory>
8 #include <string>
9 
10 #include "ngraph/except.hpp"
11 #include "ngraph/node.hpp"
12 
13 #ifdef _WIN32
14 #pragma warning(push)
15 
16 #pragma warning(disable : 4100)
17 #endif
18 
19 // Prevents the compiler from complaining about or optimizing away variables
20 // that appear unused on Linux
21 #if (defined(__GNUC__) && !defined(__clang__))
22 #undef NG_ATTRIBUTE_UNUSED
23 #define NG_ATTRIBUTE_UNUSED __attribute__((__unused__))
24 #else
25 #define NG_ATTRIBUTE_UNUSED
26 #endif
27 
28 #define UNUSED_PARAMETER NG_ATTRIBUTE_UNUSED = 0
29 
30 namespace ngraph
31 {
32  namespace op
33  {
34  namespace util
35  {
36  namespace error
37  {
39  {
40  UnknownActivationFunction(const std::string& func_name)
41  : ngraph_error{"Unknown activation function: " + func_name}
42  {
43  }
44  };
45  } // namespace error
46 
47  namespace detail
48  {
49  std::shared_ptr<Node> sigmoid(const std::shared_ptr<Node>& arg,
50  float alpha UNUSED_PARAMETER,
51  float beta UNUSED_PARAMETER);
52  std::shared_ptr<Node> tanh(const std::shared_ptr<Node>& arg,
53  float alpha UNUSED_PARAMETER,
54  float beta UNUSED_PARAMETER);
55  std::shared_ptr<Node> relu(const std::shared_ptr<Node>& arg,
56  float alpha UNUSED_PARAMETER,
57  float beta UNUSED_PARAMETER);
58  std::shared_ptr<Node>
59  hardsigmoid(const std::shared_ptr<Node>& arg, float alpha, float beta);
60  } // namespace detail
61 
62  using ActivationFunctionType = std::shared_ptr<Node> (*)(const std::shared_ptr<Node>&,
63  float,
64  float);
65 
66  ///
67  /// \brief Class representing activation function used in RNN cells.
68  ///
69  class NGRAPH_API ActivationFunction
70  {
71  public:
72  ActivationFunction(ActivationFunctionType f, float alpha, float beta);
73  ActivationFunction(ActivationFunctionType f, float alpha);
74  ActivationFunction(ActivationFunctionType f);
75  ActivationFunction() = default;
76 
77  ///
78  /// \brief Calls stored activation function with provided node argument.
79  ///
80  std::shared_ptr<Node> operator()(const std::shared_ptr<Node>& arg) const;
81 
82  void set_alpha(float alpha) { m_alpha = alpha; }
83  void set_beta(float beta) { m_beta = beta; }
84 
85  private:
86  /// \brief Activation function wrapper.
87  ActivationFunctionType m_function;
88  /// \brief Activation function alpha parameter (may be unused).
89  float m_alpha;
90  /// \brief Activation function beta parameter (may be unused).
91  float m_beta;
92  };
93 
94  /// \brief Gets the activation function by name.
95  ///
96  /// \param[in] func_name The function name
97  ///
98  /// \throws UnknownActivationFunction When provided func_name is unknown.
99  ///
100  /// \return The activation function object.
101  ///
102  ActivationFunction get_activation_func_by_name(const std::string& func_name);
103  } // namespace util
104 
105  } // namespace op
106 
107 } // namespace ngraph
108 
109 #ifdef _WIN32
110 #pragma warning(pop)
111 #endif
112 
113 #ifdef UNUSED_PARAMETER
114 #undef UNUSED_PARAMETER
115 #endif
116 #ifdef NG_ATTRIBUTE_UNUSED
117 #undef NG_ATTRIBUTE_UNUSED
118 #endif
Base error for ngraph runtime errors.
Definition: except.hpp:16
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:70
std::shared_ptr< Node > operator()(const std::shared_ptr< Node > &arg) const
Calls stored activation function with provided node argument.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: activation_functions.hpp:39