constant.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <cmath>
20 #include <cstring>
21 #include <sstream>
22 
23 #include "ngraph/coordinate_diff.hpp"
24 #include "ngraph/node.hpp"
25 #include "ngraph/runtime/aligned_buffer.hpp"
26 #include "ngraph/runtime/host_tensor.hpp"
27 #include "ngraph/runtime/shared_buffer.hpp"
28 #include "ngraph/type/element_type.hpp"
29 #include "ngraph/type/element_type_traits.hpp"
30 #include "ngraph/util.hpp"
31 
32 namespace ngraph
33 {
34  namespace op
35  {
36  namespace v0
37  {
38  /// \brief Class for constants.
39  class NGRAPH_API Constant : public Op
40  {
41  public:
42  static constexpr NodeTypeInfo type_info{"Constant", 0};
43  const NodeTypeInfo& get_type_info() const override { return type_info; }
44  Constant() = default;
45 
46  /// \brief Initialize a constant from tensor
47  /// \param tensor The tensor with data
48  Constant(const std::shared_ptr<runtime::Tensor>& tensor);
49 
50  /// \brief Constructs a tensor constant.
51  ///
52  /// \param type The element type of the tensor constant.
53  /// \param shape The shape of the tensor constant.
54  /// \param values A vector of literals for initializing the tensor constant. The
55  /// size of values must match the size of the shape.
56  template <typename T>
57  Constant(const element::Type& type,
58  const Shape& shape,
59  const std::vector<T>& values)
60  : Constant(type, shape)
61  {
62  NODE_VALIDATION_CHECK(
63  this,
64  values.size() == 1 || values.size() == shape_size(m_shape),
65  "Did not get the expected number of literals for a constant of shape ",
66  m_shape,
67  " (got ",
68  values.size(),
69  ", expected ",
70  (shape_size(m_shape) == 1 ? "" : "1 or "),
71  shape_size(m_shape),
72  ").");
73 
74  if (values.size() == 1)
75  {
76  write_values(std::vector<T>(shape_size(m_shape), values[0]));
77  }
78  else
79  {
80  write_values(values);
81  }
82  constructor_validate_and_infer_types();
83  m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
84  }
85 
86  /// \brief Create unitialized constant
87  Constant(const element::Type& type, const Shape& shape);
88  /// \brief Constructs a uniform tensor constant.
89  ///
90  /// \param type The element type of the tensor constant.
91  /// \param shape The shape of the tensor constant.
92  /// \param value A scalar for initializing the uniform tensor constant. The
93  /// value is broadcast to the specified shape.
94  template <class T,
95  class = typename std::enable_if<std::is_fundamental<T>::value>::type>
96  Constant(const element::Type& type, const Shape& shape, T value)
97  : Constant(type, shape)
98  {
99  auto size = shape_size(m_shape);
100 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
101 #pragma GCC diagnostic push
102 #pragma GCC diagnostic error "-Wswitch"
103 #pragma GCC diagnostic error "-Wswitch-enum"
104 #endif
105  switch (type)
106  {
107  case element::Type_t::boolean:
108  std::fill_n(
109  get_data_ptr_nc<element::Type_t::boolean>(),
110  size,
111  static_cast<
113  value));
114  break;
115  case element::Type_t::bf16:
116  std::fill_n(
117  get_data_ptr_nc<element::Type_t::bf16>(),
118  size,
119  static_cast<
121  value));
122  break;
123  case element::Type_t::f16:
124  std::fill_n(
125  get_data_ptr_nc<element::Type_t::f16>(),
126  size,
127  static_cast<
129  value));
130  break;
131  case element::Type_t::f32:
132  std::fill_n(
133  get_data_ptr_nc<element::Type_t::f32>(),
134  size,
135  static_cast<
137  value));
138  break;
139  case element::Type_t::f64:
140  std::fill_n(
141  get_data_ptr_nc<element::Type_t::f64>(),
142  size,
143  static_cast<
145  value));
146  break;
147  case element::Type_t::i8:
148  std::fill_n(
149  get_data_ptr_nc<element::Type_t::i8>(),
150  size,
151  static_cast<
153  value));
154  break;
155  case element::Type_t::i16:
156  std::fill_n(
157  get_data_ptr_nc<element::Type_t::i16>(),
158  size,
159  static_cast<
161  value));
162  break;
163  case element::Type_t::i32:
164  std::fill_n(
165  get_data_ptr_nc<element::Type_t::i32>(),
166  size,
167  static_cast<
169  value));
170  break;
171  case element::Type_t::i64:
172  std::fill_n(
173  get_data_ptr_nc<element::Type_t::i64>(),
174  size,
175  static_cast<
177  value));
178  break;
179  case element::Type_t::u8:
180  std::fill_n(
181  get_data_ptr_nc<element::Type_t::u8>(),
182  size,
183  static_cast<
185  value));
186  break;
187  case element::Type_t::u16:
188  std::fill_n(
189  get_data_ptr_nc<element::Type_t::u16>(),
190  size,
191  static_cast<
193  value));
194  break;
195  case element::Type_t::u32:
196  std::fill_n(
197  get_data_ptr_nc<element::Type_t::u32>(),
198  size,
199  static_cast<
201  value));
202  break;
203  case element::Type_t::u64:
204  std::fill_n(
205  get_data_ptr_nc<element::Type_t::u64>(),
206  size,
207  static_cast<
209  value));
210  break;
211  case element::Type_t::u1: throw std::runtime_error("unsupported type");
212  case element::Type_t::undefined: throw std::runtime_error("unsupported type");
213  case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
214  }
215 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
216 #pragma GCC diagnostic pop
217 #endif
218  constructor_validate_and_infer_types();
219  m_all_elements_bitwise_identical = true;
220  }
221 
222  /// \brief Constructs a tensor constant
223  /// This constructor is mainly to support deserialization of constants.
224  ///
225  /// \param type The element type of the tensor constant.
226  /// \param shape The shape of the tensor constant.
227  /// \param values A list of string values to use as the constant data.
228  Constant(const element::Type& type,
229  const Shape& shape,
230  const std::vector<std::string>& values);
231 
232  /// \brief Constructs a tensor constant with the supplied data
233  ///
234  /// \param type The element type of the tensor constant.
235  /// \param shape The shape of the tensor constant.
236  /// \param data A void* to constant data.
237  Constant(const element::Type& type, const Shape& shape, const void* data);
238 
239  /// \brief Constructs a tensor constant with the supplied data
240  ///
241  /// \param type The element type of the tensor constant.
242  /// \param shape The shape of the tensor constant.
243  /// \param data A pointer to pre-allocated shared data.
244  template <typename T>
245  Constant(const element::Type& type,
246  const Shape& shape,
247  std::shared_ptr<runtime::SharedBuffer<T>> data)
248  : m_element_type(type)
249  , m_shape(shape)
250  {
251  m_data = data;
252  constructor_validate_and_infer_types();
253  }
254 
255  Constant(const Constant& other);
256  Constant& operator=(const Constant&) = delete;
257 
258  virtual ~Constant() override;
259 
260  void validate_and_infer_types() override
261  {
262  infer_element_type();
263  set_output_type(0, m_element_type, m_shape);
264  }
265 
266  bool visit_attributes(AttributeVisitor& visitor) override;
267 
268  bool evaluate(const HostTensorVector& outputs,
269  const HostTensorVector& inputs) const override;
270  bool evaluate_lower(const HostTensorVector& outputs) const override;
271  bool evaluate_upper(const HostTensorVector& outputs) const override;
272 
273  // Don't constant fold a constant; it would make a copy
274  bool constant_fold(OutputVector& outputs, const OutputVector& inputs) override
275  {
276  return false;
277  }
278 
279  /// \brief Returns the value of the constant node as a Shape object
280  /// Can only be used on element::i64 nodes and interprets
281  /// negative values as zeros.
283  /// \brief Returns the value of the constant node as a Strides
284  /// object
285  /// Can only be used on element::i64 nodes and interprets
286  /// negative values as zeros.
288  /// \brief Returns the value of the constant node as a Coordinate
289  /// object
290  /// Can only be used on element::i64 nodes and interprets
291  /// negative values as zeros.
293  /// \brief Returns the value of the constant node as a
294  /// CoordinateDiff object
295  /// Can only be used on element::i64 nodes.
297  /// \brief Returns the value of the constant node as an AxisVector
298  /// object
299  /// Can only be used on element::i64 nodes and interprets
300  /// negative values as zeros.
302  /// \brief Returns the value of the constant node as an AxisSet
303  /// object
304  /// Can only be used on element::i64 nodes and interprets
305  /// negative values as zeros.
306  /// Repeated values are allowed.
308 
309  /// \brief Update Constant shape. New shape size must equal to the data elements
310  /// count
311  ///
312  /// \param shape The shape of the tensor constant.
313  void set_data_shape(const Shape& shape);
314 
315  /// \brief Wrapper around constructing a shared_ptr of a Constant
316  ///
317  /// \param type The element type of the tensor constant.
318  /// \param shape The shape of the tensor constant.
319  /// \param values A vector of values to use as the constant data.
320  template <typename T>
321  static std::shared_ptr<op::v0::Constant>
322  create(const element::Type& type, Shape shape, const std::vector<T> values)
323  {
324  auto result = std::make_shared<op::v0::Constant>(type, shape, values);
325  result->validate_and_infer_types();
326  return result;
327  }
328 
329  /// \brief Wrapper around constructing a shared_ptr of a Constant
330  ///
331  /// \param type The element type of the tensor constant.
332  /// \param shape The shape of the tensor constant.
333  /// \param values An initializer_list of values to use as the constant data.
334  template <typename T>
335  static std::shared_ptr<op::v0::Constant>
336  create(const element::Type& type, Shape shape, std::initializer_list<T> values)
337  {
338  auto result =
339  std::make_shared<op::v0::Constant>(type, shape, std::vector<T>{values});
340  result->validate_and_infer_types();
341  return result;
342  }
343 
344  virtual std::shared_ptr<Node>
345  clone_with_new_inputs(const OutputVector& new_args) const override;
346 
347  /// \return The initialization literals for the tensor constant.
348  std::vector<std::string> get_value_strings() const;
349 
350  template <typename T>
351  std::vector<T> get_vector() const
352  {
353  const T* p = get_data_ptr<T>();
354  if (p == nullptr)
355  throw std::runtime_error("Cannot create vector! Buffer is not allocated.");
356  return std::vector<T>(p, p + shape_size(m_shape));
357  }
358 
359  /// \brief Return the Constant's value as a vector cast to type T
360  ///
361  /// \tparam T Type to which data vector's entries will be cast.
362  /// \return Constant's data vector.
363  template <typename T>
364  std::vector<T> cast_vector() const
365  {
366  auto source_type = get_element_type();
367  std::vector<T> rc;
368 
369 #if defined(_MSC_VER)
370 #pragma warning(push)
371 #pragma warning(disable : 4244)
372 #endif
373  switch (source_type)
374  {
375  case element::Type_t::boolean:
376  {
377  cast_vector<char>(rc);
378  break;
379  }
380  case element::Type_t::bf16:
381  {
382  cast_vector<bfloat16>(rc);
383  break;
384  }
385  case element::Type_t::f16:
386  {
387  cast_vector<float16>(rc);
388  break;
389  }
390  case element::Type_t::f32:
391  {
392  cast_vector<float>(rc);
393  break;
394  }
395  case element::Type_t::f64:
396  {
397  cast_vector<double>(rc);
398  break;
399  }
400  case element::Type_t::i8:
401  {
402  cast_vector<int8_t>(rc);
403  break;
404  }
405  case element::Type_t::i16:
406  {
407  cast_vector<int16_t>(rc);
408  break;
409  }
410  case element::Type_t::i32:
411  {
412  cast_vector<int32_t>(rc);
413  break;
414  }
415  case element::Type_t::i64:
416  {
417  cast_vector<int64_t>(rc);
418  break;
419  }
420  case element::Type_t::u8:
421  {
422  cast_vector<uint8_t>(rc);
423  break;
424  }
425  case element::Type_t::u16:
426  {
427  cast_vector<uint16_t>(rc);
428  break;
429  }
430  case element::Type_t::u32:
431  {
432  cast_vector<uint32_t>(rc);
433  break;
434  }
435  case element::Type_t::u64:
436  {
437  cast_vector<uint64_t>(rc);
438  break;
439  }
440  default: throw std::runtime_error("unsupported type");
441  }
442 #if defined(_MSC_VER)
443 #pragma warning(pop)
444 #endif
445  return rc;
446  }
447 
448  const void* get_data_ptr() const { return (m_data ? m_data->get_ptr() : nullptr); }
449  template <typename T>
450  const T* get_data_ptr() const
451  {
452  if (sizeof(T) > m_element_type.size() && shape_size(m_shape) > 0)
453  {
454  throw ngraph_error("Buffer over-read");
455  }
456 
457  return static_cast<const T*>(get_data_ptr());
458  }
459 
460  template <element::Type_t ET>
461  const typename element_type_traits<ET>::value_type* get_data_ptr() const
462  {
463  NGRAPH_CHECK(ET == get_element_type(),
464  "get_data_ptr() called for incorrect element type.");
465  return static_cast<const typename element_type_traits<ET>::value_type*>(
466  get_data_ptr());
467  }
468 
469  bool get_all_data_elements_bitwise_identical() const
470  {
471  return m_all_elements_bitwise_identical;
472  }
473  std::string convert_value_to_string(size_t index) const;
474 
475  /**
476  * \brief Allows to avoid buffer allocation on the visit_attributes call
477  */
479  {
480  m_alloc_buffer_on_visit_attributes = val;
481  }
482 
483  protected:
484  template <typename IN_T, typename OUT_T>
485  void cast_vector(std::vector<OUT_T>& output_vector) const
486  {
487  auto source_vector = get_vector<IN_T>();
488  output_vector.reserve(source_vector.size());
489 
490  std::transform(source_vector.begin(),
491  source_vector.end(),
492  std::back_inserter(output_vector),
493  [](IN_T c) { return static_cast<OUT_T>(c); });
494  }
495 
496  void allocate_buffer();
497 
498  void* get_data_ptr_nc() { return (m_data ? m_data->get_ptr() : nullptr); }
499  template <element::Type_t ET>
500  typename element_type_traits<ET>::value_type* get_data_ptr_nc()
501  {
502  NGRAPH_CHECK(ET == get_element_type(),
503  "get_data_ptr_nc() called for incorrect element type.");
504  return static_cast<typename element_type_traits<ET>::value_type*>(
505  get_data_ptr_nc());
506  }
507 
508  Constant(const OutputVector& args)
509  : Op(args)
510  , m_shape({})
511  {
512  }
513 
514  virtual void infer_element_type() {}
515  template <typename T>
516  void write_values(const std::vector<T>& values)
517  {
518  write_to_buffer(
519  m_element_type, m_shape, values, get_data_ptr_nc(), shape_size(m_shape));
520  }
521 
522  template <typename T, typename U>
523  static void write_buffer(void* target, const std::vector<U>& source, size_t count)
524  {
525  T* p = reinterpret_cast<T*>(target);
526  for (size_t i = 0; i < count; i++)
527  {
528  p[i] = static_cast<T>(source[i]);
529  }
530  }
531 
532  template <typename T>
533  static void write_to_buffer(const element::Type& target_type,
534  const Shape& /* target_shape */,
535  const std::vector<T>& source,
536  void* target,
537  size_t target_element_count)
538  {
539  if (source.size() != target_element_count)
540  {
541  throw std::runtime_error("Constant initializer does not match shape");
542  }
543 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
544 #pragma GCC diagnostic push
545 #pragma GCC diagnostic error "-Wswitch"
546 #pragma GCC diagnostic error "-Wswitch-enum"
547 #endif
548  switch (target_type)
549  {
550  case element::Type_t::boolean:
551  write_buffer<char, T>(target, source, target_element_count);
552  break;
553  case element::Type_t::bf16:
554  write_buffer<bfloat16, T>(target, source, target_element_count);
555  break;
556  case element::Type_t::f16:
557  write_buffer<float16, T>(target, source, target_element_count);
558  break;
559  case element::Type_t::f32:
560  write_buffer<float, T>(target, source, target_element_count);
561  break;
562  case element::Type_t::f64:
563  write_buffer<double, T>(target, source, target_element_count);
564  break;
565  case element::Type_t::i8:
566  write_buffer<int8_t, T>(target, source, target_element_count);
567  break;
568  case element::Type_t::i16:
569  write_buffer<int16_t, T>(target, source, target_element_count);
570  break;
571  case element::Type_t::i32:
572  write_buffer<int32_t, T>(target, source, target_element_count);
573  break;
574  case element::Type_t::i64:
575  write_buffer<int64_t, T>(target, source, target_element_count);
576  break;
577  case element::Type_t::u8:
578  write_buffer<uint8_t, T>(target, source, target_element_count);
579  break;
580  case element::Type_t::u16:
581  write_buffer<uint16_t, T>(target, source, target_element_count);
582  break;
583  case element::Type_t::u32:
584  write_buffer<uint32_t, T>(target, source, target_element_count);
585  break;
586  case element::Type_t::u64:
587  write_buffer<uint64_t, T>(target, source, target_element_count);
588  break;
589  case element::Type_t::u1: throw std::runtime_error("unsupported type");
590  case element::Type_t::undefined: throw std::runtime_error("unsupported type");
591  case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
592  }
593 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
594 #pragma GCC diagnostic pop
595 #endif
596  }
597 
598  static constexpr size_t host_alignment() { return 64; }
599  element::Type m_element_type;
600  Shape m_shape{};
601  std::shared_ptr<runtime::AlignedBuffer> m_data;
602  bool m_all_elements_bitwise_identical;
603  bool are_all_data_elements_bitwise_identical() const;
604  bool m_alloc_buffer_on_visit_attributes = true;
605  };
606  }
607  using v0::Constant;
608  }
609 }
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
A set of axes.
Definition: axis_set.hpp:31
A vector of axes.
Definition: axis_vector.hpp:30
A difference (signed) of tensor element coordinates.
Definition: coordinate_diff.hpp:30
Coordinates for a tensor element.
Definition: coordinate.hpp:30
Shape for a tensor.
Definition: shape.hpp:31
Strides for a tensor.
Definition: strides.hpp:30
Definition: element_type.hpp:61
Base error for ngraph runtime errors.
Definition: except.hpp:28
Root of all actual ops.
Definition: op.hpp:29
Class for constants.
Definition: constant.hpp:40
std::vector< T > cast_vector() const
Return the Constant's value as a vector cast to type T.
Definition: constant.hpp:364
std::vector< std::string > get_value_strings() const
void alloc_buffer_on_visit_attributes(bool val)
Allows to avoid buffer allocation on the visit_attributes call.
Definition: constant.hpp:478
Constant(const element::Type &type, const Shape &shape, const void *data)
Constructs a tensor constant with the supplied data.
AxisSet get_axis_set_val() const
Returns the value of the constant node as an AxisSet object Can only be used on element::i64 nodes an...
Constant(const std::shared_ptr< runtime::Tensor > &tensor)
Initialize a constant from tensor.
AxisVector get_axis_vector_val() const
Returns the value of the constant node as an AxisVector object Can only be used on element::i64 nodes...
const NodeTypeInfo & get_type_info() const override
Definition: constant.hpp:43
Constant(const element::Type &type, const Shape &shape)
Create unitialized constant.
Constant(const element::Type &type, const Shape &shape, const std::vector< std::string > &values)
Constructs a tensor constant This constructor is mainly to support deserialization of constants.
CoordinateDiff get_coordinate_diff_val() const
Returns the value of the constant node as a CoordinateDiff object Can only be used on element::i64 no...
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
void set_data_shape(const Shape &shape)
Update Constant shape. New shape size must equal to the data elements count.
Strides get_strides_val() const
Returns the value of the constant node as a Strides object Can only be used on element::i64 nodes and...
Constant(const element::Type &type, const Shape &shape, T value)
Constructs a uniform tensor constant.
Definition: constant.hpp:96
static std::shared_ptr< op::v0::Constant > create(const element::Type &type, Shape shape, const std::vector< T > values)
Wrapper around constructing a shared_ptr of a Constant.
Definition: constant.hpp:322
Constant(const element::Type &type, const Shape &shape, std::shared_ptr< runtime::SharedBuffer< T >> data)
Constructs a tensor constant with the supplied data.
Definition: constant.hpp:245
Shape get_shape_val() const
Returns the value of the constant node as a Shape object Can only be used on element::i64 nodes and i...
Coordinate get_coordinate_val() const
Returns the value of the constant node as a Coordinate object Can only be used on element::i64 nodes ...
static std::shared_ptr< op::v0::Constant > create(const element::Type &type, Shape shape, std::initializer_list< T > values)
Wrapper around constructing a shared_ptr of a Constant.
Definition: constant.hpp:336
Constant(const element::Type &type, const Shape &shape, const std::vector< T > &values)
Constructs a tensor constant.
Definition: constant.hpp:57
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Definition: constant.hpp:260
SharedBuffer class to store pointer to pre-acclocated buffer.
Definition: shared_buffer.hpp:30
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
size_t shape_size(const SHAPE_TYPE &shape)
Number of elements in spanned by a shape.
Definition: shape.hpp:71
Definition: type.hpp:39
Definition: element_type_traits.hpp:25