ie_inetwork_iterator.hpp
1 // Copyright (C) 2018 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 /**
6  * @brief A header file for the CNNNetworkIterator class
7  * @file ie_cnn_network_iterator.hpp
8  */
9 #pragma once
10 #include <utility>
11 #include <unordered_map>
12 #include <unordered_set>
13 #include <list>
14 #include <iterator>
15 #include <memory>
16 #include <vector>
17 
18 #include <ie_inetwork.hpp>
19 
20 namespace InferenceEngine {
21 namespace details {
22 
23 template<class NT, class LT>
24 class INetworkIterator: public std::iterator<std::input_iterator_tag, std::shared_ptr<LT>> {
25 public:
26  explicit INetworkIterator(NT * network, bool toEnd = false): network(network), currentIdx(0) {
27  if (!network || toEnd)
28  return;
29  const auto& inputs = network->getInputs();
30 
31  std::vector<std::shared_ptr<LT>> allInputs;
32  for (const auto& input : inputs) {
33  allInputs.push_back(std::dynamic_pointer_cast<LT>(input));
34  }
35 
36  bool res = forestDFS(allInputs, [&](std::shared_ptr<LT> current) {
37  sortedLayers.push_back(current);
38  }, false);
39 
40  if (!res) {
41  THROW_IE_EXCEPTION << "Sorting not possible, due to existed loop.";
42  }
43 
44  std::reverse(std::begin(sortedLayers), std::end(sortedLayers));
45  currentLayer = getNextLayer();
46  }
47  bool operator!=(const INetworkIterator& that) const {
48  return !operator==(that);
49  }
50  bool operator==(const INetworkIterator& that) const {
51  return network == that.network && currentLayer == that.currentLayer;
52  }
53  typename INetworkIterator::reference operator*() {
54  if (nullptr == currentLayer) {
55  THROW_IE_EXCEPTION << "iterator out of bound";
56  }
57  return currentLayer;
58  }
59 
60  INetworkIterator& operator++() {
61  currentLayer = getNextLayer();
62  return *this;
63  }
64 
65  const INetworkIterator<NT, LT> operator++(int) {
66  INetworkIterator<NT, LT> retval = *this;
67  ++(*this);
68  return retval;
69  }
70 
71 private:
72  std::vector<std::shared_ptr<LT>> sortedLayers;
73  std::shared_ptr<LT> currentLayer;
74  size_t currentIdx;
75  NT *network = nullptr;
76 
77  std::shared_ptr<LT> getNextLayer() {
78  return (sortedLayers.size() > currentIdx) ? sortedLayers[currentIdx++] : nullptr;
79  }
80 
81  template<class T>
82  inline bool forestDFS(const std::vector<std::shared_ptr<LT>>& heads, const T &visit, bool bVisitBefore) {
83  if (heads.empty()) {
84  return true;
85  }
86 
87  std::unordered_map<idx_t, bool> visited;
88  for (auto & layer : heads) {
89  if (!DFS(visited, layer, visit, bVisitBefore)) {
90  return false;
91  }
92  }
93  return true;
94  }
95 
96  template<class T>
97  inline bool DFS(std::unordered_map<idx_t, bool> &visited,
98  const std::shared_ptr<LT> &layer,
99  const T &visit,
100  bool visitBefore) {
101  if (layer == nullptr) {
102  return true;
103  }
104 
105  if (visitBefore)
106  visit(layer);
107 
108  visited[layer->getId()] = false;
109  for (const auto &connection : network->getLayerConnections(layer->getId())) {
110  if (connection.to().layerId() == layer->getId()) {
111  continue;
112  }
113  const auto outLayer = network->getLayer(connection.to().layerId());
114  auto i = visited.find(outLayer->getId());
115  if (i != visited.end()) {
116  /**
117  * cycle detected we entered still not completed node
118  */
119  if (!i->second) {
120  return false;
121  }
122  continue;
123  }
124 
125  if (!DFS(visited, outLayer, visit, visitBefore)) {
126  return false;
127  }
128  }
129  if (!visitBefore)
130  visit(layer);
131  visited[layer->getId()] = true;
132  return true;
133  }
134 };
135 
136 } // namespace details
137 } // namespace InferenceEngine
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:22
Definition: ie_argmax_layer.hpp:11
a header file for the Inference Engine Network interface