manager.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <list>
20 #include <memory>
21 #include <typeinfo>
22 #include <vector>
23 
24 #include "ngraph/pass/pass.hpp"
25 #include "ngraph/pass/validate.hpp"
26 
27 namespace ngraph
28 {
29  namespace pass
30  {
31  class NGRAPH_API Manager
32  {
33  public:
34  Manager();
35  ~Manager();
36 
37  //// \brief Construct Manager with shared PassConfig instance
38  explicit Manager(std::shared_ptr<PassConfig> pass_config);
39 
40  /// \brief Register given transformation class type to execution list
41  /// Example below show the basic usage of pass::Manager
42  ///
43  /// pass::Manager manager;
44  /// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
45  /// manager.run_passes(f);
46  ///
47  /// For some purposes transformation can be registered and disabled by default.
48  ///
49  /// manager.register_pass<MyTransformation, false>();
50  ///
51  /// \return shared_ptr to the transformation instance
52  template <typename T, bool Enable = true, class... Args>
53  std::shared_ptr<T> register_pass(Args&&... args)
54  {
55  auto rc = push_pass<T>(std::forward<Args>(args)...);
56  rc->set_pass_config(m_pass_config);
57  if (m_per_pass_validation)
58  {
59  push_pass<Validate>();
60  }
61  if (!Enable && !m_pass_config->is_enabled<T>())
62  {
63  m_pass_config->disable<T>();
64  }
65  return rc;
66  }
67 
68  void run_passes(std::shared_ptr<Function>);
69 
70  void set_pass_visualization(bool new_state) { m_visualize = new_state; }
71  /// \brief Set flag to enable/disable running Validate pass after executing
72  /// each registered pass
73  /// \param new_state Value "true" enables Validate pass run; "false", otherwise
74  void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
75  /// \brief Callback is a lambda function that can be used by registered transformations.
76  /// The main purpose of this callback is to provide a way for plugins to disable/enable
77  /// transformations based on some conditions. In some cases plugins may want not to
78  /// execute some
79  /// transformations.
80  /// For example plugin can disable unpleasant decompositions because of performance
81  /// reasons for
82  /// some cases.
83  /// Callback example:
84  /// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
85  /// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) !=
86  /// nullptr;
87  /// };
88  /// This callback returns true in case of DepthToSpace operation. So when execution
89  /// DepthToSpace
90  /// decomposition pass will check is this decomposition needed or plugin can execute
91  /// this
92  /// operation directly. And of course on transformation side we need to have a response
93  /// for this
94  /// callback.
95  /// if (transformation_callback(batch_to_space)) {
96  /// return false;
97  /// }
98  /// \param callback lamda function that returns true in case if node is supported by
99  /// plugin and
100  /// transformation is not needed
101  NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
102  void set_callback(const param_callback& callback)
103  {
104  m_pass_config->set_callback(callback);
105  }
106  /// \return PassConfig shared object. This object is used for transformations pipeline
107  /// configuration.
108  /// This object allows to disable/enable transformations execution, set callback to
109  /// particular
110  /// transformation. For mo details see PassConfig class.
111  std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
112  protected:
113  template <typename T, class... Args>
114  std::shared_ptr<T> push_pass(Args&&... args)
115  {
116  static_assert(std::is_base_of<pass::PassBase, T>::value,
117  "pass not derived from pass base");
118  auto pass = std::make_shared<T>(std::forward<Args>(args)...);
119  auto pass_base = std::static_pointer_cast<PassBase>(pass);
120  m_pass_list.push_back(pass_base);
121  return pass;
122  }
123 
124  std::shared_ptr<PassConfig> m_pass_config;
125  std::vector<std::shared_ptr<PassBase>> m_pass_list;
126  bool m_visualize = false;
127  bool m_per_pass_validation = true;
128  };
129  }
130 }
Definition: manager.hpp:32
std::shared_ptr< T > register_pass(Args &&... args)
Register given transformation class type to execution list Example below show the basic usage of pass...
Definition: manager.hpp:53
std::shared_ptr< PassConfig > get_pass_config()
Definition: manager.hpp:111
void set_per_pass_validation(bool new_state)
Set flag to enable/disable running Validate pass after executing each registered pass.
Definition: manager.hpp:74
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28