Anforderungen  |   Konzepte  |   Entwurf  |   Entwicklung  |   Qualitätssicherung  |   Lebenszyklus  |   Steuerung
 
 
 
 


Quelle  highbd_inv_txfm_neon.c   Sprache: C

 
/*
 * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
 *
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you canzip
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
 */


#include <arm_neon.h>
#include <assert.h>

#include "av1/common/av1_inv_txfm1d_cfg.h"
#include "av1/common/idct.h"
#include "config/aom_config.h"
#include "config/av1_rtcd.h"

#if AOM_ARCH_AARCH64
#define TRANSPOSE_4X4(x0, x1, x2, x3, y0, y1, y2, y3)         \
  do {                                                        \
    int32x4x2_t swap_low = vtrnq_s32(x0, x1);                 \
    int32x4x2_t swap_high = vtrnq_s32(x2, x3);                \
    y0 = vreinterpretq_s32_s64(                               \
        vzip1q_s64(vreinterpretq_s64_s32(swap_low.val[0]),    \
                   vreinterpretq_s64_s32(swap_high.val[0]))); \
    y1 = vreinterpretq_s32_s64(                               \
        vzip1q_s64(vreinterpretq_s64_s32(swap_low.val[1]),    \
                   vreinterpretq_s64_s32(swap_high.val[1]))); \
    y2 = vreinterpretq_s32_s64(                               \
        vzip2q_s64(vreinterpretq_s64_s32(swap_low.val[0]),    \
                   vreinterpretq_s64_s32(swap_high.val[0]))); \
    y3 = vreinterpretq_s32_s64(                               \
        vzip2q_s64(vreinterpretq_s64_s32(swap_low.val[1]),    \
                   vreinterpretq_s64_s32(swap_high.val[1]))); \
  } while (0)
#else
#define TRANSPOSE_4X4(x0, x1, x2, x3, y0, y1, y2, y3)                    \
  do {                                                                   \
    int32x4x2_t swap_low = vtrnq_s32(x0, x1);                            \
    int32x4x2_t swap_high = vtrnq_s32(x2, x3);                           \
    y0 = vextq_s32(vextq_s32(swap_low.val[0], swap_low.val[0], 2),       \
                   swap_high.val[0], 2);                                 \
    y1 = vextq_s32(vextq_s32(swap_low.val[1], swap_low.val[1], 2),       \
                   swap_high.val[1], 2);                                 \
    y2 = vextq_s32(swap_low.val[0],                                      \
                   vextq_s32(swap_high.val[0], swap_high.val[0], 2), 2); \
    y3 = vextq_s32(swap_low.val[1],                                      \
                   vextq_s32(swap_high.val[1], swap_high.val[1], 2), 2); \
  } while (0)
#endif  // AOM_ARCH_AARCH64

static inline void transpose_4x4(const int32x4_t *in, int32x4_t *out) {
  TRANSPOSE_4X4(in[0], in[1], in[2], in[3], out[0], out[1], out[2], out[3]);
}

static inline void transpose_8x8(const int32x4_t *in, int32x4_t *out) {
  TRANSPOSE_4X4(in[0], in[2], in[4], in[6], out[0], out[2], out[4], out[6]);
  TRANSPOSE_4X4(in[1], in[3], in[5], in[7], out[8], out[10], out[12], out[14]);
  TRANSPOSE_4X4(in[8], in[10], in[12], in[14], out[1], out[3], out[5], out[7]);
  TRANSPOSE_4X4(in[9], in[11], in[13], in[15], out[9], out[11], out[13],
                out[15]);
}

static inline void round_shift_array_32_neon(int32x4_t *input,
                                             int32x4_t *output, const int size,
                                             const int bit) {
  const int32x4_t v_bit = vdupq_n_s32(-bit);
  for (int i = 0; i < size; i++) {
    output[i] = vrshlq_s32(input[i], v_bit);
  }
}

static inline void round_shift_rect_array_32_neon(int32x4_t *input,
                                                  int32x4_t *output,
                                                  const int size) {
  for (int i = 0; i < size; i++) {
    const int32x4_t r0 = vmulq_n_s32(input[i], NewInvSqrt2);
    output[i] = vrshrq_n_s32(r0, NewSqrt2Bits);
  }
}

static inline int32x4_t half_btf_neon_r(const int32_t *n0, const int32x4_t *w0,
                                        const int32_t *n1, const int32x4_t *w1,
                                        const int32x4_t *v_bit,
                                        const int32x4_t *rnding) {
  int32x4_t x;
  x = vmlaq_n_s32(*rnding, *w0, *n0);
  x = vmlaq_n_s32(x, *w1, *n1);
  x = vshlq_s32(x, *v_bit);
  return x;
}

static inline int32x4_t half_btf_neon_mode11_r(
    const int32_t *n0, const int32x4_t *w0, const int32_t *n1,
    const int32x4_t *w1, const int32x4_t *v_bit, const int32x4_t *rnding) {
  int32x4_t x;
  x = vmlaq_n_s32(*rnding, *w0, -*n0);
  x = vmlaq_n_s32(x, *w1, -*n1);
  x = vshlq_s32(x, *v_bit);
  return x;
}

static inline int32x4_t half_btf_neon_mode01_r(
    const int32_t *n0, const int32x4_t *w0, const int32_t *n1,
    const int32x4_t *w1, const int32x4_t *v_bit, const int32x4_t *rnding) {
  int32x4_t x;
  x = vmlaq_n_s32(*rnding, *w0, *n0);
  x = vmlsq_n_s32(x, *w1, *n1);
  x = vshlq_s32(x, *v_bit);
  return x;
}

static inline int32x4_t half_btf_neon_mode10_r(
    const int32_t *n0, const int32x4_t *w0, const int32_t *n1,
    const int32x4_t *w1, const int32x4_t *v_bit, const int32x4_t *rnding) {
  int32x4_t x;
  x = vmlaq_n_s32(*rnding, *w1, *n1);
  x = vmlsq_n_s32(x, *w0, *n0);
  x = vshlq_s32(x, *v_bit);
  return x;
}

static inline int32x4_t half_btf_0_neon_r(const int32_t *n0,
                                          const int32x4_t *w0,
                                          const int32x4_t *v_bit,
                                          const int32x4_t *rnding) {
  int32x4_t x;
  x = vmlaq_n_s32(*rnding, *w0, *n0);
  x = vshlq_s32(x, *v_bit);
  return x;
}

static inline int32x4_t half_btf_0_m_neon_r(const int32_t *n0,
                                            const int32x4_t *w0,
                                            const int32x4_t *v_bit,
                                            const int32x4_t *rnding) {
  int32x4_t x;
  x = vmlaq_n_s32(*rnding, *w0, -*n0);
  x = vshlq_s32(x, *v_bit);
  return x;
}

static inline void flip_buf_neon(int32x4_t *in, int32x4_t *out, int size) {
  for (int i = 0; i < size; ++i) {
    out[size - i - 1] = in[i];
  }
}

typedef void (*fwd_transform_1d_neon)(int32x4_t *in, int32x4_t *out, int bit,
                                      const int num_cols);

typedef void (*transform_1d_neon)(int32x4_t *in, int32x4_t *out, int32_t bit,
                                  int32_t do_cols, int32_t bd,
                                  int32_t out_shift);

static inline uint16x8_t highbd_clamp_u16(uint16x8_t *u, const uint16x8_t *min,
                                          const uint16x8_t *max) {
  int16x8_t clamped;
  clamped = vminq_s16(vreinterpretq_s16_u16(*u), vreinterpretq_s16_u16(*max));
  clamped = vmaxq_s16(clamped, vreinterpretq_s16_u16(*min));
  return vreinterpretq_u16_s16(clamped);
}

static inline void round_shift_4x4(int32x4_t *in, int shift) {
  if (shift != 0) {
    const int32x4_t v_shift = vdupq_n_s32(-shift);
    in[0] = vrshlq_s32(in[0], v_shift);
    in[1] = vrshlq_s32(in[1], v_shift);
    in[2] = vrshlq_s32(in[2], v_shift);
    in[3] = vrshlq_s32(in[3], v_shift);
  }
}

static void round_shift_8x8(int32x4_t *in, int shift) {
  assert(shift != 0);
  const int32x4_t v_shift = vdupq_n_s32(-shift);
  in[0] = vrshlq_s32(in[0], v_shift);
  in[1] = vrshlq_s32(in[1], v_shift);
  in[2] = vrshlq_s32(in[2], v_shift);
  in[3] = vrshlq_s32(in[3], v_shift);
  in[4] = vrshlq_s32(in[4], v_shift);
  in[5] = vrshlq_s32(in[5], v_shift);
  in[6] = vrshlq_s32(in[6], v_shift);
  in[7] = vrshlq_s32(in[7], v_shift);
  in[8] = vrshlq_s32(in[8], v_shift);
  in[9] = vrshlq_s32(in[9], v_shift);
  in[10] = vrshlq_s32(in[10], v_shift);
  in[11] = vrshlq_s32(in[11], v_shift);
  in[12] = vrshlq_s32(in[12], v_shift);
  in[13] = vrshlq_s32(in[13], v_shift);
  in[14] = vrshlq_s32(in[14], v_shift);
  in[15] = vrshlq_s32(in[15], v_shift);
}

static void highbd_clamp_s32_neon(int32x4_t *in, int32x4_t *out,
                                  const int32x4_t *clamp_lo,
                                  const int32x4_t *clamp_hi, int size) {
  int32x4_t a0, a1;
  for (int i = 0; i < size; i += 4) {
    a0 = vmaxq_s32(in[i], *clamp_lo);
    out[i] = vminq_s32(a0, *clamp_hi);

    a1 = vmaxq_s32(in[i + 1], *clamp_lo);
    out[i + 1] = vminq_s32(a1, *clamp_hi);

    a0 = vmaxq_s32(in[i + 2], *clamp_lo);
    out[i + 2] = vminq_s32(a0, *clamp_hi);

    a1 = vmaxq_s32(in[i + 3], *clamp_lo);
    out[i + 3] = vminq_s32(a1, *clamp_hi);
  }
}

static inline uint16x8_t highbd_get_recon_8x8_neon(const uint16x8_t pred,
                                                   int32x4_t res0,
                                                   int32x4_t res1,
                                                   const int bd) {
  const uint16x8_t v_zero = vdupq_n_u16(0);
  int32x4_t min_clip_val = vreinterpretq_s32_u16(v_zero);
  int32x4_t max_clip_val = vdupq_n_s32((1 << bd) - 1);
  uint16x8x2_t x;
  x.val[0] = vreinterpretq_u16_s32(
      vaddw_s16(res0, vreinterpret_s16_u16(vget_low_u16(pred))));
  x.val[1] = vreinterpretq_u16_s32(
      vaddw_s16(res1, vreinterpret_s16_u16(vget_high_u16(pred))));
  x.val[0] = vreinterpretq_u16_s32(
      vmaxq_s32(vreinterpretq_s32_u16(x.val[0]), min_clip_val));
  x.val[0] = vreinterpretq_u16_s32(
      vminq_s32(vreinterpretq_s32_u16(x.val[0]), max_clip_val));
  x.val[1] = vreinterpretq_u16_s32(
      vmaxq_s32(vreinterpretq_s32_u16(x.val[1]), min_clip_val));
  x.val[1] = vreinterpretq_u16_s32(
      vminq_s32(vreinterpretq_s32_u16(x.val[1]), max_clip_val));
  uint16x8_t res = vcombine_u16(vqmovn_u32(vreinterpretq_u32_u16(x.val[0])),
                                vqmovn_u32(vreinterpretq_u32_u16(x.val[1])));
  return res;
}

