evaluator.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <map>
20 #include <stack>
21 #include <utility>
22 
23 #include "ngraph/node.hpp"
24 #include "ngraph/shape.hpp"
25 #include "ngraph/type/element_type_traits.hpp"
26 
27 namespace ngraph
28 {
29  /// \brief Execute handlers on a subgraph to compute values
30  ///
31  ///
32  template <typename V>
33  class Evaluator
34  {
35  public:
36  /// \brief values we compute for outputs
37  using value_map = std::map<RawNodeOutput, V>;
38 
39  /// \brief Handler for a computation of a value about an op
40  ///
41  /// A handler is passed a Node* and a vector of computed input values. The handler should
42  /// return a vector of computed output values.
43  using op_handler = std::function<std::vector<V>(Node* op, std::vector<V>& inputs)>;
44 
45  /// \brief Table of ops with handlers
46  using op_handler_map = std::map<Node::type_info_t, op_handler>;
47 
48  /// \brief construct handler using the provided op handlers.
49  ///
50  /// Evaluations share previously computed values so that calls on multiple nodes can share
51  /// work. All state is kept in the value map, which is accessible for clearing or seeding
52  /// with
53  /// Evaluator::get_value_map().
54  ///
55  /// \param Handlers for ops. Pairs of Node::type_info_t and handler functions.
56  Evaluator(const op_handler_map& handlers, value_map& values)
57  : m_handlers(handlers)
58  , m_value_map(values)
59  {
60  }
61 
62  /// \brief Retrieves the value_map, which holds all Output<Node> value associations.
63  value_map& get_value_map() { return m_value_map; }
64  const value_map& get_value_map() const { return m_value_map; }
65  /// \brief If set, handles all ops
66  const op_handler& get_univeral_handler() const { return m_universal_handler; }
67  /// \brief If set, handles all ops not in the handlers
68  const op_handler& get_default_handler() const { return m_default_handler; }
69  /// \brief If set, handles all ops
70  void set_univeral_handler(const op_handler& handler) { m_universal_handler = handler; }
71  /// \brief If set, handles all ops not in the handlers
72  void set_default_handler(const op_handler& handler) { m_default_handler = handler; }
73  protected:
74  op_handler get_handler(Node* node)
75  {
76  op_handler handler = m_universal_handler;
77  if (!handler)
78  {
79  auto it = m_handlers.find(node->get_type_info());
80  if (it == m_handlers.end())
81  {
82  handler = m_default_handler;
83  }
84  else
85  {
86  handler = it->second;
87  }
88  }
89  return handler;
90  }
91 
92  class Inst;
93  using InstPtr = std::unique_ptr<Inst>;
94  using InstStack = std::stack<InstPtr>;
95 
96  /// \brief Intstructions for evaluations state machine
97  class Inst
98  {
99  protected:
100  Inst(Node* node)
101  : m_node(node)
102  {
103  }
104 
105  public:
106  virtual ~Inst() {}
107  virtual void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) = 0;
108  Node* get_node() { return m_node; }
109  protected:
110  Node* m_node;
111  };
112 
113  /// \brief Ensure value has been analyzed
114  class ValueInst : public Inst
115  {
116  public:
117  ValueInst(const Output<Node>& value)
118  : Inst(value.get_node())
119  , m_index(value.get_index())
120  {
121  }
122 
123  ValueInst(const RawNodeOutput& value)
124  : Inst(value.node)
125  , m_index(value.index)
126  {
127  }
128 
129  void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override
130  {
131  // Request to analyze this value if we can
132  if (auto handler = evaluator.get_handler(node))
133  {
134  // Ensure the inputs are processed and then execute the op handler
135  inst_stack.push(InstPtr(new ExecuteInst(node, handler)));
136  for (auto v : node->input_values())
137  {
138  inst_stack.push(InstPtr(new ValueInst(v)));
139  }
140  }
141  else
142  {
143  // We don't know how to handle this op, so mark the outputs as unknown
144  for (auto output : node->outputs())
145  {
146  evaluator.get_value_map()[output] = V();
147  }
148  }
149  }
150 
151  private:
152  int64_t m_index;
153  };
154 
155  /// \brief All arguments have been handled; execute the node handler
156  class ExecuteInst : public Inst
157  {
158  public:
159  ExecuteInst(Node* node, op_handler& handler)
160  : Inst(node)
161  , m_handler(handler)
162  {
163  }
164 
165  void handle(Evaluator& evaluator, InstStack& inst_stack, Node* node) override
166  {
167  // Request to execute the handleer. Pass what we know about the inputs to the
168  // handler and associate the results with the outputs
169  std::vector<V> inputs;
170  for (auto v : node->input_values())
171  {
172  inputs.push_back(evaluator.get_value_map().at(v));
173  }
174  std::vector<V> outputs = m_handler(node, inputs);
175  for (size_t i = 0; i < outputs.size(); ++i)
176  {
177  evaluator.get_value_map()[node->output(i)] = outputs[i];
178  }
179  }
180 
181  private:
182  op_handler m_handler;
183  };
184 
185  public:
186  /// \brief Determine information about value
187  V evaluate(const Output<Node>& value)
188  {
189  InstStack inst_stack;
190  inst_stack.push(InstPtr(new ValueInst(value)));
191  while (!inst_stack.empty())
192  {
193  InstPtr inst;
194  std::swap(inst_stack.top(), inst);
195  inst_stack.pop();
196  auto node = inst->get_node();
197  if (m_value_map.find(node->output(0)) != m_value_map.end())
198  {
199  // Already computed
200  continue;
201  }
202  inst->handle(*this, inst_stack, node);
203  }
204  return m_value_map.at(value);
205  }
206 
207  protected:
208  op_handler m_universal_handler;
209  op_handler_map m_handlers;
210  op_handler m_default_handler;
211  value_map& m_value_map;
212  };
213 }
All arguments have been handled; execute the node handler.
Definition: evaluator.hpp:157
Intstructions for evaluations state machine.
Definition: evaluator.hpp:98
Ensure value has been analyzed.
Definition: evaluator.hpp:115
Execute handlers on a subgraph to compute values.
Definition: evaluator.hpp:34
value_map & get_value_map()
Retrieves the value_map, which holds all Output<Node> value associations.
Definition: evaluator.hpp:63
V evaluate(const Output< Node > &value)
Determine information about value.
Definition: evaluator.hpp:187
std::map< RawNodeOutput, V > value_map
values we compute for outputs
Definition: evaluator.hpp:37
const op_handler & get_univeral_handler() const
If set, handles all ops.
Definition: evaluator.hpp:66
std::map< Node::type_info_t, op_handler > op_handler_map
Table of ops with handlers.
Definition: evaluator.hpp:46
const op_handler & get_default_handler() const
If set, handles all ops not in the handlers.
Definition: evaluator.hpp:68
void set_univeral_handler(const op_handler &handler)
If set, handles all ops.
Definition: evaluator.hpp:70
void set_default_handler(const op_handler &handler)
If set, handles all ops not in the handlers.
Definition: evaluator.hpp:72
Evaluator(const op_handler_map &handlers, value_map &values)
construct handler using the provided op handlers.
Definition: evaluator.hpp:56
std::function< std::vector< V >(Node *op, std::vector< V > &inputs)> op_handler
Handler for a computation of a value about an op.
Definition: evaluator.hpp:43
Definition: node.hpp:132
virtual const type_info_t & get_type_info() const =0
Output< Node > output(size_t output_index)
std::vector< Output< Node > > outputs()
std::vector< Output< Node > > input_values() const
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Node * get_node() const
size_t get_index() const
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
Definition: node.hpp:606