graph_rewrite.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <functional>
20 #include <memory>
21 #include <set>
22 
23 #include "ngraph/pass/pass.hpp"
24 #include "ngraph/pattern/matcher.hpp"
25 
26 namespace ngraph
27 {
28  using matcher_pass_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
29  using graph_rewrite_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
30  using recurrent_graph_rewrite_callback =
31  std::function<bool(ngraph::pattern::RecurrentMatcher& m)>;
32  using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
33  namespace pass
34  {
35  /// \brief MatcherPass is a basic block for pattern based transformations. It describes
36  /// pattern and
37  /// action that is applied if pattern is matched.
38  ///
39  /// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented
40  /// and
41  /// finally registered by using \sa register_matcher. MatcherPass can be executed on node
42  /// within
43  /// \sa apply method. To run matcher pass on Function use GraphRewrite.
44  /// In addition MatcherPass provides a way for adding new operations into GraphRewrite
45  /// execution
46  /// queue. That means that operations that were created inside transformation callback can
47  /// be added
48  /// for matching. To register node use \sa register_new_node method. GraphRewrite
49  /// automatically
50  /// takes registered nodes and put them to execution queue. If multiple nodes were register
51  /// make
52  /// sure that they were registered in topological order.
53  /// Note: when implementing pattern for Matcher make sure that root node is an operation
54  /// from opset
55  /// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
56  /// passes more
57  /// efficient.
58 
59  class NGRAPH_API MatcherPass : public ngraph::pass::PassBase
60  {
61  public:
62  NGRAPH_RTTI_DECLARATION;
63 
64  MatcherPass() = default;
65 
66  MatcherPass(const MatcherPass&) = delete;
67  MatcherPass& operator=(const MatcherPass&) = delete;
68 
69  explicit MatcherPass(
70  const std::string& name,
71  const std::shared_ptr<pattern::Matcher>& m,
72  const handler_callback& handler,
73  const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
74  : PassBase()
75  , m_handler(handler)
76  , m_matcher(m)
77  {
78  set_name(name);
79  set_property(property, true);
80  }
81 
82  bool apply(std::shared_ptr<ngraph::Node> node);
83 
84  template <typename T, class... Args>
85  std::shared_ptr<T> register_new_node(Args&&... args)
86  {
87  auto node = std::make_shared<T>(std::forward<Args>(args)...);
88  m_new_nodes.push_back(node);
89  return node;
90  }
91 
92  const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes()
93  {
94  return m_new_nodes;
95  }
96  void clear_new_nodes() { m_new_nodes.clear(); }
97  std::shared_ptr<pattern::Matcher> get_matcher() { return m_matcher; }
98  protected:
99  void register_matcher(
100  const std::shared_ptr<pattern::Matcher>& m,
101  const ngraph::graph_rewrite_callback& callback,
102  const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
103 
104  private:
105  handler_callback m_handler;
106  std::shared_ptr<pattern::Matcher> m_matcher;
107  std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
108  };
109 
110  /// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function
111  /// in
112  /// efficient way
113  ///
114  /// Graph rewrite pass is used for matcher passes execution on Function.
115  /// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass
116  /// class.
117  /// As a default algorithm graph rewrite pass traverse Function in topological order and
118  /// applies
119  /// registered matcher passes for each node. But if all registered matcher passes have type
120  /// based
121  /// root node in Matcher pattern then efficient mechanism is used to execute them.
122  /// Matcher pattern root is type based if it's operation from opset or
123  /// pattern::op::WrapType.
124  /// Note: when implementing pattern for Matcher make sure that root node is an operation
125  /// from opset
126  /// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
127  /// passes more
128  /// efficient.
129 
130  class NGRAPH_API GraphRewrite : public ngraph::pass::FunctionPass
131  {
132  public:
133  NGRAPH_RTTI_DECLARATION;
134 
135  GraphRewrite() = default;
136 
137  explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass)
138  : FunctionPass()
139  {
140  m_matchers.push_back(pass);
141  }
142 
143  /// \brief Register given transformation class type to GraphRewrite execution list
144  /// All registered transformations will be executed in a single graph traversal.
145  /// Example below show the basic usage of pass::GraphRewrite
146  ///
147  /// pass::Manager manager;
148  /// auto anchor = manager.register_pass<GraphRewrite>();
149  /// anchor->add_matcher<MatcherPassA>();
150  /// anchor->add_matcher<MatcherPassB>();
151  /// anchor->set_name("CommonMatchers");
152  /// manager.run_passes(f);
153  ///
154  /// For some purposes transformation can be registered and disabled by default.
155  ///
156  /// anchor->add_matcher<MatcherPassB, false>();
157  ///
158  /// \return shared_ptr to the transformation instance
159  template <typename T,
160  bool Enabled = true,
161  class... Args,
162  typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value,
163  bool>::type = true>
164  std::shared_ptr<T> add_matcher(Args&&... args)
165  {
166  static_assert(std::is_base_of<pass::MatcherPass, T>::value,
167  "pass not derived from MatcherPass");
168  auto pass = std::make_shared<T>(std::forward<Args>(args)...);
169  auto pass_config = get_pass_config();
170  pass->set_pass_config(pass_config);
171  if (!Enabled && !pass_config->is_enabled<T>())
172  {
173  pass_config->disable<T>();
174  }
175  m_matchers.push_back(pass);
176  return pass;
177  }
178 
179  /// \brief Register passes from GraphRewrite class that contains sequence of matcher
180  /// passes registered in its ctor.
181  /// For example:
182  ///
183  /// class ngraph::pass::LinFusions: public ngraph::pass::GraphRewrite {
184  /// public:
185  /// NGRAPH_RTTI_DECLARATION;
186  /// Fusions() {
187  /// add_matcher<ngraph::pass::AddFusion>();
188  /// add_matcher<ngraph::pass::MulFusion>();
189  /// }
190  /// };
191  ///
192  /// pass::Manager manager;
193  /// auto anchor = manager.register_pass<GraphRewrite>();
194  /// anchor->add_matcher<LinFusions>();
195  /// anchor->add_matcher<OtherFusions>();
196  /// anchor->set_name("CommonFusions");
197  /// manager.run_passes(f);
198  ///
199  /// In this case all matcher passes from LinFusions pass will be united with other
200  /// registered matchers.
201  template <typename T,
202  class... Args,
203  typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value,
204  bool>::type = true>
205  void add_matcher(Args&&... args)
206  {
207  static_assert(std::is_base_of<pass::GraphRewrite, T>::value,
208  "pass not derived from GraphRewrite");
209  auto pass = std::make_shared<T>(std::forward<Args>(args)...);
210  auto pass_config = get_pass_config();
211 
212  for (auto& matcher : pass->m_matchers)
213  {
214  pass->set_pass_config(pass_config);
215  m_matchers.push_back(matcher);
216  }
217  }
218 
219  NGRAPH_DEPRECATED("Use MatcherPass instead")
220  void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
221  const ngraph::graph_rewrite_callback& callback,
222  const PassPropertyMask& property);
223 
224  NGRAPH_DEPRECATED("Use MatcherPass instead")
225  void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
226  const ngraph::graph_rewrite_callback& callback);
227 
228  bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
229 
230  void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
231 
232  protected:
233  bool m_enable_shape_inference = false;
234 
235  std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
236  };
237 
238  class NGRAPH_API RecurrentGraphRewrite : public ngraph::pass::FunctionPass
239  {
240  public:
241  RecurrentGraphRewrite(size_t num_iters = 10)
242  : FunctionPass()
243  , m_num_iters(num_iters)
244  {
245  }
246 
247  void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
248  const ngraph::recurrent_graph_rewrite_callback& callback,
249  const PassPropertyMask& property);
250 
251  // TODO: This interface may deprecate after all passes are refactored.
252  void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
253  const ngraph::recurrent_graph_rewrite_callback& callback);
254 
255  virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
256 
257  private:
258  size_t m_num_iters;
259 
260  std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
261  };
262  } // namespace pass
263 } // namespace ngraph
A user-defined function.
Definition: function.hpp:36
Definition: pass.hpp:106
GraphRewrite is a container for MatcherPasses that allows to run them on Function in efficient way.
Definition: graph_rewrite.hpp:131
void add_matcher(Args &&... args)
Register passes from GraphRewrite class that contains sequence of matcher passes registered in its ct...
Definition: graph_rewrite.hpp:205
std::shared_ptr< T > add_matcher(Args &&... args)
Register given transformation class type to GraphRewrite execution list All registered transformation...
Definition: graph_rewrite.hpp:164
MatcherPass is a basic block for pattern based transformations. It describes pattern and action that ...
Definition: graph_rewrite.hpp:60
Definition: pass.hpp:45
Class representing a transformations config that is used for disabling/enabling transformations regis...
Definition: pass_config.hpp:71
Definition: graph_rewrite.hpp:239
Definition: matcher.hpp:75
Definition: matcher.hpp:201
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28