Quellcodebibliothek Statistik Leitseite products/Sources/formale Sprachen/C/Firefox/third_party/highway/hwy/ops/   (Browser von der Mozilla Stiftung Version 136.0.1©)  Datei vom 10.2.2025 mit Größe 281 kB image not shown  

Quelle  x86_512-inl.h   Sprache: C

 
// Copyright 2019 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// 512-bit AVX512 vectors and operations.
// External include guard in highway.h - see comment there.

// WARNING: most operations do not cross 128-bit block boundaries. In
// particular, "Broadcast", pack and zip behavior may be surprising.

// Must come before HWY_DIAGNOSTICS and HWY_COMPILER_CLANGCL
#include "hwy/base.h"

// Avoid uninitialized warnings in GCC's avx512fintrin.h - see
// https://github.com/google/highway/issues/710)
HWY_DIAGNOSTICS(push)
#if HWY_COMPILER_GCC_ACTUAL
HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized")
HWY_DIAGNOSTICS_OFF(disable : 4701 4703 6001 26494,
                    ignored "-Wmaybe-uninitialized")
#endif

#include <immintrin.h>  // AVX2+

#if HWY_COMPILER_CLANGCL
// Including <immintrin.h> should be enough, but Clang's headers helpfully skip
// including these headers when _MSC_VER is defined, like when using clang-cl.
// Include these directly here.
// clang-format off
#include <smmintrin.h>

#include <avxintrin.h>
// avxintrin defines __m256i and must come before avx2intrin.
#include <avx2intrin.h>
#include <f16cintrin.h>
#include <fmaintrin.h>

#include <avx512fintrin.h>
#include <avx512vlintrin.h>
#include <avx512bwintrin.h>
#include <avx512vlbwintrin.h>
#include <avx512dqintrin.h>
#include <avx512vldqintrin.h>
#include <avx512cdintrin.h>
#include <avx512vlcdintrin.h>

#if HWY_TARGET <= HWY_AVX3_DL
#include <avx512bitalgintrin.h>
#include <avx512vlbitalgintrin.h>
#include <avx512vbmiintrin.h>
#include <avx512vbmivlintrin.h>
#include <avx512vbmi2intrin.h>
#include <avx512vlvbmi2intrin.h>
#include <avx512vpopcntdqintrin.h>
#include <avx512vpopcntdqvlintrin.h>
#include <avx512vnniintrin.h>
#include <avx512vlvnniintrin.h>
// Must come after avx512fintrin, else will not define 512-bit intrinsics.
#include <vaesintrin.h>
#include <vpclmulqdqintrin.h>
#include <gfniintrin.h>
#endif  // HWY_TARGET <= HWY_AVX3_DL

#if HWY_TARGET <= HWY_AVX3_SPR
#include <avx512fp16intrin.h>
#include <avx512vlfp16intrin.h>
#endif  // HWY_TARGET <= HWY_AVX3_SPR

// clang-format on
#endif  // HWY_COMPILER_CLANGCL

// For half-width vectors. Already includes base.h and shared-inl.h.
#include "hwy/ops/x86_256-inl.h"

HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {

namespace detail {

template <typename T>
struct Raw512 {
  using type = __m512i;
};
#if HWY_HAVE_FLOAT16
template <>
struct Raw512<float16_t> {
  using type = __m512h;
};
#endif  // HWY_HAVE_FLOAT16
template <>
struct Raw512<float> {
  using type = __m512;
};
template <>
struct Raw512<double> {
  using type = __m512d;
};

// Template arg: sizeof(lane type)
template <size_t size>
struct RawMask512 {};
template <>
struct RawMask512<1> {
  using type = __mmask64;
};
template <>
struct RawMask512<2> {
  using type = __mmask32;
};
template <>
struct RawMask512<4> {
  using type = __mmask16;
};
template <>
struct RawMask512<8> {
  using type = __mmask8;
};

}  // namespace detail

template <typename T>
class Vec512 {
  using Raw = typename detail::Raw512<T>::type;

 public:
  using PrivateT = T;                                  // only for DFromV
  static constexpr size_t kPrivateN = 64 / sizeof(T);  // only for DFromV

  // Compound assignment. Only usable if there is a corresponding non-member
  // binary operator overload. For example, only f32 and f64 support division.
  HWY_INLINE Vec512& operator*=(const Vec512 other) {
    return *this = (*this * other);
  }
  HWY_INLINE Vec512& operator/=(const Vec512 other) {
    return *this = (*this / other);
  }
  HWY_INLINE Vec512& operator+=(const Vec512 other) {
    return *this = (*this + other);
  }
  HWY_INLINE Vec512& operator-=(const Vec512 other) {
    return *this = (*this - other);
  }
  HWY_INLINE Vec512& operator%=(const Vec512 other) {
    return *this = (*this % other);
  }
  HWY_INLINE Vec512& operator&=(const Vec512 other) {
    return *this = (*this & other);
  }
  HWY_INLINE Vec512& operator|=(const Vec512 other) {
    return *this = (*this | other);
  }
  HWY_INLINE Vec512& operator^=(const Vec512 other) {
    return *this = (*this ^ other);
  }

  Raw raw;
};

// Mask register: one bit per lane.
template <typename T>
struct Mask512 {
  using Raw = typename detail::RawMask512<sizeof(T)>::type;
  Raw raw;
};

template <typename T>
using Full512 = Simd<T, 64 / sizeof(T), 0>;

// ------------------------------ BitCast

namespace detail {

HWY_INLINE __m512i BitCastToInteger(__m512i v) { return v; }
#if HWY_HAVE_FLOAT16
HWY_INLINE __m512i BitCastToInteger(__m512h v) {
  return _mm512_castph_si512(v);
}
#endif  // HWY_HAVE_FLOAT16
HWY_INLINE __m512i BitCastToInteger(__m512 v) { return _mm512_castps_si512(v); }
HWY_INLINE __m512i BitCastToInteger(__m512d v) {
  return _mm512_castpd_si512(v);
}

template <typename T>
HWY_INLINE Vec512<uint8_t> BitCastToByte(Vec512<T> v) {
  return Vec512<uint8_t>{BitCastToInteger(v.raw)};
}

// Cannot rely on function overloading because return types differ.
template <typename T>
struct BitCastFromInteger512 {
  HWY_INLINE __m512i operator()(__m512i v) { return v; }
};
#if HWY_HAVE_FLOAT16
template <>
struct BitCastFromInteger512<float16_t> {
  HWY_INLINE __m512h operator()(__m512i v) { return _mm512_castsi512_ph(v); }
};
#endif  // HWY_HAVE_FLOAT16
template <>
struct BitCastFromInteger512<float> {
  HWY_INLINE __m512 operator()(__m512i v) { return _mm512_castsi512_ps(v); }
};
template <>
struct BitCastFromInteger512<double> {
  HWY_INLINE __m512d operator()(__m512i v) { return _mm512_castsi512_pd(v); }
};

template <class D, HWY_IF_V_SIZE_D(D, 64)>
HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */, Vec512<uint8_t> v) {
  return VFromD<D>{BitCastFromInteger512<TFromD<D>>()(v.raw)};
}

}  // namespace detail

template <class D, HWY_IF_V_SIZE_D(D, 64), typename FromT>
HWY_API VFromD<D> BitCast(D d, Vec512<FromT> v) {
  return detail::BitCastFromByte(d, detail::BitCastToByte(v));
}

// ------------------------------ Set

template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)>
HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) {
  return VFromD<D>{_mm512_set1_epi8(static_cast<char>(t))};  // NOLINT
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI16_D(D)>
HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) {
  return VFromD<D>{_mm512_set1_epi16(static_cast<short>(t))};  // NOLINT
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)>
HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) {
  return VFromD<D>{_mm512_set1_epi32(static_cast<int>(t))};
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)>
HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) {
  return VFromD<D>{_mm512_set1_epi64(static_cast<long long>(t))};  // NOLINT
}
// bfloat16_t is handled by x86_128-inl.h.
#if HWY_HAVE_FLOAT16
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)>
HWY_API Vec512<float16_t> Set(D /* tag */, float16_t t) {
  return Vec512<float16_t>{_mm512_set1_ph(t)};
}
#endif  // HWY_HAVE_FLOAT16
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)>
HWY_API Vec512<float> Set(D /* tag */, float t) {
  return Vec512<float>{_mm512_set1_ps(t)};
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)>
HWY_API Vec512<double> Set(D /* tag */, double t) {
  return Vec512<double>{_mm512_set1_pd(t)};
}

// ------------------------------ Zero (Set)

// GCC pre-9.1 lacked setzero, so use Set instead.
#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900

// Cannot use VFromD here because it is defined in terms of Zero.
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_SPECIAL_FLOAT_D(D)>
HWY_API Vec512<TFromD<D>> Zero(D d) {
  return Set(d, TFromD<D>{0});
}
// BitCast is defined below, but the Raw type is the same, so use that.
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_BF16_D(D)>
HWY_API Vec512<bfloat16_t> Zero(D /* tag */) {
  const RebindToUnsigned<D> du;
  return Vec512<bfloat16_t>{Set(du, 0).raw};
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)>
HWY_API Vec512<float16_t> Zero(D /* tag */) {
  const RebindToUnsigned<D> du;
  return Vec512<float16_t>{Set(du, 0).raw};
}

#else