static inline uint16x4_t highbd_get_recon_4xn_neon(uint16x4_t pred,
                                                   int32x4_t res0,
                                                   const int bd) {
  uint16x4_t x0_ = vreinterpret_u16_s16(
      vmovn_s32(vaddw_s16(res0, vreinterpret_s16_u16(pred))));
  uint16x8_t x0 = vcombine_u16(x0_, x0_);
  const uint16x8_t vmin = vdupq_n_u16(0);
  const uint16x8_t vmax = vdupq_n_u16((1 << bd) - 1);
  x0 = highbd_clamp_u16(&x0, &vmin, &vmax);
  return vget_low_u16(x0);
}

static inline void highbd_write_buffer_4xn_neon(int32x4_t *in, uint16_t *output,
                                                int stride, int flipud,
                                                int height, const int bd) {
  int j = flipud ? (height - 1) : 0;
  const int step = flipud ? -1 : 1;
  for (int i = 0; i < height; ++i, j += step) {
    uint16x4_t v = vld1_u16(output + i * stride);
    uint16x4_t u = highbd_get_recon_4xn_neon(v, in[j], bd);

    vst1_u16(output + i * stride, u);
  }
}

static inline void highbd_write_buffer_8xn_neon(int32x4_t *in, uint16_t *output,
                                                int stride, int flipud,
                                                int height, const int bd) {
  int j = flipud ? (height - 1) : 0;
  const int step = flipud ? -1 : 1;
  for (int i = 0; i < height; ++i, j += step) {
    uint16x8_t v = vld1q_u16(output + i * stride);
    uint16x8_t u = highbd_get_recon_8x8_neon(v, in[j], in[j + height], bd);

    vst1q_u16(output + i * stride, u);
  }
}

static inline void load_buffer_32bit_input(const int32_t *in, int stride,
                                           int32x4_t *out, int out_size) {
  for (int i = 0; i < out_size; ++i) {
    out[i] = vld1q_s32(in + i * stride);
  }
}

static inline void load_buffer_4x4(const int32_t *coeff, int32x4_t *in) {
  in[0] = vld1q_s32(coeff + 0);
  in[1] = vld1q_s32(coeff + 4);
  in[2] = vld1q_s32(coeff + 8);
  in[3] = vld1q_s32(coeff + 12);
}

static void addsub_neon(const int32x4_t in0, const int32x4_t in1,
                        int32x4_t *out0, int32x4_t *out1,
                        const int32x4_t *clamp_lo, const int32x4_t *clamp_hi) {
  int32x4_t a0 = vaddq_s32(in0, in1);
  int32x4_t a1 = vsubq_s32(in0, in1);

  a0 = vmaxq_s32(a0, *clamp_lo);
  a0 = vminq_s32(a0, *clamp_hi);
  a1 = vmaxq_s32(a1, *clamp_lo);
  a1 = vminq_s32(a1, *clamp_hi);

  *out0 = a0;
  *out1 = a1;
}

static void shift_and_clamp_neon(int32x4_t *in0, int32x4_t *in1,
                                 const int32x4_t *clamp_lo,
                                 const int32x4_t *clamp_hi,
                                 const int32x4_t *v_shift) {
  int32x4_t in0_w_offset = vrshlq_s32(*in0, *v_shift);
  int32x4_t in1_w_offset = vrshlq_s32(*in1, *v_shift);

  in0_w_offset = vmaxq_s32(in0_w_offset, *clamp_lo);
  in0_w_offset = vminq_s32(in0_w_offset, *clamp_hi);
  in1_w_offset = vmaxq_s32(in1_w_offset, *clamp_lo);
  in1_w_offset = vminq_s32(in1_w_offset, *clamp_hi);

  *in0 = in0_w_offset;
  *in1 = in1_w_offset;
}

static inline void idct32_stage4_neon(int32x4_t *bf1, const int32_t *cospi,
                                      const int32x4_t *v_bit,
                                      const int32x4_t *rnding) {
  int32x4_t temp1, temp2;
  temp1 = half_btf_neon_mode10_r(&cospi[8], &bf1[17], &cospi[56], &bf1[30],
                                 v_bit, rnding);
  bf1[30] =
      half_btf_neon_r(&cospi[56], &bf1[17], &cospi[8], &bf1[30], v_bit, rnding);
  bf1[17] = temp1;

  temp2 = half_btf_neon_mode11_r(&cospi[56], &bf1[18], &cospi[8], &bf1[29],
                                 v_bit, rnding);
  bf1[29] = half_btf_neon_mode10_r(&cospi[8], &bf1[18], &cospi[56], &bf1[29],
                                   v_bit, rnding);
  bf1[18] = temp2;

  temp1 = half_btf_neon_mode10_r(&cospi[40], &bf1[21], &cospi[24], &bf1[26],
                                 v_bit, rnding);
  bf1[26] = half_btf_neon_r(&cospi[24], &bf1[21], &cospi[40], &bf1[26], v_bit,
                            rnding);
  bf1[21] = temp1;

  temp2 = half_btf_neon_mode11_r(&cospi[24], &bf1[22], &cospi[40], &bf1[25],
                                 v_bit, rnding);
  bf1[25] = half_btf_neon_mode10_r(&cospi[40], &bf1[22], &cospi[24], &bf1[25],
                                   v_bit, rnding);
  bf1[22] = temp2;
}

static inline void idct32_stage5_neon(int32x4_t *bf1, const int32_t *cospi,
                                      const int32x4_t *clamp_lo,
                                      const int32x4_t *clamp_hi,
                                      const int32x4_t *v_bit,
                                      const int32x4_t *rnding) {
  int32x4_t temp1, temp2;
  temp1 = half_btf_neon_mode10_r(&cospi[16], &bf1[9], &cospi[48], &bf1[14],
                                 v_bit, rnding);
  bf1[14] =
      half_btf_neon_r(&cospi[48], &bf1[9], &cospi[16], &bf1[14], v_bit, rnding);
  bf1[9] = temp1;

  temp2 = half_btf_neon_mode11_r(&cospi[48], &bf1[10], &cospi[16], &bf1[13],
                                 v_bit, rnding);
  bf1[13] = half_btf_neon_mode10_r(&cospi[16], &bf1[10], &cospi[48], &bf1[13],
                                   v_bit, rnding);
  bf1[10] = temp2;

  addsub_neon(bf1[16], bf1[19], bf1 + 16, bf1 + 19, clamp_lo, clamp_hi);
  addsub_neon(bf1[17], bf1[18], bf1 + 17, bf1 + 18, clamp_lo, clamp_hi);
  addsub_neon(bf1[23], bf1[20], bf1 + 23, bf1 + 20, clamp_lo, clamp_hi);
  addsub_neon(bf1[22], bf1[21], bf1 + 22, bf1 + 21, clamp_lo, clamp_hi);
  addsub_neon(bf1[24], bf1[27], bf1 + 24, bf1 + 27, clamp_lo, clamp_hi);
  addsub_neon(bf1[25], bf1[26], bf1 + 25, bf1 + 26, clamp_lo, clamp_hi);
  addsub_neon(bf1[31], bf1[28], bf1 + 31, bf1 + 28, clamp_lo, clamp_hi);
  addsub_neon(bf1[30], bf1[29], bf1 + 30, bf1 + 29, clamp_lo, clamp_hi);
}

static inline void idct32_stage6_neon(int32x4_t *bf1, const int32_t *cospi,
                                      const int32x4_t *clamp_lo,
                                      const int32x4_t *clamp_hi,
                                      const int32x4_t *v_bit,
                                      const int32x4_t *rnding) {
  int32x4_t temp1, temp2;
  temp1 = half_btf_neon_mode10_r(&cospi[32], &bf1[5], &cospi[32], &bf1[6],
                                 v_bit, rnding);
  bf1[6] =
      half_btf_neon_r(&cospi[32], &bf1[5], &cospi[32], &bf1[6], v_bit, rnding);
  bf1[5] = temp1;

  addsub_neon(bf1[8], bf1[11], bf1 + 8, bf1 + 11, clamp_lo, clamp_hi);
  addsub_neon(bf1[9], bf1[10], bf1 + 9, bf1 + 10, clamp_lo, clamp_hi);
  addsub_neon(bf1[15], bf1[12], bf1 + 15, bf1 + 12, clamp_lo, clamp_hi);
  addsub_neon(bf1[14], bf1[13], bf1 + 14, bf1 + 13, clamp_lo, clamp_hi);

  temp1 = half_btf_neon_mode10_r(&cospi[16], &bf1[18], &cospi[48], &bf1[29],
                                 v_bit, rnding);
  bf1[29] = half_btf_neon_r(&cospi[48], &bf1[18], &cospi[16], &bf1[29], v_bit,
                            rnding);
  bf1[18] = temp1;
  temp2 = half_btf_neon_mode10_r(&cospi[16], &bf1[19], &cospi[48], &bf1[28],
                                 v_bit, rnding);
  bf1[28] = half_btf_neon_r(&cospi[48], &bf1[19], &cospi[16], &bf1[28], v_bit,
                            rnding);
  bf1[19] = temp2;
  temp1 = half_btf_neon_mode11_r(&cospi[48], &bf1[20], &cospi[16], &bf1[27],
                                 v_bit, rnding);
  bf1[27] = half_btf_neon_mode10_r(&cospi[16], &bf1[20], &cospi[48], &bf1[27],
                                   v_bit, rnding);
  bf1[20] = temp1;
  temp2 = half_btf_neon_mode11_r(&cospi[48], &bf1[21], &cospi[16], &bf1[26],
                                 v_bit, rnding);
  bf1[26] = half_btf_neon_mode10_r(&cospi[16], &bf1[21], &cospi[48], &bf1[26],
                                   v_bit, rnding);
  bf1[21] = temp2;
}

