convert_ti_to_sequences.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <vector>
8 #include <memory>
9 
10 #include <vector>
11 #include <memory>
12 
13 #include <transformations_visibility.hpp>
14 
15 #include <ngraph/pass/graph_rewrite.hpp>
16 
17 namespace ngraph {
18 namespace pass {
19 
20 class TRANSFORMATIONS_API ConvertTensorIteratorToLSTMSequence;
21 class TRANSFORMATIONS_API ConvertTensorIteratorToRNNSequence;
22 class TRANSFORMATIONS_API ConvertTensorIteratorToGRUSequence;
23 
24 } // namespace pass
25 } // namespace ngraph
26 
27 /**
28  * @ingroup ie_transformation_common_api
29  * @brief Finds all TensorIterator layers, detects the pattern Squeeze->LSTMCell->Unsqueeze in the TensorIterator body,
30  * converts this pattern to LSTMSequence layer and replaces them TensorIterator.
31  */
32 
33 class ngraph::pass::ConvertTensorIteratorToLSTMSequence: public ngraph::pass::MatcherPass {
34 public:
35  NGRAPH_RTTI_DECLARATION;
37 };
38 
39 /**
40  * @ingroup ie_transformation_common_api
41  * @brief Finds all TensorIterator layers, detects the pattern Squeeze->RNNCell->Unsqueeze in the TensorIterator body,
42  * converts this pattern to RNNSequence layer and replaces them TensorIterator.
43  */
44 
45 class ngraph::pass::ConvertTensorIteratorToRNNSequence: public ngraph::pass::MatcherPass {
46 public:
47  NGRAPH_RTTI_DECLARATION;
49 };
50 
51 /**
52  * @ingroup ie_transformation_common_api
53  * @brief Finds all TensorIterator layers, detects the pattern Squeeze->GRUCell->Unsqueeze in the TensorIterator body,
54  * converts this pattern to GRUSequence layer and replaces them TensorIterator.
55  */
56 
57 class ngraph::pass::ConvertTensorIteratorToGRUSequence: public ngraph::pass::MatcherPass {
58 public:
59  NGRAPH_RTTI_DECLARATION;
61 };
Finds all TensorIterator layers, detects the pattern Squeeze->GRUCell->Unsqueeze in the TensorIterato...
Definition: convert_ti_to_sequences.hpp:57
Finds all TensorIterator layers, detects the pattern Squeeze->LSTMCell->Unsqueeze in the TensorIterat...
Definition: convert_ti_to_sequences.hpp:33
Finds all TensorIterator layers, detects the pattern Squeeze->RNNCell->Unsqueeze in the TensorIterato...
Definition: convert_ti_to_sequences.hpp:45
ngraph namespace
Definition: add_fake_quantize_fusion.hpp:14