template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)>
HWY_API Vec512<TFromD<D>> Zero(D /* tag */) {
  return Vec512<TFromD<D>>{_mm512_setzero_si512()};
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_BF16_D(D)>
HWY_API Vec512<bfloat16_t> Zero(D /* tag */) {
  return Vec512<bfloat16_t>{_mm512_setzero_si512()};
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)>
HWY_API Vec512<float16_t> Zero(D /* tag */) {
#if HWY_HAVE_FLOAT16
  return Vec512<float16_t>{_mm512_setzero_ph()};
#else
  return Vec512<float16_t>{_mm512_setzero_si512()};
#endif
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)>
HWY_API Vec512<float> Zero(D /* tag */) {
  return Vec512<float>{_mm512_setzero_ps()};
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)>
HWY_API Vec512<double> Zero(D /* tag */) {
  return Vec512<double>{_mm512_setzero_pd()};
}

#endif  // HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900

// ------------------------------ Undefined

HWY_DIAGNOSTICS(push)
HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized")

// Returns a vector with uninitialized elements.
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)>
HWY_API Vec512<TFromD<D>> Undefined(D /* tag */) {
  // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC
  // generate an XOR instruction.
  return Vec512<TFromD<D>>{_mm512_undefined_epi32()};
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_BF16_D(D)>
HWY_API Vec512<bfloat16_t> Undefined(D /* tag */) {
  return Vec512<bfloat16_t>{_mm512_undefined_epi32()};
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)>
HWY_API Vec512<float16_t> Undefined(D /* tag */) {
#if HWY_HAVE_FLOAT16
  return Vec512<float16_t>{_mm512_undefined_ph()};
#else
  return Vec512<float16_t>{_mm512_undefined_epi32()};
#endif
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)>
HWY_API Vec512<float> Undefined(D /* tag */) {
  return Vec512<float>{_mm512_undefined_ps()};
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)>
HWY_API Vec512<double> Undefined(D /* tag */) {
  return Vec512<double>{_mm512_undefined_pd()};
}

HWY_DIAGNOSTICS(pop)

// ------------------------------ ResizeBitCast

// 64-byte vector to 16-byte vector
template <class D, class FromV, HWY_IF_V_SIZE_V(FromV, 64),
          HWY_IF_V_SIZE_D(D, 16)>
HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
  return BitCast(d, Vec128<uint8_t>{_mm512_castsi512_si128(
                        BitCast(Full512<uint8_t>(), v).raw)});
}

// <= 16-byte vector to 64-byte vector
template <class D, class FromV, HWY_IF_V_SIZE_LE_V(FromV, 16),
          HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
  return BitCast(d, Vec512<uint8_t>{_mm512_castsi128_si512(
                        ResizeBitCast(Full128<uint8_t>(), v).raw)});
}

// 32-byte vector to 64-byte vector
template <class D, class FromV, HWY_IF_V_SIZE_V(FromV, 32),
          HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
  return BitCast(d, Vec512<uint8_t>{_mm512_castsi256_si512(
                        BitCast(Full256<uint8_t>(), v).raw)});
}

// ------------------------------ Dup128VecFromValues

template <class D, HWY_IF_UI8_D(D), HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1,
                                      TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
                                      TFromD<D> t5, TFromD<D> t6, TFromD<D> t7,
                                      TFromD<D> t8, TFromD<D> t9, TFromD<D> t10,
                                      TFromD<D> t11, TFromD<D> t12,
                                      TFromD<D> t13, TFromD<D> t14,
                                      TFromD<D> t15) {
#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900
  // Missing set_epi8/16.
  return BroadcastBlock<0>(ResizeBitCast(
      d, Dup128VecFromValues(Full128<TFromD<D>>(), t0, t1, t2, t3, t4, t5, t6,
                             t7, t8, t9, t10, t11, t12, t13, t14, t15)));
#else
  (void)d;
  // Need to use _mm512_set_epi8 as there is no _mm512_setr_epi8 intrinsic
  // available
  return VFromD<D>{_mm512_set_epi8(
      static_cast<char>(t15), static_cast<char>(t14), static_cast<char>(t13),
      static_cast<char>(t12), static_cast<char>(t11), static_cast<char>(t10),
      static_cast<char>(t9), static_cast<char>(t8), static_cast<char>(t7),
      static_cast<char>(t6), static_cast<char>(t5), static_cast<char>(t4),
      static_cast<char>(t3), static_cast<char>(t2), static_cast<char>(t1),
      static_cast<char>(t0), static_cast<char>(t15), static_cast<char>(t14),
      static_cast<char>(t13), static_cast<char>(t12), static_cast<char>(t11),
      static_cast<char>(t10), static_cast<char>(t9), static_cast<char>(t8),
      static_cast<char>(t7), static_cast<char>(t6), static_cast<char>(t5),
      static_cast<char>(t4), static_cast<char>(t3), static_cast<char>(t2),
      static_cast<char>(t1), static_cast<char>(t0), static_cast<char>(t15),
      static_cast<char>(t14), static_cast<char>(t13), static_cast<char>(t12),
      static_cast<char>(t11), static_cast<char>(t10), static_cast<char>(t9),
      static_cast<char>(t8), static_cast<char>(t7), static_cast<char>(t6),
      static_cast<char>(t5), static_cast<char>(t4), static_cast<char>(t3),
      static_cast<char>(t2), static_cast<char>(t1), static_cast<char>(t0),
      static_cast<char>(t15), static_cast<char>(t14), static_cast<char>(t13),
      static_cast<char>(t12), static_cast<char>(t11), static_cast<char>(t10),
      static_cast<char>(t9), static_cast<char>(t8), static_cast<char>(t7),
      static_cast<char>(t6), static_cast<char>(t5), static_cast<char>(t4),
      static_cast<char>(t3), static_cast<char>(t2), static_cast<char>(t1),
      static_cast<char>(t0))};
#endif
}

template <class D, HWY_IF_UI16_D(D), HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1,
                                      TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
                                      TFromD<D> t5, TFromD<D> t6,
                                      TFromD<D> t7) {
#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900
  // Missing set_epi8/16.
  return BroadcastBlock<0>(
      ResizeBitCast(d, Dup128VecFromValues(Full128<TFromD<D>>(), t0, t1, t2, t3,
                                           t4, t5, t6, t7)));
#else
  (void)d;
  // Need to use _mm512_set_epi16 as there is no _mm512_setr_epi16 intrinsic
  // available
  return VFromD<D>{
      _mm512_set_epi16(static_cast<int16_t>(t7), static_cast<int16_t>(t6),
                       static_cast<int16_t>(t5), static_cast<int16_t>(t4),
                       static_cast<int16_t>(t3), static_cast<int16_t>(t2),
                       static_cast<int16_t>(t1), static_cast<int16_t>(t0),
                       static_cast<int16_t>(t7), static_cast<int16_t>(t6),
                       static_cast<int16_t>(t5), static_cast<int16_t>(t4),
                       static_cast<int16_t>(t3), static_cast<int16_t>(t2),
                       static_cast<int16_t>(t1), static_cast<int16_t>(t0),
                       static_cast<int16_t>(t7), static_cast<int16_t>(t6),
                       static_cast<int16_t>(t5), static_cast<int16_t>(t4),
                       static_cast<int16_t>(t3), static_cast<int16_t>(t2),
                       static_cast<int16_t>(t1), static_cast<int16_t>(t0),
                       static_cast<int16_t>(t7), static_cast<int16_t>(t6),
                       static_cast<int16_t>(t5), static_cast<int16_t>(t4),
                       static_cast<int16_t>(t3), static_cast<int16_t>(t2),
                       static_cast<int16_t>(t1), static_cast<int16_t>(t0))};
#endif
}

#if HWY_HAVE_FLOAT16
template <class D, HWY_IF_F16_D(D), HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1,
                                      TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
                                      TFromD<D> t5, TFromD<D> t6,
                                      TFromD<D> t7) {
  return VFromD<D>{_mm512_setr_ph(t0, t1, t2, t3, t4, t5, t6, t7, t0, t1, t2,
                                  t3, t4, t5, t6, t7, t0, t1, t2, t3, t4, t5,
                                  t6, t7, t0, t1, t2, t3, t4, t5, t6, t7)};
}
#endif

template <class D, HWY_IF_UI32_D(D), HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1,
                                      TFromD<D> t2, TFromD<D> t3) {
  return VFromD<D>{
      _mm512_setr_epi32(static_cast<int32_t>(t0), static_cast<int32_t>(t1),
                        static_cast<int32_t>(t2), static_cast<int32_t>(t3),
                        static_cast<int32_t>(t0), static_cast<int32_t>(t1),
                        static_cast<int32_t>(t2), static_cast<int32_t>(t3),
                        static_cast<int32_t>(t0), static_cast<int32_t>(t1),
                        static_cast<int32_t>(t2), static_cast<int32_t>(t3),
                        static_cast<int32_t>(t0), static_cast<int32_t>(t1),
                        static_cast<int32_t>(t2), static_cast<int32_t>(t3))};
}

template <class D, HWY_IF_F32_D(D), HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1,
                                      TFromD<D> t2, TFromD<D> t3) {
  return VFromD<D>{_mm512_setr_ps(t0, t1, t2, t3, t0, t1, t2, t3, t0, t1, t2,
                                  t3, t0, t1, t2, t3)};
}

template <class D, HWY_IF_UI64_D(D), HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) {
  return VFromD<D>{
      _mm512_setr_epi64(static_cast<int64_t>(t0), static_cast<int64_t>(t1),
                        static_cast<int64_t>(t0), static_cast<int64_t>(t1),
                        static_cast<int64_t>(t0), static_cast<int64_t>(t1),
                        static_cast<int64_t>(t0), static_cast<int64_t>(t1))};
}

template <class D, HWY_IF_F64_D(D), HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) {
  return VFromD<D>{_mm512_setr_pd(t0, t1, t0, t1, t0, t1, t0, t1)};
}

// ----------------------------- Iota

