// Copyright 2019 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// 512-bit AVX512 vectors and operations.
// External include guard in highway.h - see comment there.
// WARNING: most operations do not cross 128-bit block boundaries. In
// particular, "Broadcast", pack and zip behavior may be surprising.
// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL
#include "hwy/base.h"
// Avoid uninitialized warnings in GCC's avx512fintrin.h - see
// https://github.com/google/highway/issues/710)
HWY_DIAGNOSTICS(push)
#if HWY_COMPILER_GCC_ACTUAL
HWY_DIAGNOSTICS_OFF(disable : 4700, ignored
"-Wuninitialized")
HWY_DIAGNOSTICS_OFF(disable : 4701 4703 6001 26494,
ignored
"-Wmaybe-uninitialized")
#endif
#include <immintrin.h>
// AVX2+
#if HWY_COMPILER_CLANGCL
// Including <immintrin.h> should be enough, but Clang's headers helpfully skip
// including these headers when _MSC_VER is defined, like when using clang-cl.
// Include these directly here.
// clang-format off
#include <smmintrin.h>
#include <avxintrin.h>
// avxintrin defines __m256i and must come before avx2intrin.
#include <avx2intrin.h>
#include <f16cintrin.h>
#include <fmaintrin.h>
#include <avx512fintrin.h>
#include <avx512vlintrin.h>
#include <avx512bwintrin.h>
#include <avx512vlbwintrin.h>
#include <avx512dqintrin.h>
#include <avx512vldqintrin.h>
#include <avx512cdintrin.h>
#include <avx512vlcdintrin.h>
#if HWY_TARGET <= HWY_AVX3_DL
#include <avx512bitalgintrin.h>
#include <avx512vlbitalgintrin.h>
#include <avx512vbmiintrin.h>
#include <avx512vbmivlintrin.h>
#include <avx512vbmi2intrin.h>
#include <avx512vlvbmi2intrin.h>
#include <avx512vpopcntdqintrin.h>
#include <avx512vpopcntdqvlintrin.h>
#include <avx512vnniintrin.h>
#include <avx512vlvnniintrin.h>
// Must come after avx512fintrin, else will not define 512-bit intrinsics.
#include <vaesintrin.h>
#include <vpclmulqdqintrin.h>
#include <gfniintrin.h>
#endif // HWY_TARGET <= HWY_AVX3_DL
#if HWY_TARGET <= HWY_AVX3_SPR
#include <avx512fp16intrin.h>
#include <avx512vlfp16intrin.h>
#endif // HWY_TARGET <= HWY_AVX3_SPR
// clang-format on
#endif // HWY_COMPILER_CLANGCL
// For half-width vectors. Already includes base.h and shared-inl.h.
#include "hwy/ops/x86_256-inl.h"
HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {
namespace detail {
template <
typename T>
struct Raw512 {
using type = __m512i;
};
#if HWY_HAVE_FLOAT16
template <>
struct Raw512<float16_t> {
using type = __m512h;
};
#endif // HWY_HAVE_FLOAT16
template <>
struct Raw512<
float> {
using type = __m512;
};
template <>
struct Raw512<
double> {
using type = __m512d;
};
// Template arg: sizeof(lane type)
template <size_t size>
struct RawMask512 {};
template <>
struct RawMask512<1> {
using type = __mmask64;
};
template <>
struct RawMask512<2> {
using type = __mmask32;
};
template <>
struct RawMask512<4> {
using type = __mmask16;
};
template <>
struct RawMask512<8> {
using type = __mmask8;
};
}
// namespace detail
template <
typename T>
class Vec512 {
using Raw =
typename detail::Raw512<T>::type;
public:
using PrivateT = T;
// only for DFromV
static constexpr size_t kPrivateN = 64 /
sizeof(T);
// only for DFromV
// Compound assignment. Only usable if there is a corresponding non-member
// binary operator overload. For example, only f32 and f64 support division.
HWY_INLINE Vec512&
operator*=(
const Vec512 other) {
return *
this = (*
this * other);
}
HWY_INLINE Vec512&
operator/=(
const Vec512 other) {
return *
this = (*
this / other);
}
HWY_INLINE Vec512&
operator+=(
const Vec512 other) {
return *
this = (*
this + other);
}
HWY_INLINE Vec512& operator-=(
const Vec512 other) {
return *
this = (*
this - other);
}
HWY_INLINE Vec512&
operator%=(
const Vec512 other) {
return *
this = (*
this % other);
}
HWY_INLINE Vec512&
operator&=(
const Vec512 other) {
return *
this = (*
this & other);
}
HWY_INLINE Vec512&
operator|=(
const Vec512 other) {
return *
this = (*
this | other);
}
HWY_INLINE Vec512&
operator^=(
const Vec512 other) {
return *
this = (*
this ^ other);
}
Raw raw;
};
// Mask register: one bit per lane.
template <
typename T>
struct Mask512 {
using Raw =
typename detail::RawMask512<
sizeof(T)>::type;
Raw raw;
};
template <
typename T>
using Full512 = Simd<T, 64 /
sizeof(T), 0>;
// ------------------------------ BitCast
namespace detail {
HWY_INLINE __m512i BitCastToInteger(__m512i v) {
return v; }
#if HWY_HAVE_FLOAT16
HWY_INLINE __m512i BitCastToInteger(__m512h v) {
return _mm512_castph_si512(v);
}
#endif // HWY_HAVE_FLOAT16
HWY_INLINE __m512i BitCastToInteger(__m512 v) {
return _mm512_castps_si512(v); }
HWY_INLINE __m512i BitCastToInteger(__m512d v) {
return _mm512_castpd_si512(v);
}
template <
typename T>
HWY_INLINE Vec512<uint8_t> BitCastToByte(Vec512<T> v) {
return Vec512<uint8_t>{BitCastToInteger(v.raw)};
}
// Cannot rely on function overloading because return types differ.
template <
typename T>
struct BitCastFromInteger512 {
HWY_INLINE __m512i
operator()(__m512i v) {
return v; }
};
#if HWY_HAVE_FLOAT16
template <>
struct BitCastFromInteger512<float16_t> {
HWY_INLINE __m512h
operator()(__m512i v) {
return _mm512_castsi512_ph(v); }
};
#endif // HWY_HAVE_FLOAT16
template <>
struct BitCastFromInteger512<
float> {
HWY_INLINE __m512
operator()(__m512i v) {
return _mm512_castsi512_ps(v); }
};
template <>
struct BitCastFromInteger512<
double> {
HWY_INLINE __m512d
operator()(__m512i v) {
return _mm512_castsi512_pd(v); }
};
template <
class D, HWY_IF_V_SIZE_D(D, 64)>
HWY_INLINE VFromD<D> BitCastFromByte(D
/* tag */, Vec512<uint8_t> v) {
return VFromD<D>{BitCastFromInteger512<TFromD<D>>()(v.raw)};
}
}
// namespace detail
template <
class D, HWY_IF_V_SIZE_D(D, 64),
typename FromT>
HWY_API VFromD<D> BitCast(D d, Vec512<FromT> v) {
return detail::BitCastFromByte(d, detail::BitCastToByte(v));
}
// ------------------------------ Set
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)>
HWY_API VFromD<D> Set(D
/* tag */, TFromD<D> t) {
return VFromD<D>{_mm512_set1_epi8(
static_cast<
char>(t))};
// NOLINT
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI16_D(D)>
HWY_API VFromD<D> Set(D
/* tag */, TFromD<D> t) {
return VFromD<D>{_mm512_set1_epi16(
static_cast<
short>(t))};
// NOLINT
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)>
HWY_API VFromD<D> Set(D
/* tag */, TFromD<D> t) {
return VFromD<D>{_mm512_set1_epi32(
static_cast<
int>(t))};
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)>
HWY_API VFromD<D> Set(D
/* tag */, TFromD<D> t) {
return VFromD<D>{_mm512_set1_epi64(
static_cast<
long long>(t))};
// NOLINT
}
// bfloat16_t is handled by x86_128-inl.h.
#if HWY_HAVE_FLOAT16
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)>
HWY_API Vec512<float16_t> Set(D
/* tag */, float16_t t) {
return Vec512<float16_t>{_mm512_set1_ph(t)};
}
#endif // HWY_HAVE_FLOAT16
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)>
HWY_API Vec512<
float> Set(D
/* tag */, float t) {
return Vec512<
float>{_mm512_set1_ps(t)};
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)>
HWY_API Vec512<
double> Set(D
/* tag */, double t) {
return Vec512<
double>{_mm512_set1_pd(t)};
}
// ------------------------------ Zero (Set)
// GCC pre-9.1 lacked setzero, so use Set instead.
#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900
// Cannot use VFromD here because it is defined in terms of Zero.
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_SPECIAL_FLOAT_D(D)>
HWY_API Vec512<TFromD<D>> Zero(D d) {
return Set(d, TFromD<D>{0});
}
// BitCast is defined below, but the Raw type is the same, so use that.
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_BF16_D(D)>
HWY_API Vec512<bfloat16_t> Zero(D
/* tag */) {
const RebindToUnsigned<D> du;
return Vec512<bfloat16_t>{Set(du, 0).raw};
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)>
HWY_API Vec512<float16_t> Zero(D
/* tag */) {
const RebindToUnsigned<D> du;
return Vec512<float16_t>{Set(du, 0).raw};
}
#else
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)>
HWY_API Vec512<TFromD<D>> Zero(D
/* tag */) {
return Vec512<TFromD<D>>{_mm512_setzero_si512()};
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_BF16_D(D)>
HWY_API Vec512<bfloat16_t> Zero(D
/* tag */) {
return Vec512<bfloat16_t>{_mm512_setzero_si512()};
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)>
HWY_API Vec512<float16_t> Zero(D
/* tag */) {
#if HWY_HAVE_FLOAT16
return Vec512<float16_t>{_mm512_setzero_ph()};
#else
return Vec512<float16_t>{_mm512_setzero_si512()};
#endif
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)>
HWY_API Vec512<
float> Zero(D
/* tag */) {
return Vec512<
float>{_mm512_setzero_ps()};
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)>
HWY_API Vec512<
double> Zero(D
/* tag */) {
return Vec512<
double>{_mm512_setzero_pd()};
}
#endif // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900
// ------------------------------ Undefined
HWY_DIAGNOSTICS(push)
HWY_DIAGNOSTICS_OFF(disable : 4700, ignored
"-Wuninitialized")
// Returns a vector with uninitialized elements.
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)>
HWY_API Vec512<TFromD<D>> Undefined(D
/* tag */) {
// Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC
// generate an XOR instruction.
return Vec512<TFromD<D>>{_mm512_undefined_epi32()};
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_BF16_D(D)>
HWY_API Vec512<bfloat16_t> Undefined(D
/* tag */) {
return Vec512<bfloat16_t>{_mm512_undefined_epi32()};
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)>
HWY_API Vec512<float16_t> Undefined(D
/* tag */) {
#if HWY_HAVE_FLOAT16
return Vec512<float16_t>{_mm512_undefined_ph()};
#else
return Vec512<float16_t>{_mm512_undefined_epi32()};
#endif
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)>
HWY_API Vec512<
float> Undefined(D
/* tag */) {
return Vec512<
float>{_mm512_undefined_ps()};
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)>
HWY_API Vec512<
double> Undefined(D
/* tag */) {
return Vec512<
double>{_mm512_undefined_pd()};
}
HWY_DIAGNOSTICS(pop)
// ------------------------------ ResizeBitCast
// 64-byte vector to 16-byte vector
template <
class D,
class FromV, HWY_IF_V_SIZE_V(FromV, 64),
HWY_IF_V_SIZE_D(D, 16)>
HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
return BitCast(d, Vec128<uint8_t>{_mm512_castsi512_si128(
BitCast(Full512<uint8_t>(), v).raw)});
}
// <= 16-byte vector to 64-byte vector
template <
class D,
class FromV, HWY_IF_V_SIZE_LE_V(FromV, 16),
HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
return BitCast(d, Vec512<uint8_t>{_mm512_castsi128_si512(
ResizeBitCast(Full128<uint8_t>(), v).raw)});
}
// 32-byte vector to 64-byte vector
template <
class D,
class FromV, HWY_IF_V_SIZE_V(FromV, 32),
HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
return BitCast(d, Vec512<uint8_t>{_mm512_castsi256_si512(
BitCast(Full256<uint8_t>(), v).raw)});
}
// ------------------------------ Dup128VecFromValues
template <
class D, HWY_IF_UI8_D(D), HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
TFromD<D> t5, TFromD<D> t6, TFromD<D> t7,
TFromD<D> t8, TFromD<D> t9, TFromD<D> t10,
TFromD<D> t11, TFromD<D> t12,
TFromD<D> t13, TFromD<D> t14,
TFromD<D> t15) {
#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900
// Missing set_epi8/16.
return BroadcastBlock<0>(ResizeBitCast(
d, Dup128VecFromValues(Full128<TFromD<D>>(), t0, t1, t2, t3, t4, t5, t6,
t7, t8, t9, t10, t11, t12, t13, t14, t15)));
#else
(
void)d;
// Need to use _mm512_set_epi8 as there is no _mm512_setr_epi8 intrinsic
// available
return VFromD<D>{_mm512_set_epi8(
static_cast<
char>(t15),
static_cast<
char>(t14),
static_cast<
char>(t13),
static_cast<
char>(t12),
static_cast<
char>(t11),
static_cast<
char>(t10),
static_cast<
char>(t9),
static_cast<
char>(t8),
static_cast<
char>(t7),
static_cast<
char>(t6),
static_cast<
char>(t5),
static_cast<
char>(t4),
static_cast<
char>(t3),
static_cast<
char>(t2),
static_cast<
char>(t1),
static_cast<
char>(t0),
static_cast<
char>(t15),
static_cast<
char>(t14),
static_cast<
char>(t13),
static_cast<
char>(t12),
static_cast<
char>(t11),
static_cast<
char>(t10),
static_cast<
char>(t9),
static_cast<
char>(t8),
static_cast<
char>(t7),
static_cast<
char>(t6),
static_cast<
char>(t5),
static_cast<
char>(t4),
static_cast<
char>(t3),
static_cast<
char>(t2),
static_cast<
char>(t1),
static_cast<
char>(t0),
static_cast<
char>(t15),
static_cast<
char>(t14),
static_cast<
char>(t13),
static_cast<
char>(t12),
static_cast<
char>(t11),
static_cast<
char>(t10),
static_cast<
char>(t9),
static_cast<
char>(t8),
static_cast<
char>(t7),
static_cast<
char>(t6),
static_cast<
char>(t5),
static_cast<
char>(t4),
static_cast<
char>(t3),
static_cast<
char>(t2),
static_cast<
char>(t1),
static_cast<
char>(t0),
static_cast<
char>(t15),
static_cast<
char>(t14),
static_cast<
char>(t13),
static_cast<
char>(t12),
static_cast<
char>(t11),
static_cast<
char>(t10),
static_cast<
char>(t9),
static_cast<
char>(t8),
static_cast<
char>(t7),
static_cast<
char>(t6),
static_cast<
char>(t5),
static_cast<
char>(t4),
static_cast<
char>(t3),
static_cast<
char>(t2),
static_cast<
char>(t1),
static_cast<
char>(t0))};
#endif
}
template <
class D, HWY_IF_UI16_D(D), HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
TFromD<D> t5, TFromD<D> t6,
TFromD<D> t7) {
#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900
// Missing set_epi8/16.
return BroadcastBlock<0>(
ResizeBitCast(d, Dup128VecFromValues(Full128<TFromD<D>>(), t0, t1, t2, t3,
t4, t5, t6, t7)));
#else
(
void)d;
// Need to use _mm512_set_epi16 as there is no _mm512_setr_epi16 intrinsic
// available
return VFromD<D>{
_mm512_set_epi16(
static_cast<int16_t>(t7),
static_cast<int16_t>(t6),
static_cast<int16_t>(t5),
static_cast<int16_t>(t4),
static_cast<int16_t>(t3),
static_cast<int16_t>(t2),
static_cast<int16_t>(t1),
static_cast<int16_t>(t0),
static_cast<int16_t>(t7),
static_cast<int16_t>(t6),
static_cast<int16_t>(t5),
static_cast<int16_t>(t4),
static_cast<int16_t>(t3),
static_cast<int16_t>(t2),
static_cast<int16_t>(t1),
static_cast<int16_t>(t0),
static_cast<int16_t>(t7),
static_cast<int16_t>(t6),
static_cast<int16_t>(t5),
static_cast<int16_t>(t4),
static_cast<int16_t>(t3),
static_cast<int16_t>(t2),
static_cast<int16_t>(t1),
static_cast<int16_t>(t0),
static_cast<int16_t>(t7),
static_cast<int16_t>(t6),
static_cast<int16_t>(t5),
static_cast<int16_t>(t4),
static_cast<int16_t>(t3),
static_cast<int16_t>(t2),
static_cast<int16_t>(t1),
static_cast<int16_t>(t0))};
#endif
}
#if HWY_HAVE_FLOAT16
template <
class D, HWY_IF_F16_D(D), HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
TFromD<D> t5, TFromD<D> t6,
TFromD<D> t7) {
return VFromD<D>{_mm512_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, t2,
t3, t4, t5, t6, t7, t0, t1, t2, t3, t4, t5,
t6, t7, t0, t1, t2, t3, t4, t5, t6, t7)};
}
#endif
template <
class D, HWY_IF_UI32_D(D), HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3) {
return VFromD<D>{
_mm512_setr_epi32(
static_cast<int32_t>(t0),
static_cast<int32_t>(t1),
static_cast<int32_t>(t2),
static_cast<int32_t>(t3),
static_cast<int32_t>(t0),
static_cast<int32_t>(t1),
static_cast<int32_t>(t2),
static_cast<int32_t>(t3),
static_cast<int32_t>(t0),
static_cast<int32_t>(t1),
static_cast<int32_t>(t2),
static_cast<int32_t>(t3),
static_cast<int32_t>(t0),
static_cast<int32_t>(t1),
static_cast<int32_t>(t2),
static_cast<int32_t>(t3))};
}
template <
class D, HWY_IF_F32_D(D), HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3) {
return VFromD<D>{_mm512_setr_ps(t0, t1, t2, t3, t0, t1, t2, t3, t0, t1, t2,
t3, t0, t1, t2, t3)};
}
template <
class D, HWY_IF_UI64_D(D), HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1) {
return VFromD<D>{
_mm512_setr_epi64(
static_cast<int64_t>(t0),
static_cast<int64_t>(t1),
static_cast<int64_t>(t0),
static_cast<int64_t>(t1),
static_cast<int64_t>(t0),
static_cast<int64_t>(t1),
static_cast<int64_t>(t0),
static_cast<int64_t>(t1))};
}
template <
class D, HWY_IF_F64_D(D), HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1) {
return VFromD<D>{_mm512_setr_pd(t0, t1, t0, t1, t0, t1, t0, t1)};
}
// ----------------------------- Iota
namespace detail {
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)>
HWY_INLINE VFromD<D> Iota0(D d) {
#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900
// Missing set_epi8/16.
alignas(64)
static constexpr TFromD<D> kIota[64] = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
return Load(d, kIota);
#else
(
void)d;
return VFromD<D>{_mm512_set_epi8(
static_cast<
char>(63),
static_cast<
char>(62),
static_cast<
char>(61),
static_cast<
char>(60),
static_cast<
char>(59),
static_cast<
char>(58),
static_cast<
char>(57),
static_cast<
char>(56),
static_cast<
char>(55),
static_cast<
char>(54),
static_cast<
char>(53),
static_cast<
char>(52),
static_cast<
char>(51),
static_cast<
char>(50),
static_cast<
char>(49),
static_cast<
char>(48),
static_cast<
char>(47),
static_cast<
char>(46),
static_cast<
char>(45),
static_cast<
char>(44),
static_cast<
char>(43),
static_cast<
char>(42),
static_cast<
char>(41),
static_cast<
char>(40),
static_cast<
char>(39),
static_cast<
char>(38),
static_cast<
char>(37),
static_cast<
char>(36),
static_cast<
char>(35),
static_cast<
char>(34),
static_cast<
char>(33),
static_cast<
char>(32),
static_cast<
char>(31),
static_cast<
char>(30),
static_cast<
char>(29),
static_cast<
char>(28),
static_cast<
char>(27),
static_cast<
char>(26),
static_cast<
char>(25),
static_cast<
char>(24),
static_cast<
char>(23),
static_cast<
char>(22),
static_cast<
char>(21),
static_cast<
char>(20),
static_cast<
char>(19),
static_cast<
char>(18),
static_cast<
char>(17),
static_cast<
char>(16),
static_cast<
char>(15),
static_cast<
char>(14),
static_cast<
char>(13),
static_cast<
char>(12),
static_cast<
char>(11),
static_cast<
char>(10),
static_cast<
char>(9),
static_cast<
char>(8),
static_cast<
char>(7),
static_cast<
char>(6),
static_cast<
char>(5),
static_cast<
char>(4),
static_cast<
char>(3),
static_cast<
char>(2),
static_cast<
char>(1),
static_cast<
char>(0))};
#endif
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI16_D(D)>
HWY_INLINE VFromD<D> Iota0(D d) {
#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900
// Missing set_epi8/16.
alignas(64)
static constexpr TFromD<D> kIota[32] = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
return Load(d, kIota);
#else
(
void)d;
return VFromD<D>{_mm512_set_epi16(
int16_t{31}, int16_t{30}, int16_t{29}, int16_t{28}, int16_t{27},
int16_t{26}, int16_t{25}, int16_t{24}, int16_t{23}, int16_t{22},
int16_t{21}, int16_t{20}, int16_t{19}, int16_t{18}, int16_t{17},
int16_t{16}, int16_t{15}, int16_t{14}, int16_t{13}, int16_t{12},
int16_t{11}, int16_t{10}, int16_t{9}, int16_t{8}, int16_t{7}, int16_t{6},
int16_t{5}, int16_t{4}, int16_t{3}, int16_t{2}, int16_t{1}, int16_t{0})};
#endif
}
#if HWY_HAVE_FLOAT16
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)>
HWY_INLINE VFromD<D> Iota0(D
/*d*/) {
return VFromD<D>{_mm512_set_ph(
float16_t{31}, float16_t{30}, float16_t{29}, float16_t{28}, float16_t{27},
float16_t{26}, float16_t{25}, float16_t{24}, float16_t{23}, float16_t{22},
float16_t{21}, float16_t{20}, float16_t{19}, float16_t{18}, float16_t{17},
float16_t{16}, float16_t{15}, float16_t{14}, float16_t{13}, float16_t{12},
float16_t{11}, float16_t{10}, float16_t{9}, float16_t{8}, float16_t{7},
float16_t{6}, float16_t{5}, float16_t{4}, float16_t{3}, float16_t{2},
float16_t{1}, float16_t{0})};
}
#endif // HWY_HAVE_FLOAT16
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)>
HWY_INLINE VFromD<D> Iota0(D
/*d*/) {
return VFromD<D>{_mm512_set_epi32(
int32_t{15}, int32_t{14}, int32_t{13}, int32_t{12}, int32_t{11},
int32_t{10}, int32_t{9}, int32_t{8}, int32_t{7}, int32_t{6}, int32_t{5},
int32_t{4}, int32_t{3}, int32_t{2}, int32_t{1}, int32_t{0})};
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)>
HWY_INLINE VFromD<D> Iota0(D
/*d*/) {
return VFromD<D>{_mm512_set_epi64(int64_t{7}, int64_t{6}, int64_t{5},
int64_t{4}, int64_t{3}, int64_t{2},
int64_t{1}, int64_t{0})};
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)>
HWY_INLINE VFromD<D> Iota0(D
/*d*/) {
return VFromD<D>{_mm512_set_ps(15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f,
8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f,
0.0f)};
}
template <
class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)>
HWY_INLINE VFromD<D> Iota0(D
/*d*/) {
return VFromD<D>{_mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0)};
}
}
// namespace detail
template <
class D,
typename T2, HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Iota(D d,
const T2 first) {
return detail::Iota0(d) + Set(d, ConvertScalarTo<TFromD<D>>(first));
}
// ================================================== LOGICAL
// ------------------------------ Not
template <
typename T>
HWY_API Vec512<T>
Not(
const Vec512<T> v) {
const DFromV<decltype(v)> d;
const RebindToUnsigned<decltype(d)> du;
using VU = VFromD<decltype(du)>;
const __m512i vu = BitCast(du, v).raw;
return BitCast(d, VU{_mm512_ternarylogic_epi32(vu, vu, vu, 0x55)});
}
// ------------------------------ And
template <
typename T>
HWY_API Vec512<T>
And(
const Vec512<T> a,
const Vec512<T> b) {
const DFromV<decltype(a)> d;
// for float16_t
const RebindToUnsigned<decltype(d)> du;
return BitCast(d, VFromD<decltype(du)>{_mm512_and_si512(BitCast(du, a).raw,
BitCast(du, b).raw)});
}
HWY_API Vec512<
float>
And(
const Vec512<
float> a,
const Vec512<
float> b) {
return Vec512<
float>{_mm512_and_ps(a.raw, b.raw)};
}
HWY_API Vec512<
double>
And(
const Vec512<
double> a,
const Vec512<
double> b) {
return Vec512<
double>{_mm512_and_pd(a.raw, b.raw)};
}
// ------------------------------ AndNot
// Returns ~not_mask & mask.
template <
typename T>
HWY_API Vec512<T> AndNot(
const Vec512<T> not_mask,
const Vec512<T> mask) {
const DFromV<decltype(mask)> d;
// for float16_t
const RebindToUnsigned<decltype(d)> du;
return BitCast(d, VFromD<decltype(du)>{_mm512_andnot_si512(
BitCast(du, not_mask).raw, BitCast(du, mask).raw)});
}
HWY_API Vec512<
float> AndNot(
const Vec512<
float> not_mask,
const Vec512<
float> mask) {
return Vec512<
float>{_mm512_andnot_ps(not_mask.raw, mask.raw)};
}
HWY_API Vec512<
double> AndNot(
const Vec512<
double> not_mask,
const Vec512<
double> mask) {
return Vec512<
double>{_mm512_andnot_pd(not_mask.raw, mask.raw)};
}
// ------------------------------ Or
template <
typename T>
HWY_API Vec512<T>
Or(
const Vec512<T> a,
const Vec512<T> b) {
const DFromV<decltype(a)> d;
// for float16_t
const RebindToUnsigned<decltype(d)> du;
return BitCast(d, VFromD<decltype(du)>{_mm512_or_si512(BitCast(du, a).raw,
BitCast(du, b).raw)});
}
HWY_API Vec512<
float>
Or(
const Vec512<
float> a,
const Vec512<
float> b) {
return Vec512<
float>{_mm512_or_ps(a.raw, b.raw)};
}
HWY_API Vec512<
double>
Or(
const Vec512<
double> a,
const Vec512<
double> b) {
return Vec512<
double>{_mm512_or_pd(a.raw, b.raw)};
}
// ------------------------------ Xor
template <
typename T>
HWY_API Vec512<T>
Xor(
const Vec512<T> a,
const Vec512<T> b) {
const DFromV<decltype(a)> d;
// for float16_t
const RebindToUnsigned<decltype(d)> du;
return BitCast(d, VFromD<decltype(du)>{_mm512_xor_si512(BitCast(du, a).raw,
BitCast(du, b).raw)});
}
HWY_API Vec512<
float>
Xor(
const Vec512<
float> a,
const Vec512<
float> b) {
return Vec512<
float>{_mm512_xor_ps(a.raw, b.raw)};
}
HWY_API Vec512<
double>
Xor(
const Vec512<
double> a,
const Vec512<
double> b) {
return Vec512<
double>{_mm512_xor_pd(a.raw, b.raw)};
}
// ------------------------------ Xor3
template <
typename T>
HWY_API Vec512<T> Xor3(Vec512<T> x1, Vec512<T> x2, Vec512<T> x3) {
const DFromV<decltype(x1)> d;
const RebindToUnsigned<decltype(d)> du;
using VU = VFromD<decltype(du)>;
const __m512i ret = _mm512_ternarylogic_epi64(
BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96);
return BitCast(d, VU{ret});
}
// ------------------------------ Or3
template <
typename T>
HWY_API Vec512<T> Or3(Vec512<T> o1, Vec512<T> o2, Vec512<T> o3) {
const DFromV<decltype(o1)> d;
const RebindToUnsigned<decltype(d)> du;
using VU = VFromD<decltype(du)>;
const __m512i ret = _mm512_ternarylogic_epi64(
BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE);
return BitCast(d, VU{ret});
}
// ------------------------------ OrAnd
template <
typename T>
HWY_API Vec512<T> OrAnd(Vec512<T> o, Vec512<T> a1, Vec512<T> a2) {
const DFromV<decltype(o)> d;
const RebindToUnsigned<decltype(d)> du;
using VU = VFromD<decltype(du)>;
const __m512i ret = _mm512_ternarylogic_epi64(
BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8);
return BitCast(d, VU{ret});
}
// ------------------------------ IfVecThenElse
template <
typename T>
HWY_API Vec512<T> IfVecThenElse(Vec512<T> mask, Vec512<T> yes, Vec512<T> no) {
const DFromV<decltype(yes)> d;
const RebindToUnsigned<decltype(d)> du;
using VU = VFromD<decltype(du)>;
return BitCast(d, VU{_mm512_ternarylogic_epi64(BitCast(du, mask).raw,
BitCast(du, yes).raw,
BitCast(du, no).raw, 0xCA)});
}
// ------------------------------ Operator overloads (internal-only if float)
template <
typename T>
HWY_API Vec512<T>
operator&(
const Vec512<T> a,
const Vec512<T> b) {
return And(a, b);
}
template <
typename T>
HWY_API Vec512<T>
operator|(
const Vec512<T> a,
const Vec512<T> b) {
return Or(a, b);
}
template <
typename T>
HWY_API Vec512<T>
operator^(
const Vec512<T> a,
const Vec512<T> b) {
return Xor(a, b);
}
// ------------------------------ PopulationCount
// 8/16 require BITALG, 32/64 require VPOPCNTDQ.
#if HWY_TARGET <= HWY_AVX3_DL
#ifdef HWY_NATIVE_POPCNT
#undef HWY_NATIVE_POPCNT
#else
#define HWY_NATIVE_POPCNT
#endif
namespace detail {
template <
typename T>
HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<1>
/* tag */, Vec512<T> v) {
return Vec512<T>{_mm512_popcnt_epi8(v.raw)};
}
template <
typename T>
HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<2>
/* tag */, Vec512<T> v) {
return Vec512<T>{_mm512_popcnt_epi16(v.raw)};
}
template <
typename T>
HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<4>
/* tag */, Vec512<T> v) {
return Vec512<T>{_mm512_popcnt_epi32(v.raw)};
}
template <
typename T>
HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<8>
/* tag */, Vec512<T> v) {
return Vec512<T>{_mm512_popcnt_epi64(v.raw)};
}
}
// namespace detail
template <
typename T>
HWY_API Vec512<T> PopulationCount(Vec512<T> v) {
return detail::PopulationCount(hwy::SizeTag<
sizeof(T)>(), v);
}
#endif // HWY_TARGET <= HWY_AVX3_DL
// ================================================== MASK
// ------------------------------ FirstN
// Possibilities for constructing a bitmask of N ones:
// - kshift* only consider the lowest byte of the shift count, so they would
// not correctly handle large n.
// - Scalar shifts >= 64 are UB.
// - BZHI has the desired semantics; we assume AVX-512 implies BMI2. However,
// we need 64-bit masks for sizeof(T) == 1, so special-case 32-bit builds.
#if HWY_ARCH_X86_32
namespace detail {
// 32 bit mask is sufficient for lane size >= 2.
template <
typename T, HWY_IF_NOT_T_SIZE(T, 1)>
HWY_INLINE Mask512<T> FirstN(size_t n) {
Mask512<T> m;
const uint32_t all = ~uint32_t{0};
// BZHI only looks at the lower 8 bits of n, but it has been clamped to
// MaxLanes, which is at most 32.
m.raw =
static_cast<decltype(m.raw)>(_bzhi_u32(all, n));
return m;
}
#if HWY_COMPILER_MSVC >= 1920 || HWY_COMPILER_GCC_ACTUAL >= 900 || \
HWY_COMPILER_CLANG || HWY_COMPILER_ICC
template <
typename T, HWY_IF_T_SIZE(T, 1)>
HWY_INLINE Mask512<T> FirstN(size_t n) {
uint32_t lo_mask;
uint32_t hi_mask;
uint32_t hi_mask_len;
#if HWY_COMPILER_GCC
if (__builtin_constant_p(n >= 32) && n >= 32) {
if (__builtin_constant_p(n >= 64) && n >= 64) {
hi_mask_len = 32u;
}
else {
hi_mask_len =
static_cast<uint32_t>(n) - 32u;
}
lo_mask = hi_mask = 0xFFFFFFFFu;
}
else // NOLINT(readability/braces)
#endif
{
const uint32_t lo_mask_len =
static_cast<uint32_t>(n);
lo_mask = _bzhi_u32(0xFFFFFFFFu, lo_mask_len);
#if HWY_COMPILER_GCC
if (__builtin_constant_p(lo_mask_len <= 32) && lo_mask_len <= 32) {
return Mask512<T>{
static_cast<__mmask64>(lo_mask)};
}
#endif
_addcarry_u32(_subborrow_u32(0, lo_mask_len, 32u, &hi_mask_len),
0xFFFFFFFFu, 0u, &hi_mask);
}
hi_mask = _bzhi_u32(hi_mask, hi_mask_len);
#if HWY_COMPILER_GCC && !HWY_COMPILER_ICC
if (__builtin_constant_p((
static_cast<uint64_t>(hi_mask) << 32) | lo_mask))
#endif
return Mask512<T>{
static_cast<__mmask64>(
(
static_cast<uint64_t>(hi_mask) << 32) | lo_mask)};
#if HWY_COMPILER_GCC && !HWY_COMPILER_ICC
else
return Mask512<T>{_mm512_kunpackd(
static_cast<__mmask64>(hi_mask),
static_cast<__mmask64>(lo_mask))};
#endif
}
#else // HWY_COMPILER..
template <
typename T, HWY_IF_T_SIZE(T, 1)>
HWY_INLINE Mask512<T> FirstN(size_t n) {
const uint64_t bits = n < 64 ? ((1ULL << n) - 1) : ~uint64_t{0};
return Mask512<T>{
static_cast<__mmask64>(bits)};
}
#endif // HWY_COMPILER..
}
// namespace detail
#endif // HWY_ARCH_X86_32
template <
class D, HWY_IF_V_SIZE_D(D, 64)>
HWY_API MFromD<D> FirstN(D d, size_t n) {
// This ensures `num` <= 255 as required by bzhi, which only looks
// at the lower 8 bits.
n = HWY_MIN(n, MaxLanes(d));
#if HWY_ARCH_X86_64
MFromD<D> m;
const uint64_t all = ~uint64_t{0};
m.raw =
static_cast<decltype(m.raw)>(_bzhi_u64(all, n));
return m;
#else
return detail::FirstN<TFromD<D>>(n);
#endif // HWY_ARCH_X86_64
}
// ------------------------------ IfThenElse
// Returns mask ? b : a.
namespace detail {
// Templates for signed/unsigned integer of a particular size.
template <
typename T>
HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<1>
/* tag */,
const Mask512<T> mask,
const Vec512<T> yes,
const Vec512<T> no) {
return Vec512<T>{_mm512_mask_blend_epi8(mask.raw, no.raw, yes.raw)};
}
template <
typename T>
HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<2>
/* tag */,
const Mask512<T> mask,
const Vec512<T> yes,
const Vec512<T> no) {
return Vec512<T>{_mm512_mask_blend_epi16(mask.raw, no.raw, yes.raw)};
}
template <
typename T>
HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<4>
/* tag */,
const Mask512<T> mask,
const Vec512<T> yes,
const Vec512<T> no) {
return Vec512<T>{_mm512_mask_blend_epi32(mask.raw, no.raw, yes.raw)};
}
template <
typename T>
HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<8>
/* tag */,
const Mask512<T> mask,
const Vec512<T> yes,
const Vec512<T> no) {
return Vec512<T>{_mm512_mask_blend_epi64(mask.raw, no.raw, yes.raw)};
}
}
// namespace detail
template <
typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)>
HWY_API Vec512<T> IfThenElse(
const Mask512<T> mask,
const Vec512<T> yes,
const Vec512<T> no) {
return detail::IfThenElse(hwy::SizeTag<
sizeof(T)>(), mask, yes, no);
}
#if HWY_HAVE_FLOAT16
HWY_API Vec512<float16_t> IfThenElse(Mask512<float16_t> mask,
Vec512<float16_t> yes,
Vec512<float16_t> no) {
return Vec512<float16_t>{_mm512_mask_blend_ph(mask.raw, no.raw, yes.raw)};
}
#endif // HWY_HAVE_FLOAT16
HWY_API Vec512<
float> IfThenElse(Mask512<
float> mask, Vec512<
float> yes,
Vec512<
float> no) {
return Vec512<
float>{_mm512_mask_blend_ps(mask.raw, no.raw, yes.raw)};
}
HWY_API Vec512<
double> IfThenElse(Mask512<
double> mask, Vec512<
double> yes,
Vec512<
double> no) {
return Vec512<
double>{_mm512_mask_blend_pd(mask.raw, no.raw, yes.raw)};
}
namespace detail {
template <
typename T>
HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<1>
/* tag */,
const Mask512<T> mask,
const Vec512<T> yes) {
return Vec512<T>{_mm512_maskz_mov_epi8(mask.raw, yes.raw)};
}
template <
typename T>
HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<2>
/* tag */,
const Mask512<T> mask,
const Vec512<T> yes) {
return Vec512<T>{_mm512_maskz_mov_epi16(mask.raw, yes.raw)};
}
template <
typename T>
HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<4>
/* tag */,
const Mask512<T> mask,
const Vec512<T> yes) {
return Vec512<T>{_mm512_maskz_mov_epi32(mask.raw, yes.raw)};
}
template <
typename T>
HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<8>
/* tag */,
const Mask512<T> mask,
const Vec512<T> yes) {
return Vec512<T>{_mm512_maskz_mov_epi64(mask.raw, yes.raw)};
}
}
// namespace detail
template <
typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)>
HWY_API Vec512<T> IfThenElseZero(
const Mask512<T> mask,
const Vec512<T> yes) {
return detail::IfThenElseZero(hwy::SizeTag<
sizeof(T)>(), mask, yes);
}
HWY_API Vec512<
float> IfThenElseZero(Mask512<
float> mask, Vec512<
float> yes) {
return Vec512<
float>{_mm512_maskz_mov_ps(mask.raw, yes.raw)};
}
HWY_API Vec512<
double> IfThenElseZero(Mask512<
double> mask,
Vec512<
double> yes) {
return Vec512<
double>{_mm512_maskz_mov_pd(mask.raw, yes.raw)};
}
namespace detail {
template <
typename T>
HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<1>
/* tag */,
const Mask512<T> mask,
const Vec512<T> no) {
// xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16.
return Vec512<T>{_mm512_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)};
}
template <
typename T>
HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<2>
/* tag */,
const Mask512<T> mask,
const Vec512<T> no) {
return Vec512<T>{_mm512_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)};
}
template <
typename T>
HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<4>
/* tag */,
const Mask512<T> mask,
const Vec512<T> no) {
return Vec512<T>{_mm512_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)};
}
template <
typename T>
HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<8>
/* tag */,
const Mask512<T> mask,
const Vec512<T> no) {
return Vec512<T>{_mm512_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)};
}
}
// namespace detail
template <
typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)>
HWY_API Vec512<T> IfThenZeroElse(
const Mask512<T> mask,
const Vec512<T> no) {
return detail::IfThenZeroElse(hwy::SizeTag<
sizeof(T)>(), mask, no);
}
HWY_API Vec512<
float> IfThenZeroElse(Mask512<
float> mask, Vec512<
float> no) {
return Vec512<
float>{_mm512_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)};
}
HWY_API Vec512<
double> IfThenZeroElse(Mask512<
double> mask, Vec512<
double> no) {
return Vec512<
double>{_mm512_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)};
}
template <
typename T>
HWY_API Vec512<T> IfNegativeThenElse(Vec512<T> v, Vec512<T> yes, Vec512<T> no) {
static_assert(IsSigned<T>(),
"Only works for signed/float");
// AVX3 MaskFromVec only looks at the MSB
return IfThenElse(MaskFromVec(v), yes, no);
}
template <
typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T),
HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4))>
HWY_API Vec512<T> IfNegativeThenNegOrUndefIfZero(Vec512<T> mask, Vec512<T> v) {
// AVX3 MaskFromVec only looks at the MSB
const DFromV<decltype(v)> d;
return MaskedSubOr(v, MaskFromVec(mask), Zero(d), v);
}
template <
typename T, HWY_IF_FLOAT(T)>
HWY_API Vec512<T> ZeroIfNegative(
const Vec512<T> v) {
// AVX3 MaskFromVec only looks at the MSB
return IfThenZeroElse(MaskFromVec(v), v);
}
// ================================================== ARITHMETIC
// ------------------------------ Addition
// Unsigned
HWY_API Vec512<uint8_t>
operator+(Vec512<uint8_t> a, Vec512<uint8_t> b) {
return Vec512<uint8_t>{_mm512_add_epi8(a.raw, b.raw)};
}
HWY_API Vec512<uint16_t>
operator+(Vec512<uint16_t> a, Vec512<uint16_t> b) {
return Vec512<uint16_t>{_mm512_add_epi16(a.raw, b.raw)};
}
HWY_API Vec512<uint32_t>
operator+(Vec512<uint32_t> a, Vec512<uint32_t> b) {
return Vec512<uint32_t>{_mm512_add_epi32(a.raw, b.raw)};
}
HWY_API Vec512<uint64_t>
operator+(Vec512<uint64_t> a, Vec512<uint64_t> b) {
return Vec512<uint64_t>{_mm512_add_epi64(a.raw, b.raw)};
}
// Signed
HWY_API Vec512<int8_t>
operator+(Vec512<int8_t> a, Vec512<int8_t> b) {
return Vec512<int8_t>{_mm512_add_epi8(a.raw, b.raw)};
}
HWY_API Vec512<int16_t>
operator+(Vec512<int16_t> a, Vec512<int16_t> b) {
return Vec512<int16_t>{_mm512_add_epi16(a.raw, b.raw)};
}
HWY_API Vec512<int32_t>
operator+(Vec512<int32_t> a, Vec512<int32_t> b) {
return Vec512<int32_t>{_mm512_add_epi32(a.raw, b.raw)};
}
HWY_API Vec512<int64_t>
operator+(Vec512<int64_t> a, Vec512<int64_t> b) {
return Vec512<int64_t>{_mm512_add_epi64(a.raw, b.raw)};
}
// Float
#if HWY_HAVE_FLOAT16
HWY_API Vec512<float16_t>
operator+(Vec512<float16_t> a, Vec512<float16_t> b) {
return Vec512<float16_t>{_mm512_add_ph(a.raw, b.raw)};
}
#endif // HWY_HAVE_FLOAT16
HWY_API Vec512<
float>
operator+(Vec512<
float> a, Vec512<
float> b) {
return Vec512<
float>{_mm512_add_ps(a.raw, b.raw)};
}
HWY_API Vec512<
double>
operator+(Vec512<
double> a, Vec512<
double> b) {
return Vec512<
double>{_mm512_add_pd(a.raw, b.raw)};
}
// ------------------------------ Subtraction
// Unsigned
HWY_API Vec512<uint8_t> operator-(Vec512<uint8_t> a, Vec512<uint8_t> b) {
return Vec512<uint8_t>{_mm512_sub_epi8(a.raw, b.raw)};
}
HWY_API Vec512<uint16_t> operator-(Vec512<uint16_t> a, Vec512<uint16_t> b) {
return Vec512<uint16_t>{_mm512_sub_epi16(a.raw, b.raw)};
}
HWY_API Vec512<uint32_t> operator-(Vec512<uint32_t> a, Vec512<uint32_t> b) {
return Vec512<uint32_t>{_mm512_sub_epi32(a.raw, b.raw)};
}
HWY_API Vec512<uint64_t> operator-(Vec512<uint64_t> a, Vec512<uint64_t> b) {
return Vec512<uint64_t>{_mm512_sub_epi64(a.raw, b.raw)};
}
// Signed
HWY_API Vec512<int8_t> operator-(Vec512<int8_t> a, Vec512<int8_t> b) {
return Vec512<int8_t>{_mm512_sub_epi8(a.raw, b.raw)};
}
HWY_API Vec512<int16_t> operator-(Vec512<int16_t> a, Vec512<int16_t> b) {
return Vec512<int16_t>{_mm512_sub_epi16(a.raw, b.raw)};
}
HWY_API Vec512<int32_t> operator-(Vec512<int32_t> a, Vec512<int32_t> b) {
return Vec512<int32_t>{_mm512_sub_epi32(a.raw, b.raw)};
}
HWY_API Vec512<int64_t> operator-(Vec512<int64_t> a, Vec512<int64_t> b) {
return Vec512<int64_t>{_mm512_sub_epi64(a.raw, b.raw)};
}
// Float
#if HWY_HAVE_FLOAT16
HWY_API Vec512<float16_t> operator-(Vec512<float16_t> a, Vec512<float16_t> b) {
return Vec512<float16_t>{_mm512_sub_ph(a.raw, b.raw)};
}
#endif // HWY_HAVE_FLOAT16
HWY_API Vec512<
float> operator-(Vec512<
float> a, Vec512<
float> b) {
return Vec512<
float>{_mm512_sub_ps(a.raw, b.raw)};
}
HWY_API Vec512<
double> operator-(Vec512<
double> a, Vec512<
double> b) {
return Vec512<
double>{_mm512_sub_pd(a.raw, b.raw)};
}
// ------------------------------ SumsOf8
HWY_API Vec512<uint64_t> SumsOf8(
const Vec512<uint8_t> v) {
const Full512<uint8_t> d;
return Vec512<uint64_t>{_mm512_sad_epu8(v.raw, Zero(d).raw)};
}
HWY_API Vec512<uint64_t> SumsOf8AbsDiff(Vec512<uint8_t> a, Vec512<uint8_t> b) {
return Vec512<uint64_t>{_mm512_sad_epu8(a.raw, b.raw)};
}
// ------------------------------ SumsOf4
namespace detail {
HWY_INLINE Vec512<uint32_t> SumsOf4(hwy::UnsignedTag
/*type_tag*/,
hwy::SizeTag<1>
/*lane_size_tag*/,
Vec512<uint8_t> v) {
const DFromV<decltype(v)> d;
// _mm512_maskz_dbsad_epu8 is used below as the odd uint16_t lanes need to be
// zeroed out and the sums of the 4 consecutive lanes are already in the
// even uint16_t lanes of the _mm512_maskz_dbsad_epu8 result.
return Vec512<uint32_t>{_mm512_maskz_dbsad_epu8(
static_cast<__mmask32>(0x55555555), v.raw, Zero(d).raw, 0)};
}
// I8->I32 SumsOf4
// Generic for all vector lengths
template <
class V>
HWY_INLINE VFromD<RepartitionToWideX2<DFromV<V>>> SumsOf4(
hwy::SignedTag
/*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, V v) {
const DFromV<decltype(v)> d;
const RebindToUnsigned<decltype(d)> du;
const RepartitionToWideX2<decltype(d)> di32;
// Adjust the values of v to be in the 0..255 range by adding 128 to each lane
// of v (which is the same as an bitwise XOR of each i8 lane by 128) and then
// bitcasting the Xor result to an u8 vector.
const auto v_adj = BitCast(du,
Xor(v, SignBit(d)));
// Need to add -512 to each i32 lane of the result of the
// SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v_adj) operation to account
// for the adjustment made above.
return BitCast(di32, SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v_adj)) +
Set(di32, int32_t{-512});
}
}
// namespace detail
// ------------------------------ SumsOfShuffledQuadAbsDiff
#if HWY_TARGET <= HWY_AVX3
template <
int kIdx3,
int kIdx2,
int kIdx1,
int kIdx0>
static Vec512<uint16_t> SumsOfShuffledQuadAbsDiff(Vec512<uint8_t> a,
Vec512<uint8_t> b) {
static_assert(0 <= kIdx0 && kIdx0 <= 3,
"kIdx0 must be between 0 and 3");
static_assert(0 <= kIdx1 && kIdx1 <= 3,
"kIdx1 must be between 0 and 3");
static_assert(0 <= kIdx2 && kIdx2 <= 3,
"kIdx2 must be between 0 and 3");
static_assert(0 <= kIdx3 && kIdx3 <= 3,
"kIdx3 must be between 0 and 3");
return Vec512<uint16_t>{
_mm512_dbsad_epu8(b.raw, a.raw, _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0))};
}
#endif
// ------------------------------ SaturatedAdd
// Returns a + b clamped to the destination range.
// Unsigned
HWY_API Vec512<uint8_t> SaturatedAdd(Vec512<uint8_t> a, Vec512<uint8_t> b) {
return Vec512<uint8_t>{_mm512_adds_epu8(a.raw, b.raw)};
}
HWY_API Vec512<uint16_t> SaturatedAdd(Vec512<uint16_t> a, Vec512<uint16_t> b) {
return Vec512<uint16_t>{_mm512_adds_epu16(a.raw, b.raw)};
}
// Signed
HWY_API Vec512<int8_t> SaturatedAdd(Vec512<int8_t> a, Vec512<int8_t> b) {
return Vec512<int8_t>{_mm512_adds_epi8(a.raw, b.raw)};
}
HWY_API Vec512<int16_t> SaturatedAdd(Vec512<int16_t> a, Vec512<int16_t> b) {
return Vec512<int16_t>{_mm512_adds_epi16(a.raw, b.raw)};
}
// ------------------------------ SaturatedSub
// Returns a - b clamped to the destination range.
// Unsigned
HWY_API Vec512<uint8_t> SaturatedSub(Vec512<uint8_t> a, Vec512<uint8_t> b) {
return Vec512<uint8_t>{_mm512_subs_epu8(a.raw, b.raw)};
}
HWY_API Vec512<uint16_t> SaturatedSub(Vec512<uint16_t> a, Vec512<uint16_t> b) {
return Vec512<uint16_t>{_mm512_subs_epu16(a.raw, b.raw)};
}
// Signed
HWY_API Vec512<int8_t> SaturatedSub(Vec512<int8_t> a, Vec512<int8_t> b) {
return Vec512<int8_t>{_mm512_subs_epi8(a.raw, b.raw)};
}
HWY_API Vec512<int16_t> SaturatedSub(Vec512<int16_t> a, Vec512<int16_t> b) {
return Vec512<int16_t>{_mm512_subs_epi16(a.raw, b.raw)};
}
// ------------------------------ Average
// Returns (a + b + 1) / 2
// Unsigned
HWY_API Vec512<uint8_t> AverageRound(Vec512<uint8_t> a, Vec512<uint8_t> b) {
return Vec512<uint8_t>{_mm512_avg_epu8(a.raw, b.raw)};
}
HWY_API Vec512<uint16_t> AverageRound(Vec512<uint16_t> a, Vec512<uint16_t> b) {
return Vec512<uint16_t>{_mm512_avg_epu16(a.raw, b.raw)};
}
// ------------------------------ Abs (Sub)
// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1.
HWY_API Vec512<int8_t> Abs(
const Vec512<int8_t> v) {
#if HWY_COMPILER_MSVC
// Workaround for incorrect codegen? (untested due to internal compiler error)
const DFromV<decltype(v)> d;
const auto zero = Zero(d);
return Vec512<int8_t>{_mm512_max_epi8(v.raw, (zero - v).raw)};
#else
return Vec512<int8_t>{_mm512_abs_epi8(v.raw)};
#endif
}
HWY_API Vec512<int16_t> Abs(
const Vec512<int16_t> v) {
return Vec512<int16_t>{_mm512_abs_epi16(v.raw)};
}
HWY_API Vec512<int32_t> Abs(
const Vec512<int32_t> v) {
return Vec512<int32_t>{_mm512_abs_epi32(v.raw)};
}
HWY_API Vec512<int64_t> Abs(
const Vec512<int64_t> v) {
return Vec512<int64_t>{_mm512_abs_epi64(v.raw)};
}
// ------------------------------ ShiftLeft
#if HWY_TARGET <= HWY_AVX3_DL
namespace detail {
template <
typename T>
HWY_API Vec512<T> GaloisAffine(Vec512<T> v, Vec512<uint64_t> matrix) {
return Vec512<T>{_mm512_gf2p8affine_epi64_epi8(v.raw, matrix.raw, 0)};
}
}
// namespace detail
#endif // HWY_TARGET <= HWY_AVX3_DL
template <
int kBits>
HWY_API Vec512<uint16_t> ShiftLeft(
const Vec512<uint16_t> v) {
return Vec512<uint16_t>{_mm512_slli_epi16(v.raw, kBits)};
}
template <
int kBits>
HWY_API Vec512<uint32_t> ShiftLeft(
const Vec512<uint32_t> v) {
return Vec512<uint32_t>{_mm512_slli_epi32(v.raw, kBits)};
}
template <
int kBits>
HWY_API Vec512<uint64_t> ShiftLeft(
const Vec512<uint64_t> v) {
return Vec512<uint64_t>{_mm512_slli_epi64(v.raw, kBits)};
}
template <
int kBits>
HWY_API Vec512<int16_t> ShiftLeft(
const Vec512<int16_t> v) {
return Vec512<int16_t>{_mm512_slli_epi16(v.raw, kBits)};
}
template <
int kBits>
HWY_API Vec512<int32_t> ShiftLeft(
const Vec512<int32_t> v) {
return Vec512<int32_t>{_mm512_slli_epi32(v.raw, kBits)};
}
template <
int kBits>
HWY_API Vec512<int64_t> ShiftLeft(
const Vec512<int64_t> v) {
return Vec512<int64_t>{_mm512_slli_epi64(v.raw, kBits)};
}
#if HWY_TARGET <= HWY_AVX3_DL
// Generic for all vector lengths. Must be defined after all GaloisAffine.
template <
int kBits,
class V, HWY_IF_T_SIZE_V(V, 1)>
HWY_API V ShiftLeft(
const V v) {
const Repartition<uint64_t, DFromV<V>> du64;
if (kBits == 0)
return v;
if (kBits == 1)
return v + v;
constexpr uint64_t kMatrix = (0x0102040810204080ULL >> kBits) &
(0x0101010101010101ULL * (0xFF >> kBits));
return detail::GaloisAffine(v, Set(du64, kMatrix));
}
#else // HWY_TARGET > HWY_AVX3_DL
template <
int kBits,
typename T, HWY_IF_T_SIZE(T, 1)>
HWY_API Vec512<T> ShiftLeft(
const Vec512<T> v) {
const DFromV<decltype(v)> d8;
const RepartitionToWide<decltype(d8)> d16;
const auto shifted = BitCast(d8, ShiftLeft<kBits>(BitCast(d16, v)));
return kBits == 1
? (v + v)
: (shifted & Set(d8,
static_cast<T>((0xFF << kBits) & 0xFF)));
}
#endif // HWY_TARGET > HWY_AVX3_DL
// ------------------------------ ShiftRight
template <
int kBits>
HWY_API Vec512<uint16_t> ShiftRight(
const Vec512<uint16_t> v) {
return Vec512<uint16_t>{_mm512_srli_epi16(v.raw, kBits)};
}
template <
int kBits>
HWY_API Vec512<uint32_t> ShiftRight(
const Vec512<uint32_t> v) {
return Vec512<uint32_t>{_mm512_srli_epi32(v.raw, kBits)};
}
template <
int kBits>
HWY_API Vec512<uint64_t> ShiftRight(
const Vec512<uint64_t> v) {
return Vec512<uint64_t>{_mm512_srli_epi64(v.raw, kBits)};
}
template <
int kBits>
HWY_API Vec512<int16_t> ShiftRight(
const Vec512<int16_t> v) {
return Vec512<int16_t>{_mm512_srai_epi16(v.raw, kBits)};
}
template <
int kBits>
HWY_API Vec512<int32_t> ShiftRight(
const Vec512<int32_t> v) {
return Vec512<int32_t>{_mm512_srai_epi32(v.raw, kBits)};
}
template <
int kBits>
HWY_API Vec512<int64_t> ShiftRight(
const Vec512<int64_t> v) {
return Vec512<int64_t>{_mm512_srai_epi64(v.raw, kBits)};
}
#if HWY_TARGET <= HWY_AVX3_DL
// Generic for all vector lengths. Must be defined after all GaloisAffine.
template <
int kBits,
class V, HWY_IF_U8_D(DFromV<V>)>
HWY_API V ShiftRight(
const V v) {
const Repartition<uint64_t, DFromV<V>> du64;
if (kBits == 0)
return v;
constexpr uint64_t kMatrix =
(0x0102040810204080ULL << kBits) &
(0x0101010101010101ULL * ((0xFF << kBits) & 0xFF));
return detail::GaloisAffine(v, Set(du64, kMatrix));
}
// Generic for all vector lengths. Must be defined after all GaloisAffine.
template <
int kBits,
class V, HWY_IF_I8_D(DFromV<V>)>
HWY_API V ShiftRight(
const V v) {
const Repartition<uint64_t, DFromV<V>> du64;
if (kBits == 0)
return v;
constexpr uint64_t kShift =
(0x0102040810204080ULL << kBits) &
(0x0101010101010101ULL * ((0xFF << kBits) & 0xFF));
constexpr uint64_t kSign =
kBits == 0 ? 0 : (0x8080808080808080ULL >> (64 - (8 * kBits)));
return detail::GaloisAffine(v, Set(du64, kShift | kSign));
}
#else // HWY_TARGET > HWY_AVX3_DL
template <
int kBits>
HWY_API Vec512<uint8_t> ShiftRight(
const Vec512<uint8_t> v) {
const DFromV<decltype(v)> d8;
// Use raw instead of BitCast to support N=1.
const Vec512<uint8_t> shifted{ShiftRight<kBits>(Vec512<uint16_t>{v.raw}).raw};
return shifted & Set(d8, 0xFF >> kBits);
}
template <
int kBits>
HWY_API Vec512<int8_t> ShiftRight(
const Vec512<int8_t> v) {
const DFromV<decltype(v)> di;
const RebindToUnsigned<decltype(di)> du;
const auto shifted = BitCast(di, ShiftRight<kBits>(BitCast(du, v)));
const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits));
return (shifted ^ shifted_sign) - shifted_sign;
}
#endif // HWY_TARGET > HWY_AVX3_DL
// ------------------------------ RotateRight
template <
int kBits,
typename T, HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2))>
HWY_API Vec512<T> RotateRight(
const Vec512<T> v) {
constexpr size_t kSizeInBits =
sizeof(T) * 8;
static_assert(0 <= kBits && kBits < kSizeInBits,
"Invalid shift count");
if (kBits == 0)
return v;
// AVX3 does not support 8/16-bit.
return Or(ShiftRight<kBits>(v),
ShiftLeft<HWY_MIN(kSizeInBits - 1, kSizeInBits - kBits)>(v));
}
template <
int kBits>
HWY_API Vec512<uint32_t> RotateRight(
const Vec512<uint32_t> v) {
static_assert(0 <= kBits && kBits < 32,
"Invalid shift count");
if (kBits == 0)
return v;
return Vec512<uint32_t>{_mm512_ror_epi32(v.raw, kBits)};
}
template <
int kBits>
HWY_API Vec512<uint64_t> RotateRight(
const Vec512<uint64_t> v) {
static_assert(0 <= kBits && kBits < 64,
"Invalid shift count");
if (kBits == 0)
return v;
return Vec512<uint64_t>{_mm512_ror_epi64(v.raw, kBits)};
}
// ------------------------------ ShiftLeftSame
// GCC <14 and Clang <11 do not follow the Intel documentation for AVX-512
// shift-with-immediate: the counts should all be unsigned int.
#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1100
using Shift16Count =
int;
using Shift3264Count =
int;
#elif HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400
// GCC 11.0 requires these, prior versions used a macro+cast and don't care.
using Shift16Count =
int;
using Shift3264Count =
unsigned int;
#else
// Assume documented behavior. Clang 11, GCC 14 and MSVC 14.28.29910 match this.
using Shift16Count =
unsigned int;
using Shift3264Count =
unsigned int;
#endif
HWY_API Vec512<uint16_t> ShiftLeftSame(
const Vec512<uint16_t> v,
const int bits) {
#if HWY_COMPILER_GCC
if (__builtin_constant_p(bits)) {
return Vec512<uint16_t>{
_mm512_slli_epi16(v.raw,
static_cast<Shift16Count>(bits))};
}
#endif
return Vec512<uint16_t>{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<uint32_t> ShiftLeftSame(
const Vec512<uint32_t> v,
const int bits) {
#if HWY_COMPILER_GCC
if (__builtin_constant_p(bits)) {
return Vec512<uint32_t>{
_mm512_slli_epi32(v.raw,
static_cast<Shift3264Count>(bits))};
}
#endif
return Vec512<uint32_t>{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<uint64_t> ShiftLeftSame(
const Vec512<uint64_t> v,
const int bits) {
#if HWY_COMPILER_GCC
if (__builtin_constant_p(bits)) {
return Vec512<uint64_t>{
_mm512_slli_epi64(v.raw,
static_cast<Shift3264Count>(bits))};
}
#endif
return Vec512<uint64_t>{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<int16_t> ShiftLeftSame(
const Vec512<int16_t> v,
const int bits) {
#if HWY_COMPILER_GCC
if (__builtin_constant_p(bits)) {
return Vec512<int16_t>{
_mm512_slli_epi16(v.raw,
static_cast<Shift16Count>(bits))};
}
#endif
return Vec512<int16_t>{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<int32_t> ShiftLeftSame(
const Vec512<int32_t> v,
const int bits) {
#if HWY_COMPILER_GCC
if (__builtin_constant_p(bits)) {
return Vec512<int32_t>{
_mm512_slli_epi32(v.raw,
static_cast<Shift3264Count>(bits))};
}
#endif
return Vec512<int32_t>{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<int64_t> ShiftLeftSame(
const Vec512<int64_t> v,
const int bits) {
#if HWY_COMPILER_GCC
if (__builtin_constant_p(bits)) {
return Vec512<int64_t>{
_mm512_slli_epi64(v.raw,
static_cast<Shift3264Count>(bits))};
}
#endif
return Vec512<int64_t>{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))};
}
template <
typename T, HWY_IF_T_SIZE(T, 1)>
HWY_API Vec512<T> ShiftLeftSame(
const Vec512<T> v,
const int bits) {
const DFromV<decltype(v)> d8;
const RepartitionToWide<decltype(d8)> d16;
const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits));
return shifted & Set(d8,
static_cast<T>((0xFF << bits) & 0xFF));
}
// ------------------------------ ShiftRightSame
HWY_API Vec512<uint16_t> ShiftRightSame(
const Vec512<uint16_t> v,
const int bits) {
#if HWY_COMPILER_GCC
if (__builtin_constant_p(bits)) {
return Vec512<uint16_t>{
_mm512_srli_epi16(v.raw,
static_cast<Shift16Count>(bits))};
}
#endif
return Vec512<uint16_t>{_mm512_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<uint32_t> ShiftRightSame(
const Vec512<uint32_t> v,
const int bits) {
#if HWY_COMPILER_GCC
if (__builtin_constant_p(bits)) {
return Vec512<uint32_t>{
_mm512_srli_epi32(v.raw,
static_cast<Shift3264Count>(bits))};
}
#endif
return Vec512<uint32_t>{_mm512_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<uint64_t> ShiftRightSame(
const Vec512<uint64_t> v,
const int bits) {
#if HWY_COMPILER_GCC
if (__builtin_constant_p(bits)) {
return Vec512<uint64_t>{
_mm512_srli_epi64(v.raw,
static_cast<Shift3264Count>(bits))};
}
#endif
return Vec512<uint64_t>{_mm512_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<uint8_t> ShiftRightSame(Vec512<uint8_t> v,
const int bits) {
const DFromV<decltype(v)> d8;
const RepartitionToWide<decltype(d8)> d16;
const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits));
return shifted & Set(d8,
static_cast<uint8_t>(0xFF >> bits));
}
HWY_API Vec512<int16_t> ShiftRightSame(
const Vec512<int16_t> v,
const int bits) {
#if HWY_COMPILER_GCC
if (__builtin_constant_p(bits)) {
return Vec512<int16_t>{
_mm512_srai_epi16(v.raw,
static_cast<Shift16Count>(bits))};
}
#endif
return Vec512<int16_t>{_mm512_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<int32_t> ShiftRightSame(
const Vec512<int32_t> v,
const int bits) {
#if HWY_COMPILER_GCC
if (__builtin_constant_p(bits)) {
return Vec512<int32_t>{
_mm512_srai_epi32(v.raw,
static_cast<Shift3264Count>(bits))};
}
#endif
return Vec512<int32_t>{_mm512_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<int64_t> ShiftRightSame(
const Vec512<int64_t> v,
const int bits) {
#if HWY_COMPILER_GCC
if (__builtin_constant_p(bits)) {
return Vec512<int64_t>{
_mm512_srai_epi64(v.raw,
static_cast<Shift3264Count>(bits))};
}
#endif
return Vec512<int64_t>{_mm512_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<int8_t> ShiftRightSame(Vec512<int8_t> v,
const int bits) {
const DFromV<decltype(v)> di;
const RebindToUnsigned<decltype(di)> du;
const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits));
const auto shifted_sign =
BitCast(di, Set(du,
static_cast<uint8_t>(0x80 >> bits)));
return (shifted ^ shifted_sign) - shifted_sign;
}
// ------------------------------ Minimum
// Unsigned
HWY_API Vec512<uint8_t> Min(Vec512<uint8_t> a, Vec512<uint8_t> b) {
return Vec512<uint8_t>{_mm512_min_epu8(a.raw, b.raw)};
}
HWY_API Vec512<uint16_t> Min(Vec512<uint16_t> a, Vec512<uint16_t> b) {
return Vec512<uint16_t>{_mm512_min_epu16(a.raw, b.raw)};
}
HWY_API Vec512<uint32_t> Min(Vec512<uint32_t> a, Vec512<uint32_t> b) {
return Vec512<uint32_t>{_mm512_min_epu32(a.raw, b.raw)};
}
HWY_API Vec512<uint64_t> Min(Vec512<uint64_t> a, Vec512<uint64_t> b) {
return Vec512<uint64_t>{_mm512_min_epu64(a.raw, b.raw)};
}
// Signed
HWY_API Vec512<int8_t> Min(Vec512<int8_t> a, Vec512<int8_t> b) {
return Vec512<int8_t>{_mm512_min_epi8(a.raw, b.raw)};
}
HWY_API Vec512<int16_t> Min(Vec512<int16_t> a, Vec512<int16_t> b) {
return Vec512<int16_t>{_mm512_min_epi16(a.raw, b.raw)};
}
HWY_API Vec512<int32_t> Min(Vec512<int32_t> a, Vec512<int32_t> b) {
return Vec512<int32_t>{_mm512_min_epi32(a.raw, b.raw)};
}
HWY_API Vec512<int64_t> Min(Vec512<int64_t> a, Vec512<int64_t> b) {
return Vec512<int64_t>{_mm512_min_epi64(a.raw, b.raw)};
}
// Float
#if HWY_HAVE_FLOAT16
HWY_API Vec512<float16_t> Min(Vec512<float16_t> a, Vec512<float16_t> b) {
return Vec512<float16_t>{_mm512_min_ph(a.raw, b.raw)};
}
#endif // HWY_HAVE_FLOAT16
HWY_API Vec512<
float> Min(Vec512<
float> a, Vec512<
float> b) {
return Vec512<
float>{_mm512_min_ps(a.raw, b.raw)};
}
HWY_API Vec512<
double> Min(Vec512<
double> a, Vec512<
double> b) {
return Vec512<
double>{_mm512_min_pd(a.raw, b.raw)};
}
// ------------------------------ Maximum
// Unsigned
HWY_API Vec512<uint8_t> Max(Vec512<uint8_t> a, Vec512<uint8_t> b) {
return Vec512<uint8_t>{_mm512_max_epu8(a.raw, b.raw)};
}
HWY_API Vec512<uint16_t> Max(Vec512<uint16_t> a, Vec512<uint16_t> b) {
return Vec512<uint16_t>{_mm512_max_epu16(a.raw, b.raw)};
}
HWY_API Vec512<uint32_t> Max(Vec512<uint32_t> a, Vec512<uint32_t> b) {
return Vec512<uint32_t>{_mm512_max_epu32(a.raw, b.raw)};
}
HWY_API Vec512<uint64_t> Max(Vec512<uint64_t> a, Vec512<uint64_t> b) {
return Vec512<uint64_t>{_mm512_max_epu64(a.raw, b.raw)};
}
// Signed
HWY_API Vec512<int8_t> Max(Vec512<int8_t> a, Vec512<int8_t> b) {
return Vec512<int8_t>{_mm512_max_epi8(a.raw, b.raw)};
}
HWY_API Vec512<int16_t> Max(Vec512<int16_t> a, Vec512<int16_t> b) {
return Vec512<int16_t>{_mm512_max_epi16(a.raw, b.raw)};
}
HWY_API Vec512<int32_t> Max(Vec512<int32_t> a, Vec512<int32_t> b) {
return Vec512<int32_t>{_mm512_max_epi32(a.raw, b.raw)};
}
HWY_API Vec512<int64_t> Max(Vec512<int64_t> a, Vec512<int64_t> b) {
return Vec512<int64_t>{_mm512_max_epi64(a.raw, b.raw)};
}
// Float
#if HWY_HAVE_FLOAT16
HWY_API Vec512<float16_t> Max(Vec512<float16_t> a, Vec512<float16_t> b) {
return Vec512<float16_t>{_mm512_max_ph(a.raw, b.raw)};
}
#endif // HWY_HAVE_FLOAT16
HWY_API Vec512<
float> Max(Vec512<
float> a, Vec512<
float> b) {
return Vec512<
float>{_mm512_max_ps(a.raw, b.raw)};
}
HWY_API Vec512<
double> Max(Vec512<
double> a, Vec512<
double> b) {
return Vec512<
double>{_mm512_max_pd(a.raw, b.raw)};
}
// ------------------------------ Integer multiplication
// Per-target flag to prevent generic_ops-inl.h from defining 64-bit operator*.
#ifdef HWY_NATIVE_MUL_64
#undef HWY_NATIVE_MUL_64
#else
#define HWY_NATIVE_MUL_64
#endif
// Unsigned
HWY_API Vec512<uint16_t>
operator*(Vec512<uint16_t> a, Vec512<uint16_t> b) {
return Vec512<uint16_t>{_mm512_mullo_epi16(a.raw, b.raw)};
}
HWY_API Vec512<uint32_t>
operator*(Vec512<uint32_t> a, Vec512<uint32_t> b) {
return Vec512<uint32_t>{_mm512_mullo_epi32(a.raw, b.raw)};
}
HWY_API Vec512<uint64_t>
operator*(Vec512<uint64_t> a, Vec512<uint64_t> b) {
return Vec512<uint64_t>{_mm512_mullo_epi64(a.raw, b.raw)};
}
HWY_API Vec256<uint64_t>
operator*(Vec256<uint64_t> a, Vec256<uint64_t> b) {
return Vec256<uint64_t>{_mm256_mullo_epi64(a.raw, b.raw)};
}
template <size_t N>
HWY_API Vec128<uint64_t, N>
operator*(Vec128<uint64_t, N> a,
Vec128<uint64_t, N> b) {
return Vec128<uint64_t, N>{_mm_mullo_epi64(a.raw, b.raw)};
}
// Signed
HWY_API Vec512<int16_t>
operator*(Vec512<int16_t> a, Vec512<int16_t> b) {
return Vec512<int16_t>{_mm512_mullo_epi16(a.raw, b.raw)};
}
HWY_API Vec512<int32_t>
operator*(Vec512<int32_t> a, Vec512<int32_t> b) {
return Vec512<int32_t>{_mm512_mullo_epi32(a.raw, b.raw)};
}
HWY_API Vec512<int64_t>
operator*(Vec512<int64_t> a, Vec512<int64_t> b) {
return Vec512<int64_t>{_mm512_mullo_epi64(a.raw, b.raw)};
}
HWY_API Vec256<int64_t>
operator*(Vec256<int64_t> a, Vec256<int64_t> b) {
return Vec256<int64_t>{_mm256_mullo_epi64(a.raw, b.raw)};
}
template <size_t N>
HWY_API Vec128<int64_t, N>
operator*(Vec128<int64_t, N> a,
Vec128<int64_t, N> b) {
return Vec128<int64_t, N>{_mm_mullo_epi64(a.raw, b.raw)};
}
// Returns the upper 16 bits of a * b in each lane.
HWY_API Vec512<uint16_t> MulHigh(Vec512<uint16_t> a, Vec512<uint16_t> b) {
return Vec512<uint16_t>{_mm512_mulhi_epu16(a.raw, b.raw)};
}
HWY_API Vec512<int16_t> MulHigh(Vec512<int16_t> a, Vec512<int16_t> b) {
return Vec512<int16_t>{_mm512_mulhi_epi16(a.raw, b.raw)};
}
HWY_API Vec512<int16_t> MulFixedPoint15(Vec512<int16_t> a, Vec512<int16_t> b) {
return Vec512<int16_t>{_mm512_mulhrs_epi16(a.raw, b.raw)};
}
// Multiplies even lanes (0, 2 ..) and places the double-wide result into
// even and the upper half into its odd neighbor lane.
HWY_API Vec512<int64_t> MulEven(Vec512<int32_t> a, Vec512<int32_t> b) {
return Vec512<int64_t>{_mm512_mul_epi32(a.raw, b.raw)};
}
HWY_API Vec512<uint64_t> MulEven(Vec512<uint32_t> a, Vec512<uint32_t> b) {
return Vec512<uint64_t>{_mm512_mul_epu32(a.raw, b.raw)};
}
// ------------------------------ Neg (Sub)
--> --------------------
--> maximum size reached
--> --------------------