shape.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <cstdio>
20 #include <vector>
21 
22 #include "ngraph/attribute_adapter.hpp"
23 #include "ngraph/axis_set.hpp"
24 #include "ngraph/ngraph_visibility.hpp"
25 #include "ngraph/strides.hpp"
26 
27 namespace ngraph
28 {
29  /// \brief Shape for a tensor.
30  class Shape : public std::vector<size_t>
31  {
32  public:
33  NGRAPH_API Shape();
34 
35  NGRAPH_API Shape(const std::initializer_list<size_t>& axis_lengths);
36 
37  NGRAPH_API Shape(const std::vector<size_t>& axis_lengths);
38 
39  NGRAPH_API Shape(const Shape& axis_lengths);
40 
41  NGRAPH_API explicit Shape(size_t n, size_t initial_value = 0);
42 
43  NGRAPH_API ~Shape();
44 
45  template <class InputIterator>
46  Shape(InputIterator first, InputIterator last)
47  : std::vector<size_t>(first, last)
48  {
49  }
50 
51  NGRAPH_API Shape& operator=(const Shape& v);
52  NGRAPH_API Shape& operator=(Shape&& v) noexcept;
53  };
54 
55  template <>
56  class NGRAPH_API AttributeAdapter<Shape>
57  : public IndirectVectorValueAccessor<Shape, std::vector<int64_t>>
58 
59  {
60  public:
61  AttributeAdapter(Shape& value)
63  {
64  }
65  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Shape>", 0};
66  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
67  };
68 
69  /// Number of elements in spanned by a shape
70  template <typename SHAPE_TYPE>
71  size_t shape_size(const SHAPE_TYPE& shape)
72  {
73  size_t size = 1;
74  for (auto d : shape)
75  {
76  size *= d;
77  }
78  return size;
79  }
80 
81  /// Row-major strides for a shape
82  template <typename SHAPE_TYPE>
83  std::vector<size_t> row_major_strides(const SHAPE_TYPE& shape)
84  {
85  std::vector<size_t> strides(shape.size());
86  size_t s = 1;
87  auto st = strides.rbegin();
88  for (auto d = shape.rbegin(); d != shape.rend() && st != strides.rend(); d++, st++)
89  {
90  *st = s;
91  s *= *d;
92  }
93  return strides;
94  }
95 
96  template <typename SHAPE_TYPE>
97  size_t row_major_stride(const SHAPE_TYPE& shape, size_t axis)
98  {
99  size_t s = 1;
100  for (size_t i = shape.size(); i-- > axis + 1;)
101  {
102  s *= shape[i];
103  }
104  return s;
105  }
106 
107  template <typename SHAPE_TYPE>
108  inline bool is_scalar(const SHAPE_TYPE& shape)
109  {
110  return 0 == shape.size();
111  }
112 
113  template <typename SHAPE_TYPE>
114  inline bool is_vector(const SHAPE_TYPE& shape)
115  {
116  return 1 == shape.size();
117  }
118 
119  NGRAPH_API
120  std::ostream& operator<<(std::ostream& s, const Shape& shape);
121 }
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:171
Definition: attribute_adapter.hpp:137
Shape for a tensor.
Definition: shape.hpp:31
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
size_t shape_size(const SHAPE_TYPE &shape)
Number of elements in spanned by a shape.
Definition: shape.hpp:71
std::vector< size_t > row_major_strides(const SHAPE_TYPE &shape)
Row-major strides for a shape.
Definition: shape.hpp:83
Definition: type.hpp:39