host_tensor.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <memory>
8 
9 #include "ngraph/node.hpp"
10 #include "ngraph/runtime/tensor.hpp"
11 #include "ngraph/type/element_type.hpp"
12 #include "ngraph/type/element_type_traits.hpp"
13 
14 namespace ngraph
15 {
16  namespace op
17  {
18  namespace v0
19  {
20  class Constant;
21  }
22  } // namespace op
23  namespace runtime
24  {
25  class NGRAPH_API HostTensor : public ngraph::runtime::Tensor
26  {
27  public:
28  HostTensor(const element::Type& element_type,
29  const Shape& shape,
30  void* memory_pointer,
31  const std::string& name = "");
32  HostTensor(const element::Type& element_type,
33  const Shape& shape,
34  const std::string& name = "");
35  HostTensor(const element::Type& element_type,
36  const PartialShape& partial_shape,
37  const std::string& name = "");
38  HostTensor(const std::string& name = "");
39  explicit HostTensor(const Output<Node>&);
40  explicit HostTensor(const std::shared_ptr<op::v0::Constant>& constant);
41  virtual ~HostTensor() override;
42 
43  void initialize(const std::shared_ptr<op::v0::Constant>& constant);
44 
45  void* get_data_ptr();
46  const void* get_data_ptr() const;
47 
48  template <typename T>
49  T* get_data_ptr()
50  {
51  return static_cast<T*>(get_data_ptr());
52  }
53 
54  template <typename T>
55  const T* get_data_ptr() const
56  {
57  return static_cast<T*>(get_data_ptr());
58  }
59 
60  template <element::Type_t ET>
61  typename element_type_traits<ET>::value_type* get_data_ptr()
62  {
63  NGRAPH_CHECK(ET == get_element_type(),
64  "get_data_ptr() called for incorrect element type.");
65  return static_cast<typename element_type_traits<ET>::value_type*>(get_data_ptr());
66  }
67 
68  template <element::Type_t ET>
69  const typename element_type_traits<ET>::value_type* get_data_ptr() const
70  {
71  NGRAPH_CHECK(ET == get_element_type(),
72  "get_data_ptr() called for incorrect element type.");
73  return static_cast<typename element_type_traits<ET>::value_type>(get_data_ptr());
74  }
75 
76  /// \brief Write bytes directly into the tensor
77  /// \param p Pointer to source of data
78  /// \param n Number of bytes to write, must be integral number of elements.
79  void write(const void* p, size_t n) override;
80 
81  /// \brief Read bytes directly from the tensor
82  /// \param p Pointer to destination for data
83  /// \param n Number of bytes to read, must be integral number of elements.
84  void read(void* p, size_t n) const override;
85 
86  bool get_is_allocated() const;
87  /// \brief Set the element type. Must be compatible with the current element type.
88  /// \param element_type The element type
89  void set_element_type(const element::Type& element_type);
90  /// \brief Set the actual shape of the tensor compatibly with the partial shape.
91  /// \param shape The shape being set
92  void set_shape(const Shape& shape);
93  /// \brief Set the shape of a node from an input
94  /// \param arg The input argument
95  void set_unary(const HostTensorPtr& arg);
96  /// \brief Set the shape of the tensor using broadcast rules
97  /// \param autob The broadcast mode
98  /// \param arg0 The first argument
99  /// \param arg1 The second argument
101  const HostTensorPtr& arg0,
102  const HostTensorPtr& arg1);
103  /// \brief Set the shape of the tensor using broadcast rules
104  /// \param autob The broadcast mode
105  /// \param arg0 The first argument
106  /// \param arg1 The second argument
107  /// \param element_type The output element type
109  const HostTensorPtr& arg0,
110  const HostTensorPtr& arg1,
111  const element::Type& element_type);
112 
113  private:
114  void allocate_buffer();
115  HostTensor(const HostTensor&) = delete;
116  HostTensor(HostTensor&&) = delete;
117  HostTensor& operator=(const HostTensor&) = delete;
118 
119  void* m_memory_pointer{nullptr};
120  void* m_allocated_buffer_pool{nullptr};
121  void* m_aligned_buffer_pool{nullptr};
122  size_t m_buffer_size;
123  };
124  } // namespace runtime
125 } // namespace ngraph
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:34
Shape for a tensor.
Definition: shape.hpp:19
Definition: element_type.hpp:51
Definition: host_tensor.hpp:26
void set_shape(const Shape &shape)
Set the actual shape of the tensor compatibly with the partial shape.
void write(const void *p, size_t n) override
Write bytes directly into the tensor.
void read(void *p, size_t n) const override
Read bytes directly from the tensor.
void set_element_type(const element::Type &element_type)
Set the element type. Must be compatible with the current element type.
void set_unary(const HostTensorPtr &arg)
Set the shape of a node from an input.
void set_broadcast(const op::AutoBroadcastSpec &autob, const HostTensorPtr &arg0, const HostTensorPtr &arg1)
Set the shape of the tensor using broadcast rules.
void set_broadcast(const op::AutoBroadcastSpec &autob, const HostTensorPtr &arg0, const HostTensorPtr &arg1, const element::Type &element_type)
Set the shape of the tensor using broadcast rules.
Definition: tensor.hpp:20
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: element_type_traits.hpp:13
Implicit broadcast specification.
Definition: attr_types.hpp:311