23 #include "ngraph/node.hpp"
24 #include "ngraph/op/constant.hpp"
25 #include "ngraph/pattern/op/any.hpp"
26 #include "ngraph/pattern/op/any_of.hpp"
27 #include "ngraph/pattern/op/any_output.hpp"
28 #include "ngraph/pattern/op/label.hpp"
29 #include "ngraph/pattern/op/skip.hpp"
46 bool finish(
bool is_successful);
51 PatternValueMap m_pattern_value_map;
52 PatternValueMaps m_pattern_value_maps;
54 size_t m_capture_size;
77 using PatternMap = ngraph::pattern::PatternMap;
80 Matcher(
const std::shared_ptr<Node> pattern_node, std::nullptr_t name) =
delete;
83 Matcher(Output<Node>& pattern_node)
84 : m_pattern_node{pattern_node}
88 Matcher(Output<Node>& pattern_node,
const std::string& name)
89 : m_pattern_node(pattern_node)
99 Matcher(
const Output<Node>& pattern_node,
const std::string& name,
bool strict_mode)
100 : m_pattern_node(pattern_node)
102 , m_strict_mode(strict_mode)
108 Matcher(std::shared_ptr<Node> pattern_node);
109 Matcher(std::shared_ptr<Node> pattern_node,
const std::string& name);
110 Matcher(std::shared_ptr<Node> pattern_node,
const std::string& name,
bool strict_mode);
116 bool match(
const Output<Node>& graph_value);
118 bool match(std::shared_ptr<Node> graph_node);
124 bool match(
const Output<Node>& graph_value,
const PatternMap& previous_matches);
125 bool match(
const Output<Node>& graph_value,
const PatternValueMap& previous_matches);
127 template <
typename T>
128 static std::shared_ptr<T> unique_match(std::shared_ptr<Node> node)
130 std::shared_ptr<T> matched;
131 for (
auto arg : node->input_values())
133 if (
auto t_casted = as_type_ptr<T>(arg.get_node_shared_ptr()))
137 throw ngraph_error(
"There's more than two arguments of the same type");
148 bool is_contained_match(
const NodeVector& exclusions = {},
bool ignore_unused =
true);
149 const NodeVector get_matched_nodes() {
return as_node_vector(m_matched_list); }
150 const OutputVector& get_matched_values()
const {
return m_matched_list; }
151 OutputVector& get_matched_values() {
return m_matched_list; }
153 const std::string& get_name() {
return m_name; }
154 std::shared_ptr<Node> get_pattern() {
return m_pattern_node.get_node_shared_ptr(); }
155 Output<Node> get_pattern_value() {
return m_pattern_node; }
156 std::shared_ptr<Node> get_match_root();
157 Output<Node> get_match_value();
158 PatternMap get_pattern_map()
const;
159 PatternValueMap& get_pattern_value_map() {
return m_pattern_map; }
160 PatternValueMaps& get_pattern_value_maps() {
return m_pattern_value_maps; }
173 bool is_strict_mode() {
return m_strict_mode; }
174 virtual bool match_arguments(Node* pattern_node,
175 const std::shared_ptr<Node>& graph_node);
177 void capture(
const std::set<Node*>& static_nodes);
181 size_t get_number_of_recurrent_matches()
const {
return m_pattern_value_maps.size(); }
182 NodeVector get_bound_nodes_for_pattern(
const Output<Node>& pattern)
const;
183 size_t get_number_of_bound_labels()
const;
187 Output<Node> m_match_root;
188 Output<Node> m_pattern_node;
189 PatternValueMap m_pattern_map;
190 PatternValueMaps m_pattern_value_maps;
191 OutputVector m_matched_list;
194 bool match_permutation(
const OutputVector& pattern_args,
const OutputVector& args);
196 std::string m_name{
"unnamed"};
197 bool m_strict_mode{
false};
213 const Output<Node>& pattern,
214 const std::shared_ptr<Node>& rpattern,
215 const std::set<std::shared_ptr<Node>>& correlated_patterns)
216 : m_initial_pattern(initial_pattern)
218 , m_recurrent_pattern(rpattern)
219 , m_correlated_patterns(correlated_patterns)
232 const std::shared_ptr<Node>& rpattern,
233 const std::set<std::shared_ptr<Node>>& correlated_patterns)
239 const Output<Node>& pattern,
240 const std::shared_ptr<Node>& rpattern,
241 const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
244 const std::shared_ptr<Node>& rpattern,
245 const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
254 if (m_matches.count(pattern) == 0)
256 throw ngraph_error(
"No bound nodes for a given label");
259 return as_node_vector(m_matches.at(pattern));
262 size_t get_number_of_recurrent_matches()
const
264 if (m_matches.size() == 0)
269 return (*m_matches.begin()).second.size();
272 size_t get_number_of_bound_labels()
const {
return m_matches.size(); }
276 std::shared_ptr<Node> get_match_root() {
return m_match_root.get_node_shared_ptr(); }
277 Output<Node> get_match_value() {
return m_match_root; }
279 Output<Node> m_initial_pattern;
280 Output<Node> m_pattern;
281 std::shared_ptr<Node> m_recurrent_pattern;
282 const std::set<std::shared_ptr<Node>> m_correlated_patterns;
283 RPatternValueMap m_matches;
284 Output<Node> m_match_root;