12 #include "ngraph/node.hpp"
13 #include "ngraph/op/util/activation_functions.hpp"
14 #include "ngraph/op/util/fused_op.hpp"
15 #include "ngraph/op/util/rnn_cell_base.hpp"
54 std::size_t hidden_size);
83 std::size_t hidden_size,
84 const std::vector<std::string>& activations,
85 const std::vector<float>& activations_alpha,
86 const std::vector<float>& activations_beta,
88 bool linear_before_reset);
130 std::size_t hidden_size,
131 const std::vector<std::string>& activations =
132 std::vector<std::string>{
"sigmoid",
"tanh"},
133 const std::vector<float>& activations_alpha = {},
134 const std::vector<float>& activations_beta = {},
136 bool linear_before_reset =
false);
140 virtual std::shared_ptr<Node>
141 clone_with_new_inputs(
const OutputVector& new_args)
const override;
143 bool get_linear_before_reset()
const {
return m_linear_before_reset; }
147 void add_default_bias_input();
158 static constexpr std::size_t s_gates_count{3};
165 bool m_linear_before_reset;
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:70
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:55
Class for GRU cell node.
Definition: gru_cell.hpp:32
const NodeTypeInfo & get_type_info() const override
Definition: gru_cell.hpp:35
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, const std::vector< std::string > &activations, const std::vector< float > &activations_alpha, const std::vector< float > &activations_beta, float clip, bool linear_before_reset)
Constructs GRUCell node.
virtual void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size)
Constructs GRUCell node.
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool linear_before_reset=false)
Constructs GRUCell node.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16