manager.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <list>
8 #include <memory>
9 #include <typeinfo>
10 #include <vector>
11 
12 #include "ngraph/pass/pass.hpp"
13 #include "ngraph/pass/validate.hpp"
14 
15 namespace ngraph
16 {
17  namespace pass
18  {
19  class NGRAPH_API Manager
20  {
21  public:
22  Manager();
23  ~Manager();
24 
25  //// \brief Construct Manager with shared PassConfig instance
26  explicit Manager(std::shared_ptr<PassConfig> pass_config);
27 
28  /// \brief Register given transformation class type to execution list
29  /// Example below show the basic usage of pass::Manager
30  ///
31  /// pass::Manager manager;
32  /// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
33  /// manager.run_passes(f);
34  ///
35  /// For some purposes transformation can be registered and disabled by default.
36  ///
37  /// manager.register_pass<MyTransformation, false>();
38  ///
39  /// \return shared_ptr to the transformation instance
40  template <typename T, bool Enable = true, class... Args>
41  std::shared_ptr<T> register_pass(Args&&... args)
42  {
43  auto rc = push_pass<T>(std::forward<Args>(args)...);
44  rc->set_pass_config(m_pass_config);
45  if (m_per_pass_validation)
46  {
47  push_pass<Validate>();
48  }
49  if (!Enable && !m_pass_config->is_enabled<T>())
50  {
51  m_pass_config->disable<T>();
52  }
53  return rc;
54  }
55 
56  void run_passes(std::shared_ptr<Function>);
57 
58  void set_pass_visualization(bool new_state) { m_visualize = new_state; }
59  /// \brief Set flag to enable/disable running Validate pass after executing
60  /// each registered pass
61  /// \param new_state Value "true" enables Validate pass run; "false", otherwise
62  void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
63  /// \brief Callback is a lambda function that can be used by registered transformations.
64  /// The main purpose of this callback is to provide a way for plugins to disable/enable
65  /// transformations based on some conditions. In some cases plugins may want not to
66  /// execute some
67  /// transformations.
68  /// For example plugin can disable unpleasant decompositions because of performance
69  /// reasons for
70  /// some cases.
71  /// Callback example:
72  /// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
73  /// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) !=
74  /// nullptr;
75  /// };
76  /// This callback returns true in case of DepthToSpace operation. So when execution
77  /// DepthToSpace
78  /// decomposition pass will check is this decomposition needed or plugin can execute
79  /// this
80  /// operation directly. And of course on transformation side we need to have a response
81  /// for this
82  /// callback.
83  /// if (transformation_callback(batch_to_space)) {
84  /// return false;
85  /// }
86  /// \param callback lamda function that returns true in case if node is supported by
87  /// plugin and
88  /// transformation is not needed
89  NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
90  void set_callback(const param_callback& callback)
91  {
92  m_pass_config->set_callback(callback);
93  }
94  /// \return PassConfig shared object. This object is used for transformations pipeline
95  /// configuration.
96  /// This object allows to disable/enable transformations execution, set callback to
97  /// particular
98  /// transformation. For mo details see PassConfig class.
99  std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
100 
101  protected:
102  template <typename T, class... Args>
103  std::shared_ptr<T> push_pass(Args&&... args)
104  {
105  static_assert(std::is_base_of<pass::PassBase, T>::value,
106  "pass not derived from pass base");
107  auto pass = std::make_shared<T>(std::forward<Args>(args)...);
108  auto pass_base = std::static_pointer_cast<PassBase>(pass);
109  m_pass_list.push_back(pass_base);
110  return pass;
111  }
112 
113  std::shared_ptr<PassConfig> m_pass_config;
114  std::vector<std::shared_ptr<PassBase>> m_pass_list;
115  bool m_visualize = false;
116  bool m_per_pass_validation = true;
117  };
118  } // namespace pass
119 } // namespace ngraph
Definition: manager.hpp:20
std::shared_ptr< T > register_pass(Args &&... args)
Register given transformation class type to execution list Example below show the basic usage of pass...
Definition: manager.hpp:41
std::shared_ptr< PassConfig > get_pass_config()
Definition: manager.hpp:99
void set_per_pass_validation(bool new_state)
Set flag to enable/disable running Validate pass after executing each registered pass.
Definition: manager.hpp:62
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16