rnn_cell_base.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "ngraph/node.hpp"
25 #include "ngraph/op/util/activation_functions.hpp"
26 
27 namespace ngraph
28 {
29  namespace op
30  {
31  namespace util
32  {
33  enum class LSTMWeightsFormat
34  {
35  FICO, // IE
36  ICOF, // PyTorch
37  IFCO, // DNNL, TF, MxNet
38  IFOC, // Caffe
39  IOFC, // ONNX
40  };
41 
42  ///
43  /// \brief Change data format of provided node.
44  ///
45  /// \param[in] node The input node to be permuted.
46  ///
47  ///
48  /// \param[in] from_format Original node weights format.
49  ///
50  ///
51  /// \param[in] to_format Weights format to convert to.
52  ///
53  /// \return Node representing reshaped tensor according to `to_format` weights
54  /// format.
55  ///
56  std::shared_ptr<Node> NGRAPH_API
57  convert_lstm_node_format(const Output<Node>& node,
58  LSTMWeightsFormat from_format,
59  LSTMWeightsFormat to_format = LSTMWeightsFormat::FICO);
60 
61  /// \brief Base class for all recurrent network cells.
62  ///
63  /// \note It holds all common attributes.
64  ///
65  class NGRAPH_API RNNCellBase : public Op
66  {
67  public:
68  ///
69  /// \brief Constructs a RNNCellBase class.
70  ///
71  /// \param[in] hidden_size The number of hidden units for recurrent cell.
72  /// \param[in] clip The value defining clipping range [-clip, clip]
73  /// on input of activation functions.
74  /// \param[in] activations The vector of activation functions used inside
75  /// recurrent cell.
76  /// \param[in] activations_alpha The vector of alpha parameters for activation
77  /// functions in order respective to activation list.
78  /// \param[in] activations_beta The vector of beta parameters for activation
79  /// functions in order respective to activation list.
80  ///
81  RNNCellBase(const OutputVector& args,
82  std::size_t hidden_size,
83  float clip,
84  const std::vector<std::string>& activations,
85  const std::vector<float>& activations_alpha,
86  const std::vector<float>& activations_beta);
87 
88  RNNCellBase();
89  virtual ~RNNCellBase() = default;
90 
91  ///
92  /// \brief Validates static rank and dimension for provided input parameters.
93  /// Additionally input_size dimension is checked for X and W inputs.
94  ///
95  ///
96  /// \param[in] input Vector with RNN-Cell op inputs in following order:
97  /// X, initial_hidden_state, W, R and B.
98  ///
99  void validate_input_rank_dimension(const std::vector<ngraph::PartialShape>& input);
100 
101  virtual bool visit_attributes(AttributeVisitor& visitor);
102  std::size_t get_hidden_size() const { return m_hidden_size; }
103  float get_clip() const { return m_clip; }
104  const std::vector<std::string>& get_activations() const { return m_activations; }
105  const std::vector<float>& get_activations_alpha() const
106  {
107  return m_activations_alpha;
108  }
109  const std::vector<float>& get_activations_beta() const
110  {
111  return m_activations_beta;
112  }
113 
114  protected:
115  ///
116  /// \brief Constructs activation function object.
117  ///
118  /// \param[in] idx The index of the activation function name.
119  ///
120  /// \return The object representing activation function.
121  ///
123  ///
124  /// \brief Creates node with element-wise add operation with numpy
125  /// broadcasting.
126  ///
127  /// \param[in] lhs The left hand side argument node.
128  /// \param[in] rhs The right hand side argument node.
129  ///
130  /// \return Node with element-wise add operation.
131  ///
132  static std::shared_ptr<Node> add(const Output<Node>& lhs, const Output<Node>& rhs);
133  ///
134  /// \brief Creates node with element-wise subtract operation with numpy
135  /// broadcasting.
136  ///
137  /// \param[in] lhs The left hand side argument node.
138  /// \param[in] rhs The right hand side argument node.
139  ///
140  /// \return Node with element-wise subtract operation.
141  ///
142  static std::shared_ptr<Node> sub(const Output<Node>& lhs, const Output<Node>& rhs);
143  ///
144  /// \brief Creates node with element-wise multiply operation with numpy
145  /// broadcasting.
146  ///
147  /// \param[in] lhs The left hand side argument node.
148  /// \param[in] rhs The right hand side argument node.
149  ///
150  /// \return Node with element-wise multiply operation.
151  ///
152  static std::shared_ptr<Node> mul(const Output<Node>& lhs, const Output<Node>& rhs);
153  ///
154  /// \brief Creates node with element-wise clip operation with numpy
155  /// broadcasting.
156  ///
157  /// \param[in] data The input tensor for clipping.
158  ///
159  /// \return Node with element-wise clip operation.
160  ///
161  std::shared_ptr<Node> clip(const Output<Node>& data) const;
162 
163  protected:
164  std::size_t m_hidden_size;
165  float m_clip;
166  std::vector<std::string> m_activations;
167  std::vector<float> m_activations_alpha;
168  std::vector<float> m_activations_beta;
169  };
170  } // namespace util
171  } // namespace op
172 } // namespace ngraph
ngraph::op::util::RNNCellBase::validate_input_rank_dimension
void validate_input_rank_dimension(const std::vector< ngraph::PartialShape > &input)
Validates static rank and dimension for provided input parameters. Additionally input_size dimension ...
ngraph::op::util::RNNCellBase
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:66
ngraph::op::util::RNNCellBase::sub
static std::shared_ptr< Node > sub(const Output< Node > &lhs, const Output< Node > &rhs)
Creates node with element-wise subtract operation with numpy broadcasting.
ngraph::op::util::RNNCellBase::mul
static std::shared_ptr< Node > mul(const Output< Node > &lhs, const Output< Node > &rhs)
Creates node with element-wise multiply operation with numpy broadcasting.
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::op::util::RNNCellBase::get_activation_function
ActivationFunction get_activation_function(std::size_t idx) const
Constructs activation function object.
ngraph::op::util::RNNCellBase::RNNCellBase
RNNCellBase(const OutputVector &args, std::size_t hidden_size, float clip, const std::vector< std::string > &activations, const std::vector< float > &activations_alpha, const std::vector< float > &activations_beta)
Constructs a RNNCellBase class.
ngraph::op::util::ActivationFunction
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:82
ngraph::AttributeVisitor
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:70
ngraph::op::util::RNNCellBase::add
static std::shared_ptr< Node > add(const Output< Node > &lhs, const Output< Node > &rhs)
Creates node with element-wise add operation with numpy broadcasting.
ngraph::op::util::RNNCellBase::clip
std::shared_ptr< Node > clip(const Output< Node > &data) const
Creates node with element-wise clip operation with numpy broadcasting.
ngraph::op::Op
Root of all actual ops.
Definition: op.hpp:29