namespace detail {

template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_T_SIZE_D(D, 1)>
HWY_INLINE VFromD<D> Iota0(D d) {
#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900
  // Missing set_epi8/16.
  alignas(64) static constexpr TFromD<D> kIota[64] = {
      0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
      16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
      32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
      48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63};
  return Load(d, kIota);
#else
  (void)d;
  return VFromD<D>{_mm512_set_epi8(
      static_cast<char>(63), static_cast<char>(62), static_cast<char>(61),
      static_cast<char>(60), static_cast<char>(59), static_cast<char>(58),
      static_cast<char>(57), static_cast<char>(56), static_cast<char>(55),
      static_cast<char>(54), static_cast<char>(53), static_cast<char>(52),
      static_cast<char>(51), static_cast<char>(50), static_cast<char>(49),
      static_cast<char>(48), static_cast<char>(47), static_cast<char>(46),
      static_cast<char>(45), static_cast<char>(44), static_cast<char>(43),
      static_cast<char>(42), static_cast<char>(41), static_cast<char>(40),
      static_cast<char>(39), static_cast<char>(38), static_cast<char>(37),
      static_cast<char>(36), static_cast<char>(35), static_cast<char>(34),
      static_cast<char>(33), static_cast<char>(32), static_cast<char>(31),
      static_cast<char>(30), static_cast<char>(29), static_cast<char>(28),
      static_cast<char>(27), static_cast<char>(26), static_cast<char>(25),
      static_cast<char>(24), static_cast<char>(23), static_cast<char>(22),
      static_cast<char>(21), static_cast<char>(20), static_cast<char>(19),
      static_cast<char>(18), static_cast<char>(17), static_cast<char>(16),
      static_cast<char>(15), static_cast<char>(14), static_cast<char>(13),
      static_cast<char>(12), static_cast<char>(11), static_cast<char>(10),
      static_cast<char>(9), static_cast<char>(8), static_cast<char>(7),
      static_cast<char>(6), static_cast<char>(5), static_cast<char>(4),
      static_cast<char>(3), static_cast<char>(2), static_cast<char>(1),
      static_cast<char>(0))};
#endif
}

template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI16_D(D)>
HWY_INLINE VFromD<D> Iota0(D d) {
#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 900
  // Missing set_epi8/16.
  alignas(64) static constexpr TFromD<D> kIota[32] = {
      0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
      16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31};
  return Load(d, kIota);
#else
  (void)d;
  return VFromD<D>{_mm512_set_epi16(
      int16_t{31}, int16_t{30}, int16_t{29}, int16_t{28}, int16_t{27},
      int16_t{26}, int16_t{25}, int16_t{24}, int16_t{23}, int16_t{22},
      int16_t{21}, int16_t{20}, int16_t{19}, int16_t{18}, int16_t{17},
      int16_t{16}, int16_t{15}, int16_t{14}, int16_t{13}, int16_t{12},
      int16_t{11}, int16_t{10}, int16_t{9}, int16_t{8}, int16_t{7}, int16_t{6},
      int16_t{5}, int16_t{4}, int16_t{3}, int16_t{2}, int16_t{1}, int16_t{0})};
#endif
}

#if HWY_HAVE_FLOAT16
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F16_D(D)>
HWY_INLINE VFromD<D> Iota0(D /*d*/) {
  return VFromD<D>{_mm512_set_ph(
      float16_t{31}, float16_t{30}, float16_t{29}, float16_t{28}, float16_t{27},
      float16_t{26}, float16_t{25}, float16_t{24}, float16_t{23}, float16_t{22},
      float16_t{21}, float16_t{20}, float16_t{19}, float16_t{18}, float16_t{17},
      float16_t{16}, float16_t{15}, float16_t{14}, float16_t{13}, float16_t{12},
      float16_t{11}, float16_t{10}, float16_t{9}, float16_t{8}, float16_t{7},
      float16_t{6}, float16_t{5}, float16_t{4}, float16_t{3}, float16_t{2},
      float16_t{1}, float16_t{0})};
}
#endif  // HWY_HAVE_FLOAT16

template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI32_D(D)>
HWY_INLINE VFromD<D> Iota0(D /*d*/) {
  return VFromD<D>{_mm512_set_epi32(
      int32_t{15}, int32_t{14}, int32_t{13}, int32_t{12}, int32_t{11},
      int32_t{10}, int32_t{9}, int32_t{8}, int32_t{7}, int32_t{6}, int32_t{5},
      int32_t{4}, int32_t{3}, int32_t{2}, int32_t{1}, int32_t{0})};
}

template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_UI64_D(D)>
HWY_INLINE VFromD<D> Iota0(D /*d*/) {
  return VFromD<D>{_mm512_set_epi64(int64_t{7}, int64_t{6}, int64_t{5},
                                    int64_t{4}, int64_t{3}, int64_t{2},
                                    int64_t{1}, int64_t{0})};
}

template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F32_D(D)>
HWY_INLINE VFromD<D> Iota0(D /*d*/) {
  return VFromD<D>{_mm512_set_ps(15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f,
                                 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f,
                                 0.0f)};
}

template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_F64_D(D)>
HWY_INLINE VFromD<D> Iota0(D /*d*/) {
  return VFromD<D>{_mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0)};
}

}  // namespace detail

template <class D, typename T2, HWY_IF_V_SIZE_D(D, 64)>
HWY_API VFromD<D> Iota(D d, const T2 first) {
  return detail::Iota0(d) + Set(d, ConvertScalarTo<TFromD<D>>(first));
}

// ================================================== LOGICAL

// ------------------------------ Not

template <typename T>
HWY_API Vec512<T> Not(const Vec512<T> v) {
  const DFromV<decltype(v)> d;
  const RebindToUnsigned<decltype(d)> du;
  using VU = VFromD<decltype(du)>;
  const __m512i vu = BitCast(du, v).raw;
  return BitCast(d, VU{_mm512_ternarylogic_epi32(vu, vu, vu, 0x55)});
}

// ------------------------------ And

template <typename T>
HWY_API Vec512<T> And(const Vec512<T> a, const Vec512<T> b) {
  const DFromV<decltype(a)> d;  // for float16_t
  const RebindToUnsigned<decltype(d)> du;
  return BitCast(d, VFromD<decltype(du)>{_mm512_and_si512(BitCast(du, a).raw,
                                                          BitCast(du, b).raw)});
}

HWY_API Vec512<floatAnd(const Vec512<float> a, const Vec512<float> b) {
  return Vec512<float>{_mm512_and_ps(a.raw, b.raw)};
}
HWY_API Vec512<doubleAnd(const Vec512<double> a, const Vec512<double> b) {
  return Vec512<double>{_mm512_and_pd(a.raw, b.raw)};
}

// ------------------------------ AndNot

// Returns ~not_mask & mask.
template <typename T>
HWY_API Vec512<T> AndNot(const Vec512<T> not_mask, const Vec512<T> mask) {
  const DFromV<decltype(mask)> d;  // for float16_t
  const RebindToUnsigned<decltype(d)> du;
  return BitCast(d, VFromD<decltype(du)>{_mm512_andnot_si512(
                        BitCast(du, not_mask).raw, BitCast(du, mask).raw)});
}
HWY_API Vec512<float> AndNot(const Vec512<float> not_mask,
                             const Vec512<float> mask) {
  return Vec512<float>{_mm512_andnot_ps(not_mask.raw, mask.raw)};
}
HWY_API Vec512<double> AndNot(const Vec512<double> not_mask,
                              const Vec512<double> mask) {
  return Vec512<double>{_mm512_andnot_pd(not_mask.raw, mask.raw)};
}

// ------------------------------ Or

template <typename T>
HWY_API Vec512<T> Or(const Vec512<T> a, const Vec512<T> b) {
  const DFromV<decltype(a)> d;  // for float16_t
  const RebindToUnsigned<decltype(d)> du;
  return BitCast(d, VFromD<decltype(du)>{_mm512_or_si512(BitCast(du, a).raw,
                                                         BitCast(du, b).raw)});
}

HWY_API Vec512<floatOr(const Vec512<float> a, const Vec512<float> b) {
  return Vec512<float>{_mm512_or_ps(a.raw, b.raw)};
}
HWY_API Vec512<doubleOr(const Vec512<double> a, const Vec512<double> b) {
  return Vec512<double>{_mm512_or_pd(a.raw, b.raw)};
}

// ------------------------------ Xor

template <typename T>
HWY_API Vec512<T> Xor(const Vec512<T> a, const Vec512<T> b) {
  const DFromV<decltype(a)> d;  // for float16_t
  const RebindToUnsigned<decltype(d)> du;
  return BitCast(d, VFromD<decltype(du)>{_mm512_xor_si512(BitCast(du, a).raw,
                                                          BitCast(du, b).raw)});
}

HWY_API Vec512<floatXor(const Vec512<float> a, const Vec512<float> b) {
  return Vec512<float>{_mm512_xor_ps(a.raw, b.raw)};
}
HWY_API Vec512<doubleXor(const Vec512<double> a, const Vec512<double> b) {
  return Vec512<double>{_mm512_xor_pd(a.raw, b.raw)};
}

// ------------------------------ Xor3
template <typename T>
HWY_API Vec512<T> Xor3(Vec512<T> x1, Vec512<T> x2, Vec512<T> x3) {
  const DFromV<decltype(x1)> d;
  const RebindToUnsigned<decltype(d)> du;
  using VU = VFromD<decltype(du)>;
  const __m512i ret = _mm512_ternarylogic_epi64(
      BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96);
  return BitCast(d, VU{ret});
}

// ------------------------------ Or3
template <typename T>
HWY_API Vec512<T> Or3(Vec512<T> o1, Vec512<T> o2, Vec512<T> o3) {
  const DFromV<decltype(o1)> d;
  const RebindToUnsigned<decltype(d)> du;
  using VU = VFromD<decltype(du)>;
  const __m512i ret = _mm512_ternarylogic_epi64(
      BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE);
  return BitCast(d, VU{ret});
}