static inline void idct32_stage7_neon(int32x4_t *bf1, const int32_t *cospi,
                                      const int32x4_t *clamp_lo,
                                      const int32x4_t *clamp_hi,
                                      const int32x4_t *v_bit,
                                      const int32x4_t *rnding) {
  int32x4_t temp1, temp2;
  addsub_neon(bf1[0], bf1[7], bf1 + 0, bf1 + 7, clamp_lo, clamp_hi);
  addsub_neon(bf1[1], bf1[6], bf1 + 1, bf1 + 6, clamp_lo, clamp_hi);
  addsub_neon(bf1[2], bf1[5], bf1 + 2, bf1 + 5, clamp_lo, clamp_hi);
  addsub_neon(bf1[3], bf1[4], bf1 + 3, bf1 + 4, clamp_lo, clamp_hi);
  temp1 = half_btf_neon_mode10_r(&cospi[32], &bf1[10], &cospi[32], &bf1[13],
                                 v_bit, rnding);
  bf1[13] = half_btf_neon_r(&cospi[32], &bf1[10], &cospi[32], &bf1[13], v_bit,
                            rnding);
  bf1[10] = temp1;
  temp2 = half_btf_neon_mode10_r(&cospi[32], &bf1[11], &cospi[32], &bf1[12],
                                 v_bit, rnding);
  bf1[12] = half_btf_neon_r(&cospi[32], &bf1[11], &cospi[32], &bf1[12], v_bit,
                            rnding);
  bf1[11] = temp2;

  addsub_neon(bf1[16], bf1[23], bf1 + 16, bf1 + 23, clamp_lo, clamp_hi);
  addsub_neon(bf1[17], bf1[22], bf1 + 17, bf1 + 22, clamp_lo, clamp_hi);
  addsub_neon(bf1[18], bf1[21], bf1 + 18, bf1 + 21, clamp_lo, clamp_hi);
  addsub_neon(bf1[19], bf1[20], bf1 + 19, bf1 + 20, clamp_lo, clamp_hi);
  addsub_neon(bf1[31], bf1[24], bf1 + 31, bf1 + 24, clamp_lo, clamp_hi);
  addsub_neon(bf1[30], bf1[25], bf1 + 30, bf1 + 25, clamp_lo, clamp_hi);
  addsub_neon(bf1[29], bf1[26], bf1 + 29, bf1 + 26, clamp_lo, clamp_hi);
  addsub_neon(bf1[28], bf1[27], bf1 + 28, bf1 + 27, clamp_lo, clamp_hi);
}

static inline void idct32_stage8_neon(int32x4_t *bf1, const int32_t *cospi,
                                      const int32x4_t *clamp_lo,
                                      const int32x4_t *clamp_hi,
                                      const int32x4_t *v_bit,
                                      const int32x4_t *rnding) {
  int32x4_t temp1, temp2;
  addsub_neon(bf1[0], bf1[15], bf1 + 0, bf1 + 15, clamp_lo, clamp_hi);
  addsub_neon(bf1[1], bf1[14], bf1 + 1, bf1 + 14, clamp_lo, clamp_hi);
  addsub_neon(bf1[2], bf1[13], bf1 + 2, bf1 + 13, clamp_lo, clamp_hi);
  addsub_neon(bf1[3], bf1[12], bf1 + 3, bf1 + 12, clamp_lo, clamp_hi);
  addsub_neon(bf1[4], bf1[11], bf1 + 4, bf1 + 11, clamp_lo, clamp_hi);
  addsub_neon(bf1[5], bf1[10], bf1 + 5, bf1 + 10, clamp_lo, clamp_hi);
  addsub_neon(bf1[6], bf1[9], bf1 + 6, bf1 + 9, clamp_lo, clamp_hi);
  addsub_neon(bf1[7], bf1[8], bf1 + 7, bf1 + 8, clamp_lo, clamp_hi);
  temp1 = half_btf_neon_mode10_r(&cospi[32], &bf1[20], &cospi[32], &bf1[27],
                                 v_bit, rnding);
  bf1[27] = half_btf_neon_r(&cospi[32], &bf1[20], &cospi[32], &bf1[27], v_bit,
                            rnding);
  bf1[20] = temp1;
  temp2 = half_btf_neon_mode10_r(&cospi[32], &bf1[21], &cospi[32], &bf1[26],
                                 v_bit, rnding);
  bf1[26] = half_btf_neon_r(&cospi[32], &bf1[21], &cospi[32], &bf1[26], v_bit,
                            rnding);
  bf1[21] = temp2;
  temp1 = half_btf_neon_mode10_r(&cospi[32], &bf1[22], &cospi[32], &bf1[25],
                                 v_bit, rnding);
  bf1[25] = half_btf_neon_r(&cospi[32], &bf1[22], &cospi[32], &bf1[25], v_bit,
                            rnding);
  bf1[22] = temp1;
  temp2 = half_btf_neon_mode10_r(&cospi[32], &bf1[23], &cospi[32], &bf1[24],
                                 v_bit, rnding);
  bf1[24] = half_btf_neon_r(&cospi[32], &bf1[23], &cospi[32], &bf1[24], v_bit,
                            rnding);
  bf1[23] = temp2;
}

static inline void idct32_stage9_neon(int32x4_t *bf1, int32x4_t *out,
                                      const int do_cols, const int bd,
                                      const int out_shift,
                                      const int32x4_t *clamp_lo,
                                      const int32x4_t *clamp_hi) {
  addsub_neon(bf1[0], bf1[31], out + 0, out + 31, clamp_lo, clamp_hi);
  addsub_neon(bf1[1], bf1[30], out + 1, out + 30, clamp_lo, clamp_hi);
  addsub_neon(bf1[2], bf1[29], out + 2, out + 29, clamp_lo, clamp_hi);
  addsub_neon(bf1[3], bf1[28], out + 3, out + 28, clamp_lo, clamp_hi);
  addsub_neon(bf1[4], bf1[27], out + 4, out + 27, clamp_lo, clamp_hi);
  addsub_neon(bf1[5], bf1[26], out + 5, out + 26, clamp_lo, clamp_hi);
  addsub_neon(bf1[6], bf1[25], out + 6, out + 25, clamp_lo, clamp_hi);
  addsub_neon(bf1[7], bf1[24], out + 7, out + 24, clamp_lo, clamp_hi);
  addsub_neon(bf1[8], bf1[23], out + 8, out + 23, clamp_lo, clamp_hi);
  addsub_neon(bf1[9], bf1[22], out + 9, out + 22, clamp_lo, clamp_hi);
  addsub_neon(bf1[10], bf1[21], out + 10, out + 21, clamp_lo, clamp_hi);
  addsub_neon(bf1[11], bf1[20], out + 11, out + 20, clamp_lo, clamp_hi);
  addsub_neon(bf1[12], bf1[19], out + 12, out + 19, clamp_lo, clamp_hi);
  addsub_neon(bf1[13], bf1[18], out + 13, out + 18, clamp_lo, clamp_hi);
  addsub_neon(bf1[14], bf1[17], out + 14, out + 17, clamp_lo, clamp_hi);
  addsub_neon(bf1[15], bf1[16], out + 15, out + 16, clamp_lo, clamp_hi);

  if (!do_cols) {
    const int log_range_out = AOMMAX(16, bd + 6);
    const int32x4_t clamp_lo_out = vdupq_n_s32(-(1 << (log_range_out - 1)));
    const int32x4_t clamp_hi_out = vdupq_n_s32((1 << (log_range_out - 1)) - 1);
    for (int i = 0; i < 32; i += 8) {
      round_shift_4x4(out + i, out_shift);
      round_shift_4x4(out + i + 4, out_shift);
    }
    highbd_clamp_s32_neon(out, out, &clamp_lo_out, &clamp_hi_out, 32);
  }
}

static void neg_shift_neon(const int32x4_t *in0, const int32x4_t *in1,
                           int32x4_t *out0, int32x4_t *out1,
                           const int32x4_t *clamp_lo, const int32x4_t *clamp_hi,
                           const int32x4_t *v_shift, int32x4_t *offset) {
  int32x4_t a0 = vaddq_s32(*offset, *in0);
  int32x4_t a1 = vsubq_s32(*offset, *in1);

  a0 = vshlq_s32(a0, *v_shift);
  a1 = vshlq_s32(a1, *v_shift);

  a0 = vmaxq_s32(a0, *clamp_lo);
  a0 = vminq_s32(a0, *clamp_hi);
  a1 = vmaxq_s32(a1, *clamp_lo);
  a1 = vminq_s32(a1, *clamp_hi);

  *out0 = a0;
  *out1 = a1;
}

static void idct4x4_neon(int32x4_t *in, int32x4_t *out, int bit, int do_cols,
                         int bd, int out_shift) {
  const int32_t *cospi = cospi_arr(bit);
  int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
  int32x4_t clamp_lo = vdupq_n_s32(-(1 << (log_range - 1)));
  int32x4_t clamp_hi = vdupq_n_s32((1 << (log_range - 1)) - 1);
  int32x4_t rnding = vdupq_n_s32(1 << (bit - 1));

  int32x4_t u0, u1, u2, u3;
  int32x4_t v0, v1, v2, v3, x, y;

  // Stage 0-1-2

  u0 = in[0];
  u1 = in[1];
  u2 = in[2];
  u3 = in[3];

  const int32x4_t v_bit = vdupq_n_s32(-bit);

  x = vmlaq_n_s32(rnding, u0, cospi[32]);
  y = vmulq_n_s32(u2, cospi[32]);
  v0 = vaddq_s32(x, y);
  v0 = vshlq_s32(v0, v_bit);

  v1 = vsubq_s32(x, y);
  v1 = vshlq_s32(v1, v_bit);

  x = vmlaq_n_s32(rnding, u1, cospi[48]);
  v2 = vmlsq_n_s32(x, u3, cospi[16]);
  v2 = vshlq_s32(v2, v_bit);

  x = vmlaq_n_s32(rnding, u1, cospi[16]);
  v3 = vmlaq_n_s32(x, u3, cospi[48]);
  v3 = vshlq_s32(v3, v_bit);
  // Stage 3
  addsub_neon(v0, v3, out + 0, out + 3, &clamp_lo, &clamp_hi);
  addsub_neon(v1, v2, out + 1, out + 2, &clamp_lo, &clamp_hi);

  if (!do_cols) {
    log_range = AOMMAX(16, bd + 6);
    clamp_lo = vdupq_n_s32(-(1 << (log_range - 1)));
    clamp_hi = vdupq_n_s32((1 << (log_range - 1)) - 1);
    const int32x4_t v_shift = vdupq_n_s32(-out_shift);
    shift_and_clamp_neon(out + 0, out + 3, &clamp_lo, &clamp_hi, &v_shift);
    shift_and_clamp_neon(out + 1, out + 2, &clamp_lo, &clamp_hi, &v_shift);
  }
}

