pass.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <list>
20 #include <memory>
21 #include <vector>
22 
23 #include "ngraph/deprecated.hpp"
24 #include "ngraph/function.hpp"
25 #include "ngraph/node.hpp"
26 #include "ngraph/util.hpp"
27 
28 namespace ngraph
29 {
30  namespace pass
31  {
32  enum class PassProperty : uint32_t
33  {
34  // Pass requires node shapes to be static
35  REQUIRE_STATIC_SHAPE = 0x1,
36  // Pass transformation will change the function's dynamic state
37  CHANGE_DYNAMIC_STATE = 1 << 1,
38  };
39 
40  typedef EnumMask<PassProperty> PassPropertyMask;
41  const PassPropertyMask all_pass_property_off;
42  using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
43 
44  class NGRAPH_API PassBase
45  {
46  friend class Manager;
47 
48  public:
49  PassBase();
50  virtual ~PassBase() {}
51  /// Check if this pass has all the pass properties.
52  bool get_property(const PassPropertyMask& prop_mask) const;
53 
54  void set_name(const std::string& name) { m_name = name; }
55  std::string get_name() const;
56 
57  void set_callback(const param_callback& callback);
58 
59  using type_info_t = DiscreteTypeInfo;
60 
61  virtual const type_info_t& get_type_info() const = 0;
62 
63  protected:
64  void set_property(const PassPropertyMask& prop, bool value);
65 
66  param_callback m_transformation_callback =
67  [](const std::shared_ptr<const Node>&) -> bool { return false; };
68  bool m_has_default_callback = true;
69 
70  private:
71  PassPropertyMask m_property;
72  std::string m_name;
73  };
74 
75  class NGRAPH_API FunctionPass : public PassBase
76  {
77  public:
78  NGRAPH_RTTI_DECLARATION;
79  virtual ~FunctionPass();
80  virtual bool run_on_function(std::shared_ptr<ngraph::Function>) = 0;
81  };
82 
83  class NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.") NGRAPH_API NodePass
84  : public PassBase
85  {
86  public:
87  NGRAPH_RTTI_DECLARATION;
88  virtual ~NodePass();
89  virtual bool run_on_node(std::shared_ptr<ngraph::Node>) = 0;
90  };
91 
92  class Manager;
93  enum class FusionType : uint32_t
94  {
95  //`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
96  // i.e. implement `generate_adjoints`
97  DIFFERENTIABLE_FUSIONS = 0x1,
98  REGULAR_FUSIONS = 0x2,
99  //`FOP_FUSIONS` produce ops in the FusedOps category that might
100  // not be supported by all backends
101  FOP_FUSIONS = 0x4,
102  ALL_FUSIONS = 0xFFFFFFFF
103  };
104  typedef EnumMask<FusionType> FusionTypeMask;
105  }
106 }
ngraph::pass::PassBase::get_property
bool get_property(const PassPropertyMask &prop_mask) const
Check if this pass has all the pass properties.
ngraph::set_callback
void set_callback(param_callback callback)
Callback is a lambda function that can be used by registered transformations. The main purpose of thi...
Definition: manager.hpp:77
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::pass::PassBase
Definition: pass.hpp:45
ngraph::pass::FunctionPass
Definition: pass.hpp:76