15 #include <unordered_map> 16 #include <unordered_set> 23 template <
class NT,
class LT>
24 class INFERENCE_ENGINE_NN_BUILDER_DEPRECATED INetworkIterator
25 :
public std::iterator<std::input_iterator_tag, std::shared_ptr<LT>> {
27 explicit INetworkIterator(NT* network,
bool toEnd): network(network), currentIdx(0) {}
28 explicit INetworkIterator(NT* network): network(network), currentIdx(0) {
30 const auto& inputs = network->getInputs();
32 std::vector<std::shared_ptr<LT>> allInputs;
33 for (
const auto& input : inputs) {
34 allInputs.push_back(std::dynamic_pointer_cast<LT>(input));
39 [&](std::shared_ptr<LT> current) {
40 sortedLayers.push_back(current);
44 std::reverse(std::begin(sortedLayers), std::end(sortedLayers));
45 currentLayer = getNextLayer();
48 IE_SUPPRESS_DEPRECATED_START
50 bool operator!=(
const INetworkIterator& that)
const {
51 return !operator==(that);
54 bool operator==(
const INetworkIterator& that)
const {
55 return network == that.network && currentLayer == that.currentLayer;
58 typename INetworkIterator::reference operator*() {
59 if (
nullptr == currentLayer) {
65 INetworkIterator& operator++() {
66 currentLayer = getNextLayer();
70 const INetworkIterator<NT, LT> operator++(
int) {
71 INetworkIterator<NT, LT> retval = *
this;
76 IE_SUPPRESS_DEPRECATED_END
79 std::vector<std::shared_ptr<LT>> sortedLayers;
80 std::shared_ptr<LT> currentLayer;
81 NT* network =
nullptr;
84 std::shared_ptr<LT> getNextLayer() {
85 return (sortedLayers.size() > currentIdx) ? sortedLayers[currentIdx++] :
nullptr;
89 inline void forestDFS(
const std::vector<std::shared_ptr<LT>>& heads,
const T& visit,
bool bVisitBefore) {
94 std::unordered_map<idx_t, bool> visited;
95 for (
auto& layer : heads) {
96 DFS(visited, layer, visit, bVisitBefore);
101 inline void DFS(std::unordered_map<idx_t, bool>& visited,
const std::shared_ptr<LT>& layer,
const T& visit,
103 if (layer ==
nullptr) {
107 if (visitBefore) visit(layer);
109 visited[layer->getId()] =
false;
110 for (
const auto& connection : network->getLayerConnections(layer->getId())) {
111 if (connection.to().layerId() == layer->getId()) {
114 const auto outLayer = network->getLayer(connection.to().layerId());
115 if (!outLayer)
THROW_IE_EXCEPTION <<
"Couldn't get layer with id: " << connection.to().layerId();
116 auto i = visited.find(outLayer->getId());
117 if (i != visited.end()) {
127 DFS(visited, outLayer, visit, visitBefore);
129 if (!visitBefore) visit(layer);
130 visited[layer->getId()] =
true;
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:25
Inference Engine API.
Definition: ie_argmax_layer.hpp:15
a header file for the Inference Engine Network interface