ngraph.opset6.gather_tree

ngraph.opset6.gather_tree(step_ids: Union[_pyngraph.Node, int, float, numpy.ndarray], parent_idx: Union[_pyngraph.Node, int, float, numpy.ndarray], max_seq_len: Union[_pyngraph.Node, int, float, numpy.ndarray], end_token: Union[_pyngraph.Node, int, float, numpy.ndarray], name: Optional[str] = None)_pyngraph.Node

Perform GatherTree operation.

Parameters
  • step_ids – The tensor with indices from per each step.

  • parent_idx – The tensor with with parent beam indices.

  • max_seq_len – The tensor with maximum lengths for each sequence in the batch.

  • end_token – The scalar tensor with value of the end marker in a sequence.

  • name – Optional name for output node.

Returns

The new node performing a GatherTree operation.

The GatherTree node generates the complete beams from the indices per each step and the parent beam indices. GatherTree uses the following logic:

for batch in range(BATCH_SIZE):
    for beam in range(BEAM_WIDTH):
        max_sequence_in_beam = min(MAX_TIME, max_seq_len[batch])

        parent = parent_idx[max_sequence_in_beam - 1, batch, beam]

        for level in reversed(range(max_sequence_in_beam - 1)):
            final_idx[level, batch, beam] = step_idx[level, batch, parent]

            parent = parent_idx[level, batch, parent]