tensor_iterator.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <vector>
20 
21 #include "ngraph/function.hpp"
22 #include "ngraph/op/parameter.hpp"
23 #include "ngraph/op/util/sub_graph_base.hpp"
24 
25 namespace ngraph
26 {
27  namespace op
28  {
29  namespace v0
30  {
31  /// \brief Iterate a body over tensors, accumulating into tensors.
32  class NGRAPH_API TensorIterator : public op::util::SubGraphOp
33  {
34  public:
35  static constexpr NodeTypeInfo type_info{"TensorIterator", 0};
36  const NodeTypeInfo& get_type_info() const override { return type_info; }
37  bool visit_attributes(AttributeVisitor& visitor) override;
38 
39  TensorIterator() = default;
40  explicit TensorIterator(const OutputVector& values);
41 
42  std::shared_ptr<Node>
43  clone_with_new_inputs(const OutputVector& new_args) const override;
44  /// \return the body of the iteration
45  std::shared_ptr<Function> get_body() const { return m_body; }
46  /// \param body set the body of the iteration
47  void set_body(const std::shared_ptr<Function>& body) { m_body = body; }
48  void validate_and_infer_types() override;
49  void revalidate_and_infer_types_for_body_ops();
50  /// \return the body of the iteration
51  std::shared_ptr<Function> get_function() override;
52 
53  private:
54  void try_to_set_num_iterations_if_no_slice_inputs();
55  };
56  }
57  using v0::TensorIterator;
58  }
59 }
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
Abstract base class for sub-graph based ops, i.e ops that have sub-graph.
Definition: sub_graph_base.hpp:31
Iterate a body over tensors, accumulating into tensors.
Definition: tensor_iterator.hpp:33
void set_body(const std::shared_ptr< Function > &body)
Definition: tensor_iterator.hpp:47
const NodeTypeInfo & get_type_info() const override
Definition: tensor_iterator.hpp:36
std::shared_ptr< Function > get_body() const
Definition: tensor_iterator.hpp:45
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
std::shared_ptr< Function > get_function() override
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
Definition: type.hpp:39