pattern.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <functional>
20 
21 #include "ngraph/node.hpp"
22 
23 namespace ngraph
24 {
25  namespace pattern
26  {
27  namespace op
28  {
29  class Label;
30  }
31 
32  class Matcher;
33  class MatchState;
34 
35  using RPatternValueMap = std::map<std::shared_ptr<Node>, OutputVector>;
36  using PatternValueMap = std::map<std::shared_ptr<Node>, Output<Node>>;
37  using PatternValueMaps = std::vector<PatternValueMap>;
38 
39  using PatternMap = std::map<std::shared_ptr<Node>, std::shared_ptr<Node>>;
40 
41  PatternMap as_pattern_map(const PatternValueMap& pattern_value_map);
42  PatternValueMap as_pattern_value_map(const PatternMap& pattern_map);
43 
44  template <typename T>
45  std::function<bool(std::shared_ptr<Node>)> has_class()
46  {
47  auto pred = [](std::shared_ptr<Node> node) -> bool { return is_type<T>(node); };
48 
49  return pred;
50  }
51 
52  NGRAPH_API
53  std::function<bool(Output<Node>)> consumers_count(size_t n);
54 
55  NGRAPH_API
56  std::function<bool(Output<Node>)> has_static_dim(size_t pos);
57 
58  NGRAPH_API
59  std::function<bool(Output<Node>)> has_static_dims(const std::vector<size_t>& dims);
60 
61  NGRAPH_API
62  std::function<bool(Output<Node>)> has_static_shape();
63 
64  NGRAPH_API
65  std::function<bool(Output<Node>)> has_static_rank();
66 
67  NGRAPH_API
68  std::function<bool(Output<Node>)> type_matches(const element::Type& type);
69 
70  NGRAPH_API
71  std::function<bool(Output<Node>)> type_matches_any(const std::vector<element::Type>& types);
72 
73  namespace op
74  {
75  using NodePredicate = std::function<bool(std::shared_ptr<Node>)>;
76  using ValuePredicate = std::function<bool(const Output<Node>& value)>;
77 
78  NGRAPH_API
79  ValuePredicate as_value_predicate(NodePredicate pred);
80 
81  class NGRAPH_API Pattern : public Node
82  {
83  public:
84  /// \brief \p a base class for \sa Skip and \sa Label
85  ///
86  Pattern(const OutputVector& patterns, ValuePredicate pred)
87  : Node(patterns)
88  , m_predicate(pred)
89  {
90  if (!m_predicate)
91  {
92  m_predicate = [](const Output<Node>&) { return true; };
93  }
94  }
95 
96  Pattern(const OutputVector& patterns)
97  : Pattern(patterns, nullptr)
98  {
99  }
100 
101  virtual std::shared_ptr<Node>
102  clone_with_new_inputs(const OutputVector& /* new_args */) const override
103  {
104  throw ngraph_error("Uncopyable");
105  }
106 
107  ValuePredicate get_predicate() const;
108 
109  protected:
110  ValuePredicate m_predicate;
111  };
112  }
113  }
114 }
Definition: node.hpp:132
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Base error for ngraph runtime errors.
Definition: except.hpp:28
Definition: pattern.hpp:82
Pattern(const OutputVector &patterns, ValuePredicate pred)
a base class for
Definition: pattern.hpp:86
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28