// Copyright 2019 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// 256-bit vectors and AVX2 instructions, plus some AVX512-VL operations when
// compiling for that target.
// External include guard in highway.h - see comment there.
// WARNING: most operations do not cross 128-bit block boundaries. In
// particular, "Broadcast", pack and zip behavior may be surprising.
// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL
#include "hwy/base.h"
// Avoid uninitialized warnings in GCC's avx512fintrin.h - see
// https://github.com/google/highway/issues/710)
HWY_DIAGNOSTICS(push)
#if HWY_COMPILER_GCC_ACTUAL
HWY_DIAGNOSTICS_OFF(disable : 4700, ignored
"-Wuninitialized")
HWY_DIAGNOSTICS_OFF(disable : 4701 4703 6001 26494,
ignored
"-Wmaybe-uninitialized")
#endif
// Must come before HWY_COMPILER_CLANGCL
#include <immintrin.h>
// AVX2+
#if HWY_COMPILER_CLANGCL
// Including <immintrin.h> should be enough, but Clang's headers helpfully skip
// including these headers when _MSC_VER is defined, like when using clang-cl.
// Include these directly here.
#include <avxintrin.h>
// avxintrin defines __m256i and must come before avx2intrin.
#include <avx2intrin.h>
#include <bmi2intrin.h>
// _pext_u64
#include <f16cintrin.h>
#include <fmaintrin.h>
#include <smmintrin.h>
#endif // HWY_COMPILER_CLANGCL
// For half-width vectors. Already includes base.h.
#include "hwy/ops/shared-inl.h"
// Already included by shared-inl, but do it again to avoid IDE warnings.
#include "hwy/ops/x86_128-inl.h"
HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {
namespace detail {
template <
typename T>
struct Raw256 {
using type = __m256i;
};
#if HWY_HAVE_FLOAT16
template <>
struct Raw256<float16_t> {
using type = __m256h;
};
#endif // HWY_HAVE_FLOAT16
template <>
struct Raw256<
float> {
using type = __m256;
};
template <>
struct Raw256<
double> {
using type = __m256d;
};
}
// namespace detail
template <
typename T>
class Vec256 {
using Raw =
typename detail::Raw256<T>::type;
public:
using PrivateT = T;
// only for DFromV
static constexpr size_t kPrivateN = 32 /
sizeof(T);
// only for DFromV
// Compound assignment. Only usable if there is a corresponding non-member
// binary operator overload. For example, only f32 and f64 support division.
HWY_INLINE Vec256&
operator*=(
const Vec256 other) {
return *
this = (*
this * other);
}
HWY_INLINE Vec256&
operator/=(
const Vec256 other) {
return *
this = (*
this / other);
}
HWY_INLINE Vec256&
operator+=(
const Vec256 other) {
return *
this = (*
this + other);
}
HWY_INLINE Vec256& operator-=(
const Vec256 other) {
return *
this = (*
this - other);
}
HWY_INLINE Vec256&
operator%=(
const Vec256 other) {
return *
this = (*
this % other);
}
HWY_INLINE Vec256&
operator&=(
const Vec256 other) {
return *
this = (*
this & other);
}
HWY_INLINE Vec256&
operator|=(
const Vec256 other) {
return *
this = (*
this | other);
}
HWY_INLINE Vec256&
operator^=(
const Vec256 other) {
return *
this = (*
this ^ other);
}
Raw raw;
};
#if HWY_TARGET <= HWY_AVX3
namespace detail {
// Template arg: sizeof(lane type)
template <size_t size>
struct RawMask256 {};
template <>
struct RawMask256<1> {
using type = __mmask32;
};
template <>
struct RawMask256<2> {
using type = __mmask16;
};
template <>
struct RawMask256<4> {
using type = __mmask8;
};
template <>
struct RawMask256<8> {
using type = __mmask8;
};
}
// namespace detail
template <
typename T>
struct Mask256 {
using Raw =
typename detail::RawMask256<
sizeof(T)>::type;
static Mask256<T> FromBits(uint64_t mask_bits) {
return Mask256<T>{
static_cast<Raw>(mask_bits)};
}
Raw raw;
};
#else // AVX2
// FF..FF or 0.
template <
typename T>
struct Mask256 {
typename detail::Raw256<T>::type raw;
};
#endif // AVX2
#if HWY_TARGET <= HWY_AVX3
namespace detail {
// Used by Expand() emulation, which is required for both AVX3 and AVX2.
template <
typename T>
HWY_INLINE uint64_t BitsFromMask(
const Mask256<T> mask) {
return mask.raw;
}
}
// namespace detail
#endif // HWY_TARGET <= HWY_AVX3
template <
typename T>
using Full256 = Simd<T, 32 /
sizeof(T), 0>;
// ------------------------------ BitCast
namespace detail {
HWY_INLINE __m256i BitCastToInteger(__m256i v) {
return v; }
#if HWY_HAVE_FLOAT16
HWY_INLINE __m256i BitCastToInteger(__m256h v) {
return _mm256_castph_si256(v);
}
#endif // HWY_HAVE_FLOAT16
HWY_INLINE __m256i BitCastToInteger(__m256 v) {
return _mm256_castps_si256(v); }
HWY_INLINE __m256i BitCastToInteger(__m256d v) {
return _mm256_castpd_si256(v);
}
template <
typename T>
HWY_INLINE Vec256<uint8_t> BitCastToByte(Vec256<T> v) {
return Vec256<uint8_t>{BitCastToInteger(v.raw)};
}
// Cannot rely on function overloading because return types differ.
template <
typename T>
struct BitCastFromInteger256 {
HWY_INLINE __m256i
operator()(__m256i v) {
return v; }
};
#if HWY_HAVE_FLOAT16
template <>
struct BitCastFromInteger256<float16_t> {
HWY_INLINE __m256h
operator()(__m256i v) {
return _mm256_castsi256_ph(v); }
};
#endif // HWY_HAVE_FLOAT16
template <>
struct BitCastFromInteger256<
float> {
HWY_INLINE __m256
operator()(__m256i v) {
return _mm256_castsi256_ps(v); }
};
template <>
struct BitCastFromInteger256<
double> {
HWY_INLINE __m256d
operator()(__m256i v) {
return _mm256_castsi256_pd(v); }
};
template <
class D, HWY_IF_V_SIZE_D(D, 32)>
HWY_INLINE VFromD<D> BitCastFromByte(D
/* tag */, Vec256<uint8_t> v) {
return VFromD<D>{BitCastFromInteger256<TFromD<D>>()(v.raw)};
}
}
// namespace detail
template <
class D, HWY_IF_V_SIZE_D(D, 32),
typename FromT>
HWY_API VFromD<D> BitCast(D d, Vec256<FromT> v) {
return detail::BitCastFromByte(d, detail::BitCastToByte(v));
}
// ------------------------------ Zero
// Cannot use VFromD here because it is defined in terms of Zero.
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)>
HWY_API Vec256<TFromD<D>> Zero(D
/* tag */) {
return Vec256<TFromD<D>>{_mm256_setzero_si256()};
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_BF16_D(D)>
HWY_API Vec256<bfloat16_t> Zero(D
/* tag */) {
return Vec256<bfloat16_t>{_mm256_setzero_si256()};
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)>
HWY_API Vec256<float16_t> Zero(D
/* tag */) {
#if HWY_HAVE_FLOAT16
return Vec256<float16_t>{_mm256_setzero_ph()};
#else
return Vec256<float16_t>{_mm256_setzero_si256()};
#endif // HWY_HAVE_FLOAT16
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)>
HWY_API Vec256<
float> Zero(D
/* tag */) {
return Vec256<
float>{_mm256_setzero_ps()};
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
HWY_API Vec256<
double> Zero(D
/* tag */) {
return Vec256<
double>{_mm256_setzero_pd()};
}
// ------------------------------ Set
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)>
HWY_API VFromD<D> Set(D
/* tag */, TFromD<D> t) {
return VFromD<D>{_mm256_set1_epi8(
static_cast<
char>(t))};
// NOLINT
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI16_D(D)>
HWY_API VFromD<D> Set(D
/* tag */, TFromD<D> t) {
return VFromD<D>{_mm256_set1_epi16(
static_cast<
short>(t))};
// NOLINT
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)>
HWY_API VFromD<D> Set(D
/* tag */, TFromD<D> t) {
return VFromD<D>{_mm256_set1_epi32(
static_cast<
int>(t))};
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)>
HWY_API VFromD<D> Set(D
/* tag */, TFromD<D> t) {
return VFromD<D>{_mm256_set1_epi64x(
static_cast<
long long>(t))};
// NOLINT
}
// bfloat16_t is handled by x86_128-inl.h.
#if HWY_HAVE_FLOAT16
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)>
HWY_API Vec256<float16_t> Set(D
/* tag */, float16_t t) {
return Vec256<float16_t>{_mm256_set1_ph(t)};
}
#endif // HWY_HAVE_FLOAT16
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)>
HWY_API Vec256<
float> Set(D
/* tag */, float t) {
return Vec256<
float>{_mm256_set1_ps(t)};
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
HWY_API Vec256<
double> Set(D
/* tag */, double t) {
return Vec256<
double>{_mm256_set1_pd(t)};
}
HWY_DIAGNOSTICS(push)
HWY_DIAGNOSTICS_OFF(disable : 4700, ignored
"-Wuninitialized")
// Returns a vector with uninitialized elements.
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)>
HWY_API VFromD<D> Undefined(D
/* tag */) {
// Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC
// generate an XOR instruction.
return VFromD<D>{_mm256_undefined_si256()};
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_BF16_D(D)>
HWY_API Vec256<bfloat16_t> Undefined(D
/* tag */) {
return Vec256<bfloat16_t>{_mm256_undefined_si256()};
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)>
HWY_API Vec256<float16_t> Undefined(D
/* tag */) {
#if HWY_HAVE_FLOAT16
return Vec256<float16_t>{_mm256_undefined_ph()};
#else
return Vec256<float16_t>{_mm256_undefined_si256()};
#endif
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)>
HWY_API Vec256<
float> Undefined(D
/* tag */) {
return Vec256<
float>{_mm256_undefined_ps()};
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
HWY_API Vec256<
double> Undefined(D
/* tag */) {
return Vec256<
double>{_mm256_undefined_pd()};
}
HWY_DIAGNOSTICS(pop)
// ------------------------------ ResizeBitCast
// 32-byte vector to 32-byte vector (or 64-byte vector to 64-byte vector on
// AVX3)
template <
class D,
class FromV, HWY_IF_V_SIZE_GT_V(FromV, 16),
HWY_IF_V_SIZE_D(D, HWY_MAX_LANES_V(FromV) *
sizeof(TFromV<FromV>))>
HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
return BitCast(d, v);
}
// 32-byte vector to 16-byte vector (or 64-byte vector to 32-byte vector on
// AVX3)
template <
class D,
class FromV, HWY_IF_V_SIZE_GT_V(FromV, 16),
HWY_IF_V_SIZE_D(D,
(HWY_MAX_LANES_V(FromV) *
sizeof(TFromV<FromV>)) / 2)>
HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
const DFromV<decltype(v)> d_from;
const Half<decltype(d_from)> dh_from;
return BitCast(d, LowerHalf(dh_from, v));
}
// 32-byte vector (or 64-byte vector on AVX3) to <= 8-byte vector
template <
class D,
class FromV, HWY_IF_V_SIZE_GT_V(FromV, 16),
HWY_IF_V_SIZE_LE_D(D, 8)>
HWY_API VFromD<D> ResizeBitCast(D
/*d*/, FromV v) {
return VFromD<D>{ResizeBitCast(Full128<TFromD<D>>(), v).raw};
}
// <= 16-byte vector to 32-byte vector
template <
class D,
class FromV, HWY_IF_V_SIZE_LE_V(FromV, 16),
HWY_IF_V_SIZE_D(D, 32)>
HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
return BitCast(d, Vec256<uint8_t>{_mm256_castsi128_si256(
ResizeBitCast(Full128<uint8_t>(), v).raw)});
}
// ------------------------------ Dup128VecFromValues
template <
class D, HWY_IF_UI8_D(D), HWY_IF_V_SIZE_D(D, 32)>
HWY_API VFromD<D> Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
TFromD<D> t5, TFromD<D> t6, TFromD<D> t7,
TFromD<D> t8, TFromD<D> t9, TFromD<D> t10,
TFromD<D> t11, TFromD<D> t12,
TFromD<D> t13, TFromD<D> t14,
TFromD<D> t15) {
return VFromD<D>{_mm256_setr_epi8(
static_cast<
char>(t0),
static_cast<
char>(t1),
static_cast<
char>(t2),
static_cast<
char>(t3),
static_cast<
char>(t4),
static_cast<
char>(t5),
static_cast<
char>(t6),
static_cast<
char>(t7),
static_cast<
char>(t8),
static_cast<
char>(t9),
static_cast<
char>(t10),
static_cast<
char>(t11),
static_cast<
char>(t12),
static_cast<
char>(t13),
static_cast<
char>(t14),
static_cast<
char>(t15),
static_cast<
char>(t0),
static_cast<
char>(t1),
static_cast<
char>(t2),
static_cast<
char>(t3),
static_cast<
char>(t4),
static_cast<
char>(t5),
static_cast<
char>(t6),
static_cast<
char>(t7),
static_cast<
char>(t8),
static_cast<
char>(t9),
static_cast<
char>(t10),
static_cast<
char>(t11),
static_cast<
char>(t12),
static_cast<
char>(t13),
static_cast<
char>(t14),
static_cast<
char>(t15))};
}
template <
class D, HWY_IF_UI16_D(D), HWY_IF_V_SIZE_D(D, 32)>
HWY_API VFromD<D> Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
TFromD<D> t5, TFromD<D> t6,
TFromD<D> t7) {
return VFromD<D>{
_mm256_setr_epi16(
static_cast<int16_t>(t0),
static_cast<int16_t>(t1),
static_cast<int16_t>(t2),
static_cast<int16_t>(t3),
static_cast<int16_t>(t4),
static_cast<int16_t>(t5),
static_cast<int16_t>(t6),
static_cast<int16_t>(t7),
static_cast<int16_t>(t0),
static_cast<int16_t>(t1),
static_cast<int16_t>(t2),
static_cast<int16_t>(t3),
static_cast<int16_t>(t4),
static_cast<int16_t>(t5),
static_cast<int16_t>(t6),
static_cast<int16_t>(t7))};
}
#if HWY_HAVE_FLOAT16
template <
class D, HWY_IF_F16_D(D), HWY_IF_V_SIZE_D(D, 32)>
HWY_API VFromD<D> Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
TFromD<D> t5, TFromD<D> t6,
TFromD<D> t7) {
return VFromD<D>{_mm256_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, t2,
t3, t4, t5, t6, t7)};
}
#endif
template <
class D, HWY_IF_UI32_D(D), HWY_IF_V_SIZE_D(D, 32)>
HWY_API VFromD<D> Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3) {
return VFromD<D>{
_mm256_setr_epi32(
static_cast<int32_t>(t0),
static_cast<int32_t>(t1),
static_cast<int32_t>(t2),
static_cast<int32_t>(t3),
static_cast<int32_t>(t0),
static_cast<int32_t>(t1),
static_cast<int32_t>(t2),
static_cast<int32_t>(t3))};
}
template <
class D, HWY_IF_F32_D(D), HWY_IF_V_SIZE_D(D, 32)>
HWY_API VFromD<D> Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3) {
return VFromD<D>{_mm256_setr_ps(t0, t1, t2, t3, t0, t1, t2, t3)};
}
template <
class D, HWY_IF_UI64_D(D), HWY_IF_V_SIZE_D(D, 32)>
HWY_API VFromD<D> Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1) {
return VFromD<D>{
_mm256_setr_epi64x(
static_cast<int64_t>(t0),
static_cast<int64_t>(t1),
static_cast<int64_t>(t0),
static_cast<int64_t>(t1))};
}
template <
class D, HWY_IF_F64_D(D), HWY_IF_V_SIZE_D(D, 32)>
HWY_API VFromD<D> Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1) {
return VFromD<D>{_mm256_setr_pd(t0, t1, t0, t1)};
}
// ================================================== LOGICAL
// ------------------------------ And
template <
typename T>
HWY_API Vec256<T>
And(Vec256<T> a, Vec256<T> b) {
const DFromV<decltype(a)> d;
// for float16_t
const RebindToUnsigned<decltype(d)> du;
return BitCast(d, VFromD<decltype(du)>{_mm256_and_si256(BitCast(du, a).raw,
BitCast(du, b).raw)});
}
HWY_API Vec256<
float>
And(Vec256<
float> a, Vec256<
float> b) {
return Vec256<
float>{_mm256_and_ps(a.raw, b.raw)};
}
HWY_API Vec256<
double>
And(Vec256<
double> a, Vec256<
double> b) {
return Vec256<
double>{_mm256_and_pd(a.raw, b.raw)};
}
// ------------------------------ AndNot
// Returns ~not_mask & mask.
template <
typename T>
HWY_API Vec256<T> AndNot(Vec256<T> not_mask, Vec256<T> mask) {
const DFromV<decltype(mask)> d;
// for float16_t
const RebindToUnsigned<decltype(d)> du;
return BitCast(d, VFromD<decltype(du)>{_mm256_andnot_si256(
BitCast(du, not_mask).raw, BitCast(du, mask).raw)});
}
HWY_API Vec256<
float> AndNot(Vec256<
float> not_mask, Vec256<
float> mask) {
return Vec256<
float>{_mm256_andnot_ps(not_mask.raw, mask.raw)};
}
HWY_API Vec256<
double> AndNot(Vec256<
double> not_mask, Vec256<
double> mask) {
return Vec256<
double>{_mm256_andnot_pd(not_mask.raw, mask.raw)};
}
// ------------------------------ Or
template <
typename T>
HWY_API Vec256<T>
Or(Vec256<T> a, Vec256<T> b) {
const DFromV<decltype(a)> d;
// for float16_t
const RebindToUnsigned<decltype(d)> du;
return BitCast(d, VFromD<decltype(du)>{_mm256_or_si256(BitCast(du, a).raw,
BitCast(du, b).raw)});
}
HWY_API Vec256<
float>
Or(Vec256<
float> a, Vec256<
float> b) {
return Vec256<
float>{_mm256_or_ps(a.raw, b.raw)};
}
HWY_API Vec256<
double>
Or(Vec256<
double> a, Vec256<
double> b) {
return Vec256<
double>{_mm256_or_pd(a.raw, b.raw)};
}
// ------------------------------ Xor
template <
typename T>
HWY_API Vec256<T>
Xor(Vec256<T> a, Vec256<T> b) {
const DFromV<decltype(a)> d;
// for float16_t
const RebindToUnsigned<decltype(d)> du;
return BitCast(d, VFromD<decltype(du)>{_mm256_xor_si256(BitCast(du, a).raw,
BitCast(du, b).raw)});
}
HWY_API Vec256<
float>
Xor(Vec256<
float> a, Vec256<
float> b) {
return Vec256<
float>{_mm256_xor_ps(a.raw, b.raw)};
}
HWY_API Vec256<
double>
Xor(Vec256<
double> a, Vec256<
double> b) {
return Vec256<
double>{_mm256_xor_pd(a.raw, b.raw)};
}
// ------------------------------ Not
template <
typename T>
HWY_API Vec256<T>
Not(
const Vec256<T> v) {
const DFromV<decltype(v)> d;
using TU = MakeUnsigned<T>;
#if HWY_TARGET <= HWY_AVX3
const __m256i vu = BitCast(RebindToUnsigned<decltype(d)>(), v).raw;
return BitCast(d, Vec256<TU>{_mm256_ternarylogic_epi32(vu, vu, vu, 0x55)});
#else
return Xor(v, BitCast(d, Vec256<TU>{_mm256_set1_epi32(-1)}));
#endif
}
// ------------------------------ Xor3
template <
typename T>
HWY_API Vec256<T> Xor3(Vec256<T> x1, Vec256<T> x2, Vec256<T> x3) {
#if HWY_TARGET <= HWY_AVX3
const DFromV<decltype(x1)> d;
const RebindToUnsigned<decltype(d)> du;
using VU = VFromD<decltype(du)>;
const __m256i ret = _mm256_ternarylogic_epi64(
BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96);
return BitCast(d, VU{ret});
#else
return Xor(x1,
Xor(x2, x3));
#endif
}
// ------------------------------ Or3
template <
typename T>
HWY_API Vec256<T> Or3(Vec256<T> o1, Vec256<T> o2, Vec256<T> o3) {
#if HWY_TARGET <= HWY_AVX3
const DFromV<decltype(o1)> d;
const RebindToUnsigned<decltype(d)> du;
using VU = VFromD<decltype(du)>;
const __m256i ret = _mm256_ternarylogic_epi64(
BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE);
return BitCast(d, VU{ret});
#else
return Or(o1,
Or(o2, o3));
#endif
}
// ------------------------------ OrAnd
template <
typename T>
HWY_API Vec256<T> OrAnd(Vec256<T> o, Vec256<T> a1, Vec256<T> a2) {
#if HWY_TARGET <= HWY_AVX3
const DFromV<decltype(o)> d;
const RebindToUnsigned<decltype(d)> du;
using VU = VFromD<decltype(du)>;
const __m256i ret = _mm256_ternarylogic_epi64(
BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8);
return BitCast(d, VU{ret});
#else
return Or(o,
And(a1, a2));
#endif
}
// ------------------------------ IfVecThenElse
template <
typename T>
HWY_API Vec256<T> IfVecThenElse(Vec256<T> mask, Vec256<T> yes, Vec256<T> no) {
#if HWY_TARGET <= HWY_AVX3
const DFromV<decltype(yes)> d;
const RebindToUnsigned<decltype(d)> du;
using VU = VFromD<decltype(du)>;
return BitCast(d, VU{_mm256_ternarylogic_epi64(BitCast(du, mask).raw,
BitCast(du, yes).raw,
BitCast(du, no).raw, 0xCA)});
#else
return IfThenElse(MaskFromVec(mask), yes, no);
#endif
}
// ------------------------------ Operator overloads (internal-only if float)
template <
typename T>
HWY_API Vec256<T>
operator&(
const Vec256<T> a,
const Vec256<T> b) {
return And(a, b);
}
template <
typename T>
HWY_API Vec256<T>
operator|(
const Vec256<T> a,
const Vec256<T> b) {
return Or(a, b);
}
template <
typename T>
HWY_API Vec256<T>
operator^(
const Vec256<T> a,
const Vec256<T> b) {
return Xor(a, b);
}
// ------------------------------ PopulationCount
// 8/16 require BITALG, 32/64 require VPOPCNTDQ.
#if HWY_TARGET <= HWY_AVX3_DL
#ifdef HWY_NATIVE_POPCNT
#undef HWY_NATIVE_POPCNT
#else
#define HWY_NATIVE_POPCNT
#endif
namespace detail {
template <
typename T>
HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<1>
/* tag */, Vec256<T> v) {
return Vec256<T>{_mm256_popcnt_epi8(v.raw)};
}
template <
typename T>
HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<2>
/* tag */, Vec256<T> v) {
return Vec256<T>{_mm256_popcnt_epi16(v.raw)};
}
template <
typename T>
HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<4>
/* tag */, Vec256<T> v) {
return Vec256<T>{_mm256_popcnt_epi32(v.raw)};
}
template <
typename T>
HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<8>
/* tag */, Vec256<T> v) {
return Vec256<T>{_mm256_popcnt_epi64(v.raw)};
}
}
// namespace detail
template <
typename T>
HWY_API Vec256<T> PopulationCount(Vec256<T> v) {
return detail::PopulationCount(hwy::SizeTag<
sizeof(T)>(), v);
}
#endif // HWY_TARGET <= HWY_AVX3_DL
// ================================================== MASK
#if HWY_TARGET <= HWY_AVX3
// ------------------------------ IfThenElse
// Returns mask ? b : a.
namespace detail {
// Templates for signed/unsigned integer of a particular size.
template <
typename T>
HWY_INLINE Vec256<T> IfThenElse(hwy::SizeTag<1>
/* tag */, Mask256<T> mask,
Vec256<T> yes, Vec256<T> no) {
return Vec256<T>{_mm256_mask_blend_epi8(mask.raw, no.raw, yes.raw)};
}
template <
typename T>
HWY_INLINE Vec256<T> IfThenElse(hwy::SizeTag<2>
/* tag */, Mask256<T> mask,
Vec256<T> yes, Vec256<T> no) {
return Vec256<T>{_mm256_mask_blend_epi16(mask.raw, no.raw, yes.raw)};
}
template <
typename T>
HWY_INLINE Vec256<T> IfThenElse(hwy::SizeTag<4>
/* tag */, Mask256<T> mask,
Vec256<T> yes, Vec256<T> no) {
return Vec256<T>{_mm256_mask_blend_epi32(mask.raw, no.raw, yes.raw)};
}
template <
typename T>
HWY_INLINE Vec256<T> IfThenElse(hwy::SizeTag<8>
/* tag */, Mask256<T> mask,
Vec256<T> yes, Vec256<T> no) {
return Vec256<T>{_mm256_mask_blend_epi64(mask.raw, no.raw, yes.raw)};
}
}
// namespace detail
template <
typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)>
HWY_API Vec256<T> IfThenElse(Mask256<T> mask, Vec256<T> yes, Vec256<T> no) {
return detail::IfThenElse(hwy::SizeTag<
sizeof(T)>(), mask, yes, no);
}
#if HWY_HAVE_FLOAT16
HWY_API Vec256<float16_t> IfThenElse(Mask256<float16_t> mask,
Vec256<float16_t> yes,
Vec256<float16_t> no) {
return Vec256<float16_t>{_mm256_mask_blend_ph(mask.raw, no.raw, yes.raw)};
}
#endif // HWY_HAVE_FLOAT16
HWY_API Vec256<
float> IfThenElse(Mask256<
float> mask, Vec256<
float> yes,
Vec256<
float> no) {
return Vec256<
float>{_mm256_mask_blend_ps(mask.raw, no.raw, yes.raw)};
}
HWY_API Vec256<
double> IfThenElse(Mask256<
double> mask, Vec256<
double> yes,
Vec256<
double> no) {
return Vec256<
double>{_mm256_mask_blend_pd(mask.raw, no.raw, yes.raw)};
}
namespace detail {
template <
typename T>
HWY_INLINE Vec256<T> IfThenElseZero(hwy::SizeTag<1>
/* tag */, Mask256<T> mask,
Vec256<T> yes) {
return Vec256<T>{_mm256_maskz_mov_epi8(mask.raw, yes.raw)};
}
template <
typename T>
HWY_INLINE Vec256<T> IfThenElseZero(hwy::SizeTag<2>
/* tag */, Mask256<T> mask,
Vec256<T> yes) {
return Vec256<T>{_mm256_maskz_mov_epi16(mask.raw, yes.raw)};
}
template <
typename T>
HWY_INLINE Vec256<T> IfThenElseZero(hwy::SizeTag<4>
/* tag */, Mask256<T> mask,
Vec256<T> yes) {
return Vec256<T>{_mm256_maskz_mov_epi32(mask.raw, yes.raw)};
}
template <
typename T>
HWY_INLINE Vec256<T> IfThenElseZero(hwy::SizeTag<8>
/* tag */, Mask256<T> mask,
Vec256<T> yes) {
return Vec256<T>{_mm256_maskz_mov_epi64(mask.raw, yes.raw)};
}
}
// namespace detail
template <
typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)>
HWY_API Vec256<T> IfThenElseZero(Mask256<T> mask, Vec256<T> yes) {
return detail::IfThenElseZero(hwy::SizeTag<
sizeof(T)>(), mask, yes);
}
HWY_API Vec256<
float> IfThenElseZero(Mask256<
float> mask, Vec256<
float> yes) {
return Vec256<
float>{_mm256_maskz_mov_ps(mask.raw, yes.raw)};
}
HWY_API Vec256<
double> IfThenElseZero(Mask256<
double> mask,
Vec256<
double> yes) {
return Vec256<
double>{_mm256_maskz_mov_pd(mask.raw, yes.raw)};
}
namespace detail {
template <
typename T>
HWY_INLINE Vec256<T> IfThenZeroElse(hwy::SizeTag<1>
/* tag */, Mask256<T> mask,
Vec256<T> no) {
// xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16.
return Vec256<T>{_mm256_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)};
}
template <
typename T>
HWY_INLINE Vec256<T> IfThenZeroElse(hwy::SizeTag<2>
/* tag */, Mask256<T> mask,
Vec256<T> no) {
return Vec256<T>{_mm256_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)};
}
template <
typename T>
HWY_INLINE Vec256<T> IfThenZeroElse(hwy::SizeTag<4>
/* tag */, Mask256<T> mask,
Vec256<T> no) {
return Vec256<T>{_mm256_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)};
}
template <
typename T>
HWY_INLINE Vec256<T> IfThenZeroElse(hwy::SizeTag<8>
/* tag */, Mask256<T> mask,
Vec256<T> no) {
return Vec256<T>{_mm256_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)};
}
}
// namespace detail
template <
typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)>
HWY_API Vec256<T> IfThenZeroElse(Mask256<T> mask, Vec256<T> no) {
return detail::IfThenZeroElse(hwy::SizeTag<
sizeof(T)>(), mask, no);
}
HWY_API Vec256<
float> IfThenZeroElse(Mask256<
float> mask, Vec256<
float> no) {
return Vec256<
float>{_mm256_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)};
}
HWY_API Vec256<
double> IfThenZeroElse(Mask256<
double> mask, Vec256<
double> no) {
return Vec256<
double>{_mm256_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)};
}
template <
typename T>
HWY_API Vec256<T> ZeroIfNegative(
const Vec256<T> v) {
static_assert(IsSigned<T>(),
"Only for float");
// AVX3 MaskFromVec only looks at the MSB
return IfThenZeroElse(MaskFromVec(v), v);
}
// ------------------------------ Mask logical
namespace detail {
template <
typename T>
HWY_INLINE Mask256<T>
And(hwy::SizeTag<1>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kand_mask32(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask32>(a.raw & b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T>
And(hwy::SizeTag<2>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kand_mask16(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask16>(a.raw & b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T>
And(hwy::SizeTag<4>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kand_mask8(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask8>(a.raw & b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T>
And(hwy::SizeTag<8>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kand_mask8(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask8>(a.raw & b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T> AndNot(hwy::SizeTag<1>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kandn_mask32(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask32>(~a.raw & b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T> AndNot(hwy::SizeTag<2>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kandn_mask16(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask16>(~a.raw & b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T> AndNot(hwy::SizeTag<4>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kandn_mask8(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask8>(~a.raw & b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T> AndNot(hwy::SizeTag<8>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kandn_mask8(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask8>(~a.raw & b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T>
Or(hwy::SizeTag<1>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kor_mask32(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask32>(a.raw | b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T>
Or(hwy::SizeTag<2>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kor_mask16(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask16>(a.raw | b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T>
Or(hwy::SizeTag<4>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kor_mask8(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask8>(a.raw | b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T>
Or(hwy::SizeTag<8>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kor_mask8(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask8>(a.raw | b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T>
Xor(hwy::SizeTag<1>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kxor_mask32(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask32>(a.raw ^ b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T>
Xor(hwy::SizeTag<2>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kxor_mask16(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask16>(a.raw ^ b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T>
Xor(hwy::SizeTag<4>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kxor_mask8(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask8>(a.raw ^ b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T>
Xor(hwy::SizeTag<8>
/*tag*/, const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kxor_mask8(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask8>(a.raw ^ b.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T> ExclusiveNeither(hwy::SizeTag<1>
/*tag*/,
const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kxnor_mask32(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask32>(~(a.raw ^ b.raw) & 0xFFFFFFFF)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T> ExclusiveNeither(hwy::SizeTag<2>
/*tag*/,
const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kxnor_mask16(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask16>(~(a.raw ^ b.raw) & 0xFFFF)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T> ExclusiveNeither(hwy::SizeTag<4>
/*tag*/,
const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{_kxnor_mask8(a.raw, b.raw)};
#else
return Mask256<T>{
static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xFF)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T> ExclusiveNeither(hwy::SizeTag<8>
/*tag*/,
const Mask256<T> a,
const Mask256<T> b) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{
static_cast<__mmask8>(_kxnor_mask8(a.raw, b.raw) & 0xF)};
#else
return Mask256<T>{
static_cast<__mmask8>(~(a.raw ^ b.raw) & 0xF)};
#endif
}
// UnmaskedNot returns ~m.raw without zeroing out any invalid bits
template <
typename T, HWY_IF_T_SIZE(T, 1)>
HWY_INLINE Mask256<T> UnmaskedNot(
const Mask256<T> m) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{
static_cast<__mmask32>(_knot_mask32(m.raw))};
#else
return Mask256<T>{
static_cast<__mmask32>(~m.raw)};
#endif
}
template <
typename T, HWY_IF_T_SIZE(T, 2)>
HWY_INLINE Mask256<T> UnmaskedNot(
const Mask256<T> m) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{
static_cast<__mmask16>(_knot_mask16(m.raw))};
#else
return Mask256<T>{
static_cast<__mmask16>(~m.raw)};
#endif
}
template <
typename T, HWY_IF_T_SIZE_ONE_OF(T, (1 << 4) | (1 << 8))>
HWY_INLINE Mask256<T> UnmaskedNot(
const Mask256<T> m) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
return Mask256<T>{
static_cast<__mmask8>(_knot_mask8(m.raw))};
#else
return Mask256<T>{
static_cast<__mmask8>(~m.raw)};
#endif
}
template <
typename T>
HWY_INLINE Mask256<T>
Not(hwy::SizeTag<1>
/*tag*/, const Mask256<T> m) {
// sizeof(T) == 1: simply return ~m as all 32 bits of m are valid
return UnmaskedNot(m);
}
template <
typename T>
HWY_INLINE Mask256<T>
Not(hwy::SizeTag<2>
/*tag*/, const Mask256<T> m) {
// sizeof(T) == 2: simply return ~m as all 16 bits of m are valid
return UnmaskedNot(m);
}
template <
typename T>
HWY_INLINE Mask256<T>
Not(hwy::SizeTag<4>
/*tag*/, const Mask256<T> m) {
// sizeof(T) == 4: simply return ~m as all 8 bits of m are valid
return UnmaskedNot(m);
}
template <
typename T>
HWY_INLINE Mask256<T>
Not(hwy::SizeTag<8>
/*tag*/, const Mask256<T> m) {
// sizeof(T) == 8: need to zero out the upper 4 bits of ~m as only the lower
// 4 bits of m are valid
// Return (~m) & 0x0F
return AndNot(hwy::SizeTag<8>(), m, Mask256<T>::FromBits(uint64_t{0x0F}));
}
}
// namespace detail
template <
typename T>
HWY_API Mask256<T>
And(
const Mask256<T> a, Mask256<T> b) {
return detail::
And(hwy::SizeTag<
sizeof(T)>(), a, b);
}
template <
typename T>
HWY_API Mask256<T> AndNot(
const Mask256<T> a, Mask256<T> b) {
return detail::AndNot(hwy::SizeTag<
sizeof(T)>(), a, b);
}
template <
typename T>
HWY_API Mask256<T>
Or(
const Mask256<T> a, Mask256<T> b) {
return detail::
Or(hwy::SizeTag<
sizeof(T)>(), a, b);
}
template <
typename T>
HWY_API Mask256<T>
Xor(
const Mask256<T> a, Mask256<T> b) {
return detail::
Xor(hwy::SizeTag<
sizeof(T)>(), a, b);
}
template <
typename T>
HWY_API Mask256<T>
Not(
const Mask256<T> m) {
// Flip only the valid bits.
return detail::
Not(hwy::SizeTag<
sizeof(T)>(), m);
}
template <
typename T>
HWY_API Mask256<T> ExclusiveNeither(
const Mask256<T> a, Mask256<T> b) {
return detail::ExclusiveNeither(hwy::SizeTag<
sizeof(T)>(), a, b);
}
template <
class D, HWY_IF_LANES_D(D, 32)>
HWY_API MFromD<D> CombineMasks(D
/*d*/, MFromD<Half<D>> hi,
MFromD<Half<D>> lo) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
const __mmask32 combined_mask = _mm512_kunpackw(
static_cast<__mmask32>(hi.raw),
static_cast<__mmask32>(lo.raw));
#else
const auto combined_mask =
((
static_cast<uint32_t>(hi.raw) << 16) | (lo.raw & 0xFFFFu));
#endif
return MFromD<D>{
static_cast<decltype(MFromD<D>().raw)>(combined_mask)};
}
template <
class D, HWY_IF_LANES_D(D, 16)>
HWY_API MFromD<D> UpperHalfOfMask(D
/*d*/, MFromD<Twice<D>> m) {
#if HWY_COMPILER_HAS_MASK_INTRINSICS
const auto shifted_mask = _kshiftri_mask32(
static_cast<__mmask32>(m.raw), 16);
#else
const auto shifted_mask =
static_cast<uint32_t>(m.raw) >> 16;
#endif
return MFromD<D>{
static_cast<decltype(MFromD<D>().raw)>(shifted_mask)};
}
#else // AVX2
// ------------------------------ Mask
// Mask and Vec are the same (true = FF..FF).
template <
typename T>
HWY_API Mask256<T> MaskFromVec(
const Vec256<T> v) {
return Mask256<T>{v.raw};
}
template <
typename T>
HWY_API Vec256<T> VecFromMask(
const Mask256<T> v) {
return Vec256<T>{v.raw};
}
// ------------------------------ IfThenElse
// mask ? yes : no
template <
typename T, HWY_IF_NOT_FLOAT3264(T)>
HWY_API Vec256<T> IfThenElse(Mask256<T> mask, Vec256<T> yes, Vec256<T> no) {
return Vec256<T>{_mm256_blendv_epi8(no.raw, yes.raw, mask.raw)};
}
HWY_API Vec256<
float> IfThenElse(Mask256<
float> mask, Vec256<
float> yes,
Vec256<
float> no) {
return Vec256<
float>{_mm256_blendv_ps(no.raw, yes.raw, mask.raw)};
}
HWY_API Vec256<
double> IfThenElse(Mask256<
double> mask, Vec256<
double> yes,
Vec256<
double> no) {
return Vec256<
double>{_mm256_blendv_pd(no.raw, yes.raw, mask.raw)};
}
// mask ? yes : 0
template <
typename T>
HWY_API Vec256<T> IfThenElseZero(Mask256<T> mask, Vec256<T> yes) {
const DFromV<decltype(yes)> d;
return yes & VecFromMask(d, mask);
}
// mask ? 0 : no
template <
typename T>
HWY_API Vec256<T> IfThenZeroElse(Mask256<T> mask, Vec256<T> no) {
const DFromV<decltype(no)> d;
return AndNot(VecFromMask(d, mask), no);
}
template <
typename T>
HWY_API Vec256<T> ZeroIfNegative(Vec256<T> v) {
static_assert(IsSigned<T>(),
"Only for float");
const DFromV<decltype(v)> d;
const auto zero = Zero(d);
// AVX2 IfThenElse only looks at the MSB for 32/64-bit lanes
return IfThenElse(MaskFromVec(v), zero, v);
}
// ------------------------------ Mask logical
template <
typename T>
HWY_API Mask256<T>
Not(
const Mask256<T> m) {
const Full256<T> d;
return MaskFromVec(
Not(VecFromMask(d, m)));
}
template <
typename T>
HWY_API Mask256<T>
And(
const Mask256<T> a, Mask256<T> b) {
const Full256<T> d;
return MaskFromVec(
And(VecFromMask(d, a), VecFromMask(d, b)));
}
template <
typename T>
HWY_API Mask256<T> AndNot(
const Mask256<T> a, Mask256<T> b) {
const Full256<T> d;
return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b)));
}
template <
typename T>
HWY_API Mask256<T>
Or(
const Mask256<T> a, Mask256<T> b) {
const Full256<T> d;
return MaskFromVec(
Or(VecFromMask(d, a), VecFromMask(d, b)));
}
template <
typename T>
HWY_API Mask256<T>
Xor(
const Mask256<T> a, Mask256<T> b) {
const Full256<T> d;
return MaskFromVec(
Xor(VecFromMask(d, a), VecFromMask(d, b)));
}
template <
typename T>
HWY_API Mask256<T> ExclusiveNeither(
const Mask256<T> a, Mask256<T> b) {
const Full256<T> d;
return MaskFromVec(AndNot(VecFromMask(d, a),
Not(VecFromMask(d, b))));
}
#endif // HWY_TARGET <= HWY_AVX3
// ================================================== COMPARE
#if HWY_TARGET <= HWY_AVX3
// Comparisons set a mask bit to 1 if the condition is true, else 0.
template <
class DTo, HWY_IF_V_SIZE_D(DTo, 32),
typename TFrom>
HWY_API MFromD<DTo> RebindMask(DTo
/*tag*/, Mask256<TFrom> m) {
static_assert(
sizeof(TFrom) ==
sizeof(TFromD<DTo>),
"Must have same size");
return MFromD<DTo>{m.raw};
}
namespace detail {
template <
typename T>
HWY_INLINE Mask256<T> TestBit(hwy::SizeTag<1>
/*tag*/, const Vec256<T> v,
const Vec256<T> bit) {
return Mask256<T>{_mm256_test_epi8_mask(v.raw, bit.raw)};
}
template <
typename T>
HWY_INLINE Mask256<T> TestBit(hwy::SizeTag<2>
/*tag*/, const Vec256<T> v,
const Vec256<T> bit) {
return Mask256<T>{_mm256_test_epi16_mask(v.raw, bit.raw)};
}
template <
typename T>
HWY_INLINE Mask256<T> TestBit(hwy::SizeTag<4>
/*tag*/, const Vec256<T> v,
const Vec256<T> bit) {
return Mask256<T>{_mm256_test_epi32_mask(v.raw, bit.raw)};
}
template <
typename T>
HWY_INLINE Mask256<T> TestBit(hwy::SizeTag<8>
/*tag*/, const Vec256<T> v,
const Vec256<T> bit) {
return Mask256<T>{_mm256_test_epi64_mask(v.raw, bit.raw)};
}
}
// namespace detail
template <
typename T>
HWY_API Mask256<T> TestBit(
const Vec256<T> v,
const Vec256<T> bit) {
static_assert(!hwy::IsFloat<T>(),
"Only integer vectors supported");
return detail::TestBit(hwy::SizeTag<
sizeof(T)>(), v, bit);
}
// ------------------------------ Equality
template <
typename T, HWY_IF_T_SIZE(T, 1)>
HWY_API Mask256<T>
operator==(
const Vec256<T> a,
const Vec256<T> b) {
return Mask256<T>{_mm256_cmpeq_epi8_mask(a.raw, b.raw)};
}
template <
typename T, HWY_IF_UI16(T)>
HWY_API Mask256<T>
operator==(
const Vec256<T> a,
const Vec256<T> b) {
return Mask256<T>{_mm256_cmpeq_epi16_mask(a.raw, b.raw)};
}
template <
typename T, HWY_IF_UI32(T)>
HWY_API Mask256<T>
operator==(
const Vec256<T> a,
const Vec256<T> b) {
return Mask256<T>{_mm256_cmpeq_epi32_mask(a.raw, b.raw)};
}
template <
typename T, HWY_IF_UI64(T)>
HWY_API Mask256<T>
operator==(
const Vec256<T> a,
const Vec256<T> b) {
return Mask256<T>{_mm256_cmpeq_epi64_mask(a.raw, b.raw)};
}
#if HWY_HAVE_FLOAT16
HWY_API Mask256<float16_t>
operator==(Vec256<float16_t> a,
Vec256<float16_t> b) {
// Work around warnings in the intrinsic definitions (passing -1 as a mask).
HWY_DIAGNOSTICS(push)
HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored
"-Wsign-conversion")
return Mask256<float16_t>{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_EQ_OQ)};
HWY_DIAGNOSTICS(pop)
}
#endif // HWY_HAVE_FLOAT16
HWY_API Mask256<
float>
operator==(Vec256<
float> a, Vec256<
float> b) {
return Mask256<
float>{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)};
}
HWY_API Mask256<
double>
operator==(Vec256<
double> a, Vec256<
double> b) {
return Mask256<
double>{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)};
}
// ------------------------------ Inequality
template <
typename T, HWY_IF_T_SIZE(T, 1)>
HWY_API Mask256<T>
operator!=(
const Vec256<T> a,
const Vec256<T> b) {
return Mask256<T>{_mm256_cmpneq_epi8_mask(a.raw, b.raw)};
}
template <
typename T, HWY_IF_UI16(T)>
HWY_API Mask256<T>
operator!=(
const Vec256<T> a,
const Vec256<T> b) {
return Mask256<T>{_mm256_cmpneq_epi16_mask(a.raw, b.raw)};
}
template <
typename T, HWY_IF_UI32(T)>
HWY_API Mask256<T>
operator!=(
const Vec256<T> a,
const Vec256<T> b) {
return Mask256<T>{_mm256_cmpneq_epi32_mask(a.raw, b.raw)};
}
template <
typename T, HWY_IF_UI64(T)>
HWY_API Mask256<T>
operator!=(
const Vec256<T> a,
const Vec256<T> b) {
return Mask256<T>{_mm256_cmpneq_epi64_mask(a.raw, b.raw)};
}
#if HWY_HAVE_FLOAT16
HWY_API Mask256<float16_t>
operator!=(Vec256<float16_t> a,
Vec256<float16_t> b) {
// Work around warnings in the intrinsic definitions (passing -1 as a mask).
HWY_DIAGNOSTICS(push)
HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored
"-Wsign-conversion")
return Mask256<float16_t>{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_NEQ_OQ)};
HWY_DIAGNOSTICS(pop)
}
#endif // HWY_HAVE_FLOAT16
HWY_API Mask256<
float>
operator!=(Vec256<
float> a, Vec256<
float> b) {
return Mask256<
float>{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)};
}
HWY_API Mask256<
double>
operator!=(Vec256<
double> a, Vec256<
double> b) {
return Mask256<
double>{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)};
}
// ------------------------------ Strict inequality
HWY_API Mask256<int8_t>
operator>(Vec256<int8_t> a, Vec256<int8_t> b) {
return Mask256<int8_t>{_mm256_cmpgt_epi8_mask(a.raw, b.raw)};
}
HWY_API Mask256<int16_t>
operator>(Vec256<int16_t> a, Vec256<int16_t> b) {
return Mask256<int16_t>{_mm256_cmpgt_epi16_mask(a.raw, b.raw)};
}
HWY_API Mask256<int32_t>
operator>(Vec256<int32_t> a, Vec256<int32_t> b) {
return Mask256<int32_t>{_mm256_cmpgt_epi32_mask(a.raw, b.raw)};
}
HWY_API Mask256<int64_t>
operator>(Vec256<int64_t> a, Vec256<int64_t> b) {
return Mask256<int64_t>{_mm256_cmpgt_epi64_mask(a.raw, b.raw)};
}
HWY_API Mask256<uint8_t>
operator>(Vec256<uint8_t> a, Vec256<uint8_t> b) {
return Mask256<uint8_t>{_mm256_cmpgt_epu8_mask(a.raw, b.raw)};
}
HWY_API Mask256<uint16_t>
operator>(Vec256<uint16_t> a, Vec256<uint16_t> b) {
return Mask256<uint16_t>{_mm256_cmpgt_epu16_mask(a.raw, b.raw)};
}
HWY_API Mask256<uint32_t>
operator>(Vec256<uint32_t> a, Vec256<uint32_t> b) {
return Mask256<uint32_t>{_mm256_cmpgt_epu32_mask(a.raw, b.raw)};
}
HWY_API Mask256<uint64_t>
operator>(Vec256<uint64_t> a, Vec256<uint64_t> b) {
return Mask256<uint64_t>{_mm256_cmpgt_epu64_mask(a.raw, b.raw)};
}
#if HWY_HAVE_FLOAT16
HWY_API Mask256<float16_t>
operator>(Vec256<float16_t> a, Vec256<float16_t> b) {
// Work around warnings in the intrinsic definitions (passing -1 as a mask).
HWY_DIAGNOSTICS(push)
HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored
"-Wsign-conversion")
return Mask256<float16_t>{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_GT_OQ)};
HWY_DIAGNOSTICS(pop)
}
#endif // HWY_HAVE_FLOAT16
HWY_API Mask256<
float>
operator>(Vec256<
float> a, Vec256<
float> b) {
return Mask256<
float>{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)};
}
HWY_API Mask256<
double>
operator>(Vec256<
double> a, Vec256<
double> b) {
return Mask256<
double>{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)};
}
// ------------------------------ Weak inequality
#if HWY_HAVE_FLOAT16
HWY_API Mask256<float16_t>
operator>=(Vec256<float16_t> a,
Vec256<float16_t> b) {
// Work around warnings in the intrinsic definitions (passing -1 as a mask).
HWY_DIAGNOSTICS(push)
HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored
"-Wsign-conversion")
return Mask256<float16_t>{_mm256_cmp_ph_mask(a.raw, b.raw, _CMP_GE_OQ)};
HWY_DIAGNOSTICS(pop)
}
#endif // HWY_HAVE_FLOAT16
HWY_API Mask256<
float>
operator>=(Vec256<
float> a, Vec256<
float> b) {
return Mask256<
float>{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)};
}
HWY_API Mask256<
double>
operator>=(Vec256<
double> a, Vec256<
double> b) {
return Mask256<
double>{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)};
}
HWY_API Mask256<int8_t>
operator>=(Vec256<int8_t> a, Vec256<int8_t> b) {
return Mask256<int8_t>{_mm256_cmpge_epi8_mask(a.raw, b.raw)};
}
HWY_API Mask256<int16_t>
operator>=(Vec256<int16_t> a, Vec256<int16_t> b) {
return Mask256<int16_t>{_mm256_cmpge_epi16_mask(a.raw, b.raw)};
}
HWY_API Mask256<int32_t>
operator>=(Vec256<int32_t> a, Vec256<int32_t> b) {
return Mask256<int32_t>{_mm256_cmpge_epi32_mask(a.raw, b.raw)};
}
HWY_API Mask256<int64_t>
operator>=(Vec256<int64_t> a, Vec256<int64_t> b) {
return Mask256<int64_t>{_mm256_cmpge_epi64_mask(a.raw, b.raw)};
}
HWY_API Mask256<uint8_t>
operator>=(Vec256<uint8_t> a, Vec256<uint8_t> b) {
return Mask256<uint8_t>{_mm256_cmpge_epu8_mask(a.raw, b.raw)};
}
HWY_API Mask256<uint16_t>
operator>=(
const Vec256<uint16_t> a,
const Vec256<uint16_t> b) {
return Mask256<uint16_t>{_mm256_cmpge_epu16_mask(a.raw, b.raw)};
}
HWY_API Mask256<uint32_t>
operator>=(
const Vec256<uint32_t> a,
const Vec256<uint32_t> b) {
return Mask256<uint32_t>{_mm256_cmpge_epu32_mask(a.raw, b.raw)};
}
HWY_API Mask256<uint64_t>
operator>=(
const Vec256<uint64_t> a,
const Vec256<uint64_t> b) {
return Mask256<uint64_t>{_mm256_cmpge_epu64_mask(a.raw, b.raw)};
}
// ------------------------------ Mask
namespace detail {
template <
typename T>
HWY_INLINE Mask256<T> MaskFromVec(hwy::SizeTag<1>
/*tag*/, const Vec256<T> v) {
return Mask256<T>{_mm256_movepi8_mask(v.raw)};
}
template <
typename T>
HWY_INLINE Mask256<T> MaskFromVec(hwy::SizeTag<2>
/*tag*/, const Vec256<T> v) {
return Mask256<T>{_mm256_movepi16_mask(v.raw)};
}
template <
typename T>
HWY_INLINE Mask256<T> MaskFromVec(hwy::SizeTag<4>
/*tag*/, const Vec256<T> v) {
return Mask256<T>{_mm256_movepi32_mask(v.raw)};
}
template <
typename T>
HWY_INLINE Mask256<T> MaskFromVec(hwy::SizeTag<8>
/*tag*/, const Vec256<T> v) {
return Mask256<T>{_mm256_movepi64_mask(v.raw)};
}
}
// namespace detail
template <
typename T, HWY_IF_NOT_FLOAT(T)>
HWY_API Mask256<T> MaskFromVec(
const Vec256<T> v) {
return detail::MaskFromVec(hwy::SizeTag<
sizeof(T)>(), v);
}
// There do not seem to be native floating-point versions of these instructions.
template <
typename T, HWY_IF_FLOAT(T)>
HWY_API Mask256<T> MaskFromVec(
const Vec256<T> v) {
const RebindToSigned<DFromV<decltype(v)>> di;
return Mask256<T>{MaskFromVec(BitCast(di, v)).raw};
}
template <
typename T, HWY_IF_T_SIZE(T, 1)>
HWY_API Vec256<T> VecFromMask(
const Mask256<T> v) {
return Vec256<T>{_mm256_movm_epi8(v.raw)};
}
template <
typename T, HWY_IF_UI16(T)>
HWY_API Vec256<T> VecFromMask(
const Mask256<T> v) {
return Vec256<T>{_mm256_movm_epi16(v.raw)};
}
template <
typename T, HWY_IF_UI32(T)>
HWY_API Vec256<T> VecFromMask(
const Mask256<T> v) {
return Vec256<T>{_mm256_movm_epi32(v.raw)};
}
template <
typename T, HWY_IF_UI64(T)>
HWY_API Vec256<T> VecFromMask(
const Mask256<T> v) {
return Vec256<T>{_mm256_movm_epi64(v.raw)};
}
#if HWY_HAVE_FLOAT16
HWY_API Vec256<float16_t> VecFromMask(
const Mask256<float16_t> v) {
return Vec256<float16_t>{_mm256_castsi256_ph(_mm256_movm_epi16(v.raw))};
}
#endif // HWY_HAVE_FLOAT16
HWY_API Vec256<
float> VecFromMask(
const Mask256<
float> v) {
return Vec256<
float>{_mm256_castsi256_ps(_mm256_movm_epi32(v.raw))};
}
HWY_API Vec256<
double> VecFromMask(
const Mask256<
double> v) {
return Vec256<
double>{_mm256_castsi256_pd(_mm256_movm_epi64(v.raw))};
}
#else // AVX2
// Comparisons fill a lane with 1-bits if the condition is true, else 0.
template <
class DTo, HWY_IF_V_SIZE_D(DTo, 32),
typename TFrom>
HWY_API MFromD<DTo> RebindMask(DTo d_to, Mask256<TFrom> m) {
static_assert(
sizeof(TFrom) ==
sizeof(TFromD<DTo>),
"Must have same size");
const Full256<TFrom> dfrom;
return MaskFromVec(BitCast(d_to, VecFromMask(dfrom, m)));
}
template <
typename T>
HWY_API Mask256<T> TestBit(
const Vec256<T> v,
const Vec256<T> bit) {
static_assert(!hwy::IsFloat<T>(),
"Only integer vectors supported");
return (v & bit) == bit;
}
// ------------------------------ Equality
template <
typename T, HWY_IF_T_SIZE(T, 1)>
HWY_API Mask256<T>
operator==(Vec256<T> a, Vec256<T> b) {
return Mask256<T>{_mm256_cmpeq_epi8(a.raw, b.raw)};
}
template <
typename T, HWY_IF_UI16(T)>
HWY_API Mask256<T>
operator==(Vec256<T> a, Vec256<T> b) {
return Mask256<T>{_mm256_cmpeq_epi16(a.raw, b.raw)};
}
template <
typename T, HWY_IF_UI32(T)>
HWY_API Mask256<T>
operator==(Vec256<T> a, Vec256<T> b) {
return Mask256<T>{_mm256_cmpeq_epi32(a.raw, b.raw)};
}
template <
typename T, HWY_IF_UI64(T)>
HWY_API Mask256<T>
operator==(Vec256<T> a, Vec256<T> b) {
return Mask256<T>{_mm256_cmpeq_epi64(a.raw, b.raw)};
}
HWY_API Mask256<
float>
operator==(Vec256<
float> a, Vec256<
float> b) {
return Mask256<
float>{_mm256_cmp_ps(a.raw, b.raw, _CMP_EQ_OQ)};
}
HWY_API Mask256<
double>
operator==(Vec256<
double> a, Vec256<
double> b) {
return Mask256<
double>{_mm256_cmp_pd(a.raw, b.raw, _CMP_EQ_OQ)};
}
// ------------------------------ Inequality
template <
typename T, HWY_IF_NOT_FLOAT3264(T)>
HWY_API Mask256<T>
operator!=(Vec256<T> a, Vec256<T> b) {
return Not(a == b);
}
HWY_API Mask256<
float>
operator!=(Vec256<
float> a, Vec256<
float> b) {
return Mask256<
float>{_mm256_cmp_ps(a.raw, b.raw, _CMP_NEQ_OQ)};
}
HWY_API Mask256<
double>
operator!=(Vec256<
double> a, Vec256<
double> b) {
return Mask256<
double>{_mm256_cmp_pd(a.raw, b.raw, _CMP_NEQ_OQ)};
}
// ------------------------------ Strict inequality
// Tag dispatch instead of SFINAE for MSVC 2017 compatibility
namespace detail {
// Pre-9.3 GCC immintrin.h uses char, which may be unsigned, causing cmpgt_epi8
// to perform an unsigned comparison instead of the intended signed. Workaround
// is to cast to an explicitly signed type. See https://godbolt.org/z/PL7Ujy
#if HWY_COMPILER_GCC_ACTUAL != 0 && HWY_COMPILER_GCC_ACTUAL < 903
#define HWY_AVX2_GCC_CMPGT8_WORKAROUND 1
#else
#define HWY_AVX2_GCC_CMPGT8_WORKAROUND 0
#endif
HWY_API Mask256<int8_t> Gt(hwy::SignedTag
/*tag*/, Vec256<int8_t> a,
Vec256<int8_t> b) {
#if HWY_AVX2_GCC_CMPGT8_WORKAROUND
using i8x32 =
signed char __attribute__((__vector_size__(32)));
return Mask256<int8_t>{
static_cast<__m256i>(
reinterpret_cast<i8x32>(a.raw) >
reinterpret_cast<i8x32>(b.raw))};
#else
return Mask256<int8_t>{_mm256_cmpgt_epi8(a.raw, b.raw)};
#endif
}
HWY_API Mask256<int16_t> Gt(hwy::SignedTag
/*tag*/, Vec256<int16_t> a,
Vec256<int16_t> b) {
return Mask256<int16_t>{_mm256_cmpgt_epi16(a.raw, b.raw)};
}
HWY_API Mask256<int32_t> Gt(hwy::SignedTag
/*tag*/, Vec256<int32_t> a,
Vec256<int32_t> b) {
return Mask256<int32_t>{_mm256_cmpgt_epi32(a.raw, b.raw)};
}
HWY_API Mask256<int64_t> Gt(hwy::SignedTag
/*tag*/, Vec256<int64_t> a,
Vec256<int64_t> b) {
return Mask256<int64_t>{_mm256_cmpgt_epi64(a.raw, b.raw)};
}
template <
typename T>
HWY_INLINE Mask256<T> Gt(hwy::UnsignedTag
/*tag*/, Vec256<T> a, Vec256<T> b) {
const Full256<T> du;
const RebindToSigned<decltype(du)> di;
const Vec256<T> msb = Set(du, (LimitsMax<T>() >> 1) + 1);
return RebindMask(du, BitCast(di,
Xor(a, msb)) > BitCast(di,
Xor(b, msb)));
}
HWY_API Mask256<
float> Gt(hwy::FloatTag
/*tag*/, Vec256<float> a,
Vec256<
float> b) {
return Mask256<
float>{_mm256_cmp_ps(a.raw, b.raw, _CMP_GT_OQ)};
}
HWY_API Mask256<
double> Gt(hwy::FloatTag
/*tag*/, Vec256<double> a,
Vec256<
double> b) {
return Mask256<
double>{_mm256_cmp_pd(a.raw, b.raw, _CMP_GT_OQ)};
}
}
// namespace detail
template <
typename T>
HWY_API Mask256<T>
operator>(Vec256<T> a, Vec256<T> b) {
return detail::Gt(hwy::TypeTag<T>(), a, b);
}
// ------------------------------ Weak inequality
namespace detail {
template <
typename T>
HWY_INLINE Mask256<T> Ge(hwy::SignedTag tag, Vec256<T> a, Vec256<T> b) {
return Not(Gt(tag, b, a));
}
template <
typename T>
HWY_INLINE Mask256<T> Ge(hwy::UnsignedTag tag, Vec256<T> a, Vec256<T> b) {
return Not(Gt(tag, b, a));
}
HWY_INLINE Mask256<
float> Ge(hwy::FloatTag
/*tag*/, Vec256<float> a,
Vec256<
float> b) {
return Mask256<
float>{_mm256_cmp_ps(a.raw, b.raw, _CMP_GE_OQ)};
}
HWY_INLINE Mask256<
double> Ge(hwy::FloatTag
/*tag*/, Vec256<double> a,
Vec256<
double> b) {
return Mask256<
double>{_mm256_cmp_pd(a.raw, b.raw, _CMP_GE_OQ)};
}
}
// namespace detail
template <
typename T>
HWY_API Mask256<T>
operator>=(Vec256<T> a, Vec256<T> b) {
return detail::Ge(hwy::TypeTag<T>(), a, b);
}
#endif // HWY_TARGET <= HWY_AVX3
// ------------------------------ Reversed comparisons
template <
typename T>
HWY_API Mask256<T>
operator<(
const Vec256<T> a,
const Vec256<T> b) {
return b > a;
}
template <
typename T>
HWY_API Mask256<T>
operator<=(
const Vec256<T> a,
const Vec256<T> b) {
return b >= a;
}
// ------------------------------ Min (Gt, IfThenElse)
// Unsigned
HWY_API Vec256<uint8_t> Min(
const Vec256<uint8_t> a,
const Vec256<uint8_t> b) {
return Vec256<uint8_t>{_mm256_min_epu8(a.raw, b.raw)};
}
HWY_API Vec256<uint16_t> Min(
const Vec256<uint16_t> a,
const Vec256<uint16_t> b) {
return Vec256<uint16_t>{_mm256_min_epu16(a.raw, b.raw)};
}
HWY_API Vec256<uint32_t> Min(
const Vec256<uint32_t> a,
const Vec256<uint32_t> b) {
return Vec256<uint32_t>{_mm256_min_epu32(a.raw, b.raw)};
}
HWY_API Vec256<uint64_t> Min(
const Vec256<uint64_t> a,
const Vec256<uint64_t> b) {
#if HWY_TARGET <= HWY_AVX3
return Vec256<uint64_t>{_mm256_min_epu64(a.raw, b.raw)};
#else
const Full256<uint64_t> du;
const Full256<int64_t> di;
const auto msb = Set(du, 1ull << 63);
const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb));
return IfThenElse(gt, b, a);
#endif
}
// Signed
HWY_API Vec256<int8_t> Min(
const Vec256<int8_t> a,
const Vec256<int8_t> b) {
return Vec256<int8_t>{_mm256_min_epi8(a.raw, b.raw)};
}
HWY_API Vec256<int16_t> Min(
const Vec256<int16_t> a,
const Vec256<int16_t> b) {
return Vec256<int16_t>{_mm256_min_epi16(a.raw, b.raw)};
}
HWY_API Vec256<int32_t> Min(
const Vec256<int32_t> a,
const Vec256<int32_t> b) {
return Vec256<int32_t>{_mm256_min_epi32(a.raw, b.raw)};
}
HWY_API Vec256<int64_t> Min(
const Vec256<int64_t> a,
const Vec256<int64_t> b) {
#if HWY_TARGET <= HWY_AVX3
return Vec256<int64_t>{_mm256_min_epi64(a.raw, b.raw)};
#else
return IfThenElse(a < b, a, b);
#endif
}
// Float
#if HWY_HAVE_FLOAT16
HWY_API Vec256<float16_t> Min(Vec256<float16_t> a, Vec256<float16_t> b) {
return Vec256<float16_t>{_mm256_min_ph(a.raw, b.raw)};
}
#endif // HWY_HAVE_FLOAT16
HWY_API Vec256<
float> Min(
const Vec256<
float> a,
const Vec256<
float> b) {
return Vec256<
float>{_mm256_min_ps(a.raw, b.raw)};
}
HWY_API Vec256<
double> Min(
const Vec256<
double> a,
const Vec256<
double> b) {
return Vec256<
double>{_mm256_min_pd(a.raw, b.raw)};
}
// ------------------------------ Max (Gt, IfThenElse)
// Unsigned
HWY_API Vec256<uint8_t> Max(
const Vec256<uint8_t> a,
const Vec256<uint8_t> b) {
return Vec256<uint8_t>{_mm256_max_epu8(a.raw, b.raw)};
}
HWY_API Vec256<uint16_t> Max(
const Vec256<uint16_t> a,
const Vec256<uint16_t> b) {
return Vec256<uint16_t>{_mm256_max_epu16(a.raw, b.raw)};
}
HWY_API Vec256<uint32_t> Max(
const Vec256<uint32_t> a,
const Vec256<uint32_t> b) {
return Vec256<uint32_t>{_mm256_max_epu32(a.raw, b.raw)};
}
HWY_API Vec256<uint64_t> Max(
const Vec256<uint64_t> a,
const Vec256<uint64_t> b) {
#if HWY_TARGET <= HWY_AVX3
return Vec256<uint64_t>{_mm256_max_epu64(a.raw, b.raw)};
#else
const Full256<uint64_t> du;
const Full256<int64_t> di;
const auto msb = Set(du, 1ull << 63);
const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb));
return IfThenElse(gt, a, b);
#endif
}
// Signed
HWY_API Vec256<int8_t> Max(
const Vec256<int8_t> a,
const Vec256<int8_t> b) {
return Vec256<int8_t>{_mm256_max_epi8(a.raw, b.raw)};
}
HWY_API Vec256<int16_t> Max(
const Vec256<int16_t> a,
const Vec256<int16_t> b) {
return Vec256<int16_t>{_mm256_max_epi16(a.raw, b.raw)};
}
HWY_API Vec256<int32_t> Max(
const Vec256<int32_t> a,
const Vec256<int32_t> b) {
return Vec256<int32_t>{_mm256_max_epi32(a.raw, b.raw)};
}
HWY_API Vec256<int64_t> Max(
const Vec256<int64_t> a,
const Vec256<int64_t> b) {
#if HWY_TARGET <= HWY_AVX3
return Vec256<int64_t>{_mm256_max_epi64(a.raw, b.raw)};
#else
return IfThenElse(a < b, b, a);
#endif
}
// Float
#if HWY_HAVE_FLOAT16
HWY_API Vec256<float16_t> Max(Vec256<float16_t> a, Vec256<float16_t> b) {
return Vec256<float16_t>{_mm256_max_ph(a.raw, b.raw)};
}
#endif // HWY_HAVE_FLOAT16
HWY_API Vec256<
float> Max(
const Vec256<
float> a,
const Vec256<
float> b) {
return Vec256<
float>{_mm256_max_ps(a.raw, b.raw)};
}
HWY_API Vec256<
double> Max(
const Vec256<
double> a,
const Vec256<
double> b) {
return Vec256<
double>{_mm256_max_pd(a.raw, b.raw)};
}
// ------------------------------ Iota
namespace detail {
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)>
HWY_INLINE VFromD<D> Iota0(D
/*d*/) {
return VFromD<D>{_mm256_set_epi8(
static_cast<
char>(31),
static_cast<
char>(30),
static_cast<
char>(29),
static_cast<
char>(28),
static_cast<
char>(27),
static_cast<
char>(26),
static_cast<
char>(25),
static_cast<
char>(24),
static_cast<
char>(23),
static_cast<
char>(22),
static_cast<
char>(21),
static_cast<
char>(20),
static_cast<
char>(19),
static_cast<
char>(18),
static_cast<
char>(17),
static_cast<
char>(16),
static_cast<
char>(15),
static_cast<
char>(14),
static_cast<
char>(13),
static_cast<
char>(12),
static_cast<
char>(11),
static_cast<
char>(10),
static_cast<
char>(9),
static_cast<
char>(8),
static_cast<
char>(7),
static_cast<
char>(6),
static_cast<
char>(5),
static_cast<
char>(4),
static_cast<
char>(3),
static_cast<
char>(2),
static_cast<
char>(1),
static_cast<
char>(0))};
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI16_D(D)>
HWY_INLINE VFromD<D> Iota0(D
/*d*/) {
return VFromD<D>{_mm256_set_epi16(
int16_t{15}, int16_t{14}, int16_t{13}, int16_t{12}, int16_t{11},
int16_t{10}, int16_t{9}, int16_t{8}, int16_t{7}, int16_t{6}, int16_t{5},
int16_t{4}, int16_t{3}, int16_t{2}, int16_t{1}, int16_t{0})};
}
#if HWY_HAVE_FLOAT16
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)>
HWY_INLINE VFromD<D> Iota0(D
/*d*/) {
return VFromD<D>{
_mm256_set_ph(float16_t{15}, float16_t{14}, float16_t{13}, float16_t{12},
float16_t{11}, float16_t{10}, float16_t{9}, float16_t{8},
float16_t{7}, float16_t{6}, float16_t{5}, float16_t{4},
float16_t{3}, float16_t{2}, float16_t{1}, float16_t{0})};
}
#endif // HWY_HAVE_FLOAT16
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)>
HWY_INLINE VFromD<D> Iota0(D
/*d*/) {
return VFromD<D>{_mm256_set_epi32(int32_t{7}, int32_t{6}, int32_t{5},
int32_t{4}, int32_t{3}, int32_t{2},
int32_t{1}, int32_t{0})};
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)>
HWY_INLINE VFromD<D> Iota0(D
/*d*/) {
return VFromD<D>{
_mm256_set_epi64x(int64_t{3}, int64_t{2}, int64_t{1}, int64_t{0})};
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)>
HWY_INLINE VFromD<D> Iota0(D
/*d*/) {
return VFromD<D>{
_mm256_set_ps(7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f)};
}
template <
class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
HWY_INLINE VFromD<D> Iota0(D
/*d*/) {
return VFromD<D>{_mm256_set_pd(3.0, 2.0, 1.0, 0.0)};
}
}
// namespace detail
template <
class D, HWY_IF_V_SIZE_D(D, 32),
typename T2>
HWY_API VFromD<D> Iota(D d,
const T2 first) {
return detail::Iota0(d) + Set(d, ConvertScalarTo<TFromD<D>>(first));
}
// ------------------------------ FirstN (Iota, Lt)
template <
class D, HWY_IF_V_SIZE_D(D, 32),
class M = MFromD<D>>
HWY_API M FirstN(
const D d, size_t n) {
constexpr size_t kN = MaxLanes(d);
// For AVX3, this ensures `num` <= 255 as required by bzhi, which only looks
// at the lower 8 bits; for AVX2 and below, this ensures `num` fits in TI.
n = HWY_MIN(n, kN);
#if HWY_TARGET <= HWY_AVX3
#if HWY_ARCH_X86_64
const uint64_t all = (1ull << kN) - 1;
return M::FromBits(_bzhi_u64(all, n));
#else
const uint32_t all =
static_cast<uint32_t>((1ull << kN) - 1);
return M::FromBits(_bzhi_u32(all,
static_cast<uint32_t>(n)));
#endif // HWY_ARCH_X86_64
#else
const RebindToSigned<decltype(d)> di;
// Signed comparisons are cheaper.
using TI = TFromD<decltype(di)>;
return RebindMask(d, detail::Iota0(di) < Set(di,
static_cast<TI>(n)));
#endif
}
// ================================================== ARITHMETIC
// ------------------------------ Addition
// Unsigned
HWY_API Vec256<uint8_t>
operator+(Vec256<uint8_t> a, Vec256<uint8_t> b) {
return Vec256<uint8_t>{_mm256_add_epi8(a.raw, b.raw)};
}
HWY_API Vec256<uint16_t>
operator+(Vec256<uint16_t> a, Vec256<uint16_t> b) {
return Vec256<uint16_t>{_mm256_add_epi16(a.raw, b.raw)};
}
HWY_API Vec256<uint32_t>
operator+(Vec256<uint32_t> a, Vec256<uint32_t> b) {
return Vec256<uint32_t>{_mm256_add_epi32(a.raw, b.raw)};
}
HWY_API Vec256<uint64_t>
operator+(Vec256<uint64_t> a, Vec256<uint64_t> b) {
return Vec256<uint64_t>{_mm256_add_epi64(a.raw, b.raw)};
}
// Signed
HWY_API Vec256<int8_t>
operator+(Vec256<int8_t> a, Vec256<int8_t> b) {
return Vec256<int8_t>{_mm256_add_epi8(a.raw, b.raw)};
}
HWY_API Vec256<int16_t>
operator+(Vec256<int16_t> a, Vec256<int16_t> b) {
return Vec256<int16_t>{_mm256_add_epi16(a.raw, b.raw)};
}
HWY_API Vec256<int32_t>
operator+(Vec256<int32_t> a, Vec256<int32_t> b) {
return Vec256<int32_t>{_mm256_add_epi32(a.raw, b.raw)};
}
HWY_API Vec256<int64_t>
operator+(Vec256<int64_t> a, Vec256<int64_t> b) {
return Vec256<int64_t>{_mm256_add_epi64(a.raw, b.raw)};
}
// Float
#if HWY_HAVE_FLOAT16
HWY_API Vec256<float16_t>
operator+(Vec256<float16_t> a, Vec256<float16_t> b) {
return Vec256<float16_t>{_mm256_add_ph(a.raw, b.raw)};
}
#endif // HWY_HAVE_FLOAT16
HWY_API Vec256<
float>
operator+(Vec256<
float> a, Vec256<
float> b) {
return Vec256<
float>{_mm256_add_ps(a.raw, b.raw)};
}
HWY_API Vec256<
double>
operator+(Vec256<
double> a, Vec256<
double> b) {
return Vec256<
double>{_mm256_add_pd(a.raw, b.raw)};
}
--> --------------------
--> maximum size reached
--> --------------------