loop.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <vector>
8 
9 #include "ngraph/factory_adapter.hpp"
10 #include "ngraph/function.hpp"
11 #include "ngraph/op/constant.hpp"
12 #include "ngraph/op/parameter.hpp"
13 #include "ngraph/op/tensor_iterator.hpp"
14 #include "ngraph/op/util/sub_graph_base.hpp"
15 
16 namespace ngraph
17 {
18  namespace op
19  {
20  namespace v5
21  {
22  /// \brief Iterate a body over tensors, accumulating into tensors.
23  class NGRAPH_API Loop : public op::util::SubGraphOp
24  {
25  public:
26  /// \brief Allows to define the purpose of inputs/outputs in the body
28  {
29  SpecialBodyPorts() = default;
30  SpecialBodyPorts(int64_t in_current_iteration_input_idx,
31  int64_t in_body_condition_output_idx)
32  : current_iteration_input_idx(in_current_iteration_input_idx)
33  , body_condition_output_idx(in_body_condition_output_idx)
34  {
35  }
36  // -1 means the input is not provided, this input is optional
37  int64_t current_iteration_input_idx = -1;
38  // -1 means the output is not provided,
39  // this output is required, throw an exception if not provided
40  int64_t body_condition_output_idx = -1;
41  };
42 
43  NGRAPH_RTTI_DECLARATION;
44 
45  /// \brief Constructs a Loop operation.
46  Loop() = default;
47 
48  /// \brief Constructs a Loop operation.
49  ///
50  /// \param trip_count Node specifies the maximum number of iterations.
51  /// \param execution_condition Node determines whether to execute the first
52  /// iteration or not.
53  Loop(const Output<Node>& trip_count, const Output<Node>& execution_condition);
54 
56  int64_t start,
57  int64_t stride,
58  int64_t part_size,
59  int64_t end,
60  int64_t axis) override;
61 
62  void set_special_body_ports(const SpecialBodyPorts& special_body_ports)
63  {
64  m_special_body_ports = special_body_ports;
65  }
66 
67  SpecialBodyPorts get_special_body_ports() const { return m_special_body_ports; }
68  void validate_and_infer_types() override;
69  bool visit_attributes(AttributeVisitor& visitor) override;
70  std::shared_ptr<Node>
71  clone_with_new_inputs(const OutputVector& new_args) const override;
72 
73  bool evaluate(const HostTensorVector& outputs,
74  const HostTensorVector& inputs) const override;
75  bool has_evaluate() const override;
76 
77  protected:
78  Loop(const Loop&);
79 
80  private:
81  void clone_to(Loop& dst, const OutputVector& new_args) const;
82 
83  SpecialBodyPorts m_special_body_ports;
84  };
85  } // namespace v5
86  } // namespace op
87 
88  template <>
89  class NGRAPH_API AttributeAdapter<op::v5::Loop::SpecialBodyPorts>
90  : public DirectValueAccessor<op::v5::Loop::SpecialBodyPorts>
91  {
92  public:
95  {
96  }
97 
98  static constexpr DiscreteTypeInfo type_info{
99  "AttributeAdapter<op::v5::Loop::SpecialBodyPorts>", 0};
100  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
101  };
102 } // namespace ngraph
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:161
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
Definition: attribute_adapter.hpp:67
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Abstract base class for sub-graph based ops, i.e ops that have sub-graph.
Definition: sub_graph_base.hpp:19
Iterate a body over tensors, accumulating into tensors.
Definition: loop.hpp:24
Loop()=default
Constructs a Loop operation.
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Output< Node > get_concatenated_slices(const Output< Node > &value, int64_t start, int64_t stride, int64_t part_size, int64_t end, int64_t axis) override
Concatenates slices from all iterations.
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
Loop(const Output< Node > &trip_count, const Output< Node > &execution_condition)
Constructs a Loop operation.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27
Allows to define the purpose of inputs/outputs in the body.
Definition: loop.hpp:28