lstm_sequence.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <cstddef>
20 #include <cstdint>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "ngraph/node.hpp"
26 #include "ngraph/op/constant.hpp"
27 #include "ngraph/op/lstm_cell.hpp"
28 #include "ngraph/op/util/attr_types.hpp"
29 #include "ngraph/op/util/fused_op.hpp"
30 #include "ngraph/op/util/rnn_cell_base.hpp"
31 
32 namespace ngraph
33 {
34  namespace op
35  {
36  namespace v0
37  {
38  NGRAPH_SUPPRESS_DEPRECATED_START
39 
40  ///
41  /// \brief Class for lstm sequence node.
42  ///
43  /// \note It follows notation and equations defined as in ONNX standard:
44  /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
45  ///
46  /// \sa LSTMCell, RNNCell, GRUCell
47  ///
48  ///
49  class NGRAPH_API LSTMSequence : public util::FusedOp
50  {
51  public:
52  NGRAPH_RTTI_DECLARATION;
53  LSTMSequence();
54 
56 
57  size_t get_default_output_index() const override { return no_default_index(); }
58  explicit LSTMSequence(const Output<Node>& X,
59  const Output<Node>& initial_hidden_state,
60  const Output<Node>& initial_cell_state,
61  const Output<Node>& sequence_lengths,
62  const Output<Node>& W,
63  const Output<Node>& R,
64  const Output<Node>& B,
65  const Output<Node>& P,
66  const std::int64_t hidden_size,
67  const direction lstm_direction,
68  LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
69  const std::vector<float> activations_alpha = {},
70  const std::vector<float> activations_beta = {},
71  const std::vector<std::string> activations = {"sigmoid",
72  "tanh",
73  "tanh"},
74  const float clip_threshold = 0,
75  const bool input_forget = false);
76 
77  explicit LSTMSequence(const Output<Node>& X,
78  const Output<Node>& initial_hidden_state,
79  const Output<Node>& initial_cell_state,
80  const Output<Node>& sequence_lengths,
81  const Output<Node>& W,
82  const Output<Node>& R,
83  const Output<Node>& B,
84  const std::int64_t hidden_size,
85  const direction lstm_direction,
86  LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
87  const std::vector<float>& activations_alpha = {},
88  const std::vector<float>& activations_beta = {},
89  const std::vector<std::string>& activations = {"sigmoid",
90  "tanh",
91  "tanh"},
92  const float clip_threshold = 0,
93  const bool input_forget = false);
94 
95  virtual void validate_and_infer_types() override;
96  bool visit_attributes(AttributeVisitor& visitor) override;
97  virtual OutputVector decompose_op() const override;
98 
99  virtual std::shared_ptr<Node>
100  clone_with_new_inputs(const OutputVector& new_args) const override;
101 
102  std::vector<float> get_activations_alpha() const { return m_activations_alpha; }
103  std::vector<float> get_activations_beta() const { return m_activations_beta; }
104  std::vector<std::string> get_activations() const { return m_activations; }
105  float get_clip_threshold() const { return m_clip_threshold; }
106  direction get_direction() const { return m_direction; }
107  std::int64_t get_hidden_size() const { return m_hidden_size; }
108  bool get_input_forget() const { return m_input_forget; }
109  LSTMWeightsFormat get_weights_format() const { return m_weights_format; }
110  private:
111  ///
112  /// \brief Gets the masked value according to sequence lenght in a batch.
113  ///
114  /// \note Zeros out values or sets them to default value for inputs with
115  /// sequence lenght shorter than currently procssed time step.
116  ///
117  /// \param[in] data The input value.
118  /// \param[in] time_step The current time step denoting sequence lenght.
119  /// \param[in] batch_axis The batch axis index of data tensor.
120  /// \param[in] default_value The default value for masked elements.
121  ///
122  /// \return The masked value.
123  ///
124  std::shared_ptr<Node>
125  get_masked_node(const Output<Node>& data,
126  std::int32_t time_step,
127  std::size_t batch_axis = 0,
128  const Output<Node>& default_value = Output<Node>()) const;
129 
130  OutputVector lstm_pass(bool is_reverse = false) const;
131 
132  // Split(bi-directional) and squeeze input data to remove 'num_direction' dimension.
133  std::shared_ptr<Node> prepare_input(Output<Node> node,
134  bool is_reverse,
135  size_t num_direction_axis = 0) const;
136 
137  std::vector<float> m_activations_alpha;
138  std::vector<float> m_activations_beta;
139  std::vector<std::string> m_activations;
140  float m_clip_threshold;
141  direction m_direction;
142  std::int64_t m_hidden_size;
143  bool m_input_forget;
144  LSTMWeightsFormat m_weights_format;
145  };
146 
147  NGRAPH_SUPPRESS_DEPRECATED_END
148  }
149 
150  namespace v5
151  {
152  ///
153  /// \brief Class for lstm sequence node.
154  ///
155  /// \note It follows notation and equations defined as in ONNX standard:
156  /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
157  ///
158  /// \sa LSTMCell, RNNCell, GRUCell
159  ///
160  ///
161  class NGRAPH_API LSTMSequence : public util::RNNCellBase
162  {
163  public:
164  NGRAPH_RTTI_DECLARATION;
165  LSTMSequence() = default;
166 
168 
169  size_t get_default_output_index() const override { return no_default_index(); }
170  explicit LSTMSequence(const Output<Node>& X,
171  const Output<Node>& initial_hidden_state,
172  const Output<Node>& initial_cell_state,
173  const Output<Node>& sequence_lengths,
174  const Output<Node>& W,
175  const Output<Node>& R,
176  const Output<Node>& B,
177  const std::int64_t hidden_size,
178  const direction lstm_direction,
179  const std::vector<float>& activations_alpha = {},
180  const std::vector<float>& activations_beta = {},
181  const std::vector<std::string>& activations = {"sigmoid",
182  "tanh",
183  "tanh"},
184  const float clip = 0.f)
185  : RNNCellBase(
186  {X, initial_hidden_state, initial_cell_state, sequence_lengths, W, R, B},
187  hidden_size,
188  clip,
189  activations,
190  activations_alpha,
191  activations_beta)
192  , m_direction(lstm_direction)
193  {
194  constructor_validate_and_infer_types();
195  }
196 
197  void validate_and_infer_types() override;
198  bool visit_attributes(AttributeVisitor& visitor) override;
199 
200  std::shared_ptr<Node>
201  clone_with_new_inputs(const OutputVector& new_args) const override;
202 
203  direction get_direction() const { return m_direction; }
204  private:
205  direction m_direction;
206  };
207  }
208  } // namespace op
209 
210 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:67
Class for lstm sequence node.
Definition: lstm_sequence.hpp:50
Class for lstm sequence node.
Definition: lstm_sequence.hpp:162
size_t get_default_output_index() const override
Returns the output of the default output, or throws if there is none.
Definition: lstm_sequence.hpp:169
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
RecurrentSequenceDirection
This class defines possible recurrent sequence directions.
Definition: attr_types.hpp:432
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28