pattern.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <functional>
8 
9 #include "ngraph/node.hpp"
10 
11 namespace ngraph
12 {
13  namespace pattern
14  {
15  namespace op
16  {
17  class Label;
18  }
19 
20  class Matcher;
21  class MatchState;
22 
23  using RPatternValueMap = std::map<std::shared_ptr<Node>, OutputVector>;
24  using PatternValueMap = std::map<std::shared_ptr<Node>, Output<Node>>;
25  using PatternValueMaps = std::vector<PatternValueMap>;
26 
27  using PatternMap = std::map<std::shared_ptr<Node>, std::shared_ptr<Node>>;
28 
29  PatternMap as_pattern_map(const PatternValueMap& pattern_value_map);
30  PatternValueMap as_pattern_value_map(const PatternMap& pattern_map);
31 
32  template <typename T>
33  std::function<bool(std::shared_ptr<Node>)> has_class()
34  {
35  auto pred = [](std::shared_ptr<Node> node) -> bool { return is_type<T>(node); };
36 
37  return pred;
38  }
39 
40  NGRAPH_API
41  std::function<bool(Output<Node>)> consumers_count(size_t n);
42 
43  NGRAPH_API
44  std::function<bool(Output<Node>)> has_static_dim(size_t pos);
45 
46  NGRAPH_API
47  std::function<bool(Output<Node>)> has_static_dims(const std::vector<size_t>& dims);
48 
49  NGRAPH_API
50  std::function<bool(Output<Node>)> has_static_shape();
51 
52  NGRAPH_API
53  std::function<bool(Output<Node>)> has_static_rank();
54 
55  NGRAPH_API
56  std::function<bool(Output<Node>)> rank_equals(const Dimension& expected_rank);
57 
58  NGRAPH_API
59  std::function<bool(Output<Node>)> type_matches(const element::Type& type);
60 
61  NGRAPH_API
62  std::function<bool(Output<Node>)> type_matches_any(const std::vector<element::Type>& types);
63 
64  namespace op
65  {
66  using NodePredicate = std::function<bool(std::shared_ptr<Node>)>;
67  using ValuePredicate = std::function<bool(const Output<Node>& value)>;
68 
69  NGRAPH_API
70  ValuePredicate as_value_predicate(NodePredicate pred);
71 
72  class NGRAPH_API Pattern : public Node
73  {
74  public:
75  /// \brief \p a base class for \sa Skip and \sa Label
76  ///
77  Pattern(const OutputVector& patterns, ValuePredicate pred)
78  : Node(patterns)
79  , m_predicate(pred)
80  {
81  if (!m_predicate)
82  {
83  m_predicate = [](const Output<Node>&) { return true; };
84  }
85  }
86 
87  Pattern(const OutputVector& patterns)
88  : Pattern(patterns, nullptr)
89  {
90  }
91 
92  virtual std::shared_ptr<Node>
93  clone_with_new_inputs(const OutputVector& /* new_args */) const override
94  {
95  throw ngraph_error("Uncopyable");
96  }
97 
98  ValuePredicate get_predicate() const;
99 
100  protected:
101  ValuePredicate m_predicate;
102  };
103  } // namespace op
104  } // namespace pattern
105 } // namespace ngraph
Definition: node.hpp:127
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Base error for ngraph runtime errors.
Definition: except.hpp:16
Definition: pattern.hpp:73
Pattern(const OutputVector &patterns, ValuePredicate pred)
a base class for
Definition: pattern.hpp:77
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16