tensor.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <memory>
20 #include <string>
21 #include <unordered_set>
22 
23 #include "ngraph/partial_shape.hpp"
24 #include "ngraph/shape.hpp"
25 #include "ngraph/type/element_type.hpp"
26 
27 namespace ngraph
28 {
29  class Node;
30 
31  namespace runtime
32  {
33  class HostTensor;
34  }
35  using HostTensorPtr = std::shared_ptr<runtime::HostTensor>;
36  namespace descriptor
37  {
38  /// \brief Compile-time descriptor of a first-class value that is a tensor.
39  class NGRAPH_API Tensor
40  {
41  Tensor(const Tensor&) = delete;
42  Tensor& operator=(const Tensor&) = delete;
43 
44  public:
45  Tensor(const element::Type& element_type,
46  const PartialShape& pshape,
47  const std::string& name);
48  Tensor(const element::Type& element_type,
49  const PartialShape& pshape,
50  Node* node,
51  size_t node_output_number);
52 
53  NGRAPH_DEPRECATED("get_name() is deprecated! Please use get_names() instead.")
54  const std::string& get_name() const;
55  NGRAPH_DEPRECATED("set_name() is deprecated! Please use set_names() instead.")
56  void set_name(const std::string& name);
57 
58  const std::unordered_set<std::string>& get_names() const;
59  void set_names(const std::unordered_set<std::string>& names);
60  void set_tensor_type(const element::Type& element_type, const PartialShape& pshape);
61  void set_element_type(const element::Type& elemenet_type);
62  void set_partial_shape(const PartialShape& partial_shape);
63 
64  /// \brief sets lower bound value description
65  void set_lower_value(const HostTensorPtr& value);
66  /// \brief sets upper bound value description
67  void set_upper_value(const HostTensorPtr& value);
68  /// \brief unsets bound value descriptions
70 
71  const element::Type& get_element_type() const { return m_element_type; }
72  const Shape& get_shape() const;
73  const PartialShape& get_partial_shape() const { return m_partial_shape; }
74  /// \brief gets lower bound value description
75  HostTensorPtr get_lower_value() const { return m_lower_value; }
76  /// \brief gets upper bound value description
77  HostTensorPtr get_upper_value() const { return m_upper_value; }
78  /// \brief checks if lower and upper bound are set and point to the same HostTensor
79  bool has_and_set_bound() const
80  {
81  return m_upper_value != nullptr && m_upper_value == m_lower_value;
82  }
83  size_t size() const;
84 
85  protected:
86  element::Type m_element_type;
87 
88  // TODO(amprocte): For now we are maintaining both m_shape and m_partial_shape fields,
89  // with m_shape possibly being invalid (get_shape will throw an exception if it
90  // is). This is because get_shape() returns a const reference. I think ideally we
91  // should refactor so that get_shape returns by value.
92  Shape m_shape;
93  PartialShape m_partial_shape;
94  Node* m_node{nullptr};
95  HostTensorPtr m_lower_value, m_upper_value;
96  size_t m_node_output_number{0};
97 
98  std::string m_name;
99  std::unordered_set<std::string> m_names;
100  };
101 
102  NGRAPH_API
103  std::ostream& operator<<(std::ostream&, const ngraph::descriptor::Tensor&);
104  }
105 }
Definition: node.hpp:132
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:46
Shape for a tensor.
Definition: shape.hpp:31
Compile-time descriptor of a first-class value that is a tensor.
Definition: tensor.hpp:40
void set_upper_value(const HostTensorPtr &value)
sets upper bound value description
HostTensorPtr get_upper_value() const
gets upper bound value description
Definition: tensor.hpp:77
void set_lower_value(const HostTensorPtr &value)
sets lower bound value description
HostTensorPtr get_lower_value() const
gets lower bound value description
Definition: tensor.hpp:75
bool has_and_set_bound() const
checks if lower and upper bound are set and point to the same HostTensor
Definition: tensor.hpp:79
void invalidate_values()
unsets bound value descriptions
Definition: element_type.hpp:61
Definition: host_tensor.hpp:38
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28