broadcast_base.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #include "ngraph/axis_set.hpp"
6 #include "ngraph/axis_vector.hpp"
7 #include "ngraph/op/op.hpp"
8 #include "ngraph/op/util/attr_types.hpp"
9 
10 #pragma once
11 
12 namespace ngraph
13 {
14  namespace op
15  {
16  namespace util
17  {
18  class NGRAPH_API BroadcastBase : public Op
19  {
20  protected:
21  BroadcastBase() = default;
22  /// \brief Constructs a broadcast operation.
23  ///
24  /// \param arg The input tensor to be broadcast.
25  /// \param target_shape The shape of the output tensor.
26  /// \param axes_mapping The axis positions (0-based) in the result that correspond
27  /// to input axes.
28  /// \param broadcast_mode Broadcast specification to use for determining broadcast
29  /// axes. 'axes_mapping' should not be provided if mode other
30  ///
32  const Output<Node>& target_shape,
33  const Output<Node>& axes_mapping,
34  const BroadcastModeSpec& broadcast_mode = BroadcastType::EXPLICIT);
35 
36  /// \brief Constructs a broadcast operation.
37  ///
38  /// \param arg The input tensor to be broadcast.
39  /// \param target_shape The shape of the output tensor.
40  /// \param broadcast_mode Broadcast specification to use for determining broadcast
41  /// axes
43  const Output<Node>& target_shape,
44  const BroadcastModeSpec& broadcast_mode = BroadcastType::NUMPY);
45 
46  public:
47  void validate_and_infer_types() override;
48  /// \return true and the AxisSet if broadcast axes can be fully determined.
49  virtual std::pair<bool, AxisSet> get_broadcast_axes() const;
50 
51  bool evaluate(const HostTensorVector& outputs,
52  const HostTensorVector& inputs) const override;
53 
54  protected:
55  BroadcastModeSpec m_mode;
56 
57  bool evaluate_broadcast(const HostTensorPtr& arg0,
58  const HostTensorPtr& out,
59  const std::pair<bool, AxisSet> pair_broadcast_axes,
60  const Shape output_shape) const;
61 
62  bool evaluate_broadcast(const HostTensorPtr& arg0,
63  const HostTensorPtr& out,
64  const AxisSet& broadcast_axes) const;
65 
66  bool evaluate_lower(const HostTensorVector& outputs) const override;
67  bool evaluate_upper(const HostTensorVector& outputs) const override;
68 
70  get_result_shape_pdpd(const PartialShape& arg0_shape,
71  const PartialShape& target_shape,
72  const op::BroadcastModeSpec& broadcast_spec) const;
73 
74  void validate_target_shape_numpy(const PartialShape& arg_shape,
75  const PartialShape& target_shape) const;
76 
77  static std::pair<bool, AxisSet>
78  get_broadcast_axes_numpy_pdpd(const Shape& arg_shape,
79  const Shape& result_shape,
80  const op::BroadcastModeSpec& broadcast_spec);
81 
82  static std::pair<bool, AxisSet>
83  get_broadcast_axes_none(const AxisVector axes_mapping_val,
84  const size_t target_shape);
85 
86  void validate_target_shape_none(const PartialShape& arg_shape,
87  const AxisVector& axes_mapping_val,
88  const PartialShape& target_shape) const;
89 
90  Shape get_target_shape(const HostTensorPtr& input1) const;
91  };
92  } // namespace util
93  } // namespace op
94 } // namespace ngraph
A set of axes.
Definition: axis_set.hpp:19
A vector of axes.
Definition: axis_vector.hpp:18
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:34
Shape for a tensor.
Definition: shape.hpp:19
Root of all actual ops.
Definition: op.hpp:17
Definition: broadcast_base.hpp:19
virtual std::pair< bool, AxisSet > get_broadcast_axes() const
BroadcastBase(const Output< Node > &arg, const Output< Node > &target_shape, const Output< Node > &axes_mapping, const BroadcastModeSpec &broadcast_mode=BroadcastType::EXPLICIT)
Constructs a broadcast operation.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
BroadcastBase(const Output< Node > &arg, const Output< Node > &target_shape, const BroadcastModeSpec &broadcast_mode=BroadcastType::NUMPY)
Constructs a broadcast operation.
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Implicit broadcast specification.
Definition: attr_types.hpp:370