11 #include <unordered_map>
12 #include <unordered_set>
23 template<
class NT,
class LT>
24 class INetworkIterator:
public std::iterator<std::input_iterator_tag, std::shared_ptr<LT>> {
26 explicit INetworkIterator(NT * network,
bool toEnd =
false): network(network), currentIdx(0) {
27 if (!network || toEnd)
29 const auto& inputs = network->getInputs();
31 std::vector<std::shared_ptr<LT>> allInputs;
32 for (
const auto& input : inputs) {
33 allInputs.push_back(std::dynamic_pointer_cast<LT>(input));
36 bool res = forestDFS(allInputs, [&](std::shared_ptr<LT> current) {
37 sortedLayers.push_back(current);
44 std::reverse(std::begin(sortedLayers), std::end(sortedLayers));
45 currentLayer = getNextLayer();
47 bool operator!=(
const INetworkIterator& that)
const {
48 return !operator==(that);
50 bool operator==(
const INetworkIterator& that)
const {
51 return network == that.network && currentLayer == that.currentLayer;
53 typename INetworkIterator::reference operator*() {
54 if (
nullptr == currentLayer) {
60 INetworkIterator& operator++() {
61 currentLayer = getNextLayer();
65 const INetworkIterator<NT, LT> operator++(
int) {
66 INetworkIterator<NT, LT> retval = *
this;
72 std::vector<std::shared_ptr<LT>> sortedLayers;
73 std::shared_ptr<LT> currentLayer;
75 NT *network =
nullptr;
77 std::shared_ptr<LT> getNextLayer() {
78 return (sortedLayers.size() > currentIdx) ? sortedLayers[currentIdx++] :
nullptr;
82 inline bool forestDFS(
const std::vector<std::shared_ptr<LT>>& heads,
const T &visit,
bool bVisitBefore) {
87 std::unordered_map<idx_t, bool> visited;
88 for (
auto & layer : heads) {
89 if (!DFS(visited, layer, visit, bVisitBefore)) {
97 inline bool DFS(std::unordered_map<idx_t, bool> &visited,
98 const std::shared_ptr<LT> &layer,
101 if (layer ==
nullptr) {
108 visited[layer->getId()] =
false;
109 for (
const auto &connection : network->getLayerConnections(layer->getId())) {
110 if (connection.to().layerId() == layer->getId()) {
113 const auto outLayer = network->getLayer(connection.to().layerId());
114 auto i = visited.find(outLayer->getId());
115 if (i != visited.end()) {
125 if (!DFS(visited, outLayer, visit, visitBefore)) {
131 visited[layer->getId()] =
true;
#define THROW_IE_EXCEPTION
A macro used to throw the exception with a notable description.
Definition: ie_exception.hpp:22
Definition: ie_argmax_layer.hpp:11
a header file for the Inference Engine Network interface