sub_graph_base.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <ngraph/op/parameter.hpp>
8 #include "ngraph/op/op.hpp"
9 
10 namespace ngraph
11 {
12  namespace op
13  {
14  namespace util
15  {
16  /// \brief Abstract base class for sub-graph based ops, i.e ops that have sub-graph
17  ///
18  class NGRAPH_API SubGraphOp : public Op
19  {
20  public:
21  NGRAPH_RTTI_DECLARATION;
22  /// \brief Describes a connection between a SubGraphOp input and the body.
24  {
25  protected:
26  ///
27  /// \brief Constructs a new instance.
28  ///
29  /// \param input_index Position of the SubGraphOp input
30  /// \param body_parameter_index Body parameter to receive input
31  ///
32  InputDescription(uint64_t input_index, uint64_t body_parameter_index);
33  InputDescription() = default;
34 
35  public:
37  virtual ~InputDescription() = default;
38  virtual std::shared_ptr<InputDescription> copy() const = 0;
39 
40  virtual const type_info_t& get_type_info() const = 0;
41 
42  uint64_t m_input_index{0};
43  uint64_t m_body_parameter_index{0};
44  };
45 
46  ///
47  /// \brief Describes a body input formed from slices of an input to
48  /// SubGraphOp.
49  ///
50  class NGRAPH_API SliceInputDescription : public InputDescription
51  {
52  public:
53  static constexpr type_info_t type_info{"SliceInputDescription", 0};
54  const type_info_t& get_type_info() const override { return type_info; }
55  ///
56  /// \brief Constructs a new instance.
57  ///
58  /// \param input_index Position of the SubGraphOp input
59  /// \param body_parameter_index Body parameter position to receive input
60  /// \param start First index for slices
61  /// \param stride Step amount for slices
62  /// \param part_size Width of slices
63  /// \param end Last index for slices
64  /// \param axis Axis being sliced
65  ///
66  SliceInputDescription(uint64_t input_index,
67  uint64_t body_parameter_index,
68  int64_t start,
69  int64_t stride,
70  int64_t part_size,
71  int64_t end,
72  int64_t axis);
73  SliceInputDescription() = default;
74  std::shared_ptr<InputDescription> copy() const override;
75  int64_t m_start{0};
76  int64_t m_stride{0};
77  int64_t m_part_size{0};
78  int64_t m_end{0};
79  int64_t m_axis{0};
80  };
81 
82  ///
83  /// \brief Describes a body input initialized from a SubGraphOp input on
84  /// the first iteration, and then a body output thereafter.
85  ///
86  class NGRAPH_API MergedInputDescription : public InputDescription
87  {
88  public:
89  static constexpr type_info_t type_info{"MergedInputDescription", 0};
90  const type_info_t& get_type_info() const override { return type_info; }
91  ///
92  /// \brief Constructs a new instance.
93  ///
94  /// \param input_index Position of the SubGraphOp input
95  /// supplying a value to body_parameter for
96  /// the initial iteration.
97  /// \param body_parameter_index Body parameter position to receive input.
98  /// \param body_value_index Body value to supply body_parameter for
99  /// successive
100  /// iterations.
101  ///
102  MergedInputDescription(uint64_t input_index,
103  uint64_t body_parameter_index,
104  uint64_t body_value_index);
105  MergedInputDescription() = default;
106  std::shared_ptr<InputDescription> copy() const override;
107  uint64_t m_body_value_index{0};
108  };
109 
110  ///
111  /// \brief Describes a body input initialized from a SubGraphOp input on
112  /// the first iteration, and invariant thereafter.
113  ///
114  class NGRAPH_API InvariantInputDescription : public InputDescription
115  {
116  public:
117  static constexpr type_info_t type_info{"InvariantInputDescription", 0};
118  const type_info_t& get_type_info() const override { return type_info; }
119  ///
120  /// \brief Constructs a new instance.
121  ///
122  /// \param input_index Position of the SubGraphOp input
123  /// \param body_parameter_index Body parameter to receive input
124  ///
125  InvariantInputDescription(uint64_t input_index, uint64_t body_parameter_index);
126  InvariantInputDescription() = default;
127  std::shared_ptr<InputDescription> copy() const override;
128  };
129 
130  /// \brief Describes how a SubGraphOp output is produced from the body.
132  {
133  protected:
134  ///
135  /// \brief Constructs a new instance.
136  ///
137  /// \param body_value_index A body value that produces the output
138  /// \param output_index The SubGraphOp output index
139  ///
140  OutputDescription(uint64_t body_value_index, uint64_t output_index);
141  OutputDescription() = default;
142 
143  public:
145  virtual ~OutputDescription() = default;
146  virtual std::shared_ptr<OutputDescription> copy() const = 0;
147  virtual const type_info_t& get_type_info() const = 0;
148 
149  uint64_t m_body_value_index{0};
150  uint64_t m_output_index{0};
151  };
152 
153  /// \brief Produces an output by concatenating an output from each iteration
154  class NGRAPH_API ConcatOutputDescription : public OutputDescription
155  {
156  public:
157  static constexpr type_info_t type_info{"ConcatOutputDescription", 0};
158  const type_info_t& get_type_info() const override { return type_info; }
159  ///
160  /// \brief Constructs a new instance.
161  ///
162  /// \param body_value_index A body value that produces the output
163  /// \param output_index The SubGraphOp output index
164  /// \param start First index for slices
165  /// \param stride Step amount for slices
166  /// \param part_size Width of slices
167  /// \param end Last index for slices
168  /// \param axis Axis being sliced
169  ///
170  ConcatOutputDescription(uint64_t body_value_index,
171  uint64_t output_index,
172  int64_t start,
173  int64_t stride,
174  int64_t part_size,
175  int64_t end,
176  int64_t axis);
177  ConcatOutputDescription() = default;
178 
179  std::shared_ptr<OutputDescription> copy() const override;
180  int64_t m_start{0};
181  int64_t m_stride{0};
182  int64_t m_part_size{0};
183  int64_t m_end{0};
184  int64_t m_axis{0};
185  };
186 
187  /// \brief Produces an output from a specific iteration
188  class NGRAPH_API BodyOutputDescription : public OutputDescription
189  {
190  public:
191  static constexpr type_info_t type_info{"BodyOutputDescription", 0};
192  const type_info_t& get_type_info() const override { return type_info; }
193  ///
194  /// \brief Constructs a new instance.
195  ///
196  /// \param body_value_index A body value that produces the output
197  /// \param output_index The SubGraphOp output index
198  /// \param iteration which iteration (typically -1, final) will
199  /// supply the value
200  ///
201  BodyOutputDescription(uint64_t body_value_index,
202  uint64_t output_index,
203  int64_t iteration);
204  BodyOutputDescription() = default;
205  std::shared_ptr<OutputDescription> copy() const override;
206  int64_t m_iteration{0};
207  };
208 
209  virtual std::shared_ptr<Function> get_function() { return m_body; };
210  virtual std::shared_ptr<const Function> get_function() const { return m_body; };
211  virtual void set_function(const std::shared_ptr<Function>& func) { m_body = func; };
212  /// \return a reference to the input descriptions.
213  const std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions() const
214  {
215  return m_input_descriptions;
216  }
217  /// \return a reference to the input descriptions. Can add input descriptions
218  /// before
219  /// validation.
220  std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions()
221  {
222  return m_input_descriptions;
223  }
224  /// \return a reference to the output descriptions.
225  const std::vector<std::shared_ptr<OutputDescription>>&
227  {
228  return m_output_descriptions;
229  }
230  /// \return a reference to the output descriptions. Can add output descriptions
231  /// before
232  /// validation.
233  std::vector<std::shared_ptr<OutputDescription>>& get_output_descriptions()
234  {
235  return m_output_descriptions;
236  }
237 
238  ///
239  /// \brief Indicate that a body parameter comes from slices of a value
240  ///
241  /// \param parameter The parameter to receive the slices
242  /// \param value The value to be sliced. This will be added as an input to
243  /// SubGraphOp.
244  /// \param start First index on axis of the slicing
245  /// \param stride Stepping of the slice
246  /// \param part_size Size of the slice on axis
247  /// \param end The last index on axis of the slicing
248  /// \param axis The axis to slice along
249  ///
250  virtual void set_sliced_input(const std::shared_ptr<Parameter>& parameter,
251  const Output<Node>& value,
252  int64_t start,
253  int64_t stride,
254  int64_t part_size,
255  int64_t end,
256  int64_t axis);
257  ///
258  /// \brief Indicates that a body parameter has an initial value in the first
259  /// iteration and computed value thereafter
260  ///
261  /// \param[in] body_parameter The body parameter
262  /// \param initial_value Value for the parameter in first iteration. This
263  /// will be added as an input to Loop.
264  /// \param successive_value Value for the parameter in successive iterations.
265  /// The value is what is active in the most recent
266  /// completed iteration.
267  ///
268  virtual void set_merged_input(const std::shared_ptr<Parameter>& body_parameter,
269  const Output<Node>& initial_value,
270  const Output<Node>& successive_value);
271  ///
272  /// \brief Indicates that a body parameter has an invariant value during
273  /// iteration that may depend on values computed outside of the
274  /// iteration.
275  ///
276  /// \param body_parameter The body parameter
277  /// \param value The value supplied as an input to the block
278  ///
279  virtual void set_invariant_input(const std::shared_ptr<Parameter>& body_parameter,
280  const Output<Node>& value);
281  ///
282  /// \brief Gets a value for a particular iteration point
283  ///
284  /// \param body_value The value
285  /// \param iteration The iteration that supplies the value. Negative values
286  /// are from the last iteration.
287  /// Default value -1 (the last iteration).
288  ///
289  /// \return The iterator value.
290  ///
291  virtual Output<Node> get_iter_value(const Output<Node>& body_value,
292  int64_t iteration = -1);
293  ///
294  /// \brief Concatenates slices from all iterations
295  ///
296  /// \param value The value supplying slice values from each iteration.
297  /// \param start First index on axis of the slicing
298  /// \param stride Stepping of the slice
299  /// \param part_size Size of the slice on axis
300  /// \param end The last index on axis of the slicing
301  /// \param axis The axis to slice along
302  ///
303  /// \return The concatenated slices.
304  ///
306  int64_t start,
307  int64_t stride,
308  int64_t part_size,
309  int64_t end,
310  int64_t axis);
311 
312  SubGraphOp(const SubGraphOp&) = delete;
313  SubGraphOp(SubGraphOp&&) = default;
314 
315  SubGraphOp& operator=(const SubGraphOp&) = delete;
316  SubGraphOp& operator=(SubGraphOp&&) = default;
317 
318  int64_t get_num_iterations() const { return m_num_iterations; }
319 
320  protected:
321  int64_t m_num_iterations =
322  -1; // -1 means infinity for Loop op, inconsistent for TensorIterator
323 
324  // Find an input corresponding to value, adding one if necessary.
325  Input<Node> input_for_value(const Output<Node>& value);
326 
327  SubGraphOp() = default;
328 
329  explicit SubGraphOp(const OutputVector& args);
330 
331  std::shared_ptr<Function> m_body;
332  std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>
333  m_input_descriptions;
334  std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>
335  m_output_descriptions;
336  };
337  using InputDescriptionPtr = std::shared_ptr<util::SubGraphOp::InputDescription>;
338  using OutputDescriptionPtr = std::shared_ptr<util::SubGraphOp::OutputDescription>;
339  using InputDescriptionVector = std::vector<InputDescriptionPtr>;
340  using OutputDescriptionVector = std::vector<OutputDescriptionPtr>;
341  } // namespace util
342  } // namespace op
343 
344  template <>
345  class NGRAPH_API AttributeAdapter<
346  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>
347  : public DirectValueAccessor<
348  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>
349  {
350  public:
352  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>& value)
354  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>(
355  value)
356  {
357  }
358 
359  static constexpr DiscreteTypeInfo type_info{
360  "AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::"
361  "InputDescription>>>",
362  0};
363  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
364  };
365 
366  template <>
367  class NGRAPH_API AttributeAdapter<
368  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>
369  : public DirectValueAccessor<
370  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>
371  {
372  public:
374  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>& value)
376  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>(
377  value)
378  {
379  }
380 
381  static constexpr DiscreteTypeInfo type_info{
382  "AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::"
383  "OutputDescription>>>",
384  0};
385  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
386  };
387 } // namespace ngraph
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:161
Definition: attribute_adapter.hpp:67
A handle for one of a node's inputs.
Definition: node_input.hpp:32
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Root of all actual ops.
Definition: op.hpp:17
Produces an output from a specific iteration.
Definition: sub_graph_base.hpp:189
BodyOutputDescription(uint64_t body_value_index, uint64_t output_index, int64_t iteration)
Constructs a new instance.
Produces an output by concatenating an output from each iteration.
Definition: sub_graph_base.hpp:155
ConcatOutputDescription(uint64_t body_value_index, uint64_t output_index, int64_t start, int64_t stride, int64_t part_size, int64_t end, int64_t axis)
Constructs a new instance.
Describes a connection between a SubGraphOp input and the body.
Definition: sub_graph_base.hpp:24
InputDescription(uint64_t input_index, uint64_t body_parameter_index)
Constructs a new instance.
Describes a body input initialized from a SubGraphOp input on the first iteration,...
Definition: sub_graph_base.hpp:115
InvariantInputDescription(uint64_t input_index, uint64_t body_parameter_index)
Constructs a new instance.
Describes a body input initialized from a SubGraphOp input on the first iteration,...
Definition: sub_graph_base.hpp:87
MergedInputDescription(uint64_t input_index, uint64_t body_parameter_index, uint64_t body_value_index)
Constructs a new instance.
Describes how a SubGraphOp output is produced from the body.
Definition: sub_graph_base.hpp:132
OutputDescription(uint64_t body_value_index, uint64_t output_index)
Constructs a new instance.
Describes a body input formed from slices of an input to SubGraphOp.
Definition: sub_graph_base.hpp:51
SliceInputDescription(uint64_t input_index, uint64_t body_parameter_index, int64_t start, int64_t stride, int64_t part_size, int64_t end, int64_t axis)
Constructs a new instance.
Abstract base class for sub-graph based ops, i.e ops that have sub-graph.
Definition: sub_graph_base.hpp:19
virtual void set_sliced_input(const std::shared_ptr< Parameter > &parameter, const Output< Node > &value, int64_t start, int64_t stride, int64_t part_size, int64_t end, int64_t axis)
Indicate that a body parameter comes from slices of a value.
virtual Output< Node > get_iter_value(const Output< Node > &body_value, int64_t iteration=-1)
Gets a value for a particular iteration point.
virtual void set_invariant_input(const std::shared_ptr< Parameter > &body_parameter, const Output< Node > &value)
Indicates that a body parameter has an invariant value during iteration that may depend on values com...
std::vector< std::shared_ptr< InputDescription > > & get_input_descriptions()
Definition: sub_graph_base.hpp:220
const std::vector< std::shared_ptr< OutputDescription > > & get_output_descriptions() const
Definition: sub_graph_base.hpp:226
virtual void set_merged_input(const std::shared_ptr< Parameter > &body_parameter, const Output< Node > &initial_value, const Output< Node > &successive_value)
Indicates that a body parameter has an initial value in the first iteration and computed value therea...
virtual Output< Node > get_concatenated_slices(const Output< Node > &value, int64_t start, int64_t stride, int64_t part_size, int64_t end, int64_t axis)
Concatenates slices from all iterations.
std::vector< std::shared_ptr< OutputDescription > > & get_output_descriptions()
Definition: sub_graph_base.hpp:233
const std::vector< std::shared_ptr< InputDescription > > & get_input_descriptions() const
Definition: sub_graph_base.hpp:213
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27