pass.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <list>
20 #include <memory>
21 #include <vector>
22 
23 #include "ngraph/deprecated.hpp"
24 #include "ngraph/function.hpp"
25 #include "ngraph/node.hpp"
26 #include "ngraph/pass/pass_config.hpp"
27 #include "ngraph/util.hpp"
28 
29 namespace ngraph
30 {
31  namespace pass
32  {
33  enum class PassProperty : uint32_t
34  {
35  // Pass requires node shapes to be static
36  REQUIRE_STATIC_SHAPE = 0x1,
37  // Pass transformation will change the function's dynamic state
38  CHANGE_DYNAMIC_STATE = 1 << 1,
39  };
40 
41  typedef EnumMask<PassProperty> PassPropertyMask;
42  const PassPropertyMask all_pass_property_off;
43 
44  class NGRAPH_API PassBase
45  {
46  friend class Manager;
47 
48  public:
49  PassBase();
50  virtual ~PassBase() {}
51  /// Check if this pass has all the pass properties.
52  bool get_property(const PassPropertyMask& prop_mask) const;
53 
54  void set_name(const std::string& name) { m_name = name; }
55  std::string get_name() const;
56 
57  /// \brief Set callback for particular transformation type.
58  /// This method set global callback. For more details see PassConfig class
59  /// documentation.
60  /// \param callback lambda function that takes node and returns bool
61  void set_callback(const param_callback& callback);
62 
63  /// \brief Set PassConfig for particular transformation instance
64  /// \param pass_config is a PassConfig shared_ptr
65  virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config)
66  {
67  m_pass_config = pass_config;
68  }
69 
70  /// \brief Allows to access PassConfig shared instance
71  /// \return Shared instance of PassConfig class
72  std::shared_ptr<PassConfig> get_pass_config() { return m_pass_config; }
73  /// \brief Applies callback for given node. By default callback returns false.
74  /// This method remains here only for backward compatibility and will be removed
75  /// after all transformations are moved to transformation_callback() method.
76  /// \return result of callback execution for given node
77  NGRAPH_DEPRECATED("Please use transformation_callback method instead")
78  bool m_transformation_callback(const std::shared_ptr<const Node>& node)
79  {
80  return m_pass_config->get_callback(get_type_info())(node);
81  }
82 
83  /// \brief Applies callback for given node. By default callback returns false.
84  /// \param node which will be used inside callback
85  /// \return result of callback execution for given node
86  bool transformation_callback(const std::shared_ptr<const Node>& node)
87  {
88  return m_pass_config->get_callback(get_type_info())(node);
89  }
90 
91  using type_info_t = DiscreteTypeInfo;
92 
93  virtual const type_info_t& get_type_info() const = 0;
94 
95  protected:
96  void set_property(const PassPropertyMask& prop, bool value);
97 
98  private:
99  PassPropertyMask m_property;
100 
101  std::string m_name;
102  std::shared_ptr<PassConfig> m_pass_config;
103  };
104 
105  class NGRAPH_API FunctionPass : public PassBase
106  {
107  public:
108  NGRAPH_RTTI_DECLARATION;
109  virtual ~FunctionPass();
110  virtual bool run_on_function(std::shared_ptr<ngraph::Function>) = 0;
111  };
112 
113  class NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.") NGRAPH_API NodePass
114  : public PassBase
115  {
116  public:
117  NGRAPH_RTTI_DECLARATION;
118  virtual ~NodePass();
119  virtual bool run_on_node(std::shared_ptr<ngraph::Node>) = 0;
120  };
121 
122  class Manager;
123  enum class FusionType : uint32_t
124  {
125  //`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
126  // i.e. implement `generate_adjoints`
127  DIFFERENTIABLE_FUSIONS = 0x1,
128  REGULAR_FUSIONS = 0x2,
129  //`FOP_FUSIONS` produce ops in the FusedOps category that might
130  // not be supported by all backends
131  FOP_FUSIONS = 0x4,
132  ALL_FUSIONS = 0xFFFFFFFF
133  };
134  typedef EnumMask<FusionType> FusionTypeMask;
135  }
136 }
Definition: node.hpp:132
Definition: pass.hpp:106
Definition: manager.hpp:32
MatcherPass is a basic block for pattern based transformations. It describes pattern and action that ...
Definition: graph_rewrite.hpp:60
Definition: pass.hpp:45
void set_callback(const param_callback &callback)
Set callback for particular transformation type. This method set global callback. For more details se...
bool get_property(const PassPropertyMask &prop_mask) const
Check if this pass has all the pass properties.
std::shared_ptr< PassConfig > get_pass_config()
Allows to access PassConfig shared instance.
Definition: pass.hpp:72
virtual void set_pass_config(const std::shared_ptr< PassConfig > &pass_config)
Set PassConfig for particular transformation instance.
Definition: pass.hpp:65
bool transformation_callback(const std::shared_ptr< const Node > &node)
Applies callback for given node. By default callback returns false.
Definition: pass.hpp:86
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
Definition: type.hpp:39