rnn_cell.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "ngraph/node.hpp"
25 #include "ngraph/op/util/activation_functions.hpp"
26 #include "ngraph/op/util/fused_op.hpp"
27 #include "ngraph/op/util/rnn_cell_base.hpp"
28 
29 namespace ngraph
30 {
31  namespace op
32  {
33  namespace v0
34  {
35  ///
36  /// \brief Class for single RNN cell node.
37  ///
38  /// \note It follows notation and equations defined as in ONNX standard:
39  /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#RNN
40  ///
41  /// \note It calculates following equations:
42  ///
43  /// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
44  ///
45  /// * - Is a dot product,
46  /// f - is activation functions.
47  ///
48  /// \note This class represents only single *cell* (for current time step)
49  /// and not the whole RNN Sequence layer
50  ///
51  /// \sa LSTMSequence, LSTMCell, GRUCell
52  ///
53  class NGRAPH_API RNNCell : public util::RNNCellBase
54  {
55  public:
56  static constexpr NodeTypeInfo type_info{"RNNCell", 0};
57  const NodeTypeInfo& get_type_info() const override { return type_info; }
58  RNNCell();
59  ///
60  /// \brief Constructs RNNCell node.
61  ///
62  /// \param[in] X The input tensor with shape: [batch_size,
63  /// input_size].
64  /// \param[in] initial_hidden_state The hidden state tensor at current time step
65  /// with shape: [batch_size, hidden_size].
66  /// \param[in] W The weight tensor with shape: [hidden_size,
67  /// input_size].
68  /// \param[in] R The recurrence weight tensor with shape:
69  /// [hidden_size, hidden_size].
70  /// \param[in] hidden_size The number of hidden units for recurrent cell.
71  /// \param[in] activations The vector of activation functions used inside
72  /// recurrent cell.
73  /// \param[in] activations_alpha The vector of alpha parameters for activation
74  /// functions in order respective to activation
75  /// list.
76  /// \param[in] activations_beta The vector of beta parameters for activation
77  /// functions in order respective to activation
78  /// list.
79  /// \param[in] clip The value defining clipping range [-clip,
80  /// clip] on input of activation functions.
81  ///
83  const Output<Node>& X,
84  const Output<Node>& initial_hidden_state,
85  const Output<Node>& W,
86  const Output<Node>& R,
87  std::size_t hidden_size,
88  const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
89  const std::vector<float>& activations_alpha = {},
90  const std::vector<float>& activations_beta = {},
91  float clip = 0.f);
92 
93  ///
94  /// \brief Constructs RNNCell node.
95  ///
96  /// \param[in] X The input tensor with shape: [batch_size,
97  /// input_size].
98  /// \param[in] initial_hidden_state The hidden state tensor at current time step
99  /// with shape: [batch_size, hidden_size].
100  /// \param[in] W The weight tensor with shape: [hidden_size,
101  /// input_size].
102  /// \param[in] R The recurrence weight tensor with shape:
103  /// [hidden_size, hidden_size].
104  /// \param[in] B The bias tensor for input gate with shape:
105  /// [hidden_size].
106  /// \param[in] hidden_size The number of hidden units for recurrent cell.
107  /// \param[in] activations The vector of activation functions used inside
108  /// recurrent cell.
109  /// \param[in] activations_alpha The vector of alpha parameters for activation
110  /// functions in order respective to activation
111  /// list.
112  /// \param[in] activations_beta The vector of beta parameters for activation
113  /// functions in order respective to activation
114  /// list.
115  /// \param[in] clip The value defining clipping range [-clip,
116  /// clip] on input of activation functions.
117  ///
119  const Output<Node>& X,
120  const Output<Node>& initial_hidden_state,
121  const Output<Node>& W,
122  const Output<Node>& R,
123  const Output<Node>& B,
124  std::size_t hidden_size,
125  const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
126  const std::vector<float>& activations_alpha = {},
127  const std::vector<float>& activations_beta = {},
128  float clip = 0.f);
129 
130  void validate_and_infer_types() override;
131  bool visit_attributes(AttributeVisitor& visitor) override;
132  std::shared_ptr<Node>
133  clone_with_new_inputs(const OutputVector& new_args) const override;
134 
135  private:
136  ///
137  /// \brief Creates the default bias input initialized with zeros.
138  ///
139  /// \return The object of Output class.
140  ///
141  Output<Node> get_default_bias_input() const;
142 
143  ///
144  /// \brief The Activation function f.
145  ///
146  util::ActivationFunction m_activation_f;
147 
148  static constexpr std::size_t s_gates_count{1};
149  };
150  }
151  } // namespace op
152 } // namespace ngraph
ngraph::op::v0::RNNCell
Class for single RNN cell node.
Definition: rnn_cell.hpp:54
ngraph::op::util::RNNCellBase
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:66
ngraph::op::v0::RNNCell::validate_and_infer_types
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
ngraph::op::v0::RNNCell::RNNCell
RNNCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f)
Constructs RNNCell node.
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::op::v0::RNNCell::get_type_info
const NodeTypeInfo & get_type_info() const override
Definition: rnn_cell.hpp:57
ngraph::op::v0::RNNCell::RNNCell
RNNCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f)
Constructs RNNCell node.
ngraph::op::util::ActivationFunction
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:82
ngraph::AttributeVisitor
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:70