Overview of Transformations API

This guide contains all necessary information that you need to start implementing nGraph transformations.

Prerequisites

Before creating a transformation, do the following:

  • Make sure that there is no transformation with the same functionality in the Transformation Library

  • Learn how the Transformation Library is structured and how transformations are organized

  • Understand where to put your transformation code

Transformation Library Structure

Transformation library is independent from Inference Engine target library named as inference_engine_transformations and is located in the inference-engine/src/transformations directory.

Transformations root directory contains two folders:

  • ngraph_ops - Contains internal opset operations that are common for plugins.

  • transformations - Includes all transformations, utils, runtime info attributes, and pass managers.

All internal operations and transformations located inside the Transformation Library can be used inside plugins. All legacy operations and transformations were moved to a legacy library and are not recommended to be used.

Transformation Flow Layers

Transformation flow in the transformation library has several layers:

  1. Pass managers - Execute any type of transformations and provide additional debug capabilities.

  2. Transformations - Perform a particular transformation algorithm on ngraph::Function.

  3. Low-level functions - Take a set of nodes and perform some transformation action. They are not mandatory and all transformation code can be located inside the transformation. But if some transformation parts can potentially be reused in other transformations, we suggest keeping them as separate functions.

Location for Your Transformation Code

To decide where to store your transformation code, please follow these rules:

  1. If it is a plugin-specific transformation and cannot be reused by other plugins, keep source code inside plugin.

  2. If this transformation relates to opset operation conversion or optimization, keep sources inside the transformation library.

After you decide where to store your transformation code, you can start developing your own nGraph transformation.

and graph representation

nGraph function is a very simple thing: it stores shared pointers to ngraph::op::Parameter, ngraph::op::Result and ngraph::op::Sink operations that are inputs, outputs and sinks of the graph. Sinks of the graph have no consumers and not included into results vector. All other operations hold each other via shared pointers: child operation holds its parent (hard link). If operation has no consumers and it’s not Result or Sink operation (shared pointer counter is zero) then it will be destructed and won’t be accessible anymore. Each operation in ngraph::Function has a std::shared_ptr<ngraph::Node> type.

For examples of how to build an nGraph function, see the Build nGraph Function page.

Transformations types

nGraph has three main transformation types:

_images/transformations_structure.png

<Untitled>

ngraph::pass::FunctionPass is used for transformations that take entire ngraph::Function as an input and process it.

Template for FunctionPass transformation class

