loop.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <vector>
20 
21 #include "ngraph/factory_adapter.hpp"
22 #include "ngraph/function.hpp"
23 #include "ngraph/op/constant.hpp"
24 #include "ngraph/op/parameter.hpp"
25 #include "ngraph/op/tensor_iterator.hpp"
26 #include "ngraph/op/util/sub_graph_base.hpp"
27 
28 namespace ngraph
29 {
30  namespace op
31  {
32  namespace v5
33  {
34  /// \brief Iterate a body over tensors, accumulating into tensors.
35  class NGRAPH_API Loop : public op::util::SubGraphOp
36  {
37  public:
38  /// \brief Allows to define the purpose of inputs/outputs in the body
40  {
41  SpecialBodyPorts() = default;
42  SpecialBodyPorts(int64_t in_current_iteration_input_idx,
43  int64_t in_body_condition_output_idx)
44  : current_iteration_input_idx(in_current_iteration_input_idx)
45  , body_condition_output_idx(in_body_condition_output_idx)
46  {
47  }
48  // -1 means the input is not provided, this input is optional
49  int64_t current_iteration_input_idx = -1;
50  // -1 means the output is not provided,
51  // this output is required, throw an exception if not provided
52  int64_t body_condition_output_idx = -1;
53  };
54 
55  NGRAPH_RTTI_DECLARATION;
56 
57  /// \brief Constructs a Loop operation.
58  Loop() = default;
59 
60  /// \brief Constructs a Loop operation.
61  ///
62  /// \param trip_count Node specifies the maximum number of iterations.
63  /// \param execution_condition Node determines whether to execute the first
64  /// iteration or not.
65  Loop(const Output<Node>& trip_count, const Output<Node>& execution_condition);
66 
68  int64_t start,
69  int64_t stride,
70  int64_t part_size,
71  int64_t end,
72  int64_t axis) override;
73 
74  void set_special_body_ports(const SpecialBodyPorts& special_body_ports)
75  {
76  m_special_body_ports = special_body_ports;
77  }
78 
79  SpecialBodyPorts get_special_body_ports() const { return m_special_body_ports; }
80  void validate_and_infer_types() override;
81  bool visit_attributes(AttributeVisitor& visitor) override;
82  std::shared_ptr<Node>
83  clone_with_new_inputs(const OutputVector& new_args) const override;
84 
85  bool evaluate(const HostTensorVector& outputs,
86  const HostTensorVector& inputs) const override;
87 
88  protected:
89  Loop(const Loop&);
90 
91  private:
92  void clone_to(Loop& dst, const OutputVector& new_args) const;
93 
94  SpecialBodyPorts m_special_body_ports;
95  };
96  }
97  }
98 
99  template <>
100  class NGRAPH_API AttributeAdapter<op::v5::Loop::SpecialBodyPorts>
101  : public DirectValueAccessor<op::v5::Loop::SpecialBodyPorts>
102  {
103  public:
106  {
107  }
108 
109  static constexpr DiscreteTypeInfo type_info{
110  "AttributeAdapter<op::v5::Loop::SpecialBodyPorts>", 0};
111  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
112  };
113 }
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:171
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
Definition: attribute_adapter.hpp:79
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Abstract base class for sub-graph based ops, i.e ops that have sub-graph.
Definition: sub_graph_base.hpp:31
Iterate a body over tensors, accumulating into tensors.
Definition: loop.hpp:36
Loop()=default
Constructs a Loop operation.
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Output< Node > get_concatenated_slices(const Output< Node > &value, int64_t start, int64_t stride, int64_t part_size, int64_t end, int64_t axis) override
Concatenates slices from all iterations.
Loop(const Output< Node > &trip_count, const Output< Node > &execution_condition)
Constructs a Loop operation.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
Definition: type.hpp:39
Allows to define the purpose of inputs/outputs in the body.
Definition: loop.hpp:40