Quellcodebibliothek Statistik Leitseite products/Sources/formale Sprachen/C/Firefox/third_party/aom/av1/encoder/x86/   (Browser von der Mozilla Stiftung Version 136.0.1©)  Datei vom 10.2.2025 mit Größe 124 kB image not shown  

Quelle  highbd_fwd_txfm_avx2.c   Sprache: C

 
/*
 * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
 *
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
 */

#include <assert.h>
#include <immintrin.h> /*AVX2*/

#include "config/aom_config.h"
#include "config/av1_rtcd.h"
#include "av1/common/av1_txfm.h"
#include "av1/encoder/av1_fwd_txfm1d_cfg.h"
#include "aom_dsp/txfm_common.h"
#include "aom_ports/mem.h"
#include "aom_dsp/x86/txfm_common_sse2.h"
#include "aom_dsp/x86/txfm_common_avx2.h"

static inline void load_buffer_8x8_avx2(const int16_t *input, __m256i *out,
                                        int stride, int flipud, int fliplr,
                                        int shift) {
  __m128i out1[8];
  if (!flipud) {
    out1[0] = _mm_load_si128((const __m128i *)(input + 0 * stride));
    out1[1] = _mm_load_si128((const __m128i *)(input + 1 * stride));
    out1[2] = _mm_load_si128((const __m128i *)(input + 2 * stride));
    out1[3] = _mm_load_si128((const __m128i *)(input + 3 * stride));
    out1[4] = _mm_load_si128((const __m128i *)(input + 4 * stride));
    out1[5] = _mm_load_si128((const __m128i *)(input + 5 * stride));
    out1[6] = _mm_load_si128((const __m128i *)(input + 6 * stride));
    out1[7] = _mm_load_si128((const __m128i *)(input + 7 * stride));

  } else {
    out1[7] = _mm_load_si128((const __m128i *)(input + 0 * stride));
    out1[6] = _mm_load_si128((const __m128i *)(input + 1 * stride));
    out1[5] = _mm_load_si128((const __m128i *)(input + 2 * stride));
    out1[4] = _mm_load_si128((const __m128i *)(input + 3 * stride));
    out1[3] = _mm_load_si128((const __m128i *)(input + 4 * stride));
    out1[2] = _mm_load_si128((const __m128i *)(input + 5 * stride));
    out1[1] = _mm_load_si128((const __m128i *)(input + 6 * stride));
    out1[0] = _mm_load_si128((const __m128i *)(input + 7 * stride));
  }
  if (!fliplr) {
    out[0] = _mm256_cvtepi16_epi32(out1[0]);
    out[1] = _mm256_cvtepi16_epi32(out1[1]);
    out[2] = _mm256_cvtepi16_epi32(out1[2]);
    out[3] = _mm256_cvtepi16_epi32(out1[3]);
    out[4] = _mm256_cvtepi16_epi32(out1[4]);
    out[5] = _mm256_cvtepi16_epi32(out1[5]);
    out[6] = _mm256_cvtepi16_epi32(out1[6]);
    out[7] = _mm256_cvtepi16_epi32(out1[7]);

  } else {
    out[0] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[0]));
    out[1] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[1]));
    out[2] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[2]));
    out[3] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[3]));
    out[4] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[4]));
    out[5] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[5]));
    out[6] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[6]));
    out[7] = _mm256_cvtepi16_epi32(mm_reverse_epi16(out1[7]));
  }
  out[0] = _mm256_slli_epi32(out[0], shift);
  out[1] = _mm256_slli_epi32(out[1], shift);
  out[2] = _mm256_slli_epi32(out[2], shift);
  out[3] = _mm256_slli_epi32(out[3], shift);
  out[4] = _mm256_slli_epi32(out[4], shift);
  out[5] = _mm256_slli_epi32(out[5], shift);
  out[6] = _mm256_slli_epi32(out[6], shift);
  out[7] = _mm256_slli_epi32(out[7], shift);
}
static inline void col_txfm_8x8_rounding(__m256i *in, int shift) {
  const __m256i rounding = _mm256_set1_epi32(1 << (shift - 1));

  in[0] = _mm256_add_epi32(in[0], rounding);
  in[1] = _mm256_add_epi32(in[1], rounding);
  in[2] = _mm256_add_epi32(in[2], rounding);
  in[3] = _mm256_add_epi32(in[3], rounding);
  in[4] = _mm256_add_epi32(in[4], rounding);
  in[5] = _mm256_add_epi32(in[5], rounding);
  in[6] = _mm256_add_epi32(in[6], rounding);
  in[7] = _mm256_add_epi32(in[7], rounding);

  in[0] = _mm256_srai_epi32(in[0], shift);
  in[1] = _mm256_srai_epi32(in[1], shift);
  in[2] = _mm256_srai_epi32(in[2], shift);
  in[3] = _mm256_srai_epi32(in[3], shift);
  in[4] = _mm256_srai_epi32(in[4], shift);
  in[5] = _mm256_srai_epi32(in[5], shift);
  in[6] = _mm256_srai_epi32(in[6], shift);
  in[7] = _mm256_srai_epi32(in[7], shift);
}
static inline void load_buffer_8x16_avx2(const int16_t *input, __m256i *out,
                                         int stride, int flipud, int fliplr,
                                         int shift) {
  const int16_t *topL = input;
  const int16_t *botL = input + 8 * stride;

  const int16_t *tmp;

  if (flipud) {
    tmp = topL;
    topL = botL;
    botL = tmp;
  }
  load_buffer_8x8_avx2(topL, out, stride, flipud, fliplr, shift);
  load_buffer_8x8_avx2(botL, out + 8, stride, flipud, fliplr, shift);
}
static inline void load_buffer_16xn_avx2(const int16_t *input, __m256i *out,
                                         int stride, int height, int outstride,
                                         int flipud, int fliplr) {
  __m256i out1[64];
  if (!flipud) {
    for (int i = 0; i < height; i++) {
      out1[i] = _mm256_loadu_si256((const __m256i *)(input + i * stride));
    }
  } else {
    for (int i = 0; i < height; i++) {
      out1[(height - 1) - i] =
          _mm256_loadu_si256((const __m256i *)(input + i * stride));
    }
  }
  if (!fliplr) {
    for (int i = 0; i < height; i++) {
      out[i * outstride] =
          _mm256_cvtepi16_epi32(_mm256_castsi256_si128(out1[i]));
      out[i * outstride + 1] =
          _mm256_cvtepi16_epi32(_mm256_extractf128_si256(out1[i], 1));
    }
  } else {
    for (int i = 0; i < height; i++) {
      out[i * outstride + 1] = _mm256_cvtepi16_epi32(
          mm_reverse_epi16(_mm256_castsi256_si128(out1[i])));
      out[i * outstride + 0] = _mm256_cvtepi16_epi32(
          mm_reverse_epi16(_mm256_extractf128_si256(out1[i], 1)));
    }
  }
}

static void fwd_txfm_transpose_8x8_avx2(const __m256i *in, __m256i *out,
                                        const int instride,
                                        const int outstride) {
  __m256i u0, u1, u2, u3, u4, u5, u6, u7;
  __m256i x0, x1;

  u0 = _mm256_unpacklo_epi32(in[0 * instride], in[1 * instride]);
  u1 = _mm256_unpackhi_epi32(in[0 * instride], in[1 * instride]);

  u2 = _mm256_unpacklo_epi32(in[2 * instride], in[3 * instride]);
  u3 = _mm256_unpackhi_epi32(in[2 * instride], in[3 * instride]);

  u4 = _mm256_unpacklo_epi32(in[4 * instride], in[5 * instride]);
  u5 = _mm256_unpackhi_epi32(in[4 * instride], in[5 * instride]);

  u6 = _mm256_unpacklo_epi32(in[6 * instride], in[7 * instride]);
  u7 = _mm256_unpackhi_epi32(in[6 * instride], in[7 * instride]);

  x0 = _mm256_unpacklo_epi64(u0, u2);
  x1 = _mm256_unpacklo_epi64(u4, u6);
  out[0 * outstride] = _mm256_permute2f128_si256(x0, x1, 0x20);
  out[4 * outstride] = _mm256_permute2f128_si256(x0, x1, 0x31);

  x0 = _mm256_unpackhi_epi64(u0, u2);
  x1 = _mm256_unpackhi_epi64(u4, u6);
  out[1 * outstride] = _mm256_permute2f128_si256(x0, x1, 0x20);
  out[5 * outstride] = _mm256_permute2f128_si256(x0, x1, 0x31);

  x0 = _mm256_unpacklo_epi64(u1, u3);
  x1 = _mm256_unpacklo_epi64(u5, u7);
  out[2 * outstride] = _mm256_permute2f128_si256(x0, x1, 0x20);
  out[6 * outstride] = _mm256_permute2f128_si256(x0, x1, 0x31);

  x0 = _mm256_unpackhi_epi64(u1, u3);
  x1 = _mm256_unpackhi_epi64(u5, u7);
  out[3 * outstride] = _mm256_permute2f128_si256(x0, x1, 0x20);
  out[7 * outstride] = _mm256_permute2f128_si256(x0, x1, 0x31);
}
static inline void round_shift_32_8xn_avx2(__m256i *in, int size, int bit,
                                           int stride) {
  if (bit < 0) {
    bit = -bit;
    __m256i round = _mm256_set1_epi32(1 << (bit - 1));
    for (int i = 0; i < size; ++i) {
      in[stride * i] = _mm256_add_epi32(in[stride * i], round);
      in[stride * i] = _mm256_srai_epi32(in[stride * i], bit);
    }
  } else if (bit > 0) {
    for (int i = 0; i < size; ++i) {
      in[stride * i] = _mm256_slli_epi32(in[stride * i], bit);
    }
  }
}
static inline void store_buffer_avx2(const __m256i *const in, int32_t *out,
                                     const int stride, const int out_size) {
  for (int i = 0; i < out_size; ++i) {
    _mm256_store_si256((__m256i *)(out), in[i]);
    out += stride;
  }
}
static inline void fwd_txfm_transpose_16x16_avx2(const __m256i *in,
                                                 __m256i *out) {
  fwd_txfm_transpose_8x8_avx2(&in[0], &out[0], 2, 2);
  fwd_txfm_transpose_8x8_avx2(&in[1], &out[16], 2, 2);
  fwd_txfm_transpose_8x8_avx2(&in[16], &out[1], 2, 2);
  fwd_txfm_transpose_8x8_avx2(&in[17], &out[17], 2, 2);
}