static void iadst4x4_neon(int32x4_t *in, int32x4_t *out, int bit, int do_cols,
                          int bd, int out_shift) {
  const int32_t *sinpi = sinpi_arr(bit);
  const int32x4_t zero = vdupq_n_s32(0);
  int64x2_t rnding = vdupq_n_s64(1ll << (bit + 4 - 1));
  const int32x2_t mul = vdup_n_s32(1 << 4);
  int32x4_t t;
  int32x4_t s0, s1, s2, s3, s4, s5, s6, s7;
  int32x4_t x0, x1, x2, x3;
  int32x4_t u0, u1, u2, u3;

  x0 = in[0];
  x1 = in[1];
  x2 = in[2];
  x3 = in[3];

  s0 = vmulq_n_s32(x0, sinpi[1]);
  s1 = vmulq_n_s32(x0, sinpi[2]);
  s2 = vmulq_n_s32(x1, sinpi[3]);
  s3 = vmulq_n_s32(x2, sinpi[4]);
  s4 = vmulq_n_s32(x2, sinpi[1]);
  s5 = vmulq_n_s32(x3, sinpi[2]);
  s6 = vmulq_n_s32(x3, sinpi[4]);
  t = vsubq_s32(x0, x2);
  s7 = vaddq_s32(t, x3);

  t = vaddq_s32(s0, s3);
  s0 = vaddq_s32(t, s5);
  t = vsubq_s32(s1, s4);
  s1 = vsubq_s32(t, s6);
  s3 = s2;
  s2 = vmulq_n_s32(s7, sinpi[3]);

  u0 = vaddq_s32(s0, s3);
  u1 = vaddq_s32(s1, s3);
  u2 = s2;
  t = vaddq_s32(s0, s1);
  u3 = vsubq_s32(t, s3);

  // u0
  int32x4x2_t u0x;
  u0x.val[0] = vreinterpretq_s32_s64(
      vmull_s32(vmovn_s64(vreinterpretq_s64_s32(u0)), mul));
  u0x.val[0] = vreinterpretq_s32_s64(
      vaddq_s64(vreinterpretq_s64_s32(u0x.val[0]), rnding));

  u0 = vextq_s32(u0, zero, 1);
  u0x.val[1] = vreinterpretq_s32_s64(
      vmull_s32(vmovn_s64(vreinterpretq_s64_s32(u0)), mul));
  u0x.val[1] = vreinterpretq_s32_s64(
      vaddq_s64(vreinterpretq_s64_s32(u0x.val[1]), rnding));

  u0x.val[0] = vreinterpretq_s32_s16(vextq_s16(
      vreinterpretq_s16_s32(u0x.val[0]), vreinterpretq_s16_s32(zero), 1));
  u0x.val[1] = vreinterpretq_s32_s16(vextq_s16(
      vreinterpretq_s16_s32(u0x.val[1]), vreinterpretq_s16_s32(zero), 1));

  u0x = vzipq_s32(u0x.val[0], u0x.val[1]);
#if AOM_ARCH_AARCH64
  u0 = vreinterpretq_s32_s64(vzip1q_s64(vreinterpretq_s64_s32(u0x.val[0]),
                                        vreinterpretq_s64_s32(u0x.val[1])));
#else
  u0 = vcombine_s32(vget_low_s32(u0x.val[0]), vget_low_s32(u0x.val[1]));
#endif  // AOM_ARCH_AARCH64
  // u1
  int32x4x2_t u1x;
  u1x.val[0] = vreinterpretq_s32_s64(
      vmull_s32(vmovn_s64(vreinterpretq_s64_s32(u1)), mul));
  u1x.val[0] = vreinterpretq_s32_s64(
      vaddq_s64(vreinterpretq_s64_s32(u1x.val[0]), rnding));

  u1 = vextq_s32(u1, zero, 1);
  u1x.val[1] = vreinterpretq_s32_s64(
      vmull_s32(vmovn_s64(vreinterpretq_s64_s32(u1)), mul));
  u1x.val[1] = vreinterpretq_s32_s64(
      vaddq_s64(vreinterpretq_s64_s32(u1x.val[1]), rnding));

  u1x.val[0] = vreinterpretq_s32_s16(vextq_s16(
      vreinterpretq_s16_s32(u1x.val[0]), vreinterpretq_s16_s32(zero), 1));
  u1x.val[1] = vreinterpretq_s32_s16(vextq_s16(
      vreinterpretq_s16_s32(u1x.val[1]), vreinterpretq_s16_s32(zero), 1));

  u1x = vzipq_s32(u1x.val[0], u1x.val[1]);
#if AOM_ARCH_AARCH64
  u1 = vreinterpretq_s32_s64(vzip1q_s64(vreinterpretq_s64_s32(u1x.val[0]),
                                        vreinterpretq_s64_s32(u1x.val[1])));
#else
  u1 = vcombine_s32(vget_low_s32(u1x.val[0]), vget_low_s32(u1x.val[1]));
#endif  // AOM_ARCH_AARCH64

  // u2
  int32x4x2_t u2x;
  u2x.val[0] = vreinterpretq_s32_s64(
      vmull_s32(vmovn_s64(vreinterpretq_s64_s32(u2)), mul));
  u2x.val[0] = vreinterpretq_s32_s64(
      vaddq_s64(vreinterpretq_s64_s32(u2x.val[0]), rnding));

  u2 = vextq_s32(u2, zero, 1);
  u2x.val[1] = vreinterpretq_s32_s64(
      vmull_s32(vmovn_s64(vreinterpretq_s64_s32(u2)), mul));
  u2x.val[1] = vreinterpretq_s32_s64(
      vaddq_s64(vreinterpretq_s64_s32(u2x.val[1]), rnding));

  u2x.val[0] = vreinterpretq_s32_s16(vextq_s16(
      vreinterpretq_s16_s32(u2x.val[0]), vreinterpretq_s16_s32(zero), 1));
  u2x.val[1] = vreinterpretq_s32_s16(vextq_s16(
      vreinterpretq_s16_s32(u2x.val[1]), vreinterpretq_s16_s32(zero), 1));

  u2x = vzipq_s32(u2x.val[0], u2x.val[1]);
#if AOM_ARCH_AARCH64
  u2 = vreinterpretq_s32_s64(vzip1q_s64(vreinterpretq_s64_s32(u2x.val[0]),
                                        vreinterpretq_s64_s32(u2x.val[1])));
#else
  u2 = vcombine_s32(vget_low_s32(u2x.val[0]), vget_low_s32(u2x.val[1]));
#endif  // AOM_ARCH_AARCH64

  // u3
  int32x4x2_t u3x;
  u3x.val[0] = vreinterpretq_s32_s64(
      vmull_s32(vmovn_s64(vreinterpretq_s64_s32(u3)), mul));
  u3x.val[0] = vreinterpretq_s32_s64(
      vaddq_s64(vreinterpretq_s64_s32(u3x.val[0]), rnding));

  u3 = vextq_s32(u3, zero, 1);
  u3x.val[1] = vreinterpretq_s32_s64(
      vmull_s32(vmovn_s64(vreinterpretq_s64_s32(u3)), mul));
  u3x.val[1] = vreinterpretq_s32_s64(
      vaddq_s64(vreinterpretq_s64_s32(u3x.val[1]), rnding));

  u3x.val[0] = vreinterpretq_s32_s16(vextq_s16(
      vreinterpretq_s16_s32(u3x.val[0]), vreinterpretq_s16_s32(zero), 1));
  u3x.val[1] = vreinterpretq_s32_s16(vextq_s16(
      vreinterpretq_s16_s32(u3x.val[1]), vreinterpretq_s16_s32(zero), 1));

  u3x = vzipq_s32(u3x.val[0], u3x.val[1]);
#if AOM_ARCH_AARCH64
  u3 = vreinterpretq_s32_s64(vzip1q_s64(vreinterpretq_s64_s32(u3x.val[0]),
                                        vreinterpretq_s64_s32(u3x.val[1])));
#else
  u3 = vcombine_s32(vget_low_s32(u3x.val[0]), vget_low_s32(u3x.val[1]));
#endif  // AOM_ARCH_AARCH64

  out[0] = u0;
  out[1] = u1;
  out[2] = u2;
  out[3] = u3;

  if (!do_cols) {
    const int log_range = AOMMAX(16, bd + 6);
    const int32x4_t clamp_lo = vdupq_n_s32(-(1 << (log_range - 1)));
    const int32x4_t clamp_hi = vdupq_n_s32((1 << (log_range - 1)) - 1);
    round_shift_4x4(out, out_shift);
    highbd_clamp_s32_neon(out, out, &clamp_lo, &clamp_hi, 4);
  }
}

static void write_buffer_4x4(int32x4_t *in, uint16_t *output, int stride,
                             int fliplr, int flipud, int shift, int bd) {
  uint32x4_t u0, u1, u2, u3;
  uint16x4_t v0, v1, v2, v3;
  round_shift_4x4(in, shift);

  v0 = vld1_u16(output + 0 * stride);
  v1 = vld1_u16(output + 1 * stride);
  v2 = vld1_u16(output + 2 * stride);
  v3 = vld1_u16(output + 3 * stride);

  if (fliplr) {
    u0 = vrev64q_u32(vreinterpretq_u32_s32(in[0]));
    in[0] = vreinterpretq_s32_u32(vextq_u32(u0, u0, 2));
    u0 = vrev64q_u32(vreinterpretq_u32_s32(in[1]));
    in[1] = vreinterpretq_s32_u32(vextq_u32(u0, u0, 2));
    u0 = vrev64q_u32(vreinterpretq_u32_s32(in[2]));
    in[2] = vreinterpretq_s32_u32(vextq_u32(u0, u0, 2));
    u0 = vrev64q_u32(vreinterpretq_u32_s32(in[3]));
    in[3] = vreinterpretq_s32_u32(vextq_u32(u0, u0, 2));
  }

  if (flipud) {
    u0 = vaddw_u16(vreinterpretq_u32_s32(in[3]), v0);
    u1 = vaddw_u16(vreinterpretq_u32_s32(in[2]), v1);
    u2 = vaddw_u16(vreinterpretq_u32_s32(in[1]), v2);
    u3 = vaddw_u16(vreinterpretq_u32_s32(in[0]), v3);
  } else {
    u0 = vaddw_u16(vreinterpretq_u32_s32(in[0]), v0);
    u1 = vaddw_u16(vreinterpretq_u32_s32(in[1]), v1);
    u2 = vaddw_u16(vreinterpretq_u32_s32(in[2]), v2);
    u3 = vaddw_u16(vreinterpretq_u32_s32(in[3]), v3);
  }

  uint16x8_t u4 = vcombine_u16(vqmovn_u32(u0), vqmovn_u32(u1));
  uint16x8_t u5 = vcombine_u16(vqmovn_u32(u2), vqmovn_u32(u3));
  const uint16x8_t vmin = vdupq_n_u16(0);
  const uint16x8_t vmax = vdupq_n_u16((1 << bd) - 1);
  u4 = highbd_clamp_u16(&u4, &vmin, &vmax);
  u5 = highbd_clamp_u16(&u5, &vmin, &vmax);

  vst1_u16(output + 0 * stride, vget_low_u16(u4));
  vst1_u16(output + 1 * stride, vget_high_u16(u4));
  vst1_u16(output + 2 * stride, vget_low_u16(u5));
  vst1_u16(output + 3 * stride, vget_high_u16(u5));
}

static void iidentity4_neon(int32x4_t *in, int32x4_t *out, int bit, int do_cols,
                            int bd, int out_shift) {
  (void)bit;
  int32x4_t zero = vdupq_n_s32(0);
  int32x2_t fact = vdup_n_s32(NewSqrt2);
  int32x4x2_t a0;
  const int64x2_t rnding = vdupq_n_s64(1 << (NewSqrt2Bits - 1));

  for (int i = 0; i < 4; i++) {
    a0.val[0] = vreinterpretq_s32_s64(
        vmlal_s32(rnding, vmovn_s64(vreinterpretq_s64_s32(in[i])), fact));
    a0.val[0] = vreinterpretq_s32_s64(
        vshrq_n_s64(vreinterpretq_s64_s32(a0.val[0]), NewSqrt2Bits));
    a0.val[1] = vextq_s32(in[i], zero, 1);
    a0.val[1] = vreinterpretq_s32_s64(
        vmlal_s32(rnding, vmovn_s64(vreinterpretq_s64_s32(a0.val[1])), fact));
    a0.val[1] = vreinterpretq_s32_s64(
        vshrq_n_s64(vreinterpretq_s64_s32(a0.val[1]), NewSqrt2Bits));

    a0 = vzipq_s32(a0.val[0], a0.val[1]);
#if AOM_ARCH_AARCH64
    out[i] = vreinterpretq_s32_s64(vzip1q_s64(
        vreinterpretq_s64_s32(a0.val[0]), vreinterpretq_s64_s32(a0.val[1])));
#else
    out[i] = vextq_s32(vextq_s32(a0.val[0], a0.val[0], 2), a0.val[1], 2);
#endif
  }
  if (!do_cols) {
    const int log_range = AOMMAX(16, bd + 6);
    const int32x4_t clamp_lo = vdupq_n_s32(-(1 << (log_range - 1)));
    const int32x4_t clamp_hi = vdupq_n_s32((1 << (log_range - 1)) - 1);
    round_shift_4x4(out, out_shift);
    highbd_clamp_s32_neon(out, out, &clamp_lo, &clamp_hi, 4);
  }
}