// template_function_transformation.hpp
class ngraph::pass::MyFunctionTransformation : public ngraph::pass::FunctionPass {
public:
    NGRAPH_RTTI_DECLARATION;
    bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
// template_function_transformation.cpp
NGRAPH_RTTI_DEFINITION(ngraph::pass::MyFunctionTransformation, "MyFunctionTransformation", 0);

bool pass::MyFunctionTransformation::run_on_function(std::shared_ptr<ngraph::Function> f) {
    RUN_ON_FUNCTION_SCOPE(MyFunctionTransformation);
    // Example transformation code
    NodeVector nodes;

    // Traverse nGraph Function in topological order
    for (auto& node : f->get_ordered_ops()) {
        // Check that number of input and output ports are equal to 1
        if (node->inputs().size() == 1 && node->outputs().size() == 1) {
            // Check that input and output shape a fully defined (not dynamic) and number of consumers equal to 1
            Input<Node> input = node->input(0);
            Output<Node> output = node->output(0);
            if (input.get_partial_shape().is_static() && output.get_partial_shape().is_static() && output.get_target_inputs().size() == 1) {
                nodes.push_back(node);
            }
        }
    }

    // Print types and names for collected nodes
    for (auto& node : nodes) {
        std::cout << "Type: " << node->get_type_info().name << std::endl << "Name: " << node->get_friendly_name() << std::endl;
    }

    // Return false because we didn't change nGraph Function
    return false;
}

Using ngraph::FunctionPass, you need to override the run_on_function method where you will write the transformation code. Return value is true if the original function has changed during transformation (new operation was added, or operations replacement was made, or node attributes were changed); otherwise, it is false. For transformation API, please follow the working with ngraph::Function section. Also ngraph::FunctionPass based transformations can be executed via pass::Manager. See the examples in the Using pass manager section.

<Untitled>

ngraph::pass::MatcherPass is used for pattern-based transformations.

Template for MatcherPass transformation class

// transformations/template_pattern_transformation.hpp
/**
 * @ingroup ie_transformation_common_api
 * @brief Add transformation description.
 */
class ngraph::pass::DecomposeDivideMatcher : public ngraph::pass::MatcherPass {
public:
    NGRAPH_RTTI_DECLARATION;
    DecomposeDivideMatcher();
};
// template_pattern_transformation.cpp
NGRAPH_RTTI_DEFINITION(ngraph::pass::DecomposeDivideMatcher, "DecomposeDivideMatcher", 0);

ngraph::pass::DecomposeDivideMatcher::DecomposeDivideMatcher() {
    MATCHER_SCOPE(DecomposeDivideMatcher);
    // Pattern example
    auto input0 = pattern::any_input();
    auto input1 = pattern::any_input();
    auto div = std::make_shared<ngraph::opset3::Divide>(input0, input1);

    ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
        auto div = std::dynamic_pointer_cast<ngraph::opset3::Divide>(m.get_match_root());
        // We can not apply this transformation in case with integer input data type
        if (!div || div->input(0).get_element_type().is_integral()) {
            return false;
        }

        // Decompose Divide into Multiply with Power operations
        auto pow = std::make_shared<ngraph::opset3::Power>(div->input_value(1), opset3::Constant::create(div->get_input_element_type(1), Shape {1}, {-1}));

        auto mul = std::make_shared<ngraph::opset3::Multiply>(div->input_value(0), pow);

        // Save original name to last operation in replacement sub-graph
        mul->set_friendly_name(div->get_friendly_name());

        // Copy runtime info attributes to newly created operation
        ngraph::copy_runtime_info(div, {pow, mul});

        // Replace Divide operation with Multiply
        ngraph::replace_node(div, mul);

        // Return true as the root node was changed
        return true;
    };

    // Register pattern with Divide operation as a pattern root node
    auto m = std::make_shared<ngraph::pattern::Matcher>(div, "ConvertDivide");
    // Register Matcher
    register_matcher(m, callback);
}

To use ngraph::pass::MatcherPass, you need to complete these steps:

  1. Create a pattern

  2. Implement a callback

  3. Register the pattern and Matcher

  4. Execute MatcherPass

So let’s go through each of these steps.

Create a pattern

Pattern is a single root ngraph::Function. But the only difference is that you do not need to create a function object, you just need to create and connect opset or special pattern operations. Then you need to take the last created operation and put it as a root of the pattern. This root node will be used as a root node in pattern matching.

Note

Any nodes in a pattern that have no consumers and are not registered as root will not be used in pattern matching.

// Pattern example
auto input = std::make_shared<ngraph::opset3::Parameter>(element::i64, Shape{1});
auto shapeof = std::make_shared<ngraph::opset3::ShapeOf>(input);

// Create Matcher with Parameter->ShapeOf pattern
auto m = std::make_shared<ngraph::pattern::Matcher>(shapeof, "MyPatternBasedTransformation");

The Parameter operation in the example above has type and shape specified. These attributes are needed only to create Parameter operation class and will not be used in pattern matching.

For more pattern examples, refer to the pattern matching section.

Implement callback

Callback is an action applied to every pattern entrance. In general, callback is the lambda function that takes Matcher object with detected subgraph.

ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
    // Get root node
    std::shared_ptr<Node> root_node = m.get_match_root();

    // Get all nodes matched by pattern
    NodeVector nodes = m.get_matched_nodes();

    // Transformation code
    return false;
};

The example above shows the callback structure and how Matcher can be used for accessing nodes detected by pattern. Callback return value is true if root node was replaced and another pattern cannot be applied to the same root node; otherwise, it is false.

Note

It is not recommended to manipulate with nodes that are under root node. This may affect GraphRewrite execution as it is expected that all nodes that come after root node in topological order are valid and can be used in pattern matching.

