opset.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <locale>
20 #include <map>
21 #include <mutex>
22 #include <set>
23 
24 #include "ngraph/factory.hpp"
25 #include "ngraph/ngraph_visibility.hpp"
26 #include "ngraph/node.hpp"
27 
28 namespace ngraph
29 {
30  /// \brief Run-time opset information
31  class NGRAPH_API OpSet
32  {
33  static std::mutex& get_mutex();
34 
35  public:
36  OpSet() {}
37  std::set<NodeTypeInfo>::size_type size() const
38  {
39  std::lock_guard<std::mutex> guard(get_mutex());
40  return m_op_types.size();
41  }
42  /// \brief Insert an op into the opset with a particular name and factory
43  void insert(const std::string& name,
44  const NodeTypeInfo& type_info,
45  FactoryRegistry<Node>::Factory factory)
46  {
47  std::lock_guard<std::mutex> guard(get_mutex());
48  m_op_types.insert(type_info);
49  m_name_type_info_map[name] = type_info;
50  m_case_insensitive_type_info_map[to_upper_name(name)] = type_info;
51  m_factory_registry.register_factory(type_info, factory);
52  }
53 
54  /// \brief Insert OP_TYPE into the opset with a special name and the default factory
55  template <typename OP_TYPE>
56  void insert(const std::string& name)
57  {
58  insert(name, OP_TYPE::type_info, FactoryRegistry<Node>::get_default_factory<OP_TYPE>());
59  }
60 
61  /// \brief Insert OP_TYPE into the opset with the default name and factory
62  template <typename OP_TYPE>
63  void insert()
64  {
65  insert<OP_TYPE>(OP_TYPE::type_info.name);
66  }
67 
68  const std::set<NodeTypeInfo>& get_types_info() const { return m_op_types; }
69  /// \brief Create the op named name using it's factory
70  ngraph::Node* create(const std::string& name) const;
71 
72  /// \brief Create the op named name using it's factory
73  ngraph::Node* create_insensitive(const std::string& name) const;
74 
75  /// \brief Return true if OP_TYPE is in the opset
76  bool contains_type(const NodeTypeInfo& type_info) const
77  {
78  std::lock_guard<std::mutex> guard(get_mutex());
79  return m_op_types.find(type_info) != m_op_types.end();
80  }
81 
82  /// \brief Return true if OP_TYPE is in the opset
83  template <typename OP_TYPE>
84  bool contains_type() const
85  {
86  return contains_type(OP_TYPE::type_info);
87  }
88 
89  /// \brief Return true if name is in the opset
90  bool contains_type(const std::string& name) const
91  {
92  std::lock_guard<std::mutex> guard(get_mutex());
93  return m_name_type_info_map.find(name) != m_name_type_info_map.end();
94  }
95 
96  /// \brief Return true if name is in the opset
97  bool contains_type_insensitive(const std::string& name) const
98  {
99  std::lock_guard<std::mutex> guard(get_mutex());
100  return m_case_insensitive_type_info_map.find(to_upper_name(name)) !=
101  m_case_insensitive_type_info_map.end();
102  }
103 
104  /// \brief Return true if node's type is in the opset
105  bool contains_op_type(const Node* node) const
106  {
107  std::lock_guard<std::mutex> guard(get_mutex());
108  return m_op_types.find(node->get_type_info()) != m_op_types.end();
109  }
110 
111  const std::set<NodeTypeInfo>& get_type_info_set() const { return m_op_types; }
112  ngraph::FactoryRegistry<ngraph::Node>& get_factory_registry() { return m_factory_registry; }
113  protected:
114  static std::string to_upper_name(const std::string& name)
115  {
116  std::string upper_name = name;
117  std::locale loc;
118  std::transform(upper_name.begin(),
119  upper_name.end(),
120  upper_name.begin(),
121  [&loc](char c) { return std::toupper(c, loc); });
122  return upper_name;
123  }
124 
125  ngraph::FactoryRegistry<ngraph::Node> m_factory_registry;
126  std::set<NodeTypeInfo> m_op_types;
127  std::map<std::string, NodeTypeInfo> m_name_type_info_map;
128  std::map<std::string, NodeTypeInfo> m_case_insensitive_type_info_map;
129  };
130 
131  const NGRAPH_API OpSet& get_opset1();
132  const NGRAPH_API OpSet& get_opset2();
133  const NGRAPH_API OpSet& get_opset3();
134  const NGRAPH_API OpSet& get_opset4();
135  const NGRAPH_API OpSet& get_opset5();
136  const NGRAPH_API OpSet& get_opset6();
137 }
Registry of factories that can construct objects derived from BASE_TYPE.
Definition: factory.hpp:32
Definition: node.hpp:132
virtual const type_info_t & get_type_info() const =0
Run-time opset information.
Definition: opset.hpp:32
void insert(const std::string &name)
Insert OP_TYPE into the opset with a special name and the default factory.
Definition: opset.hpp:56
ngraph::Node * create(const std::string &name) const
Create the op named name using it's factory.
void insert()
Insert OP_TYPE into the opset with the default name and factory.
Definition: opset.hpp:63
void insert(const std::string &name, const NodeTypeInfo &type_info, FactoryRegistry< Node >::Factory factory)
Insert an op into the opset with a particular name and factory.
Definition: opset.hpp:43
bool contains_type_insensitive(const std::string &name) const
Return true if name is in the opset.
Definition: opset.hpp:97
ngraph::Node * create_insensitive(const std::string &name) const
Create the op named name using it's factory.
bool contains_op_type(const Node *node) const
Return true if node's type is in the opset.
Definition: opset.hpp:105
bool contains_type(const NodeTypeInfo &type_info) const
Return true if OP_TYPE is in the opset.
Definition: opset.hpp:76
bool contains_type() const
Return true if OP_TYPE is in the opset.
Definition: opset.hpp:84
bool contains_type(const std::string &name) const
Return true if name is in the opset.
Definition: opset.hpp:90
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
Definition: type.hpp:39