ie_inetwork_iterator.hpp
Go to the documentation of this file.
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 /**
6  * @brief A header file for the CNNNetworkIterator class
7  *
8  * @file
9  */
10 #pragma once
11 #include <ie_network.hpp>
12 #include <iterator>
13 #include <list>
14 #include <memory>
15 #include <unordered_map>
16 #include <unordered_set>
17 #include <utility>
18 #include <vector>
19 
20 namespace InferenceEngine {
21 namespace details {
22 
23 template <class NT, class LT>
24 class INFERENCE_ENGINE_NN_BUILDER_DEPRECATED INetworkIterator
25  : public std::iterator<std::input_iterator_tag, std::shared_ptr<LT>> {
26 public:
27  explicit INetworkIterator(NT* network, bool toEnd): network(network), currentIdx(0) {}
28  explicit INetworkIterator(NT* network): network(network), currentIdx(0) {
29  if (!network) return;
30  const auto& inputs = network->getInputs();
31 
32  std::vector<std::shared_ptr<LT>> allInputs;
33  for (const auto& input : inputs) {
34  allInputs.push_back(std::dynamic_pointer_cast<LT>(input));
35  }
36 
37  forestDFS(
38  allInputs,
39  [&](std::shared_ptr<LT> current) {
40  sortedLayers.push_back(current);
41  },
42  false);
43 
44  std::reverse(std::begin(sortedLayers), std::end(sortedLayers));
45  currentLayer = getNextLayer();
46  }
47 
48  IE_SUPPRESS_DEPRECATED_START
49 
50  bool operator!=(const INetworkIterator& that) const {
51  return !operator==(that);
52  }
53 
54  bool operator==(const INetworkIterator& that) const {
55  return network == that.network && currentLayer == that.currentLayer;
56  }
57 
58  typename INetworkIterator::reference operator*() {
59  if (nullptr == currentLayer) {
60  THROW_IE_EXCEPTION << "iterator out of bound";
61  }
62  return currentLayer;
63  }
64 
65  INetworkIterator& operator++() {
66  currentLayer = getNextLayer();
67  return *this;
68  }
69 
70  const INetworkIterator<NT, LT> operator++(int) {
71  INetworkIterator<NT, LT> retval = *this;
72  ++(*this);
73  return retval;
74  }
75 
76  IE_SUPPRESS_DEPRECATED_END
77 
78 private:
79  std::vector<std::shared_ptr<LT>> sortedLayers;
80  std::shared_ptr<LT> currentLayer;
81  NT* network = nullptr;
82  size_t currentIdx;
83 
84  std::shared_ptr<LT> getNextLayer() {
85  return (sortedLayers.size() > currentIdx) ? sortedLayers[currentIdx++] : nullptr;
86  }
87 
88  template <class T>
89  inline void forestDFS(const std::vector<std::shared_ptr<LT>>& heads, const T& visit, bool bVisitBefore) {
90  if (heads.empty()) {
91  return;
92  }
93 
94  std::unordered_map<idx_t, bool> visited;
95  for (auto& layer : heads) {
96  DFS(visited, layer, visit, bVisitBefore);
97  }
98  }
99 
100  template <class T>
101  inline void DFS(std::unordered_map<idx_t, bool>& visited, const std::shared_ptr<LT>& layer, const T& visit,
102  bool visitBefore) {
103  if (layer == nullptr) {
104  return;
105  }
106 
107  if (visitBefore) visit(layer);
108 
109  visited[layer->getId()] = false;
110  for (const auto& connection : network->getLayerConnections(layer->getId())) {
111  if (connection.to().layerId() == layer->getId()) {
112  continue;
113  }
114  const auto outLayer = network->getLayer(connection.to().layerId());
115  if (!outLayer) THROW_IE_EXCEPTION << "Couldn't get layer with id: " << connection.to().layerId();
116  auto i = visited.find(outLayer->getId());
117  if (i != visited.end()) {
118  /**
119  * cycle detected we entered still not completed node
120  */
121  if (!i->second) {
122  THROW_IE_EXCEPTION << "Sorting not possible, due to existed loop.";
123  }
124  continue;
125  }
126 
127  DFS(visited, outLayer, visit, visitBefore);
128  }
129  if (!visitBefore) visit(layer);
130  visited[layer->getId()] = true;
131  }
132 };
133 
134 } // namespace details
135 } // namespace InferenceEngine
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:25
Inference Engine API.
Definition: ie_argmax_layer.hpp:15
a header file for the Inference Engine Network interface