// Copyright 2021 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Arm SVE[2] vectors (length not known at compile time).
// External include guard in highway.h - see comment there.
#include <arm_sve.h>
#include "hwy/ops/shared-inl.h"
// Arm C215 declares that SVE vector lengths will always be a power of two.
// We default to relying on this, which makes some operations more efficient.
// You can still opt into fixups by setting this to 0 (unsupported).
#ifndef HWY_SVE_IS_POW2
#define HWY_SVE_IS_POW2 1
#endif
#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128
#define HWY_SVE_HAVE_2 1
#else
#define HWY_SVE_HAVE_2 0
#endif
// If 1, both __bf16 and a limited set of *_bf16 SVE intrinsics are available:
// create/get/set/dup, ld/st, sel, rev, trn, uzp, zip.
#if HWY_ARM_HAVE_SCALAR_BF16_TYPE &&
defined(__ARM_FEATURE_SVE_BF16)
#define HWY_SVE_HAVE_BF16_FEATURE 1
#else
#define HWY_SVE_HAVE_BF16_FEATURE 0
#endif
// HWY_SVE_HAVE_BF16_VEC is defined to 1 if the SVE svbfloat16_t vector type
// is supported, even if HWY_SVE_HAVE_BF16_FEATURE (= intrinsics) is 0.
#if HWY_SVE_HAVE_BF16_FEATURE || HWY_COMPILER_GCC_ACTUAL >= 1000
#define HWY_SVE_HAVE_BF16_VEC 1
#else
#define HWY_SVE_HAVE_BF16_VEC 0
#endif
HWY_BEFORE_NAMESPACE();
namespace hwy {
namespace HWY_NAMESPACE {
template <
class V>
struct DFromV_t {};
// specialized in macros
template <
class V>
using DFromV =
typename DFromV_t<RemoveConst<V>>::type;
template <
class V>
using TFromV = TFromD<DFromV<V>>;
// ================================================== MACROS
// Generate specializations and function definitions using X macros. Although
// harder to read and debug, writing everything manually is too bulky.
namespace detail {
// for code folding
// Args: BASE, CHAR, BITS, HALF, NAME, OP
// Unsigned:
#define HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) X_MACRO(uint, u, 8, 8, NAME, OP)
#define HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) X_MACRO(uint, u, 16, 8, NAME, OP)
#define HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \
X_MACRO(uint, u, 32, 16, NAME, OP)
#define HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \
X_MACRO(uint, u, 64, 32, NAME, OP)
// Signed:
#define HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) X_MACRO(
int, s, 8, 8, NAME, OP)
#define HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) X_MACRO(
int, s, 16, 8, NAME, OP)
#define HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) X_MACRO(
int, s, 32, 16, NAME, OP)
#define HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP) X_MACRO(
int, s, 64, 32, NAME, OP)
// Float:
#define HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \
X_MACRO(
float, f, 16, 16, NAME, OP)
#define HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \
X_MACRO(
float, f, 32, 16, NAME, OP)
#define HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP) \
X_MACRO(
float, f, 64, 32, NAME, OP)
#define HWY_SVE_FOREACH_BF16_UNCONDITIONAL(X_MACRO, NAME, OP) \
X_MACRO(bfloat, bf, 16, 16, NAME, OP)
#if HWY_SVE_HAVE_BF16_FEATURE
#define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_BF16_UNCONDITIONAL(X_MACRO, NAME, OP)
// We have both f16 and bf16, so nothing is emulated.
#define HWY_SVE_IF_EMULATED_D(D) hwy::EnableIf<
false>* = nullptr
#define HWY_SVE_IF_NOT_EMULATED_D(D) hwy::EnableIf<
true>* = nullptr
#else
#define HWY_SVE_FOREACH_BF16(X_MACRO, NAME, OP)
#define HWY_SVE_IF_EMULATED_D(D) HWY_IF_BF16_D(D)
#define HWY_SVE_IF_NOT_EMULATED_D(D) HWY_IF_NOT_BF16_D(D)
#endif // HWY_SVE_HAVE_BF16_FEATURE
// For all element sizes:
#define HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP)
#define HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP)
#define HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_F32(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_F64(X_MACRO, NAME, OP)
// HWY_SVE_FOREACH_F does not include HWY_SVE_FOREACH_BF16 because SVE lacks
// bf16 overloads for some intrinsics (especially less-common arithmetic).
// However, this does include f16 because SVE supports it unconditionally.
#define HWY_SVE_FOREACH_F(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_F16(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP)
// Commonly used type categories for a given element size:
#define HWY_SVE_FOREACH_UI08(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_U08(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_I08(X_MACRO, NAME, OP)
#define HWY_SVE_FOREACH_UI16(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_U16(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_I16(X_MACRO, NAME, OP)
#define HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_U32(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_I32(X_MACRO, NAME, OP)
#define HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_U64(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_I64(X_MACRO, NAME, OP)
#define HWY_SVE_FOREACH_UIF3264(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_UI32(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_UI64(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_F3264(X_MACRO, NAME, OP)
// Commonly used type categories:
#define HWY_SVE_FOREACH_UI(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_I(X_MACRO, NAME, OP)
#define HWY_SVE_FOREACH_IF(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_F(X_MACRO, NAME, OP)
#define HWY_SVE_FOREACH(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_U(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_I(X_MACRO, NAME, OP) \
HWY_SVE_FOREACH_F(X_MACRO, NAME, OP)
// Assemble types for use in x-macros
#define HWY_SVE_T(BASE, BITS) BASE
##BITS
##_t
#define HWY_SVE_D(BASE, BITS, N, POW2) Simd<HWY_SVE_T(BASE, BITS), N, POW2>
#define HWY_SVE_V(BASE, BITS) sv
##BASE
##BITS
##_t
#define HWY_SVE_TUPLE(BASE, BITS, MUL) sv
##BASE
##BITS
##x
##MUL
##_t
}
// namespace detail
#define HWY_SPECIALIZE(BASE,
CHAR, BITS, HALF, NAME, OP) \
template <> \
struct DFromV_t<HWY_SVE_V(BASE, BITS)> { \
using type = ScalableTag<HWY_SVE_T(BASE, BITS)>; \
};
HWY_SVE_FOREACH(HWY_SPECIALIZE, _, _)
#if HWY_SVE_HAVE_BF16_FEATURE || HWY_SVE_HAVE_BF16_VEC
HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
#endif
#undef HWY_SPECIALIZE
// Note: _x (don't-care value for inactive lanes) avoids additional MOVPRFX
// instructions, and we anyway only use it when the predicate is ptrue.
// vector = f(vector), e.g. Not
#define HWY_SVE_RETV_ARGPV(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv
##OP
##_
##CHAR##BITS
##_x(HWY_SVE_PTRUE(BITS), v); \
}
#define HWY_SVE_RETV_ARGV(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv
##OP
##_
##CHAR##BITS(v); \
}
// vector = f(vector, scalar), e.g. detail::AddN
#define HWY_SVE_RETV_ARGPVN(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
return sv
##OP
##_
##CHAR##BITS
##_x(HWY_SVE_PTRUE(BITS), a, b); \
}
#define HWY_SVE_RETV_ARGVN(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
return sv
##OP
##_
##CHAR##BITS(a, b); \
}
// vector = f(vector, vector), e.g. Add
#define HWY_SVE_RETV_ARGVV(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv
##OP
##_
##CHAR##BITS(a, b); \
}
// All-true mask
#define HWY_SVE_RETV_ARGPVV(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv
##OP
##_
##CHAR##BITS
##_x(HWY_SVE_PTRUE(BITS), a, b); \
}
// User-specified mask. Mask=false value is undefined and must be set by caller
// because SVE instructions take it from one of the two inputs, whereas
// AVX-512, RVV and Highway allow a third argument.
#define HWY_SVE_RETV_ARGMVV(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv
##OP
##_
##CHAR##BITS
##_x(m, a, b); \
}
#define HWY_SVE_RETV_ARGVVV(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \
HWY_SVE_V(BASE, BITS) c) { \
return sv
##OP
##_
##CHAR##BITS(a, b, c); \
}
// ------------------------------ Lanes
namespace detail {
// Returns actual lanes of a hardware vector without rounding to a power of two.
template <
typename T, HWY_IF_T_SIZE(T, 1)>
HWY_INLINE size_t AllHardwareLanes() {
return svcntb_pat(SV_ALL);
}
template <
typename T, HWY_IF_T_SIZE(T, 2)>
HWY_INLINE size_t AllHardwareLanes() {
return svcnth_pat(SV_ALL);
}
template <
typename T, HWY_IF_T_SIZE(T, 4)>
HWY_INLINE size_t AllHardwareLanes() {
return svcntw_pat(SV_ALL);
}
template <
typename T, HWY_IF_T_SIZE(T, 8)>
HWY_INLINE size_t AllHardwareLanes() {
return svcntd_pat(SV_ALL);
}
// All-true mask from a macro
#if HWY_SVE_IS_POW2
#define HWY_SVE_ALL_PTRUE(BITS) svptrue_b
##BITS()
#define HWY_SVE_PTRUE(BITS) svptrue_b
##BITS()
#else
#define HWY_SVE_ALL_PTRUE(BITS) svptrue_pat_b
##BITS(SV_ALL)
#define HWY_SVE_PTRUE(BITS) svptrue_pat_b
##BITS(SV_POW2)
#endif // HWY_SVE_IS_POW2
}
// namespace detail
#if HWY_HAVE_SCALABLE
// Returns actual number of lanes after capping by N and shifting. May return 0
// (e.g. for "1/8th" of a u32x4 - would be 1 for 1/8th of u32x8).
template <
typename T, size_t N,
int kPow2>
HWY_API size_t Lanes(Simd<T, N, kPow2> d) {
const size_t actual = detail::AllHardwareLanes<T>();
constexpr size_t kMaxLanes = MaxLanes(d);
constexpr
int kClampedPow2 = HWY_MIN(kPow2, 0);
// Common case of full vectors: avoid any extra instructions.
if (detail::IsFull(d))
return actual;
return HWY_MIN(detail::ScaleByPower(actual, kClampedPow2), kMaxLanes);
}
#endif // HWY_HAVE_SCALABLE
// ================================================== MASK INIT
// One mask bit per byte; only the one belonging to the lowest byte is valid.
// ------------------------------ FirstN
#define HWY_SVE_FIRSTN(BASE,
CHAR, BITS, HALF, NAME, OP) \
template <size_t N,
int kPow2> \
HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, size_t count) { \
const size_t limit = detail::IsFull(d) ? count : HWY_MIN(Lanes(d), count); \
return sv
##OP
##_b
##BITS
##_u32(uint32_t{0},
static_cast<uint32_t>(limit)); \
}
HWY_SVE_FOREACH(HWY_SVE_FIRSTN, FirstN, whilelt)
HWY_SVE_FOREACH_BF16(HWY_SVE_FIRSTN, FirstN, whilelt)
template <
class D, HWY_SVE_IF_EMULATED_D(D)>
svbool_t FirstN(D
/* tag */, size_t count) {
return FirstN(RebindToUnsigned<D>(), count);
}
#undef HWY_SVE_FIRSTN
template <
class D>
using MFromD = svbool_t;
namespace detail {
#define HWY_SVE_WRAP_PTRUE(BASE,
CHAR, BITS, HALF, NAME, OP) \
template <size_t N,
int kPow2> \
HWY_API svbool_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2)
/* d */) { \
return HWY_SVE_PTRUE(BITS); \
} \
template <size_t N,
int kPow2> \
HWY_API svbool_t All
##NAME(HWY_SVE_D(BASE, BITS, N, kPow2)
/* d */) { \
return HWY_SVE_ALL_PTRUE(BITS); \
}
HWY_SVE_FOREACH(HWY_SVE_WRAP_PTRUE, PTrue, ptrue)
// return all-true
HWY_SVE_FOREACH_BF16(HWY_SVE_WRAP_PTRUE, PTrue, ptrue)
#undef HWY_SVE_WRAP_PTRUE
HWY_API svbool_t PFalse() {
return svpfalse_b(); }
// Returns all-true if d is HWY_FULL or FirstN(N) after capping N.
//
// This is used in functions that load/store memory; other functions (e.g.
// arithmetic) can ignore d and use PTrue instead.
template <
class D>
svbool_t MakeMask(D d) {
return IsFull(d) ? PTrue(d) : FirstN(d, Lanes(d));
}
}
// namespace detail
#ifdef HWY_NATIVE_MASK_FALSE
#undef HWY_NATIVE_MASK_FALSE
#else
#define HWY_NATIVE_MASK_FALSE
#endif
template <
class D>
HWY_API svbool_t MaskFalse(
const D
/*d*/) {
return detail::PFalse();
}
// ================================================== INIT
// ------------------------------ Set
// vector = f(d, scalar), e.g. Set
#define HWY_SVE_SET(BASE,
CHAR, BITS, HALF, NAME, OP) \
template <size_t N,
int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_D(BASE, BITS, N, kPow2)
/* d */, \
HWY_SVE_T(BASE, BITS) arg) { \
return sv
##OP
##_
##CHAR##BITS(arg); \
}
HWY_SVE_FOREACH(HWY_SVE_SET, Set, dup_n)
#if HWY_SVE_HAVE_BF16_FEATURE
// for if-elif chain
HWY_SVE_FOREACH_BF16(HWY_SVE_SET, Set, dup_n)
#elif HWY_SVE_HAVE_BF16_VEC
// Required for Zero and VFromD
template <
class D, HWY_IF_BF16_D(D)>
HWY_API svbfloat16_t Set(D d, bfloat16_t arg) {
return svreinterpret_bf16_u16(
Set(RebindToUnsigned<decltype(d)>(), BitCastScalar<uint16_t>(arg)));
}
#else // neither bf16 feature nor vector: emulate with u16
// Required for Zero and VFromD
template <
class D, HWY_IF_BF16_D(D)>
HWY_API svuint16_t Set(D d, bfloat16_t arg) {
const RebindToUnsigned<decltype(d)> du;
return Set(du, BitCastScalar<uint16_t>(arg));
}
#endif // HWY_SVE_HAVE_BF16_FEATURE
#undef HWY_SVE_SET
template <
class D>
using VFromD = decltype(Set(D(), TFromD<D>()));
using VBF16 = VFromD<ScalableTag<bfloat16_t>>;
// ------------------------------ Zero
template <
class D>
VFromD<D> Zero(D d) {
// Cast to support bfloat16_t.
const RebindToUnsigned<decltype(d)> du;
return BitCast(d, Set(du, 0));
}
// ------------------------------ BitCast
namespace detail {
// u8: no change
#define HWY_SVE_CAST_NOP(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \
return v; \
} \
template <size_t N,
int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) BitCastFromByte( \
HWY_SVE_D(BASE, BITS, N, kPow2)
/* d */, HWY_SVE_V(BASE, BITS) v) { \
return v; \
}
// All other types
#define HWY_SVE_CAST(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_INLINE svuint8_t BitCastToByte(HWY_SVE_V(BASE, BITS) v) { \
return sv
##OP
##_u8_
##CHAR##BITS(v); \
} \
template <size_t N,
int kPow2> \
HWY_INLINE HWY_SVE_V(BASE, BITS) \
BitCastFromByte(HWY_SVE_D(BASE, BITS, N, kPow2)
/* d */, svuint8_t v) { \
return sv
##OP
##_
##CHAR##BITS
##_u8(v); \
}
// U08 is special-cased, hence do not use FOREACH.
HWY_SVE_FOREACH_U08(HWY_SVE_CAST_NOP, _, _)
HWY_SVE_FOREACH_I08(HWY_SVE_CAST, _, reinterpret)
HWY_SVE_FOREACH_UI16(HWY_SVE_CAST, _, reinterpret)
HWY_SVE_FOREACH_UI32(HWY_SVE_CAST, _, reinterpret)
HWY_SVE_FOREACH_UI64(HWY_SVE_CAST, _, reinterpret)
HWY_SVE_FOREACH_F(HWY_SVE_CAST, _, reinterpret)
#undef HWY_SVE_CAST_NOP
#undef HWY_SVE_CAST
template <
class V, HWY_SVE_IF_EMULATED_D(DFromV<V>)>
HWY_INLINE svuint8_t BitCastToByte(V v) {
#if HWY_SVE_HAVE_BF16_VEC
return svreinterpret_u8_bf16(v);
#else
const RebindToUnsigned<DFromV<V>> du;
return BitCastToByte(BitCast(du, v));
#endif
}
template <
class D, HWY_SVE_IF_EMULATED_D(D)>
HWY_INLINE VFromD<D> BitCastFromByte(D d, svuint8_t v) {
#if HWY_SVE_HAVE_BF16_VEC
(
void)d;
return svreinterpret_bf16_u8(v);
#else
const RebindToUnsigned<decltype(d)> du;
return BitCastFromByte(du, v);
#endif
}
}
// namespace detail
template <
class D,
class FromV>
HWY_API VFromD<D> BitCast(D d, FromV v) {
return detail::BitCastFromByte(d, detail::BitCastToByte(v));
}
// ------------------------------ Undefined
#define HWY_SVE_UNDEFINED(BASE,
CHAR, BITS, HALF, NAME, OP) \
template <size_t N,
int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_D(BASE, BITS, N, kPow2)
/* d */) { \
return sv
##OP
##_
##CHAR##BITS(); \
}
HWY_SVE_FOREACH(HWY_SVE_UNDEFINED, Undefined, undef)
template <
class D, HWY_SVE_IF_EMULATED_D(D)>
VFromD<D> Undefined(D d) {
const RebindToUnsigned<D> du;
return BitCast(d, Undefined(du));
}
// ------------------------------ Tuple
// tuples = f(d, v..), e.g. Create2
#define HWY_SVE_CREATE(BASE,
CHAR, BITS, HALF, NAME, OP) \
template <size_t N,
int kPow2> \
HWY_API HWY_SVE_TUPLE(BASE, BITS, 2) \
NAME
##2(HWY_SVE_D(BASE, BITS, N, kPow2)
/* d */, \
HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1) { \
return sv
##OP
##2_
##CHAR##BITS(v0, v1); \
} \
template <size_t N,
int kPow2> \
HWY_API HWY_SVE_TUPLE(BASE, BITS, 3) NAME
##3( \
HWY_SVE_D(BASE, BITS, N, kPow2)
/* d */, HWY_SVE_V(BASE, BITS) v0, \
HWY_SVE_V(BASE, BITS) v1, HWY_SVE_V(BASE, BITS) v2) { \
return sv
##OP
##3_
##CHAR##BITS(v0, v1, v2); \
} \
template <size_t N,
int kPow2> \
HWY_API HWY_SVE_TUPLE(BASE, BITS, 4) \
NAME
##4(HWY_SVE_D(BASE, BITS, N, kPow2)
/* d */, \
HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \
HWY_SVE_V(BASE, BITS) v2, HWY_SVE_V(BASE, BITS) v3) { \
return sv
##OP
##4_
##CHAR##BITS(v0, v1, v2, v3); \
}
HWY_SVE_FOREACH(HWY_SVE_CREATE, Create, create)
HWY_SVE_FOREACH_BF16(HWY_SVE_CREATE, Create, create)
#undef HWY_SVE_CREATE
template <
class D>
using Vec2 = decltype(Create2(D(), Zero(D()), Zero(D())));
template <
class D>
using Vec3 = decltype(Create3(D(), Zero(D()), Zero(D()), Zero(D())));
template <
class D>
using Vec4 = decltype(Create4(D(), Zero(D()), Zero(D()), Zero(D()), Zero(D())));
#define HWY_SVE_GET(BASE,
CHAR, BITS, HALF, NAME, OP) \
template <size_t kIndex> \
HWY_API HWY_SVE_V(BASE, BITS) NAME
##2(HWY_SVE_TUPLE(BASE, BITS, 2) tuple) { \
return sv
##OP
##2_
##CHAR##BITS(tuple, kIndex); \
} \
template <size_t kIndex> \
HWY_API HWY_SVE_V(BASE, BITS) NAME
##3(HWY_SVE_TUPLE(BASE, BITS, 3) tuple) { \
return sv
##OP
##3_
##CHAR##BITS(tuple, kIndex); \
} \
template <size_t kIndex> \
HWY_API HWY_SVE_V(BASE, BITS) NAME
##4(HWY_SVE_TUPLE(BASE, BITS, 4) tuple) { \
return sv
##OP
##4_
##CHAR##BITS(tuple, kIndex); \
}
HWY_SVE_FOREACH(HWY_SVE_GET, Get, get)
HWY_SVE_FOREACH_BF16(HWY_SVE_GET, Get, get)
#undef HWY_SVE_GET
#define HWY_SVE_SET(BASE,
CHAR, BITS, HALF, NAME, OP) \
template <size_t kIndex> \
HWY_API HWY_SVE_TUPLE(BASE, BITS, 2) \
NAME
##2(HWY_SVE_TUPLE(BASE, BITS, 2) tuple, HWY_SVE_V(BASE, BITS) vec) { \
return sv
##OP
##2_
##CHAR##BITS(tuple, kIndex, vec); \
} \
template <size_t kIndex> \
HWY_API HWY_SVE_TUPLE(BASE, BITS, 3) \
NAME
##3(HWY_SVE_TUPLE(BASE, BITS, 3) tuple, HWY_SVE_V(BASE, BITS) vec) { \
return sv
##OP
##3_
##CHAR##BITS(tuple, kIndex, vec); \
} \
template <size_t kIndex> \
HWY_API HWY_SVE_TUPLE(BASE, BITS, 4) \
NAME
##4(HWY_SVE_TUPLE(BASE, BITS, 4) tuple, HWY_SVE_V(BASE, BITS) vec) { \
return sv
##OP
##4_
##CHAR##BITS(tuple, kIndex, vec); \
}
HWY_SVE_FOREACH(HWY_SVE_SET, Set, set)
HWY_SVE_FOREACH_BF16(HWY_SVE_SET, Set, set)
#undef HWY_SVE_SET
// ------------------------------ ResizeBitCast
// Same as BitCast on SVE
template <
class D,
class FromV>
HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
return BitCast(d, v);
}
// ------------------------------ Dup128VecFromValues
template <
class D, HWY_IF_I8_D(D)>
HWY_API svint8_t Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
TFromD<D> t5, TFromD<D> t6, TFromD<D> t7,
TFromD<D> t8, TFromD<D> t9, TFromD<D> t10,
TFromD<D> t11, TFromD<D> t12,
TFromD<D> t13, TFromD<D> t14,
TFromD<D> t15) {
return svdupq_n_s8(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13,
t14, t15);
}
template <
class D, HWY_IF_U8_D(D)>
HWY_API svuint8_t Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
TFromD<D> t5, TFromD<D> t6, TFromD<D> t7,
TFromD<D> t8, TFromD<D> t9, TFromD<D> t10,
TFromD<D> t11, TFromD<D> t12,
TFromD<D> t13, TFromD<D> t14,
TFromD<D> t15) {
return svdupq_n_u8(t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13,
t14, t15);
}
template <
class D, HWY_IF_I16_D(D)>
HWY_API svint16_t Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
TFromD<D> t5, TFromD<D> t6,
TFromD<D> t7) {
return svdupq_n_s16(t0, t1, t2, t3, t4, t5, t6, t7);
}
template <
class D, HWY_IF_U16_D(D)>
HWY_API svuint16_t Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
TFromD<D> t5, TFromD<D> t6,
TFromD<D> t7) {
return svdupq_n_u16(t0, t1, t2, t3, t4, t5, t6, t7);
}
template <
class D, HWY_IF_F16_D(D)>
HWY_API svfloat16_t Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3,
TFromD<D> t4, TFromD<D> t5,
TFromD<D> t6, TFromD<D> t7) {
return svdupq_n_f16(t0, t1, t2, t3, t4, t5, t6, t7);
}
template <
class D, HWY_SVE_IF_EMULATED_D(D)>
HWY_API VBF16 Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1, TFromD<D> t2,
TFromD<D> t3, TFromD<D> t4, TFromD<D> t5,
TFromD<D> t6, TFromD<D> t7) {
const RebindToUnsigned<decltype(d)> du;
return BitCast(
d, Dup128VecFromValues(
du, BitCastScalar<uint16_t>(t0), BitCastScalar<uint16_t>(t1),
BitCastScalar<uint16_t>(t2), BitCastScalar<uint16_t>(t3),
BitCastScalar<uint16_t>(t4), BitCastScalar<uint16_t>(t5),
BitCastScalar<uint16_t>(t6), BitCastScalar<uint16_t>(t7)));
}
template <
class D, HWY_IF_I32_D(D)>
HWY_API svint32_t Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3) {
return svdupq_n_s32(t0, t1, t2, t3);
}
template <
class D, HWY_IF_U32_D(D)>
HWY_API svuint32_t Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3) {
return svdupq_n_u32(t0, t1, t2, t3);
}
template <
class D, HWY_IF_F32_D(D)>
HWY_API svfloat32_t Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3) {
return svdupq_n_f32(t0, t1, t2, t3);
}
template <
class D, HWY_IF_I64_D(D)>
HWY_API svint64_t Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1) {
return svdupq_n_s64(t0, t1);
}
template <
class D, HWY_IF_U64_D(D)>
HWY_API svuint64_t Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1) {
return svdupq_n_u64(t0, t1);
}
template <
class D, HWY_IF_F64_D(D)>
HWY_API svfloat64_t Dup128VecFromValues(D
/*d*/, TFromD<D> t0, TFromD<D> t1) {
return svdupq_n_f64(t0, t1);
}
// ================================================== LOGICAL
// detail::*N() functions accept a scalar argument to avoid extra Set().
// ------------------------------ Not
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPV,
Not,
not )
// NOLINT
// ------------------------------ And
namespace detail {
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, AndN, and_n)
}
// namespace detail
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV,
And,
and)
template <
class V, HWY_IF_FLOAT_V(V)>
HWY_API V
And(
const V a,
const V b) {
const DFromV<V> df;
const RebindToUnsigned<decltype(df)> du;
return BitCast(df,
And(BitCast(du, a), BitCast(du, b)));
}
// ------------------------------ Or
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV,
Or, orr)
template <
class V, HWY_IF_FLOAT_V(V)>
HWY_API V
Or(
const V a,
const V b) {
const DFromV<V> df;
const RebindToUnsigned<decltype(df)> du;
return BitCast(df,
Or(BitCast(du, a), BitCast(du, b)));
}
// ------------------------------ Xor
namespace detail {
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, XorN, eor_n)
}
// namespace detail
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV,
Xor, eor)
template <
class V, HWY_IF_FLOAT_V(V)>
HWY_API V
Xor(
const V a,
const V b) {
const DFromV<V> df;
const RebindToUnsigned<decltype(df)> du;
return BitCast(df,
Xor(BitCast(du, a), BitCast(du, b)));
}
// ------------------------------ AndNot
namespace detail {
#define HWY_SVE_RETV_ARGPVN_SWAP(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_T(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv
##OP
##_
##CHAR##BITS
##_x(HWY_SVE_PTRUE(BITS), b, a); \
}
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN_SWAP, AndNotN, bic_n)
#undef HWY_SVE_RETV_ARGPVN_SWAP
}
// namespace detail
#define HWY_SVE_RETV_ARGPVV_SWAP(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv
##OP
##_
##CHAR##BITS
##_x(HWY_SVE_PTRUE(BITS), b, a); \
}
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV_SWAP, AndNot, bic)
#undef HWY_SVE_RETV_ARGPVV_SWAP
template <
class V, HWY_IF_FLOAT_V(V)>
HWY_API V AndNot(
const V a,
const V b) {
const DFromV<V> df;
const RebindToUnsigned<decltype(df)> du;
return BitCast(df, AndNot(BitCast(du, a), BitCast(du, b)));
}
// ------------------------------ Xor3
#if HWY_SVE_HAVE_2
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVVV, Xor3, eor3)
template <
class V, HWY_IF_FLOAT_V(V)>
HWY_API V Xor3(
const V x1,
const V x2,
const V x3) {
const DFromV<V> df;
const RebindToUnsigned<decltype(df)> du;
return BitCast(df, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3)));
}
#else
template <
class V>
HWY_API V Xor3(V x1, V x2, V x3) {
return Xor(x1,
Xor(x2, x3));
}
#endif
// ------------------------------ Or3
template <
class V>
HWY_API V Or3(V o1, V o2, V o3) {
return Or(o1,
Or(o2, o3));
}
// ------------------------------ OrAnd
template <
class V>
HWY_API V OrAnd(
const V o,
const V a1,
const V a2) {
return Or(o,
And(a1, a2));
}
// ------------------------------ PopulationCount
#ifdef HWY_NATIVE_POPCNT
#undef HWY_NATIVE_POPCNT
#else
#define HWY_NATIVE_POPCNT
#endif
// Need to return original type instead of unsigned.
#define HWY_SVE_POPCNT(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return BitCast(DFromV<decltype(v)>(), \
sv
##OP
##_
##CHAR##BITS
##_x(HWY_SVE_PTRUE(BITS), v)); \
}
HWY_SVE_FOREACH_UI(HWY_SVE_POPCNT, PopulationCount, cnt)
#undef HWY_SVE_POPCNT
// ================================================== SIGN
// ------------------------------ Neg
HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Neg, neg)
HWY_API VBF16 Neg(VBF16 v) {
const DFromV<decltype(v)> d;
const RebindToUnsigned<decltype(d)> du;
using TU = TFromD<decltype(du)>;
return BitCast(d,
Xor(BitCast(du, v), Set(du, SignMask<TU>())));
}
// ------------------------------ SaturatedNeg
#if HWY_SVE_HAVE_2
#ifdef HWY_NATIVE_SATURATED_NEG_8_16_32
#undef HWY_NATIVE_SATURATED_NEG_8_16_32
#else
#define HWY_NATIVE_SATURATED_NEG_8_16_32
#endif
#ifdef HWY_NATIVE_SATURATED_NEG_64
#undef HWY_NATIVE_SATURATED_NEG_64
#else
#define HWY_NATIVE_SATURATED_NEG_64
#endif
HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedNeg, qneg)
#endif // HWY_SVE_HAVE_2
// ------------------------------ Abs
HWY_SVE_FOREACH_IF(HWY_SVE_RETV_ARGPV, Abs, abs)
// ------------------------------ SaturatedAbs
#if HWY_SVE_HAVE_2
#ifdef HWY_NATIVE_SATURATED_ABS
#undef HWY_NATIVE_SATURATED_ABS
#else
#define HWY_NATIVE_SATURATED_ABS
#endif
HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPV, SaturatedAbs, qabs)
#endif // HWY_SVE_HAVE_2
// ================================================== ARITHMETIC
// Per-target flags to prevent generic_ops-inl.h defining Add etc.
#ifdef HWY_NATIVE_OPERATOR_REPLACEMENTS
#undef HWY_NATIVE_OPERATOR_REPLACEMENTS
#else
#define HWY_NATIVE_OPERATOR_REPLACEMENTS
#endif
// ------------------------------ Add
namespace detail {
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN, AddN, add_n)
}
// namespace detail
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Add, add)
// ------------------------------ Sub
namespace detail {
// Can't use HWY_SVE_RETV_ARGPVN because caller wants to specify pg.
#define HWY_SVE_RETV_ARGPVN_MASK(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t pg, HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
return sv
##OP
##_
##CHAR##BITS
##_z(pg, a, b); \
}
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVN_MASK, SubN, sub_n)
#undef HWY_SVE_RETV_ARGPVN_MASK
}
// namespace detail
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Sub, sub)
// ------------------------------ SumsOf8
HWY_API svuint64_t SumsOf8(
const svuint8_t v) {
const ScalableTag<uint32_t> du32;
const ScalableTag<uint64_t> du64;
const svbool_t pg = detail::PTrue(du64);
const svuint32_t sums_of_4 = svdot_n_u32(Zero(du32), v, 1);
// Compute pairwise sum of u32 and extend to u64.
#if HWY_SVE_HAVE_2
return svadalp_u64_x(pg, Zero(du64), sums_of_4);
#else
const svuint64_t hi = svlsr_n_u64_x(pg, BitCast(du64, sums_of_4), 32);
// Isolate the lower 32 bits (to be added to the upper 32 and zero-extended)
const svuint64_t lo = svextw_u64_x(pg, BitCast(du64, sums_of_4));
return Add(hi, lo);
#endif
}
HWY_API svint64_t SumsOf8(
const svint8_t v) {
const ScalableTag<int32_t> di32;
const ScalableTag<int64_t> di64;
const svbool_t pg = detail::PTrue(di64);
const svint32_t sums_of_4 = svdot_n_s32(Zero(di32), v, 1);
#if HWY_SVE_HAVE_2
return svadalp_s64_x(pg, Zero(di64), sums_of_4);
#else
const svint64_t hi = svasr_n_s64_x(pg, BitCast(di64, sums_of_4), 32);
// Isolate the lower 32 bits (to be added to the upper 32 and sign-extended)
const svint64_t lo = svextw_s64_x(pg, BitCast(di64, sums_of_4));
return Add(hi, lo);
#endif
}
// ------------------------------ SumsOf2
#if HWY_SVE_HAVE_2
namespace detail {
HWY_INLINE svint16_t SumsOf2(hwy::SignedTag
/*type_tag*/,
hwy::SizeTag<1>
/*lane_size_tag*/, svint8_t v) {
const ScalableTag<int16_t> di16;
const svbool_t pg = detail::PTrue(di16);
return svadalp_s16_x(pg, Zero(di16), v);
}
HWY_INLINE svuint16_t SumsOf2(hwy::UnsignedTag
/*type_tag*/,
hwy::SizeTag<1>
/*lane_size_tag*/, svuint8_t v) {
const ScalableTag<uint16_t> du16;
const svbool_t pg = detail::PTrue(du16);
return svadalp_u16_x(pg, Zero(du16), v);
}
HWY_INLINE svint32_t SumsOf2(hwy::SignedTag
/*type_tag*/,
hwy::SizeTag<2>
/*lane_size_tag*/, svint16_t v) {
const ScalableTag<int32_t> di32;
const svbool_t pg = detail::PTrue(di32);
return svadalp_s32_x(pg, Zero(di32), v);
}
HWY_INLINE svuint32_t SumsOf2(hwy::UnsignedTag
/*type_tag*/,
hwy::SizeTag<2>
/*lane_size_tag*/, svuint16_t v) {
const ScalableTag<uint32_t> du32;
const svbool_t pg = detail::PTrue(du32);
return svadalp_u32_x(pg, Zero(du32), v);
}
HWY_INLINE svint64_t SumsOf2(hwy::SignedTag
/*type_tag*/,
hwy::SizeTag<4>
/*lane_size_tag*/, svint32_t v) {
const ScalableTag<int64_t> di64;
const svbool_t pg = detail::PTrue(di64);
return svadalp_s64_x(pg, Zero(di64), v);
}
HWY_INLINE svuint64_t SumsOf2(hwy::UnsignedTag
/*type_tag*/,
hwy::SizeTag<4>
/*lane_size_tag*/, svuint32_t v) {
const ScalableTag<uint64_t> du64;
const svbool_t pg = detail::PTrue(du64);
return svadalp_u64_x(pg, Zero(du64), v);
}
}
// namespace detail
#endif // HWY_SVE_HAVE_2
// ------------------------------ SumsOf4
namespace detail {
HWY_INLINE svint32_t SumsOf4(hwy::SignedTag
/*type_tag*/,
hwy::SizeTag<1>
/*lane_size_tag*/, svint8_t v) {
return svdot_n_s32(Zero(ScalableTag<int32_t>()), v, 1);
}
HWY_INLINE svuint32_t SumsOf4(hwy::UnsignedTag
/*type_tag*/,
hwy::SizeTag<1>
/*lane_size_tag*/, svuint8_t v) {
return svdot_n_u32(Zero(ScalableTag<uint32_t>()), v, 1);
}
HWY_INLINE svint64_t SumsOf4(hwy::SignedTag
/*type_tag*/,
hwy::SizeTag<2>
/*lane_size_tag*/, svint16_t v) {
return svdot_n_s64(Zero(ScalableTag<int64_t>()), v, 1);
}
HWY_INLINE svuint64_t SumsOf4(hwy::UnsignedTag
/*type_tag*/,
hwy::SizeTag<2>
/*lane_size_tag*/, svuint16_t v) {
return svdot_n_u64(Zero(ScalableTag<uint64_t>()), v, 1);
}
}
// namespace detail
// ------------------------------ SaturatedAdd
#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB
#undef HWY_NATIVE_I32_SATURATED_ADDSUB
#else
#define HWY_NATIVE_I32_SATURATED_ADDSUB
#endif
#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB
#undef HWY_NATIVE_U32_SATURATED_ADDSUB
#else
#define HWY_NATIVE_U32_SATURATED_ADDSUB
#endif
#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB
#undef HWY_NATIVE_I64_SATURATED_ADDSUB
#else
#define HWY_NATIVE_I64_SATURATED_ADDSUB
#endif
#ifdef HWY_NATIVE_U64_SATURATED_ADDSUB
#undef HWY_NATIVE_U64_SATURATED_ADDSUB
#else
#define HWY_NATIVE_U64_SATURATED_ADDSUB
#endif
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVV, SaturatedAdd, qadd)
// ------------------------------ SaturatedSub
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVV, SaturatedSub, qsub)
// ------------------------------ AbsDiff
#ifdef HWY_NATIVE_INTEGER_ABS_DIFF
#undef HWY_NATIVE_INTEGER_ABS_DIFF
#else
#define HWY_NATIVE_INTEGER_ABS_DIFF
#endif
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, AbsDiff, abd)
// ------------------------------ ShiftLeft[Same]
#define HWY_SVE_SHIFT_N(BASE,
CHAR, BITS, HALF, NAME, OP) \
template <
int kBits> \
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv
##OP
##_
##CHAR##BITS
##_x(HWY_SVE_PTRUE(BITS), v, kBits); \
} \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME
##Same(HWY_SVE_V(BASE, BITS) v, HWY_SVE_T(uint, BITS) bits) { \
return sv
##OP
##_
##CHAR##BITS
##_x(HWY_SVE_PTRUE(BITS), v, bits); \
}
HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT_N, ShiftLeft, lsl_n)
// ------------------------------ ShiftRight[Same]
HWY_SVE_FOREACH_U(HWY_SVE_SHIFT_N, ShiftRight, lsr_n)
HWY_SVE_FOREACH_I(HWY_SVE_SHIFT_N, ShiftRight, asr_n)
#undef HWY_SVE_SHIFT_N
// ------------------------------ RotateRight
// TODO(janwas): svxar on SVE2
template <
int kBits,
class V>
HWY_API V RotateRight(
const V v) {
constexpr size_t kSizeInBits =
sizeof(TFromV<V>) * 8;
static_assert(0 <= kBits && kBits < kSizeInBits,
"Invalid shift count");
if (kBits == 0)
return v;
return Or(ShiftRight<kBits>(v),
ShiftLeft<HWY_MIN(kSizeInBits - 1, kSizeInBits - kBits)>(v));
}
// ------------------------------ Shl/r
#define HWY_SVE_SHIFT(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(BASE, BITS) bits) { \
const RebindToUnsigned<DFromV<decltype(v)>> du; \
return sv
##OP
##_
##CHAR##BITS
##_x(HWY_SVE_PTRUE(BITS), v, \
BitCast(du, bits)); \
}
HWY_SVE_FOREACH_UI(HWY_SVE_SHIFT, Shl, lsl)
HWY_SVE_FOREACH_U(HWY_SVE_SHIFT, Shr, lsr)
HWY_SVE_FOREACH_I(HWY_SVE_SHIFT, Shr, asr)
#undef HWY_SVE_SHIFT
// ------------------------------ Min/Max
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Min, min)
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVV, Max, max)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Min, minnm)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Max, maxnm)
namespace detail {
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MinN, min_n)
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGPVN, MaxN, max_n)
}
// namespace detail
// ------------------------------ Mul
// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*.
#ifdef HWY_NATIVE_MUL_8
#undef HWY_NATIVE_MUL_8
#else
#define HWY_NATIVE_MUL_8
#endif
#ifdef HWY_NATIVE_MUL_64
#undef HWY_NATIVE_MUL_64
#else
#define HWY_NATIVE_MUL_64
#endif
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGPVV, Mul, mul)
// ------------------------------ MulHigh
HWY_SVE_FOREACH_UI16(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
// Not part of API, used internally:
HWY_SVE_FOREACH_UI08(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
HWY_SVE_FOREACH_U64(HWY_SVE_RETV_ARGPVV, MulHigh, mulh)
// ------------------------------ MulFixedPoint15
HWY_API svint16_t MulFixedPoint15(svint16_t a, svint16_t b) {
#if HWY_SVE_HAVE_2
return svqrdmulh_s16(a, b);
#else
const DFromV<decltype(a)> d;
const RebindToUnsigned<decltype(d)> du;
const svuint16_t lo = BitCast(du, Mul(a, b));
const svint16_t hi = MulHigh(a, b);
// We want (lo + 0x4000) >> 15, but that can overflow, and if it does we must
// carry that into the result. Instead isolate the top two bits because only
// they can influence the result.
const svuint16_t lo_top2 = ShiftRight<14>(lo);
// Bits 11: add 2, 10: add 1, 01: add 1, 00: add 0.
const svuint16_t rounding = ShiftRight<1>(detail::AddN(lo_top2, 1));
return Add(Add(hi, hi), BitCast(d, rounding));
#endif
}
// ------------------------------ Div
#ifdef HWY_NATIVE_INT_DIV
#undef HWY_NATIVE_INT_DIV
#else
#define HWY_NATIVE_INT_DIV
#endif
HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGPVV, Div, div)
HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGPVV, Div, div)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Div, div)
// ------------------------------ ApproximateReciprocal
#ifdef HWY_NATIVE_F64_APPROX_RECIP
#undef HWY_NATIVE_F64_APPROX_RECIP
#else
#define HWY_NATIVE_F64_APPROX_RECIP
#endif
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe)
// ------------------------------ Sqrt
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt)
// ------------------------------ ApproximateReciprocalSqrt
#ifdef HWY_NATIVE_F64_APPROX_RSQRT
#undef HWY_NATIVE_F64_APPROX_RSQRT
#else
#define HWY_NATIVE_F64_APPROX_RSQRT
#endif
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocalSqrt, rsqrte)
// ------------------------------ MulAdd
// Per-target flag to prevent generic_ops-inl.h from defining int MulAdd.
#ifdef HWY_NATIVE_INT_FMA
#undef HWY_NATIVE_INT_FMA
#else
#define HWY_NATIVE_INT_FMA
#endif
#define HWY_SVE_FMA(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) mul, HWY_SVE_V(BASE, BITS) x, \
HWY_SVE_V(BASE, BITS) add) { \
return sv
##OP
##_
##CHAR##BITS
##_x(HWY_SVE_PTRUE(BITS), x, mul, add); \
}
HWY_SVE_FOREACH(HWY_SVE_FMA, MulAdd, mad)
// ------------------------------ NegMulAdd
HWY_SVE_FOREACH(HWY_SVE_FMA, NegMulAdd, msb)
// ------------------------------ MulSub
HWY_SVE_FOREACH_F(HWY_SVE_FMA, MulSub, nmsb)
// ------------------------------ NegMulSub
HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad)
#undef HWY_SVE_FMA
// ------------------------------ Round etc.
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Floor, rintm)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Ceil, rintp)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Trunc, rintz)
// ================================================== MASK
// ------------------------------ RebindMask
template <
class D,
typename MFrom>
HWY_API svbool_t RebindMask(
const D
/*d*/, const MFrom mask) {
return mask;
}
// ------------------------------ Mask logical
HWY_API svbool_t
Not(svbool_t m) {
// We don't know the lane type, so assume 8-bit. For larger types, this will
// de-canonicalize the predicate, i.e. set bits to 1 even though they do not
// correspond to the lowest byte in the lane. Arm says such bits are ignored.
return svnot_b_z(HWY_SVE_PTRUE(8), m);
}
HWY_API svbool_t
And(svbool_t a, svbool_t b) {
return svand_b_z(b, b, a);
// same order as AndNot for consistency
}
HWY_API svbool_t AndNot(svbool_t a, svbool_t b) {
return svbic_b_z(b, b, a);
// reversed order like NEON
}
HWY_API svbool_t
Or(svbool_t a, svbool_t b) {
return svsel_b(a, a, b);
// a ? true : b
}
HWY_API svbool_t
Xor(svbool_t a, svbool_t b) {
return svsel_b(a, svnand_b_z(a, a, b), b);
// a ? !(a & b) : b.
}
HWY_API svbool_t ExclusiveNeither(svbool_t a, svbool_t b) {
return svnor_b_z(HWY_SVE_PTRUE(8), a, b);
// !a && !b, undefined if a && b.
}
// ------------------------------ CountTrue
#define HWY_SVE_COUNT_TRUE(BASE,
CHAR, BITS, HALF, NAME, OP) \
template <size_t N,
int kPow2> \
HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2) d, svbool_t m) { \
return sv
##OP
##_b
##BITS(detail::MakeMask(d), m); \
}
HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE, CountTrue, cntp)
#undef HWY_SVE_COUNT_TRUE
// For 16-bit Compress: full vector, not limited to SV_POW2.
namespace detail {
#define HWY_SVE_COUNT_TRUE_FULL(BASE,
CHAR, BITS, HALF, NAME, OP) \
template <size_t N,
int kPow2> \
HWY_API size_t NAME(HWY_SVE_D(BASE, BITS, N, kPow2)
/* d */, svbool_t m) { \
return sv
##OP
##_b
##BITS(svptrue_b
##BITS(), m); \
}
HWY_SVE_FOREACH(HWY_SVE_COUNT_TRUE_FULL, CountTrueFull, cntp)
#undef HWY_SVE_COUNT_TRUE_FULL
}
// namespace detail
// ------------------------------ AllFalse
template <
class D>
HWY_API
bool AllFalse(D d, svbool_t m) {
return !svptest_any(detail::MakeMask(d), m);
}
// ------------------------------ AllTrue
template <
class D>
HWY_API
bool AllTrue(D d, svbool_t m) {
return CountTrue(d, m) == Lanes(d);
}
// ------------------------------ FindFirstTrue
template <
class D>
HWY_API intptr_t FindFirstTrue(D d, svbool_t m) {
return AllFalse(d, m) ? intptr_t{-1}
:
static_cast<intptr_t>(
CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m)));
}
// ------------------------------ FindKnownFirstTrue
template <
class D>
HWY_API size_t FindKnownFirstTrue(D d, svbool_t m) {
return CountTrue(d, svbrkb_b_z(detail::MakeMask(d), m));
}
// ------------------------------ IfThenElse
#define HWY_SVE_IF_THEN_ELSE(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) yes, HWY_SVE_V(BASE, BITS) no) { \
return sv
##OP
##_
##CHAR##BITS(m, yes, no); \
}
HWY_SVE_FOREACH(HWY_SVE_IF_THEN_ELSE, IfThenElse, sel)
#undef HWY_SVE_IF_THEN_ELSE
template <
class V,
class D = DFromV<V>, HWY_SVE_IF_EMULATED_D(D)>
HWY_API V IfThenElse(
const svbool_t mask, V yes, V no) {
const RebindToUnsigned<D> du;
return BitCast(
D(), IfThenElse(RebindMask(du, mask), BitCast(du, yes), BitCast(du, no)));
}
// ------------------------------ IfThenElseZero
template <
class V,
class D = DFromV<V>, HWY_SVE_IF_NOT_EMULATED_D(D)>
HWY_API V IfThenElseZero(
const svbool_t mask,
const V yes) {
return IfThenElse(mask, yes, Zero(D()));
}
template <
class V,
class D = DFromV<V>, HWY_SVE_IF_EMULATED_D(D)>
HWY_API V IfThenElseZero(
const svbool_t mask, V yes) {
const RebindToUnsigned<D> du;
return BitCast(D(), IfThenElseZero(RebindMask(du, mask), BitCast(du, yes)));
}
// ------------------------------ IfThenZeroElse
template <
class V,
class D = DFromV<V>, HWY_SVE_IF_NOT_EMULATED_D(D)>
HWY_API V IfThenZeroElse(
const svbool_t mask,
const V no) {
return IfThenElse(mask, Zero(D()), no);
}
template <
class V,
class D = DFromV<V>, HWY_SVE_IF_EMULATED_D(D)>
HWY_API V IfThenZeroElse(
const svbool_t mask, V no) {
const RebindToUnsigned<D> du;
return BitCast(D(), IfThenZeroElse(RebindMask(du, mask), BitCast(du, no)));
}
// ------------------------------ Additional mask logical operations
HWY_API svbool_t SetBeforeFirst(svbool_t m) {
// We don't know the lane type, so assume 8-bit. For larger types, this will
// de-canonicalize the predicate, i.e. set bits to 1 even though they do not
// correspond to the lowest byte in the lane. Arm says such bits are ignored.
return svbrkb_b_z(HWY_SVE_PTRUE(8), m);
}
HWY_API svbool_t SetAtOrBeforeFirst(svbool_t m) {
// We don't know the lane type, so assume 8-bit. For larger types, this will
// de-canonicalize the predicate, i.e. set bits to 1 even though they do not
// correspond to the lowest byte in the lane. Arm says such bits are ignored.
return svbrka_b_z(HWY_SVE_PTRUE(8), m);
}
HWY_API svbool_t SetOnlyFirst(svbool_t m) {
return svbrka_b_z(m, m); }
HWY_API svbool_t SetAtOrAfterFirst(svbool_t m) {
return Not(SetBeforeFirst(m));
}
// ------------------------------ PromoteMaskTo
#ifdef HWY_NATIVE_PROMOTE_MASK_TO
#undef HWY_NATIVE_PROMOTE_MASK_TO
#else
#define HWY_NATIVE_PROMOTE_MASK_TO
#endif
template <
class DTo,
class DFrom,
HWY_IF_T_SIZE_D(DTo,
sizeof(TFromD<DFrom>) * 2)>
HWY_API svbool_t PromoteMaskTo(DTo
/*d_to*/, DFrom /*d_from*/, svbool_t m) {
return svunpklo_b(m);
}
template <
class DTo,
class DFrom,
HWY_IF_T_SIZE_GT_D(DTo,
sizeof(TFromD<DFrom>) * 2)>
HWY_API svbool_t PromoteMaskTo(DTo d_to, DFrom d_from, svbool_t m) {
using TFrom = TFromD<DFrom>;
using TWFrom = MakeWide<MakeUnsigned<TFrom>>;
static_assert(
sizeof(TWFrom) >
sizeof(TFrom),
"sizeof(TWFrom) > sizeof(TFrom) must be true");
const Rebind<TWFrom, decltype(d_from)> dw_from;
return PromoteMaskTo(d_to, dw_from, PromoteMaskTo(dw_from, d_from, m));
}
// ------------------------------ DemoteMaskTo
#ifdef HWY_NATIVE_DEMOTE_MASK_TO
#undef HWY_NATIVE_DEMOTE_MASK_TO
#else
#define HWY_NATIVE_DEMOTE_MASK_TO
#endif
template <
class DTo,
class DFrom, HWY_IF_T_SIZE_D(DTo, 1),
HWY_IF_T_SIZE_D(DFrom, 2)>
HWY_API svbool_t DemoteMaskTo(DTo
/*d_to*/, DFrom /*d_from*/, svbool_t m) {
return svuzp1_b8(m, m);
}
template <
class DTo,
class DFrom, HWY_IF_T_SIZE_D(DTo, 2),
HWY_IF_T_SIZE_D(DFrom, 4)>
HWY_API svbool_t DemoteMaskTo(DTo
/*d_to*/, DFrom /*d_from*/, svbool_t m) {
return svuzp1_b16(m, m);
}
template <
class DTo,
class DFrom, HWY_IF_T_SIZE_D(DTo, 4),
HWY_IF_T_SIZE_D(DFrom, 8)>
HWY_API svbool_t DemoteMaskTo(DTo
/*d_to*/, DFrom /*d_from*/, svbool_t m) {
return svuzp1_b32(m, m);
}
template <
class DTo,
class DFrom,
HWY_IF_T_SIZE_LE_D(DTo,
sizeof(TFromD<DFrom>) / 4)>
HWY_API svbool_t DemoteMaskTo(DTo d_to, DFrom d_from, svbool_t m) {
using TFrom = TFromD<DFrom>;
using TNFrom = MakeNarrow<MakeUnsigned<TFrom>>;
static_assert(
sizeof(TNFrom) <
sizeof(TFrom),
"sizeof(TNFrom) < sizeof(TFrom) must be true");
const Rebind<TNFrom, decltype(d_from)> dn_from;
return DemoteMaskTo(d_to, dn_from, DemoteMaskTo(dn_from, d_from, m));
}
// ------------------------------ LowerHalfOfMask
#ifdef HWY_NATIVE_LOWER_HALF_OF_MASK
#undef HWY_NATIVE_LOWER_HALF_OF_MASK
#else
#define HWY_NATIVE_LOWER_HALF_OF_MASK
#endif
template <
class D>
HWY_API svbool_t LowerHalfOfMask(D
/*d*/, svbool_t m) {
return m;
}
// ------------------------------ MaskedAddOr etc. (IfThenElse)
#ifdef HWY_NATIVE_MASKED_ARITH
#undef HWY_NATIVE_MASKED_ARITH
#else
#define HWY_NATIVE_MASKED_ARITH
#endif
namespace detail {
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMin, min)
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMax, max)
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedAdd, add)
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedSub, sub)
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMul, mul)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedDiv, div)
HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGMVV, MaskedDiv, div)
HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGMVV, MaskedDiv, div)
#if HWY_SVE_HAVE_2
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatAdd, qadd)
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatSub, qsub)
#endif
}
// namespace detail
template <
class V,
class M>
HWY_API V MaskedMinOr(V no, M m, V a, V b) {
return IfThenElse(m, detail::MaskedMin(m, a, b), no);
}
template <
class V,
class M>
HWY_API V MaskedMaxOr(V no, M m, V a, V b) {
return IfThenElse(m, detail::MaskedMax(m, a, b), no);
}
template <
class V,
class M>
HWY_API V MaskedAddOr(V no, M m, V a, V b) {
return IfThenElse(m, detail::MaskedAdd(m, a, b), no);
}
template <
class V,
class M>
HWY_API V MaskedSubOr(V no, M m, V a, V b) {
return IfThenElse(m, detail::MaskedSub(m, a, b), no);
}
template <
class V,
class M>
HWY_API V MaskedMulOr(V no, M m, V a, V b) {
return IfThenElse(m, detail::MaskedMul(m, a, b), no);
}
template <
class V,
class M,
HWY_IF_T_SIZE_ONE_OF_V(
V, (hwy::IsSame<TFromV<V>, hwy::float16_t>() ? (1 << 2) : 0) |
(1 << 4) | (1 << 8))>
HWY_API V MaskedDivOr(V no, M m, V a, V b) {
return IfThenElse(m, detail::MaskedDiv(m, a, b), no);
}
// I8/U8/I16/U16 MaskedDivOr is implemented after I8/U8/I16/U16 Div
#if HWY_SVE_HAVE_2
template <
class V,
class M>
HWY_API V MaskedSatAddOr(V no, M m, V a, V b) {
return IfThenElse(m, detail::MaskedSatAdd(m, a, b), no);
}
template <
class V,
class M>
HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
return IfThenElse(m, detail::MaskedSatSub(m, a, b), no);
}
#else
template <
class V,
class M>
HWY_API V MaskedSatAddOr(V no, M m, V a, V b) {
return IfThenElse(m, SaturatedAdd(a, b), no);
}
template <
class V,
class M>
HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
return IfThenElse(m, SaturatedSub(a, b), no);
}
#endif
// ================================================== COMPARE
// mask = f(vector, vector)
#define HWY_SVE_COMPARE(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv
##OP
##_
##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \
}
#define HWY_SVE_COMPARE_N(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API svbool_t NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_T(BASE, BITS) b) { \
return sv
##OP
##_
##CHAR##BITS(HWY_SVE_PTRUE(BITS), a, b); \
}
// ------------------------------ Eq
HWY_SVE_FOREACH(HWY_SVE_COMPARE, Eq, cmpeq)
namespace detail {
HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, EqN, cmpeq_n)
}
// namespace detail
// ------------------------------ Ne
HWY_SVE_FOREACH(HWY_SVE_COMPARE, Ne, cmpne)
namespace detail {
HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, NeN, cmpne_n)
}
// namespace detail
// ------------------------------ Lt
HWY_SVE_FOREACH(HWY_SVE_COMPARE, Lt, cmplt)
namespace detail {
HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LtN, cmplt_n)
}
// namespace detail
// ------------------------------ Le
HWY_SVE_FOREACH(HWY_SVE_COMPARE, Le, cmple)
namespace detail {
HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, LeN, cmple_n)
}
// namespace detail
// ------------------------------ Gt/Ge (swapped order)
template <
class V>
HWY_API svbool_t Gt(
const V a,
const V b) {
return Lt(b, a);
}
template <
class V>
HWY_API svbool_t Ge(
const V a,
const V b) {
return Le(b, a);
}
namespace detail {
HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, GeN, cmpge_n)
HWY_SVE_FOREACH(HWY_SVE_COMPARE_N, GtN, cmpgt_n)
}
// namespace detail
#undef HWY_SVE_COMPARE
#undef HWY_SVE_COMPARE_N
// ------------------------------ TestBit
template <
class V>
HWY_API svbool_t TestBit(
const V a,
const V bit) {
return detail::NeN(
And(a, bit), 0);
}
// ------------------------------ MaskFromVec (Ne)
template <
class V>
HWY_API svbool_t MaskFromVec(
const V v) {
using T = TFromV<V>;
return detail::NeN(v, ConvertScalarTo<T>(0));
}
// ------------------------------ VecFromMask
template <
class D>
HWY_API VFromD<D> VecFromMask(
const D d, svbool_t mask) {
const RebindToSigned<D> di;
// This generates MOV imm, whereas svdup_n_s8_z generates MOV scalar, which
// requires an extra instruction plus M0 pipeline.
return BitCast(d, IfThenElseZero(mask, Set(di, -1)));
}
// ------------------------------ IfVecThenElse (MaskFromVec, IfThenElse)
#if HWY_SVE_HAVE_2
#define HWY_SVE_IF_VEC(BASE,
CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) yes, \
HWY_SVE_V(BASE, BITS) no) { \
return sv
##OP
##_
##CHAR##BITS(yes, no, mask); \
}
HWY_SVE_FOREACH_UI(HWY_SVE_IF_VEC, IfVecThenElse, bsl)
#undef HWY_SVE_IF_VEC
template <
class V, HWY_IF_FLOAT_V(V)>
HWY_API V IfVecThenElse(
const V mask,
const V yes,
const V no) {
const DFromV<V> d;
const RebindToUnsigned<decltype(d)> du;
return BitCast(
d, IfVecThenElse(BitCast(du, mask), BitCast(du, yes), BitCast(du, no)));
}
#else
template <
class V>
HWY_API V IfVecThenElse(
const V mask,
const V yes,
const V no) {
return Or(
And(mask, yes), AndNot(mask, no));
}
#endif // HWY_SVE_HAVE_2
// ------------------------------ BitwiseIfThenElse
#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE
#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE
#else
#define HWY_NATIVE_BITWISE_IF_THEN_ELSE
#endif
template <
class V>
HWY_API V BitwiseIfThenElse(V mask, V yes, V no) {
return IfVecThenElse(mask, yes, no);
}
// ------------------------------ CopySign (BitwiseIfThenElse)
template <
class V>
HWY_API V CopySign(
const V magn,
const V sign) {
const DFromV<decltype(magn)> d;
return BitwiseIfThenElse(SignBit(d), sign, magn);
}
// ------------------------------ CopySignToAbs
template <
class V>
HWY_API V CopySignToAbs(
const V abs,
const V sign) {
#if HWY_SVE_HAVE_2
// CopySign is more efficient than OrAnd
return CopySign(abs, sign);
#else
const DFromV<V> d;
return OrAnd(abs, SignBit(d), sign);
#endif
}
// ------------------------------ Floating-point classification (Ne)
template <
class V>
HWY_API svbool_t IsNaN(
const V v) {
return Ne(v, v);
// could also use cmpuo
}
// Per-target flag to prevent generic_ops-inl.h from defining IsInf / IsFinite.
// We use a fused Set/comparison for IsFinite.
#ifdef HWY_NATIVE_ISINF
#undef HWY_NATIVE_ISINF
#else
#define HWY_NATIVE_ISINF
#endif
template <
class V>
HWY_API svbool_t IsInf(
const V v) {
using T = TFromV<V>;
const DFromV<decltype(v)> d;
const RebindToUnsigned<decltype(d)> du;
const RebindToSigned<decltype(d)> di;
// 'Shift left' to clear the sign bit
const VFromD<decltype(du)> vu = BitCast(du, v);
const VFromD<decltype(du)> v2 = Add(vu, vu);
// Check for exponent=max and mantissa=0.
const VFromD<decltype(di)> max2 = Set(di, hwy::MaxExponentTimes2<T>());
return RebindMask(d, Eq(v2, BitCast(du, max2)));
}
// Returns whether normal/subnormal/zero.
template <
class V>
HWY_API svbool_t IsFinite(
const V v) {
using T = TFromV<V>;
const DFromV<decltype(v)> d;
const RebindToUnsigned<decltype(d)> du;
const RebindToSigned<decltype(d)> di;
// cheaper than unsigned comparison
const VFromD<decltype(du)> vu = BitCast(du, v);
// 'Shift left' to clear the sign bit, then right so we can compare with the
// max exponent (cannot compare with MaxExponentTimes2 directly because it is
// negative and non-negative floats would be greater).
const VFromD<decltype(di)> exp =
BitCast(di, ShiftRight<hwy::MantissaBits<T>() + 1>(Add(vu, vu)));
return RebindMask(d, detail::LtN(exp, hwy::MaxExponentField<T>()));
}
// ================================================== MEMORY
// ------------------------------ LoadU/MaskedLoad/LoadDup128/StoreU/Stream
#define HWY_SVE_MEM(BASE,
CHAR, BITS, HALF, NAME, OP) \
template <size_t N,
int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) \
LoadU(HWY_SVE_D(BASE, BITS, N, kPow2) d, \
const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
return svld1_
##CHAR##BITS(detail::MakeMask(d), \
detail::NativeLanePointer(p)); \
} \
template <size_t N,
int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) \
MaskedLoad(svbool_t m, HWY_SVE_D(BASE, BITS, N, kPow2)
/* d */, \
const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
return svld1_
##CHAR##BITS(m, detail::NativeLanePointer(p)); \
} \
template <size_t N,
int kPow2> \
HWY_API
void StoreU(HWY_SVE_V(BASE, BITS) v, \
HWY_SVE_D(BASE, BITS, N, kPow2) d, \
HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
svst1_
##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), v); \
} \
template <size_t N,
int kPow2> \
HWY_API
void Stream(HWY_SVE_V(BASE, BITS) v, \
HWY_SVE_D(BASE, BITS, N, kPow2) d, \
HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
svstnt1_
##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), \
v); \
} \
template <size_t N,
int kPow2> \
HWY_API
void BlendedStore(HWY_SVE_V(BASE, BITS) v, svbool_t m, \
HWY_SVE_D(BASE, BITS, N, kPow2)
/* d */, \
HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
svst1_
##CHAR##BITS(m, detail::NativeLanePointer(p), v); \
}
HWY_SVE_FOREACH(HWY_SVE_MEM, _, _)
HWY_SVE_FOREACH_BF16(HWY_SVE_MEM, _, _)
template <
class D, HWY_SVE_IF_EMULATED_D(D)>
HWY_API VFromD<D> LoadU(D d,
const TFromD<D>* HWY_RESTRICT p) {
const RebindToUnsigned<decltype(d)> du;
return BitCast(d, LoadU(du, detail::U16LanePointer(p)));
}
template <
class D, HWY_SVE_IF_EMULATED_D(D)>
HWY_API
void StoreU(VFromD<D> v, D d, TFromD<D>* HWY_RESTRICT p) {
const RebindToUnsigned<decltype(d)> du;
StoreU(BitCast(du, v), du, detail::U16LanePointer(p));
}
template <
class D, HWY_SVE_IF_EMULATED_D(D)>
HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D d,
const TFromD<D>* HWY_RESTRICT p) {
const RebindToUnsigned<decltype(d)> du;
return BitCast(d,
MaskedLoad(RebindMask(du, m), du, detail::U16LanePointer(p)));
}
// MaskedLoadOr is generic and does not require emulation.
template <
class D, HWY_SVE_IF_EMULATED_D(D)>
HWY_API
void BlendedStore(VFromD<D> v, MFromD<D> m, D d,
TFromD<D>* HWY_RESTRICT p) {
const RebindToUnsigned<decltype(d)> du;
BlendedStore(BitCast(du, v), RebindMask(du, m), du,
detail::U16LanePointer(p));
}
#undef HWY_SVE_MEM
#if HWY_TARGET != HWY_SVE2_128
namespace detail {
#define HWY_SVE_LOAD_DUP128(BASE,
CHAR, BITS, HALF, NAME, OP) \
template <size_t N,
int kPow2> \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_D(BASE, BITS, N, kPow2)
/* d */, \
const HWY_SVE_T(BASE, BITS) * HWY_RESTRICT p) { \
/* All-true predicate to load all 128 bits. */ \
return sv
##OP
##_
##CHAR##BITS(HWY_SVE_PTRUE(8), \
detail::NativeLanePointer(p)); \
}
HWY_SVE_FOREACH(HWY_SVE_LOAD_DUP128, LoadDupFull128, ld1rq)
HWY_SVE_FOREACH_BF16(HWY_SVE_LOAD_DUP128, LoadDupFull128, ld1rq)
template <
class D, HWY_SVE_IF_EMULATED_D(D)>
HWY_API VFromD<D> LoadDupFull128(D d,
const TFromD<D>* HWY_RESTRICT p) {
const RebindToUnsigned<decltype(d)> du;
return BitCast(d, LoadDupFull128(du, detail::U16LanePointer(p)));
}
}
// namespace detail
#endif // HWY_TARGET != HWY_SVE2_128
#if HWY_TARGET == HWY_SVE2_128
// On the HWY_SVE2_128 target, LoadDup128 is the same as LoadU since vectors
// cannot exceed 16 bytes on the HWY_SVE2_128 target.
template <
class D>
HWY_API VFromD<D> LoadDup128(D d,
const TFromD<D>* HWY_RESTRICT p) {
return LoadU(d, p);
}
#else // HWY_TARGET != HWY_SVE2_128
// If D().MaxBytes() <= 16 is true, simply do a LoadU operation.
template <
class D, HWY_IF_V_SIZE_LE_D(D, 16)>
HWY_API VFromD<D> LoadDup128(D d,
const TFromD<D>* HWY_RESTRICT p) {
return LoadU(d, p);
}
// If D().MaxBytes() > 16 is true, need to load the vector using ld1rq
template <
class D, HWY_IF_V_SIZE_GT_D(D, 16)>
HWY_API VFromD<D> LoadDup128(D d,
const TFromD<D>* HWY_RESTRICT p) {
return detail::LoadDupFull128(d, p);
}
#endif // HWY_TARGET != HWY_SVE2_128
// ------------------------------ Load/Store
// SVE only requires lane alignment, not natural alignment of the entire
// vector, so Load/Store are the same as LoadU/StoreU.
template <
class D>
HWY_API VFromD<D> Load(D d,
const TFromD<D>* HWY_RESTRICT p) {
return LoadU(d, p);
}
template <
class V,
class D>
HWY_API
void Store(
const V v, D d, TFromD<D>* HWY_RESTRICT p) {
StoreU(v, d, p);
}
// ------------------------------ MaskedLoadOr
// SVE MaskedLoad hard-codes zero, so this requires an extra blend.
template <
class D>
HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D d,
const TFromD<D>* HWY_RESTRICT p) {
return IfThenElse(m, MaskedLoad(m, d, p), v);
}
// ------------------------------ ScatterOffset/Index
#ifdef HWY_NATIVE_SCATTER
#undef HWY_NATIVE_SCATTER
#else
#define HWY_NATIVE_SCATTER
#endif
#define HWY_SVE_SCATTER_OFFSET(BASE,
CHAR, BITS, HALF, NAME, OP) \
template <size_t N,
int kPow2> \
HWY_API
void NAME(HWY_SVE_V(BASE, BITS) v, \
--> --------------------
--> maximum size reached
--> --------------------