void av1_inv_txfm2d_add_4x4_neon(const int32_t *input, uint16_t *output,
                                 int stride, TX_TYPE tx_type, int bd) {
  int32x4_t in[4];

  const int8_t *shift = av1_inv_txfm_shift_ls[TX_4X4];

  switch (tx_type) {
    case DCT_DCT:
      load_buffer_4x4(input, in);
      idct4x4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      idct4x4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd);
      break;
    case ADST_DCT:
      load_buffer_4x4(input, in);
      idct4x4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd);
      break;
    case DCT_ADST:
      load_buffer_4x4(input, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      idct4x4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd);
      break;
    case ADST_ADST:
      load_buffer_4x4(input, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd);
      break;
    case FLIPADST_DCT:
      load_buffer_4x4(input, in);
      idct4x4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 0, 1, -shift[1], bd);
      break;
    case DCT_FLIPADST:
      load_buffer_4x4(input, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      idct4x4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 1, 0, -shift[1], bd);
      break;
    case FLIPADST_FLIPADST:
      load_buffer_4x4(input, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 1, 1, -shift[1], bd);
      break;
    case ADST_FLIPADST:
      load_buffer_4x4(input, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 1, 0, -shift[1], bd);
      break;
    case FLIPADST_ADST:
      load_buffer_4x4(input, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 0, 1, -shift[1], bd);
      break;
    case IDTX:
      load_buffer_4x4(input, in);
      iidentity4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      iidentity4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd);
      break;
    case V_DCT:
      load_buffer_4x4(input, in);
      iidentity4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      idct4x4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd);
      break;
    case H_DCT:
      load_buffer_4x4(input, in);
      idct4x4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      iidentity4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd);
      break;
    case V_ADST:
      load_buffer_4x4(input, in);
      iidentity4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd);
      break;
    case H_ADST:
      load_buffer_4x4(input, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      iidentity4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd);
      break;
    case V_FLIPADST:
      load_buffer_4x4(input, in);
      iidentity4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 0, 1, -shift[1], bd);
      break;
    case H_FLIPADST:
      load_buffer_4x4(input, in);
      iadst4x4_neon(in, in, INV_COS_BIT, 0, bd, 0);
      transpose_4x4(in, in);
      iidentity4_neon(in, in, INV_COS_BIT, 1, bd, 0);
      write_buffer_4x4(in, output, stride, 1, 0, -shift[1], bd);
      break;
    default: assert(0);
  }
}

// 8x8
static void load_buffer_8x8(const int32_t *coeff, int32x4_t *in) {
  in[0] = vld1q_s32(coeff + 0);
  in[1] = vld1q_s32(coeff + 4);
  in[2] = vld1q_s32(coeff + 8);
  in[3] = vld1q_s32(coeff + 12);
  in[4] = vld1q_s32(coeff + 16);
  in[5] = vld1q_s32(coeff + 20);
  in[6] = vld1q_s32(coeff + 24);
  in[7] = vld1q_s32(coeff + 28);
  in[8] = vld1q_s32(coeff + 32);
  in[9] = vld1q_s32(coeff + 36);
  in[10] = vld1q_s32(coeff + 40);
  in[11] = vld1q_s32(coeff + 44);
  in[12] = vld1q_s32(coeff + 48);
  in[13] = vld1q_s32(coeff + 52);
  in[14] = vld1q_s32(coeff + 56);
  in[15] = vld1q_s32(coeff + 60);
}

static void idct8x8_neon(int32x4_t *in, int32x4_t *out, int bit, int do_cols,
                         int bd, int out_shift) {
  const int32_t *cospi = cospi_arr(bit);
  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
  const int32x4_t clamp_lo = vdupq_n_s32(-(1 << (log_range - 1)));
  const int32x4_t clamp_hi = vdupq_n_s32((1 << (log_range - 1)) - 1);
  int32x4_t u0, u1, u2, u3, u4, u5, u6, u7;
  int32x4_t v0, v1, v2, v3, v4, v5, v6, v7;
  int32x4_t x, y;
  int col;
  const int32x4_t rnding = vdupq_n_s32(1 << (bit - 1));
  const int32x4_t v_bit = vdupq_n_s32(-bit);
  // Note:
  //  Even column: 0, 2, ..., 14
  //  Odd column: 1, 3, ..., 15
  //  one even column plus one odd column constructs one row (8 coeffs)
  //  total we have 8 rows (8x8).
  for (col = 0; col < 2; ++col) {
    // stage 0
    // stage 1
    // stage 2
    u0 = in[0 * 2 + col];
    u1 = in[4 * 2 + col];
    u2 = in[2 * 2 + col];
    u3 = in[6 * 2 + col];

    x = vmulq_n_s32(in[1 * 2 + col], cospi[56]);
    u4 = vmlaq_n_s32(x, in[7 * 2 + col], -cospi[8]);
    u4 = vaddq_s32(u4, rnding);
    u4 = vshlq_s32(u4, v_bit);

    x = vmulq_n_s32(in[1 * 2 + col], cospi[8]);
    u7 = vmlaq_n_s32(x, in[7 * 2 + col], cospi[56]);
    u7 = vaddq_s32(u7, rnding);
    u7 = vshlq_s32(u7, v_bit);

    x = vmulq_n_s32(in[5 * 2 + col], cospi[24]);
    u5 = vmlaq_n_s32(x, in[3 * 2 + col], -cospi[40]);
    u5 = vaddq_s32(u5, rnding);
    u5 = vshlq_s32(u5, v_bit);

    x = vmulq_n_s32(in[5 * 2 + col], cospi[40]);
    u6 = vmlaq_n_s32(x, in[3 * 2 + col], cospi[24]);
    u6 = vaddq_s32(u6, rnding);
    u6 = vshlq_s32(u6, v_bit);

    // stage 3
    x = vmulq_n_s32(u0, cospi[32]);
    y = vmulq_n_s32(u1, cospi[32]);
    v0 = vaddq_s32(x, y);
    v0 = vaddq_s32(v0, rnding);
    v0 = vshlq_s32(v0, v_bit);

    v1 = vsubq_s32(x, y);
    v1 = vaddq_s32(v1, rnding);
    v1 = vshlq_s32(v1, v_bit);

    x = vmulq_n_s32(u2, cospi[48]);
    v2 = vmlaq_n_s32(x, u3, -cospi[16]);
    v2 = vaddq_s32(v2, rnding);
    v2 = vshlq_s32(v2, v_bit);

    x = vmulq_n_s32(u2, cospi[16]);
    v3 = vmlaq_n_s32(x, u3, cospi[48]);
    v3 = vaddq_s32(v3, rnding);
    v3 = vshlq_s32(v3, v_bit);

    addsub_neon(u4, u5, &v4, &v5, &clamp_lo, &clamp_hi);
    addsub_neon(u7, u6, &v7, &v6, &clamp_lo, &clamp_hi);

    // stage 4
    addsub_neon(v0, v3, &u0, &u3, &clamp_lo, &clamp_hi);
    addsub_neon(v1, v2, &u1, &u2, &clamp_lo, &clamp_hi);
    u4 = v4;
    u7 = v7;

    x = vmulq_n_s32(v5, cospi[32]);
    y = vmulq_n_s32(v6, cospi[32]);
    u6 = vaddq_s32(y, x);
    u6 = vaddq_s32(u6, rnding);
    u6 = vshlq_s32(u6, v_bit);

    u5 = vsubq_s32(y, x);
    u5 = vaddq_s32(u5, rnding);
    u5 = vshlq_s32(u5, v_bit);

    // stage 5
    addsub_neon(u0, u7, out + 0 * 2 + col, out + 7 * 2 + col, &clamp_lo,
                &clamp_hi);
    addsub_neon(u1, u6, out + 1 * 2 + col, out + 6 * 2 + col, &clamp_lo,
                &clamp_hi);
    addsub_neon(u2, u5, out + 2 * 2 + col, out + 5 * 2 + col, &clamp_lo,
                &clamp_hi);
    addsub_neon(u3, u4, out + 3 * 2 + col, out + 4 * 2 + col, &clamp_lo,
                &clamp_hi);
  }

  if (!do_cols) {
    const int log_range_out = AOMMAX(16, bd + 6);
    const int32x4_t clamp_lo_out = vdupq_n_s32(-(1 << (log_range_out - 1)));
    const int32x4_t clamp_hi_out = vdupq_n_s32((1 << (log_range_out - 1)) - 1);
    round_shift_8x8(out, out_shift);
    highbd_clamp_s32_neon(out, out, &clamp_lo_out, &clamp_hi_out, 16);
  }
}

