factory.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <functional>
20 #include <mutex>
21 #include <unordered_map>
22 
23 #include "ngraph/ngraph_visibility.hpp"
24 
25 namespace ngraph
26 {
27  NGRAPH_API std::mutex& get_registry_mutex();
28 
29  /// \brief Registry of factories that can construct objects derived from BASE_TYPE
30  template <typename BASE_TYPE>
32  {
33  public:
34  using Factory = std::function<BASE_TYPE*()>;
35  using FactoryMap = std::unordered_map<typename BASE_TYPE::type_info_t, Factory>;
36 
37  // \brief Get the default factory for DERIVED_TYPE. Specialize as needed.
38  template <typename DERIVED_TYPE>
39  static Factory get_default_factory()
40  {
41  return []() { return new DERIVED_TYPE(); };
42  }
43 
44  /// \brief Register a custom factory for type_info
45  void register_factory(const typename BASE_TYPE::type_info_t& type_info, Factory factory)
46  {
47  std::lock_guard<std::mutex> guard(get_registry_mutex());
48  m_factory_map[type_info] = factory;
49  }
50 
51  /// \brief Register a custom factory for DERIVED_TYPE
52  template <typename DERIVED_TYPE>
53  void register_factory(Factory factory)
54  {
55  register_factory(DERIVED_TYPE::type_info, factory);
56  }
57 
58  /// \brief Register the defualt constructor factory for DERIVED_TYPE
59  template <typename DERIVED_TYPE>
61  {
62  register_factory<DERIVED_TYPE>(get_default_factory<DERIVED_TYPE>());
63  }
64 
65  /// \brief Check to see if a factory is registered
66  bool has_factory(const typename BASE_TYPE::type_info_t& info)
67  {
68  std::lock_guard<std::mutex> guard(get_registry_mutex());
69  return m_factory_map.find(info) != m_factory_map.end();
70  }
71 
72  /// \brief Check to see if DERIVED_TYPE has a registered factory
73  template <typename DERIVED_TYPE>
74  bool has_factory()
75  {
76  return has_factory(DERIVED_TYPE::type_info);
77  }
78 
79  /// \brief Create an instance for type_info
80  BASE_TYPE* create(const typename BASE_TYPE::type_info_t& type_info) const
81  {
82  std::lock_guard<std::mutex> guard(get_registry_mutex());
83  auto it = m_factory_map.find(type_info);
84  return it == m_factory_map.end() ? nullptr : it->second();
85  }
86 
87  /// \brief Create an instance using factory for DERIVED_TYPE
88  template <typename DERIVED_TYPE>
89  BASE_TYPE* create() const
90  {
91  return create(DERIVED_TYPE::type_info);
92  }
93 
94  /// \brief Get the factory for BASE_TYPE
96 
97  protected:
98  FactoryMap m_factory_map;
99  };
100 }
ngraph::FactoryRegistry
Registry of factories that can construct objects derived from BASE_TYPE.
Definition: factory.hpp:32
ngraph::FactoryRegistry::get
static FactoryRegistry< BASE_TYPE > & get()
Get the factory for BASE_TYPE.
ngraph::FactoryRegistry::create
BASE_TYPE * create() const
Create an instance using factory for DERIVED_TYPE.
Definition: factory.hpp:89
ngraph::FactoryRegistry::register_factory
void register_factory(Factory factory)
Register a custom factory for DERIVED_TYPE.
Definition: factory.hpp:53
ngraph::FactoryRegistry::register_factory
void register_factory()
Register the defualt constructor factory for DERIVED_TYPE.
Definition: factory.hpp:60
ngraph::FactoryRegistry::has_factory
bool has_factory(const typename BASE_TYPE::type_info_t &info)
Check to see if a factory is registered.
Definition: factory.hpp:66
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::FactoryRegistry::register_factory
void register_factory(const typename BASE_TYPE::type_info_t &type_info, Factory factory)
Register a custom factory for type_info.
Definition: factory.hpp:45
ngraph::FactoryRegistry::has_factory
bool has_factory()
Check to see if DERIVED_TYPE has a registered factory.
Definition: factory.hpp:74
ngraph::FactoryRegistry::create
BASE_TYPE * create(const typename BASE_TYPE::type_info_t &type_info) const
Create an instance for type_info.
Definition: factory.hpp:80