bfloat16.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <cmath>
20 #include <iostream>
21 #include <limits>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "ngraph/ngraph_visibility.hpp"
27 
28 #define ROUND_MODE_TO_NEAREST_EVEN
29 
30 namespace ngraph
31 {
32  class NGRAPH_API bfloat16
33  {
34  public:
35  constexpr bfloat16()
36  : m_value{0}
37  {
38  }
39  bfloat16(float value)
40  : m_value
41  {
42 #if defined ROUND_MODE_TO_NEAREST
43  round_to_nearest(value)
44 #elif defined ROUND_MODE_TO_NEAREST_EVEN
45  round_to_nearest_even(value)
46 #elif defined ROUND_MODE_TRUNCATE
47  truncate(value)
48 #else
49 #error \
50  "ROUNDING_MODE must be one of ROUND_MODE_TO_NEAREST, ROUND_MODE_TO_NEAREST_EVEN, or ROUND_MODE_TRUNCATE"
51 #endif
52  }
53  {
54  }
55 
56  template <typename I>
57  explicit bfloat16(I value)
58  : m_value{bfloat16{static_cast<float>(value)}.m_value}
59  {
60  }
61 
62  std::string to_string() const;
63  size_t size() const;
64  template <typename T>
65  bool operator==(const T& other) const;
66  template <typename T>
67  bool operator!=(const T& other) const
68  {
69  return !(*this == other);
70  }
71  template <typename T>
72  bool operator<(const T& other) const;
73  template <typename T>
74  bool operator<=(const T& other) const;
75  template <typename T>
76  bool operator>(const T& other) const;
77  template <typename T>
78  bool operator>=(const T& other) const;
79  template <typename T>
80  bfloat16 operator+(const T& other) const;
81  template <typename T>
82  bfloat16 operator+=(const T& other);
83  template <typename T>
84  bfloat16 operator-(const T& other) const;
85  template <typename T>
86  bfloat16 operator-=(const T& other);
87  template <typename T>
88  bfloat16 operator*(const T& other) const;
89  template <typename T>
90  bfloat16 operator*=(const T& other);
91  template <typename T>
92  bfloat16 operator/(const T& other) const;
93  template <typename T>
94  bfloat16 operator/=(const T& other);
95  operator float() const;
96 
97  static std::vector<float> to_float_vector(const std::vector<bfloat16>&);
98  static std::vector<bfloat16> from_float_vector(const std::vector<float>&);
99  static constexpr bfloat16 from_bits(uint16_t bits) { return bfloat16(bits, true); }
100  uint16_t to_bits() const;
101  friend std::ostream& operator<<(std::ostream& out, const bfloat16& obj)
102  {
103  out << static_cast<float>(obj);
104  return out;
105  }
106 
107 #define cu32(x) (F32(x).i)
108 
109  static uint16_t round_to_nearest_even(float x)
110  {
111  return static_cast<uint16_t>((cu32(x) + ((cu32(x) & 0x00010000) >> 1)) >> 16);
112  }
113 
114  static uint16_t round_to_nearest(float x)
115  {
116  return static_cast<uint16_t>((cu32(x) + 0x8000) >> 16);
117  }
118 
119  static uint16_t truncate(float x) { return static_cast<uint16_t>((cu32(x)) >> 16); }
120  private:
121  constexpr bfloat16(uint16_t x, bool)
122  : m_value{x}
123  {
124  }
125  union F32 {
126  F32(float val)
127  : f{val}
128  {
129  }
130  F32(uint32_t val)
131  : i{val}
132  {
133  }
134  float f;
135  uint32_t i;
136  };
137 
138  uint16_t m_value;
139  };
140 
141  template <typename T>
142  bool bfloat16::operator==(const T& other) const
143  {
144 #if defined(__GNUC__)
145 #pragma GCC diagnostic push
146 #pragma GCC diagnostic ignored "-Wfloat-equal"
147 #endif
148  return (static_cast<float>(*this) == static_cast<float>(other));
149 #if defined(__GNUC__)
150 #pragma GCC diagnostic pop
151 #endif
152  }
153 
154  template <typename T>
155  bool bfloat16::operator<(const T& other) const
156  {
157  return (static_cast<float>(*this) < static_cast<float>(other));
158  }
159 
160  template <typename T>
161  bool bfloat16::operator<=(const T& other) const
162  {
163  return (static_cast<float>(*this) <= static_cast<float>(other));
164  }
165 
166  template <typename T>
167  bool bfloat16::operator>(const T& other) const
168  {
169  return (static_cast<float>(*this) > static_cast<float>(other));
170  }
171 
172  template <typename T>
173  bool bfloat16::operator>=(const T& other) const
174  {
175  return (static_cast<float>(*this) >= static_cast<float>(other));
176  }
177 
178  template <typename T>
179  bfloat16 bfloat16::operator+(const T& other) const
180  {
181  return {static_cast<float>(*this) + static_cast<float>(other)};
182  }
183 
184  template <typename T>
185  bfloat16 bfloat16::operator+=(const T& other)
186  {
187  return *this = *this + other;
188  }
189 
190  template <typename T>
191  bfloat16 bfloat16::operator-(const T& other) const
192  {
193  return {static_cast<float>(*this) - static_cast<float>(other)};
194  }
195 
196  template <typename T>
197  bfloat16 bfloat16::operator-=(const T& other)
198  {
199  return *this = *this - other;
200  }
201 
202  template <typename T>
203  bfloat16 bfloat16::operator*(const T& other) const
204  {
205  return {static_cast<float>(*this) * static_cast<float>(other)};
206  }
207 
208  template <typename T>
209  bfloat16 bfloat16::operator*=(const T& other)
210  {
211  return *this = *this * other;
212  }
213 
214  template <typename T>
215  bfloat16 bfloat16::operator/(const T& other) const
216  {
217  return {static_cast<float>(*this) / static_cast<float>(other)};
218  }
219 
220  template <typename T>
221  bfloat16 bfloat16::operator/=(const T& other)
222  {
223  return *this = *this / other;
224  }
225 }
226 
227 namespace std
228 {
229  template <>
230  class numeric_limits<ngraph::bfloat16>
231  {
232  public:
233  static constexpr bool is_specialized = true;
234  static constexpr ngraph::bfloat16 min() noexcept
235  {
236  return ngraph::bfloat16::from_bits(0x007F);
237  }
238  static constexpr ngraph::bfloat16 max() noexcept
239  {
240  return ngraph::bfloat16::from_bits(0x7F7F);
241  }
242  static constexpr ngraph::bfloat16 lowest() noexcept
243  {
244  return ngraph::bfloat16::from_bits(0xFF7F);
245  }
246  static constexpr int digits = 7;
247  static constexpr int digits10 = 2;
248  static constexpr bool is_signed = true;
249  static constexpr bool is_integer = false;
250  static constexpr bool is_exact = false;
251  static constexpr int radix = 2;
252  static constexpr ngraph::bfloat16 epsilon() noexcept
253  {
254  return ngraph::bfloat16::from_bits(0x3C00);
255  }
256  static constexpr ngraph::bfloat16 round_error() noexcept
257  {
258  return ngraph::bfloat16::from_bits(0x3F00);
259  }
260  static constexpr int min_exponent = -125;
261  static constexpr int min_exponent10 = -37;
262  static constexpr int max_exponent = 128;
263  static constexpr int max_exponent10 = 38;
264  static constexpr bool has_infinity = true;
265  static constexpr bool has_quiet_NaN = true;
266  static constexpr bool has_signaling_NaN = true;
267  static constexpr float_denorm_style has_denorm = denorm_absent;
268  static constexpr bool has_denorm_loss = false;
269  static constexpr ngraph::bfloat16 infinity() noexcept
270  {
271  return ngraph::bfloat16::from_bits(0x7F80);
272  }
273  static constexpr ngraph::bfloat16 quiet_NaN() noexcept
274  {
275  return ngraph::bfloat16::from_bits(0x7FC0);
276  }
277  static constexpr ngraph::bfloat16 signaling_NaN() noexcept
278  {
279  return ngraph::bfloat16::from_bits(0x7FC0);
280  }
281  static constexpr ngraph::bfloat16 denorm_min() noexcept
282  {
283  return ngraph::bfloat16::from_bits(0);
284  }
285  static constexpr bool is_iec559 = false;
286  static constexpr bool is_bounded = false;
287  static constexpr bool is_modulo = false;
288  static constexpr bool traps = false;
289  static constexpr bool tinyness_before = false;
290  static constexpr float_round_style round_style = round_to_nearest;
291  };
292 }
Definition: bfloat16.hpp:33
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
PartialShape operator+(const PartialShape &s1, const PartialShape &s2)
Elementwise addition of two PartialShape objects.