ie_inetwork_iterator.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 /**
6  * @brief A header file for the CNNNetworkIterator class
7  * @file ie_cnn_network_iterator.hpp
8  */
9 #pragma once
10 #include <utility>
11 #include <unordered_map>
12 #include <unordered_set>
13 #include <list>
14 #include <iterator>
15 #include <memory>
16 #include <vector>
17 
18 #include <ie_network.hpp>
19 
20 namespace InferenceEngine {
21 namespace details {
22 
23 template<class NT, class LT>
24 class INetworkIterator: public std::iterator<std::input_iterator_tag, std::shared_ptr<LT>> {
25 public:
26  explicit INetworkIterator(NT * network, bool toEnd = false): network(network), currentIdx(0) {
27  if (!network || toEnd)
28  return;
29  const auto& inputs = network->getInputs();
30 
31  std::vector<std::shared_ptr<LT>> allInputs;
32  for (const auto& input : inputs) {
33  allInputs.push_back(std::dynamic_pointer_cast<LT>(input));
34  }
35 
36  forestDFS(allInputs, [&](std::shared_ptr<LT> current) {
37  sortedLayers.push_back(current);
38  }, false);
39 
40  std::reverse(std::begin(sortedLayers), std::end(sortedLayers));
41  currentLayer = getNextLayer();
42  }
43 
44  bool operator!=(const INetworkIterator& that) const {
45  return !operator==(that);
46  }
47 
48  bool operator==(const INetworkIterator& that) const {
49  return network == that.network && currentLayer == that.currentLayer;
50  }
51 
52  typename INetworkIterator::reference operator*() {
53  if (nullptr == currentLayer) {
54  THROW_IE_EXCEPTION << "iterator out of bound";
55  }
56  return currentLayer;
57  }
58 
59  INetworkIterator& operator++() {
60  currentLayer = getNextLayer();
61  return *this;
62  }
63 
64  const INetworkIterator<NT, LT> operator++(int) {
65  INetworkIterator<NT, LT> retval = *this;
66  ++(*this);
67  return retval;
68  }
69 
70 private:
71  std::vector<std::shared_ptr<LT>> sortedLayers;
72  std::shared_ptr<LT> currentLayer;
73  size_t currentIdx;
74  NT *network = nullptr;
75 
76  std::shared_ptr<LT> getNextLayer() {
77  return (sortedLayers.size() > currentIdx) ? sortedLayers[currentIdx++] : nullptr;
78  }
79 
80  template<class T>
81  inline void forestDFS(const std::vector<std::shared_ptr<LT>>& heads, const T &visit, bool bVisitBefore) {
82  if (heads.empty()) {
83  return;
84  }
85 
86  std::unordered_map<idx_t, bool> visited;
87  for (auto & layer : heads) {
88  DFS(visited, layer, visit, bVisitBefore);
89  }
90  }
91 
92  template<class T>
93  inline void DFS(std::unordered_map<idx_t, bool> &visited,
94  const std::shared_ptr<LT> &layer,
95  const T &visit,
96  bool visitBefore) {
97  if (layer == nullptr) {
98  return;
99  }
100 
101  if (visitBefore)
102  visit(layer);
103 
104  visited[layer->getId()] = false;
105  for (const auto &connection : network->getLayerConnections(layer->getId())) {
106  if (connection.to().layerId() == layer->getId()) {
107  continue;
108  }
109  const auto outLayer = network->getLayer(connection.to().layerId());
110  if (!outLayer)
111  THROW_IE_EXCEPTION << "Couldn't get layer with id: " << connection.to().layerId();
112  auto i = visited.find(outLayer->getId());
113  if (i != visited.end()) {
114  /**
115  * cycle detected we entered still not completed node
116  */
117  if (!i->second) {
118  THROW_IE_EXCEPTION << "Sorting not possible, due to existed loop.";
119  }
120  continue;
121  }
122 
123  DFS(visited, outLayer, visit, visitBefore);
124  }
125  if (!visitBefore)
126  visit(layer);
127  visited[layer->getId()] = true;
128  }
129 };
130 
131 } // namespace details
132 } // namespace InferenceEngine
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:22
Definition: ie_argmax_layer.hpp:11