graph_rewrite.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <functional>
8 #include <memory>
9 #include <set>
10 
11 #include "ngraph/pass/pass.hpp"
12 #include "ngraph/pattern/matcher.hpp"
13 
14 namespace ngraph
15 {
16  using matcher_pass_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
17  using graph_rewrite_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
18  using recurrent_graph_rewrite_callback =
19  std::function<bool(ngraph::pattern::RecurrentMatcher& m)>;
20  using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
21  namespace pass
22  {
23  /// \brief MatcherPass is a basic block for pattern based transformations. It describes
24  /// pattern and
25  /// action that is applied if pattern is matched.
26  ///
27  /// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented
28  /// and
29  /// finally registered by using \sa register_matcher. MatcherPass can be executed on node
30  /// within
31  /// \sa apply method. To run matcher pass on Function use GraphRewrite.
32  /// In addition MatcherPass provides a way for adding new operations into GraphRewrite
33  /// execution
34  /// queue. That means that operations that were created inside transformation callback can
35  /// be added
36  /// for matching. To register node use \sa register_new_node method. GraphRewrite
37  /// automatically
38  /// takes registered nodes and put them to execution queue. If multiple nodes were register
39  /// make
40  /// sure that they were registered in topological order.
41  /// Note: when implementing pattern for Matcher make sure that root node is an operation
42  /// from opset
43  /// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
44  /// passes more
45  /// efficient.
46 
47  class NGRAPH_API MatcherPass : public ngraph::pass::PassBase
48  {
49  public:
50  NGRAPH_RTTI_DECLARATION;
51 
52  MatcherPass() = default;
53 
54  MatcherPass(const MatcherPass&) = delete;
55  MatcherPass& operator=(const MatcherPass&) = delete;
56 
57  explicit MatcherPass(
58  const std::string& name,
59  const std::shared_ptr<pattern::Matcher>& m,
60  const handler_callback& handler,
61  const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
62  : PassBase()
63  , m_handler(handler)
64  , m_matcher(m)
65  {
66  set_name(name);
67  set_property(property, true);
68  }
69 
70  bool apply(std::shared_ptr<ngraph::Node> node);
71 
72  template <typename T, class... Args>
73  std::shared_ptr<T> register_new_node(Args&&... args)
74  {
75  auto node = std::make_shared<T>(std::forward<Args>(args)...);
76  m_new_nodes.push_back(node);
77  return node;
78  }
79 
80  const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes()
81  {
82  return m_new_nodes;
83  }
84  void clear_new_nodes() { m_new_nodes.clear(); }
85  std::shared_ptr<pattern::Matcher> get_matcher() { return m_matcher; }
86 
87  protected:
88  void register_matcher(
89  const std::shared_ptr<pattern::Matcher>& m,
90  const ngraph::graph_rewrite_callback& callback,
91  const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
92 
93  private:
94  handler_callback m_handler;
95  std::shared_ptr<pattern::Matcher> m_matcher;
96  std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
97  };
98 
99  /// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function
100  /// in
101  /// efficient way
102  ///
103  /// Graph rewrite pass is used for matcher passes execution on Function.
104  /// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass
105  /// class.
106  /// As a default algorithm graph rewrite pass traverse Function in topological order and
107  /// applies
108  /// registered matcher passes for each node. But if all registered matcher passes have type
109  /// based
110  /// root node in Matcher pattern then efficient mechanism is used to execute them.
111  /// Matcher pattern root is type based if it's operation from opset or
112  /// pattern::op::WrapType.
113  /// Note: when implementing pattern for Matcher make sure that root node is an operation
114  /// from opset
115  /// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
116  /// passes more
117  /// efficient.
118 
119  class NGRAPH_API GraphRewrite : public ngraph::pass::FunctionPass
120  {
121  public:
122  NGRAPH_RTTI_DECLARATION;
123 
124  GraphRewrite() = default;
125 
126  explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass)
127  : FunctionPass()
128  {
129  m_matchers.push_back(pass);
130  }
131 
132  /// \brief Register given transformation class type to GraphRewrite execution list
133  /// All registered transformations will be executed in a single graph traversal.
134  /// Example below show the basic usage of pass::GraphRewrite
135  ///
136  /// pass::Manager manager;
137  /// auto anchor = manager.register_pass<GraphRewrite>();
138  /// anchor->add_matcher<MatcherPassA>();
139  /// anchor->add_matcher<MatcherPassB>();
140  /// anchor->set_name("CommonMatchers");
141  /// manager.run_passes(f);
142  ///
143  /// For some purposes transformation can be registered and disabled by default.
144  ///
145  /// anchor->add_matcher<MatcherPassB, false>();
146  ///
147  /// \return shared_ptr to the transformation instance
148  template <typename T,
149  bool Enabled = true,
150  class... Args,
151  typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value,
152  bool>::type = true>
153  std::shared_ptr<T> add_matcher(Args&&... args)
154  {
155  static_assert(std::is_base_of<pass::MatcherPass, T>::value,
156  "pass not derived from MatcherPass");
157  auto pass = std::make_shared<T>(std::forward<Args>(args)...);
158  auto pass_config = get_pass_config();
159  pass->set_pass_config(pass_config);
160  if (!Enabled && !pass_config->is_enabled<T>())
161  {
162  pass_config->disable<T>();
163  }
164  m_matchers.push_back(pass);
165  return pass;
166  }
167 
168  /// \brief Register passes from GraphRewrite class that contains sequence of matcher
169  /// passes registered in its ctor.
170  /// For example:
171  ///
172  /// class ngraph::pass::LinFusions: public ngraph::pass::GraphRewrite {
173  /// public:
174  /// NGRAPH_RTTI_DECLARATION;
175  /// Fusions() {
176  /// add_matcher<ngraph::pass::AddFusion>();
177  /// add_matcher<ngraph::pass::MulFusion>();
178  /// }
179  /// };
180  ///
181  /// pass::Manager manager;
182  /// auto anchor = manager.register_pass<GraphRewrite>();
183  /// anchor->add_matcher<LinFusions>();
184  /// anchor->add_matcher<OtherFusions>();
185  /// anchor->set_name("CommonFusions");
186  /// manager.run_passes(f);
187  ///
188  /// In this case all matcher passes from LinFusions pass will be united with other
189  /// registered matchers.
190  template <typename T,
191  class... Args,
192  typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value,
193  bool>::type = true>
194  void add_matcher(Args&&... args)
195  {
196  static_assert(std::is_base_of<pass::GraphRewrite, T>::value,
197  "pass not derived from GraphRewrite");
198  auto pass = std::make_shared<T>(std::forward<Args>(args)...);
199  auto pass_config = get_pass_config();
200 
201  for (auto& matcher : pass->m_matchers)
202  {
203  pass->set_pass_config(pass_config);
204  m_matchers.push_back(matcher);
205  }
206  }
207 
208  NGRAPH_DEPRECATED("Use MatcherPass instead")
209  void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
210  const ngraph::graph_rewrite_callback& callback,
211  const PassPropertyMask& property);
212 
213  NGRAPH_DEPRECATED("Use MatcherPass instead")
214  void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
215  const ngraph::graph_rewrite_callback& callback);
216 
217  bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
218 
219  void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
220 
221  protected:
222  bool apply_matcher_passes(std::shared_ptr<Function> f,
223  std::deque<std::shared_ptr<Node>> nodes_to_run);
224 
225  bool m_enable_shape_inference = false;
226 
227  std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
228  };
229 
230  class NGRAPH_API BackwardGraphRewrite : public ngraph::pass::GraphRewrite
231  {
232  public:
233  NGRAPH_RTTI_DECLARATION;
234 
235  BackwardGraphRewrite() = default;
236 
237  explicit BackwardGraphRewrite(const std::shared_ptr<MatcherPass>& pass)
238  : GraphRewrite(pass)
239  {
240  }
241 
242  bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
243  };
244 
246  {
247  public:
248  RecurrentGraphRewrite(size_t num_iters = 10)
249  : FunctionPass()
250  , m_num_iters(num_iters)
251  {
252  }
253 
254  void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
255  const ngraph::recurrent_graph_rewrite_callback& callback,
256  const PassPropertyMask& property);
257 
258  // TODO: This interface may deprecate after all passes are refactored.
259  void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
260  const ngraph::recurrent_graph_rewrite_callback& callback);
261 
262  virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
263 
264  private:
265  size_t m_num_iters;
266 
267  std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
268  };
269  } // namespace pass
270 } // namespace ngraph
A user-defined function.
Definition: function.hpp:27
Definition: node.hpp:127
Definition: graph_rewrite.hpp:231
Definition: pass.hpp:94
GraphRewrite is a container for MatcherPasses that allows to run them on Function in efficient way.
Definition: graph_rewrite.hpp:120
void add_matcher(Args &&... args)
Register passes from GraphRewrite class that contains sequence of matcher passes registered in its ct...
Definition: graph_rewrite.hpp:194
std::shared_ptr< T > add_matcher(Args &&... args)
Register given transformation class type to GraphRewrite execution list All registered transformation...
Definition: graph_rewrite.hpp:153
MatcherPass is a basic block for pattern based transformations. It describes pattern and action that ...
Definition: graph_rewrite.hpp:48
Definition: pass.hpp:33
Class representing a transformations config that is used for disabling/enabling transformations regis...
Definition: pass_config.hpp:59
Definition: graph_rewrite.hpp:246
Definition: matcher.hpp:63
Definition: matcher.hpp:189
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16