pass_config.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <list>
20 #include <memory>
21 #include <vector>
22 
23 #include "ngraph/deprecated.hpp"
24 #include "ngraph/function.hpp"
25 #include "ngraph/node.hpp"
26 #include "ngraph/util.hpp"
27 
28 namespace ngraph
29 {
30  namespace pass
31  {
32  using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
33  using param_callback_map = std::map<ngraph::DiscreteTypeInfo, param_callback>;
34 
35  /// \brief Class representing a transformations config that is used for disabling/enabling
36  /// transformations registered inside pass::Manager and also allows to set callback for all
37  /// transformations or for particular transformation.
38  ///
39  /// When pass::Manager is created all passes registered inside this manager including nested
40  /// passes will share the same instance of PassConfig class.
41  /// To work with this class first you need to get shared instance of this class by calling
42  /// manager.get_pass_config() method. Then you will be able to disable/enable passes based
43  /// on transformations type_info. For example:
44  ///
45  /// pass::Manager manager;
46  /// manager.register_pass<CommonOptimizations>();
47  /// auto pass_config = manager.get_pass_config();
48  /// pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
49  /// // CommonOptimizations pipeline
50  /// manager.run_passes(f);
51  ///
52  /// Sometimes it is needed to call transformation inside other transformation manually. And
53  /// for that case before running transformation you need manually check that this pass is
54  /// not disabled and then you need to set current PassConfig instance to this
55  /// transformation. For example:
56  ///
57  /// // Inside MatcherPass callback or inside FunctionPass run_on_function() method
58  /// // you need to call get_pass_config() method to get shared instance of PassConfig
59  /// auto pass_config = get_pass_config();
60  ///
61  /// // Before running nested transformation you need to check is it disabled or not
62  /// if (!pass_config->is_disabled<ConvertGELU>()) {
63  /// auto pass = ConvertGELU();
64  /// pass->set_pass_config(pass_config);
65  /// pass.apply(node);
66  /// }
67  ///
68  /// Following this logic inside your transformations you will guaranty that transformations
69  /// will be executed in a right way.
70  class NGRAPH_API PassConfig
71  {
72  public:
73  /// \brief Disable transformation by its type_info
74  /// \param type_info Transformation type_info
75  void disable(const DiscreteTypeInfo& type_info);
76  /// \brief Disable transformation by its class type (based on type_info)
77  template <typename T>
78  void disable()
79  {
80  NGRAPH_SUPPRESS_DEPRECATED_START
81  disable(T::type_info);
82  NGRAPH_SUPPRESS_DEPRECATED_END
83  }
84 
85  /// \brief Enable transformation by its type_info
86  /// \param type_info Transformation type_info
87  void enable(const DiscreteTypeInfo& type_info);
88  /// \brief Enable transformation by its class type (based on type_info)
89  template <typename T>
90  void enable()
91  {
92  NGRAPH_SUPPRESS_DEPRECATED_START
93  enable(T::type_info);
94  NGRAPH_SUPPRESS_DEPRECATED_END
95  }
96 
97  /// \brief Set callback for all kind of transformations
98  void set_callback(const param_callback& callback) { m_callback = callback; }
99  template <typename... Args>
100  typename std::enable_if<sizeof...(Args) == 0>::type
101  set_callback(const param_callback& callback)
102  {
103  }
104 
105  /// \brief Set callback for particular transformation class types
106  ///
107  /// Example below show how to set callback for one or multiple passes using this method.
108  ///
109  /// pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
110  /// ngraph::pass::ConvertSpaceToBatch>(
111  /// [](const_node_ptr &node) -> bool {
112  /// // Disable transformations for cases when input shape rank is not
113  /// equal to 4
114  /// const auto input_shape_rank =
115  /// node->get_output_partial_shape(0).rank().get_length();
116  /// if (input_shape_rank != 4) {
117  /// return false;
118  /// }
119  /// return true;
120  /// });
121  ///
122  /// Note that inside transformations you must provide code that work with this callback.
123  /// See example below:
124  ///
125  /// if (transformation_callback(node)) {
126  /// return false; // exit from transformation
127  /// }
128  ///
129  template <typename T, class... Args>
130  void set_callback(const param_callback& callback)
131  {
132  m_callback_map[T::type_info] = callback;
133  set_callback<Args...>(callback);
134  }
135 
136  /// \brief Get callback for given transformation type_info
137  /// \param type_info Transformation type_info
138  ///
139  /// In case if callback wasn't set for given transformation type then global callback
140  /// will be returned. But if even global callback wasn't set then default callback will
141  /// be returned.
142  param_callback get_callback(const DiscreteTypeInfo& type_info) const;
143 
144  /// \brief Get callback for given transformation class type
145  /// \return callback lambda function
146  template <typename T>
147  param_callback get_callback() const
148  {
149  NGRAPH_SUPPRESS_DEPRECATED_START
150  return get_callback(T::type_info);
151  NGRAPH_SUPPRESS_DEPRECATED_END
152  }
153 
154  /// \brief Check either transformation type is disabled or not
155  /// \param type_info Transformation type_info
156  /// \return true if transformation type was disabled and false otherwise
157  bool is_disabled(const DiscreteTypeInfo& type_info) const
158  {
159  return m_disabled.count(type_info);
160  }
161 
162  /// \brief Check either transformation class type is disabled or not
163  /// \return true if transformation type was disabled and false otherwise
164  template <typename T>
165  bool is_disabled() const
166  {
167  NGRAPH_SUPPRESS_DEPRECATED_START
168  return is_disabled(T::type_info);
169  NGRAPH_SUPPRESS_DEPRECATED_END
170  }
171 
172  /// \brief Check either transformation type is force enabled or not
173  /// \param type_info Transformation type_info
174  /// \return true if transformation type was force enabled and false otherwise
175  bool is_enabled(const DiscreteTypeInfo& type_info) const
176  {
177  return m_enabled.count(type_info);
178  }
179 
180  /// \brief Check either transformation class type is force enabled or not
181  /// \return true if transformation type was force enabled and false otherwise
182  template <typename T>
183  bool is_enabled() const
184  {
185  return is_enabled(T::type_info);
186  }
187 
188  void add_disabled_passes(const PassConfig& rhs);
189 
190  private:
191  param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
192  return false;
193  };
194  param_callback_map m_callback_map;
195  std::unordered_set<DiscreteTypeInfo> m_disabled;
196  std::unordered_set<DiscreteTypeInfo> m_enabled;
197  };
198  }
199 }
Class representing a transformations config that is used for disabling/enabling transformations regis...
Definition: pass_config.hpp:71
param_callback get_callback(const DiscreteTypeInfo &type_info) const
Get callback for given transformation type_info.
bool is_enabled(const DiscreteTypeInfo &type_info) const
Check either transformation type is force enabled or not.
Definition: pass_config.hpp:175
void disable(const DiscreteTypeInfo &type_info)
Disable transformation by its type_info.
param_callback get_callback() const
Get callback for given transformation class type.
Definition: pass_config.hpp:147
bool is_disabled() const
Check either transformation class type is disabled or not.
Definition: pass_config.hpp:165
void enable()
Enable transformation by its class type (based on type_info)
Definition: pass_config.hpp:90
void set_callback(const param_callback &callback)
Set callback for particular transformation class types.
Definition: pass_config.hpp:130
void set_callback(const param_callback &callback)
Set callback for all kind of transformations.
Definition: pass_config.hpp:98
void enable(const DiscreteTypeInfo &type_info)
Enable transformation by its type_info.
void disable()
Disable transformation by its class type (based on type_info)
Definition: pass_config.hpp:78
bool is_enabled() const
Check either transformation class type is force enabled or not.
Definition: pass_config.hpp:183
bool is_disabled(const DiscreteTypeInfo &type_info) const
Check either transformation type is disabled or not.
Definition: pass_config.hpp:157
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
Definition: type.hpp:39