ie_cnn_network_iterator.hpp
Go to the documentation of this file.
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 /**
6  * @brief A header file for the CNNNetworkIterator class
7  * @file ie_cnn_network_iterator.hpp
8  */
9 #pragma once
10 #include <utility>
11 #include <unordered_set>
12 #include <list>
13 #include <iterator>
14 
15 #include "ie_locked_memory.hpp"
16 #include "ie_icnn_network.hpp"
17 
18 namespace InferenceEngine {
19 namespace details {
20 
21 /**
22  * @brief This class enables range loops for CNNNetwork objects
23  */
24 class CNNNetworkIterator {
25  std::unordered_set<CNNLayer*> visited;
26  std::list<CNNLayerPtr> nextLayersTovisit;
27  InferenceEngine::CNNLayerPtr currentLayer;
28  ICNNNetwork * network = nullptr;
29 
30  public:
31  /**
32  * iterator trait definitions
33  */
34  typedef std::forward_iterator_tag iterator_category;
35  typedef CNNLayerPtr value_type;
36  typedef int difference_type;
37  typedef CNNLayerPtr pointer;
38  typedef CNNLayerPtr reference;
39 
40  /**
41  * @brief Default constructor
42  */
43  CNNNetworkIterator() = default;
44  /**
45  * @brief Constructor. Creates an iterator for specified CNNNetwork instance.
46  * @param network Network to iterate. Make sure the network object is not destroyed before iterator goes out of scope.
47  */
48  explicit CNNNetworkIterator(ICNNNetwork * network) {
49  InputsDataMap inputs;
50  network->getInputsInfo(inputs);
51  if (!inputs.empty()) {
52  auto & nextLayers = inputs.begin()->second->getInputData()->getInputTo();
53  if (!nextLayers.empty()) {
54  currentLayer = nextLayers.begin()->second;
55  nextLayersTovisit.push_back(currentLayer);
56  visited.insert(currentLayer.get());
57  }
58  }
59  }
60 
61  /**
62  * @brief Performs pre-increment
63  * @return This CNNNetworkIterator instance
64  */
65  CNNNetworkIterator &operator++() {
66  currentLayer = next();
67  return *this;
68  }
69 
70  /**
71  * @brief Performs post-increment.
72  * Implementation does not follow the std interface since only move semantics is used
73  */
74  void operator++(int) {
75  currentLayer = next();
76  }
77 
78  /**
79  * @brief Checks if the given iterator is not equal to this one
80  * @param that Iterator to compare with
81  * @return true if the given iterator is not equal to this one, false - otherwise
82  */
83  bool operator!=(const CNNNetworkIterator &that) const {
84  return !operator==(that);
85  }
86 
87  /**
88  * @brief Gets const layer pointer referenced by this iterator
89  */
90  const CNNLayerPtr &operator*() const {
91  if (nullptr == currentLayer) {
92  THROW_IE_EXCEPTION << "iterator out of bound";
93  }
94  return currentLayer;
95  }
96 
97  /**
98  * @brief Gets a layer pointer referenced by this iterator
99  */
100  CNNLayerPtr &operator*() {
101  if (nullptr == currentLayer) {
102  THROW_IE_EXCEPTION << "iterator out of bound";
103  }
104  return currentLayer;
105  }
106  /**
107  * @brief Compares the given iterator with this one
108  * @param that Iterator to compare with
109  * @return true if the given iterator is equal to this one, false - otherwise
110  */
111  bool operator==(const CNNNetworkIterator &that) const {
112  return network == that.network && currentLayer == that.currentLayer;
113  }
114 
115  private:
116  /**
117  * @brief implementation based on BFS
118  */
119  CNNLayerPtr next() {
120  if (nextLayersTovisit.empty()) {
121  return nullptr;
122  }
123 
124  auto nextLayer = nextLayersTovisit.front();
125  nextLayersTovisit.pop_front();
126 
127  // visit child that not visited
128  for (auto && output : nextLayer->outData) {
129  for (auto && child : output->getInputTo()) {
130  if (visited.find(child.second.get()) == visited.end()) {
131  nextLayersTovisit.push_back(child.second);
132  visited.insert(child.second.get());
133  }
134  }
135  }
136 
137  // visit parents
138  for (auto && parent : nextLayer->insData) {
139  auto parentLayer = parent.lock()->getCreatorLayer().lock();
140  if (parentLayer && visited.find(parentLayer.get()) == visited.end()) {
141  nextLayersTovisit.push_back(parentLayer);
142  visited.insert(parentLayer.get());
143  }
144  }
145 
146  return nextLayersTovisit.empty() ? nullptr : nextLayersTovisit.front();
147  }
148 };
149 } // namespace details
150 } // namespace InferenceEngine
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:22
Definition: ie_argmax_layer.hpp:11
std::shared_ptr< CNNLayer > CNNLayerPtr
A smart pointer to the CNNLayer.
Definition: ie_common.h:36
This is a header file for the ICNNNetwork class.
std::map< std::string, InputInfo::Ptr > InputsDataMap
A collection that contains string as key, and InputInfo smart pointer as value.
Definition: ie_input_info.hpp:193
A header file for generic LockedMemory<> and different variations of locks.