lstm_cell.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "ngraph/node.hpp"
25 #include "ngraph/op/util/activation_functions.hpp"
26 #include "ngraph/op/util/fused_op.hpp"
27 #include "ngraph/op/util/rnn_cell_base.hpp"
28 
29 namespace ngraph
30 {
31  namespace op
32  {
33  enum class LSTMWeightsFormat
34  {
35  FICO, // IE
36  ICOF, // PyTorch
37  IFCO, // DNNL, TF, MxNet
38  IFOC, // Caffe
39  IOFC, // ONNX
40  };
41 
42  namespace v0
43  {
44  ///
45  /// \brief Class for single lstm cell node.
46  ///
47  /// \note Following implementation supports:
48  /// \li \c peepholes Gers & Schmidhuber (2000)
49  /// https://ieeexplore.ieee.org/document/861302
50  /// \li Coupling input and forget gates.
51  ///
52  /// \note It calculates following equations:
53  ///
54  /// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
55  /// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
56  /// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
57  /// Ct = ft (.) Ct-1 + it (.) ct
58  /// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
59  /// Ht = ot (.) h(Ct)
60  ///
61  /// * - Is a dot product,
62  /// (.) - is a Hadamard product (element-wise),
63  /// f, g, h - are activation functions.
64  ///
65  /// \note This class represents only single *cell* (for current time step) and not
66  /// the whole LSTM Sequence layer
67  ///
68  /// \sa LSTMSequence, RNNCell, GRUCell
69  ///
70  class NGRAPH_API LSTMCell : public util::RNNCellBase
71  {
72  public:
73  static constexpr NodeTypeInfo type_info{"LSTMCell", 0};
74  const NodeTypeInfo& get_type_info() const override { return type_info; }
75  LSTMCell();
76  ///
77  /// \brief Constructs LSTMCell node.
78  ///
79  /// \param[in] X The input tensor with shape: [batch_size,
80  /// input_size].
81  /// \param[in] initial_hidden_state The hidden state tensor at current time step
82  /// with shape: [batch_size, hidden_size].
83  /// \param[in] initial_cell_state The cell state tensor at current time step
84  /// with shape: [batch_size, hidden_size].
85  /// \param[in] W The gate weights tensor with shape:
86  /// [4*hidden_size, input_size].
87  /// \param[in] R The recurrence weights tensor with shape:
88  /// [4*hidden_size, hidden_size].
89  /// \param[in] hidden_size The number of hidden units for recurrent cell.
90  /// \param[in] weights_format The order of gates in weights tensors. The
91  /// default format is IFCO since it is used by
92  /// DNNL.
93  /// \param[in] activations The vector of activation functions used inside
94  /// recurrent cell.
95  /// \param[in] activations_alpha The vector of alpha parameters for activation
96  /// functions in order respective to activation
97  /// list.
98  /// \param[in] activations_beta The vector of beta parameters for activation
99  /// functions in order respective to activation
100  /// list.
101  /// \param[in] clip The value defining clipping range [-clip,
102  /// clip] on input of activation functions.
103  /// \param[in] input_forget Controls coupling input and forget gates.
104  ///
106  const Output<Node>& initial_hidden_state,
107  const Output<Node>& initial_cell_state,
108  const Output<Node>& W,
109  const Output<Node>& R,
110  std::size_t hidden_size,
111  LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
112  const std::vector<std::string>& activations =
113  std::vector<std::string>{"sigmoid", "tanh", "tanh"},
114  const std::vector<float>& activations_alpha = {},
115  const std::vector<float>& activations_beta = {},
116  float clip = 0.f,
117  bool input_forget = false);
118 
119  ///
120  /// \brief Constructs LSTMCell node.
121  ///
122  /// \param[in] X The input tensor with shape: [batch_size,
123  /// input_size].
124  /// \param[in] initial_hidden_state The hidden state tensor at current time step
125  /// with shape: [batch_size, hidden_size].
126  /// \param[in] initial_cell_state The cell state tensor at current time step
127  /// with shape: [batch_size, hidden_size].
128  /// \param[in] W The weight tensor with shape: [4*hidden_size,
129  /// input_size].
130  /// \param[in] R The recurrence weight tensor with shape:
131  /// [4*hidden_size, hidden_size].
132  /// \param[in] B The bias tensor for gates with shape:
133  /// [4*hidden_size].
134  /// \param[in] hidden_size The number of hidden units for recurrent cell.
135  /// \param[in] weights_format The order of gates in weights tensors. The
136  /// default format is IFCO since it is used by
137  /// DNNL.
138  /// \param[in] activations The vector of activation functions used inside
139  /// recurrent cell.
140  /// \param[in] activations_alpha The vector of alpha parameters for activation
141  /// functions in order respective to activation
142  /// list.
143  /// \param[in] activations_beta The vector of beta parameters for activation
144  /// functions in order respective to activation
145  /// list.
146  /// \param[in] clip The value defining clipping range [-clip,
147  /// clip] on input of activation functions.
148  /// \param[in] input_forget Controls coupling input and forget gates.
149  ///
151  const Output<Node>& initial_hidden_state,
152  const Output<Node>& initial_cell_state,
153  const Output<Node>& W,
154  const Output<Node>& R,
155  const Output<Node>& B,
156  std::size_t hidden_size,
157  LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
158  const std::vector<std::string>& activations =
159  std::vector<std::string>{"sigmoid", "tanh", "tanh"},
160  const std::vector<float>& activations_alpha = {},
161  const std::vector<float>& activations_beta = {},
162  float clip = 0.f,
163  bool input_forget = false);
164 
165  ///
166  /// \brief Constructs LSTMCell node.
167  ///
168  /// \param[in] X The input tensor with shape: [batch_size,
169  /// input_size].
170  /// \param[in] initial_hidden_state The hidden state tensor at current time step
171  /// with shape: [batch_size, hidden_size].
172  /// \param[in] initial_cell_state The cell state tensor at current time step
173  /// with shape: [batch_size, hidden_size].
174  /// \param[in] W The weight tensor with shape: [4*hidden_size,
175  /// input_size].
176  /// \param[in] R The recurrence weight tensor with shape:
177  /// [4*hidden_size, hidden_size].
178  /// \param[in] B The bias tensor for gates with shape:
179  /// [4*hidden_size].
180  /// \param[in] P The weight tensor for peepholes with shape:
181  /// [3*hidden_size] - 3 equals to only iof gates.
182  /// The order is: input, output, forget gates.
183  /// \param[in] hidden_size The number of hidden units for recurrent cell.
184  /// \param[in] weights_format The order of gates in weights tensors. The
185  /// default format is IFCO since it is used by
186  /// DNNL.
187  /// \param[in] activations The vector of activation functions used inside
188  /// recurrent cell.
189  /// \param[in] activations_alpha The vector of alpha parameters for activation
190  /// functions in order respective to activation
191  /// list.
192  /// \param[in] activations_beta The vector of beta parameters for activation
193  /// functions in order respective to activation
194  /// list.
195  /// \param[in] clip The value defining clipping range [-clip,
196  /// clip] on input of activation functions.
197  /// \param[in] input_forget Controls coupling input and forget gates.
198  ///
200  const Output<Node>& initial_hidden_state,
201  const Output<Node>& initial_cell_state,
202  const Output<Node>& W,
203  const Output<Node>& R,
204  const Output<Node>& B,
205  const Output<Node>& P,
206  std::size_t hidden_size,
207  LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
208  const std::vector<std::string>& activations =
209  std::vector<std::string>{"sigmoid", "tanh", "tanh"},
210  const std::vector<float>& activations_alpha = {},
211  const std::vector<float>& activations_beta = {},
212  float clip = 0.f,
213  bool input_forget = false);
214 
215  virtual void validate_and_infer_types() override;
216  bool visit_attributes(AttributeVisitor& visitor) override;
217  virtual std::shared_ptr<Node>
218  clone_with_new_inputs(const OutputVector& new_args) const override;
219 
220  bool get_input_forget() const { return m_input_forget; }
221  LSTMWeightsFormat get_weights_format() const { return m_weights_format; }
222  private:
223  ///
224  /// \brief Creates the default bias input initialized with zeros.
225  ///
226  /// \return The object of Output class.
227  ///
228  Output<Node> get_default_bias_input() const;
229 
230  ///
231  /// \brief Creates the default peepholes input initialized with zeros.
232  ///
233  /// \return The object of Output class.
234  ///
235  Output<Node> get_default_peepholes_input() const;
236  ///
237  /// \brief The Activation function f.
238  ///
239  util::ActivationFunction m_activation_f;
240  ///
241  /// \brief The Activation function g.
242  ///
243  util::ActivationFunction m_activation_g;
244  ///
245  /// \brief The Activation function h.
246  ///
247  util::ActivationFunction m_activation_h;
248  ///
249  /// \brief Controls whether to couple input and forget gates.
250  ///
251  bool m_input_forget = false;
252 
253  ///
254  /// \brief The order of gates in weights tensors.
255  ///
256  LSTMWeightsFormat m_weights_format;
257 
258  static constexpr std::size_t s_gates_count{4};
259  static constexpr std::size_t s_peepholes_count{3};
260  };
261  } // v0
262 
263  namespace v4
264  {
265  ///
266  /// \brief Class for single lstm cell node.
267  ///
268  /// \note Following implementation supports:
269  /// \li \c peepholes Gers & Schmidhuber (2000)
270  /// https://ieeexplore.ieee.org/document/861302
271  /// \li Coupling input and forget gates.
272  ///
273  /// \note It calculates following equations:
274  ///
275  /// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
276  /// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
277  /// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
278  /// Ct = ft (.) Ct-1 + it (.) ct
279  /// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
280  /// Ht = ot (.) h(Ct)
281  ///
282  /// * - Is a dot product,
283  /// (.) - is a Hadamard product (element-wise),
284  /// f, g, h - are activation functions.
285  ///
286  /// \note This class represents only single *cell* (for current time step) and not
287  /// the whole LSTM Sequence layer
288  ///
289  /// \sa LSTMSequence, RNNCell, GRUCell
290  ///
291  class NGRAPH_API LSTMCell : public util::RNNCellBase
292  {
293  public:
294  static constexpr NodeTypeInfo type_info{"LSTMCell", 4};
295  const NodeTypeInfo& get_type_info() const override { return type_info; }
296  LSTMCell();
297  ///
298  /// \brief Constructs LSTMCell node.
299  ///
300  /// \param[in] X The input tensor with shape: [batch_size,
301  /// input_size].
302  /// \param[in] initial_hidden_state The hidden state tensor at current time step
303  /// with shape: [batch_size, hidden_size].
304  /// \param[in] initial_cell_state The cell state tensor at current time step
305  /// with shape: [batch_size, hidden_size].
306  /// \param[in] W The gate weights tensor with shape:
307  /// [4*hidden_size, input_size].
308  /// \param[in] R The recurrence weights tensor with shape:
309  /// [4*hidden_size, hidden_size].
310  /// \param[in] hidden_size The number of hidden units for recurrent cell.
311  /// \param[in] activations The vector of activation functions used inside
312  /// recurrent cell.
313  /// \param[in] activations_alpha The vector of alpha parameters for activation
314  /// functions in order respective to activation
315  /// list.
316  /// \param[in] activations_beta The vector of beta parameters for activation
317  /// functions in order respective to activation
318  /// list.
319  /// \param[in] clip The value defining clipping range [-clip,
320  /// clip] on input of activation functions.
322  const Output<Node>& initial_hidden_state,
323  const Output<Node>& initial_cell_state,
324  const Output<Node>& W,
325  const Output<Node>& R,
326  std::size_t hidden_size,
327  const std::vector<std::string>& activations =
328  std::vector<std::string>{"sigmoid", "tanh", "tanh"},
329  const std::vector<float>& activations_alpha = {},
330  const std::vector<float>& activations_beta = {},
331  float clip = 0.f);
332 
333  ///
334  /// \brief Constructs LSTMCell node.
335  ///
336  /// \param[in] X The input tensor with shape: [batch_size,
337  /// input_size].
338  /// \param[in] initial_hidden_state The hidden state tensor at current time step
339  /// with shape: [batch_size, hidden_size].
340  /// \param[in] initial_cell_state The cell state tensor at current time step
341  /// with shape: [batch_size, hidden_size].
342  /// \param[in] W The weight tensor with shape: [4*hidden_size,
343  /// input_size].
344  /// \param[in] R The recurrence weight tensor with shape:
345  /// [4*hidden_size, hidden_size].
346  /// \param[in] B The bias tensor for gates with shape:
347  /// [4*hidden_size].
348  /// \param[in] hidden_size The number of hidden units for recurrent cell.
349  /// \param[in] activations The vector of activation functions used inside
350  /// recurrent cell.
351  /// \param[in] activations_alpha The vector of alpha parameters for activation
352  /// functions in order respective to activation
353  /// list.
354  /// \param[in] activations_beta The vector of beta parameters for activation
355  /// functions in order respective to activation
356  /// list.
357  /// \param[in] clip The value defining clipping range [-clip,
358  /// clip] on input of activation functions.
359  ///
361  const Output<Node>& initial_hidden_state,
362  const Output<Node>& initial_cell_state,
363  const Output<Node>& W,
364  const Output<Node>& R,
365  const Output<Node>& B,
366  std::size_t hidden_size,
367  const std::vector<std::string>& activations =
368  std::vector<std::string>{"sigmoid", "tanh", "tanh"},
369  const std::vector<float>& activations_alpha = {},
370  const std::vector<float>& activations_beta = {},
371  float clip = 0.f);
372 
373  void validate_and_infer_types() override;
374 
375  bool visit_attributes(AttributeVisitor& visitor) override;
376  std::shared_ptr<Node>
377  clone_with_new_inputs(const OutputVector& new_args) const override;
378 
379  private:
380  ///
381  /// \brief Creates the default bias input initialized with zeros.
382  ///
383  /// \return The object of Output class.
384  ///
385  Output<Node> get_default_bias_input() const;
386 
387  ///
388  /// \brief The Activation function f.
389  ///
390  util::ActivationFunction m_activation_f;
391  ///
392  /// \brief The Activation function g.
393  ///
394  util::ActivationFunction m_activation_g;
395  ///
396  /// \brief The Activation function h.
397  ///
398  util::ActivationFunction m_activation_h;
399 
400  static constexpr std::size_t s_gates_count{4};
401  };
402  } // v4
403  } // namespace op
404 
405  NGRAPH_API
406  std::ostream& operator<<(std::ostream& s, const op::LSTMWeightsFormat& type);
407 
408  template <>
409  class NGRAPH_API AttributeAdapter<op::LSTMWeightsFormat>
410  : public EnumAttributeAdapterBase<op::LSTMWeightsFormat>
411  {
412  public:
413  AttributeAdapter(op::LSTMWeightsFormat& value)
415  {
416  }
417 
418  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::LSTMWeightsFormat>", 1};
419  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
420  };
421 } // namespace ngraph
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:171
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
Access an enum via a string.
Definition: attribute_adapter.hpp:178
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:82
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:67
Class for single lstm cell node.
Definition: lstm_cell.hpp:71
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool input_forget=false)
Constructs LSTMCell node.
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool input_forget=false)
Constructs LSTMCell node.
virtual void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, const Output< Node > &P, std::size_t hidden_size, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool input_forget=false)
Constructs LSTMCell node.
const NodeTypeInfo & get_type_info() const override
Definition: lstm_cell.hpp:74
Class for single lstm cell node.
Definition: lstm_cell.hpp:292
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f)
Constructs LSTMCell node.
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f)
Constructs LSTMCell node.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
const NodeTypeInfo & get_type_info() const override
Definition: lstm_cell.hpp:295
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
Definition: type.hpp:39