static void iadst8x8_neon(int32x4_t *in, int32x4_t *out, int bit, int do_cols,
                          int bd, int out_shift) {
  const int32_t *cospi = cospi_arr(bit);
  const int32x4_t kZero = vdupq_n_s32(0);
  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
  const int32x4_t clamp_lo = vdupq_n_s32(-(1 << (log_range - 1)));
  const int32x4_t clamp_hi = vdupq_n_s32((1 << (log_range - 1)) - 1);
  int32x4_t u[8], v[8], x;
  const int32x4_t v_bit = vdupq_n_s32(-bit);
  const int32x4_t rnding = vdupq_n_s32(1 << (bit - 1));
  // stage 0-1-2
  // (1)
  u[0] = vmlaq_n_s32(rnding, in[14], cospi[4]);
  u[0] = vmlaq_n_s32(u[0], in[0], cospi[60]);
  u[0] = vshlq_s32(u[0], v_bit);

  u[1] = vmlaq_n_s32(rnding, in[14], cospi[60]);
  u[1] = vmlsq_n_s32(u[1], in[0], cospi[4]);
  u[1] = vshlq_s32(u[1], v_bit);

  // (2)
  u[2] = vmlaq_n_s32(rnding, in[10], cospi[20]);
  u[2] = vmlaq_n_s32(u[2], in[4], cospi[44]);
  u[2] = vshlq_s32(u[2], v_bit);

  u[3] = vmlaq_n_s32(rnding, in[10], cospi[44]);
  u[3] = vmlsq_n_s32(u[3], in[4], cospi[20]);
  u[3] = vshlq_s32(u[3], v_bit);

  // (3)
  u[4] = vmlaq_n_s32(rnding, in[6], cospi[36]);
  u[4] = vmlaq_n_s32(u[4], in[8], cospi[28]);
  u[4] = vshlq_s32(u[4], v_bit);

  u[5] = vmlaq_n_s32(rnding, in[6], cospi[28]);
  u[5] = vmlsq_n_s32(u[5], in[8], cospi[36]);
  u[5] = vshlq_s32(u[5], v_bit);

  // (4)
  u[6] = vmlaq_n_s32(rnding, in[2], cospi[52]);
  u[6] = vmlaq_n_s32(u[6], in[12], cospi[12]);
  u[6] = vshlq_s32(u[6], v_bit);

  u[7] = vmlaq_n_s32(rnding, in[2], cospi[12]);
  u[7] = vmlsq_n_s32(u[7], in[12], cospi[52]);
  u[7] = vshlq_s32(u[7], v_bit);

  // stage 3
  addsub_neon(u[0], u[4], &v[0], &v[4], &clamp_lo, &clamp_hi);
  addsub_neon(u[1], u[5], &v[1], &v[5], &clamp_lo, &clamp_hi);
  addsub_neon(u[2], u[6], &v[2], &v[6], &clamp_lo, &clamp_hi);
  addsub_neon(u[3], u[7], &v[3], &v[7], &clamp_lo, &clamp_hi);

  // stage 4
  u[0] = v[0];
  u[1] = v[1];
  u[2] = v[2];
  u[3] = v[3];

  u[4] = vmlaq_n_s32(rnding, v[4], cospi[16]);
  u[4] = vmlaq_n_s32(u[4], v[5], cospi[48]);
  u[4] = vshlq_s32(u[4], v_bit);

  u[5] = vmlaq_n_s32(rnding, v[4], cospi[48]);
  u[5] = vmlsq_n_s32(u[5], v[5], cospi[16]);
  u[5] = vshlq_s32(u[5], v_bit);

  u[6] = vmlaq_n_s32(rnding, v[7], cospi[16]);
  u[6] = vmlsq_n_s32(u[6], v[6], cospi[48]);
  u[6] = vshlq_s32(u[6], v_bit);

  u[7] = vmlaq_n_s32(rnding, v[7], cospi[48]);
  u[7] = vmlaq_n_s32(u[7], v[6], cospi[16]);
  u[7] = vshlq_s32(u[7], v_bit);

  // stage 5
  addsub_neon(u[0], u[2], &v[0], &v[2], &clamp_lo, &clamp_hi);
  addsub_neon(u[1], u[3], &v[1], &v[3], &clamp_lo, &clamp_hi);
  addsub_neon(u[4], u[6], &v[4], &v[6], &clamp_lo, &clamp_hi);
  addsub_neon(u[5], u[7], &v[5], &v[7], &clamp_lo, &clamp_hi);

  // stage 6
  u[0] = v[0];
  u[1] = v[1];
  u[4] = v[4];
  u[5] = v[5];

  v[0] = vmlaq_n_s32(rnding, v[2], cospi[32]);
  x = vmulq_n_s32(v[3], cospi[32]);
  u[2] = vaddq_s32(v[0], x);
  u[2] = vshlq_s32(u[2], v_bit);

  u[3] = vsubq_s32(v[0], x);
  u[3] = vshlq_s32(u[3], v_bit);

  v[0] = vmlaq_n_s32(rnding, v[6], cospi[32]);
  x = vmulq_n_s32(v[7], cospi[32]);
  u[6] = vaddq_s32(v[0], x);
  u[6] = vshlq_s32(u[6], v_bit);

  u[7] = vsubq_s32(v[0], x);
  u[7] = vshlq_s32(u[7], v_bit);

  // stage 7
  if (do_cols) {
    out[0] = u[0];
    out[2] = vsubq_s32(kZero, u[4]);
    out[4] = u[6];
    out[6] = vsubq_s32(kZero, u[2]);
    out[8] = u[3];
    out[10] = vsubq_s32(kZero, u[7]);
    out[12] = u[5];
    out[14] = vsubq_s32(kZero, u[1]);
  } else {
    const int log_range_out = AOMMAX(16, bd + 6);
    const int32x4_t clamp_lo_out = vdupq_n_s32(-(1 << (log_range_out - 1)));
    const int32x4_t clamp_hi_out = vdupq_n_s32((1 << (log_range_out - 1)) - 1);
    const int32x4_t v_shift = vdupq_n_s32(-out_shift);
    int32x4_t offset = vdupq_n_s32((1 << out_shift) >> 1);
    neg_shift_neon(&u[0], &u[4], out + 0, out + 2, &clamp_lo_out, &clamp_hi_out,
                   &v_shift, &offset);
    neg_shift_neon(&u[6], &u[2], out + 4, out + 6, &clamp_lo_out, &clamp_hi_out,
                   &v_shift, &offset);
    neg_shift_neon(&u[3], &u[7], out + 8, out + 10, &clamp_lo_out,
                   &clamp_hi_out, &v_shift, &offset);
    neg_shift_neon(&u[5], &u[1], out + 12, out + 14, &clamp_lo_out,
                   &clamp_hi_out, &v_shift, &offset);
  }

  // Odd 8 points: 1, 3, ..., 15
  // stage 0
  // stage 1
  // stage 2
  // (1)
  u[0] = vmlaq_n_s32(rnding, in[15], cospi[4]);
  u[0] = vmlaq_n_s32(u[0], in[1], cospi[60]);
  u[0] = vshlq_s32(u[0], v_bit);

  u[1] = vmlaq_n_s32(rnding, in[15], cospi[60]);
  u[1] = vmlsq_n_s32(u[1], in[1], cospi[4]);
  u[1] = vshlq_s32(u[1], v_bit);

  // (2)
  u[2] = vmlaq_n_s32(rnding, in[11], cospi[20]);
  u[2] = vmlaq_n_s32(u[2], in[5], cospi[44]);
  u[2] = vshlq_s32(u[2], v_bit);

  u[3] = vmlaq_n_s32(rnding, in[11], cospi[44]);
  u[3] = vmlsq_n_s32(u[3], in[5], cospi[20]);
  u[3] = vshlq_s32(u[3], v_bit);

  // (3)
  u[4] = vmlaq_n_s32(rnding, in[7], cospi[36]);
  u[4] = vmlaq_n_s32(u[4], in[9], cospi[28]);
  u[4] = vshlq_s32(u[4], v_bit);

  u[5] = vmlaq_n_s32(rnding, in[7], cospi[28]);
  u[5] = vmlsq_n_s32(u[5], in[9], cospi[36]);
  u[5] = vshlq_s32(u[5], v_bit);

  // (4)
  u[6] = vmlaq_n_s32(rnding, in[3], cospi[52]);
  u[6] = vmlaq_n_s32(u[6], in[13], cospi[12]);
  u[6] = vshlq_s32(u[6], v_bit);

  u[7] = vmlaq_n_s32(rnding, in[3], cospi[12]);
  u[7] = vmlsq_n_s32(u[7], in[13], cospi[52]);
  u[7] = vshlq_s32(u[7], v_bit);

  // stage 3
  addsub_neon(u[0], u[4], &v[0], &v[4], &clamp_lo, &clamp_hi);
  addsub_neon(u[1], u[5], &v[1], &v[5], &clamp_lo, &clamp_hi);
  addsub_neon(u[2], u[6], &v[2], &v[6], &clamp_lo, &clamp_hi);
  addsub_neon(u[3], u[7], &v[3], &v[7], &clamp_lo, &clamp_hi);

  // stage 4
  u[0] = v[0];
  u[1] = v[1];
  u[2] = v[2];
  u[3] = v[3];

  u[4] = vmlaq_n_s32(rnding, v[4], cospi[16]);
  u[4] = vmlaq_n_s32(u[4], v[5], cospi[48]);
  u[4] = vshlq_s32(u[4], v_bit);

  u[5] = vmlaq_n_s32(rnding, v[4], cospi[48]);
  u[5] = vmlsq_n_s32(u[5], v[5], cospi[16]);
  u[5] = vshlq_s32(u[5], v_bit);

  u[6] = vmlaq_n_s32(rnding, v[7], cospi[16]);
  u[6] = vmlsq_n_s32(u[6], v[6], cospi[48]);
  u[6] = vshlq_s32(u[6], v_bit);

  u[7] = vmlaq_n_s32(rnding, v[6], cospi[16]);
  u[7] = vmlaq_n_s32(u[7], v[7], cospi[48]);
  u[7] = vshlq_s32(u[7], v_bit);

  // stage 5
  addsub_neon(u[0], u[2], &v[0], &v[2], &clamp_lo, &clamp_hi);
  addsub_neon(u[1], u[3], &v[1], &v[3], &clamp_lo, &clamp_hi);
  addsub_neon(u[4], u[6], &v[4], &v[6], &clamp_lo, &clamp_hi);
  addsub_neon(u[5], u[7], &v[5], &v[7], &clamp_lo, &clamp_hi);

  // stage 6
  u[0] = v[0];
  u[1] = v[1];
  u[4] = v[4];
  u[5] = v[5];

  v[0] = vmlaq_n_s32(rnding, v[2], cospi[32]);
  x = vmulq_n_s32(v[3], cospi[32]);
  u[2] = vaddq_s32(v[0], x);
  u[2] = vshlq_s32(u[2], v_bit);

  u[3] = vsubq_s32(v[0], x);
  u[3] = vshlq_s32(u[3], v_bit);

  v[0] = vmlaq_n_s32(rnding, v[6], cospi[32]);
  x = vmulq_n_s32(v[7], cospi[32]);
  u[6] = vaddq_s32(v[0], x);
  u[6] = vshlq_s32(u[6], v_bit);

  u[7] = vsubq_s32(v[0], x);
  u[7] = vshlq_s32(u[7], v_bit);

  // stage 7
  if (do_cols) {
    out[1] = u[0];
    out[3] = vsubq_s32(kZero, u[4]);
    out[5] = u[6];
    out[7] = vsubq_s32(kZero, u[2]);
    out[9] = u[3];
    out[11] = vsubq_s32(kZero, u[7]);
    out[13] = u[5];
    out[15] = vsubq_s32(kZero, u[1]);
  } else {
    const int log_range_out = AOMMAX(16, bd + 6);
    const int32x4_t clamp_lo_out = vdupq_n_s32(-(1 << (log_range_out - 1)));
    const int32x4_t clamp_hi_out = vdupq_n_s32((1 << (log_range_out - 1)) - 1);
    const int32x4_t v_shift = vdupq_n_s32(-out_shift);
    int32x4_t offset = vdupq_n_s32((1 << out_shift) >> 1);
    neg_shift_neon(&u[0], &u[4], out + 1, out + 3, &clamp_lo_out, &clamp_hi_out,
                   &v_shift, &offset);
    neg_shift_neon(&u[6], &u[2], out + 5, out + 7, &clamp_lo_out, &clamp_hi_out,
                   &v_shift, &offset);
    neg_shift_neon(&u[3], &u[7], out + 9, out + 11, &clamp_lo_out,
                   &clamp_hi_out, &v_shift, &offset);
    neg_shift_neon(&u[5], &u[1], out + 13, out + 15, &clamp_lo_out,
                   &clamp_hi_out, &v_shift, &offset);
  }
}

