Public Types | Public Member Functions | Data Fields
ngraph::op::v0::LSTMSequence Class Reference

Class for lstm sequence node. More...

#include <lstm_sequence.hpp>

Inheritance diagram for ngraph::op::v0::LSTMSequence:
Inheritance graph
[legend]
Collaboration diagram for ngraph::op::v0::LSTMSequence:
Collaboration graph
[legend]

Public Types

using direction = RecurrentSequenceDirection
 

Public Member Functions

size_t get_default_output_index () const override
 
 LSTMSequence (const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &sequence_lengths, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, const Output< Node > &P, const std::int64_t hidden_size, const direction lstm_direction, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< float > activations_alpha={}, const std::vector< float > activations_beta={}, const std::vector< std::string > activations={"sigmoid", "tanh", "tanh"}, const float clip_threshold=0, const bool input_forget=false)
 
 LSTMSequence (const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &initial_cell_state, const Output< Node > &sequence_lengths, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, const std::int64_t hidden_size, const direction lstm_direction, LSTMWeightsFormat weights_format=LSTMWeightsFormat::IFCO, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, const std::vector< std::string > &activations={"sigmoid", "tanh", "tanh"}, const float clip_threshold=0, const bool input_forget=false)
 
virtual void validate_and_infer_types () override
 
bool visit_attributes (AttributeVisitor &visitor) override
 
virtual OutputVector decompose_op () const override
 
virtual std::shared_ptr< Node > clone_with_new_inputs (const OutputVector &new_args) const override
 
std::vector< float > get_activations_alpha () const
 
std::vector< float > get_activations_beta () const
 
std::vector< std::string > get_activations () const
 
float get_clip_threshold () const
 
direction get_direction () const
 
std::int64_t get_hidden_size () const
 
bool get_input_forget () const
 
LSTMWeightsFormat get_weights_format () const
 

Data Fields

 NGRAPH_RTTI_DECLARATION
 

Detailed Description

Class for lstm sequence node.

Note
It follows notation and equations defined as in ONNX standard: https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
See also
LSTMCell, RNNCell, GRUCell

The documentation for this class was generated from the following file: