ie_cnn_network.h
Go to the documentation of this file.
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 /**
6  * @brief A header file that provides wrapper for ICNNNetwork object
7  *
8  * @file ie_cnn_network.h
9  */
10 #pragma once
11 
12 #include <ie_icnn_net_reader.h>
13 
16 #include <ie_icnn_network.hpp>
17 #include <map>
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "ie_blob.h"
24 #include "ie_common.h"
25 #include "ie_data.h"
26 
27 namespace ngraph {
28 
29 class Function;
30 
31 } // namespace ngraph
32 
33 namespace InferenceEngine {
34 
35 /**
36  * @brief This class contains all the information about the Neural Network and the related binary information
37  */
38 class INFERENCE_ENGINE_API_CLASS(CNNNetwork) {
39 public:
40  /**
41  * @brief A default constructor
42  */
43  CNNNetwork() = default;
44 
45  /**
46  * @brief Allows helper class to manage lifetime of network object
47  *
48  * @param network Pointer to the network object
49  */
50  explicit CNNNetwork(std::shared_ptr<ICNNNetwork> network): network(network) {
51  actual = network.get();
52  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
53  }
54 
55  /**
56  * @brief A constructor from ngraph::Function object
57  * @param network Pointer to the ngraph::Function object
58  */
59  explicit CNNNetwork(const std::shared_ptr<const ngraph::Function>& network);
60 
61  /**
62  * @brief A constructor from ICNNNetReader object
63  *
64  * @param reader Pointer to the ICNNNetReader object
65  */
66  IE_SUPPRESS_DEPRECATED_START
67  explicit CNNNetwork(CNNNetReaderPtr reader_): reader(reader_) {
68  if (reader == nullptr) {
69  THROW_IE_EXCEPTION << "ICNNNetReader was not initialized.";
70  }
71  if ((actual = reader->getNetwork(nullptr)) == nullptr) {
72  THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
73  }
74  }
75  IE_SUPPRESS_DEPRECATED_END
76 
77  /**
78  * @brief A destructor
79  */
80  virtual ~CNNNetwork() {}
81 
82  /**
83  * @deprecated Network precision does not make sence, use precision on egdes. The method will be removed in 2021.1
84  * @copybrief ICNNNetwork::getPrecision
85  *
86  * Wraps ICNNNetwork::getPrecision
87  *
88  * @return A precision type
89  */
90  INFERENCE_ENGINE_DEPRECATED("Network precision does not make sence, use precision on egdes. The method will be removed in 2021.1")
91  virtual Precision getPrecision() const;
92 
93  /**
94  * @copybrief ICNNNetwork::getOutputsInfo
95  *
96  * Wraps ICNNNetwork::getOutputsInfo
97  *
98  * @return outputs Reference to the OutputsDataMap object
99  */
100  virtual OutputsDataMap getOutputsInfo() const {
101  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
102  OutputsDataMap outputs;
103  actual->getOutputsInfo(outputs);
104  return outputs;
105  }
106 
107  /**
108  * @copybrief ICNNNetwork::getInputsInfo
109  *
110  * Wraps ICNNNetwork::getInputsInfo
111  *
112  * @return inputs Reference to InputsDataMap object
113  */
114  virtual InputsDataMap getInputsInfo() const {
115  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
116  InputsDataMap inputs;
117  actual->getInputsInfo(inputs);
118  return inputs;
119  }
120 
121  /**
122  * @copybrief ICNNNetwork::layerCount
123  *
124  * Wraps ICNNNetwork::layerCount
125  *
126  * @return The number of layers as an integer value
127  */
128  size_t layerCount() const {
129  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
130  return actual->layerCount();
131  }
132 
133  /**
134  * @copybrief ICNNNetwork::getName
135  *
136  * Wraps ICNNNetwork::getName
137  *
138  * @return Network name
139  */
140  const std::string& getName() const {
141  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
142  return actual->getName();
143  }
144 
145  /**
146  * @copybrief ICNNNetwork::setBatchSize
147  *
148  * Wraps ICNNNetwork::setBatchSize
149  *
150  * @param size Size of batch to set
151  * @return Status code of the operation
152  */
153  virtual void setBatchSize(const size_t size) {
154  CALL_STATUS_FNC(setBatchSize, size);
155  }
156 
157  /**
158  * @copybrief ICNNNetwork::getBatchSize
159  *
160  * Wraps ICNNNetwork::getBatchSize
161  *
162  * @return The size of batch as a size_t value
163  */
164  virtual size_t getBatchSize() const {
165  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
166  return actual->getBatchSize();
167  }
168 
169  /**
170  * @brief An overloaded operator cast to get pointer on current network
171  *
172  * @return A shared pointer of the current network
173  */
174  operator ICNNNetwork::Ptr() {
175  return network;
176  }
177 
178  /**
179  * @brief An overloaded operator & to get current network
180  *
181  * @return An instance of the current network
182  */
183  operator ICNNNetwork&() {
184  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
185  return *actual;
186  }
187 
188  /**
189  * @brief An overloaded operator & to get current network
190  *
191  * @return A const reference of the current network
192  */
193  operator const ICNNNetwork&() const {
194  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
195  return *actual;
196  }
197 
198  /**
199  * @brief Returns constant nGraph function
200  *
201  * @return constant nGraph function
202  */
203  std::shared_ptr<ngraph::Function> getFunction() {
204  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
205  return actual->getFunction();
206  }
207 
208  /**
209  * @brief Returns constant nGraph function
210  *
211  * @return constant nGraph function
212  */
213  std::shared_ptr<const ngraph::Function> getFunction() const {
214  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
215  return actual->getFunction();
216  }
217 
218  /**
219  * @copybrief ICNNNetwork::addOutput
220  *
221  * Wraps ICNNNetwork::addOutput
222  *
223  * @param layerName Name of the layer
224  * @param outputIndex Index of the output
225  */
226  void addOutput(const std::string& layerName, size_t outputIndex = 0) {
227  CALL_STATUS_FNC(addOutput, layerName, outputIndex);
228  }
229 
230  /**
231  * @deprecated Migrate to IR v10 and work with ngraph::Function directly. The method will be removed in 2021.1
232  * @copybrief ICNNNetwork::getLayerByName
233  *
234  * Wraps ICNNNetwork::getLayerByName
235  *
236  * @param layerName Given name of the layer
237  * @return Status code of the operation. InferenceEngine::OK if succeeded
238  */
239  INFERENCE_ENGINE_DEPRECATED("Migrate to IR v10 and work with ngraph::Function directly. The method will be removed in 2021.1")
240  CNNLayerPtr getLayerByName(const char* layerName) const;
241 
242  /**
243  * @deprecated Use CNNNetwork::getFunction() and work with ngraph::Function directly. The method will be removed in 2021.1
244  * @brief Begin layer iterator
245  *
246  * Order of layers is implementation specific,
247  * and can be changed in future
248  *
249  * @return Iterator pointing to a layer
250  */
251  IE_SUPPRESS_DEPRECATED_START
252  INFERENCE_ENGINE_DEPRECATED("Use CNNNetwork::getFunction() and work with ngraph::Function directly. The method will be removed in 2021.1")
253  details::CNNNetworkIterator begin() const;
254 
255  /**
256  * @deprecated Use CNNNetwork::getFunction() and work with ngraph::Function directly. The method will be removed in 2021.1
257  * @brief End layer iterator
258  * @return Iterator pointing to a layer
259  */
260  INFERENCE_ENGINE_DEPRECATED("Use CNNNetwork::getFunction() and work with ngraph::Function directly. The method will be removed in 2021.1")
261  details::CNNNetworkIterator end() const;
262  IE_SUPPRESS_DEPRECATED_END
263 
264  /**
265  * @deprecated Use CNNNetwork::layerCount() instead. The method will be removed in 2021.1
266  * @brief Number of layers in network object
267  *
268  * @return Number of layers.
269  */
270  INFERENCE_ENGINE_DEPRECATED("Use CNNNetwork::layerCount() instead. The method will be removed in 2021.1")
271  size_t size() const;
272 
273  /**
274  * @deprecated Use Core::AddExtension to add an extension to the library
275  * @brief Registers extension within the plugin
276  *
277  * @param extension Pointer to already loaded reader extension with shape propagation implementations
278  */
279  INFERENCE_ENGINE_DEPRECATED("Use Core::AddExtension to add an extension to the library")
280  void AddExtension(InferenceEngine::IShapeInferExtensionPtr extension);
281 
282  /**
283  * @brief Helper method to get collect all input shapes with names of corresponding Data objects
284  *
285  * @return Map of pairs: input name and its dimension.
286  */
288  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
290  InputsDataMap inputs;
291  actual->getInputsInfo(inputs);
292  for (const auto& pair : inputs) {
293  auto info = pair.second;
294  if (info) {
295  auto data = info->getInputData();
296  if (data) {
297  shapes[data->getName()] = data->getTensorDesc().getDims();
298  }
299  }
300  }
301  return shapes;
302  }
303 
304  /**
305  * @brief Run shape inference with new input shapes for the network
306  *
307  * @param inputShapes - map of pairs: name of corresponding data and its dimension.
308  */
309  virtual void reshape(const ICNNNetwork::InputShapes& inputShapes) {
310  CALL_STATUS_FNC(reshape, inputShapes);
311  }
312 
313  /**
314  * @brief Serialize network to IR and weights files.
315  *
316  * @param xmlPath Path to output IR file.
317  * @param binPath Path to output weights file. The parameter is skipped in case
318  * of executable graph info serialization.
319  */
320  void serialize(const std::string& xmlPath, const std::string& binPath = "") const {
321  CALL_STATUS_FNC(serialize, xmlPath, binPath);
322  }
323 
324 protected:
325  /**
326  * @brief Reader extra reference, might be nullptr
327  */
328  IE_SUPPRESS_DEPRECATED_START
330  IE_SUPPRESS_DEPRECATED_END
331  /**
332  * @brief Network extra interface, might be nullptr
333  */
334  std::shared_ptr<ICNNNetwork> network;
335 
336  /**
337  * @brief A pointer to the current network
338  */
339  ICNNNetwork* actual = nullptr;
340  /**
341  * @brief A pointer to output data
342  */
344 };
345 
346 } // namespace InferenceEngine
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:25
virtual ~CNNNetwork()
A destructor.
Definition: ie_cnn_network.h:80
CNNNetwork(std::shared_ptr< ICNNNetwork > network)
Allows helper class to manage lifetime of network object.
Definition: ie_cnn_network.h:50
std::shared_ptr< IShapeInferExtension > IShapeInferExtensionPtr
A shared pointer to a IShapeInferExtension interface.
Definition: ie_iextension.h:366
Definition: cldnn_config.hpp:16
std::shared_ptr< CNNLayer > CNNLayerPtr
A smart pointer to the CNNLayer.
Definition: ie_common.h:39
A header file that provides macros to handle no exception methods.
virtual ICNNNetwork::InputShapes getInputShapes() const
Helper method to get collect all input shapes with names of corresponding Data objects.
Definition: ie_cnn_network.h:287
virtual InputsDataMap getInputsInfo() const
Definition: ie_cnn_network.h:114
void addOutput(const std::string &layerName, size_t outputIndex=0)
Definition: ie_cnn_network.h:226
A header file for Blob and generic TBlob<>
A header file for the CNNNetworkIterator class.
This is a header file for the ICNNNetwork class.
std::map< std::string, InputInfo::Ptr > InputsDataMap
A collection that contains string as key, and InputInfo smart pointer as value.
Definition: ie_input_info.hpp:160
InferenceEngine::details::SOPointer< ICNNNetReader, InferenceEngine::details::SharedObjectLoader > CNNNetReaderPtr
A C++ helper to work with objects created by the IR readers plugin. Implements different interfaces...
Definition: ie_icnn_net_reader.h:154
virtual void reshape(const ICNNNetwork::InputShapes &inputShapes)
Run shape inference with new input shapes for the network.
Definition: ie_cnn_network.h:309
size_t layerCount() const
Definition: ie_cnn_network.h:128
std::shared_ptr< ICNNNetwork > Ptr
A shared pointer to a ICNNNetwork interface.
Definition: ie_icnn_network.hpp:48
std::shared_ptr< ngraph::Function > getFunction()
Returns constant nGraph function.
Definition: ie_cnn_network.h:203
virtual size_t getBatchSize() const
Definition: ie_cnn_network.h:164
This is the main interface to describe the NN topology.
Definition: ie_icnn_network.hpp:43
Definition: ie_cnn_network.h:27
void serialize(const std::string &xmlPath, const std::string &binPath="") const
Serialize network to IR and weights files.
Definition: ie_cnn_network.h:320
virtual OutputsDataMap getOutputsInfo() const
Definition: ie_cnn_network.h:100
This class contains all the information about the Neural Network and the related binary information...
Definition: ie_cnn_network.h:38
This header file defines the main Data representation node.
std::shared_ptr< ICNNNetwork > network
Network extra interface, might be nullptr.
Definition: ie_cnn_network.h:334
virtual void setBatchSize(const size_t size)
Definition: ie_cnn_network.h:153
DataPtr output
A pointer to output data.
Definition: ie_cnn_network.h:343
A header file that provides interface for network reader that is used to build networks from a given ...
std::shared_ptr< const ngraph::Function > getFunction() const
Returns constant nGraph function.
Definition: ie_cnn_network.h:213
std::shared_ptr< Data > DataPtr
Smart pointer to Data.
Definition: ie_common.h:53
const std::string & getName() const
Definition: ie_cnn_network.h:140
std::map< std::string, DataPtr > OutputsDataMap
A collection that contains string as key, and Data smart pointer as value.
Definition: ie_icnn_network.hpp:37
std::map< std::string, SizeVector > InputShapes
Map of pairs: name of corresponding data and its dimension.
Definition: ie_icnn_network.hpp:206
CNNNetwork(CNNNetReaderPtr reader_)
A constructor from ICNNNetReader object.
Definition: ie_cnn_network.h:67
CNNNetReaderPtr reader
Reader extra reference, might be nullptr.
Definition: ie_cnn_network.h:329
This is a header file with common inference engine definitions.
This class holds precision value and provides precision related operations.
Definition: ie_precision.hpp:22