opset.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <locale>
8 #include <map>
9 #include <mutex>
10 #include <set>
11 
12 #include "ngraph/factory.hpp"
13 #include "ngraph/ngraph_visibility.hpp"
14 #include "ngraph/node.hpp"
15 
16 namespace ngraph
17 {
18  /// \brief Run-time opset information
19  class NGRAPH_API OpSet
20  {
21  static std::mutex& get_mutex();
22 
23  public:
24  OpSet() {}
25  std::set<NodeTypeInfo>::size_type size() const
26  {
27  std::lock_guard<std::mutex> guard(get_mutex());
28  return m_op_types.size();
29  }
30  /// \brief Insert an op into the opset with a particular name and factory
31  void insert(const std::string& name,
32  const NodeTypeInfo& type_info,
33  FactoryRegistry<Node>::Factory factory)
34  {
35  std::lock_guard<std::mutex> guard(get_mutex());
36  m_op_types.insert(type_info);
37  m_name_type_info_map[name] = type_info;
38  m_case_insensitive_type_info_map[to_upper_name(name)] = type_info;
39  m_factory_registry.register_factory(type_info, factory);
40  }
41 
42  /// \brief Insert OP_TYPE into the opset with a special name and the default factory
43  template <typename OP_TYPE>
44  void insert(const std::string& name)
45  {
46  insert(name, OP_TYPE::type_info, FactoryRegistry<Node>::get_default_factory<OP_TYPE>());
47  }
48 
49  /// \brief Insert OP_TYPE into the opset with the default name and factory
50  template <typename OP_TYPE>
51  void insert()
52  {
53  insert<OP_TYPE>(OP_TYPE::type_info.name);
54  }
55 
56  const std::set<NodeTypeInfo>& get_types_info() const { return m_op_types; }
57  /// \brief Create the op named name using it's factory
58  ngraph::Node* create(const std::string& name) const;
59 
60  /// \brief Create the op named name using it's factory
61  ngraph::Node* create_insensitive(const std::string& name) const;
62 
63  /// \brief Return true if OP_TYPE is in the opset
64  bool contains_type(const NodeTypeInfo& type_info) const
65  {
66  std::lock_guard<std::mutex> guard(get_mutex());
67  return m_op_types.find(type_info) != m_op_types.end();
68  }
69 
70  /// \brief Return true if OP_TYPE is in the opset
71  template <typename OP_TYPE>
72  bool contains_type() const
73  {
74  return contains_type(OP_TYPE::type_info);
75  }
76 
77  /// \brief Return true if name is in the opset
78  bool contains_type(const std::string& name) const
79  {
80  std::lock_guard<std::mutex> guard(get_mutex());
81  return m_name_type_info_map.find(name) != m_name_type_info_map.end();
82  }
83 
84  /// \brief Return true if name is in the opset
85  bool contains_type_insensitive(const std::string& name) const
86  {
87  std::lock_guard<std::mutex> guard(get_mutex());
88  return m_case_insensitive_type_info_map.find(to_upper_name(name)) !=
89  m_case_insensitive_type_info_map.end();
90  }
91 
92  /// \brief Return true if node's type is in the opset
93  bool contains_op_type(const Node* node) const
94  {
95  std::lock_guard<std::mutex> guard(get_mutex());
96  return m_op_types.find(node->get_type_info()) != m_op_types.end();
97  }
98 
99  const std::set<NodeTypeInfo>& get_type_info_set() const { return m_op_types; }
100  ngraph::FactoryRegistry<ngraph::Node>& get_factory_registry() { return m_factory_registry; }
101 
102  protected:
103  static std::string to_upper_name(const std::string& name)
104  {
105  std::string upper_name = name;
106  std::locale loc;
107  std::transform(upper_name.begin(),
108  upper_name.end(),
109  upper_name.begin(),
110  [&loc](char c) { return std::toupper(c, loc); });
111  return upper_name;
112  }
113 
114  ngraph::FactoryRegistry<ngraph::Node> m_factory_registry;
115  std::set<NodeTypeInfo> m_op_types;
116  std::map<std::string, NodeTypeInfo> m_name_type_info_map;
117  std::map<std::string, NodeTypeInfo> m_case_insensitive_type_info_map;
118  };
119 
120  const NGRAPH_API OpSet& get_opset1();
121  const NGRAPH_API OpSet& get_opset2();
122  const NGRAPH_API OpSet& get_opset3();
123  const NGRAPH_API OpSet& get_opset4();
124  const NGRAPH_API OpSet& get_opset5();
125  const NGRAPH_API OpSet& get_opset6();
126  const NGRAPH_API OpSet& get_opset7();
127 } // namespace ngraph
Registry of factories that can construct objects derived from BASE_TYPE.
Definition: factory.hpp:20
Definition: node.hpp:127
virtual const type_info_t & get_type_info() const =0
Run-time opset information.
Definition: opset.hpp:20
void insert(const std::string &name)
Insert OP_TYPE into the opset with a special name and the default factory.
Definition: opset.hpp:44
ngraph::Node * create(const std::string &name) const
Create the op named name using it's factory.
void insert()
Insert OP_TYPE into the opset with the default name and factory.
Definition: opset.hpp:51
void insert(const std::string &name, const NodeTypeInfo &type_info, FactoryRegistry< Node >::Factory factory)
Insert an op into the opset with a particular name and factory.
Definition: opset.hpp:31
bool contains_type_insensitive(const std::string &name) const
Return true if name is in the opset.
Definition: opset.hpp:85
ngraph::Node * create_insensitive(const std::string &name) const
Create the op named name using it's factory.
bool contains_op_type(const Node *node) const
Return true if node's type is in the opset.
Definition: opset.hpp:93
bool contains_type(const NodeTypeInfo &type_info) const
Return true if OP_TYPE is in the opset.
Definition: opset.hpp:64
bool contains_type() const
Return true if OP_TYPE is in the opset.
Definition: opset.hpp:72
bool contains_type(const std::string &name) const
Return true if name is in the opset.
Definition: opset.hpp:78
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27