lstm_cell.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <cstddef>
8 #include <memory>
9 #include <string>
10 #include <vector>
11 
12 #include "ngraph/node.hpp"
13 #include "ngraph/op/util/activation_functions.hpp"
14 #include "ngraph/op/util/fused_op.hpp"
15 #include "ngraph/op/util/rnn_cell_base.hpp"
16 
17 namespace ngraph
18 {
19  namespace op
20  {
21  enum class LSTMWeightsFormat
22  {
23  FICO, // IE
24  ICOF, // PyTorch
25  IFCO, // DNNL, TF, MxNet
26  IFOC, // Caffe
27  IOFC, // ONNX
28  };
29 
30  namespace v0
31  {
32  ///
33  /// \brief Class for single lstm cell node.
34  ///
35  /// \note Following implementation supports:
36  /// \li \c peepholes Gers & Schmidhuber (2000)
37  /// https://ieeexplore.ieee.org/document/861302
38  /// \li Coupling input and forget gates.
39  ///
40  /// \note It calculates following equations:
41  ///
42  /// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
43  /// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
44  /// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
45  /// Ct = ft (.) Ct-1 + it (.) ct
46  /// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
47  /// Ht = ot (.) h(Ct)
48  ///
49  /// * - Is a dot product,
50  /// (.) - is a Hadamard product (element-wise),
51  /// f, g, h - are activation functions.
52  ///
53  /// \note This class represents only single *cell* (for current time step) and not
54  /// the whole LSTM Sequence layer
55  ///
56  /// \sa LSTMSequence, RNNCell, GRUCell
57  ///
58  class NGRAPH_API LSTMCell : public util::RNNCellBase
59  {
60  public:
61  static constexpr NodeTypeInfo type_info{"LSTMCell", 0};
62  const NodeTypeInfo& get_type_info() const override { return type_info; }
63  LSTMCell();
64  ///
65  /// \brief Constructs LSTMCell node.
66  ///
67  /// \param[in] X The input tensor with shape: [batch_size,
68  /// input_size].
69  /// \param[in] initial_hidden_state The hidden state tensor at current time step
70  /// with shape: [batch_size, hidden_size].
71  /// \param[in] initial_cell_state The cell state tensor at current time step
72  /// with shape: [batch_size, hidden_size].
73  /// \param[in] W The gate weights tensor with shape:
74  /// [4*hidden_size, input_size].
75  /// \param[in] R The recurrence weights tensor with shape:
76  /// [4*hidden_size, hidden_size].
77  /// \param[in] hidden_size The number of hidden units for recurrent cell.
78  /// \param[in] weights_format The order of gates in weights tensors. The
79  /// default format is IFCO since it is used by
80  /// DNNL.
81  /// \param[in] activations The vector of activation functions used inside
82  /// recurrent cell.
83  /// \param[in] activations_alpha The vector of alpha parameters for activation
84  /// functions in order respective to activation
85  /// list.
86  /// \param[in] activations_beta The vector of beta parameters for activation
87  /// functions in order respective to activation
88  /// list.
89  /// \param[in] clip The value defining clipping range [-clip,
90  /// clip] on input of activation functions.
91  /// \param[in] input_forget Controls coupling input and forget gates.
92  ///
94  const Output<Node>& initial_hidden_state,
95  const Output<Node>& initial_cell_state,
96  const Output<Node>& W,
97  const Output<Node>& R,
98  std::size_t hidden_size,
99  LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
100  const std::vector<std::string>& activations =
101  std::vector<std::string>{"sigmoid", "tanh", "tanh"},
102  const std::vector<float>& activations_alpha = {},
103  const std::vector<float>& activations_beta = {},
104  float clip = 0.f,
105  bool input_forget = false);
106 
107  ///
108  /// \brief Constructs LSTMCell node.
109  ///
110  /// \param[in] X The input tensor with shape: [batch_size,
111  /// input_size].
112  /// \param[in] initial_hidden_state The hidden state tensor at current time step
113  /// with shape: [batch_size, hidden_size].
114  /// \param[in] initial_cell_state The cell state tensor at current time step
115  /// with shape: [batch_size, hidden_size].
116  /// \param[in] W The weight tensor with shape: [4*hidden_size,
117  /// input_size].
118  /// \param[in] R The recurrence weight tensor with shape:
119  /// [4*hidden_size, hidden_size].
120  /// \param[in] B The bias tensor for gates with shape:
121  /// [4*hidden_size].
122  /// \param[in] hidden_size The number of hidden units for recurrent cell.
123  /// \param[in] weights_format The order of gates in weights tensors. The
124  /// default format is IFCO since it is used by
125  /// DNNL.
126  /// \param[in] activations The vector of activation functions used inside
127  /// recurrent cell.
128  /// \param[in] activations_alpha The vector of alpha parameters for activation
129  /// functions in order respective to activation
130  /// list.
131  /// \param[in] activations_beta The vector of beta parameters for activation
132  /// functions in order respective to activation
133  /// list.
134  /// \param[in] clip The value defining clipping range [-clip,
135  /// clip] on input of activation functions.
136  /// \param[in] input_forget Controls coupling input and forget gates.
137  ///
139  const Output<Node>& initial_hidden_state,
140  const Output<Node>& initial_cell_state,
141  const Output<Node>& W,
142  const Output<Node>& R,
143  const Output<Node>& B,
144  std::size_t hidden_size,
145  LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
146  const std::vector<std::string>& activations =
147  std::vector<std::string>{"sigmoid", "tanh", "tanh"},
148  const std::vector<float>& activations_alpha = {},
149  const std::vector<float>& activations_beta = {},
150  float clip = 0.f,
151  bool input_forget = false);
152 
153  ///
154  /// \brief Constructs LSTMCell node.
155  ///
156  /// \param[in] X The input tensor with shape: [batch_size,
157  /// input_size].
158  /// \param[in] initial_hidden_state The hidden state tensor at current time step
159  /// with shape: [batch_size, hidden_size].
160  /// \param[in] initial_cell_state The cell state tensor at current time step
161  /// with shape: [batch_size, hidden_size].
162  /// \param[in] W The weight tensor with shape: [4*hidden_size,
163  /// input_size].
164  /// \param[in] R The recurrence weight tensor with shape:
165  /// [4*hidden_size, hidden_size].
166  /// \param[in] B The bias tensor for gates with shape:
167  /// [4*hidden_size].
168  /// \param[in] P The weight tensor for peepholes with shape:
169  /// [3*hidden_size] - 3 equals to only iof gates.
170  /// The order is: input, output, forget gates.
171  /// \param[in] hidden_size The number of hidden units for recurrent cell.
172  /// \param[in] weights_format The order of gates in weights tensors. The
173  /// default format is IFCO since it is used by
174  /// DNNL.
175  /// \param[in] activations The vector of activation functions used inside
176  /// recurrent cell.
177  /// \param[in] activations_alpha The vector of alpha parameters for activation
178  /// functions in order respective to activation
179  /// list.
180  /// \param[in] activations_beta The vector of beta parameters for activation
181  /// functions in order respective to activation
182  /// list.
183  /// \param[in] clip The value defining clipping range [-clip,
184  /// clip] on input of activation functions.
185  /// \param[in] input_forget Controls coupling input and forget gates.
186  ///
188  const Output<Node>& initial_hidden_state,
189  const Output<Node>& initial_cell_state,
190  const Output<Node>& W,
191  const Output<Node>& R,
192  const Output<Node>& B,
193  const Output<Node>& P,
194  std::size_t hidden_size,
195  LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
196  const std::vector<std::string>& activations =
197  std::vector<std::string>{"sigmoid", "tanh", "tanh"},
198  const std::vector<float>& activations_alpha = {},
199  const std::vector<float>& activations_beta = {},
200  float clip = 0.f,
201  bool input_forget = false);
202 
203  virtual void validate_and_infer_types() override;
204  bool visit_attributes(AttributeVisitor& visitor) override;
205  virtual std::shared_ptr<Node>
206  clone_with_new_inputs(const OutputVector& new_args) const override;
207 
208  bool get_input_forget() const { return m_input_forget; }
209  LSTMWeightsFormat get_weights_format() const { return m_weights_format; }
210 
211  private:
212  ///
213  /// \brief Creates the default bias input initialized with zeros.
214  ///
215  /// \return The object of Output class.
216  ///
217  Output<Node> get_default_bias_input() const;
218 
219  ///
220  /// \brief Creates the default peepholes input initialized with zeros.
221  ///
222  /// \return The object of Output class.
223  ///
224  Output<Node> get_default_peepholes_input() const;
225  ///
226  /// \brief The Activation function f.
227  ///
228  util::ActivationFunction m_activation_f;
229  ///
230  /// \brief The Activation function g.
231  ///
232  util::ActivationFunction m_activation_g;
233  ///
234  /// \brief The Activation function h.
235  ///
236  util::ActivationFunction m_activation_h;
237  ///
238  /// \brief Controls whether to couple input and forget gates.
239  ///
240  bool m_input_forget = false;
241 
242  ///
243  /// \brief The order of gates in weights tensors.
244  ///
245  LSTMWeightsFormat m_weights_format;
246 
247  static constexpr std::size_t s_gates_count{4};
248  static constexpr std::size_t s_peepholes_count{3};
249  };
250  } // namespace v0
251 
252  namespace v4
253  {
254  ///
255  /// \brief Class for single lstm cell node.
256  ///
257  /// \note Following implementation supports:
258  /// \li \c peepholes Gers & Schmidhuber (2000)
259  /// https://ieeexplore.ieee.org/document/861302
260  /// \li Coupling input and forget gates.
261  ///
262  /// \note It calculates following equations:
263  ///
264  /// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
265  /// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
266  /// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
267  /// Ct = ft (.) Ct-1 + it (.) ct
268  /// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
269  /// Ht = ot (.) h(Ct)
270  ///
271  /// * - Is a dot product,
272  /// (.) - is a Hadamard product (element-wise),
273  /// f, g, h - are activation functions.
274  ///
275  /// \note This class represents only single *cell* (for current time step) and not
276  /// the whole LSTM Sequence layer
277  ///
278  /// \sa LSTMSequence, RNNCell, GRUCell
279  ///
280  class NGRAPH_API LSTMCell : public util::RNNCellBase
281  {
282  public:
283  static constexpr NodeTypeInfo type_info{"LSTMCell", 4};
284  const NodeTypeInfo& get_type_info() const override { return type_info; }
285  LSTMCell();
286  ///
287  /// \brief Constructs LSTMCell node.
288  ///
289  /// \param[in] X The input tensor with shape: [batch_size,
290  /// input_size].
291  /// \param[in] initial_hidden_state The hidden state tensor at current time step
292  /// with shape: [batch_size, hidden_size].
293  /// \param[in] initial_cell_state The cell state tensor at current time step
294  /// with shape: [batch_size, hidden_size].
295  /// \param[in] W The gate weights tensor with shape:
296  /// [4*hidden_size, input_size].
297  /// \param[in] R The recurrence weights tensor with shape:
298  /// [4*hidden_size, hidden_size].
299  /// \param[in] hidden_size The number of hidden units for recurrent cell.
300  /// \param[in] activations The vector of activation functions used inside
301  /// recurrent cell.
302  /// \param[in] activations_alpha The vector of alpha parameters for activation
303  /// functions in order respective to activation
304  /// list.
305  /// \param[in] activations_beta The vector of beta parameters for activation
306  /// functions in order respective to activation
307  /// list.
308  /// \param[in] clip The value defining clipping range [-clip,
309  /// clip] on input of activation functions.
311  const Output<Node>& initial_hidden_state,
312  const Output<Node>& initial_cell_state,
313  const Output<Node>& W,
314  const Output<Node>& R,
315  std::size_t hidden_size,
316  const std::vector<std::string>& activations =
317  std::vector<std::string>{"sigmoid", "tanh", "tanh"},
318  const std::vector<float>& activations_alpha = {},
319  const std::vector<float>& activations_beta = {},
320  float clip = 0.f);
321 
322  ///
323  /// \brief Constructs LSTMCell node.
324  ///
325  /// \param[in] X The input tensor with shape: [batch_size,
326  /// input_size].
327  /// \param[in] initial_hidden_state The hidden state tensor at current time step
328  /// with shape: [batch_size, hidden_size].
329  /// \param[in] initial_cell_state The cell state tensor at current time step
330  /// with shape: [batch_size, hidden_size].
331  /// \param[in] W The weight tensor with shape: [4*hidden_size,
332  /// input_size].
333  /// \param[in] R The recurrence weight tensor with shape:
334  /// [4*hidden_size, hidden_size].
335  /// \param[in] B The bias tensor for gates with shape:
336  /// [4*hidden_size].
337  /// \param[in] hidden_size The number of hidden units for recurrent cell.
338  /// \param[in] activations The vector of activation functions used inside
339  /// recurrent cell.
340  /// \param[in] activations_alpha The vector of alpha parameters for activation
341  /// functions in order respective to activation
342  /// list.
343  /// \param[in] activations_beta The vector of beta parameters for activation
344  /// functions in order respective to activation
345  /// list.
346  /// \param[in] clip The value defining clipping range [-clip,
347  /// clip] on input of activation functions.
348  ///
350  const Output<Node>& initial_hidden_state,
351  const Output<Node>& initial_cell_state,
352  const Output<Node>& W,
353  const Output<Node>& R,
354  const Output<Node>& B,
355  std::size_t hidden_size,
356  const std::vector<std::string>& activations =
357  std::vector<std::string>{"sigmoid", "tanh", "tanh"},
358  const std::vector<float>& activations_alpha = {},
359  const std::vector<float>& activations_beta = {},
360  float clip = 0.f);
361 
362  void validate_and_infer_types() override;
363 
364  bool visit_attributes(AttributeVisitor& visitor) override;
365  std::shared_ptr<Node>
366  clone_with_new_inputs(const OutputVector& new_args) const override;
367 
368  private:
369  ///
370  /// \brief Creates the default bias input initialized with zeros.
371  ///
372  /// \return The object of Output class.
373  ///
374  Output<Node> get_default_bias_input() const;
375 
376  ///
377  /// \brief The Activation function f.
378  ///
379  util::ActivationFunction m_activation_f;
380  ///
381  /// \brief The Activation function g.
382  ///
383  util::ActivationFunction m_activation_g;
384  ///
385  /// \brief The Activation function h.
386  ///
387  util::ActivationFunction m_activation_h;
388 
389  static constexpr std::size_t s_gates_count{4};
390  };
391  } // namespace v4
392  } // namespace op
393 
394  NGRAPH_API
395  std::ostream& operator<<(std::ostream& s, const op::LSTMWeightsFormat& type);
396 
397  template <>
398  class NGRAPH_API AttributeAdapter<op::LSTMWeightsFormat>
399  : public EnumAttributeAdapterBase<op::LSTMWeightsFormat>
400  {
401  public:
402  AttributeAdapter(op::LSTMWeightsFormat& value)
404  {
405  }
406 
407  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::LSTMWeightsFormat>", 1};
408  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
409  };
410 } // namespace ngraph
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:161
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
Access an enum via a string.
Definition: attribute_adapter.hpp:168
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:70
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:55
Class for single lstm cell node.
Definition: lstm_cell.hpp:59
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool input_forget=false)
Constructs LSTMCell node.
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool input_forget=false)
Constructs LSTMCell node.
virtual void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, const Output< Node > &P, std::size_t hidden_size, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool input_forget=false)
Constructs LSTMCell node.
const NodeTypeInfo & get_type_info() const override
Definition: lstm_cell.hpp:62
Class for single lstm cell node.
Definition: lstm_cell.hpp:281
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f)
Constructs LSTMCell node.
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f)
Constructs LSTMCell node.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
const NodeTypeInfo & get_type_info() const override
Definition: lstm_cell.hpp:284
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27