// ------------------------------ OrAnd
template <typename T>
HWY_API Vec512<T> OrAnd(Vec512<T> o, Vec512<T> a1, Vec512<T> a2) {
  const DFromV<decltype(o)> d;
  const RebindToUnsigned<decltype(d)> du;
  using VU = VFromD<decltype(du)>;
  const __m512i ret = _mm512_ternarylogic_epi64(
      BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8);
  return BitCast(d, VU{ret});
}

// ------------------------------ IfVecThenElse
template <typename T>
HWY_API Vec512<T> IfVecThenElse(Vec512<T> mask, Vec512<T> yes, Vec512<T> no) {
  const DFromV<decltype(yes)> d;
  const RebindToUnsigned<decltype(d)> du;
  using VU = VFromD<decltype(du)>;
  return BitCast(d, VU{_mm512_ternarylogic_epi64(BitCast(du, mask).raw,
                                                 BitCast(du, yes).raw,
                                                 BitCast(du, no).raw, 0xCA)});
}

// ------------------------------ Operator overloads (internal-only if float)

template <typename T>
HWY_API Vec512<T> operator&(const Vec512<T> a, const Vec512<T> b) {
  return And(a, b);
}

template <typename T>
HWY_API Vec512<T> operator|(const Vec512<T> a, const Vec512<T> b) {
  return Or(a, b);
}

template <typename T>
HWY_API Vec512<T> operator^(const Vec512<T> a, const Vec512<T> b) {
  return Xor(a, b);
}

// ------------------------------ PopulationCount

// 8/16 require BITALG, 32/64 require VPOPCNTDQ.
#if HWY_TARGET <= HWY_AVX3_DL

#ifdef HWY_NATIVE_POPCNT
#undef HWY_NATIVE_POPCNT
#else
#define HWY_NATIVE_POPCNT
#endif

namespace detail {

template <typename T>
HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<1> /* tag */, Vec512<T> v) {
  return Vec512<T>{_mm512_popcnt_epi8(v.raw)};
}
template <typename T>
HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<2> /* tag */, Vec512<T> v) {
  return Vec512<T>{_mm512_popcnt_epi16(v.raw)};
}
template <typename T>
HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<4> /* tag */, Vec512<T> v) {
  return Vec512<T>{_mm512_popcnt_epi32(v.raw)};
}
template <typename T>
HWY_INLINE Vec512<T> PopulationCount(hwy::SizeTag<8> /* tag */, Vec512<T> v) {
  return Vec512<T>{_mm512_popcnt_epi64(v.raw)};
}

}  // namespace detail

template <typename T>
HWY_API Vec512<T> PopulationCount(Vec512<T> v) {
  return detail::PopulationCount(hwy::SizeTag<sizeof(T)>(), v);
}

#endif  // HWY_TARGET <= HWY_AVX3_DL

// ================================================== MASK

// ------------------------------ FirstN

// Possibilities for constructing a bitmask of N ones:
// - kshift* only consider the lowest byte of the shift count, so they would
//   not correctly handle large n.
// - Scalar shifts >= 64 are UB.
// - BZHI has the desired semantics; we assume AVX-512 implies BMI2. However,
//   we need 64-bit masks for sizeof(T) == 1, so special-case 32-bit builds.

#if HWY_ARCH_X86_32
namespace detail {

// 32 bit mask is sufficient for lane size >= 2.
template <typename T, HWY_IF_NOT_T_SIZE(T, 1)>
HWY_INLINE Mask512<T> FirstN(size_t n) {
  Mask512<T> m;
  const uint32_t all = ~uint32_t{0};
  // BZHI only looks at the lower 8 bits of n, but it has been clamped to
  // MaxLanes, which is at most 32.
  m.raw = static_cast<decltype(m.raw)>(_bzhi_u32(all, n));
  return m;
}

#if HWY_COMPILER_MSVC >= 1920 || HWY_COMPILER_GCC_ACTUAL >= 900 || \
    HWY_COMPILER_CLANG || HWY_COMPILER_ICC
template <typename T, HWY_IF_T_SIZE(T, 1)>
HWY_INLINE Mask512<T> FirstN(size_t n) {
  uint32_t lo_mask;
  uint32_t hi_mask;
  uint32_t hi_mask_len;
#if HWY_COMPILER_GCC
  if (__builtin_constant_p(n >= 32) && n >= 32) {
    if (__builtin_constant_p(n >= 64) && n >= 64) {
      hi_mask_len = 32u;
    } else {
      hi_mask_len = static_cast<uint32_t>(n) - 32u;
    }
    lo_mask = hi_mask = 0xFFFFFFFFu;
  } else  // NOLINT(readability/braces)
#endif
  {
    const uint32_t lo_mask_len = static_cast<uint32_t>(n);
    lo_mask = _bzhi_u32(0xFFFFFFFFu, lo_mask_len);

#if HWY_COMPILER_GCC
    if (__builtin_constant_p(lo_mask_len <= 32) && lo_mask_len <= 32) {
      return Mask512<T>{static_cast<__mmask64>(lo_mask)};
    }
#endif

    _addcarry_u32(_subborrow_u32(0, lo_mask_len, 32u, &hi_mask_len),
                  0xFFFFFFFFu, 0u, &hi_mask);
  }
  hi_mask = _bzhi_u32(hi_mask, hi_mask_len);
#if HWY_COMPILER_GCC && !HWY_COMPILER_ICC
  if (__builtin_constant_p((static_cast<uint64_t>(hi_mask) << 32) | lo_mask))
#endif
    return Mask512<T>{static_cast<__mmask64>(
        (static_cast<uint64_t>(hi_mask) << 32) | lo_mask)};
#if HWY_COMPILER_GCC && !HWY_COMPILER_ICC
  else
    return Mask512<T>{_mm512_kunpackd(static_cast<__mmask64>(hi_mask),
                                      static_cast<__mmask64>(lo_mask))};
#endif
}
#else   // HWY_COMPILER..
template <typename T, HWY_IF_T_SIZE(T, 1)>
HWY_INLINE Mask512<T> FirstN(size_t n) {
  const uint64_t bits = n < 64 ? ((1ULL << n) - 1) : ~uint64_t{0};
  return Mask512<T>{static_cast<__mmask64>(bits)};
}
#endif  // HWY_COMPILER..
}  // namespace detail
#endif  // HWY_ARCH_X86_32

template <class D, HWY_IF_V_SIZE_D(D, 64)>
HWY_API MFromD<D> FirstN(D d, size_t n) {
  // This ensures `num` <= 255 as required by bzhi, which only looks
  // at the lower 8 bits.
  n = HWY_MIN(n, MaxLanes(d));

#if HWY_ARCH_X86_64
  MFromD<D> m;
  const uint64_t all = ~uint64_t{0};
  m.raw = static_cast<decltype(m.raw)>(_bzhi_u64(all, n));
  return m;
#else
  return detail::FirstN<TFromD<D>>(n);
#endif  // HWY_ARCH_X86_64
}

// ------------------------------ IfThenElse

// Returns mask ? b : a.

namespace detail {

// Templates for signed/unsigned integer of a particular size.
template <typename T>
HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<1> /* tag */,
                                const Mask512<T> mask, const Vec512<T> yes,
                                const Vec512<T> no) {
  return Vec512<T>{_mm512_mask_blend_epi8(mask.raw, no.raw, yes.raw)};
}
template <typename T>
HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<2> /* tag */,
                                const Mask512<T> mask, const Vec512<T> yes,
                                const Vec512<T> no) {
  return Vec512<T>{_mm512_mask_blend_epi16(mask.raw, no.raw, yes.raw)};
}
template <typename T>
HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<4> /* tag */,
                                const Mask512<T> mask, const Vec512<T> yes,
                                const Vec512<T> no) {
  return Vec512<T>{_mm512_mask_blend_epi32(mask.raw, no.raw, yes.raw)};
}
template <typename T>
HWY_INLINE Vec512<T> IfThenElse(hwy::SizeTag<8> /* tag */,
                                const Mask512<T> mask, const Vec512<T> yes,
                                const Vec512<T> no) {
  return Vec512<T>{_mm512_mask_blend_epi64(mask.raw, no.raw, yes.raw)};
}

}  // namespace detail

template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)>
HWY_API Vec512<T> IfThenElse(const Mask512<T> mask, const Vec512<T> yes,
                             const Vec512<T> no) {
  return detail::IfThenElse(hwy::SizeTag<sizeof(T)>(), mask, yes, no);
}
#if HWY_HAVE_FLOAT16
HWY_API Vec512<float16_t> IfThenElse(Mask512<float16_t> mask,
                                     Vec512<float16_t> yes,
                                     Vec512<float16_t> no) {
  return Vec512<float16_t>{_mm512_mask_blend_ph(mask.raw, no.raw, yes.raw)};
}
#endif  // HWY_HAVE_FLOAT16
HWY_API Vec512<float> IfThenElse(Mask512<float> mask, Vec512<float> yes,
                                 Vec512<float> no) {
  return Vec512<float>{_mm512_mask_blend_ps(mask.raw, no.raw, yes.raw)};
}
HWY_API Vec512<double> IfThenElse(Mask512<double> mask, Vec512<double> yes,
                                  Vec512<double> no) {
  return Vec512<double>{_mm512_mask_blend_pd(mask.raw, no.raw, yes.raw)};
}

