class ngraph::op::v0::LSTMSequence

Overview

Class for lstm sequence node. More…

#include <lstm_sequence.hpp>

class LSTMSequence: public ngraph::op::util::FusedOp
{
public:
    // typedefs

    typedef RecurrentSequenceDirection direction;

    // construction

    LSTMSequence();

    LSTMSequence(
        const Output<Node>& X,
        const Output<Node>& initial_hidden_state,
        const Output<Node>& initial_cell_state,
        const Output<Node>& sequence_lengths,
        const Output<Node>& W,
        const Output<Node>& R,
        const Output<Node>& B,
        const Output<Node>& P,
        const std::int64_t hidden_size,
        const direction lstm_direction,
        LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
        const std::vector<float> activations_alpha = {},
        const std::vector<float> activations_beta = {},
        const std::vector<std::string> activations = {"sigmoid", "tanh", "tanh"},
        const float clip_threshold = 0,
        const bool input_forget = false
        );

    LSTMSequence(
        const Output<Node>& X,
        const Output<Node>& initial_hidden_state,
        const Output<Node>& initial_cell_state,
        const Output<Node>& sequence_lengths,
        const Output<Node>& W,
        const Output<Node>& R,
        const Output<Node>& B,
        const std::int64_t hidden_size,
        const direction lstm_direction,
        LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
        const std::vector<float>& activations_alpha = {},
        const std::vector<float>& activations_beta = {},
        const std::vector<std::string>& activations = {"sigmoid", "tanh", "tanh"},
        const float clip_threshold = 0,
        const bool input_forget = false
        );

    // methods

    size_t get_default_output_index() const;
    virtual void validate_and_infer_types();
    bool visit_attributes(AttributeVisitor& visitor);
    virtual OutputVector decompose_op() const;
    virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const;
    std::vector<float> get_activations_alpha() const;
    std::vector<float> get_activations_beta() const;
    std::vector<std::string> get_activations() const;
    float get_clip_threshold() const;
    direction get_direction() const;
    std::int64_t get_hidden_size() const;
    bool get_input_forget() const;
    LSTMWeightsFormat get_weights_format() const;
};

Detailed Documentation

Class for lstm sequence node.

It follows notation and equations defined as in ONNX standard: https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM

See also:

LSTMCell, RNNCell, GRUCell