static inline __m256i av1_half_btf_avx2(const __m256i *w0, const __m256i *n0,
                                        const __m256i *w1, const __m256i *n1,
                                        const __m256i *rounding, int bit) {
  __m256i x, y;

  x = _mm256_mullo_epi32(*w0, *n0);
  y = _mm256_mullo_epi32(*w1, *n1);
  x = _mm256_add_epi32(x, y);
  x = _mm256_add_epi32(x, *rounding);
  x = _mm256_srai_epi32(x, bit);
  return x;
}
#define btf_32_avx2_type0(w0, w1, in0, in1, out0, out1, bit) \
  do {                                                       \
    const __m256i ww0 = _mm256_set1_epi32(w0);               \
    const __m256i ww1 = _mm256_set1_epi32(w1);               \
    const __m256i in0_w0 = _mm256_mullo_epi32(in0, ww0);     \
    const __m256i in1_w1 = _mm256_mullo_epi32(in1, ww1);     \
    out0 = _mm256_add_epi32(in0_w0, in1_w1);                 \
    round_shift_32_8xn_avx2(&out0, 1, -bit, 1);              \
    const __m256i in0_w1 = _mm256_mullo_epi32(in0, ww1);     \
    const __m256i in1_w0 = _mm256_mullo_epi32(in1, ww0);     \
    out1 = _mm256_sub_epi32(in0_w1, in1_w0);                 \
    round_shift_32_8xn_avx2(&out1, 1, -bit, 1);              \
  } while (0)

#define btf_32_type0_avx2_new(ww0, ww1, in0, in1, out0, out1, r, bit) \
  do {                                                                \
    const __m256i in0_w0 = _mm256_mullo_epi32(in0, ww0);              \
    const __m256i in1_w1 = _mm256_mullo_epi32(in1, ww1);              \
    out0 = _mm256_add_epi32(in0_w0, in1_w1);                          \
    out0 = _mm256_add_epi32(out0, r);                                 \
    out0 = _mm256_srai_epi32(out0, bit);                              \
    const __m256i in0_w1 = _mm256_mullo_epi32(in0, ww1);              \
    const __m256i in1_w0 = _mm256_mullo_epi32(in1, ww0);              \
    out1 = _mm256_sub_epi32(in0_w1, in1_w0);                          \
    out1 = _mm256_add_epi32(out1, r);                                 \
    out1 = _mm256_srai_epi32(out1, bit);                              \
  } while (0)

typedef void (*transform_1d_avx2)(__m256i *in, __m256i *out,
                                  const int8_t cos_bit, int instride,
                                  int outstride);
