23 #include "ngraph/node.hpp"
24 #include "ngraph/op/constant.hpp"
25 #include "ngraph/pattern/op/any.hpp"
26 #include "ngraph/pattern/op/any_of.hpp"
27 #include "ngraph/pattern/op/any_output.hpp"
28 #include "ngraph/pattern/op/label.hpp"
29 #include "ngraph/pattern/op/skip.hpp"
46 bool finish(
bool is_successful);
51 PatternValueMap m_pattern_value_map;
52 PatternValueMaps m_pattern_value_maps;
54 size_t m_capture_size;
77 using PatternMap = ngraph::pattern::PatternMap;
80 Matcher(
const std::shared_ptr<Node> pattern_node, std::nullptr_t name) =
delete;
84 : m_pattern_node{pattern_node}
89 : m_pattern_node(pattern_node)
100 : m_pattern_node(pattern_node)
102 , m_strict_mode(strict_mode)
108 Matcher(std::shared_ptr<Node> pattern_node);
109 Matcher(std::shared_ptr<Node> pattern_node,
const std::string& name);
110 Matcher(std::shared_ptr<Node> pattern_node,
const std::string& name,
bool strict_mode);
118 bool match(std::shared_ptr<Node> graph_node);
125 bool match(
const Output<Node>& graph_value,
const PatternValueMap& previous_matches);
127 template <
typename T>
128 static std::shared_ptr<T> unique_match(std::shared_ptr<Node> node)
130 std::shared_ptr<T> matched;
131 for (
auto arg : node->input_values())
133 if (
auto t_casted = as_type_ptr<T>(arg.get_node_shared_ptr()))
137 throw ngraph_error(
"There's more than two arguments of the same type");
148 bool is_contained_match(
const NodeVector& exclusions = {},
bool ignore_unused =
true);
149 const NodeVector get_matched_nodes() {
return as_node_vector(m_matched_list); }
150 const OutputVector& get_matched_values()
const {
return m_matched_list; }
151 OutputVector& get_matched_values() {
return m_matched_list; }
153 const std::string& get_name() {
return m_name; }
154 std::shared_ptr<Node> get_pattern() {
return m_pattern_node.get_node_shared_ptr(); }
155 Output<Node> get_pattern_value() {
return m_pattern_node; }
156 std::shared_ptr<Node> get_match_root();
157 Output<Node> get_match_value();
158 PatternMap get_pattern_map()
const;
159 PatternValueMap& get_pattern_value_map() {
return m_pattern_map; }
160 PatternValueMaps& get_pattern_value_maps() {
return m_pattern_value_maps; }
173 bool is_strict_mode() {
return m_strict_mode; }
174 virtual bool match_arguments(
Node* pattern_node,
175 const std::shared_ptr<Node>& graph_node);
177 void capture(
const std::set<Node*>& static_nodes);
181 size_t get_number_of_recurrent_matches()
const {
return m_pattern_value_maps.size(); }
182 NodeVector get_bound_nodes_for_pattern(
const Output<Node>& pattern)
const;
183 size_t get_number_of_bound_labels()
const;
189 PatternValueMap m_pattern_map;
190 PatternValueMaps m_pattern_value_maps;
191 OutputVector m_matched_list;
194 bool match_permutation(
const OutputVector& pattern_args,
const OutputVector& args);
196 std::string m_name{
"unnamed"};
197 bool m_strict_mode{
false};
214 const std::shared_ptr<Node>& rpattern,
215 const std::set<std::shared_ptr<Node>>& correlated_patterns)
216 : m_initial_pattern(initial_pattern)
218 , m_recurrent_pattern(rpattern)
219 , m_correlated_patterns(correlated_patterns)
232 const std::shared_ptr<Node>& rpattern,
233 const std::set<std::shared_ptr<Node>>& correlated_patterns)
240 const std::shared_ptr<Node>& rpattern,
241 const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
244 const std::shared_ptr<Node>& rpattern,
245 const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
254 if (m_matches.count(pattern) == 0)
259 return as_node_vector(m_matches.at(pattern));
262 size_t get_number_of_recurrent_matches()
const
264 if (m_matches.size() == 0)
269 return (*m_matches.begin()).second.size();
272 size_t get_number_of_bound_labels()
const {
return m_matches.size(); }
276 std::shared_ptr<Node> get_match_root() {
return m_match_root.get_node_shared_ptr(); }
281 std::shared_ptr<Node> m_recurrent_pattern;
282 const std::set<std::shared_ptr<Node>> m_correlated_patterns;
283 RPatternValueMap m_matches;
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Base error for ngraph runtime errors.
Definition: except.hpp:28
Definition: matcher.hpp:43
Definition: matcher.hpp:75
Matcher(const Output< Node > &pattern_node, const std::string &name, bool strict_mode)
Constructs a Matcher object.
Definition: matcher.hpp:99
bool match(const Output< Node > &graph_value, const PatternMap &previous_matches)
Matches a pattern to graph_node.
MatcherState start_match()
Try a match.
size_t add_node(Output< Node > node)
Low-level helper to match recurring patterns.
bool match(const Output< Node > &graph_value)
Matches a pattern to graph_node.
Definition: matcher.hpp:201
RecurrentMatcher(const Output< Node > &pattern, const std::shared_ptr< Node > &rpattern, const std::set< std::shared_ptr< Node >> &correlated_patterns)
Constructs a RecurrentMatcher object. Reccurent Matchers are used to match repeating patterns (e....
Definition: matcher.hpp:231
bool match(Output< Node > graph)
Tries to match a pattern for an individual cell to a given graph.
NodeVector get_bound_nodes_for_pattern(const std::shared_ptr< Node > &pattern) const
Returns a vector of bound nodes for a given label (used in a pattern describing an individual cell.
Definition: matcher.hpp:252
RecurrentMatcher(const Output< Node > &initial_pattern, const Output< Node > &pattern, const std::shared_ptr< Node > &rpattern, const std::set< std::shared_ptr< Node >> &correlated_patterns)
Constructs a RecurrentMatcher object. Reccurent Matchers are used to match repeating patterns (e....
Definition: matcher.hpp:212
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28