lstm_sequence.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <cstddef>
8 #include <cstdint>
9 #include <memory>
10 #include <string>
11 #include <vector>
12 
13 #include "ngraph/node.hpp"
14 #include "ngraph/op/constant.hpp"
15 #include "ngraph/op/lstm_cell.hpp"
16 #include "ngraph/op/util/attr_types.hpp"
17 #include "ngraph/op/util/fused_op.hpp"
18 #include "ngraph/op/util/rnn_cell_base.hpp"
19 
20 namespace ngraph
21 {
22  namespace op
23  {
24  namespace v0
25  {
26  NGRAPH_SUPPRESS_DEPRECATED_START
27 
28  ///
29  /// \brief Class for lstm sequence node.
30  ///
31  /// \note It follows notation and equations defined as in ONNX standard:
32  /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
33  ///
34  /// \sa LSTMCell, RNNCell, GRUCell
35  ///
36  ///
37  class NGRAPH_API LSTMSequence : public util::FusedOp
38  {
39  public:
40  NGRAPH_RTTI_DECLARATION;
41  LSTMSequence();
42 
44 
45  size_t get_default_output_index() const override { return no_default_index(); }
46  explicit LSTMSequence(const Output<Node>& X,
47  const Output<Node>& initial_hidden_state,
48  const Output<Node>& initial_cell_state,
49  const Output<Node>& sequence_lengths,
50  const Output<Node>& W,
51  const Output<Node>& R,
52  const Output<Node>& B,
53  const Output<Node>& P,
54  const std::int64_t hidden_size,
55  const direction lstm_direction,
56  LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
57  const std::vector<float> activations_alpha = {},
58  const std::vector<float> activations_beta = {},
59  const std::vector<std::string> activations = {"sigmoid",
60  "tanh",
61  "tanh"},
62  const float clip_threshold = 0,
63  const bool input_forget = false);
64 
65  explicit LSTMSequence(const Output<Node>& X,
66  const Output<Node>& initial_hidden_state,
67  const Output<Node>& initial_cell_state,
68  const Output<Node>& sequence_lengths,
69  const Output<Node>& W,
70  const Output<Node>& R,
71  const Output<Node>& B,
72  const std::int64_t hidden_size,
73  const direction lstm_direction,
74  LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
75  const std::vector<float>& activations_alpha = {},
76  const std::vector<float>& activations_beta = {},
77  const std::vector<std::string>& activations = {"sigmoid",
78  "tanh",
79  "tanh"},
80  const float clip_threshold = 0,
81  const bool input_forget = false);
82 
83  virtual void validate_and_infer_types() override;
84  bool visit_attributes(AttributeVisitor& visitor) override;
85  virtual OutputVector decompose_op() const override;
86 
87  virtual std::shared_ptr<Node>
88  clone_with_new_inputs(const OutputVector& new_args) const override;
89 
90  std::vector<float> get_activations_alpha() const { return m_activations_alpha; }
91  std::vector<float> get_activations_beta() const { return m_activations_beta; }
92  std::vector<std::string> get_activations() const { return m_activations; }
93  float get_clip_threshold() const { return m_clip_threshold; }
94  direction get_direction() const { return m_direction; }
95  std::int64_t get_hidden_size() const { return m_hidden_size; }
96  bool get_input_forget() const { return m_input_forget; }
97  LSTMWeightsFormat get_weights_format() const { return m_weights_format; }
98 
99  private:
100  ///
101  /// \brief Gets the masked value according to sequence length in a batch.
102  ///
103  /// \note Zeros out values or sets them to default value for inputs with
104  /// sequence length shorter than currently procssed time step.
105  ///
106  /// \param[in] data The input value.
107  /// \param[in] time_step The current time step denoting sequence length.
108  /// \param[in] batch_axis The batch axis index of data tensor.
109  /// \param[in] default_value The default value for masked elements.
110  ///
111  /// \return The masked value.
112  ///
113  std::shared_ptr<Node>
114  get_masked_node(const Output<Node>& data,
115  std::int32_t time_step,
116  std::size_t batch_axis = 0,
117  const Output<Node>& default_value = Output<Node>()) const;
118 
119  OutputVector lstm_pass(bool is_reverse = false) const;
120 
121  // Split(bi-directional) and squeeze input data to remove 'num_direction' dimension.
122  std::shared_ptr<Node> prepare_input(Output<Node> node,
123  bool is_reverse,
124  size_t num_direction_axis = 0) const;
125 
126  std::vector<float> m_activations_alpha;
127  std::vector<float> m_activations_beta;
128  std::vector<std::string> m_activations;
129  float m_clip_threshold;
130  direction m_direction;
131  std::int64_t m_hidden_size;
132  bool m_input_forget;
133  LSTMWeightsFormat m_weights_format;
134  };
135 
136  NGRAPH_SUPPRESS_DEPRECATED_END
137  } // namespace v0
138 
139  namespace v5
140  {
141  ///
142  /// \brief Class for lstm sequence node.
143  ///
144  /// \note It follows notation and equations defined as in ONNX standard:
145  /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
146  ///
147  /// \sa LSTMCell, RNNCell, GRUCell
148  ///
149  ///
150  class NGRAPH_API LSTMSequence : public util::RNNCellBase
151  {
152  public:
153  NGRAPH_RTTI_DECLARATION;
154  LSTMSequence() = default;
155 
157 
158  size_t get_default_output_index() const override { return no_default_index(); }
159  explicit LSTMSequence(const Output<Node>& X,
160  const Output<Node>& initial_hidden_state,
161  const Output<Node>& initial_cell_state,
162  const Output<Node>& sequence_lengths,
163  const Output<Node>& W,
164  const Output<Node>& R,
165  const Output<Node>& B,
166  const std::int64_t hidden_size,
167  const direction lstm_direction,
168  const std::vector<float>& activations_alpha = {},
169  const std::vector<float>& activations_beta = {},
170  const std::vector<std::string>& activations = {"sigmoid",
171  "tanh",
172  "tanh"},
173  const float clip = 0.f)
174  : RNNCellBase(
175  {X, initial_hidden_state, initial_cell_state, sequence_lengths, W, R, B},
176  hidden_size,
177  clip,
178  activations,
179  activations_alpha,
180  activations_beta)
181  , m_direction(lstm_direction)
182  {
183  constructor_validate_and_infer_types();
184  }
185 
186  void validate_and_infer_types() override;
187  bool visit_attributes(AttributeVisitor& visitor) override;
188 
189  std::shared_ptr<Node>
190  clone_with_new_inputs(const OutputVector& new_args) const override;
191 
192  direction get_direction() const { return m_direction; }
193 
194  private:
195  direction m_direction;
196  };
197  } // namespace v5
198  } // namespace op
199 
200 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:55
Class for lstm sequence node.
Definition: lstm_sequence.hpp:38
Class for lstm sequence node.
Definition: lstm_sequence.hpp:151
size_t get_default_output_index() const override
Returns the output of the default output, or throws if there is none.
Definition: lstm_sequence.hpp:158
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
RecurrentSequenceDirection
This class defines possible recurrent sequence directions.
Definition: attr_types.hpp:424
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16