ie_cnn_network.h
Go to the documentation of this file.
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 /**
6  * @brief A header file that provides wrapper for ICNNNetwork object
7  *
8  * @file ie_cnn_network.h
9  */
10 #pragma once
11 
12 #include <map>
13 #include <memory>
14 #include <string>
15 #include <utility>
16 #include <vector>
17 
18 #include "ie_icnn_network.hpp"
19 #include "ie_blob.h"
20 #include "ie_common.h"
21 #include "ie_data.h"
23 #include "ie_extension.h"
24 
25 namespace ngraph {
26 
27 class Function;
28 
29 } // namespace ngraph
30 
31 namespace InferenceEngine {
32 
33 /**
34  * @brief This class contains all the information about the Neural Network and the related binary information
35  */
36 class INFERENCE_ENGINE_API_CLASS(CNNNetwork) {
37 public:
38  /**
39  * @brief A default constructor
40  */
41  CNNNetwork() = default;
42 
43  /**
44  * @brief Allows helper class to manage lifetime of network object
45  *
46  * @param network Pointer to the network object
47  */
48  explicit CNNNetwork(std::shared_ptr<ICNNNetwork> network): network(network) {
49  actual = network.get();
50  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
51  }
52 
53  /**
54  * @brief A constructor from ngraph::Function object
55  * This constructor wraps existing ngraph::Function
56  * If you want to avoid modification of original Function, please create a copy
57  * @param network Pointer to the ngraph::Function object
58  * @param exts Vector of pointers to IE extension objects
59  */
60  explicit CNNNetwork(const std::shared_ptr<ngraph::Function>& network,
61  const std::vector<IExtensionPtr>& exts = {});
62 
63  /**
64  * @brief A destructor
65  */
66  virtual ~CNNNetwork() {}
67 
68  /**
69  * @copybrief ICNNNetwork::getOutputsInfo
70  *
71  * Wraps ICNNNetwork::getOutputsInfo
72  *
73  * @return outputs Reference to the OutputsDataMap object
74  */
75  virtual OutputsDataMap getOutputsInfo() const {
76  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
77  OutputsDataMap outputs;
78  actual->getOutputsInfo(outputs);
79  return outputs;
80  }
81 
82  /**
83  * @copybrief ICNNNetwork::getInputsInfo
84  *
85  * Wraps ICNNNetwork::getInputsInfo
86  *
87  * @return inputs Reference to InputsDataMap object
88  */
89  virtual InputsDataMap getInputsInfo() const {
90  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
91  InputsDataMap inputs;
92  actual->getInputsInfo(inputs);
93  return inputs;
94  }
95 
96  /**
97  * @copybrief ICNNNetwork::layerCount
98  *
99  * Wraps ICNNNetwork::layerCount
100  *
101  * @return The number of layers as an integer value
102  */
103  size_t layerCount() const {
104  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
105  return actual->layerCount();
106  }
107 
108  /**
109  * @copybrief ICNNNetwork::getName
110  *
111  * Wraps ICNNNetwork::getName
112  *
113  * @return Network name
114  */
115  const std::string& getName() const {
116  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
117  return actual->getName();
118  }
119 
120  /**
121  * @copybrief ICNNNetwork::setBatchSize
122  *
123  * Wraps ICNNNetwork::setBatchSize
124  *
125  * @param size Size of batch to set
126  */
127  virtual void setBatchSize(const size_t size) {
128  CALL_STATUS_FNC(setBatchSize, size);
129  }
130 
131  /**
132  * @copybrief ICNNNetwork::getBatchSize
133  *
134  * Wraps ICNNNetwork::getBatchSize
135  *
136  * @return The size of batch as a size_t value
137  */
138  virtual size_t getBatchSize() const {
139  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
140  return actual->getBatchSize();
141  }
142 
143  /**
144  * @brief An overloaded operator cast to get pointer on current network
145  *
146  * @return A shared pointer of the current network
147  */
148  operator ICNNNetwork::Ptr() {
149  return network;
150  }
151 
152  /**
153  * @brief An overloaded operator & to get current network
154  *
155  * @return An instance of the current network
156  */
157  operator ICNNNetwork&() {
158  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
159  return *actual;
160  }
161 
162  /**
163  * @brief An overloaded operator & to get current network
164  *
165  * @return A const reference of the current network
166  */
167  operator const ICNNNetwork&() const {
168  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
169  return *actual;
170  }
171 
172  /**
173  * @brief Returns constant nGraph function
174  *
175  * @return constant nGraph function
176  */
177  std::shared_ptr<ngraph::Function> getFunction() {
178  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
179  return actual->getFunction();
180  }
181 
182  /**
183  * @brief Returns constant nGraph function
184  *
185  * @return constant nGraph function
186  */
187  std::shared_ptr<const ngraph::Function> getFunction() const {
188  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
189  return actual->getFunction();
190  }
191 
192  /**
193  * @copybrief ICNNNetwork::addOutput
194  *
195  * Wraps ICNNNetwork::addOutput
196  *
197  * @param layerName Name of the layer
198  * @param outputIndex Index of the output
199  */
200  void addOutput(const std::string& layerName, size_t outputIndex = 0) {
201  CALL_STATUS_FNC(addOutput, layerName, outputIndex);
202  }
203 
204  /**
205  * @brief Helper method to get collect all input shapes with names of corresponding Data objects
206  *
207  * @return Map of pairs: input name and its dimension.
208  */
210  if (actual == nullptr) THROW_IE_EXCEPTION << "CNNNetwork was not initialized.";
212  InputsDataMap inputs;
213  actual->getInputsInfo(inputs);
214  for (const auto& pair : inputs) {
215  auto info = pair.second;
216  if (info) {
217  auto data = info->getInputData();
218  if (data) {
219  shapes[data->getName()] = data->getTensorDesc().getDims();
220  }
221  }
222  }
223  return shapes;
224  }
225 
226  /**
227  * @brief Run shape inference with new input shapes for the network
228  *
229  * @param inputShapes - map of pairs: name of corresponding data and its dimension.
230  */
231  virtual void reshape(const ICNNNetwork::InputShapes& inputShapes) {
232  CALL_STATUS_FNC(reshape, inputShapes);
233  }
234 
235  /**
236  * @brief Serialize network to IR and weights files.
237  *
238  * @param xmlPath Path to output IR file.
239  * @param binPath Path to output weights file. The parameter is skipped in case
240  * of executable graph info serialization.
241  */
242  void serialize(const std::string& xmlPath, const std::string& binPath = "") const {
243  CALL_STATUS_FNC(serialize, xmlPath, binPath);
244  }
245 
246 protected:
247  /**
248  * @brief Network extra interface, might be nullptr
249  */
250  std::shared_ptr<ICNNNetwork> network;
251 
252  /**
253  * @brief A pointer to the current network
254  */
255  ICNNNetwork* actual = nullptr;
256  /**
257  * @brief A pointer to output data
258  */
260 };
261 
262 } // namespace InferenceEngine
std::shared_ptr< ICNNNetwork > network
Network extra interface, might be nullptr.
Definition: ie_cnn_network.h:250
CNNNetwork(const std::shared_ptr< ngraph::Function > &network, const std::vector< IExtensionPtr > &exts={})
A constructor from ngraph::Function object This constructor wraps existing ngraph::Function If you wa...
std::shared_ptr< const ngraph::Function > getFunction() const
Returns constant nGraph function.
Definition: ie_cnn_network.h:187
virtual size_t getBatchSize() const
Gets the inference batch size.
Definition: ie_cnn_network.h:138
virtual void reshape(const ICNNNetwork::InputShapes &inputShapes)
Run shape inference with new input shapes for the network.
Definition: ie_cnn_network.h:231
virtual OutputsDataMap getOutputsInfo() const
Gets the network output Data node information. The received info is stored in the given Data node.
Definition: ie_cnn_network.h:75
virtual void setBatchSize(const size_t size)
Changes the inference batch size.
Definition: ie_cnn_network.h:127
size_t layerCount() const
Returns the number of layers in the network as an integer value.
Definition: ie_cnn_network.h:103
CNNNetwork(std::shared_ptr< ICNNNetwork > network)
Allows helper class to manage lifetime of network object.
Definition: ie_cnn_network.h:48
std::shared_ptr< ngraph::Function > getFunction()
Returns constant nGraph function.
Definition: ie_cnn_network.h:177
DataPtr output
A pointer to output data.
Definition: ie_cnn_network.h:259
virtual ~CNNNetwork()
A destructor.
Definition: ie_cnn_network.h:66
void addOutput(const std::string &layerName, size_t outputIndex=0)
Adds output to the layer.
Definition: ie_cnn_network.h:200
const std::string & getName() const
Returns the network name.
Definition: ie_cnn_network.h:115
CNNNetwork()=default
A default constructor.
virtual ICNNNetwork::InputShapes getInputShapes() const
Helper method to get collect all input shapes with names of corresponding Data objects.
Definition: ie_cnn_network.h:209
virtual InputsDataMap getInputsInfo() const
Gets the network input Data node information. The received info is stored in the given InputsDataMap ...
Definition: ie_cnn_network.h:89
void serialize(const std::string &xmlPath, const std::string &binPath="") const
Serialize network to IR and weights files.
Definition: ie_cnn_network.h:242
This is the main interface to describe the NN topology.
Definition: ie_icnn_network.hpp:39
std::shared_ptr< ICNNNetwork > Ptr
A shared pointer to a ICNNNetwork interface.
Definition: ie_icnn_network.hpp:44
std::map< std::string, SizeVector > InputShapes
Map of pairs: name of corresponding data and its dimension.
Definition: ie_icnn_network.hpp:147
A header file for Blob and generic TBlob<>
This is a header file with common inference engine definitions.
std::shared_ptr< Data > DataPtr
Smart pointer to Data.
Definition: ie_common.h:37
This header file defines the main Data representation node.
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:25
A header file that provides macros to handle no exception methods.
A header file that defines a wrapper class for handling extension instantiation and releasing resourc...
This is a header file for the ICNNNetwork class.
std::map< std::string, DataPtr > OutputsDataMap
A collection that contains string as key, and Data smart pointer as value.
Definition: ie_icnn_network.hpp:33
std::map< std::string, InputInfo::Ptr > InputsDataMap
A collection that contains string as key, and InputInfo smart pointer as value.
Definition: ie_input_info.hpp:161