25 #include "ngraph/node.hpp"
26 #include "ngraph/op/constant.hpp"
27 #include "ngraph/op/lstm_cell.hpp"
28 #include "ngraph/op/util/attr_types.hpp"
29 #include "ngraph/op/util/fused_op.hpp"
30 #include "ngraph/op/util/rnn_cell_base.hpp"
38 NGRAPH_SUPPRESS_DEPRECATED_START
52 NGRAPH_RTTI_DECLARATION;
57 size_t get_default_output_index()
const override {
return no_default_index(); }
66 const std::int64_t hidden_size,
68 LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
69 const std::vector<float> activations_alpha = {},
70 const std::vector<float> activations_beta = {},
71 const std::vector<std::string> activations = {
"sigmoid",
74 const float clip_threshold = 0,
75 const bool input_forget =
false);
84 const std::int64_t hidden_size,
86 LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
87 const std::vector<float>& activations_alpha = {},
88 const std::vector<float>& activations_beta = {},
89 const std::vector<std::string>& activations = {
"sigmoid",
92 const float clip_threshold = 0,
93 const bool input_forget =
false);
95 virtual void validate_and_infer_types()
override;
97 virtual OutputVector decompose_op()
const override;
99 virtual std::shared_ptr<Node>
100 clone_with_new_inputs(
const OutputVector& new_args)
const override;
102 std::vector<float> get_activations_alpha()
const {
return m_activations_alpha; }
103 std::vector<float> get_activations_beta()
const {
return m_activations_beta; }
104 std::vector<std::string> get_activations()
const {
return m_activations; }
105 float get_clip_threshold()
const {
return m_clip_threshold; }
106 direction get_direction()
const {
return m_direction; }
107 std::int64_t get_hidden_size()
const {
return m_hidden_size; }
108 bool get_input_forget()
const {
return m_input_forget; }
109 LSTMWeightsFormat get_weights_format()
const {
return m_weights_format; }
124 std::shared_ptr<Node>
126 std::int32_t time_step,
127 std::size_t batch_axis = 0,
130 OutputVector lstm_pass(
bool is_reverse =
false)
const;
135 size_t num_direction_axis = 0)
const;
137 std::vector<float> m_activations_alpha;
138 std::vector<float> m_activations_beta;
139 std::vector<std::string> m_activations;
140 float m_clip_threshold;
142 std::int64_t m_hidden_size;
144 LSTMWeightsFormat m_weights_format;
147 NGRAPH_SUPPRESS_DEPRECATED_END
164 NGRAPH_RTTI_DECLARATION;
177 const std::int64_t hidden_size,
178 const direction lstm_direction,
179 const std::vector<float>& activations_alpha = {},
180 const std::vector<float>& activations_beta = {},
181 const std::vector<std::string>& activations = {
"sigmoid",
184 const float clip = 0.f)
186 {X, initial_hidden_state, initial_cell_state, sequence_lengths, W, R, B},
192 , m_direction(lstm_direction)
194 constructor_validate_and_infer_types();
200 std::shared_ptr<Node>
201 clone_with_new_inputs(
const OutputVector& new_args)
const override;
203 direction get_direction()
const {
return m_direction; }
205 direction m_direction;
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:67
Class for lstm sequence node.
Definition: lstm_sequence.hpp:50
Class for lstm sequence node.
Definition: lstm_sequence.hpp:162
size_t get_default_output_index() const override
Returns the output of the default output, or throws if there is none.
Definition: lstm_sequence.hpp:169
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
RecurrentSequenceDirection
This class defines possible recurrent sequence directions.
Definition: attr_types.hpp:432
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28