OpenVINO Matcher Pass

ov::pass::MatcherPass is used for pattern-based transformations.

Template for MatcherPass transformation class

// transformations/template_pattern_transformation.hpp
/\*\*
 \* @ingroup ie_transformation_common_api
 \* @brief Add transformation description.
 \*/
class ov::pass::DecomposeDivideMatcher : public ov::pass::MatcherPass {
public:
    OPENVINO_RTTI("DecomposeDivideMatcher", "0");
    DecomposeDivideMatcher();
};
// template_pattern_transformation.cpp
ov::pass::DecomposeDivideMatcher::DecomposeDivideMatcher() {
    MATCHER_SCOPE(DecomposeDivideMatcher);
    // Pattern example
    auto input0 = pattern::any_input();
    auto input1 = pattern::any_input();
    auto div = std::make_shared<ov::opset3::Divide>(input0, input1);

    ov::matcher_pass_callback callback = [](pattern::Matcher& m) {
        auto div = std::dynamic_pointer_cast<ov::opset3::Divide>(m.get_match_root());
        // We can not apply this transformation in case with integer input data type
        if (!div || div->input(0).get_element_type().is_integral()) {
            return false;
        }

        // Decompose Divide into Multiply with Power operations
        auto pow = std::make_shared<ov::opset3::Power>(
            div->input_value(1),
            opset3::Constant::create(div->get_input_element_type(1), Shape{1}, {-1}));

        auto mul = std::make_shared<ov::opset3::Multiply>(div->input_value(0), pow);

        // Save original name to last operation in replacement sub-graph
        mul->set_friendly_name(div->get_friendly_name());

        // Copy runtime info attributes to newly created operation
        ov::copy_runtime_info(div, {pow, mul});

        // Replace Divide operation with Multiply
        ov::replace_node(div, mul);

        // Return true as the root node was changed
        return true;
    };

    // Register pattern with Divide operation as a pattern root node
    auto m = std::make_shared<ov::pass::pattern::Matcher>(div, "ConvertDivide");
    // Register Matcher
    register_matcher(m, callback);
}

To use ov::pass::MatcherPass, you need to complete these steps:

  1. Create a pattern

  2. Implement a callback

  3. Register the pattern and Matcher

  4. Execute MatcherPass

So let’s go through each of these steps.

Create a pattern

Pattern is a single root ov::Model. But the only difference is that you do not need to create a model object, you just need to create and connect opset or special pattern operations. Then you need to take the last created operation and put it as a root of the pattern. This root node will be used as a root node in pattern matching.

Note

Any nodes in a pattern that have no consumers and are not registered as root will not be used in pattern matching.

// Pattern example
auto input = std::make_shared<ov::opset8::Parameter>(ov::element::i64, ov::Shape{1});
auto shapeof = std::make_shared<ov::opset8::ShapeOf>(input);

// Create Matcher with Parameter->ShapeOf pattern
auto m = std::make_shared<ov::pass::pattern::Matcher>(shapeof, "MyPatternBasedTransformation");

The Parameter operation in the example above has type and shape specified. These attributes are needed only to create Parameter operation class and will not be used in pattern matching.

For more pattern examples, refer to the pattern matching section.

Implement callback

Callback is an action applied to every pattern entrance. In general, callback is the lambda function that takes Matcher object with detected subgraph.

ov::graph_rewrite_callback callback = [](ov::pass::pattern::Matcher& m) {
    // Get root node
    std::shared_ptr<ov::Node> root_node = m.get_match_root();

    // Get all nodes matched by pattern
    ov::NodeVector nodes = m.get_matched_nodes();

    // Transformation code
    return false;
};

The example above shows the callback structure and how Matcher can be used for accessing nodes detected by pattern. Callback return value is true if root node was replaced and another pattern cannot be applied to the same root node; otherwise, it is false.

Note

It is not recommended to manipulate with nodes that are under root node. This may affect GraphRewrite execution as it is expected that all nodes that come after root node in topological order are valid and can be used in pattern matching.

MatcherPass also provides functionality that allows reporting of the newly created nodes that can be used in additional pattern matching. If MatcherPass was registered in ov::pass::Manager or ov::pass::GraphRewrite, these registered nodes will be added for additional pattern matching. That means that matcher passes registered in ov::pass::GraphRewrite will be applied to these nodes.

The example below shows how single MatcherPass can fuse sequence of operations using the register_new_node method.

