11 #include <unordered_map>
12 #include <unordered_set>
18 #include <ie_network.hpp>
23 template<
class NT,
class LT>
24 class INetworkIterator:
public std::iterator<std::input_iterator_tag, std::shared_ptr<LT>> {
26 explicit INetworkIterator(NT * network,
bool toEnd): network(network), currentIdx(0) {}
27 explicit INetworkIterator(NT * network): network(network), currentIdx(0) {
30 const auto& inputs = network->getInputs();
32 std::vector<std::shared_ptr<LT>> allInputs;
33 for (
const auto& input : inputs) {
34 allInputs.push_back(std::dynamic_pointer_cast<LT>(input));
37 forestDFS(allInputs, [&](std::shared_ptr<LT> current) {
38 sortedLayers.push_back(current);
41 std::reverse(std::begin(sortedLayers), std::end(sortedLayers));
42 currentLayer = getNextLayer();
45 bool operator!=(
const INetworkIterator& that)
const {
46 return !operator==(that);
49 bool operator==(
const INetworkIterator& that)
const {
50 return network == that.network && currentLayer == that.currentLayer;
53 typename INetworkIterator::reference operator*() {
54 if (
nullptr == currentLayer) {
60 INetworkIterator& operator++() {
61 currentLayer = getNextLayer();
65 const INetworkIterator<NT, LT> operator++(
int) {
66 INetworkIterator<NT, LT> retval = *
this;
72 std::vector<std::shared_ptr<LT>> sortedLayers;
73 std::shared_ptr<LT> currentLayer;
75 NT *network =
nullptr;
77 std::shared_ptr<LT> getNextLayer() {
78 return (sortedLayers.size() > currentIdx) ? sortedLayers[currentIdx++] :
nullptr;
82 inline void forestDFS(
const std::vector<std::shared_ptr<LT>>& heads,
const T &visit,
bool bVisitBefore) {
87 std::unordered_map<idx_t, bool> visited;
88 for (
auto & layer : heads) {
89 DFS(visited, layer, visit, bVisitBefore);
94 inline void DFS(std::unordered_map<idx_t, bool> &visited,
95 const std::shared_ptr<LT> &layer,
98 if (layer ==
nullptr) {
105 visited[layer->getId()] =
false;
106 for (
const auto &connection : network->getLayerConnections(layer->getId())) {
107 if (connection.to().layerId() == layer->getId()) {
110 const auto outLayer = network->getLayer(connection.to().layerId());
113 auto i = visited.find(outLayer->getId());
114 if (i != visited.end()) {
124 DFS(visited, outLayer, visit, visitBefore);
128 visited[layer->getId()] =
true;
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:22
Definition: ie_argmax_layer.hpp:11