gru_cell.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "ngraph/node.hpp"
25 #include "ngraph/op/util/activation_functions.hpp"
26 #include "ngraph/op/util/fused_op.hpp"
27 #include "ngraph/op/util/rnn_cell_base.hpp"
28 
29 namespace ngraph
30 {
31  namespace op
32  {
33  namespace v3
34  {
35  ///
36  /// \brief Class for GRU cell node.
37  ///
38  /// \note It follows notation and equations defined as in ONNX standard:
39  /// https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU
40  ///
41  /// Note this class represents only single *cell* and not whole GRU *layer*.
42  ///
43  class NGRAPH_API GRUCell : public util::RNNCellBase
44  {
45  public:
46  static constexpr NodeTypeInfo type_info{"GRUCell", 3};
47  const NodeTypeInfo& get_type_info() const override { return type_info; }
48  GRUCell();
49  ///
50  /// \brief Constructs GRUCell node.
51  ///
52  /// \param[in] X The input tensor with shape: [batch_size,
53  /// input_size].
54  /// \param[in] initial_hidden_state The hidden state tensor at current time step
55  /// with shape: [batch_size, hidden_size].
56  /// \param[in] W The weight tensor with shape:
57  /// [gates_count * hidden_size, input_size].
58  /// \param[in] R The recurrence weight tensor with shape:
59  /// [gates_count * hidden_size, hidden_size].
60  /// \param[in] hidden_size The number of hidden units for recurrent cell.
61  ///
62  GRUCell(const Output<Node>& X,
63  const Output<Node>& initial_hidden_state,
64  const Output<Node>& W,
65  const Output<Node>& R,
66  std::size_t hidden_size);
67 
68  ///
69  /// \brief Constructs GRUCell node.
70  ///
71  /// \param[in] X The input tensor with shape: [batch_size,
72  /// input_size].
73  /// \param[in] initial_hidden_state The hidden state tensor at current time step
74  /// with shape: [batch_size, hidden_size].
75  /// \param[in] W The weight tensor with shape:
76  /// [gates_count * hidden_size, input_size].
77  /// \param[in] R The recurrence weight tensor with shape:
78  /// [gates_count * hidden_size, hidden_size].
79  /// \param[in] hidden_size The number of hidden units for recurrent cell.
80  /// \param[in] activations The vector of activation functions used inside
81  /// recurrent cell.
82  /// \param[in] activations_alpha The vector of alpha parameters for activation
83  /// functions in order respective to activation
84  /// list.
85  /// \param[in] activations_beta The vector of beta parameters for activation
86  /// functions in order respective to activation
87  /// list.
88  /// \param[in] clip The value defining clipping range [-clip,
89  /// clip] on input of activation functions.
90  ///
91  GRUCell(const Output<Node>& X,
92  const Output<Node>& initial_hidden_state,
93  const Output<Node>& W,
94  const Output<Node>& R,
95  std::size_t hidden_size,
96  const std::vector<std::string>& activations,
97  const std::vector<float>& activations_alpha,
98  const std::vector<float>& activations_beta,
99  float clip,
100  bool linear_before_reset);
101 
102  ///
103  /// \brief Constructs GRUCell node.
104  ///
105  /// \param[in] X The input tensor with shape: [batch_size,
106  /// input_size].
107  /// \param[in] initial_hidden_state The hidden state tensor at current time step
108  /// with shape: [batch_size, hidden_size].
109  /// \param[in] W The weight tensor with shape: [gates_count *
110  /// hidden_size, input_size].
111  /// \param[in] R The recurrence weight tensor with shape:
112  /// [gates_count * hidden_size, hidden_size].
113  /// \param[in] hidden_size The number of hidden units for recurrent cell.
114  /// \param[in] B The sum of biases (weight and recurrence) for
115  /// update, reset and hidden gates.
116  /// If linear_before_reset := true then biases for
117  /// hidden gates are
118  /// placed separately (weight and recurrence).
119  /// Shape: [gates_count * hidden_size] if
120  /// linear_before_reset := false
121  /// Shape: [(gates_count + 1) * hidden_size] if
122  /// linear_before_reset := true
123  /// \param[in] activations The vector of activation functions used inside
124  /// recurrent cell.
125  /// \param[in] activations_alpha The vector of alpha parameters for activation
126  /// functions in order respective to activation
127  /// list.
128  /// \param[in] activations_beta The vector of beta parameters for activation
129  /// functions in order respective to activation
130  /// list.
131  /// \param[in] clip The value defining clipping range [-clip,
132  /// clip] on input of activation functions.
133  /// \param[in] linear_before_reset Whether or not to apply the linear
134  /// transformation before multiplying by the
135  /// output of the reset gate.
136  ///
137  GRUCell(const Output<Node>& X,
138  const Output<Node>& initial_hidden_state,
139  const Output<Node>& W,
140  const Output<Node>& R,
141  const Output<Node>& B,
142  std::size_t hidden_size,
143  const std::vector<std::string>& activations =
144  std::vector<std::string>{"sigmoid", "tanh"},
145  const std::vector<float>& activations_alpha = {},
146  const std::vector<float>& activations_beta = {},
147  float clip = 0.f,
148  bool linear_before_reset = false);
149 
150  virtual void validate_and_infer_types() override;
151  bool visit_attributes(AttributeVisitor& visitor) override;
152  virtual std::shared_ptr<Node>
153  clone_with_new_inputs(const OutputVector& new_args) const override;
154 
155  bool get_linear_before_reset() const { return m_linear_before_reset; }
156  private:
157  /// brief Add and initialize bias input to all zeros.
158  void add_default_bias_input();
159 
160  ///
161  /// \brief The Activation function f.
162  ///
163  util::ActivationFunction m_activation_f;
164  ///
165  /// \brief The Activation function g.
166  ///
167  util::ActivationFunction m_activation_g;
168 
169  static constexpr std::size_t s_gates_count{3};
170  ///
171  /// \brief Control whether or not apply the linear transformation.
172  ///
173  /// \note The linear transformation may be applied when computing the output of
174  /// hidden gate. It's done before multiplying by the output of the reset gate.
175  ///
176  bool m_linear_before_reset;
177  };
178  }
179  }
180 }
ngraph::op::v3::GRUCell::GRUCell
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size)
Constructs GRUCell node.
ngraph::op::v3::GRUCell::get_type_info
const NodeTypeInfo & get_type_info() const override
Definition: gru_cell.hpp:47
ngraph::op::util::RNNCellBase
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:66
ngraph::op::v3::GRUCell::GRUCell
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, const Output< Node > &B, std::size_t hidden_size, const std::vector< std::string > &activations=std::vector< std::string >{"sigmoid", "tanh"}, const std::vector< float > &activations_alpha={}, const std::vector< float > &activations_beta={}, float clip=0.f, bool linear_before_reset=false)
Constructs GRUCell node.
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::op::v3::GRUCell::validate_and_infer_types
virtual void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
ngraph::op::util::ActivationFunction
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:82
ngraph::AttributeVisitor
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:70
ngraph::op::v3::GRUCell
Class for GRU cell node.
Definition: gru_cell.hpp:44
ngraph::op::v3::GRUCell::GRUCell
GRUCell(const Output< Node > &X, const Output< Node > &initial_hidden_state, const Output< Node > &W, const Output< Node > &R, std::size_t hidden_size, const std::vector< std::string > &activations, const std::vector< float > &activations_alpha, const std::vector< float > &activations_beta, float clip, bool linear_before_reset)
Constructs GRUCell node.