6 #ifndef DATATYPE_CONVERSION_H_
7 #define DATATYPE_CONVERSION_H_
10 #include <immintrin.h>
20 #if defined(__AVX512F__)
21 constexpr
size_t kAvx512Bits = 512;
22 constexpr
size_t kAvx512Bytes = kAvx512Bits / 8;
23 constexpr
size_t kAvx512FloatsPerInstr = kAvx512Bytes /
sizeof(float);
25 constexpr
size_t kAvx512FloatsPerLoop = kAvx512FloatsPerInstr * 2;
26 constexpr
size_t kAvx512ShortsPerInstr = kAvx512Bytes /
sizeof(short);
28 constexpr
size_t kAvx512ShortsPerLoop = kAvx512ShortsPerInstr / 2;
43 for (
size_t i = 0;
i < n_elems;
i++) {
51 #if defined(__AVX512F__)
52 #if defined(DATATYPE_MEMORY_CHECK)
53 RtAssert(((n_elems % kAvx512ShortsPerLoop) == 0) &&
54 ((
reinterpret_cast<intptr_t
>(in_buf) % kAvx512Bytes) == 0) &&
55 ((
reinterpret_cast<intptr_t
>(out_buf) % kAvx512Bytes) == 0),
56 "Data Alignment not correct before calling into AVX optimizations");
59 const bool unaligned =
60 ((
reinterpret_cast<intptr_t
>(in_buf) % kAvx512Bytes) > 0);
63 const __m512i magic_i = _mm512_castps_si512(magic);
64 for (
size_t i = 0;
i < n_elems;
i += kAvx512ShortsPerLoop) {
69 ? _mm256_loadu_si256(
reinterpret_cast<const __m256i*
>(in_buf +
i))
70 : _mm256_load_si256(
reinterpret_cast<const __m256i*
>(in_buf +
i));
72 const __m512i val_unpacked = _mm512_cvtepu16_epi32(val);
75 const __m512i val_f_int =
76 _mm512_xor_si512(val_unpacked, magic_i);
77 const __m512 val_f = _mm512_castsi512_ps(val_f_int);
78 const __m512 converted = _mm512_sub_ps(val_f, magic);
79 _mm512_store_ps(out_buf +
i, converted);
85 throw std::runtime_error(
"AVX512 is not supported");
90 float* out_buf,
size_t n_elems) {
91 #if defined(DATATYPE_MEMORY_CHECK)
93 ((
reinterpret_cast<intptr_t
>(in_buf) %
kAvx2Bytes) == 0) &&
94 ((
reinterpret_cast<intptr_t
>(out_buf) %
kAvx2Bytes) == 0),
95 "Data Alignment not correct before calling into AVX optimizations");
97 const bool unaligned =
98 ((
reinterpret_cast<intptr_t
>(in_buf) %
kAvx2Bytes) > 0);
102 const __m256i magic_i = _mm256_castps_si256(magic);
107 ? _mm_loadu_si128(
reinterpret_cast<const __m128i*
>(in_buf +
i))
108 : _mm_load_si128(
reinterpret_cast<const __m128i*
>(in_buf +
i));
111 const __m256i val_unpacked = _mm256_cvtepu16_epi32(val);
114 const __m256i val_f_int =
115 _mm256_xor_si256(val_unpacked, magic_i);
116 const __m256 val_f = _mm256_castsi256_ps(val_f_int);
117 const __m256 converted = _mm256_sub_ps(val_f, magic);
118 _mm256_store_ps(out_buf +
i, converted);
128 #if defined(__AVX512F__)
141 short* out_buf,
size_t n_elems,
143 float scale_down_factor) {
144 #if defined(__AVX512F__)
145 #if defined(DATATYPE_MEMORY_CHECK)
146 constexpr
size_t kAvx512ShortPerInstr = kAvx512Bytes /
sizeof(short);
147 RtAssert(((n_elems % kAvx512FloatsPerInstr) == 0) &&
148 ((n_prefix % kAvx512ShortPerInstr) == 0) &&
149 ((
reinterpret_cast<intptr_t
>(in_buf) % kAvx512Bytes) == 0) &&
150 ((
reinterpret_cast<intptr_t
>(out_buf) % kAvx512Bytes) == 0),
151 "Data Alignment not correct before calling into AVX optimizations");
154 const __m512 scale_factor = _mm512_set1_ps(scale_factor_float);
155 const __m512i permute_index = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7);
156 for (
size_t i = 0;
i < n_elems;
i += kAvx512FloatsPerLoop) {
157 const __m512 in1 = _mm512_load_ps(&in_buf[
i]);
158 const __m512 in2 = _mm512_load_ps(&in_buf[
i + kAvx512FloatsPerInstr]);
159 const __m512 scaled_in1 = _mm512_mul_ps(in1, scale_factor);
160 const __m512 scaled_in2 = _mm512_mul_ps(in2, scale_factor);
161 const __m512i int32_1 = _mm512_cvtps_epi32(scaled_in1);
162 const __m512i int32_2 = _mm512_cvtps_epi32(scaled_in2);
163 const __m512i short_int16 = _mm512_packs_epi32(int32_1, int32_2);
164 const __m512i shuffled =
165 _mm512_permutexvar_epi64(permute_index, short_int16);
166 _mm512_stream_si512(
reinterpret_cast<__m512i*
>(&out_buf[
i + n_prefix]),
169 const size_t repeat_idx = n_elems - n_prefix;
170 if (
i >= repeat_idx) {
171 _mm512_stream_si512(
reinterpret_cast<__m512i*
>(&out_buf[
i - repeat_idx]),
180 unused(scale_down_factor);
181 throw std::runtime_error(
"AVX512 is not supported");
191 short* out_buf,
size_t n_elems,
193 float scale_down_factor) {
194 #if defined(DATATYPE_MEMORY_CHECK)
195 constexpr
size_t kAvx2ShortPerInstr =
kAvx2Bytes /
sizeof(short);
197 ((n_prefix % kAvx2ShortPerInstr) == 0) &&
198 ((
reinterpret_cast<intptr_t
>(in_buf) %
kAvx2Bytes) == 0) &&
199 ((
reinterpret_cast<intptr_t
>(out_buf) %
kAvx2Bytes) == 0),
200 "Data Alignment not correct before calling into AVX optimizations");
205 const __m256 scale_factor = _mm256_set1_ps(scale_factor_float);
208 const __m256 in1 = _mm256_load_ps(&in_buf[
i]);
211 const __m256 scaled_in1 = _mm256_mul_ps(in1, scale_factor);
212 const __m256 scaled_in2 = _mm256_mul_ps(in2, scale_factor);
214 const __m256i integer1 = _mm256_cvtps_epi32(scaled_in1);
215 const __m256i integer2 = _mm256_cvtps_epi32(scaled_in2);
217 const __m256i short_ints = _mm256_packs_epi32(integer1, integer2);
219 const __m256i slice = _mm256_permute4x64_epi64(short_ints, 0xD8);
221 _mm256_stream_si256(
reinterpret_cast<__m256i*
>(&out_buf[
i + n_prefix]),
224 const size_t repeat_idx = n_elems - n_prefix;
225 if (
i >= repeat_idx) {
226 _mm256_stream_si256(
reinterpret_cast<__m256i*
>(&out_buf[
i - repeat_idx]),
238 size_t n_elems,
size_t n_prefix = 0,
239 float scale_down_factor = 1.0f) {
240 for (
size_t i = 0;
i < n_elems;
i++) {
241 short converted_value;
242 const float scaled_value =
246 if (scaled_value >= SHRT_MAX) {
247 converted_value = SHRT_MAX;
248 }
else if (scaled_value <= SHRT_MIN) {
249 converted_value = SHRT_MIN;
251 converted_value =
static_cast<short>(scaled_value);
253 out_buf[
i + n_prefix] = converted_value;
256 for (
size_t i = 0;
i < n_prefix;
i++) {
257 out_buf[
i] = out_buf[
i + n_elems];
267 size_t n_elems,
size_t n_prefix = 0,
268 float scale_down_factor = 1.0f) {
269 #if defined(__AVX512F__)
280 const std::complex<float>* in_buf, std::complex<short>* out_buf,
281 size_t n_elems,
size_t n_prefix,
float scale_down_factor) {
282 const auto* in =
reinterpret_cast<const float*
>(in_buf);
283 auto* out =
reinterpret_cast<short*
>(out_buf);
284 #if defined(__AVX512F__)
300 #if defined(DATATYPE_MEMORY_CHECK)
302 "ConvertFloatTo12bitIq n_elems not multiple of 2");
304 size_t index_short = 0;
305 for (
size_t i = 0;
i < n_elems;
i =
i + 2) {
311 out_buf[index_short] = (uint8_t)(temp_i >> 4);
312 out_buf[index_short + 1] =
313 ((uint8_t)(temp_i >> 12)) | ((uint8_t)(temp_q & 0xf0));
314 out_buf[index_short + 2] = (uint8_t)(temp_q >> 8);
316 std::cout <<
"i: " <<
i <<
" " << std::bitset<16>(temp_i) <<
" "
317 << std::bitset<16>(temp_q) <<
" => "
318 << std::bitset<8>(out_buf[index_short]) <<
" "
319 << std::bitset<8>(out_buf[index_short + 1]) <<
" "
320 << std::bitset<8>(out_buf[index_short + 2]) << std::endl;
321 std::printf(
"Original: %.4f+%.4fi \n", in_buf[
i], in_buf[
i + 1]);
328 static inline void SimdConvert16bitIqToFloat(__m256i val,
float* out_buf,
329 __m512 magic, __m512i magic_i) {
331 __m512i val_unpacked = _mm512_cvtepu16_epi32(val);
334 __m512i val_f_int = _mm512_xor_si512(val_unpacked, magic_i);
335 __m512 val_f = _mm512_castsi512_ps(val_f_int);
336 __m512 converted = _mm512_sub_ps(val_f, magic);
337 _mm512_store_ps(out_buf, converted);
343 #if defined(DATATYPE_MEMORY_CHECK)
345 ((
reinterpret_cast<intptr_t
>(in_buf) %
kAvx2Bytes) == 0) &&
346 ((
reinterpret_cast<intptr_t
>(out_buf) %
kAvx2Bytes) == 0),
347 "Convert12bitIqTo16bitIq: Data Alignment not correct before calling "
348 "into AVX optimizations");
350 for (
size_t i = 0;
i < n_elems;
i += 16) {
351 _mm256_loadu_si256((__m256i
const*)in_buf);
352 _mm256_loadu_si256((__m256i
const*)(in_buf + 16));
354 _mm256_setr_epi16(*(uint16_t*)in_buf, *(uint16_t*)(in_buf + 3),
355 *(uint16_t*)(in_buf + 6), *(uint16_t*)(in_buf + 9),
356 *(uint16_t*)(in_buf + 12), *(uint16_t*)(in_buf + 15),
357 *(uint16_t*)(in_buf + 18), *(uint16_t*)(in_buf + 21),
358 *(uint16_t*)(in_buf + 24), *(uint16_t*)(in_buf + 27),
359 *(uint16_t*)(in_buf + 30), *(uint16_t*)(in_buf + 33),
360 *(uint16_t*)(in_buf + 36), *(uint16_t*)(in_buf + 39),
361 *(uint16_t*)(in_buf + 42), *(uint16_t*)(in_buf + 45));
363 __m256i mask_q = _mm256_set1_epi16(0xfff0);
365 _mm256_setr_epi16(*(uint16_t*)(in_buf + 1), *(uint16_t*)(in_buf + 4),
366 *(uint16_t*)(in_buf + 7), *(uint16_t*)(in_buf + 10),
367 *(uint16_t*)(in_buf + 13), *(uint16_t*)(in_buf + 16),
368 *(uint16_t*)(in_buf + 19), *(uint16_t*)(in_buf + 22),
369 *(uint16_t*)(in_buf + 25), *(uint16_t*)(in_buf + 28),
370 *(uint16_t*)(in_buf + 31), *(uint16_t*)(in_buf + 34),
371 *(uint16_t*)(in_buf + 37), *(uint16_t*)(in_buf + 40),
372 *(uint16_t*)(in_buf + 43), *(uint16_t*)(in_buf + 46));
374 temp_q = _mm256_and_si256(temp_q, mask_q);
375 temp_i = _mm256_slli_epi16(temp_i, 4);
377 __m256i iq_0 = _mm256_unpacklo_epi16(temp_i, temp_q);
378 __m256i iq_1 = _mm256_unpackhi_epi16(temp_i, temp_q);
379 __m256i output_0 = _mm256_permute2f128_si256(iq_0, iq_1, 0x20);
380 __m256i output_1 = _mm256_permute2f128_si256(iq_0, iq_1, 0x31);
381 _mm256_store_si256((__m256i*)(out_buf +
i * 2), output_0);
382 _mm256_store_si256((__m256i*)(out_buf +
i * 2 + 16), output_1);
407 const uint16_t* in_16bits_buf,
411 const __m512 magic = _mm512_set1_ps(
float((1 << 23) + (1 << 15)) / 131072.f);
412 const __m512i magic_i = _mm512_castps_si512(magic);
414 const __m256 magic = _mm256_set1_ps(
float((1 << 23) + (1 << 15)) / 131072.f);
415 const __m256i magic_i = _mm256_castps_si256(magic);
418 for (
size_t i = 0;
i < n_elems / 3;
i += 32) {
420 _mm512_set_epi16(*(uint16_t*)(in_buf + 93), *(uint16_t*)(in_buf + 90),
421 *(uint16_t*)(in_buf + 87), *(uint16_t*)(in_buf + 84),
422 *(uint16_t*)(in_buf + 81), *(uint16_t*)(in_buf + 78),
423 *(uint16_t*)(in_buf + 75), *(uint16_t*)(in_buf + 72),
424 *(uint16_t*)(in_buf + 69), *(uint16_t*)(in_buf + 66),
425 *(uint16_t*)(in_buf + 63), *(uint16_t*)(in_buf + 60),
426 *(uint16_t*)(in_buf + 57), *(uint16_t*)(in_buf + 54),
427 *(uint16_t*)(in_buf + 51), *(uint16_t*)(in_buf + 48),
428 *(uint16_t*)(in_buf + 45), *(uint16_t*)(in_buf + 42),
429 *(uint16_t*)(in_buf + 39), *(uint16_t*)(in_buf + 36),
430 *(uint16_t*)(in_buf + 33), *(uint16_t*)(in_buf + 30),
431 *(uint16_t*)(in_buf + 27), *(uint16_t*)(in_buf + 24),
432 *(uint16_t*)(in_buf + 21), *(uint16_t*)(in_buf + 18),
433 *(uint16_t*)(in_buf + 15), *(uint16_t*)(in_buf + 12),
434 *(uint16_t*)(in_buf + 9), *(uint16_t*)(in_buf + 6),
435 *(uint16_t*)(in_buf + 3), *(uint16_t*)(in_buf + 0));
437 __m512i mask_q = _mm512_set1_epi16(0xfff0);
439 _mm512_set_epi16(*(uint16_t*)(in_buf + 94), *(uint16_t*)(in_buf + 91),
440 *(uint16_t*)(in_buf + 88), *(uint16_t*)(in_buf + 85),
441 *(uint16_t*)(in_buf + 82), *(uint16_t*)(in_buf + 79),
442 *(uint16_t*)(in_buf + 76), *(uint16_t*)(in_buf + 73),
443 *(uint16_t*)(in_buf + 70), *(uint16_t*)(in_buf + 67),
444 *(uint16_t*)(in_buf + 64), *(uint16_t*)(in_buf + 61),
445 *(uint16_t*)(in_buf + 58), *(uint16_t*)(in_buf + 55),
446 *(uint16_t*)(in_buf + 52), *(uint16_t*)(in_buf + 49),
447 *(uint16_t*)(in_buf + 46), *(uint16_t*)(in_buf + 43),
448 *(uint16_t*)(in_buf + 40), *(uint16_t*)(in_buf + 37),
449 *(uint16_t*)(in_buf + 34), *(uint16_t*)(in_buf + 31),
450 *(uint16_t*)(in_buf + 28), *(uint16_t*)(in_buf + 25),
451 *(uint16_t*)(in_buf + 22), *(uint16_t*)(in_buf + 19),
452 *(uint16_t*)(in_buf + 16), *(uint16_t*)(in_buf + 13),
453 *(uint16_t*)(in_buf + 10), *(uint16_t*)(in_buf + 7),
454 *(uint16_t*)(in_buf + 4), *(uint16_t*)(in_buf + 1));
456 temp_q = _mm512_and_si512(temp_q, mask_q);
457 temp_i = _mm512_slli_epi16(temp_i, 4);
459 __m512i iq_0 = _mm512_unpacklo_epi16(temp_i, temp_q);
460 __m512i iq_1 = _mm512_unpackhi_epi16(temp_i, temp_q);
461 __m512i output_0 = _mm512_permutex2var_epi64(
462 iq_0, _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0), iq_1);
463 __m512i output_1 = _mm512_permutex2var_epi64(
464 iq_0, _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4), iq_1);
466 SimdConvert16bitIqToFloat(_mm512_extracti64x4_epi64(output_0, 0),
467 out_buf +
i * 2, magic, magic_i);
468 SimdConvert16bitIqToFloat(_mm512_extracti64x4_epi64(output_0, 1),
469 out_buf +
i * 2 + 16, magic, magic_i);
470 SimdConvert16bitIqToFloat(_mm512_extracti64x4_epi64(output_1, 0),
471 out_buf +
i * 2 + 32, magic, magic_i);
472 SimdConvert16bitIqToFloat(_mm512_extracti64x4_epi64(output_1, 1),
473 out_buf +
i * 2 + 48, magic, magic_i);
478 for (
size_t i = 0;
i < n_elems / 3;
i += 16) {
482 _mm256_setr_epi16(*(uint16_t*)in_buf, *(uint16_t*)(in_buf + 3),
483 *(uint16_t*)(in_buf + 6), *(uint16_t*)(in_buf + 9),
484 *(uint16_t*)(in_buf + 12), *(uint16_t*)(in_buf + 15),
485 *(uint16_t*)(in_buf + 18), *(uint16_t*)(in_buf + 21),
486 *(uint16_t*)(in_buf + 24), *(uint16_t*)(in_buf + 27),
487 *(uint16_t*)(in_buf + 30), *(uint16_t*)(in_buf + 33),
488 *(uint16_t*)(in_buf + 36), *(uint16_t*)(in_buf + 39),
489 *(uint16_t*)(in_buf + 42), *(uint16_t*)(in_buf + 45));
491 __m256i mask_q = _mm256_set1_epi16(0xfff0);
493 _mm256_setr_epi16(*(uint16_t*)(in_buf + 1), *(uint16_t*)(in_buf + 4),
494 *(uint16_t*)(in_buf + 7), *(uint16_t*)(in_buf + 10),
495 *(uint16_t*)(in_buf + 13), *(uint16_t*)(in_buf + 16),
496 *(uint16_t*)(in_buf + 19), *(uint16_t*)(in_buf + 22),
497 *(uint16_t*)(in_buf + 25), *(uint16_t*)(in_buf + 28),
498 *(uint16_t*)(in_buf + 31), *(uint16_t*)(in_buf + 34),
499 *(uint16_t*)(in_buf + 37), *(uint16_t*)(in_buf + 40),
500 *(uint16_t*)(in_buf + 43), *(uint16_t*)(in_buf + 46));
502 temp_q = _mm256_and_si256(temp_q, mask_q);
503 temp_i = _mm256_slli_epi16(temp_i, 4);
505 __m256i iq_0 = _mm256_unpacklo_epi16(temp_i, temp_q);
506 __m256i iq_1 = _mm256_unpackhi_epi16(temp_i, temp_q);
507 __m256i output_0 = _mm256_permute2f128_si256(iq_0, iq_1, 0x20);
508 __m256i output_1 = _mm256_permute2f128_si256(iq_0, iq_1, 0x31);
510 _mm256_store_si256((__m256i*)(in_16bits_buf), output_0);
511 _mm256_store_si256((__m256i*)(in_16bits_buf + 16), output_1);
514 for (
size_t j = 0; j < 2; j++) {
517 _mm_load_si128((__m128i*)(in_16bits_buf + j * 16));
519 __m128i val1 = _mm_load_si128((__m128i*)(in_16bits_buf + j * 16 + 8));
521 __m256i val_unpacked = _mm256_cvtepu16_epi32(val);
525 _mm256_xor_si256(val_unpacked, magic_i);
526 __m256 val_f = _mm256_castsi256_ps(val_f_int);
527 __m256 converted = _mm256_sub_ps(val_f, magic);
528 _mm256_store_ps(out_buf +
i * 2 + j * 16, converted);
530 __m256i val_unpacked1 = _mm256_cvtepu16_epi32(val1);
532 _mm256_xor_si256(val_unpacked1, magic_i);
533 __m256 val_f1 = _mm256_castsi256_ps(val_f_int1);
534 __m256 converted1 = _mm256_sub_ps(val_f1, magic);
535 _mm256_store_ps(out_buf +
i * 2 + j * 16 + 8,
550 #if defined(DATATYPE_MEMORY_CHECK)
552 ((
reinterpret_cast<intptr_t
>(in_buf) % 64) == 0) &&
553 ((
reinterpret_cast<intptr_t
>(out_buf) % 64) == 0),
554 "SimdConvertFloat16ToFloat32: Data Alignment not correct before "
555 "calling into AVX optimizations");
558 for (
size_t i = 0;
i < n_elems;
i += 16) {
559 __m256i val_a = _mm256_load_si256((__m256i*)(in_buf +
i / 2));
560 __m512 val = _mm512_cvtph_ps(val_a);
561 _mm512_store_ps(out_buf +
i, val);
564 for (
size_t i = 0;
i < n_elems;
i += 8) {
565 __m128i val_a = _mm_load_si128((__m128i*)(in_buf +
i / 2));
566 __m256 val = _mm256_cvtph_ps(val_a);
567 _mm256_store_ps(out_buf +
i, val);
579 #if defined(DATATYPE_MEMORY_CHECK)
581 ((
reinterpret_cast<intptr_t
>(in_buf) % 64) == 0) &&
582 ((
reinterpret_cast<intptr_t
>(out_buf) % 64) == 0),
583 "SimdConvertFloat32ToFloat16: Data Alignment not correct before "
584 "calling into AVX optimizations");
588 for (
size_t i = 0;
i < n_elems;
i += 16) {
589 __m512 val_a = _mm512_load_ps(in_buf +
i);
590 __m256i val = _mm512_cvtps_ph(val_a, _MM_FROUND_NO_EXC);
591 _mm256_store_si256(
reinterpret_cast<__m256i*
>(out_buf +
i / 2), val);
594 for (
size_t i = 0;
i < n_elems;
i += 8) {
595 __m256 val_a = _mm256_load_ps(in_buf +
i);
596 __m128i val = _mm256_cvtps_ph(val_a, _MM_FROUND_NO_EXC);
597 _mm_store_si128(
reinterpret_cast<__m128i*
>(out_buf +
i / 2), val);
601 #endif // DATATYPE_CONVERSION_H_