24 #include "ngraph/node.hpp"
25 #include "ngraph/op/util/activation_functions.hpp"
26 #include "ngraph/op/util/fused_op.hpp"
27 #include "ngraph/op/util/rnn_cell_base.hpp"
66 std::size_t hidden_size);
95 std::size_t hidden_size,
96 const std::vector<std::string>& activations,
97 const std::vector<float>& activations_alpha,
98 const std::vector<float>& activations_beta,
100 bool linear_before_reset);
142 std::size_t hidden_size,
143 const std::vector<std::string>& activations =
144 std::vector<std::string>{
"sigmoid",
"tanh"},
145 const std::vector<float>& activations_alpha = {},
146 const std::vector<float>& activations_beta = {},
148 bool linear_before_reset =
false);
152 virtual std::shared_ptr<Node>
153 clone_with_new_inputs(
const OutputVector& new_args)
const override;
155 bool get_linear_before_reset()
const {
return m_linear_before_reset; }
158 void add_default_bias_input();
169 static constexpr std::size_t s_gates_count{3};
176 bool m_linear_before_reset;
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:82
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:67
Class for GRU cell node.
Definition: gru_cell.hpp:44
const NodeTypeInfo & get_type_info() const override
Definition: gru_cell.hpp:47
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, const std::vector< std::string > &activations, const std::vector< float > &activations_alpha, const std::vector< float > &activations_beta, float clip, bool linear_before_reset)
Constructs GRUCell node.
virtual void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size)
Constructs GRUCell node.
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool linear_before_reset=false)
Constructs GRUCell node.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28