namespace detail {

template <typename T>
HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<1> /* tag */,
                                    const Mask512<T> mask,
                                    const Vec512<T> yes) {
  return Vec512<T>{_mm512_maskz_mov_epi8(mask.raw, yes.raw)};
}
template <typename T>
HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<2> /* tag */,
                                    const Mask512<T> mask,
                                    const Vec512<T> yes) {
  return Vec512<T>{_mm512_maskz_mov_epi16(mask.raw, yes.raw)};
}
template <typename T>
HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<4> /* tag */,
                                    const Mask512<T> mask,
                                    const Vec512<T> yes) {
  return Vec512<T>{_mm512_maskz_mov_epi32(mask.raw, yes.raw)};
}
template <typename T>
HWY_INLINE Vec512<T> IfThenElseZero(hwy::SizeTag<8> /* tag */,
                                    const Mask512<T> mask,
                                    const Vec512<T> yes) {
  return Vec512<T>{_mm512_maskz_mov_epi64(mask.raw, yes.raw)};
}

}  // namespace detail

template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)>
HWY_API Vec512<T> IfThenElseZero(const Mask512<T> mask, const Vec512<T> yes) {
  return detail::IfThenElseZero(hwy::SizeTag<sizeof(T)>(), mask, yes);
}
HWY_API Vec512<float> IfThenElseZero(Mask512<float> mask, Vec512<float> yes) {
  return Vec512<float>{_mm512_maskz_mov_ps(mask.raw, yes.raw)};
}
HWY_API Vec512<double> IfThenElseZero(Mask512<double> mask,
                                      Vec512<double> yes) {
  return Vec512<double>{_mm512_maskz_mov_pd(mask.raw, yes.raw)};
}

namespace detail {

template <typename T>
HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<1> /* tag */,
                                    const Mask512<T> mask, const Vec512<T> no) {
  // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16.
  return Vec512<T>{_mm512_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)};
}
template <typename T>
HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<2> /* tag */,
                                    const Mask512<T> mask, const Vec512<T> no) {
  return Vec512<T>{_mm512_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)};
}
template <typename T>
HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<4> /* tag */,
                                    const Mask512<T> mask, const Vec512<T> no) {
  return Vec512<T>{_mm512_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)};
}
template <typename T>
HWY_INLINE Vec512<T> IfThenZeroElse(hwy::SizeTag<8> /* tag */,
                                    const Mask512<T> mask, const Vec512<T> no) {
  return Vec512<T>{_mm512_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)};
}

}  // namespace detail

template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T)>
HWY_API Vec512<T> IfThenZeroElse(const Mask512<T> mask, const Vec512<T> no) {
  return detail::IfThenZeroElse(hwy::SizeTag<sizeof(T)>(), mask, no);
}
HWY_API Vec512<float> IfThenZeroElse(Mask512<float> mask, Vec512<float> no) {
  return Vec512<float>{_mm512_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)};
}
HWY_API Vec512<double> IfThenZeroElse(Mask512<double> mask, Vec512<double> no) {
  return Vec512<double>{_mm512_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)};
}

template <typename T>
HWY_API Vec512<T> IfNegativeThenElse(Vec512<T> v, Vec512<T> yes, Vec512<T> no) {
  static_assert(IsSigned<T>(), "Only works for signed/float");
  // AVX3 MaskFromVec only looks at the MSB
  return IfThenElse(MaskFromVec(v), yes, no);
}

template <typename T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T),
          HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2) | (1 << 4))>
HWY_API Vec512<T> IfNegativeThenNegOrUndefIfZero(Vec512<T> mask, Vec512<T> v) {
  // AVX3 MaskFromVec only looks at the MSB
  const DFromV<decltype(v)> d;
  return MaskedSubOr(v, MaskFromVec(mask), Zero(d), v);
}

template <typename T, HWY_IF_FLOAT(T)>
HWY_API Vec512<T> ZeroIfNegative(const Vec512<T> v) {
  // AVX3 MaskFromVec only looks at the MSB
  return IfThenZeroElse(MaskFromVec(v), v);
}

// ================================================== ARITHMETIC

// ------------------------------ Addition

// Unsigned
HWY_API Vec512<uint8_t> operator+(Vec512<uint8_t> a, Vec512<uint8_t> b) {
  return Vec512<uint8_t>{_mm512_add_epi8(a.raw, b.raw)};
}
HWY_API Vec512<uint16_t> operator+(Vec512<uint16_t> a, Vec512<uint16_t> b) {
  return Vec512<uint16_t>{_mm512_add_epi16(a.raw, b.raw)};
}
HWY_API Vec512<uint32_t> operator+(Vec512<uint32_t> a, Vec512<uint32_t> b) {
  return Vec512<uint32_t>{_mm512_add_epi32(a.raw, b.raw)};
}
HWY_API Vec512<uint64_t> operator+(Vec512<uint64_t> a, Vec512<uint64_t> b) {
  return Vec512<uint64_t>{_mm512_add_epi64(a.raw, b.raw)};
}

// Signed
HWY_API Vec512<int8_t> operator+(Vec512<int8_t> a, Vec512<int8_t> b) {
  return Vec512<int8_t>{_mm512_add_epi8(a.raw, b.raw)};
}
HWY_API Vec512<int16_t> operator+(Vec512<int16_t> a, Vec512<int16_t> b) {
  return Vec512<int16_t>{_mm512_add_epi16(a.raw, b.raw)};
}
HWY_API Vec512<int32_t> operator+(Vec512<int32_t> a, Vec512<int32_t> b) {
  return Vec512<int32_t>{_mm512_add_epi32(a.raw, b.raw)};
}
HWY_API Vec512<int64_t> operator+(Vec512<int64_t> a, Vec512<int64_t> b) {
  return Vec512<int64_t>{_mm512_add_epi64(a.raw, b.raw)};
}

// Float
#if HWY_HAVE_FLOAT16
HWY_API Vec512<float16_t> operator+(Vec512<float16_t> a, Vec512<float16_t> b) {
  return Vec512<float16_t>{_mm512_add_ph(a.raw, b.raw)};
}
#endif  // HWY_HAVE_FLOAT16
HWY_API Vec512<floatoperator+(Vec512<float> a, Vec512<float> b) {
  return Vec512<float>{_mm512_add_ps(a.raw, b.raw)};
}
HWY_API Vec512<doubleoperator+(Vec512<double> a, Vec512<double> b) {
  return Vec512<double>{_mm512_add_pd(a.raw, b.raw)};
}

// ------------------------------ Subtraction

// Unsigned
HWY_API Vec512<uint8_t> operator-(Vec512<uint8_t> a, Vec512<uint8_t> b) {
  return Vec512<uint8_t>{_mm512_sub_epi8(a.raw, b.raw)};
}
HWY_API Vec512<uint16_t> operator-(Vec512<uint16_t> a, Vec512<uint16_t> b) {
  return Vec512<uint16_t>{_mm512_sub_epi16(a.raw, b.raw)};
}
HWY_API Vec512<uint32_t> operator-(Vec512<uint32_t> a, Vec512<uint32_t> b) {
  return Vec512<uint32_t>{_mm512_sub_epi32(a.raw, b.raw)};
}
HWY_API Vec512<uint64_t> operator-(Vec512<uint64_t> a, Vec512<uint64_t> b) {
  return Vec512<uint64_t>{_mm512_sub_epi64(a.raw, b.raw)};
}

// Signed
HWY_API Vec512<int8_t> operator-(Vec512<int8_t> a, Vec512<int8_t> b) {
  return Vec512<int8_t>{_mm512_sub_epi8(a.raw, b.raw)};
}
HWY_API Vec512<int16_t> operator-(Vec512<int16_t> a, Vec512<int16_t> b) {
  return Vec512<int16_t>{_mm512_sub_epi16(a.raw, b.raw)};
}
HWY_API Vec512<int32_t> operator-(Vec512<int32_t> a, Vec512<int32_t> b) {
  return Vec512<int32_t>{_mm512_sub_epi32(a.raw, b.raw)};
}
HWY_API Vec512<int64_t> operator-(Vec512<int64_t> a, Vec512<int64_t> b) {
  return Vec512<int64_t>{_mm512_sub_epi64(a.raw, b.raw)};
}

// Float
#if HWY_HAVE_FLOAT16
HWY_API Vec512<float16_t> operator-(Vec512<float16_t> a, Vec512<float16_t> b) {
  return Vec512<float16_t>{_mm512_sub_ph(a.raw, b.raw)};
}
#endif  // HWY_HAVE_FLOAT16
HWY_API Vec512<float> operator-(Vec512<float> a, Vec512<float> b) {
  return Vec512<float>{_mm512_sub_ps(a.raw, b.raw)};
}
HWY_API Vec512<double> operator-(Vec512<double> a, Vec512<double> b) {
  return Vec512<double>{_mm512_sub_pd(a.raw, b.raw)};
}

// ------------------------------ SumsOf8
HWY_API Vec512<uint64_t> SumsOf8(const Vec512<uint8_t> v) {
  const Full512<uint8_t> d;
  return Vec512<uint64_t>{_mm512_sad_epu8(v.raw, Zero(d).raw)};
}

HWY_API Vec512<uint64_t> SumsOf8AbsDiff(Vec512<uint8_t> a, Vec512<uint8_t> b) {
  return Vec512<uint64_t>{_mm512_sad_epu8(a.raw, b.raw)};
}

