rnn_cell_base.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <cstddef>
8 #include <memory>
9 #include <string>
10 #include <vector>
11 
12 #include "ngraph/node.hpp"
13 #include "ngraph/op/util/activation_functions.hpp"
14 
15 namespace ngraph
16 {
17  namespace op
18  {
19  namespace util
20  {
21  enum class LSTMWeightsFormat
22  {
23  FICO, // IE
24  ICOF, // PyTorch
25  IFCO, // DNNL, TF, MxNet
26  IFOC, // Caffe
27  IOFC, // ONNX
28  };
29 
30  ///
31  /// \brief Change data format of provided node.
32  ///
33  /// \param[in] node The input node to be permuted.
34  ///
35  ///
36  /// \param[in] from_format Original node weights format.
37  ///
38  ///
39  /// \param[in] to_format Weights format to convert to.
40  ///
41  /// \return Node representing reshaped tensor according to `to_format` weights
42  /// format.
43  ///
44  std::shared_ptr<Node> NGRAPH_API
45  convert_lstm_node_format(const Output<Node>& node,
46  LSTMWeightsFormat from_format,
47  LSTMWeightsFormat to_format = LSTMWeightsFormat::FICO,
48  int64_t axis = 0);
49 
50  /// \brief Base class for all recurrent network cells.
51  ///
52  /// \note It holds all common attributes.
53  ///
54  class NGRAPH_API RNNCellBase : public Op
55  {
56  public:
57  ///
58  /// \brief Constructs a RNNCellBase class.
59  ///
60  /// \param[in] hidden_size The number of hidden units for recurrent cell.
61  /// \param[in] clip The value defining clipping range [-clip, clip]
62  /// on input of activation functions.
63  /// \param[in] activations The vector of activation functions used inside
64  /// recurrent cell.
65  /// \param[in] activations_alpha The vector of alpha parameters for activation
66  /// functions in order respective to activation list.
67  /// \param[in] activations_beta The vector of beta parameters for activation
68  /// functions in order respective to activation list.
69  ///
70  RNNCellBase(const OutputVector& args,
71  std::size_t hidden_size,
72  float clip,
73  const std::vector<std::string>& activations,
74  const std::vector<float>& activations_alpha,
75  const std::vector<float>& activations_beta);
76 
77  RNNCellBase();
78  virtual ~RNNCellBase() = default;
79 
80  ///
81  /// \brief Validates static rank and dimension for provided input parameters.
82  /// Additionally input_size dimension is checked for X and W inputs.
83  ///
84  ///
85  /// \param[in] input Vector with RNN-Cell op inputs in following order:
86  /// X, initial_hidden_state, W, R and B.
87  ///
88  void validate_input_rank_dimension(const std::vector<ngraph::PartialShape>& input);
89 
90  virtual bool visit_attributes(AttributeVisitor& visitor);
91  std::size_t get_hidden_size() const { return m_hidden_size; }
92  float get_clip() const { return m_clip; }
93  const std::vector<std::string>& get_activations() const { return m_activations; }
94  const std::vector<float>& get_activations_alpha() const
95  {
96  return m_activations_alpha;
97  }
98  const std::vector<float>& get_activations_beta() const
99  {
100  return m_activations_beta;
101  }
102 
103  protected:
104  ///
105  /// \brief Constructs activation function object.
106  ///
107  /// \param[in] idx The index of the activation function name.
108  ///
109  /// \return The object representing activation function.
110  ///
112  ///
113  /// \brief Creates node with element-wise add operation with numpy
114  /// broadcasting.
115  ///
116  /// \param[in] lhs The left hand side argument node.
117  /// \param[in] rhs The right hand side argument node.
118  ///
119  /// \return Node with element-wise add operation.
120  ///
121  static std::shared_ptr<Node> add(const Output<Node>& lhs, const Output<Node>& rhs);
122  ///
123  /// \brief Creates node with element-wise subtract operation with numpy
124  /// broadcasting.
125  ///
126  /// \param[in] lhs The left hand side argument node.
127  /// \param[in] rhs The right hand side argument node.
128  ///
129  /// \return Node with element-wise subtract operation.
130  ///
131  static std::shared_ptr<Node> sub(const Output<Node>& lhs, const Output<Node>& rhs);
132  ///
133  /// \brief Creates node with element-wise multiply operation with numpy
134  /// broadcasting.
135  ///
136  /// \param[in] lhs The left hand side argument node.
137  /// \param[in] rhs The right hand side argument node.
138  ///
139  /// \return Node with element-wise multiply operation.
140  ///
141  static std::shared_ptr<Node> mul(const Output<Node>& lhs, const Output<Node>& rhs);
142  ///
143  /// \brief Creates node with element-wise clip operation with numpy
144  /// broadcasting.
145  ///
146  /// \param[in] data The input tensor for clipping.
147  ///
148  /// \return Node with element-wise clip operation.
149  ///
150  std::shared_ptr<Node> clip(const Output<Node>& data) const;
151 
152  protected:
153  std::size_t m_hidden_size;
154  float m_clip;
155  std::vector<std::string> m_activations;
156  std::vector<float> m_activations_alpha;
157  std::vector<float> m_activations_beta;
158  };
159  } // namespace util
160  } // namespace op
161 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Root of all actual ops.
Definition: op.hpp:17
Class representing activation function used in RNN cells.
Definition: activation_functions.hpp:70
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:55
static std::shared_ptr< Node > add(const Output< Node > &lhs, const Output< Node > &rhs)
Creates node with element-wise add operation with numpy broadcasting.
ActivationFunction get_activation_function(std::size_t idx) const
Constructs activation function object.
void validate_input_rank_dimension(const std::vector< ngraph::PartialShape > &input)
Validates static rank and dimension for provided input parameters. Additionally input_size dimension ...
std::shared_ptr< Node > clip(const Output< Node > &data) const
Creates node with element-wise clip operation with numpy broadcasting.
static std::shared_ptr< Node > mul(const Output< Node > &lhs, const Output< Node > &rhs)
Creates node with element-wise multiply operation with numpy broadcasting.
RNNCellBase(const OutputVector &args, std::size_t hidden_size, float clip, const std::vector< std::string > &activations, const std::vector< float > &activations_alpha, const std::vector< float > &activations_beta)
Constructs a RNNCellBase class.
static std::shared_ptr< Node > sub(const Output< Node > &lhs, const Output< Node > &rhs)
Creates node with element-wise subtract operation with numpy broadcasting.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16