rnn_cell.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <cstddef>
8 #include <memory>
9 #include <string>
10 #include <vector>
11 
12 #include "ngraph/node.hpp"
13 #include "ngraph/op/util/activation_functions.hpp"
14 #include "ngraph/op/util/fused_op.hpp"
15 #include "ngraph/op/util/rnn_cell_base.hpp"
16 
17 namespace ngraph
18 {
19  namespace op
20  {
21  namespace v0
22  {
23  ///
24  /// \brief Class for single RNN cell node.
25  ///
26  /// \note It follows notation and equations defined as in ONNX standard:
27  /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#RNN
28  ///
29  /// \note It calculates following equations:
30  ///
31  /// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
32  ///
33  /// * - Is a dot product,
34  /// f - is activation functions.
35  ///
36  /// \note This class represents only single *cell* (for current time step)
37  /// and not the whole RNN Sequence layer
38  ///
39  /// \sa LSTMSequence, LSTMCell, GRUCell
40  ///
41  class NGRAPH_API RNNCell : public util::RNNCellBase
42  {
43  public:
44  static constexpr NodeTypeInfo type_info{"RNNCell", 0};
45  const NodeTypeInfo& get_type_info() const override { return type_info; }
46  RNNCell();
47  ///
48  /// \brief Constructs RNNCell node.
49  ///
50  /// \param[in] X The input tensor with shape: [batch_size,
51  /// input_size].
52  /// \param[in] initial_hidden_state The hidden state tensor at current time step
53  /// with shape: [batch_size, hidden_size].
54  /// \param[in] W The weight tensor with shape: [hidden_size,
55  /// input_size].
56  /// \param[in] R The recurrence weight tensor with shape:
57  /// [hidden_size, hidden_size].
58  /// \param[in] hidden_size The number of hidden units for recurrent cell.
59  /// \param[in] activations The vector of activation functions used inside
60  /// recurrent cell.
61  /// \param[in] activations_alpha The vector of alpha parameters for activation
62  /// functions in order respective to activation
63  /// list.
64  /// \param[in] activations_beta The vector of beta parameters for activation
65  /// functions in order respective to activation
66  /// list.
67  /// \param[in] clip The value defining clipping range [-clip,
68  /// clip] on input of activation functions.
69  ///
71  const Output<Node>& X,
72  const Output<Node>& initial_hidden_state,
73  const Output<Node>& W,
74  const Output<Node>& R,
75  std::size_t hidden_size,
76  const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
77  const std::vector<float>& activations_alpha = {},
78  const std::vector<float>& activations_beta = {},
79  float clip = 0.f);
80 
81  ///
82  /// \brief Constructs RNNCell node.
83  ///
84  /// \param[in] X The input tensor with shape: [batch_size,
85  /// input_size].
86  /// \param[in] initial_hidden_state The hidden state tensor at current time step
87  /// with shape: [batch_size, hidden_size].
88  /// \param[in] W The weight tensor with shape: [hidden_size,
89  /// input_size].
90  /// \param[in] R The recurrence weight tensor with shape:
91  /// [hidden_size, hidden_size].
92  /// \param[in] B The bias tensor for input gate with shape:
93  /// [hidden_size].
94  /// \param[in] hidden_size The number of hidden units for recurrent cell.
95  /// \param[in] activations The vector of activation functions used inside
96  /// recurrent cell.
97  /// \param[in] activations_alpha The vector of alpha parameters for activation
98  /// functions in order respective to activation
99  /// list.
100  /// \param[in] activations_beta The vector of beta parameters for activation
101  /// functions in order respective to activation
102  /// list.
103  /// \param[in] clip The value defining clipping range [-clip,
104  /// clip] on input of activation functions.
105  ///
107  const Output<Node>& X,
108  const Output<Node>& initial_hidden_state,
109  const Output<Node>& W,
110  const Output<Node>& R,
111  const Output<Node>& B,
112  std::size_t hidden_size,
113  const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
114  const std::vector<float>& activations_alpha = {},
115  const std::vector<float>& activations_beta = {},
116  float clip = 0.f);
117 
118  void validate_and_infer_types() override;
119  bool visit_attributes(AttributeVisitor& visitor) override;
120  std::shared_ptr<Node>
121  clone_with_new_inputs(const OutputVector& new_args) const override;
122 
123  private:
124  ///
125  /// \brief Creates the default bias input initialized with zeros.
126  ///
127  /// \return The object of Output class.
128  ///
129  Output<Node> get_default_bias_input() const;
130 
131  ///
132  /// \brief The Activation function f.
133  ///
134  util::ActivationFunction m_activation_f;
135 
136  static constexpr std::size_t s_gates_count{1};
137  };
138  } // namespace v0
139  } // namespace op
140 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:70
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:55
Class for single RNN cell node.
Definition: rnn_cell.hpp:42
RNNCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f)
Constructs RNNCell node.
const NodeTypeInfo & get_type_info() const override
Definition: rnn_cell.hpp:45
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
RNNCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f)
Constructs RNNCell node.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27