Class for lstm sequence node. More...
#include <lstm_sequence.hpp>
Public Types | |
using | direction = RecurrentSequenceDirection |
Public Member Functions | |
size_t | get_default_output_index () const override |
LSTMSequence (const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &sequence_lengths, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, const Output< Node > &P, const std::int64_t hidden_size, const direction lstm_direction, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< float > activations_alpha={}, const std::vector< float > activations_beta={}, const std::vector< std::string > activations={"sigmoid", "tanh", "tanh"}, const float clip_threshold=0, const bool input_forget=false) | |
LSTMSequence (const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &sequence_lengths, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, const std::int64_t hidden_size, const direction lstm_direction, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, const std::vector< std::string > &activations={"sigmoid", "tanh", "tanh"}, const float clip_threshold=0, const bool input_forget=false) | |
virtual void | validate_and_infer_types () override |
bool | visit_attributes (AttributeVisitor &visitor) override |
virtual OutputVector | decompose_op () const override |
virtual std::shared_ptr< Node > | clone_with_new_inputs (const OutputVector &new_args) const override |
std::vector< float > | get_activations_alpha () const |
std::vector< float > | get_activations_beta () const |
std::vector< std::string > | get_activations () const |
float | get_clip_threshold () const |
direction | get_direction () const |
std::int64_t | get_hidden_size () const |
bool | get_input_forget () const |
LSTMWeightsFormat | get_weights_format () const |
Data Fields | |
NGRAPH_RTTI_DECLARATION | |
Class for lstm sequence node.