broadcast_base.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #include "ngraph/axis_set.hpp"
18 #include "ngraph/axis_vector.hpp"
19 #include "ngraph/op/op.hpp"
20 #include "ngraph/op/util/attr_types.hpp"
21 
22 #pragma once
23 
24 namespace ngraph
25 {
26  namespace op
27  {
28  namespace util
29  {
30  class NGRAPH_API BroadcastBase : public Op
31  {
32  protected:
33  BroadcastBase() = default;
34  /// \brief Constructs a broadcast operation.
35  ///
36  /// \param arg The input tensor to be broadcast.
37  /// \param target_shape The shape of the output tensor.
38  /// \param axes_mapping The axis positions (0-based) in the result that correspond
39  /// to input axes.
40  /// \param broadcast_mode Broadcast specification to use for determining broadcast
41  /// axes. 'axes_mapping' should not be provided if mode other
42  ///
44  const Output<Node>& target_shape,
45  const Output<Node>& axes_mapping,
46  const BroadcastModeSpec& broadcast_mode = BroadcastType::EXPLICIT);
47 
48  /// \brief Constructs a broadcast operation.
49  ///
50  /// \param arg The input tensor to be broadcast.
51  /// \param target_shape The shape of the output tensor.
52  /// \param broadcast_mode Broadcast specification to use for determining broadcast
53  /// axes
55  const Output<Node>& target_shape,
56  const BroadcastModeSpec& broadcast_mode = BroadcastType::NUMPY);
57 
58  public:
59  void validate_and_infer_types() override;
60  /// \return true and the AxisSet if broadcast axes can be fully determined.
61  virtual std::pair<bool, AxisSet> get_broadcast_axes() const;
62 
63  bool evaluate(const HostTensorVector& outputs,
64  const HostTensorVector& inputs) const override;
65 
66  protected:
67  BroadcastModeSpec m_mode;
68 
69  bool evaluate_broadcast(const HostTensorPtr& arg0,
70  const HostTensorPtr& out,
71  const std::pair<bool, AxisSet> pair_broadcast_axes,
72  const Shape output_shape) const;
73 
74  bool evaluate(const HostTensorPtr& arg0,
75  const HostTensorPtr& out,
76  const AxisSet& broadcast_axes) const;
77  bool evaluate_lower(const HostTensorVector& outputs) const override;
78  bool evaluate_upper(const HostTensorVector& outputs) const override;
79 
81  get_result_shape_pdpd(const PartialShape& arg0_shape,
82  const PartialShape& target_shape,
83  const op::BroadcastModeSpec& broadcast_spec) const;
84 
85  void validate_target_shape_numpy(const PartialShape& arg_shape,
86  const PartialShape& target_shape) const;
87 
88  static std::pair<bool, AxisSet>
89  get_broadcast_axes_numpy_pdpd(const Shape& arg_shape,
90  const Shape& result_shape,
91  const op::BroadcastModeSpec& broadcast_spec);
92 
93  static std::pair<bool, AxisSet>
94  get_broadcast_axes_none(const AxisVector axes_mapping_val,
95  const size_t target_shape);
96 
97  void validate_target_shape_none(const PartialShape& arg_shape,
98  const AxisVector& axes_mapping_val,
99  const PartialShape& target_shape) const;
100 
101  Shape get_target_shape(const HostTensorPtr& input1) const;
102  };
103  }
104  }
105 }
A set of axes.
Definition: axis_set.hpp:31
A vector of axes.
Definition: axis_vector.hpp:30
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:46
Shape for a tensor.
Definition: shape.hpp:31
Root of all actual ops.
Definition: op.hpp:29
Definition: broadcast_base.hpp:31
virtual std::pair< bool, AxisSet > get_broadcast_axes() const
BroadcastBase(const Output< Node > &arg, const Output< Node > &target_shape, const Output< Node > &axes_mapping, const BroadcastModeSpec &broadcast_mode=BroadcastType::EXPLICIT)
Constructs a broadcast operation.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
BroadcastBase(const Output< Node > &arg, const Output< Node > &target_shape, const BroadcastModeSpec &broadcast_mode=BroadcastType::NUMPY)
Constructs a broadcast operation.
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
Implicit broadcast specification.
Definition: attr_types.hpp:379