class ov::pass::pattern::RecurrentMatcher

Overview

#include <matcher.hpp>

class RecurrentMatcher
{
public:
    // construction

    RecurrentMatcher(
        const Output<Node>& initial_pattern,
        const Output<Node>& pattern,
        const std::shared_ptr<Node>& rpattern,
        const std::set<std::shared_ptr<Node>>& correlated_patterns
        );

    RecurrentMatcher(
        const Output<Node>& pattern,
        const std::shared_ptr<Node>& rpattern,
        const std::set<std::shared_ptr<Node>>& correlated_patterns
        );

    RecurrentMatcher(
        const Output<Node>& initial_pattern,
        const Output<Node>& pattern,
        const std::shared_ptr<Node>& rpattern,
        const std::set<std::shared_ptr<op::Label>>& correlated_patterns
        );

    RecurrentMatcher(
        const Output<Node>& pattern,
        const std::shared_ptr<Node>& rpattern,
        const std::set<std::shared_ptr<op::Label>>& correlated_patterns
        );

    // methods

    NodeVector get_bound_nodes_for_pattern(const std::shared_ptr<Node>& pattern) const;
    size_t get_number_of_recurrent_matches() const;
    size_t get_number_of_bound_labels() const;
    bool match(Output<Node> graph);
    std::shared_ptr<Node> get_match_root();
    Output<Node> get_match_value();
};

Detailed Documentation

Construction

RecurrentMatcher(
    const Output<Node>& initial_pattern,
    const Output<Node>& pattern,
    const std::shared_ptr<Node>& rpattern,
    const std::set<std::shared_ptr<Node>>& correlated_patterns
    )

Constructs a RecurrentMatcher object. Reccurent Matchers are used to match repeating patterns (e.g. RNN, LSTM, GRU cells)

Parameters:

initial_pattern

is a pattern sub graph describing the initial cell

pattern

is a pattern sub graph describing an individual cell

rpattern

is a (recurring) label to denote which node the next match should start at

correlated_patterns

is a set of labels whose bound nodes must remain the same across all cells

RecurrentMatcher(
    const Output<Node>& pattern,
    const std::shared_ptr<Node>& rpattern,
    const std::set<std::shared_ptr<Node>>& correlated_patterns
    )

Constructs a RecurrentMatcher object. Reccurent Matchers are used to match repeating patterns (e.g. RNN, LSTM, GRU cells)

Parameters:

pattern

is a pattern sub graph describing an individual cell

rpattern

is a (recurring) label to denote which node the next match should start at

correlated_patterns

is a set of labels whose bound nodes must remain the same across all cells

Methods

NodeVector get_bound_nodes_for_pattern(const std::shared_ptr<Node>& pattern) const

Returns a vector of bound nodes for a given label (used in a pattern describing an individual cell.

bool match(Output<Node> graph)

Tries to match a pattern for an individual cell to a given graph.