gru_cell_decomposition.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <vector>
8 #include <memory>
9 
10 #include <transformations_visibility.hpp>
11 
12 #include <ngraph/pass/graph_rewrite.hpp>
13 
14 namespace ngraph {
15 namespace pass {
16 
17 class TRANSFORMATIONS_API GRUCellDecomposition;
18 
19 } // namespace pass
20 } // namespace ngraph
21 
22 /**
23  * @ingroup ie_transformation_common_api
24  * @brief GRUCellDecomposition transformation decomposes GRUCell layer with inputs X, H, W, R, B
25  * to Add, Split, MatMul, Multiply and Subtract ops according to the formula:
26  (.) - Denotes element-wise multiplication.
27  * - Denotes dot product.
28  f, g - are activation functions
29 
30  zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
31  rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
32  ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # when linear_before_reset := false # (default)
33  ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset:= true
34  Ht = (1 - zt) (.) ht + zt (.) Ht-1
35  * *
36  */
37 
38 class ngraph::pass::GRUCellDecomposition: public ngraph::pass::MatcherPass {
39 public:
40  NGRAPH_RTTI_DECLARATION;
42 };
GRUCellDecomposition transformation decomposes GRUCell layer with inputs X, H, W, R,...
Definition: gru_cell_decomposition.hpp:38
ngraph namespace
Definition: add_fake_quantize_fusion.hpp:14