MatcherPass also provides functionality that allows reporting of the newly created nodes that can be used in additional pattern matching. If MatcherPass was registered in pass::Manager or pass::GraphRewrite, these registered nodes will be added for additional pattern matching. That means that matcher passes registered in pass::GraphRewrite will be applied to these nodes.

The example below shows how single MatcherPass can fuse sequence of operations using the register_new_node method.

NGRAPH_RTTI_DEFINITION(ngraph::pass::ReluReluFusionMatcher, "ReluReluFusionMatcher", 0);

ngraph::pass::ReluReluFusionMatcher::ReluReluFusionMatcher() {
    MATCHER_SCOPE(ReluReluFusionMatcher);
    auto m_relu1 = ngraph::pattern::wrap_type<ngraph::opset3::Relu>(pattern::consumers_count(1));
    auto m_relu2 = ngraph::pattern::wrap_type<ngraph::opset3::Relu>({m_relu1});

    ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
        // Map that helps to connect labels with matched outputs
        auto& node_to_output = m.get_pattern_value_map();

        // Create new Relu operation and add register it for additional execution
        auto new_relu = register_new_node<ngraph::opset3::Relu>(node_to_output.at(m_relu1).get_node_shared_ptr()->input_value(0));

        // Copy runtime info attributes to newly created operation
        ngraph::copy_runtime_info(m.get_matched_nodes(), new_relu);

        // Save last Relu name to new Relu operation
        new_relu->set_friendly_name(m.get_match_root()->get_friendly_name());

        // Replace Relu->Relu with Relu
        ngraph::replace_node(m.get_match_root(), new_relu);

        // Return true as the root node was changed
        return true;
    };

    // Register pattern with Relu operation as a pattern root node
    auto m = std::make_shared<ngraph::pattern::Matcher>(m_relu2, "ReluReluFusion");
    // Register Matcher
    register_matcher(m, callback);
}

Note

If you register multiple nodes, please add them in topological order. We do not topologically sort these nodes as it is a time-consuming operation.

Register pattern and Matcher

The last step is to register Matcher and callback inside the MatcherPass pass. To do this, call the register_matcher method.

Note

Only one matcher can be registered for a single MatcherPass class.

// Register matcher and callback
register_matcher(m, callback);

Execute MatcherPass

MatcherPass has multiple ways to be executed:

  • Run on a single node - it can be useful if you want to run MatcherPass inside another transformation.

    if (ngraph::pass::DecomposeDivideMatcher().apply(node)) {
        // successful execution (root node was replaced)
    }
  • Run on ngraph::Function using GraphRewrite - this approach gives ability to run MatcherPass on whole ngraph::Function. Moreover, multiple MatcherPass transformation can be registered in a single GraphRewite to be executed in a single graph traversal.

    // Two matcher passes will run simultaneously in a single graph traversal
    ngraph::pass::GraphRewrite pass;
    pass.add_matcher<ngraph::pass::DecomposeDivideMatcher>();
    pass.add_matcher<ngraph::pass::ReluReluFusionMatcher>();
    pass.run_on_function(f);
  • Run on ngraph::Function using pass::Manager - this approach helps you to register MatcherPass for execution on ngraph::Function as another transformation types.

    // Two matchers will run independently (two independent graph traversals)
    // pass::Manager automatically creates GraphRewrite container for each MatcherPass
    pass::Manager manager;
    manager.register_pass<ngraph::pass::DecomposeDivideMatcher>();
    manager.register_pass<ngraph::pass::ReluReluFusionMatcher>();
    manager.run_passes(f);

<Untitled>

GraphRewrite pass serves for running multiple matcher passes on ngraph::Function in a single graph traversal. Example:

// Two matcher passes will run simultaneously in a single graph traversal
ngraph::pass::GraphRewrite pass;
pass.add_matcher<ngraph::pass::DecomposeDivideMatcher>();
pass.add_matcher<ngraph::pass::ReluReluFusionMatcher>();
pass.run_on_function(f);

In addition, GraphRewrite handles nodes that were registered by MatcherPasses during their execution. This nodes will be added to the beginning of the sequence with nodes for pattern matching.

Note

when using pass::Manager temporary GraphRewrite is used to execute single MatcherPass.

GraphRewrite has two algorithms for MatcherPasses execution. First algorithm is straightforward. It applies each MatcherPass in registration order to current node.

_images/graph_rewrite_execution.png

But it is not really efficient when you have a lot of registered passes. So first of all GraphRewrite checks that all MatcherPass patterns has type-based root node (it means that type of this node is not hidden into predicate). And then creates map from registered MatcherPasses. That helps to avoid additional cost of applying each MatcherPass for each node.

_images/graph_rewrite_efficient_search.png

Note

GraphRewrite execution algorithm cannot be set manually and depends only on root nodes registered inside MatcherPasses.

Pattern Matching

Sometimes patterns cannot be expressed via regular nGraph operations or it is too complicated. For example, if you want to detect Convolution->Add sub-graph without specifying particular input type for Convolution operation or you want to create a pattern where some of operations can have different types. And for these cases nGraph provides additional helpers to construct patterns for GraphRewrite transformations.

There are two main helpers:

  1. ngraph::pattern::any_input - helps to express inputs if their types are undefined.

  2. ngraph::pattern::wrap_type<T> - helps to express nodes of pattern without specifying node attributes.

Let’s go through the example to have better understanding of how it works:

Note

Node attributes do not participate in pattern matching and are needed only for operations creation. Only operation types participate in pattern matching.

The example below shows basic usage of pattern::any_input. Here we construct Multiply pattern with arbitrary first input and Constant as a second input. Also as Multiply is commutative operation, it does not matter in which order we set inputs (any_input/Constant or Constant/any_input) because both cases will be matched.

// Detect Multiply with arbitrary first input and second as Constant
// ngraph::pattern::op::Label - represent arbitrary input
auto input = ngraph::pattern::any_input();
auto value = ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0.5});
auto mul = std::make_shared<ngraph::opset3::Multiply>(input, value);
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "MultiplyMatcher");

This example shows how we can construct a pattern when operation has arbitrary number of inputs.

// Detect Concat operation with arbitrary number of inputs
auto concat = ngraph::pattern::wrap_type<ngraph::opset3::Concat>();
auto m = std::make_shared<ngraph::pattern::Matcher>(concat, "ConcatMatcher");

This example shows how to use predicate to construct a pattern. Also it shows how to match pattern manually on given node.

// Detect Multiply->Add sequence where mul has exactly one consumer
auto mul = ngraph::pattern::wrap_type<ngraph::opset3::Multiply>(ngraph::pattern::consumers_count(1)/*сheck consumers count*/);
auto add = ngraph::pattern::wrap_type<ngraph::opset3::Add>({mul, ngraph::pattern::any_input()});
auto m = std::make_shared<ngraph::pattern::Matcher>(add, "MultiplyAddMatcher");
// Matcher can be used to match pattern manually on given node
if (m->match(node->output(0))) {
    // Successfully matched
}

Note

Be careful with manual matching because Matcher object holds matched nodes. To clear a match, use the m->clear_state() method.

Working with

In this chapter we will review nGraph API that allows us to manipulate with ngraph::Function.

input and output ports

First of all let’s talk about ngraph::Node input/output ports. Each nGraph operation has input and output ports except cases when operation has Result, Parameter, or Constant type.

Every port belongs to its node, so using a port we can access parent node, get shape and type for particular input/output, get all consumers in case of output port, and get producer node in case of input port. With output port we can set inputs for newly created operations.

Lets look at the code example.

// Let's suppose that node is opset3::Convolution operation
// as we know opset3::Convolution has two input ports (data, weights) and one output port
Input <Node> data = node->input(0);
Input <Node> weights = node->input(1);
Output <Node> output = node->output(0);

// Getting shape and type
auto pshape = data.get_partial_shape();
auto el_type = data.get_element_type();

// Getting parent for input port
Output <Node> parent_output;
parent_output = data.get_source_output();

// Another short way to get partent for output port
parent_output = node->input_value(0);

// Getting all consumers for output port
auto consumers = output.get_target_inputs();

