diff options
author | trav90 <travawine@palemoon.org> | 2018-10-15 21:45:30 -0500 |
---|---|---|
committer | trav90 <travawine@palemoon.org> | 2018-10-15 21:45:30 -0500 |
commit | 68569dee1416593955c1570d638b3d9250b33012 (patch) | |
tree | d960f017cd7eba3f125b7e8a813789ee2e076310 /third_party/aom/av1/encoder/x86 | |
parent | 07c17b6b98ed32fcecff15c083ab0fd878de3cf0 (diff) | |
download | UXP-68569dee1416593955c1570d638b3d9250b33012.tar UXP-68569dee1416593955c1570d638b3d9250b33012.tar.gz UXP-68569dee1416593955c1570d638b3d9250b33012.tar.lz UXP-68569dee1416593955c1570d638b3d9250b33012.tar.xz UXP-68569dee1416593955c1570d638b3d9250b33012.zip |
Import aom library
This is the reference implementation for the Alliance for Open Media's av1 video code.
The commit used was 4d668d7feb1f8abd809d1bca0418570a7f142a36.
Diffstat (limited to 'third_party/aom/av1/encoder/x86')
-rw-r--r-- | third_party/aom/av1/encoder/x86/av1_highbd_quantize_sse4.c | 193 | ||||
-rw-r--r-- | third_party/aom/av1/encoder/x86/av1_quantize_sse2.c | 211 | ||||
-rw-r--r-- | third_party/aom/av1/encoder/x86/av1_quantize_ssse3_x86_64.asm | 204 | ||||
-rw-r--r-- | third_party/aom/av1/encoder/x86/av1_ssim_opt_x86_64.asm | 219 | ||||
-rw-r--r-- | third_party/aom/av1/encoder/x86/dct_intrin_sse2.c | 3884 | ||||
-rw-r--r-- | third_party/aom/av1/encoder/x86/dct_sse2.asm | 87 | ||||
-rw-r--r-- | third_party/aom/av1/encoder/x86/dct_ssse3.c | 469 | ||||
-rw-r--r-- | third_party/aom/av1/encoder/x86/error_intrin_avx2.c | 73 | ||||
-rw-r--r-- | third_party/aom/av1/encoder/x86/error_sse2.asm | 125 | ||||
-rw-r--r-- | third_party/aom/av1/encoder/x86/highbd_block_error_intrin_sse2.c | 72 | ||||
-rw-r--r-- | third_party/aom/av1/encoder/x86/highbd_fwd_txfm_sse4.c | 1895 | ||||
-rw-r--r-- | third_party/aom/av1/encoder/x86/hybrid_fwd_txfm_avx2.c | 1678 | ||||
-rw-r--r-- | third_party/aom/av1/encoder/x86/temporal_filter_apply_sse2.asm | 215 | ||||
-rw-r--r-- | third_party/aom/av1/encoder/x86/wedge_utils_sse2.c | 254 |
14 files changed, 9579 insertions, 0 deletions
diff --git a/third_party/aom/av1/encoder/x86/av1_highbd_quantize_sse4.c b/third_party/aom/av1/encoder/x86/av1_highbd_quantize_sse4.c new file mode 100644 index 000000000..fa5626002 --- /dev/null +++ b/third_party/aom/av1/encoder/x86/av1_highbd_quantize_sse4.c @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <smmintrin.h> +#include <stdint.h> + +#include "./av1_rtcd.h" +#include "aom_dsp/aom_dsp_common.h" + +// Coefficient quantization phase 1 +// param[0-2] : rounding/quan/dequan constants +static INLINE void quantize_coeff_phase1(__m128i *coeff, const __m128i *param, + const int shift, const int scale, + __m128i *qcoeff, __m128i *dquan, + __m128i *sign) { + const __m128i zero = _mm_setzero_si128(); + const __m128i one = _mm_set1_epi32(1); + + *sign = _mm_cmplt_epi32(*coeff, zero); + *sign = _mm_or_si128(*sign, one); + *coeff = _mm_abs_epi32(*coeff); + + qcoeff[0] = _mm_add_epi32(*coeff, param[0]); + qcoeff[1] = _mm_unpackhi_epi32(qcoeff[0], zero); + qcoeff[0] = _mm_unpacklo_epi32(qcoeff[0], zero); + + qcoeff[0] = _mm_mul_epi32(qcoeff[0], param[1]); + qcoeff[0] = _mm_srli_epi64(qcoeff[0], shift); + dquan[0] = _mm_mul_epi32(qcoeff[0], param[2]); + dquan[0] = _mm_srli_epi64(dquan[0], scale); +} + +// Coefficient quantization phase 2 +static INLINE void quantize_coeff_phase2(__m128i *qcoeff, __m128i *dquan, + const __m128i *sign, + const __m128i *param, const int shift, + const int scale, tran_low_t *qAddr, + tran_low_t *dqAddr) { + __m128i mask0L = _mm_set_epi32(-1, -1, 0, 0); + __m128i mask0H = _mm_set_epi32(0, 0, -1, -1); + + qcoeff[1] = _mm_mul_epi32(qcoeff[1], param[1]); + qcoeff[1] = _mm_srli_epi64(qcoeff[1], shift); + dquan[1] = _mm_mul_epi32(qcoeff[1], param[2]); + dquan[1] = _mm_srli_epi64(dquan[1], scale); + + // combine L&H + qcoeff[0] = _mm_shuffle_epi32(qcoeff[0], 0xd8); + qcoeff[1] = _mm_shuffle_epi32(qcoeff[1], 0x8d); + + qcoeff[0] = _mm_and_si128(qcoeff[0], mask0H); + qcoeff[1] = _mm_and_si128(qcoeff[1], mask0L); + + dquan[0] = _mm_shuffle_epi32(dquan[0], 0xd8); + dquan[1] = _mm_shuffle_epi32(dquan[1], 0x8d); + + dquan[0] = _mm_and_si128(dquan[0], mask0H); + dquan[1] = _mm_and_si128(dquan[1], mask0L); + + qcoeff[0] = _mm_or_si128(qcoeff[0], qcoeff[1]); + dquan[0] = _mm_or_si128(dquan[0], dquan[1]); + + qcoeff[0] = _mm_sign_epi32(qcoeff[0], *sign); + dquan[0] = _mm_sign_epi32(dquan[0], *sign); + + _mm_storeu_si128((__m128i *)qAddr, qcoeff[0]); + _mm_storeu_si128((__m128i *)dqAddr, dquan[0]); +} + +static INLINE void find_eob(tran_low_t *qcoeff_ptr, const int16_t *iscan, + __m128i *eob) { + const __m128i zero = _mm_setzero_si128(); + __m128i mask, iscanIdx; + const __m128i q0 = _mm_loadu_si128((__m128i const *)qcoeff_ptr); + const __m128i q1 = _mm_loadu_si128((__m128i const *)(qcoeff_ptr + 4)); + __m128i nz_flag0 = _mm_cmpeq_epi32(q0, zero); + __m128i nz_flag1 = _mm_cmpeq_epi32(q1, zero); + + nz_flag0 = _mm_cmpeq_epi32(nz_flag0, zero); + nz_flag1 = _mm_cmpeq_epi32(nz_flag1, zero); + + mask = _mm_packs_epi32(nz_flag0, nz_flag1); + iscanIdx = _mm_loadu_si128((__m128i const *)iscan); + iscanIdx = _mm_sub_epi16(iscanIdx, mask); + iscanIdx = _mm_and_si128(iscanIdx, mask); + *eob = _mm_max_epi16(*eob, iscanIdx); +} + +static INLINE uint16_t get_accumulated_eob(__m128i *eob) { + __m128i eob_shuffled; + uint16_t eobValue; + eob_shuffled = _mm_shuffle_epi32(*eob, 0xe); + *eob = _mm_max_epi16(*eob, eob_shuffled); + eob_shuffled = _mm_shufflelo_epi16(*eob, 0xe); + *eob = _mm_max_epi16(*eob, eob_shuffled); + eob_shuffled = _mm_shufflelo_epi16(*eob, 0x1); + *eob = _mm_max_epi16(*eob, eob_shuffled); + eobValue = _mm_extract_epi16(*eob, 0); + return eobValue; +} + +void av1_highbd_quantize_fp_sse4_1( + const tran_low_t *coeff_ptr, intptr_t count, int skip_block, + const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, + tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, + const int16_t *scan, const int16_t *iscan, int log_scale) { + __m128i coeff[2], qcoeff[2], dequant[2], qparam[3], coeff_sign; + __m128i eob = _mm_setzero_si128(); + const tran_low_t *src = coeff_ptr; + tran_low_t *quanAddr = qcoeff_ptr; + tran_low_t *dquanAddr = dqcoeff_ptr; + const int shift = 16 - log_scale; + const int coeff_stride = 4; + const int quan_stride = coeff_stride; + (void)skip_block; + (void)zbin_ptr; + (void)quant_shift_ptr; + (void)scan; + + memset(quanAddr, 0, count * sizeof(quanAddr[0])); + memset(dquanAddr, 0, count * sizeof(dquanAddr[0])); + + if (!skip_block) { + coeff[0] = _mm_loadu_si128((__m128i const *)src); + + qparam[0] = + _mm_set_epi32(round_ptr[1], round_ptr[1], round_ptr[1], round_ptr[0]); + qparam[1] = _mm_set_epi64x(quant_ptr[1], quant_ptr[0]); + qparam[2] = _mm_set_epi64x(dequant_ptr[1], dequant_ptr[0]); + + // DC and first 3 AC + quantize_coeff_phase1(&coeff[0], qparam, shift, log_scale, qcoeff, dequant, + &coeff_sign); + + // update round/quan/dquan for AC + qparam[0] = _mm_unpackhi_epi64(qparam[0], qparam[0]); + qparam[1] = _mm_set_epi64x(quant_ptr[1], quant_ptr[1]); + qparam[2] = _mm_set_epi64x(dequant_ptr[1], dequant_ptr[1]); + + quantize_coeff_phase2(qcoeff, dequant, &coeff_sign, qparam, shift, + log_scale, quanAddr, dquanAddr); + + // next 4 AC + coeff[1] = _mm_loadu_si128((__m128i const *)(src + coeff_stride)); + quantize_coeff_phase1(&coeff[1], qparam, shift, log_scale, qcoeff, dequant, + &coeff_sign); + quantize_coeff_phase2(qcoeff, dequant, &coeff_sign, qparam, shift, + log_scale, quanAddr + quan_stride, + dquanAddr + quan_stride); + + find_eob(quanAddr, iscan, &eob); + + count -= 8; + + // loop for the rest of AC + while (count > 0) { + src += coeff_stride << 1; + quanAddr += quan_stride << 1; + dquanAddr += quan_stride << 1; + iscan += quan_stride << 1; + + coeff[0] = _mm_loadu_si128((__m128i const *)src); + coeff[1] = _mm_loadu_si128((__m128i const *)(src + coeff_stride)); + + quantize_coeff_phase1(&coeff[0], qparam, shift, log_scale, qcoeff, + dequant, &coeff_sign); + quantize_coeff_phase2(qcoeff, dequant, &coeff_sign, qparam, shift, + log_scale, quanAddr, dquanAddr); + + quantize_coeff_phase1(&coeff[1], qparam, shift, log_scale, qcoeff, + dequant, &coeff_sign); + quantize_coeff_phase2(qcoeff, dequant, &coeff_sign, qparam, shift, + log_scale, quanAddr + quan_stride, + dquanAddr + quan_stride); + + find_eob(quanAddr, iscan, &eob); + + count -= 8; + } + *eob_ptr = get_accumulated_eob(&eob); + } else { + *eob_ptr = 0; + } +} diff --git a/third_party/aom/av1/encoder/x86/av1_quantize_sse2.c b/third_party/aom/av1/encoder/x86/av1_quantize_sse2.c new file mode 100644 index 000000000..f9c95b6cb --- /dev/null +++ b/third_party/aom/av1/encoder/x86/av1_quantize_sse2.c @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <emmintrin.h> +#include <xmmintrin.h> + +#include "./av1_rtcd.h" +#include "aom/aom_integer.h" + +void av1_quantize_fp_sse2(const int16_t *coeff_ptr, intptr_t n_coeffs, + int skip_block, const int16_t *zbin_ptr, + const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, int16_t *qcoeff_ptr, + int16_t *dqcoeff_ptr, const int16_t *dequant_ptr, + uint16_t *eob_ptr, const int16_t *scan_ptr, + const int16_t *iscan_ptr) { + __m128i zero; + __m128i thr; + int16_t nzflag; + (void)scan_ptr; + (void)zbin_ptr; + (void)quant_shift_ptr; + + coeff_ptr += n_coeffs; + iscan_ptr += n_coeffs; + qcoeff_ptr += n_coeffs; + dqcoeff_ptr += n_coeffs; + n_coeffs = -n_coeffs; + zero = _mm_setzero_si128(); + + if (!skip_block) { + __m128i eob; + __m128i round, quant, dequant; + { + __m128i coeff0, coeff1; + + // Setup global values + { + round = _mm_load_si128((const __m128i *)round_ptr); + quant = _mm_load_si128((const __m128i *)quant_ptr); + dequant = _mm_load_si128((const __m128i *)dequant_ptr); + } + + { + __m128i coeff0_sign, coeff1_sign; + __m128i qcoeff0, qcoeff1; + __m128i qtmp0, qtmp1; + // Do DC and first 15 AC + coeff0 = _mm_load_si128((const __m128i *)(coeff_ptr + n_coeffs)); + coeff1 = _mm_load_si128((const __m128i *)(coeff_ptr + n_coeffs) + 1); + + // Poor man's sign extract + coeff0_sign = _mm_srai_epi16(coeff0, 15); + coeff1_sign = _mm_srai_epi16(coeff1, 15); + qcoeff0 = _mm_xor_si128(coeff0, coeff0_sign); + qcoeff1 = _mm_xor_si128(coeff1, coeff1_sign); + qcoeff0 = _mm_sub_epi16(qcoeff0, coeff0_sign); + qcoeff1 = _mm_sub_epi16(qcoeff1, coeff1_sign); + + qcoeff0 = _mm_adds_epi16(qcoeff0, round); + round = _mm_unpackhi_epi64(round, round); + qcoeff1 = _mm_adds_epi16(qcoeff1, round); + qtmp0 = _mm_mulhi_epi16(qcoeff0, quant); + quant = _mm_unpackhi_epi64(quant, quant); + qtmp1 = _mm_mulhi_epi16(qcoeff1, quant); + + // Reinsert signs + qcoeff0 = _mm_xor_si128(qtmp0, coeff0_sign); + qcoeff1 = _mm_xor_si128(qtmp1, coeff1_sign); + qcoeff0 = _mm_sub_epi16(qcoeff0, coeff0_sign); + qcoeff1 = _mm_sub_epi16(qcoeff1, coeff1_sign); + + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs), qcoeff0); + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs) + 1, qcoeff1); + + coeff0 = _mm_mullo_epi16(qcoeff0, dequant); + dequant = _mm_unpackhi_epi64(dequant, dequant); + coeff1 = _mm_mullo_epi16(qcoeff1, dequant); + + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs), coeff0); + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs) + 1, coeff1); + } + + { + // Scan for eob + __m128i zero_coeff0, zero_coeff1; + __m128i nzero_coeff0, nzero_coeff1; + __m128i iscan0, iscan1; + __m128i eob1; + zero_coeff0 = _mm_cmpeq_epi16(coeff0, zero); + zero_coeff1 = _mm_cmpeq_epi16(coeff1, zero); + nzero_coeff0 = _mm_cmpeq_epi16(zero_coeff0, zero); + nzero_coeff1 = _mm_cmpeq_epi16(zero_coeff1, zero); + iscan0 = _mm_load_si128((const __m128i *)(iscan_ptr + n_coeffs)); + iscan1 = _mm_load_si128((const __m128i *)(iscan_ptr + n_coeffs) + 1); + // Add one to convert from indices to counts + iscan0 = _mm_sub_epi16(iscan0, nzero_coeff0); + iscan1 = _mm_sub_epi16(iscan1, nzero_coeff1); + eob = _mm_and_si128(iscan0, nzero_coeff0); + eob1 = _mm_and_si128(iscan1, nzero_coeff1); + eob = _mm_max_epi16(eob, eob1); + } + n_coeffs += 8 * 2; + } + + thr = _mm_srai_epi16(dequant, 1); + + // AC only loop + while (n_coeffs < 0) { + __m128i coeff0, coeff1; + { + __m128i coeff0_sign, coeff1_sign; + __m128i qcoeff0, qcoeff1; + __m128i qtmp0, qtmp1; + + coeff0 = _mm_load_si128((const __m128i *)(coeff_ptr + n_coeffs)); + coeff1 = _mm_load_si128((const __m128i *)(coeff_ptr + n_coeffs) + 1); + + // Poor man's sign extract + coeff0_sign = _mm_srai_epi16(coeff0, 15); + coeff1_sign = _mm_srai_epi16(coeff1, 15); + qcoeff0 = _mm_xor_si128(coeff0, coeff0_sign); + qcoeff1 = _mm_xor_si128(coeff1, coeff1_sign); + qcoeff0 = _mm_sub_epi16(qcoeff0, coeff0_sign); + qcoeff1 = _mm_sub_epi16(qcoeff1, coeff1_sign); + + nzflag = _mm_movemask_epi8(_mm_cmpgt_epi16(qcoeff0, thr)) | + _mm_movemask_epi8(_mm_cmpgt_epi16(qcoeff1, thr)); + + if (nzflag) { + qcoeff0 = _mm_adds_epi16(qcoeff0, round); + qcoeff1 = _mm_adds_epi16(qcoeff1, round); + qtmp0 = _mm_mulhi_epi16(qcoeff0, quant); + qtmp1 = _mm_mulhi_epi16(qcoeff1, quant); + + // Reinsert signs + qcoeff0 = _mm_xor_si128(qtmp0, coeff0_sign); + qcoeff1 = _mm_xor_si128(qtmp1, coeff1_sign); + qcoeff0 = _mm_sub_epi16(qcoeff0, coeff0_sign); + qcoeff1 = _mm_sub_epi16(qcoeff1, coeff1_sign); + + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs), qcoeff0); + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs) + 1, qcoeff1); + + coeff0 = _mm_mullo_epi16(qcoeff0, dequant); + coeff1 = _mm_mullo_epi16(qcoeff1, dequant); + + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs), coeff0); + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs) + 1, coeff1); + } else { + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs), zero); + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs) + 1, zero); + + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs), zero); + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs) + 1, zero); + } + } + + if (nzflag) { + // Scan for eob + __m128i zero_coeff0, zero_coeff1; + __m128i nzero_coeff0, nzero_coeff1; + __m128i iscan0, iscan1; + __m128i eob0, eob1; + zero_coeff0 = _mm_cmpeq_epi16(coeff0, zero); + zero_coeff1 = _mm_cmpeq_epi16(coeff1, zero); + nzero_coeff0 = _mm_cmpeq_epi16(zero_coeff0, zero); + nzero_coeff1 = _mm_cmpeq_epi16(zero_coeff1, zero); + iscan0 = _mm_load_si128((const __m128i *)(iscan_ptr + n_coeffs)); + iscan1 = _mm_load_si128((const __m128i *)(iscan_ptr + n_coeffs) + 1); + // Add one to convert from indices to counts + iscan0 = _mm_sub_epi16(iscan0, nzero_coeff0); + iscan1 = _mm_sub_epi16(iscan1, nzero_coeff1); + eob0 = _mm_and_si128(iscan0, nzero_coeff0); + eob1 = _mm_and_si128(iscan1, nzero_coeff1); + eob0 = _mm_max_epi16(eob0, eob1); + eob = _mm_max_epi16(eob, eob0); + } + n_coeffs += 8 * 2; + } + + // Accumulate EOB + { + __m128i eob_shuffled; + eob_shuffled = _mm_shuffle_epi32(eob, 0xe); + eob = _mm_max_epi16(eob, eob_shuffled); + eob_shuffled = _mm_shufflelo_epi16(eob, 0xe); + eob = _mm_max_epi16(eob, eob_shuffled); + eob_shuffled = _mm_shufflelo_epi16(eob, 0x1); + eob = _mm_max_epi16(eob, eob_shuffled); + *eob_ptr = _mm_extract_epi16(eob, 1); + } + } else { + do { + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs), zero); + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs) + 1, zero); + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs), zero); + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs) + 1, zero); + n_coeffs += 8 * 2; + } while (n_coeffs < 0); + *eob_ptr = 0; + } +} diff --git a/third_party/aom/av1/encoder/x86/av1_quantize_ssse3_x86_64.asm b/third_party/aom/av1/encoder/x86/av1_quantize_ssse3_x86_64.asm new file mode 100644 index 000000000..ad4ae274e --- /dev/null +++ b/third_party/aom/av1/encoder/x86/av1_quantize_ssse3_x86_64.asm @@ -0,0 +1,204 @@ +; +; Copyright (c) 2016, Alliance for Open Media. All rights reserved +; +; This source code is subject to the terms of the BSD 2 Clause License and +; the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +; was not distributed with this source code in the LICENSE file, you can +; obtain it at www.aomedia.org/license/software. If the Alliance for Open +; Media Patent License 1.0 was not distributed with this source code in the +; PATENTS file, you can obtain it at www.aomedia.org/license/patent. +; + +; + +%define private_prefix av1 + +%include "third_party/x86inc/x86inc.asm" + +SECTION_RODATA +pw_1: times 8 dw 1 + +SECTION .text + +%macro QUANTIZE_FP 2 +cglobal quantize_%1, 0, %2, 15, coeff, ncoeff, skip, zbin, round, quant, \ + shift, qcoeff, dqcoeff, dequant, \ + eob, scan, iscan + cmp dword skipm, 0 + jne .blank + + ; actual quantize loop - setup pointers, rounders, etc. + movifnidn coeffq, coeffmp + movifnidn ncoeffq, ncoeffmp + mov r2, dequantmp + movifnidn zbinq, zbinmp + movifnidn roundq, roundmp + movifnidn quantq, quantmp + mova m1, [roundq] ; m1 = round + mova m2, [quantq] ; m2 = quant +%ifidn %1, fp_32x32 + pcmpeqw m5, m5 + psrlw m5, 15 + paddw m1, m5 + psrlw m1, 1 ; m1 = (m1 + 1) / 2 +%endif + mova m3, [r2q] ; m3 = dequant + mov r3, qcoeffmp + mov r4, dqcoeffmp + mov r5, iscanmp +%ifidn %1, fp_32x32 + psllw m2, 1 +%endif + pxor m5, m5 ; m5 = dedicated zero + + lea coeffq, [ coeffq+ncoeffq*2] + lea r5q, [ r5q+ncoeffq*2] + lea r3q, [ r3q+ncoeffq*2] + lea r4q, [r4q+ncoeffq*2] + neg ncoeffq + + ; get DC and first 15 AC coeffs + mova m9, [ coeffq+ncoeffq*2+ 0] ; m9 = c[i] + mova m10, [ coeffq+ncoeffq*2+16] ; m10 = c[i] + pabsw m6, m9 ; m6 = abs(m9) + pabsw m11, m10 ; m11 = abs(m10) + pcmpeqw m7, m7 + + paddsw m6, m1 ; m6 += round + punpckhqdq m1, m1 + paddsw m11, m1 ; m11 += round + pmulhw m8, m6, m2 ; m8 = m6*q>>16 + punpckhqdq m2, m2 + pmulhw m13, m11, m2 ; m13 = m11*q>>16 + psignw m8, m9 ; m8 = reinsert sign + psignw m13, m10 ; m13 = reinsert sign + mova [r3q+ncoeffq*2+ 0], m8 + mova [r3q+ncoeffq*2+16], m13 +%ifidn %1, fp_32x32 + pabsw m8, m8 + pabsw m13, m13 +%endif + pmullw m8, m3 ; r4[i] = r3[i] * q + punpckhqdq m3, m3 + pmullw m13, m3 ; r4[i] = r3[i] * q +%ifidn %1, fp_32x32 + psrlw m8, 1 + psrlw m13, 1 + psignw m8, m9 + psignw m13, m10 + psrlw m0, m3, 2 +%else + psrlw m0, m3, 1 +%endif + mova [r4q+ncoeffq*2+ 0], m8 + mova [r4q+ncoeffq*2+16], m13 + pcmpeqw m8, m5 ; m8 = c[i] == 0 + pcmpeqw m13, m5 ; m13 = c[i] == 0 + mova m6, [ r5q+ncoeffq*2+ 0] ; m6 = scan[i] + mova m11, [ r5q+ncoeffq*2+16] ; m11 = scan[i] + psubw m6, m7 ; m6 = scan[i] + 1 + psubw m11, m7 ; m11 = scan[i] + 1 + pandn m8, m6 ; m8 = max(eob) + pandn m13, m11 ; m13 = max(eob) + pmaxsw m8, m13 + add ncoeffq, mmsize + jz .accumulate_eob + +.ac_only_loop: + mova m9, [ coeffq+ncoeffq*2+ 0] ; m9 = c[i] + mova m10, [ coeffq+ncoeffq*2+16] ; m10 = c[i] + pabsw m6, m9 ; m6 = abs(m9) + pabsw m11, m10 ; m11 = abs(m10) + + pcmpgtw m7, m6, m0 + pcmpgtw m12, m11, m0 + pmovmskb r6d, m7 + pmovmskb r2d, m12 + + or r6, r2 + jz .skip_iter + + pcmpeqw m7, m7 + + paddsw m6, m1 ; m6 += round + paddsw m11, m1 ; m11 += round + pmulhw m14, m6, m2 ; m14 = m6*q>>16 + pmulhw m13, m11, m2 ; m13 = m11*q>>16 + psignw m14, m9 ; m14 = reinsert sign + psignw m13, m10 ; m13 = reinsert sign + mova [r3q+ncoeffq*2+ 0], m14 + mova [r3q+ncoeffq*2+16], m13 +%ifidn %1, fp_32x32 + pabsw m14, m14 + pabsw m13, m13 +%endif + pmullw m14, m3 ; r4[i] = r3[i] * q + pmullw m13, m3 ; r4[i] = r3[i] * q +%ifidn %1, fp_32x32 + psrlw m14, 1 + psrlw m13, 1 + psignw m14, m9 + psignw m13, m10 +%endif + mova [r4q+ncoeffq*2+ 0], m14 + mova [r4q+ncoeffq*2+16], m13 + pcmpeqw m14, m5 ; m14 = c[i] == 0 + pcmpeqw m13, m5 ; m13 = c[i] == 0 + mova m6, [ r5q+ncoeffq*2+ 0] ; m6 = scan[i] + mova m11, [ r5q+ncoeffq*2+16] ; m11 = scan[i] + psubw m6, m7 ; m6 = scan[i] + 1 + psubw m11, m7 ; m11 = scan[i] + 1 + pandn m14, m6 ; m14 = max(eob) + pandn m13, m11 ; m13 = max(eob) + pmaxsw m8, m14 + pmaxsw m8, m13 + add ncoeffq, mmsize + jl .ac_only_loop + + jmp .accumulate_eob +.skip_iter: + mova [r3q+ncoeffq*2+ 0], m5 + mova [r3q+ncoeffq*2+16], m5 + mova [r4q+ncoeffq*2+ 0], m5 + mova [r4q+ncoeffq*2+16], m5 + add ncoeffq, mmsize + jl .ac_only_loop + +.accumulate_eob: + ; horizontally accumulate/max eobs and write into [eob] memory pointer + mov r2, eobmp + pshufd m7, m8, 0xe + pmaxsw m8, m7 + pshuflw m7, m8, 0xe + pmaxsw m8, m7 + pshuflw m7, m8, 0x1 + pmaxsw m8, m7 + pextrw r6, m8, 0 + mov [r2], r6 + RET + + ; skip-block, i.e. just write all zeroes +.blank: + mov r0, dqcoeffmp + movifnidn ncoeffq, ncoeffmp + mov r2, qcoeffmp + mov r3, eobmp + + lea r0q, [r0q+ncoeffq*2] + lea r2q, [r2q+ncoeffq*2] + neg ncoeffq + pxor m7, m7 +.blank_loop: + mova [r0q+ncoeffq*2+ 0], m7 + mova [r0q+ncoeffq*2+16], m7 + mova [r2q+ncoeffq*2+ 0], m7 + mova [r2q+ncoeffq*2+16], m7 + add ncoeffq, mmsize + jl .blank_loop + mov word [r3q], 0 + RET +%endmacro + +INIT_XMM ssse3 +QUANTIZE_FP fp, 7 +QUANTIZE_FP fp_32x32, 7 diff --git a/third_party/aom/av1/encoder/x86/av1_ssim_opt_x86_64.asm b/third_party/aom/av1/encoder/x86/av1_ssim_opt_x86_64.asm new file mode 100644 index 000000000..dcc697ba3 --- /dev/null +++ b/third_party/aom/av1/encoder/x86/av1_ssim_opt_x86_64.asm @@ -0,0 +1,219 @@ +; +; Copyright (c) 2016, Alliance for Open Media. All rights reserved +; +; This source code is subject to the terms of the BSD 2 Clause License and +; the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +; was not distributed with this source code in the LICENSE file, you can +; obtain it at www.aomedia.org/license/software. If the Alliance for Open +; Media Patent License 1.0 was not distributed with this source code in the +; PATENTS file, you can obtain it at www.aomedia.org/license/patent. +; + +; + +%include "aom_ports/x86_abi_support.asm" + +; tabulate_ssim - sums sum_s,sum_r,sum_sq_s,sum_sq_r, sum_sxr +%macro TABULATE_SSIM 0 + paddusw xmm15, xmm3 ; sum_s + paddusw xmm14, xmm4 ; sum_r + movdqa xmm1, xmm3 + pmaddwd xmm1, xmm1 + paddd xmm13, xmm1 ; sum_sq_s + movdqa xmm2, xmm4 + pmaddwd xmm2, xmm2 + paddd xmm12, xmm2 ; sum_sq_r + pmaddwd xmm3, xmm4 + paddd xmm11, xmm3 ; sum_sxr +%endmacro + +; Sum across the register %1 starting with q words +%macro SUM_ACROSS_Q 1 + movdqa xmm2,%1 + punpckldq %1,xmm0 + punpckhdq xmm2,xmm0 + paddq %1,xmm2 + movdqa xmm2,%1 + punpcklqdq %1,xmm0 + punpckhqdq xmm2,xmm0 + paddq %1,xmm2 +%endmacro + +; Sum across the register %1 starting with q words +%macro SUM_ACROSS_W 1 + movdqa xmm1, %1 + punpcklwd %1,xmm0 + punpckhwd xmm1,xmm0 + paddd %1, xmm1 + SUM_ACROSS_Q %1 +%endmacro +;void ssim_parms_sse2( +; unsigned char *s, +; int sp, +; unsigned char *r, +; int rp +; unsigned long *sum_s, +; unsigned long *sum_r, +; unsigned long *sum_sq_s, +; unsigned long *sum_sq_r, +; unsigned long *sum_sxr); +; +; TODO: Use parm passing through structure, probably don't need the pxors +; ( calling app will initialize to 0 ) could easily fit everything in sse2 +; without too much hastle, and can probably do better estimates with psadw +; or pavgb At this point this is just meant to be first pass for calculating +; all the parms needed for 16x16 ssim so we can play with dssim as distortion +; in mode selection code. +global sym(av1_ssim_parms_16x16_sse2) PRIVATE +sym(av1_ssim_parms_16x16_sse2): + push rbp + mov rbp, rsp + SHADOW_ARGS_TO_STACK 9 + SAVE_XMM 15 + push rsi + push rdi + ; end prolog + + mov rsi, arg(0) ;s + mov rcx, arg(1) ;sp + mov rdi, arg(2) ;r + mov rax, arg(3) ;rp + + pxor xmm0, xmm0 + pxor xmm15,xmm15 ;sum_s + pxor xmm14,xmm14 ;sum_r + pxor xmm13,xmm13 ;sum_sq_s + pxor xmm12,xmm12 ;sum_sq_r + pxor xmm11,xmm11 ;sum_sxr + + mov rdx, 16 ;row counter +.NextRow: + + ;grab source and reference pixels + movdqu xmm5, [rsi] + movdqu xmm6, [rdi] + movdqa xmm3, xmm5 + movdqa xmm4, xmm6 + punpckhbw xmm3, xmm0 ; high_s + punpckhbw xmm4, xmm0 ; high_r + + TABULATE_SSIM + + movdqa xmm3, xmm5 + movdqa xmm4, xmm6 + punpcklbw xmm3, xmm0 ; low_s + punpcklbw xmm4, xmm0 ; low_r + + TABULATE_SSIM + + add rsi, rcx ; next s row + add rdi, rax ; next r row + + dec rdx ; counter + jnz .NextRow + + SUM_ACROSS_W xmm15 + SUM_ACROSS_W xmm14 + SUM_ACROSS_Q xmm13 + SUM_ACROSS_Q xmm12 + SUM_ACROSS_Q xmm11 + + mov rdi,arg(4) + movd [rdi], xmm15; + mov rdi,arg(5) + movd [rdi], xmm14; + mov rdi,arg(6) + movd [rdi], xmm13; + mov rdi,arg(7) + movd [rdi], xmm12; + mov rdi,arg(8) + movd [rdi], xmm11; + + ; begin epilog + pop rdi + pop rsi + RESTORE_XMM + UNSHADOW_ARGS + pop rbp + ret + +;void ssim_parms_sse2( +; unsigned char *s, +; int sp, +; unsigned char *r, +; int rp +; unsigned long *sum_s, +; unsigned long *sum_r, +; unsigned long *sum_sq_s, +; unsigned long *sum_sq_r, +; unsigned long *sum_sxr); +; +; TODO: Use parm passing through structure, probably don't need the pxors +; ( calling app will initialize to 0 ) could easily fit everything in sse2 +; without too much hastle, and can probably do better estimates with psadw +; or pavgb At this point this is just meant to be first pass for calculating +; all the parms needed for 16x16 ssim so we can play with dssim as distortion +; in mode selection code. +global sym(av1_ssim_parms_8x8_sse2) PRIVATE +sym(av1_ssim_parms_8x8_sse2): + push rbp + mov rbp, rsp + SHADOW_ARGS_TO_STACK 9 + SAVE_XMM 15 + push rsi + push rdi + ; end prolog + + mov rsi, arg(0) ;s + mov rcx, arg(1) ;sp + mov rdi, arg(2) ;r + mov rax, arg(3) ;rp + + pxor xmm0, xmm0 + pxor xmm15,xmm15 ;sum_s + pxor xmm14,xmm14 ;sum_r + pxor xmm13,xmm13 ;sum_sq_s + pxor xmm12,xmm12 ;sum_sq_r + pxor xmm11,xmm11 ;sum_sxr + + mov rdx, 8 ;row counter +.NextRow: + + ;grab source and reference pixels + movq xmm3, [rsi] + movq xmm4, [rdi] + punpcklbw xmm3, xmm0 ; low_s + punpcklbw xmm4, xmm0 ; low_r + + TABULATE_SSIM + + add rsi, rcx ; next s row + add rdi, rax ; next r row + + dec rdx ; counter + jnz .NextRow + + SUM_ACROSS_W xmm15 + SUM_ACROSS_W xmm14 + SUM_ACROSS_Q xmm13 + SUM_ACROSS_Q xmm12 + SUM_ACROSS_Q xmm11 + + mov rdi,arg(4) + movd [rdi], xmm15; + mov rdi,arg(5) + movd [rdi], xmm14; + mov rdi,arg(6) + movd [rdi], xmm13; + mov rdi,arg(7) + movd [rdi], xmm12; + mov rdi,arg(8) + movd [rdi], xmm11; + + ; begin epilog + pop rdi + pop rsi + RESTORE_XMM + UNSHADOW_ARGS + pop rbp + ret diff --git a/third_party/aom/av1/encoder/x86/dct_intrin_sse2.c b/third_party/aom/av1/encoder/x86/dct_intrin_sse2.c new file mode 100644 index 000000000..37c4b0d88 --- /dev/null +++ b/third_party/aom/av1/encoder/x86/dct_intrin_sse2.c @@ -0,0 +1,3884 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <assert.h> +#include <emmintrin.h> // SSE2 + +#include "./aom_dsp_rtcd.h" +#include "./av1_rtcd.h" +#include "aom_dsp/txfm_common.h" +#include "aom_dsp/x86/fwd_txfm_sse2.h" +#include "aom_dsp/x86/synonyms.h" +#include "aom_dsp/x86/txfm_common_sse2.h" +#include "aom_ports/mem.h" + +static INLINE void load_buffer_4x4(const int16_t *input, __m128i *in, + int stride, int flipud, int fliplr) { + const __m128i k__nonzero_bias_a = _mm_setr_epi16(0, 1, 1, 1, 1, 1, 1, 1); + const __m128i k__nonzero_bias_b = _mm_setr_epi16(1, 0, 0, 0, 0, 0, 0, 0); + __m128i mask; + + if (!flipud) { + in[0] = _mm_loadl_epi64((const __m128i *)(input + 0 * stride)); + in[1] = _mm_loadl_epi64((const __m128i *)(input + 1 * stride)); + in[2] = _mm_loadl_epi64((const __m128i *)(input + 2 * stride)); + in[3] = _mm_loadl_epi64((const __m128i *)(input + 3 * stride)); + } else { + in[0] = _mm_loadl_epi64((const __m128i *)(input + 3 * stride)); + in[1] = _mm_loadl_epi64((const __m128i *)(input + 2 * stride)); + in[2] = _mm_loadl_epi64((const __m128i *)(input + 1 * stride)); + in[3] = _mm_loadl_epi64((const __m128i *)(input + 0 * stride)); + } + + if (fliplr) { + in[0] = _mm_shufflelo_epi16(in[0], 0x1b); + in[1] = _mm_shufflelo_epi16(in[1], 0x1b); + in[2] = _mm_shufflelo_epi16(in[2], 0x1b); + in[3] = _mm_shufflelo_epi16(in[3], 0x1b); + } + + in[0] = _mm_slli_epi16(in[0], 4); + in[1] = _mm_slli_epi16(in[1], 4); + in[2] = _mm_slli_epi16(in[2], 4); + in[3] = _mm_slli_epi16(in[3], 4); + + mask = _mm_cmpeq_epi16(in[0], k__nonzero_bias_a); + in[0] = _mm_add_epi16(in[0], mask); + in[0] = _mm_add_epi16(in[0], k__nonzero_bias_b); +} + +static INLINE void write_buffer_4x4(tran_low_t *output, __m128i *res) { + const __m128i kOne = _mm_set1_epi16(1); + __m128i in01 = _mm_unpacklo_epi64(res[0], res[1]); + __m128i in23 = _mm_unpacklo_epi64(res[2], res[3]); + __m128i out01 = _mm_add_epi16(in01, kOne); + __m128i out23 = _mm_add_epi16(in23, kOne); + out01 = _mm_srai_epi16(out01, 2); + out23 = _mm_srai_epi16(out23, 2); + store_output(&out01, (output + 0 * 8)); + store_output(&out23, (output + 1 * 8)); +} + +static INLINE void transpose_4x4(__m128i *res) { + // Combine and transpose + // 00 01 02 03 20 21 22 23 + // 10 11 12 13 30 31 32 33 + const __m128i tr0_0 = _mm_unpacklo_epi16(res[0], res[1]); + const __m128i tr0_1 = _mm_unpackhi_epi16(res[0], res[1]); + + // 00 10 01 11 02 12 03 13 + // 20 30 21 31 22 32 23 33 + res[0] = _mm_unpacklo_epi32(tr0_0, tr0_1); + res[2] = _mm_unpackhi_epi32(tr0_0, tr0_1); + + // 00 10 20 30 01 11 21 31 + // 02 12 22 32 03 13 23 33 + // only use the first 4 16-bit integers + res[1] = _mm_unpackhi_epi64(res[0], res[0]); + res[3] = _mm_unpackhi_epi64(res[2], res[2]); +} + +static void fdct4_sse2(__m128i *in) { + const __m128i k__cospi_p16_p16 = _mm_set1_epi16((int16_t)cospi_16_64); + const __m128i k__cospi_p16_m16 = pair_set_epi16(cospi_16_64, -cospi_16_64); + const __m128i k__cospi_p08_p24 = pair_set_epi16(cospi_8_64, cospi_24_64); + const __m128i k__cospi_p24_m08 = pair_set_epi16(cospi_24_64, -cospi_8_64); + const __m128i k__DCT_CONST_ROUNDING = _mm_set1_epi32(DCT_CONST_ROUNDING); + + __m128i u[4], v[4]; + u[0] = _mm_unpacklo_epi16(in[0], in[1]); + u[1] = _mm_unpacklo_epi16(in[3], in[2]); + + v[0] = _mm_add_epi16(u[0], u[1]); + v[1] = _mm_sub_epi16(u[0], u[1]); + + u[0] = _mm_madd_epi16(v[0], k__cospi_p16_p16); // 0 + u[1] = _mm_madd_epi16(v[0], k__cospi_p16_m16); // 2 + u[2] = _mm_madd_epi16(v[1], k__cospi_p08_p24); // 1 + u[3] = _mm_madd_epi16(v[1], k__cospi_p24_m08); // 3 + + v[0] = _mm_add_epi32(u[0], k__DCT_CONST_ROUNDING); + v[1] = _mm_add_epi32(u[1], k__DCT_CONST_ROUNDING); + v[2] = _mm_add_epi32(u[2], k__DCT_CONST_ROUNDING); + v[3] = _mm_add_epi32(u[3], k__DCT_CONST_ROUNDING); + u[0] = _mm_srai_epi32(v[0], DCT_CONST_BITS); + u[1] = _mm_srai_epi32(v[1], DCT_CONST_BITS); + u[2] = _mm_srai_epi32(v[2], DCT_CONST_BITS); + u[3] = _mm_srai_epi32(v[3], DCT_CONST_BITS); + + in[0] = _mm_packs_epi32(u[0], u[1]); + in[1] = _mm_packs_epi32(u[2], u[3]); + transpose_4x4(in); +} + +static void fadst4_sse2(__m128i *in) { + const __m128i k__sinpi_p01_p02 = pair_set_epi16(sinpi_1_9, sinpi_2_9); + const __m128i k__sinpi_p04_m01 = pair_set_epi16(sinpi_4_9, -sinpi_1_9); + const __m128i k__sinpi_p03_p04 = pair_set_epi16(sinpi_3_9, sinpi_4_9); + const __m128i k__sinpi_m03_p02 = pair_set_epi16(-sinpi_3_9, sinpi_2_9); + const __m128i k__sinpi_p03_p03 = _mm_set1_epi16((int16_t)sinpi_3_9); + const __m128i kZero = _mm_set1_epi16(0); + const __m128i k__DCT_CONST_ROUNDING = _mm_set1_epi32(DCT_CONST_ROUNDING); + __m128i u[8], v[8]; + __m128i in7 = _mm_add_epi16(in[0], in[1]); + + u[0] = _mm_unpacklo_epi16(in[0], in[1]); + u[1] = _mm_unpacklo_epi16(in[2], in[3]); + u[2] = _mm_unpacklo_epi16(in7, kZero); + u[3] = _mm_unpacklo_epi16(in[2], kZero); + u[4] = _mm_unpacklo_epi16(in[3], kZero); + + v[0] = _mm_madd_epi16(u[0], k__sinpi_p01_p02); // s0 + s2 + v[1] = _mm_madd_epi16(u[1], k__sinpi_p03_p04); // s4 + s5 + v[2] = _mm_madd_epi16(u[2], k__sinpi_p03_p03); // x1 + v[3] = _mm_madd_epi16(u[0], k__sinpi_p04_m01); // s1 - s3 + v[4] = _mm_madd_epi16(u[1], k__sinpi_m03_p02); // -s4 + s6 + v[5] = _mm_madd_epi16(u[3], k__sinpi_p03_p03); // s4 + v[6] = _mm_madd_epi16(u[4], k__sinpi_p03_p03); + + u[0] = _mm_add_epi32(v[0], v[1]); + u[1] = _mm_sub_epi32(v[2], v[6]); + u[2] = _mm_add_epi32(v[3], v[4]); + u[3] = _mm_sub_epi32(u[2], u[0]); + u[4] = _mm_slli_epi32(v[5], 2); + u[5] = _mm_sub_epi32(u[4], v[5]); + u[6] = _mm_add_epi32(u[3], u[5]); + + v[0] = _mm_add_epi32(u[0], k__DCT_CONST_ROUNDING); + v[1] = _mm_add_epi32(u[1], k__DCT_CONST_ROUNDING); + v[2] = _mm_add_epi32(u[2], k__DCT_CONST_ROUNDING); + v[3] = _mm_add_epi32(u[6], k__DCT_CONST_ROUNDING); + + u[0] = _mm_srai_epi32(v[0], DCT_CONST_BITS); + u[1] = _mm_srai_epi32(v[1], DCT_CONST_BITS); + u[2] = _mm_srai_epi32(v[2], DCT_CONST_BITS); + u[3] = _mm_srai_epi32(v[3], DCT_CONST_BITS); + + in[0] = _mm_packs_epi32(u[0], u[2]); + in[1] = _mm_packs_epi32(u[1], u[3]); + transpose_4x4(in); +} + +#if CONFIG_EXT_TX +static void fidtx4_sse2(__m128i *in) { + const __m128i k__zero_epi16 = _mm_set1_epi16((int16_t)0); + const __m128i k__sqrt2_epi16 = _mm_set1_epi16((int16_t)Sqrt2); + const __m128i k__DCT_CONST_ROUNDING = _mm_set1_epi32(DCT_CONST_ROUNDING); + + __m128i v0, v1, v2, v3; + __m128i u0, u1, u2, u3; + + v0 = _mm_unpacklo_epi16(in[0], k__zero_epi16); + v1 = _mm_unpacklo_epi16(in[1], k__zero_epi16); + v2 = _mm_unpacklo_epi16(in[2], k__zero_epi16); + v3 = _mm_unpacklo_epi16(in[3], k__zero_epi16); + + u0 = _mm_madd_epi16(v0, k__sqrt2_epi16); + u1 = _mm_madd_epi16(v1, k__sqrt2_epi16); + u2 = _mm_madd_epi16(v2, k__sqrt2_epi16); + u3 = _mm_madd_epi16(v3, k__sqrt2_epi16); + + v0 = _mm_add_epi32(u0, k__DCT_CONST_ROUNDING); + v1 = _mm_add_epi32(u1, k__DCT_CONST_ROUNDING); + v2 = _mm_add_epi32(u2, k__DCT_CONST_ROUNDING); + v3 = _mm_add_epi32(u3, k__DCT_CONST_ROUNDING); + + u0 = _mm_srai_epi32(v0, DCT_CONST_BITS); + u1 = _mm_srai_epi32(v1, DCT_CONST_BITS); + u2 = _mm_srai_epi32(v2, DCT_CONST_BITS); + u3 = _mm_srai_epi32(v3, DCT_CONST_BITS); + + in[0] = _mm_packs_epi32(u0, u2); + in[1] = _mm_packs_epi32(u1, u3); + transpose_4x4(in); +} +#endif // CONFIG_EXT_TX + +void av1_fht4x4_sse2(const int16_t *input, tran_low_t *output, int stride, + int tx_type) { + __m128i in[4]; + + switch (tx_type) { + case DCT_DCT: aom_fdct4x4_sse2(input, output, stride); break; + case ADST_DCT: + load_buffer_4x4(input, in, stride, 0, 0); + fadst4_sse2(in); + fdct4_sse2(in); + write_buffer_4x4(output, in); + break; + case DCT_ADST: + load_buffer_4x4(input, in, stride, 0, 0); + fdct4_sse2(in); + fadst4_sse2(in); + write_buffer_4x4(output, in); + break; + case ADST_ADST: + load_buffer_4x4(input, in, stride, 0, 0); + fadst4_sse2(in); + fadst4_sse2(in); + write_buffer_4x4(output, in); + break; +#if CONFIG_EXT_TX + case FLIPADST_DCT: + load_buffer_4x4(input, in, stride, 1, 0); + fadst4_sse2(in); + fdct4_sse2(in); + write_buffer_4x4(output, in); + break; + case DCT_FLIPADST: + load_buffer_4x4(input, in, stride, 0, 1); + fdct4_sse2(in); + fadst4_sse2(in); + write_buffer_4x4(output, in); + break; + case FLIPADST_FLIPADST: + load_buffer_4x4(input, in, stride, 1, 1); + fadst4_sse2(in); + fadst4_sse2(in); + write_buffer_4x4(output, in); + break; + case ADST_FLIPADST: + load_buffer_4x4(input, in, stride, 0, 1); + fadst4_sse2(in); + fadst4_sse2(in); + write_buffer_4x4(output, in); + break; + case FLIPADST_ADST: + load_buffer_4x4(input, in, stride, 1, 0); + fadst4_sse2(in); + fadst4_sse2(in); + write_buffer_4x4(output, in); + break; + case IDTX: + load_buffer_4x4(input, in, stride, 0, 0); + fidtx4_sse2(in); + fidtx4_sse2(in); + write_buffer_4x4(output, in); + break; + case V_DCT: + load_buffer_4x4(input, in, stride, 0, 0); + fdct4_sse2(in); + fidtx4_sse2(in); + write_buffer_4x4(output, in); + break; + case H_DCT: + load_buffer_4x4(input, in, stride, 0, 0); + fidtx4_sse2(in); + fdct4_sse2(in); + write_buffer_4x4(output, in); + break; + case V_ADST: + load_buffer_4x4(input, in, stride, 0, 0); + fadst4_sse2(in); + fidtx4_sse2(in); + write_buffer_4x4(output, in); + break; + case H_ADST: + load_buffer_4x4(input, in, stride, 0, 0); + fidtx4_sse2(in); + fadst4_sse2(in); + write_buffer_4x4(output, in); + break; + case V_FLIPADST: + load_buffer_4x4(input, in, stride, 1, 0); + fadst4_sse2(in); + fidtx4_sse2(in); + write_buffer_4x4(output, in); + break; + case H_FLIPADST: + load_buffer_4x4(input, in, stride, 0, 1); + fidtx4_sse2(in); + fadst4_sse2(in); + write_buffer_4x4(output, in); + break; +#endif // CONFIG_EXT_TX + default: assert(0); + } +} + +void av1_fdct8x8_quant_sse2(const int16_t *input, int stride, + int16_t *coeff_ptr, intptr_t n_coeffs, + int skip_block, const int16_t *zbin_ptr, + const int16_t *round_ptr, const int16_t *quant_ptr, + const int16_t *quant_shift_ptr, int16_t *qcoeff_ptr, + int16_t *dqcoeff_ptr, const int16_t *dequant_ptr, + uint16_t *eob_ptr, const int16_t *scan_ptr, + const int16_t *iscan_ptr) { + __m128i zero; + int pass; + // Constants + // When we use them, in one case, they are all the same. In all others + // it's a pair of them that we need to repeat four times. This is done + // by constructing the 32 bit constant corresponding to that pair. + const __m128i k__cospi_p16_p16 = _mm_set1_epi16((int16_t)cospi_16_64); + const __m128i k__cospi_p16_m16 = pair_set_epi16(cospi_16_64, -cospi_16_64); + const __m128i k__cospi_p24_p08 = pair_set_epi16(cospi_24_64, cospi_8_64); + const __m128i k__cospi_m08_p24 = pair_set_epi16(-cospi_8_64, cospi_24_64); + const __m128i k__cospi_p28_p04 = pair_set_epi16(cospi_28_64, cospi_4_64); + const __m128i k__cospi_m04_p28 = pair_set_epi16(-cospi_4_64, cospi_28_64); + const __m128i k__cospi_p12_p20 = pair_set_epi16(cospi_12_64, cospi_20_64); + const __m128i k__cospi_m20_p12 = pair_set_epi16(-cospi_20_64, cospi_12_64); + const __m128i k__DCT_CONST_ROUNDING = _mm_set1_epi32(DCT_CONST_ROUNDING); + // Load input + __m128i in0 = _mm_load_si128((const __m128i *)(input + 0 * stride)); + __m128i in1 = _mm_load_si128((const __m128i *)(input + 1 * stride)); + __m128i in2 = _mm_load_si128((const __m128i *)(input + 2 * stride)); + __m128i in3 = _mm_load_si128((const __m128i *)(input + 3 * stride)); + __m128i in4 = _mm_load_si128((const __m128i *)(input + 4 * stride)); + __m128i in5 = _mm_load_si128((const __m128i *)(input + 5 * stride)); + __m128i in6 = _mm_load_si128((const __m128i *)(input + 6 * stride)); + __m128i in7 = _mm_load_si128((const __m128i *)(input + 7 * stride)); + __m128i *in[8]; + int index = 0; + + (void)scan_ptr; + (void)zbin_ptr; + (void)quant_shift_ptr; + (void)coeff_ptr; + + // Pre-condition input (shift by two) + in0 = _mm_slli_epi16(in0, 2); + in1 = _mm_slli_epi16(in1, 2); + in2 = _mm_slli_epi16(in2, 2); + in3 = _mm_slli_epi16(in3, 2); + in4 = _mm_slli_epi16(in4, 2); + in5 = _mm_slli_epi16(in5, 2); + in6 = _mm_slli_epi16(in6, 2); + in7 = _mm_slli_epi16(in7, 2); + + in[0] = &in0; + in[1] = &in1; + in[2] = &in2; + in[3] = &in3; + in[4] = &in4; + in[5] = &in5; + in[6] = &in6; + in[7] = &in7; + + // We do two passes, first the columns, then the rows. The results of the + // first pass are transposed so that the same column code can be reused. The + // results of the second pass are also transposed so that the rows (processed + // as columns) are put back in row positions. + for (pass = 0; pass < 2; pass++) { + // To store results of each pass before the transpose. + __m128i res0, res1, res2, res3, res4, res5, res6, res7; + // Add/subtract + const __m128i q0 = _mm_add_epi16(in0, in7); + const __m128i q1 = _mm_add_epi16(in1, in6); + const __m128i q2 = _mm_add_epi16(in2, in5); + const __m128i q3 = _mm_add_epi16(in3, in4); + const __m128i q4 = _mm_sub_epi16(in3, in4); + const __m128i q5 = _mm_sub_epi16(in2, in5); + const __m128i q6 = _mm_sub_epi16(in1, in6); + const __m128i q7 = _mm_sub_epi16(in0, in7); + // Work on first four results + { + // Add/subtract + const __m128i r0 = _mm_add_epi16(q0, q3); + const __m128i r1 = _mm_add_epi16(q1, q2); + const __m128i r2 = _mm_sub_epi16(q1, q2); + const __m128i r3 = _mm_sub_epi16(q0, q3); + // Interleave to do the multiply by constants which gets us into 32bits + const __m128i t0 = _mm_unpacklo_epi16(r0, r1); + const __m128i t1 = _mm_unpackhi_epi16(r0, r1); + const __m128i t2 = _mm_unpacklo_epi16(r2, r3); + const __m128i t3 = _mm_unpackhi_epi16(r2, r3); + const __m128i u0 = _mm_madd_epi16(t0, k__cospi_p16_p16); + const __m128i u1 = _mm_madd_epi16(t1, k__cospi_p16_p16); + const __m128i u2 = _mm_madd_epi16(t0, k__cospi_p16_m16); + const __m128i u3 = _mm_madd_epi16(t1, k__cospi_p16_m16); + const __m128i u4 = _mm_madd_epi16(t2, k__cospi_p24_p08); + const __m128i u5 = _mm_madd_epi16(t3, k__cospi_p24_p08); + const __m128i u6 = _mm_madd_epi16(t2, k__cospi_m08_p24); + const __m128i u7 = _mm_madd_epi16(t3, k__cospi_m08_p24); + // dct_const_round_shift + const __m128i v0 = _mm_add_epi32(u0, k__DCT_CONST_ROUNDING); + const __m128i v1 = _mm_add_epi32(u1, k__DCT_CONST_ROUNDING); + const __m128i v2 = _mm_add_epi32(u2, k__DCT_CONST_ROUNDING); + const __m128i v3 = _mm_add_epi32(u3, k__DCT_CONST_ROUNDING); + const __m128i v4 = _mm_add_epi32(u4, k__DCT_CONST_ROUNDING); + const __m128i v5 = _mm_add_epi32(u5, k__DCT_CONST_ROUNDING); + const __m128i v6 = _mm_add_epi32(u6, k__DCT_CONST_ROUNDING); + const __m128i v7 = _mm_add_epi32(u7, k__DCT_CONST_ROUNDING); + const __m128i w0 = _mm_srai_epi32(v0, DCT_CONST_BITS); + const __m128i w1 = _mm_srai_epi32(v1, DCT_CONST_BITS); + const __m128i w2 = _mm_srai_epi32(v2, DCT_CONST_BITS); + const __m128i w3 = _mm_srai_epi32(v3, DCT_CONST_BITS); + const __m128i w4 = _mm_srai_epi32(v4, DCT_CONST_BITS); + const __m128i w5 = _mm_srai_epi32(v5, DCT_CONST_BITS); + const __m128i w6 = _mm_srai_epi32(v6, DCT_CONST_BITS); + const __m128i w7 = _mm_srai_epi32(v7, DCT_CONST_BITS); + // Combine + res0 = _mm_packs_epi32(w0, w1); + res4 = _mm_packs_epi32(w2, w3); + res2 = _mm_packs_epi32(w4, w5); + res6 = _mm_packs_epi32(w6, w7); + } + // Work on next four results + { + // Interleave to do the multiply by constants which gets us into 32bits + const __m128i d0 = _mm_unpacklo_epi16(q6, q5); + const __m128i d1 = _mm_unpackhi_epi16(q6, q5); + const __m128i e0 = _mm_madd_epi16(d0, k__cospi_p16_m16); + const __m128i e1 = _mm_madd_epi16(d1, k__cospi_p16_m16); + const __m128i e2 = _mm_madd_epi16(d0, k__cospi_p16_p16); + const __m128i e3 = _mm_madd_epi16(d1, k__cospi_p16_p16); + // dct_const_round_shift + const __m128i f0 = _mm_add_epi32(e0, k__DCT_CONST_ROUNDING); + const __m128i f1 = _mm_add_epi32(e1, k__DCT_CONST_ROUNDING); + const __m128i f2 = _mm_add_epi32(e2, k__DCT_CONST_ROUNDING); + const __m128i f3 = _mm_add_epi32(e3, k__DCT_CONST_ROUNDING); + const __m128i s0 = _mm_srai_epi32(f0, DCT_CONST_BITS); + const __m128i s1 = _mm_srai_epi32(f1, DCT_CONST_BITS); + const __m128i s2 = _mm_srai_epi32(f2, DCT_CONST_BITS); + const __m128i s3 = _mm_srai_epi32(f3, DCT_CONST_BITS); + // Combine + const __m128i r0 = _mm_packs_epi32(s0, s1); + const __m128i r1 = _mm_packs_epi32(s2, s3); + // Add/subtract + const __m128i x0 = _mm_add_epi16(q4, r0); + const __m128i x1 = _mm_sub_epi16(q4, r0); + const __m128i x2 = _mm_sub_epi16(q7, r1); + const __m128i x3 = _mm_add_epi16(q7, r1); + // Interleave to do the multiply by constants which gets us into 32bits + const __m128i t0 = _mm_unpacklo_epi16(x0, x3); + const __m128i t1 = _mm_unpackhi_epi16(x0, x3); + const __m128i t2 = _mm_unpacklo_epi16(x1, x2); + const __m128i t3 = _mm_unpackhi_epi16(x1, x2); + const __m128i u0 = _mm_madd_epi16(t0, k__cospi_p28_p04); + const __m128i u1 = _mm_madd_epi16(t1, k__cospi_p28_p04); + const __m128i u2 = _mm_madd_epi16(t0, k__cospi_m04_p28); + const __m128i u3 = _mm_madd_epi16(t1, k__cospi_m04_p28); + const __m128i u4 = _mm_madd_epi16(t2, k__cospi_p12_p20); + const __m128i u5 = _mm_madd_epi16(t3, k__cospi_p12_p20); + const __m128i u6 = _mm_madd_epi16(t2, k__cospi_m20_p12); + const __m128i u7 = _mm_madd_epi16(t3, k__cospi_m20_p12); + // dct_const_round_shift + const __m128i v0 = _mm_add_epi32(u0, k__DCT_CONST_ROUNDING); + const __m128i v1 = _mm_add_epi32(u1, k__DCT_CONST_ROUNDING); + const __m128i v2 = _mm_add_epi32(u2, k__DCT_CONST_ROUNDING); + const __m128i v3 = _mm_add_epi32(u3, k__DCT_CONST_ROUNDING); + const __m128i v4 = _mm_add_epi32(u4, k__DCT_CONST_ROUNDING); + const __m128i v5 = _mm_add_epi32(u5, k__DCT_CONST_ROUNDING); + const __m128i v6 = _mm_add_epi32(u6, k__DCT_CONST_ROUNDING); + const __m128i v7 = _mm_add_epi32(u7, k__DCT_CONST_ROUNDING); + const __m128i w0 = _mm_srai_epi32(v0, DCT_CONST_BITS); + const __m128i w1 = _mm_srai_epi32(v1, DCT_CONST_BITS); + const __m128i w2 = _mm_srai_epi32(v2, DCT_CONST_BITS); + const __m128i w3 = _mm_srai_epi32(v3, DCT_CONST_BITS); + const __m128i w4 = _mm_srai_epi32(v4, DCT_CONST_BITS); + const __m128i w5 = _mm_srai_epi32(v5, DCT_CONST_BITS); + const __m128i w6 = _mm_srai_epi32(v6, DCT_CONST_BITS); + const __m128i w7 = _mm_srai_epi32(v7, DCT_CONST_BITS); + // Combine + res1 = _mm_packs_epi32(w0, w1); + res7 = _mm_packs_epi32(w2, w3); + res5 = _mm_packs_epi32(w4, w5); + res3 = _mm_packs_epi32(w6, w7); + } + // Transpose the 8x8. + { + // 00 01 02 03 04 05 06 07 + // 10 11 12 13 14 15 16 17 + // 20 21 22 23 24 25 26 27 + // 30 31 32 33 34 35 36 37 + // 40 41 42 43 44 45 46 47 + // 50 51 52 53 54 55 56 57 + // 60 61 62 63 64 65 66 67 + // 70 71 72 73 74 75 76 77 + const __m128i tr0_0 = _mm_unpacklo_epi16(res0, res1); + const __m128i tr0_1 = _mm_unpacklo_epi16(res2, res3); + const __m128i tr0_2 = _mm_unpackhi_epi16(res0, res1); + const __m128i tr0_3 = _mm_unpackhi_epi16(res2, res3); + const __m128i tr0_4 = _mm_unpacklo_epi16(res4, res5); + const __m128i tr0_5 = _mm_unpacklo_epi16(res6, res7); + const __m128i tr0_6 = _mm_unpackhi_epi16(res4, res5); + const __m128i tr0_7 = _mm_unpackhi_epi16(res6, res7); + // 00 10 01 11 02 12 03 13 + // 20 30 21 31 22 32 23 33 + // 04 14 05 15 06 16 07 17 + // 24 34 25 35 26 36 27 37 + // 40 50 41 51 42 52 43 53 + // 60 70 61 71 62 72 63 73 + // 54 54 55 55 56 56 57 57 + // 64 74 65 75 66 76 67 77 + const __m128i tr1_0 = _mm_unpacklo_epi32(tr0_0, tr0_1); + const __m128i tr1_1 = _mm_unpacklo_epi32(tr0_2, tr0_3); + const __m128i tr1_2 = _mm_unpackhi_epi32(tr0_0, tr0_1); + const __m128i tr1_3 = _mm_unpackhi_epi32(tr0_2, tr0_3); + const __m128i tr1_4 = _mm_unpacklo_epi32(tr0_4, tr0_5); + const __m128i tr1_5 = _mm_unpacklo_epi32(tr0_6, tr0_7); + const __m128i tr1_6 = _mm_unpackhi_epi32(tr0_4, tr0_5); + const __m128i tr1_7 = _mm_unpackhi_epi32(tr0_6, tr0_7); + // 00 10 20 30 01 11 21 31 + // 40 50 60 70 41 51 61 71 + // 02 12 22 32 03 13 23 33 + // 42 52 62 72 43 53 63 73 + // 04 14 24 34 05 15 21 36 + // 44 54 64 74 45 55 61 76 + // 06 16 26 36 07 17 27 37 + // 46 56 66 76 47 57 67 77 + in0 = _mm_unpacklo_epi64(tr1_0, tr1_4); + in1 = _mm_unpackhi_epi64(tr1_0, tr1_4); + in2 = _mm_unpacklo_epi64(tr1_2, tr1_6); + in3 = _mm_unpackhi_epi64(tr1_2, tr1_6); + in4 = _mm_unpacklo_epi64(tr1_1, tr1_5); + in5 = _mm_unpackhi_epi64(tr1_1, tr1_5); + in6 = _mm_unpacklo_epi64(tr1_3, tr1_7); + in7 = _mm_unpackhi_epi64(tr1_3, tr1_7); + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + // 02 12 22 32 42 52 62 72 + // 03 13 23 33 43 53 63 73 + // 04 14 24 34 44 54 64 74 + // 05 15 25 35 45 55 65 75 + // 06 16 26 36 46 56 66 76 + // 07 17 27 37 47 57 67 77 + } + } + // Post-condition output and store it + { + // Post-condition (division by two) + // division of two 16 bits signed numbers using shifts + // n / 2 = (n - (n >> 15)) >> 1 + const __m128i sign_in0 = _mm_srai_epi16(in0, 15); + const __m128i sign_in1 = _mm_srai_epi16(in1, 15); + const __m128i sign_in2 = _mm_srai_epi16(in2, 15); + const __m128i sign_in3 = _mm_srai_epi16(in3, 15); + const __m128i sign_in4 = _mm_srai_epi16(in4, 15); + const __m128i sign_in5 = _mm_srai_epi16(in5, 15); + const __m128i sign_in6 = _mm_srai_epi16(in6, 15); + const __m128i sign_in7 = _mm_srai_epi16(in7, 15); + in0 = _mm_sub_epi16(in0, sign_in0); + in1 = _mm_sub_epi16(in1, sign_in1); + in2 = _mm_sub_epi16(in2, sign_in2); + in3 = _mm_sub_epi16(in3, sign_in3); + in4 = _mm_sub_epi16(in4, sign_in4); + in5 = _mm_sub_epi16(in5, sign_in5); + in6 = _mm_sub_epi16(in6, sign_in6); + in7 = _mm_sub_epi16(in7, sign_in7); + in0 = _mm_srai_epi16(in0, 1); + in1 = _mm_srai_epi16(in1, 1); + in2 = _mm_srai_epi16(in2, 1); + in3 = _mm_srai_epi16(in3, 1); + in4 = _mm_srai_epi16(in4, 1); + in5 = _mm_srai_epi16(in5, 1); + in6 = _mm_srai_epi16(in6, 1); + in7 = _mm_srai_epi16(in7, 1); + } + + iscan_ptr += n_coeffs; + qcoeff_ptr += n_coeffs; + dqcoeff_ptr += n_coeffs; + n_coeffs = -n_coeffs; + zero = _mm_setzero_si128(); + + if (!skip_block) { + __m128i eob; + __m128i round, quant, dequant; + { + __m128i coeff0, coeff1; + + // Setup global values + { + round = _mm_load_si128((const __m128i *)round_ptr); + quant = _mm_load_si128((const __m128i *)quant_ptr); + dequant = _mm_load_si128((const __m128i *)dequant_ptr); + } + + { + __m128i coeff0_sign, coeff1_sign; + __m128i qcoeff0, qcoeff1; + __m128i qtmp0, qtmp1; + // Do DC and first 15 AC + coeff0 = *in[0]; + coeff1 = *in[1]; + + // Poor man's sign extract + coeff0_sign = _mm_srai_epi16(coeff0, 15); + coeff1_sign = _mm_srai_epi16(coeff1, 15); + qcoeff0 = _mm_xor_si128(coeff0, coeff0_sign); + qcoeff1 = _mm_xor_si128(coeff1, coeff1_sign); + qcoeff0 = _mm_sub_epi16(qcoeff0, coeff0_sign); + qcoeff1 = _mm_sub_epi16(qcoeff1, coeff1_sign); + + qcoeff0 = _mm_adds_epi16(qcoeff0, round); + round = _mm_unpackhi_epi64(round, round); + qcoeff1 = _mm_adds_epi16(qcoeff1, round); + qtmp0 = _mm_mulhi_epi16(qcoeff0, quant); + quant = _mm_unpackhi_epi64(quant, quant); + qtmp1 = _mm_mulhi_epi16(qcoeff1, quant); + + // Reinsert signs + qcoeff0 = _mm_xor_si128(qtmp0, coeff0_sign); + qcoeff1 = _mm_xor_si128(qtmp1, coeff1_sign); + qcoeff0 = _mm_sub_epi16(qcoeff0, coeff0_sign); + qcoeff1 = _mm_sub_epi16(qcoeff1, coeff1_sign); + + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs), qcoeff0); + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs) + 1, qcoeff1); + + coeff0 = _mm_mullo_epi16(qcoeff0, dequant); + dequant = _mm_unpackhi_epi64(dequant, dequant); + coeff1 = _mm_mullo_epi16(qcoeff1, dequant); + + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs), coeff0); + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs) + 1, coeff1); + } + + { + // Scan for eob + __m128i zero_coeff0, zero_coeff1; + __m128i nzero_coeff0, nzero_coeff1; + __m128i iscan0, iscan1; + __m128i eob1; + zero_coeff0 = _mm_cmpeq_epi16(coeff0, zero); + zero_coeff1 = _mm_cmpeq_epi16(coeff1, zero); + nzero_coeff0 = _mm_cmpeq_epi16(zero_coeff0, zero); + nzero_coeff1 = _mm_cmpeq_epi16(zero_coeff1, zero); + iscan0 = _mm_load_si128((const __m128i *)(iscan_ptr + n_coeffs)); + iscan1 = _mm_load_si128((const __m128i *)(iscan_ptr + n_coeffs) + 1); + // Add one to convert from indices to counts + iscan0 = _mm_sub_epi16(iscan0, nzero_coeff0); + iscan1 = _mm_sub_epi16(iscan1, nzero_coeff1); + eob = _mm_and_si128(iscan0, nzero_coeff0); + eob1 = _mm_and_si128(iscan1, nzero_coeff1); + eob = _mm_max_epi16(eob, eob1); + } + n_coeffs += 8 * 2; + } + + // AC only loop + index = 2; + while (n_coeffs < 0) { + __m128i coeff0, coeff1; + { + __m128i coeff0_sign, coeff1_sign; + __m128i qcoeff0, qcoeff1; + __m128i qtmp0, qtmp1; + + assert(index < (int)(sizeof(in) / sizeof(in[0])) - 1); + coeff0 = *in[index]; + coeff1 = *in[index + 1]; + + // Poor man's sign extract + coeff0_sign = _mm_srai_epi16(coeff0, 15); + coeff1_sign = _mm_srai_epi16(coeff1, 15); + qcoeff0 = _mm_xor_si128(coeff0, coeff0_sign); + qcoeff1 = _mm_xor_si128(coeff1, coeff1_sign); + qcoeff0 = _mm_sub_epi16(qcoeff0, coeff0_sign); + qcoeff1 = _mm_sub_epi16(qcoeff1, coeff1_sign); + + qcoeff0 = _mm_adds_epi16(qcoeff0, round); + qcoeff1 = _mm_adds_epi16(qcoeff1, round); + qtmp0 = _mm_mulhi_epi16(qcoeff0, quant); + qtmp1 = _mm_mulhi_epi16(qcoeff1, quant); + + // Reinsert signs + qcoeff0 = _mm_xor_si128(qtmp0, coeff0_sign); + qcoeff1 = _mm_xor_si128(qtmp1, coeff1_sign); + qcoeff0 = _mm_sub_epi16(qcoeff0, coeff0_sign); + qcoeff1 = _mm_sub_epi16(qcoeff1, coeff1_sign); + + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs), qcoeff0); + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs) + 1, qcoeff1); + + coeff0 = _mm_mullo_epi16(qcoeff0, dequant); + coeff1 = _mm_mullo_epi16(qcoeff1, dequant); + + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs), coeff0); + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs) + 1, coeff1); + } + + { + // Scan for eob + __m128i zero_coeff0, zero_coeff1; + __m128i nzero_coeff0, nzero_coeff1; + __m128i iscan0, iscan1; + __m128i eob0, eob1; + zero_coeff0 = _mm_cmpeq_epi16(coeff0, zero); + zero_coeff1 = _mm_cmpeq_epi16(coeff1, zero); + nzero_coeff0 = _mm_cmpeq_epi16(zero_coeff0, zero); + nzero_coeff1 = _mm_cmpeq_epi16(zero_coeff1, zero); + iscan0 = _mm_load_si128((const __m128i *)(iscan_ptr + n_coeffs)); + iscan1 = _mm_load_si128((const __m128i *)(iscan_ptr + n_coeffs) + 1); + // Add one to convert from indices to counts + iscan0 = _mm_sub_epi16(iscan0, nzero_coeff0); + iscan1 = _mm_sub_epi16(iscan1, nzero_coeff1); + eob0 = _mm_and_si128(iscan0, nzero_coeff0); + eob1 = _mm_and_si128(iscan1, nzero_coeff1); + eob0 = _mm_max_epi16(eob0, eob1); + eob = _mm_max_epi16(eob, eob0); + } + n_coeffs += 8 * 2; + index += 2; + } + + // Accumulate EOB + { + __m128i eob_shuffled; + eob_shuffled = _mm_shuffle_epi32(eob, 0xe); + eob = _mm_max_epi16(eob, eob_shuffled); + eob_shuffled = _mm_shufflelo_epi16(eob, 0xe); + eob = _mm_max_epi16(eob, eob_shuffled); + eob_shuffled = _mm_shufflelo_epi16(eob, 0x1); + eob = _mm_max_epi16(eob, eob_shuffled); + *eob_ptr = _mm_extract_epi16(eob, 1); + } + } else { + do { + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs), zero); + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs) + 1, zero); + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs), zero); + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs) + 1, zero); + n_coeffs += 8 * 2; + } while (n_coeffs < 0); + *eob_ptr = 0; + } +} + +// load 8x8 array +static INLINE void load_buffer_8x8(const int16_t *input, __m128i *in, + int stride, int flipud, int fliplr) { + if (!flipud) { + in[0] = _mm_load_si128((const __m128i *)(input + 0 * stride)); + in[1] = _mm_load_si128((const __m128i *)(input + 1 * stride)); + in[2] = _mm_load_si128((const __m128i *)(input + 2 * stride)); + in[3] = _mm_load_si128((const __m128i *)(input + 3 * stride)); + in[4] = _mm_load_si128((const __m128i *)(input + 4 * stride)); + in[5] = _mm_load_si128((const __m128i *)(input + 5 * stride)); + in[6] = _mm_load_si128((const __m128i *)(input + 6 * stride)); + in[7] = _mm_load_si128((const __m128i *)(input + 7 * stride)); + } else { + in[0] = _mm_load_si128((const __m128i *)(input + 7 * stride)); + in[1] = _mm_load_si128((const __m128i *)(input + 6 * stride)); + in[2] = _mm_load_si128((const __m128i *)(input + 5 * stride)); + in[3] = _mm_load_si128((const __m128i *)(input + 4 * stride)); + in[4] = _mm_load_si128((const __m128i *)(input + 3 * stride)); + in[5] = _mm_load_si128((const __m128i *)(input + 2 * stride)); + in[6] = _mm_load_si128((const __m128i *)(input + 1 * stride)); + in[7] = _mm_load_si128((const __m128i *)(input + 0 * stride)); + } + + if (fliplr) { + in[0] = mm_reverse_epi16(in[0]); + in[1] = mm_reverse_epi16(in[1]); + in[2] = mm_reverse_epi16(in[2]); + in[3] = mm_reverse_epi16(in[3]); + in[4] = mm_reverse_epi16(in[4]); + in[5] = mm_reverse_epi16(in[5]); + in[6] = mm_reverse_epi16(in[6]); + in[7] = mm_reverse_epi16(in[7]); + } + + in[0] = _mm_slli_epi16(in[0], 2); + in[1] = _mm_slli_epi16(in[1], 2); + in[2] = _mm_slli_epi16(in[2], 2); + in[3] = _mm_slli_epi16(in[3], 2); + in[4] = _mm_slli_epi16(in[4], 2); + in[5] = _mm_slli_epi16(in[5], 2); + in[6] = _mm_slli_epi16(in[6], 2); + in[7] = _mm_slli_epi16(in[7], 2); +} + +// right shift and rounding +static INLINE void right_shift_8x8(__m128i *res, const int bit) { + __m128i sign0 = _mm_srai_epi16(res[0], 15); + __m128i sign1 = _mm_srai_epi16(res[1], 15); + __m128i sign2 = _mm_srai_epi16(res[2], 15); + __m128i sign3 = _mm_srai_epi16(res[3], 15); + __m128i sign4 = _mm_srai_epi16(res[4], 15); + __m128i sign5 = _mm_srai_epi16(res[5], 15); + __m128i sign6 = _mm_srai_epi16(res[6], 15); + __m128i sign7 = _mm_srai_epi16(res[7], 15); + + if (bit == 2) { + const __m128i const_rounding = _mm_set1_epi16(1); + res[0] = _mm_adds_epi16(res[0], const_rounding); + res[1] = _mm_adds_epi16(res[1], const_rounding); + res[2] = _mm_adds_epi16(res[2], const_rounding); + res[3] = _mm_adds_epi16(res[3], const_rounding); + res[4] = _mm_adds_epi16(res[4], const_rounding); + res[5] = _mm_adds_epi16(res[5], const_rounding); + res[6] = _mm_adds_epi16(res[6], const_rounding); + res[7] = _mm_adds_epi16(res[7], const_rounding); + } + + res[0] = _mm_sub_epi16(res[0], sign0); + res[1] = _mm_sub_epi16(res[1], sign1); + res[2] = _mm_sub_epi16(res[2], sign2); + res[3] = _mm_sub_epi16(res[3], sign3); + res[4] = _mm_sub_epi16(res[4], sign4); + res[5] = _mm_sub_epi16(res[5], sign5); + res[6] = _mm_sub_epi16(res[6], sign6); + res[7] = _mm_sub_epi16(res[7], sign7); + + if (bit == 1) { + res[0] = _mm_srai_epi16(res[0], 1); + res[1] = _mm_srai_epi16(res[1], 1); + res[2] = _mm_srai_epi16(res[2], 1); + res[3] = _mm_srai_epi16(res[3], 1); + res[4] = _mm_srai_epi16(res[4], 1); + res[5] = _mm_srai_epi16(res[5], 1); + res[6] = _mm_srai_epi16(res[6], 1); + res[7] = _mm_srai_epi16(res[7], 1); + } else { + res[0] = _mm_srai_epi16(res[0], 2); + res[1] = _mm_srai_epi16(res[1], 2); + res[2] = _mm_srai_epi16(res[2], 2); + res[3] = _mm_srai_epi16(res[3], 2); + res[4] = _mm_srai_epi16(res[4], 2); + res[5] = _mm_srai_epi16(res[5], 2); + res[6] = _mm_srai_epi16(res[6], 2); + res[7] = _mm_srai_epi16(res[7], 2); + } +} + +// write 8x8 array +static INLINE void write_buffer_8x8(tran_low_t *output, __m128i *res, + int stride) { + store_output(&res[0], (output + 0 * stride)); + store_output(&res[1], (output + 1 * stride)); + store_output(&res[2], (output + 2 * stride)); + store_output(&res[3], (output + 3 * stride)); + store_output(&res[4], (output + 4 * stride)); + store_output(&res[5], (output + 5 * stride)); + store_output(&res[6], (output + 6 * stride)); + store_output(&res[7], (output + 7 * stride)); +} + +// perform in-place transpose +static INLINE void array_transpose_8x8(__m128i *in, __m128i *res) { + const __m128i tr0_0 = _mm_unpacklo_epi16(in[0], in[1]); + const __m128i tr0_1 = _mm_unpacklo_epi16(in[2], in[3]); + const __m128i tr0_2 = _mm_unpackhi_epi16(in[0], in[1]); + const __m128i tr0_3 = _mm_unpackhi_epi16(in[2], in[3]); + const __m128i tr0_4 = _mm_unpacklo_epi16(in[4], in[5]); + const __m128i tr0_5 = _mm_unpacklo_epi16(in[6], in[7]); + const __m128i tr0_6 = _mm_unpackhi_epi16(in[4], in[5]); + const __m128i tr0_7 = _mm_unpackhi_epi16(in[6], in[7]); + // 00 10 01 11 02 12 03 13 + // 20 30 21 31 22 32 23 33 + // 04 14 05 15 06 16 07 17 + // 24 34 25 35 26 36 27 37 + // 40 50 41 51 42 52 43 53 + // 60 70 61 71 62 72 63 73 + // 44 54 45 55 46 56 47 57 + // 64 74 65 75 66 76 67 77 + const __m128i tr1_0 = _mm_unpacklo_epi32(tr0_0, tr0_1); + const __m128i tr1_1 = _mm_unpacklo_epi32(tr0_4, tr0_5); + const __m128i tr1_2 = _mm_unpackhi_epi32(tr0_0, tr0_1); + const __m128i tr1_3 = _mm_unpackhi_epi32(tr0_4, tr0_5); + const __m128i tr1_4 = _mm_unpacklo_epi32(tr0_2, tr0_3); + const __m128i tr1_5 = _mm_unpacklo_epi32(tr0_6, tr0_7); + const __m128i tr1_6 = _mm_unpackhi_epi32(tr0_2, tr0_3); + const __m128i tr1_7 = _mm_unpackhi_epi32(tr0_6, tr0_7); + // 00 10 20 30 01 11 21 31 + // 40 50 60 70 41 51 61 71 + // 02 12 22 32 03 13 23 33 + // 42 52 62 72 43 53 63 73 + // 04 14 24 34 05 15 25 35 + // 44 54 64 74 45 55 65 75 + // 06 16 26 36 07 17 27 37 + // 46 56 66 76 47 57 67 77 + res[0] = _mm_unpacklo_epi64(tr1_0, tr1_1); + res[1] = _mm_unpackhi_epi64(tr1_0, tr1_1); + res[2] = _mm_unpacklo_epi64(tr1_2, tr1_3); + res[3] = _mm_unpackhi_epi64(tr1_2, tr1_3); + res[4] = _mm_unpacklo_epi64(tr1_4, tr1_5); + res[5] = _mm_unpackhi_epi64(tr1_4, tr1_5); + res[6] = _mm_unpacklo_epi64(tr1_6, tr1_7); + res[7] = _mm_unpackhi_epi64(tr1_6, tr1_7); + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + // 02 12 22 32 42 52 62 72 + // 03 13 23 33 43 53 63 73 + // 04 14 24 34 44 54 64 74 + // 05 15 25 35 45 55 65 75 + // 06 16 26 36 46 56 66 76 + // 07 17 27 37 47 57 67 77 +} + +static void fdct8_sse2(__m128i *in) { + // constants + const __m128i k__cospi_p16_p16 = _mm_set1_epi16((int16_t)cospi_16_64); + const __m128i k__cospi_p16_m16 = pair_set_epi16(cospi_16_64, -cospi_16_64); + const __m128i k__cospi_p24_p08 = pair_set_epi16(cospi_24_64, cospi_8_64); + const __m128i k__cospi_m08_p24 = pair_set_epi16(-cospi_8_64, cospi_24_64); + const __m128i k__cospi_p28_p04 = pair_set_epi16(cospi_28_64, cospi_4_64); + const __m128i k__cospi_m04_p28 = pair_set_epi16(-cospi_4_64, cospi_28_64); + const __m128i k__cospi_p12_p20 = pair_set_epi16(cospi_12_64, cospi_20_64); + const __m128i k__cospi_m20_p12 = pair_set_epi16(-cospi_20_64, cospi_12_64); + const __m128i k__DCT_CONST_ROUNDING = _mm_set1_epi32(DCT_CONST_ROUNDING); + __m128i u0, u1, u2, u3, u4, u5, u6, u7; + __m128i v0, v1, v2, v3, v4, v5, v6, v7; + __m128i s0, s1, s2, s3, s4, s5, s6, s7; + + // stage 1 + s0 = _mm_add_epi16(in[0], in[7]); + s1 = _mm_add_epi16(in[1], in[6]); + s2 = _mm_add_epi16(in[2], in[5]); + s3 = _mm_add_epi16(in[3], in[4]); + s4 = _mm_sub_epi16(in[3], in[4]); + s5 = _mm_sub_epi16(in[2], in[5]); + s6 = _mm_sub_epi16(in[1], in[6]); + s7 = _mm_sub_epi16(in[0], in[7]); + + u0 = _mm_add_epi16(s0, s3); + u1 = _mm_add_epi16(s1, s2); + u2 = _mm_sub_epi16(s1, s2); + u3 = _mm_sub_epi16(s0, s3); + // interleave and perform butterfly multiplication/addition + v0 = _mm_unpacklo_epi16(u0, u1); + v1 = _mm_unpackhi_epi16(u0, u1); + v2 = _mm_unpacklo_epi16(u2, u3); + v3 = _mm_unpackhi_epi16(u2, u3); + + u0 = _mm_madd_epi16(v0, k__cospi_p16_p16); + u1 = _mm_madd_epi16(v1, k__cospi_p16_p16); + u2 = _mm_madd_epi16(v0, k__cospi_p16_m16); + u3 = _mm_madd_epi16(v1, k__cospi_p16_m16); + u4 = _mm_madd_epi16(v2, k__cospi_p24_p08); + u5 = _mm_madd_epi16(v3, k__cospi_p24_p08); + u6 = _mm_madd_epi16(v2, k__cospi_m08_p24); + u7 = _mm_madd_epi16(v3, k__cospi_m08_p24); + + // shift and rounding + v0 = _mm_add_epi32(u0, k__DCT_CONST_ROUNDING); + v1 = _mm_add_epi32(u1, k__DCT_CONST_ROUNDING); + v2 = _mm_add_epi32(u2, k__DCT_CONST_ROUNDING); + v3 = _mm_add_epi32(u3, k__DCT_CONST_ROUNDING); + v4 = _mm_add_epi32(u4, k__DCT_CONST_ROUNDING); + v5 = _mm_add_epi32(u5, k__DCT_CONST_ROUNDING); + v6 = _mm_add_epi32(u6, k__DCT_CONST_ROUNDING); + v7 = _mm_add_epi32(u7, k__DCT_CONST_ROUNDING); + + u0 = _mm_srai_epi32(v0, DCT_CONST_BITS); + u1 = _mm_srai_epi32(v1, DCT_CONST_BITS); + u2 = _mm_srai_epi32(v2, DCT_CONST_BITS); + u3 = _mm_srai_epi32(v3, DCT_CONST_BITS); + u4 = _mm_srai_epi32(v4, DCT_CONST_BITS); + u5 = _mm_srai_epi32(v5, DCT_CONST_BITS); + u6 = _mm_srai_epi32(v6, DCT_CONST_BITS); + u7 = _mm_srai_epi32(v7, DCT_CONST_BITS); + + in[0] = _mm_packs_epi32(u0, u1); + in[2] = _mm_packs_epi32(u4, u5); + in[4] = _mm_packs_epi32(u2, u3); + in[6] = _mm_packs_epi32(u6, u7); + + // stage 2 + // interleave and perform butterfly multiplication/addition + u0 = _mm_unpacklo_epi16(s6, s5); + u1 = _mm_unpackhi_epi16(s6, s5); + v0 = _mm_madd_epi16(u0, k__cospi_p16_m16); + v1 = _mm_madd_epi16(u1, k__cospi_p16_m16); + v2 = _mm_madd_epi16(u0, k__cospi_p16_p16); + v3 = _mm_madd_epi16(u1, k__cospi_p16_p16); + + // shift and rounding + u0 = _mm_add_epi32(v0, k__DCT_CONST_ROUNDING); + u1 = _mm_add_epi32(v1, k__DCT_CONST_ROUNDING); + u2 = _mm_add_epi32(v2, k__DCT_CONST_ROUNDING); + u3 = _mm_add_epi32(v3, k__DCT_CONST_ROUNDING); + + v0 = _mm_srai_epi32(u0, DCT_CONST_BITS); + v1 = _mm_srai_epi32(u1, DCT_CONST_BITS); + v2 = _mm_srai_epi32(u2, DCT_CONST_BITS); + v3 = _mm_srai_epi32(u3, DCT_CONST_BITS); + + u0 = _mm_packs_epi32(v0, v1); + u1 = _mm_packs_epi32(v2, v3); + + // stage 3 + s0 = _mm_add_epi16(s4, u0); + s1 = _mm_sub_epi16(s4, u0); + s2 = _mm_sub_epi16(s7, u1); + s3 = _mm_add_epi16(s7, u1); + + // stage 4 + u0 = _mm_unpacklo_epi16(s0, s3); + u1 = _mm_unpackhi_epi16(s0, s3); + u2 = _mm_unpacklo_epi16(s1, s2); + u3 = _mm_unpackhi_epi16(s1, s2); + + v0 = _mm_madd_epi16(u0, k__cospi_p28_p04); + v1 = _mm_madd_epi16(u1, k__cospi_p28_p04); + v2 = _mm_madd_epi16(u2, k__cospi_p12_p20); + v3 = _mm_madd_epi16(u3, k__cospi_p12_p20); + v4 = _mm_madd_epi16(u2, k__cospi_m20_p12); + v5 = _mm_madd_epi16(u3, k__cospi_m20_p12); + v6 = _mm_madd_epi16(u0, k__cospi_m04_p28); + v7 = _mm_madd_epi16(u1, k__cospi_m04_p28); + + // shift and rounding + u0 = _mm_add_epi32(v0, k__DCT_CONST_ROUNDING); + u1 = _mm_add_epi32(v1, k__DCT_CONST_ROUNDING); + u2 = _mm_add_epi32(v2, k__DCT_CONST_ROUNDING); + u3 = _mm_add_epi32(v3, k__DCT_CONST_ROUNDING); + u4 = _mm_add_epi32(v4, k__DCT_CONST_ROUNDING); + u5 = _mm_add_epi32(v5, k__DCT_CONST_ROUNDING); + u6 = _mm_add_epi32(v6, k__DCT_CONST_ROUNDING); + u7 = _mm_add_epi32(v7, k__DCT_CONST_ROUNDING); + + v0 = _mm_srai_epi32(u0, DCT_CONST_BITS); + v1 = _mm_srai_epi32(u1, DCT_CONST_BITS); + v2 = _mm_srai_epi32(u2, DCT_CONST_BITS); + v3 = _mm_srai_epi32(u3, DCT_CONST_BITS); + v4 = _mm_srai_epi32(u4, DCT_CONST_BITS); + v5 = _mm_srai_epi32(u5, DCT_CONST_BITS); + v6 = _mm_srai_epi32(u6, DCT_CONST_BITS); + v7 = _mm_srai_epi32(u7, DCT_CONST_BITS); + + in[1] = _mm_packs_epi32(v0, v1); + in[3] = _mm_packs_epi32(v4, v5); + in[5] = _mm_packs_epi32(v2, v3); + in[7] = _mm_packs_epi32(v6, v7); + + // transpose + array_transpose_8x8(in, in); +} + +static void fadst8_sse2(__m128i *in) { + // Constants + const __m128i k__cospi_p02_p30 = pair_set_epi16(cospi_2_64, cospi_30_64); + const __m128i k__cospi_p30_m02 = pair_set_epi16(cospi_30_64, -cospi_2_64); + const __m128i k__cospi_p10_p22 = pair_set_epi16(cospi_10_64, cospi_22_64); + const __m128i k__cospi_p22_m10 = pair_set_epi16(cospi_22_64, -cospi_10_64); + const __m128i k__cospi_p18_p14 = pair_set_epi16(cospi_18_64, cospi_14_64); + const __m128i k__cospi_p14_m18 = pair_set_epi16(cospi_14_64, -cospi_18_64); + const __m128i k__cospi_p26_p06 = pair_set_epi16(cospi_26_64, cospi_6_64); + const __m128i k__cospi_p06_m26 = pair_set_epi16(cospi_6_64, -cospi_26_64); + const __m128i k__cospi_p08_p24 = pair_set_epi16(cospi_8_64, cospi_24_64); + const __m128i k__cospi_p24_m08 = pair_set_epi16(cospi_24_64, -cospi_8_64); + const __m128i k__cospi_m24_p08 = pair_set_epi16(-cospi_24_64, cospi_8_64); + const __m128i k__cospi_p16_m16 = pair_set_epi16(cospi_16_64, -cospi_16_64); + const __m128i k__cospi_p16_p16 = _mm_set1_epi16((int16_t)cospi_16_64); + const __m128i k__const_0 = _mm_set1_epi16(0); + const __m128i k__DCT_CONST_ROUNDING = _mm_set1_epi32(DCT_CONST_ROUNDING); + + __m128i u0, u1, u2, u3, u4, u5, u6, u7, u8, u9, u10, u11, u12, u13, u14, u15; + __m128i v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15; + __m128i w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15; + __m128i s0, s1, s2, s3, s4, s5, s6, s7; + __m128i in0, in1, in2, in3, in4, in5, in6, in7; + + // properly aligned for butterfly input + in0 = in[7]; + in1 = in[0]; + in2 = in[5]; + in3 = in[2]; + in4 = in[3]; + in5 = in[4]; + in6 = in[1]; + in7 = in[6]; + + // column transformation + // stage 1 + // interleave and multiply/add into 32-bit integer + s0 = _mm_unpacklo_epi16(in0, in1); + s1 = _mm_unpackhi_epi16(in0, in1); + s2 = _mm_unpacklo_epi16(in2, in3); + s3 = _mm_unpackhi_epi16(in2, in3); + s4 = _mm_unpacklo_epi16(in4, in5); + s5 = _mm_unpackhi_epi16(in4, in5); + s6 = _mm_unpacklo_epi16(in6, in7); + s7 = _mm_unpackhi_epi16(in6, in7); + + u0 = _mm_madd_epi16(s0, k__cospi_p02_p30); + u1 = _mm_madd_epi16(s1, k__cospi_p02_p30); + u2 = _mm_madd_epi16(s0, k__cospi_p30_m02); + u3 = _mm_madd_epi16(s1, k__cospi_p30_m02); + u4 = _mm_madd_epi16(s2, k__cospi_p10_p22); + u5 = _mm_madd_epi16(s3, k__cospi_p10_p22); + u6 = _mm_madd_epi16(s2, k__cospi_p22_m10); + u7 = _mm_madd_epi16(s3, k__cospi_p22_m10); + u8 = _mm_madd_epi16(s4, k__cospi_p18_p14); + u9 = _mm_madd_epi16(s5, k__cospi_p18_p14); + u10 = _mm_madd_epi16(s4, k__cospi_p14_m18); + u11 = _mm_madd_epi16(s5, k__cospi_p14_m18); + u12 = _mm_madd_epi16(s6, k__cospi_p26_p06); + u13 = _mm_madd_epi16(s7, k__cospi_p26_p06); + u14 = _mm_madd_epi16(s6, k__cospi_p06_m26); + u15 = _mm_madd_epi16(s7, k__cospi_p06_m26); + + // addition + w0 = _mm_add_epi32(u0, u8); + w1 = _mm_add_epi32(u1, u9); + w2 = _mm_add_epi32(u2, u10); + w3 = _mm_add_epi32(u3, u11); + w4 = _mm_add_epi32(u4, u12); + w5 = _mm_add_epi32(u5, u13); + w6 = _mm_add_epi32(u6, u14); + w7 = _mm_add_epi32(u7, u15); + w8 = _mm_sub_epi32(u0, u8); + w9 = _mm_sub_epi32(u1, u9); + w10 = _mm_sub_epi32(u2, u10); + w11 = _mm_sub_epi32(u3, u11); + w12 = _mm_sub_epi32(u4, u12); + w13 = _mm_sub_epi32(u5, u13); + w14 = _mm_sub_epi32(u6, u14); + w15 = _mm_sub_epi32(u7, u15); + + // shift and rounding + v8 = _mm_add_epi32(w8, k__DCT_CONST_ROUNDING); + v9 = _mm_add_epi32(w9, k__DCT_CONST_ROUNDING); + v10 = _mm_add_epi32(w10, k__DCT_CONST_ROUNDING); + v11 = _mm_add_epi32(w11, k__DCT_CONST_ROUNDING); + v12 = _mm_add_epi32(w12, k__DCT_CONST_ROUNDING); + v13 = _mm_add_epi32(w13, k__DCT_CONST_ROUNDING); + v14 = _mm_add_epi32(w14, k__DCT_CONST_ROUNDING); + v15 = _mm_add_epi32(w15, k__DCT_CONST_ROUNDING); + + u8 = _mm_srai_epi32(v8, DCT_CONST_BITS); + u9 = _mm_srai_epi32(v9, DCT_CONST_BITS); + u10 = _mm_srai_epi32(v10, DCT_CONST_BITS); + u11 = _mm_srai_epi32(v11, DCT_CONST_BITS); + u12 = _mm_srai_epi32(v12, DCT_CONST_BITS); + u13 = _mm_srai_epi32(v13, DCT_CONST_BITS); + u14 = _mm_srai_epi32(v14, DCT_CONST_BITS); + u15 = _mm_srai_epi32(v15, DCT_CONST_BITS); + + // back to 16-bit and pack 8 integers into __m128i + v0 = _mm_add_epi32(w0, w4); + v1 = _mm_add_epi32(w1, w5); + v2 = _mm_add_epi32(w2, w6); + v3 = _mm_add_epi32(w3, w7); + v4 = _mm_sub_epi32(w0, w4); + v5 = _mm_sub_epi32(w1, w5); + v6 = _mm_sub_epi32(w2, w6); + v7 = _mm_sub_epi32(w3, w7); + + w0 = _mm_add_epi32(v0, k__DCT_CONST_ROUNDING); + w1 = _mm_add_epi32(v1, k__DCT_CONST_ROUNDING); + w2 = _mm_add_epi32(v2, k__DCT_CONST_ROUNDING); + w3 = _mm_add_epi32(v3, k__DCT_CONST_ROUNDING); + w4 = _mm_add_epi32(v4, k__DCT_CONST_ROUNDING); + w5 = _mm_add_epi32(v5, k__DCT_CONST_ROUNDING); + w6 = _mm_add_epi32(v6, k__DCT_CONST_ROUNDING); + w7 = _mm_add_epi32(v7, k__DCT_CONST_ROUNDING); + + v0 = _mm_srai_epi32(w0, DCT_CONST_BITS); + v1 = _mm_srai_epi32(w1, DCT_CONST_BITS); + v2 = _mm_srai_epi32(w2, DCT_CONST_BITS); + v3 = _mm_srai_epi32(w3, DCT_CONST_BITS); + v4 = _mm_srai_epi32(w4, DCT_CONST_BITS); + v5 = _mm_srai_epi32(w5, DCT_CONST_BITS); + v6 = _mm_srai_epi32(w6, DCT_CONST_BITS); + v7 = _mm_srai_epi32(w7, DCT_CONST_BITS); + + in[4] = _mm_packs_epi32(u8, u9); + in[5] = _mm_packs_epi32(u10, u11); + in[6] = _mm_packs_epi32(u12, u13); + in[7] = _mm_packs_epi32(u14, u15); + + // stage 2 + s0 = _mm_packs_epi32(v0, v1); + s1 = _mm_packs_epi32(v2, v3); + s2 = _mm_packs_epi32(v4, v5); + s3 = _mm_packs_epi32(v6, v7); + + u0 = _mm_unpacklo_epi16(in[4], in[5]); + u1 = _mm_unpackhi_epi16(in[4], in[5]); + u2 = _mm_unpacklo_epi16(in[6], in[7]); + u3 = _mm_unpackhi_epi16(in[6], in[7]); + + v0 = _mm_madd_epi16(u0, k__cospi_p08_p24); + v1 = _mm_madd_epi16(u1, k__cospi_p08_p24); + v2 = _mm_madd_epi16(u0, k__cospi_p24_m08); + v3 = _mm_madd_epi16(u1, k__cospi_p24_m08); + v4 = _mm_madd_epi16(u2, k__cospi_m24_p08); + v5 = _mm_madd_epi16(u3, k__cospi_m24_p08); + v6 = _mm_madd_epi16(u2, k__cospi_p08_p24); + v7 = _mm_madd_epi16(u3, k__cospi_p08_p24); + + w0 = _mm_add_epi32(v0, v4); + w1 = _mm_add_epi32(v1, v5); + w2 = _mm_add_epi32(v2, v6); + w3 = _mm_add_epi32(v3, v7); + w4 = _mm_sub_epi32(v0, v4); + w5 = _mm_sub_epi32(v1, v5); + w6 = _mm_sub_epi32(v2, v6); + w7 = _mm_sub_epi32(v3, v7); + + v0 = _mm_add_epi32(w0, k__DCT_CONST_ROUNDING); + v1 = _mm_add_epi32(w1, k__DCT_CONST_ROUNDING); + v2 = _mm_add_epi32(w2, k__DCT_CONST_ROUNDING); + v3 = _mm_add_epi32(w3, k__DCT_CONST_ROUNDING); + v4 = _mm_add_epi32(w4, k__DCT_CONST_ROUNDING); + v5 = _mm_add_epi32(w5, k__DCT_CONST_ROUNDING); + v6 = _mm_add_epi32(w6, k__DCT_CONST_ROUNDING); + v7 = _mm_add_epi32(w7, k__DCT_CONST_ROUNDING); + + u0 = _mm_srai_epi32(v0, DCT_CONST_BITS); + u1 = _mm_srai_epi32(v1, DCT_CONST_BITS); + u2 = _mm_srai_epi32(v2, DCT_CONST_BITS); + u3 = _mm_srai_epi32(v3, DCT_CONST_BITS); + u4 = _mm_srai_epi32(v4, DCT_CONST_BITS); + u5 = _mm_srai_epi32(v5, DCT_CONST_BITS); + u6 = _mm_srai_epi32(v6, DCT_CONST_BITS); + u7 = _mm_srai_epi32(v7, DCT_CONST_BITS); + + // back to 16-bit intergers + s4 = _mm_packs_epi32(u0, u1); + s5 = _mm_packs_epi32(u2, u3); + s6 = _mm_packs_epi32(u4, u5); + s7 = _mm_packs_epi32(u6, u7); + + // stage 3 + u0 = _mm_unpacklo_epi16(s2, s3); + u1 = _mm_unpackhi_epi16(s2, s3); + u2 = _mm_unpacklo_epi16(s6, s7); + u3 = _mm_unpackhi_epi16(s6, s7); + + v0 = _mm_madd_epi16(u0, k__cospi_p16_p16); + v1 = _mm_madd_epi16(u1, k__cospi_p16_p16); + v2 = _mm_madd_epi16(u0, k__cospi_p16_m16); + v3 = _mm_madd_epi16(u1, k__cospi_p16_m16); + v4 = _mm_madd_epi16(u2, k__cospi_p16_p16); + v5 = _mm_madd_epi16(u3, k__cospi_p16_p16); + v6 = _mm_madd_epi16(u2, k__cospi_p16_m16); + v7 = _mm_madd_epi16(u3, k__cospi_p16_m16); + + u0 = _mm_add_epi32(v0, k__DCT_CONST_ROUNDING); + u1 = _mm_add_epi32(v1, k__DCT_CONST_ROUNDING); + u2 = _mm_add_epi32(v2, k__DCT_CONST_ROUNDING); + u3 = _mm_add_epi32(v3, k__DCT_CONST_ROUNDING); + u4 = _mm_add_epi32(v4, k__DCT_CONST_ROUNDING); + u5 = _mm_add_epi32(v5, k__DCT_CONST_ROUNDING); + u6 = _mm_add_epi32(v6, k__DCT_CONST_ROUNDING); + u7 = _mm_add_epi32(v7, k__DCT_CONST_ROUNDING); + + v0 = _mm_srai_epi32(u0, DCT_CONST_BITS); + v1 = _mm_srai_epi32(u1, DCT_CONST_BITS); + v2 = _mm_srai_epi32(u2, DCT_CONST_BITS); + v3 = _mm_srai_epi32(u3, DCT_CONST_BITS); + v4 = _mm_srai_epi32(u4, DCT_CONST_BITS); + v5 = _mm_srai_epi32(u5, DCT_CONST_BITS); + v6 = _mm_srai_epi32(u6, DCT_CONST_BITS); + v7 = _mm_srai_epi32(u7, DCT_CONST_BITS); + + s2 = _mm_packs_epi32(v0, v1); + s3 = _mm_packs_epi32(v2, v3); + s6 = _mm_packs_epi32(v4, v5); + s7 = _mm_packs_epi32(v6, v7); + + // FIXME(jingning): do subtract using bit inversion? + in[0] = s0; + in[1] = _mm_sub_epi16(k__const_0, s4); + in[2] = s6; + in[3] = _mm_sub_epi16(k__const_0, s2); + in[4] = s3; + in[5] = _mm_sub_epi16(k__const_0, s7); + in[6] = s5; + in[7] = _mm_sub_epi16(k__const_0, s1); + + // transpose + array_transpose_8x8(in, in); +} + +#if CONFIG_EXT_TX +static void fidtx8_sse2(__m128i *in) { + in[0] = _mm_slli_epi16(in[0], 1); + in[1] = _mm_slli_epi16(in[1], 1); + in[2] = _mm_slli_epi16(in[2], 1); + in[3] = _mm_slli_epi16(in[3], 1); + in[4] = _mm_slli_epi16(in[4], 1); + in[5] = _mm_slli_epi16(in[5], 1); + in[6] = _mm_slli_epi16(in[6], 1); + in[7] = _mm_slli_epi16(in[7], 1); + + array_transpose_8x8(in, in); +} +#endif // CONFIG_EXT_TX + +void av1_fht8x8_sse2(const int16_t *input, tran_low_t *output, int stride, + int tx_type) { + __m128i in[8]; + + switch (tx_type) { + case DCT_DCT: aom_fdct8x8_sse2(input, output, stride); break; + case ADST_DCT: + load_buffer_8x8(input, in, stride, 0, 0); + fadst8_sse2(in); + fdct8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; + case DCT_ADST: + load_buffer_8x8(input, in, stride, 0, 0); + fdct8_sse2(in); + fadst8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; + case ADST_ADST: + load_buffer_8x8(input, in, stride, 0, 0); + fadst8_sse2(in); + fadst8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; +#if CONFIG_EXT_TX + case FLIPADST_DCT: + load_buffer_8x8(input, in, stride, 1, 0); + fadst8_sse2(in); + fdct8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; + case DCT_FLIPADST: + load_buffer_8x8(input, in, stride, 0, 1); + fdct8_sse2(in); + fadst8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; + case FLIPADST_FLIPADST: + load_buffer_8x8(input, in, stride, 1, 1); + fadst8_sse2(in); + fadst8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; + case ADST_FLIPADST: + load_buffer_8x8(input, in, stride, 0, 1); + fadst8_sse2(in); + fadst8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; + case FLIPADST_ADST: + load_buffer_8x8(input, in, stride, 1, 0); + fadst8_sse2(in); + fadst8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; + case IDTX: + load_buffer_8x8(input, in, stride, 0, 0); + fidtx8_sse2(in); + fidtx8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; + case V_DCT: + load_buffer_8x8(input, in, stride, 0, 0); + fdct8_sse2(in); + fidtx8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; + case H_DCT: + load_buffer_8x8(input, in, stride, 0, 0); + fidtx8_sse2(in); + fdct8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; + case V_ADST: + load_buffer_8x8(input, in, stride, 0, 0); + fadst8_sse2(in); + fidtx8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; + case H_ADST: + load_buffer_8x8(input, in, stride, 0, 0); + fidtx8_sse2(in); + fadst8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; + case V_FLIPADST: + load_buffer_8x8(input, in, stride, 1, 0); + fadst8_sse2(in); + fidtx8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; + case H_FLIPADST: + load_buffer_8x8(input, in, stride, 0, 1); + fidtx8_sse2(in); + fadst8_sse2(in); + right_shift_8x8(in, 1); + write_buffer_8x8(output, in, 8); + break; +#endif // CONFIG_EXT_TX + default: assert(0); + } +} + +static INLINE void load_buffer_16x16(const int16_t *input, __m128i *in0, + __m128i *in1, int stride, int flipud, + int fliplr) { + // Load 4 8x8 blocks + const int16_t *topL = input; + const int16_t *topR = input + 8; + const int16_t *botL = input + 8 * stride; + const int16_t *botR = input + 8 * stride + 8; + + const int16_t *tmp; + + if (flipud) { + // Swap left columns + tmp = topL; + topL = botL; + botL = tmp; + // Swap right columns + tmp = topR; + topR = botR; + botR = tmp; + } + + if (fliplr) { + // Swap top rows + tmp = topL; + topL = topR; + topR = tmp; + // Swap bottom rows + tmp = botL; + botL = botR; + botR = tmp; + } + + // load first 8 columns + load_buffer_8x8(topL, in0, stride, flipud, fliplr); + load_buffer_8x8(botL, in0 + 8, stride, flipud, fliplr); + + // load second 8 columns + load_buffer_8x8(topR, in1, stride, flipud, fliplr); + load_buffer_8x8(botR, in1 + 8, stride, flipud, fliplr); +} + +static INLINE void write_buffer_16x16(tran_low_t *output, __m128i *in0, + __m128i *in1, int stride) { + // write first 8 columns + write_buffer_8x8(output, in0, stride); + write_buffer_8x8(output + 8 * stride, in0 + 8, stride); + // write second 8 columns + output += 8; + write_buffer_8x8(output, in1, stride); + write_buffer_8x8(output + 8 * stride, in1 + 8, stride); +} + +static INLINE void array_transpose_16x16(__m128i *res0, __m128i *res1) { + __m128i tbuf[8]; + array_transpose_8x8(res0, res0); + array_transpose_8x8(res1, tbuf); + array_transpose_8x8(res0 + 8, res1); + array_transpose_8x8(res1 + 8, res1 + 8); + + res0[8] = tbuf[0]; + res0[9] = tbuf[1]; + res0[10] = tbuf[2]; + res0[11] = tbuf[3]; + res0[12] = tbuf[4]; + res0[13] = tbuf[5]; + res0[14] = tbuf[6]; + res0[15] = tbuf[7]; +} + +static INLINE void right_shift_16x16(__m128i *res0, __m128i *res1) { + // perform rounding operations + right_shift_8x8(res0, 2); + right_shift_8x8(res0 + 8, 2); + right_shift_8x8(res1, 2); + right_shift_8x8(res1 + 8, 2); +} + +static void fdct16_8col(__m128i *in) { + // perform 16x16 1-D DCT for 8 columns + __m128i i[8], s[8], p[8], t[8], u[16], v[16]; + const __m128i k__cospi_p16_p16 = _mm_set1_epi16((int16_t)cospi_16_64); + const __m128i k__cospi_p16_m16 = pair_set_epi16(cospi_16_64, -cospi_16_64); + const __m128i k__cospi_m16_p16 = pair_set_epi16(-cospi_16_64, cospi_16_64); + const __m128i k__cospi_p24_p08 = pair_set_epi16(cospi_24_64, cospi_8_64); + const __m128i k__cospi_m24_m08 = pair_set_epi16(-cospi_24_64, -cospi_8_64); + const __m128i k__cospi_m08_p24 = pair_set_epi16(-cospi_8_64, cospi_24_64); + const __m128i k__cospi_p28_p04 = pair_set_epi16(cospi_28_64, cospi_4_64); + const __m128i k__cospi_m04_p28 = pair_set_epi16(-cospi_4_64, cospi_28_64); + const __m128i k__cospi_p12_p20 = pair_set_epi16(cospi_12_64, cospi_20_64); + const __m128i k__cospi_m20_p12 = pair_set_epi16(-cospi_20_64, cospi_12_64); + const __m128i k__cospi_p30_p02 = pair_set_epi16(cospi_30_64, cospi_2_64); + const __m128i k__cospi_p14_p18 = pair_set_epi16(cospi_14_64, cospi_18_64); + const __m128i k__cospi_m02_p30 = pair_set_epi16(-cospi_2_64, cospi_30_64); + const __m128i k__cospi_m18_p14 = pair_set_epi16(-cospi_18_64, cospi_14_64); + const __m128i k__cospi_p22_p10 = pair_set_epi16(cospi_22_64, cospi_10_64); + const __m128i k__cospi_p06_p26 = pair_set_epi16(cospi_6_64, cospi_26_64); + const __m128i k__cospi_m10_p22 = pair_set_epi16(-cospi_10_64, cospi_22_64); + const __m128i k__cospi_m26_p06 = pair_set_epi16(-cospi_26_64, cospi_6_64); + const __m128i k__DCT_CONST_ROUNDING = _mm_set1_epi32(DCT_CONST_ROUNDING); + + // stage 1 + i[0] = _mm_add_epi16(in[0], in[15]); + i[1] = _mm_add_epi16(in[1], in[14]); + i[2] = _mm_add_epi16(in[2], in[13]); + i[3] = _mm_add_epi16(in[3], in[12]); + i[4] = _mm_add_epi16(in[4], in[11]); + i[5] = _mm_add_epi16(in[5], in[10]); + i[6] = _mm_add_epi16(in[6], in[9]); + i[7] = _mm_add_epi16(in[7], in[8]); + + s[0] = _mm_sub_epi16(in[7], in[8]); + s[1] = _mm_sub_epi16(in[6], in[9]); + s[2] = _mm_sub_epi16(in[5], in[10]); + s[3] = _mm_sub_epi16(in[4], in[11]); + s[4] = _mm_sub_epi16(in[3], in[12]); + s[5] = _mm_sub_epi16(in[2], in[13]); + s[6] = _mm_sub_epi16(in[1], in[14]); + s[7] = _mm_sub_epi16(in[0], in[15]); + + p[0] = _mm_add_epi16(i[0], i[7]); + p[1] = _mm_add_epi16(i[1], i[6]); + p[2] = _mm_add_epi16(i[2], i[5]); + p[3] = _mm_add_epi16(i[3], i[4]); + p[4] = _mm_sub_epi16(i[3], i[4]); + p[5] = _mm_sub_epi16(i[2], i[5]); + p[6] = _mm_sub_epi16(i[1], i[6]); + p[7] = _mm_sub_epi16(i[0], i[7]); + + u[0] = _mm_add_epi16(p[0], p[3]); + u[1] = _mm_add_epi16(p[1], p[2]); + u[2] = _mm_sub_epi16(p[1], p[2]); + u[3] = _mm_sub_epi16(p[0], p[3]); + + v[0] = _mm_unpacklo_epi16(u[0], u[1]); + v[1] = _mm_unpackhi_epi16(u[0], u[1]); + v[2] = _mm_unpacklo_epi16(u[2], u[3]); + v[3] = _mm_unpackhi_epi16(u[2], u[3]); + + u[0] = _mm_madd_epi16(v[0], k__cospi_p16_p16); + u[1] = _mm_madd_epi16(v[1], k__cospi_p16_p16); + u[2] = _mm_madd_epi16(v[0], k__cospi_p16_m16); + u[3] = _mm_madd_epi16(v[1], k__cospi_p16_m16); + u[4] = _mm_madd_epi16(v[2], k__cospi_p24_p08); + u[5] = _mm_madd_epi16(v[3], k__cospi_p24_p08); + u[6] = _mm_madd_epi16(v[2], k__cospi_m08_p24); + u[7] = _mm_madd_epi16(v[3], k__cospi_m08_p24); + + v[0] = _mm_add_epi32(u[0], k__DCT_CONST_ROUNDING); + v[1] = _mm_add_epi32(u[1], k__DCT_CONST_ROUNDING); + v[2] = _mm_add_epi32(u[2], k__DCT_CONST_ROUNDING); + v[3] = _mm_add_epi32(u[3], k__DCT_CONST_ROUNDING); + v[4] = _mm_add_epi32(u[4], k__DCT_CONST_ROUNDING); + v[5] = _mm_add_epi32(u[5], k__DCT_CONST_ROUNDING); + v[6] = _mm_add_epi32(u[6], k__DCT_CONST_ROUNDING); + v[7] = _mm_add_epi32(u[7], k__DCT_CONST_ROUNDING); + + u[0] = _mm_srai_epi32(v[0], DCT_CONST_BITS); + u[1] = _mm_srai_epi32(v[1], DCT_CONST_BITS); + u[2] = _mm_srai_epi32(v[2], DCT_CONST_BITS); + u[3] = _mm_srai_epi32(v[3], DCT_CONST_BITS); + u[4] = _mm_srai_epi32(v[4], DCT_CONST_BITS); + u[5] = _mm_srai_epi32(v[5], DCT_CONST_BITS); + u[6] = _mm_srai_epi32(v[6], DCT_CONST_BITS); + u[7] = _mm_srai_epi32(v[7], DCT_CONST_BITS); + + in[0] = _mm_packs_epi32(u[0], u[1]); + in[4] = _mm_packs_epi32(u[4], u[5]); + in[8] = _mm_packs_epi32(u[2], u[3]); + in[12] = _mm_packs_epi32(u[6], u[7]); + + u[0] = _mm_unpacklo_epi16(p[5], p[6]); + u[1] = _mm_unpackhi_epi16(p[5], p[6]); + v[0] = _mm_madd_epi16(u[0], k__cospi_m16_p16); + v[1] = _mm_madd_epi16(u[1], k__cospi_m16_p16); + v[2] = _mm_madd_epi16(u[0], k__cospi_p16_p16); + v[3] = _mm_madd_epi16(u[1], k__cospi_p16_p16); + + u[0] = _mm_add_epi32(v[0], k__DCT_CONST_ROUNDING); + u[1] = _mm_add_epi32(v[1], k__DCT_CONST_ROUNDING); + u[2] = _mm_add_epi32(v[2], k__DCT_CONST_ROUNDING); + u[3] = _mm_add_epi32(v[3], k__DCT_CONST_ROUNDING); + + v[0] = _mm_srai_epi32(u[0], DCT_CONST_BITS); + v[1] = _mm_srai_epi32(u[1], DCT_CONST_BITS); + v[2] = _mm_srai_epi32(u[2], DCT_CONST_BITS); + v[3] = _mm_srai_epi32(u[3], DCT_CONST_BITS); + + u[0] = _mm_packs_epi32(v[0], v[1]); + u[1] = _mm_packs_epi32(v[2], v[3]); + + t[0] = _mm_add_epi16(p[4], u[0]); + t[1] = _mm_sub_epi16(p[4], u[0]); + t[2] = _mm_sub_epi16(p[7], u[1]); + t[3] = _mm_add_epi16(p[7], u[1]); + + u[0] = _mm_unpacklo_epi16(t[0], t[3]); + u[1] = _mm_unpackhi_epi16(t[0], t[3]); + u[2] = _mm_unpacklo_epi16(t[1], t[2]); + u[3] = _mm_unpackhi_epi16(t[1], t[2]); + + v[0] = _mm_madd_epi16(u[0], k__cospi_p28_p04); + v[1] = _mm_madd_epi16(u[1], k__cospi_p28_p04); + v[2] = _mm_madd_epi16(u[2], k__cospi_p12_p20); + v[3] = _mm_madd_epi16(u[3], k__cospi_p12_p20); + v[4] = _mm_madd_epi16(u[2], k__cospi_m20_p12); + v[5] = _mm_madd_epi16(u[3], k__cospi_m20_p12); + v[6] = _mm_madd_epi16(u[0], k__cospi_m04_p28); + v[7] = _mm_madd_epi16(u[1], k__cospi_m04_p28); + + u[0] = _mm_add_epi32(v[0], k__DCT_CONST_ROUNDING); + u[1] = _mm_add_epi32(v[1], k__DCT_CONST_ROUNDING); + u[2] = _mm_add_epi32(v[2], k__DCT_CONST_ROUNDING); + u[3] = _mm_add_epi32(v[3], k__DCT_CONST_ROUNDING); + u[4] = _mm_add_epi32(v[4], k__DCT_CONST_ROUNDING); + u[5] = _mm_add_epi32(v[5], k__DCT_CONST_ROUNDING); + u[6] = _mm_add_epi32(v[6], k__DCT_CONST_ROUNDING); + u[7] = _mm_add_epi32(v[7], k__DCT_CONST_ROUNDING); + + v[0] = _mm_srai_epi32(u[0], DCT_CONST_BITS); + v[1] = _mm_srai_epi32(u[1], DCT_CONST_BITS); + v[2] = _mm_srai_epi32(u[2], DCT_CONST_BITS); + v[3] = _mm_srai_epi32(u[3], DCT_CONST_BITS); + v[4] = _mm_srai_epi32(u[4], DCT_CONST_BITS); + v[5] = _mm_srai_epi32(u[5], DCT_CONST_BITS); + v[6] = _mm_srai_epi32(u[6], DCT_CONST_BITS); + v[7] = _mm_srai_epi32(u[7], DCT_CONST_BITS); + + in[2] = _mm_packs_epi32(v[0], v[1]); + in[6] = _mm_packs_epi32(v[4], v[5]); + in[10] = _mm_packs_epi32(v[2], v[3]); + in[14] = _mm_packs_epi32(v[6], v[7]); + + // stage 2 + u[0] = _mm_unpacklo_epi16(s[2], s[5]); + u[1] = _mm_unpackhi_epi16(s[2], s[5]); + u[2] = _mm_unpacklo_epi16(s[3], s[4]); + u[3] = _mm_unpackhi_epi16(s[3], s[4]); + + v[0] = _mm_madd_epi16(u[0], k__cospi_m16_p16); + v[1] = _mm_madd_epi16(u[1], k__cospi_m16_p16); + v[2] = _mm_madd_epi16(u[2], k__cospi_m16_p16); + v[3] = _mm_madd_epi16(u[3], k__cospi_m16_p16); + v[4] = _mm_madd_epi16(u[2], k__cospi_p16_p16); + v[5] = _mm_madd_epi16(u[3], k__cospi_p16_p16); + v[6] = _mm_madd_epi16(u[0], k__cospi_p16_p16); + v[7] = _mm_madd_epi16(u[1], k__cospi_p16_p16); + + u[0] = _mm_add_epi32(v[0], k__DCT_CONST_ROUNDING); + u[1] = _mm_add_epi32(v[1], k__DCT_CONST_ROUNDING); + u[2] = _mm_add_epi32(v[2], k__DCT_CONST_ROUNDING); + u[3] = _mm_add_epi32(v[3], k__DCT_CONST_ROUNDING); + u[4] = _mm_add_epi32(v[4], k__DCT_CONST_ROUNDING); + u[5] = _mm_add_epi32(v[5], k__DCT_CONST_ROUNDING); + u[6] = _mm_add_epi32(v[6], k__DCT_CONST_ROUNDING); + u[7] = _mm_add_epi32(v[7], k__DCT_CONST_ROUNDING); + + v[0] = _mm_srai_epi32(u[0], DCT_CONST_BITS); + v[1] = _mm_srai_epi32(u[1], DCT_CONST_BITS); + v[2] = _mm_srai_epi32(u[2], DCT_CONST_BITS); + v[3] = _mm_srai_epi32(u[3], DCT_CONST_BITS); + v[4] = _mm_srai_epi32(u[4], DCT_CONST_BITS); + v[5] = _mm_srai_epi32(u[5], DCT_CONST_BITS); + v[6] = _mm_srai_epi32(u[6], DCT_CONST_BITS); + v[7] = _mm_srai_epi32(u[7], DCT_CONST_BITS); + + t[2] = _mm_packs_epi32(v[0], v[1]); + t[3] = _mm_packs_epi32(v[2], v[3]); + t[4] = _mm_packs_epi32(v[4], v[5]); + t[5] = _mm_packs_epi32(v[6], v[7]); + + // stage 3 + p[0] = _mm_add_epi16(s[0], t[3]); + p[1] = _mm_add_epi16(s[1], t[2]); + p[2] = _mm_sub_epi16(s[1], t[2]); + p[3] = _mm_sub_epi16(s[0], t[3]); + p[4] = _mm_sub_epi16(s[7], t[4]); + p[5] = _mm_sub_epi16(s[6], t[5]); + p[6] = _mm_add_epi16(s[6], t[5]); + p[7] = _mm_add_epi16(s[7], t[4]); + + // stage 4 + u[0] = _mm_unpacklo_epi16(p[1], p[6]); + u[1] = _mm_unpackhi_epi16(p[1], p[6]); + u[2] = _mm_unpacklo_epi16(p[2], p[5]); + u[3] = _mm_unpackhi_epi16(p[2], p[5]); + + v[0] = _mm_madd_epi16(u[0], k__cospi_m08_p24); + v[1] = _mm_madd_epi16(u[1], k__cospi_m08_p24); + v[2] = _mm_madd_epi16(u[2], k__cospi_m24_m08); + v[3] = _mm_madd_epi16(u[3], k__cospi_m24_m08); + v[4] = _mm_madd_epi16(u[2], k__cospi_m08_p24); + v[5] = _mm_madd_epi16(u[3], k__cospi_m08_p24); + v[6] = _mm_madd_epi16(u[0], k__cospi_p24_p08); + v[7] = _mm_madd_epi16(u[1], k__cospi_p24_p08); + + u[0] = _mm_add_epi32(v[0], k__DCT_CONST_ROUNDING); + u[1] = _mm_add_epi32(v[1], k__DCT_CONST_ROUNDING); + u[2] = _mm_add_epi32(v[2], k__DCT_CONST_ROUNDING); + u[3] = _mm_add_epi32(v[3], k__DCT_CONST_ROUNDING); + u[4] = _mm_add_epi32(v[4], k__DCT_CONST_ROUNDING); + u[5] = _mm_add_epi32(v[5], k__DCT_CONST_ROUNDING); + u[6] = _mm_add_epi32(v[6], k__DCT_CONST_ROUNDING); + u[7] = _mm_add_epi32(v[7], k__DCT_CONST_ROUNDING); + + v[0] = _mm_srai_epi32(u[0], DCT_CONST_BITS); + v[1] = _mm_srai_epi32(u[1], DCT_CONST_BITS); + v[2] = _mm_srai_epi32(u[2], DCT_CONST_BITS); + v[3] = _mm_srai_epi32(u[3], DCT_CONST_BITS); + v[4] = _mm_srai_epi32(u[4], DCT_CONST_BITS); + v[5] = _mm_srai_epi32(u[5], DCT_CONST_BITS); + v[6] = _mm_srai_epi32(u[6], DCT_CONST_BITS); + v[7] = _mm_srai_epi32(u[7], DCT_CONST_BITS); + + t[1] = _mm_packs_epi32(v[0], v[1]); + t[2] = _mm_packs_epi32(v[2], v[3]); + t[5] = _mm_packs_epi32(v[4], v[5]); + t[6] = _mm_packs_epi32(v[6], v[7]); + + // stage 5 + s[0] = _mm_add_epi16(p[0], t[1]); + s[1] = _mm_sub_epi16(p[0], t[1]); + s[2] = _mm_sub_epi16(p[3], t[2]); + s[3] = _mm_add_epi16(p[3], t[2]); + s[4] = _mm_add_epi16(p[4], t[5]); + s[5] = _mm_sub_epi16(p[4], t[5]); + s[6] = _mm_sub_epi16(p[7], t[6]); + s[7] = _mm_add_epi16(p[7], t[6]); + + // stage 6 + u[0] = _mm_unpacklo_epi16(s[0], s[7]); + u[1] = _mm_unpackhi_epi16(s[0], s[7]); + u[2] = _mm_unpacklo_epi16(s[1], s[6]); + u[3] = _mm_unpackhi_epi16(s[1], s[6]); + u[4] = _mm_unpacklo_epi16(s[2], s[5]); + u[5] = _mm_unpackhi_epi16(s[2], s[5]); + u[6] = _mm_unpacklo_epi16(s[3], s[4]); + u[7] = _mm_unpackhi_epi16(s[3], s[4]); + + v[0] = _mm_madd_epi16(u[0], k__cospi_p30_p02); + v[1] = _mm_madd_epi16(u[1], k__cospi_p30_p02); + v[2] = _mm_madd_epi16(u[2], k__cospi_p14_p18); + v[3] = _mm_madd_epi16(u[3], k__cospi_p14_p18); + v[4] = _mm_madd_epi16(u[4], k__cospi_p22_p10); + v[5] = _mm_madd_epi16(u[5], k__cospi_p22_p10); + v[6] = _mm_madd_epi16(u[6], k__cospi_p06_p26); + v[7] = _mm_madd_epi16(u[7], k__cospi_p06_p26); + v[8] = _mm_madd_epi16(u[6], k__cospi_m26_p06); + v[9] = _mm_madd_epi16(u[7], k__cospi_m26_p06); + v[10] = _mm_madd_epi16(u[4], k__cospi_m10_p22); + v[11] = _mm_madd_epi16(u[5], k__cospi_m10_p22); + v[12] = _mm_madd_epi16(u[2], k__cospi_m18_p14); + v[13] = _mm_madd_epi16(u[3], k__cospi_m18_p14); + v[14] = _mm_madd_epi16(u[0], k__cospi_m02_p30); + v[15] = _mm_madd_epi16(u[1], k__cospi_m02_p30); + + u[0] = _mm_add_epi32(v[0], k__DCT_CONST_ROUNDING); + u[1] = _mm_add_epi32(v[1], k__DCT_CONST_ROUNDING); + u[2] = _mm_add_epi32(v[2], k__DCT_CONST_ROUNDING); + u[3] = _mm_add_epi32(v[3], k__DCT_CONST_ROUNDING); + u[4] = _mm_add_epi32(v[4], k__DCT_CONST_ROUNDING); + u[5] = _mm_add_epi32(v[5], k__DCT_CONST_ROUNDING); + u[6] = _mm_add_epi32(v[6], k__DCT_CONST_ROUNDING); + u[7] = _mm_add_epi32(v[7], k__DCT_CONST_ROUNDING); + u[8] = _mm_add_epi32(v[8], k__DCT_CONST_ROUNDING); + u[9] = _mm_add_epi32(v[9], k__DCT_CONST_ROUNDING); + u[10] = _mm_add_epi32(v[10], k__DCT_CONST_ROUNDING); + u[11] = _mm_add_epi32(v[11], k__DCT_CONST_ROUNDING); + u[12] = _mm_add_epi32(v[12], k__DCT_CONST_ROUNDING); + u[13] = _mm_add_epi32(v[13], k__DCT_CONST_ROUNDING); + u[14] = _mm_add_epi32(v[14], k__DCT_CONST_ROUNDING); + u[15] = _mm_add_epi32(v[15], k__DCT_CONST_ROUNDING); + + v[0] = _mm_srai_epi32(u[0], DCT_CONST_BITS); + v[1] = _mm_srai_epi32(u[1], DCT_CONST_BITS); + v[2] = _mm_srai_epi32(u[2], DCT_CONST_BITS); + v[3] = _mm_srai_epi32(u[3], DCT_CONST_BITS); + v[4] = _mm_srai_epi32(u[4], DCT_CONST_BITS); + v[5] = _mm_srai_epi32(u[5], DCT_CONST_BITS); + v[6] = _mm_srai_epi32(u[6], DCT_CONST_BITS); + v[7] = _mm_srai_epi32(u[7], DCT_CONST_BITS); + v[8] = _mm_srai_epi32(u[8], DCT_CONST_BITS); + v[9] = _mm_srai_epi32(u[9], DCT_CONST_BITS); + v[10] = _mm_srai_epi32(u[10], DCT_CONST_BITS); + v[11] = _mm_srai_epi32(u[11], DCT_CONST_BITS); + v[12] = _mm_srai_epi32(u[12], DCT_CONST_BITS); + v[13] = _mm_srai_epi32(u[13], DCT_CONST_BITS); + v[14] = _mm_srai_epi32(u[14], DCT_CONST_BITS); + v[15] = _mm_srai_epi32(u[15], DCT_CONST_BITS); + + in[1] = _mm_packs_epi32(v[0], v[1]); + in[9] = _mm_packs_epi32(v[2], v[3]); + in[5] = _mm_packs_epi32(v[4], v[5]); + in[13] = _mm_packs_epi32(v[6], v[7]); + in[3] = _mm_packs_epi32(v[8], v[9]); + in[11] = _mm_packs_epi32(v[10], v[11]); + in[7] = _mm_packs_epi32(v[12], v[13]); + in[15] = _mm_packs_epi32(v[14], v[15]); +} + +static void fadst16_8col(__m128i *in) { + // perform 16x16 1-D ADST for 8 columns + __m128i s[16], x[16], u[32], v[32]; + const __m128i k__cospi_p01_p31 = pair_set_epi16(cospi_1_64, cospi_31_64); + const __m128i k__cospi_p31_m01 = pair_set_epi16(cospi_31_64, -cospi_1_64); + const __m128i k__cospi_p05_p27 = pair_set_epi16(cospi_5_64, cospi_27_64); + const __m128i k__cospi_p27_m05 = pair_set_epi16(cospi_27_64, -cospi_5_64); + const __m128i k__cospi_p09_p23 = pair_set_epi16(cospi_9_64, cospi_23_64); + const __m128i k__cospi_p23_m09 = pair_set_epi16(cospi_23_64, -cospi_9_64); + const __m128i k__cospi_p13_p19 = pair_set_epi16(cospi_13_64, cospi_19_64); + const __m128i k__cospi_p19_m13 = pair_set_epi16(cospi_19_64, -cospi_13_64); + const __m128i k__cospi_p17_p15 = pair_set_epi16(cospi_17_64, cospi_15_64); + const __m128i k__cospi_p15_m17 = pair_set_epi16(cospi_15_64, -cospi_17_64); + const __m128i k__cospi_p21_p11 = pair_set_epi16(cospi_21_64, cospi_11_64); + const __m128i k__cospi_p11_m21 = pair_set_epi16(cospi_11_64, -cospi_21_64); + const __m128i k__cospi_p25_p07 = pair_set_epi16(cospi_25_64, cospi_7_64); + const __m128i k__cospi_p07_m25 = pair_set_epi16(cospi_7_64, -cospi_25_64); + const __m128i k__cospi_p29_p03 = pair_set_epi16(cospi_29_64, cospi_3_64); + const __m128i k__cospi_p03_m29 = pair_set_epi16(cospi_3_64, -cospi_29_64); + const __m128i k__cospi_p04_p28 = pair_set_epi16(cospi_4_64, cospi_28_64); + const __m128i k__cospi_p28_m04 = pair_set_epi16(cospi_28_64, -cospi_4_64); + const __m128i k__cospi_p20_p12 = pair_set_epi16(cospi_20_64, cospi_12_64); + const __m128i k__cospi_p12_m20 = pair_set_epi16(cospi_12_64, -cospi_20_64); + const __m128i k__cospi_m28_p04 = pair_set_epi16(-cospi_28_64, cospi_4_64); + const __m128i k__cospi_m12_p20 = pair_set_epi16(-cospi_12_64, cospi_20_64); + const __m128i k__cospi_p08_p24 = pair_set_epi16(cospi_8_64, cospi_24_64); + const __m128i k__cospi_p24_m08 = pair_set_epi16(cospi_24_64, -cospi_8_64); + const __m128i k__cospi_m24_p08 = pair_set_epi16(-cospi_24_64, cospi_8_64); + const __m128i k__cospi_m16_m16 = _mm_set1_epi16((int16_t)-cospi_16_64); + const __m128i k__cospi_p16_p16 = _mm_set1_epi16((int16_t)cospi_16_64); + const __m128i k__cospi_p16_m16 = pair_set_epi16(cospi_16_64, -cospi_16_64); + const __m128i k__cospi_m16_p16 = pair_set_epi16(-cospi_16_64, cospi_16_64); + const __m128i k__DCT_CONST_ROUNDING = _mm_set1_epi32(DCT_CONST_ROUNDING); + const __m128i kZero = _mm_set1_epi16(0); + + u[0] = _mm_unpacklo_epi16(in[15], in[0]); + u[1] = _mm_unpackhi_epi16(in[15], in[0]); + u[2] = _mm_unpacklo_epi16(in[13], in[2]); + u[3] = _mm_unpackhi_epi16(in[13], in[2]); + u[4] = _mm_unpacklo_epi16(in[11], in[4]); + u[5] = _mm_unpackhi_epi16(in[11], in[4]); + u[6] = _mm_unpacklo_epi16(in[9], in[6]); + u[7] = _mm_unpackhi_epi16(in[9], in[6]); + u[8] = _mm_unpacklo_epi16(in[7], in[8]); + u[9] = _mm_unpackhi_epi16(in[7], in[8]); + u[10] = _mm_unpacklo_epi16(in[5], in[10]); + u[11] = _mm_unpackhi_epi16(in[5], in[10]); + u[12] = _mm_unpacklo_epi16(in[3], in[12]); + u[13] = _mm_unpackhi_epi16(in[3], in[12]); + u[14] = _mm_unpacklo_epi16(in[1], in[14]); + u[15] = _mm_unpackhi_epi16(in[1], in[14]); + + v[0] = _mm_madd_epi16(u[0], k__cospi_p01_p31); + v[1] = _mm_madd_epi16(u[1], k__cospi_p01_p31); + v[2] = _mm_madd_epi16(u[0], k__cospi_p31_m01); + v[3] = _mm_madd_epi16(u[1], k__cospi_p31_m01); + v[4] = _mm_madd_epi16(u[2], k__cospi_p05_p27); + v[5] = _mm_madd_epi16(u[3], k__cospi_p05_p27); + v[6] = _mm_madd_epi16(u[2], k__cospi_p27_m05); + v[7] = _mm_madd_epi16(u[3], k__cospi_p27_m05); + v[8] = _mm_madd_epi16(u[4], k__cospi_p09_p23); + v[9] = _mm_madd_epi16(u[5], k__cospi_p09_p23); + v[10] = _mm_madd_epi16(u[4], k__cospi_p23_m09); + v[11] = _mm_madd_epi16(u[5], k__cospi_p23_m09); + v[12] = _mm_madd_epi16(u[6], k__cospi_p13_p19); + v[13] = _mm_madd_epi16(u[7], k__cospi_p13_p19); + v[14] = _mm_madd_epi16(u[6], k__cospi_p19_m13); + v[15] = _mm_madd_epi16(u[7], k__cospi_p19_m13); + v[16] = _mm_madd_epi16(u[8], k__cospi_p17_p15); + v[17] = _mm_madd_epi16(u[9], k__cospi_p17_p15); + v[18] = _mm_madd_epi16(u[8], k__cospi_p15_m17); + v[19] = _mm_madd_epi16(u[9], k__cospi_p15_m17); + v[20] = _mm_madd_epi16(u[10], k__cospi_p21_p11); + v[21] = _mm_madd_epi16(u[11], k__cospi_p21_p11); + v[22] = _mm_madd_epi16(u[10], k__cospi_p11_m21); + v[23] = _mm_madd_epi16(u[11], k__cospi_p11_m21); + v[24] = _mm_madd_epi16(u[12], k__cospi_p25_p07); + v[25] = _mm_madd_epi16(u[13], k__cospi_p25_p07); + v[26] = _mm_madd_epi16(u[12], k__cospi_p07_m25); + v[27] = _mm_madd_epi16(u[13], k__cospi_p07_m25); + v[28] = _mm_madd_epi16(u[14], k__cospi_p29_p03); + v[29] = _mm_madd_epi16(u[15], k__cospi_p29_p03); + v[30] = _mm_madd_epi16(u[14], k__cospi_p03_m29); + v[31] = _mm_madd_epi16(u[15], k__cospi_p03_m29); + + u[0] = _mm_add_epi32(v[0], v[16]); + u[1] = _mm_add_epi32(v[1], v[17]); + u[2] = _mm_add_epi32(v[2], v[18]); + u[3] = _mm_add_epi32(v[3], v[19]); + u[4] = _mm_add_epi32(v[4], v[20]); + u[5] = _mm_add_epi32(v[5], v[21]); + u[6] = _mm_add_epi32(v[6], v[22]); + u[7] = _mm_add_epi32(v[7], v[23]); + u[8] = _mm_add_epi32(v[8], v[24]); + u[9] = _mm_add_epi32(v[9], v[25]); + u[10] = _mm_add_epi32(v[10], v[26]); + u[11] = _mm_add_epi32(v[11], v[27]); + u[12] = _mm_add_epi32(v[12], v[28]); + u[13] = _mm_add_epi32(v[13], v[29]); + u[14] = _mm_add_epi32(v[14], v[30]); + u[15] = _mm_add_epi32(v[15], v[31]); + u[16] = _mm_sub_epi32(v[0], v[16]); + u[17] = _mm_sub_epi32(v[1], v[17]); + u[18] = _mm_sub_epi32(v[2], v[18]); + u[19] = _mm_sub_epi32(v[3], v[19]); + u[20] = _mm_sub_epi32(v[4], v[20]); + u[21] = _mm_sub_epi32(v[5], v[21]); + u[22] = _mm_sub_epi32(v[6], v[22]); + u[23] = _mm_sub_epi32(v[7], v[23]); + u[24] = _mm_sub_epi32(v[8], v[24]); + u[25] = _mm_sub_epi32(v[9], v[25]); + u[26] = _mm_sub_epi32(v[10], v[26]); + u[27] = _mm_sub_epi32(v[11], v[27]); + u[28] = _mm_sub_epi32(v[12], v[28]); + u[29] = _mm_sub_epi32(v[13], v[29]); + u[30] = _mm_sub_epi32(v[14], v[30]); + u[31] = _mm_sub_epi32(v[15], v[31]); + + v[16] = _mm_add_epi32(u[16], k__DCT_CONST_ROUNDING); + v[17] = _mm_add_epi32(u[17], k__DCT_CONST_ROUNDING); + v[18] = _mm_add_epi32(u[18], k__DCT_CONST_ROUNDING); + v[19] = _mm_add_epi32(u[19], k__DCT_CONST_ROUNDING); + v[20] = _mm_add_epi32(u[20], k__DCT_CONST_ROUNDING); + v[21] = _mm_add_epi32(u[21], k__DCT_CONST_ROUNDING); + v[22] = _mm_add_epi32(u[22], k__DCT_CONST_ROUNDING); + v[23] = _mm_add_epi32(u[23], k__DCT_CONST_ROUNDING); + v[24] = _mm_add_epi32(u[24], k__DCT_CONST_ROUNDING); + v[25] = _mm_add_epi32(u[25], k__DCT_CONST_ROUNDING); + v[26] = _mm_add_epi32(u[26], k__DCT_CONST_ROUNDING); + v[27] = _mm_add_epi32(u[27], k__DCT_CONST_ROUNDING); + v[28] = _mm_add_epi32(u[28], k__DCT_CONST_ROUNDING); + v[29] = _mm_add_epi32(u[29], k__DCT_CONST_ROUNDING); + v[30] = _mm_add_epi32(u[30], k__DCT_CONST_ROUNDING); + v[31] = _mm_add_epi32(u[31], k__DCT_CONST_ROUNDING); + + u[16] = _mm_srai_epi32(v[16], DCT_CONST_BITS); + u[17] = _mm_srai_epi32(v[17], DCT_CONST_BITS); + u[18] = _mm_srai_epi32(v[18], DCT_CONST_BITS); + u[19] = _mm_srai_epi32(v[19], DCT_CONST_BITS); + u[20] = _mm_srai_epi32(v[20], DCT_CONST_BITS); + u[21] = _mm_srai_epi32(v[21], DCT_CONST_BITS); + u[22] = _mm_srai_epi32(v[22], DCT_CONST_BITS); + u[23] = _mm_srai_epi32(v[23], DCT_CONST_BITS); + u[24] = _mm_srai_epi32(v[24], DCT_CONST_BITS); + u[25] = _mm_srai_epi32(v[25], DCT_CONST_BITS); + u[26] = _mm_srai_epi32(v[26], DCT_CONST_BITS); + u[27] = _mm_srai_epi32(v[27], DCT_CONST_BITS); + u[28] = _mm_srai_epi32(v[28], DCT_CONST_BITS); + u[29] = _mm_srai_epi32(v[29], DCT_CONST_BITS); + u[30] = _mm_srai_epi32(v[30], DCT_CONST_BITS); + u[31] = _mm_srai_epi32(v[31], DCT_CONST_BITS); + + v[0] = _mm_add_epi32(u[0], u[8]); + v[1] = _mm_add_epi32(u[1], u[9]); + v[2] = _mm_add_epi32(u[2], u[10]); + v[3] = _mm_add_epi32(u[3], u[11]); + v[4] = _mm_add_epi32(u[4], u[12]); + v[5] = _mm_add_epi32(u[5], u[13]); + v[6] = _mm_add_epi32(u[6], u[14]); + v[7] = _mm_add_epi32(u[7], u[15]); + + v[16] = _mm_add_epi32(v[0], v[4]); + v[17] = _mm_add_epi32(v[1], v[5]); + v[18] = _mm_add_epi32(v[2], v[6]); + v[19] = _mm_add_epi32(v[3], v[7]); + v[20] = _mm_sub_epi32(v[0], v[4]); + v[21] = _mm_sub_epi32(v[1], v[5]); + v[22] = _mm_sub_epi32(v[2], v[6]); + v[23] = _mm_sub_epi32(v[3], v[7]); + v[16] = _mm_add_epi32(v[16], k__DCT_CONST_ROUNDING); + v[17] = _mm_add_epi32(v[17], k__DCT_CONST_ROUNDING); + v[18] = _mm_add_epi32(v[18], k__DCT_CONST_ROUNDING); + v[19] = _mm_add_epi32(v[19], k__DCT_CONST_ROUNDING); + v[20] = _mm_add_epi32(v[20], k__DCT_CONST_ROUNDING); + v[21] = _mm_add_epi32(v[21], k__DCT_CONST_ROUNDING); + v[22] = _mm_add_epi32(v[22], k__DCT_CONST_ROUNDING); + v[23] = _mm_add_epi32(v[23], k__DCT_CONST_ROUNDING); + v[16] = _mm_srai_epi32(v[16], DCT_CONST_BITS); + v[17] = _mm_srai_epi32(v[17], DCT_CONST_BITS); + v[18] = _mm_srai_epi32(v[18], DCT_CONST_BITS); + v[19] = _mm_srai_epi32(v[19], DCT_CONST_BITS); + v[20] = _mm_srai_epi32(v[20], DCT_CONST_BITS); + v[21] = _mm_srai_epi32(v[21], DCT_CONST_BITS); + v[22] = _mm_srai_epi32(v[22], DCT_CONST_BITS); + v[23] = _mm_srai_epi32(v[23], DCT_CONST_BITS); + s[0] = _mm_packs_epi32(v[16], v[17]); + s[1] = _mm_packs_epi32(v[18], v[19]); + s[2] = _mm_packs_epi32(v[20], v[21]); + s[3] = _mm_packs_epi32(v[22], v[23]); + + v[8] = _mm_sub_epi32(u[0], u[8]); + v[9] = _mm_sub_epi32(u[1], u[9]); + v[10] = _mm_sub_epi32(u[2], u[10]); + v[11] = _mm_sub_epi32(u[3], u[11]); + v[12] = _mm_sub_epi32(u[4], u[12]); + v[13] = _mm_sub_epi32(u[5], u[13]); + v[14] = _mm_sub_epi32(u[6], u[14]); + v[15] = _mm_sub_epi32(u[7], u[15]); + + v[8] = _mm_add_epi32(v[8], k__DCT_CONST_ROUNDING); + v[9] = _mm_add_epi32(v[9], k__DCT_CONST_ROUNDING); + v[10] = _mm_add_epi32(v[10], k__DCT_CONST_ROUNDING); + v[11] = _mm_add_epi32(v[11], k__DCT_CONST_ROUNDING); + v[12] = _mm_add_epi32(v[12], k__DCT_CONST_ROUNDING); + v[13] = _mm_add_epi32(v[13], k__DCT_CONST_ROUNDING); + v[14] = _mm_add_epi32(v[14], k__DCT_CONST_ROUNDING); + v[15] = _mm_add_epi32(v[15], k__DCT_CONST_ROUNDING); + + v[8] = _mm_srai_epi32(v[8], DCT_CONST_BITS); + v[9] = _mm_srai_epi32(v[9], DCT_CONST_BITS); + v[10] = _mm_srai_epi32(v[10], DCT_CONST_BITS); + v[11] = _mm_srai_epi32(v[11], DCT_CONST_BITS); + v[12] = _mm_srai_epi32(v[12], DCT_CONST_BITS); + v[13] = _mm_srai_epi32(v[13], DCT_CONST_BITS); + v[14] = _mm_srai_epi32(v[14], DCT_CONST_BITS); + v[15] = _mm_srai_epi32(v[15], DCT_CONST_BITS); + + s[4] = _mm_packs_epi32(v[8], v[9]); + s[5] = _mm_packs_epi32(v[10], v[11]); + s[6] = _mm_packs_epi32(v[12], v[13]); + s[7] = _mm_packs_epi32(v[14], v[15]); + // + + s[8] = _mm_packs_epi32(u[16], u[17]); + s[9] = _mm_packs_epi32(u[18], u[19]); + s[10] = _mm_packs_epi32(u[20], u[21]); + s[11] = _mm_packs_epi32(u[22], u[23]); + s[12] = _mm_packs_epi32(u[24], u[25]); + s[13] = _mm_packs_epi32(u[26], u[27]); + s[14] = _mm_packs_epi32(u[28], u[29]); + s[15] = _mm_packs_epi32(u[30], u[31]); + + // stage 2 + u[0] = _mm_unpacklo_epi16(s[8], s[9]); + u[1] = _mm_unpackhi_epi16(s[8], s[9]); + u[2] = _mm_unpacklo_epi16(s[10], s[11]); + u[3] = _mm_unpackhi_epi16(s[10], s[11]); + u[4] = _mm_unpacklo_epi16(s[12], s[13]); + u[5] = _mm_unpackhi_epi16(s[12], s[13]); + u[6] = _mm_unpacklo_epi16(s[14], s[15]); + u[7] = _mm_unpackhi_epi16(s[14], s[15]); + + v[0] = _mm_madd_epi16(u[0], k__cospi_p04_p28); + v[1] = _mm_madd_epi16(u[1], k__cospi_p04_p28); + v[2] = _mm_madd_epi16(u[0], k__cospi_p28_m04); + v[3] = _mm_madd_epi16(u[1], k__cospi_p28_m04); + v[4] = _mm_madd_epi16(u[2], k__cospi_p20_p12); + v[5] = _mm_madd_epi16(u[3], k__cospi_p20_p12); + v[6] = _mm_madd_epi16(u[2], k__cospi_p12_m20); + v[7] = _mm_madd_epi16(u[3], k__cospi_p12_m20); + v[8] = _mm_madd_epi16(u[4], k__cospi_m28_p04); + v[9] = _mm_madd_epi16(u[5], k__cospi_m28_p04); + v[10] = _mm_madd_epi16(u[4], k__cospi_p04_p28); + v[11] = _mm_madd_epi16(u[5], k__cospi_p04_p28); + v[12] = _mm_madd_epi16(u[6], k__cospi_m12_p20); + v[13] = _mm_madd_epi16(u[7], k__cospi_m12_p20); + v[14] = _mm_madd_epi16(u[6], k__cospi_p20_p12); + v[15] = _mm_madd_epi16(u[7], k__cospi_p20_p12); + + u[0] = _mm_add_epi32(v[0], v[8]); + u[1] = _mm_add_epi32(v[1], v[9]); + u[2] = _mm_add_epi32(v[2], v[10]); + u[3] = _mm_add_epi32(v[3], v[11]); + u[4] = _mm_add_epi32(v[4], v[12]); + u[5] = _mm_add_epi32(v[5], v[13]); + u[6] = _mm_add_epi32(v[6], v[14]); + u[7] = _mm_add_epi32(v[7], v[15]); + u[8] = _mm_sub_epi32(v[0], v[8]); + u[9] = _mm_sub_epi32(v[1], v[9]); + u[10] = _mm_sub_epi32(v[2], v[10]); + u[11] = _mm_sub_epi32(v[3], v[11]); + u[12] = _mm_sub_epi32(v[4], v[12]); + u[13] = _mm_sub_epi32(v[5], v[13]); + u[14] = _mm_sub_epi32(v[6], v[14]); + u[15] = _mm_sub_epi32(v[7], v[15]); + + v[8] = _mm_add_epi32(u[8], k__DCT_CONST_ROUNDING); + v[9] = _mm_add_epi32(u[9], k__DCT_CONST_ROUNDING); + v[10] = _mm_add_epi32(u[10], k__DCT_CONST_ROUNDING); + v[11] = _mm_add_epi32(u[11], k__DCT_CONST_ROUNDING); + v[12] = _mm_add_epi32(u[12], k__DCT_CONST_ROUNDING); + v[13] = _mm_add_epi32(u[13], k__DCT_CONST_ROUNDING); + v[14] = _mm_add_epi32(u[14], k__DCT_CONST_ROUNDING); + v[15] = _mm_add_epi32(u[15], k__DCT_CONST_ROUNDING); + + u[8] = _mm_srai_epi32(v[8], DCT_CONST_BITS); + u[9] = _mm_srai_epi32(v[9], DCT_CONST_BITS); + u[10] = _mm_srai_epi32(v[10], DCT_CONST_BITS); + u[11] = _mm_srai_epi32(v[11], DCT_CONST_BITS); + u[12] = _mm_srai_epi32(v[12], DCT_CONST_BITS); + u[13] = _mm_srai_epi32(v[13], DCT_CONST_BITS); + u[14] = _mm_srai_epi32(v[14], DCT_CONST_BITS); + u[15] = _mm_srai_epi32(v[15], DCT_CONST_BITS); + + v[8] = _mm_add_epi32(u[0], u[4]); + v[9] = _mm_add_epi32(u[1], u[5]); + v[10] = _mm_add_epi32(u[2], u[6]); + v[11] = _mm_add_epi32(u[3], u[7]); + v[12] = _mm_sub_epi32(u[0], u[4]); + v[13] = _mm_sub_epi32(u[1], u[5]); + v[14] = _mm_sub_epi32(u[2], u[6]); + v[15] = _mm_sub_epi32(u[3], u[7]); + + v[8] = _mm_add_epi32(v[8], k__DCT_CONST_ROUNDING); + v[9] = _mm_add_epi32(v[9], k__DCT_CONST_ROUNDING); + v[10] = _mm_add_epi32(v[10], k__DCT_CONST_ROUNDING); + v[11] = _mm_add_epi32(v[11], k__DCT_CONST_ROUNDING); + v[12] = _mm_add_epi32(v[12], k__DCT_CONST_ROUNDING); + v[13] = _mm_add_epi32(v[13], k__DCT_CONST_ROUNDING); + v[14] = _mm_add_epi32(v[14], k__DCT_CONST_ROUNDING); + v[15] = _mm_add_epi32(v[15], k__DCT_CONST_ROUNDING); + v[8] = _mm_srai_epi32(v[8], DCT_CONST_BITS); + v[9] = _mm_srai_epi32(v[9], DCT_CONST_BITS); + v[10] = _mm_srai_epi32(v[10], DCT_CONST_BITS); + v[11] = _mm_srai_epi32(v[11], DCT_CONST_BITS); + v[12] = _mm_srai_epi32(v[12], DCT_CONST_BITS); + v[13] = _mm_srai_epi32(v[13], DCT_CONST_BITS); + v[14] = _mm_srai_epi32(v[14], DCT_CONST_BITS); + v[15] = _mm_srai_epi32(v[15], DCT_CONST_BITS); + s[8] = _mm_packs_epi32(v[8], v[9]); + s[9] = _mm_packs_epi32(v[10], v[11]); + s[10] = _mm_packs_epi32(v[12], v[13]); + s[11] = _mm_packs_epi32(v[14], v[15]); + + x[12] = _mm_packs_epi32(u[8], u[9]); + x[13] = _mm_packs_epi32(u[10], u[11]); + x[14] = _mm_packs_epi32(u[12], u[13]); + x[15] = _mm_packs_epi32(u[14], u[15]); + + // stage 3 + u[0] = _mm_unpacklo_epi16(s[4], s[5]); + u[1] = _mm_unpackhi_epi16(s[4], s[5]); + u[2] = _mm_unpacklo_epi16(s[6], s[7]); + u[3] = _mm_unpackhi_epi16(s[6], s[7]); + u[4] = _mm_unpacklo_epi16(x[12], x[13]); + u[5] = _mm_unpackhi_epi16(x[12], x[13]); + u[6] = _mm_unpacklo_epi16(x[14], x[15]); + u[7] = _mm_unpackhi_epi16(x[14], x[15]); + + v[0] = _mm_madd_epi16(u[0], k__cospi_p08_p24); + v[1] = _mm_madd_epi16(u[1], k__cospi_p08_p24); + v[2] = _mm_madd_epi16(u[0], k__cospi_p24_m08); + v[3] = _mm_madd_epi16(u[1], k__cospi_p24_m08); + v[4] = _mm_madd_epi16(u[2], k__cospi_m24_p08); + v[5] = _mm_madd_epi16(u[3], k__cospi_m24_p08); + v[6] = _mm_madd_epi16(u[2], k__cospi_p08_p24); + v[7] = _mm_madd_epi16(u[3], k__cospi_p08_p24); + v[8] = _mm_madd_epi16(u[4], k__cospi_p08_p24); + v[9] = _mm_madd_epi16(u[5], k__cospi_p08_p24); + v[10] = _mm_madd_epi16(u[4], k__cospi_p24_m08); + v[11] = _mm_madd_epi16(u[5], k__cospi_p24_m08); + v[12] = _mm_madd_epi16(u[6], k__cospi_m24_p08); + v[13] = _mm_madd_epi16(u[7], k__cospi_m24_p08); + v[14] = _mm_madd_epi16(u[6], k__cospi_p08_p24); + v[15] = _mm_madd_epi16(u[7], k__cospi_p08_p24); + + u[0] = _mm_add_epi32(v[0], v[4]); + u[1] = _mm_add_epi32(v[1], v[5]); + u[2] = _mm_add_epi32(v[2], v[6]); + u[3] = _mm_add_epi32(v[3], v[7]); + u[4] = _mm_sub_epi32(v[0], v[4]); + u[5] = _mm_sub_epi32(v[1], v[5]); + u[6] = _mm_sub_epi32(v[2], v[6]); + u[7] = _mm_sub_epi32(v[3], v[7]); + u[8] = _mm_add_epi32(v[8], v[12]); + u[9] = _mm_add_epi32(v[9], v[13]); + u[10] = _mm_add_epi32(v[10], v[14]); + u[11] = _mm_add_epi32(v[11], v[15]); + u[12] = _mm_sub_epi32(v[8], v[12]); + u[13] = _mm_sub_epi32(v[9], v[13]); + u[14] = _mm_sub_epi32(v[10], v[14]); + u[15] = _mm_sub_epi32(v[11], v[15]); + + u[0] = _mm_add_epi32(u[0], k__DCT_CONST_ROUNDING); + u[1] = _mm_add_epi32(u[1], k__DCT_CONST_ROUNDING); + u[2] = _mm_add_epi32(u[2], k__DCT_CONST_ROUNDING); + u[3] = _mm_add_epi32(u[3], k__DCT_CONST_ROUNDING); + u[4] = _mm_add_epi32(u[4], k__DCT_CONST_ROUNDING); + u[5] = _mm_add_epi32(u[5], k__DCT_CONST_ROUNDING); + u[6] = _mm_add_epi32(u[6], k__DCT_CONST_ROUNDING); + u[7] = _mm_add_epi32(u[7], k__DCT_CONST_ROUNDING); + u[8] = _mm_add_epi32(u[8], k__DCT_CONST_ROUNDING); + u[9] = _mm_add_epi32(u[9], k__DCT_CONST_ROUNDING); + u[10] = _mm_add_epi32(u[10], k__DCT_CONST_ROUNDING); + u[11] = _mm_add_epi32(u[11], k__DCT_CONST_ROUNDING); + u[12] = _mm_add_epi32(u[12], k__DCT_CONST_ROUNDING); + u[13] = _mm_add_epi32(u[13], k__DCT_CONST_ROUNDING); + u[14] = _mm_add_epi32(u[14], k__DCT_CONST_ROUNDING); + u[15] = _mm_add_epi32(u[15], k__DCT_CONST_ROUNDING); + + v[0] = _mm_srai_epi32(u[0], DCT_CONST_BITS); + v[1] = _mm_srai_epi32(u[1], DCT_CONST_BITS); + v[2] = _mm_srai_epi32(u[2], DCT_CONST_BITS); + v[3] = _mm_srai_epi32(u[3], DCT_CONST_BITS); + v[4] = _mm_srai_epi32(u[4], DCT_CONST_BITS); + v[5] = _mm_srai_epi32(u[5], DCT_CONST_BITS); + v[6] = _mm_srai_epi32(u[6], DCT_CONST_BITS); + v[7] = _mm_srai_epi32(u[7], DCT_CONST_BITS); + v[8] = _mm_srai_epi32(u[8], DCT_CONST_BITS); + v[9] = _mm_srai_epi32(u[9], DCT_CONST_BITS); + v[10] = _mm_srai_epi32(u[10], DCT_CONST_BITS); + v[11] = _mm_srai_epi32(u[11], DCT_CONST_BITS); + v[12] = _mm_srai_epi32(u[12], DCT_CONST_BITS); + v[13] = _mm_srai_epi32(u[13], DCT_CONST_BITS); + v[14] = _mm_srai_epi32(u[14], DCT_CONST_BITS); + v[15] = _mm_srai_epi32(u[15], DCT_CONST_BITS); + + s[4] = _mm_packs_epi32(v[0], v[1]); + s[5] = _mm_packs_epi32(v[2], v[3]); + s[6] = _mm_packs_epi32(v[4], v[5]); + s[7] = _mm_packs_epi32(v[6], v[7]); + + s[12] = _mm_packs_epi32(v[8], v[9]); + s[13] = _mm_packs_epi32(v[10], v[11]); + s[14] = _mm_packs_epi32(v[12], v[13]); + s[15] = _mm_packs_epi32(v[14], v[15]); + + // stage 4 + u[0] = _mm_unpacklo_epi16(s[2], s[3]); + u[1] = _mm_unpackhi_epi16(s[2], s[3]); + u[2] = _mm_unpacklo_epi16(s[6], s[7]); + u[3] = _mm_unpackhi_epi16(s[6], s[7]); + u[4] = _mm_unpacklo_epi16(s[10], s[11]); + u[5] = _mm_unpackhi_epi16(s[10], s[11]); + u[6] = _mm_unpacklo_epi16(s[14], s[15]); + u[7] = _mm_unpackhi_epi16(s[14], s[15]); + + v[0] = _mm_madd_epi16(u[0], k__cospi_m16_m16); + v[1] = _mm_madd_epi16(u[1], k__cospi_m16_m16); + v[2] = _mm_madd_epi16(u[0], k__cospi_p16_m16); + v[3] = _mm_madd_epi16(u[1], k__cospi_p16_m16); + v[4] = _mm_madd_epi16(u[2], k__cospi_p16_p16); + v[5] = _mm_madd_epi16(u[3], k__cospi_p16_p16); + v[6] = _mm_madd_epi16(u[2], k__cospi_m16_p16); + v[7] = _mm_madd_epi16(u[3], k__cospi_m16_p16); + v[8] = _mm_madd_epi16(u[4], k__cospi_p16_p16); + v[9] = _mm_madd_epi16(u[5], k__cospi_p16_p16); + v[10] = _mm_madd_epi16(u[4], k__cospi_m16_p16); + v[11] = _mm_madd_epi16(u[5], k__cospi_m16_p16); + v[12] = _mm_madd_epi16(u[6], k__cospi_m16_m16); + v[13] = _mm_madd_epi16(u[7], k__cospi_m16_m16); + v[14] = _mm_madd_epi16(u[6], k__cospi_p16_m16); + v[15] = _mm_madd_epi16(u[7], k__cospi_p16_m16); + + u[0] = _mm_add_epi32(v[0], k__DCT_CONST_ROUNDING); + u[1] = _mm_add_epi32(v[1], k__DCT_CONST_ROUNDING); + u[2] = _mm_add_epi32(v[2], k__DCT_CONST_ROUNDING); + u[3] = _mm_add_epi32(v[3], k__DCT_CONST_ROUNDING); + u[4] = _mm_add_epi32(v[4], k__DCT_CONST_ROUNDING); + u[5] = _mm_add_epi32(v[5], k__DCT_CONST_ROUNDING); + u[6] = _mm_add_epi32(v[6], k__DCT_CONST_ROUNDING); + u[7] = _mm_add_epi32(v[7], k__DCT_CONST_ROUNDING); + u[8] = _mm_add_epi32(v[8], k__DCT_CONST_ROUNDING); + u[9] = _mm_add_epi32(v[9], k__DCT_CONST_ROUNDING); + u[10] = _mm_add_epi32(v[10], k__DCT_CONST_ROUNDING); + u[11] = _mm_add_epi32(v[11], k__DCT_CONST_ROUNDING); + u[12] = _mm_add_epi32(v[12], k__DCT_CONST_ROUNDING); + u[13] = _mm_add_epi32(v[13], k__DCT_CONST_ROUNDING); + u[14] = _mm_add_epi32(v[14], k__DCT_CONST_ROUNDING); + u[15] = _mm_add_epi32(v[15], k__DCT_CONST_ROUNDING); + + v[0] = _mm_srai_epi32(u[0], DCT_CONST_BITS); + v[1] = _mm_srai_epi32(u[1], DCT_CONST_BITS); + v[2] = _mm_srai_epi32(u[2], DCT_CONST_BITS); + v[3] = _mm_srai_epi32(u[3], DCT_CONST_BITS); + v[4] = _mm_srai_epi32(u[4], DCT_CONST_BITS); + v[5] = _mm_srai_epi32(u[5], DCT_CONST_BITS); + v[6] = _mm_srai_epi32(u[6], DCT_CONST_BITS); + v[7] = _mm_srai_epi32(u[7], DCT_CONST_BITS); + v[8] = _mm_srai_epi32(u[8], DCT_CONST_BITS); + v[9] = _mm_srai_epi32(u[9], DCT_CONST_BITS); + v[10] = _mm_srai_epi32(u[10], DCT_CONST_BITS); + v[11] = _mm_srai_epi32(u[11], DCT_CONST_BITS); + v[12] = _mm_srai_epi32(u[12], DCT_CONST_BITS); + v[13] = _mm_srai_epi32(u[13], DCT_CONST_BITS); + v[14] = _mm_srai_epi32(u[14], DCT_CONST_BITS); + v[15] = _mm_srai_epi32(u[15], DCT_CONST_BITS); + + in[0] = s[0]; + in[1] = _mm_sub_epi16(kZero, s[8]); + in[2] = s[12]; + in[3] = _mm_sub_epi16(kZero, s[4]); + in[4] = _mm_packs_epi32(v[4], v[5]); + in[5] = _mm_packs_epi32(v[12], v[13]); + in[6] = _mm_packs_epi32(v[8], v[9]); + in[7] = _mm_packs_epi32(v[0], v[1]); + in[8] = _mm_packs_epi32(v[2], v[3]); + in[9] = _mm_packs_epi32(v[10], v[11]); + in[10] = _mm_packs_epi32(v[14], v[15]); + in[11] = _mm_packs_epi32(v[6], v[7]); + in[12] = s[5]; + in[13] = _mm_sub_epi16(kZero, s[13]); + in[14] = s[9]; + in[15] = _mm_sub_epi16(kZero, s[1]); +} + +static void fdct16_sse2(__m128i *in0, __m128i *in1) { + fdct16_8col(in0); + fdct16_8col(in1); + array_transpose_16x16(in0, in1); +} + +static void fadst16_sse2(__m128i *in0, __m128i *in1) { + fadst16_8col(in0); + fadst16_8col(in1); + array_transpose_16x16(in0, in1); +} + +#if CONFIG_EXT_TX +static void fidtx16_sse2(__m128i *in0, __m128i *in1) { + idtx16_8col(in0); + idtx16_8col(in1); + array_transpose_16x16(in0, in1); +} +#endif // CONFIG_EXT_TX + +void av1_fht16x16_sse2(const int16_t *input, tran_low_t *output, int stride, + int tx_type) { + __m128i in0[16], in1[16]; + + switch (tx_type) { + case DCT_DCT: + load_buffer_16x16(input, in0, in1, stride, 0, 0); + fdct16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fdct16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; + case ADST_DCT: + load_buffer_16x16(input, in0, in1, stride, 0, 0); + fadst16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fdct16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; + case DCT_ADST: + load_buffer_16x16(input, in0, in1, stride, 0, 0); + fdct16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fadst16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; + case ADST_ADST: + load_buffer_16x16(input, in0, in1, stride, 0, 0); + fadst16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fadst16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; +#if CONFIG_EXT_TX + case FLIPADST_DCT: + load_buffer_16x16(input, in0, in1, stride, 1, 0); + fadst16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fdct16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; + case DCT_FLIPADST: + load_buffer_16x16(input, in0, in1, stride, 0, 1); + fdct16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fadst16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; + case FLIPADST_FLIPADST: + load_buffer_16x16(input, in0, in1, stride, 1, 1); + fadst16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fadst16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; + case ADST_FLIPADST: + load_buffer_16x16(input, in0, in1, stride, 0, 1); + fadst16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fadst16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; + case FLIPADST_ADST: + load_buffer_16x16(input, in0, in1, stride, 1, 0); + fadst16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fadst16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; + case IDTX: + load_buffer_16x16(input, in0, in1, stride, 0, 0); + fidtx16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fidtx16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; + case V_DCT: + load_buffer_16x16(input, in0, in1, stride, 0, 0); + fdct16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fidtx16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; + case H_DCT: + load_buffer_16x16(input, in0, in1, stride, 0, 0); + fidtx16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fdct16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; + case V_ADST: + load_buffer_16x16(input, in0, in1, stride, 0, 0); + fadst16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fidtx16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; + case H_ADST: + load_buffer_16x16(input, in0, in1, stride, 0, 0); + fidtx16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fadst16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; + case V_FLIPADST: + load_buffer_16x16(input, in0, in1, stride, 1, 0); + fadst16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fidtx16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; + case H_FLIPADST: + load_buffer_16x16(input, in0, in1, stride, 0, 1); + fidtx16_sse2(in0, in1); + right_shift_16x16(in0, in1); + fadst16_sse2(in0, in1); + write_buffer_16x16(output, in0, in1, 16); + break; +#endif // CONFIG_EXT_TX + default: assert(0); break; + } +} + +static INLINE void prepare_4x8_row_first(__m128i *in) { + in[0] = _mm_unpacklo_epi64(in[0], in[2]); + in[1] = _mm_unpacklo_epi64(in[1], in[3]); + transpose_4x4(in); + in[4] = _mm_unpacklo_epi64(in[4], in[6]); + in[5] = _mm_unpacklo_epi64(in[5], in[7]); + transpose_4x4(in + 4); +} + +// Load input into the left-hand half of in (ie, into lanes 0..3 of +// each element of in). The right hand half (lanes 4..7) should be +// treated as being filled with "don't care" values. +static INLINE void load_buffer_4x8(const int16_t *input, __m128i *in, + int stride, int flipud, int fliplr) { + const int shift = 2; + if (!flipud) { + in[0] = _mm_loadl_epi64((const __m128i *)(input + 0 * stride)); + in[1] = _mm_loadl_epi64((const __m128i *)(input + 1 * stride)); + in[2] = _mm_loadl_epi64((const __m128i *)(input + 2 * stride)); + in[3] = _mm_loadl_epi64((const __m128i *)(input + 3 * stride)); + in[4] = _mm_loadl_epi64((const __m128i *)(input + 4 * stride)); + in[5] = _mm_loadl_epi64((const __m128i *)(input + 5 * stride)); + in[6] = _mm_loadl_epi64((const __m128i *)(input + 6 * stride)); + in[7] = _mm_loadl_epi64((const __m128i *)(input + 7 * stride)); + } else { + in[0] = _mm_loadl_epi64((const __m128i *)(input + 7 * stride)); + in[1] = _mm_loadl_epi64((const __m128i *)(input + 6 * stride)); + in[2] = _mm_loadl_epi64((const __m128i *)(input + 5 * stride)); + in[3] = _mm_loadl_epi64((const __m128i *)(input + 4 * stride)); + in[4] = _mm_loadl_epi64((const __m128i *)(input + 3 * stride)); + in[5] = _mm_loadl_epi64((const __m128i *)(input + 2 * stride)); + in[6] = _mm_loadl_epi64((const __m128i *)(input + 1 * stride)); + in[7] = _mm_loadl_epi64((const __m128i *)(input + 0 * stride)); + } + + if (fliplr) { + in[0] = _mm_shufflelo_epi16(in[0], 0x1b); + in[1] = _mm_shufflelo_epi16(in[1], 0x1b); + in[2] = _mm_shufflelo_epi16(in[2], 0x1b); + in[3] = _mm_shufflelo_epi16(in[3], 0x1b); + in[4] = _mm_shufflelo_epi16(in[4], 0x1b); + in[5] = _mm_shufflelo_epi16(in[5], 0x1b); + in[6] = _mm_shufflelo_epi16(in[6], 0x1b); + in[7] = _mm_shufflelo_epi16(in[7], 0x1b); + } + + in[0] = _mm_slli_epi16(in[0], shift); + in[1] = _mm_slli_epi16(in[1], shift); + in[2] = _mm_slli_epi16(in[2], shift); + in[3] = _mm_slli_epi16(in[3], shift); + in[4] = _mm_slli_epi16(in[4], shift); + in[5] = _mm_slli_epi16(in[5], shift); + in[6] = _mm_slli_epi16(in[6], shift); + in[7] = _mm_slli_epi16(in[7], shift); + + scale_sqrt2_8x4(in); + scale_sqrt2_8x4(in + 4); + prepare_4x8_row_first(in); +} + +static INLINE void write_buffer_4x8(tran_low_t *output, __m128i *res) { + __m128i in01, in23, in45, in67, sign01, sign23, sign45, sign67; + const int shift = 1; + + // revert the 8x8 txfm's transpose + array_transpose_8x8(res, res); + + in01 = _mm_unpacklo_epi64(res[0], res[1]); + in23 = _mm_unpacklo_epi64(res[2], res[3]); + in45 = _mm_unpacklo_epi64(res[4], res[5]); + in67 = _mm_unpacklo_epi64(res[6], res[7]); + + sign01 = _mm_srai_epi16(in01, 15); + sign23 = _mm_srai_epi16(in23, 15); + sign45 = _mm_srai_epi16(in45, 15); + sign67 = _mm_srai_epi16(in67, 15); + + in01 = _mm_sub_epi16(in01, sign01); + in23 = _mm_sub_epi16(in23, sign23); + in45 = _mm_sub_epi16(in45, sign45); + in67 = _mm_sub_epi16(in67, sign67); + + in01 = _mm_srai_epi16(in01, shift); + in23 = _mm_srai_epi16(in23, shift); + in45 = _mm_srai_epi16(in45, shift); + in67 = _mm_srai_epi16(in67, shift); + + store_output(&in01, (output + 0 * 8)); + store_output(&in23, (output + 1 * 8)); + store_output(&in45, (output + 2 * 8)); + store_output(&in67, (output + 3 * 8)); +} + +void av1_fht4x8_sse2(const int16_t *input, tran_low_t *output, int stride, + int tx_type) { + __m128i in[8]; + + switch (tx_type) { + case DCT_DCT: + load_buffer_4x8(input, in, stride, 0, 0); + fdct4_sse2(in); + fdct4_sse2(in + 4); + fdct8_sse2(in); + break; + case ADST_DCT: + load_buffer_4x8(input, in, stride, 0, 0); + fdct4_sse2(in); + fdct4_sse2(in + 4); + fadst8_sse2(in); + break; + case DCT_ADST: + load_buffer_4x8(input, in, stride, 0, 0); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fdct8_sse2(in); + break; + case ADST_ADST: + load_buffer_4x8(input, in, stride, 0, 0); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fadst8_sse2(in); + break; +#if CONFIG_EXT_TX + case FLIPADST_DCT: + load_buffer_4x8(input, in, stride, 1, 0); + fdct4_sse2(in); + fdct4_sse2(in + 4); + fadst8_sse2(in); + break; + case DCT_FLIPADST: + load_buffer_4x8(input, in, stride, 0, 1); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fdct8_sse2(in); + break; + case FLIPADST_FLIPADST: + load_buffer_4x8(input, in, stride, 1, 1); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fadst8_sse2(in); + break; + case ADST_FLIPADST: + load_buffer_4x8(input, in, stride, 0, 1); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fadst8_sse2(in); + break; + case FLIPADST_ADST: + load_buffer_4x8(input, in, stride, 1, 0); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fadst8_sse2(in); + break; + case IDTX: + load_buffer_4x8(input, in, stride, 0, 0); + fidtx4_sse2(in); + fidtx4_sse2(in + 4); + fidtx8_sse2(in); + break; + case V_DCT: + load_buffer_4x8(input, in, stride, 0, 0); + fidtx4_sse2(in); + fidtx4_sse2(in + 4); + fdct8_sse2(in); + break; + case H_DCT: + load_buffer_4x8(input, in, stride, 0, 0); + fdct4_sse2(in); + fdct4_sse2(in + 4); + fidtx8_sse2(in); + break; + case V_ADST: + load_buffer_4x8(input, in, stride, 0, 0); + fidtx4_sse2(in); + fidtx4_sse2(in + 4); + fadst8_sse2(in); + break; + case H_ADST: + load_buffer_4x8(input, in, stride, 0, 0); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fidtx8_sse2(in); + break; + case V_FLIPADST: + load_buffer_4x8(input, in, stride, 1, 0); + fidtx4_sse2(in); + fidtx4_sse2(in + 4); + fadst8_sse2(in); + break; + case H_FLIPADST: + load_buffer_4x8(input, in, stride, 0, 1); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fidtx8_sse2(in); + break; +#endif + default: assert(0); break; + } + write_buffer_4x8(output, in); +} + +// Load input into the left-hand half of in (ie, into lanes 0..3 of +// each element of in). The right hand half (lanes 4..7) should be +// treated as being filled with "don't care" values. +// The input is split horizontally into two 4x4 +// chunks 'l' and 'r'. Then 'l' is stored in the top-left 4x4 +// block of 'in' and 'r' is stored in the bottom-left block. +// This is to allow us to reuse 4x4 transforms. +static INLINE void load_buffer_8x4(const int16_t *input, __m128i *in, + int stride, int flipud, int fliplr) { + const int shift = 2; + if (!flipud) { + in[0] = _mm_loadu_si128((const __m128i *)(input + 0 * stride)); + in[1] = _mm_loadu_si128((const __m128i *)(input + 1 * stride)); + in[2] = _mm_loadu_si128((const __m128i *)(input + 2 * stride)); + in[3] = _mm_loadu_si128((const __m128i *)(input + 3 * stride)); + } else { + in[0] = _mm_loadu_si128((const __m128i *)(input + 3 * stride)); + in[1] = _mm_loadu_si128((const __m128i *)(input + 2 * stride)); + in[2] = _mm_loadu_si128((const __m128i *)(input + 1 * stride)); + in[3] = _mm_loadu_si128((const __m128i *)(input + 0 * stride)); + } + + if (fliplr) { + in[0] = mm_reverse_epi16(in[0]); + in[1] = mm_reverse_epi16(in[1]); + in[2] = mm_reverse_epi16(in[2]); + in[3] = mm_reverse_epi16(in[3]); + } + + in[0] = _mm_slli_epi16(in[0], shift); + in[1] = _mm_slli_epi16(in[1], shift); + in[2] = _mm_slli_epi16(in[2], shift); + in[3] = _mm_slli_epi16(in[3], shift); + + scale_sqrt2_8x4(in); + + in[4] = _mm_shuffle_epi32(in[0], 0xe); + in[5] = _mm_shuffle_epi32(in[1], 0xe); + in[6] = _mm_shuffle_epi32(in[2], 0xe); + in[7] = _mm_shuffle_epi32(in[3], 0xe); +} + +static INLINE void write_buffer_8x4(tran_low_t *output, __m128i *res) { + __m128i out0, out1, out2, out3, sign0, sign1, sign2, sign3; + const int shift = 1; + sign0 = _mm_srai_epi16(res[0], 15); + sign1 = _mm_srai_epi16(res[1], 15); + sign2 = _mm_srai_epi16(res[2], 15); + sign3 = _mm_srai_epi16(res[3], 15); + + out0 = _mm_sub_epi16(res[0], sign0); + out1 = _mm_sub_epi16(res[1], sign1); + out2 = _mm_sub_epi16(res[2], sign2); + out3 = _mm_sub_epi16(res[3], sign3); + + out0 = _mm_srai_epi16(out0, shift); + out1 = _mm_srai_epi16(out1, shift); + out2 = _mm_srai_epi16(out2, shift); + out3 = _mm_srai_epi16(out3, shift); + + store_output(&out0, (output + 0 * 8)); + store_output(&out1, (output + 1 * 8)); + store_output(&out2, (output + 2 * 8)); + store_output(&out3, (output + 3 * 8)); +} + +void av1_fht8x4_sse2(const int16_t *input, tran_low_t *output, int stride, + int tx_type) { + __m128i in[8]; + + switch (tx_type) { + case DCT_DCT: + load_buffer_8x4(input, in, stride, 0, 0); + fdct4_sse2(in); + fdct4_sse2(in + 4); + fdct8_sse2(in); + break; + case ADST_DCT: + load_buffer_8x4(input, in, stride, 0, 0); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fdct8_sse2(in); + break; + case DCT_ADST: + load_buffer_8x4(input, in, stride, 0, 0); + fdct4_sse2(in); + fdct4_sse2(in + 4); + fadst8_sse2(in); + break; + case ADST_ADST: + load_buffer_8x4(input, in, stride, 0, 0); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fadst8_sse2(in); + break; +#if CONFIG_EXT_TX + case FLIPADST_DCT: + load_buffer_8x4(input, in, stride, 1, 0); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fdct8_sse2(in); + break; + case DCT_FLIPADST: + load_buffer_8x4(input, in, stride, 0, 1); + fdct4_sse2(in); + fdct4_sse2(in + 4); + fadst8_sse2(in); + break; + case FLIPADST_FLIPADST: + load_buffer_8x4(input, in, stride, 1, 1); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fadst8_sse2(in); + break; + case ADST_FLIPADST: + load_buffer_8x4(input, in, stride, 0, 1); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fadst8_sse2(in); + break; + case FLIPADST_ADST: + load_buffer_8x4(input, in, stride, 1, 0); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fadst8_sse2(in); + break; + case IDTX: + load_buffer_8x4(input, in, stride, 0, 0); + fidtx4_sse2(in); + fidtx4_sse2(in + 4); + fidtx8_sse2(in); + break; + case V_DCT: + load_buffer_8x4(input, in, stride, 0, 0); + fdct4_sse2(in); + fdct4_sse2(in + 4); + fidtx8_sse2(in); + break; + case H_DCT: + load_buffer_8x4(input, in, stride, 0, 0); + fidtx4_sse2(in); + fidtx4_sse2(in + 4); + fdct8_sse2(in); + break; + case V_ADST: + load_buffer_8x4(input, in, stride, 0, 0); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fidtx8_sse2(in); + break; + case H_ADST: + load_buffer_8x4(input, in, stride, 0, 0); + fidtx4_sse2(in); + fidtx4_sse2(in + 4); + fadst8_sse2(in); + break; + case V_FLIPADST: + load_buffer_8x4(input, in, stride, 1, 0); + fadst4_sse2(in); + fadst4_sse2(in + 4); + fidtx8_sse2(in); + break; + case H_FLIPADST: + load_buffer_8x4(input, in, stride, 0, 1); + fidtx4_sse2(in); + fidtx4_sse2(in + 4); + fadst8_sse2(in); + break; +#endif + default: assert(0); break; + } + write_buffer_8x4(output, in); +} + +static INLINE void load_buffer_8x16(const int16_t *input, __m128i *in, + int stride, int flipud, int fliplr) { + // Load 2 8x8 blocks + const int16_t *t = input; + const int16_t *b = input + 8 * stride; + + if (flipud) { + const int16_t *const tmp = t; + t = b; + b = tmp; + } + + load_buffer_8x8(t, in, stride, flipud, fliplr); + scale_sqrt2_8x8(in); + load_buffer_8x8(b, in + 8, stride, flipud, fliplr); + scale_sqrt2_8x8(in + 8); +} + +static INLINE void round_power_of_two_signed(__m128i *x, int n) { + const __m128i rounding = _mm_set1_epi16((1 << n) >> 1); + const __m128i sign = _mm_srai_epi16(*x, 15); + const __m128i res = _mm_add_epi16(_mm_add_epi16(*x, rounding), sign); + *x = _mm_srai_epi16(res, n); +} + +static void row_8x16_rounding(__m128i *in, int bits) { + int i; + for (i = 0; i < 16; i++) { + round_power_of_two_signed(&in[i], bits); + } +} + +void av1_fht8x16_sse2(const int16_t *input, tran_low_t *output, int stride, + int tx_type) { + __m128i in[16]; + + __m128i *const t = in; // Alias to top 8x8 sub block + __m128i *const b = in + 8; // Alias to bottom 8x8 sub block + + switch (tx_type) { + case DCT_DCT: + load_buffer_8x16(input, in, stride, 0, 0); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fdct8_sse2(t); + fdct8_sse2(b); + row_8x16_rounding(in, 2); + fdct16_8col(in); + break; + case ADST_DCT: + load_buffer_8x16(input, in, stride, 0, 0); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fdct8_sse2(t); + fdct8_sse2(b); + row_8x16_rounding(in, 2); + fadst16_8col(in); + break; + case DCT_ADST: + load_buffer_8x16(input, in, stride, 0, 0); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fadst8_sse2(t); + fadst8_sse2(b); + row_8x16_rounding(in, 2); + fdct16_8col(in); + break; + case ADST_ADST: + load_buffer_8x16(input, in, stride, 0, 0); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fadst8_sse2(t); + fadst8_sse2(b); + row_8x16_rounding(in, 2); + fadst16_8col(in); + break; +#if CONFIG_EXT_TX + case FLIPADST_DCT: + load_buffer_8x16(input, in, stride, 1, 0); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fdct8_sse2(t); + fdct8_sse2(b); + row_8x16_rounding(in, 2); + fadst16_8col(in); + break; + case DCT_FLIPADST: + load_buffer_8x16(input, in, stride, 0, 1); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fadst8_sse2(t); + fadst8_sse2(b); + row_8x16_rounding(in, 2); + fdct16_8col(in); + break; + case FLIPADST_FLIPADST: + load_buffer_8x16(input, in, stride, 1, 1); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fadst8_sse2(t); + fadst8_sse2(b); + row_8x16_rounding(in, 2); + fadst16_8col(in); + break; + case ADST_FLIPADST: + load_buffer_8x16(input, in, stride, 0, 1); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fadst8_sse2(t); + fadst8_sse2(b); + row_8x16_rounding(in, 2); + fadst16_8col(in); + break; + case FLIPADST_ADST: + load_buffer_8x16(input, in, stride, 1, 0); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fadst8_sse2(t); + fadst8_sse2(b); + row_8x16_rounding(in, 2); + fadst16_8col(in); + break; + case IDTX: + load_buffer_8x16(input, in, stride, 0, 0); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fidtx8_sse2(t); + fidtx8_sse2(b); + row_8x16_rounding(in, 2); + idtx16_8col(in); + break; + case V_DCT: + load_buffer_8x16(input, in, stride, 0, 0); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fidtx8_sse2(t); + fidtx8_sse2(b); + row_8x16_rounding(in, 2); + fdct16_8col(in); + break; + case H_DCT: + load_buffer_8x16(input, in, stride, 0, 0); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fdct8_sse2(t); + fdct8_sse2(b); + row_8x16_rounding(in, 2); + idtx16_8col(in); + break; + case V_ADST: + load_buffer_8x16(input, in, stride, 0, 0); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fidtx8_sse2(t); + fidtx8_sse2(b); + row_8x16_rounding(in, 2); + fadst16_8col(in); + break; + case H_ADST: + load_buffer_8x16(input, in, stride, 0, 0); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fadst8_sse2(t); + fadst8_sse2(b); + row_8x16_rounding(in, 2); + idtx16_8col(in); + break; + case V_FLIPADST: + load_buffer_8x16(input, in, stride, 1, 0); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fidtx8_sse2(t); + fidtx8_sse2(b); + row_8x16_rounding(in, 2); + fadst16_8col(in); + break; + case H_FLIPADST: + load_buffer_8x16(input, in, stride, 0, 1); + array_transpose_8x8(t, t); + array_transpose_8x8(b, b); + fadst8_sse2(t); + fadst8_sse2(b); + row_8x16_rounding(in, 2); + idtx16_8col(in); + break; +#endif + default: assert(0); break; + } + write_buffer_8x8(output, t, 8); + write_buffer_8x8(output + 64, b, 8); +} + +static INLINE void load_buffer_16x8(const int16_t *input, __m128i *in, + int stride, int flipud, int fliplr) { + // Load 2 8x8 blocks + const int16_t *l = input; + const int16_t *r = input + 8; + + if (fliplr) { + const int16_t *const tmp = l; + l = r; + r = tmp; + } + + // load first 8 columns + load_buffer_8x8(l, in, stride, flipud, fliplr); + scale_sqrt2_8x8(in); + load_buffer_8x8(r, in + 8, stride, flipud, fliplr); + scale_sqrt2_8x8(in + 8); +} + +#define col_16x8_rounding row_8x16_rounding + +void av1_fht16x8_sse2(const int16_t *input, tran_low_t *output, int stride, + int tx_type) { + __m128i in[16]; + + __m128i *const l = in; // Alias to left 8x8 sub block + __m128i *const r = in + 8; // Alias to right 8x8 sub block, which we store + // in the second half of the array + + switch (tx_type) { + case DCT_DCT: + load_buffer_16x8(input, in, stride, 0, 0); + fdct8_sse2(l); + fdct8_sse2(r); + col_16x8_rounding(in, 2); + fdct16_8col(in); + break; + case ADST_DCT: + load_buffer_16x8(input, in, stride, 0, 0); + fadst8_sse2(l); + fadst8_sse2(r); + col_16x8_rounding(in, 2); + fdct16_8col(in); + break; + case DCT_ADST: + load_buffer_16x8(input, in, stride, 0, 0); + fdct8_sse2(l); + fdct8_sse2(r); + col_16x8_rounding(in, 2); + fadst16_8col(in); + break; + case ADST_ADST: + load_buffer_16x8(input, in, stride, 0, 0); + fadst8_sse2(l); + fadst8_sse2(r); + col_16x8_rounding(in, 2); + fadst16_8col(in); + break; +#if CONFIG_EXT_TX + case FLIPADST_DCT: + load_buffer_16x8(input, in, stride, 1, 0); + fadst8_sse2(l); + fadst8_sse2(r); + col_16x8_rounding(in, 2); + fdct16_8col(in); + break; + case DCT_FLIPADST: + load_buffer_16x8(input, in, stride, 0, 1); + fdct8_sse2(l); + fdct8_sse2(r); + col_16x8_rounding(in, 2); + fadst16_8col(in); + break; + case FLIPADST_FLIPADST: + load_buffer_16x8(input, in, stride, 1, 1); + fadst8_sse2(l); + fadst8_sse2(r); + col_16x8_rounding(in, 2); + fadst16_8col(in); + break; + case ADST_FLIPADST: + load_buffer_16x8(input, in, stride, 0, 1); + fadst8_sse2(l); + fadst8_sse2(r); + col_16x8_rounding(in, 2); + fadst16_8col(in); + break; + case FLIPADST_ADST: + load_buffer_16x8(input, in, stride, 1, 0); + fadst8_sse2(l); + fadst8_sse2(r); + col_16x8_rounding(in, 2); + fadst16_8col(in); + break; + case IDTX: + load_buffer_16x8(input, in, stride, 0, 0); + fidtx8_sse2(l); + fidtx8_sse2(r); + col_16x8_rounding(in, 2); + idtx16_8col(in); + break; + case V_DCT: + load_buffer_16x8(input, in, stride, 0, 0); + fdct8_sse2(l); + fdct8_sse2(r); + col_16x8_rounding(in, 2); + idtx16_8col(in); + break; + case H_DCT: + load_buffer_16x8(input, in, stride, 0, 0); + fidtx8_sse2(l); + fidtx8_sse2(r); + col_16x8_rounding(in, 2); + fdct16_8col(in); + break; + case V_ADST: + load_buffer_16x8(input, in, stride, 0, 0); + fadst8_sse2(l); + fadst8_sse2(r); + col_16x8_rounding(in, 2); + idtx16_8col(in); + break; + case H_ADST: + load_buffer_16x8(input, in, stride, 0, 0); + fidtx8_sse2(l); + fidtx8_sse2(r); + col_16x8_rounding(in, 2); + fadst16_8col(in); + break; + case V_FLIPADST: + load_buffer_16x8(input, in, stride, 1, 0); + fadst8_sse2(l); + fadst8_sse2(r); + col_16x8_rounding(in, 2); + idtx16_8col(in); + break; + case H_FLIPADST: + load_buffer_16x8(input, in, stride, 0, 1); + fidtx8_sse2(l); + fidtx8_sse2(r); + col_16x8_rounding(in, 2); + fadst16_8col(in); + break; +#endif + default: assert(0); break; + } + array_transpose_8x8(l, l); + array_transpose_8x8(r, r); + write_buffer_8x8(output, l, 16); + write_buffer_8x8(output + 8, r, 16); +} + +// Note: The 16-column 32-element transforms expect their input to be +// split up into a 2x2 grid of 8x16 blocks +static INLINE void fdct32_16col(__m128i *tl, __m128i *tr, __m128i *bl, + __m128i *br) { + fdct32_8col(tl, bl); + fdct32_8col(tr, br); + array_transpose_16x16(tl, tr); + array_transpose_16x16(bl, br); +} + +#if CONFIG_EXT_TX +static INLINE void fidtx32_16col(__m128i *tl, __m128i *tr, __m128i *bl, + __m128i *br) { + int i; + for (i = 0; i < 16; ++i) { + tl[i] = _mm_slli_epi16(tl[i], 2); + tr[i] = _mm_slli_epi16(tr[i], 2); + bl[i] = _mm_slli_epi16(bl[i], 2); + br[i] = _mm_slli_epi16(br[i], 2); + } + array_transpose_16x16(tl, tr); + array_transpose_16x16(bl, br); +} +#endif + +static INLINE void load_buffer_16x32(const int16_t *input, __m128i *intl, + __m128i *intr, __m128i *inbl, + __m128i *inbr, int stride, int flipud, + int fliplr) { + int i; + if (flipud) { + input = input + 31 * stride; + stride = -stride; + } + + for (i = 0; i < 16; ++i) { + intl[i] = _mm_slli_epi16( + _mm_load_si128((const __m128i *)(input + i * stride + 0)), 2); + intr[i] = _mm_slli_epi16( + _mm_load_si128((const __m128i *)(input + i * stride + 8)), 2); + inbl[i] = _mm_slli_epi16( + _mm_load_si128((const __m128i *)(input + (i + 16) * stride + 0)), 2); + inbr[i] = _mm_slli_epi16( + _mm_load_si128((const __m128i *)(input + (i + 16) * stride + 8)), 2); + } + + if (fliplr) { + __m128i tmp; + for (i = 0; i < 16; ++i) { + tmp = intl[i]; + intl[i] = mm_reverse_epi16(intr[i]); + intr[i] = mm_reverse_epi16(tmp); + tmp = inbl[i]; + inbl[i] = mm_reverse_epi16(inbr[i]); + inbr[i] = mm_reverse_epi16(tmp); + } + } + + scale_sqrt2_8x16(intl); + scale_sqrt2_8x16(intr); + scale_sqrt2_8x16(inbl); + scale_sqrt2_8x16(inbr); +} + +static INLINE void write_buffer_16x32(tran_low_t *output, __m128i *restl, + __m128i *restr, __m128i *resbl, + __m128i *resbr) { + int i; + for (i = 0; i < 16; ++i) { + store_output(&restl[i], output + i * 16 + 0); + store_output(&restr[i], output + i * 16 + 8); + store_output(&resbl[i], output + (i + 16) * 16 + 0); + store_output(&resbr[i], output + (i + 16) * 16 + 8); + } +} + +static INLINE void round_signed_8x8(__m128i *in, const int bit) { + const __m128i rounding = _mm_set1_epi16((1 << bit) >> 1); + __m128i sign0 = _mm_srai_epi16(in[0], 15); + __m128i sign1 = _mm_srai_epi16(in[1], 15); + __m128i sign2 = _mm_srai_epi16(in[2], 15); + __m128i sign3 = _mm_srai_epi16(in[3], 15); + __m128i sign4 = _mm_srai_epi16(in[4], 15); + __m128i sign5 = _mm_srai_epi16(in[5], 15); + __m128i sign6 = _mm_srai_epi16(in[6], 15); + __m128i sign7 = _mm_srai_epi16(in[7], 15); + + in[0] = _mm_add_epi16(_mm_add_epi16(in[0], rounding), sign0); + in[1] = _mm_add_epi16(_mm_add_epi16(in[1], rounding), sign1); + in[2] = _mm_add_epi16(_mm_add_epi16(in[2], rounding), sign2); + in[3] = _mm_add_epi16(_mm_add_epi16(in[3], rounding), sign3); + in[4] = _mm_add_epi16(_mm_add_epi16(in[4], rounding), sign4); + in[5] = _mm_add_epi16(_mm_add_epi16(in[5], rounding), sign5); + in[6] = _mm_add_epi16(_mm_add_epi16(in[6], rounding), sign6); + in[7] = _mm_add_epi16(_mm_add_epi16(in[7], rounding), sign7); + + in[0] = _mm_srai_epi16(in[0], bit); + in[1] = _mm_srai_epi16(in[1], bit); + in[2] = _mm_srai_epi16(in[2], bit); + in[3] = _mm_srai_epi16(in[3], bit); + in[4] = _mm_srai_epi16(in[4], bit); + in[5] = _mm_srai_epi16(in[5], bit); + in[6] = _mm_srai_epi16(in[6], bit); + in[7] = _mm_srai_epi16(in[7], bit); +} + +static INLINE void round_signed_16x16(__m128i *in0, __m128i *in1) { + const int bit = 4; + round_signed_8x8(in0, bit); + round_signed_8x8(in0 + 8, bit); + round_signed_8x8(in1, bit); + round_signed_8x8(in1 + 8, bit); +} + +// Note: +// suffix "t" indicates the transpose operation comes first +static void fdct16t_sse2(__m128i *in0, __m128i *in1) { + array_transpose_16x16(in0, in1); + fdct16_8col(in0); + fdct16_8col(in1); +} + +static void fadst16t_sse2(__m128i *in0, __m128i *in1) { + array_transpose_16x16(in0, in1); + fadst16_8col(in0); + fadst16_8col(in1); +} + +static INLINE void fdct32t_16col(__m128i *tl, __m128i *tr, __m128i *bl, + __m128i *br) { + array_transpose_16x16(tl, tr); + array_transpose_16x16(bl, br); + fdct32_8col(tl, bl); + fdct32_8col(tr, br); +} + +typedef enum transpose_indicator_ { + transpose, + no_transpose, +} transpose_indicator; + +static INLINE void fhalfright32_16col(__m128i *tl, __m128i *tr, __m128i *bl, + __m128i *br, transpose_indicator t) { + __m128i tmpl[16], tmpr[16]; + int i; + + // Copy the bottom half of the input to temporary storage + for (i = 0; i < 16; ++i) { + tmpl[i] = bl[i]; + tmpr[i] = br[i]; + } + + // Generate the bottom half of the output + for (i = 0; i < 16; ++i) { + bl[i] = _mm_slli_epi16(tl[i], 2); + br[i] = _mm_slli_epi16(tr[i], 2); + } + array_transpose_16x16(bl, br); + + // Copy the temporary storage back to the top half of the input + for (i = 0; i < 16; ++i) { + tl[i] = tmpl[i]; + tr[i] = tmpr[i]; + } + + // Generate the top half of the output + scale_sqrt2_8x16(tl); + scale_sqrt2_8x16(tr); + if (t == transpose) + fdct16t_sse2(tl, tr); + else + fdct16_sse2(tl, tr); +} + +// Note on data layout, for both this and the 32x16 transforms: +// So that we can reuse the 16-element transforms easily, +// we want to split the input into 8x16 blocks. +// For 16x32, this means the input is a 2x2 grid of such blocks. +// For 32x16, it means the input is a 4x1 grid. +void av1_fht16x32_sse2(const int16_t *input, tran_low_t *output, int stride, + int tx_type) { + __m128i intl[16], intr[16], inbl[16], inbr[16]; + + switch (tx_type) { + case DCT_DCT: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0); + fdct16t_sse2(intl, intr); + fdct16t_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fdct32t_16col(intl, intr, inbl, inbr); + break; + case ADST_DCT: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0); + fdct16t_sse2(intl, intr); + fdct16t_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fhalfright32_16col(intl, intr, inbl, inbr, transpose); + break; + case DCT_ADST: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0); + fadst16t_sse2(intl, intr); + fadst16t_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fdct32t_16col(intl, intr, inbl, inbr); + break; + case ADST_ADST: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0); + fadst16t_sse2(intl, intr); + fadst16t_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fhalfright32_16col(intl, intr, inbl, inbr, transpose); + break; +#if CONFIG_EXT_TX + case FLIPADST_DCT: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 1, 0); + fdct16t_sse2(intl, intr); + fdct16t_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fhalfright32_16col(intl, intr, inbl, inbr, transpose); + break; + case DCT_FLIPADST: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 1); + fadst16t_sse2(intl, intr); + fadst16t_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fdct32t_16col(intl, intr, inbl, inbr); + break; + case FLIPADST_FLIPADST: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 1, 1); + fadst16t_sse2(intl, intr); + fadst16t_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fhalfright32_16col(intl, intr, inbl, inbr, transpose); + break; + case ADST_FLIPADST: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 1); + fadst16t_sse2(intl, intr); + fadst16t_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fhalfright32_16col(intl, intr, inbl, inbr, transpose); + break; + case FLIPADST_ADST: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 1, 0); + fadst16t_sse2(intl, intr); + fadst16t_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fhalfright32_16col(intl, intr, inbl, inbr, transpose); + break; + case IDTX: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0); + fidtx16_sse2(intl, intr); + fidtx16_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fidtx32_16col(intl, intr, inbl, inbr); + break; + case V_DCT: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0); + fidtx16_sse2(intl, intr); + fidtx16_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fdct32t_16col(intl, intr, inbl, inbr); + break; + case H_DCT: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0); + fdct16t_sse2(intl, intr); + fdct16t_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fidtx32_16col(intl, intr, inbl, inbr); + break; + case V_ADST: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0); + fidtx16_sse2(intl, intr); + fidtx16_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fhalfright32_16col(intl, intr, inbl, inbr, transpose); + break; + case H_ADST: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 0); + fadst16t_sse2(intl, intr); + fadst16t_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fidtx32_16col(intl, intr, inbl, inbr); + break; + case V_FLIPADST: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 1, 0); + fidtx16_sse2(intl, intr); + fidtx16_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fhalfright32_16col(intl, intr, inbl, inbr, transpose); + break; + case H_FLIPADST: + load_buffer_16x32(input, intl, intr, inbl, inbr, stride, 0, 1); + fadst16t_sse2(intl, intr); + fadst16t_sse2(inbl, inbr); + round_signed_16x16(intl, intr); + round_signed_16x16(inbl, inbr); + fidtx32_16col(intl, intr, inbl, inbr); + break; +#endif + default: assert(0); break; + } + write_buffer_16x32(output, intl, intr, inbl, inbr); +} + +static INLINE void load_buffer_32x16(const int16_t *input, __m128i *in0, + __m128i *in1, __m128i *in2, __m128i *in3, + int stride, int flipud, int fliplr) { + int i; + if (flipud) { + input += 15 * stride; + stride = -stride; + } + + for (i = 0; i < 16; ++i) { + in0[i] = _mm_slli_epi16( + _mm_load_si128((const __m128i *)(input + i * stride + 0)), 2); + in1[i] = _mm_slli_epi16( + _mm_load_si128((const __m128i *)(input + i * stride + 8)), 2); + in2[i] = _mm_slli_epi16( + _mm_load_si128((const __m128i *)(input + i * stride + 16)), 2); + in3[i] = _mm_slli_epi16( + _mm_load_si128((const __m128i *)(input + i * stride + 24)), 2); + } + + if (fliplr) { + for (i = 0; i < 16; ++i) { + __m128i tmp1 = in0[i]; + __m128i tmp2 = in1[i]; + in0[i] = mm_reverse_epi16(in3[i]); + in1[i] = mm_reverse_epi16(in2[i]); + in2[i] = mm_reverse_epi16(tmp2); + in3[i] = mm_reverse_epi16(tmp1); + } + } + + scale_sqrt2_8x16(in0); + scale_sqrt2_8x16(in1); + scale_sqrt2_8x16(in2); + scale_sqrt2_8x16(in3); +} + +static INLINE void write_buffer_32x16(tran_low_t *output, __m128i *res0, + __m128i *res1, __m128i *res2, + __m128i *res3) { + int i; + for (i = 0; i < 16; ++i) { + store_output(&res0[i], output + i * 32 + 0); + store_output(&res1[i], output + i * 32 + 8); + store_output(&res2[i], output + i * 32 + 16); + store_output(&res3[i], output + i * 32 + 24); + } +} + +void av1_fht32x16_sse2(const int16_t *input, tran_low_t *output, int stride, + int tx_type) { + __m128i in0[16], in1[16], in2[16], in3[16]; + + load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 0); + switch (tx_type) { + case DCT_DCT: + fdct16_sse2(in0, in1); + fdct16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fdct32_16col(in0, in1, in2, in3); + break; + case ADST_DCT: + fadst16_sse2(in0, in1); + fadst16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fdct32_16col(in0, in1, in2, in3); + break; + case DCT_ADST: + fdct16_sse2(in0, in1); + fdct16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fhalfright32_16col(in0, in1, in2, in3, no_transpose); + break; + case ADST_ADST: + fadst16_sse2(in0, in1); + fadst16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fhalfright32_16col(in0, in1, in2, in3, no_transpose); + break; +#if CONFIG_EXT_TX + case FLIPADST_DCT: + load_buffer_32x16(input, in0, in1, in2, in3, stride, 1, 0); + fadst16_sse2(in0, in1); + fadst16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fdct32_16col(in0, in1, in2, in3); + break; + case DCT_FLIPADST: + load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 1); + fdct16_sse2(in0, in1); + fdct16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fhalfright32_16col(in0, in1, in2, in3, no_transpose); + break; + case FLIPADST_FLIPADST: + load_buffer_32x16(input, in0, in1, in2, in3, stride, 1, 1); + fadst16_sse2(in0, in1); + fadst16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fhalfright32_16col(in0, in1, in2, in3, no_transpose); + break; + case ADST_FLIPADST: + load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 1); + fadst16_sse2(in0, in1); + fadst16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fhalfright32_16col(in0, in1, in2, in3, no_transpose); + break; + case FLIPADST_ADST: + load_buffer_32x16(input, in0, in1, in2, in3, stride, 1, 0); + fadst16_sse2(in0, in1); + fadst16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fhalfright32_16col(in0, in1, in2, in3, no_transpose); + break; + case IDTX: + load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 0); + fidtx16_sse2(in0, in1); + fidtx16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fidtx32_16col(in0, in1, in2, in3); + break; + case V_DCT: + load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 0); + fdct16_sse2(in0, in1); + fdct16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fidtx32_16col(in0, in1, in2, in3); + break; + case H_DCT: + load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 0); + fidtx16_sse2(in0, in1); + fidtx16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fdct32_16col(in0, in1, in2, in3); + break; + case V_ADST: + load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 0); + fadst16_sse2(in0, in1); + fadst16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fidtx32_16col(in0, in1, in2, in3); + break; + case H_ADST: + load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 0); + fidtx16_sse2(in0, in1); + fidtx16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fhalfright32_16col(in0, in1, in2, in3, no_transpose); + break; + case V_FLIPADST: + load_buffer_32x16(input, in0, in1, in2, in3, stride, 1, 0); + fadst16_sse2(in0, in1); + fadst16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fidtx32_16col(in0, in1, in2, in3); + break; + case H_FLIPADST: + load_buffer_32x16(input, in0, in1, in2, in3, stride, 0, 1); + fidtx16_sse2(in0, in1); + fidtx16_sse2(in2, in3); + round_signed_16x16(in0, in1); + round_signed_16x16(in2, in3); + fhalfright32_16col(in0, in1, in2, in3, no_transpose); + break; +#endif + default: assert(0); break; + } + write_buffer_32x16(output, in0, in1, in2, in3); +} + +// Note: +// 32x32 hybrid fwd txfm +// 4x2 grids of 8x16 block. Each block is represented by __m128i in[16] +static INLINE void load_buffer_32x32(const int16_t *input, + __m128i *in0 /*in0[32]*/, + __m128i *in1 /*in1[32]*/, + __m128i *in2 /*in2[32]*/, + __m128i *in3 /*in3[32]*/, int stride, + int flipud, int fliplr) { + if (flipud) { + input += 31 * stride; + stride = -stride; + } + + int i; + for (i = 0; i < 32; ++i) { + in0[i] = _mm_slli_epi16( + _mm_load_si128((const __m128i *)(input + i * stride + 0)), 2); + in1[i] = _mm_slli_epi16( + _mm_load_si128((const __m128i *)(input + i * stride + 8)), 2); + in2[i] = _mm_slli_epi16( + _mm_load_si128((const __m128i *)(input + i * stride + 16)), 2); + in3[i] = _mm_slli_epi16( + _mm_load_si128((const __m128i *)(input + i * stride + 24)), 2); + } + + if (fliplr) { + for (i = 0; i < 32; ++i) { + __m128i tmp1 = in0[i]; + __m128i tmp2 = in1[i]; + in0[i] = mm_reverse_epi16(in3[i]); + in1[i] = mm_reverse_epi16(in2[i]); + in2[i] = mm_reverse_epi16(tmp2); + in3[i] = mm_reverse_epi16(tmp1); + } + } +} + +static INLINE void swap_16x16(__m128i *b0l /*b0l[16]*/, + __m128i *b0r /*b0r[16]*/, + __m128i *b1l /*b1l[16]*/, + __m128i *b1r /*b1r[16]*/) { + int i; + for (i = 0; i < 16; ++i) { + __m128i tmp0 = b1l[i]; + __m128i tmp1 = b1r[i]; + b1l[i] = b0l[i]; + b1r[i] = b0r[i]; + b0l[i] = tmp0; + b0r[i] = tmp1; + } +} + +static INLINE void fdct32(__m128i *in0, __m128i *in1, __m128i *in2, + __m128i *in3) { + fdct32_8col(in0, &in0[16]); + fdct32_8col(in1, &in1[16]); + fdct32_8col(in2, &in2[16]); + fdct32_8col(in3, &in3[16]); + + array_transpose_16x16(in0, in1); + array_transpose_16x16(&in0[16], &in1[16]); + array_transpose_16x16(in2, in3); + array_transpose_16x16(&in2[16], &in3[16]); + + swap_16x16(&in0[16], &in1[16], in2, in3); +} + +static INLINE void fhalfright32(__m128i *in0, __m128i *in1, __m128i *in2, + __m128i *in3) { + fhalfright32_16col(in0, in1, &in0[16], &in1[16], no_transpose); + fhalfright32_16col(in2, in3, &in2[16], &in3[16], no_transpose); + swap_16x16(&in0[16], &in1[16], in2, in3); +} + +#if CONFIG_EXT_TX +static INLINE void fidtx32(__m128i *in0, __m128i *in1, __m128i *in2, + __m128i *in3) { + fidtx32_16col(in0, in1, &in0[16], &in1[16]); + fidtx32_16col(in2, in3, &in2[16], &in3[16]); + swap_16x16(&in0[16], &in1[16], in2, in3); +} +#endif + +static INLINE void round_signed_32x32(__m128i *in0, __m128i *in1, __m128i *in2, + __m128i *in3) { + round_signed_16x16(in0, in1); + round_signed_16x16(&in0[16], &in1[16]); + round_signed_16x16(in2, in3); + round_signed_16x16(&in2[16], &in3[16]); +} + +static INLINE void write_buffer_32x32(__m128i *in0, __m128i *in1, __m128i *in2, + __m128i *in3, tran_low_t *output) { + int i; + for (i = 0; i < 32; ++i) { + store_output(&in0[i], output + i * 32 + 0); + store_output(&in1[i], output + i * 32 + 8); + store_output(&in2[i], output + i * 32 + 16); + store_output(&in3[i], output + i * 32 + 24); + } +} + +void av1_fht32x32_sse2(const int16_t *input, tran_low_t *output, int stride, + int tx_type) { + __m128i in0[32], in1[32], in2[32], in3[32]; + + load_buffer_32x32(input, in0, in1, in2, in3, stride, 0, 0); + switch (tx_type) { + case DCT_DCT: + fdct32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fdct32(in0, in1, in2, in3); + break; + case ADST_DCT: + fhalfright32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fdct32(in0, in1, in2, in3); + break; + case DCT_ADST: + fdct32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fhalfright32(in0, in1, in2, in3); + break; + case ADST_ADST: + fhalfright32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fhalfright32(in0, in1, in2, in3); + break; +#if CONFIG_EXT_TX + case FLIPADST_DCT: + load_buffer_32x32(input, in0, in1, in2, in3, stride, 1, 0); + fhalfright32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fdct32(in0, in1, in2, in3); + break; + case DCT_FLIPADST: + load_buffer_32x32(input, in0, in1, in2, in3, stride, 0, 1); + fdct32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fhalfright32(in0, in1, in2, in3); + break; + case FLIPADST_FLIPADST: + load_buffer_32x32(input, in0, in1, in2, in3, stride, 1, 1); + fhalfright32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fhalfright32(in0, in1, in2, in3); + break; + case ADST_FLIPADST: + load_buffer_32x32(input, in0, in1, in2, in3, stride, 0, 1); + fhalfright32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fhalfright32(in0, in1, in2, in3); + break; + case FLIPADST_ADST: + load_buffer_32x32(input, in0, in1, in2, in3, stride, 1, 0); + fhalfright32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fhalfright32(in0, in1, in2, in3); + break; + case IDTX: + fidtx32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fidtx32(in0, in1, in2, in3); + break; + case V_DCT: + fdct32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fidtx32(in0, in1, in2, in3); + break; + case H_DCT: + fidtx32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fdct32(in0, in1, in2, in3); + break; + case V_ADST: + fhalfright32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fidtx32(in0, in1, in2, in3); + break; + case H_ADST: + fidtx32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fhalfright32(in0, in1, in2, in3); + break; + case V_FLIPADST: + load_buffer_32x32(input, in0, in1, in2, in3, stride, 1, 0); + fhalfright32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fidtx32(in0, in1, in2, in3); + break; + case H_FLIPADST: + load_buffer_32x32(input, in0, in1, in2, in3, stride, 0, 1); + fidtx32(in0, in1, in2, in3); + round_signed_32x32(in0, in1, in2, in3); + fhalfright32(in0, in1, in2, in3); + break; +#endif + default: assert(0); + } + write_buffer_32x32(in0, in1, in2, in3, output); +} diff --git a/third_party/aom/av1/encoder/x86/dct_sse2.asm b/third_party/aom/av1/encoder/x86/dct_sse2.asm new file mode 100644 index 000000000..a99db3d6e --- /dev/null +++ b/third_party/aom/av1/encoder/x86/dct_sse2.asm @@ -0,0 +1,87 @@ +; +; Copyright (c) 2016, Alliance for Open Media. All rights reserved +; +; This source code is subject to the terms of the BSD 2 Clause License and +; the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +; was not distributed with this source code in the LICENSE file, you can +; obtain it at www.aomedia.org/license/software. If the Alliance for Open +; Media Patent License 1.0 was not distributed with this source code in the +; PATENTS file, you can obtain it at www.aomedia.org/license/patent. +; + +%define private_prefix av1 + +%include "third_party/x86inc/x86inc.asm" + +SECTION .text + +%macro TRANSFORM_COLS 0 + paddw m0, m1 + movq m4, m0 + psubw m3, m2 + psubw m4, m3 + psraw m4, 1 + movq m5, m4 + psubw m5, m1 ;b1 + psubw m4, m2 ;c1 + psubw m0, m4 + paddw m3, m5 + ; m0 a0 + SWAP 1, 4 ; m1 c1 + SWAP 2, 3 ; m2 d1 + SWAP 3, 5 ; m3 b1 +%endmacro + +%macro TRANSPOSE_4X4 0 + ; 00 01 02 03 + ; 10 11 12 13 + ; 20 21 22 23 + ; 30 31 32 33 + punpcklwd m0, m1 ; 00 10 01 11 02 12 03 13 + punpcklwd m2, m3 ; 20 30 21 31 22 32 23 33 + mova m1, m0 + punpckldq m0, m2 ; 00 10 20 30 01 11 21 31 + punpckhdq m1, m2 ; 02 12 22 32 03 13 23 33 +%endmacro + +INIT_XMM sse2 +cglobal fwht4x4, 3, 4, 8, input, output, stride + lea r3q, [inputq + strideq*4] + movq m0, [inputq] ;a1 + movq m1, [inputq + strideq*2] ;b1 + movq m2, [r3q] ;c1 + movq m3, [r3q + strideq*2] ;d1 + + TRANSFORM_COLS + TRANSPOSE_4X4 + SWAP 1, 2 + psrldq m1, m0, 8 + psrldq m3, m2, 8 + TRANSFORM_COLS + TRANSPOSE_4X4 + + psllw m0, 2 + psllw m1, 2 + +%if CONFIG_HIGHBITDEPTH + ; sign extension + mova m2, m0 + mova m3, m1 + punpcklwd m0, m0 + punpcklwd m1, m1 + punpckhwd m2, m2 + punpckhwd m3, m3 + psrad m0, 16 + psrad m1, 16 + psrad m2, 16 + psrad m3, 16 + mova [outputq], m0 + mova [outputq + 16], m2 + mova [outputq + 32], m1 + mova [outputq + 48], m3 +%else + mova [outputq], m0 + mova [outputq + 16], m1 +%endif + + RET diff --git a/third_party/aom/av1/encoder/x86/dct_ssse3.c b/third_party/aom/av1/encoder/x86/dct_ssse3.c new file mode 100644 index 000000000..717a99af8 --- /dev/null +++ b/third_party/aom/av1/encoder/x86/dct_ssse3.c @@ -0,0 +1,469 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <assert.h> +#if defined(_MSC_VER) && _MSC_VER <= 1500 +// Need to include math.h before calling tmmintrin.h/intrin.h +// in certain versions of MSVS. +#include <math.h> +#endif +#include <tmmintrin.h> // SSSE3 + +#include "./av1_rtcd.h" +#include "aom_dsp/x86/inv_txfm_sse2.h" +#include "aom_dsp/x86/txfm_common_sse2.h" + +void av1_fdct8x8_quant_ssse3( + const int16_t *input, int stride, int16_t *coeff_ptr, intptr_t n_coeffs, + int skip_block, const int16_t *zbin_ptr, const int16_t *round_ptr, + const int16_t *quant_ptr, const int16_t *quant_shift_ptr, + int16_t *qcoeff_ptr, int16_t *dqcoeff_ptr, const int16_t *dequant_ptr, + uint16_t *eob_ptr, const int16_t *scan_ptr, const int16_t *iscan_ptr) { + __m128i zero; + int pass; + // Constants + // When we use them, in one case, they are all the same. In all others + // it's a pair of them that we need to repeat four times. This is done + // by constructing the 32 bit constant corresponding to that pair. + const __m128i k__dual_p16_p16 = dual_set_epi16(23170, 23170); + const __m128i k__cospi_p16_p16 = _mm_set1_epi16((int16_t)cospi_16_64); + const __m128i k__cospi_p16_m16 = pair_set_epi16(cospi_16_64, -cospi_16_64); + const __m128i k__cospi_p24_p08 = pair_set_epi16(cospi_24_64, cospi_8_64); + const __m128i k__cospi_m08_p24 = pair_set_epi16(-cospi_8_64, cospi_24_64); + const __m128i k__cospi_p28_p04 = pair_set_epi16(cospi_28_64, cospi_4_64); + const __m128i k__cospi_m04_p28 = pair_set_epi16(-cospi_4_64, cospi_28_64); + const __m128i k__cospi_p12_p20 = pair_set_epi16(cospi_12_64, cospi_20_64); + const __m128i k__cospi_m20_p12 = pair_set_epi16(-cospi_20_64, cospi_12_64); + const __m128i k__DCT_CONST_ROUNDING = _mm_set1_epi32(DCT_CONST_ROUNDING); + // Load input + __m128i in0 = _mm_load_si128((const __m128i *)(input + 0 * stride)); + __m128i in1 = _mm_load_si128((const __m128i *)(input + 1 * stride)); + __m128i in2 = _mm_load_si128((const __m128i *)(input + 2 * stride)); + __m128i in3 = _mm_load_si128((const __m128i *)(input + 3 * stride)); + __m128i in4 = _mm_load_si128((const __m128i *)(input + 4 * stride)); + __m128i in5 = _mm_load_si128((const __m128i *)(input + 5 * stride)); + __m128i in6 = _mm_load_si128((const __m128i *)(input + 6 * stride)); + __m128i in7 = _mm_load_si128((const __m128i *)(input + 7 * stride)); + __m128i *in[8]; + int index = 0; + + (void)scan_ptr; + (void)zbin_ptr; + (void)quant_shift_ptr; + (void)coeff_ptr; + + // Pre-condition input (shift by two) + in0 = _mm_slli_epi16(in0, 2); + in1 = _mm_slli_epi16(in1, 2); + in2 = _mm_slli_epi16(in2, 2); + in3 = _mm_slli_epi16(in3, 2); + in4 = _mm_slli_epi16(in4, 2); + in5 = _mm_slli_epi16(in5, 2); + in6 = _mm_slli_epi16(in6, 2); + in7 = _mm_slli_epi16(in7, 2); + + in[0] = &in0; + in[1] = &in1; + in[2] = &in2; + in[3] = &in3; + in[4] = &in4; + in[5] = &in5; + in[6] = &in6; + in[7] = &in7; + + // We do two passes, first the columns, then the rows. The results of the + // first pass are transposed so that the same column code can be reused. The + // results of the second pass are also transposed so that the rows (processed + // as columns) are put back in row positions. + for (pass = 0; pass < 2; pass++) { + // To store results of each pass before the transpose. + __m128i res0, res1, res2, res3, res4, res5, res6, res7; + // Add/subtract + const __m128i q0 = _mm_add_epi16(in0, in7); + const __m128i q1 = _mm_add_epi16(in1, in6); + const __m128i q2 = _mm_add_epi16(in2, in5); + const __m128i q3 = _mm_add_epi16(in3, in4); + const __m128i q4 = _mm_sub_epi16(in3, in4); + const __m128i q5 = _mm_sub_epi16(in2, in5); + const __m128i q6 = _mm_sub_epi16(in1, in6); + const __m128i q7 = _mm_sub_epi16(in0, in7); + // Work on first four results + { + // Add/subtract + const __m128i r0 = _mm_add_epi16(q0, q3); + const __m128i r1 = _mm_add_epi16(q1, q2); + const __m128i r2 = _mm_sub_epi16(q1, q2); + const __m128i r3 = _mm_sub_epi16(q0, q3); + // Interleave to do the multiply by constants which gets us into 32bits + const __m128i t0 = _mm_unpacklo_epi16(r0, r1); + const __m128i t1 = _mm_unpackhi_epi16(r0, r1); + const __m128i t2 = _mm_unpacklo_epi16(r2, r3); + const __m128i t3 = _mm_unpackhi_epi16(r2, r3); + + const __m128i u0 = _mm_madd_epi16(t0, k__cospi_p16_p16); + const __m128i u1 = _mm_madd_epi16(t1, k__cospi_p16_p16); + const __m128i u2 = _mm_madd_epi16(t0, k__cospi_p16_m16); + const __m128i u3 = _mm_madd_epi16(t1, k__cospi_p16_m16); + + const __m128i u4 = _mm_madd_epi16(t2, k__cospi_p24_p08); + const __m128i u5 = _mm_madd_epi16(t3, k__cospi_p24_p08); + const __m128i u6 = _mm_madd_epi16(t2, k__cospi_m08_p24); + const __m128i u7 = _mm_madd_epi16(t3, k__cospi_m08_p24); + // dct_const_round_shift + + const __m128i v0 = _mm_add_epi32(u0, k__DCT_CONST_ROUNDING); + const __m128i v1 = _mm_add_epi32(u1, k__DCT_CONST_ROUNDING); + const __m128i v2 = _mm_add_epi32(u2, k__DCT_CONST_ROUNDING); + const __m128i v3 = _mm_add_epi32(u3, k__DCT_CONST_ROUNDING); + + const __m128i v4 = _mm_add_epi32(u4, k__DCT_CONST_ROUNDING); + const __m128i v5 = _mm_add_epi32(u5, k__DCT_CONST_ROUNDING); + const __m128i v6 = _mm_add_epi32(u6, k__DCT_CONST_ROUNDING); + const __m128i v7 = _mm_add_epi32(u7, k__DCT_CONST_ROUNDING); + + const __m128i w0 = _mm_srai_epi32(v0, DCT_CONST_BITS); + const __m128i w1 = _mm_srai_epi32(v1, DCT_CONST_BITS); + const __m128i w2 = _mm_srai_epi32(v2, DCT_CONST_BITS); + const __m128i w3 = _mm_srai_epi32(v3, DCT_CONST_BITS); + + const __m128i w4 = _mm_srai_epi32(v4, DCT_CONST_BITS); + const __m128i w5 = _mm_srai_epi32(v5, DCT_CONST_BITS); + const __m128i w6 = _mm_srai_epi32(v6, DCT_CONST_BITS); + const __m128i w7 = _mm_srai_epi32(v7, DCT_CONST_BITS); + // Combine + + res0 = _mm_packs_epi32(w0, w1); + res4 = _mm_packs_epi32(w2, w3); + res2 = _mm_packs_epi32(w4, w5); + res6 = _mm_packs_epi32(w6, w7); + } + // Work on next four results + { + // Interleave to do the multiply by constants which gets us into 32bits + const __m128i d0 = _mm_sub_epi16(q6, q5); + const __m128i d1 = _mm_add_epi16(q6, q5); + const __m128i r0 = _mm_mulhrs_epi16(d0, k__dual_p16_p16); + const __m128i r1 = _mm_mulhrs_epi16(d1, k__dual_p16_p16); + + // Add/subtract + const __m128i x0 = _mm_add_epi16(q4, r0); + const __m128i x1 = _mm_sub_epi16(q4, r0); + const __m128i x2 = _mm_sub_epi16(q7, r1); + const __m128i x3 = _mm_add_epi16(q7, r1); + // Interleave to do the multiply by constants which gets us into 32bits + const __m128i t0 = _mm_unpacklo_epi16(x0, x3); + const __m128i t1 = _mm_unpackhi_epi16(x0, x3); + const __m128i t2 = _mm_unpacklo_epi16(x1, x2); + const __m128i t3 = _mm_unpackhi_epi16(x1, x2); + const __m128i u0 = _mm_madd_epi16(t0, k__cospi_p28_p04); + const __m128i u1 = _mm_madd_epi16(t1, k__cospi_p28_p04); + const __m128i u2 = _mm_madd_epi16(t0, k__cospi_m04_p28); + const __m128i u3 = _mm_madd_epi16(t1, k__cospi_m04_p28); + const __m128i u4 = _mm_madd_epi16(t2, k__cospi_p12_p20); + const __m128i u5 = _mm_madd_epi16(t3, k__cospi_p12_p20); + const __m128i u6 = _mm_madd_epi16(t2, k__cospi_m20_p12); + const __m128i u7 = _mm_madd_epi16(t3, k__cospi_m20_p12); + // dct_const_round_shift + const __m128i v0 = _mm_add_epi32(u0, k__DCT_CONST_ROUNDING); + const __m128i v1 = _mm_add_epi32(u1, k__DCT_CONST_ROUNDING); + const __m128i v2 = _mm_add_epi32(u2, k__DCT_CONST_ROUNDING); + const __m128i v3 = _mm_add_epi32(u3, k__DCT_CONST_ROUNDING); + const __m128i v4 = _mm_add_epi32(u4, k__DCT_CONST_ROUNDING); + const __m128i v5 = _mm_add_epi32(u5, k__DCT_CONST_ROUNDING); + const __m128i v6 = _mm_add_epi32(u6, k__DCT_CONST_ROUNDING); + const __m128i v7 = _mm_add_epi32(u7, k__DCT_CONST_ROUNDING); + const __m128i w0 = _mm_srai_epi32(v0, DCT_CONST_BITS); + const __m128i w1 = _mm_srai_epi32(v1, DCT_CONST_BITS); + const __m128i w2 = _mm_srai_epi32(v2, DCT_CONST_BITS); + const __m128i w3 = _mm_srai_epi32(v3, DCT_CONST_BITS); + const __m128i w4 = _mm_srai_epi32(v4, DCT_CONST_BITS); + const __m128i w5 = _mm_srai_epi32(v5, DCT_CONST_BITS); + const __m128i w6 = _mm_srai_epi32(v6, DCT_CONST_BITS); + const __m128i w7 = _mm_srai_epi32(v7, DCT_CONST_BITS); + // Combine + res1 = _mm_packs_epi32(w0, w1); + res7 = _mm_packs_epi32(w2, w3); + res5 = _mm_packs_epi32(w4, w5); + res3 = _mm_packs_epi32(w6, w7); + } + // Transpose the 8x8. + { + // 00 01 02 03 04 05 06 07 + // 10 11 12 13 14 15 16 17 + // 20 21 22 23 24 25 26 27 + // 30 31 32 33 34 35 36 37 + // 40 41 42 43 44 45 46 47 + // 50 51 52 53 54 55 56 57 + // 60 61 62 63 64 65 66 67 + // 70 71 72 73 74 75 76 77 + const __m128i tr0_0 = _mm_unpacklo_epi16(res0, res1); + const __m128i tr0_1 = _mm_unpacklo_epi16(res2, res3); + const __m128i tr0_2 = _mm_unpackhi_epi16(res0, res1); + const __m128i tr0_3 = _mm_unpackhi_epi16(res2, res3); + const __m128i tr0_4 = _mm_unpacklo_epi16(res4, res5); + const __m128i tr0_5 = _mm_unpacklo_epi16(res6, res7); + const __m128i tr0_6 = _mm_unpackhi_epi16(res4, res5); + const __m128i tr0_7 = _mm_unpackhi_epi16(res6, res7); + // 00 10 01 11 02 12 03 13 + // 20 30 21 31 22 32 23 33 + // 04 14 05 15 06 16 07 17 + // 24 34 25 35 26 36 27 37 + // 40 50 41 51 42 52 43 53 + // 60 70 61 71 62 72 63 73 + // 54 54 55 55 56 56 57 57 + // 64 74 65 75 66 76 67 77 + const __m128i tr1_0 = _mm_unpacklo_epi32(tr0_0, tr0_1); + const __m128i tr1_1 = _mm_unpacklo_epi32(tr0_2, tr0_3); + const __m128i tr1_2 = _mm_unpackhi_epi32(tr0_0, tr0_1); + const __m128i tr1_3 = _mm_unpackhi_epi32(tr0_2, tr0_3); + const __m128i tr1_4 = _mm_unpacklo_epi32(tr0_4, tr0_5); + const __m128i tr1_5 = _mm_unpacklo_epi32(tr0_6, tr0_7); + const __m128i tr1_6 = _mm_unpackhi_epi32(tr0_4, tr0_5); + const __m128i tr1_7 = _mm_unpackhi_epi32(tr0_6, tr0_7); + // 00 10 20 30 01 11 21 31 + // 40 50 60 70 41 51 61 71 + // 02 12 22 32 03 13 23 33 + // 42 52 62 72 43 53 63 73 + // 04 14 24 34 05 15 21 36 + // 44 54 64 74 45 55 61 76 + // 06 16 26 36 07 17 27 37 + // 46 56 66 76 47 57 67 77 + in0 = _mm_unpacklo_epi64(tr1_0, tr1_4); + in1 = _mm_unpackhi_epi64(tr1_0, tr1_4); + in2 = _mm_unpacklo_epi64(tr1_2, tr1_6); + in3 = _mm_unpackhi_epi64(tr1_2, tr1_6); + in4 = _mm_unpacklo_epi64(tr1_1, tr1_5); + in5 = _mm_unpackhi_epi64(tr1_1, tr1_5); + in6 = _mm_unpacklo_epi64(tr1_3, tr1_7); + in7 = _mm_unpackhi_epi64(tr1_3, tr1_7); + // 00 10 20 30 40 50 60 70 + // 01 11 21 31 41 51 61 71 + // 02 12 22 32 42 52 62 72 + // 03 13 23 33 43 53 63 73 + // 04 14 24 34 44 54 64 74 + // 05 15 25 35 45 55 65 75 + // 06 16 26 36 46 56 66 76 + // 07 17 27 37 47 57 67 77 + } + } + // Post-condition output and store it + { + // Post-condition (division by two) + // division of two 16 bits signed numbers using shifts + // n / 2 = (n - (n >> 15)) >> 1 + const __m128i sign_in0 = _mm_srai_epi16(in0, 15); + const __m128i sign_in1 = _mm_srai_epi16(in1, 15); + const __m128i sign_in2 = _mm_srai_epi16(in2, 15); + const __m128i sign_in3 = _mm_srai_epi16(in3, 15); + const __m128i sign_in4 = _mm_srai_epi16(in4, 15); + const __m128i sign_in5 = _mm_srai_epi16(in5, 15); + const __m128i sign_in6 = _mm_srai_epi16(in6, 15); + const __m128i sign_in7 = _mm_srai_epi16(in7, 15); + in0 = _mm_sub_epi16(in0, sign_in0); + in1 = _mm_sub_epi16(in1, sign_in1); + in2 = _mm_sub_epi16(in2, sign_in2); + in3 = _mm_sub_epi16(in3, sign_in3); + in4 = _mm_sub_epi16(in4, sign_in4); + in5 = _mm_sub_epi16(in5, sign_in5); + in6 = _mm_sub_epi16(in6, sign_in6); + in7 = _mm_sub_epi16(in7, sign_in7); + in0 = _mm_srai_epi16(in0, 1); + in1 = _mm_srai_epi16(in1, 1); + in2 = _mm_srai_epi16(in2, 1); + in3 = _mm_srai_epi16(in3, 1); + in4 = _mm_srai_epi16(in4, 1); + in5 = _mm_srai_epi16(in5, 1); + in6 = _mm_srai_epi16(in6, 1); + in7 = _mm_srai_epi16(in7, 1); + } + + iscan_ptr += n_coeffs; + qcoeff_ptr += n_coeffs; + dqcoeff_ptr += n_coeffs; + n_coeffs = -n_coeffs; + zero = _mm_setzero_si128(); + + if (!skip_block) { + __m128i eob; + __m128i round, quant, dequant, thr; + int16_t nzflag; + { + __m128i coeff0, coeff1; + + // Setup global values + { + round = _mm_load_si128((const __m128i *)round_ptr); + quant = _mm_load_si128((const __m128i *)quant_ptr); + dequant = _mm_load_si128((const __m128i *)dequant_ptr); + } + + { + __m128i coeff0_sign, coeff1_sign; + __m128i qcoeff0, qcoeff1; + __m128i qtmp0, qtmp1; + // Do DC and first 15 AC + coeff0 = *in[0]; + coeff1 = *in[1]; + + // Poor man's sign extract + coeff0_sign = _mm_srai_epi16(coeff0, 15); + coeff1_sign = _mm_srai_epi16(coeff1, 15); + qcoeff0 = _mm_xor_si128(coeff0, coeff0_sign); + qcoeff1 = _mm_xor_si128(coeff1, coeff1_sign); + qcoeff0 = _mm_sub_epi16(qcoeff0, coeff0_sign); + qcoeff1 = _mm_sub_epi16(qcoeff1, coeff1_sign); + + qcoeff0 = _mm_adds_epi16(qcoeff0, round); + round = _mm_unpackhi_epi64(round, round); + qcoeff1 = _mm_adds_epi16(qcoeff1, round); + qtmp0 = _mm_mulhi_epi16(qcoeff0, quant); + quant = _mm_unpackhi_epi64(quant, quant); + qtmp1 = _mm_mulhi_epi16(qcoeff1, quant); + + // Reinsert signs + qcoeff0 = _mm_xor_si128(qtmp0, coeff0_sign); + qcoeff1 = _mm_xor_si128(qtmp1, coeff1_sign); + qcoeff0 = _mm_sub_epi16(qcoeff0, coeff0_sign); + qcoeff1 = _mm_sub_epi16(qcoeff1, coeff1_sign); + + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs), qcoeff0); + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs) + 1, qcoeff1); + + coeff0 = _mm_mullo_epi16(qcoeff0, dequant); + dequant = _mm_unpackhi_epi64(dequant, dequant); + coeff1 = _mm_mullo_epi16(qcoeff1, dequant); + + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs), coeff0); + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs) + 1, coeff1); + } + + { + // Scan for eob + __m128i zero_coeff0, zero_coeff1; + __m128i nzero_coeff0, nzero_coeff1; + __m128i iscan0, iscan1; + __m128i eob1; + zero_coeff0 = _mm_cmpeq_epi16(coeff0, zero); + zero_coeff1 = _mm_cmpeq_epi16(coeff1, zero); + nzero_coeff0 = _mm_cmpeq_epi16(zero_coeff0, zero); + nzero_coeff1 = _mm_cmpeq_epi16(zero_coeff1, zero); + iscan0 = _mm_load_si128((const __m128i *)(iscan_ptr + n_coeffs)); + iscan1 = _mm_load_si128((const __m128i *)(iscan_ptr + n_coeffs) + 1); + // Add one to convert from indices to counts + iscan0 = _mm_sub_epi16(iscan0, nzero_coeff0); + iscan1 = _mm_sub_epi16(iscan1, nzero_coeff1); + eob = _mm_and_si128(iscan0, nzero_coeff0); + eob1 = _mm_and_si128(iscan1, nzero_coeff1); + eob = _mm_max_epi16(eob, eob1); + } + n_coeffs += 8 * 2; + } + + // AC only loop + index = 2; + thr = _mm_srai_epi16(dequant, 1); + while (n_coeffs < 0) { + __m128i coeff0, coeff1; + { + __m128i coeff0_sign, coeff1_sign; + __m128i qcoeff0, qcoeff1; + __m128i qtmp0, qtmp1; + + assert(index < (int)(sizeof(in) / sizeof(in[0])) - 1); + coeff0 = *in[index]; + coeff1 = *in[index + 1]; + + // Poor man's sign extract + coeff0_sign = _mm_srai_epi16(coeff0, 15); + coeff1_sign = _mm_srai_epi16(coeff1, 15); + qcoeff0 = _mm_xor_si128(coeff0, coeff0_sign); + qcoeff1 = _mm_xor_si128(coeff1, coeff1_sign); + qcoeff0 = _mm_sub_epi16(qcoeff0, coeff0_sign); + qcoeff1 = _mm_sub_epi16(qcoeff1, coeff1_sign); + + nzflag = _mm_movemask_epi8(_mm_cmpgt_epi16(qcoeff0, thr)) | + _mm_movemask_epi8(_mm_cmpgt_epi16(qcoeff1, thr)); + + if (nzflag) { + qcoeff0 = _mm_adds_epi16(qcoeff0, round); + qcoeff1 = _mm_adds_epi16(qcoeff1, round); + qtmp0 = _mm_mulhi_epi16(qcoeff0, quant); + qtmp1 = _mm_mulhi_epi16(qcoeff1, quant); + + // Reinsert signs + qcoeff0 = _mm_xor_si128(qtmp0, coeff0_sign); + qcoeff1 = _mm_xor_si128(qtmp1, coeff1_sign); + qcoeff0 = _mm_sub_epi16(qcoeff0, coeff0_sign); + qcoeff1 = _mm_sub_epi16(qcoeff1, coeff1_sign); + + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs), qcoeff0); + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs) + 1, qcoeff1); + + coeff0 = _mm_mullo_epi16(qcoeff0, dequant); + coeff1 = _mm_mullo_epi16(qcoeff1, dequant); + + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs), coeff0); + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs) + 1, coeff1); + } else { + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs), zero); + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs) + 1, zero); + + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs), zero); + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs) + 1, zero); + } + } + + if (nzflag) { + // Scan for eob + __m128i zero_coeff0, zero_coeff1; + __m128i nzero_coeff0, nzero_coeff1; + __m128i iscan0, iscan1; + __m128i eob0, eob1; + zero_coeff0 = _mm_cmpeq_epi16(coeff0, zero); + zero_coeff1 = _mm_cmpeq_epi16(coeff1, zero); + nzero_coeff0 = _mm_cmpeq_epi16(zero_coeff0, zero); + nzero_coeff1 = _mm_cmpeq_epi16(zero_coeff1, zero); + iscan0 = _mm_load_si128((const __m128i *)(iscan_ptr + n_coeffs)); + iscan1 = _mm_load_si128((const __m128i *)(iscan_ptr + n_coeffs) + 1); + // Add one to convert from indices to counts + iscan0 = _mm_sub_epi16(iscan0, nzero_coeff0); + iscan1 = _mm_sub_epi16(iscan1, nzero_coeff1); + eob0 = _mm_and_si128(iscan0, nzero_coeff0); + eob1 = _mm_and_si128(iscan1, nzero_coeff1); + eob0 = _mm_max_epi16(eob0, eob1); + eob = _mm_max_epi16(eob, eob0); + } + n_coeffs += 8 * 2; + index += 2; + } + + // Accumulate EOB + { + __m128i eob_shuffled; + eob_shuffled = _mm_shuffle_epi32(eob, 0xe); + eob = _mm_max_epi16(eob, eob_shuffled); + eob_shuffled = _mm_shufflelo_epi16(eob, 0xe); + eob = _mm_max_epi16(eob, eob_shuffled); + eob_shuffled = _mm_shufflelo_epi16(eob, 0x1); + eob = _mm_max_epi16(eob, eob_shuffled); + *eob_ptr = _mm_extract_epi16(eob, 1); + } + } else { + do { + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs), zero); + _mm_store_si128((__m128i *)(dqcoeff_ptr + n_coeffs) + 1, zero); + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs), zero); + _mm_store_si128((__m128i *)(qcoeff_ptr + n_coeffs) + 1, zero); + n_coeffs += 8 * 2; + } while (n_coeffs < 0); + *eob_ptr = 0; + } +} diff --git a/third_party/aom/av1/encoder/x86/error_intrin_avx2.c b/third_party/aom/av1/encoder/x86/error_intrin_avx2.c new file mode 100644 index 000000000..ae733a1ce --- /dev/null +++ b/third_party/aom/av1/encoder/x86/error_intrin_avx2.c @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <immintrin.h> // AVX2 + +#include "./av1_rtcd.h" +#include "aom/aom_integer.h" + +int64_t av1_block_error_avx2(const int16_t *coeff, const int16_t *dqcoeff, + intptr_t block_size, int64_t *ssz) { + __m256i sse_reg, ssz_reg, coeff_reg, dqcoeff_reg; + __m256i exp_dqcoeff_lo, exp_dqcoeff_hi, exp_coeff_lo, exp_coeff_hi; + __m256i sse_reg_64hi, ssz_reg_64hi; + __m128i sse_reg128, ssz_reg128; + int64_t sse; + int i; + const __m256i zero_reg = _mm256_set1_epi16(0); + + // init sse and ssz registerd to zero + sse_reg = _mm256_set1_epi16(0); + ssz_reg = _mm256_set1_epi16(0); + + for (i = 0; i < block_size; i += 16) { + // load 32 bytes from coeff and dqcoeff + coeff_reg = _mm256_loadu_si256((const __m256i *)(coeff + i)); + dqcoeff_reg = _mm256_loadu_si256((const __m256i *)(dqcoeff + i)); + // dqcoeff - coeff + dqcoeff_reg = _mm256_sub_epi16(dqcoeff_reg, coeff_reg); + // madd (dqcoeff - coeff) + dqcoeff_reg = _mm256_madd_epi16(dqcoeff_reg, dqcoeff_reg); + // madd coeff + coeff_reg = _mm256_madd_epi16(coeff_reg, coeff_reg); + // expand each double word of madd (dqcoeff - coeff) to quad word + exp_dqcoeff_lo = _mm256_unpacklo_epi32(dqcoeff_reg, zero_reg); + exp_dqcoeff_hi = _mm256_unpackhi_epi32(dqcoeff_reg, zero_reg); + // expand each double word of madd (coeff) to quad word + exp_coeff_lo = _mm256_unpacklo_epi32(coeff_reg, zero_reg); + exp_coeff_hi = _mm256_unpackhi_epi32(coeff_reg, zero_reg); + // add each quad word of madd (dqcoeff - coeff) and madd (coeff) + sse_reg = _mm256_add_epi64(sse_reg, exp_dqcoeff_lo); + ssz_reg = _mm256_add_epi64(ssz_reg, exp_coeff_lo); + sse_reg = _mm256_add_epi64(sse_reg, exp_dqcoeff_hi); + ssz_reg = _mm256_add_epi64(ssz_reg, exp_coeff_hi); + } + // save the higher 64 bit of each 128 bit lane + sse_reg_64hi = _mm256_srli_si256(sse_reg, 8); + ssz_reg_64hi = _mm256_srli_si256(ssz_reg, 8); + // add the higher 64 bit to the low 64 bit + sse_reg = _mm256_add_epi64(sse_reg, sse_reg_64hi); + ssz_reg = _mm256_add_epi64(ssz_reg, ssz_reg_64hi); + + // add each 64 bit from each of the 128 bit lane of the 256 bit + sse_reg128 = _mm_add_epi64(_mm256_castsi256_si128(sse_reg), + _mm256_extractf128_si256(sse_reg, 1)); + + ssz_reg128 = _mm_add_epi64(_mm256_castsi256_si128(ssz_reg), + _mm256_extractf128_si256(ssz_reg, 1)); + + // store the results + _mm_storel_epi64((__m128i *)(&sse), sse_reg128); + + _mm_storel_epi64((__m128i *)(ssz), ssz_reg128); + _mm256_zeroupper(); + return sse; +} diff --git a/third_party/aom/av1/encoder/x86/error_sse2.asm b/third_party/aom/av1/encoder/x86/error_sse2.asm new file mode 100644 index 000000000..4680f1fab --- /dev/null +++ b/third_party/aom/av1/encoder/x86/error_sse2.asm @@ -0,0 +1,125 @@ +; +; Copyright (c) 2016, Alliance for Open Media. All rights reserved +; +; This source code is subject to the terms of the BSD 2 Clause License and +; the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +; was not distributed with this source code in the LICENSE file, you can +; obtain it at www.aomedia.org/license/software. If the Alliance for Open +; Media Patent License 1.0 was not distributed with this source code in the +; PATENTS file, you can obtain it at www.aomedia.org/license/patent. +; + +; + +%define private_prefix av1 + +%include "third_party/x86inc/x86inc.asm" + +SECTION .text + +; int64_t av1_block_error(int16_t *coeff, int16_t *dqcoeff, intptr_t block_size, +; int64_t *ssz) + +INIT_XMM sse2 +cglobal block_error, 3, 3, 8, uqc, dqc, size, ssz + pxor m4, m4 ; sse accumulator + pxor m6, m6 ; ssz accumulator + pxor m5, m5 ; dedicated zero register + lea uqcq, [uqcq+sizeq*2] + lea dqcq, [dqcq+sizeq*2] + neg sizeq +.loop: + mova m2, [uqcq+sizeq*2] + mova m0, [dqcq+sizeq*2] + mova m3, [uqcq+sizeq*2+mmsize] + mova m1, [dqcq+sizeq*2+mmsize] + psubw m0, m2 + psubw m1, m3 + ; individual errors are max. 15bit+sign, so squares are 30bit, and + ; thus the sum of 2 should fit in a 31bit integer (+ unused sign bit) + pmaddwd m0, m0 + pmaddwd m1, m1 + pmaddwd m2, m2 + pmaddwd m3, m3 + ; accumulate in 64bit + punpckldq m7, m0, m5 + punpckhdq m0, m5 + paddq m4, m7 + punpckldq m7, m1, m5 + paddq m4, m0 + punpckhdq m1, m5 + paddq m4, m7 + punpckldq m7, m2, m5 + paddq m4, m1 + punpckhdq m2, m5 + paddq m6, m7 + punpckldq m7, m3, m5 + paddq m6, m2 + punpckhdq m3, m5 + paddq m6, m7 + paddq m6, m3 + add sizeq, mmsize + jl .loop + + ; accumulate horizontally and store in return value + movhlps m5, m4 + movhlps m7, m6 + paddq m4, m5 + paddq m6, m7 +%if ARCH_X86_64 + movq rax, m4 + movq [sszq], m6 +%else + mov eax, sszm + pshufd m5, m4, 0x1 + movq [eax], m6 + movd eax, m4 + movd edx, m5 +%endif + RET + +; Compute the sum of squared difference between two int16_t vectors. +; int64_t av1_block_error_fp(int16_t *coeff, int16_t *dqcoeff, +; intptr_t block_size) + +INIT_XMM sse2 +cglobal block_error_fp, 3, 3, 6, uqc, dqc, size + pxor m4, m4 ; sse accumulator + pxor m5, m5 ; dedicated zero register + lea uqcq, [uqcq+sizeq*2] + lea dqcq, [dqcq+sizeq*2] + neg sizeq +.loop: + mova m2, [uqcq+sizeq*2] + mova m0, [dqcq+sizeq*2] + mova m3, [uqcq+sizeq*2+mmsize] + mova m1, [dqcq+sizeq*2+mmsize] + psubw m0, m2 + psubw m1, m3 + ; individual errors are max. 15bit+sign, so squares are 30bit, and + ; thus the sum of 2 should fit in a 31bit integer (+ unused sign bit) + pmaddwd m0, m0 + pmaddwd m1, m1 + ; accumulate in 64bit + punpckldq m3, m0, m5 + punpckhdq m0, m5 + paddq m4, m3 + punpckldq m3, m1, m5 + paddq m4, m0 + punpckhdq m1, m5 + paddq m4, m3 + paddq m4, m1 + add sizeq, mmsize + jl .loop + + ; accumulate horizontally and store in return value + movhlps m5, m4 + paddq m4, m5 +%if ARCH_X86_64 + movq rax, m4 +%else + pshufd m5, m4, 0x1 + movd eax, m4 + movd edx, m5 +%endif + RET diff --git a/third_party/aom/av1/encoder/x86/highbd_block_error_intrin_sse2.c b/third_party/aom/av1/encoder/x86/highbd_block_error_intrin_sse2.c new file mode 100644 index 000000000..777304ace --- /dev/null +++ b/third_party/aom/av1/encoder/x86/highbd_block_error_intrin_sse2.c @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <emmintrin.h> +#include <stdio.h> + +#include "av1/common/common.h" + +int64_t av1_highbd_block_error_sse2(tran_low_t *coeff, tran_low_t *dqcoeff, + intptr_t block_size, int64_t *ssz, + int bps) { + int i, j, test; + uint32_t temp[4]; + __m128i max, min, cmp0, cmp1, cmp2, cmp3; + int64_t error = 0, sqcoeff = 0; + const int shift = 2 * (bps - 8); + const int rounding = shift > 0 ? 1 << (shift - 1) : 0; + + for (i = 0; i < block_size; i += 8) { + // Load the data into xmm registers + __m128i mm_coeff = _mm_load_si128((__m128i *)(coeff + i)); + __m128i mm_coeff2 = _mm_load_si128((__m128i *)(coeff + i + 4)); + __m128i mm_dqcoeff = _mm_load_si128((__m128i *)(dqcoeff + i)); + __m128i mm_dqcoeff2 = _mm_load_si128((__m128i *)(dqcoeff + i + 4)); + // Check if any values require more than 15 bit + max = _mm_set1_epi32(0x3fff); + min = _mm_set1_epi32(0xffffc000); + cmp0 = _mm_xor_si128(_mm_cmpgt_epi32(mm_coeff, max), + _mm_cmplt_epi32(mm_coeff, min)); + cmp1 = _mm_xor_si128(_mm_cmpgt_epi32(mm_coeff2, max), + _mm_cmplt_epi32(mm_coeff2, min)); + cmp2 = _mm_xor_si128(_mm_cmpgt_epi32(mm_dqcoeff, max), + _mm_cmplt_epi32(mm_dqcoeff, min)); + cmp3 = _mm_xor_si128(_mm_cmpgt_epi32(mm_dqcoeff2, max), + _mm_cmplt_epi32(mm_dqcoeff2, min)); + test = _mm_movemask_epi8( + _mm_or_si128(_mm_or_si128(cmp0, cmp1), _mm_or_si128(cmp2, cmp3))); + + if (!test) { + __m128i mm_diff, error_sse2, sqcoeff_sse2; + mm_coeff = _mm_packs_epi32(mm_coeff, mm_coeff2); + mm_dqcoeff = _mm_packs_epi32(mm_dqcoeff, mm_dqcoeff2); + mm_diff = _mm_sub_epi16(mm_coeff, mm_dqcoeff); + error_sse2 = _mm_madd_epi16(mm_diff, mm_diff); + sqcoeff_sse2 = _mm_madd_epi16(mm_coeff, mm_coeff); + _mm_storeu_si128((__m128i *)temp, error_sse2); + error = error + temp[0] + temp[1] + temp[2] + temp[3]; + _mm_storeu_si128((__m128i *)temp, sqcoeff_sse2); + sqcoeff += temp[0] + temp[1] + temp[2] + temp[3]; + } else { + for (j = 0; j < 8; j++) { + const int64_t diff = coeff[i + j] - dqcoeff[i + j]; + error += diff * diff; + sqcoeff += (int64_t)coeff[i + j] * (int64_t)coeff[i + j]; + } + } + } + assert(error >= 0 && sqcoeff >= 0); + error = (error + rounding) >> shift; + sqcoeff = (sqcoeff + rounding) >> shift; + + *ssz = sqcoeff; + return error; +} diff --git a/third_party/aom/av1/encoder/x86/highbd_fwd_txfm_sse4.c b/third_party/aom/av1/encoder/x86/highbd_fwd_txfm_sse4.c new file mode 100644 index 000000000..f201a29aa --- /dev/null +++ b/third_party/aom/av1/encoder/x86/highbd_fwd_txfm_sse4.c @@ -0,0 +1,1895 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ +#include <assert.h> +#include <smmintrin.h> /* SSE4.1 */ + +#include "./av1_rtcd.h" +#include "./aom_config.h" +#include "av1/common/av1_fwd_txfm2d_cfg.h" +#include "av1/common/av1_txfm.h" +#include "av1/common/x86/highbd_txfm_utility_sse4.h" +#include "aom_dsp/txfm_common.h" +#include "aom_dsp/x86/txfm_common_sse2.h" +#include "aom_ports/mem.h" + +static INLINE void load_buffer_4x4(const int16_t *input, __m128i *in, + int stride, int flipud, int fliplr, + int shift) { + if (!flipud) { + in[0] = _mm_loadl_epi64((const __m128i *)(input + 0 * stride)); + in[1] = _mm_loadl_epi64((const __m128i *)(input + 1 * stride)); + in[2] = _mm_loadl_epi64((const __m128i *)(input + 2 * stride)); + in[3] = _mm_loadl_epi64((const __m128i *)(input + 3 * stride)); + } else { + in[0] = _mm_loadl_epi64((const __m128i *)(input + 3 * stride)); + in[1] = _mm_loadl_epi64((const __m128i *)(input + 2 * stride)); + in[2] = _mm_loadl_epi64((const __m128i *)(input + 1 * stride)); + in[3] = _mm_loadl_epi64((const __m128i *)(input + 0 * stride)); + } + + if (fliplr) { + in[0] = _mm_shufflelo_epi16(in[0], 0x1b); + in[1] = _mm_shufflelo_epi16(in[1], 0x1b); + in[2] = _mm_shufflelo_epi16(in[2], 0x1b); + in[3] = _mm_shufflelo_epi16(in[3], 0x1b); + } + + in[0] = _mm_cvtepi16_epi32(in[0]); + in[1] = _mm_cvtepi16_epi32(in[1]); + in[2] = _mm_cvtepi16_epi32(in[2]); + in[3] = _mm_cvtepi16_epi32(in[3]); + + in[0] = _mm_slli_epi32(in[0], shift); + in[1] = _mm_slli_epi32(in[1], shift); + in[2] = _mm_slli_epi32(in[2], shift); + in[3] = _mm_slli_epi32(in[3], shift); +} + +// We only use stage-2 bit; +// shift[0] is used in load_buffer_4x4() +// shift[1] is used in txfm_func_col() +// shift[2] is used in txfm_func_row() +static void fdct4x4_sse4_1(__m128i *in, int bit) { + const int32_t *cospi = cospi_arr[bit - cos_bit_min]; + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + __m128i s0, s1, s2, s3; + __m128i u0, u1, u2, u3; + __m128i v0, v1, v2, v3; + + s0 = _mm_add_epi32(in[0], in[3]); + s1 = _mm_add_epi32(in[1], in[2]); + s2 = _mm_sub_epi32(in[1], in[2]); + s3 = _mm_sub_epi32(in[0], in[3]); + + // btf_32_sse4_1_type0(cospi32, cospi32, s[01], u[02], bit); + u0 = _mm_mullo_epi32(s0, cospi32); + u1 = _mm_mullo_epi32(s1, cospi32); + u2 = _mm_add_epi32(u0, u1); + v0 = _mm_sub_epi32(u0, u1); + + u3 = _mm_add_epi32(u2, rnding); + v1 = _mm_add_epi32(v0, rnding); + + u0 = _mm_srai_epi32(u3, bit); + u2 = _mm_srai_epi32(v1, bit); + + // btf_32_sse4_1_type1(cospi48, cospi16, s[23], u[13], bit); + v0 = _mm_mullo_epi32(s2, cospi48); + v1 = _mm_mullo_epi32(s3, cospi16); + v2 = _mm_add_epi32(v0, v1); + + v3 = _mm_add_epi32(v2, rnding); + u1 = _mm_srai_epi32(v3, bit); + + v0 = _mm_mullo_epi32(s2, cospi16); + v1 = _mm_mullo_epi32(s3, cospi48); + v2 = _mm_sub_epi32(v1, v0); + + v3 = _mm_add_epi32(v2, rnding); + u3 = _mm_srai_epi32(v3, bit); + + // Note: shift[1] and shift[2] are zeros + + // Transpose 4x4 32-bit + v0 = _mm_unpacklo_epi32(u0, u1); + v1 = _mm_unpackhi_epi32(u0, u1); + v2 = _mm_unpacklo_epi32(u2, u3); + v3 = _mm_unpackhi_epi32(u2, u3); + + in[0] = _mm_unpacklo_epi64(v0, v2); + in[1] = _mm_unpackhi_epi64(v0, v2); + in[2] = _mm_unpacklo_epi64(v1, v3); + in[3] = _mm_unpackhi_epi64(v1, v3); +} + +static INLINE void write_buffer_4x4(__m128i *res, tran_low_t *output) { + _mm_store_si128((__m128i *)(output + 0 * 4), res[0]); + _mm_store_si128((__m128i *)(output + 1 * 4), res[1]); + _mm_store_si128((__m128i *)(output + 2 * 4), res[2]); + _mm_store_si128((__m128i *)(output + 3 * 4), res[3]); +} + +// Note: +// We implement av1_fwd_txfm2d_4x4(). This function is kept here since +// av1_highbd_fht4x4_c() is not removed yet +void av1_highbd_fht4x4_sse4_1(const int16_t *input, tran_low_t *output, + int stride, int tx_type) { + (void)input; + (void)output; + (void)stride; + (void)tx_type; + assert(0); +} + +static void fadst4x4_sse4_1(__m128i *in, int bit) { + const int32_t *cospi = cospi_arr[bit - cos_bit_min]; + const __m128i cospi8 = _mm_set1_epi32(cospi[8]); + const __m128i cospi56 = _mm_set1_epi32(cospi[56]); + const __m128i cospi40 = _mm_set1_epi32(cospi[40]); + const __m128i cospi24 = _mm_set1_epi32(cospi[24]); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + const __m128i kZero = _mm_setzero_si128(); + __m128i s0, s1, s2, s3; + __m128i u0, u1, u2, u3; + __m128i v0, v1, v2, v3; + + // stage 0 + // stage 1 + // stage 2 + u0 = _mm_mullo_epi32(in[3], cospi8); + u1 = _mm_mullo_epi32(in[0], cospi56); + u2 = _mm_add_epi32(u0, u1); + s0 = _mm_add_epi32(u2, rnding); + s0 = _mm_srai_epi32(s0, bit); + + v0 = _mm_mullo_epi32(in[3], cospi56); + v1 = _mm_mullo_epi32(in[0], cospi8); + v2 = _mm_sub_epi32(v0, v1); + s1 = _mm_add_epi32(v2, rnding); + s1 = _mm_srai_epi32(s1, bit); + + u0 = _mm_mullo_epi32(in[1], cospi40); + u1 = _mm_mullo_epi32(in[2], cospi24); + u2 = _mm_add_epi32(u0, u1); + s2 = _mm_add_epi32(u2, rnding); + s2 = _mm_srai_epi32(s2, bit); + + v0 = _mm_mullo_epi32(in[1], cospi24); + v1 = _mm_mullo_epi32(in[2], cospi40); + v2 = _mm_sub_epi32(v0, v1); + s3 = _mm_add_epi32(v2, rnding); + s3 = _mm_srai_epi32(s3, bit); + + // stage 3 + u0 = _mm_add_epi32(s0, s2); + u2 = _mm_sub_epi32(s0, s2); + u1 = _mm_add_epi32(s1, s3); + u3 = _mm_sub_epi32(s1, s3); + + // stage 4 + v0 = _mm_mullo_epi32(u2, cospi32); + v1 = _mm_mullo_epi32(u3, cospi32); + v2 = _mm_add_epi32(v0, v1); + s2 = _mm_add_epi32(v2, rnding); + u2 = _mm_srai_epi32(s2, bit); + + v2 = _mm_sub_epi32(v0, v1); + s3 = _mm_add_epi32(v2, rnding); + u3 = _mm_srai_epi32(s3, bit); + + // u0, u1, u2, u3 + u2 = _mm_sub_epi32(kZero, u2); + u1 = _mm_sub_epi32(kZero, u1); + + // u0, u2, u3, u1 + // Transpose 4x4 32-bit + v0 = _mm_unpacklo_epi32(u0, u2); + v1 = _mm_unpackhi_epi32(u0, u2); + v2 = _mm_unpacklo_epi32(u3, u1); + v3 = _mm_unpackhi_epi32(u3, u1); + + in[0] = _mm_unpacklo_epi64(v0, v2); + in[1] = _mm_unpackhi_epi64(v0, v2); + in[2] = _mm_unpacklo_epi64(v1, v3); + in[3] = _mm_unpackhi_epi64(v1, v3); +} + +void av1_fwd_txfm2d_4x4_sse4_1(const int16_t *input, int32_t *coeff, + int input_stride, int tx_type, int bd) { + __m128i in[4]; + const TXFM_2D_CFG *cfg = NULL; + + switch (tx_type) { + case DCT_DCT: + cfg = &fwd_txfm_2d_cfg_dct_dct_4; + load_buffer_4x4(input, in, input_stride, 0, 0, cfg->shift[0]); + fdct4x4_sse4_1(in, cfg->cos_bit_col[2]); + fdct4x4_sse4_1(in, cfg->cos_bit_row[2]); + write_buffer_4x4(in, coeff); + break; + case ADST_DCT: + cfg = &fwd_txfm_2d_cfg_adst_dct_4; + load_buffer_4x4(input, in, input_stride, 0, 0, cfg->shift[0]); + fadst4x4_sse4_1(in, cfg->cos_bit_col[2]); + fdct4x4_sse4_1(in, cfg->cos_bit_row[2]); + write_buffer_4x4(in, coeff); + break; + case DCT_ADST: + cfg = &fwd_txfm_2d_cfg_dct_adst_4; + load_buffer_4x4(input, in, input_stride, 0, 0, cfg->shift[0]); + fdct4x4_sse4_1(in, cfg->cos_bit_col[2]); + fadst4x4_sse4_1(in, cfg->cos_bit_row[2]); + write_buffer_4x4(in, coeff); + break; + case ADST_ADST: + cfg = &fwd_txfm_2d_cfg_adst_adst_4; + load_buffer_4x4(input, in, input_stride, 0, 0, cfg->shift[0]); + fadst4x4_sse4_1(in, cfg->cos_bit_col[2]); + fadst4x4_sse4_1(in, cfg->cos_bit_row[2]); + write_buffer_4x4(in, coeff); + break; +#if CONFIG_EXT_TX + case FLIPADST_DCT: + cfg = &fwd_txfm_2d_cfg_adst_dct_4; + load_buffer_4x4(input, in, input_stride, 1, 0, cfg->shift[0]); + fadst4x4_sse4_1(in, cfg->cos_bit_col[2]); + fdct4x4_sse4_1(in, cfg->cos_bit_row[2]); + write_buffer_4x4(in, coeff); + break; + case DCT_FLIPADST: + cfg = &fwd_txfm_2d_cfg_dct_adst_4; + load_buffer_4x4(input, in, input_stride, 0, 1, cfg->shift[0]); + fdct4x4_sse4_1(in, cfg->cos_bit_col[2]); + fadst4x4_sse4_1(in, cfg->cos_bit_row[2]); + write_buffer_4x4(in, coeff); + break; + case FLIPADST_FLIPADST: + cfg = &fwd_txfm_2d_cfg_adst_adst_4; + load_buffer_4x4(input, in, input_stride, 1, 1, cfg->shift[0]); + fadst4x4_sse4_1(in, cfg->cos_bit_col[2]); + fadst4x4_sse4_1(in, cfg->cos_bit_row[2]); + write_buffer_4x4(in, coeff); + break; + case ADST_FLIPADST: + cfg = &fwd_txfm_2d_cfg_adst_adst_4; + load_buffer_4x4(input, in, input_stride, 0, 1, cfg->shift[0]); + fadst4x4_sse4_1(in, cfg->cos_bit_col[2]); + fadst4x4_sse4_1(in, cfg->cos_bit_row[2]); + write_buffer_4x4(in, coeff); + break; + case FLIPADST_ADST: + cfg = &fwd_txfm_2d_cfg_adst_adst_4; + load_buffer_4x4(input, in, input_stride, 1, 0, cfg->shift[0]); + fadst4x4_sse4_1(in, cfg->cos_bit_col[2]); + fadst4x4_sse4_1(in, cfg->cos_bit_row[2]); + write_buffer_4x4(in, coeff); + break; +#endif + default: assert(0); + } + (void)bd; +} + +static INLINE void load_buffer_8x8(const int16_t *input, __m128i *in, + int stride, int flipud, int fliplr, + int shift) { + __m128i u; + if (!flipud) { + in[0] = _mm_load_si128((const __m128i *)(input + 0 * stride)); + in[1] = _mm_load_si128((const __m128i *)(input + 1 * stride)); + in[2] = _mm_load_si128((const __m128i *)(input + 2 * stride)); + in[3] = _mm_load_si128((const __m128i *)(input + 3 * stride)); + in[4] = _mm_load_si128((const __m128i *)(input + 4 * stride)); + in[5] = _mm_load_si128((const __m128i *)(input + 5 * stride)); + in[6] = _mm_load_si128((const __m128i *)(input + 6 * stride)); + in[7] = _mm_load_si128((const __m128i *)(input + 7 * stride)); + } else { + in[0] = _mm_load_si128((const __m128i *)(input + 7 * stride)); + in[1] = _mm_load_si128((const __m128i *)(input + 6 * stride)); + in[2] = _mm_load_si128((const __m128i *)(input + 5 * stride)); + in[3] = _mm_load_si128((const __m128i *)(input + 4 * stride)); + in[4] = _mm_load_si128((const __m128i *)(input + 3 * stride)); + in[5] = _mm_load_si128((const __m128i *)(input + 2 * stride)); + in[6] = _mm_load_si128((const __m128i *)(input + 1 * stride)); + in[7] = _mm_load_si128((const __m128i *)(input + 0 * stride)); + } + + if (fliplr) { + in[0] = mm_reverse_epi16(in[0]); + in[1] = mm_reverse_epi16(in[1]); + in[2] = mm_reverse_epi16(in[2]); + in[3] = mm_reverse_epi16(in[3]); + in[4] = mm_reverse_epi16(in[4]); + in[5] = mm_reverse_epi16(in[5]); + in[6] = mm_reverse_epi16(in[6]); + in[7] = mm_reverse_epi16(in[7]); + } + + u = _mm_unpackhi_epi64(in[4], in[4]); + in[8] = _mm_cvtepi16_epi32(in[4]); + in[9] = _mm_cvtepi16_epi32(u); + + u = _mm_unpackhi_epi64(in[5], in[5]); + in[10] = _mm_cvtepi16_epi32(in[5]); + in[11] = _mm_cvtepi16_epi32(u); + + u = _mm_unpackhi_epi64(in[6], in[6]); + in[12] = _mm_cvtepi16_epi32(in[6]); + in[13] = _mm_cvtepi16_epi32(u); + + u = _mm_unpackhi_epi64(in[7], in[7]); + in[14] = _mm_cvtepi16_epi32(in[7]); + in[15] = _mm_cvtepi16_epi32(u); + + u = _mm_unpackhi_epi64(in[3], in[3]); + in[6] = _mm_cvtepi16_epi32(in[3]); + in[7] = _mm_cvtepi16_epi32(u); + + u = _mm_unpackhi_epi64(in[2], in[2]); + in[4] = _mm_cvtepi16_epi32(in[2]); + in[5] = _mm_cvtepi16_epi32(u); + + u = _mm_unpackhi_epi64(in[1], in[1]); + in[2] = _mm_cvtepi16_epi32(in[1]); + in[3] = _mm_cvtepi16_epi32(u); + + u = _mm_unpackhi_epi64(in[0], in[0]); + in[0] = _mm_cvtepi16_epi32(in[0]); + in[1] = _mm_cvtepi16_epi32(u); + + in[0] = _mm_slli_epi32(in[0], shift); + in[1] = _mm_slli_epi32(in[1], shift); + in[2] = _mm_slli_epi32(in[2], shift); + in[3] = _mm_slli_epi32(in[3], shift); + in[4] = _mm_slli_epi32(in[4], shift); + in[5] = _mm_slli_epi32(in[5], shift); + in[6] = _mm_slli_epi32(in[6], shift); + in[7] = _mm_slli_epi32(in[7], shift); + + in[8] = _mm_slli_epi32(in[8], shift); + in[9] = _mm_slli_epi32(in[9], shift); + in[10] = _mm_slli_epi32(in[10], shift); + in[11] = _mm_slli_epi32(in[11], shift); + in[12] = _mm_slli_epi32(in[12], shift); + in[13] = _mm_slli_epi32(in[13], shift); + in[14] = _mm_slli_epi32(in[14], shift); + in[15] = _mm_slli_epi32(in[15], shift); +} + +static INLINE void col_txfm_8x8_rounding(__m128i *in, int shift) { + const __m128i rounding = _mm_set1_epi32(1 << (shift - 1)); + + in[0] = _mm_add_epi32(in[0], rounding); + in[1] = _mm_add_epi32(in[1], rounding); + in[2] = _mm_add_epi32(in[2], rounding); + in[3] = _mm_add_epi32(in[3], rounding); + in[4] = _mm_add_epi32(in[4], rounding); + in[5] = _mm_add_epi32(in[5], rounding); + in[6] = _mm_add_epi32(in[6], rounding); + in[7] = _mm_add_epi32(in[7], rounding); + in[8] = _mm_add_epi32(in[8], rounding); + in[9] = _mm_add_epi32(in[9], rounding); + in[10] = _mm_add_epi32(in[10], rounding); + in[11] = _mm_add_epi32(in[11], rounding); + in[12] = _mm_add_epi32(in[12], rounding); + in[13] = _mm_add_epi32(in[13], rounding); + in[14] = _mm_add_epi32(in[14], rounding); + in[15] = _mm_add_epi32(in[15], rounding); + + in[0] = _mm_srai_epi32(in[0], shift); + in[1] = _mm_srai_epi32(in[1], shift); + in[2] = _mm_srai_epi32(in[2], shift); + in[3] = _mm_srai_epi32(in[3], shift); + in[4] = _mm_srai_epi32(in[4], shift); + in[5] = _mm_srai_epi32(in[5], shift); + in[6] = _mm_srai_epi32(in[6], shift); + in[7] = _mm_srai_epi32(in[7], shift); + in[8] = _mm_srai_epi32(in[8], shift); + in[9] = _mm_srai_epi32(in[9], shift); + in[10] = _mm_srai_epi32(in[10], shift); + in[11] = _mm_srai_epi32(in[11], shift); + in[12] = _mm_srai_epi32(in[12], shift); + in[13] = _mm_srai_epi32(in[13], shift); + in[14] = _mm_srai_epi32(in[14], shift); + in[15] = _mm_srai_epi32(in[15], shift); +} + +static INLINE void write_buffer_8x8(const __m128i *res, tran_low_t *output) { + _mm_store_si128((__m128i *)(output + 0 * 4), res[0]); + _mm_store_si128((__m128i *)(output + 1 * 4), res[1]); + _mm_store_si128((__m128i *)(output + 2 * 4), res[2]); + _mm_store_si128((__m128i *)(output + 3 * 4), res[3]); + + _mm_store_si128((__m128i *)(output + 4 * 4), res[4]); + _mm_store_si128((__m128i *)(output + 5 * 4), res[5]); + _mm_store_si128((__m128i *)(output + 6 * 4), res[6]); + _mm_store_si128((__m128i *)(output + 7 * 4), res[7]); + + _mm_store_si128((__m128i *)(output + 8 * 4), res[8]); + _mm_store_si128((__m128i *)(output + 9 * 4), res[9]); + _mm_store_si128((__m128i *)(output + 10 * 4), res[10]); + _mm_store_si128((__m128i *)(output + 11 * 4), res[11]); + + _mm_store_si128((__m128i *)(output + 12 * 4), res[12]); + _mm_store_si128((__m128i *)(output + 13 * 4), res[13]); + _mm_store_si128((__m128i *)(output + 14 * 4), res[14]); + _mm_store_si128((__m128i *)(output + 15 * 4), res[15]); +} + +static void fdct8x8_sse4_1(__m128i *in, __m128i *out, int bit) { + const int32_t *cospi = cospi_arr[bit - cos_bit_min]; + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i cospim32 = _mm_set1_epi32(-cospi[32]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i cospi56 = _mm_set1_epi32(cospi[56]); + const __m128i cospi8 = _mm_set1_epi32(cospi[8]); + const __m128i cospi24 = _mm_set1_epi32(cospi[24]); + const __m128i cospi40 = _mm_set1_epi32(cospi[40]); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + __m128i u[8], v[8]; + + // Even 8 points 0, 2, ..., 14 + // stage 0 + // stage 1 + u[0] = _mm_add_epi32(in[0], in[14]); + v[7] = _mm_sub_epi32(in[0], in[14]); // v[7] + u[1] = _mm_add_epi32(in[2], in[12]); + u[6] = _mm_sub_epi32(in[2], in[12]); + u[2] = _mm_add_epi32(in[4], in[10]); + u[5] = _mm_sub_epi32(in[4], in[10]); + u[3] = _mm_add_epi32(in[6], in[8]); + v[4] = _mm_sub_epi32(in[6], in[8]); // v[4] + + // stage 2 + v[0] = _mm_add_epi32(u[0], u[3]); + v[3] = _mm_sub_epi32(u[0], u[3]); + v[1] = _mm_add_epi32(u[1], u[2]); + v[2] = _mm_sub_epi32(u[1], u[2]); + + v[5] = _mm_mullo_epi32(u[5], cospim32); + v[6] = _mm_mullo_epi32(u[6], cospi32); + v[5] = _mm_add_epi32(v[5], v[6]); + v[5] = _mm_add_epi32(v[5], rnding); + v[5] = _mm_srai_epi32(v[5], bit); + + u[0] = _mm_mullo_epi32(u[5], cospi32); + v[6] = _mm_mullo_epi32(u[6], cospim32); + v[6] = _mm_sub_epi32(u[0], v[6]); + v[6] = _mm_add_epi32(v[6], rnding); + v[6] = _mm_srai_epi32(v[6], bit); + + // stage 3 + // type 0 + v[0] = _mm_mullo_epi32(v[0], cospi32); + v[1] = _mm_mullo_epi32(v[1], cospi32); + u[0] = _mm_add_epi32(v[0], v[1]); + u[0] = _mm_add_epi32(u[0], rnding); + u[0] = _mm_srai_epi32(u[0], bit); + + u[1] = _mm_sub_epi32(v[0], v[1]); + u[1] = _mm_add_epi32(u[1], rnding); + u[1] = _mm_srai_epi32(u[1], bit); + + // type 1 + v[0] = _mm_mullo_epi32(v[2], cospi48); + v[1] = _mm_mullo_epi32(v[3], cospi16); + u[2] = _mm_add_epi32(v[0], v[1]); + u[2] = _mm_add_epi32(u[2], rnding); + u[2] = _mm_srai_epi32(u[2], bit); + + v[0] = _mm_mullo_epi32(v[2], cospi16); + v[1] = _mm_mullo_epi32(v[3], cospi48); + u[3] = _mm_sub_epi32(v[1], v[0]); + u[3] = _mm_add_epi32(u[3], rnding); + u[3] = _mm_srai_epi32(u[3], bit); + + u[4] = _mm_add_epi32(v[4], v[5]); + u[5] = _mm_sub_epi32(v[4], v[5]); + u[6] = _mm_sub_epi32(v[7], v[6]); + u[7] = _mm_add_epi32(v[7], v[6]); + + // stage 4 + // stage 5 + v[0] = _mm_mullo_epi32(u[4], cospi56); + v[1] = _mm_mullo_epi32(u[7], cospi8); + v[0] = _mm_add_epi32(v[0], v[1]); + v[0] = _mm_add_epi32(v[0], rnding); + out[2] = _mm_srai_epi32(v[0], bit); // buf0[4] + + v[0] = _mm_mullo_epi32(u[4], cospi8); + v[1] = _mm_mullo_epi32(u[7], cospi56); + v[0] = _mm_sub_epi32(v[1], v[0]); + v[0] = _mm_add_epi32(v[0], rnding); + out[14] = _mm_srai_epi32(v[0], bit); // buf0[7] + + v[0] = _mm_mullo_epi32(u[5], cospi24); + v[1] = _mm_mullo_epi32(u[6], cospi40); + v[0] = _mm_add_epi32(v[0], v[1]); + v[0] = _mm_add_epi32(v[0], rnding); + out[10] = _mm_srai_epi32(v[0], bit); // buf0[5] + + v[0] = _mm_mullo_epi32(u[5], cospi40); + v[1] = _mm_mullo_epi32(u[6], cospi24); + v[0] = _mm_sub_epi32(v[1], v[0]); + v[0] = _mm_add_epi32(v[0], rnding); + out[6] = _mm_srai_epi32(v[0], bit); // buf0[6] + + out[0] = u[0]; // buf0[0] + out[8] = u[1]; // buf0[1] + out[4] = u[2]; // buf0[2] + out[12] = u[3]; // buf0[3] + + // Odd 8 points: 1, 3, ..., 15 + // stage 0 + // stage 1 + u[0] = _mm_add_epi32(in[1], in[15]); + v[7] = _mm_sub_epi32(in[1], in[15]); // v[7] + u[1] = _mm_add_epi32(in[3], in[13]); + u[6] = _mm_sub_epi32(in[3], in[13]); + u[2] = _mm_add_epi32(in[5], in[11]); + u[5] = _mm_sub_epi32(in[5], in[11]); + u[3] = _mm_add_epi32(in[7], in[9]); + v[4] = _mm_sub_epi32(in[7], in[9]); // v[4] + + // stage 2 + v[0] = _mm_add_epi32(u[0], u[3]); + v[3] = _mm_sub_epi32(u[0], u[3]); + v[1] = _mm_add_epi32(u[1], u[2]); + v[2] = _mm_sub_epi32(u[1], u[2]); + + v[5] = _mm_mullo_epi32(u[5], cospim32); + v[6] = _mm_mullo_epi32(u[6], cospi32); + v[5] = _mm_add_epi32(v[5], v[6]); + v[5] = _mm_add_epi32(v[5], rnding); + v[5] = _mm_srai_epi32(v[5], bit); + + u[0] = _mm_mullo_epi32(u[5], cospi32); + v[6] = _mm_mullo_epi32(u[6], cospim32); + v[6] = _mm_sub_epi32(u[0], v[6]); + v[6] = _mm_add_epi32(v[6], rnding); + v[6] = _mm_srai_epi32(v[6], bit); + + // stage 3 + // type 0 + v[0] = _mm_mullo_epi32(v[0], cospi32); + v[1] = _mm_mullo_epi32(v[1], cospi32); + u[0] = _mm_add_epi32(v[0], v[1]); + u[0] = _mm_add_epi32(u[0], rnding); + u[0] = _mm_srai_epi32(u[0], bit); + + u[1] = _mm_sub_epi32(v[0], v[1]); + u[1] = _mm_add_epi32(u[1], rnding); + u[1] = _mm_srai_epi32(u[1], bit); + + // type 1 + v[0] = _mm_mullo_epi32(v[2], cospi48); + v[1] = _mm_mullo_epi32(v[3], cospi16); + u[2] = _mm_add_epi32(v[0], v[1]); + u[2] = _mm_add_epi32(u[2], rnding); + u[2] = _mm_srai_epi32(u[2], bit); + + v[0] = _mm_mullo_epi32(v[2], cospi16); + v[1] = _mm_mullo_epi32(v[3], cospi48); + u[3] = _mm_sub_epi32(v[1], v[0]); + u[3] = _mm_add_epi32(u[3], rnding); + u[3] = _mm_srai_epi32(u[3], bit); + + u[4] = _mm_add_epi32(v[4], v[5]); + u[5] = _mm_sub_epi32(v[4], v[5]); + u[6] = _mm_sub_epi32(v[7], v[6]); + u[7] = _mm_add_epi32(v[7], v[6]); + + // stage 4 + // stage 5 + v[0] = _mm_mullo_epi32(u[4], cospi56); + v[1] = _mm_mullo_epi32(u[7], cospi8); + v[0] = _mm_add_epi32(v[0], v[1]); + v[0] = _mm_add_epi32(v[0], rnding); + out[3] = _mm_srai_epi32(v[0], bit); // buf0[4] + + v[0] = _mm_mullo_epi32(u[4], cospi8); + v[1] = _mm_mullo_epi32(u[7], cospi56); + v[0] = _mm_sub_epi32(v[1], v[0]); + v[0] = _mm_add_epi32(v[0], rnding); + out[15] = _mm_srai_epi32(v[0], bit); // buf0[7] + + v[0] = _mm_mullo_epi32(u[5], cospi24); + v[1] = _mm_mullo_epi32(u[6], cospi40); + v[0] = _mm_add_epi32(v[0], v[1]); + v[0] = _mm_add_epi32(v[0], rnding); + out[11] = _mm_srai_epi32(v[0], bit); // buf0[5] + + v[0] = _mm_mullo_epi32(u[5], cospi40); + v[1] = _mm_mullo_epi32(u[6], cospi24); + v[0] = _mm_sub_epi32(v[1], v[0]); + v[0] = _mm_add_epi32(v[0], rnding); + out[7] = _mm_srai_epi32(v[0], bit); // buf0[6] + + out[1] = u[0]; // buf0[0] + out[9] = u[1]; // buf0[1] + out[5] = u[2]; // buf0[2] + out[13] = u[3]; // buf0[3] +} + +static void fadst8x8_sse4_1(__m128i *in, __m128i *out, int bit) { + const int32_t *cospi = cospi_arr[bit - cos_bit_min]; + const __m128i cospi4 = _mm_set1_epi32(cospi[4]); + const __m128i cospi60 = _mm_set1_epi32(cospi[60]); + const __m128i cospi20 = _mm_set1_epi32(cospi[20]); + const __m128i cospi44 = _mm_set1_epi32(cospi[44]); + const __m128i cospi36 = _mm_set1_epi32(cospi[36]); + const __m128i cospi28 = _mm_set1_epi32(cospi[28]); + const __m128i cospi52 = _mm_set1_epi32(cospi[52]); + const __m128i cospi12 = _mm_set1_epi32(cospi[12]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospim48 = _mm_set1_epi32(-cospi[48]); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + const __m128i kZero = _mm_setzero_si128(); + __m128i u[8], v[8], x; + + // Even 8 points: 0, 2, ..., 14 + // stage 0 + // stage 1 + // stage 2 + // (1) + u[0] = _mm_mullo_epi32(in[14], cospi4); + x = _mm_mullo_epi32(in[0], cospi60); + u[0] = _mm_add_epi32(u[0], x); + u[0] = _mm_add_epi32(u[0], rnding); + u[0] = _mm_srai_epi32(u[0], bit); + + u[1] = _mm_mullo_epi32(in[14], cospi60); + x = _mm_mullo_epi32(in[0], cospi4); + u[1] = _mm_sub_epi32(u[1], x); + u[1] = _mm_add_epi32(u[1], rnding); + u[1] = _mm_srai_epi32(u[1], bit); + + // (2) + u[2] = _mm_mullo_epi32(in[10], cospi20); + x = _mm_mullo_epi32(in[4], cospi44); + u[2] = _mm_add_epi32(u[2], x); + u[2] = _mm_add_epi32(u[2], rnding); + u[2] = _mm_srai_epi32(u[2], bit); + + u[3] = _mm_mullo_epi32(in[10], cospi44); + x = _mm_mullo_epi32(in[4], cospi20); + u[3] = _mm_sub_epi32(u[3], x); + u[3] = _mm_add_epi32(u[3], rnding); + u[3] = _mm_srai_epi32(u[3], bit); + + // (3) + u[4] = _mm_mullo_epi32(in[6], cospi36); + x = _mm_mullo_epi32(in[8], cospi28); + u[4] = _mm_add_epi32(u[4], x); + u[4] = _mm_add_epi32(u[4], rnding); + u[4] = _mm_srai_epi32(u[4], bit); + + u[5] = _mm_mullo_epi32(in[6], cospi28); + x = _mm_mullo_epi32(in[8], cospi36); + u[5] = _mm_sub_epi32(u[5], x); + u[5] = _mm_add_epi32(u[5], rnding); + u[5] = _mm_srai_epi32(u[5], bit); + + // (4) + u[6] = _mm_mullo_epi32(in[2], cospi52); + x = _mm_mullo_epi32(in[12], cospi12); + u[6] = _mm_add_epi32(u[6], x); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + u[7] = _mm_mullo_epi32(in[2], cospi12); + x = _mm_mullo_epi32(in[12], cospi52); + u[7] = _mm_sub_epi32(u[7], x); + u[7] = _mm_add_epi32(u[7], rnding); + u[7] = _mm_srai_epi32(u[7], bit); + + // stage 3 + v[0] = _mm_add_epi32(u[0], u[4]); + v[4] = _mm_sub_epi32(u[0], u[4]); + v[1] = _mm_add_epi32(u[1], u[5]); + v[5] = _mm_sub_epi32(u[1], u[5]); + v[2] = _mm_add_epi32(u[2], u[6]); + v[6] = _mm_sub_epi32(u[2], u[6]); + v[3] = _mm_add_epi32(u[3], u[7]); + v[7] = _mm_sub_epi32(u[3], u[7]); + + // stage 4 + u[0] = v[0]; + u[1] = v[1]; + u[2] = v[2]; + u[3] = v[3]; + + u[4] = _mm_mullo_epi32(v[4], cospi16); + x = _mm_mullo_epi32(v[5], cospi48); + u[4] = _mm_add_epi32(u[4], x); + u[4] = _mm_add_epi32(u[4], rnding); + u[4] = _mm_srai_epi32(u[4], bit); + + u[5] = _mm_mullo_epi32(v[4], cospi48); + x = _mm_mullo_epi32(v[5], cospi16); + u[5] = _mm_sub_epi32(u[5], x); + u[5] = _mm_add_epi32(u[5], rnding); + u[5] = _mm_srai_epi32(u[5], bit); + + u[6] = _mm_mullo_epi32(v[6], cospim48); + x = _mm_mullo_epi32(v[7], cospi16); + u[6] = _mm_add_epi32(u[6], x); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + u[7] = _mm_mullo_epi32(v[6], cospi16); + x = _mm_mullo_epi32(v[7], cospim48); + u[7] = _mm_sub_epi32(u[7], x); + u[7] = _mm_add_epi32(u[7], rnding); + u[7] = _mm_srai_epi32(u[7], bit); + + // stage 5 + v[0] = _mm_add_epi32(u[0], u[2]); + v[2] = _mm_sub_epi32(u[0], u[2]); + v[1] = _mm_add_epi32(u[1], u[3]); + v[3] = _mm_sub_epi32(u[1], u[3]); + v[4] = _mm_add_epi32(u[4], u[6]); + v[6] = _mm_sub_epi32(u[4], u[6]); + v[5] = _mm_add_epi32(u[5], u[7]); + v[7] = _mm_sub_epi32(u[5], u[7]); + + // stage 6 + u[0] = v[0]; + u[1] = v[1]; + u[4] = v[4]; + u[5] = v[5]; + + v[0] = _mm_mullo_epi32(v[2], cospi32); + x = _mm_mullo_epi32(v[3], cospi32); + u[2] = _mm_add_epi32(v[0], x); + u[2] = _mm_add_epi32(u[2], rnding); + u[2] = _mm_srai_epi32(u[2], bit); + + u[3] = _mm_sub_epi32(v[0], x); + u[3] = _mm_add_epi32(u[3], rnding); + u[3] = _mm_srai_epi32(u[3], bit); + + v[0] = _mm_mullo_epi32(v[6], cospi32); + x = _mm_mullo_epi32(v[7], cospi32); + u[6] = _mm_add_epi32(v[0], x); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + u[7] = _mm_sub_epi32(v[0], x); + u[7] = _mm_add_epi32(u[7], rnding); + u[7] = _mm_srai_epi32(u[7], bit); + + // stage 7 + out[0] = u[0]; + out[2] = _mm_sub_epi32(kZero, u[4]); + out[4] = u[6]; + out[6] = _mm_sub_epi32(kZero, u[2]); + out[8] = u[3]; + out[10] = _mm_sub_epi32(kZero, u[7]); + out[12] = u[5]; + out[14] = _mm_sub_epi32(kZero, u[1]); + + // Odd 8 points: 1, 3, ..., 15 + // stage 0 + // stage 1 + // stage 2 + // (1) + u[0] = _mm_mullo_epi32(in[15], cospi4); + x = _mm_mullo_epi32(in[1], cospi60); + u[0] = _mm_add_epi32(u[0], x); + u[0] = _mm_add_epi32(u[0], rnding); + u[0] = _mm_srai_epi32(u[0], bit); + + u[1] = _mm_mullo_epi32(in[15], cospi60); + x = _mm_mullo_epi32(in[1], cospi4); + u[1] = _mm_sub_epi32(u[1], x); + u[1] = _mm_add_epi32(u[1], rnding); + u[1] = _mm_srai_epi32(u[1], bit); + + // (2) + u[2] = _mm_mullo_epi32(in[11], cospi20); + x = _mm_mullo_epi32(in[5], cospi44); + u[2] = _mm_add_epi32(u[2], x); + u[2] = _mm_add_epi32(u[2], rnding); + u[2] = _mm_srai_epi32(u[2], bit); + + u[3] = _mm_mullo_epi32(in[11], cospi44); + x = _mm_mullo_epi32(in[5], cospi20); + u[3] = _mm_sub_epi32(u[3], x); + u[3] = _mm_add_epi32(u[3], rnding); + u[3] = _mm_srai_epi32(u[3], bit); + + // (3) + u[4] = _mm_mullo_epi32(in[7], cospi36); + x = _mm_mullo_epi32(in[9], cospi28); + u[4] = _mm_add_epi32(u[4], x); + u[4] = _mm_add_epi32(u[4], rnding); + u[4] = _mm_srai_epi32(u[4], bit); + + u[5] = _mm_mullo_epi32(in[7], cospi28); + x = _mm_mullo_epi32(in[9], cospi36); + u[5] = _mm_sub_epi32(u[5], x); + u[5] = _mm_add_epi32(u[5], rnding); + u[5] = _mm_srai_epi32(u[5], bit); + + // (4) + u[6] = _mm_mullo_epi32(in[3], cospi52); + x = _mm_mullo_epi32(in[13], cospi12); + u[6] = _mm_add_epi32(u[6], x); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + u[7] = _mm_mullo_epi32(in[3], cospi12); + x = _mm_mullo_epi32(in[13], cospi52); + u[7] = _mm_sub_epi32(u[7], x); + u[7] = _mm_add_epi32(u[7], rnding); + u[7] = _mm_srai_epi32(u[7], bit); + + // stage 3 + v[0] = _mm_add_epi32(u[0], u[4]); + v[4] = _mm_sub_epi32(u[0], u[4]); + v[1] = _mm_add_epi32(u[1], u[5]); + v[5] = _mm_sub_epi32(u[1], u[5]); + v[2] = _mm_add_epi32(u[2], u[6]); + v[6] = _mm_sub_epi32(u[2], u[6]); + v[3] = _mm_add_epi32(u[3], u[7]); + v[7] = _mm_sub_epi32(u[3], u[7]); + + // stage 4 + u[0] = v[0]; + u[1] = v[1]; + u[2] = v[2]; + u[3] = v[3]; + + u[4] = _mm_mullo_epi32(v[4], cospi16); + x = _mm_mullo_epi32(v[5], cospi48); + u[4] = _mm_add_epi32(u[4], x); + u[4] = _mm_add_epi32(u[4], rnding); + u[4] = _mm_srai_epi32(u[4], bit); + + u[5] = _mm_mullo_epi32(v[4], cospi48); + x = _mm_mullo_epi32(v[5], cospi16); + u[5] = _mm_sub_epi32(u[5], x); + u[5] = _mm_add_epi32(u[5], rnding); + u[5] = _mm_srai_epi32(u[5], bit); + + u[6] = _mm_mullo_epi32(v[6], cospim48); + x = _mm_mullo_epi32(v[7], cospi16); + u[6] = _mm_add_epi32(u[6], x); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + u[7] = _mm_mullo_epi32(v[6], cospi16); + x = _mm_mullo_epi32(v[7], cospim48); + u[7] = _mm_sub_epi32(u[7], x); + u[7] = _mm_add_epi32(u[7], rnding); + u[7] = _mm_srai_epi32(u[7], bit); + + // stage 5 + v[0] = _mm_add_epi32(u[0], u[2]); + v[2] = _mm_sub_epi32(u[0], u[2]); + v[1] = _mm_add_epi32(u[1], u[3]); + v[3] = _mm_sub_epi32(u[1], u[3]); + v[4] = _mm_add_epi32(u[4], u[6]); + v[6] = _mm_sub_epi32(u[4], u[6]); + v[5] = _mm_add_epi32(u[5], u[7]); + v[7] = _mm_sub_epi32(u[5], u[7]); + + // stage 6 + u[0] = v[0]; + u[1] = v[1]; + u[4] = v[4]; + u[5] = v[5]; + + v[0] = _mm_mullo_epi32(v[2], cospi32); + x = _mm_mullo_epi32(v[3], cospi32); + u[2] = _mm_add_epi32(v[0], x); + u[2] = _mm_add_epi32(u[2], rnding); + u[2] = _mm_srai_epi32(u[2], bit); + + u[3] = _mm_sub_epi32(v[0], x); + u[3] = _mm_add_epi32(u[3], rnding); + u[3] = _mm_srai_epi32(u[3], bit); + + v[0] = _mm_mullo_epi32(v[6], cospi32); + x = _mm_mullo_epi32(v[7], cospi32); + u[6] = _mm_add_epi32(v[0], x); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + u[7] = _mm_sub_epi32(v[0], x); + u[7] = _mm_add_epi32(u[7], rnding); + u[7] = _mm_srai_epi32(u[7], bit); + + // stage 7 + out[1] = u[0]; + out[3] = _mm_sub_epi32(kZero, u[4]); + out[5] = u[6]; + out[7] = _mm_sub_epi32(kZero, u[2]); + out[9] = u[3]; + out[11] = _mm_sub_epi32(kZero, u[7]); + out[13] = u[5]; + out[15] = _mm_sub_epi32(kZero, u[1]); +} + +void av1_fwd_txfm2d_8x8_sse4_1(const int16_t *input, int32_t *coeff, int stride, + int tx_type, int bd) { + __m128i in[16], out[16]; + const TXFM_2D_CFG *cfg = NULL; + + switch (tx_type) { + case DCT_DCT: + cfg = &fwd_txfm_2d_cfg_dct_dct_8; + load_buffer_8x8(input, in, stride, 0, 0, cfg->shift[0]); + fdct8x8_sse4_1(in, out, cfg->cos_bit_col[2]); + col_txfm_8x8_rounding(out, -cfg->shift[1]); + transpose_8x8(out, in); + fdct8x8_sse4_1(in, out, cfg->cos_bit_row[2]); + transpose_8x8(out, in); + write_buffer_8x8(in, coeff); + break; + case ADST_DCT: + cfg = &fwd_txfm_2d_cfg_adst_dct_8; + load_buffer_8x8(input, in, stride, 0, 0, cfg->shift[0]); + fadst8x8_sse4_1(in, out, cfg->cos_bit_col[2]); + col_txfm_8x8_rounding(out, -cfg->shift[1]); + transpose_8x8(out, in); + fdct8x8_sse4_1(in, out, cfg->cos_bit_row[2]); + transpose_8x8(out, in); + write_buffer_8x8(in, coeff); + break; + case DCT_ADST: + cfg = &fwd_txfm_2d_cfg_dct_adst_8; + load_buffer_8x8(input, in, stride, 0, 0, cfg->shift[0]); + fdct8x8_sse4_1(in, out, cfg->cos_bit_col[2]); + col_txfm_8x8_rounding(out, -cfg->shift[1]); + transpose_8x8(out, in); + fadst8x8_sse4_1(in, out, cfg->cos_bit_row[2]); + transpose_8x8(out, in); + write_buffer_8x8(in, coeff); + break; + case ADST_ADST: + cfg = &fwd_txfm_2d_cfg_adst_adst_8; + load_buffer_8x8(input, in, stride, 0, 0, cfg->shift[0]); + fadst8x8_sse4_1(in, out, cfg->cos_bit_col[2]); + col_txfm_8x8_rounding(out, -cfg->shift[1]); + transpose_8x8(out, in); + fadst8x8_sse4_1(in, out, cfg->cos_bit_row[2]); + transpose_8x8(out, in); + write_buffer_8x8(in, coeff); + break; +#if CONFIG_EXT_TX + case FLIPADST_DCT: + cfg = &fwd_txfm_2d_cfg_adst_dct_8; + load_buffer_8x8(input, in, stride, 1, 0, cfg->shift[0]); + fadst8x8_sse4_1(in, out, cfg->cos_bit_col[2]); + col_txfm_8x8_rounding(out, -cfg->shift[1]); + transpose_8x8(out, in); + fdct8x8_sse4_1(in, out, cfg->cos_bit_row[2]); + transpose_8x8(out, in); + write_buffer_8x8(in, coeff); + break; + case DCT_FLIPADST: + cfg = &fwd_txfm_2d_cfg_dct_adst_8; + load_buffer_8x8(input, in, stride, 0, 1, cfg->shift[0]); + fdct8x8_sse4_1(in, out, cfg->cos_bit_col[2]); + col_txfm_8x8_rounding(out, -cfg->shift[1]); + transpose_8x8(out, in); + fadst8x8_sse4_1(in, out, cfg->cos_bit_row[2]); + transpose_8x8(out, in); + write_buffer_8x8(in, coeff); + break; + case FLIPADST_FLIPADST: + cfg = &fwd_txfm_2d_cfg_adst_adst_8; + load_buffer_8x8(input, in, stride, 1, 1, cfg->shift[0]); + fadst8x8_sse4_1(in, out, cfg->cos_bit_col[2]); + col_txfm_8x8_rounding(out, -cfg->shift[1]); + transpose_8x8(out, in); + fadst8x8_sse4_1(in, out, cfg->cos_bit_row[2]); + transpose_8x8(out, in); + write_buffer_8x8(in, coeff); + break; + case ADST_FLIPADST: + cfg = &fwd_txfm_2d_cfg_adst_adst_8; + load_buffer_8x8(input, in, stride, 0, 1, cfg->shift[0]); + fadst8x8_sse4_1(in, out, cfg->cos_bit_col[2]); + col_txfm_8x8_rounding(out, -cfg->shift[1]); + transpose_8x8(out, in); + fadst8x8_sse4_1(in, out, cfg->cos_bit_row[2]); + transpose_8x8(out, in); + write_buffer_8x8(in, coeff); + break; + case FLIPADST_ADST: + cfg = &fwd_txfm_2d_cfg_adst_adst_8; + load_buffer_8x8(input, in, stride, 1, 0, cfg->shift[0]); + fadst8x8_sse4_1(in, out, cfg->cos_bit_col[2]); + col_txfm_8x8_rounding(out, -cfg->shift[1]); + transpose_8x8(out, in); + fadst8x8_sse4_1(in, out, cfg->cos_bit_row[2]); + transpose_8x8(out, in); + write_buffer_8x8(in, coeff); + break; +#endif // CONFIG_EXT_TX + default: assert(0); + } + (void)bd; +} + +// Hybrid Transform 16x16 + +static INLINE void convert_8x8_to_16x16(const __m128i *in, __m128i *out) { + int row_index = 0; + int dst_index = 0; + int src_index = 0; + + // row 0, 1, .., 7 + do { + out[dst_index] = in[src_index]; + out[dst_index + 1] = in[src_index + 1]; + out[dst_index + 2] = in[src_index + 16]; + out[dst_index + 3] = in[src_index + 17]; + dst_index += 4; + src_index += 2; + row_index += 1; + } while (row_index < 8); + + // row 8, 9, ..., 15 + src_index += 16; + do { + out[dst_index] = in[src_index]; + out[dst_index + 1] = in[src_index + 1]; + out[dst_index + 2] = in[src_index + 16]; + out[dst_index + 3] = in[src_index + 17]; + dst_index += 4; + src_index += 2; + row_index += 1; + } while (row_index < 16); +} + +static INLINE void load_buffer_16x16(const int16_t *input, __m128i *out, + int stride, int flipud, int fliplr, + int shift) { + __m128i in[64]; + // Load 4 8x8 blocks + const int16_t *topL = input; + const int16_t *topR = input + 8; + const int16_t *botL = input + 8 * stride; + const int16_t *botR = input + 8 * stride + 8; + + const int16_t *tmp; + + if (flipud) { + // Swap left columns + tmp = topL; + topL = botL; + botL = tmp; + // Swap right columns + tmp = topR; + topR = botR; + botR = tmp; + } + + if (fliplr) { + // Swap top rows + tmp = topL; + topL = topR; + topR = tmp; + // Swap bottom rows + tmp = botL; + botL = botR; + botR = tmp; + } + + // load first 8 columns + load_buffer_8x8(topL, &in[0], stride, flipud, fliplr, shift); + load_buffer_8x8(botL, &in[32], stride, flipud, fliplr, shift); + + // load second 8 columns + load_buffer_8x8(topR, &in[16], stride, flipud, fliplr, shift); + load_buffer_8x8(botR, &in[48], stride, flipud, fliplr, shift); + + convert_8x8_to_16x16(in, out); +} + +static void fdct16x16_sse4_1(__m128i *in, __m128i *out, int bit) { + const int32_t *cospi = cospi_arr[bit - cos_bit_min]; + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i cospim32 = _mm_set1_epi32(-cospi[32]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i cospim48 = _mm_set1_epi32(-cospi[48]); + const __m128i cospim16 = _mm_set1_epi32(-cospi[16]); + const __m128i cospi56 = _mm_set1_epi32(cospi[56]); + const __m128i cospi8 = _mm_set1_epi32(cospi[8]); + const __m128i cospi24 = _mm_set1_epi32(cospi[24]); + const __m128i cospi40 = _mm_set1_epi32(cospi[40]); + const __m128i cospi60 = _mm_set1_epi32(cospi[60]); + const __m128i cospi4 = _mm_set1_epi32(cospi[4]); + const __m128i cospi28 = _mm_set1_epi32(cospi[28]); + const __m128i cospi36 = _mm_set1_epi32(cospi[36]); + const __m128i cospi44 = _mm_set1_epi32(cospi[44]); + const __m128i cospi20 = _mm_set1_epi32(cospi[20]); + const __m128i cospi12 = _mm_set1_epi32(cospi[12]); + const __m128i cospi52 = _mm_set1_epi32(cospi[52]); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + __m128i u[16], v[16], x; + const int col_num = 4; + int col; + + // Calculate the column 0, 1, 2, 3 + for (col = 0; col < col_num; ++col) { + // stage 0 + // stage 1 + u[0] = _mm_add_epi32(in[0 * col_num + col], in[15 * col_num + col]); + u[15] = _mm_sub_epi32(in[0 * col_num + col], in[15 * col_num + col]); + u[1] = _mm_add_epi32(in[1 * col_num + col], in[14 * col_num + col]); + u[14] = _mm_sub_epi32(in[1 * col_num + col], in[14 * col_num + col]); + u[2] = _mm_add_epi32(in[2 * col_num + col], in[13 * col_num + col]); + u[13] = _mm_sub_epi32(in[2 * col_num + col], in[13 * col_num + col]); + u[3] = _mm_add_epi32(in[3 * col_num + col], in[12 * col_num + col]); + u[12] = _mm_sub_epi32(in[3 * col_num + col], in[12 * col_num + col]); + u[4] = _mm_add_epi32(in[4 * col_num + col], in[11 * col_num + col]); + u[11] = _mm_sub_epi32(in[4 * col_num + col], in[11 * col_num + col]); + u[5] = _mm_add_epi32(in[5 * col_num + col], in[10 * col_num + col]); + u[10] = _mm_sub_epi32(in[5 * col_num + col], in[10 * col_num + col]); + u[6] = _mm_add_epi32(in[6 * col_num + col], in[9 * col_num + col]); + u[9] = _mm_sub_epi32(in[6 * col_num + col], in[9 * col_num + col]); + u[7] = _mm_add_epi32(in[7 * col_num + col], in[8 * col_num + col]); + u[8] = _mm_sub_epi32(in[7 * col_num + col], in[8 * col_num + col]); + + // stage 2 + v[0] = _mm_add_epi32(u[0], u[7]); + v[7] = _mm_sub_epi32(u[0], u[7]); + v[1] = _mm_add_epi32(u[1], u[6]); + v[6] = _mm_sub_epi32(u[1], u[6]); + v[2] = _mm_add_epi32(u[2], u[5]); + v[5] = _mm_sub_epi32(u[2], u[5]); + v[3] = _mm_add_epi32(u[3], u[4]); + v[4] = _mm_sub_epi32(u[3], u[4]); + v[8] = u[8]; + v[9] = u[9]; + + v[10] = _mm_mullo_epi32(u[10], cospim32); + x = _mm_mullo_epi32(u[13], cospi32); + v[10] = _mm_add_epi32(v[10], x); + v[10] = _mm_add_epi32(v[10], rnding); + v[10] = _mm_srai_epi32(v[10], bit); + + v[13] = _mm_mullo_epi32(u[10], cospi32); + x = _mm_mullo_epi32(u[13], cospim32); + v[13] = _mm_sub_epi32(v[13], x); + v[13] = _mm_add_epi32(v[13], rnding); + v[13] = _mm_srai_epi32(v[13], bit); + + v[11] = _mm_mullo_epi32(u[11], cospim32); + x = _mm_mullo_epi32(u[12], cospi32); + v[11] = _mm_add_epi32(v[11], x); + v[11] = _mm_add_epi32(v[11], rnding); + v[11] = _mm_srai_epi32(v[11], bit); + + v[12] = _mm_mullo_epi32(u[11], cospi32); + x = _mm_mullo_epi32(u[12], cospim32); + v[12] = _mm_sub_epi32(v[12], x); + v[12] = _mm_add_epi32(v[12], rnding); + v[12] = _mm_srai_epi32(v[12], bit); + v[14] = u[14]; + v[15] = u[15]; + + // stage 3 + u[0] = _mm_add_epi32(v[0], v[3]); + u[3] = _mm_sub_epi32(v[0], v[3]); + u[1] = _mm_add_epi32(v[1], v[2]); + u[2] = _mm_sub_epi32(v[1], v[2]); + u[4] = v[4]; + + u[5] = _mm_mullo_epi32(v[5], cospim32); + x = _mm_mullo_epi32(v[6], cospi32); + u[5] = _mm_add_epi32(u[5], x); + u[5] = _mm_add_epi32(u[5], rnding); + u[5] = _mm_srai_epi32(u[5], bit); + + u[6] = _mm_mullo_epi32(v[5], cospi32); + x = _mm_mullo_epi32(v[6], cospim32); + u[6] = _mm_sub_epi32(u[6], x); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + u[7] = v[7]; + u[8] = _mm_add_epi32(v[8], v[11]); + u[11] = _mm_sub_epi32(v[8], v[11]); + u[9] = _mm_add_epi32(v[9], v[10]); + u[10] = _mm_sub_epi32(v[9], v[10]); + u[12] = _mm_sub_epi32(v[15], v[12]); + u[15] = _mm_add_epi32(v[15], v[12]); + u[13] = _mm_sub_epi32(v[14], v[13]); + u[14] = _mm_add_epi32(v[14], v[13]); + + // stage 4 + u[0] = _mm_mullo_epi32(u[0], cospi32); + u[1] = _mm_mullo_epi32(u[1], cospi32); + v[0] = _mm_add_epi32(u[0], u[1]); + v[0] = _mm_add_epi32(v[0], rnding); + v[0] = _mm_srai_epi32(v[0], bit); + + v[1] = _mm_sub_epi32(u[0], u[1]); + v[1] = _mm_add_epi32(v[1], rnding); + v[1] = _mm_srai_epi32(v[1], bit); + + v[2] = _mm_mullo_epi32(u[2], cospi48); + x = _mm_mullo_epi32(u[3], cospi16); + v[2] = _mm_add_epi32(v[2], x); + v[2] = _mm_add_epi32(v[2], rnding); + v[2] = _mm_srai_epi32(v[2], bit); + + v[3] = _mm_mullo_epi32(u[2], cospi16); + x = _mm_mullo_epi32(u[3], cospi48); + v[3] = _mm_sub_epi32(x, v[3]); + v[3] = _mm_add_epi32(v[3], rnding); + v[3] = _mm_srai_epi32(v[3], bit); + + v[4] = _mm_add_epi32(u[4], u[5]); + v[5] = _mm_sub_epi32(u[4], u[5]); + v[6] = _mm_sub_epi32(u[7], u[6]); + v[7] = _mm_add_epi32(u[7], u[6]); + v[8] = u[8]; + + v[9] = _mm_mullo_epi32(u[9], cospim16); + x = _mm_mullo_epi32(u[14], cospi48); + v[9] = _mm_add_epi32(v[9], x); + v[9] = _mm_add_epi32(v[9], rnding); + v[9] = _mm_srai_epi32(v[9], bit); + + v[14] = _mm_mullo_epi32(u[9], cospi48); + x = _mm_mullo_epi32(u[14], cospim16); + v[14] = _mm_sub_epi32(v[14], x); + v[14] = _mm_add_epi32(v[14], rnding); + v[14] = _mm_srai_epi32(v[14], bit); + + v[10] = _mm_mullo_epi32(u[10], cospim48); + x = _mm_mullo_epi32(u[13], cospim16); + v[10] = _mm_add_epi32(v[10], x); + v[10] = _mm_add_epi32(v[10], rnding); + v[10] = _mm_srai_epi32(v[10], bit); + + v[13] = _mm_mullo_epi32(u[10], cospim16); + x = _mm_mullo_epi32(u[13], cospim48); + v[13] = _mm_sub_epi32(v[13], x); + v[13] = _mm_add_epi32(v[13], rnding); + v[13] = _mm_srai_epi32(v[13], bit); + + v[11] = u[11]; + v[12] = u[12]; + v[15] = u[15]; + + // stage 5 + u[0] = v[0]; + u[1] = v[1]; + u[2] = v[2]; + u[3] = v[3]; + + u[4] = _mm_mullo_epi32(v[4], cospi56); + x = _mm_mullo_epi32(v[7], cospi8); + u[4] = _mm_add_epi32(u[4], x); + u[4] = _mm_add_epi32(u[4], rnding); + u[4] = _mm_srai_epi32(u[4], bit); + + u[7] = _mm_mullo_epi32(v[4], cospi8); + x = _mm_mullo_epi32(v[7], cospi56); + u[7] = _mm_sub_epi32(x, u[7]); + u[7] = _mm_add_epi32(u[7], rnding); + u[7] = _mm_srai_epi32(u[7], bit); + + u[5] = _mm_mullo_epi32(v[5], cospi24); + x = _mm_mullo_epi32(v[6], cospi40); + u[5] = _mm_add_epi32(u[5], x); + u[5] = _mm_add_epi32(u[5], rnding); + u[5] = _mm_srai_epi32(u[5], bit); + + u[6] = _mm_mullo_epi32(v[5], cospi40); + x = _mm_mullo_epi32(v[6], cospi24); + u[6] = _mm_sub_epi32(x, u[6]); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + u[8] = _mm_add_epi32(v[8], v[9]); + u[9] = _mm_sub_epi32(v[8], v[9]); + u[10] = _mm_sub_epi32(v[11], v[10]); + u[11] = _mm_add_epi32(v[11], v[10]); + u[12] = _mm_add_epi32(v[12], v[13]); + u[13] = _mm_sub_epi32(v[12], v[13]); + u[14] = _mm_sub_epi32(v[15], v[14]); + u[15] = _mm_add_epi32(v[15], v[14]); + + // stage 6 + v[0] = u[0]; + v[1] = u[1]; + v[2] = u[2]; + v[3] = u[3]; + v[4] = u[4]; + v[5] = u[5]; + v[6] = u[6]; + v[7] = u[7]; + + v[8] = _mm_mullo_epi32(u[8], cospi60); + x = _mm_mullo_epi32(u[15], cospi4); + v[8] = _mm_add_epi32(v[8], x); + v[8] = _mm_add_epi32(v[8], rnding); + v[8] = _mm_srai_epi32(v[8], bit); + + v[15] = _mm_mullo_epi32(u[8], cospi4); + x = _mm_mullo_epi32(u[15], cospi60); + v[15] = _mm_sub_epi32(x, v[15]); + v[15] = _mm_add_epi32(v[15], rnding); + v[15] = _mm_srai_epi32(v[15], bit); + + v[9] = _mm_mullo_epi32(u[9], cospi28); + x = _mm_mullo_epi32(u[14], cospi36); + v[9] = _mm_add_epi32(v[9], x); + v[9] = _mm_add_epi32(v[9], rnding); + v[9] = _mm_srai_epi32(v[9], bit); + + v[14] = _mm_mullo_epi32(u[9], cospi36); + x = _mm_mullo_epi32(u[14], cospi28); + v[14] = _mm_sub_epi32(x, v[14]); + v[14] = _mm_add_epi32(v[14], rnding); + v[14] = _mm_srai_epi32(v[14], bit); + + v[10] = _mm_mullo_epi32(u[10], cospi44); + x = _mm_mullo_epi32(u[13], cospi20); + v[10] = _mm_add_epi32(v[10], x); + v[10] = _mm_add_epi32(v[10], rnding); + v[10] = _mm_srai_epi32(v[10], bit); + + v[13] = _mm_mullo_epi32(u[10], cospi20); + x = _mm_mullo_epi32(u[13], cospi44); + v[13] = _mm_sub_epi32(x, v[13]); + v[13] = _mm_add_epi32(v[13], rnding); + v[13] = _mm_srai_epi32(v[13], bit); + + v[11] = _mm_mullo_epi32(u[11], cospi12); + x = _mm_mullo_epi32(u[12], cospi52); + v[11] = _mm_add_epi32(v[11], x); + v[11] = _mm_add_epi32(v[11], rnding); + v[11] = _mm_srai_epi32(v[11], bit); + + v[12] = _mm_mullo_epi32(u[11], cospi52); + x = _mm_mullo_epi32(u[12], cospi12); + v[12] = _mm_sub_epi32(x, v[12]); + v[12] = _mm_add_epi32(v[12], rnding); + v[12] = _mm_srai_epi32(v[12], bit); + + out[0 * col_num + col] = v[0]; + out[1 * col_num + col] = v[8]; + out[2 * col_num + col] = v[4]; + out[3 * col_num + col] = v[12]; + out[4 * col_num + col] = v[2]; + out[5 * col_num + col] = v[10]; + out[6 * col_num + col] = v[6]; + out[7 * col_num + col] = v[14]; + out[8 * col_num + col] = v[1]; + out[9 * col_num + col] = v[9]; + out[10 * col_num + col] = v[5]; + out[11 * col_num + col] = v[13]; + out[12 * col_num + col] = v[3]; + out[13 * col_num + col] = v[11]; + out[14 * col_num + col] = v[7]; + out[15 * col_num + col] = v[15]; + } +} + +static void fadst16x16_sse4_1(__m128i *in, __m128i *out, int bit) { + const int32_t *cospi = cospi_arr[bit - cos_bit_min]; + const __m128i cospi2 = _mm_set1_epi32(cospi[2]); + const __m128i cospi62 = _mm_set1_epi32(cospi[62]); + const __m128i cospi10 = _mm_set1_epi32(cospi[10]); + const __m128i cospi54 = _mm_set1_epi32(cospi[54]); + const __m128i cospi18 = _mm_set1_epi32(cospi[18]); + const __m128i cospi46 = _mm_set1_epi32(cospi[46]); + const __m128i cospi26 = _mm_set1_epi32(cospi[26]); + const __m128i cospi38 = _mm_set1_epi32(cospi[38]); + const __m128i cospi34 = _mm_set1_epi32(cospi[34]); + const __m128i cospi30 = _mm_set1_epi32(cospi[30]); + const __m128i cospi42 = _mm_set1_epi32(cospi[42]); + const __m128i cospi22 = _mm_set1_epi32(cospi[22]); + const __m128i cospi50 = _mm_set1_epi32(cospi[50]); + const __m128i cospi14 = _mm_set1_epi32(cospi[14]); + const __m128i cospi58 = _mm_set1_epi32(cospi[58]); + const __m128i cospi6 = _mm_set1_epi32(cospi[6]); + const __m128i cospi8 = _mm_set1_epi32(cospi[8]); + const __m128i cospi56 = _mm_set1_epi32(cospi[56]); + const __m128i cospi40 = _mm_set1_epi32(cospi[40]); + const __m128i cospi24 = _mm_set1_epi32(cospi[24]); + const __m128i cospim56 = _mm_set1_epi32(-cospi[56]); + const __m128i cospim24 = _mm_set1_epi32(-cospi[24]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i cospim48 = _mm_set1_epi32(-cospi[48]); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + __m128i u[16], v[16], x, y; + const int col_num = 4; + int col; + + // Calculate the column 0, 1, 2, 3 + for (col = 0; col < col_num; ++col) { + // stage 0 + // stage 1 + // stage 2 + v[0] = _mm_mullo_epi32(in[15 * col_num + col], cospi2); + x = _mm_mullo_epi32(in[0 * col_num + col], cospi62); + v[0] = _mm_add_epi32(v[0], x); + v[0] = _mm_add_epi32(v[0], rnding); + v[0] = _mm_srai_epi32(v[0], bit); + + v[1] = _mm_mullo_epi32(in[15 * col_num + col], cospi62); + x = _mm_mullo_epi32(in[0 * col_num + col], cospi2); + v[1] = _mm_sub_epi32(v[1], x); + v[1] = _mm_add_epi32(v[1], rnding); + v[1] = _mm_srai_epi32(v[1], bit); + + v[2] = _mm_mullo_epi32(in[13 * col_num + col], cospi10); + x = _mm_mullo_epi32(in[2 * col_num + col], cospi54); + v[2] = _mm_add_epi32(v[2], x); + v[2] = _mm_add_epi32(v[2], rnding); + v[2] = _mm_srai_epi32(v[2], bit); + + v[3] = _mm_mullo_epi32(in[13 * col_num + col], cospi54); + x = _mm_mullo_epi32(in[2 * col_num + col], cospi10); + v[3] = _mm_sub_epi32(v[3], x); + v[3] = _mm_add_epi32(v[3], rnding); + v[3] = _mm_srai_epi32(v[3], bit); + + v[4] = _mm_mullo_epi32(in[11 * col_num + col], cospi18); + x = _mm_mullo_epi32(in[4 * col_num + col], cospi46); + v[4] = _mm_add_epi32(v[4], x); + v[4] = _mm_add_epi32(v[4], rnding); + v[4] = _mm_srai_epi32(v[4], bit); + + v[5] = _mm_mullo_epi32(in[11 * col_num + col], cospi46); + x = _mm_mullo_epi32(in[4 * col_num + col], cospi18); + v[5] = _mm_sub_epi32(v[5], x); + v[5] = _mm_add_epi32(v[5], rnding); + v[5] = _mm_srai_epi32(v[5], bit); + + v[6] = _mm_mullo_epi32(in[9 * col_num + col], cospi26); + x = _mm_mullo_epi32(in[6 * col_num + col], cospi38); + v[6] = _mm_add_epi32(v[6], x); + v[6] = _mm_add_epi32(v[6], rnding); + v[6] = _mm_srai_epi32(v[6], bit); + + v[7] = _mm_mullo_epi32(in[9 * col_num + col], cospi38); + x = _mm_mullo_epi32(in[6 * col_num + col], cospi26); + v[7] = _mm_sub_epi32(v[7], x); + v[7] = _mm_add_epi32(v[7], rnding); + v[7] = _mm_srai_epi32(v[7], bit); + + v[8] = _mm_mullo_epi32(in[7 * col_num + col], cospi34); + x = _mm_mullo_epi32(in[8 * col_num + col], cospi30); + v[8] = _mm_add_epi32(v[8], x); + v[8] = _mm_add_epi32(v[8], rnding); + v[8] = _mm_srai_epi32(v[8], bit); + + v[9] = _mm_mullo_epi32(in[7 * col_num + col], cospi30); + x = _mm_mullo_epi32(in[8 * col_num + col], cospi34); + v[9] = _mm_sub_epi32(v[9], x); + v[9] = _mm_add_epi32(v[9], rnding); + v[9] = _mm_srai_epi32(v[9], bit); + + v[10] = _mm_mullo_epi32(in[5 * col_num + col], cospi42); + x = _mm_mullo_epi32(in[10 * col_num + col], cospi22); + v[10] = _mm_add_epi32(v[10], x); + v[10] = _mm_add_epi32(v[10], rnding); + v[10] = _mm_srai_epi32(v[10], bit); + + v[11] = _mm_mullo_epi32(in[5 * col_num + col], cospi22); + x = _mm_mullo_epi32(in[10 * col_num + col], cospi42); + v[11] = _mm_sub_epi32(v[11], x); + v[11] = _mm_add_epi32(v[11], rnding); + v[11] = _mm_srai_epi32(v[11], bit); + + v[12] = _mm_mullo_epi32(in[3 * col_num + col], cospi50); + x = _mm_mullo_epi32(in[12 * col_num + col], cospi14); + v[12] = _mm_add_epi32(v[12], x); + v[12] = _mm_add_epi32(v[12], rnding); + v[12] = _mm_srai_epi32(v[12], bit); + + v[13] = _mm_mullo_epi32(in[3 * col_num + col], cospi14); + x = _mm_mullo_epi32(in[12 * col_num + col], cospi50); + v[13] = _mm_sub_epi32(v[13], x); + v[13] = _mm_add_epi32(v[13], rnding); + v[13] = _mm_srai_epi32(v[13], bit); + + v[14] = _mm_mullo_epi32(in[1 * col_num + col], cospi58); + x = _mm_mullo_epi32(in[14 * col_num + col], cospi6); + v[14] = _mm_add_epi32(v[14], x); + v[14] = _mm_add_epi32(v[14], rnding); + v[14] = _mm_srai_epi32(v[14], bit); + + v[15] = _mm_mullo_epi32(in[1 * col_num + col], cospi6); + x = _mm_mullo_epi32(in[14 * col_num + col], cospi58); + v[15] = _mm_sub_epi32(v[15], x); + v[15] = _mm_add_epi32(v[15], rnding); + v[15] = _mm_srai_epi32(v[15], bit); + + // stage 3 + u[0] = _mm_add_epi32(v[0], v[8]); + u[8] = _mm_sub_epi32(v[0], v[8]); + u[1] = _mm_add_epi32(v[1], v[9]); + u[9] = _mm_sub_epi32(v[1], v[9]); + u[2] = _mm_add_epi32(v[2], v[10]); + u[10] = _mm_sub_epi32(v[2], v[10]); + u[3] = _mm_add_epi32(v[3], v[11]); + u[11] = _mm_sub_epi32(v[3], v[11]); + u[4] = _mm_add_epi32(v[4], v[12]); + u[12] = _mm_sub_epi32(v[4], v[12]); + u[5] = _mm_add_epi32(v[5], v[13]); + u[13] = _mm_sub_epi32(v[5], v[13]); + u[6] = _mm_add_epi32(v[6], v[14]); + u[14] = _mm_sub_epi32(v[6], v[14]); + u[7] = _mm_add_epi32(v[7], v[15]); + u[15] = _mm_sub_epi32(v[7], v[15]); + + // stage 4 + v[0] = u[0]; + v[1] = u[1]; + v[2] = u[2]; + v[3] = u[3]; + v[4] = u[4]; + v[5] = u[5]; + v[6] = u[6]; + v[7] = u[7]; + + v[8] = _mm_mullo_epi32(u[8], cospi8); + x = _mm_mullo_epi32(u[9], cospi56); + v[8] = _mm_add_epi32(v[8], x); + v[8] = _mm_add_epi32(v[8], rnding); + v[8] = _mm_srai_epi32(v[8], bit); + + v[9] = _mm_mullo_epi32(u[8], cospi56); + x = _mm_mullo_epi32(u[9], cospi8); + v[9] = _mm_sub_epi32(v[9], x); + v[9] = _mm_add_epi32(v[9], rnding); + v[9] = _mm_srai_epi32(v[9], bit); + + v[10] = _mm_mullo_epi32(u[10], cospi40); + x = _mm_mullo_epi32(u[11], cospi24); + v[10] = _mm_add_epi32(v[10], x); + v[10] = _mm_add_epi32(v[10], rnding); + v[10] = _mm_srai_epi32(v[10], bit); + + v[11] = _mm_mullo_epi32(u[10], cospi24); + x = _mm_mullo_epi32(u[11], cospi40); + v[11] = _mm_sub_epi32(v[11], x); + v[11] = _mm_add_epi32(v[11], rnding); + v[11] = _mm_srai_epi32(v[11], bit); + + v[12] = _mm_mullo_epi32(u[12], cospim56); + x = _mm_mullo_epi32(u[13], cospi8); + v[12] = _mm_add_epi32(v[12], x); + v[12] = _mm_add_epi32(v[12], rnding); + v[12] = _mm_srai_epi32(v[12], bit); + + v[13] = _mm_mullo_epi32(u[12], cospi8); + x = _mm_mullo_epi32(u[13], cospim56); + v[13] = _mm_sub_epi32(v[13], x); + v[13] = _mm_add_epi32(v[13], rnding); + v[13] = _mm_srai_epi32(v[13], bit); + + v[14] = _mm_mullo_epi32(u[14], cospim24); + x = _mm_mullo_epi32(u[15], cospi40); + v[14] = _mm_add_epi32(v[14], x); + v[14] = _mm_add_epi32(v[14], rnding); + v[14] = _mm_srai_epi32(v[14], bit); + + v[15] = _mm_mullo_epi32(u[14], cospi40); + x = _mm_mullo_epi32(u[15], cospim24); + v[15] = _mm_sub_epi32(v[15], x); + v[15] = _mm_add_epi32(v[15], rnding); + v[15] = _mm_srai_epi32(v[15], bit); + + // stage 5 + u[0] = _mm_add_epi32(v[0], v[4]); + u[4] = _mm_sub_epi32(v[0], v[4]); + u[1] = _mm_add_epi32(v[1], v[5]); + u[5] = _mm_sub_epi32(v[1], v[5]); + u[2] = _mm_add_epi32(v[2], v[6]); + u[6] = _mm_sub_epi32(v[2], v[6]); + u[3] = _mm_add_epi32(v[3], v[7]); + u[7] = _mm_sub_epi32(v[3], v[7]); + u[8] = _mm_add_epi32(v[8], v[12]); + u[12] = _mm_sub_epi32(v[8], v[12]); + u[9] = _mm_add_epi32(v[9], v[13]); + u[13] = _mm_sub_epi32(v[9], v[13]); + u[10] = _mm_add_epi32(v[10], v[14]); + u[14] = _mm_sub_epi32(v[10], v[14]); + u[11] = _mm_add_epi32(v[11], v[15]); + u[15] = _mm_sub_epi32(v[11], v[15]); + + // stage 6 + v[0] = u[0]; + v[1] = u[1]; + v[2] = u[2]; + v[3] = u[3]; + + v[4] = _mm_mullo_epi32(u[4], cospi16); + x = _mm_mullo_epi32(u[5], cospi48); + v[4] = _mm_add_epi32(v[4], x); + v[4] = _mm_add_epi32(v[4], rnding); + v[4] = _mm_srai_epi32(v[4], bit); + + v[5] = _mm_mullo_epi32(u[4], cospi48); + x = _mm_mullo_epi32(u[5], cospi16); + v[5] = _mm_sub_epi32(v[5], x); + v[5] = _mm_add_epi32(v[5], rnding); + v[5] = _mm_srai_epi32(v[5], bit); + + v[6] = _mm_mullo_epi32(u[6], cospim48); + x = _mm_mullo_epi32(u[7], cospi16); + v[6] = _mm_add_epi32(v[6], x); + v[6] = _mm_add_epi32(v[6], rnding); + v[6] = _mm_srai_epi32(v[6], bit); + + v[7] = _mm_mullo_epi32(u[6], cospi16); + x = _mm_mullo_epi32(u[7], cospim48); + v[7] = _mm_sub_epi32(v[7], x); + v[7] = _mm_add_epi32(v[7], rnding); + v[7] = _mm_srai_epi32(v[7], bit); + + v[8] = u[8]; + v[9] = u[9]; + v[10] = u[10]; + v[11] = u[11]; + + v[12] = _mm_mullo_epi32(u[12], cospi16); + x = _mm_mullo_epi32(u[13], cospi48); + v[12] = _mm_add_epi32(v[12], x); + v[12] = _mm_add_epi32(v[12], rnding); + v[12] = _mm_srai_epi32(v[12], bit); + + v[13] = _mm_mullo_epi32(u[12], cospi48); + x = _mm_mullo_epi32(u[13], cospi16); + v[13] = _mm_sub_epi32(v[13], x); + v[13] = _mm_add_epi32(v[13], rnding); + v[13] = _mm_srai_epi32(v[13], bit); + + v[14] = _mm_mullo_epi32(u[14], cospim48); + x = _mm_mullo_epi32(u[15], cospi16); + v[14] = _mm_add_epi32(v[14], x); + v[14] = _mm_add_epi32(v[14], rnding); + v[14] = _mm_srai_epi32(v[14], bit); + + v[15] = _mm_mullo_epi32(u[14], cospi16); + x = _mm_mullo_epi32(u[15], cospim48); + v[15] = _mm_sub_epi32(v[15], x); + v[15] = _mm_add_epi32(v[15], rnding); + v[15] = _mm_srai_epi32(v[15], bit); + + // stage 7 + u[0] = _mm_add_epi32(v[0], v[2]); + u[2] = _mm_sub_epi32(v[0], v[2]); + u[1] = _mm_add_epi32(v[1], v[3]); + u[3] = _mm_sub_epi32(v[1], v[3]); + u[4] = _mm_add_epi32(v[4], v[6]); + u[6] = _mm_sub_epi32(v[4], v[6]); + u[5] = _mm_add_epi32(v[5], v[7]); + u[7] = _mm_sub_epi32(v[5], v[7]); + u[8] = _mm_add_epi32(v[8], v[10]); + u[10] = _mm_sub_epi32(v[8], v[10]); + u[9] = _mm_add_epi32(v[9], v[11]); + u[11] = _mm_sub_epi32(v[9], v[11]); + u[12] = _mm_add_epi32(v[12], v[14]); + u[14] = _mm_sub_epi32(v[12], v[14]); + u[13] = _mm_add_epi32(v[13], v[15]); + u[15] = _mm_sub_epi32(v[13], v[15]); + + // stage 8 + v[0] = u[0]; + v[1] = u[1]; + + y = _mm_mullo_epi32(u[2], cospi32); + x = _mm_mullo_epi32(u[3], cospi32); + v[2] = _mm_add_epi32(y, x); + v[2] = _mm_add_epi32(v[2], rnding); + v[2] = _mm_srai_epi32(v[2], bit); + + v[3] = _mm_sub_epi32(y, x); + v[3] = _mm_add_epi32(v[3], rnding); + v[3] = _mm_srai_epi32(v[3], bit); + + v[4] = u[4]; + v[5] = u[5]; + + y = _mm_mullo_epi32(u[6], cospi32); + x = _mm_mullo_epi32(u[7], cospi32); + v[6] = _mm_add_epi32(y, x); + v[6] = _mm_add_epi32(v[6], rnding); + v[6] = _mm_srai_epi32(v[6], bit); + + v[7] = _mm_sub_epi32(y, x); + v[7] = _mm_add_epi32(v[7], rnding); + v[7] = _mm_srai_epi32(v[7], bit); + + v[8] = u[8]; + v[9] = u[9]; + + y = _mm_mullo_epi32(u[10], cospi32); + x = _mm_mullo_epi32(u[11], cospi32); + v[10] = _mm_add_epi32(y, x); + v[10] = _mm_add_epi32(v[10], rnding); + v[10] = _mm_srai_epi32(v[10], bit); + + v[11] = _mm_sub_epi32(y, x); + v[11] = _mm_add_epi32(v[11], rnding); + v[11] = _mm_srai_epi32(v[11], bit); + + v[12] = u[12]; + v[13] = u[13]; + + y = _mm_mullo_epi32(u[14], cospi32); + x = _mm_mullo_epi32(u[15], cospi32); + v[14] = _mm_add_epi32(y, x); + v[14] = _mm_add_epi32(v[14], rnding); + v[14] = _mm_srai_epi32(v[14], bit); + + v[15] = _mm_sub_epi32(y, x); + v[15] = _mm_add_epi32(v[15], rnding); + v[15] = _mm_srai_epi32(v[15], bit); + + // stage 9 + out[0 * col_num + col] = v[0]; + out[1 * col_num + col] = _mm_sub_epi32(_mm_set1_epi32(0), v[8]); + out[2 * col_num + col] = v[12]; + out[3 * col_num + col] = _mm_sub_epi32(_mm_set1_epi32(0), v[4]); + out[4 * col_num + col] = v[6]; + out[5 * col_num + col] = _mm_sub_epi32(_mm_set1_epi32(0), v[14]); + out[6 * col_num + col] = v[10]; + out[7 * col_num + col] = _mm_sub_epi32(_mm_set1_epi32(0), v[2]); + out[8 * col_num + col] = v[3]; + out[9 * col_num + col] = _mm_sub_epi32(_mm_set1_epi32(0), v[11]); + out[10 * col_num + col] = v[15]; + out[11 * col_num + col] = _mm_sub_epi32(_mm_set1_epi32(0), v[7]); + out[12 * col_num + col] = v[5]; + out[13 * col_num + col] = _mm_sub_epi32(_mm_set1_epi32(0), v[13]); + out[14 * col_num + col] = v[9]; + out[15 * col_num + col] = _mm_sub_epi32(_mm_set1_epi32(0), v[1]); + } +} + +static void col_txfm_16x16_rounding(__m128i *in, int shift) { + // Note: + // We split 16x16 rounding into 4 sections of 8x8 rounding, + // instead of 4 columns + col_txfm_8x8_rounding(&in[0], shift); + col_txfm_8x8_rounding(&in[16], shift); + col_txfm_8x8_rounding(&in[32], shift); + col_txfm_8x8_rounding(&in[48], shift); +} + +static void write_buffer_16x16(const __m128i *in, tran_low_t *output) { + const int size_8x8 = 16 * 4; + write_buffer_8x8(&in[0], output); + output += size_8x8; + write_buffer_8x8(&in[16], output); + output += size_8x8; + write_buffer_8x8(&in[32], output); + output += size_8x8; + write_buffer_8x8(&in[48], output); +} + +void av1_fwd_txfm2d_16x16_sse4_1(const int16_t *input, int32_t *coeff, + int stride, int tx_type, int bd) { + __m128i in[64], out[64]; + const TXFM_2D_CFG *cfg = NULL; + + switch (tx_type) { + case DCT_DCT: + cfg = &fwd_txfm_2d_cfg_dct_dct_16; + load_buffer_16x16(input, in, stride, 0, 0, cfg->shift[0]); + fdct16x16_sse4_1(in, out, cfg->cos_bit_col[0]); + col_txfm_16x16_rounding(out, -cfg->shift[1]); + transpose_16x16(out, in); + fdct16x16_sse4_1(in, out, cfg->cos_bit_row[0]); + transpose_16x16(out, in); + write_buffer_16x16(in, coeff); + break; + case ADST_DCT: + cfg = &fwd_txfm_2d_cfg_adst_dct_16; + load_buffer_16x16(input, in, stride, 0, 0, cfg->shift[0]); + fadst16x16_sse4_1(in, out, cfg->cos_bit_col[0]); + col_txfm_16x16_rounding(out, -cfg->shift[1]); + transpose_16x16(out, in); + fdct16x16_sse4_1(in, out, cfg->cos_bit_row[0]); + transpose_16x16(out, in); + write_buffer_16x16(in, coeff); + break; + case DCT_ADST: + cfg = &fwd_txfm_2d_cfg_dct_adst_16; + load_buffer_16x16(input, in, stride, 0, 0, cfg->shift[0]); + fdct16x16_sse4_1(in, out, cfg->cos_bit_col[0]); + col_txfm_16x16_rounding(out, -cfg->shift[1]); + transpose_16x16(out, in); + fadst16x16_sse4_1(in, out, cfg->cos_bit_row[0]); + transpose_16x16(out, in); + write_buffer_16x16(in, coeff); + break; + case ADST_ADST: + cfg = &fwd_txfm_2d_cfg_adst_adst_16; + load_buffer_16x16(input, in, stride, 0, 0, cfg->shift[0]); + fadst16x16_sse4_1(in, out, cfg->cos_bit_col[0]); + col_txfm_16x16_rounding(out, -cfg->shift[1]); + transpose_16x16(out, in); + fadst16x16_sse4_1(in, out, cfg->cos_bit_row[0]); + transpose_16x16(out, in); + write_buffer_16x16(in, coeff); + break; +#if CONFIG_EXT_TX + case FLIPADST_DCT: + cfg = &fwd_txfm_2d_cfg_adst_dct_16; + load_buffer_16x16(input, in, stride, 1, 0, cfg->shift[0]); + fadst16x16_sse4_1(in, out, cfg->cos_bit_col[0]); + col_txfm_16x16_rounding(out, -cfg->shift[1]); + transpose_16x16(out, in); + fdct16x16_sse4_1(in, out, cfg->cos_bit_row[0]); + transpose_16x16(out, in); + write_buffer_16x16(in, coeff); + break; + case DCT_FLIPADST: + cfg = &fwd_txfm_2d_cfg_dct_adst_16; + load_buffer_16x16(input, in, stride, 0, 1, cfg->shift[0]); + fdct16x16_sse4_1(in, out, cfg->cos_bit_col[0]); + col_txfm_16x16_rounding(out, -cfg->shift[1]); + transpose_16x16(out, in); + fadst16x16_sse4_1(in, out, cfg->cos_bit_row[0]); + transpose_16x16(out, in); + write_buffer_16x16(in, coeff); + break; + case FLIPADST_FLIPADST: + cfg = &fwd_txfm_2d_cfg_adst_adst_16; + load_buffer_16x16(input, in, stride, 1, 1, cfg->shift[0]); + fadst16x16_sse4_1(in, out, cfg->cos_bit_col[0]); + col_txfm_16x16_rounding(out, -cfg->shift[1]); + transpose_16x16(out, in); + fadst16x16_sse4_1(in, out, cfg->cos_bit_row[0]); + transpose_16x16(out, in); + write_buffer_16x16(in, coeff); + break; + case ADST_FLIPADST: + cfg = &fwd_txfm_2d_cfg_adst_adst_16; + load_buffer_16x16(input, in, stride, 0, 1, cfg->shift[0]); + fadst16x16_sse4_1(in, out, cfg->cos_bit_col[0]); + col_txfm_16x16_rounding(out, -cfg->shift[1]); + transpose_16x16(out, in); + fadst16x16_sse4_1(in, out, cfg->cos_bit_row[0]); + transpose_16x16(out, in); + write_buffer_16x16(in, coeff); + break; + case FLIPADST_ADST: + cfg = &fwd_txfm_2d_cfg_adst_adst_16; + load_buffer_16x16(input, in, stride, 1, 0, cfg->shift[0]); + fadst16x16_sse4_1(in, out, cfg->cos_bit_col[0]); + col_txfm_16x16_rounding(out, -cfg->shift[1]); + transpose_16x16(out, in); + fadst16x16_sse4_1(in, out, cfg->cos_bit_row[0]); + transpose_16x16(out, in); + write_buffer_16x16(in, coeff); + break; +#endif // CONFIG_EXT_TX + default: assert(0); + } + (void)bd; +} diff --git a/third_party/aom/av1/encoder/x86/hybrid_fwd_txfm_avx2.c b/third_party/aom/av1/encoder/x86/hybrid_fwd_txfm_avx2.c new file mode 100644 index 000000000..198e4e4c4 --- /dev/null +++ b/third_party/aom/av1/encoder/x86/hybrid_fwd_txfm_avx2.c @@ -0,0 +1,1678 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <immintrin.h> // avx2 + +#include "./av1_rtcd.h" +#include "./aom_dsp_rtcd.h" + +#include "aom_dsp/x86/fwd_txfm_avx2.h" +#include "aom_dsp/txfm_common.h" +#include "aom_dsp/x86/txfm_common_avx2.h" + +static int32_t get_16x16_sum(const int16_t *input, int stride) { + __m256i r0, r1, r2, r3, u0, u1; + __m256i zero = _mm256_setzero_si256(); + __m256i sum = _mm256_setzero_si256(); + const int16_t *blockBound = input + (stride << 4); + __m128i v0, v1; + + while (input < blockBound) { + r0 = _mm256_loadu_si256((__m256i const *)input); + r1 = _mm256_loadu_si256((__m256i const *)(input + stride)); + r2 = _mm256_loadu_si256((__m256i const *)(input + 2 * stride)); + r3 = _mm256_loadu_si256((__m256i const *)(input + 3 * stride)); + + u0 = _mm256_add_epi16(r0, r1); + u1 = _mm256_add_epi16(r2, r3); + sum = _mm256_add_epi16(sum, u0); + sum = _mm256_add_epi16(sum, u1); + + input += stride << 2; + } + + // unpack 16 int16_t into 2x8 int32_t + u0 = _mm256_unpacklo_epi16(zero, sum); + u1 = _mm256_unpackhi_epi16(zero, sum); + u0 = _mm256_srai_epi32(u0, 16); + u1 = _mm256_srai_epi32(u1, 16); + sum = _mm256_add_epi32(u0, u1); + + u0 = _mm256_srli_si256(sum, 8); + u1 = _mm256_add_epi32(sum, u0); + + v0 = _mm_add_epi32(_mm256_extracti128_si256(u1, 1), + _mm256_castsi256_si128(u1)); + v1 = _mm_srli_si128(v0, 4); + v0 = _mm_add_epi32(v0, v1); + return (int32_t)_mm_extract_epi32(v0, 0); +} + +void aom_fdct16x16_1_avx2(const int16_t *input, tran_low_t *output, + int stride) { + int32_t dc = get_16x16_sum(input, stride); + output[0] = (tran_low_t)(dc >> 1); + _mm256_zeroupper(); +} + +static INLINE void load_buffer_16x16(const int16_t *input, int stride, + int flipud, int fliplr, __m256i *in) { + if (!flipud) { + in[0] = _mm256_loadu_si256((const __m256i *)(input + 0 * stride)); + in[1] = _mm256_loadu_si256((const __m256i *)(input + 1 * stride)); + in[2] = _mm256_loadu_si256((const __m256i *)(input + 2 * stride)); + in[3] = _mm256_loadu_si256((const __m256i *)(input + 3 * stride)); + in[4] = _mm256_loadu_si256((const __m256i *)(input + 4 * stride)); + in[5] = _mm256_loadu_si256((const __m256i *)(input + 5 * stride)); + in[6] = _mm256_loadu_si256((const __m256i *)(input + 6 * stride)); + in[7] = _mm256_loadu_si256((const __m256i *)(input + 7 * stride)); + in[8] = _mm256_loadu_si256((const __m256i *)(input + 8 * stride)); + in[9] = _mm256_loadu_si256((const __m256i *)(input + 9 * stride)); + in[10] = _mm256_loadu_si256((const __m256i *)(input + 10 * stride)); + in[11] = _mm256_loadu_si256((const __m256i *)(input + 11 * stride)); + in[12] = _mm256_loadu_si256((const __m256i *)(input + 12 * stride)); + in[13] = _mm256_loadu_si256((const __m256i *)(input + 13 * stride)); + in[14] = _mm256_loadu_si256((const __m256i *)(input + 14 * stride)); + in[15] = _mm256_loadu_si256((const __m256i *)(input + 15 * stride)); + } else { + in[0] = _mm256_loadu_si256((const __m256i *)(input + 15 * stride)); + in[1] = _mm256_loadu_si256((const __m256i *)(input + 14 * stride)); + in[2] = _mm256_loadu_si256((const __m256i *)(input + 13 * stride)); + in[3] = _mm256_loadu_si256((const __m256i *)(input + 12 * stride)); + in[4] = _mm256_loadu_si256((const __m256i *)(input + 11 * stride)); + in[5] = _mm256_loadu_si256((const __m256i *)(input + 10 * stride)); + in[6] = _mm256_loadu_si256((const __m256i *)(input + 9 * stride)); + in[7] = _mm256_loadu_si256((const __m256i *)(input + 8 * stride)); + in[8] = _mm256_loadu_si256((const __m256i *)(input + 7 * stride)); + in[9] = _mm256_loadu_si256((const __m256i *)(input + 6 * stride)); + in[10] = _mm256_loadu_si256((const __m256i *)(input + 5 * stride)); + in[11] = _mm256_loadu_si256((const __m256i *)(input + 4 * stride)); + in[12] = _mm256_loadu_si256((const __m256i *)(input + 3 * stride)); + in[13] = _mm256_loadu_si256((const __m256i *)(input + 2 * stride)); + in[14] = _mm256_loadu_si256((const __m256i *)(input + 1 * stride)); + in[15] = _mm256_loadu_si256((const __m256i *)(input + 0 * stride)); + } + + if (fliplr) { + mm256_reverse_epi16(&in[0]); + mm256_reverse_epi16(&in[1]); + mm256_reverse_epi16(&in[2]); + mm256_reverse_epi16(&in[3]); + mm256_reverse_epi16(&in[4]); + mm256_reverse_epi16(&in[5]); + mm256_reverse_epi16(&in[6]); + mm256_reverse_epi16(&in[7]); + mm256_reverse_epi16(&in[8]); + mm256_reverse_epi16(&in[9]); + mm256_reverse_epi16(&in[10]); + mm256_reverse_epi16(&in[11]); + mm256_reverse_epi16(&in[12]); + mm256_reverse_epi16(&in[13]); + mm256_reverse_epi16(&in[14]); + mm256_reverse_epi16(&in[15]); + } + + in[0] = _mm256_slli_epi16(in[0], 2); + in[1] = _mm256_slli_epi16(in[1], 2); + in[2] = _mm256_slli_epi16(in[2], 2); + in[3] = _mm256_slli_epi16(in[3], 2); + in[4] = _mm256_slli_epi16(in[4], 2); + in[5] = _mm256_slli_epi16(in[5], 2); + in[6] = _mm256_slli_epi16(in[6], 2); + in[7] = _mm256_slli_epi16(in[7], 2); + in[8] = _mm256_slli_epi16(in[8], 2); + in[9] = _mm256_slli_epi16(in[9], 2); + in[10] = _mm256_slli_epi16(in[10], 2); + in[11] = _mm256_slli_epi16(in[11], 2); + in[12] = _mm256_slli_epi16(in[12], 2); + in[13] = _mm256_slli_epi16(in[13], 2); + in[14] = _mm256_slli_epi16(in[14], 2); + in[15] = _mm256_slli_epi16(in[15], 2); +} + +static INLINE void write_buffer_16x16(const __m256i *in, tran_low_t *output) { + int i; + for (i = 0; i < 16; ++i) { + storeu_output_avx2(&in[i], output + (i << 4)); + } +} + +static void right_shift_16x16(__m256i *in) { + const __m256i one = _mm256_set1_epi16(1); + __m256i s0 = _mm256_srai_epi16(in[0], 15); + __m256i s1 = _mm256_srai_epi16(in[1], 15); + __m256i s2 = _mm256_srai_epi16(in[2], 15); + __m256i s3 = _mm256_srai_epi16(in[3], 15); + __m256i s4 = _mm256_srai_epi16(in[4], 15); + __m256i s5 = _mm256_srai_epi16(in[5], 15); + __m256i s6 = _mm256_srai_epi16(in[6], 15); + __m256i s7 = _mm256_srai_epi16(in[7], 15); + __m256i s8 = _mm256_srai_epi16(in[8], 15); + __m256i s9 = _mm256_srai_epi16(in[9], 15); + __m256i s10 = _mm256_srai_epi16(in[10], 15); + __m256i s11 = _mm256_srai_epi16(in[11], 15); + __m256i s12 = _mm256_srai_epi16(in[12], 15); + __m256i s13 = _mm256_srai_epi16(in[13], 15); + __m256i s14 = _mm256_srai_epi16(in[14], 15); + __m256i s15 = _mm256_srai_epi16(in[15], 15); + + in[0] = _mm256_add_epi16(in[0], one); + in[1] = _mm256_add_epi16(in[1], one); + in[2] = _mm256_add_epi16(in[2], one); + in[3] = _mm256_add_epi16(in[3], one); + in[4] = _mm256_add_epi16(in[4], one); + in[5] = _mm256_add_epi16(in[5], one); + in[6] = _mm256_add_epi16(in[6], one); + in[7] = _mm256_add_epi16(in[7], one); + in[8] = _mm256_add_epi16(in[8], one); + in[9] = _mm256_add_epi16(in[9], one); + in[10] = _mm256_add_epi16(in[10], one); + in[11] = _mm256_add_epi16(in[11], one); + in[12] = _mm256_add_epi16(in[12], one); + in[13] = _mm256_add_epi16(in[13], one); + in[14] = _mm256_add_epi16(in[14], one); + in[15] = _mm256_add_epi16(in[15], one); + + in[0] = _mm256_sub_epi16(in[0], s0); + in[1] = _mm256_sub_epi16(in[1], s1); + in[2] = _mm256_sub_epi16(in[2], s2); + in[3] = _mm256_sub_epi16(in[3], s3); + in[4] = _mm256_sub_epi16(in[4], s4); + in[5] = _mm256_sub_epi16(in[5], s5); + in[6] = _mm256_sub_epi16(in[6], s6); + in[7] = _mm256_sub_epi16(in[7], s7); + in[8] = _mm256_sub_epi16(in[8], s8); + in[9] = _mm256_sub_epi16(in[9], s9); + in[10] = _mm256_sub_epi16(in[10], s10); + in[11] = _mm256_sub_epi16(in[11], s11); + in[12] = _mm256_sub_epi16(in[12], s12); + in[13] = _mm256_sub_epi16(in[13], s13); + in[14] = _mm256_sub_epi16(in[14], s14); + in[15] = _mm256_sub_epi16(in[15], s15); + + in[0] = _mm256_srai_epi16(in[0], 2); + in[1] = _mm256_srai_epi16(in[1], 2); + in[2] = _mm256_srai_epi16(in[2], 2); + in[3] = _mm256_srai_epi16(in[3], 2); + in[4] = _mm256_srai_epi16(in[4], 2); + in[5] = _mm256_srai_epi16(in[5], 2); + in[6] = _mm256_srai_epi16(in[6], 2); + in[7] = _mm256_srai_epi16(in[7], 2); + in[8] = _mm256_srai_epi16(in[8], 2); + in[9] = _mm256_srai_epi16(in[9], 2); + in[10] = _mm256_srai_epi16(in[10], 2); + in[11] = _mm256_srai_epi16(in[11], 2); + in[12] = _mm256_srai_epi16(in[12], 2); + in[13] = _mm256_srai_epi16(in[13], 2); + in[14] = _mm256_srai_epi16(in[14], 2); + in[15] = _mm256_srai_epi16(in[15], 2); +} + +static void fdct16_avx2(__m256i *in) { + // sequence: cospi_L_H = pairs(L, H) and L first + const __m256i cospi_p16_m16 = pair256_set_epi16(cospi_16_64, -cospi_16_64); + const __m256i cospi_p16_p16 = pair256_set_epi16(cospi_16_64, cospi_16_64); + const __m256i cospi_p24_p08 = pair256_set_epi16(cospi_24_64, cospi_8_64); + const __m256i cospi_m08_p24 = pair256_set_epi16(-cospi_8_64, cospi_24_64); + const __m256i cospi_m24_m08 = pair256_set_epi16(-cospi_24_64, -cospi_8_64); + + const __m256i cospi_p28_p04 = pair256_set_epi16(cospi_28_64, cospi_4_64); + const __m256i cospi_m04_p28 = pair256_set_epi16(-cospi_4_64, cospi_28_64); + const __m256i cospi_p12_p20 = pair256_set_epi16(cospi_12_64, cospi_20_64); + const __m256i cospi_m20_p12 = pair256_set_epi16(-cospi_20_64, cospi_12_64); + + const __m256i cospi_p30_p02 = pair256_set_epi16(cospi_30_64, cospi_2_64); + const __m256i cospi_m02_p30 = pair256_set_epi16(-cospi_2_64, cospi_30_64); + + const __m256i cospi_p14_p18 = pair256_set_epi16(cospi_14_64, cospi_18_64); + const __m256i cospi_m18_p14 = pair256_set_epi16(-cospi_18_64, cospi_14_64); + + const __m256i cospi_p22_p10 = pair256_set_epi16(cospi_22_64, cospi_10_64); + const __m256i cospi_m10_p22 = pair256_set_epi16(-cospi_10_64, cospi_22_64); + + const __m256i cospi_p06_p26 = pair256_set_epi16(cospi_6_64, cospi_26_64); + const __m256i cospi_m26_p06 = pair256_set_epi16(-cospi_26_64, cospi_6_64); + + __m256i u0, u1, u2, u3, u4, u5, u6, u7; + __m256i s0, s1, s2, s3, s4, s5, s6, s7; + __m256i t0, t1, t2, t3, t4, t5, t6, t7; + __m256i v0, v1, v2, v3; + __m256i x0, x1; + + // 0, 4, 8, 12 + u0 = _mm256_add_epi16(in[0], in[15]); + u1 = _mm256_add_epi16(in[1], in[14]); + u2 = _mm256_add_epi16(in[2], in[13]); + u3 = _mm256_add_epi16(in[3], in[12]); + u4 = _mm256_add_epi16(in[4], in[11]); + u5 = _mm256_add_epi16(in[5], in[10]); + u6 = _mm256_add_epi16(in[6], in[9]); + u7 = _mm256_add_epi16(in[7], in[8]); + + s0 = _mm256_add_epi16(u0, u7); + s1 = _mm256_add_epi16(u1, u6); + s2 = _mm256_add_epi16(u2, u5); + s3 = _mm256_add_epi16(u3, u4); + + // 0, 8 + v0 = _mm256_add_epi16(s0, s3); + v1 = _mm256_add_epi16(s1, s2); + + x0 = _mm256_unpacklo_epi16(v0, v1); + x1 = _mm256_unpackhi_epi16(v0, v1); + + t0 = butter_fly(x0, x1, cospi_p16_p16); + t1 = butter_fly(x0, x1, cospi_p16_m16); + + // 4, 12 + v0 = _mm256_sub_epi16(s1, s2); + v1 = _mm256_sub_epi16(s0, s3); + + x0 = _mm256_unpacklo_epi16(v0, v1); + x1 = _mm256_unpackhi_epi16(v0, v1); + + t2 = butter_fly(x0, x1, cospi_p24_p08); + t3 = butter_fly(x0, x1, cospi_m08_p24); + + // 2, 6, 10, 14 + s0 = _mm256_sub_epi16(u3, u4); + s1 = _mm256_sub_epi16(u2, u5); + s2 = _mm256_sub_epi16(u1, u6); + s3 = _mm256_sub_epi16(u0, u7); + + v0 = s0; // output[4] + v3 = s3; // output[7] + + x0 = _mm256_unpacklo_epi16(s2, s1); + x1 = _mm256_unpackhi_epi16(s2, s1); + + v2 = butter_fly(x0, x1, cospi_p16_p16); // output[5] + v1 = butter_fly(x0, x1, cospi_p16_m16); // output[6] + + s0 = _mm256_add_epi16(v0, v1); // step[4] + s1 = _mm256_sub_epi16(v0, v1); // step[5] + s2 = _mm256_sub_epi16(v3, v2); // step[6] + s3 = _mm256_add_epi16(v3, v2); // step[7] + + // 2, 14 + x0 = _mm256_unpacklo_epi16(s0, s3); + x1 = _mm256_unpackhi_epi16(s0, s3); + + t4 = butter_fly(x0, x1, cospi_p28_p04); + t5 = butter_fly(x0, x1, cospi_m04_p28); + + // 10, 6 + x0 = _mm256_unpacklo_epi16(s1, s2); + x1 = _mm256_unpackhi_epi16(s1, s2); + t6 = butter_fly(x0, x1, cospi_p12_p20); + t7 = butter_fly(x0, x1, cospi_m20_p12); + + // 1, 3, 5, 7, 9, 11, 13, 15 + s0 = _mm256_sub_epi16(in[7], in[8]); // step[8] + s1 = _mm256_sub_epi16(in[6], in[9]); // step[9] + u2 = _mm256_sub_epi16(in[5], in[10]); + u3 = _mm256_sub_epi16(in[4], in[11]); + u4 = _mm256_sub_epi16(in[3], in[12]); + u5 = _mm256_sub_epi16(in[2], in[13]); + s6 = _mm256_sub_epi16(in[1], in[14]); // step[14] + s7 = _mm256_sub_epi16(in[0], in[15]); // step[15] + + in[0] = t0; + in[8] = t1; + in[4] = t2; + in[12] = t3; + in[2] = t4; + in[14] = t5; + in[10] = t6; + in[6] = t7; + + x0 = _mm256_unpacklo_epi16(u5, u2); + x1 = _mm256_unpackhi_epi16(u5, u2); + + s2 = butter_fly(x0, x1, cospi_p16_p16); // step[13] + s5 = butter_fly(x0, x1, cospi_p16_m16); // step[10] + + x0 = _mm256_unpacklo_epi16(u4, u3); + x1 = _mm256_unpackhi_epi16(u4, u3); + + s3 = butter_fly(x0, x1, cospi_p16_p16); // step[12] + s4 = butter_fly(x0, x1, cospi_p16_m16); // step[11] + + u0 = _mm256_add_epi16(s0, s4); // output[8] + u1 = _mm256_add_epi16(s1, s5); + u2 = _mm256_sub_epi16(s1, s5); + u3 = _mm256_sub_epi16(s0, s4); + u4 = _mm256_sub_epi16(s7, s3); + u5 = _mm256_sub_epi16(s6, s2); + u6 = _mm256_add_epi16(s6, s2); + u7 = _mm256_add_epi16(s7, s3); + + // stage 4 + s0 = u0; + s3 = u3; + s4 = u4; + s7 = u7; + + x0 = _mm256_unpacklo_epi16(u1, u6); + x1 = _mm256_unpackhi_epi16(u1, u6); + + s1 = butter_fly(x0, x1, cospi_m08_p24); + s6 = butter_fly(x0, x1, cospi_p24_p08); + + x0 = _mm256_unpacklo_epi16(u2, u5); + x1 = _mm256_unpackhi_epi16(u2, u5); + + s2 = butter_fly(x0, x1, cospi_m24_m08); + s5 = butter_fly(x0, x1, cospi_m08_p24); + + // stage 5 + u0 = _mm256_add_epi16(s0, s1); + u1 = _mm256_sub_epi16(s0, s1); + u2 = _mm256_sub_epi16(s3, s2); + u3 = _mm256_add_epi16(s3, s2); + u4 = _mm256_add_epi16(s4, s5); + u5 = _mm256_sub_epi16(s4, s5); + u6 = _mm256_sub_epi16(s7, s6); + u7 = _mm256_add_epi16(s7, s6); + + // stage 6 + x0 = _mm256_unpacklo_epi16(u0, u7); + x1 = _mm256_unpackhi_epi16(u0, u7); + in[1] = butter_fly(x0, x1, cospi_p30_p02); + in[15] = butter_fly(x0, x1, cospi_m02_p30); + + x0 = _mm256_unpacklo_epi16(u1, u6); + x1 = _mm256_unpackhi_epi16(u1, u6); + in[9] = butter_fly(x0, x1, cospi_p14_p18); + in[7] = butter_fly(x0, x1, cospi_m18_p14); + + x0 = _mm256_unpacklo_epi16(u2, u5); + x1 = _mm256_unpackhi_epi16(u2, u5); + in[5] = butter_fly(x0, x1, cospi_p22_p10); + in[11] = butter_fly(x0, x1, cospi_m10_p22); + + x0 = _mm256_unpacklo_epi16(u3, u4); + x1 = _mm256_unpackhi_epi16(u3, u4); + in[13] = butter_fly(x0, x1, cospi_p06_p26); + in[3] = butter_fly(x0, x1, cospi_m26_p06); +} + +void fadst16_avx2(__m256i *in) { + const __m256i cospi_p01_p31 = pair256_set_epi16(cospi_1_64, cospi_31_64); + const __m256i cospi_p31_m01 = pair256_set_epi16(cospi_31_64, -cospi_1_64); + const __m256i cospi_p05_p27 = pair256_set_epi16(cospi_5_64, cospi_27_64); + const __m256i cospi_p27_m05 = pair256_set_epi16(cospi_27_64, -cospi_5_64); + const __m256i cospi_p09_p23 = pair256_set_epi16(cospi_9_64, cospi_23_64); + const __m256i cospi_p23_m09 = pair256_set_epi16(cospi_23_64, -cospi_9_64); + const __m256i cospi_p13_p19 = pair256_set_epi16(cospi_13_64, cospi_19_64); + const __m256i cospi_p19_m13 = pair256_set_epi16(cospi_19_64, -cospi_13_64); + const __m256i cospi_p17_p15 = pair256_set_epi16(cospi_17_64, cospi_15_64); + const __m256i cospi_p15_m17 = pair256_set_epi16(cospi_15_64, -cospi_17_64); + const __m256i cospi_p21_p11 = pair256_set_epi16(cospi_21_64, cospi_11_64); + const __m256i cospi_p11_m21 = pair256_set_epi16(cospi_11_64, -cospi_21_64); + const __m256i cospi_p25_p07 = pair256_set_epi16(cospi_25_64, cospi_7_64); + const __m256i cospi_p07_m25 = pair256_set_epi16(cospi_7_64, -cospi_25_64); + const __m256i cospi_p29_p03 = pair256_set_epi16(cospi_29_64, cospi_3_64); + const __m256i cospi_p03_m29 = pair256_set_epi16(cospi_3_64, -cospi_29_64); + const __m256i cospi_p04_p28 = pair256_set_epi16(cospi_4_64, cospi_28_64); + const __m256i cospi_p28_m04 = pair256_set_epi16(cospi_28_64, -cospi_4_64); + const __m256i cospi_p20_p12 = pair256_set_epi16(cospi_20_64, cospi_12_64); + const __m256i cospi_p12_m20 = pair256_set_epi16(cospi_12_64, -cospi_20_64); + const __m256i cospi_m28_p04 = pair256_set_epi16(-cospi_28_64, cospi_4_64); + const __m256i cospi_m12_p20 = pair256_set_epi16(-cospi_12_64, cospi_20_64); + const __m256i cospi_p08_p24 = pair256_set_epi16(cospi_8_64, cospi_24_64); + const __m256i cospi_p24_m08 = pair256_set_epi16(cospi_24_64, -cospi_8_64); + const __m256i cospi_m24_p08 = pair256_set_epi16(-cospi_24_64, cospi_8_64); + const __m256i cospi_m16_m16 = _mm256_set1_epi16((int16_t)-cospi_16_64); + const __m256i cospi_p16_p16 = _mm256_set1_epi16((int16_t)cospi_16_64); + const __m256i cospi_p16_m16 = pair256_set_epi16(cospi_16_64, -cospi_16_64); + const __m256i cospi_m16_p16 = pair256_set_epi16(-cospi_16_64, cospi_16_64); + const __m256i zero = _mm256_setzero_si256(); + const __m256i dct_rounding = _mm256_set1_epi32(DCT_CONST_ROUNDING); + __m256i s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15; + __m256i x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15; + __m256i u0, u1, u2, u3, u4, u5, u6, u7, u8, u9, u10, u11, u12, u13, u14, u15; + __m256i v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15; + __m256i y0, y1; + + // stage 1, s takes low 256 bits; x takes high 256 bits + y0 = _mm256_unpacklo_epi16(in[15], in[0]); + y1 = _mm256_unpackhi_epi16(in[15], in[0]); + s0 = _mm256_madd_epi16(y0, cospi_p01_p31); + x0 = _mm256_madd_epi16(y1, cospi_p01_p31); + s1 = _mm256_madd_epi16(y0, cospi_p31_m01); + x1 = _mm256_madd_epi16(y1, cospi_p31_m01); + + y0 = _mm256_unpacklo_epi16(in[13], in[2]); + y1 = _mm256_unpackhi_epi16(in[13], in[2]); + s2 = _mm256_madd_epi16(y0, cospi_p05_p27); + x2 = _mm256_madd_epi16(y1, cospi_p05_p27); + s3 = _mm256_madd_epi16(y0, cospi_p27_m05); + x3 = _mm256_madd_epi16(y1, cospi_p27_m05); + + y0 = _mm256_unpacklo_epi16(in[11], in[4]); + y1 = _mm256_unpackhi_epi16(in[11], in[4]); + s4 = _mm256_madd_epi16(y0, cospi_p09_p23); + x4 = _mm256_madd_epi16(y1, cospi_p09_p23); + s5 = _mm256_madd_epi16(y0, cospi_p23_m09); + x5 = _mm256_madd_epi16(y1, cospi_p23_m09); + + y0 = _mm256_unpacklo_epi16(in[9], in[6]); + y1 = _mm256_unpackhi_epi16(in[9], in[6]); + s6 = _mm256_madd_epi16(y0, cospi_p13_p19); + x6 = _mm256_madd_epi16(y1, cospi_p13_p19); + s7 = _mm256_madd_epi16(y0, cospi_p19_m13); + x7 = _mm256_madd_epi16(y1, cospi_p19_m13); + + y0 = _mm256_unpacklo_epi16(in[7], in[8]); + y1 = _mm256_unpackhi_epi16(in[7], in[8]); + s8 = _mm256_madd_epi16(y0, cospi_p17_p15); + x8 = _mm256_madd_epi16(y1, cospi_p17_p15); + s9 = _mm256_madd_epi16(y0, cospi_p15_m17); + x9 = _mm256_madd_epi16(y1, cospi_p15_m17); + + y0 = _mm256_unpacklo_epi16(in[5], in[10]); + y1 = _mm256_unpackhi_epi16(in[5], in[10]); + s10 = _mm256_madd_epi16(y0, cospi_p21_p11); + x10 = _mm256_madd_epi16(y1, cospi_p21_p11); + s11 = _mm256_madd_epi16(y0, cospi_p11_m21); + x11 = _mm256_madd_epi16(y1, cospi_p11_m21); + + y0 = _mm256_unpacklo_epi16(in[3], in[12]); + y1 = _mm256_unpackhi_epi16(in[3], in[12]); + s12 = _mm256_madd_epi16(y0, cospi_p25_p07); + x12 = _mm256_madd_epi16(y1, cospi_p25_p07); + s13 = _mm256_madd_epi16(y0, cospi_p07_m25); + x13 = _mm256_madd_epi16(y1, cospi_p07_m25); + + y0 = _mm256_unpacklo_epi16(in[1], in[14]); + y1 = _mm256_unpackhi_epi16(in[1], in[14]); + s14 = _mm256_madd_epi16(y0, cospi_p29_p03); + x14 = _mm256_madd_epi16(y1, cospi_p29_p03); + s15 = _mm256_madd_epi16(y0, cospi_p03_m29); + x15 = _mm256_madd_epi16(y1, cospi_p03_m29); + + // u takes low 256 bits; v takes high 256 bits + u0 = _mm256_add_epi32(s0, s8); + u1 = _mm256_add_epi32(s1, s9); + u2 = _mm256_add_epi32(s2, s10); + u3 = _mm256_add_epi32(s3, s11); + u4 = _mm256_add_epi32(s4, s12); + u5 = _mm256_add_epi32(s5, s13); + u6 = _mm256_add_epi32(s6, s14); + u7 = _mm256_add_epi32(s7, s15); + + u8 = _mm256_sub_epi32(s0, s8); + u9 = _mm256_sub_epi32(s1, s9); + u10 = _mm256_sub_epi32(s2, s10); + u11 = _mm256_sub_epi32(s3, s11); + u12 = _mm256_sub_epi32(s4, s12); + u13 = _mm256_sub_epi32(s5, s13); + u14 = _mm256_sub_epi32(s6, s14); + u15 = _mm256_sub_epi32(s7, s15); + + v0 = _mm256_add_epi32(x0, x8); + v1 = _mm256_add_epi32(x1, x9); + v2 = _mm256_add_epi32(x2, x10); + v3 = _mm256_add_epi32(x3, x11); + v4 = _mm256_add_epi32(x4, x12); + v5 = _mm256_add_epi32(x5, x13); + v6 = _mm256_add_epi32(x6, x14); + v7 = _mm256_add_epi32(x7, x15); + + v8 = _mm256_sub_epi32(x0, x8); + v9 = _mm256_sub_epi32(x1, x9); + v10 = _mm256_sub_epi32(x2, x10); + v11 = _mm256_sub_epi32(x3, x11); + v12 = _mm256_sub_epi32(x4, x12); + v13 = _mm256_sub_epi32(x5, x13); + v14 = _mm256_sub_epi32(x6, x14); + v15 = _mm256_sub_epi32(x7, x15); + + // low 256 bits rounding + u8 = _mm256_add_epi32(u8, dct_rounding); + u9 = _mm256_add_epi32(u9, dct_rounding); + u10 = _mm256_add_epi32(u10, dct_rounding); + u11 = _mm256_add_epi32(u11, dct_rounding); + u12 = _mm256_add_epi32(u12, dct_rounding); + u13 = _mm256_add_epi32(u13, dct_rounding); + u14 = _mm256_add_epi32(u14, dct_rounding); + u15 = _mm256_add_epi32(u15, dct_rounding); + + u8 = _mm256_srai_epi32(u8, DCT_CONST_BITS); + u9 = _mm256_srai_epi32(u9, DCT_CONST_BITS); + u10 = _mm256_srai_epi32(u10, DCT_CONST_BITS); + u11 = _mm256_srai_epi32(u11, DCT_CONST_BITS); + u12 = _mm256_srai_epi32(u12, DCT_CONST_BITS); + u13 = _mm256_srai_epi32(u13, DCT_CONST_BITS); + u14 = _mm256_srai_epi32(u14, DCT_CONST_BITS); + u15 = _mm256_srai_epi32(u15, DCT_CONST_BITS); + + // high 256 bits rounding + v8 = _mm256_add_epi32(v8, dct_rounding); + v9 = _mm256_add_epi32(v9, dct_rounding); + v10 = _mm256_add_epi32(v10, dct_rounding); + v11 = _mm256_add_epi32(v11, dct_rounding); + v12 = _mm256_add_epi32(v12, dct_rounding); + v13 = _mm256_add_epi32(v13, dct_rounding); + v14 = _mm256_add_epi32(v14, dct_rounding); + v15 = _mm256_add_epi32(v15, dct_rounding); + + v8 = _mm256_srai_epi32(v8, DCT_CONST_BITS); + v9 = _mm256_srai_epi32(v9, DCT_CONST_BITS); + v10 = _mm256_srai_epi32(v10, DCT_CONST_BITS); + v11 = _mm256_srai_epi32(v11, DCT_CONST_BITS); + v12 = _mm256_srai_epi32(v12, DCT_CONST_BITS); + v13 = _mm256_srai_epi32(v13, DCT_CONST_BITS); + v14 = _mm256_srai_epi32(v14, DCT_CONST_BITS); + v15 = _mm256_srai_epi32(v15, DCT_CONST_BITS); + + // Saturation pack 32-bit to 16-bit + x8 = _mm256_packs_epi32(u8, v8); + x9 = _mm256_packs_epi32(u9, v9); + x10 = _mm256_packs_epi32(u10, v10); + x11 = _mm256_packs_epi32(u11, v11); + x12 = _mm256_packs_epi32(u12, v12); + x13 = _mm256_packs_epi32(u13, v13); + x14 = _mm256_packs_epi32(u14, v14); + x15 = _mm256_packs_epi32(u15, v15); + + // stage 2 + y0 = _mm256_unpacklo_epi16(x8, x9); + y1 = _mm256_unpackhi_epi16(x8, x9); + s8 = _mm256_madd_epi16(y0, cospi_p04_p28); + x8 = _mm256_madd_epi16(y1, cospi_p04_p28); + s9 = _mm256_madd_epi16(y0, cospi_p28_m04); + x9 = _mm256_madd_epi16(y1, cospi_p28_m04); + + y0 = _mm256_unpacklo_epi16(x10, x11); + y1 = _mm256_unpackhi_epi16(x10, x11); + s10 = _mm256_madd_epi16(y0, cospi_p20_p12); + x10 = _mm256_madd_epi16(y1, cospi_p20_p12); + s11 = _mm256_madd_epi16(y0, cospi_p12_m20); + x11 = _mm256_madd_epi16(y1, cospi_p12_m20); + + y0 = _mm256_unpacklo_epi16(x12, x13); + y1 = _mm256_unpackhi_epi16(x12, x13); + s12 = _mm256_madd_epi16(y0, cospi_m28_p04); + x12 = _mm256_madd_epi16(y1, cospi_m28_p04); + s13 = _mm256_madd_epi16(y0, cospi_p04_p28); + x13 = _mm256_madd_epi16(y1, cospi_p04_p28); + + y0 = _mm256_unpacklo_epi16(x14, x15); + y1 = _mm256_unpackhi_epi16(x14, x15); + s14 = _mm256_madd_epi16(y0, cospi_m12_p20); + x14 = _mm256_madd_epi16(y1, cospi_m12_p20); + s15 = _mm256_madd_epi16(y0, cospi_p20_p12); + x15 = _mm256_madd_epi16(y1, cospi_p20_p12); + + x0 = _mm256_add_epi32(u0, u4); + s0 = _mm256_add_epi32(v0, v4); + x1 = _mm256_add_epi32(u1, u5); + s1 = _mm256_add_epi32(v1, v5); + x2 = _mm256_add_epi32(u2, u6); + s2 = _mm256_add_epi32(v2, v6); + x3 = _mm256_add_epi32(u3, u7); + s3 = _mm256_add_epi32(v3, v7); + + v8 = _mm256_sub_epi32(u0, u4); + v9 = _mm256_sub_epi32(v0, v4); + v10 = _mm256_sub_epi32(u1, u5); + v11 = _mm256_sub_epi32(v1, v5); + v12 = _mm256_sub_epi32(u2, u6); + v13 = _mm256_sub_epi32(v2, v6); + v14 = _mm256_sub_epi32(u3, u7); + v15 = _mm256_sub_epi32(v3, v7); + + v8 = _mm256_add_epi32(v8, dct_rounding); + v9 = _mm256_add_epi32(v9, dct_rounding); + v10 = _mm256_add_epi32(v10, dct_rounding); + v11 = _mm256_add_epi32(v11, dct_rounding); + v12 = _mm256_add_epi32(v12, dct_rounding); + v13 = _mm256_add_epi32(v13, dct_rounding); + v14 = _mm256_add_epi32(v14, dct_rounding); + v15 = _mm256_add_epi32(v15, dct_rounding); + + v8 = _mm256_srai_epi32(v8, DCT_CONST_BITS); + v9 = _mm256_srai_epi32(v9, DCT_CONST_BITS); + v10 = _mm256_srai_epi32(v10, DCT_CONST_BITS); + v11 = _mm256_srai_epi32(v11, DCT_CONST_BITS); + v12 = _mm256_srai_epi32(v12, DCT_CONST_BITS); + v13 = _mm256_srai_epi32(v13, DCT_CONST_BITS); + v14 = _mm256_srai_epi32(v14, DCT_CONST_BITS); + v15 = _mm256_srai_epi32(v15, DCT_CONST_BITS); + + x4 = _mm256_packs_epi32(v8, v9); + x5 = _mm256_packs_epi32(v10, v11); + x6 = _mm256_packs_epi32(v12, v13); + x7 = _mm256_packs_epi32(v14, v15); + + u8 = _mm256_add_epi32(s8, s12); + u9 = _mm256_add_epi32(s9, s13); + u10 = _mm256_add_epi32(s10, s14); + u11 = _mm256_add_epi32(s11, s15); + u12 = _mm256_sub_epi32(s8, s12); + u13 = _mm256_sub_epi32(s9, s13); + u14 = _mm256_sub_epi32(s10, s14); + u15 = _mm256_sub_epi32(s11, s15); + + v8 = _mm256_add_epi32(x8, x12); + v9 = _mm256_add_epi32(x9, x13); + v10 = _mm256_add_epi32(x10, x14); + v11 = _mm256_add_epi32(x11, x15); + v12 = _mm256_sub_epi32(x8, x12); + v13 = _mm256_sub_epi32(x9, x13); + v14 = _mm256_sub_epi32(x10, x14); + v15 = _mm256_sub_epi32(x11, x15); + + u12 = _mm256_add_epi32(u12, dct_rounding); + u13 = _mm256_add_epi32(u13, dct_rounding); + u14 = _mm256_add_epi32(u14, dct_rounding); + u15 = _mm256_add_epi32(u15, dct_rounding); + + u12 = _mm256_srai_epi32(u12, DCT_CONST_BITS); + u13 = _mm256_srai_epi32(u13, DCT_CONST_BITS); + u14 = _mm256_srai_epi32(u14, DCT_CONST_BITS); + u15 = _mm256_srai_epi32(u15, DCT_CONST_BITS); + + v12 = _mm256_add_epi32(v12, dct_rounding); + v13 = _mm256_add_epi32(v13, dct_rounding); + v14 = _mm256_add_epi32(v14, dct_rounding); + v15 = _mm256_add_epi32(v15, dct_rounding); + + v12 = _mm256_srai_epi32(v12, DCT_CONST_BITS); + v13 = _mm256_srai_epi32(v13, DCT_CONST_BITS); + v14 = _mm256_srai_epi32(v14, DCT_CONST_BITS); + v15 = _mm256_srai_epi32(v15, DCT_CONST_BITS); + + x12 = _mm256_packs_epi32(u12, v12); + x13 = _mm256_packs_epi32(u13, v13); + x14 = _mm256_packs_epi32(u14, v14); + x15 = _mm256_packs_epi32(u15, v15); + + // stage 3 + y0 = _mm256_unpacklo_epi16(x4, x5); + y1 = _mm256_unpackhi_epi16(x4, x5); + s4 = _mm256_madd_epi16(y0, cospi_p08_p24); + x4 = _mm256_madd_epi16(y1, cospi_p08_p24); + s5 = _mm256_madd_epi16(y0, cospi_p24_m08); + x5 = _mm256_madd_epi16(y1, cospi_p24_m08); + + y0 = _mm256_unpacklo_epi16(x6, x7); + y1 = _mm256_unpackhi_epi16(x6, x7); + s6 = _mm256_madd_epi16(y0, cospi_m24_p08); + x6 = _mm256_madd_epi16(y1, cospi_m24_p08); + s7 = _mm256_madd_epi16(y0, cospi_p08_p24); + x7 = _mm256_madd_epi16(y1, cospi_p08_p24); + + y0 = _mm256_unpacklo_epi16(x12, x13); + y1 = _mm256_unpackhi_epi16(x12, x13); + s12 = _mm256_madd_epi16(y0, cospi_p08_p24); + x12 = _mm256_madd_epi16(y1, cospi_p08_p24); + s13 = _mm256_madd_epi16(y0, cospi_p24_m08); + x13 = _mm256_madd_epi16(y1, cospi_p24_m08); + + y0 = _mm256_unpacklo_epi16(x14, x15); + y1 = _mm256_unpackhi_epi16(x14, x15); + s14 = _mm256_madd_epi16(y0, cospi_m24_p08); + x14 = _mm256_madd_epi16(y1, cospi_m24_p08); + s15 = _mm256_madd_epi16(y0, cospi_p08_p24); + x15 = _mm256_madd_epi16(y1, cospi_p08_p24); + + u0 = _mm256_add_epi32(x0, x2); + v0 = _mm256_add_epi32(s0, s2); + u1 = _mm256_add_epi32(x1, x3); + v1 = _mm256_add_epi32(s1, s3); + u2 = _mm256_sub_epi32(x0, x2); + v2 = _mm256_sub_epi32(s0, s2); + u3 = _mm256_sub_epi32(x1, x3); + v3 = _mm256_sub_epi32(s1, s3); + + u0 = _mm256_add_epi32(u0, dct_rounding); + v0 = _mm256_add_epi32(v0, dct_rounding); + u1 = _mm256_add_epi32(u1, dct_rounding); + v1 = _mm256_add_epi32(v1, dct_rounding); + u2 = _mm256_add_epi32(u2, dct_rounding); + v2 = _mm256_add_epi32(v2, dct_rounding); + u3 = _mm256_add_epi32(u3, dct_rounding); + v3 = _mm256_add_epi32(v3, dct_rounding); + + u0 = _mm256_srai_epi32(u0, DCT_CONST_BITS); + v0 = _mm256_srai_epi32(v0, DCT_CONST_BITS); + u1 = _mm256_srai_epi32(u1, DCT_CONST_BITS); + v1 = _mm256_srai_epi32(v1, DCT_CONST_BITS); + u2 = _mm256_srai_epi32(u2, DCT_CONST_BITS); + v2 = _mm256_srai_epi32(v2, DCT_CONST_BITS); + u3 = _mm256_srai_epi32(u3, DCT_CONST_BITS); + v3 = _mm256_srai_epi32(v3, DCT_CONST_BITS); + + in[0] = _mm256_packs_epi32(u0, v0); + x1 = _mm256_packs_epi32(u1, v1); + x2 = _mm256_packs_epi32(u2, v2); + x3 = _mm256_packs_epi32(u3, v3); + + // Rounding on s4 + s6, s5 + s7, s4 - s6, s5 - s7 + u4 = _mm256_add_epi32(s4, s6); + u5 = _mm256_add_epi32(s5, s7); + u6 = _mm256_sub_epi32(s4, s6); + u7 = _mm256_sub_epi32(s5, s7); + + v4 = _mm256_add_epi32(x4, x6); + v5 = _mm256_add_epi32(x5, x7); + v6 = _mm256_sub_epi32(x4, x6); + v7 = _mm256_sub_epi32(x5, x7); + + u4 = _mm256_add_epi32(u4, dct_rounding); + u5 = _mm256_add_epi32(u5, dct_rounding); + u6 = _mm256_add_epi32(u6, dct_rounding); + u7 = _mm256_add_epi32(u7, dct_rounding); + + u4 = _mm256_srai_epi32(u4, DCT_CONST_BITS); + u5 = _mm256_srai_epi32(u5, DCT_CONST_BITS); + u6 = _mm256_srai_epi32(u6, DCT_CONST_BITS); + u7 = _mm256_srai_epi32(u7, DCT_CONST_BITS); + + v4 = _mm256_add_epi32(v4, dct_rounding); + v5 = _mm256_add_epi32(v5, dct_rounding); + v6 = _mm256_add_epi32(v6, dct_rounding); + v7 = _mm256_add_epi32(v7, dct_rounding); + + v4 = _mm256_srai_epi32(v4, DCT_CONST_BITS); + v5 = _mm256_srai_epi32(v5, DCT_CONST_BITS); + v6 = _mm256_srai_epi32(v6, DCT_CONST_BITS); + v7 = _mm256_srai_epi32(v7, DCT_CONST_BITS); + + x4 = _mm256_packs_epi32(u4, v4); + in[12] = _mm256_packs_epi32(u5, v5); + x6 = _mm256_packs_epi32(u6, v6); + x7 = _mm256_packs_epi32(u7, v7); + + u0 = _mm256_add_epi32(u8, u10); + v0 = _mm256_add_epi32(v8, v10); + u1 = _mm256_add_epi32(u9, u11); + v1 = _mm256_add_epi32(v9, v11); + u2 = _mm256_sub_epi32(u8, u10); + v2 = _mm256_sub_epi32(v8, v10); + u3 = _mm256_sub_epi32(u9, u11); + v3 = _mm256_sub_epi32(v9, v11); + + u0 = _mm256_add_epi32(u0, dct_rounding); + v0 = _mm256_add_epi32(v0, dct_rounding); + u1 = _mm256_add_epi32(u1, dct_rounding); + v1 = _mm256_add_epi32(v1, dct_rounding); + u2 = _mm256_add_epi32(u2, dct_rounding); + v2 = _mm256_add_epi32(v2, dct_rounding); + u3 = _mm256_add_epi32(u3, dct_rounding); + v3 = _mm256_add_epi32(v3, dct_rounding); + + u0 = _mm256_srai_epi32(u0, DCT_CONST_BITS); + v0 = _mm256_srai_epi32(v0, DCT_CONST_BITS); + u1 = _mm256_srai_epi32(u1, DCT_CONST_BITS); + v1 = _mm256_srai_epi32(v1, DCT_CONST_BITS); + u2 = _mm256_srai_epi32(u2, DCT_CONST_BITS); + v2 = _mm256_srai_epi32(v2, DCT_CONST_BITS); + u3 = _mm256_srai_epi32(u3, DCT_CONST_BITS); + v3 = _mm256_srai_epi32(v3, DCT_CONST_BITS); + + x8 = _mm256_packs_epi32(u0, v0); + in[14] = _mm256_packs_epi32(u1, v1); + x10 = _mm256_packs_epi32(u2, v2); + x11 = _mm256_packs_epi32(u3, v3); + + // Rounding on s12 + s14, s13 + s15, s12 - s14, s13 - s15 + u12 = _mm256_add_epi32(s12, s14); + u13 = _mm256_add_epi32(s13, s15); + u14 = _mm256_sub_epi32(s12, s14); + u15 = _mm256_sub_epi32(s13, s15); + + v12 = _mm256_add_epi32(x12, x14); + v13 = _mm256_add_epi32(x13, x15); + v14 = _mm256_sub_epi32(x12, x14); + v15 = _mm256_sub_epi32(x13, x15); + + u12 = _mm256_add_epi32(u12, dct_rounding); + u13 = _mm256_add_epi32(u13, dct_rounding); + u14 = _mm256_add_epi32(u14, dct_rounding); + u15 = _mm256_add_epi32(u15, dct_rounding); + + u12 = _mm256_srai_epi32(u12, DCT_CONST_BITS); + u13 = _mm256_srai_epi32(u13, DCT_CONST_BITS); + u14 = _mm256_srai_epi32(u14, DCT_CONST_BITS); + u15 = _mm256_srai_epi32(u15, DCT_CONST_BITS); + + v12 = _mm256_add_epi32(v12, dct_rounding); + v13 = _mm256_add_epi32(v13, dct_rounding); + v14 = _mm256_add_epi32(v14, dct_rounding); + v15 = _mm256_add_epi32(v15, dct_rounding); + + v12 = _mm256_srai_epi32(v12, DCT_CONST_BITS); + v13 = _mm256_srai_epi32(v13, DCT_CONST_BITS); + v14 = _mm256_srai_epi32(v14, DCT_CONST_BITS); + v15 = _mm256_srai_epi32(v15, DCT_CONST_BITS); + + x12 = _mm256_packs_epi32(u12, v12); + x13 = _mm256_packs_epi32(u13, v13); + x14 = _mm256_packs_epi32(u14, v14); + x15 = _mm256_packs_epi32(u15, v15); + in[2] = x12; + + // stage 4 + y0 = _mm256_unpacklo_epi16(x2, x3); + y1 = _mm256_unpackhi_epi16(x2, x3); + s2 = _mm256_madd_epi16(y0, cospi_m16_m16); + x2 = _mm256_madd_epi16(y1, cospi_m16_m16); + s3 = _mm256_madd_epi16(y0, cospi_p16_m16); + x3 = _mm256_madd_epi16(y1, cospi_p16_m16); + + y0 = _mm256_unpacklo_epi16(x6, x7); + y1 = _mm256_unpackhi_epi16(x6, x7); + s6 = _mm256_madd_epi16(y0, cospi_p16_p16); + x6 = _mm256_madd_epi16(y1, cospi_p16_p16); + s7 = _mm256_madd_epi16(y0, cospi_m16_p16); + x7 = _mm256_madd_epi16(y1, cospi_m16_p16); + + y0 = _mm256_unpacklo_epi16(x10, x11); + y1 = _mm256_unpackhi_epi16(x10, x11); + s10 = _mm256_madd_epi16(y0, cospi_p16_p16); + x10 = _mm256_madd_epi16(y1, cospi_p16_p16); + s11 = _mm256_madd_epi16(y0, cospi_m16_p16); + x11 = _mm256_madd_epi16(y1, cospi_m16_p16); + + y0 = _mm256_unpacklo_epi16(x14, x15); + y1 = _mm256_unpackhi_epi16(x14, x15); + s14 = _mm256_madd_epi16(y0, cospi_m16_m16); + x14 = _mm256_madd_epi16(y1, cospi_m16_m16); + s15 = _mm256_madd_epi16(y0, cospi_p16_m16); + x15 = _mm256_madd_epi16(y1, cospi_p16_m16); + + // Rounding + u2 = _mm256_add_epi32(s2, dct_rounding); + u3 = _mm256_add_epi32(s3, dct_rounding); + u6 = _mm256_add_epi32(s6, dct_rounding); + u7 = _mm256_add_epi32(s7, dct_rounding); + + u10 = _mm256_add_epi32(s10, dct_rounding); + u11 = _mm256_add_epi32(s11, dct_rounding); + u14 = _mm256_add_epi32(s14, dct_rounding); + u15 = _mm256_add_epi32(s15, dct_rounding); + + u2 = _mm256_srai_epi32(u2, DCT_CONST_BITS); + u3 = _mm256_srai_epi32(u3, DCT_CONST_BITS); + u6 = _mm256_srai_epi32(u6, DCT_CONST_BITS); + u7 = _mm256_srai_epi32(u7, DCT_CONST_BITS); + + u10 = _mm256_srai_epi32(u10, DCT_CONST_BITS); + u11 = _mm256_srai_epi32(u11, DCT_CONST_BITS); + u14 = _mm256_srai_epi32(u14, DCT_CONST_BITS); + u15 = _mm256_srai_epi32(u15, DCT_CONST_BITS); + + v2 = _mm256_add_epi32(x2, dct_rounding); + v3 = _mm256_add_epi32(x3, dct_rounding); + v6 = _mm256_add_epi32(x6, dct_rounding); + v7 = _mm256_add_epi32(x7, dct_rounding); + + v10 = _mm256_add_epi32(x10, dct_rounding); + v11 = _mm256_add_epi32(x11, dct_rounding); + v14 = _mm256_add_epi32(x14, dct_rounding); + v15 = _mm256_add_epi32(x15, dct_rounding); + + v2 = _mm256_srai_epi32(v2, DCT_CONST_BITS); + v3 = _mm256_srai_epi32(v3, DCT_CONST_BITS); + v6 = _mm256_srai_epi32(v6, DCT_CONST_BITS); + v7 = _mm256_srai_epi32(v7, DCT_CONST_BITS); + + v10 = _mm256_srai_epi32(v10, DCT_CONST_BITS); + v11 = _mm256_srai_epi32(v11, DCT_CONST_BITS); + v14 = _mm256_srai_epi32(v14, DCT_CONST_BITS); + v15 = _mm256_srai_epi32(v15, DCT_CONST_BITS); + + in[7] = _mm256_packs_epi32(u2, v2); + in[8] = _mm256_packs_epi32(u3, v3); + + in[4] = _mm256_packs_epi32(u6, v6); + in[11] = _mm256_packs_epi32(u7, v7); + + in[6] = _mm256_packs_epi32(u10, v10); + in[9] = _mm256_packs_epi32(u11, v11); + + in[5] = _mm256_packs_epi32(u14, v14); + in[10] = _mm256_packs_epi32(u15, v15); + + in[1] = _mm256_sub_epi16(zero, x8); + in[3] = _mm256_sub_epi16(zero, x4); + in[13] = _mm256_sub_epi16(zero, x13); + in[15] = _mm256_sub_epi16(zero, x1); +} + +#if CONFIG_EXT_TX +static void fidtx16_avx2(__m256i *in) { txfm_scaling16_avx2(Sqrt2, in); } +#endif + +void av1_fht16x16_avx2(const int16_t *input, tran_low_t *output, int stride, + int tx_type) { + __m256i in[16]; + + switch (tx_type) { + case DCT_DCT: + load_buffer_16x16(input, stride, 0, 0, in); + fdct16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fdct16_avx2(in); + break; + case ADST_DCT: + load_buffer_16x16(input, stride, 0, 0, in); + fadst16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fdct16_avx2(in); + break; + case DCT_ADST: + load_buffer_16x16(input, stride, 0, 0, in); + fdct16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fadst16_avx2(in); + break; + case ADST_ADST: + load_buffer_16x16(input, stride, 0, 0, in); + fadst16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fadst16_avx2(in); + break; +#if CONFIG_EXT_TX + case FLIPADST_DCT: + load_buffer_16x16(input, stride, 1, 0, in); + fadst16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fdct16_avx2(in); + break; + case DCT_FLIPADST: + load_buffer_16x16(input, stride, 0, 1, in); + fdct16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fadst16_avx2(in); + break; + case FLIPADST_FLIPADST: + load_buffer_16x16(input, stride, 1, 1, in); + fadst16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fadst16_avx2(in); + break; + case ADST_FLIPADST: + load_buffer_16x16(input, stride, 0, 1, in); + fadst16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fadst16_avx2(in); + break; + case FLIPADST_ADST: + load_buffer_16x16(input, stride, 1, 0, in); + fadst16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fadst16_avx2(in); + break; + case IDTX: + load_buffer_16x16(input, stride, 0, 0, in); + fidtx16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fidtx16_avx2(in); + break; + case V_DCT: + load_buffer_16x16(input, stride, 0, 0, in); + fdct16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fidtx16_avx2(in); + break; + case H_DCT: + load_buffer_16x16(input, stride, 0, 0, in); + fidtx16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fdct16_avx2(in); + break; + case V_ADST: + load_buffer_16x16(input, stride, 0, 0, in); + fadst16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fidtx16_avx2(in); + break; + case H_ADST: + load_buffer_16x16(input, stride, 0, 0, in); + fidtx16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fadst16_avx2(in); + break; + case V_FLIPADST: + load_buffer_16x16(input, stride, 1, 0, in); + fadst16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fidtx16_avx2(in); + break; + case H_FLIPADST: + load_buffer_16x16(input, stride, 0, 1, in); + fidtx16_avx2(in); + mm256_transpose_16x16(in); + right_shift_16x16(in); + fadst16_avx2(in); + break; +#endif // CONFIG_EXT_TX + default: assert(0); break; + } + mm256_transpose_16x16(in); + write_buffer_16x16(in, output); + _mm256_zeroupper(); +} + +void aom_fdct32x32_1_avx2(const int16_t *input, tran_low_t *output, + int stride) { + // left and upper corner + int32_t sum = get_16x16_sum(input, stride); + // right and upper corner + sum += get_16x16_sum(input + 16, stride); + // left and lower corner + sum += get_16x16_sum(input + (stride << 4), stride); + // right and lower corner + sum += get_16x16_sum(input + (stride << 4) + 16, stride); + + sum >>= 3; + output[0] = (tran_low_t)sum; + _mm256_zeroupper(); +} + +static void mm256_vectors_swap(__m256i *a0, __m256i *a1, const int size) { + int i = 0; + __m256i temp; + while (i < size) { + temp = a0[i]; + a0[i] = a1[i]; + a1[i] = temp; + i++; + } +} + +static void mm256_transpose_32x32(__m256i *in0, __m256i *in1) { + mm256_transpose_16x16(in0); + mm256_transpose_16x16(&in0[16]); + mm256_transpose_16x16(in1); + mm256_transpose_16x16(&in1[16]); + mm256_vectors_swap(&in0[16], in1, 16); +} + +static void prepare_16x16_even(const __m256i *in, __m256i *even) { + even[0] = _mm256_add_epi16(in[0], in[31]); + even[1] = _mm256_add_epi16(in[1], in[30]); + even[2] = _mm256_add_epi16(in[2], in[29]); + even[3] = _mm256_add_epi16(in[3], in[28]); + even[4] = _mm256_add_epi16(in[4], in[27]); + even[5] = _mm256_add_epi16(in[5], in[26]); + even[6] = _mm256_add_epi16(in[6], in[25]); + even[7] = _mm256_add_epi16(in[7], in[24]); + even[8] = _mm256_add_epi16(in[8], in[23]); + even[9] = _mm256_add_epi16(in[9], in[22]); + even[10] = _mm256_add_epi16(in[10], in[21]); + even[11] = _mm256_add_epi16(in[11], in[20]); + even[12] = _mm256_add_epi16(in[12], in[19]); + even[13] = _mm256_add_epi16(in[13], in[18]); + even[14] = _mm256_add_epi16(in[14], in[17]); + even[15] = _mm256_add_epi16(in[15], in[16]); +} + +static void prepare_16x16_odd(const __m256i *in, __m256i *odd) { + odd[0] = _mm256_sub_epi16(in[15], in[16]); + odd[1] = _mm256_sub_epi16(in[14], in[17]); + odd[2] = _mm256_sub_epi16(in[13], in[18]); + odd[3] = _mm256_sub_epi16(in[12], in[19]); + odd[4] = _mm256_sub_epi16(in[11], in[20]); + odd[5] = _mm256_sub_epi16(in[10], in[21]); + odd[6] = _mm256_sub_epi16(in[9], in[22]); + odd[7] = _mm256_sub_epi16(in[8], in[23]); + odd[8] = _mm256_sub_epi16(in[7], in[24]); + odd[9] = _mm256_sub_epi16(in[6], in[25]); + odd[10] = _mm256_sub_epi16(in[5], in[26]); + odd[11] = _mm256_sub_epi16(in[4], in[27]); + odd[12] = _mm256_sub_epi16(in[3], in[28]); + odd[13] = _mm256_sub_epi16(in[2], in[29]); + odd[14] = _mm256_sub_epi16(in[1], in[30]); + odd[15] = _mm256_sub_epi16(in[0], in[31]); +} + +static void collect_16col(const __m256i *even, const __m256i *odd, + __m256i *out) { + // fdct16_avx2() already maps the output + out[0] = even[0]; + out[2] = even[1]; + out[4] = even[2]; + out[6] = even[3]; + out[8] = even[4]; + out[10] = even[5]; + out[12] = even[6]; + out[14] = even[7]; + out[16] = even[8]; + out[18] = even[9]; + out[20] = even[10]; + out[22] = even[11]; + out[24] = even[12]; + out[26] = even[13]; + out[28] = even[14]; + out[30] = even[15]; + + out[1] = odd[0]; + out[17] = odd[1]; + out[9] = odd[2]; + out[25] = odd[3]; + out[5] = odd[4]; + out[21] = odd[5]; + out[13] = odd[6]; + out[29] = odd[7]; + out[3] = odd[8]; + out[19] = odd[9]; + out[11] = odd[10]; + out[27] = odd[11]; + out[7] = odd[12]; + out[23] = odd[13]; + out[15] = odd[14]; + out[31] = odd[15]; +} + +static void collect_coeffs(const __m256i *first_16col_even, + const __m256i *first_16col_odd, + const __m256i *second_16col_even, + const __m256i *second_16col_odd, __m256i *in0, + __m256i *in1) { + collect_16col(first_16col_even, first_16col_odd, in0); + collect_16col(second_16col_even, second_16col_odd, in1); +} + +static void fdct16_odd_avx2(__m256i *in) { + // sequence: cospi_L_H = pairs(L, H) and L first + const __m256i cospi_p16_p16 = pair256_set_epi16(cospi_16_64, cospi_16_64); + const __m256i cospi_m16_p16 = pair256_set_epi16(-cospi_16_64, cospi_16_64); + const __m256i cospi_m08_p24 = pair256_set_epi16(-cospi_8_64, cospi_24_64); + const __m256i cospi_p24_p08 = pair256_set_epi16(cospi_24_64, cospi_8_64); + const __m256i cospi_m24_m08 = pair256_set_epi16(-cospi_24_64, -cospi_8_64); + const __m256i cospi_m04_p28 = pair256_set_epi16(-cospi_4_64, cospi_28_64); + const __m256i cospi_p28_p04 = pair256_set_epi16(cospi_28_64, cospi_4_64); + const __m256i cospi_m28_m04 = pair256_set_epi16(-cospi_28_64, -cospi_4_64); + const __m256i cospi_m20_p12 = pair256_set_epi16(-cospi_20_64, cospi_12_64); + const __m256i cospi_p12_p20 = pair256_set_epi16(cospi_12_64, cospi_20_64); + const __m256i cospi_m12_m20 = pair256_set_epi16(-cospi_12_64, -cospi_20_64); + + const __m256i cospi_p31_p01 = pair256_set_epi16(cospi_31_64, cospi_1_64); + const __m256i cospi_m01_p31 = pair256_set_epi16(-cospi_1_64, cospi_31_64); + const __m256i cospi_p15_p17 = pair256_set_epi16(cospi_15_64, cospi_17_64); + const __m256i cospi_m17_p15 = pair256_set_epi16(-cospi_17_64, cospi_15_64); + const __m256i cospi_p23_p09 = pair256_set_epi16(cospi_23_64, cospi_9_64); + const __m256i cospi_m09_p23 = pair256_set_epi16(-cospi_9_64, cospi_23_64); + const __m256i cospi_p07_p25 = pair256_set_epi16(cospi_7_64, cospi_25_64); + const __m256i cospi_m25_p07 = pair256_set_epi16(-cospi_25_64, cospi_7_64); + const __m256i cospi_p27_p05 = pair256_set_epi16(cospi_27_64, cospi_5_64); + const __m256i cospi_m05_p27 = pair256_set_epi16(-cospi_5_64, cospi_27_64); + const __m256i cospi_p11_p21 = pair256_set_epi16(cospi_11_64, cospi_21_64); + const __m256i cospi_m21_p11 = pair256_set_epi16(-cospi_21_64, cospi_11_64); + const __m256i cospi_p19_p13 = pair256_set_epi16(cospi_19_64, cospi_13_64); + const __m256i cospi_m13_p19 = pair256_set_epi16(-cospi_13_64, cospi_19_64); + const __m256i cospi_p03_p29 = pair256_set_epi16(cospi_3_64, cospi_29_64); + const __m256i cospi_m29_p03 = pair256_set_epi16(-cospi_29_64, cospi_3_64); + + __m256i x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15; + __m256i y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11, y12, y13, y14, y15; + __m256i u0, u1; + + // stage 1 is in prepare_16x16_odd() + + // stage 2 + y0 = in[0]; + y1 = in[1]; + y2 = in[2]; + y3 = in[3]; + + u0 = _mm256_unpacklo_epi16(in[4], in[11]); + u1 = _mm256_unpackhi_epi16(in[4], in[11]); + y4 = butter_fly(u0, u1, cospi_m16_p16); + y11 = butter_fly(u0, u1, cospi_p16_p16); + + u0 = _mm256_unpacklo_epi16(in[5], in[10]); + u1 = _mm256_unpackhi_epi16(in[5], in[10]); + y5 = butter_fly(u0, u1, cospi_m16_p16); + y10 = butter_fly(u0, u1, cospi_p16_p16); + + u0 = _mm256_unpacklo_epi16(in[6], in[9]); + u1 = _mm256_unpackhi_epi16(in[6], in[9]); + y6 = butter_fly(u0, u1, cospi_m16_p16); + y9 = butter_fly(u0, u1, cospi_p16_p16); + + u0 = _mm256_unpacklo_epi16(in[7], in[8]); + u1 = _mm256_unpackhi_epi16(in[7], in[8]); + y7 = butter_fly(u0, u1, cospi_m16_p16); + y8 = butter_fly(u0, u1, cospi_p16_p16); + + y12 = in[12]; + y13 = in[13]; + y14 = in[14]; + y15 = in[15]; + + // stage 3 + x0 = _mm256_add_epi16(y0, y7); + x1 = _mm256_add_epi16(y1, y6); + x2 = _mm256_add_epi16(y2, y5); + x3 = _mm256_add_epi16(y3, y4); + x4 = _mm256_sub_epi16(y3, y4); + x5 = _mm256_sub_epi16(y2, y5); + x6 = _mm256_sub_epi16(y1, y6); + x7 = _mm256_sub_epi16(y0, y7); + x8 = _mm256_sub_epi16(y15, y8); + x9 = _mm256_sub_epi16(y14, y9); + x10 = _mm256_sub_epi16(y13, y10); + x11 = _mm256_sub_epi16(y12, y11); + x12 = _mm256_add_epi16(y12, y11); + x13 = _mm256_add_epi16(y13, y10); + x14 = _mm256_add_epi16(y14, y9); + x15 = _mm256_add_epi16(y15, y8); + + // stage 4 + y0 = x0; + y1 = x1; + y6 = x6; + y7 = x7; + y8 = x8; + y9 = x9; + y14 = x14; + y15 = x15; + + u0 = _mm256_unpacklo_epi16(x2, x13); + u1 = _mm256_unpackhi_epi16(x2, x13); + y2 = butter_fly(u0, u1, cospi_m08_p24); + y13 = butter_fly(u0, u1, cospi_p24_p08); + + u0 = _mm256_unpacklo_epi16(x3, x12); + u1 = _mm256_unpackhi_epi16(x3, x12); + y3 = butter_fly(u0, u1, cospi_m08_p24); + y12 = butter_fly(u0, u1, cospi_p24_p08); + + u0 = _mm256_unpacklo_epi16(x4, x11); + u1 = _mm256_unpackhi_epi16(x4, x11); + y4 = butter_fly(u0, u1, cospi_m24_m08); + y11 = butter_fly(u0, u1, cospi_m08_p24); + + u0 = _mm256_unpacklo_epi16(x5, x10); + u1 = _mm256_unpackhi_epi16(x5, x10); + y5 = butter_fly(u0, u1, cospi_m24_m08); + y10 = butter_fly(u0, u1, cospi_m08_p24); + + // stage 5 + x0 = _mm256_add_epi16(y0, y3); + x1 = _mm256_add_epi16(y1, y2); + x2 = _mm256_sub_epi16(y1, y2); + x3 = _mm256_sub_epi16(y0, y3); + x4 = _mm256_sub_epi16(y7, y4); + x5 = _mm256_sub_epi16(y6, y5); + x6 = _mm256_add_epi16(y6, y5); + x7 = _mm256_add_epi16(y7, y4); + + x8 = _mm256_add_epi16(y8, y11); + x9 = _mm256_add_epi16(y9, y10); + x10 = _mm256_sub_epi16(y9, y10); + x11 = _mm256_sub_epi16(y8, y11); + x12 = _mm256_sub_epi16(y15, y12); + x13 = _mm256_sub_epi16(y14, y13); + x14 = _mm256_add_epi16(y14, y13); + x15 = _mm256_add_epi16(y15, y12); + + // stage 6 + y0 = x0; + y3 = x3; + y4 = x4; + y7 = x7; + y8 = x8; + y11 = x11; + y12 = x12; + y15 = x15; + + u0 = _mm256_unpacklo_epi16(x1, x14); + u1 = _mm256_unpackhi_epi16(x1, x14); + y1 = butter_fly(u0, u1, cospi_m04_p28); + y14 = butter_fly(u0, u1, cospi_p28_p04); + + u0 = _mm256_unpacklo_epi16(x2, x13); + u1 = _mm256_unpackhi_epi16(x2, x13); + y2 = butter_fly(u0, u1, cospi_m28_m04); + y13 = butter_fly(u0, u1, cospi_m04_p28); + + u0 = _mm256_unpacklo_epi16(x5, x10); + u1 = _mm256_unpackhi_epi16(x5, x10); + y5 = butter_fly(u0, u1, cospi_m20_p12); + y10 = butter_fly(u0, u1, cospi_p12_p20); + + u0 = _mm256_unpacklo_epi16(x6, x9); + u1 = _mm256_unpackhi_epi16(x6, x9); + y6 = butter_fly(u0, u1, cospi_m12_m20); + y9 = butter_fly(u0, u1, cospi_m20_p12); + + // stage 7 + x0 = _mm256_add_epi16(y0, y1); + x1 = _mm256_sub_epi16(y0, y1); + x2 = _mm256_sub_epi16(y3, y2); + x3 = _mm256_add_epi16(y3, y2); + x4 = _mm256_add_epi16(y4, y5); + x5 = _mm256_sub_epi16(y4, y5); + x6 = _mm256_sub_epi16(y7, y6); + x7 = _mm256_add_epi16(y7, y6); + + x8 = _mm256_add_epi16(y8, y9); + x9 = _mm256_sub_epi16(y8, y9); + x10 = _mm256_sub_epi16(y11, y10); + x11 = _mm256_add_epi16(y11, y10); + x12 = _mm256_add_epi16(y12, y13); + x13 = _mm256_sub_epi16(y12, y13); + x14 = _mm256_sub_epi16(y15, y14); + x15 = _mm256_add_epi16(y15, y14); + + // stage 8 + u0 = _mm256_unpacklo_epi16(x0, x15); + u1 = _mm256_unpackhi_epi16(x0, x15); + in[0] = butter_fly(u0, u1, cospi_p31_p01); + in[15] = butter_fly(u0, u1, cospi_m01_p31); + + u0 = _mm256_unpacklo_epi16(x1, x14); + u1 = _mm256_unpackhi_epi16(x1, x14); + in[1] = butter_fly(u0, u1, cospi_p15_p17); + in[14] = butter_fly(u0, u1, cospi_m17_p15); + + u0 = _mm256_unpacklo_epi16(x2, x13); + u1 = _mm256_unpackhi_epi16(x2, x13); + in[2] = butter_fly(u0, u1, cospi_p23_p09); + in[13] = butter_fly(u0, u1, cospi_m09_p23); + + u0 = _mm256_unpacklo_epi16(x3, x12); + u1 = _mm256_unpackhi_epi16(x3, x12); + in[3] = butter_fly(u0, u1, cospi_p07_p25); + in[12] = butter_fly(u0, u1, cospi_m25_p07); + + u0 = _mm256_unpacklo_epi16(x4, x11); + u1 = _mm256_unpackhi_epi16(x4, x11); + in[4] = butter_fly(u0, u1, cospi_p27_p05); + in[11] = butter_fly(u0, u1, cospi_m05_p27); + + u0 = _mm256_unpacklo_epi16(x5, x10); + u1 = _mm256_unpackhi_epi16(x5, x10); + in[5] = butter_fly(u0, u1, cospi_p11_p21); + in[10] = butter_fly(u0, u1, cospi_m21_p11); + + u0 = _mm256_unpacklo_epi16(x6, x9); + u1 = _mm256_unpackhi_epi16(x6, x9); + in[6] = butter_fly(u0, u1, cospi_p19_p13); + in[9] = butter_fly(u0, u1, cospi_m13_p19); + + u0 = _mm256_unpacklo_epi16(x7, x8); + u1 = _mm256_unpackhi_epi16(x7, x8); + in[7] = butter_fly(u0, u1, cospi_p03_p29); + in[8] = butter_fly(u0, u1, cospi_m29_p03); +} + +static void fdct32_avx2(__m256i *in0, __m256i *in1) { + __m256i even0[16], even1[16], odd0[16], odd1[16]; + prepare_16x16_even(in0, even0); + fdct16_avx2(even0); + + prepare_16x16_odd(in0, odd0); + fdct16_odd_avx2(odd0); + + prepare_16x16_even(in1, even1); + fdct16_avx2(even1); + + prepare_16x16_odd(in1, odd1); + fdct16_odd_avx2(odd1); + + collect_coeffs(even0, odd0, even1, odd1, in0, in1); + + mm256_transpose_32x32(in0, in1); +} + +static INLINE void write_buffer_32x32(const __m256i *in0, const __m256i *in1, + tran_low_t *output) { + int i = 0; + const int stride = 32; + tran_low_t *coeff = output; + while (i < 32) { + storeu_output_avx2(&in0[i], coeff); + storeu_output_avx2(&in1[i], coeff + 16); + coeff += stride; + i += 1; + } +} + +#if CONFIG_EXT_TX +static void fhalfright32_16col_avx2(__m256i *in) { + int i = 0; + const __m256i zero = _mm256_setzero_si256(); + const __m256i sqrt2 = _mm256_set1_epi16(Sqrt2); + const __m256i dct_rounding = _mm256_set1_epi32(DCT_CONST_ROUNDING); + __m256i x0, x1; + + while (i < 16) { + in[i] = _mm256_slli_epi16(in[i], 2); + x0 = _mm256_unpacklo_epi16(in[i + 16], zero); + x1 = _mm256_unpackhi_epi16(in[i + 16], zero); + x0 = _mm256_madd_epi16(x0, sqrt2); + x1 = _mm256_madd_epi16(x1, sqrt2); + x0 = _mm256_add_epi32(x0, dct_rounding); + x1 = _mm256_add_epi32(x1, dct_rounding); + x0 = _mm256_srai_epi32(x0, DCT_CONST_BITS); + x1 = _mm256_srai_epi32(x1, DCT_CONST_BITS); + in[i + 16] = _mm256_packs_epi32(x0, x1); + i += 1; + } + fdct16_avx2(&in[16]); +} + +static void fhalfright32_avx2(__m256i *in0, __m256i *in1) { + fhalfright32_16col_avx2(in0); + fhalfright32_16col_avx2(in1); + mm256_vectors_swap(in0, &in0[16], 16); + mm256_vectors_swap(in1, &in1[16], 16); + mm256_transpose_32x32(in0, in1); +} +#endif // CONFIG_EXT_TX + +static INLINE void load_buffer_32x32(const int16_t *input, int stride, + int flipud, int fliplr, __m256i *in0, + __m256i *in1) { + // Load 4 16x16 blocks + const int16_t *topL = input; + const int16_t *topR = input + 16; + const int16_t *botL = input + 16 * stride; + const int16_t *botR = input + 16 * stride + 16; + + const int16_t *tmp; + + if (flipud) { + // Swap left columns + tmp = topL; + topL = botL; + botL = tmp; + // Swap right columns + tmp = topR; + topR = botR; + botR = tmp; + } + + if (fliplr) { + // Swap top rows + tmp = topL; + topL = topR; + topR = tmp; + // Swap bottom rows + tmp = botL; + botL = botR; + botR = tmp; + } + + // load first 16 columns + load_buffer_16x16(topL, stride, flipud, fliplr, in0); + load_buffer_16x16(botL, stride, flipud, fliplr, in0 + 16); + + // load second 16 columns + load_buffer_16x16(topR, stride, flipud, fliplr, in1); + load_buffer_16x16(botR, stride, flipud, fliplr, in1 + 16); +} + +static INLINE void right_shift_32x32_16col(int bit, __m256i *in) { + int i = 0; + const __m256i rounding = _mm256_set1_epi16((1 << bit) >> 1); + __m256i sign; + while (i < 32) { + sign = _mm256_srai_epi16(in[i], 15); + in[i] = _mm256_add_epi16(in[i], rounding); + in[i] = _mm256_add_epi16(in[i], sign); + in[i] = _mm256_srai_epi16(in[i], bit); + i += 1; + } +} + +// Positive rounding +static INLINE void right_shift_32x32(__m256i *in0, __m256i *in1) { + const int bit = 4; + right_shift_32x32_16col(bit, in0); + right_shift_32x32_16col(bit, in1); +} + +#if CONFIG_EXT_TX +static void fidtx32_avx2(__m256i *in0, __m256i *in1) { + int i = 0; + while (i < 32) { + in0[i] = _mm256_slli_epi16(in0[i], 2); + in1[i] = _mm256_slli_epi16(in1[i], 2); + i += 1; + } + mm256_transpose_32x32(in0, in1); +} +#endif + +void av1_fht32x32_avx2(const int16_t *input, tran_low_t *output, int stride, + int tx_type) { + __m256i in0[32]; // left 32 columns + __m256i in1[32]; // right 32 columns + + switch (tx_type) { + case DCT_DCT: + load_buffer_32x32(input, stride, 0, 0, in0, in1); + fdct32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fdct32_avx2(in0, in1); + break; +#if CONFIG_EXT_TX + case ADST_DCT: + load_buffer_32x32(input, stride, 0, 0, in0, in1); + fhalfright32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fdct32_avx2(in0, in1); + break; + case DCT_ADST: + load_buffer_32x32(input, stride, 0, 0, in0, in1); + fdct32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fhalfright32_avx2(in0, in1); + break; + case ADST_ADST: + load_buffer_32x32(input, stride, 0, 0, in0, in1); + fhalfright32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fhalfright32_avx2(in0, in1); + break; + case FLIPADST_DCT: + load_buffer_32x32(input, stride, 1, 0, in0, in1); + fhalfright32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fdct32_avx2(in0, in1); + break; + case DCT_FLIPADST: + load_buffer_32x32(input, stride, 0, 1, in0, in1); + fdct32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fhalfright32_avx2(in0, in1); + break; + case FLIPADST_FLIPADST: + load_buffer_32x32(input, stride, 1, 1, in0, in1); + fhalfright32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fhalfright32_avx2(in0, in1); + break; + case ADST_FLIPADST: + load_buffer_32x32(input, stride, 0, 1, in0, in1); + fhalfright32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fhalfright32_avx2(in0, in1); + break; + case FLIPADST_ADST: + load_buffer_32x32(input, stride, 1, 0, in0, in1); + fhalfright32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fhalfright32_avx2(in0, in1); + break; + case IDTX: + load_buffer_32x32(input, stride, 0, 0, in0, in1); + fidtx32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fidtx32_avx2(in0, in1); + break; + case V_DCT: + load_buffer_32x32(input, stride, 0, 0, in0, in1); + fdct32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fidtx32_avx2(in0, in1); + break; + case H_DCT: + load_buffer_32x32(input, stride, 0, 0, in0, in1); + fidtx32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fdct32_avx2(in0, in1); + break; + case V_ADST: + load_buffer_32x32(input, stride, 0, 0, in0, in1); + fhalfright32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fidtx32_avx2(in0, in1); + break; + case H_ADST: + load_buffer_32x32(input, stride, 0, 0, in0, in1); + fidtx32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fhalfright32_avx2(in0, in1); + break; + case V_FLIPADST: + load_buffer_32x32(input, stride, 1, 0, in0, in1); + fhalfright32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fidtx32_avx2(in0, in1); + break; + case H_FLIPADST: + load_buffer_32x32(input, stride, 0, 1, in0, in1); + fidtx32_avx2(in0, in1); + right_shift_32x32(in0, in1); + fhalfright32_avx2(in0, in1); + break; +#endif // CONFIG_EXT_TX + default: assert(0); break; + } + write_buffer_32x32(in0, in1, output); + _mm256_zeroupper(); +} diff --git a/third_party/aom/av1/encoder/x86/temporal_filter_apply_sse2.asm b/third_party/aom/av1/encoder/x86/temporal_filter_apply_sse2.asm new file mode 100644 index 000000000..7186b6b92 --- /dev/null +++ b/third_party/aom/av1/encoder/x86/temporal_filter_apply_sse2.asm @@ -0,0 +1,215 @@ +; +; Copyright (c) 2016, Alliance for Open Media. All rights reserved +; +; This source code is subject to the terms of the BSD 2 Clause License and +; the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License +; was not distributed with this source code in the LICENSE file, you can +; obtain it at www.aomedia.org/license/software. If the Alliance for Open +; Media Patent License 1.0 was not distributed with this source code in the +; PATENTS file, you can obtain it at www.aomedia.org/license/patent. +; + +; + + +%include "aom_ports/x86_abi_support.asm" + +; void av1_temporal_filter_apply_sse2 | arg +; (unsigned char *frame1, | 0 +; unsigned int stride, | 1 +; unsigned char *frame2, | 2 +; unsigned int block_width, | 3 +; unsigned int block_height, | 4 +; int strength, | 5 +; int filter_weight, | 6 +; unsigned int *accumulator, | 7 +; unsigned short *count) | 8 +global sym(av1_temporal_filter_apply_sse2) PRIVATE +sym(av1_temporal_filter_apply_sse2): + + push rbp + mov rbp, rsp + SHADOW_ARGS_TO_STACK 9 + SAVE_XMM 7 + GET_GOT rbx + push rsi + push rdi + ALIGN_STACK 16, rax + %define block_width 0 + %define block_height 16 + %define strength 32 + %define filter_weight 48 + %define rounding_bit 64 + %define rbp_backup 80 + %define stack_size 96 + sub rsp, stack_size + mov [rsp + rbp_backup], rbp + ; end prolog + + mov edx, arg(3) + mov [rsp + block_width], rdx + mov edx, arg(4) + mov [rsp + block_height], rdx + movd xmm6, arg(5) + movdqa [rsp + strength], xmm6 ; where strength is used, all 16 bytes are read + + ; calculate the rounding bit outside the loop + ; 0x8000 >> (16 - strength) + mov rdx, 16 + sub rdx, arg(5) ; 16 - strength + movq xmm4, rdx ; can't use rdx w/ shift + movdqa xmm5, [GLOBAL(_const_top_bit)] + psrlw xmm5, xmm4 + movdqa [rsp + rounding_bit], xmm5 + + mov rsi, arg(0) ; src/frame1 + mov rdx, arg(2) ; predictor frame + mov rdi, arg(7) ; accumulator + mov rax, arg(8) ; count + + ; dup the filter weight and store for later + movd xmm0, arg(6) ; filter_weight + pshuflw xmm0, xmm0, 0 + punpcklwd xmm0, xmm0 + movdqa [rsp + filter_weight], xmm0 + + mov rbp, arg(1) ; stride + pxor xmm7, xmm7 ; zero for extraction + + mov rcx, [rsp + block_width] + imul rcx, [rsp + block_height] + add rcx, rdx + cmp dword ptr [rsp + block_width], 8 + jne .temporal_filter_apply_load_16 + +.temporal_filter_apply_load_8: + movq xmm0, [rsi] ; first row + lea rsi, [rsi + rbp] ; += stride + punpcklbw xmm0, xmm7 ; src[ 0- 7] + movq xmm1, [rsi] ; second row + lea rsi, [rsi + rbp] ; += stride + punpcklbw xmm1, xmm7 ; src[ 8-15] + jmp .temporal_filter_apply_load_finished + +.temporal_filter_apply_load_16: + movdqa xmm0, [rsi] ; src (frame1) + lea rsi, [rsi + rbp] ; += stride + movdqa xmm1, xmm0 + punpcklbw xmm0, xmm7 ; src[ 0- 7] + punpckhbw xmm1, xmm7 ; src[ 8-15] + +.temporal_filter_apply_load_finished: + movdqa xmm2, [rdx] ; predictor (frame2) + movdqa xmm3, xmm2 + punpcklbw xmm2, xmm7 ; pred[ 0- 7] + punpckhbw xmm3, xmm7 ; pred[ 8-15] + + ; modifier = src_byte - pixel_value + psubw xmm0, xmm2 ; src - pred[ 0- 7] + psubw xmm1, xmm3 ; src - pred[ 8-15] + + ; modifier *= modifier + pmullw xmm0, xmm0 ; modifer[ 0- 7]^2 + pmullw xmm1, xmm1 ; modifer[ 8-15]^2 + + ; modifier *= 3 + pmullw xmm0, [GLOBAL(_const_3w)] + pmullw xmm1, [GLOBAL(_const_3w)] + + ; modifer += 0x8000 >> (16 - strength) + paddw xmm0, [rsp + rounding_bit] + paddw xmm1, [rsp + rounding_bit] + + ; modifier >>= strength + psrlw xmm0, [rsp + strength] + psrlw xmm1, [rsp + strength] + + ; modifier = 16 - modifier + ; saturation takes care of modifier > 16 + movdqa xmm3, [GLOBAL(_const_16w)] + movdqa xmm2, [GLOBAL(_const_16w)] + psubusw xmm3, xmm1 + psubusw xmm2, xmm0 + + ; modifier *= filter_weight + pmullw xmm2, [rsp + filter_weight] + pmullw xmm3, [rsp + filter_weight] + + ; count + movdqa xmm4, [rax] + movdqa xmm5, [rax+16] + ; += modifier + paddw xmm4, xmm2 + paddw xmm5, xmm3 + ; write back + movdqa [rax], xmm4 + movdqa [rax+16], xmm5 + lea rax, [rax + 16*2] ; count += 16*(sizeof(short)) + + ; load and extract the predictor up to shorts + pxor xmm7, xmm7 + movdqa xmm0, [rdx] + lea rdx, [rdx + 16*1] ; pred += 16*(sizeof(char)) + movdqa xmm1, xmm0 + punpcklbw xmm0, xmm7 ; pred[ 0- 7] + punpckhbw xmm1, xmm7 ; pred[ 8-15] + + ; modifier *= pixel_value + pmullw xmm0, xmm2 + pmullw xmm1, xmm3 + + ; expand to double words + movdqa xmm2, xmm0 + punpcklwd xmm0, xmm7 ; [ 0- 3] + punpckhwd xmm2, xmm7 ; [ 4- 7] + movdqa xmm3, xmm1 + punpcklwd xmm1, xmm7 ; [ 8-11] + punpckhwd xmm3, xmm7 ; [12-15] + + ; accumulator + movdqa xmm4, [rdi] + movdqa xmm5, [rdi+16] + movdqa xmm6, [rdi+32] + movdqa xmm7, [rdi+48] + ; += modifier + paddd xmm4, xmm0 + paddd xmm5, xmm2 + paddd xmm6, xmm1 + paddd xmm7, xmm3 + ; write back + movdqa [rdi], xmm4 + movdqa [rdi+16], xmm5 + movdqa [rdi+32], xmm6 + movdqa [rdi+48], xmm7 + lea rdi, [rdi + 16*4] ; accumulator += 16*(sizeof(int)) + + cmp rdx, rcx + je .temporal_filter_apply_epilog + pxor xmm7, xmm7 ; zero for extraction + cmp dword ptr [rsp + block_width], 16 + je .temporal_filter_apply_load_16 + jmp .temporal_filter_apply_load_8 + +.temporal_filter_apply_epilog: + ; begin epilog + mov rbp, [rsp + rbp_backup] + add rsp, stack_size + pop rsp + pop rdi + pop rsi + RESTORE_GOT + RESTORE_XMM + UNSHADOW_ARGS + pop rbp + ret + +SECTION_RODATA +align 16 +_const_3w: + times 8 dw 3 +align 16 +_const_top_bit: + times 8 dw 1<<15 +align 16 +_const_16w: + times 8 dw 16 diff --git a/third_party/aom/av1/encoder/x86/wedge_utils_sse2.c b/third_party/aom/av1/encoder/x86/wedge_utils_sse2.c new file mode 100644 index 000000000..bf233ca4d --- /dev/null +++ b/third_party/aom/av1/encoder/x86/wedge_utils_sse2.c @@ -0,0 +1,254 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <assert.h> +#include <immintrin.h> + +#include "aom_dsp/x86/synonyms.h" + +#include "aom/aom_integer.h" + +#include "av1/common/reconinter.h" + +#define MAX_MASK_VALUE (1 << WEDGE_WEIGHT_BITS) + +/** + * See av1_wedge_sse_from_residuals_c + */ +uint64_t av1_wedge_sse_from_residuals_sse2(const int16_t *r1, const int16_t *d, + const uint8_t *m, int N) { + int n = -N; + int n8 = n + 8; + + uint64_t csse; + + const __m128i v_mask_max_w = _mm_set1_epi16(MAX_MASK_VALUE); + const __m128i v_zext_q = _mm_set_epi32(0, 0xffffffff, 0, 0xffffffff); + + __m128i v_acc0_q = _mm_setzero_si128(); + + assert(N % 64 == 0); + + r1 += N; + d += N; + m += N; + + do { + const __m128i v_r0_w = xx_load_128(r1 + n); + const __m128i v_r1_w = xx_load_128(r1 + n8); + const __m128i v_d0_w = xx_load_128(d + n); + const __m128i v_d1_w = xx_load_128(d + n8); + const __m128i v_m01_b = xx_load_128(m + n); + + const __m128i v_rd0l_w = _mm_unpacklo_epi16(v_d0_w, v_r0_w); + const __m128i v_rd0h_w = _mm_unpackhi_epi16(v_d0_w, v_r0_w); + const __m128i v_rd1l_w = _mm_unpacklo_epi16(v_d1_w, v_r1_w); + const __m128i v_rd1h_w = _mm_unpackhi_epi16(v_d1_w, v_r1_w); + const __m128i v_m0_w = _mm_unpacklo_epi8(v_m01_b, _mm_setzero_si128()); + const __m128i v_m1_w = _mm_unpackhi_epi8(v_m01_b, _mm_setzero_si128()); + + const __m128i v_m0l_w = _mm_unpacklo_epi16(v_m0_w, v_mask_max_w); + const __m128i v_m0h_w = _mm_unpackhi_epi16(v_m0_w, v_mask_max_w); + const __m128i v_m1l_w = _mm_unpacklo_epi16(v_m1_w, v_mask_max_w); + const __m128i v_m1h_w = _mm_unpackhi_epi16(v_m1_w, v_mask_max_w); + + const __m128i v_t0l_d = _mm_madd_epi16(v_rd0l_w, v_m0l_w); + const __m128i v_t0h_d = _mm_madd_epi16(v_rd0h_w, v_m0h_w); + const __m128i v_t1l_d = _mm_madd_epi16(v_rd1l_w, v_m1l_w); + const __m128i v_t1h_d = _mm_madd_epi16(v_rd1h_w, v_m1h_w); + + const __m128i v_t0_w = _mm_packs_epi32(v_t0l_d, v_t0h_d); + const __m128i v_t1_w = _mm_packs_epi32(v_t1l_d, v_t1h_d); + + const __m128i v_sq0_d = _mm_madd_epi16(v_t0_w, v_t0_w); + const __m128i v_sq1_d = _mm_madd_epi16(v_t1_w, v_t1_w); + + const __m128i v_sum0_q = _mm_add_epi64(_mm_and_si128(v_sq0_d, v_zext_q), + _mm_srli_epi64(v_sq0_d, 32)); + const __m128i v_sum1_q = _mm_add_epi64(_mm_and_si128(v_sq1_d, v_zext_q), + _mm_srli_epi64(v_sq1_d, 32)); + + v_acc0_q = _mm_add_epi64(v_acc0_q, v_sum0_q); + v_acc0_q = _mm_add_epi64(v_acc0_q, v_sum1_q); + + n8 += 16; + n += 16; + } while (n); + + v_acc0_q = _mm_add_epi64(v_acc0_q, _mm_srli_si128(v_acc0_q, 8)); + +#if ARCH_X86_64 + csse = (uint64_t)_mm_cvtsi128_si64(v_acc0_q); +#else + xx_storel_64(&csse, v_acc0_q); +#endif + + return ROUND_POWER_OF_TWO(csse, 2 * WEDGE_WEIGHT_BITS); +} + +/** + * See av1_wedge_sign_from_residuals_c + */ +int av1_wedge_sign_from_residuals_sse2(const int16_t *ds, const uint8_t *m, + int N, int64_t limit) { + int64_t acc; + + __m128i v_sign_d; + __m128i v_acc0_d = _mm_setzero_si128(); + __m128i v_acc1_d = _mm_setzero_si128(); + __m128i v_acc_q; + + // Input size limited to 8192 by the use of 32 bit accumulators and m + // being between [0, 64]. Overflow might happen at larger sizes, + // though it is practically impossible on real video input. + assert(N < 8192); + assert(N % 64 == 0); + + do { + const __m128i v_m01_b = xx_load_128(m); + const __m128i v_m23_b = xx_load_128(m + 16); + const __m128i v_m45_b = xx_load_128(m + 32); + const __m128i v_m67_b = xx_load_128(m + 48); + + const __m128i v_d0_w = xx_load_128(ds); + const __m128i v_d1_w = xx_load_128(ds + 8); + const __m128i v_d2_w = xx_load_128(ds + 16); + const __m128i v_d3_w = xx_load_128(ds + 24); + const __m128i v_d4_w = xx_load_128(ds + 32); + const __m128i v_d5_w = xx_load_128(ds + 40); + const __m128i v_d6_w = xx_load_128(ds + 48); + const __m128i v_d7_w = xx_load_128(ds + 56); + + const __m128i v_m0_w = _mm_unpacklo_epi8(v_m01_b, _mm_setzero_si128()); + const __m128i v_m1_w = _mm_unpackhi_epi8(v_m01_b, _mm_setzero_si128()); + const __m128i v_m2_w = _mm_unpacklo_epi8(v_m23_b, _mm_setzero_si128()); + const __m128i v_m3_w = _mm_unpackhi_epi8(v_m23_b, _mm_setzero_si128()); + const __m128i v_m4_w = _mm_unpacklo_epi8(v_m45_b, _mm_setzero_si128()); + const __m128i v_m5_w = _mm_unpackhi_epi8(v_m45_b, _mm_setzero_si128()); + const __m128i v_m6_w = _mm_unpacklo_epi8(v_m67_b, _mm_setzero_si128()); + const __m128i v_m7_w = _mm_unpackhi_epi8(v_m67_b, _mm_setzero_si128()); + + const __m128i v_p0_d = _mm_madd_epi16(v_d0_w, v_m0_w); + const __m128i v_p1_d = _mm_madd_epi16(v_d1_w, v_m1_w); + const __m128i v_p2_d = _mm_madd_epi16(v_d2_w, v_m2_w); + const __m128i v_p3_d = _mm_madd_epi16(v_d3_w, v_m3_w); + const __m128i v_p4_d = _mm_madd_epi16(v_d4_w, v_m4_w); + const __m128i v_p5_d = _mm_madd_epi16(v_d5_w, v_m5_w); + const __m128i v_p6_d = _mm_madd_epi16(v_d6_w, v_m6_w); + const __m128i v_p7_d = _mm_madd_epi16(v_d7_w, v_m7_w); + + const __m128i v_p01_d = _mm_add_epi32(v_p0_d, v_p1_d); + const __m128i v_p23_d = _mm_add_epi32(v_p2_d, v_p3_d); + const __m128i v_p45_d = _mm_add_epi32(v_p4_d, v_p5_d); + const __m128i v_p67_d = _mm_add_epi32(v_p6_d, v_p7_d); + + const __m128i v_p0123_d = _mm_add_epi32(v_p01_d, v_p23_d); + const __m128i v_p4567_d = _mm_add_epi32(v_p45_d, v_p67_d); + + v_acc0_d = _mm_add_epi32(v_acc0_d, v_p0123_d); + v_acc1_d = _mm_add_epi32(v_acc1_d, v_p4567_d); + + ds += 64; + m += 64; + + N -= 64; + } while (N); + + v_sign_d = _mm_cmplt_epi32(v_acc0_d, _mm_setzero_si128()); + v_acc0_d = _mm_add_epi64(_mm_unpacklo_epi32(v_acc0_d, v_sign_d), + _mm_unpackhi_epi32(v_acc0_d, v_sign_d)); + + v_sign_d = _mm_cmplt_epi32(v_acc1_d, _mm_setzero_si128()); + v_acc1_d = _mm_add_epi64(_mm_unpacklo_epi32(v_acc1_d, v_sign_d), + _mm_unpackhi_epi32(v_acc1_d, v_sign_d)); + + v_acc_q = _mm_add_epi64(v_acc0_d, v_acc1_d); + + v_acc_q = _mm_add_epi64(v_acc_q, _mm_srli_si128(v_acc_q, 8)); + +#if ARCH_X86_64 + acc = (uint64_t)_mm_cvtsi128_si64(v_acc_q); +#else + xx_storel_64(&acc, v_acc_q); +#endif + + return acc > limit; +} + +// Negate under mask +static INLINE __m128i negm_epi16(__m128i v_v_w, __m128i v_mask_w) { + return _mm_sub_epi16(_mm_xor_si128(v_v_w, v_mask_w), v_mask_w); +} + +/** + * av1_wedge_compute_delta_squares_c + */ +void av1_wedge_compute_delta_squares_sse2(int16_t *d, const int16_t *a, + const int16_t *b, int N) { + const __m128i v_neg_w = + _mm_set_epi16(0xffff, 0, 0xffff, 0, 0xffff, 0, 0xffff, 0); + + assert(N % 64 == 0); + + do { + const __m128i v_a0_w = xx_load_128(a); + const __m128i v_b0_w = xx_load_128(b); + const __m128i v_a1_w = xx_load_128(a + 8); + const __m128i v_b1_w = xx_load_128(b + 8); + const __m128i v_a2_w = xx_load_128(a + 16); + const __m128i v_b2_w = xx_load_128(b + 16); + const __m128i v_a3_w = xx_load_128(a + 24); + const __m128i v_b3_w = xx_load_128(b + 24); + + const __m128i v_ab0l_w = _mm_unpacklo_epi16(v_a0_w, v_b0_w); + const __m128i v_ab0h_w = _mm_unpackhi_epi16(v_a0_w, v_b0_w); + const __m128i v_ab1l_w = _mm_unpacklo_epi16(v_a1_w, v_b1_w); + const __m128i v_ab1h_w = _mm_unpackhi_epi16(v_a1_w, v_b1_w); + const __m128i v_ab2l_w = _mm_unpacklo_epi16(v_a2_w, v_b2_w); + const __m128i v_ab2h_w = _mm_unpackhi_epi16(v_a2_w, v_b2_w); + const __m128i v_ab3l_w = _mm_unpacklo_epi16(v_a3_w, v_b3_w); + const __m128i v_ab3h_w = _mm_unpackhi_epi16(v_a3_w, v_b3_w); + + // Negate top word of pairs + const __m128i v_abl0n_w = negm_epi16(v_ab0l_w, v_neg_w); + const __m128i v_abh0n_w = negm_epi16(v_ab0h_w, v_neg_w); + const __m128i v_abl1n_w = negm_epi16(v_ab1l_w, v_neg_w); + const __m128i v_abh1n_w = negm_epi16(v_ab1h_w, v_neg_w); + const __m128i v_abl2n_w = negm_epi16(v_ab2l_w, v_neg_w); + const __m128i v_abh2n_w = negm_epi16(v_ab2h_w, v_neg_w); + const __m128i v_abl3n_w = negm_epi16(v_ab3l_w, v_neg_w); + const __m128i v_abh3n_w = negm_epi16(v_ab3h_w, v_neg_w); + + const __m128i v_r0l_w = _mm_madd_epi16(v_ab0l_w, v_abl0n_w); + const __m128i v_r0h_w = _mm_madd_epi16(v_ab0h_w, v_abh0n_w); + const __m128i v_r1l_w = _mm_madd_epi16(v_ab1l_w, v_abl1n_w); + const __m128i v_r1h_w = _mm_madd_epi16(v_ab1h_w, v_abh1n_w); + const __m128i v_r2l_w = _mm_madd_epi16(v_ab2l_w, v_abl2n_w); + const __m128i v_r2h_w = _mm_madd_epi16(v_ab2h_w, v_abh2n_w); + const __m128i v_r3l_w = _mm_madd_epi16(v_ab3l_w, v_abl3n_w); + const __m128i v_r3h_w = _mm_madd_epi16(v_ab3h_w, v_abh3n_w); + + const __m128i v_r0_w = _mm_packs_epi32(v_r0l_w, v_r0h_w); + const __m128i v_r1_w = _mm_packs_epi32(v_r1l_w, v_r1h_w); + const __m128i v_r2_w = _mm_packs_epi32(v_r2l_w, v_r2h_w); + const __m128i v_r3_w = _mm_packs_epi32(v_r3l_w, v_r3h_w); + + xx_store_128(d, v_r0_w); + xx_store_128(d + 8, v_r1_w); + xx_store_128(d + 16, v_r2_w); + xx_store_128(d + 24, v_r3_w); + + a += 32; + b += 32; + d += 32; + N -= 32; + } while (N); +} |