24 #include "ngraph/node.hpp"
25 #include "ngraph/op/util/activation_functions.hpp"
26 #include "ngraph/op/util/fused_op.hpp"
27 #include "ngraph/op/util/rnn_cell_base.hpp"
46 static constexpr NodeTypeInfo type_info{
"GRUCell", 3};
47 const NodeTypeInfo&
get_type_info()
const override {
return type_info; }
63 const Output<Node>& initial_hidden_state,
64 const Output<Node>& W,
65 const Output<Node>& R,
66 std::size_t hidden_size);
92 const Output<Node>& initial_hidden_state,
93 const Output<Node>& W,
94 const Output<Node>& R,
95 std::size_t hidden_size,
96 const std::vector<std::string>& activations,
97 const std::vector<float>& activations_alpha,
98 const std::vector<float>& activations_beta,
100 bool linear_before_reset);
138 const Output<Node>& initial_hidden_state,
139 const Output<Node>& W,
140 const Output<Node>& R,
141 const Output<Node>& B,
142 std::size_t hidden_size,
143 const std::vector<std::string>& activations =
144 std::vector<std::string>{
"sigmoid",
"tanh"},
145 const std::vector<float>& activations_alpha = {},
146 const std::vector<float>& activations_beta = {},
148 bool linear_before_reset =
false);
152 virtual std::shared_ptr<Node>
153 clone_with_new_inputs(
const OutputVector& new_args)
const override;
155 bool get_linear_before_reset()
const {
return m_linear_before_reset; }
158 void add_default_bias_input();
169 static constexpr std::size_t s_gates_count{3};
176 bool m_linear_before_reset;
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size)
Constructs GRUCell node.
const NodeTypeInfo & get_type_info() const override
Definition: gru_cell.hpp:47
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:66
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool linear_before_reset=false)
Constructs GRUCell node.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
virtual void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:82
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:70
Class for GRU cell node.
Definition: gru_cell.hpp:44
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, const std::vector< std::string > &activations, const std::vector< float > &activations_alpha, const std::vector< float > &activations_beta, float clip, bool linear_before_reset)
Constructs GRUCell node.