ov::pass::ReluReluFusionMatcher::ReluReluFusionMatcher() {
    MATCHER_SCOPE(ReluReluFusionMatcher);
    auto m_relu1 = ov::pass::pattern::wrap_type<ov::opset3::Relu>(pattern::consumers_count(1));
    auto m_relu2 = ov::pass::pattern::wrap_type<ov::opset3::Relu>({m_relu1});

    ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
        // Map that helps to connect labels with matched outputs
        auto& node_to_output = m.get_pattern_value_map();

        // Create new Relu operation and add register it for additional execution
        auto new_relu =
            register_new_node<ov::opset3::Relu>(node_to_output.at(m_relu1).get_node_shared_ptr()->input_value(0));

        // Copy runtime info attributes to newly created operation
        ov::copy_runtime_info(m.get_matched_nodes(), new_relu);

        // Save last Relu name to new Relu operation
        new_relu->set_friendly_name(m.get_match_root()->get_friendly_name());

        // Replace Relu->Relu with Relu
        ov::replace_node(m.get_match_root(), new_relu);

        // Return true as the root node was changed
        return true;
    };

    // Register pattern with Relu operation as a pattern root node
    auto m = std::make_shared<ov::pass::pattern::Matcher>(m_relu2, "ReluReluFusion");
    // Register Matcher
    register_matcher(m, callback);
}

Note

If you register multiple nodes, please add them in topological order. We do not topologically sort these nodes as it is a time-consuming operation.

Register pattern and Matcher

The last step is to register Matcher and callback inside the MatcherPass pass. To do this, call the register_matcher method.

Note

Only one matcher can be registered for a single MatcherPass class.

// Register matcher and callback
register_matcher(m, callback);

Execute MatcherPass

MatcherPass has multiple ways to be executed:

  • Run on a single node - it can be useful if you want to run MatcherPass inside another transformation.

    if (ov::pass::DecomposeDivideMatcher().apply(node)) {
        // successful execution (root node was replaced)
    }
  • Run on ov::Model using GraphRewrite - this approach gives ability to run MatcherPass on whole ov::Model. Moreover, multiple MatcherPass transformation can be registered in a single GraphRewite to be executed in a single graph traversal.

    // Two matcher passes will run simultaneously in a single graph traversal
    ov::pass::GraphRewrite pass;
    pass.add_matcher<ov::pass::DecomposeDivideMatcher>();
    pass.add_matcher<ov::pass::ReluReluFusionMatcher>();
    pass.run_on_model(f);
  • Run on ov::Model using ov::pass::Manager - this approach helps you to register MatcherPass for execution on ov::Model as another transformation types.

    // Two matchers will run independently (two independent graph traversals)
    // pass::Manager automatically creates GraphRewrite container for each MatcherPass
    ov::pass::Manager manager;
    manager.register_pass<ov::pass::DecomposeDivideMatcher>();
    manager.register_pass<ov::pass::ReluReluFusionMatcher>();
    manager.run_passes(f);

Pattern Matching

Sometimes patterns cannot be expressed via regular operations or it is too complicated. For example, if you want to detect Convolution->Add sub-graph without specifying particular input type for Convolution operation or you want to create a pattern where some of operations can have different types. And for these cases OpenVINO™ provides additional helpers to construct patterns for GraphRewrite transformations.

There are two main helpers:

  1. ov::pass::pattern::any_input - helps to express inputs if their types are undefined.

  2. ov::pass::pattern::wrap_type<T> - helps to express nodes of pattern without specifying node attributes.

Let’s go through the example to have better understanding of how it works:

Note

Node attributes do not participate in pattern matching and are needed only for operations creation. Only operation types participate in pattern matching.

The example below shows basic usage of ov::passpattern::any_input. Here we construct Multiply pattern with arbitrary first input and Constant as a second input. Also as Multiply is commutative operation, it does not matter in which order we set inputs (any_input/Constant or Constant/any_input) because both cases will be matched.

// Detect Multiply with arbitrary first input and second as Constant
// ov::pattern::op::Label - represent arbitrary input
auto input = ov::pass::pattern::any_input();
auto value = ov::opset8::Constant::create(ov::element::f32, ov::Shape{1}, {0.5});
auto mul = std::make_shared<ov::opset8::Multiply>(input, value);
auto m = std::make_shared<ov::pass::pattern::Matcher>(mul, "MultiplyMatcher");

This example shows how we can construct a pattern when operation has arbitrary number of inputs.

// Detect Concat operation with arbitrary number of inputs
auto concat = ov::pass::pattern::wrap_type<ov::opset8::Concat>();
auto m = std::make_shared<ov::pass::pattern::Matcher>(concat, "ConcatMatcher");

This example shows how to use predicate to construct a pattern. Also it shows how to match pattern manually on given node.

// Detect Multiply->Add sequence where mul has exactly one consumer
auto mul = ov::pass::pattern::wrap_type<ov::opset8::Multiply>(ov::pass::pattern::consumers_count(1)/\*сheck consumers count\*/);
auto add = ov::pass::pattern::wrap_type<ov::opset8::Add>({mul, ov::pass::pattern::any_input()});
auto m = std::make_shared<ov::pass::pattern::Matcher>(add, "MultiplyAddMatcher");
// Matcher can be used to match pattern manually on given node
if (m->match(node->output(0))) {
    // Successfully matched
}

Note

Be careful with manual matching because Matcher object holds matched nodes. To clear a match, use the m->clear_state() method.