24 #include "ngraph/node.hpp"
25 #include "ngraph/op/util/activation_functions.hpp"
26 #include "ngraph/op/util/fused_op.hpp"
27 #include "ngraph/op/util/rnn_cell_base.hpp"
56 static constexpr NodeTypeInfo type_info{
"RNNCell", 0};
57 const NodeTypeInfo&
get_type_info()
const override {
return type_info; }
83 const Output<Node>& X,
84 const Output<Node>& initial_hidden_state,
85 const Output<Node>& W,
86 const Output<Node>& R,
87 std::size_t hidden_size,
88 const std::vector<std::string>& activations = std::vector<std::string>{
"tanh"},
89 const std::vector<float>& activations_alpha = {},
90 const std::vector<float>& activations_beta = {},
119 const Output<Node>& X,
120 const Output<Node>& initial_hidden_state,
121 const Output<Node>& W,
122 const Output<Node>& R,
123 const Output<Node>& B,
124 std::size_t hidden_size,
125 const std::vector<std::string>& activations = std::vector<std::string>{
"tanh"},
126 const std::vector<float>& activations_alpha = {},
127 const std::vector<float>& activations_beta = {},
132 std::shared_ptr<Node>
133 clone_with_new_inputs(
const OutputVector& new_args)
const override;
141 Output<Node> get_default_bias_input()
const;
148 static constexpr std::size_t s_gates_count{1};
Class for single RNN cell node.
Definition: rnn_cell.hpp:54
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:66
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
RNNCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f)
Constructs RNNCell node.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
const NodeTypeInfo & get_type_info() const override
Definition: rnn_cell.hpp:57
RNNCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f)
Constructs RNNCell node.
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:82
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:70