ie_inetwork_iterator.hpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 /**
6  * @brief A header file for the CNNNetworkIterator class
7  * @file ie_cnn_network_iterator.hpp
8  */
9 #pragma once
10 #include <utility>
11 #include <unordered_map>
12 #include <unordered_set>
13 #include <list>
14 #include <iterator>
15 #include <memory>
16 #include <vector>
17 
18 #include <ie_network.hpp>
19 
20 namespace InferenceEngine {
21 namespace details {
22 
23 template<class NT, class LT>
24 class INetworkIterator: public std::iterator<std::input_iterator_tag, std::shared_ptr<LT>> {
25 public:
26  explicit INetworkIterator(NT * network, bool toEnd): network(network), currentIdx(0) {}
27  explicit INetworkIterator(NT * network): network(network), currentIdx(0) {
28  if (!network)
29  return;
30  const auto& inputs = network->getInputs();
31 
32  std::vector<std::shared_ptr<LT>> allInputs;
33  for (const auto& input : inputs) {
34  allInputs.push_back(std::dynamic_pointer_cast<LT>(input));
35  }
36 
37  forestDFS(allInputs, [&](std::shared_ptr<LT> current) {
38  sortedLayers.push_back(current);
39  }, false);
40 
41  std::reverse(std::begin(sortedLayers), std::end(sortedLayers));
42  currentLayer = getNextLayer();
43  }
44 
45  bool operator!=(const INetworkIterator& that) const {
46  return !operator==(that);
47  }
48 
49  bool operator==(const INetworkIterator& that) const {
50  return network == that.network && currentLayer == that.currentLayer;
51  }
52 
53  typename INetworkIterator::reference operator*() {
54  if (nullptr == currentLayer) {
55  THROW_IE_EXCEPTION << "iterator out of bound";
56  }
57  return currentLayer;
58  }
59 
60  INetworkIterator& operator++() {
61  currentLayer = getNextLayer();
62  return *this;
63  }
64 
65  const INetworkIterator<NT, LT> operator++(int) {
66  INetworkIterator<NT, LT> retval = *this;
67  ++(*this);
68  return retval;
69  }
70 
71 private:
72  std::vector<std::shared_ptr<LT>> sortedLayers;
73  std::shared_ptr<LT> currentLayer;
74  size_t currentIdx;
75  NT *network = nullptr;
76 
77  std::shared_ptr<LT> getNextLayer() {
78  return (sortedLayers.size() > currentIdx) ? sortedLayers[currentIdx++] : nullptr;
79  }
80 
81  template<class T>
82  inline void forestDFS(const std::vector<std::shared_ptr<LT>>& heads, const T &visit, bool bVisitBefore) {
83  if (heads.empty()) {
84  return;
85  }
86 
87  std::unordered_map<idx_t, bool> visited;
88  for (auto & layer : heads) {
89  DFS(visited, layer, visit, bVisitBefore);
90  }
91  }
92 
93  template<class T>
94  inline void DFS(std::unordered_map<idx_t, bool> &visited,
95  const std::shared_ptr<LT> &layer,
96  const T &visit,
97  bool visitBefore) {
98  if (layer == nullptr) {
99  return;
100  }
101 
102  if (visitBefore)
103  visit(layer);
104 
105  visited[layer->getId()] = false;
106  for (const auto &connection : network->getLayerConnections(layer->getId())) {
107  if (connection.to().layerId() == layer->getId()) {
108  continue;
109  }
110  const auto outLayer = network->getLayer(connection.to().layerId());
111  if (!outLayer)
112  THROW_IE_EXCEPTION << "Couldn't get layer with id: " << connection.to().layerId();
113  auto i = visited.find(outLayer->getId());
114  if (i != visited.end()) {
115  /**
116  * cycle detected we entered still not completed node
117  */
118  if (!i->second) {
119  THROW_IE_EXCEPTION << "Sorting not possible, due to existed loop.";
120  }
121  continue;
122  }
123 
124  DFS(visited, outLayer, visit, visitBefore);
125  }
126  if (!visitBefore)
127  visit(layer);
128  visited[layer->getId()] = true;
129  }
130 };
131 
132 } // namespace details
133 } // namespace InferenceEngine
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:22
Definition: ie_argmax_layer.hpp:11