tensor.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <memory>
8 #include <string>
9 #include <unordered_set>
10 
11 #include "ngraph/partial_shape.hpp"
12 #include "ngraph/shape.hpp"
13 #include "ngraph/type/element_type.hpp"
14 
15 namespace ngraph
16 {
17  class Node;
18 
19  namespace runtime
20  {
21  class HostTensor;
22  }
23  using HostTensorPtr = std::shared_ptr<runtime::HostTensor>;
24  namespace descriptor
25  {
26  /// \brief Compile-time descriptor of a first-class value that is a tensor.
27  class NGRAPH_API Tensor
28  {
29  Tensor(const Tensor&) = delete;
30  Tensor& operator=(const Tensor&) = delete;
31 
32  public:
33  Tensor(const element::Type& element_type,
34  const PartialShape& pshape,
35  const std::string& name);
36  Tensor(const element::Type& element_type,
37  const PartialShape& pshape,
38  Node* node,
39  size_t node_output_number);
40 
41  NGRAPH_DEPRECATED("get_name() is deprecated! Please use get_names() instead.")
42  const std::string& get_name() const;
43  NGRAPH_DEPRECATED("set_name() is deprecated! Please use set_names() instead.")
44  void set_name(const std::string& name);
45 
46  const std::unordered_set<std::string>& get_names() const;
47  void set_names(const std::unordered_set<std::string>& names);
48  void set_tensor_type(const element::Type& element_type, const PartialShape& pshape);
49  void set_element_type(const element::Type& elemenet_type);
50  void set_partial_shape(const PartialShape& partial_shape);
51 
52  /// \brief sets lower bound value description
53  void set_lower_value(const HostTensorPtr& value);
54  /// \brief sets upper bound value description
55  void set_upper_value(const HostTensorPtr& value);
56  /// \brief unsets bound value descriptions
58 
59  const element::Type& get_element_type() const { return m_element_type; }
60  const Shape& get_shape() const;
61  const PartialShape& get_partial_shape() const { return m_partial_shape; }
62  /// \brief gets lower bound value description
63  HostTensorPtr get_lower_value() const { return m_lower_value; }
64  /// \brief gets upper bound value description
65  HostTensorPtr get_upper_value() const { return m_upper_value; }
66  /// \brief checks if lower and upper bound are set and point to the same HostTensor
67  bool has_and_set_bound() const
68  {
69  return m_upper_value != nullptr && m_upper_value == m_lower_value;
70  }
71  size_t size() const;
72 
73  protected:
74  element::Type m_element_type;
75 
76  // TODO(amprocte): For now we are maintaining both m_shape and m_partial_shape fields,
77  // with m_shape possibly being invalid (get_shape will throw an exception if it
78  // is). This is because get_shape() returns a const reference. I think ideally we
79  // should refactor so that get_shape returns by value.
80  Shape m_shape;
81  PartialShape m_partial_shape;
82  Node* m_node{nullptr};
83  HostTensorPtr m_lower_value, m_upper_value;
84  size_t m_node_output_number{0};
85 
86  std::string m_name;
87  std::unordered_set<std::string> m_names;
88  };
89 
90  NGRAPH_API
91  std::ostream& operator<<(std::ostream&, const ngraph::descriptor::Tensor&);
92  } // namespace descriptor
93 } // namespace ngraph
Definition: node.hpp:127
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:34
Shape for a tensor.
Definition: shape.hpp:19
Compile-time descriptor of a first-class value that is a tensor.
Definition: tensor.hpp:28
void set_upper_value(const HostTensorPtr &value)
sets upper bound value description
HostTensorPtr get_upper_value() const
gets upper bound value description
Definition: tensor.hpp:65
void set_lower_value(const HostTensorPtr &value)
sets lower bound value description
HostTensorPtr get_lower_value() const
gets lower bound value description
Definition: tensor.hpp:63
bool has_and_set_bound() const
checks if lower and upper bound are set and point to the same HostTensor
Definition: tensor.hpp:67
void invalidate_values()
unsets bound value descriptions
Definition: element_type.hpp:51
Definition: host_tensor.hpp:26
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16