broadcast.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/axis_set.hpp"
8 #include "ngraph/op/op.hpp"
9 #include "ngraph/op/util/attr_types.hpp"
10 #include "ngraph/op/util/broadcast_base.hpp"
11 
12 namespace ngraph
13 {
14  namespace op
15  {
16  namespace v3
17  {
18  /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
19  /// input as needed along the new axes.
20  class NGRAPH_API Broadcast : public util::BroadcastBase
21  {
22  public:
23  static constexpr NodeTypeInfo type_info{"Broadcast", 3};
24  const NodeTypeInfo& get_type_info() const override { return type_info; }
25  /// \brief Constructs a broadcast operation.
26  Broadcast() = default;
27  /// \brief Constructs a broadcast operation.
28  ///
29  /// \param arg The input tensor to be broadcast.
30  /// \param target_shape The shape of the output tensor.
31  /// \param axes_mapping The axis positions (0-based) in the result that correspond
32  /// to input axes. 'Arg' tensor is broadcast along the
33  /// remaining axes.
34  /// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
35  /// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
36  /// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
37  /// \param broadcast_spec Broadcast specification to use for determining broadcast
38  /// axes. 'axes_mapping' should not be provided if mode other
39  /// than explicit (none) is used.
40  Broadcast(const Output<Node>& arg,
41  const Output<Node>& target_shape,
42  const Output<Node>& axes_mapping,
43  const BroadcastModeSpec& broadcast_spec = BroadcastType::EXPLICIT);
44 
45  /// \brief Constructs a broadcast operation.
46  ///
47  /// \param arg The input tensor to be broadcast.
48  /// \param target_shape The shape of the output tensor.
49  /// \param broadcast_spec Broadcast specification to use for determining broadcast
50  /// axes
51  Broadcast(const Output<Node>& arg,
52  const Output<Node>& target_shape,
53  const BroadcastModeSpec& broadcast_spec = BroadcastType::NUMPY);
54 
55  bool visit_attributes(AttributeVisitor& visitor) override;
56 
57  std::shared_ptr<Node>
58  clone_with_new_inputs(const OutputVector& new_args) const override;
59 
60  // \return Broadcast Specification.
61  const BroadcastModeSpec& get_broadcast_spec() const { return m_mode; }
62  void set_broadcast_spec(const BroadcastModeSpec& broadcast_spec)
63  {
64  m_mode = broadcast_spec;
65  }
66 
67  void validate_and_infer_types() override;
68 
69  /// \return true and the AxisSet if broadcast axes can be fully determined.
70  std::pair<bool, AxisSet> get_broadcast_axes() const override;
71  bool evaluate(const HostTensorVector& outputs,
72  const HostTensorVector& inputs) const override;
73  bool has_evaluate() const override;
74 
75  private:
76  bool broadcast_evaluate(const HostTensorVector& outputs,
77  const HostTensorVector& inputs) const;
78  };
79  } // namespace v3
80 
81  namespace v1
82  {
83  /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
84  /// input as needed along the new axes.
85  class NGRAPH_API Broadcast : public util::BroadcastBase
86  {
87  public:
88  static constexpr NodeTypeInfo type_info{"Broadcast", 1};
89  const NodeTypeInfo& get_type_info() const override { return type_info; }
90  /// \brief Constructs a broadcast operation.
91  Broadcast() = default;
92  /// \brief Constructs a broadcast operation.
93  ///
94  /// \param arg The input tensor to be broadcast.
95  /// \param target_shape The shape of the output tensor.
96  /// \param axes_mapping The axis positions (0-based) in the result that correspond
97  /// to input axes. 'Arg' tensor is broadcast along the
98  /// remaining axes.
99  /// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
100  /// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
101  /// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
102  /// \param broadcast_spec Broadcast specification to use for determining broadcast
103  /// axes. 'axes_mapping' is ignored if broadcast_spec is not
104  /// NONE
106  const Output<Node>& target_shape,
107  const Output<Node>& axes_mapping,
108  const AutoBroadcastSpec& broadcast_spec = AutoBroadcastSpec());
109 
110  /// \brief Constructs a broadcast operation.
111  ///
112  /// \param arg The input tensor to be broadcast.
113  /// \param target_shape The shape of the output tensor.
114  /// \param broadcast_spec Broadcast specification to use for determining broadcast
115  /// axes
117  const Output<Node>& target_shape,
118  const AutoBroadcastSpec& broadcast_spec =
119  AutoBroadcastSpec(AutoBroadcastType::NUMPY));
120 
121  bool visit_attributes(AttributeVisitor& visitor) override;
122 
123  std::shared_ptr<Node>
124  clone_with_new_inputs(const OutputVector& new_args) const override;
125 
126  /// \return Broadcast Specification.
127  const AutoBroadcastSpec& get_broadcast_spec() const { return m_broadcast_spec; }
128  void set_broadcast_spec(const AutoBroadcastSpec& broadcast_spec)
129  {
130  m_broadcast_spec = broadcast_spec;
131  }
132 
133  void validate_and_infer_types() override;
134  bool evaluate(const HostTensorVector& outputs,
135  const HostTensorVector& inputs) const override;
136  bool has_evaluate() const override;
137 
138  protected:
139  AutoBroadcastSpec m_broadcast_spec;
140  };
141  } // namespace v1
142  } // namespace op
143 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Definition: broadcast_base.hpp:19
Operation which "adds" axes to an input tensor, replicating elements from the input as needed along t...
Definition: broadcast.hpp:86
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
Broadcast()=default
Constructs a broadcast operation.
const AutoBroadcastSpec & get_broadcast_spec() const
Definition: broadcast.hpp:127
const NodeTypeInfo & get_type_info() const override
Definition: broadcast.hpp:89
Broadcast(const Output< Node > &arg, const Output< Node > &target_shape, const AutoBroadcastSpec &broadcast_spec=AutoBroadcastSpec(AutoBroadcastType::NUMPY))
Constructs a broadcast operation.
Broadcast(const Output< Node > &arg, const Output< Node > &target_shape, const Output< Node > &axes_mapping, const AutoBroadcastSpec &broadcast_spec=AutoBroadcastSpec())
Constructs a broadcast operation.
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Operation which "adds" axes to an input tensor, replicating elements from the input as needed along t...
Definition: broadcast.hpp:21
Broadcast(const Output< Node > &arg, const Output< Node > &target_shape, const BroadcastModeSpec &broadcast_spec=BroadcastType::NUMPY)
Constructs a broadcast operation.
Broadcast(const Output< Node > &arg, const Output< Node > &target_shape, const Output< Node > &axes_mapping, const BroadcastModeSpec &broadcast_spec=BroadcastType::EXPLICIT)
Constructs a broadcast operation.
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
Broadcast()=default
Constructs a broadcast operation.
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
std::pair< bool, AxisSet > get_broadcast_axes() const override
const NodeTypeInfo & get_type_info() const override
Definition: broadcast.hpp:24
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27
Implicit broadcast specification.
Definition: attr_types.hpp:311
Implicit broadcast specification.
Definition: attr_types.hpp:370