broadcast.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include "ngraph/axis_set.hpp"
20 #include "ngraph/op/op.hpp"
21 #include "ngraph/op/util/attr_types.hpp"
22 #include "ngraph/op/util/broadcast_base.hpp"
23 
24 namespace ngraph
25 {
26  namespace op
27  {
28  namespace v3
29  {
30  /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
31  /// input as needed along the new axes.
32  class NGRAPH_API Broadcast : public util::BroadcastBase
33  {
34  public:
35  static constexpr NodeTypeInfo type_info{"Broadcast", 3};
36  const NodeTypeInfo& get_type_info() const override { return type_info; }
37  /// \brief Constructs a broadcast operation.
38  Broadcast() = default;
39  /// \brief Constructs a broadcast operation.
40  ///
41  /// \param arg The input tensor to be broadcast.
42  /// \param target_shape The shape of the output tensor.
43  /// \param axes_mapping The axis positions (0-based) in the result that correspond
44  /// to input axes. 'Arg' tensor is broadcast along the
45  /// remaining axes.
46  /// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
47  /// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
48  /// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
49  /// \param broadcast_spec Broadcast specification to use for determining broadcast
50  /// axes. 'axes_mapping' should not be provided if mode other
51  /// than explicit (none) is used.
52  Broadcast(const Output<Node>& arg,
53  const Output<Node>& target_shape,
54  const Output<Node>& axes_mapping,
55  const BroadcastModeSpec& broadcast_spec = BroadcastType::EXPLICIT);
56 
57  /// \brief Constructs a broadcast operation.
58  ///
59  /// \param arg The input tensor to be broadcast.
60  /// \param target_shape The shape of the output tensor.
61  /// \param broadcast_spec Broadcast specification to use for determining broadcast
62  /// axes
63  Broadcast(const Output<Node>& arg,
64  const Output<Node>& target_shape,
65  const BroadcastModeSpec& broadcast_spec = BroadcastType::NUMPY);
66 
67  bool visit_attributes(AttributeVisitor& visitor) override;
68 
69  std::shared_ptr<Node>
70  clone_with_new_inputs(const OutputVector& new_args) const override;
71 
72  // \return Broadcast Specification.
73  const BroadcastModeSpec& get_broadcast_spec() const { return m_mode; }
74  void set_broadcast_spec(const BroadcastModeSpec& broadcast_spec)
75  {
76  m_mode = broadcast_spec;
77  }
78 
79  void validate_and_infer_types() override;
80 
81  /// \return true and the AxisSet if broadcast axes can be fully determined.
82  std::pair<bool, AxisSet> get_broadcast_axes() const override;
83  bool evaluate(const HostTensorVector& outputs,
84  const HostTensorVector& inputs) const override;
85 
86  private:
87  bool broadcast_evaluate(const HostTensorVector& outputs,
88  const HostTensorVector& inputs) const;
89  };
90  } // namespace v3
91 
92  namespace v1
93  {
94  /// \brief Operation which "adds" axes to an input tensor, replicating elements from the
95  /// input as needed along the new axes.
96  class NGRAPH_API Broadcast : public util::BroadcastBase
97  {
98  public:
99  static constexpr NodeTypeInfo type_info{"Broadcast", 1};
100  const NodeTypeInfo& get_type_info() const override { return type_info; }
101  /// \brief Constructs a broadcast operation.
102  Broadcast() = default;
103  /// \brief Constructs a broadcast operation.
104  ///
105  /// \param arg The input tensor to be broadcast.
106  /// \param target_shape The shape of the output tensor.
107  /// \param axes_mapping The axis positions (0-based) in the result that correspond
108  /// to input axes. 'Arg' tensor is broadcast along the
109  /// remaining axes.
110  /// E.g., Input Shape - [3, 4], Target Shape - [3, 5, 4, 4]
111  /// axes_mapping - [0, 2] => Broadcast along axes 1 and 3.
112  /// axes_mapping - [0, 3] => Broadcast along axes 1 and 2.
113  /// \param broadcast_spec Broadcast specification to use for determining broadcast
114  /// axes. 'axes_mapping' is ignored if broadcast_spec is not
115  /// NONE
117  const Output<Node>& target_shape,
118  const Output<Node>& axes_mapping,
119  const AutoBroadcastSpec& broadcast_spec = AutoBroadcastSpec());
120 
121  /// \brief Constructs a broadcast operation.
122  ///
123  /// \param arg The input tensor to be broadcast.
124  /// \param target_shape The shape of the output tensor.
125  /// \param broadcast_spec Broadcast specification to use for determining broadcast
126  /// axes
128  const Output<Node>& target_shape,
129  const AutoBroadcastSpec& broadcast_spec =
130  AutoBroadcastSpec(AutoBroadcastType::NUMPY));
131 
132  bool visit_attributes(AttributeVisitor& visitor) override;
133 
134  std::shared_ptr<Node>
135  clone_with_new_inputs(const OutputVector& new_args) const override;
136 
137  /// \return Broadcast Specification.
138  const AutoBroadcastSpec& get_broadcast_spec() const { return m_broadcast_spec; }
139  void set_broadcast_spec(const AutoBroadcastSpec& broadcast_spec)
140  {
141  m_broadcast_spec = broadcast_spec;
142  }
143 
144  void validate_and_infer_types() override;
145  bool evaluate(const HostTensorVector& outputs,
146  const HostTensorVector& inputs) const override;
147 
148  protected:
149  AutoBroadcastSpec m_broadcast_spec;
150  };
151  } // namespace v1
152  }
153 }
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Definition: broadcast_base.hpp:31
Operation which "adds" axes to an input tensor, replicating elements from the input as needed along t...
Definition: broadcast.hpp:97
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
Broadcast()=default
Constructs a broadcast operation.
const AutoBroadcastSpec & get_broadcast_spec() const
Definition: broadcast.hpp:138
const NodeTypeInfo & get_type_info() const override
Definition: broadcast.hpp:100
Broadcast(const Output< Node > &arg, const Output< Node > &target_shape, const AutoBroadcastSpec &broadcast_spec=AutoBroadcastSpec(AutoBroadcastType::NUMPY))
Constructs a broadcast operation.
Broadcast(const Output< Node > &arg, const Output< Node > &target_shape, const Output< Node > &axes_mapping, const AutoBroadcastSpec &broadcast_spec=AutoBroadcastSpec())
Constructs a broadcast operation.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Operation which "adds" axes to an input tensor, replicating elements from the input as needed along t...
Definition: broadcast.hpp:33
Broadcast(const Output< Node > &arg, const Output< Node > &target_shape, const BroadcastModeSpec &broadcast_spec=BroadcastType::NUMPY)
Constructs a broadcast operation.
Broadcast(const Output< Node > &arg, const Output< Node > &target_shape, const Output< Node > &axes_mapping, const BroadcastModeSpec &broadcast_spec=BroadcastType::EXPLICIT)
Constructs a broadcast operation.
Broadcast()=default
Constructs a broadcast operation.
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
std::pair< bool, AxisSet > get_broadcast_axes() const override
const NodeTypeInfo & get_type_info() const override
Definition: broadcast.hpp:36
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
Definition: type.hpp:39
Implicit broadcast specification.
Definition: attr_types.hpp:323
Implicit broadcast specification.
Definition: attr_types.hpp:379