static void iidentity8_neon(int32x4_t *in, int32x4_t *out, int bit, int do_cols,
                            int bd, int out_shift) {
  (void)bit;
  out[0] = vaddq_s32(in[0], in[0]);
  out[1] = vaddq_s32(in[1], in[1]);
  out[2] = vaddq_s32(in[2], in[2]);
  out[3] = vaddq_s32(in[3], in[3]);
  out[4] = vaddq_s32(in[4], in[4]);
  out[5] = vaddq_s32(in[5], in[5]);
  out[6] = vaddq_s32(in[6], in[6]);
  out[7] = vaddq_s32(in[7], in[7]);

  if (!do_cols) {
    const int log_range = AOMMAX(16, bd + 6);
    const int32x4_t clamp_lo = vdupq_n_s32(-(1 << (log_range - 1)));
    const int32x4_t clamp_hi = vdupq_n_s32((1 << (log_range - 1)) - 1);
    round_shift_4x4(out, out_shift);
    round_shift_4x4(out + 4, out_shift);
    highbd_clamp_s32_neon(out, out, &clamp_lo, &clamp_hi, 8);
  }
}

static uint16x8_t get_recon_8x8(const uint16x8_t pred, int32x4_t res_lo,
                                int32x4_t res_hi, int fliplr, int bd) {
  uint16x8x2_t x;

  if (fliplr) {
    res_lo = vrev64q_s32(res_lo);
    res_lo = vextq_s32(res_lo, res_lo, 2);
    res_hi = vrev64q_s32(res_hi);
    res_hi = vextq_s32(res_hi, res_hi, 2);
    x.val[0] = vreinterpretq_u16_s32(
        vaddw_s16(res_hi, vreinterpret_s16_u16(vget_low_u16(pred))));
    x.val[1] = vreinterpretq_u16_s32(
        vaddw_s16(res_lo, vreinterpret_s16_u16(vget_high_u16(pred))));

  } else {
    x.val[0] = vreinterpretq_u16_s32(
        vaddw_s16(res_lo, vreinterpret_s16_u16(vget_low_u16(pred))));
    x.val[1] = vreinterpretq_u16_s32(
        vaddw_s16(res_hi, vreinterpret_s16_u16(vget_high_u16(pred))));
  }

  uint16x8_t x2 = vcombine_u16(vqmovn_u32(vreinterpretq_u32_u16(x.val[0])),
                               vqmovn_u32(vreinterpretq_u32_u16(x.val[1])));
  const uint16x8_t vmin = vdupq_n_u16(0);
  const uint16x8_t vmax = vdupq_n_u16((1 << bd) - 1);
  return highbd_clamp_u16(&x2, &vmin, &vmax);
}

static void write_buffer_8x8(int32x4_t *in, uint16_t *output, int stride,
                             int fliplr, int flipud, int shift, int bd) {
  uint16x8_t u0, u1, u2, u3, u4, u5, u6, u7;
  uint16x8_t v0, v1, v2, v3, v4, v5, v6, v7;
  round_shift_8x8(in, shift);

  v0 = vld1q_u16(output + 0 * stride);
  v1 = vld1q_u16(output + 1 * stride);
  v2 = vld1q_u16(output + 2 * stride);
  v3 = vld1q_u16(output + 3 * stride);
  v4 = vld1q_u16(output + 4 * stride);
  v5 = vld1q_u16(output + 5 * stride);
  v6 = vld1q_u16(output + 6 * stride);
  v7 = vld1q_u16(output + 7 * stride);

  if (flipud) {
    u0 = get_recon_8x8(v0, in[14], in[15], fliplr, bd);
    u1 = get_recon_8x8(v1, in[12], in[13], fliplr, bd);
    u2 = get_recon_8x8(v2, in[10], in[11], fliplr, bd);
    u3 = get_recon_8x8(v3, in[8], in[9], fliplr, bd);
    u4 = get_recon_8x8(v4, in[6], in[7], fliplr, bd);
    u5 = get_recon_8x8(v5, in[4], in[5], fliplr, bd);
    u6 = get_recon_8x8(v6, in[2], in[3], fliplr, bd);
    u7 = get_recon_8x8(v7, in[0], in[1], fliplr, bd);
  } else {
    u0 = get_recon_8x8(v0, in[0], in[1], fliplr, bd);
    u1 = get_recon_8x8(v1, in[2], in[3], fliplr, bd);
    u2 = get_recon_8x8(v2, in[4], in[5], fliplr, bd);
    u3 = get_recon_8x8(v3, in[6], in[7], fliplr, bd);
    u4 = get_recon_8x8(v4, in[8], in[9], fliplr, bd);
    u5 = get_recon_8x8(v5, in[10], in[11], fliplr, bd);
    u6 = get_recon_8x8(v6, in[12], in[13], fliplr, bd);
    u7 = get_recon_8x8(v7, in[14], in[15], fliplr, bd);
  }

  vst1q_u16(output + 0 * stride, u0);
  vst1q_u16(output + 1 * stride, u1);
  vst1q_u16(output + 2 * stride, u2);
  vst1q_u16(output + 3 * stride, u3);
  vst1q_u16(output + 4 * stride, u4);
  vst1q_u16(output + 5 * stride, u5);
  vst1q_u16(output + 6 * stride, u6);
  vst1q_u16(output + 7 * stride, u7);
}

void av1_inv_txfm2d_add_8x8_neon(const int32_t *input, uint16_t *output,
                                 int stride, TX_TYPE tx_type, int bd) {
  int32x4_t in[16], out[16];
  const int8_t *shift = av1_inv_txfm_shift_ls[TX_8X8];

  switch (tx_type) {
    case DCT_DCT:
      load_buffer_8x8(input, in);
      idct8x8_neon(in, out, INV_COS_BIT, 0, bd, -shift[0]);
      transpose_8x8(out, in);
      idct8x8_neon(in, out, INV_COS_BIT, 1, bd, 0);
      write_buffer_8x8(out, output, stride, 0, 0, -shift[1], bd);
      break;
    case DCT_ADST:
      load_buffer_8x8(input, in);
      iadst8x8_neon(in, out, INV_COS_BIT, 0, bd, -shift[0]);
      transpose_8x8(out, in);
      idct8x8_neon(in, out, INV_COS_BIT, 1, bd, 0);
      write_buffer_8x8(out, output, stride, 0, 0, -shift[1], bd);
      break;
    case ADST_DCT:
      load_buffer_8x8(input, in);
      idct8x8_neon(in, out, INV_COS_BIT, 0, bd, -shift[0]);
      transpose_8x8(out, in);
      iadst8x8_neon(in, out, INV_COS_BIT, 1, bd, 0);
      write_buffer_8x8(out, output, stride, 0, 0, -shift[1], bd);
      break;
    case ADST_ADST:
      load_buffer_8x8(input, in);
      iadst8x8_neon(in, out, INV_COS_BIT, 0, bd, -shift[0]);
      transpose_8x8(out, in);
      iadst8x8_neon(in, out, INV_COS_BIT, 1, bd, 0);
      write_buffer_8x8(out, output, stride, 0, 0, -shift[1], bd);
      break;
    case FLIPADST_DCT:
      load_buffer_8x8(input, in);
      idct8x8_neon(in, out, INV_COS_BIT, 0, bd, -shift[0]);
      transpose_8x8(out, in);
      iadst8x8_neon(in, out, INV_COS_BIT, 1, bd, 0);
      write_buffer_8x8(out, output, stride, 0, 1, -shift[1], bd);
      break;
    case DCT_FLIPADST:
      load_buffer_8x8(input, in);
      iadst8x8_neon(in, out, INV_COS_BIT, 0, bd, -shift[0]);
      transpose_8x8(out, in);
      idct8x8_neon(in, out, INV_COS_BIT, 1, bd, 0);
      write_buffer_8x8(out, output, stride, 1, 0, -shift[1], bd);
      break;
    case ADST_FLIPADST:
      load_buffer_8x8(input, in);
      iadst8x8_neon(in, out, INV_COS_BIT, 0, bd, -shift[0]);
      transpose_8x8(out, in);
      iadst8x8_neon(in, out, INV_COS_BIT, 1, bd, 0);
      write_buffer_8x8(out, output, stride, 1, 0, -shift[1], bd);
      break;
    case FLIPADST_FLIPADST:
      load_buffer_8x8(input, in);
      iadst8x8_neon(in, out, INV_COS_BIT, 0, bd, -shift[0]);
      transpose_8x8(out, in);
      iadst8x8_neon(in, out, INV_COS_BIT, 1, bd, 0);
      write_buffer_8x8(out, output, stride, 1, 1, -shift[1], bd);
      break;
    case FLIPADST_ADST:
      load_buffer_8x8(input, in);
      iadst8x8_neon(in, out, INV_COS_BIT, 0, bd, -shift[0]);
      transpose_8x8(out, in);
      iadst8x8_neon(in, out, INV_COS_BIT, 1, bd, 0);
      write_buffer_8x8(out, output, stride, 0, 1, -shift[1], bd);
      break;
    default: assert(0);
  }
}

static void idct8x8_low1_neon(int32x4_t *in, int32x4_t *out, int bit,
                              int do_cols, int bd, int out_shift) {
  const int32_t *cospi = cospi_arr(bit);
  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
  int32x4_t clamp_lo = vdupq_n_s32(-(1 << (log_range - 1)));
  int32x4_t clamp_hi = vdupq_n_s32((1 << (log_range - 1)) - 1);
  int32x4_t x;
  const int32x4_t v_bit = vdupq_n_s32(-bit);
  const int32x4_t rnding = vdupq_n_s32(1 << (bit - 1));
  // stage 0-1-2-3
  x = vmulq_n_s32(in[0], cospi[32]);
  x = vaddq_s32(vshlq_s32(x, v_bit), rnding);

  // stage 4-5
  if (!do_cols) {
    const int log_range_out = AOMMAX(16, bd + 6);
    clamp_lo = vdupq_n_s32(-(1 << (log_range_out - 1)));
    clamp_hi = vdupq_n_s32((1 << (log_range_out - 1)) - 1);

    int32x4_t offset = vdupq_n_s32((1 << out_shift) >> 1);
    x = vaddq_s32(x, offset);
    x = vshlq_s32(x, vdupq_n_s32(-out_shift));
  }

  x = vmaxq_s32(x, clamp_lo);
  x = vminq_s32(x, clamp_hi);
  out[0] = x;
  out[1] = x;
  out[2] = x;
  out[3] = x;
  out[4] = x;
  out[5] = x;
  out[6] = x;
  out[7] = x;
}

