24 #include "ngraph/node.hpp"
25 #include "ngraph/op/util/activation_functions.hpp"
26 #include "ngraph/op/util/fused_op.hpp"
27 #include "ngraph/op/util/rnn_cell_base.hpp"
33 enum class LSTMWeightsFormat
110 std::size_t hidden_size,
111 LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
112 const std::vector<std::string>& activations =
113 std::vector<std::string>{
"sigmoid",
"tanh",
"tanh"},
114 const std::vector<float>& activations_alpha = {},
115 const std::vector<float>& activations_beta = {},
117 bool input_forget =
false);
156 std::size_t hidden_size,
157 LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
158 const std::vector<std::string>& activations =
159 std::vector<std::string>{
"sigmoid",
"tanh",
"tanh"},
160 const std::vector<float>& activations_alpha = {},
161 const std::vector<float>& activations_beta = {},
163 bool input_forget =
false);
206 std::size_t hidden_size,
207 LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
208 const std::vector<std::string>& activations =
209 std::vector<std::string>{
"sigmoid",
"tanh",
"tanh"},
210 const std::vector<float>& activations_alpha = {},
211 const std::vector<float>& activations_beta = {},
213 bool input_forget =
false);
217 virtual std::shared_ptr<Node>
218 clone_with_new_inputs(
const OutputVector& new_args)
const override;
220 bool get_input_forget()
const {
return m_input_forget; }
221 LSTMWeightsFormat get_weights_format()
const {
return m_weights_format; }
251 bool m_input_forget =
false;
256 LSTMWeightsFormat m_weights_format;
258 static constexpr std::size_t s_gates_count{4};
259 static constexpr std::size_t s_peepholes_count{3};
326 std::size_t hidden_size,
327 const std::vector<std::string>& activations =
328 std::vector<std::string>{
"sigmoid",
"tanh",
"tanh"},
329 const std::vector<float>& activations_alpha = {},
330 const std::vector<float>& activations_beta = {},
366 std::size_t hidden_size,
367 const std::vector<std::string>& activations =
368 std::vector<std::string>{
"sigmoid",
"tanh",
"tanh"},
369 const std::vector<float>& activations_alpha = {},
370 const std::vector<float>& activations_beta = {},
376 std::shared_ptr<Node>
377 clone_with_new_inputs(
const OutputVector& new_args)
const override;
400 static constexpr std::size_t s_gates_count{4};
406 std::ostream& operator<<(std::ostream& s,
const op::LSTMWeightsFormat& type);
418 static constexpr
DiscreteTypeInfo type_info{
"AttributeAdapter<op::LSTMWeightsFormat>", 1};
419 const DiscreteTypeInfo& get_type_info()
const override {
return type_info; }
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:171
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
Access an enum via a string.
Definition: attribute_adapter.hpp:178
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:82
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:67
Class for single lstm cell node.
Definition: lstm_cell.hpp:71
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool input_forget=false)
Constructs LSTMCell node.
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool input_forget=false)
Constructs LSTMCell node.
virtual void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, const Output< Node > &P, std::size_t hidden_size, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool input_forget=false)
Constructs LSTMCell node.
const NodeTypeInfo & get_type_info() const override
Definition: lstm_cell.hpp:74
Class for single lstm cell node.
Definition: lstm_cell.hpp:292
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f)
Constructs LSTMCell node.
LSTMCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f)
Constructs LSTMCell node.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
const NodeTypeInfo & get_type_info() const override
Definition: lstm_cell.hpp:295
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28