reshape.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include "ngraph/axis_vector.hpp"
20 #include "ngraph/node.hpp"
21 #include "ngraph/op/op.hpp"
22 #include "ngraph/runtime/host_tensor.hpp"
23 
24 namespace ngraph
25 {
26  namespace op
27  {
28  namespace v1
29  {
30  /// \brief Tensor dynamic reshape operation.
31  ///
32  /// "Converts" an input tensor into a new shape with the same number of elements.
33  /// This op does not touch the actual data. If needed, use Transpose for that purpose.
34  ///
35  class NGRAPH_API Reshape : public Op
36  {
37  public:
38  NGRAPH_RTTI_DECLARATION;
39  Reshape() = default;
40  /// \brief Constructs a dynamic reshape operation. This operation does not perform
41  /// transpose.
42  ///
43  /// \param arg The tensor to be reshaped.
44  /// \param shape_pattern The node that defines output shape shape_pattern.
45  /// If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape
46  /// must
47  /// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
48  /// A value of -1 is allowed for at most one dimension, in which case the
49  /// dimension size is inferred based on element count of input tensor.
50  /// \param special_zero Treats zeros in `shape_pattern` as wildcard flags indicating
51  /// a
52  /// copy from input shape at the same index.
53  ///
54  Reshape(const Output<Node>& arg,
55  const Output<Node>& shape_pattern,
56  bool special_zero);
57 
58  bool visit_attributes(AttributeVisitor& visitor) override;
59  void validate_and_infer_types() override;
60 
61  size_t get_version() const override { return 1; }
62  virtual std::shared_ptr<Node>
63  clone_with_new_inputs(const OutputVector& new_args) const override;
64 
65  bool get_special_zero() const { return m_special_zero; }
66  void set_special_zero(bool special_zero) { m_special_zero = special_zero; }
67  bool evaluate(const HostTensorVector& outputs,
68  const HostTensorVector& inputs) const override;
69  bool evaluate_lower(const HostTensorVector& outputs) const override;
70  bool evaluate_upper(const HostTensorVector& outputs) const override;
71  bool constant_fold(OutputVector& output_values,
72  const OutputVector& inputs_values) override;
73 
74  protected:
75  bool m_special_zero;
76  bool evaluate_reshape(const HostTensorVector& outputs,
77  const HostTensorVector& inputs) const;
78  };
79  }
80  }
81 }
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Root of all actual ops.
Definition: op.hpp:29
Tensor dynamic reshape operation.
Definition: reshape.hpp:36
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Reshape(const Output< Node > &arg, const Output< Node > &shape_pattern, bool special_zero)
Constructs a dynamic reshape operation. This operation does not perform transpose.
size_t get_version() const override
Definition: reshape.hpp:61
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28