class ngraph::pattern::Matcher

Overview

Matcher looks for node patterns in a computation graph. More…

#include <matcher.hpp>

class Matcher
{
public:
    // typedefs

    typedef ngraph::pattern::PatternMap PatternMap;

    // fields

    Output<Node> m_match_root;
    Output<Node> m_pattern_node;
    PatternValueMap m_pattern_map;
    PatternValueMaps m_pattern_value_maps;
    OutputVector m_matched_list;

    // construction

    Matcher(const std::shared_ptr<Node> pattern_node, std::nullptr_t name);
    Matcher();
    Matcher(Output<Node>& pattern_node);
    Matcher(Output<Node>& pattern_node, const std::string& name);

    Matcher(
        const Output<Node>& pattern_node,
        const std::string& name,
        bool strict_mode
        );

    Matcher(std::shared_ptr<Node> pattern_node);
    Matcher(std::shared_ptr<Node> pattern_node, const std::string& name);

    Matcher(
        std::shared_ptr<Node> pattern_node,
        const std::string& name,
        bool strict_mode
        );

    // methods

    bool match(const Output<Node>& graph_value);
    bool match(std::shared_ptr<Node> graph_node);
    bool match(const Output<Node>& graph_value, const PatternMap& previous_matches);

    bool match(
        const Output<Node>& graph_value,
        const PatternValueMap& previous_matches
        );

    bool is_contained_match(
        const NodeVector& exclusions = {},
        bool ignore_unused = true
        );

    const NodeVector get_matched_nodes();
    const OutputVector& get_matched_values() const;
    OutputVector& get_matched_values();
    void reset();
    const std::string& get_name();
    std::shared_ptr<Node> get_pattern();
    Output<Node> get_pattern_value();
    std::shared_ptr<Node> get_match_root();
    Output<Node> get_match_value();
    PatternMap get_pattern_map() const;
    PatternValueMap& get_pattern_value_map();
    PatternValueMaps& get_pattern_value_maps();
    size_t add_node(Output<Node> node);

    virtual bool match_value(
        const ngraph::Output<Node>& pattern_value,
        const ngraph::Output<Node>& graph_value
        );

    bool is_strict_mode();

    virtual bool match_arguments(
        Node* pattern_node,
        const std::shared_ptr<Node>& graph_node
        );

    void capture(const std::set<Node*>& static_nodes);
    void clear_state();
    size_t get_number_of_recurrent_matches() const;
    NodeVector get_bound_nodes_for_pattern(const Output<Node>& pattern) const;
    size_t get_number_of_bound_labels() const;
    MatcherState start_match();

    template <typename T>
    static std::shared_ptr<T> unique_match(std::shared_ptr<Node> node);
};

Detailed Documentation

Matcher looks for node patterns in a computation graph. The patterns are described by an automaton that is described by an extended computation graph. The matcher executes by attempting to match the start node of the pattern to a computation graph value (output of a Node). In addition to determing if a match occurs, a pattern node may add graph nodes to a list of matched nodes, associate nodes with graph values, and start submatches. Submatches add match state changes to the enclosing match if the submatch succeeds; otherwise the state is reverted.

The default match behavior of a pattern node with a graph nodes is that the computation graph value is added to the end of the matched value list and the match succeeds if the node/pattern types match and the input values match. In the case of a commutative node, the inputs can match in any order. If the matcher is in strict mode, the graph value element type and shape must also match.

Pattern nodes that have different match behavior are in ngraph::pattern::op and have descriptions of their match behavior.

Construction

Matcher(
    const Output<Node>& pattern_node,
    const std::string& name,
    bool strict_mode
    )

Constructs a Matcher object.

Parameters:

pattern_node

is a pattern sub graph that will be matched against input graphs

name

is a string which is used for logging and disabling a matcher

strict_mode

forces a matcher to consider shapes and ET of nodes

Methods

bool match(const Output<Node>& graph_value)

Matches a pattern to graph_node.

Parameters:

graph_value

is an input graph to be matched against

bool match(const Output<Node>& graph_value, const PatternMap& previous_matches)

Matches a pattern to graph_node.

Parameters:

graph_value

is an input graph to be matched against

previous_matches

contains previous mappings from labels to nodes to use

size_t add_node(Output<Node> node)

Low-level helper to match recurring patterns.

Parameters:

graph

is a graph to be matched against

pattern

is a recurring pattern

rpattern

specifies a node to recur from next

patterns

a map from labels to matches

MatcherState start_match()

Try a match.