19 #include "ngraph/axis_set.hpp"
20 #include "ngraph/op/op.hpp"
21 #include "ngraph/op/util/attr_types.hpp"
22 #include "ngraph/op/util/broadcast_base.hpp"
35 static constexpr NodeTypeInfo type_info{
"Broadcast", 3};
36 const NodeTypeInfo&
get_type_info()
const override {
return type_info; }
53 const Output<Node>& target_shape,
54 const Output<Node>& axes_mapping,
64 const Output<Node>& target_shape,
70 clone_with_new_inputs(
const OutputVector& new_args)
const override;
76 m_mode = broadcast_spec;
83 bool evaluate(
const HostTensorVector& outputs,
84 const HostTensorVector& inputs)
const override;
95 static constexpr NodeTypeInfo type_info{
"Broadcast", 1};
96 const NodeTypeInfo&
get_type_info()
const override {
return type_info; }
113 const Output<Node>& target_shape,
114 const Output<Node>& axes_mapping,
124 const Output<Node>& target_shape,
130 std::shared_ptr<Node>
131 clone_with_new_inputs(
const OutputVector& new_args)
const override;
137 m_broadcast_spec = broadcast_spec;
141 bool evaluate(
const HostTensorVector& outputs,
142 const HostTensorVector& inputs)
const override;
151 NGRAPH_SUPPRESS_DEPRECATED_START
154 class NGRAPH_DEPRECATED(
155 "This operation is deprecated and will be removed soon. "
156 "Use v1::Broadcast instead of it.") NGRAPH_API Broadcast :
public Op
159 static constexpr NodeTypeInfo type_info{
"Broadcast", 0};
160 const NodeTypeInfo& get_type_info()
const override {
return type_info; }
162 Broadcast() =
default;
170 Broadcast(
const Output<Node>& arg,
172 const AxisSet& broadcast_axes);
173 bool visit_attributes(AttributeVisitor& visitor)
override;
174 void validate_and_infer_types()
override;
176 std::shared_ptr<Node>
177 clone_with_new_inputs(
const OutputVector& new_args)
const override;
180 const AxisSet& get_broadcast_axes()
const {
return m_broadcast_axes; }
181 void set_broadcast_axes(
const AxisSet& broadcast_axes)
183 m_broadcast_axes = broadcast_axes;
185 const Shape& get_broadcast_shape()
const {
return m_shape; }
186 void set_broadcast_shape(
const Shape& shape) { m_shape = shape; }
187 bool evaluate(
const HostTensorVector& outputs,
188 const HostTensorVector& inputs)
const override;
191 Broadcast(
const OutputVector& args,
193 const AxisSet& broadcast_axes);
195 virtual void infer_shape() {}
197 AxisSet m_broadcast_axes;
201 class NGRAPH_DEPRECATED(
202 "This operation is deprecated and will be removed soon. Please don't use it.")
203 NGRAPH_API BroadcastLike :
public v0::Broadcast
206 static constexpr NodeTypeInfo type_info{
"BroadcastLike", 0};
207 const NodeTypeInfo& get_type_info()
const override {
return type_info; }
209 BroadcastLike() =
default;
219 BroadcastLike(
const Output<Node>& arg,
220 const Output<Node>& like_arg,
221 const AxisSet& initial_broadcast_axes);
222 bool visit_attributes(AttributeVisitor& visitor)
override;
223 std::shared_ptr<Node>
224 clone_with_new_inputs(
const OutputVector& new_args)
const override;
226 void infer_shape()
override;
227 const AxisSet& get_initial_broadcast_axes()
const
229 return m_initial_broadcast_axes;
231 void set_initial_broadcast_axes(
const AxisSet& initial_broadcast_axes)
233 m_initial_broadcast_axes = initial_broadcast_axes;
237 AxisSet m_initial_broadcast_axes;
239 NGRAPH_SUPPRESS_DEPRECATED_END
242 NGRAPH_SUPPRESS_DEPRECATED_START
244 using v0::BroadcastLike;
245 NGRAPH_SUPPRESS_DEPRECATED_END