ngraph.utils.broadcasting.get_broadcast_axes

ngraph.utils.broadcasting.get_broadcast_axes(output_shape: List[int], input_shape: List[int], axis: int = None)_pyngraph.AxisSet

Generate a list of broadcast axes for ngraph++ broadcast.

Informally, a broadcast “adds” axes to the input tensor, replicating elements from the input tensor as needed to fill the new dimensions. Function calculate which of the output axes are added in this way.

Parameters
  • output_shape – The new shape for the output tensor.

  • input_shape – The shape of input tensor.

  • axis – The axis along which we want to replicate elements.

Returns

The indices of added axes.