reshape.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/axis_vector.hpp"
8 #include "ngraph/node.hpp"
9 #include "ngraph/op/op.hpp"
10 #include "ngraph/runtime/host_tensor.hpp"
11 
12 namespace ngraph
13 {
14  namespace op
15  {
16  namespace v1
17  {
18  /// \brief Tensor dynamic reshape operation.
19  ///
20  /// "Converts" an input tensor into a new shape with the same number of elements.
21  /// This op does not touch the actual data. If needed, use Transpose for that purpose.
22  ///
23  class NGRAPH_API Reshape : public Op
24  {
25  public:
26  NGRAPH_RTTI_DECLARATION;
27  Reshape() = default;
28  /// \brief Constructs a dynamic reshape operation. This operation does not perform
29  /// transpose.
30  ///
31  /// \param arg The tensor to be reshaped.
32  /// \param shape_pattern The node that defines output shape shape_pattern.
33  /// If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape
34  /// must
35  /// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
36  /// A value of -1 is allowed for at most one dimension, in which case the
37  /// dimension size is inferred based on element count of input tensor.
38  /// \param special_zero Treats zeros in `shape_pattern` as wildcard flags indicating
39  /// a
40  /// copy from input shape at the same index.
41  ///
42  Reshape(const Output<Node>& arg,
43  const Output<Node>& shape_pattern,
44  bool special_zero);
45 
46  bool visit_attributes(AttributeVisitor& visitor) override;
47  void validate_and_infer_types() override;
48 
49  size_t get_version() const override { return 1; }
50  virtual std::shared_ptr<Node>
51  clone_with_new_inputs(const OutputVector& new_args) const override;
52 
53  bool get_special_zero() const { return m_special_zero; }
54  void set_special_zero(bool special_zero) { m_special_zero = special_zero; }
55  bool evaluate(const HostTensorVector& outputs,
56  const HostTensorVector& inputs) const override;
57  bool has_evaluate() const override;
58  bool evaluate_lower(const HostTensorVector& outputs) const override;
59  bool evaluate_upper(const HostTensorVector& outputs) const override;
60  bool constant_fold(OutputVector& output_values,
61  const OutputVector& inputs_values) override;
62 
63  protected:
64  bool m_special_zero;
65  bool evaluate_reshape(const HostTensorVector& outputs,
66  const HostTensorVector& inputs) const;
67 
68  private:
69  void calculate_output_shape(std::vector<Dimension>& reshape_pattern,
70  const int64_t& minus_one_idx,
71  const PartialShape& input_pshape,
72  std::vector<Dimension>& output_shape) const;
73  };
74  } // namespace v1
75  } // namespace op
76 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:34
Root of all actual ops.
Definition: op.hpp:17
Tensor dynamic reshape operation.
Definition: reshape.hpp:24
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Reshape(const Output< Node > &arg, const Output< Node > &shape_pattern, bool special_zero)
Constructs a dynamic reshape operation. This operation does not perform transpose.
size_t get_version() const override
Definition: reshape.hpp:49
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16