sub_graph_base.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <ngraph/op/parameter.hpp>
20 #include "ngraph/op/op.hpp"
21 
22 namespace ngraph
23 {
24  namespace op
25  {
26  namespace util
27  {
28  /// \brief Abstract base class for sub-graph based ops, i.e ops that have sub-graph
29  ///
30  class NGRAPH_API SubGraphOp : public Op
31  {
32  public:
33  NGRAPH_RTTI_DECLARATION;
34  /// \brief Describes a connection between a SubGraphOp input and the body.
36  {
37  protected:
38  ///
39  /// \brief Constructs a new instance.
40  ///
41  /// \param input_index Position of the SubGraphOp input
42  /// \param body_parameter_index Body parameter to receive input
43  ///
44  InputDescription(uint64_t input_index, uint64_t body_parameter_index);
45  InputDescription() = default;
46 
47  public:
49  virtual ~InputDescription() = default;
50  virtual std::shared_ptr<InputDescription> copy() const = 0;
51 
52  virtual const type_info_t& get_type_info() const = 0;
53 
54  uint64_t m_input_index{0};
55  uint64_t m_body_parameter_index{0};
56  };
57 
58  ///
59  /// \brief Describes a body input formed from slices of an input to
60  /// SubGraphOp.
61  ///
62  class NGRAPH_API SliceInputDescription : public InputDescription
63  {
64  public:
65  static constexpr type_info_t type_info{"SliceInputDescription", 0};
66  const type_info_t& get_type_info() const override { return type_info; }
67  ///
68  /// \brief Constructs a new instance.
69  ///
70  /// \param input_index Position of the SubGraphOp input
71  /// \param body_parameter_index Body parameter position to receive input
72  /// \param start First index for slices
73  /// \param stride Step amount for slices
74  /// \param part_size Width of slices
75  /// \param end Last index for slices
76  /// \param axis Axis being sliced
77  ///
78  SliceInputDescription(uint64_t input_index,
79  uint64_t body_parameter_index,
80  int64_t start,
81  int64_t stride,
82  int64_t part_size,
83  int64_t end,
84  int64_t axis);
85  SliceInputDescription() = default;
86  std::shared_ptr<InputDescription> copy() const override;
87  int64_t m_start{0};
88  int64_t m_stride{0};
89  int64_t m_part_size{0};
90  int64_t m_end{0};
91  int64_t m_axis{0};
92  };
93 
94  ///
95  /// \brief Describes a body input initialized from a SubGraphOp input on
96  /// the first iteration, and then a body output thereafter.
97  ///
98  class NGRAPH_API MergedInputDescription : public InputDescription
99  {
100  public:
101  static constexpr type_info_t type_info{"MergedInputDescription", 0};
102  const type_info_t& get_type_info() const override { return type_info; }
103  ///
104  /// \brief Constructs a new instance.
105  ///
106  /// \param input_index Position of the SubGraphOp input
107  /// supplying a value to body_parameter for
108  /// the initial iteration.
109  /// \param body_parameter_index Body parameter position to receive input.
110  /// \param body_value_index Body value to supply body_parameter for
111  /// successive
112  /// iterations.
113  ///
114  MergedInputDescription(uint64_t input_index,
115  uint64_t body_parameter_index,
116  uint64_t body_value_index);
117  MergedInputDescription() = default;
118  std::shared_ptr<InputDescription> copy() const override;
119  uint64_t m_body_value_index{0};
120  };
121 
122  ///
123  /// \brief Describes a body input initialized from a SubGraphOp input on
124  /// the first iteration, and invariant thereafter.
125  ///
126  class NGRAPH_API InvariantInputDescription : public InputDescription
127  {
128  public:
129  static constexpr type_info_t type_info{"InvariantInputDescription", 0};
130  const type_info_t& get_type_info() const override { return type_info; }
131  ///
132  /// \brief Constructs a new instance.
133  ///
134  /// \param input_index Position of the SubGraphOp input
135  /// \param body_parameter_index Body parameter to receive input
136  ///
137  InvariantInputDescription(uint64_t input_index, uint64_t body_parameter_index);
138  InvariantInputDescription() = default;
139  std::shared_ptr<InputDescription> copy() const override;
140  };
141 
142  /// \brief Describes how a SubGraphOp output is produced from the body.
144  {
145  protected:
146  ///
147  /// \brief Constructs a new instance.
148  ///
149  /// \param body_value_index A body value that produces the output
150  /// \param output_index The SubGraphOp output index
151  ///
152  OutputDescription(uint64_t body_value_index, uint64_t output_index);
153  OutputDescription() = default;
154 
155  public:
157  virtual ~OutputDescription() = default;
158  virtual std::shared_ptr<OutputDescription> copy() const = 0;
159  virtual const type_info_t& get_type_info() const = 0;
160 
161  uint64_t m_body_value_index{0};
162  uint64_t m_output_index{0};
163  };
164 
165  /// \brief Produces an output by concatenating an output from each iteration
166  class NGRAPH_API ConcatOutputDescription : public OutputDescription
167  {
168  public:
169  static constexpr type_info_t type_info{"ConcatOutputDescription", 0};
170  const type_info_t& get_type_info() const override { return type_info; }
171  ///
172  /// \brief Constructs a new instance.
173  ///
174  /// \param body_value_index A body value that produces the output
175  /// \param output_index The SubGraphOp output index
176  /// \param start First index for slices
177  /// \param stride Step amount for slices
178  /// \param part_size Width of slices
179  /// \param end Last index for slices
180  /// \param axis Axis being sliced
181  ///
182  ConcatOutputDescription(uint64_t body_value_index,
183  uint64_t output_index,
184  int64_t start,
185  int64_t stride,
186  int64_t part_size,
187  int64_t end,
188  int64_t axis);
189  ConcatOutputDescription() = default;
190 
191  std::shared_ptr<OutputDescription> copy() const override;
192  int64_t m_start{0};
193  int64_t m_stride{0};
194  int64_t m_part_size{0};
195  int64_t m_end{0};
196  int64_t m_axis{0};
197  };
198 
199  /// \brief Produces an output from a specific iteration
200  class NGRAPH_API BodyOutputDescription : public OutputDescription
201  {
202  public:
203  static constexpr type_info_t type_info{"BodyOutputDescription", 0};
204  const type_info_t& get_type_info() const override { return type_info; }
205  ///
206  /// \brief Constructs a new instance.
207  ///
208  /// \param body_value_index A body value that produces the output
209  /// \param output_index The SubGraphOp output index
210  /// \param iteration which iteration (typically -1, final) will
211  /// supply the value
212  ///
213  BodyOutputDescription(uint64_t body_value_index,
214  uint64_t output_index,
215  int64_t iteration);
216  BodyOutputDescription() = default;
217  std::shared_ptr<OutputDescription> copy() const override;
218  int64_t m_iteration{0};
219  };
220 
221  virtual std::shared_ptr<Function> get_function() { return m_body; };
222  virtual std::shared_ptr<const Function> get_function() const { return m_body; };
223  virtual void set_function(const std::shared_ptr<Function>& func) { m_body = func; };
224  /// \return a reference to the input descriptions.
225  const std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions() const
226  {
227  return m_input_descriptions;
228  }
229  /// \return a reference to the input descriptions. Can add input descriptions
230  /// before
231  /// validation.
232  std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions()
233  {
234  return m_input_descriptions;
235  }
236  /// \return a reference to the output descriptions.
237  const std::vector<std::shared_ptr<OutputDescription>>&
239  {
240  return m_output_descriptions;
241  }
242  /// \return a reference to the output descriptions. Can add output descriptions
243  /// before
244  /// validation.
245  std::vector<std::shared_ptr<OutputDescription>>& get_output_descriptions()
246  {
247  return m_output_descriptions;
248  }
249 
250  ///
251  /// \brief Indicate that a body parameter comes from slices of a value
252  ///
253  /// \param parameter The parameter to receive the slices
254  /// \param value The value to be sliced. This will be added as an input to
255  /// SubGraphOp.
256  /// \param start First index on axis of the slicing
257  /// \param stride Stepping of the slice
258  /// \param part_size Size of the slice on axis
259  /// \param end The last index on axis of the slicing
260  /// \param axis The axis to slice along
261  ///
262  virtual void set_sliced_input(const std::shared_ptr<Parameter>& parameter,
263  const Output<Node>& value,
264  int64_t start,
265  int64_t stride,
266  int64_t part_size,
267  int64_t end,
268  int64_t axis);
269  ///
270  /// \brief Indicates that a body parameter has an initial value in the first
271  /// iteration and computed value thereafter
272  ///
273  /// \param[in] body_parameter The body parameter
274  /// \param initial_value Value for the parameter in first iteration. This
275  /// will be added as an input to Loop.
276  /// \param successive_value Value for the parameter in successive iterations.
277  /// The value is what is active in the most recent
278  /// completed iteration.
279  ///
280  virtual void set_merged_input(const std::shared_ptr<Parameter>& body_parameter,
281  const Output<Node>& initial_value,
282  const Output<Node>& successive_value);
283  ///
284  /// \brief Indicates that a body parameter has an invariant value during
285  /// iteration that may depend on values computed outside of the
286  /// iteration.
287  ///
288  /// \param body_parameter The body parameter
289  /// \param value The value supplied as an input to the block
290  ///
291  virtual void set_invariant_input(const std::shared_ptr<Parameter>& body_parameter,
292  const Output<Node>& value);
293  ///
294  /// \brief Gets a value for a particular iteration point
295  ///
296  /// \param body_value The value
297  /// \param iteration The iteration that supplies the value. Negative values
298  /// are from the last iteration.
299  /// Default value -1 (the last iteration).
300  ///
301  /// \return The iterator value.
302  ///
303  virtual Output<Node> get_iter_value(const Output<Node>& body_value,
304  int64_t iteration = -1);
305  ///
306  /// \brief Concatenates slices from all iterations
307  ///
308  /// \param value The value supplying slice values from each iteration.
309  /// \param start First index on axis of the slicing
310  /// \param stride Stepping of the slice
311  /// \param part_size Size of the slice on axis
312  /// \param end The last index on axis of the slicing
313  /// \param axis The axis to slice along
314  ///
315  /// \return The concatenated slices.
316  ///
318  int64_t start,
319  int64_t stride,
320  int64_t part_size,
321  int64_t end,
322  int64_t axis);
323 
324  SubGraphOp(const SubGraphOp&) = delete;
325  SubGraphOp(SubGraphOp&&) = default;
326 
327  SubGraphOp& operator=(const SubGraphOp&) = delete;
328  SubGraphOp& operator=(SubGraphOp&&) = default;
329 
330  int64_t get_num_iterations() const { return m_num_iterations; }
331  protected:
332  int64_t m_num_iterations =
333  -1; // -1 means infinity for Loop op, inconsistent for TensorIterator
334 
335  // Find an input corresponding to value, adding one if necessary.
336  Input<Node> input_for_value(const Output<Node>& value);
337 
338  SubGraphOp() = default;
339 
340  explicit SubGraphOp(const OutputVector& args);
341 
342  std::shared_ptr<Function> m_body;
343  std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>
344  m_input_descriptions;
345  std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>
346  m_output_descriptions;
347  };
348  using InputDescriptionPtr = std::shared_ptr<util::SubGraphOp::InputDescription>;
349  using OutputDescriptionPtr = std::shared_ptr<util::SubGraphOp::OutputDescription>;
350  using InputDescriptionVector = std::vector<InputDescriptionPtr>;
351  using OutputDescriptionVector = std::vector<OutputDescriptionPtr>;
352  }
353  }
354 
355  template <>
356  class NGRAPH_API AttributeAdapter<
357  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>
358  : public DirectValueAccessor<
359  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>
360  {
361  public:
363  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>& value)
365  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>(
366  value)
367  {
368  }
369 
370  static constexpr DiscreteTypeInfo type_info{
371  "AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::"
372  "InputDescription>>>",
373  0};
374  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
375  };
376 
377  template <>
378  class NGRAPH_API AttributeAdapter<
379  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>
380  : public DirectValueAccessor<
381  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>
382  {
383  public:
385  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>& value)
387  std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>(
388  value)
389  {
390  }
391 
392  static constexpr DiscreteTypeInfo type_info{
393  "AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::"
394  "OutputDescription>>>",
395  0};
396  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
397  };
398 }
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:171
Definition: attribute_adapter.hpp:79
A handle for one of a node's inputs.
Definition: node_input.hpp:41
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Root of all actual ops.
Definition: op.hpp:29
Produces an output from a specific iteration.
Definition: sub_graph_base.hpp:201
BodyOutputDescription(uint64_t body_value_index, uint64_t output_index, int64_t iteration)
Constructs a new instance.
Produces an output by concatenating an output from each iteration.
Definition: sub_graph_base.hpp:167
ConcatOutputDescription(uint64_t body_value_index, uint64_t output_index, int64_t start, int64_t stride, int64_t part_size, int64_t end, int64_t axis)
Constructs a new instance.
Describes a connection between a SubGraphOp input and the body.
Definition: sub_graph_base.hpp:36
InputDescription(uint64_t input_index, uint64_t body_parameter_index)
Constructs a new instance.
Describes a body input initialized from a SubGraphOp input on the first iteration,...
Definition: sub_graph_base.hpp:127
InvariantInputDescription(uint64_t input_index, uint64_t body_parameter_index)
Constructs a new instance.
Describes a body input initialized from a SubGraphOp input on the first iteration,...
Definition: sub_graph_base.hpp:99
MergedInputDescription(uint64_t input_index, uint64_t body_parameter_index, uint64_t body_value_index)
Constructs a new instance.
Describes how a SubGraphOp output is produced from the body.
Definition: sub_graph_base.hpp:144
OutputDescription(uint64_t body_value_index, uint64_t output_index)
Constructs a new instance.
Describes a body input formed from slices of an input to SubGraphOp.
Definition: sub_graph_base.hpp:63
SliceInputDescription(uint64_t input_index, uint64_t body_parameter_index, int64_t start, int64_t stride, int64_t part_size, int64_t end, int64_t axis)
Constructs a new instance.
Abstract base class for sub-graph based ops, i.e ops that have sub-graph.
Definition: sub_graph_base.hpp:31
virtual void set_sliced_input(const std::shared_ptr< Parameter > &parameter, const Output< Node > &value, int64_t start, int64_t stride, int64_t part_size, int64_t end, int64_t axis)
Indicate that a body parameter comes from slices of a value.
virtual Output< Node > get_iter_value(const Output< Node > &body_value, int64_t iteration=-1)
Gets a value for a particular iteration point.
virtual void set_invariant_input(const std::shared_ptr< Parameter > &body_parameter, const Output< Node > &value)
Indicates that a body parameter has an invariant value during iteration that may depend on values com...
std::vector< std::shared_ptr< InputDescription > > & get_input_descriptions()
Definition: sub_graph_base.hpp:232
const std::vector< std::shared_ptr< OutputDescription > > & get_output_descriptions() const
Definition: sub_graph_base.hpp:238
virtual void set_merged_input(const std::shared_ptr< Parameter > &body_parameter, const Output< Node > &initial_value, const Output< Node > &successive_value)
Indicates that a body parameter has an initial value in the first iteration and computed value therea...
virtual Output< Node > get_concatenated_slices(const Output< Node > &value, int64_t start, int64_t stride, int64_t part_size, int64_t end, int64_t axis)
Concatenates slices from all iterations.
std::vector< std::shared_ptr< OutputDescription > > & get_output_descriptions()
Definition: sub_graph_base.hpp:245
const std::vector< std::shared_ptr< InputDescription > > & get_input_descriptions() const
Definition: sub_graph_base.hpp:225
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
Definition: type.hpp:39