static void fdct8_avx2(__m256i *in, __m256i *out, const int8_t bit,
                       const int col_num, const int outstride) {
  const int32_t *cospi = cospi_arr(bit);
  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
  const __m256i cospim32 = _mm256_set1_epi32(-cospi[32]);
  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
  const __m256i cospi56 = _mm256_set1_epi32(cospi[56]);
  const __m256i cospi8 = _mm256_set1_epi32(cospi[8]);
  const __m256i cospi24 = _mm256_set1_epi32(cospi[24]);
  const __m256i cospi40 = _mm256_set1_epi32(cospi[40]);
  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
  __m256i u[8], v[8];
  for (int col = 0; col < col_num; ++col) {
    u[0] = _mm256_add_epi32(in[0 * col_num + col], in[7 * col_num + col]);
    v[7] = _mm256_sub_epi32(in[0 * col_num + col], in[7 * col_num + col]);
    u[1] = _mm256_add_epi32(in[1 * col_num + col], in[6 * col_num + col]);
    u[6] = _mm256_sub_epi32(in[1 * col_num + col], in[6 * col_num + col]);
    u[2] = _mm256_add_epi32(in[2 * col_num + col], in[5 * col_num + col]);
    u[5] = _mm256_sub_epi32(in[2 * col_num + col], in[5 * col_num + col]);
    u[3] = _mm256_add_epi32(in[3 * col_num + col], in[4 * col_num + col]);
    v[4] = _mm256_sub_epi32(in[3 * col_num + col], in[4 * col_num + col]);
    v[0] = _mm256_add_epi32(u[0], u[3]);
    v[3] = _mm256_sub_epi32(u[0], u[3]);
    v[1] = _mm256_add_epi32(u[1], u[2]);
    v[2] = _mm256_sub_epi32(u[1], u[2]);

    v[5] = _mm256_mullo_epi32(u[5], cospim32);
    v[6] = _mm256_mullo_epi32(u[6], cospi32);
    v[5] = _mm256_add_epi32(v[5], v[6]);
    v[5] = _mm256_add_epi32(v[5], rnding);
    v[5] = _mm256_srai_epi32(v[5], bit);

    u[0] = _mm256_mullo_epi32(u[5], cospi32);
    v[6] = _mm256_mullo_epi32(u[6], cospim32);
    v[6] = _mm256_sub_epi32(u[0], v[6]);
    v[6] = _mm256_add_epi32(v[6], rnding);
    v[6] = _mm256_srai_epi32(v[6], bit);

    // stage 3
    // type 0
    v[0] = _mm256_mullo_epi32(v[0], cospi32);
    v[1] = _mm256_mullo_epi32(v[1], cospi32);
    u[0] = _mm256_add_epi32(v[0], v[1]);
    u[0] = _mm256_add_epi32(u[0], rnding);
    u[0] = _mm256_srai_epi32(u[0], bit);

    u[1] = _mm256_sub_epi32(v[0], v[1]);
    u[1] = _mm256_add_epi32(u[1], rnding);
    u[1] = _mm256_srai_epi32(u[1], bit);

    // type 1
    v[0] = _mm256_mullo_epi32(v[2], cospi48);
    v[1] = _mm256_mullo_epi32(v[3], cospi16);
    u[2] = _mm256_add_epi32(v[0], v[1]);
    u[2] = _mm256_add_epi32(u[2], rnding);
    u[2] = _mm256_srai_epi32(u[2], bit);

    v[0] = _mm256_mullo_epi32(v[2], cospi16);
    v[1] = _mm256_mullo_epi32(v[3], cospi48);
    u[3] = _mm256_sub_epi32(v[1], v[0]);
    u[3] = _mm256_add_epi32(u[3], rnding);
    u[3] = _mm256_srai_epi32(u[3], bit);

    u[4] = _mm256_add_epi32(v[4], v[5]);
    u[5] = _mm256_sub_epi32(v[4], v[5]);
    u[6] = _mm256_sub_epi32(v[7], v[6]);
    u[7] = _mm256_add_epi32(v[7], v[6]);

    // stage 4
    // stage 5
    v[0] = _mm256_mullo_epi32(u[4], cospi56);
    v[1] = _mm256_mullo_epi32(u[7], cospi8);
    v[0] = _mm256_add_epi32(v[0], v[1]);
    v[0] = _mm256_add_epi32(v[0], rnding);
    out[1 * outstride + col] = _mm256_srai_epi32(v[0], bit);  // buf0[4]

    v[0] = _mm256_mullo_epi32(u[4], cospi8);
    v[1] = _mm256_mullo_epi32(u[7], cospi56);
    v[0] = _mm256_sub_epi32(v[1], v[0]);
    v[0] = _mm256_add_epi32(v[0], rnding);
    out[7 * outstride + col] = _mm256_srai_epi32(v[0], bit);  // buf0[7]

    v[0] = _mm256_mullo_epi32(u[5], cospi24);
    v[1] = _mm256_mullo_epi32(u[6], cospi40);
    v[0] = _mm256_add_epi32(v[0], v[1]);
    v[0] = _mm256_add_epi32(v[0], rnding);
    out[5 * outstride + col] = _mm256_srai_epi32(v[0], bit);  // buf0[5]

    v[0] = _mm256_mullo_epi32(u[5], cospi40);
    v[1] = _mm256_mullo_epi32(u[6], cospi24);
    v[0] = _mm256_sub_epi32(v[1], v[0]);
    v[0] = _mm256_add_epi32(v[0], rnding);
    out[3 * outstride + col] = _mm256_srai_epi32(v[0], bit);  // buf0[6]

    out[0 * outstride + col] = u[0];  // buf0[0]
    out[4 * outstride + col] = u[1];  // buf0[1]
    out[2 * outstride + col] = u[2];  // buf0[2]
    out[6 * outstride + col] = u[3];  // buf0[3]
  }
}
static void fadst8_avx2(__m256i *in, __m256i *out, const int8_t bit,
                        const int col_num, const int outstirde) {
  (void)col_num;
  const int32_t *cospi = cospi_arr(bit);
  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
  const __m256i cospim16 = _mm256_set1_epi32(-cospi[16]);
  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
  const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]);
  const __m256i cospi4 = _mm256_set1_epi32(cospi[4]);
  const __m256i cospim4 = _mm256_set1_epi32(-cospi[4]);
  const __m256i cospi60 = _mm256_set1_epi32(cospi[60]);
  const __m256i cospi20 = _mm256_set1_epi32(cospi[20]);
  const __m256i cospim20 = _mm256_set1_epi32(-cospi[20]);
  const __m256i cospi44 = _mm256_set1_epi32(cospi[44]);
  const __m256i cospi28 = _mm256_set1_epi32(cospi[28]);
  const __m256i cospi36 = _mm256_set1_epi32(cospi[36]);
  const __m256i cospim36 = _mm256_set1_epi32(-cospi[36]);
  const __m256i cospi52 = _mm256_set1_epi32(cospi[52]);
  const __m256i cospim52 = _mm256_set1_epi32(-cospi[52]);
  const __m256i cospi12 = _mm256_set1_epi32(cospi[12]);
  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
  const __m256i zero = _mm256_setzero_si256();
  __m256i u0, u1, u2, u3, u4, u5, u6, u7;
  __m256i v0, v1, v2, v3, v4, v5, v6, v7;
  __m256i x, y;
  for (int col = 0; col < col_num; ++col) {
    u0 = in[0 * col_num + col];
    u1 = _mm256_sub_epi32(zero, in[7 * col_num + col]);
    u2 = _mm256_sub_epi32(zero, in[3 * col_num + col]);
    u3 = in[4 * col_num + col];
    u4 = _mm256_sub_epi32(zero, in[1 * col_num + col]);
    u5 = in[6 * col_num + col];
    u6 = in[2 * col_num + col];
    u7 = _mm256_sub_epi32(zero, in[5 * col_num + col]);

    // stage 2
    v0 = u0;
    v1 = u1;

    x = _mm256_mullo_epi32(u2, cospi32);
    y = _mm256_mullo_epi32(u3, cospi32);
    v2 = _mm256_add_epi32(x, y);
    v2 = _mm256_add_epi32(v2, rnding);
    v2 = _mm256_srai_epi32(v2, bit);

    v3 = _mm256_sub_epi32(x, y);
    v3 = _mm256_add_epi32(v3, rnding);
    v3 = _mm256_srai_epi32(v3, bit);

    v4 = u4;
    v5 = u5;

    x = _mm256_mullo_epi32(u6, cospi32);
    y = _mm256_mullo_epi32(u7, cospi32);
    v6 = _mm256_add_epi32(x, y);
    v6 = _mm256_add_epi32(v6, rnding);
    v6 = _mm256_srai_epi32(v6, bit);

    v7 = _mm256_sub_epi32(x, y);
    v7 = _mm256_add_epi32(v7, rnding);
    v7 = _mm256_srai_epi32(v7, bit);

    // stage 3
    u0 = _mm256_add_epi32(v0, v2);
    u1 = _mm256_add_epi32(v1, v3);
    u2 = _mm256_sub_epi32(v0, v2);
    u3 = _mm256_sub_epi32(v1, v3);
    u4 = _mm256_add_epi32(v4, v6);
    u5 = _mm256_add_epi32(v5, v7);
    u6 = _mm256_sub_epi32(v4, v6);
    u7 = _mm256_sub_epi32(v5, v7);

    // stage 4
    v0 = u0;
    v1 = u1;
    v2 = u2;
    v3 = u3;

    x = _mm256_mullo_epi32(u4, cospi16);
    y = _mm256_mullo_epi32(u5, cospi48);
    v4 = _mm256_add_epi32(x, y);
    v4 = _mm256_add_epi32(v4, rnding);
    v4 = _mm256_srai_epi32(v4, bit);

    x = _mm256_mullo_epi32(u4, cospi48);
    y = _mm256_mullo_epi32(u5, cospim16);
    v5 = _mm256_add_epi32(x, y);
    v5 = _mm256_add_epi32(v5, rnding);
    v5 = _mm256_srai_epi32(v5, bit);

    x = _mm256_mullo_epi32(u6, cospim48);
    y = _mm256_mullo_epi32(u7, cospi16);
    v6 = _mm256_add_epi32(x, y);
    v6 = _mm256_add_epi32(v6, rnding);
    v6 = _mm256_srai_epi32(v6, bit);

    x = _mm256_mullo_epi32(u6, cospi16);
    y = _mm256_mullo_epi32(u7, cospi48);
    v7 = _mm256_add_epi32(x, y);
    v7 = _mm256_add_epi32(v7, rnding);
    v7 = _mm256_srai_epi32(v7, bit);

    // stage 5
    u0 = _mm256_add_epi32(v0, v4);
    u1 = _mm256_add_epi32(v1, v5);
    u2 = _mm256_add_epi32(v2, v6);
    u3 = _mm256_add_epi32(v3, v7);
    u4 = _mm256_sub_epi32(v0, v4);
    u5 = _mm256_sub_epi32(v1, v5);
    u6 = _mm256_sub_epi32(v2, v6);
    u7 = _mm256_sub_epi32(v3, v7);

    // stage 6
    x = _mm256_mullo_epi32(u0, cospi4);
    y = _mm256_mullo_epi32(u1, cospi60);
    v0 = _mm256_add_epi32(x, y);
    v0 = _mm256_add_epi32(v0, rnding);
    v0 = _mm256_srai_epi32(v0, bit);

    x = _mm256_mullo_epi32(u0, cospi60);
    y = _mm256_mullo_epi32(u1, cospim4);
    v1 = _mm256_add_epi32(x, y);
    v1 = _mm256_add_epi32(v1, rnding);
    v1 = _mm256_srai_epi32(v1, bit);

    x = _mm256_mullo_epi32(u2, cospi20);
    y = _mm256_mullo_epi32(u3, cospi44);
    v2 = _mm256_add_epi32(x, y);
    v2 = _mm256_add_epi32(v2, rnding);
    v2 = _mm256_srai_epi32(v2, bit);

    x = _mm256_mullo_epi32(u2, cospi44);
    y = _mm256_mullo_epi32(u3, cospim20);
    v3 = _mm256_add_epi32(x, y);
    v3 = _mm256_add_epi32(v3, rnding);
    v3 = _mm256_srai_epi32(v3, bit);

    x = _mm256_mullo_epi32(u4, cospi36);
    y = _mm256_mullo_epi32(u5, cospi28);
    v4 = _mm256_add_epi32(x, y);
    v4 = _mm256_add_epi32(v4, rnding);
    v4 = _mm256_srai_epi32(v4, bit);

    x = _mm256_mullo_epi32(u4, cospi28);
    y = _mm256_mullo_epi32(u5, cospim36);
    v5 = _mm256_add_epi32(x, y);
    v5 = _mm256_add_epi32(v5, rnding);
    v5 = _mm256_srai_epi32(v5, bit);

    x = _mm256_mullo_epi32(u6, cospi52);
    y = _mm256_mullo_epi32(u7, cospi12);
    v6 = _mm256_add_epi32(x, y);
    v6 = _mm256_add_epi32(v6, rnding);
    v6 = _mm256_srai_epi32(v6, bit);

    x = _mm256_mullo_epi32(u6, cospi12);
    y = _mm256_mullo_epi32(u7, cospim52);
    v7 = _mm256_add_epi32(x, y);
    v7 = _mm256_add_epi32(v7, rnding);
    v7 = _mm256_srai_epi32(v7, bit);

    // stage 7
    out[0 * outstirde + col] = v1;
    out[1 * outstirde + col] = v6;
    out[2 * outstirde + col] = v3;
    out[3 * outstirde + col] = v4;
    out[4 * outstirde + col] = v5;
    out[5 * outstirde + col] = v2;
    out[6 * outstirde + col] = v7;
    out[7 * outstirde + col] = v0;
  }
}
static void idtx8_avx2(__m256i *in, __m256i *out, const int8_t bit, int col_num,
                       int outstride) {
  (void)bit;
  (void)outstride;
  int num_iters = 8 * col_num;
  for (int i = 0; i < num_iters; i += 8) {
    out[i] = _mm256_add_epi32(in[i], in[i]);
    out[i + 1] = _mm256_add_epi32(in[i + 1], in[i + 1]);
    out[i + 2] = _mm256_add_epi32(in[i + 2], in[i + 2]);
    out[i + 3] = _mm256_add_epi32(in[i + 3], in[i + 3]);
    out[i + 4] = _mm256_add_epi32(in[i + 4], in[i + 4]);
    out[i + 5] = _mm256_add_epi32(in[i + 5], in[i + 5]);
    out[i + 6] = _mm256_add_epi32(in[i + 6], in[i + 6]);
    out[i + 7] = _mm256_add_epi32(in[i + 7], in[i + 7]);
  }
}
void av1_fwd_txfm2d_8x8_avx2(const int16_t *input, int32_t *coeff, int stride,
                             TX_TYPE tx_type, int bd) {
  __m256i in[8], out[8];
  const TX_SIZE tx_size = TX_8X8;
  const int8_t *shift = av1_fwd_txfm_shift_ls[tx_size];
  const int txw_idx = get_txw_idx(tx_size);
  const int txh_idx = get_txh_idx(tx_size);
  const int width = tx_size_wide[tx_size];
  const int width_div8 = (width >> 3);

  switch (tx_type) {
    case DCT_DCT:
      load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
      fdct8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                 width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      fdct8_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                 width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case ADST_DCT:
      load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
      fadst8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      fdct8_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                 width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case DCT_ADST:
      load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
      fdct8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                 width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      fadst8_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case ADST_ADST:
      load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
      fadst8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      fadst8_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case FLIPADST_DCT:
      load_buffer_8x8_avx2(input, in, stride, 1, 0, shift[0]);
      fadst8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      fdct8_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                 width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case DCT_FLIPADST:
      load_buffer_8x8_avx2(input, in, stride, 0, 1, shift[0]);
      fdct8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                 width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      fadst8_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case FLIPADST_FLIPADST:
      load_buffer_8x8_avx2(input, in, stride, 1, 1, shift[0]);
      fadst8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      fadst8_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case ADST_FLIPADST:
      load_buffer_8x8_avx2(input, in, stride, 0, 1, shift[0]);
      fadst8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      fadst8_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case FLIPADST_ADST:
      load_buffer_8x8_avx2(input, in, stride, 1, 0, shift[0]);
      fadst8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      fadst8_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case IDTX:
      load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
      idtx8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                 width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      idtx8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                 width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case V_DCT:
      load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
      fdct8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                 width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      idtx8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                 width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case H_DCT:
      load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
      idtx8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                 width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      fdct8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                 width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case V_ADST:
      load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
      fadst8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      idtx8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                 width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case H_ADST:
      load_buffer_8x8_avx2(input, in, stride, 0, 0, shift[0]);
      idtx8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                 width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      fadst8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case V_FLIPADST:
      load_buffer_8x8_avx2(input, in, stride, 1, 0, shift[0]);
      fadst8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      idtx8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                 width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    case H_FLIPADST:
      load_buffer_8x8_avx2(input, in, stride, 0, 1, shift[0]);
      idtx8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                 width_div8);
      col_txfm_8x8_rounding(out, -shift[1]);
      fwd_txfm_transpose_8x8_avx2(out, in, width_div8, width_div8);
      fadst8_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 8);
      break;
    default: assert(0);
  }
  (void)bd;
}

