pattern.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <functional>
20 
21 #include "ngraph/node.hpp"
22 
23 namespace ngraph
24 {
25  namespace pattern
26  {
27  namespace op
28  {
29  class Label;
30  }
31 
32  class Matcher;
33  class MatchState;
34 
35  using RPatternValueMap = std::map<std::shared_ptr<Node>, OutputVector>;
36  using PatternValueMap = std::map<std::shared_ptr<Node>, Output<Node>>;
37  using PatternValueMaps = std::vector<PatternValueMap>;
38 
39  using PatternMap = std::map<std::shared_ptr<Node>, std::shared_ptr<Node>>;
40 
41  PatternMap as_pattern_map(const PatternValueMap& pattern_value_map);
42  PatternValueMap as_pattern_value_map(const PatternMap& pattern_map);
43 
44  template <typename T>
45  std::function<bool(std::shared_ptr<Node>)> has_class()
46  {
47  auto pred = [](std::shared_ptr<Node> node) -> bool { return is_type<T>(node); };
48 
49  return pred;
50  }
51 
52  NGRAPH_API
53  std::function<bool(Output<Node>)> consumers_count(size_t n);
54 
55  NGRAPH_API
56  std::function<bool(Output<Node>)> has_static_dim(size_t pos);
57 
58  NGRAPH_API
59  std::function<bool(Output<Node>)> has_static_dims(const std::vector<size_t>& dims);
60 
61  NGRAPH_API
62  std::function<bool(Output<Node>)> has_static_shape();
63 
64  NGRAPH_API
65  std::function<bool(Output<Node>)> type_matches(const element::Type& type);
66 
67  NGRAPH_API
68  std::function<bool(Output<Node>)> type_matches_any(const std::vector<element::Type>& types);
69 
70  namespace op
71  {
72  using NodePredicate = std::function<bool(std::shared_ptr<Node>)>;
73  using ValuePredicate = std::function<bool(const Output<Node>& value)>;
74 
75  NGRAPH_API
76  ValuePredicate as_value_predicate(NodePredicate pred);
77 
78  class NGRAPH_API Pattern : public Node
79  {
80  public:
81  /// \brief \p a base class for \sa Skip and \sa Label
82  ///
83  Pattern(const OutputVector& patterns, ValuePredicate pred)
84  : Node(patterns)
85  , m_predicate(pred)
86  {
87  if (!m_predicate)
88  {
89  m_predicate = [](const Output<Node>&) { return true; };
90  }
91  }
92 
93  Pattern(const OutputVector& patterns)
94  : Pattern(patterns, nullptr)
95  {
96  }
97 
98  virtual std::shared_ptr<Node>
99  clone_with_new_inputs(const OutputVector& /* new_args */) const override
100  {
101  throw ngraph_error("Uncopyable");
102  }
103 
104  ValuePredicate get_predicate() const;
105 
106  protected:
107  ValuePredicate m_predicate;
108  };
109  }
110  }
111 }
ngraph::pattern::op::Pattern::Pattern
Pattern(const OutputVector &patterns, ValuePredicate pred)
a base class for
Definition: pattern.hpp:83
ngraph::pattern::op::Pattern
Definition: pattern.hpp:79
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28