13 #include "ngraph/node.hpp"
14 #include "ngraph/op/constant.hpp"
15 #include "ngraph/op/lstm_cell.hpp"
16 #include "ngraph/op/util/attr_types.hpp"
17 #include "ngraph/op/util/fused_op.hpp"
18 #include "ngraph/op/util/rnn_cell_base.hpp"
26 NGRAPH_SUPPRESS_DEPRECATED_START
40 NGRAPH_RTTI_DECLARATION;
45 size_t get_default_output_index()
const override {
return no_default_index(); }
54 const std::int64_t hidden_size,
56 LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
57 const std::vector<float> activations_alpha = {},
58 const std::vector<float> activations_beta = {},
59 const std::vector<std::string> activations = {
"sigmoid",
62 const float clip_threshold = 0,
63 const bool input_forget =
false);
72 const std::int64_t hidden_size,
74 LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
75 const std::vector<float>& activations_alpha = {},
76 const std::vector<float>& activations_beta = {},
77 const std::vector<std::string>& activations = {
"sigmoid",
80 const float clip_threshold = 0,
81 const bool input_forget =
false);
83 virtual void validate_and_infer_types()
override;
85 virtual OutputVector decompose_op()
const override;
87 virtual std::shared_ptr<Node>
88 clone_with_new_inputs(
const OutputVector& new_args)
const override;
90 std::vector<float> get_activations_alpha()
const {
return m_activations_alpha; }
91 std::vector<float> get_activations_beta()
const {
return m_activations_beta; }
92 std::vector<std::string> get_activations()
const {
return m_activations; }
93 float get_clip_threshold()
const {
return m_clip_threshold; }
94 direction get_direction()
const {
return m_direction; }
95 std::int64_t get_hidden_size()
const {
return m_hidden_size; }
96 bool get_input_forget()
const {
return m_input_forget; }
97 LSTMWeightsFormat get_weights_format()
const {
return m_weights_format; }
113 std::shared_ptr<Node>
115 std::int32_t time_step,
116 std::size_t batch_axis = 0,
119 OutputVector lstm_pass(
bool is_reverse =
false)
const;
124 size_t num_direction_axis = 0)
const;
126 std::vector<float> m_activations_alpha;
127 std::vector<float> m_activations_beta;
128 std::vector<std::string> m_activations;
129 float m_clip_threshold;
131 std::int64_t m_hidden_size;
133 LSTMWeightsFormat m_weights_format;
136 NGRAPH_SUPPRESS_DEPRECATED_END
153 NGRAPH_RTTI_DECLARATION;
166 const std::int64_t hidden_size,
167 const direction lstm_direction,
168 const std::vector<float>& activations_alpha = {},
169 const std::vector<float>& activations_beta = {},
170 const std::vector<std::string>& activations = {
"sigmoid",
173 const float clip = 0.f)
175 {X, initial_hidden_state, initial_cell_state, sequence_lengths, W, R, B},
181 , m_direction(lstm_direction)
183 constructor_validate_and_infer_types();
189 std::shared_ptr<Node>
190 clone_with_new_inputs(
const OutputVector& new_args)
const override;
192 direction get_direction()
const {
return m_direction; }
195 direction m_direction;
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:55
Class for lstm sequence node.
Definition: lstm_sequence.hpp:38
Class for lstm sequence node.
Definition: lstm_sequence.hpp:151
size_t get_default_output_index() const override
Returns the output of the default output, or throws if there is none.
Definition: lstm_sequence.hpp:158
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
RecurrentSequenceDirection
This class defines possible recurrent sequence directions.
Definition: attr_types.hpp:424
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16