variadic_split.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include "ngraph/coordinate.hpp"
20 #include "ngraph/op/op.hpp"
21 #include "ngraph/strides.hpp"
22 
23 namespace ngraph
24 {
25  namespace op
26  {
27  namespace v1
28  {
29  /// \brief VariadicSplit operation splits an input tensor into pieces along some axis.
30  /// The pieces may have variadic lengths depending on "split_lengths" attribute.
31  class NGRAPH_API VariadicSplit : public Op
32  {
33  public:
34  NGRAPH_RTTI_DECLARATION;
35 
36  /// \brief Constructs a variadic split operation.
37  VariadicSplit() = default;
38  /// \brief Constructs a variadic split operation.
39  ///
40  /// \param data The tensor to be split.
41  /// \param axis The index of an axis in "data" along which to perform the
42  /// split.
43  /// \param split_lengths A list containing the sizes of each output tensor
44  /// along the split "axis". Size of "split_lengths" should be equal to the number of
45  ///
46  /// outputs. The sum of split_lengths must match data.shape[axis]
48  const Output<Node>& axis,
49  const Output<Node>& split_lengths);
50 
51  bool visit_attributes(AttributeVisitor& visitor) override;
52 
53  void validate_and_infer_types() override;
54  virtual std::shared_ptr<Node>
55  clone_with_new_inputs(const OutputVector& new_args) const override;
56  size_t get_default_output_index() const override { return no_default_index(); }
57  bool evaluate(const HostTensorVector& outputs,
58  const HostTensorVector& inputs) const override;
59 
60  private:
61  bool evaluate_variadic_split(const HostTensorVector& outputs,
62  const HostTensorVector& inputs) const;
63  };
64  } // namespace v1
65 
66  using v1::VariadicSplit;
67  } // namespace op
68 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Root of all actual ops.
Definition: op.hpp:29
VariadicSplit operation splits an input tensor into pieces along some axis. The pieces may have varia...
Definition: variadic_split.hpp:32
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
VariadicSplit()=default
Constructs a variadic split operation.
size_t get_default_output_index() const override
Returns the output of the default output, or throws if there is none.
Definition: variadic_split.hpp:56
VariadicSplit(const Output< Node > &data, const Output< Node > &axis, const Output< Node > &split_lengths)
Constructs a variadic split operation.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28