tensor_iterator.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <vector>
8 
9 #include "ngraph/function.hpp"
10 #include "ngraph/op/parameter.hpp"
11 #include "ngraph/op/util/sub_graph_base.hpp"
12 
13 namespace ngraph
14 {
15  namespace op
16  {
17  namespace v0
18  {
19  /// \brief Iterate a body over tensors, accumulating into tensors.
20  class NGRAPH_API TensorIterator : public op::util::SubGraphOp
21  {
22  public:
23  static constexpr NodeTypeInfo type_info{"TensorIterator", 0};
24  const NodeTypeInfo& get_type_info() const override { return type_info; }
25  bool visit_attributes(AttributeVisitor& visitor) override;
26 
27  TensorIterator() = default;
28  explicit TensorIterator(const OutputVector& values);
29 
30  std::shared_ptr<Node>
31  clone_with_new_inputs(const OutputVector& new_args) const override;
32  /// \return the body of the iteration
33  std::shared_ptr<Function> get_body() const { return m_body; }
34  /// \param body set the body of the iteration
35  void set_body(const std::shared_ptr<Function>& body) { m_body = body; }
36  void validate_and_infer_types() override;
37  void revalidate_and_infer_types_for_body_ops();
38  /// \return the body of the iteration
39  std::shared_ptr<Function> get_function() override;
40 
41  private:
42  void try_to_set_num_iterations_if_no_slice_inputs();
43  };
44  } // namespace v0
45  using v0::TensorIterator;
46  } // namespace op
47 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
Abstract base class for sub-graph based ops, i.e ops that have sub-graph.
Definition: sub_graph_base.hpp:19
Iterate a body over tensors, accumulating into tensors.
Definition: tensor_iterator.hpp:21
void set_body(const std::shared_ptr< Function > &body)
Definition: tensor_iterator.hpp:35
const NodeTypeInfo & get_type_info() const override
Definition: tensor_iterator.hpp:24
std::shared_ptr< Function > get_body() const
Definition: tensor_iterator.hpp:33
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
std::shared_ptr< Function > get_function() override
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27