23 #include "ngraph/pass/pass.hpp"
24 #include "ngraph/pattern/matcher.hpp"
30 using recurrent_graph_rewrite_callback =
32 using handler_callback = std::function<bool(
const std::shared_ptr<Node>& node)>;
62 NGRAPH_RTTI_DECLARATION;
70 const std::string& name,
71 const std::shared_ptr<pattern::Matcher>& m,
72 const handler_callback& handler,
79 set_property(property,
true);
82 bool apply(std::shared_ptr<ngraph::Node> node);
84 template <
typename T,
class... Args>
85 std::shared_ptr<T> register_new_node(Args&&... args)
87 auto node = std::make_shared<T>(std::forward<Args>(args)...);
88 m_new_nodes.push_back(node);
92 const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes()
96 void clear_new_nodes() { m_new_nodes.clear(); }
97 std::shared_ptr<pattern::Matcher> get_matcher() {
return m_matcher; }
99 void register_matcher(
100 const std::shared_ptr<pattern::Matcher>& m,
101 const ngraph::graph_rewrite_callback& callback,
105 handler_callback m_handler;
106 std::shared_ptr<pattern::Matcher> m_matcher;
107 std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
133 NGRAPH_RTTI_DECLARATION;
137 explicit GraphRewrite(
const std::shared_ptr<MatcherPass>& pass)
140 m_matchers.push_back(pass);
159 template <
typename T,
162 typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value,
166 static_assert(std::is_base_of<pass::MatcherPass, T>::value,
167 "pass not derived from MatcherPass");
168 auto pass = std::make_shared<T>(std::forward<Args>(args)...);
169 auto pass_config = get_pass_config();
170 pass->set_pass_config(pass_config);
171 if (!Enabled && !pass_config->is_enabled<T>())
173 pass_config->disable<T>();
175 m_matchers.push_back(pass);
201 template <
typename T,
203 typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value,
207 static_assert(std::is_base_of<pass::GraphRewrite, T>::value,
208 "pass not derived from GraphRewrite");
209 auto pass = std::make_shared<T>(std::forward<Args>(args)...);
210 auto pass_config = get_pass_config();
212 for (
auto& matcher : pass->m_matchers)
214 pass->set_pass_config(pass_config);
215 m_matchers.push_back(matcher);
219 NGRAPH_DEPRECATED(
"Use MatcherPass instead")
220 void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
221 const
ngraph::graph_rewrite_callback& callback,
225 void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
226 const
ngraph::graph_rewrite_callback& callback);
228 bool run_on_function(std::shared_ptr<
ngraph::
Function> f) override;
230 void set_pass_config(const std::shared_ptr<
PassConfig>& pass_config) override;
233 bool m_enable_shape_inference = false;
243 , m_num_iters(num_iters)
247 void add_matcher(
const std::shared_ptr<pattern::RecurrentMatcher>& m,
248 const ngraph::recurrent_graph_rewrite_callback& callback,
252 void add_matcher(
const std::shared_ptr<pattern::RecurrentMatcher>& m,
253 const ngraph::recurrent_graph_rewrite_callback& callback);
255 virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
260 std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
A user-defined function.
Definition: function.hpp:36
GraphRewrite is a container for MatcherPasses that allows to run them on Function in efficient way.
Definition: graph_rewrite.hpp:131
void add_matcher(Args &&... args)
Register passes from GraphRewrite class that contains sequence of matcher passes registered in its ct...
Definition: graph_rewrite.hpp:205
std::shared_ptr< T > add_matcher(Args &&... args)
Register given transformation class type to GraphRewrite execution list All registered transformation...
Definition: graph_rewrite.hpp:164
MatcherPass is a basic block for pattern based transformations. It describes pattern and action that ...
Definition: graph_rewrite.hpp:60
Class representing a transformations config that is used for disabling/enabling transformations regis...
Definition: pass_config.hpp:71
Definition: graph_rewrite.hpp:239
Definition: matcher.hpp:75
Definition: matcher.hpp:201
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28