11 #include "ngraph/pass/pass.hpp"
12 #include "ngraph/pattern/matcher.hpp"
18 using recurrent_graph_rewrite_callback =
20 using handler_callback = std::function<bool(
const std::shared_ptr<Node>& node)>;
50 NGRAPH_RTTI_DECLARATION;
58 const std::string& name,
59 const std::shared_ptr<pattern::Matcher>& m,
60 const handler_callback& handler,
67 set_property(property,
true);
70 bool apply(std::shared_ptr<ngraph::Node> node);
72 template <
typename T,
class... Args>
73 std::shared_ptr<T> register_new_node(Args&&... args)
75 auto node = std::make_shared<T>(std::forward<Args>(args)...);
76 m_new_nodes.push_back(node);
80 const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes()
84 void clear_new_nodes() { m_new_nodes.clear(); }
85 std::shared_ptr<pattern::Matcher> get_matcher() {
return m_matcher; }
88 void register_matcher(
89 const std::shared_ptr<pattern::Matcher>& m,
90 const ngraph::graph_rewrite_callback& callback,
94 handler_callback m_handler;
95 std::shared_ptr<pattern::Matcher> m_matcher;
96 std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
122 NGRAPH_RTTI_DECLARATION;
126 explicit GraphRewrite(
const std::shared_ptr<MatcherPass>& pass)
129 m_matchers.push_back(pass);
148 template <
typename T,
151 typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value,
155 static_assert(std::is_base_of<pass::MatcherPass, T>::value,
156 "pass not derived from MatcherPass");
157 auto pass = std::make_shared<T>(std::forward<Args>(args)...);
158 auto pass_config = get_pass_config();
159 pass->set_pass_config(pass_config);
160 if (!Enabled && !pass_config->is_enabled<T>())
162 pass_config->disable<T>();
164 m_matchers.push_back(pass);
190 template <
typename T,
192 typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value,
196 static_assert(std::is_base_of<pass::GraphRewrite, T>::value,
197 "pass not derived from GraphRewrite");
198 auto pass = std::make_shared<T>(std::forward<Args>(args)...);
199 auto pass_config = get_pass_config();
201 for (
auto& matcher : pass->m_matchers)
203 pass->set_pass_config(pass_config);
204 m_matchers.push_back(matcher);
208 NGRAPH_DEPRECATED(
"Use MatcherPass instead")
209 void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
210 const
ngraph::graph_rewrite_callback& callback,
214 void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
215 const
ngraph::graph_rewrite_callback& callback);
217 bool run_on_function(std::shared_ptr<
ngraph::
Function> f) override;
219 void set_pass_config(const std::shared_ptr<
PassConfig>& pass_config) override;
222 bool apply_matcher_passes(std::shared_ptr<
Function> f,
223 std::deque<std::shared_ptr<
Node>> nodes_to_run);
225 bool m_enable_shape_inference = false;
233 NGRAPH_RTTI_DECLARATION;
242 bool run_on_function(std::shared_ptr<ngraph::Function> f)
override;
250 , m_num_iters(num_iters)
254 void add_matcher(
const std::shared_ptr<pattern::RecurrentMatcher>& m,
255 const ngraph::recurrent_graph_rewrite_callback& callback,
259 void add_matcher(
const std::shared_ptr<pattern::RecurrentMatcher>& m,
260 const ngraph::recurrent_graph_rewrite_callback& callback);
262 virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
267 std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
A user-defined function.
Definition: function.hpp:27
Definition: graph_rewrite.hpp:231
GraphRewrite is a container for MatcherPasses that allows to run them on Function in efficient way.
Definition: graph_rewrite.hpp:120
void add_matcher(Args &&... args)
Register passes from GraphRewrite class that contains sequence of matcher passes registered in its ct...
Definition: graph_rewrite.hpp:194
std::shared_ptr< T > add_matcher(Args &&... args)
Register given transformation class type to GraphRewrite execution list All registered transformation...
Definition: graph_rewrite.hpp:153
MatcherPass is a basic block for pattern based transformations. It describes pattern and action that ...
Definition: graph_rewrite.hpp:48
Class representing a transformations config that is used for disabling/enabling transformations regis...
Definition: pass_config.hpp:59
Definition: graph_rewrite.hpp:246
Definition: matcher.hpp:63
Definition: matcher.hpp:189
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16