host_tensor.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <memory>
20 
21 #include "ngraph/node.hpp"
22 #include "ngraph/runtime/tensor.hpp"
23 #include "ngraph/type/element_type.hpp"
24 #include "ngraph/type/element_type_traits.hpp"
25 
26 namespace ngraph
27 {
28  namespace op
29  {
30  namespace v0
31  {
32  class Constant;
33  }
34  }
35  namespace runtime
36  {
37  class NGRAPH_API HostTensor : public ngraph::runtime::Tensor
38  {
39  public:
40  HostTensor(const element::Type& element_type,
41  const Shape& shape,
42  void* memory_pointer,
43  const std::string& name = "");
44  HostTensor(const element::Type& element_type,
45  const Shape& shape,
46  const std::string& name = "");
47  HostTensor(const element::Type& element_type,
48  const PartialShape& partial_shape,
49  const std::string& name = "");
50  HostTensor(const std::string& name = "");
51  explicit HostTensor(const Output<Node>&);
52  explicit HostTensor(const std::shared_ptr<op::v0::Constant>& constant);
53  virtual ~HostTensor() override;
54 
55  void initialize(const std::shared_ptr<op::v0::Constant>& constant);
56 
57  void* get_data_ptr();
58  const void* get_data_ptr() const;
59 
60  template <typename T>
61  T* get_data_ptr()
62  {
63  return static_cast<T*>(get_data_ptr());
64  }
65 
66  template <typename T>
67  const T* get_data_ptr() const
68  {
69  return static_cast<T*>(get_data_ptr());
70  }
71 
72  template <element::Type_t ET>
73  typename element_type_traits<ET>::value_type* get_data_ptr()
74  {
75  NGRAPH_CHECK(ET == get_element_type(),
76  "get_data_ptr() called for incorrect element type.");
77  return static_cast<typename element_type_traits<ET>::value_type*>(get_data_ptr());
78  }
79 
80  template <element::Type_t ET>
81  const typename element_type_traits<ET>::value_type* get_data_ptr() const
82  {
83  NGRAPH_CHECK(ET == get_element_type(),
84  "get_data_ptr() called for incorrect element type.");
85  return static_cast<typename element_type_traits<ET>::value_type>(get_data_ptr());
86  }
87 
88  /// \brief Write bytes directly into the tensor
89  /// \param p Pointer to source of data
90  /// \param n Number of bytes to write, must be integral number of elements.
91  void write(const void* p, size_t n) override;
92 
93  /// \brief Read bytes directly from the tensor
94  /// \param p Pointer to destination for data
95  /// \param n Number of bytes to read, must be integral number of elements.
96  void read(void* p, size_t n) const override;
97 
98  bool get_is_allocated() const;
99  /// \brief Set the element type. Must be compatible with the current element type.
100  /// \param element_type The element type
101  void set_element_type(const element::Type& element_type);
102  /// \brief Set the actual shape of the tensor compatibly with the partial shape.
103  /// \param shape The shape being set
104  void set_shape(const Shape& shape);
105  /// \brief Set the shape of a node from an input
106  /// \param arg The input argument
107  void set_unary(const HostTensorPtr& arg);
108  /// \brief Set the shape of the tensor using broadcast rules
109  /// \param autob The broadcast mode
110  /// \param arg0 The first argument
111  /// \param arg1 The second argument
113  const HostTensorPtr& arg0,
114  const HostTensorPtr& arg1);
115  /// \brief Set the shape of the tensor using broadcast rules
116  /// \param autob The broadcast mode
117  /// \param arg0 The first argument
118  /// \param arg1 The second argument
119  /// \param element_type The output element type
121  const HostTensorPtr& arg0,
122  const HostTensorPtr& arg1,
123  const element::Type& element_type);
124 
125  private:
126  void allocate_buffer();
127  HostTensor(const HostTensor&) = delete;
128  HostTensor(HostTensor&&) = delete;
129  HostTensor& operator=(const HostTensor&) = delete;
130 
131  void* m_memory_pointer{nullptr};
132  void* m_allocated_buffer_pool{nullptr};
133  void* m_aligned_buffer_pool{nullptr};
134  size_t m_buffer_size;
135  };
136  }
137 }
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:46
Shape for a tensor.
Definition: shape.hpp:31
Definition: element_type.hpp:61
Definition: host_tensor.hpp:38
void set_shape(const Shape &shape)
Set the actual shape of the tensor compatibly with the partial shape.
void write(const void *p, size_t n) override
Write bytes directly into the tensor.
void read(void *p, size_t n) const override
Read bytes directly from the tensor.
void set_element_type(const element::Type &element_type)
Set the element type. Must be compatible with the current element type.
void set_unary(const HostTensorPtr &arg)
Set the shape of a node from an input.
void set_broadcast(const op::AutoBroadcastSpec &autob, const HostTensorPtr &arg0, const HostTensorPtr &arg1)
Set the shape of the tensor using broadcast rules.
void set_broadcast(const op::AutoBroadcastSpec &autob, const HostTensorPtr &arg0, const HostTensorPtr &arg1, const element::Type &element_type)
Set the shape of the tensor using broadcast rules.
Definition: tensor.hpp:32
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
Definition: element_type_traits.hpp:25
Implicit broadcast specification.
Definition: attr_types.hpp:323