factory.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <functional>
8 #include <mutex>
9 #include <unordered_map>
10 
11 #include "ngraph/ngraph_visibility.hpp"
12 
13 namespace ngraph
14 {
15  NGRAPH_API std::mutex& get_registry_mutex();
16 
17  /// \brief Registry of factories that can construct objects derived from BASE_TYPE
18  template <typename BASE_TYPE>
20  {
21  public:
22  using Factory = std::function<BASE_TYPE*()>;
23  using FactoryMap = std::unordered_map<typename BASE_TYPE::type_info_t, Factory>;
24 
25  // \brief Get the default factory for DERIVED_TYPE. Specialize as needed.
26  template <typename DERIVED_TYPE>
27  static Factory get_default_factory()
28  {
29  return []() { return new DERIVED_TYPE(); };
30  }
31 
32  /// \brief Register a custom factory for type_info
33  void register_factory(const typename BASE_TYPE::type_info_t& type_info, Factory factory)
34  {
35  std::lock_guard<std::mutex> guard(get_registry_mutex());
36  m_factory_map[type_info] = factory;
37  }
38 
39  /// \brief Register a custom factory for DERIVED_TYPE
40  template <typename DERIVED_TYPE>
41  void register_factory(Factory factory)
42  {
43  register_factory(DERIVED_TYPE::type_info, factory);
44  }
45 
46  /// \brief Register the defualt constructor factory for DERIVED_TYPE
47  template <typename DERIVED_TYPE>
49  {
50  register_factory<DERIVED_TYPE>(get_default_factory<DERIVED_TYPE>());
51  }
52 
53  /// \brief Check to see if a factory is registered
54  bool has_factory(const typename BASE_TYPE::type_info_t& info)
55  {
56  std::lock_guard<std::mutex> guard(get_registry_mutex());
57  return m_factory_map.find(info) != m_factory_map.end();
58  }
59 
60  /// \brief Check to see if DERIVED_TYPE has a registered factory
61  template <typename DERIVED_TYPE>
62  bool has_factory()
63  {
64  return has_factory(DERIVED_TYPE::type_info);
65  }
66 
67  /// \brief Create an instance for type_info
68  BASE_TYPE* create(const typename BASE_TYPE::type_info_t& type_info) const
69  {
70  std::lock_guard<std::mutex> guard(get_registry_mutex());
71  auto it = m_factory_map.find(type_info);
72  return it == m_factory_map.end() ? nullptr : it->second();
73  }
74 
75  /// \brief Create an instance using factory for DERIVED_TYPE
76  template <typename DERIVED_TYPE>
77  BASE_TYPE* create() const
78  {
79  return create(DERIVED_TYPE::type_info);
80  }
81 
82  protected:
83  FactoryMap m_factory_map;
84  };
85 } // namespace ngraph
Registry of factories that can construct objects derived from BASE_TYPE.
Definition: factory.hpp:20
void register_factory()
Register the defualt constructor factory for DERIVED_TYPE.
Definition: factory.hpp:48
void register_factory(Factory factory)
Register a custom factory for DERIVED_TYPE.
Definition: factory.hpp:41
bool has_factory(const typename BASE_TYPE::type_info_t &info)
Check to see if a factory is registered.
Definition: factory.hpp:54
BASE_TYPE * create() const
Create an instance using factory for DERIVED_TYPE.
Definition: factory.hpp:77
BASE_TYPE * create(const typename BASE_TYPE::type_info_t &type_info) const
Create an instance for type_info.
Definition: factory.hpp:68
bool has_factory()
Check to see if DERIVED_TYPE has a registered factory.
Definition: factory.hpp:62
void register_factory(const typename BASE_TYPE::type_info_t &type_info, Factory factory)
Register a custom factory for type_info.
Definition: factory.hpp:33
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16