You may notice that we usually construct operations in this way:

std::shared_ptr<Node> neg_const = opset1::Constant::create(sub->get_input_element_type(1), Shape{1}, {-1}));
Output<Node> data = node->input_value(0);
auto neg = std::make_shared<ngraph::opset1::Multiply>(data, neg_const);

In this example, the opset3::Multiply operation takes Output<Node> and std::shared_ptr<Node> as inputs. But the constructor takes both as Output<Node>. In this case, std::shared_ptr<Node> will be automatically converted to Output<Node> if node has exactly one output port; otherwise, conversion raises an exception.

replacement

nGraph provides two ways for node replacement: via nGraph helper function and directly via port methods. We are going to review both of them.

Let’s start with nGraph helper functions. The most popular function is ngraph::replace_node(old_node, new_node).

We will review real replacement case where Negative operation is replaced with Multiply.

_images/ngraph_replace_node.png
bool ngraph_replace_node(std::shared_ptr<Node> node) {
    // Step 1. Verify that node has opset3::Negative type
    auto neg = std::dynamic_pointer_cast<ngraph::opset3::Negative>(node);
    if (!neg) {
        return false;
    }

    // Step 2. Create opset3::Multiply operation where the first input is negative operation input and second as Constant with -1 value
    auto mul = std::make_shared<ngraph::opset3::Multiply>(neg->input_value(0),
                                                          opset3::Constant::create(neg->get_element_type(), Shape{1}, {-1}));

    mul->set_friendly_name(neg->get_friendly_name());
    ngraph::copy_runtime_info(neg, mul);

    // Step 3. Replace Negative operation with Multiply operation
    ngraph::replace_node(neg, mul);
    return true;

    // Step 4. Negative operation will be removed automatically because all consumers was moved to Multiply operation
}

ngraph::replace_node has a constraint that number of output ports for both of ops must be the same; otherwise, it raises an exception.

The alternative way to do the same replacement is the following:

// All neg->output(0) consumers will be moved to mul->output(0) port
neg->output(0).replace(mul->output(0));

Another transformation example is insertion.

_images/ngraph_insert_node.png
// Step 1. Lets suppose that we have a node with single output port and we want to insert additional operation new_node after it
void insert_example(std::shared_ptr<ngraph::Node> node) {
    // Get all consumers for node
    auto consumers = node->output(0).get_target_inputs();

    // Step 2. Create new node. Let it be opset1::Relu.
    auto new_node = std::make_shared<ngraph::opset3::Relu>(node);

    // Step 3. Reconnect all consumers to new_node
    for (auto input : consumers) {
        input.replace_source_output(new_node);
    }
}

The alternative way to the insert operation is to make a node copy and use replace_node :

void insert_example_with_copy(std::shared_ptr<ngraph::Node> node) {
    // Make a node copy
    auto node_copy = node->clone_with_new_inputs(node->input_values());
    // Create new node
    auto new_node = std::make_shared<ngraph::opset3::Relu>(node_copy);
    ngraph::replace_node(node, new_node);
}

elimination

Another type of node replacement is its elimination.

To eliminate operation, nGraph has special method that considers all limitations related to InferenceEngine.

// Suppose we have a node that we want to remove
bool success = replace_output_update_name(node->output(0), node->input_value(0));

replace_output_update_name in case of successful replacement it automatically preserves friendly name and runtime info.

Transformation conditional compilation

Transformation library has two internal macros to support conditional compilation feature.

  • MATCHER_SCOPE(region) - allows to disable the MatcherPass if matcher isn’t used. The region name should be unique. This macro creates a local variable matcher_name which you should use as a matcher name.

  • RUN_ON_FUNCTION_SCOPE(region) - allows to disable run_on_function pass if it isn’t used. The region name should be unique.

Transformation writing essentials

When developing a transformation, you need to follow these transformation rules:

1. Operation Set (OpSet)

Use the latest version of OpSet in your transformation. An exception is op_conversion transformations, where different opsets can be used.

#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset3.hpp>

2. Dynamic Shape and Rank

nGraph has two types for shape representation: ngraph::Shape - represents static shape. ngraph::PartialShape - represents dynamic shape. It means that rank or some of dimensions are dynamic (undefined). ngraph::PartialShape can be converted to ngraph::Shape using the get_shape() method if all dimensions are static; otherwise, conversion raises an exception.

auto partial_shape = node->input(0).get_partial_shape(); // get zero input partial shape
if (partial_shape.is_dynamic() /* or !partial_shape.is_static() */) {
    return false;
}
auto static_shape = partial_shape.get_shape();

But in most cases before getting static shape using get_shape() method, you need to check that shape is static.

Also if your transformation requires only input shape rank or particular dimension value, please do not use the get_shape() method. See the example below demonstrating how to avoid using get_shape()

auto partial_shape = node->input(0).get_partial_shape(); // get zero input partial shape

// Check that input shape rank is static
if (!partial_shape.rank().is_static()) {
    return false;
}
auto rank_size = partial_shape.rank().get_length();

// Check that second dimension is not dynamic
if (rank_size < 2 || partial_shape[1].is_dynamic()) {
    return false;
}
auto dim = partial_shape[1].get_length();

Not using get_shape() method makes your transformation more flexible and applicable for more cases.

3. Friendly Names

Each ngraph::Node has a unique name (used for nGraph internals) and a friendly name. In transformations we care only about friendly name because it represents the name from intermediate representation (IR). Also friendly name is used as output tensor name (until we do not have other way to represent output tensor name) and user code that requests intermediate outputs based on these names. To avoid losing friendly name when replacing node with other node or subgraph, set the original friendly name to the latest node in replacing subgraph. See the example below.

// Replace Div operation with Power and Multiply sub-graph and set original friendly name to Multiply operation
auto pow = std::make_shared<ngraph::opset1::Power>(div->input(1).get_source_output(),
                                                           op::Constant::create(div->get_input_element_type(1), Shape{1}, {-1}));
auto mul = std::make_shared<ngraph::opset1::Multiply>(div->input(0).get_source_output(), pow);
mul->set_friendly_name(div->get_friendly_name());
ngraph::replace_node(div, mul);

In more advanced cases, when replaced operation has several outputs and we add additional consumers to its outputs, we make a decision how to set friendly name by arrangement.

4. Runtime Info

Runtime info is a map std::map<std::string, std::shared_ptr<Variant>> located inside ngraph::Node class. It represents additional attributes in ngraph::Node. These attributes can be set by users or by plugins and when executing transformation that changes ngraph::Function we need to preserve these attributes as they will not be automatically propagated. In most cases, transformations have the following types: 1:1 (replace node with another node), 1:N (replace node with a sub-graph), N:1 (fuse sub-graph into a single node), N:M (any other transformation). Currently, there is no mechanism that automatically detects transformation types, so we need to propagate this runtime information manually. See the examples below.

// Replace Transpose with Reshape operation (1:1)
ngraph::copy_runtime_info(transpose, reshape);
// Replace Div operation with Power and Multiply sub-graph (1:N)
ngraph::copy_runtime_info(div, {pow, mul});
// Fuse Convolution with Add operation (N:1)
ngraph::copy_runtime_info({conv, bias}, {conv_ie});
// Any other transformation that replaces one sub-graph with another sub-graph (N:M)
ngraph::copy_runtime_info({a, b, c}, {e, f});

When transformation has multiple fusions or decompositions, ngraph::copy_runtime_info must be called multiple times for each case.

Note : copy_runtime_info removes rt_info from destination nodes. If you want to keep it, you need to specify them in source nodes like this: copy_runtime_info({a, b, c}, {a, b})

5. Constant Folding

If your transformation inserts constant sub-graphs that need to be folded, do not forget to use ngraph::pass::ConstantFolding() after your transformation or call constant folding directly for operation. The example below shows how constant subgraph can be constructed.

// After ConstantFolding pass Power will be replaced with Constant
auto pow = std::make_shared<ngraph::opset3::Power>(
                    opset3::Constant::create(element::f32, Shape{1}, {2})
                    opset3::Constant::create(element::f32, Shape{1}, {3}));