// ------------------------------ SumsOf4
namespace detail {

HWY_INLINE Vec512<uint32_t> SumsOf4(hwy::UnsignedTag /*type_tag*/,
                                    hwy::SizeTag<1> /*lane_size_tag*/,
                                    Vec512<uint8_t> v) {
  const DFromV<decltype(v)> d;

  // _mm512_maskz_dbsad_epu8 is used below as the odd uint16_t lanes need to be
  // zeroed out and the sums of the 4 consecutive lanes are already in the
  // even uint16_t lanes of the _mm512_maskz_dbsad_epu8 result.
  return Vec512<uint32_t>{_mm512_maskz_dbsad_epu8(
      static_cast<__mmask32>(0x55555555), v.raw, Zero(d).raw, 0)};
}

// I8->I32 SumsOf4
// Generic for all vector lengths
template <class V>
HWY_INLINE VFromD<RepartitionToWideX2<DFromV<V>>> SumsOf4(
    hwy::SignedTag /*type_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, V v) {
  const DFromV<decltype(v)> d;
  const RebindToUnsigned<decltype(d)> du;
  const RepartitionToWideX2<decltype(d)> di32;

  // Adjust the values of v to be in the 0..255 range by adding 128 to each lane
  // of v (which is the same as an bitwise XOR of each i8 lane by 128) and then
  // bitcasting the Xor result to an u8 vector.
  const auto v_adj = BitCast(du, Xor(v, SignBit(d)));

  // Need to add -512 to each i32 lane of the result of the
  // SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v_adj) operation to account
  // for the adjustment made above.
  return BitCast(di32, SumsOf4(hwy::UnsignedTag(), hwy::SizeTag<1>(), v_adj)) +
         Set(di32, int32_t{-512});
}

}  // namespace detail

// ------------------------------ SumsOfShuffledQuadAbsDiff

#if HWY_TARGET <= HWY_AVX3
template <int kIdx3, int kIdx2, int kIdx1, int kIdx0>
static Vec512<uint16_t> SumsOfShuffledQuadAbsDiff(Vec512<uint8_t> a,
                                                  Vec512<uint8_t> b) {
  static_assert(0 <= kIdx0 && kIdx0 <= 3, "kIdx0 must be between 0 and 3");
  static_assert(0 <= kIdx1 && kIdx1 <= 3, "kIdx1 must be between 0 and 3");
  static_assert(0 <= kIdx2 && kIdx2 <= 3, "kIdx2 must be between 0 and 3");
  static_assert(0 <= kIdx3 && kIdx3 <= 3, "kIdx3 must be between 0 and 3");
  return Vec512<uint16_t>{
      _mm512_dbsad_epu8(b.raw, a.raw, _MM_SHUFFLE(kIdx3, kIdx2, kIdx1, kIdx0))};
}
#endif

// ------------------------------ SaturatedAdd

// Returns a + b clamped to the destination range.

// Unsigned
HWY_API Vec512<uint8_t> SaturatedAdd(Vec512<uint8_t> a, Vec512<uint8_t> b) {
  return Vec512<uint8_t>{_mm512_adds_epu8(a.raw, b.raw)};
}
HWY_API Vec512<uint16_t> SaturatedAdd(Vec512<uint16_t> a, Vec512<uint16_t> b) {
  return Vec512<uint16_t>{_mm512_adds_epu16(a.raw, b.raw)};
}

// Signed
HWY_API Vec512<int8_t> SaturatedAdd(Vec512<int8_t> a, Vec512<int8_t> b) {
  return Vec512<int8_t>{_mm512_adds_epi8(a.raw, b.raw)};
}
HWY_API Vec512<int16_t> SaturatedAdd(Vec512<int16_t> a, Vec512<int16_t> b) {
  return Vec512<int16_t>{_mm512_adds_epi16(a.raw, b.raw)};
}

// ------------------------------ SaturatedSub

// Returns a - b clamped to the destination range.

// Unsigned
HWY_API Vec512<uint8_t> SaturatedSub(Vec512<uint8_t> a, Vec512<uint8_t> b) {
  return Vec512<uint8_t>{_mm512_subs_epu8(a.raw, b.raw)};
}
HWY_API Vec512<uint16_t> SaturatedSub(Vec512<uint16_t> a, Vec512<uint16_t> b) {
  return Vec512<uint16_t>{_mm512_subs_epu16(a.raw, b.raw)};
}

// Signed
HWY_API Vec512<int8_t> SaturatedSub(Vec512<int8_t> a, Vec512<int8_t> b) {
  return Vec512<int8_t>{_mm512_subs_epi8(a.raw, b.raw)};
}
HWY_API Vec512<int16_t> SaturatedSub(Vec512<int16_t> a, Vec512<int16_t> b) {
  return Vec512<int16_t>{_mm512_subs_epi16(a.raw, b.raw)};
}

// ------------------------------ Average

// Returns (a + b + 1) / 2

// Unsigned
HWY_API Vec512<uint8_t> AverageRound(Vec512<uint8_t> a, Vec512<uint8_t> b) {
  return Vec512<uint8_t>{_mm512_avg_epu8(a.raw, b.raw)};
}
HWY_API Vec512<uint16_t> AverageRound(Vec512<uint16_t> a, Vec512<uint16_t> b) {
  return Vec512<uint16_t>{_mm512_avg_epu16(a.raw, b.raw)};
}

// ------------------------------ Abs (Sub)

// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1.
HWY_API Vec512<int8_t> Abs(const Vec512<int8_t> v) {
#if HWY_COMPILER_MSVC
  // Workaround for incorrect codegen? (untested due to internal compiler error)
  const DFromV<decltype(v)> d;
  const auto zero = Zero(d);
  return Vec512<int8_t>{_mm512_max_epi8(v.raw, (zero - v).raw)};
#else
  return Vec512<int8_t>{_mm512_abs_epi8(v.raw)};
#endif
}
HWY_API Vec512<int16_t> Abs(const Vec512<int16_t> v) {
  return Vec512<int16_t>{_mm512_abs_epi16(v.raw)};
}
HWY_API Vec512<int32_t> Abs(const Vec512<int32_t> v) {
  return Vec512<int32_t>{_mm512_abs_epi32(v.raw)};
}
HWY_API Vec512<int64_t> Abs(const Vec512<int64_t> v) {
  return Vec512<int64_t>{_mm512_abs_epi64(v.raw)};
}

// ------------------------------ ShiftLeft

#if HWY_TARGET <= HWY_AVX3_DL
namespace detail {
template <typename T>
HWY_API Vec512<T> GaloisAffine(Vec512<T> v, Vec512<uint64_t> matrix) {
  return Vec512<T>{_mm512_gf2p8affine_epi64_epi8(v.raw, matrix.raw, 0)};
}
}  // namespace detail
#endif  // HWY_TARGET <= HWY_AVX3_DL

template <int kBits>
HWY_API Vec512<uint16_t> ShiftLeft(const Vec512<uint16_t> v) {
  return Vec512<uint16_t>{_mm512_slli_epi16(v.raw, kBits)};
}

template <int kBits>
HWY_API Vec512<uint32_t> ShiftLeft(const Vec512<uint32_t> v) {
  return Vec512<uint32_t>{_mm512_slli_epi32(v.raw, kBits)};
}

template <int kBits>
HWY_API Vec512<uint64_t> ShiftLeft(const Vec512<uint64_t> v) {
  return Vec512<uint64_t>{_mm512_slli_epi64(v.raw, kBits)};
}

template <int kBits>
HWY_API Vec512<int16_t> ShiftLeft(const Vec512<int16_t> v) {
  return Vec512<int16_t>{_mm512_slli_epi16(v.raw, kBits)};
}

template <int kBits>
HWY_API Vec512<int32_t> ShiftLeft(const Vec512<int32_t> v) {
  return Vec512<int32_t>{_mm512_slli_epi32(v.raw, kBits)};
}

template <int kBits>
HWY_API Vec512<int64_t> ShiftLeft(const Vec512<int64_t> v) {
  return Vec512<int64_t>{_mm512_slli_epi64(v.raw, kBits)};
}

#if HWY_TARGET <= HWY_AVX3_DL

// Generic for all vector lengths. Must be defined after all GaloisAffine.
template <int kBits, class V, HWY_IF_T_SIZE_V(V, 1)>
HWY_API V ShiftLeft(const V v) {
  const Repartition<uint64_t, DFromV<V>> du64;
  if (kBits == 0) return v;
  if (kBits == 1) return v + v;
  constexpr uint64_t kMatrix = (0x0102040810204080ULL >> kBits) &
                               (0x0101010101010101ULL * (0xFF >> kBits));
  return detail::GaloisAffine(v, Set(du64, kMatrix));
}

#else  // HWY_TARGET > HWY_AVX3_DL

template <int kBits, typename T, HWY_IF_T_SIZE(T, 1)>
HWY_API Vec512<T> ShiftLeft(const Vec512<T> v) {
  const DFromV<decltype(v)> d8;
  const RepartitionToWide<decltype(d8)> d16;
  const auto shifted = BitCast(d8, ShiftLeft<kBits>(BitCast(d16, v)));
  return kBits == 1
             ? (v + v)
             : (shifted & Set(d8, static_cast<T>((0xFF << kBits) & 0xFF)));
}

#endif  // HWY_TARGET > HWY_AVX3_DL

// ------------------------------ ShiftRight

template <int kBits>
HWY_API Vec512<uint16_t> ShiftRight(const Vec512<uint16_t> v) {
  return Vec512<uint16_t>{_mm512_srli_epi16(v.raw, kBits)};
}

template <int kBits>
HWY_API Vec512<uint32_t> ShiftRight(const Vec512<uint32_t> v) {
  return Vec512<uint32_t>{_mm512_srli_epi32(v.raw, kBits)};
}

template <int kBits>
HWY_API Vec512<uint64_t> ShiftRight(const Vec512<uint64_t> v) {
  return Vec512<uint64_t>{_mm512_srli_epi64(v.raw, kBits)};
}

template <int kBits>
HWY_API Vec512<int16_t> ShiftRight(const Vec512<int16_t> v) {
  return Vec512<int16_t>{_mm512_srai_epi16(v.raw, kBits)};
}

template <int kBits>
HWY_API Vec512<int32_t> ShiftRight(const Vec512<int32_t> v) {
  return Vec512<int32_t>{_mm512_srai_epi32(v.raw, kBits)};
}

template <int kBits>
HWY_API Vec512<int64_t> ShiftRight(const Vec512<int64_t> v) {
  return Vec512<int64_t>{_mm512_srai_epi64(v.raw, kBits)};
}

#if HWY_TARGET <= HWY_AVX3_DL

// Generic for all vector lengths. Must be defined after all GaloisAffine.
template <int kBits, class V, HWY_IF_U8_D(DFromV<V>)>
HWY_API V ShiftRight(const V v) {
  const Repartition<uint64_t, DFromV<V>> du64;
  if (kBits == 0) return v;
  constexpr uint64_t kMatrix =
      (0x0102040810204080ULL << kBits) &
      (0x0101010101010101ULL * ((0xFF << kBits) & 0xFF));
  return detail::GaloisAffine(v, Set(du64, kMatrix));
}

// Generic for all vector lengths. Must be defined after all GaloisAffine.
template <int kBits, class V, HWY_IF_I8_D(DFromV<V>)>
HWY_API V ShiftRight(const V v) {
  const Repartition<uint64_t, DFromV<V>> du64;
  if (kBits == 0) return v;
  constexpr uint64_t kShift =
      (0x0102040810204080ULL << kBits) &
      (0x0101010101010101ULL * ((0xFF << kBits) & 0xFF));
  constexpr uint64_t kSign =
      kBits == 0 ? 0 : (0x8080808080808080ULL >> (64 - (8 * kBits)));
  return detail::GaloisAffine(v, Set(du64, kShift | kSign));
}

#else  // HWY_TARGET > HWY_AVX3_DL

template <int kBits>
HWY_API Vec512<uint8_t> ShiftRight(const Vec512<uint8_t> v) {
  const DFromV<decltype(v)> d8;
  // Use raw instead of BitCast to support N=1.
  const Vec512<uint8_t> shifted{ShiftRight<kBits>(Vec512<uint16_t>{v.raw}).raw};
  return shifted & Set(d8, 0xFF >> kBits);
}

template <int kBits>
HWY_API Vec512<int8_t> ShiftRight(const Vec512<int8_t> v) {
  const DFromV<decltype(v)> di;
  const RebindToUnsigned<decltype(di)> du;
  const auto shifted = BitCast(di, ShiftRight<kBits>(BitCast(du, v)));
  const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits));
  return (shifted ^ shifted_sign) - shifted_sign;
}

#endif  //  HWY_TARGET > HWY_AVX3_DL

// ------------------------------ RotateRight

template <int kBits, typename T, HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2))>
HWY_API Vec512<T> RotateRight(const Vec512<T> v) {
  constexpr size_t kSizeInBits = sizeof(T) * 8;
  static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count");
  if (kBits == 0) return v;
  // AVX3 does not support 8/16-bit.
  return Or(ShiftRight<kBits>(v),
            ShiftLeft<HWY_MIN(kSizeInBits - 1, kSizeInBits - kBits)>(v));
}

template <int kBits>
HWY_API Vec512<uint32_t> RotateRight(const Vec512<uint32_t> v) {
  static_assert(0 <= kBits && kBits < 32, "Invalid shift count");
  if (kBits == 0) return v;
  return Vec512<uint32_t>{_mm512_ror_epi32(v.raw, kBits)};
}

template <int kBits>
HWY_API Vec512<uint64_t> RotateRight(const Vec512<uint64_t> v) {
  static_assert(0 <= kBits && kBits < 64, "Invalid shift count");
  if (kBits == 0) return v;
  return Vec512<uint64_t>{_mm512_ror_epi64(v.raw, kBits)};
}

// ------------------------------ ShiftLeftSame

// GCC <14 and Clang <11 do not follow the Intel documentation for AVX-512
// shift-with-immediate: the counts should all be unsigned int.
#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1100
using Shift16Count = int;
using Shift3264Count = int;
#elif HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400
// GCC 11.0 requires these, prior versions used a macro+cast and don't care.
using Shift16Count = int;
using Shift3264Count = unsigned int;
#else
// Assume documented behavior. Clang 11, GCC 14 and MSVC 14.28.29910 match this.
using Shift16Count = unsigned int;
using Shift3264Count = unsigned int;
#endif

HWY_API Vec512<uint16_t> ShiftLeftSame(const Vec512<uint16_t> v,
                                       const int bits) {
#if HWY_COMPILER_GCC
  if (__builtin_constant_p(bits)) {
    return Vec512<uint16_t>{
        _mm512_slli_epi16(v.raw, static_cast<Shift16Count>(bits))};
  }
#endif
  return Vec512<uint16_t>{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<uint32_t> ShiftLeftSame(const Vec512<uint32_t> v,
                                       const int bits) {
#if HWY_COMPILER_GCC
  if (__builtin_constant_p(bits)) {
    return Vec512<uint32_t>{
        _mm512_slli_epi32(v.raw, static_cast<Shift3264Count>(bits))};
  }
#endif
  return Vec512<uint32_t>{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<uint64_t> ShiftLeftSame(const Vec512<uint64_t> v,
                                       const int bits) {
#if HWY_COMPILER_GCC
  if (__builtin_constant_p(bits)) {
    return Vec512<uint64_t>{
        _mm512_slli_epi64(v.raw, static_cast<Shift3264Count>(bits))};
  }
#endif
  return Vec512<uint64_t>{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))};
}

HWY_API Vec512<int16_t> ShiftLeftSame(const Vec512<int16_t> v, const int bits) {
#if HWY_COMPILER_GCC
  if (__builtin_constant_p(bits)) {
    return Vec512<int16_t>{
        _mm512_slli_epi16(v.raw, static_cast<Shift16Count>(bits))};
  }
#endif
  return Vec512<int16_t>{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))};
}

HWY_API Vec512<int32_t> ShiftLeftSame(const Vec512<int32_t> v, const int bits) {
#if HWY_COMPILER_GCC
  if (__builtin_constant_p(bits)) {
    return Vec512<int32_t>{
        _mm512_slli_epi32(v.raw, static_cast<Shift3264Count>(bits))};
  }
#endif
  return Vec512<int32_t>{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))};
}

HWY_API Vec512<int64_t> ShiftLeftSame(const Vec512<int64_t> v, const int bits) {
#if HWY_COMPILER_GCC
  if (__builtin_constant_p(bits)) {
    return Vec512<int64_t>{
        _mm512_slli_epi64(v.raw, static_cast<Shift3264Count>(bits))};
  }
#endif
  return Vec512<int64_t>{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))};
}

template <typename T, HWY_IF_T_SIZE(T, 1)>
HWY_API Vec512<T> ShiftLeftSame(const Vec512<T> v, const int bits) {
  const DFromV<decltype(v)> d8;
  const RepartitionToWide<decltype(d8)> d16;
  const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits));
  return shifted & Set(d8, static_cast<T>((0xFF << bits) & 0xFF));
}

// ------------------------------ ShiftRightSame

HWY_API Vec512<uint16_t> ShiftRightSame(const Vec512<uint16_t> v,
                                        const int bits) {
#if HWY_COMPILER_GCC
  if (__builtin_constant_p(bits)) {
    return Vec512<uint16_t>{
        _mm512_srli_epi16(v.raw, static_cast<Shift16Count>(bits))};
  }
#endif
  return Vec512<uint16_t>{_mm512_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<uint32_t> ShiftRightSame(const Vec512<uint32_t> v,
                                        const int bits) {
#if HWY_COMPILER_GCC
  if (__builtin_constant_p(bits)) {
    return Vec512<uint32_t>{
        _mm512_srli_epi32(v.raw, static_cast<Shift3264Count>(bits))};
  }
#endif
  return Vec512<uint32_t>{_mm512_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<uint64_t> ShiftRightSame(const Vec512<uint64_t> v,
                                        const int bits) {
#if HWY_COMPILER_GCC
  if (__builtin_constant_p(bits)) {
    return Vec512<uint64_t>{
        _mm512_srli_epi64(v.raw, static_cast<Shift3264Count>(bits))};
  }
#endif
  return Vec512<uint64_t>{_mm512_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))};
}

HWY_API Vec512<uint8_t> ShiftRightSame(Vec512<uint8_t> v, const int bits) {
  const DFromV<decltype(v)> d8;
  const RepartitionToWide<decltype(d8)> d16;
  const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits));
  return shifted & Set(d8, static_cast<uint8_t>(0xFF >> bits));
}

HWY_API Vec512<int16_t> ShiftRightSame(const Vec512<int16_t> v,
                                       const int bits) {
#if HWY_COMPILER_GCC
  if (__builtin_constant_p(bits)) {
    return Vec512<int16_t>{
        _mm512_srai_epi16(v.raw, static_cast<Shift16Count>(bits))};
  }
#endif
  return Vec512<int16_t>{_mm512_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))};
}

HWY_API Vec512<int32_t> ShiftRightSame(const Vec512<int32_t> v,
                                       const int bits) {
#if HWY_COMPILER_GCC
  if (__builtin_constant_p(bits)) {
    return Vec512<int32_t>{
        _mm512_srai_epi32(v.raw, static_cast<Shift3264Count>(bits))};
  }
#endif
  return Vec512<int32_t>{_mm512_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))};
}
HWY_API Vec512<int64_t> ShiftRightSame(const Vec512<int64_t> v,
                                       const int bits) {
#if HWY_COMPILER_GCC
  if (__builtin_constant_p(bits)) {
    return Vec512<int64_t>{
        _mm512_srai_epi64(v.raw, static_cast<Shift3264Count>(bits))};
  }
#endif
  return Vec512<int64_t>{_mm512_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))};
}

HWY_API Vec512<int8_t> ShiftRightSame(Vec512<int8_t> v, const int bits) {
  const DFromV<decltype(v)> di;
  const RebindToUnsigned<decltype(di)> du;
  const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits));
  const auto shifted_sign =
      BitCast(di, Set(du, static_cast<uint8_t>(0x80 >> bits)));
  return (shifted ^ shifted_sign) - shifted_sign;
}

// ------------------------------ Minimum

// Unsigned
HWY_API Vec512<uint8_t> Min(Vec512<uint8_t> a, Vec512<uint8_t> b) {
  return Vec512<uint8_t>{_mm512_min_epu8(a.raw, b.raw)};
}
HWY_API Vec512<uint16_t> Min(Vec512<uint16_t> a, Vec512<uint16_t> b) {
  return Vec512<uint16_t>{_mm512_min_epu16(a.raw, b.raw)};
}
HWY_API Vec512<uint32_t> Min(Vec512<uint32_t> a, Vec512<uint32_t> b) {
  return Vec512<uint32_t>{_mm512_min_epu32(a.raw, b.raw)};
}
HWY_API Vec512<uint64_t> Min(Vec512<uint64_t> a, Vec512<uint64_t> b) {
  return Vec512<uint64_t>{_mm512_min_epu64(a.raw, b.raw)};
}

// Signed
HWY_API Vec512<int8_t> Min(Vec512<int8_t> a, Vec512<int8_t> b) {
  return Vec512<int8_t>{_mm512_min_epi8(a.raw, b.raw)};
}
HWY_API Vec512<int16_t> Min(Vec512<int16_t> a, Vec512<int16_t> b) {
  return Vec512<int16_t>{_mm512_min_epi16(a.raw, b.raw)};
}
HWY_API Vec512<int32_t> Min(Vec512<int32_t> a, Vec512<int32_t> b) {
  return Vec512<int32_t>{_mm512_min_epi32(a.raw, b.raw)};
}
HWY_API Vec512<int64_t> Min(Vec512<int64_t> a, Vec512<int64_t> b) {
  return Vec512<int64_t>{_mm512_min_epi64(a.raw, b.raw)};
}

// Float
#if HWY_HAVE_FLOAT16
HWY_API Vec512<float16_t> Min(Vec512<float16_t> a, Vec512<float16_t> b) {
  return Vec512<float16_t>{_mm512_min_ph(a.raw, b.raw)};
}
#endif  // HWY_HAVE_FLOAT16
HWY_API Vec512<float> Min(Vec512<float> a, Vec512<float> b) {
  return Vec512<float>{_mm512_min_ps(a.raw, b.raw)};
}
HWY_API Vec512<double> Min(Vec512<double> a, Vec512<double> b) {
  return Vec512<double>{_mm512_min_pd(a.raw, b.raw)};
}

// ------------------------------ Maximum

// Unsigned
HWY_API Vec512<uint8_t> Max(Vec512<uint8_t> a, Vec512<uint8_t> b) {
  return Vec512<uint8_t>{_mm512_max_epu8(a.raw, b.raw)};
}
HWY_API Vec512<uint16_t> Max(Vec512<uint16_t> a, Vec512<uint16_t> b) {
  return Vec512<uint16_t>{_mm512_max_epu16(a.raw, b.raw)};
}
HWY_API Vec512<uint32_t> Max(Vec512<uint32_t> a, Vec512<uint32_t> b) {
  return Vec512<uint32_t>{_mm512_max_epu32(a.raw, b.raw)};
}
HWY_API Vec512<uint64_t> Max(Vec512<uint64_t> a, Vec512<uint64_t> b) {
  return Vec512<uint64_t>{_mm512_max_epu64(a.raw, b.raw)};
}

// Signed
HWY_API Vec512<int8_t> Max(Vec512<int8_t> a, Vec512<int8_t> b) {
  return Vec512<int8_t>{_mm512_max_epi8(a.raw, b.raw)};
}
HWY_API Vec512<int16_t> Max(Vec512<int16_t> a, Vec512<int16_t> b) {
  return Vec512<int16_t>{_mm512_max_epi16(a.raw, b.raw)};
}
HWY_API Vec512<int32_t> Max(Vec512<int32_t> a, Vec512<int32_t> b) {
  return Vec512<int32_t>{_mm512_max_epi32(a.raw, b.raw)};
}
HWY_API Vec512<int64_t> Max(Vec512<int64_t> a, Vec512<int64_t> b) {
  return Vec512<int64_t>{_mm512_max_epi64(a.raw, b.raw)};
}

// Float
#if HWY_HAVE_FLOAT16
HWY_API Vec512<float16_t> Max(Vec512<float16_t> a, Vec512<float16_t> b) {
  return Vec512<float16_t>{_mm512_max_ph(a.raw, b.raw)};
}
#endif  // HWY_HAVE_FLOAT16
HWY_API Vec512<float> Max(Vec512<float> a, Vec512<float> b) {
  return Vec512<float>{_mm512_max_ps(a.raw, b.raw)};
}
HWY_API Vec512<double> Max(Vec512<double> a, Vec512<double> b) {
  return Vec512<double>{_mm512_max_pd(a.raw, b.raw)};
}

// ------------------------------ Integer multiplication

// Per-target flag to prevent generic_ops-inl.h from defining 64-bit operator*.
#ifdef HWY_NATIVE_MUL_64
#undef HWY_NATIVE_MUL_64
#else
#define HWY_NATIVE_MUL_64
#endif

// Unsigned
HWY_API Vec512<uint16_t> operator*(Vec512<uint16_t> a, Vec512<uint16_t> b) {
  return Vec512<uint16_t>{_mm512_mullo_epi16(a.raw, b.raw)};
}
HWY_API Vec512<uint32_t> operator*(Vec512<uint32_t> a, Vec512<uint32_t> b) {
  return Vec512<uint32_t>{_mm512_mullo_epi32(a.raw, b.raw)};
}
HWY_API Vec512<uint64_t> operator*(Vec512<uint64_t> a, Vec512<uint64_t> b) {
  return Vec512<uint64_t>{_mm512_mullo_epi64(a.raw, b.raw)};
}
HWY_API Vec256<uint64_t> operator*(Vec256<uint64_t> a, Vec256<uint64_t> b) {
  return Vec256<uint64_t>{_mm256_mullo_epi64(a.raw, b.raw)};
}
template <size_t N>
HWY_API Vec128<uint64_t, N> operator*(Vec128<uint64_t, N> a,
                                      Vec128<uint64_t, N> b) {
  return Vec128<uint64_t, N>{_mm_mullo_epi64(a.raw, b.raw)};
}

// Signed
HWY_API Vec512<int16_t> operator*(Vec512<int16_t> a, Vec512<int16_t> b) {
  return Vec512<int16_t>{_mm512_mullo_epi16(a.raw, b.raw)};
}
HWY_API Vec512<int32_t> operator*(Vec512<int32_t> a, Vec512<int32_t> b) {
  return Vec512<int32_t>{_mm512_mullo_epi32(a.raw, b.raw)};
}
HWY_API Vec512<int64_t> operator*(Vec512<int64_t> a, Vec512<int64_t> b) {
  return Vec512<int64_t>{_mm512_mullo_epi64(a.raw, b.raw)};
}
HWY_API Vec256<int64_t> operator*(Vec256<int64_t> a, Vec256<int64_t> b) {
  return Vec256<int64_t>{_mm256_mullo_epi64(a.raw, b.raw)};
}
template <size_t N>
HWY_API Vec128<int64_t, N> operator*(Vec128<int64_t, N> a,
                                     Vec128<int64_t, N> b) {
  return Vec128<int64_t, N>{_mm_mullo_epi64(a.raw, b.raw)};
}
// Returns the upper 16 bits of a * b in each lane.
HWY_API Vec512<uint16_t> MulHigh(Vec512<uint16_t> a, Vec512<uint16_t> b) {
  return Vec512<uint16_t>{_mm512_mulhi_epu16(a.raw, b.raw)};
}
HWY_API Vec512<int16_t> MulHigh(Vec512<int16_t> a, Vec512<int16_t> b) {
  return Vec512<int16_t>{_mm512_mulhi_epi16(a.raw, b.raw)};
}

HWY_API Vec512<int16_t> MulFixedPoint15(Vec512<int16_t> a, Vec512<int16_t> b) {
  return Vec512<int16_t>{_mm512_mulhrs_epi16(a.raw, b.raw)};
}

// Multiplies even lanes (0, 2 ..) and places the double-wide result into
// even and the upper half into its odd neighbor lane.
HWY_API Vec512<int64_t> MulEven(Vec512<int32_t> a, Vec512<int32_t> b) {
  return Vec512<int64_t>{_mm512_mul_epi32(a.raw, b.raw)};
}
HWY_API Vec512<uint64_t> MulEven(Vec512<uint32_t> a, Vec512<uint32_t> b) {
  return Vec512<uint64_t>{_mm512_mul_epu32(a.raw, b.raw)};
}

// ------------------------------ Neg (Sub)

--> --------------------

--> maximum size reached

--> --------------------

Messung V0.5
C=86 H=100 G=93

¤ Dauer der Verarbeitung: 0.23 Sekunden  (vorverarbeitet)  ¤

*© Formatika GbR, Deutschland






Wurzel

Suchen

Beweissystem der NASA

Beweissystem Isabelle

NIST Cobol Testsuite

Cephes Mathematical Library

Wiener Entwicklungsmethode

Haftungshinweis

Die Informationen auf dieser Webseite wurden nach bestem Wissen sorgfältig zusammengestellt. Es wird jedoch weder Vollständigkeit, noch Richtigkeit, noch Qualität der bereit gestellten Informationen zugesichert.

Bemerkung:

Die farbliche Syntaxdarstellung und die Messung sind noch experimentell.