class ngraph::pass::GraphRewrite

Overview

GraphRewrite is a container for MatcherPasses that allows to run them on Function in efficient way. More…

#include <graph_rewrite.hpp>

class GraphRewrite: public ngraph::pass::FunctionPass
{
public:
    // construction

    GraphRewrite();
    GraphRewrite(const std::shared_ptr<MatcherPass>& pass);

    // methods

    template <
        typename T,
        bool Enabled = true,
        class... Args,
        typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value, bool>::type = true
        >
    std::shared_ptr<T> add_matcher(Args&&... args);

    template <
        typename T,
        class... Args,
        typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value, bool>::type = true
        >
    void add_matcher(Args&&... args);

    void add_matcher(
        const std::shared_ptr<pattern::Matcher>& m,
        const ngraph::graph_rewrite_callback& callback,
        const PassPropertyMask& property
        );

    void add_matcher(
        const std::shared_ptr<pattern::Matcher>& m,
        const ngraph::graph_rewrite_callback& callback
        );

    virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
    virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config);
};

// direct descendants

class BackwardGraphRewrite;
class BidirectionalSequenceDecomposition;
class ConvFusion;
class ConvertConvolutions;
class ConvertReduceToPooling;
class GeluFusion;
class HSigmoidFusion;
class HSwishFusion;
class LinOpSequenceFusion;
class MVNFusion;
class NopElimination;
class NormalizeL2Fusion;
class PadFusion;
class SwishFusion;
class TransposeSinking;

Inherited Members

public:
    // typedefs

    typedef DiscreteTypeInfo type_info_t;

    // methods

    bool get_property(const PassPropertyMask& prop_mask) const;
    void set_name(const std::string& name);
    std::string get_name() const;
    void set_callback(const param_callback& callback);
    virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config);
    std::shared_ptr<PassConfig> get_pass_config();
    bool m_transformation_callback(const std::shared_ptr<const Node>& node);
    bool transformation_callback(const std::shared_ptr<const Node>& node);
    virtual const type_info_t& get_type_info() const = 0;
    virtual bool run_on_function(std::shared_ptr<ngraph::Function>) = 0;

Detailed Documentation

GraphRewrite is a container for MatcherPasses that allows to run them on Function in efficient way.

Graph rewrite pass is used for matcher passes execution on Function. To register MatcherPass use

See also:

add_matcher<T>(args) method where T is a MatcherPass class. As a default algorithm graph rewrite pass traverse Function in topological order and applies registered matcher passes for each node. But if all registered matcher passes have type based root node in Matcher pattern then efficient mechanism is used to execute them. Matcher pattern root is type based if it’s operation from opset or pattern::op::WrapType. Note: when implementing pattern for Matcher make sure that root node is an operation from opset or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher passes more efficient.

Methods

template <
    typename T,
    bool Enabled = true,
    class... Args,
    typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value, bool>::type = true
    >
std::shared_ptr<T> add_matcher(Args&&... args)

Register given transformation class type to GraphRewrite execution list All registered transformations will be executed in a single graph traversal. Example below show the basic usage of pass::GraphRewrite.

pass::Manager manager;
auto anchor = manager.register_pass<GraphRewrite>();
anchor->add_matcher<MatcherPassA>();
anchor->add_matcher<MatcherPassB>();
anchor->set_name("CommonMatchers");
manager.run_passes(f);

For some purposes transformation can be registered and disabled by default.

anchor->add_matcher<MatcherPassB, false>();

Returns:

shared_ptr to the transformation instance

template <
    typename T,
    class... Args,
    typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value, bool>::type = true
    >
void add_matcher(Args&&... args)

Register passes from GraphRewrite class that contains sequence of matcher passes registered in its ctor. For example:

class ngraph::pass::LinFusions: public ngraph::pass::GraphRewrite { public: NGRAPH_RTTI_DECLARATION; Fusions() {add_matcher<ngraph::pass::AddFusion>(); add_matcher<ngraph::pass::MulFusion>(); } };

pass::Manager manager; auto anchor = manager.register_pass<GraphRewrite>(); anchor-> add_matcher<LinFusions>(); anchor-> add_matcher<OtherFusions>(); anchor->set_name(“CommonFusions”); manager.run_passes(f);

In this case all matcher passes from LinFusions pass will be united with other registered matchers.

virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)

Set PassConfig for particular transformation instance.

Parameters:

pass_config

is a PassConfig shared_ptr