auto mul = std::make_shared<ngraph::opset3::Multiply>(input /* not constant input */, pow);

Manual constant folding is more preferable than ngraph::pass::ConstantFolding() because it is much faster.

Below you can find an example of manual constant folding:

template <class T>
Output<Node> eltwise_fold(const Output<Node>& input0, const Output<Node>& input1) {
    auto eltwise = std::make_shared<T>(input0, input1);
    OutputVector output(eltwise->get_output_size());
    // If constant folding wasn't successful return eltwise output
    if (!eltwise->constant_fold(output, {input0, input1})) {
        return eltwise->output(0);
    }
    return output[0];
}

Common mistakes in transformations

In transformation development process:

  • Do not use deprecated nGraph API. Deprecated methods has the NGRAPH_DEPRECATED macros in its definition.

  • Do not pass shared_ptr<Node> as an input for other node if type of node is unknown or it has multiple outputs. Use explicit output port.

  • If you replace node with another node that produces different shape, remember that new shape will not be propagated until the first validate_nodes_and_infer_types call for ngraph::Function. If you are using pass::Manager, it will automatically call this method after each transformation execution.

  • Do not forget to call the ngraph::ConstantFolding pass if your transformation creates constant subgraphs.

  • Use latest OpSet if you are not developing downgrade transformation pass.

  • When developing a callback for ngraph::pass::MatcherPass, do not change nodes that come after the root node in topological order.

Using pass manager

ngraph::pass::Manager is a container class that can store the list of transformations and execute them. The main idea of this class is to have high-level representation for grouped list of transformations. It can register and apply any transformation types on function. In addition, ngraph::pass::Manager has extended debug capabilities (find more information in the how to debug transformations section).

The example below shows basic usage of ngraph::pass::Manager

pass::Manager manager;
manager.register_pass<ngraph::pass::MyFunctionTransformation>();
// Two matchers will run independently (two independent graph traversals)
// pass::Manager automatically creates GraphRewrite container for each MatcherPass
manager.register_pass<ngraph::pass::DecomposeDivideMatcher>();
manager.register_pass<ngraph::pass::ReluReluFusionMatcher>();
manager.run_passes(f);

Another example shows how multiple matcher passes can be united into single GraphRewrite.

// Register anchor GraphRewrite pass inside manager that will execute two matchers simultaneously
pass::Manager manager;
auto anchor = manager.register_pass<ngraph::pass::GraphRewrite>();
anchor->add_matcher<ngraph::pass::DecomposeDivideMatcher>();
anchor->add_matcher<ngraph::pass::ReluReluFusionMatcher>();
manager.run_passes(f);

Note

nGraph used to have the pass::PassConfig class for transformation pipeline manipulation.

This mechanism is now obsolete and the pass::PassConfig class will be removed in future release.

How to debug transformations

The most popular tool for transformations debugging is the ngraph::pass::VisualizeTree transformation, which visualizes ngraph::Function.

Usage example:

void visualization_example(std::shared_ptr<ngraph::Function> f) {
    ngraph::pass::Manager manager;

    // Serialize ngraph::Function to before.svg file before transformation
    manager.register_pass<ngraph::pass::VisualizeTree>("/path/to/file/before.svg");

    // Run your transformation
    // manager.register_pass<ngraph::pass::MyTransformation>();

    // Serialize ngraph::Function to after.svg file after transformation
    manager.register_pass<ngraph::pass::VisualizeTree>("/path/to/file/after.svg");

    manager.run_passes(f);
}

ngraph::pass::VisualizeTree can be parametrized via environment variables:

NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES=1 - visualize shapes
NGRAPH_VISUALIZE_TREE_OUTPUT_TYPES=1  - visualize types

Note : current VisualTree does not have user-friendly interface and it will be changed in the nearest future. The intention is to move visualization abilities inside transformations.

If you are using ngraph::pass::Manager to run sequence of transformations, you can get additional debug capabilities by using the following environment variables:

NGRAPH_PROFILE_PASS_ENABLE=1 - enables performance measurement for each transformation and prints execution status
NGRAPH_ENABLE_VISUALIZE_TRACING=1 -  enables visualization after each transformation. By default, it saves dot and svg files.

