pass_config.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <list>
8 #include <memory>
9 #include <vector>
10 
11 #include "ngraph/deprecated.hpp"
12 #include "ngraph/function.hpp"
13 #include "ngraph/node.hpp"
14 #include "ngraph/util.hpp"
15 
16 namespace ngraph
17 {
18  namespace pass
19  {
20  using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
21  using param_callback_map = std::map<ngraph::DiscreteTypeInfo, param_callback>;
22 
23  /// \brief Class representing a transformations config that is used for disabling/enabling
24  /// transformations registered inside pass::Manager and also allows to set callback for all
25  /// transformations or for particular transformation.
26  ///
27  /// When pass::Manager is created all passes registered inside this manager including nested
28  /// passes will share the same instance of PassConfig class.
29  /// To work with this class first you need to get shared instance of this class by calling
30  /// manager.get_pass_config() method. Then you will be able to disable/enable passes based
31  /// on transformations type_info. For example:
32  ///
33  /// pass::Manager manager;
34  /// manager.register_pass<CommonOptimizations>();
35  /// auto pass_config = manager.get_pass_config();
36  /// pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
37  /// // CommonOptimizations pipeline
38  /// manager.run_passes(f);
39  ///
40  /// Sometimes it is needed to call transformation inside other transformation manually. And
41  /// for that case before running transformation you need manually check that this pass is
42  /// not disabled and then you need to set current PassConfig instance to this
43  /// transformation. For example:
44  ///
45  /// // Inside MatcherPass callback or inside FunctionPass run_on_function() method
46  /// // you need to call get_pass_config() method to get shared instance of PassConfig
47  /// auto pass_config = get_pass_config();
48  ///
49  /// // Before running nested transformation you need to check is it disabled or not
50  /// if (!pass_config->is_disabled<ConvertGELU>()) {
51  /// auto pass = ConvertGELU();
52  /// pass->set_pass_config(pass_config);
53  /// pass.apply(node);
54  /// }
55  ///
56  /// Following this logic inside your transformations you will guaranty that transformations
57  /// will be executed in a right way.
58  class NGRAPH_API PassConfig
59  {
60  public:
61  /// \brief Disable transformation by its type_info
62  /// \param type_info Transformation type_info
63  void disable(const DiscreteTypeInfo& type_info);
64  /// \brief Disable transformation by its class type (based on type_info)
65  template <typename T>
66  void disable()
67  {
68  NGRAPH_SUPPRESS_DEPRECATED_START
69  disable(T::type_info);
70  NGRAPH_SUPPRESS_DEPRECATED_END
71  }
72 
73  /// \brief Enable transformation by its type_info
74  /// \param type_info Transformation type_info
75  void enable(const DiscreteTypeInfo& type_info);
76  /// \brief Enable transformation by its class type (based on type_info)
77  template <typename T>
78  void enable()
79  {
80  NGRAPH_SUPPRESS_DEPRECATED_START
81  enable(T::type_info);
82  NGRAPH_SUPPRESS_DEPRECATED_END
83  }
84 
85  /// \brief Set callback for all kind of transformations
86  void set_callback(const param_callback& callback) { m_callback = callback; }
87  template <typename... Args>
88  typename std::enable_if<sizeof...(Args) == 0>::type
89  set_callback(const param_callback& callback)
90  {
91  }
92 
93  /// \brief Set callback for particular transformation class types
94  ///
95  /// Example below show how to set callback for one or multiple passes using this method.
96  ///
97  /// pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
98  /// ngraph::pass::ConvertSpaceToBatch>(
99  /// [](const_node_ptr &node) -> bool {
100  /// // Disable transformations for cases when input shape rank is not
101  /// equal to 4
102  /// const auto input_shape_rank =
103  /// node->get_output_partial_shape(0).rank().get_length();
104  /// if (input_shape_rank != 4) {
105  /// return false;
106  /// }
107  /// return true;
108  /// });
109  ///
110  /// Note that inside transformations you must provide code that work with this callback.
111  /// See example below:
112  ///
113  /// if (transformation_callback(node)) {
114  /// return false; // exit from transformation
115  /// }
116  ///
117  template <typename T, class... Args>
118  void set_callback(const param_callback& callback)
119  {
120  m_callback_map[T::type_info] = callback;
121  set_callback<Args...>(callback);
122  }
123 
124  /// \brief Get callback for given transformation type_info
125  /// \param type_info Transformation type_info
126  ///
127  /// In case if callback wasn't set for given transformation type then global callback
128  /// will be returned. But if even global callback wasn't set then default callback will
129  /// be returned.
130  param_callback get_callback(const DiscreteTypeInfo& type_info) const;
131 
132  /// \brief Get callback for given transformation class type
133  /// \return callback lambda function
134  template <typename T>
135  param_callback get_callback() const
136  {
137  NGRAPH_SUPPRESS_DEPRECATED_START
138  return get_callback(T::type_info);
139  NGRAPH_SUPPRESS_DEPRECATED_END
140  }
141 
142  /// \brief Check either transformation type is disabled or not
143  /// \param type_info Transformation type_info
144  /// \return true if transformation type was disabled and false otherwise
145  bool is_disabled(const DiscreteTypeInfo& type_info) const
146  {
147  return m_disabled.count(type_info);
148  }
149 
150  /// \brief Check either transformation class type is disabled or not
151  /// \return true if transformation type was disabled and false otherwise
152  template <typename T>
153  bool is_disabled() const
154  {
155  NGRAPH_SUPPRESS_DEPRECATED_START
156  return is_disabled(T::type_info);
157  NGRAPH_SUPPRESS_DEPRECATED_END
158  }
159 
160  /// \brief Check either transformation type is force enabled or not
161  /// \param type_info Transformation type_info
162  /// \return true if transformation type was force enabled and false otherwise
163  bool is_enabled(const DiscreteTypeInfo& type_info) const
164  {
165  return m_enabled.count(type_info);
166  }
167 
168  /// \brief Check either transformation class type is force enabled or not
169  /// \return true if transformation type was force enabled and false otherwise
170  template <typename T>
171  bool is_enabled() const
172  {
173  return is_enabled(T::type_info);
174  }
175 
176  void add_disabled_passes(const PassConfig& rhs);
177 
178  private:
179  param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
180  return false;
181  };
182  param_callback_map m_callback_map;
183  std::unordered_set<DiscreteTypeInfo> m_disabled;
184  std::unordered_set<DiscreteTypeInfo> m_enabled;
185  };
186  } // namespace pass
187 } // namespace ngraph
Class representing a transformations config that is used for disabling/enabling transformations regis...
Definition: pass_config.hpp:59
param_callback get_callback(const DiscreteTypeInfo &type_info) const
Get callback for given transformation type_info.
bool is_enabled(const DiscreteTypeInfo &type_info) const
Check either transformation type is force enabled or not.
Definition: pass_config.hpp:163
void disable(const DiscreteTypeInfo &type_info)
Disable transformation by its type_info.
param_callback get_callback() const
Get callback for given transformation class type.
Definition: pass_config.hpp:135
bool is_disabled() const
Check either transformation class type is disabled or not.
Definition: pass_config.hpp:153
void enable()
Enable transformation by its class type (based on type_info)
Definition: pass_config.hpp:78
void set_callback(const param_callback &callback)
Set callback for particular transformation class types.
Definition: pass_config.hpp:118
void set_callback(const param_callback &callback)
Set callback for all kind of transformations.
Definition: pass_config.hpp:86
void enable(const DiscreteTypeInfo &type_info)
Enable transformation by its type_info.
void disable()
Disable transformation by its class type (based on type_info)
Definition: pass_config.hpp:66
bool is_enabled() const
Check either transformation class type is force enabled or not.
Definition: pass_config.hpp:171
bool is_disabled(const DiscreteTypeInfo &type_info) const
Check either transformation type is disabled or not.
Definition: pass_config.hpp:145
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27