Class representing a transformations config that is used for disabling/enabling transformations registered inside pass::Manager and also allows to set callback for all transformations or for particular transformation. More...
#include <pass_config.hpp>
Public Member Functions | |
void | disable (const DiscreteTypeInfo &type_info) |
Disable transformation by its type_info. More... | |
template<typename T > | |
void | disable () |
Disable transformation by its class type (based on type_info) | |
void | enable (const DiscreteTypeInfo &type_info) |
Enable transformation by its type_info. More... | |
template<typename T > | |
void | enable () |
Enable transformation by its class type (based on type_info) | |
void | set_callback (const param_callback &callback) |
Set callback for all kind of transformations. | |
template<typename... Args> | |
std::enable_if< sizeof...(Args)==0 >::type | set_callback (const param_callback &callback) |
template<typename T , class... Args> | |
void | set_callback (const param_callback &callback) |
Set callback for particular transformation class types. More... | |
param_callback | get_callback (const DiscreteTypeInfo &type_info) const |
Get callback for given transformation type_info. More... | |
template<typename T > | |
param_callback | get_callback () const |
Get callback for given transformation class type. More... | |
bool | is_disabled (const DiscreteTypeInfo &type_info) const |
Check either transformation type is disabled or not. More... | |
template<typename T > | |
bool | is_disabled () const |
Check either transformation class type is disabled or not. More... | |
bool | is_enabled (const DiscreteTypeInfo &type_info) const |
Check either transformation type is force enabled or not. More... | |
template<typename T > | |
bool | is_enabled () const |
Check either transformation class type is force enabled or not. More... | |
void | add_disabled_passes (const PassConfig &rhs) |
Class representing a transformations config that is used for disabling/enabling transformations registered inside pass::Manager and also allows to set callback for all transformations or for particular transformation.
When pass::Manager is created all passes registered inside this manager including nested passes will share the same instance of PassConfig class. To work with this class first you need to get shared instance of this class by calling manager.get_pass_config() method. Then you will be able to disable/enable passes based on transformations type_info. For example:
pass::Manager manager; manager.register_pass<CommonOptimizations>(); auto pass_config = manager.get_pass_config(); pass_config->disable<ConvertGELU>(); // this will disable nested pass inside // CommonOptimizations pipeline manager.run_passes(f);
Sometimes it is needed to call transformation inside other transformation manually. And for that case before running transformation you need manually check that this pass is not disabled and then you need to set current PassConfig instance to this transformation. For example:
// Inside MatcherPass callback or inside FunctionPass run_on_function() method // you need to call get_pass_config() method to get shared instance of PassConfig auto pass_config = get_pass_config(); // Before running nested transformation you need to check is it disabled or not if (!pass_config->is_disabled<ConvertGELU>()) { auto pass = ConvertGELU(); pass->set_pass_config(pass_config); pass.apply(node); }
Following this logic inside your transformations you will guaranty that transformations will be executed in a right way.
void ngraph::pass::PassConfig::disable | ( | const DiscreteTypeInfo & | type_info | ) |
Disable transformation by its type_info.
type_info | Transformation type_info |
void ngraph::pass::PassConfig::enable | ( | const DiscreteTypeInfo & | type_info | ) |
Enable transformation by its type_info.
type_info | Transformation type_info |
|
inline |
Get callback for given transformation class type.
param_callback ngraph::pass::PassConfig::get_callback | ( | const DiscreteTypeInfo & | type_info | ) | const |
Get callback for given transformation type_info.
type_info | Transformation type_info |
In case if callback wasn't set for given transformation type then global callback will be returned. But if even global callback wasn't set then default callback will be returned.
|
inline |
Check either transformation class type is disabled or not.
|
inline |
Check either transformation type is disabled or not.
type_info | Transformation type_info |
|
inline |
Check either transformation class type is force enabled or not.
|
inline |
Check either transformation type is force enabled or not.
type_info | Transformation type_info |
|
inline |
Set callback for particular transformation class types.
Example below show how to set callback for one or multiple passes using this method.
pass_config->set_callback<ngraph::pass::ConvertBatchToSpace, ngraph::pass::ConvertSpaceToBatch>( [](const_node_ptr &node) -> bool { // Disable transformations for cases when input shape rank is not equal to 4 const auto input_shape_rank = node->get_output_partial_shape(0).rank().get_length(); if (input_shape_rank != 4) { return false; } return true; });
Note that inside transformations you must provide code that work with this callback. See example below:
if (transformation_callback(node)) { return false; // exit from transformation }