Note : Make sure that you have dot installed on your machine; otherwise, it will silently save only dot file without svg file.

Disabling/Enabling specific transformations for plugin X

In transformation library, we provide plugins transformations like CommonOptimizations, which contains predefined sequence of transformations. We also provide a tool that helps to disable or partially disable particular transformations in a transformation pipeline. For example, if a plugin uses the CommonOptimization transformation and needs to disable the ConvertGELU transformation, then inside the plugin we have to take the PassConfig instance from pass::Manger and call disable method.

ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::CommonOptimizations>();

auto pass_config = manager.get_pass_config();
pass_config->disable<ngraph::pass::ConvertGELU>();

manager.run_passes(f);

In some cases, we need to disable transformation for some condition:

// Set callback to particular transformation with specific condition
auto pass_config = manager.get_pass_config();
pass_config->set_callback<ngraph::pass::ConvertSpaceToDepth,
                          ngraph::pass::ConvertDepthToSpace>(
        [](const std::shared_ptr<const Node> &node) -> bool {
            return node->input_value(0).get_shape().size() <= 5lu &&
                   node->input_value(0).get_shape().size() == node->get_output_shape(0).size();
        });

// Update transformation to call callback
ngraph::matcher_pass_callback callback = [=](pattern::Matcher &m) {
    auto node = m.get_match_root();
    if (transformation_callback(node)) {
        return false;
    }
    // transformation code
    return false;
};

In some cases, pass::Manager pipelines inside transformations may have transformations disabled by default but enabled inside plugins.

// Example of disabled by default transformation
{
    ngraph::pass::Manager manager;
    manager.register_pass<ngraph::pass::ConvertPadToGroupConvolution, false>();
    manager.run_passes(f);
}

// Enable disabled by default transformation inside plugin
{
    ngraph::pass::Manager manager;
    manager.register_pass<ngraph::pass::CommonOptimizations>();
    auto pass_config = manager.get_pass_config();
    pass_config->enable<ngraph::pass::ConvertPadToGroupConvolution>();
    manager.run_passes(f);
}

PassConfig instance taken from pass::Manager is shared across all registered transformations including nested transformations. So it does not matter where we work with this object (before passes registration or after).

Transformations testing

If you are developing new transformation inside plugin, you need to add test into the template_plugin/tests/functional/transformations folder. We have two types of tests: nGraph reader tests located in inference-engine/tests/functional/inference_engine/ngraph_reader and transformation tests located in inference-engine/tests/functional/inference_engine/transformations Reader tests are IR based and test end-to-end conversion from IR to CNNNetwork. Transformation tests test single ngraph transformations or low-level functions that are used inside transformations.

The basic transformation test looks like this:

TEST(TransformationTests, DISABLED_TemplateTest) {
    std::shared_ptr<ngraph::Function> f, f_ref;
    // f - ngraph::Function for applying transformation
    // f_ref - ngraph::Function that is expected after applying transformation
    {
        // Example function
        auto data = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
        auto divide_constant = ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.5});
        auto divide = std::make_shared<ngraph::opset3::Divide>(data, divide_constant);

        f = std::make_shared<ngraph::Function>(ngraph::NodeVector{divide}, ngraph::ParameterVector{data});

        // This transformation init runtime info attributes
        ngraph::pass::InitNodeInfo().run_on_function(f);

        // Run transformation
        // ngraph::pass::MyTransformation().run_on_function(f);

        // Check that after applying transformation all runtime info attributes was correctly propagated
        ASSERT_NO_THROW(check_rt_info(f));
    }

    {
        // Example reference function
        auto data = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
        auto divide_constant = ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1.5});
        auto pow = std::make_shared<ngraph::opset3::Power>(divide_constant,
                                                           ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {-1}));
        auto mul = std::make_shared<ngraph::opset3::Multiply>(data, pow);

        f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{data});
    }

    // Compare that processed function and expected function are the same
    auto res = compare_functions(f, f_ref);
    ASSERT_TRUE(res.first) << res.second;
}