class ov::pass::PassConfig

Overview

Class representing a transformations config that is used for disabling/enabling transformations registered inside pass::Manager and also allows to set callback for all transformations or for particular transformation. More…

#include <pass_config.hpp>

class PassConfig
{
public:
    // methods

    void disable(const DiscreteTypeInfo& type_info);

    template <
        class T,
        typename std::enable_if<!ngraph::HasTypeInfoMember<T>::value, bool>::type = true
        >
    void disable();

    void enable(const DiscreteTypeInfo& type_info);

    template <
        class T,
        typename std::enable_if<!ngraph::HasTypeInfoMember<T>::value, bool>::type = true
        >
    void enable();

    void set_callback(const param_callback& callback);

    template <typename... Args>
    std::enable_if<sizeof...(Args)==0>::type set_callback(const param_callback& callback);

    template <
        typename T,
        class... Args,
        typename std::enable_if<!ngraph::HasTypeInfoMember<T>::value, bool>::type = true
        >
    void set_callback(const param_callback& callback);

    param_callback get_callback(const DiscreteTypeInfo& type_info) const;

    template <
        class T,
        typename std::enable_if<ngraph::HasTypeInfoMember<T>::value, bool>::type = true
        >
    param_callback get_callback() const;

    bool is_disabled(const DiscreteTypeInfo& type_info) const;

    template <
        class T,
        typename std::enable_if<ngraph::HasTypeInfoMember<T>::value, bool>::type = true
        >
    bool is_disabled() const;

    bool is_enabled(const DiscreteTypeInfo& type_info) const;

    template <
        class T,
        typename std::enable_if<ngraph::HasTypeInfoMember<T>::value, bool>::type = true
        >
    bool is_enabled() const;

    void add_disabled_passes(const PassConfig& rhs);
};

Detailed Documentation

Class representing a transformations config that is used for disabling/enabling transformations registered inside pass::Manager and also allows to set callback for all transformations or for particular transformation.

When pass::Manager is created all passes registered inside this manager including nested passes will share the same instance of PassConfig class. To work with this class first you need to get shared instance of this class by calling manager.get_pass_config() method. Then you will be able to disable/enable passes based on transformations type_info. For example:

pass::Manager manager;
manager.register_pass<CommonOptimizations>();
auto pass_config = manager.get_pass_config();
pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
                                     // CommonOptimizations pipeline
manager.run_passes(f);

Sometimes it is needed to call transformation inside other transformation manually. And for that case before running transformation you need manually check that this pass is not disabled and then you need to set current PassConfig instance to this transformation. For example:

// Inside MatcherPass callback or inside FunctionPass run_on_function() method
// you need to call get_pass_config() method to get shared instance of PassConfig
auto pass_config = get_pass_config();

// Before running nested transformation you need to check is it disabled or not
if (!pass_config->is_disabled<ConvertGELU>()) {
    auto pass = ConvertGELU();
    pass->set_pass_config(pass_config);
    pass.apply(node);
}

Following this logic inside your transformations you will guaranty that transformations will be executed in a right way.

Methods

void disable(const DiscreteTypeInfo& type_info)

Disable transformation by its type_info.

Parameters:

type_info

Transformation type_info

template <
    class T,
    typename std::enable_if<!ngraph::HasTypeInfoMember<T>::value, bool>::type = true
    >
void disable()

Disable transformation by its class type (based on type_info)

void enable(const DiscreteTypeInfo& type_info)

Enable transformation by its type_info.

Parameters:

type_info

Transformation type_info

template <
    class T,
    typename std::enable_if<!ngraph::HasTypeInfoMember<T>::value, bool>::type = true
    >
void enable()

Enable transformation by its class type (based on type_info)

void set_callback(const param_callback& callback)

Set callback for all kind of transformations.

template <
    typename T,
    class... Args,
    typename std::enable_if<!ngraph::HasTypeInfoMember<T>::value, bool>::type = true
    >
void set_callback(const param_callback& callback)

Set callback for particular transformation class types.

Example below show how to set callback for one or multiple passes using this method.

pass_config->set_callback<ov::pass::ConvertBatchToSpace,
                          ov::pass::ConvertSpaceToBatch>(
         [](const_node_ptr &node) -> bool {
              // Disable transformations for cases when input shape rank is not
              equal to 4
              const auto input_shape_rank =
              node->get_output_partial_shape(0).rank().get_length();
              if (input_shape_rank != 4) {
                  return false;
              }
              return true;
          });

Note that inside transformations you must provide code that work with this callback. See example below:

if (transformation_callback(node)) {
    return false; // exit from transformation
}
param_callback get_callback(const DiscreteTypeInfo& type_info) const

Get callback for given transformation type_info.

In case if callback wasn’t set for given transformation type then global callback will be returned. But if even global callback wasn’t set then default callback will be returned.

Parameters:

type_info

Transformation type_info

template <
    class T,
    typename std::enable_if<ngraph::HasTypeInfoMember<T>::value, bool>::type = true
    >
param_callback get_callback() const

Get callback for given transformation class type.

Returns:

callback lambda function

bool is_disabled(const DiscreteTypeInfo& type_info) const

Check either transformation type is disabled or not.

Parameters:

type_info

Transformation type_info

Returns:

true if transformation type was disabled and false otherwise

template <
    class T,
    typename std::enable_if<ngraph::HasTypeInfoMember<T>::value, bool>::type = true
    >
bool is_disabled() const

Check either transformation class type is disabled or not.

Returns:

true if transformation type was disabled and false otherwise

bool is_enabled(const DiscreteTypeInfo& type_info) const

Check either transformation type is force enabled or not.

Parameters:

type_info

Transformation type_info

Returns:

true if transformation type was force enabled and false otherwise

template <
    class T,
    typename std::enable_if<ngraph::HasTypeInfoMember<T>::value, bool>::type = true
    >
bool is_enabled() const

Check either transformation class type is force enabled or not.

Returns:

true if transformation type was force enabled and false otherwise