11 #include <unordered_map>
12 #include <unordered_set>
18 #include <ie_network.hpp>
23 template<
class NT,
class LT>
24 class INetworkIterator:
public std::iterator<std::input_iterator_tag, std::shared_ptr<LT>> {
26 explicit INetworkIterator(NT * network,
bool toEnd =
false): network(network), currentIdx(0) {
27 if (!network || toEnd)
29 const auto& inputs = network->getInputs();
31 std::vector<std::shared_ptr<LT>> allInputs;
32 for (
const auto& input : inputs) {
33 allInputs.push_back(std::dynamic_pointer_cast<LT>(input));
36 forestDFS(allInputs, [&](std::shared_ptr<LT> current) {
37 sortedLayers.push_back(current);
40 std::reverse(std::begin(sortedLayers), std::end(sortedLayers));
41 currentLayer = getNextLayer();
44 bool operator!=(
const INetworkIterator& that)
const {
45 return !operator==(that);
48 bool operator==(
const INetworkIterator& that)
const {
49 return network == that.network && currentLayer == that.currentLayer;
52 typename INetworkIterator::reference operator*() {
53 if (
nullptr == currentLayer) {
59 INetworkIterator& operator++() {
60 currentLayer = getNextLayer();
64 const INetworkIterator<NT, LT> operator++(
int) {
65 INetworkIterator<NT, LT> retval = *
this;
71 std::vector<std::shared_ptr<LT>> sortedLayers;
72 std::shared_ptr<LT> currentLayer;
74 NT *network =
nullptr;
76 std::shared_ptr<LT> getNextLayer() {
77 return (sortedLayers.size() > currentIdx) ? sortedLayers[currentIdx++] :
nullptr;
81 inline void forestDFS(
const std::vector<std::shared_ptr<LT>>& heads,
const T &visit,
bool bVisitBefore) {
86 std::unordered_map<idx_t, bool> visited;
87 for (
auto & layer : heads) {
88 DFS(visited, layer, visit, bVisitBefore);
93 inline void DFS(std::unordered_map<idx_t, bool> &visited,
94 const std::shared_ptr<LT> &layer,
97 if (layer ==
nullptr) {
104 visited[layer->getId()] =
false;
105 for (
const auto &connection : network->getLayerConnections(layer->getId())) {
106 if (connection.to().layerId() == layer->getId()) {
109 const auto outLayer = network->getLayer(connection.to().layerId());
112 auto i = visited.find(outLayer->getId());
113 if (i != visited.end()) {
123 DFS(visited, outLayer, visit, visitBefore);
127 visited[layer->getId()] =
true;
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:22
Definition: ie_argmax_layer.hpp:11