ie_cnn_network_iterator.hpp
Go to the documentation of this file.
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 /**
6  * @brief A header file for the CNNNetworkIterator class
7  *
8  * @file ie_cnn_network_iterator.hpp
9  */
10 #pragma once
11 #include <iterator>
12 #include <list>
13 #include <unordered_set>
14 #include <utility>
15 
16 #include "ie_api.h"
17 #include "ie_icnn_network.hpp"
18 #include "ie_locked_memory.hpp"
19 
20 namespace InferenceEngine {
21 namespace details {
22 
23 /**
24  * @deprecated Migrate to IR v10 and work with ngraph::Function directly. The method will be removed in 2020.3
25  * @brief This class enables range loops for CNNNetwork objects
26  */
27 class INFERENCE_ENGINE_INTERNAL("Migrate to IR v10 and work with ngraph::Function directly. The method will be removed in 2020.3")
28 CNNNetworkIterator {
29  IE_SUPPRESS_DEPRECATED_START
30 
31  std::unordered_set<CNNLayer*> visited;
32  std::list<CNNLayerPtr> nextLayersTovisit;
33  InferenceEngine::CNNLayerPtr currentLayer;
34  ICNNNetwork* network = nullptr;
35 
36 public:
37  /**
38  * iterator trait definitions
39  */
40  typedef std::forward_iterator_tag iterator_category;
41  typedef CNNLayerPtr value_type;
42  typedef int difference_type;
43  typedef CNNLayerPtr pointer;
44  typedef CNNLayerPtr reference;
45 
46  /**
47  * @brief Default constructor
48  */
49  CNNNetworkIterator() = default;
50  /**
51  * @brief Constructor. Creates an iterator for specified CNNNetwork instance.
52  * @param network Network to iterate. Make sure the network object is not destroyed before iterator goes out of
53  * scope.
54  */
55  explicit CNNNetworkIterator(const ICNNNetwork* network) {
56  InputsDataMap inputs;
57  network->getInputsInfo(inputs);
58  if (!inputs.empty()) {
59  auto& nextLayers = inputs.begin()->second->getInputData()->getInputTo();
60  if (!nextLayers.empty()) {
61  currentLayer = nextLayers.begin()->second;
62  nextLayersTovisit.push_back(currentLayer);
63  visited.insert(currentLayer.get());
64  }
65  }
66  }
67 
68  /**
69  * @brief Performs pre-increment
70  * @return This CNNNetworkIterator instance
71  */
72  CNNNetworkIterator& operator++() {
73  currentLayer = next();
74  return *this;
75  }
76 
77  /**
78  * @brief Performs post-increment.
79  * Implementation does not follow the std interface since only move semantics is used
80  */
81  void operator++(int) {
82  currentLayer = next();
83  }
84 
85  /**
86  * @brief Checks if the given iterator is not equal to this one
87  * @param that Iterator to compare with
88  * @return true if the given iterator is not equal to this one, false - otherwise
89  */
90  bool operator!=(const CNNNetworkIterator& that) const {
91  return !operator==(that);
92  }
93 
94  /**
95  * @brief Gets const layer pointer referenced by this iterator
96  */
97  const CNNLayerPtr& operator*() const {
98  if (nullptr == currentLayer) {
99  THROW_IE_EXCEPTION << "iterator out of bound";
100  }
101  return currentLayer;
102  }
103 
104  /**
105  * @brief Gets a layer pointer referenced by this iterator
106  */
107  CNNLayerPtr& operator*() {
108  if (nullptr == currentLayer) {
109  THROW_IE_EXCEPTION << "iterator out of bound";
110  }
111  return currentLayer;
112  }
113  /**
114  * @brief Compares the given iterator with this one
115  * @param that Iterator to compare with
116  * @return true if the given iterator is equal to this one, false - otherwise
117  */
118  bool operator==(const CNNNetworkIterator& that) const {
119  return network == that.network && currentLayer == that.currentLayer;
120  }
121 
122 private:
123  /**
124  * @brief implementation based on BFS
125  */
126  CNNLayerPtr next() {
127  if (nextLayersTovisit.empty()) {
128  return nullptr;
129  }
130 
131  auto nextLayer = nextLayersTovisit.front();
132  nextLayersTovisit.pop_front();
133 
134  // visit child that not visited
135  for (auto&& output : nextLayer->outData) {
136  for (auto&& child : output->getInputTo()) {
137  if (visited.find(child.second.get()) == visited.end()) {
138  nextLayersTovisit.push_back(child.second);
139  visited.insert(child.second.get());
140  }
141  }
142  }
143 
144  // visit parents
145  for (auto&& parent : nextLayer->insData) {
146  auto parentLayer = parent.lock()->getCreatorLayer().lock();
147  if (parentLayer && visited.find(parentLayer.get()) == visited.end()) {
148  nextLayersTovisit.push_back(parentLayer);
149  visited.insert(parentLayer.get());
150  }
151  }
152 
153  return nextLayersTovisit.empty() ? nullptr : nextLayersTovisit.front();
154  }
155 
156  IE_SUPPRESS_DEPRECATED_END
157 };
158 } // namespace details
159 } // namespace InferenceEngine
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:25
Inference Engine API.
Definition: ie_argmax_layer.hpp:15
This is a header file for the ICNNNetwork class.
Definition: ie_cnn_network.h:27
A header file for generic LockedMemory<> and different variations of locks.
std::shared_ptr< CNNLayer > CNNLayerPtr
A smart pointer to the CNNLayer.
Definition: ie_common.h:39
The macro defines a symbol import/export mechanism essential for Microsoft Windows(R) OS...
std::map< std::string, InputInfo::Ptr > InputsDataMap
A collection that contains string as key, and InputInfo smart pointer as value.
Definition: ie_input_info.hpp:160