evaluator.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <map>
8 #include <stack>
9 #include <utility>
10 
11 #include "ngraph/node.hpp"
12 #include "ngraph/shape.hpp"
13 #include "ngraph/type/element_type_traits.hpp"
14 
15 namespace ngraph
16 {
17  /// \brief Execute handlers on a subgraph to compute values
18  ///
19  ///
20  template <typename V>
21  class Evaluator
22  {
23  public:
24  /// \brief values we compute for outputs
25  using value_map = std::map<RawNodeOutput, V>;
26 
27  /// \brief Handler for a computation of a value about an op
28  ///
29  /// A handler is passed a Node* and a vector of computed input values. The handler should
30  /// return a vector of computed output values.
31  using op_handler = std::function<std::vector<V>(Node* op, std::vector<V>& inputs)>;
32 
33  /// \brief Table of ops with handlers
34  using op_handler_map = std::map<Node::type_info_t, op_handler>;
35 
36  /// \brief construct handler using the provided op handlers.
37  ///
38  /// Evaluations share previously computed values so that calls on multiple nodes can share
39  /// work. All state is kept in the value map, which is accessible for clearing or seeding
40  /// with
41  /// Evaluator::get_value_map().
42  ///
43  /// \param Handlers for ops. Pairs of Node::type_info_t and handler functions.
44  Evaluator(const op_handler_map& handlers, value_map& values)
45  : m_handlers(handlers)
46  , m_value_map(values)
47  {
48  }
49 
50  /// \brief Retrieves the value_map, which holds all Output<Node> value associations.
51  value_map& get_value_map() { return m_value_map; }
52  const value_map& get_value_map() const { return m_value_map; }
53  /// \brief If set, handles all ops
54  const op_handler& get_univeral_handler() const { return m_universal_handler; }
55  /// \brief If set, handles all ops not in the handlers
56  const op_handler& get_default_handler() const { return m_default_handler; }
57  /// \brief If set, handles all ops
58  void set_univeral_handler(const op_handler& handler) { m_universal_handler = handler; }
59  /// \brief If set, handles all ops not in the handlers
60  void set_default_handler(const op_handler& handler) { m_default_handler = handler; }
61 
62  protected:
63  op_handler get_handler(Node* node)
64  {
65  op_handler handler = m_universal_handler;
66  if (!handler)
67  {
68  auto it = m_handlers.find(node->get_type_info());
69  if (it == m_handlers.end())
70  {
71  handler = m_default_handler;
72  }
73  else
74  {
75  handler = it->second;
76  }
77  }
78  return handler;
79  }
80 
81  class Inst;
82  using InstPtr = std::unique_ptr<Inst>;
83  using InstStack = std::stack<InstPtr>;
84 
85  /// \brief Intstructions for evaluations state machine
86  class Inst
87  {
88  protected:
89  Inst(Node* node)
90  : m_node(node)
91  {
92  }
93 
94  public:
95  virtual ~Inst() {}
96  virtual void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) = 0;
97  Node* get_node() { return m_node; }
98 
99  protected:
100  Node* m_node;
101  };
102 
103  /// \brief Ensure value has been analyzed
104  class ValueInst : public Inst
105  {
106  public:
107  ValueInst(const Output<Node>& value)
108  : Inst(value.get_node())
109  , m_index(value.get_index())
110  {
111  }
112 
113  ValueInst(const RawNodeOutput& value)
114  : Inst(value.node)
115  , m_index(value.index)
116  {
117  }
118 
119  void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override
120  {
121  // Request to analyze this value if we can
122  if (auto handler = evaluator.get_handler(node))
123  {
124  // Ensure the inputs are processed and then execute the op handler
125  inst_stack.push(InstPtr(new ExecuteInst(node, handler)));
126  for (auto v : node->input_values())
127  {
128  inst_stack.push(InstPtr(new ValueInst(v)));
129  }
130  }
131  else
132  {
133  // We don't know how to handle this op, so mark the outputs as unknown
134  for (auto output : node->outputs())
135  {
136  evaluator.get_value_map()[output] = V();
137  }
138  }
139  }
140 
141  private:
142  int64_t m_index;
143  };
144 
145  /// \brief All arguments have been handled; execute the node handler
146  class ExecuteInst : public Inst
147  {
148  public:
149  ExecuteInst(Node* node, op_handler& handler)
150  : Inst(node)
151  , m_handler(handler)
152  {
153  }
154 
155  void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override
156  {
157  // Request to execute the handleer. Pass what we know about the inputs to the
158  // handler and associate the results with the outputs
159  std::vector<V> inputs;
160  for (auto v : node->input_values())
161  {
162  inputs.push_back(evaluator.get_value_map().at(v));
163  }
164  std::vector<V> outputs = m_handler(node, inputs);
165  for (size_t i = 0; i < outputs.size(); ++i)
166  {
167  evaluator.get_value_map()[node->output(i)] = outputs[i];
168  }
169  }
170 
171  private:
172  op_handler m_handler;
173  };
174 
175  public:
176  /// \brief Determine information about value
177  V evaluate(const Output<Node>& value)
178  {
179  InstStack inst_stack;
180  inst_stack.push(InstPtr(new ValueInst(value)));
181  while (!inst_stack.empty())
182  {
183  InstPtr inst;
184  std::swap(inst_stack.top(), inst);
185  inst_stack.pop();
186  auto node = inst->get_node();
187  if (m_value_map.find(node->output(0)) != m_value_map.end())
188  {
189  // Already computed
190  continue;
191  }
192  inst->handle(*this, inst_stack, node);
193  }
194  return m_value_map.at(value);
195  }
196 
197  protected:
198  op_handler m_universal_handler;
199  op_handler_map m_handlers;
200  op_handler m_default_handler;
201  value_map& m_value_map;
202  };
203 } // namespace ngraph
All arguments have been handled; execute the node handler.
Definition: evaluator.hpp:147
Intstructions for evaluations state machine.
Definition: evaluator.hpp:87
Ensure value has been analyzed.
Definition: evaluator.hpp:105
Execute handlers on a subgraph to compute values.
Definition: evaluator.hpp:22
value_map & get_value_map()
Retrieves the value_map, which holds all Output<Node> value associations.
Definition: evaluator.hpp:51
V evaluate(const Output< Node > &value)
Determine information about value.
Definition: evaluator.hpp:177
std::map< RawNodeOutput, V > value_map
values we compute for outputs
Definition: evaluator.hpp:25
const op_handler & get_univeral_handler() const
If set, handles all ops.
Definition: evaluator.hpp:54
std::map< Node::type_info_t, op_handler > op_handler_map
Table of ops with handlers.
Definition: evaluator.hpp:34
const op_handler & get_default_handler() const
If set, handles all ops not in the handlers.
Definition: evaluator.hpp:56
void set_univeral_handler(const op_handler &handler)
If set, handles all ops.
Definition: evaluator.hpp:58
void set_default_handler(const op_handler &handler)
If set, handles all ops not in the handlers.
Definition: evaluator.hpp:60
Evaluator(const op_handler_map &handlers, value_map &values)
construct handler using the provided op handlers.
Definition: evaluator.hpp:44
std::function< std::vector< V >(Node *op, std::vector< V > &inputs)> op_handler
Handler for a computation of a value about an op.
Definition: evaluator.hpp:31
Definition: node.hpp:127
virtual const type_info_t & get_type_info() const =0
Output< Node > output(size_t output_index)
std::vector< Output< Node > > outputs()
std::vector< Output< Node > > input_values() const
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Node * get_node() const
size_t get_index() const
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: node.hpp:617