ie_cnn_network_iterator.hpp
Go to the documentation of this file.
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 /**
6  * @brief A header file for the CNNNetworkIterator class
7  *
8  * @file ie_cnn_network_iterator.hpp
9  */
10 #pragma once
11 #include <iterator>
12 #include <list>
13 #include <unordered_set>
14 #include <utility>
15 
16 #include "ie_icnn_network.hpp"
17 #include "ie_locked_memory.hpp"
18 
19 namespace InferenceEngine {
20 namespace details {
21 
22 /**
23  * @brief This class enables range loops for CNNNetwork objects
24  */
25 class CNNNetworkIterator {
26  std::unordered_set<CNNLayer*> visited;
27  std::list<CNNLayerPtr> nextLayersTovisit;
28  InferenceEngine::CNNLayerPtr currentLayer;
29  ICNNNetwork* network = nullptr;
30 
31 public:
32  /**
33  * iterator trait definitions
34  */
35  typedef std::forward_iterator_tag iterator_category;
36  typedef CNNLayerPtr value_type;
37  typedef int difference_type;
38  typedef CNNLayerPtr pointer;
39  typedef CNNLayerPtr reference;
40 
41  /**
42  * @brief Default constructor
43  */
44  CNNNetworkIterator() = default;
45  /**
46  * @brief Constructor. Creates an iterator for specified CNNNetwork instance.
47  * @param network Network to iterate. Make sure the network object is not destroyed before iterator goes out of
48  * scope.
49  */
50  explicit CNNNetworkIterator(ICNNNetwork* network) {
51  InputsDataMap inputs;
52  network->getInputsInfo(inputs);
53  if (!inputs.empty()) {
54  auto& nextLayers = inputs.begin()->second->getInputData()->getInputTo();
55  if (!nextLayers.empty()) {
56  currentLayer = nextLayers.begin()->second;
57  nextLayersTovisit.push_back(currentLayer);
58  visited.insert(currentLayer.get());
59  }
60  }
61  }
62 
63  /**
64  * @brief Performs pre-increment
65  * @return This CNNNetworkIterator instance
66  */
67  CNNNetworkIterator& operator++() {
68  currentLayer = next();
69  return *this;
70  }
71 
72  /**
73  * @brief Performs post-increment.
74  * Implementation does not follow the std interface since only move semantics is used
75  */
76  void operator++(int) {
77  currentLayer = next();
78  }
79 
80  /**
81  * @brief Checks if the given iterator is not equal to this one
82  * @param that Iterator to compare with
83  * @return true if the given iterator is not equal to this one, false - otherwise
84  */
85  bool operator!=(const CNNNetworkIterator& that) const {
86  return !operator==(that);
87  }
88 
89  /**
90  * @brief Gets const layer pointer referenced by this iterator
91  */
92  const CNNLayerPtr& operator*() const {
93  if (nullptr == currentLayer) {
94  THROW_IE_EXCEPTION << "iterator out of bound";
95  }
96  return currentLayer;
97  }
98 
99  /**
100  * @brief Gets a layer pointer referenced by this iterator
101  */
102  CNNLayerPtr& operator*() {
103  if (nullptr == currentLayer) {
104  THROW_IE_EXCEPTION << "iterator out of bound";
105  }
106  return currentLayer;
107  }
108  /**
109  * @brief Compares the given iterator with this one
110  * @param that Iterator to compare with
111  * @return true if the given iterator is equal to this one, false - otherwise
112  */
113  bool operator==(const CNNNetworkIterator& that) const {
114  return network == that.network && currentLayer == that.currentLayer;
115  }
116 
117 private:
118  /**
119  * @brief implementation based on BFS
120  */
121  CNNLayerPtr next() {
122  if (nextLayersTovisit.empty()) {
123  return nullptr;
124  }
125 
126  auto nextLayer = nextLayersTovisit.front();
127  nextLayersTovisit.pop_front();
128 
129  // visit child that not visited
130  for (auto&& output : nextLayer->outData) {
131  for (auto&& child : output->getInputTo()) {
132  if (visited.find(child.second.get()) == visited.end()) {
133  nextLayersTovisit.push_back(child.second);
134  visited.insert(child.second.get());
135  }
136  }
137  }
138 
139  // visit parents
140  for (auto&& parent : nextLayer->insData) {
141  auto parentLayer = parent.lock()->getCreatorLayer().lock();
142  if (parentLayer && visited.find(parentLayer.get()) == visited.end()) {
143  nextLayersTovisit.push_back(parentLayer);
144  visited.insert(parentLayer.get());
145  }
146  }
147 
148  return nextLayersTovisit.empty() ? nullptr : nextLayersTovisit.front();
149  }
150 };
151 } // namespace details
152 } // namespace InferenceEngine
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:25
Inference Engine API.
Definition: ie_argmax_layer.hpp:15
This is a header file for the ICNNNetwork class.
A header file for generic LockedMemory<> and different variations of locks.
std::shared_ptr< CNNLayer > CNNLayerPtr
A smart pointer to the CNNLayer.
Definition: ie_common.h:39
std::map< std::string, InputInfo::Ptr > InputsDataMap
A collection that contains string as key, and InputInfo smart pointer as value.
Definition: ie_input_info.hpp:164