matcher.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <algorithm>
20 #include <functional>
21 #include <memory.h>
22 
23 #include "ngraph/node.hpp"
24 #include "ngraph/op/constant.hpp"
25 #include "ngraph/pattern/op/any.hpp"
26 #include "ngraph/pattern/op/any_of.hpp"
27 #include "ngraph/pattern/op/any_output.hpp"
28 #include "ngraph/pattern/op/label.hpp"
29 #include "ngraph/pattern/op/skip.hpp"
30 
31 namespace ngraph
32 {
33  namespace pass
34  {
35  class GraphRewrite;
36  }
37 
38  namespace pattern
39  {
40  class Matcher;
41 
42  class NGRAPH_API MatcherState
43  {
44  public:
46  bool finish(bool is_successful);
47  ~MatcherState();
48 
49  protected:
50  Matcher* m_matcher;
51  PatternValueMap m_pattern_value_map;
52  PatternValueMaps m_pattern_value_maps;
53  size_t m_watermark;
54  size_t m_capture_size;
55  bool m_restore{true};
56  };
57 
58  /// Matcher looks for node patterns in a computation graph. The patterns are described by an
59  /// automaton that is described by an extended computation graph. The matcher executes
60  /// by attempting to match the start node of the pattern to a computation graph value
61  /// (output of a Node). In addition to determing if a match occurs, a pattern node may add
62  /// graph nodes to a list of matched nodes, associate nodes with graph values, and start
63  /// submatches. Submatches add match state changes to the enclosing match if the submatch
64  /// succeeds; otherwise the state is reverted.
65  ///
66  /// The default match behavior of a pattern node with a graph nodes is that the computation
67  /// graph value is added to the end of the matched value list and the match succeeds if the
68  /// node/pattern types match and the input values match. In the case of a commutative node,
69  /// the inputs can match in any order. If the matcher is in strict mode, the graph value
70  /// element type and shape must also match.
71  ///
72  /// Pattern nodes that have different match behavior are in ngraph::pattern::op and have
73  /// descriptions of their match behavior.
74  class NGRAPH_API Matcher
75  {
76  public:
77  using PatternMap = ngraph::pattern::PatternMap;
78 
79  // Avoid implicit string construction from nullptr.
80  Matcher(const std::shared_ptr<Node> pattern_node, std::nullptr_t name) = delete;
81 
82  Matcher() {}
83  Matcher(Output<Node>& pattern_node)
84  : m_pattern_node{pattern_node}
85  {
86  }
87 
88  Matcher(Output<Node>& pattern_node, const std::string& name)
89  : m_pattern_node(pattern_node)
90  , m_name{name}
91  {
92  }
93 
94  /// \brief Constructs a Matcher object
95  ///
96  /// \param pattern_node is a pattern sub graph that will be matched against input graphs
97  /// \param name is a string which is used for logging and disabling a matcher
98  /// \param strict_mode forces a matcher to consider shapes and ET of nodes
99  Matcher(const Output<Node>& pattern_node, const std::string& name, bool strict_mode)
100  : m_pattern_node(pattern_node)
101  , m_name(name)
102  , m_strict_mode(strict_mode)
103  {
104  }
105 
106  // Some matches should start on a node rather than an output. These three constructors
107  // are transition until we work out the right way to do that.
108  Matcher(std::shared_ptr<Node> pattern_node);
109  Matcher(std::shared_ptr<Node> pattern_node, const std::string& name);
110  Matcher(std::shared_ptr<Node> pattern_node, const std::string& name, bool strict_mode);
111 
112  virtual ~Matcher() {}
113  /// \brief Matches a pattern to \p graph_node
114  ///
115  /// \param graph_value is an input graph to be matched against
116  bool match(const Output<Node>& graph_value);
117 
118  bool match(std::shared_ptr<Node> graph_node);
119 
120  /// \brief Matches a pattern to \p graph_node
121  ///
122  /// \param graph_value is an input graph to be matched against
123  /// \param previous_matches contains previous mappings from labels to nodes to use
124  bool match(const Output<Node>& graph_value, const PatternMap& previous_matches);
125  bool match(const Output<Node>& graph_value, const PatternValueMap& previous_matches);
126 
127  template <typename T>
128  static std::shared_ptr<T> unique_match(std::shared_ptr<Node> node)
129  {
130  std::shared_ptr<T> matched;
131  for (auto arg : node->input_values())
132  {
133  if (auto t_casted = as_type_ptr<T>(arg.get_node_shared_ptr()))
134  {
135  if (matched)
136  {
137  throw ngraph_error("There's more than two arguments of the same type");
138  }
139  else
140  {
141  matched = t_casted;
142  }
143  }
144  }
145  return matched;
146  }
147 
148  bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true);
149  const NodeVector get_matched_nodes() { return as_node_vector(m_matched_list); }
150  const OutputVector& get_matched_values() const { return m_matched_list; }
151  OutputVector& get_matched_values() { return m_matched_list; }
152  void reset() {}
153  const std::string& get_name() { return m_name; }
154  std::shared_ptr<Node> get_pattern() { return m_pattern_node.get_node_shared_ptr(); }
155  Output<Node> get_pattern_value() { return m_pattern_node; }
156  std::shared_ptr<Node> get_match_root();
157  Output<Node> get_match_value();
158  PatternMap get_pattern_map() const;
159  PatternValueMap& get_pattern_value_map() { return m_pattern_map; }
160  PatternValueMaps& get_pattern_value_maps() { return m_pattern_value_maps; }
161  /// \brief Low-level helper to match recurring patterns
162  ///
163  /// \param graph is a graph to be matched against
164  /// \param pattern is a recurring pattern
165  /// \param rpattern specifies a node to recur from next
166  /// \param patterns a map from labels to matches
167 
168  size_t add_node(Output<Node> node);
169 
170  bool virtual match_value(const ngraph::Output<Node>& pattern_value,
171  const ngraph::Output<Node>& graph_value);
172 
173  bool is_strict_mode() { return m_strict_mode; }
174  virtual bool match_arguments(Node* pattern_node,
175  const std::shared_ptr<Node>& graph_node);
176 
177  void capture(const std::set<Node*>& static_nodes);
178 
179  void clear_state();
180 
181  size_t get_number_of_recurrent_matches() const { return m_pattern_value_maps.size(); }
182  NodeVector get_bound_nodes_for_pattern(const Output<Node>& pattern) const;
183  size_t get_number_of_bound_labels() const;
184  /// \brief Try a match
186 
187  Output<Node> m_match_root;
188  Output<Node> m_pattern_node;
189  PatternValueMap m_pattern_map;
190  PatternValueMaps m_pattern_value_maps;
191  OutputVector m_matched_list;
192 
193  protected:
194  bool match_permutation(const OutputVector& pattern_args, const OutputVector& args);
195 
196  std::string m_name{"unnamed"};
197  bool m_strict_mode{false};
198  };
199 
200  class NGRAPH_API RecurrentMatcher
201  {
202  public:
203  /// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
204  /// repeating patterns (e.g. RNN, LSTM, GRU cells)
205  ///
206  /// \param initial_pattern is a pattern sub graph describing the initial cell
207  /// \param pattern is a pattern sub graph describing an individual cell
208  /// \param rpattern is a (recurring) label to denote which node the next match should
209  /// start at
210  /// \param correlated_patterns is a set of labels whose bound nodes must remain the same
211  /// across all cells
212  RecurrentMatcher(const Output<Node>& initial_pattern,
213  const Output<Node>& pattern,
214  const std::shared_ptr<Node>& rpattern,
215  const std::set<std::shared_ptr<Node>>& correlated_patterns)
216  : m_initial_pattern(initial_pattern)
217  , m_pattern(pattern)
218  , m_recurrent_pattern(rpattern)
219  , m_correlated_patterns(correlated_patterns)
220  {
221  }
222 
223  /// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
224  /// repeating patterns (e.g. RNN, LSTM, GRU cells)
225  ///
226  /// \param pattern is a pattern sub graph describing an individual cell
227  /// \param rpattern is a (recurring) label to denote which node the next match should
228  /// start at
229  /// \param correlated_patterns is a set of labels whose bound nodes must remain the same
230  /// across all cells
231  RecurrentMatcher(const Output<Node>& pattern,
232  const std::shared_ptr<Node>& rpattern,
233  const std::set<std::shared_ptr<Node>>& correlated_patterns)
234  : RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns)
235  {
236  }
237 
238  RecurrentMatcher(const Output<Node>& initial_pattern,
239  const Output<Node>& pattern,
240  const std::shared_ptr<Node>& rpattern,
241  const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
242 
243  RecurrentMatcher(const Output<Node>& pattern,
244  const std::shared_ptr<Node>& rpattern,
245  const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
246  : RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns)
247  {
248  }
249 
250  /// \brief Returns a vector of bound nodes for a given label (used in a pattern
251  /// describing an individual cell
252  NodeVector get_bound_nodes_for_pattern(const std::shared_ptr<Node>& pattern) const
253  {
254  if (m_matches.count(pattern) == 0)
255  {
256  throw ngraph_error("No bound nodes for a given label");
257  }
258 
259  return as_node_vector(m_matches.at(pattern));
260  }
261 
262  size_t get_number_of_recurrent_matches() const
263  {
264  if (m_matches.size() == 0)
265  {
266  return 0;
267  }
268 
269  return (*m_matches.begin()).second.size();
270  }
271 
272  size_t get_number_of_bound_labels() const { return m_matches.size(); }
273  /// \brief Tries to match a pattern for an individual cell to a given \p graph
274  bool match(Output<Node> graph);
275 
276  std::shared_ptr<Node> get_match_root() { return m_match_root.get_node_shared_ptr(); }
277  Output<Node> get_match_value() { return m_match_root; }
278  private:
279  Output<Node> m_initial_pattern;
280  Output<Node> m_pattern;
281  std::shared_ptr<Node> m_recurrent_pattern;
282  const std::set<std::shared_ptr<Node>> m_correlated_patterns;
283  RPatternValueMap m_matches;
284  Output<Node> m_match_root;
285  };
286  }
287 }
ngraph::pattern::RecurrentMatcher::get_bound_nodes_for_pattern
NodeVector get_bound_nodes_for_pattern(const std::shared_ptr< Node > &pattern) const
Returns a vector of bound nodes for a given label (used in a pattern describing an individual cell.
Definition: matcher.hpp:252
ngraph::pattern::RecurrentMatcher
Definition: matcher.hpp:201
ngraph::pattern::MatcherState
Definition: matcher.hpp:43
ngraph::pattern::Matcher::match
bool match(const Output< Node > &graph_value, const PatternMap &previous_matches)
Matches a pattern to graph_node.
ngraph::pattern::Matcher
Definition: matcher.hpp:75
ngraph::pattern::RecurrentMatcher::match
bool match(Output< Node > graph)
Tries to match a pattern for an individual cell to a given graph.
ngraph::pattern::Matcher::add_node
size_t add_node(Output< Node > node)
Low-level helper to match recurring patterns.
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::pattern::Matcher::Matcher
Matcher(const Output< Node > &pattern_node, const std::string &name, bool strict_mode)
Constructs a Matcher object.
Definition: matcher.hpp:99
ngraph::pattern::RecurrentMatcher::RecurrentMatcher
RecurrentMatcher(const Output< Node > &pattern, const std::shared_ptr< Node > &rpattern, const std::set< std::shared_ptr< Node >> &correlated_patterns)
Constructs a RecurrentMatcher object. Reccurent Matchers are used to match repeating patterns (e....
Definition: matcher.hpp:231
ngraph::Output
Definition: node_output.hpp:35
ngraph::pattern::Matcher::start_match
MatcherState start_match()
Try a match.
ngraph::pattern::RecurrentMatcher::RecurrentMatcher
RecurrentMatcher(const Output< Node > &initial_pattern, const Output< Node > &pattern, const std::shared_ptr< Node > &rpattern, const std::set< std::shared_ptr< Node >> &correlated_patterns)
Constructs a RecurrentMatcher object. Reccurent Matchers are used to match repeating patterns (e....
Definition: matcher.hpp:212
ngraph::pattern::Matcher::match
bool match(const Output< Node > &graph_value)
Matches a pattern to graph_node.