14 #include <unordered_map> 15 #include <unordered_set> 22 template <
class NT,
class LT>
23 class INFERENCE_ENGINE_NN_BUILDER_DEPRECATED INetworkIterator
24 :
public std::iterator<std::input_iterator_tag, std::shared_ptr<LT>> {
26 explicit INetworkIterator(NT* network,
bool toEnd): network(network), currentIdx(0) {}
27 explicit INetworkIterator(NT* network): network(network), currentIdx(0) {
29 const auto& inputs = network->getInputs();
31 std::vector<std::shared_ptr<LT>> allInputs;
32 for (
const auto& input : inputs) {
33 allInputs.push_back(std::dynamic_pointer_cast<LT>(input));
38 [&](std::shared_ptr<LT> current) {
39 sortedLayers.push_back(current);
43 std::reverse(std::begin(sortedLayers), std::end(sortedLayers));
44 currentLayer = getNextLayer();
47 IE_SUPPRESS_DEPRECATED_START
49 bool operator!=(
const INetworkIterator& that)
const {
50 return !operator==(that);
53 bool operator==(
const INetworkIterator& that)
const {
54 return network == that.network && currentLayer == that.currentLayer;
57 typename INetworkIterator::reference operator*() {
58 if (
nullptr == currentLayer) {
64 INetworkIterator& operator++() {
65 currentLayer = getNextLayer();
69 const INetworkIterator<NT, LT> operator++(
int) {
70 INetworkIterator<NT, LT> retval = *
this;
75 IE_SUPPRESS_DEPRECATED_END
78 std::vector<std::shared_ptr<LT>> sortedLayers;
79 std::shared_ptr<LT> currentLayer;
80 NT* network =
nullptr;
83 std::shared_ptr<LT> getNextLayer() {
84 return (sortedLayers.size() > currentIdx) ? sortedLayers[currentIdx++] :
nullptr;
88 inline void forestDFS(
const std::vector<std::shared_ptr<LT>>& heads,
const T& visit,
bool bVisitBefore) {
93 std::unordered_map<idx_t, bool> visited;
94 for (
auto& layer : heads) {
95 DFS(visited, layer, visit, bVisitBefore);
100 inline void DFS(std::unordered_map<idx_t, bool>& visited,
const std::shared_ptr<LT>& layer,
const T& visit,
102 if (layer ==
nullptr) {
106 if (visitBefore) visit(layer);
108 visited[layer->getId()] =
false;
109 for (
const auto& connection : network->getLayerConnections(layer->getId())) {
110 if (connection.to().layerId() == layer->getId()) {
113 const auto outLayer = network->getLayer(connection.to().layerId());
114 if (!outLayer)
THROW_IE_EXCEPTION <<
"Couldn't get layer with id: " << connection.to().layerId();
115 auto i = visited.find(outLayer->getId());
116 if (i != visited.end()) {
126 DFS(visited, outLayer, visit, visitBefore);
128 if (!visitBefore) visit(layer);
129 visited[layer->getId()] =
true;
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:25
Inference Engine API.
Definition: ie_argmax_layer.hpp:15
A header file for the Inference Engine Network interface.