matcher.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <algorithm>
8 #include <functional>
9 #include <memory.h>
10 
11 #include "ngraph/node.hpp"
12 #include "ngraph/op/constant.hpp"
13 #include "ngraph/pattern/op/any.hpp"
14 #include "ngraph/pattern/op/any_of.hpp"
15 #include "ngraph/pattern/op/any_output.hpp"
16 #include "ngraph/pattern/op/label.hpp"
17 #include "ngraph/pattern/op/skip.hpp"
18 
19 namespace ngraph
20 {
21  namespace pass
22  {
23  class GraphRewrite;
24  }
25 
26  namespace pattern
27  {
28  class Matcher;
29 
30  class NGRAPH_API MatcherState
31  {
32  public:
34  bool finish(bool is_successful);
35  ~MatcherState();
36 
37  protected:
38  Matcher* m_matcher;
39  PatternValueMap m_pattern_value_map;
40  PatternValueMaps m_pattern_value_maps;
41  size_t m_watermark;
42  size_t m_capture_size;
43  bool m_restore{true};
44  };
45 
46  /// Matcher looks for node patterns in a computation graph. The patterns are described by an
47  /// automaton that is described by an extended computation graph. The matcher executes
48  /// by attempting to match the start node of the pattern to a computation graph value
49  /// (output of a Node). In addition to determing if a match occurs, a pattern node may add
50  /// graph nodes to a list of matched nodes, associate nodes with graph values, and start
51  /// submatches. Submatches add match state changes to the enclosing match if the submatch
52  /// succeeds; otherwise the state is reverted.
53  ///
54  /// The default match behavior of a pattern node with a graph nodes is that the computation
55  /// graph value is added to the end of the matched value list and the match succeeds if the
56  /// node/pattern types match and the input values match. In the case of a commutative node,
57  /// the inputs can match in any order. If the matcher is in strict mode, the graph value
58  /// element type and shape must also match.
59  ///
60  /// Pattern nodes that have different match behavior are in ngraph::pattern::op and have
61  /// descriptions of their match behavior.
62  class NGRAPH_API Matcher
63  {
64  public:
65  using PatternMap = ngraph::pattern::PatternMap;
66 
67  // Avoid implicit string construction from nullptr.
68  Matcher(const std::shared_ptr<Node> pattern_node, std::nullptr_t name) = delete;
69 
70  Matcher() {}
71  Matcher(Output<Node>& pattern_node)
72  : m_pattern_node{pattern_node}
73  {
74  }
75 
76  Matcher(Output<Node>& pattern_node, const std::string& name)
77  : m_pattern_node(pattern_node)
78  , m_name{name}
79  {
80  }
81 
82  /// \brief Constructs a Matcher object
83  ///
84  /// \param pattern_node is a pattern sub graph that will be matched against input graphs
85  /// \param name is a string which is used for logging and disabling a matcher
86  /// \param strict_mode forces a matcher to consider shapes and ET of nodes
87  Matcher(const Output<Node>& pattern_node, const std::string& name, bool strict_mode)
88  : m_pattern_node(pattern_node)
89  , m_name(name)
90  , m_strict_mode(strict_mode)
91  {
92  }
93 
94  // Some matches should start on a node rather than an output. These three constructors
95  // are transition until we work out the right way to do that.
96  Matcher(std::shared_ptr<Node> pattern_node);
97  Matcher(std::shared_ptr<Node> pattern_node, const std::string& name);
98  Matcher(std::shared_ptr<Node> pattern_node, const std::string& name, bool strict_mode);
99 
100  virtual ~Matcher() {}
101  /// \brief Matches a pattern to \p graph_node
102  ///
103  /// \param graph_value is an input graph to be matched against
104  bool match(const Output<Node>& graph_value);
105 
106  bool match(std::shared_ptr<Node> graph_node);
107 
108  /// \brief Matches a pattern to \p graph_node
109  ///
110  /// \param graph_value is an input graph to be matched against
111  /// \param previous_matches contains previous mappings from labels to nodes to use
112  bool match(const Output<Node>& graph_value, const PatternMap& previous_matches);
113  bool match(const Output<Node>& graph_value, const PatternValueMap& previous_matches);
114 
115  template <typename T>
116  static std::shared_ptr<T> unique_match(std::shared_ptr<Node> node)
117  {
118  std::shared_ptr<T> matched;
119  for (auto arg : node->input_values())
120  {
121  if (auto t_casted = as_type_ptr<T>(arg.get_node_shared_ptr()))
122  {
123  if (matched)
124  {
125  throw ngraph_error("There's more than two arguments of the same type");
126  }
127  else
128  {
129  matched = t_casted;
130  }
131  }
132  }
133  return matched;
134  }
135 
136  bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true);
137  const NodeVector get_matched_nodes() { return as_node_vector(m_matched_list); }
138  const OutputVector& get_matched_values() const { return m_matched_list; }
139  OutputVector& get_matched_values() { return m_matched_list; }
140  void reset() {}
141  const std::string& get_name() { return m_name; }
142  std::shared_ptr<Node> get_pattern() { return m_pattern_node.get_node_shared_ptr(); }
143  Output<Node> get_pattern_value() { return m_pattern_node; }
144  std::shared_ptr<Node> get_match_root();
145  Output<Node> get_match_value();
146  PatternMap get_pattern_map() const;
147  PatternValueMap& get_pattern_value_map() { return m_pattern_map; }
148  PatternValueMaps& get_pattern_value_maps() { return m_pattern_value_maps; }
149  /// \brief Low-level helper to match recurring patterns
150  ///
151  /// \param graph is a graph to be matched against
152  /// \param pattern is a recurring pattern
153  /// \param rpattern specifies a node to recur from next
154  /// \param patterns a map from labels to matches
155 
156  size_t add_node(Output<Node> node);
157 
158  bool virtual match_value(const ngraph::Output<Node>& pattern_value,
159  const ngraph::Output<Node>& graph_value);
160 
161  bool is_strict_mode() { return m_strict_mode; }
162  virtual bool match_arguments(Node* pattern_node,
163  const std::shared_ptr<Node>& graph_node);
164 
165  void capture(const std::set<Node*>& static_nodes);
166 
167  void clear_state();
168 
169  size_t get_number_of_recurrent_matches() const { return m_pattern_value_maps.size(); }
170  NodeVector get_bound_nodes_for_pattern(const Output<Node>& pattern) const;
171  size_t get_number_of_bound_labels() const;
172  /// \brief Try a match
174 
175  Output<Node> m_match_root;
176  Output<Node> m_pattern_node;
177  PatternValueMap m_pattern_map;
178  PatternValueMaps m_pattern_value_maps;
179  OutputVector m_matched_list;
180 
181  protected:
182  bool match_permutation(const OutputVector& pattern_args, const OutputVector& args);
183 
184  std::string m_name{"unnamed"};
185  bool m_strict_mode{false};
186  };
187 
188  class NGRAPH_API RecurrentMatcher
189  {
190  public:
191  /// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
192  /// repeating patterns (e.g. RNN, LSTM, GRU cells)
193  ///
194  /// \param initial_pattern is a pattern sub graph describing the initial cell
195  /// \param pattern is a pattern sub graph describing an individual cell
196  /// \param rpattern is a (recurring) label to denote which node the next match should
197  /// start at
198  /// \param correlated_patterns is a set of labels whose bound nodes must remain the same
199  /// across all cells
200  RecurrentMatcher(const Output<Node>& initial_pattern,
201  const Output<Node>& pattern,
202  const std::shared_ptr<Node>& rpattern,
203  const std::set<std::shared_ptr<Node>>& correlated_patterns)
204  : m_initial_pattern(initial_pattern)
205  , m_pattern(pattern)
206  , m_recurrent_pattern(rpattern)
207  , m_correlated_patterns(correlated_patterns)
208  {
209  }
210 
211  /// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
212  /// repeating patterns (e.g. RNN, LSTM, GRU cells)
213  ///
214  /// \param pattern is a pattern sub graph describing an individual cell
215  /// \param rpattern is a (recurring) label to denote which node the next match should
216  /// start at
217  /// \param correlated_patterns is a set of labels whose bound nodes must remain the same
218  /// across all cells
220  const std::shared_ptr<Node>& rpattern,
221  const std::set<std::shared_ptr<Node>>& correlated_patterns)
222  : RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns)
223  {
224  }
225 
226  RecurrentMatcher(const Output<Node>& initial_pattern,
227  const Output<Node>& pattern,
228  const std::shared_ptr<Node>& rpattern,
229  const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
230 
231  RecurrentMatcher(const Output<Node>& pattern,
232  const std::shared_ptr<Node>& rpattern,
233  const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
234  : RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns)
235  {
236  }
237 
238  /// \brief Returns a vector of bound nodes for a given label (used in a pattern
239  /// describing an individual cell
240  NodeVector get_bound_nodes_for_pattern(const std::shared_ptr<Node>& pattern) const
241  {
242  if (m_matches.count(pattern) == 0)
243  {
244  throw ngraph_error("No bound nodes for a given label");
245  }
246 
247  return as_node_vector(m_matches.at(pattern));
248  }
249 
250  size_t get_number_of_recurrent_matches() const
251  {
252  if (m_matches.size() == 0)
253  {
254  return 0;
255  }
256 
257  return (*m_matches.begin()).second.size();
258  }
259 
260  size_t get_number_of_bound_labels() const { return m_matches.size(); }
261  /// \brief Tries to match a pattern for an individual cell to a given \p graph
262  bool match(Output<Node> graph);
263 
264  std::shared_ptr<Node> get_match_root() { return m_match_root.get_node_shared_ptr(); }
265  Output<Node> get_match_value() { return m_match_root; }
266 
267  private:
268  Output<Node> m_initial_pattern;
269  Output<Node> m_pattern;
270  std::shared_ptr<Node> m_recurrent_pattern;
271  const std::set<std::shared_ptr<Node>> m_correlated_patterns;
272  RPatternValueMap m_matches;
273  Output<Node> m_match_root;
274  };
275  } // namespace pattern
276 } // namespace ngraph
Definition: node.hpp:127
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Base error for ngraph runtime errors.
Definition: except.hpp:16
Definition: matcher.hpp:31
Definition: matcher.hpp:63
Matcher(const Output< Node > &pattern_node, const std::string &name, bool strict_mode)
Constructs a Matcher object.
Definition: matcher.hpp:87
bool match(const Output< Node > &graph_value, const PatternMap &previous_matches)
Matches a pattern to graph_node.
MatcherState start_match()
Try a match.
size_t add_node(Output< Node > node)
Low-level helper to match recurring patterns.
bool match(const Output< Node > &graph_value)
Matches a pattern to graph_node.
Definition: matcher.hpp:189
RecurrentMatcher(const Output< Node > &pattern, const std::shared_ptr< Node > &rpattern, const std::set< std::shared_ptr< Node >> &correlated_patterns)
Constructs a RecurrentMatcher object. Reccurent Matchers are used to match repeating patterns (e....
Definition: matcher.hpp:219
bool match(Output< Node > graph)
Tries to match a pattern for an individual cell to a given graph.
NodeVector get_bound_nodes_for_pattern(const std::shared_ptr< Node > &pattern) const
Returns a vector of bound nodes for a given label (used in a pattern describing an individual cell.
Definition: matcher.hpp:240
RecurrentMatcher(const Output< Node > &initial_pattern, const Output< Node > &pattern, const std::shared_ptr< Node > &rpattern, const std::set< std::shared_ptr< Node >> &correlated_patterns)
Constructs a RecurrentMatcher object. Reccurent Matchers are used to match repeating patterns (e....
Definition: matcher.hpp:200
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16