ie_cnn_network_iterator.hpp
Go to the documentation of this file.
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 /**
6  * @brief A header file for the CNNNetworkIterator class
7  *
8  * @file ie_cnn_network_iterator.hpp
9  */
10 #pragma once
11 #include <iterator>
12 #include <list>
13 #include <unordered_set>
14 #include <utility>
15 
16 #include "ie_api.h"
17 #include "ie_icnn_network.hpp"
18 #include "ie_locked_memory.hpp"
19 
20 namespace InferenceEngine {
21 namespace details {
22 
23 /**
24  * @deprecated Migrate to IR v10 and work with ngraph::Function directly. The method will be removed in 2021.1
25  * @brief This class enables range loops for CNNNetwork objects
26  */
27 class INFERENCE_ENGINE_INTERNAL("Migrate to IR v10 and work with ngraph::Function directly. The method will be removed in 2021.1")
28 CNNNetworkIterator {
29  IE_SUPPRESS_DEPRECATED_START
30 
31  std::unordered_set<CNNLayer*> visited;
32  std::list<CNNLayerPtr> nextLayersTovisit;
33  InferenceEngine::CNNLayerPtr currentLayer;
34  ICNNNetwork* network = nullptr;
35 
36 public:
37  /**
38  * iterator trait definitions
39  */
40  typedef std::forward_iterator_tag iterator_category;
41  typedef CNNLayerPtr value_type;
42  typedef int difference_type;
43  typedef CNNLayerPtr pointer;
44  typedef CNNLayerPtr reference;
45 
46  /**
47  * @brief Default constructor
48  */
49  CNNNetworkIterator() = default;
50  /**
51  * @brief Constructor. Creates an iterator for specified CNNNetwork instance.
52  * @param network Network to iterate. Make sure the network object is not destroyed before iterator goes out of
53  * scope.
54  */
55  explicit CNNNetworkIterator(const ICNNNetwork* network) {
56  if (network == nullptr) THROW_IE_EXCEPTION << "ICNNNetwork object is nullptr";
57  InputsDataMap inputs;
58  network->getInputsInfo(inputs);
59  if (!inputs.empty()) {
60  auto& nextLayers = inputs.begin()->second->getInputData()->getInputTo();
61  if (!nextLayers.empty()) {
62  currentLayer = nextLayers.begin()->second;
63  nextLayersTovisit.push_back(currentLayer);
64  visited.insert(currentLayer.get());
65  }
66  }
67  }
68 
69  /**
70  * @brief Performs pre-increment
71  * @return This CNNNetworkIterator instance
72  */
73  CNNNetworkIterator& operator++() {
74  currentLayer = next();
75  return *this;
76  }
77 
78  /**
79  * @brief Performs post-increment.
80  * Implementation does not follow the std interface since only move semantics is used
81  */
82  void operator++(int) {
83  currentLayer = next();
84  }
85 
86  /**
87  * @brief Checks if the given iterator is not equal to this one
88  * @param that Iterator to compare with
89  * @return true if the given iterator is not equal to this one, false - otherwise
90  */
91  bool operator!=(const CNNNetworkIterator& that) const {
92  return !operator==(that);
93  }
94 
95  /**
96  * @brief Gets const layer pointer referenced by this iterator
97  */
98  const CNNLayerPtr& operator*() const {
99  if (nullptr == currentLayer) {
100  THROW_IE_EXCEPTION << "iterator out of bound";
101  }
102  return currentLayer;
103  }
104 
105  /**
106  * @brief Gets a layer pointer referenced by this iterator
107  */
108  CNNLayerPtr& operator*() {
109  if (nullptr == currentLayer) {
110  THROW_IE_EXCEPTION << "iterator out of bound";
111  }
112  return currentLayer;
113  }
114  /**
115  * @brief Compares the given iterator with this one
116  * @param that Iterator to compare with
117  * @return true if the given iterator is equal to this one, false - otherwise
118  */
119  bool operator==(const CNNNetworkIterator& that) const {
120  return network == that.network && currentLayer == that.currentLayer;
121  }
122 
123 private:
124  /**
125  * @brief implementation based on BFS
126  */
127  CNNLayerPtr next() {
128  if (nextLayersTovisit.empty()) {
129  return nullptr;
130  }
131 
132  auto nextLayer = nextLayersTovisit.front();
133  nextLayersTovisit.pop_front();
134 
135  // visit child that not visited
136  for (auto&& output : nextLayer->outData) {
137  for (auto&& child : output->getInputTo()) {
138  if (visited.find(child.second.get()) == visited.end()) {
139  nextLayersTovisit.push_back(child.second);
140  visited.insert(child.second.get());
141  }
142  }
143  }
144 
145  // visit parents
146  for (auto&& parent : nextLayer->insData) {
147  auto parentLayer = parent.lock()->getCreatorLayer().lock();
148  if (parentLayer && visited.find(parentLayer.get()) == visited.end()) {
149  nextLayersTovisit.push_back(parentLayer);
150  visited.insert(parentLayer.get());
151  }
152  }
153 
154  return nextLayersTovisit.empty() ? nullptr : nextLayersTovisit.front();
155  }
156 
157  IE_SUPPRESS_DEPRECATED_END
158 };
159 } // namespace details
160 } // namespace InferenceEngine
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:25
Definition: cldnn_config.hpp:16
std::shared_ptr< CNNLayer > CNNLayerPtr
A smart pointer to the CNNLayer.
Definition: ie_common.h:39
This is a header file for the ICNNNetwork class.
std::map< std::string, InputInfo::Ptr > InputsDataMap
A collection that contains string as key, and InputInfo smart pointer as value.
Definition: ie_input_info.hpp:160
Definition: ie_cnn_network.h:27
A header file for generic LockedMemory<> and different variations of locks.
The macro defines a symbol import/export mechanism essential for Microsoft Windows(R) OS...