gru_cell.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <cstddef>
8 #include <memory>
9 #include <string>
10 #include <vector>
11 
12 #include "ngraph/node.hpp"
13 #include "ngraph/op/util/activation_functions.hpp"
14 #include "ngraph/op/util/fused_op.hpp"
15 #include "ngraph/op/util/rnn_cell_base.hpp"
16 
17 namespace ngraph
18 {
19  namespace op
20  {
21  namespace v3
22  {
23  ///
24  /// \brief Class for GRU cell node.
25  ///
26  /// \note It follows notation and equations defined as in ONNX standard:
27  /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU
28  ///
29  /// Note this class represents only single *cell* and not whole GRU *layer*.
30  ///
31  class NGRAPH_API GRUCell : public util::RNNCellBase
32  {
33  public:
34  static constexpr NodeTypeInfo type_info{"GRUCell", 3};
35  const NodeTypeInfo& get_type_info() const override { return type_info; }
36  GRUCell();
37  ///
38  /// \brief Constructs GRUCell node.
39  ///
40  /// \param[in] X The input tensor with shape: [batch_size,
41  /// input_size].
42  /// \param[in] initial_hidden_state The hidden state tensor at current time step
43  /// with shape: [batch_size, hidden_size].
44  /// \param[in] W The weight tensor with shape:
45  /// [gates_count * hidden_size, input_size].
46  /// \param[in] R The recurrence weight tensor with shape:
47  /// [gates_count * hidden_size, hidden_size].
48  /// \param[in] hidden_size The number of hidden units for recurrent cell.
49  ///
51  const Output<Node>& initial_hidden_state,
52  const Output<Node>& W,
53  const Output<Node>& R,
54  std::size_t hidden_size);
55 
56  ///
57  /// \brief Constructs GRUCell node.
58  ///
59  /// \param[in] X The input tensor with shape: [batch_size,
60  /// input_size].
61  /// \param[in] initial_hidden_state The hidden state tensor at current time step
62  /// with shape: [batch_size, hidden_size].
63  /// \param[in] W The weight tensor with shape:
64  /// [gates_count * hidden_size, input_size].
65  /// \param[in] R The recurrence weight tensor with shape:
66  /// [gates_count * hidden_size, hidden_size].
67  /// \param[in] hidden_size The number of hidden units for recurrent cell.
68  /// \param[in] activations The vector of activation functions used inside
69  /// recurrent cell.
70  /// \param[in] activations_alpha The vector of alpha parameters for activation
71  /// functions in order respective to activation
72  /// list.
73  /// \param[in] activations_beta The vector of beta parameters for activation
74  /// functions in order respective to activation
75  /// list.
76  /// \param[in] clip The value defining clipping range [-clip,
77  /// clip] on input of activation functions.
78  ///
80  const Output<Node>& initial_hidden_state,
81  const Output<Node>& W,
82  const Output<Node>& R,
83  std::size_t hidden_size,
84  const std::vector<std::string>& activations,
85  const std::vector<float>& activations_alpha,
86  const std::vector<float>& activations_beta,
87  float clip,
88  bool linear_before_reset);
89 
90  ///
91  /// \brief Constructs GRUCell node.
92  ///
93  /// \param[in] X The input tensor with shape: [batch_size,
94  /// input_size].
95  /// \param[in] initial_hidden_state The hidden state tensor at current time step
96  /// with shape: [batch_size, hidden_size].
97  /// \param[in] W The weight tensor with shape: [gates_count *
98  /// hidden_size, input_size].
99  /// \param[in] R The recurrence weight tensor with shape:
100  /// [gates_count * hidden_size, hidden_size].
101  /// \param[in] hidden_size The number of hidden units for recurrent cell.
102  /// \param[in] B The sum of biases (weight and recurrence) for
103  /// update, reset and hidden gates.
104  /// If linear_before_reset := true then biases for
105  /// hidden gates are
106  /// placed separately (weight and recurrence).
107  /// Shape: [gates_count * hidden_size] if
108  /// linear_before_reset := false
109  /// Shape: [(gates_count + 1) * hidden_size] if
110  /// linear_before_reset := true
111  /// \param[in] activations The vector of activation functions used inside
112  /// recurrent cell.
113  /// \param[in] activations_alpha The vector of alpha parameters for activation
114  /// functions in order respective to activation
115  /// list.
116  /// \param[in] activations_beta The vector of beta parameters for activation
117  /// functions in order respective to activation
118  /// list.
119  /// \param[in] clip The value defining clipping range [-clip,
120  /// clip] on input of activation functions.
121  /// \param[in] linear_before_reset Whether or not to apply the linear
122  /// transformation before multiplying by the
123  /// output of the reset gate.
124  ///
126  const Output<Node>& initial_hidden_state,
127  const Output<Node>& W,
128  const Output<Node>& R,
129  const Output<Node>& B,
130  std::size_t hidden_size,
131  const std::vector<std::string>& activations =
132  std::vector<std::string>{"sigmoid", "tanh"},
133  const std::vector<float>& activations_alpha = {},
134  const std::vector<float>& activations_beta = {},
135  float clip = 0.f,
136  bool linear_before_reset = false);
137 
138  virtual void validate_and_infer_types() override;
139  bool visit_attributes(AttributeVisitor& visitor) override;
140  virtual std::shared_ptr<Node>
141  clone_with_new_inputs(const OutputVector& new_args) const override;
142 
143  bool get_linear_before_reset() const { return m_linear_before_reset; }
144 
145  private:
146  /// brief Add and initialize bias input to all zeros.
147  void add_default_bias_input();
148 
149  ///
150  /// \brief The Activation function f.
151  ///
152  util::ActivationFunction m_activation_f;
153  ///
154  /// \brief The Activation function g.
155  ///
156  util::ActivationFunction m_activation_g;
157 
158  static constexpr std::size_t s_gates_count{3};
159  ///
160  /// \brief Control whether or not apply the linear transformation.
161  ///
162  /// \note The linear transformation may be applied when computing the output of
163  /// hidden gate. It's done before multiplying by the output of the reset gate.
164  ///
165  bool m_linear_before_reset;
166  };
167  } // namespace v3
168  } // namespace op
169 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:70
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:55
Class for GRU cell node.
Definition: gru_cell.hpp:32
const NodeTypeInfo & get_type_info() const override
Definition: gru_cell.hpp:35
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, const std::vector< std::string > &activations, const std::vector< float > &activations_alpha, const std::vector< float > &activations_beta, float clip, bool linear_before_reset)
Constructs GRUCell node.
virtual void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size)
Constructs GRUCell node.
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool linear_before_reset=false)
Constructs GRUCell node.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27