activation_functions.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <memory>
20 #include <string>
21 
22 #include "ngraph/except.hpp"
23 #include "ngraph/node.hpp"
24 
25 #ifdef _WIN32
26 #pragma warning(push)
27 
28 #pragma warning(disable : 4100)
29 #endif
30 
31 // Prevents the compiler from complaining about or optimizing away variables
32 // that appear unused on Linux
33 #if (defined(__GNUC__) && !defined(__clang__))
34 #undef NG_ATTRIBUTE_UNUSED
35 #define NG_ATTRIBUTE_UNUSED __attribute__((__unused__))
36 #else
37 #define NG_ATTRIBUTE_UNUSED
38 #endif
39 
40 #define UNUSED_PARAMETER NG_ATTRIBUTE_UNUSED = 0
41 
42 namespace ngraph
43 {
44  namespace op
45  {
46  namespace util
47  {
48  namespace error
49  {
50  struct UnknownActivationFunction : ngraph_error
51  {
52  UnknownActivationFunction(const std::string& func_name)
53  : ngraph_error{"Unknown activation function: " + func_name}
54  {
55  }
56  };
57  }
58 
59  namespace detail
60  {
61  std::shared_ptr<Node> sigmoid(const std::shared_ptr<Node>& arg,
62  float alpha UNUSED_PARAMETER,
63  float beta UNUSED_PARAMETER);
64  std::shared_ptr<Node> tanh(const std::shared_ptr<Node>& arg,
65  float alpha UNUSED_PARAMETER,
66  float beta UNUSED_PARAMETER);
67  std::shared_ptr<Node> relu(const std::shared_ptr<Node>& arg,
68  float alpha UNUSED_PARAMETER,
69  float beta UNUSED_PARAMETER);
70  std::shared_ptr<Node>
71  hardsigmoid(const std::shared_ptr<Node>& arg, float alpha, float beta);
72  }
73 
74  using ActivationFunctionType = std::shared_ptr<Node> (*)(const std::shared_ptr<Node>&,
75  float,
76  float);
77 
78  ///
79  /// \brief Class representing activation function used in RNN cells.
80  ///
81  class NGRAPH_API ActivationFunction
82  {
83  public:
84  ActivationFunction(ActivationFunctionType f, float alpha, float beta);
85  ActivationFunction(ActivationFunctionType f, float alpha);
86  ActivationFunction(ActivationFunctionType f);
87  ActivationFunction() = default;
88 
89  ///
90  /// \brief Calls stored activation function with provided node argument.
91  ///
92  std::shared_ptr<Node> operator()(const std::shared_ptr<Node>& arg) const;
93 
94  void set_alpha(float alpha) { m_alpha = alpha; }
95  void set_beta(float beta) { m_beta = beta; }
96  private:
97  /// \brief Activation function wrapper.
98  ActivationFunctionType m_function;
99  /// \brief Activation function alpha parameter (may be unused).
100  float m_alpha;
101  /// \brief Activation function beta parameter (may be unused).
102  float m_beta;
103  };
104 
105  /// \brief Gets the activation function by name.
106  ///
107  /// \param[in] func_name The function name
108  ///
109  /// \throws UnknownActivationFunction When provided func_name is unknown.
110  ///
111  /// \return The activation function object.
112  ///
113  ActivationFunction get_activation_func_by_name(const std::string& func_name);
114  } // namespace util
115 
116  } // namespace op
117 
118 } // namespace ngraph
119 
120 #ifdef _WIN32
121 #pragma warning(pop)
122 #endif
123 
124 #ifdef UNUSED_PARAMETER
125 #undef UNUSED_PARAMETER
126 #endif
127 #ifdef NG_ATTRIBUTE_UNUSED
128 #undef NG_ATTRIBUTE_UNUSED
129 #endif
ngraph::op::util::error::UnknownActivationFunction
Definition: activation_functions.hpp:51
ngraph::op::util::ActivationFunction::operator()
std::shared_ptr< Node > operator()(const std::shared_ptr< Node > &arg) const
Calls stored activation function with provided node argument.
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::op::util::ActivationFunction
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:82