VariadicSplit#

Versioned name: VariadicSplit-1

Category: Data movement

Short description: VariadicSplit operation splits an input tensor into chunks along some axis. The chunks may have variadic lengths depending on split_lengths input tensor.

Detailed Description

VariadicSplit operation splits a given input tensor data into chunks along a scalar or tensor with shape [1] axis. It produces multiple output tensors based on additional input tensor split_lengths. The i-th output tensor shape is equal to the input tensor data shape, except for dimension along axis which is split_lengths[i].

\[shape\_output\_tensor = [data.shape[0], data.shape[1], \dotsc , split\_lengths[i], \dotsc , data.shape[D-1]]\]

Where D is the rank of input tensor data. The sum of elements in split_lengths must match data.shape[axis].

Attributes: VariadicSplit operation has no attributes.

Inputs

  • 1: data. A tensor of type T1 and arbitrary shape. Required.

  • 2: axis. Axis along data to split. A scalar or tensor with shape [1] of type T2 with value from range -rank(data) .. rank(data)-1. Negative values address dimensions from the end. Required.

  • 3: split_lengths. A list containing the dimension values of each output tensor shape along the split axis. A 1D tensor of type T2. The number of elements in split_lengths determines the number of outputs. The sum of elements in split_lengths must match data.shape[axis]. In addition split_lengths can contain a single -1 element, which means, all remaining items along specified axis that are not consumed by other parts. Required.

Outputs

  • Multiple outputs: Tensors of type T1. The i-th output has the same shape as data input tensor except for dimension along axis which is split_lengths[i] if split_lengths[i] != -1. Otherwise, the dimension along axis is processed as described in split_lengths input description.

Types

  • T1: any arbitrary supported type.

  • T2: any integer type.

Examples

 <layer id="1" type="VariadicSplit" ...>
     <input>
         <port id="0">            <!-- some data -->
             <dim>6</dim>
             <dim>12</dim>
             <dim>10</dim>
             <dim>24</dim>
         </port>
         <port id="1">            <!-- axis: 0 -->
         </port>
         <port id="2">
             <dim>3</dim>         <!-- split_lengths: [1, 2, 3] -->
         </port>
     </input>
     <output>
         <port id="3">
             <dim>1</dim>
             <dim>12</dim>
             <dim>10</dim>
             <dim>24</dim>
         </port>
         <port id="4">
             <dim>2</dim>
             <dim>12</dim>
             <dim>10</dim>
             <dim>24</dim>
         </port>
         <port id="5">
             <dim>3</dim>
             <dim>12</dim>
             <dim>10</dim>
             <dim>24</dim>
         </port>
     </output>
 </layer>
 <layer id="1" type="VariadicSplit" ...>
     <input>
         <port id="0">            <!-- some data -->
             <dim>6</dim>
             <dim>12</dim>
             <dim>10</dim>
             <dim>24</dim>
         </port>
         <port id="1">            <!-- axis: 0 -->
         </port>
         <port id="2">
             <dim>2</dim>         <!-- split_lengths: [-1, 2] -->
         </port>
     </input>
     <output>
         <port id="3">
             <dim>4</dim>         <!--  4 = 6 - 2  -->
             <dim>12</dim>
             <dim>10</dim>
             <dim>24</dim>
         </port>
         <port id="4">
             <dim>2</dim>
             <dim>12</dim>
             <dim>10</dim>
             <dim>24</dim>
         </port>
     </output>
 </layer>