pass.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <list>
8 #include <memory>
9 #include <vector>
10 
11 #include "ngraph/deprecated.hpp"
12 #include "ngraph/function.hpp"
13 #include "ngraph/node.hpp"
14 #include "ngraph/pass/pass_config.hpp"
15 #include "ngraph/util.hpp"
16 
17 namespace ngraph
18 {
19  namespace pass
20  {
21  enum class PassProperty : uint32_t
22  {
23  // Pass requires node shapes to be static
24  REQUIRE_STATIC_SHAPE = 0x1,
25  // Pass transformation will change the function's dynamic state
26  CHANGE_DYNAMIC_STATE = 1 << 1,
27  };
28 
29  typedef EnumMask<PassProperty> PassPropertyMask;
30  const PassPropertyMask all_pass_property_off;
31 
32  class NGRAPH_API PassBase
33  {
34  friend class Manager;
35 
36  public:
37  PassBase();
38  virtual ~PassBase() {}
39  /// Check if this pass has all the pass properties.
40  bool get_property(const PassPropertyMask& prop_mask) const;
41 
42  void set_name(const std::string& name) { m_name = name; }
43  std::string get_name() const;
44 
45  /// \brief Set callback for particular transformation type.
46  /// This method set global callback. For more details see PassConfig class
47  /// documentation.
48  /// \param callback lambda function that takes node and returns bool
49  void set_callback(const param_callback& callback);
50 
51  /// \brief Set PassConfig for particular transformation instance
52  /// \param pass_config is a PassConfig shared_ptr
53  virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
54  {
55  m_pass_config = pass_config;
56  }
57 
58  /// \brief Allows to access PassConfig shared instance
59  /// \return Shared instance of PassConfig class
60  std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
61  /// \brief Applies callback for given node. By default callback returns false.
62  /// This method remains here only for backward compatibility and will be removed
63  /// after all transformations are moved to transformation_callback() method.
64  /// \return result of callback execution for given node
65  NGRAPH_DEPRECATED("Please use transformation_callback method instead")
66  bool m_transformation_callback(const std::shared_ptr<const Node>& node)
67  {
68  return m_pass_config->get_callback(get_type_info())(node);
69  }
70 
71  /// \brief Applies callback for given node. By default callback returns false.
72  /// \param node which will be used inside callback
73  /// \return result of callback execution for given node
74  bool transformation_callback(const std::shared_ptr<const Node>& node)
75  {
76  return m_pass_config->get_callback(get_type_info())(node);
77  }
78 
79  using type_info_t = DiscreteTypeInfo;
80 
81  virtual const type_info_t& get_type_info() const = 0;
82 
83  protected:
84  void set_property(const PassPropertyMask& prop, bool value);
85 
86  private:
87  PassPropertyMask m_property;
88 
89  std::string m_name;
90  std::shared_ptr<PassConfig> m_pass_config;
91  };
92 
93  class NGRAPH_API FunctionPass : public PassBase
94  {
95  public:
96  NGRAPH_RTTI_DECLARATION;
97  virtual ~FunctionPass();
98  virtual bool run_on_function(std::shared_ptr<ngraph::Function>) = 0;
99  };
100 
101  class NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.") NGRAPH_API NodePass
102  : public PassBase
103  {
104  public:
105  NGRAPH_RTTI_DECLARATION;
106  virtual ~NodePass();
107  virtual bool run_on_node(std::shared_ptr<ngraph::Node>) = 0;
108  };
109 
110  class Manager;
111  enum class FusionType : uint32_t
112  {
113  //`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
114  // i.e. implement `generate_adjoints`
115  DIFFERENTIABLE_FUSIONS = 0x1,
116  REGULAR_FUSIONS = 0x2,
117  //`FOP_FUSIONS` produce ops in the FusedOps category that might
118  // not be supported by all backends
119  FOP_FUSIONS = 0x4,
120  ALL_FUSIONS = 0xFFFFFFFF
121  };
122  typedef EnumMask<FusionType> FusionTypeMask;
123  } // namespace pass
124 } // namespace ngraph
Definition: node.hpp:127
Definition: pass.hpp:94
Definition: manager.hpp:20
MatcherPass is a basic block for pattern based transformations. It describes pattern and action that ...
Definition: graph_rewrite.hpp:48
Definition: pass.hpp:33
void set_callback(const param_callback &callback)
Set callback for particular transformation type. This method set global callback. For more details se...
bool get_property(const PassPropertyMask &prop_mask) const
Check if this pass has all the pass properties.
std::shared_ptr< PassConfig > get_pass_config()
Allows to access PassConfig shared instance.
Definition: pass.hpp:60
virtual void set_pass_config(const std::shared_ptr< PassConfig > &pass_config)
Set PassConfig for particular transformation instance.
Definition: pass.hpp:53
bool transformation_callback(const std::shared_ptr< const Node > &node)
Applies callback for given node. By default callback returns false.
Definition: pass.hpp:74
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27