static void fdct16_avx2(__m256i *in, __m256i *out, const int8_t bit,
                        const int col_num, const int outstride) {
  const int32_t *cospi = cospi_arr(bit);
  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
  const __m256i cospim32 = _mm256_set1_epi32(-cospi[32]);
  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
  const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]);
  const __m256i cospim16 = _mm256_set1_epi32(-cospi[16]);
  const __m256i cospi56 = _mm256_set1_epi32(cospi[56]);
  const __m256i cospi8 = _mm256_set1_epi32(cospi[8]);
  const __m256i cospi24 = _mm256_set1_epi32(cospi[24]);
  const __m256i cospi40 = _mm256_set1_epi32(cospi[40]);
  const __m256i cospi60 = _mm256_set1_epi32(cospi[60]);
  const __m256i cospi4 = _mm256_set1_epi32(cospi[4]);
  const __m256i cospi28 = _mm256_set1_epi32(cospi[28]);
  const __m256i cospi36 = _mm256_set1_epi32(cospi[36]);
  const __m256i cospi44 = _mm256_set1_epi32(cospi[44]);
  const __m256i cospi20 = _mm256_set1_epi32(cospi[20]);
  const __m256i cospi12 = _mm256_set1_epi32(cospi[12]);
  const __m256i cospi52 = _mm256_set1_epi32(cospi[52]);
  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
  __m256i u[16], v[16], x;
  int col;

  // Calculate the column 0, 1, 2, 3
  for (col = 0; col < col_num; ++col) {
    // stage 0
    // stage 1
    u[0] = _mm256_add_epi32(in[0 * col_num + col], in[15 * col_num + col]);
    u[15] = _mm256_sub_epi32(in[0 * col_num + col], in[15 * col_num + col]);
    u[1] = _mm256_add_epi32(in[1 * col_num + col], in[14 * col_num + col]);
    u[14] = _mm256_sub_epi32(in[1 * col_num + col], in[14 * col_num + col]);
    u[2] = _mm256_add_epi32(in[2 * col_num + col], in[13 * col_num + col]);
    u[13] = _mm256_sub_epi32(in[2 * col_num + col], in[13 * col_num + col]);
    u[3] = _mm256_add_epi32(in[3 * col_num + col], in[12 * col_num + col]);
    u[12] = _mm256_sub_epi32(in[3 * col_num + col], in[12 * col_num + col]);
    u[4] = _mm256_add_epi32(in[4 * col_num + col], in[11 * col_num + col]);
    u[11] = _mm256_sub_epi32(in[4 * col_num + col], in[11 * col_num + col]);
    u[5] = _mm256_add_epi32(in[5 * col_num + col], in[10 * col_num + col]);
    u[10] = _mm256_sub_epi32(in[5 * col_num + col], in[10 * col_num + col]);
    u[6] = _mm256_add_epi32(in[6 * col_num + col], in[9 * col_num + col]);
    u[9] = _mm256_sub_epi32(in[6 * col_num + col], in[9 * col_num + col]);
    u[7] = _mm256_add_epi32(in[7 * col_num + col], in[8 * col_num + col]);
    u[8] = _mm256_sub_epi32(in[7 * col_num + col], in[8 * col_num + col]);

    // stage 2
    v[0] = _mm256_add_epi32(u[0], u[7]);
    v[7] = _mm256_sub_epi32(u[0], u[7]);
    v[1] = _mm256_add_epi32(u[1], u[6]);
    v[6] = _mm256_sub_epi32(u[1], u[6]);
    v[2] = _mm256_add_epi32(u[2], u[5]);
    v[5] = _mm256_sub_epi32(u[2], u[5]);
    v[3] = _mm256_add_epi32(u[3], u[4]);
    v[4] = _mm256_sub_epi32(u[3], u[4]);
    v[8] = u[8];
    v[9] = u[9];

    v[10] = _mm256_mullo_epi32(u[10], cospim32);
    x = _mm256_mullo_epi32(u[13], cospi32);
    v[10] = _mm256_add_epi32(v[10], x);
    v[10] = _mm256_add_epi32(v[10], rnding);
    v[10] = _mm256_srai_epi32(v[10], bit);

    v[13] = _mm256_mullo_epi32(u[10], cospi32);
    x = _mm256_mullo_epi32(u[13], cospim32);
    v[13] = _mm256_sub_epi32(v[13], x);
    v[13] = _mm256_add_epi32(v[13], rnding);
    v[13] = _mm256_srai_epi32(v[13], bit);

    v[11] = _mm256_mullo_epi32(u[11], cospim32);
    x = _mm256_mullo_epi32(u[12], cospi32);
    v[11] = _mm256_add_epi32(v[11], x);
    v[11] = _mm256_add_epi32(v[11], rnding);
    v[11] = _mm256_srai_epi32(v[11], bit);

    v[12] = _mm256_mullo_epi32(u[11], cospi32);
    x = _mm256_mullo_epi32(u[12], cospim32);
    v[12] = _mm256_sub_epi32(v[12], x);
    v[12] = _mm256_add_epi32(v[12], rnding);
    v[12] = _mm256_srai_epi32(v[12], bit);
    v[14] = u[14];
    v[15] = u[15];

    // stage 3
    u[0] = _mm256_add_epi32(v[0], v[3]);
    u[3] = _mm256_sub_epi32(v[0], v[3]);
    u[1] = _mm256_add_epi32(v[1], v[2]);
    u[2] = _mm256_sub_epi32(v[1], v[2]);
    u[4] = v[4];

    u[5] = _mm256_mullo_epi32(v[5], cospim32);
    x = _mm256_mullo_epi32(v[6], cospi32);
    u[5] = _mm256_add_epi32(u[5], x);
    u[5] = _mm256_add_epi32(u[5], rnding);
    u[5] = _mm256_srai_epi32(u[5], bit);

    u[6] = _mm256_mullo_epi32(v[5], cospi32);
    x = _mm256_mullo_epi32(v[6], cospim32);
    u[6] = _mm256_sub_epi32(u[6], x);
    u[6] = _mm256_add_epi32(u[6], rnding);
    u[6] = _mm256_srai_epi32(u[6], bit);

    u[7] = v[7];
    u[8] = _mm256_add_epi32(v[8], v[11]);
    u[11] = _mm256_sub_epi32(v[8], v[11]);
    u[9] = _mm256_add_epi32(v[9], v[10]);
    u[10] = _mm256_sub_epi32(v[9], v[10]);
    u[12] = _mm256_sub_epi32(v[15], v[12]);
    u[15] = _mm256_add_epi32(v[15], v[12]);
    u[13] = _mm256_sub_epi32(v[14], v[13]);
    u[14] = _mm256_add_epi32(v[14], v[13]);

    // stage 4
    u[0] = _mm256_mullo_epi32(u[0], cospi32);
    u[1] = _mm256_mullo_epi32(u[1], cospi32);
    v[0] = _mm256_add_epi32(u[0], u[1]);
    v[0] = _mm256_add_epi32(v[0], rnding);
    v[0] = _mm256_srai_epi32(v[0], bit);

    v[1] = _mm256_sub_epi32(u[0], u[1]);
    v[1] = _mm256_add_epi32(v[1], rnding);
    v[1] = _mm256_srai_epi32(v[1], bit);

    v[2] = _mm256_mullo_epi32(u[2], cospi48);
    x = _mm256_mullo_epi32(u[3], cospi16);
    v[2] = _mm256_add_epi32(v[2], x);
    v[2] = _mm256_add_epi32(v[2], rnding);
    v[2] = _mm256_srai_epi32(v[2], bit);

    v[3] = _mm256_mullo_epi32(u[2], cospi16);
    x = _mm256_mullo_epi32(u[3], cospi48);
    v[3] = _mm256_sub_epi32(x, v[3]);
    v[3] = _mm256_add_epi32(v[3], rnding);
    v[3] = _mm256_srai_epi32(v[3], bit);

    v[4] = _mm256_add_epi32(u[4], u[5]);
    v[5] = _mm256_sub_epi32(u[4], u[5]);
    v[6] = _mm256_sub_epi32(u[7], u[6]);
    v[7] = _mm256_add_epi32(u[7], u[6]);
    v[8] = u[8];

    v[9] = _mm256_mullo_epi32(u[9], cospim16);
    x = _mm256_mullo_epi32(u[14], cospi48);
    v[9] = _mm256_add_epi32(v[9], x);
    v[9] = _mm256_add_epi32(v[9], rnding);
    v[9] = _mm256_srai_epi32(v[9], bit);

    v[14] = _mm256_mullo_epi32(u[9], cospi48);
    x = _mm256_mullo_epi32(u[14], cospim16);
    v[14] = _mm256_sub_epi32(v[14], x);
    v[14] = _mm256_add_epi32(v[14], rnding);
    v[14] = _mm256_srai_epi32(v[14], bit);

    v[10] = _mm256_mullo_epi32(u[10], cospim48);
    x = _mm256_mullo_epi32(u[13], cospim16);
    v[10] = _mm256_add_epi32(v[10], x);
    v[10] = _mm256_add_epi32(v[10], rnding);
    v[10] = _mm256_srai_epi32(v[10], bit);

    v[13] = _mm256_mullo_epi32(u[10], cospim16);
    x = _mm256_mullo_epi32(u[13], cospim48);
    v[13] = _mm256_sub_epi32(v[13], x);
    v[13] = _mm256_add_epi32(v[13], rnding);
    v[13] = _mm256_srai_epi32(v[13], bit);

    v[11] = u[11];
    v[12] = u[12];
    v[15] = u[15];

    // stage 5
    u[0] = v[0];
    u[1] = v[1];
    u[2] = v[2];
    u[3] = v[3];

    u[4] = _mm256_mullo_epi32(v[4], cospi56);
    x = _mm256_mullo_epi32(v[7], cospi8);
    u[4] = _mm256_add_epi32(u[4], x);
    u[4] = _mm256_add_epi32(u[4], rnding);
    u[4] = _mm256_srai_epi32(u[4], bit);

    u[7] = _mm256_mullo_epi32(v[4], cospi8);
    x = _mm256_mullo_epi32(v[7], cospi56);
    u[7] = _mm256_sub_epi32(x, u[7]);
    u[7] = _mm256_add_epi32(u[7], rnding);
    u[7] = _mm256_srai_epi32(u[7], bit);

    u[5] = _mm256_mullo_epi32(v[5], cospi24);
    x = _mm256_mullo_epi32(v[6], cospi40);
    u[5] = _mm256_add_epi32(u[5], x);
    u[5] = _mm256_add_epi32(u[5], rnding);
    u[5] = _mm256_srai_epi32(u[5], bit);

    u[6] = _mm256_mullo_epi32(v[5], cospi40);
    x = _mm256_mullo_epi32(v[6], cospi24);
    u[6] = _mm256_sub_epi32(x, u[6]);
    u[6] = _mm256_add_epi32(u[6], rnding);
    u[6] = _mm256_srai_epi32(u[6], bit);

    u[8] = _mm256_add_epi32(v[8], v[9]);
    u[9] = _mm256_sub_epi32(v[8], v[9]);
    u[10] = _mm256_sub_epi32(v[11], v[10]);
    u[11] = _mm256_add_epi32(v[11], v[10]);
    u[12] = _mm256_add_epi32(v[12], v[13]);
    u[13] = _mm256_sub_epi32(v[12], v[13]);
    u[14] = _mm256_sub_epi32(v[15], v[14]);
    u[15] = _mm256_add_epi32(v[15], v[14]);

    // stage 6
    v[0] = u[0];
    v[1] = u[1];
    v[2] = u[2];
    v[3] = u[3];
    v[4] = u[4];
    v[5] = u[5];
    v[6] = u[6];
    v[7] = u[7];

    v[8] = _mm256_mullo_epi32(u[8], cospi60);
    x = _mm256_mullo_epi32(u[15], cospi4);
    v[8] = _mm256_add_epi32(v[8], x);
    v[8] = _mm256_add_epi32(v[8], rnding);
    v[8] = _mm256_srai_epi32(v[8], bit);

    v[15] = _mm256_mullo_epi32(u[8], cospi4);
    x = _mm256_mullo_epi32(u[15], cospi60);
    v[15] = _mm256_sub_epi32(x, v[15]);
    v[15] = _mm256_add_epi32(v[15], rnding);
    v[15] = _mm256_srai_epi32(v[15], bit);

    v[9] = _mm256_mullo_epi32(u[9], cospi28);
    x = _mm256_mullo_epi32(u[14], cospi36);
    v[9] = _mm256_add_epi32(v[9], x);
    v[9] = _mm256_add_epi32(v[9], rnding);
    v[9] = _mm256_srai_epi32(v[9], bit);

    v[14] = _mm256_mullo_epi32(u[9], cospi36);
    x = _mm256_mullo_epi32(u[14], cospi28);
    v[14] = _mm256_sub_epi32(x, v[14]);
    v[14] = _mm256_add_epi32(v[14], rnding);
    v[14] = _mm256_srai_epi32(v[14], bit);

    v[10] = _mm256_mullo_epi32(u[10], cospi44);
    x = _mm256_mullo_epi32(u[13], cospi20);
    v[10] = _mm256_add_epi32(v[10], x);
    v[10] = _mm256_add_epi32(v[10], rnding);
    v[10] = _mm256_srai_epi32(v[10], bit);

    v[13] = _mm256_mullo_epi32(u[10], cospi20);
    x = _mm256_mullo_epi32(u[13], cospi44);
    v[13] = _mm256_sub_epi32(x, v[13]);
    v[13] = _mm256_add_epi32(v[13], rnding);
    v[13] = _mm256_srai_epi32(v[13], bit);

    v[11] = _mm256_mullo_epi32(u[11], cospi12);
    x = _mm256_mullo_epi32(u[12], cospi52);
    v[11] = _mm256_add_epi32(v[11], x);
    v[11] = _mm256_add_epi32(v[11], rnding);
    v[11] = _mm256_srai_epi32(v[11], bit);

    v[12] = _mm256_mullo_epi32(u[11], cospi52);
    x = _mm256_mullo_epi32(u[12], cospi12);
    v[12] = _mm256_sub_epi32(x, v[12]);
    v[12] = _mm256_add_epi32(v[12], rnding);
    v[12] = _mm256_srai_epi32(v[12], bit);

    out[0 * outstride + col] = v[0];
    out[1 * outstride + col] = v[8];
    out[2 * outstride + col] = v[4];
    out[3 * outstride + col] = v[12];
    out[4 * outstride + col] = v[2];
    out[5 * outstride + col] = v[10];
    out[6 * outstride + col] = v[6];
    out[7 * outstride + col] = v[14];
    out[8 * outstride + col] = v[1];
    out[9 * outstride + col] = v[9];
    out[10 * outstride + col] = v[5];
    out[11 * outstride + col] = v[13];
    out[12 * outstride + col] = v[3];
    out[13 * outstride + col] = v[11];
    out[14 * outstride + col] = v[7];
    out[15 * outstride + col] = v[15];
  }
}
static void fadst16_avx2(__m256i *in, __m256i *out, const int8_t bit,
                         const int num_cols, const int outstride) {
  const int32_t *cospi = cospi_arr(bit);
  const __m256i cospi32 = _mm256_set1_epi32(cospi[32]);
  const __m256i cospi48 = _mm256_set1_epi32(cospi[48]);
  const __m256i cospi16 = _mm256_set1_epi32(cospi[16]);
  const __m256i cospim16 = _mm256_set1_epi32(-cospi[16]);
  const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]);
  const __m256i cospi8 = _mm256_set1_epi32(cospi[8]);
  const __m256i cospi56 = _mm256_set1_epi32(cospi[56]);
  const __m256i cospim56 = _mm256_set1_epi32(-cospi[56]);
  const __m256i cospim8 = _mm256_set1_epi32(-cospi[8]);
  const __m256i cospi24 = _mm256_set1_epi32(cospi[24]);
  const __m256i cospim24 = _mm256_set1_epi32(-cospi[24]);
  const __m256i cospim40 = _mm256_set1_epi32(-cospi[40]);
  const __m256i cospi40 = _mm256_set1_epi32(cospi[40]);
  const __m256i cospi2 = _mm256_set1_epi32(cospi[2]);
  const __m256i cospi62 = _mm256_set1_epi32(cospi[62]);
  const __m256i cospim2 = _mm256_set1_epi32(-cospi[2]);
  const __m256i cospi10 = _mm256_set1_epi32(cospi[10]);
  const __m256i cospi54 = _mm256_set1_epi32(cospi[54]);
  const __m256i cospim10 = _mm256_set1_epi32(-cospi[10]);
  const __m256i cospi18 = _mm256_set1_epi32(cospi[18]);
  const __m256i cospi46 = _mm256_set1_epi32(cospi[46]);
  const __m256i cospim18 = _mm256_set1_epi32(-cospi[18]);
  const __m256i cospi26 = _mm256_set1_epi32(cospi[26]);
  const __m256i cospi38 = _mm256_set1_epi32(cospi[38]);
  const __m256i cospim26 = _mm256_set1_epi32(-cospi[26]);
  const __m256i cospi34 = _mm256_set1_epi32(cospi[34]);
  const __m256i cospi30 = _mm256_set1_epi32(cospi[30]);
  const __m256i cospim34 = _mm256_set1_epi32(-cospi[34]);
  const __m256i cospi42 = _mm256_set1_epi32(cospi[42]);
  const __m256i cospi22 = _mm256_set1_epi32(cospi[22]);
  const __m256i cospim42 = _mm256_set1_epi32(-cospi[42]);
  const __m256i cospi50 = _mm256_set1_epi32(cospi[50]);
  const __m256i cospi14 = _mm256_set1_epi32(cospi[14]);
  const __m256i cospim50 = _mm256_set1_epi32(-cospi[50]);
  const __m256i cospi58 = _mm256_set1_epi32(cospi[58]);
  const __m256i cospi6 = _mm256_set1_epi32(cospi[6]);
  const __m256i cospim58 = _mm256_set1_epi32(-cospi[58]);
  const __m256i rnding = _mm256_set1_epi32(1 << (bit - 1));
  const __m256i zero = _mm256_setzero_si256();

  __m256i u[16], v[16], x, y;
  int col;

  for (col = 0; col < num_cols; ++col) {
    // stage 0
    // stage 1
    u[0] = in[0 * num_cols + col];
    u[1] = _mm256_sub_epi32(zero, in[15 * num_cols + col]);
    u[2] = _mm256_sub_epi32(zero, in[7 * num_cols + col]);
    u[3] = in[8 * num_cols + col];
    u[4] = _mm256_sub_epi32(zero, in[3 * num_cols + col]);
    u[5] = in[12 * num_cols + col];
    u[6] = in[4 * num_cols + col];
    u[7] = _mm256_sub_epi32(zero, in[11 * num_cols + col]);
    u[8] = _mm256_sub_epi32(zero, in[1 * num_cols + col]);
    u[9] = in[14 * num_cols + col];
    u[10] = in[6 * num_cols + col];
    u[11] = _mm256_sub_epi32(zero, in[9 * num_cols + col]);
    u[12] = in[2 * num_cols + col];
    u[13] = _mm256_sub_epi32(zero, in[13 * num_cols + col]);
    u[14] = _mm256_sub_epi32(zero, in[5 * num_cols + col]);
    u[15] = in[10 * num_cols + col];

    // stage 2
    v[0] = u[0];
    v[1] = u[1];

    x = _mm256_mullo_epi32(u[2], cospi32);
    y = _mm256_mullo_epi32(u[3], cospi32);
    v[2] = _mm256_add_epi32(x, y);
    v[2] = _mm256_add_epi32(v[2], rnding);
    v[2] = _mm256_srai_epi32(v[2], bit);

    v[3] = _mm256_sub_epi32(x, y);
    v[3] = _mm256_add_epi32(v[3], rnding);
    v[3] = _mm256_srai_epi32(v[3], bit);

    v[4] = u[4];
    v[5] = u[5];

    x = _mm256_mullo_epi32(u[6], cospi32);
    y = _mm256_mullo_epi32(u[7], cospi32);
    v[6] = _mm256_add_epi32(x, y);
    v[6] = _mm256_add_epi32(v[6], rnding);
    v[6] = _mm256_srai_epi32(v[6], bit);

    v[7] = _mm256_sub_epi32(x, y);
    v[7] = _mm256_add_epi32(v[7], rnding);
    v[7] = _mm256_srai_epi32(v[7], bit);

    v[8] = u[8];
    v[9] = u[9];

    x = _mm256_mullo_epi32(u[10], cospi32);
    y = _mm256_mullo_epi32(u[11], cospi32);
    v[10] = _mm256_add_epi32(x, y);
    v[10] = _mm256_add_epi32(v[10], rnding);
    v[10] = _mm256_srai_epi32(v[10], bit);

    v[11] = _mm256_sub_epi32(x, y);
    v[11] = _mm256_add_epi32(v[11], rnding);
    v[11] = _mm256_srai_epi32(v[11], bit);

    v[12] = u[12];
    v[13] = u[13];

    x = _mm256_mullo_epi32(u[14], cospi32);
    y = _mm256_mullo_epi32(u[15], cospi32);
    v[14] = _mm256_add_epi32(x, y);
    v[14] = _mm256_add_epi32(v[14], rnding);
    v[14] = _mm256_srai_epi32(v[14], bit);

    v[15] = _mm256_sub_epi32(x, y);
    v[15] = _mm256_add_epi32(v[15], rnding);
    v[15] = _mm256_srai_epi32(v[15], bit);

    // stage 3
    u[0] = _mm256_add_epi32(v[0], v[2]);
    u[1] = _mm256_add_epi32(v[1], v[3]);
    u[2] = _mm256_sub_epi32(v[0], v[2]);
    u[3] = _mm256_sub_epi32(v[1], v[3]);
    u[4] = _mm256_add_epi32(v[4], v[6]);
    u[5] = _mm256_add_epi32(v[5], v[7]);
    u[6] = _mm256_sub_epi32(v[4], v[6]);
    u[7] = _mm256_sub_epi32(v[5], v[7]);
    u[8] = _mm256_add_epi32(v[8], v[10]);
    u[9] = _mm256_add_epi32(v[9], v[11]);
    u[10] = _mm256_sub_epi32(v[8], v[10]);
    u[11] = _mm256_sub_epi32(v[9], v[11]);
    u[12] = _mm256_add_epi32(v[12], v[14]);
    u[13] = _mm256_add_epi32(v[13], v[15]);
    u[14] = _mm256_sub_epi32(v[12], v[14]);
    u[15] = _mm256_sub_epi32(v[13], v[15]);

    // stage 4
    v[0] = u[0];
    v[1] = u[1];
    v[2] = u[2];
    v[3] = u[3];
    v[4] = av1_half_btf_avx2(&cospi16, &u[4], &cospi48, &u[5], &rnding, bit);
    v[5] = av1_half_btf_avx2(&cospi48, &u[4], &cospim16, &u[5], &rnding, bit);
    v[6] = av1_half_btf_avx2(&cospim48, &u[6], &cospi16, &u[7], &rnding, bit);
    v[7] = av1_half_btf_avx2(&cospi16, &u[6], &cospi48, &u[7], &rnding, bit);
    v[8] = u[8];
    v[9] = u[9];
    v[10] = u[10];
    v[11] = u[11];
    v[12] = av1_half_btf_avx2(&cospi16, &u[12], &cospi48, &u[13], &rnding, bit);
    v[13] =
        av1_half_btf_avx2(&cospi48, &u[12], &cospim16, &u[13], &rnding, bit);
    v[14] =
        av1_half_btf_avx2(&cospim48, &u[14], &cospi16, &u[15], &rnding, bit);
    v[15] = av1_half_btf_avx2(&cospi16, &u[14], &cospi48, &u[15], &rnding, bit);

    // stage 5
    u[0] = _mm256_add_epi32(v[0], v[4]);
    u[1] = _mm256_add_epi32(v[1], v[5]);
    u[2] = _mm256_add_epi32(v[2], v[6]);
    u[3] = _mm256_add_epi32(v[3], v[7]);
    u[4] = _mm256_sub_epi32(v[0], v[4]);
    u[5] = _mm256_sub_epi32(v[1], v[5]);
    u[6] = _mm256_sub_epi32(v[2], v[6]);
    u[7] = _mm256_sub_epi32(v[3], v[7]);
    u[8] = _mm256_add_epi32(v[8], v[12]);
    u[9] = _mm256_add_epi32(v[9], v[13]);
    u[10] = _mm256_add_epi32(v[10], v[14]);
    u[11] = _mm256_add_epi32(v[11], v[15]);
    u[12] = _mm256_sub_epi32(v[8], v[12]);
    u[13] = _mm256_sub_epi32(v[9], v[13]);
    u[14] = _mm256_sub_epi32(v[10], v[14]);
    u[15] = _mm256_sub_epi32(v[11], v[15]);

    // stage 6
    v[0] = u[0];
    v[1] = u[1];
    v[2] = u[2];
    v[3] = u[3];
    v[4] = u[4];
    v[5] = u[5];
    v[6] = u[6];
    v[7] = u[7];
    v[8] = av1_half_btf_avx2(&cospi8, &u[8], &cospi56, &u[9], &rnding, bit);
    v[9] = av1_half_btf_avx2(&cospi56, &u[8], &cospim8, &u[9], &rnding, bit);
    v[10] = av1_half_btf_avx2(&cospi40, &u[10], &cospi24, &u[11], &rnding, bit);
    v[11] =
        av1_half_btf_avx2(&cospi24, &u[10], &cospim40, &u[11], &rnding, bit);
    v[12] = av1_half_btf_avx2(&cospim56, &u[12], &cospi8, &u[13], &rnding, bit);
    v[13] = av1_half_btf_avx2(&cospi8, &u[12], &cospi56, &u[13], &rnding, bit);
    v[14] =
        av1_half_btf_avx2(&cospim24, &u[14], &cospi40, &u[15], &rnding, bit);
    v[15] = av1_half_btf_avx2(&cospi40, &u[14], &cospi24, &u[15], &rnding, bit);

    // stage 7
    u[0] = _mm256_add_epi32(v[0], v[8]);
    u[1] = _mm256_add_epi32(v[1], v[9]);
    u[2] = _mm256_add_epi32(v[2], v[10]);
    u[3] = _mm256_add_epi32(v[3], v[11]);
    u[4] = _mm256_add_epi32(v[4], v[12]);
    u[5] = _mm256_add_epi32(v[5], v[13]);
    u[6] = _mm256_add_epi32(v[6], v[14]);
    u[7] = _mm256_add_epi32(v[7], v[15]);
    u[8] = _mm256_sub_epi32(v[0], v[8]);
    u[9] = _mm256_sub_epi32(v[1], v[9]);
    u[10] = _mm256_sub_epi32(v[2], v[10]);
    u[11] = _mm256_sub_epi32(v[3], v[11]);
    u[12] = _mm256_sub_epi32(v[4], v[12]);
    u[13] = _mm256_sub_epi32(v[5], v[13]);
    u[14] = _mm256_sub_epi32(v[6], v[14]);
    u[15] = _mm256_sub_epi32(v[7], v[15]);

    // stage 8
    v[0] = av1_half_btf_avx2(&cospi2, &u[0], &cospi62, &u[1], &rnding, bit);
    v[1] = av1_half_btf_avx2(&cospi62, &u[0], &cospim2, &u[1], &rnding, bit);
    v[2] = av1_half_btf_avx2(&cospi10, &u[2], &cospi54, &u[3], &rnding, bit);
    v[3] = av1_half_btf_avx2(&cospi54, &u[2], &cospim10, &u[3], &rnding, bit);
    v[4] = av1_half_btf_avx2(&cospi18, &u[4], &cospi46, &u[5], &rnding, bit);
    v[5] = av1_half_btf_avx2(&cospi46, &u[4], &cospim18, &u[5], &rnding, bit);
    v[6] = av1_half_btf_avx2(&cospi26, &u[6], &cospi38, &u[7], &rnding, bit);
    v[7] = av1_half_btf_avx2(&cospi38, &u[6], &cospim26, &u[7], &rnding, bit);
    v[8] = av1_half_btf_avx2(&cospi34, &u[8], &cospi30, &u[9], &rnding, bit);
    v[9] = av1_half_btf_avx2(&cospi30, &u[8], &cospim34, &u[9], &rnding, bit);
    v[10] = av1_half_btf_avx2(&cospi42, &u[10], &cospi22, &u[11], &rnding, bit);
    v[11] =
        av1_half_btf_avx2(&cospi22, &u[10], &cospim42, &u[11], &rnding, bit);
    v[12] = av1_half_btf_avx2(&cospi50, &u[12], &cospi14, &u[13], &rnding, bit);
    v[13] =
        av1_half_btf_avx2(&cospi14, &u[12], &cospim50, &u[13], &rnding, bit);
    v[14] = av1_half_btf_avx2(&cospi58, &u[14], &cospi6, &u[15], &rnding, bit);
    v[15] = av1_half_btf_avx2(&cospi6, &u[14], &cospim58, &u[15], &rnding, bit);

    // stage 9
    out[0 * outstride + col] = v[1];
    out[1 * outstride + col] = v[14];
    out[2 * outstride + col] = v[3];
    out[3 * outstride + col] = v[12];
    out[4 * outstride + col] = v[5];
    out[5 * outstride + col] = v[10];
    out[6 * outstride + col] = v[7];
    out[7 * outstride + col] = v[8];
    out[8 * outstride + col] = v[9];
    out[9 * outstride + col] = v[6];
    out[10 * outstride + col] = v[11];
    out[11 * outstride + col] = v[4];
    out[12 * outstride + col] = v[13];
    out[13 * outstride + col] = v[2];
    out[14 * outstride + col] = v[15];
    out[15 * outstride + col] = v[0];
  }
}
static void idtx16_avx2(__m256i *in, __m256i *out, const int8_t bit,
                        int col_num, const int outstride) {
  (void)bit;
  (void)outstride;
  __m256i fact = _mm256_set1_epi32(2 * NewSqrt2);
  __m256i offset = _mm256_set1_epi32(1 << (NewSqrt2Bits - 1));
  __m256i a_low;

  int num_iters = 16 * col_num;
  for (int i = 0; i < num_iters; i++) {
    a_low = _mm256_mullo_epi32(in[i], fact);
    a_low = _mm256_add_epi32(a_low, offset);
    out[i] = _mm256_srai_epi32(a_low, NewSqrt2Bits);
  }
}
static const transform_1d_avx2 col_highbd_txfm8x16_arr[TX_TYPES] = {
  fdct16_avx2,   // DCT_DCT
  fadst16_avx2,  // ADST_DCT
  fdct16_avx2,   // DCT_ADST
  fadst16_avx2,  // ADST_ADST
  fadst16_avx2,  // FLIPADST_DCT
  fdct16_avx2,   // DCT_FLIPADST
  fadst16_avx2,  // FLIPADST_FLIPADST
  fadst16_avx2,  // ADST_FLIPADST
  fadst16_avx2,  // FLIPADST_ADST
  idtx16_avx2,   // IDTX
  fdct16_avx2,   // V_DCT
  idtx16_avx2,   // H_DCT
  fadst16_avx2,  // V_ADST
  idtx16_avx2,   // H_ADST
  fadst16_avx2,  // V_FLIPADST
  idtx16_avx2    // H_FLIPADST
};
static const transform_1d_avx2 row_highbd_txfm8x8_arr[TX_TYPES] = {
  fdct8_avx2,   // DCT_DCT
  fdct8_avx2,   // ADST_DCT
  fadst8_avx2,  // DCT_ADST
  fadst8_avx2,  // ADST_ADST
  fdct8_avx2,   // FLIPADST_DCT
  fadst8_avx2,  // DCT_FLIPADST
  fadst8_avx2,  // FLIPADST_FLIPADST
  fadst8_avx2,  // ADST_FLIPADST
  fadst8_avx2,  // FLIPADST_ADST
  idtx8_avx2,   // IDTX
  idtx8_avx2,   // V_DCT
  fdct8_avx2,   // H_DCT
  idtx8_avx2,   // V_ADST
  fadst8_avx2,  // H_ADST
  idtx8_avx2,   // V_FLIPADST
  fadst8_avx2   // H_FLIPADST
};
void av1_fwd_txfm2d_8x16_avx2(const int16_t *input, int32_t *coeff, int stride,
                              TX_TYPE tx_type, int bd) {
  __m256i in[16], out[16];
  const int8_t *shift = av1_fwd_txfm_shift_ls[TX_8X16];
  const int txw_idx = get_txw_idx(TX_8X16);
  const int txh_idx = get_txh_idx(TX_8X16);
  const transform_1d_avx2 col_txfm = col_highbd_txfm8x16_arr[tx_type];
  const transform_1d_avx2 row_txfm = row_highbd_txfm8x8_arr[tx_type];
  const int8_t bit = av1_fwd_cos_bit_col[txw_idx][txh_idx];
  int ud_flip, lr_flip;
  get_flip_cfg(tx_type, &ud_flip, &lr_flip);

  load_buffer_8x16_avx2(input, in, stride, ud_flip, lr_flip, shift[0]);
  col_txfm(in, out, bit, 1, 1);
  col_txfm_8x8_rounding(out, -shift[1]);
  col_txfm_8x8_rounding(&out[8], -shift[1]);
  fwd_txfm_transpose_8x8_avx2(out, in, 1, 2);
  fwd_txfm_transpose_8x8_avx2(&out[8], &in[1], 1, 2);
  row_txfm(in, out, bit, 2, 2);
  round_shift_rect_array_32_avx2(out, in, 16, -shift[2], NewSqrt2);
  store_buffer_avx2(in, coeff, 8, 16);
  (void)bd;
}
static const transform_1d_avx2 col_highbd_txfm8x8_arr[TX_TYPES] = {
  fdct8_avx2,   // DCT_DCT
  fadst8_avx2,  // ADST_DCT
  fdct8_avx2,   // DCT_ADST
  fadst8_avx2,  // ADST_ADST
  fadst8_avx2,  // FLIPADST_DCT
  fdct8_avx2,   // DCT_FLIPADST
  fadst8_avx2,  // FLIPADST_FLIPADST
  fadst8_avx2,  // ADST_FLIPADST
  fadst8_avx2,  // FLIPADST_ADST
  idtx8_avx2,   // IDTX
  fdct8_avx2,   // V_DCT
  idtx8_avx2,   // H_DCT
  fadst8_avx2,  // V_ADST
  idtx8_avx2,   // H_ADST
  fadst8_avx2,  // V_FLIPADST
  idtx8_avx2    // H_FLIPADST
};
static const transform_1d_avx2 row_highbd_txfm8x16_arr[TX_TYPES] = {
  fdct16_avx2,   // DCT_DCT
  fdct16_avx2,   // ADST_DCT
  fadst16_avx2,  // DCT_ADST
  fadst16_avx2,  // ADST_ADST
  fdct16_avx2,   // FLIPADST_DCT
  fadst16_avx2,  // DCT_FLIPADST
  fadst16_avx2,  // FLIPADST_FLIPADST
  fadst16_avx2,  // ADST_FLIPADST
  fadst16_avx2,  // FLIPADST_ADST
  idtx16_avx2,   // IDTX
  idtx16_avx2,   // V_DCT
  fdct16_avx2,   // H_DCT
  idtx16_avx2,   // V_ADST
  fadst16_avx2,  // H_ADST
  idtx16_avx2,   // V_FLIPADST
  fadst16_avx2   // H_FLIPADST
};
void av1_fwd_txfm2d_16x8_avx2(const int16_t *input, int32_t *coeff, int stride,
                              TX_TYPE tx_type, int bd) {
  __m256i in[16], out[16];
  const int8_t *shift = av1_fwd_txfm_shift_ls[TX_16X8];
  const int txw_idx = get_txw_idx(TX_16X8);
  const int txh_idx = get_txh_idx(TX_16X8);
  const transform_1d_avx2 col_txfm = col_highbd_txfm8x8_arr[tx_type];
  const transform_1d_avx2 row_txfm = row_highbd_txfm8x16_arr[tx_type];
  const int8_t bit = av1_fwd_cos_bit_col[txw_idx][txh_idx];
  int ud_flip, lr_flip;
  get_flip_cfg(tx_type, &ud_flip, &lr_flip);

  load_buffer_16xn_avx2(input, in, stride, 8, 2, ud_flip, lr_flip);
  round_shift_32_8xn_avx2(in, 16, shift[0], 1);
  col_txfm(in, out, bit, 2, 2);
  round_shift_32_8xn_avx2(out, 16, shift[1], 1);
  fwd_txfm_transpose_8x8_avx2(out, in, 2, 1);
  fwd_txfm_transpose_8x8_avx2(&out[1], &in[8], 2, 1);
  row_txfm(in, out, bit, 1, 1);
  round_shift_rect_array_32_avx2(out, out, 16, -shift[2], NewSqrt2);
  store_buffer_avx2(out, coeff, 8, 16);
  (void)bd;
}
void av1_fwd_txfm2d_16x16_avx2(const int16_t *input, int32_t *coeff, int stride,
                               TX_TYPE tx_type, int bd) {
  __m256i in[32], out[32];
  const TX_SIZE tx_size = TX_16X16;
  const int8_t *shift = av1_fwd_txfm_shift_ls[tx_size];
  const int txw_idx = get_txw_idx(tx_size);
  const int txh_idx = get_txh_idx(tx_size);
  const int width = tx_size_wide[tx_size];
  const int height = tx_size_high[tx_size];
  const int width_div8 = (width >> 3);
  const int width_div16 = (width >> 4);
  const int size = (height << 1);
  switch (tx_type) {
    case DCT_DCT:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      fdct16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      fdct16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case ADST_DCT:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      fadst16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                   width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      fdct16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case DCT_ADST:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      fdct16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      fadst16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                   width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case ADST_ADST:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      fadst16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                   width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      fadst16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                   width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case FLIPADST_DCT:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 1, 0);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      fadst16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                   width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      fdct16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case DCT_FLIPADST:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 1);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      fdct16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      fadst16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                   width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case FLIPADST_FLIPADST:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 1, 1);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      fadst16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                   width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      fadst16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                   width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case ADST_FLIPADST:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 1);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      fadst16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                   width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      fadst16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                   width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case FLIPADST_ADST:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 1, 0);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      fadst16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                   width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      fadst16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                   width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case IDTX:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      idtx16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      idtx16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case V_DCT:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      fdct16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      idtx16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case H_DCT:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      idtx16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      fdct16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case V_ADST:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      fadst16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                   width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      idtx16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case H_ADST:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 0);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      idtx16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      fadst16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                   width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case V_FLIPADST:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 1, 0);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      fadst16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                   width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      idtx16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                  width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    case H_FLIPADST:
      load_buffer_16xn_avx2(input, in, stride, height, width_div8, 0, 1);
      round_shift_32_8xn_avx2(in, size, shift[0], width_div16);
      idtx16_avx2(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], width_div8,
                  width_div8);
      round_shift_32_8xn_avx2(out, size, shift[1], width_div16);
      fwd_txfm_transpose_16x16_avx2(out, in);
      fadst16_avx2(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], width_div8,
                   width_div8);
      store_buffer_avx2(out, coeff, 8, 32);
      break;
    default: assert(0);
  }
  (void)bd;
}
static inline void fdct32_avx2(__m256i *input, __m256i *output,
                               const int8_t cos_bit, const int instride,
                               const int outstride) {
  __m256i buf0[32];
  __m256i buf1[32];
  const int32_t *cospi;
  int startidx = 0 * instride;
  int endidx = 31 * instride;
  // stage 0
  // stage 1
  buf1[0] = _mm256_add_epi32(input[startidx], input[endidx]);
  buf1[31] = _mm256_sub_epi32(input[startidx], input[endidx]);
  startidx += instride;
  endidx -= instride;
  buf1[1] = _mm256_add_epi32(input[startidx], input[endidx]);
  buf1[30] = _mm256_sub_epi32(input[startidx], input[endidx]);
  startidx += instride;
  endidx -= instride;
  buf1[2] = _mm256_add_epi32(input[startidx], input[endidx]);
  buf1[29] = _mm256_sub_epi32(input[startidx], input[endidx]);
  startidx += instride;
  endidx -= instride;
  buf1[3] = _mm256_add_epi32(input[startidx], input[endidx]);
  buf1[28] = _mm256_sub_epi32(input[startidx], input[endidx]);
--> --------------------

--> maximum size reached

--> --------------------

Messung V0.5
C=92 H=95 G=93

¤ Dauer der Verarbeitung: 0.19 Sekunden  (vorverarbeitet)  ¤

*© Formatika GbR, Deutschland






Wurzel

Suchen

Beweissystem der NASA

Beweissystem Isabelle

NIST Cobol Testsuite

Cephes Mathematical Library

Wiener Entwicklungsmethode

Haftungshinweis

Die Informationen auf dieser Webseite wurden nach bestem Wissen sorgfältig zusammengestellt. Es wird jedoch weder Vollständigkeit, noch Richtigkeit, noch Qualität der bereit gestellten Informationen zugesichert.

Bemerkung:

Die farbliche Syntaxdarstellung und die Messung sind noch experimentell.