Public Member Functions
ngraph::pass::PassConfig Class Reference

Class representing a transformations config that is used for disabling/enabling transformations registered inside pass::Manager and also allows to set callback for all transformations or for particular transformation. More...

#include <pass_config.hpp>

Public Member Functions

void disable (const DiscreteTypeInfo &type_info)
 Disable transformation by its type_info. More...
 
template<typename T >
void disable ()
 Disable transformation by its class type (based on type_info)
 
void enable (const DiscreteTypeInfo &type_info)
 Enable transformation by its type_info. More...
 
template<typename T >
void enable ()
 Enable transformation by its class type (based on type_info)
 
void set_callback (const param_callback &callback)
 Set callback for all kind of transformations.
 
template<typename... Args>
std::enable_if< sizeof...(Args)==0 >::type set_callback (const param_callback &callback)
 
template<typename T , class... Args>
void set_callback (const param_callback &callback)
 Set callback for particular transformation class types. More...
 
param_callback get_callback (const DiscreteTypeInfo &type_info) const
 Get callback for given transformation type_info. More...
 
template<typename T >
param_callback get_callback () const
 Get callback for given transformation class type. More...
 
bool is_disabled (const DiscreteTypeInfo &type_info) const
 Check either transformation type is disabled or not. More...
 
template<typename T >
bool is_disabled () const
 Check either transformation class type is disabled or not. More...
 
bool is_enabled (const DiscreteTypeInfo &type_info) const
 Check either transformation type is force enabled or not. More...
 
template<typename T >
bool is_enabled () const
 Check either transformation class type is force enabled or not. More...
 
void add_disabled_passes (const PassConfig &rhs)
 

Detailed Description

Class representing a transformations config that is used for disabling/enabling transformations registered inside pass::Manager and also allows to set callback for all transformations or for particular transformation.

When pass::Manager is created all passes registered inside this manager including nested passes will share the same instance of PassConfig class. To work with this class first you need to get shared instance of this class by calling manager.get_pass_config() method. Then you will be able to disable/enable passes based on transformations type_info. For example:

pass::Manager manager;
manager.register_pass<CommonOptimizations>();
auto pass_config = manager.get_pass_config();
pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
                                     // CommonOptimizations pipeline
manager.run_passes(f);

Sometimes it is needed to call transformation inside other transformation manually. And for that case before running transformation you need manually check that this pass is not disabled and then you need to set current PassConfig instance to this transformation. For example:

// Inside MatcherPass callback or inside FunctionPass run_on_function() method
// you need to call get_pass_config() method to get shared instance of PassConfig
auto pass_config = get_pass_config();

// Before running nested transformation you need to check is it disabled or not
if (!pass_config->is_disabled<ConvertGELU>()) {
    auto pass = ConvertGELU();
    pass->set_pass_config(pass_config);
    pass.apply(node);
}

Following this logic inside your transformations you will guaranty that transformations will be executed in a right way.

Member Function Documentation

◆ disable()

void ngraph::pass::PassConfig::disable ( const DiscreteTypeInfo type_info)

Disable transformation by its type_info.

Parameters
type_infoTransformation type_info

◆ enable()

void ngraph::pass::PassConfig::enable ( const DiscreteTypeInfo type_info)

Enable transformation by its type_info.

Parameters
type_infoTransformation type_info

◆ get_callback() [1/2]

template<typename T >
param_callback ngraph::pass::PassConfig::get_callback ( ) const
inline

Get callback for given transformation class type.

Returns
callback lambda function

◆ get_callback() [2/2]

param_callback ngraph::pass::PassConfig::get_callback ( const DiscreteTypeInfo type_info) const

Get callback for given transformation type_info.

Parameters
type_infoTransformation type_info

In case if callback wasn't set for given transformation type then global callback will be returned. But if even global callback wasn't set then default callback will be returned.

◆ is_disabled() [1/2]

template<typename T >
bool ngraph::pass::PassConfig::is_disabled ( ) const
inline

Check either transformation class type is disabled or not.

Returns
true if transformation type was disabled and false otherwise

◆ is_disabled() [2/2]

bool ngraph::pass::PassConfig::is_disabled ( const DiscreteTypeInfo type_info) const
inline

Check either transformation type is disabled or not.

Parameters
type_infoTransformation type_info
Returns
true if transformation type was disabled and false otherwise

◆ is_enabled() [1/2]

template<typename T >
bool ngraph::pass::PassConfig::is_enabled ( ) const
inline

Check either transformation class type is force enabled or not.

Returns
true if transformation type was force enabled and false otherwise

◆ is_enabled() [2/2]

bool ngraph::pass::PassConfig::is_enabled ( const DiscreteTypeInfo type_info) const
inline

Check either transformation type is force enabled or not.

Parameters
type_infoTransformation type_info
Returns
true if transformation type was force enabled and false otherwise

◆ set_callback()

template<typename T , class... Args>
void ngraph::pass::PassConfig::set_callback ( const param_callback &  callback)
inline

Set callback for particular transformation class types.

Example below show how to set callback for one or multiple passes using this method.

pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
                          ngraph::pass::ConvertSpaceToBatch>(
         [](const_node_ptr &node) -> bool {
              // Disable transformations for cases when input shape rank is not
              equal to 4
              const auto input_shape_rank =
              node->get_output_partial_shape(0).rank().get_length();
              if (input_shape_rank != 4) {
                  return false;
              }
              return true;
          });

Note that inside transformations you must provide code that work with this callback. See example below:

if (transformation_callback(node)) {
    return false; // exit from transformation
}

The documentation for this class was generated from the following file: