12 #include "ngraph/node.hpp"
13 #include "ngraph/op/util/activation_functions.hpp"
14 #include "ngraph/op/util/fused_op.hpp"
15 #include "ngraph/op/util/rnn_cell_base.hpp"
21 enum class LSTMWeightsFormat
98 std::size_t hidden_size,
99 LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
100 const std::vector<std::string>& activations =
101 std::vector<std::string>{
"sigmoid",
"tanh",
"tanh"},
102 const std::vector<float>& activations_alpha = {},
103 const std::vector<float>& activations_beta = {},
105 bool input_forget =
false);
144 std::size_t hidden_size,
145 LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
146 const std::vector<std::string>& activations =
147 std::vector<std::string>{
"sigmoid",
"tanh",
"tanh"},
148 const std::vector<float>& activations_alpha = {},
149 const std::vector<float>& activations_beta = {},
151 bool input_forget =
false);
194 std::size_t hidden_size,
195 LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
196 const std::vector<std::string>& activations =
197 std::vector<std::string>{
"sigmoid",
"tanh",
"tanh"},
198 const std::vector<float>& activations_alpha = {},
199 const std::vector<float>& activations_beta = {},
201 bool input_forget =
false);
205 virtual std::shared_ptr<Node>
206 clone_with_new_inputs(
const OutputVector& new_args)
const override;
208 bool get_input_forget()
const {
return m_input_forget; }
209 LSTMWeightsFormat get_weights_format()
const {
return m_weights_format; }
240 bool m_input_forget =
false;
245 LSTMWeightsFormat m_weights_format;
247 static constexpr std::size_t s_gates_count{4};
248 static constexpr std::size_t s_peepholes_count{3};
315 std::size_t hidden_size,
316 const std::vector<std::string>& activations =
317 std::vector<std::string>{
"sigmoid",
"tanh",
"tanh"},
318 const std::vector<float>& activations_alpha = {},
319 const std::vector<float>& activations_beta = {},
355 std::size_t hidden_size,
356 const std::vector<std::string>& activations =
357 std::vector<std::string>{
"sigmoid",
"tanh",
"tanh"},
358 const std::vector<float>& activations_alpha = {},
359 const std::vector<float>& activations_beta = {},
365 std::shared_ptr<Node>
366 clone_with_new_inputs(
const OutputVector& new_args)
const override;
389 static constexpr std::size_t s_gates_count{4};
395 std::ostream& operator<<(std::ostream& s,
const op::LSTMWeightsFormat& type);
407 static constexpr
DiscreteTypeInfo type_info{
"AttributeAdapter<op::LSTMWeightsFormat>", 1};
408 const DiscreteTypeInfo& get_type_info()
const override {
return type_info; }
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:161
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
Access an enum via a string.
Definition: attribute_adapter.hpp:168
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:70
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:55
Class for single lstm cell node.
Definition: lstm_cell.hpp:59
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool input_forget=false)
Constructs LSTMCell node.
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool input_forget=false)
Constructs LSTMCell node.
virtual void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, const Output< Node > &P, std::size_t hidden_size, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool input_forget=false)
Constructs LSTMCell node.
const NodeTypeInfo & get_type_info() const override
Definition: lstm_cell.hpp:62
Class for single lstm cell node.
Definition: lstm_cell.hpp:281
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f)
Constructs LSTMCell node.
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f)
Constructs LSTMCell node.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
const NodeTypeInfo & get_type_info() const override
Definition: lstm_cell.hpp:284
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16