lstm_cell_decomposition.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <vector>
8 #include <memory>
9 
10 #include <transformations_visibility.hpp>
11 
12 #include <ngraph/pass/graph_rewrite.hpp>
13 
14 namespace ngraph {
15 namespace pass {
16 
17 class TRANSFORMATIONS_API LSTMCellDecomposition;
18 
19 } // namespace pass
20 } // namespace ngraph
21 
22 /**
23  * @ingroup ie_transformation_common_api
24  * @brief LSTMCellDecomposition transformation decomposes LSTMCell layer with inputs X, H, C, W, R, B
25  * to Add, Split, MatMul, Multiply ops according to the formula:
26  * (.) - Denotes element-wise multiplication.
27  * - Denotes dot product.
28  f, g, h - are activation functions.
29 
30  * it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
31  ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
32  ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
33  ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
34  Ct = ft (.) Ct-1 + it (.) ct
35  Ht = ot (.) h(Ct)
36  * *
37  */
38 
39 class ngraph::pass::LSTMCellDecomposition: public ngraph::pass::MatcherPass {
40 public:
41  NGRAPH_RTTI_DECLARATION;
43 };
LSTMCellDecomposition transformation decomposes LSTMCell layer with inputs X, H, C,...
Definition: lstm_cell_decomposition.hpp:39
ngraph namespace
Definition: add_fake_quantize_fusion.hpp:14