static void idct8x8_new_neon(int32x4_t *in, int32x4_t *out, int bit,
                             int do_cols, int bd, int out_shift) {
  const int32_t *cospi = cospi_arr(bit);
  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
  const int32x4_t clamp_lo = vdupq_n_s32(-(1 << (log_range - 1)));
  const int32x4_t clamp_hi = vdupq_n_s32((1 << (log_range - 1)) - 1);
  int32x4_t u0, u1, u2, u3, u4, u5, u6, u7;
  int32x4_t v0, v1, v2, v3, v4, v5, v6, v7;
  int32x4_t x, y;
  const int32x4_t v_bit = vdupq_n_s32(-bit);
  const int32x4_t rnding = vdupq_n_s32(1 << (bit - 1));

  // stage 0
  // stage 1
  // stage 2
  u0 = in[0];
  u1 = in[4];
  u2 = in[2];
  u3 = in[6];

  x = vmlaq_n_s32(rnding, in[1], cospi[56]);
  u4 = vmlaq_n_s32(x, in[7], -cospi[8]);
  u4 = vshlq_s32(u4, v_bit);

  x = vmlaq_n_s32(rnding, in[1], cospi[8]);
  u7 = vmlaq_n_s32(x, in[7], cospi[56]);
  u7 = vshlq_s32(u7, v_bit);

  x = vmlaq_n_s32(rnding, in[5], cospi[24]);
  u5 = vmlaq_n_s32(x, in[3], -cospi[40]);
  u5 = vshlq_s32(u5, v_bit);

  x = vmlaq_n_s32(rnding, in[5], cospi[40]);
  u6 = vmlaq_n_s32(x, in[3], cospi[24]);
  u6 = vshlq_s32(u6, v_bit);

  // stage 3
  x = vmlaq_n_s32(rnding, u0, cospi[32]);
  y = vmulq_n_s32(u1, cospi[32]);
  v0 = vaddq_s32(x, y);
  v0 = vshlq_s32(v0, v_bit);

  v1 = vsubq_s32(x, y);
  v1 = vshlq_s32(v1, v_bit);

  x = vmlaq_n_s32(rnding, u2, cospi[48]);
  v2 = vmlaq_n_s32(x, u3, -cospi[16]);
  v2 = vshlq_s32(v2, v_bit);

  x = vmlaq_n_s32(rnding, u2, cospi[16]);
  v3 = vmlaq_n_s32(x, u3, cospi[48]);
  v3 = vshlq_s32(v3, v_bit);

  addsub_neon(u4, u5, &v4, &v5, &clamp_lo, &clamp_hi);
  addsub_neon(u7, u6, &v7, &v6, &clamp_lo, &clamp_hi);

  // stage 4
  addsub_neon(v0, v3, &u0, &u3, &clamp_lo, &clamp_hi);
  addsub_neon(v1, v2, &u1, &u2, &clamp_lo, &clamp_hi);
  u4 = v4;
  u7 = v7;

  x = vmulq_n_s32(v5, cospi[32]);
  y = vmlaq_n_s32(rnding, v6, cospi[32]);
  u6 = vaddq_s32(y, x);
  u6 = vshlq_s32(u6, v_bit);

  u5 = vsubq_s32(y, x);
  u5 = vshlq_s32(u5, v_bit);

  // stage 5
  addsub_neon(u0, u7, out + 0, out + 7, &clamp_lo, &clamp_hi);
  addsub_neon(u1, u6, out + 1, out + 6, &clamp_lo, &clamp_hi);
  addsub_neon(u2, u5, out + 2, out + 5, &clamp_lo, &clamp_hi);
  addsub_neon(u3, u4, out + 3, out + 4, &clamp_lo, &clamp_hi);

  if (!do_cols) {
    const int log_range_out = AOMMAX(16, bd + 6);
    const int32x4_t clamp_lo_out = vdupq_n_s32(-(1 << (log_range_out - 1)));
    const int32x4_t clamp_hi_out = vdupq_n_s32((1 << (log_range_out - 1)) - 1);
    round_shift_4x4(out, out_shift);
    round_shift_4x4(out + 4, out_shift);
    highbd_clamp_s32_neon(out, out, &clamp_lo_out, &clamp_hi_out, 8);
  }
}

static void iadst8x8_low1_neon(int32x4_t *in, int32x4_t *out, int bit,
                               int do_cols, int bd, int out_shift) {
  const int32_t *cospi = cospi_arr(bit);
  int32x4_t u[8], x;
  const int32x4_t v_bit = vdupq_n_s32(-bit);
  const int32x4_t rnding = vdupq_n_s32(1 << (bit - 1));
  // stage 0-2

  u[0] = vmlaq_n_s32(rnding, in[0], cospi[60]);
  u[0] = vshlq_s32(u[0], v_bit);

  u[1] = vmlaq_n_s32(rnding, in[0], cospi[4]);
  u[1] = vshlq_s32(vnegq_s32(u[1]), v_bit);

  // stage 3-4
  int32x4_t temp1, temp2;
  temp1 = vmlaq_n_s32(rnding, u[0], cospi[16]);
  temp1 = vmlaq_n_s32(temp1, u[1], cospi[48]);
  temp1 = vshlq_s32(temp1, v_bit);
  u[4] = temp1;

  temp2 = vmlaq_n_s32(rnding, u[0], cospi[48]);
  u[5] = vmlsq_n_s32(temp2, u[1], cospi[16]);
  u[5] = vshlq_s32(u[5], v_bit);

  // stage 5-6
  temp1 = vmlaq_n_s32(rnding, u[0], cospi[32]);
  x = vmulq_n_s32(u[1], cospi[32]);
  u[2] = vaddq_s32(temp1, x);
  u[2] = vshlq_s32(u[2], v_bit);

  u[3] = vsubq_s32(temp1, x);
  u[3] = vshlq_s32(u[3], v_bit);

  temp1 = vmlaq_n_s32(rnding, u[4], cospi[32]);
  x = vmulq_n_s32(u[5], cospi[32]);
  u[6] = vaddq_s32(temp1, x);
  u[6] = vshlq_s32(u[6], v_bit);

  u[7] = vsubq_s32(temp1, x);
  u[7] = vshlq_s32(u[7], v_bit);

  // stage 7
  if (do_cols) {
    out[0] = u[0];
    out[1] = vnegq_s32(u[4]);
    out[2] = u[6];
    out[3] = vnegq_s32(u[2]);
    out[4] = u[3];
    out[5] = vnegq_s32(u[7]);
    out[6] = u[5];
    out[7] = vnegq_s32(u[1]);
  } else {
    const int log_range_out = AOMMAX(16, bd + 6);
    const int32x4_t clamp_lo_out = vdupq_n_s32(-(1 << (log_range_out - 1)));
    const int32x4_t clamp_hi_out = vdupq_n_s32((1 << (log_range_out - 1)) - 1);
    const int32x4_t v_shift = vdupq_n_s32(-out_shift);
    int32x4_t offset = vdupq_n_s32((1 << out_shift) >> 1);
    neg_shift_neon(&u[0], &u[4], out + 0, out + 1, &clamp_lo_out, &clamp_hi_out,
                   &v_shift, &offset);
    neg_shift_neon(&u[6], &u[2], out + 2, out + 3, &clamp_lo_out, &clamp_hi_out,
                   &v_shift, &offset);
    neg_shift_neon(&u[3], &u[7], out + 4, out + 5, &clamp_lo_out, &clamp_hi_out,
                   &v_shift, &offset);
    neg_shift_neon(&u[5], &u[1], out + 6, out + 7, &clamp_lo_out, &clamp_hi_out,
                   &v_shift, &offset);
  }
}

static void iadst8x8_new_neon(int32x4_t *in, int32x4_t *out, int bit,
                              int do_cols, int bd, int out_shift) {
  const int32_t *cospi = cospi_arr(bit);
  // const int32x4_t rnding = vdupq_n_s32(1 << (bit - 1));
  const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8));
  const int32x4_t clamp_lo = vdupq_n_s32(-(1 << (log_range - 1)));
  const int32x4_t clamp_hi = vdupq_n_s32((1 << (log_range - 1)) - 1);
  int32x4_t u[8], v[8], x;
  const int32x4_t v_bit = vdupq_n_s32(-bit);
  const int32x4_t rnding = vdupq_n_s32(1 << (bit - 1));
  // stage 0-2

  u[0] = vmlaq_n_s32(rnding, in[7], cospi[4]);
  u[0] = vmlaq_n_s32(u[0], in[0], cospi[60]);
  u[0] = vshlq_s32(u[0], v_bit);

  u[1] = vmlaq_n_s32(rnding, in[7], cospi[60]);
  u[1] = vmlsq_n_s32(u[1], in[0], cospi[4]);
  u[1] = vshlq_s32(u[1], v_bit);

  // (2)
  u[2] = vmlaq_n_s32(rnding, in[5], cospi[20]);
  u[2] = vmlaq_n_s32(u[2], in[2], cospi[44]);
  u[2] = vshlq_s32(u[2], v_bit);

  u[3] = vmlaq_n_s32(rnding, in[5], cospi[44]);
  u[3] = vmlsq_n_s32(u[3], in[2], cospi[20]);
  u[3] = vshlq_s32(u[3], v_bit);

  // (3)
  u[4] = vmlaq_n_s32(rnding, in[3], cospi[36]);
  u[4] = vmlaq_n_s32(u[4], in[4], cospi[28]);
  u[4] = vshlq_s32(u[4], v_bit);

  u[5] = vmlaq_n_s32(rnding, in[3], cospi[28]);
  u[5] = vmlsq_n_s32(u[5], in[4], cospi[36]);
  u[5] = vshlq_s32(u[5], v_bit);

  // (4)
  u[6] = vmulq_n_s32(in[1], cospi[52]);
  u[6] = vmlaq_n_s32(u[6], in[6], cospi[12]);
  u[6] = vaddq_s32(u[6], rnding);
  u[6] = vshlq_s32(u[6], v_bit);

  u[7] = vmulq_n_s32(in[1], cospi[12]);
  u[7] = vmlsq_n_s32(u[7], in[6], cospi[52]);
  u[7] = vaddq_s32(u[7], rnding);
  u[7] = vshlq_s32(u[7], v_bit);

  // stage 3
  addsub_neon(u[0], u[4], &v[0], &v[4], &clamp_lo, &clamp_hi);
  addsub_neon(u[1], u[5], &v[1], &v[5], &clamp_lo, &clamp_hi);
  addsub_neon(u[2], u[6], &v[2], &v[6], &clamp_lo, &clamp_hi);
  addsub_neon(u[3], u[7], &v[3], &v[7], &clamp_lo, &clamp_hi);

  // stage 4
  u[0] = v[0];
  u[1] = v[1];
  u[2] = v[2];
  u[3] = v[3];

  u[4] = vmlaq_n_s32(rnding, v[4], cospi[16]);
  u[4] = vmlaq_n_s32(u[4], v[5], cospi[48]);
  u[4] = vshlq_s32(u[4], v_bit);

  u[5] = vmlaq_n_s32(rnding, v[4], cospi[48]);
  u[5] = vmlsq_n_s32(u[5], v[5], cospi[16]);
  u[5] = vshlq_s32(u[5], v_bit);

--> --------------------

--> maximum size reached

--> --------------------

Messung V0.5
C=97 H=95 G=95

¤ Dauer der Verarbeitung: 0.11 Sekunden  (vorverarbeitet)  ¤

*© Formatika GbR, Deutschland






Wurzel

Suchen

Beweissystem der NASA

Beweissystem Isabelle

NIST Cobol Testsuite

Cephes Mathematical Library

Wiener Entwicklungsmethode

Haftungshinweis

Die Informationen auf dieser Webseite wurden nach bestem Wissen sorgfältig zusammengestellt. Es wird jedoch weder Vollständigkeit, noch Richtigkeit, noch Qualität der bereit gestellten Informationen zugesichert.

Bemerkung:

Die farbliche Syntaxdarstellung und die Messung sind noch experimentell.






                                                                                                                                                                                                                                                                                                                                                                                                     


Neuigkeiten

     Aktuelles
     Motto des Tages

Software

     Produkte
     Quellcodebibliothek

Aktivitäten

     Artikel über Sicherheit
     Anleitung zur Aktivierung von SSL

Muße

     Gedichte
     Musik
     Bilder

Jenseits des Üblichen ....

Besucherstatistik

Besucherstatistik

Monitoring

Montastic status badge