ie_cnn_network_iterator.hpp
Go to the documentation of this file.
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 /**
6  * @brief A header file for the CNNNetworkIterator class
7  * @file ie_cnn_network_iterator.hpp
8  */
9 #pragma once
10 #include <iterator>
11 #include <list>
12 #include <unordered_set>
13 #include <utility>
14 
15 #include "ie_icnn_network.hpp"
16 #include "ie_locked_memory.hpp"
17 
18 namespace InferenceEngine {
19 namespace details {
20 
21 /**
22  * @brief This class enables range loops for CNNNetwork objects
23  */
24 class CNNNetworkIterator {
25  std::unordered_set<CNNLayer*> visited;
26  std::list<CNNLayerPtr> nextLayersTovisit;
27  InferenceEngine::CNNLayerPtr currentLayer;
28  ICNNNetwork* network = nullptr;
29 
30 public:
31  /**
32  * iterator trait definitions
33  */
34  typedef std::forward_iterator_tag iterator_category;
35  typedef CNNLayerPtr value_type;
36  typedef int difference_type;
37  typedef CNNLayerPtr pointer;
38  typedef CNNLayerPtr reference;
39 
40  /**
41  * @brief Default constructor
42  */
43  CNNNetworkIterator() = default;
44  /**
45  * @brief Constructor. Creates an iterator for specified CNNNetwork instance.
46  * @param network Network to iterate. Make sure the network object is not destroyed before iterator goes out of
47  * scope.
48  */
49  explicit CNNNetworkIterator(ICNNNetwork* network) {
50  InputsDataMap inputs;
51  network->getInputsInfo(inputs);
52  if (!inputs.empty()) {
53  auto& nextLayers = inputs.begin()->second->getInputData()->getInputTo();
54  if (!nextLayers.empty()) {
55  currentLayer = nextLayers.begin()->second;
56  nextLayersTovisit.push_back(currentLayer);
57  visited.insert(currentLayer.get());
58  }
59  }
60  }
61 
62  /**
63  * @brief Performs pre-increment
64  * @return This CNNNetworkIterator instance
65  */
66  CNNNetworkIterator& operator++() {
67  currentLayer = next();
68  return *this;
69  }
70 
71  /**
72  * @brief Performs post-increment.
73  * Implementation does not follow the std interface since only move semantics is used
74  */
75  void operator++(int) {
76  currentLayer = next();
77  }
78 
79  /**
80  * @brief Checks if the given iterator is not equal to this one
81  * @param that Iterator to compare with
82  * @return true if the given iterator is not equal to this one, false - otherwise
83  */
84  bool operator!=(const CNNNetworkIterator& that) const {
85  return !operator==(that);
86  }
87 
88  /**
89  * @brief Gets const layer pointer referenced by this iterator
90  */
91  const CNNLayerPtr& operator*() const {
92  if (nullptr == currentLayer) {
93  THROW_IE_EXCEPTION << "iterator out of bound";
94  }
95  return currentLayer;
96  }
97 
98  /**
99  * @brief Gets a layer pointer referenced by this iterator
100  */
101  CNNLayerPtr& operator*() {
102  if (nullptr == currentLayer) {
103  THROW_IE_EXCEPTION << "iterator out of bound";
104  }
105  return currentLayer;
106  }
107  /**
108  * @brief Compares the given iterator with this one
109  * @param that Iterator to compare with
110  * @return true if the given iterator is equal to this one, false - otherwise
111  */
112  bool operator==(const CNNNetworkIterator& that) const {
113  return network == that.network && currentLayer == that.currentLayer;
114  }
115 
116 private:
117  /**
118  * @brief implementation based on BFS
119  */
120  CNNLayerPtr next() {
121  if (nextLayersTovisit.empty()) {
122  return nullptr;
123  }
124 
125  auto nextLayer = nextLayersTovisit.front();
126  nextLayersTovisit.pop_front();
127 
128  // visit child that not visited
129  for (auto&& output : nextLayer->outData) {
130  for (auto&& child : output->getInputTo()) {
131  if (visited.find(child.second.get()) == visited.end()) {
132  nextLayersTovisit.push_back(child.second);
133  visited.insert(child.second.get());
134  }
135  }
136  }
137 
138  // visit parents
139  for (auto&& parent : nextLayer->insData) {
140  auto parentLayer = parent.lock()->getCreatorLayer().lock();
141  if (parentLayer && visited.find(parentLayer.get()) == visited.end()) {
142  nextLayersTovisit.push_back(parentLayer);
143  visited.insert(parentLayer.get());
144  }
145  }
146 
147  return nextLayersTovisit.empty() ? nullptr : nextLayersTovisit.front();
148  }
149 };
150 } // namespace details
151 } // namespace InferenceEngine
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:24
Inference Engine API.
Definition: ie_argmax_layer.hpp:11
This is a header file for the ICNNNetwork class.
A header file for generic LockedMemory<> and different variations of locks.
std::shared_ptr< CNNLayer > CNNLayerPtr
A smart pointer to the CNNLayer.
Definition: ie_common.h:37
std::map< std::string, InputInfo::Ptr > InputsDataMap
A collection that contains string as key, and InputInfo smart pointer as value.
Definition: ie_input_info.hpp:156