broadcast.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include "ngraph/axis_set.hpp"
20 #include "ngraph/op/op.hpp"
21 #include "ngraph/op/util/attr_types.hpp"
22 #include "ngraph/op/util/broadcast_base.hpp"
23 
24 namespace ngraph
25 {
26  namespace op
27  {
28  namespace v3
29  {
30  /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
31  /// input as needed along the new axes.
32  class NGRAPH_API Broadcast : public util::BroadcastBase
33  {
34  public:
35  static constexpr NodeTypeInfo type_info{"Broadcast", 3};
36  const NodeTypeInfo& get_type_info() const override { return type_info; }
37  /// \brief Constructs a broadcast operation.
38  Broadcast() = default;
39  /// \brief Constructs a broadcast operation.
40  ///
41  /// \param arg The input tensor to be broadcast.
42  /// \param target_shape The shape of the output tensor.
43  /// \param axes_mapping The axis positions (0-based) in the result that correspond
44  /// to input axes. 'Arg' tensor is broadcast along the
45  /// remaining axes.
46  /// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
47  /// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
48  /// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
49  /// \param broadcast_spec Broadcast specification to use for determining broadcast
50  /// axes. 'axes_mapping' should not be provided if mode other
51  /// than explicit (none) is used.
52  Broadcast(const Output<Node>& arg,
53  const Output<Node>& target_shape,
54  const Output<Node>& axes_mapping,
55  const BroadcastModeSpec& broadcast_spec = BroadcastType::EXPLICIT);
56 
57  /// \brief Constructs a broadcast operation.
58  ///
59  /// \param arg The input tensor to be broadcast.
60  /// \param target_shape The shape of the output tensor.
61  /// \param broadcast_spec Broadcast specification to use for determining broadcast
62  /// axes
63  Broadcast(const Output<Node>& arg,
64  const Output<Node>& target_shape,
65  const BroadcastModeSpec& broadcast_spec = BroadcastType::NUMPY);
66 
67  bool visit_attributes(AttributeVisitor& visitor) override;
68 
69  std::shared_ptr<Node>
70  clone_with_new_inputs(const OutputVector& new_args) const override;
71 
72  // \return Broadcast Specification.
73  const BroadcastModeSpec& get_broadcast_spec() const { return m_mode; }
74  void set_broadcast_spec(const BroadcastModeSpec& broadcast_spec)
75  {
76  m_mode = broadcast_spec;
77  }
78 
79  void validate_and_infer_types() override;
80 
81  /// \return true and the AxisSet if broadcast axes can be fully determined.
82  std::pair<bool, AxisSet> get_broadcast_axes() const override;
83  bool evaluate(const HostTensorVector& outputs,
84  const HostTensorVector& inputs) const override;
85  };
86  } // namespace v3
87 
88  namespace v1
89  {
90  /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
91  /// input as needed along the new axes.
92  class NGRAPH_API Broadcast : public util::BroadcastBase
93  {
94  public:
95  static constexpr NodeTypeInfo type_info{"Broadcast", 1};
96  const NodeTypeInfo& get_type_info() const override { return type_info; }
97  /// \brief Constructs a broadcast operation.
98  Broadcast() = default;
99  /// \brief Constructs a broadcast operation.
100  ///
101  /// \param arg The input tensor to be broadcast.
102  /// \param target_shape The shape of the output tensor.
103  /// \param axes_mapping The axis positions (0-based) in the result that correspond
104  /// to input axes. 'Arg' tensor is broadcast along the
105  /// remaining axes.
106  /// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
107  /// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
108  /// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
109  /// \param broadcast_spec Broadcast specification to use for determining broadcast
110  /// axes. 'axes_mapping' is ignored if broadcast_spec is not
111  /// NONE
112  Broadcast(const Output<Node>& arg,
113  const Output<Node>& target_shape,
114  const Output<Node>& axes_mapping,
115  const AutoBroadcastSpec& broadcast_spec = AutoBroadcastSpec());
116 
117  /// \brief Constructs a broadcast operation.
118  ///
119  /// \param arg The input tensor to be broadcast.
120  /// \param target_shape The shape of the output tensor.
121  /// \param broadcast_spec Broadcast specification to use for determining broadcast
122  /// axes
123  Broadcast(const Output<Node>& arg,
124  const Output<Node>& target_shape,
125  const AutoBroadcastSpec& broadcast_spec =
126  AutoBroadcastSpec(AutoBroadcastType::NUMPY));
127 
128  bool visit_attributes(AttributeVisitor& visitor) override;
129 
130  std::shared_ptr<Node>
131  clone_with_new_inputs(const OutputVector& new_args) const override;
132 
133  /// \return Broadcast Specification.
134  const AutoBroadcastSpec& get_broadcast_spec() const { return m_broadcast_spec; }
135  void set_broadcast_spec(const AutoBroadcastSpec& broadcast_spec)
136  {
137  m_broadcast_spec = broadcast_spec;
138  }
139 
140  void validate_and_infer_types() override;
141  bool evaluate(const HostTensorVector& outputs,
142  const HostTensorVector& inputs) const override;
143 
144  protected:
145  AutoBroadcastSpec m_broadcast_spec;
146  };
147  } // namespace v1
148 
149  namespace v0
150  {
151  NGRAPH_SUPPRESS_DEPRECATED_START
152  /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
153  /// input as needed along the new axes.
154  class NGRAPH_DEPRECATED(
155  "This operation is deprecated and will be removed soon. "
156  "Use v1::Broadcast instead of it.") NGRAPH_API Broadcast : public Op
157  {
158  public:
159  static constexpr NodeTypeInfo type_info{"Broadcast", 0};
160  const NodeTypeInfo& get_type_info() const override { return type_info; }
161  /// \brief Constructs a broadcast operation.
162  Broadcast() = default;
163  /// \brief Constructs a broadcast operation.
164  ///
165  /// \param arg The input tensor to be broadcast.
166  /// \param shape The shape of the output tensor.
167  /// \param broadcast_axes The axis positions (0-based) in the result that are being
168  /// broadcast. The remaining axes in shape must be the same as
169  /// the shape of arg.
170  Broadcast(const Output<Node>& arg,
171  const Shape& shape,
172  const AxisSet& broadcast_axes);
173  bool visit_attributes(AttributeVisitor& visitor) override;
174  void validate_and_infer_types() override;
175 
176  std::shared_ptr<Node>
177  clone_with_new_inputs(const OutputVector& new_args) const override;
178 
179  /// \return A set containing the indices of the broadcast axes (0-based).
180  const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
181  void set_broadcast_axes(const AxisSet& broadcast_axes)
182  {
183  m_broadcast_axes = broadcast_axes;
184  }
185  const Shape& get_broadcast_shape() const { return m_shape; }
186  void set_broadcast_shape(const Shape& shape) { m_shape = shape; }
187  bool evaluate(const HostTensorVector& outputs,
188  const HostTensorVector& inputs) const override;
189 
190  protected:
191  Broadcast(const OutputVector& args,
192  const Shape& shape,
193  const AxisSet& broadcast_axes);
194 
195  virtual void infer_shape() {}
196  Shape m_shape;
197  AxisSet m_broadcast_axes;
198  };
199 
200  /// \brief Broadcast arg to the same shape as like_arg.
201  class NGRAPH_DEPRECATED(
202  "This operation is deprecated and will be removed soon. Please don't use it.")
203  NGRAPH_API BroadcastLike : public v0::Broadcast
204  {
205  public:
206  static constexpr NodeTypeInfo type_info{"BroadcastLike", 0};
207  const NodeTypeInfo& get_type_info() const override { return type_info; }
208  /// \brief Broadcast arg to the same shape as like_arg.
209  BroadcastLike() = default;
210  /// \brief Broadcast arg to the same shape as like_arg.
211  ///
212  /// Once the shape of like_arg is known, this op will be replaced with an equivalent
213  /// Broadcast op.
214  ///
215  /// \param arg The argument to be broadcast.
216  /// \param like_arg Provides the shape for the result.
217  /// \param initial_broadcast_axes indicates which axes will be broadcast. If empty,
218  /// arg must be scalar and all axes are broadcast.
219  BroadcastLike(const Output<Node>& arg,
220  const Output<Node>& like_arg,
221  const AxisSet& initial_broadcast_axes);
222  bool visit_attributes(AttributeVisitor& visitor) override;
223  std::shared_ptr<Node>
224  clone_with_new_inputs(const OutputVector& new_args) const override;
225 
226  void infer_shape() override;
227  const AxisSet& get_initial_broadcast_axes() const
228  {
229  return m_initial_broadcast_axes;
230  }
231  void set_initial_broadcast_axes(const AxisSet& initial_broadcast_axes)
232  {
233  m_initial_broadcast_axes = initial_broadcast_axes;
234  }
235 
236  protected:
237  AxisSet m_initial_broadcast_axes;
238  };
239  NGRAPH_SUPPRESS_DEPRECATED_END
240  } // namespace v0
241 
242  NGRAPH_SUPPRESS_DEPRECATED_START
243  using v0::Broadcast;
244  using v0::BroadcastLike;
245  NGRAPH_SUPPRESS_DEPRECATED_END
246  }
247 }
ngraph::op::BroadcastModeSpec
Implicit broadcast specification.
Definition: attr_types.hpp:377
ngraph::op::v3::Broadcast::get_type_info
const NodeTypeInfo & get_type_info() const override
Definition: broadcast.hpp:36
ngraph::op::v3::Broadcast::Broadcast
Broadcast(const Output< Node > &arg, const Output< Node > &target_shape, const Output< Node > &axes_mapping, const BroadcastModeSpec &broadcast_spec=BroadcastType::EXPLICIT)
Constructs a broadcast operation.
ngraph::op::v1::Broadcast::Broadcast
Broadcast(const Output< Node > &arg, const Output< Node > &target_shape, const Output< Node > &axes_mapping, const AutoBroadcastSpec &broadcast_spec=AutoBroadcastSpec())
Constructs a broadcast operation.
ngraph::op::v3::Broadcast
Operation which "adds" axes to an input tensor, replicating elements from the input as needed along t...
Definition: broadcast.hpp:33
ngraph::op::AutoBroadcastSpec
Implicit broadcast specification.
Definition: attr_types.hpp:321
ngraph::op::v1::Broadcast::validate_and_infer_types
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
ngraph::op::v3::Broadcast::Broadcast
Broadcast(const Output< Node > &arg, const Output< Node > &target_shape, const BroadcastModeSpec &broadcast_spec=BroadcastType::NUMPY)
Constructs a broadcast operation.
ngraph::op::v3::Broadcast::get_broadcast_axes
std::pair< bool, AxisSet > get_broadcast_axes() const override
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::op::v3::Broadcast::validate_and_infer_types
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
ngraph::op::v1::Broadcast::Broadcast
Broadcast()=default
Constructs a broadcast operation.
ngraph::op::v3::Broadcast::Broadcast
Broadcast()=default
Constructs a broadcast operation.
ngraph::op::v1::Broadcast::get_type_info
const NodeTypeInfo & get_type_info() const override
Definition: broadcast.hpp:96
ngraph::AttributeVisitor
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:70
ngraph::op::v1::Broadcast::Broadcast
Broadcast(const Output< Node > &arg, const Output< Node > &target_shape, const AutoBroadcastSpec &broadcast_spec=AutoBroadcastSpec(AutoBroadcastType::NUMPY))
Constructs a broadcast operation.
ngraph::op::v1::Broadcast
Operation which "adds" axes to an input tensor, replicating elements from the input as needed along t...
Definition: broadcast.hpp:93
ngraph::op::v1::Broadcast::get_broadcast_spec
const AutoBroadcastSpec & get_broadcast_spec() const
Definition: broadcast.hpp:134
ngraph::op::Op
Root of all actual ops.
Definition: op.hpp:29
ngraph::op::util::BroadcastBase
Definition: broadcast_base.hpp:31