diff options
author | trav90 <travawine@palemoon.org> | 2018-10-19 23:05:00 -0500 |
---|---|---|
committer | trav90 <travawine@palemoon.org> | 2018-10-19 23:05:03 -0500 |
commit | d2499ead93dc4298c0882fe98902acb1b5209f99 (patch) | |
tree | cb0b942aed59e5108f9a3e9d64e7b77854383421 /third_party/aom/av1/common/x86 | |
parent | 41fbdea457bf50c0a43e1c27c5cbf7f0a3a9eb33 (diff) | |
download | UXP-d2499ead93dc4298c0882fe98902acb1b5209f99.tar UXP-d2499ead93dc4298c0882fe98902acb1b5209f99.tar.gz UXP-d2499ead93dc4298c0882fe98902acb1b5209f99.tar.lz UXP-d2499ead93dc4298c0882fe98902acb1b5209f99.tar.xz UXP-d2499ead93dc4298c0882fe98902acb1b5209f99.zip |
Update libaom to commit ID 1e227d41f0616de9548a673a83a21ef990b62591
Diffstat (limited to 'third_party/aom/av1/common/x86')
26 files changed, 5822 insertions, 1166 deletions
diff --git a/third_party/aom/av1/common/x86/av1_convolve_scale_sse4.c b/third_party/aom/av1/common/x86/av1_convolve_scale_sse4.c index 0c5286f9d..d9fb53785 100644 --- a/third_party/aom/av1/common/x86/av1_convolve_scale_sse4.c +++ b/third_party/aom/av1/common/x86/av1_convolve_scale_sse4.c @@ -14,7 +14,6 @@ #include "config/aom_dsp_rtcd.h" -#include "aom_dsp/aom_convolve.h" #include "aom_dsp/aom_dsp_common.h" #include "aom_dsp/aom_filter.h" #include "av1/common/convolve.h" diff --git a/third_party/aom/av1/common/x86/av1_inv_txfm_avx2.c b/third_party/aom/av1/common/x86/av1_inv_txfm_avx2.c index ae331b40d..5db2ccf6c 100644 --- a/third_party/aom/av1/common/x86/av1_inv_txfm_avx2.c +++ b/third_party/aom/av1/common/x86/av1_inv_txfm_avx2.c @@ -18,6 +18,12 @@ #include "av1/common/x86/av1_inv_txfm_avx2.h" #include "av1/common/x86/av1_inv_txfm_ssse3.h" +// TODO(venkatsanampudi@ittiam.com): move this to header file + +// Sqrt2, Sqrt2^2, Sqrt2^3, Sqrt2^4, Sqrt2^5 +static int32_t NewSqrt2list[TX_SIZES] = { 5793, 2 * 4096, 2 * 5793, 4 * 4096, + 4 * 5793 }; + static INLINE void idct16_stage5_avx2(__m256i *x1, const int32_t *cospi, const __m256i _r, int8_t cos_bit) { const __m256i cospi_m32_p32 = pair_set_w16_epi16(-cospi[32], cospi[32]); diff --git a/third_party/aom/av1/common/x86/av1_inv_txfm_avx2.h b/third_party/aom/av1/common/x86/av1_inv_txfm_avx2.h index 7b5b29cf8..f74cbaeaa 100644 --- a/third_party/aom/av1/common/x86/av1_inv_txfm_avx2.h +++ b/third_party/aom/av1/common/x86/av1_inv_txfm_avx2.h @@ -8,8 +8,8 @@ * Media Patent License 1.0 was not distributed with this source code in the * PATENTS file, you can obtain it at www.aomedia.org/license/patent. */ -#ifndef AV1_COMMON_X86_AV1_INV_TXFM_AVX2_H_ -#define AV1_COMMON_X86_AV1_INV_TXFM_AVX2_H_ +#ifndef AOM_AV1_COMMON_X86_AV1_INV_TXFM_AVX2_H_ +#define AOM_AV1_COMMON_X86_AV1_INV_TXFM_AVX2_H_ #include <immintrin.h> @@ -68,4 +68,4 @@ void av1_lowbd_inv_txfm2d_add_avx2(const int32_t *input, uint8_t *output, } #endif -#endif // AV1_COMMON_X86_AV1_INV_TXFM_AVX2_H_ +#endif // AOM_AV1_COMMON_X86_AV1_INV_TXFM_AVX2_H_ diff --git a/third_party/aom/av1/common/x86/av1_inv_txfm_ssse3.c b/third_party/aom/av1/common/x86/av1_inv_txfm_ssse3.c index dd7cee24c..995bc3da4 100644 --- a/third_party/aom/av1/common/x86/av1_inv_txfm_ssse3.c +++ b/third_party/aom/av1/common/x86/av1_inv_txfm_ssse3.c @@ -16,6 +16,12 @@ #include "av1/common/x86/av1_inv_txfm_ssse3.h" #include "av1/common/x86/av1_txfm_sse2.h" +// TODO(venkatsanampudi@ittiam.com): move this to header file + +// Sqrt2, Sqrt2^2, Sqrt2^3, Sqrt2^4, Sqrt2^5 +static int32_t NewSqrt2list[TX_SIZES] = { 5793, 2 * 4096, 2 * 5793, 4 * 4096, + 4 * 5793 }; + // TODO(binpengsmail@gmail.com): replace some for loop with do {} while static void idct4_new_sse2(const __m128i *input, __m128i *output, diff --git a/third_party/aom/av1/common/x86/av1_inv_txfm_ssse3.h b/third_party/aom/av1/common/x86/av1_inv_txfm_ssse3.h index dc9be25d2..66bd339d1 100644 --- a/third_party/aom/av1/common/x86/av1_inv_txfm_ssse3.h +++ b/third_party/aom/av1/common/x86/av1_inv_txfm_ssse3.h @@ -8,8 +8,8 @@ * Media Patent License 1.0 was not distributed with this source code in the * PATENTS file, you can obtain it at www.aomedia.org/license/patent. */ -#ifndef AV1_COMMON_X86_AV1_INV_TXFM_SSSE3_H_ -#define AV1_COMMON_X86_AV1_INV_TXFM_SSSE3_H_ +#ifndef AOM_AV1_COMMON_X86_AV1_INV_TXFM_SSSE3_H_ +#define AOM_AV1_COMMON_X86_AV1_INV_TXFM_SSSE3_H_ #include <emmintrin.h> // SSE2 #include <tmmintrin.h> // SSSE3 @@ -94,10 +94,6 @@ static const ITX_TYPE_1D hitx_1d_tab[TX_TYPES] = { IIDENTITY_1D, IADST_1D, IIDENTITY_1D, IFLIPADST_1D, }; -// Sqrt2, Sqrt2^2, Sqrt2^3, Sqrt2^4, Sqrt2^5 -static int32_t NewSqrt2list[TX_SIZES] = { 5793, 2 * 4096, 2 * 5793, 4 * 4096, - 4 * 5793 }; - DECLARE_ALIGNED(16, static const int16_t, av1_eob_to_eobxy_8x8_default[8]) = { 0x0707, 0x0707, 0x0707, 0x0707, 0x0707, 0x0707, 0x0707, 0x0707, }; @@ -233,4 +229,4 @@ void av1_lowbd_inv_txfm2d_add_ssse3(const int32_t *input, uint8_t *output, } // extern "C" #endif -#endif // AV1_COMMON_X86_AV1_INV_TXFM_SSSE3_H_ +#endif // AOM_AV1_COMMON_X86_AV1_INV_TXFM_SSSE3_H_ diff --git a/third_party/aom/av1/common/x86/av1_txfm_sse2.h b/third_party/aom/av1/common/x86/av1_txfm_sse2.h index 721cfe059..77aeb6eb1 100644 --- a/third_party/aom/av1/common/x86/av1_txfm_sse2.h +++ b/third_party/aom/av1/common/x86/av1_txfm_sse2.h @@ -8,8 +8,8 @@ * Media Patent License 1.0 was not distributed with this source code in the * PATENTS file, you can obtain it at www.aomedia.org/license/patent. */ -#ifndef AV1_COMMON_X86_AV1_TXFM_SSE2_H_ -#define AV1_COMMON_X86_AV1_TXFM_SSE2_H_ +#ifndef AOM_AV1_COMMON_X86_AV1_TXFM_SSE2_H_ +#define AOM_AV1_COMMON_X86_AV1_TXFM_SSE2_H_ #include <emmintrin.h> // SSE2 @@ -314,4 +314,4 @@ typedef struct { #ifdef __cplusplus } #endif // __cplusplus -#endif // AV1_COMMON_X86_AV1_TXFM_SSE2_H_ +#endif // AOM_AV1_COMMON_X86_AV1_TXFM_SSE2_H_ diff --git a/third_party/aom/av1/common/x86/av1_txfm_sse4.h b/third_party/aom/av1/common/x86/av1_txfm_sse4.h index 367e02096..6cad821b1 100644 --- a/third_party/aom/av1/common/x86/av1_txfm_sse4.h +++ b/third_party/aom/av1/common/x86/av1_txfm_sse4.h @@ -9,8 +9,8 @@ * PATENTS file, you can obtain it at www.aomedia.org/license/patent. */ -#ifndef AV1_TXFM_SSE4_H_ -#define AV1_TXFM_SSE4_H_ +#ifndef AOM_AV1_COMMON_X86_AV1_TXFM_SSE4_H_ +#define AOM_AV1_COMMON_X86_AV1_TXFM_SSE4_H_ #include <smmintrin.h> @@ -45,8 +45,9 @@ static INLINE void av1_round_shift_array_32_sse4_1(__m128i *input, static INLINE void av1_round_shift_rect_array_32_sse4_1(__m128i *input, __m128i *output, const int size, - const int bit) { - const __m128i sqrt2 = _mm_set1_epi32(NewSqrt2); + const int bit, + const int val) { + const __m128i sqrt2 = _mm_set1_epi32(val); if (bit > 0) { int i; for (i = 0; i < size; i++) { @@ -68,4 +69,4 @@ static INLINE void av1_round_shift_rect_array_32_sse4_1(__m128i *input, } #endif -#endif // AV1_TXFM_SSE4_H_ +#endif // AOM_AV1_COMMON_X86_AV1_TXFM_SSE4_H_ diff --git a/third_party/aom/av1/common/x86/cfl_simd.h b/third_party/aom/av1/common/x86/cfl_simd.h index 7479ac3e1..3b342cd4e 100644 --- a/third_party/aom/av1/common/x86/cfl_simd.h +++ b/third_party/aom/av1/common/x86/cfl_simd.h @@ -9,6 +9,9 @@ * PATENTS file, you can obtain it at www.aomedia.org/license/patent. */ +#ifndef AOM_AV1_COMMON_X86_CFL_SIMD_H_ +#define AOM_AV1_COMMON_X86_CFL_SIMD_H_ + #include "av1/common/blockd.h" // SSSE3 version is optimal for with == 4, we reuse them in AVX2 @@ -236,3 +239,5 @@ void predict_hbd_16x16_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, int dst_stride, int alpha_q3, int bd); void predict_hbd_16x32_ssse3(const int16_t *pred_buf_q3, uint16_t *dst, int dst_stride, int alpha_q3, int bd); + +#endif // AOM_AV1_COMMON_X86_CFL_SIMD_H_ diff --git a/third_party/aom/av1/common/x86/convolve_2d_avx2.c b/third_party/aom/av1/common/x86/convolve_2d_avx2.c index 1099144fe..0acafd044 100644 --- a/third_party/aom/av1/common/x86/convolve_2d_avx2.c +++ b/third_party/aom/av1/common/x86/convolve_2d_avx2.c @@ -11,10 +11,8 @@ #include <immintrin.h> -#include "config/aom_dsp_rtcd.h" #include "config/av1_rtcd.h" -#include "aom_dsp/aom_convolve.h" #include "aom_dsp/x86/convolve_avx2.h" #include "aom_dsp/x86/convolve_common_intrin.h" #include "aom_dsp/aom_dsp_common.h" diff --git a/third_party/aom/av1/common/x86/convolve_2d_sse2.c b/third_party/aom/av1/common/x86/convolve_2d_sse2.c index 637f83cf7..b1a62a4f6 100644 --- a/third_party/aom/av1/common/x86/convolve_2d_sse2.c +++ b/third_party/aom/av1/common/x86/convolve_2d_sse2.c @@ -11,9 +11,8 @@ #include <emmintrin.h> -#include "config/aom_dsp_rtcd.h" +#include "config/av1_rtcd.h" -#include "aom_dsp/aom_convolve.h" #include "aom_dsp/aom_dsp_common.h" #include "aom_dsp/aom_filter.h" #include "aom_dsp/x86/convolve_sse2.h" diff --git a/third_party/aom/av1/common/x86/convolve_sse2.c b/third_party/aom/av1/common/x86/convolve_sse2.c index f66dee37d..5016642de 100644 --- a/third_party/aom/av1/common/x86/convolve_sse2.c +++ b/third_party/aom/av1/common/x86/convolve_sse2.c @@ -11,9 +11,8 @@ #include <emmintrin.h> -#include "config/aom_dsp_rtcd.h" +#include "config/av1_rtcd.h" -#include "aom_dsp/aom_convolve.h" #include "aom_dsp/aom_dsp_common.h" #include "aom_dsp/aom_filter.h" #include "aom_dsp/x86/convolve_common_intrin.h" @@ -76,8 +75,8 @@ static INLINE __m128i convolve_hi_y(const __m128i *const s, return convolve(ss, coeffs); } -void av1_convolve_y_sr_sse2(const uint8_t *src, int src_stride, - const uint8_t *dst, int dst_stride, int w, int h, +void av1_convolve_y_sr_sse2(const uint8_t *src, int src_stride, uint8_t *dst, + int dst_stride, int w, int h, const InterpFilterParams *filter_params_x, const InterpFilterParams *filter_params_y, const int subpel_x_q4, const int subpel_y_q4, @@ -237,8 +236,8 @@ void av1_convolve_y_sr_sse2(const uint8_t *src, int src_stride, } } -void av1_convolve_x_sr_sse2(const uint8_t *src, int src_stride, - const uint8_t *dst, int dst_stride, int w, int h, +void av1_convolve_x_sr_sse2(const uint8_t *src, int src_stride, uint8_t *dst, + int dst_stride, int w, int h, const InterpFilterParams *filter_params_x, const InterpFilterParams *filter_params_y, const int subpel_x_q4, const int subpel_y_q4, diff --git a/third_party/aom/av1/common/x86/highbd_convolve_2d_avx2.c b/third_party/aom/av1/common/x86/highbd_convolve_2d_avx2.c index 8444ffa93..ae68f0bbb 100644 --- a/third_party/aom/av1/common/x86/highbd_convolve_2d_avx2.c +++ b/third_party/aom/av1/common/x86/highbd_convolve_2d_avx2.c @@ -14,7 +14,6 @@ #include "config/aom_dsp_rtcd.h" -#include "aom_dsp/aom_convolve.h" #include "aom_dsp/x86/convolve_avx2.h" #include "aom_dsp/x86/synonyms.h" #include "aom_dsp/aom_dsp_common.h" diff --git a/third_party/aom/av1/common/x86/highbd_convolve_2d_sse4.c b/third_party/aom/av1/common/x86/highbd_convolve_2d_sse4.c index eb340523a..3f8dafb4b 100644 --- a/third_party/aom/av1/common/x86/highbd_convolve_2d_sse4.c +++ b/third_party/aom/av1/common/x86/highbd_convolve_2d_sse4.c @@ -15,7 +15,6 @@ #include "config/aom_dsp_rtcd.h" -#include "aom_dsp/aom_convolve.h" #include "aom_dsp/aom_dsp_common.h" #include "aom_dsp/aom_filter.h" #include "aom_dsp/x86/convolve_sse2.h" diff --git a/third_party/aom/av1/common/x86/highbd_convolve_2d_ssse3.c b/third_party/aom/av1/common/x86/highbd_convolve_2d_ssse3.c index 33183fdee..1d029db39 100644 --- a/third_party/aom/av1/common/x86/highbd_convolve_2d_ssse3.c +++ b/third_party/aom/av1/common/x86/highbd_convolve_2d_ssse3.c @@ -14,7 +14,6 @@ #include "config/aom_dsp_rtcd.h" -#include "aom_dsp/aom_convolve.h" #include "aom_dsp/aom_dsp_common.h" #include "aom_dsp/aom_filter.h" #include "aom_dsp/x86/convolve_sse2.h" diff --git a/third_party/aom/av1/common/x86/highbd_inv_txfm_avx2.c b/third_party/aom/av1/common/x86/highbd_inv_txfm_avx2.c index debb05a6d..ade2af03e 100644 --- a/third_party/aom/av1/common/x86/highbd_inv_txfm_avx2.c +++ b/third_party/aom/av1/common/x86/highbd_inv_txfm_avx2.c @@ -15,6 +15,9 @@ #include "config/av1_rtcd.h" #include "av1/common/av1_inv_txfm1d_cfg.h" +#include "av1/common/idct.h" +#include "av1/common/x86/av1_inv_txfm_ssse3.h" +#include "av1/common/x86/highbd_txfm_utility_sse4.h" // Note: // Total 32x4 registers to represent 32x32 block coefficients. @@ -27,131 +30,125 @@ // ... ... // v124, v125, v126, v127 -static void transpose_32x32_8x8(const __m256i *in, __m256i *out) { +static INLINE __m256i highbd_clamp_epi16_avx2(__m256i u, int bd) { + const __m256i zero = _mm256_setzero_si256(); + const __m256i one = _mm256_set1_epi16(1); + const __m256i max = _mm256_sub_epi16(_mm256_slli_epi16(one, bd), one); + __m256i clamped, mask; + + mask = _mm256_cmpgt_epi16(u, max); + clamped = _mm256_andnot_si256(mask, u); + mask = _mm256_and_si256(mask, max); + clamped = _mm256_or_si256(mask, clamped); + mask = _mm256_cmpgt_epi16(clamped, zero); + clamped = _mm256_and_si256(clamped, mask); + + return clamped; +} + +static INLINE __m256i highbd_get_recon_16x8_avx2(const __m256i pred, + __m256i res0, __m256i res1, + const int bd) { + __m256i x0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(pred)); + __m256i x1 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(pred, 1)); + + x0 = _mm256_add_epi32(res0, x0); + x1 = _mm256_add_epi32(res1, x1); + x0 = _mm256_packus_epi32(x0, x1); + x0 = _mm256_permute4x64_epi64(x0, 0xd8); + x0 = highbd_clamp_epi16_avx2(x0, bd); + return x0; +} + +static INLINE void highbd_write_buffer_16xn_avx2(__m256i *in, uint16_t *output, + int stride, int flipud, + int height, const int bd) { + int j = flipud ? (height - 1) : 0; + const int step = flipud ? -1 : 1; + for (int i = 0; i < height; ++i, j += step) { + __m256i v = _mm256_loadu_si256((__m256i const *)(output + i * stride)); + __m256i u = highbd_get_recon_16x8_avx2(v, in[j], in[j + height], bd); + + _mm256_storeu_si256((__m256i *)(output + i * stride), u); + } +} + +static INLINE __m256i av1_round_shift_32_avx2(__m256i vec, int bit) { + __m256i tmp, round; + round = _mm256_set1_epi32(1 << (bit - 1)); + tmp = _mm256_add_epi32(vec, round); + return _mm256_srai_epi32(tmp, bit); +} + +static INLINE void av1_round_shift_array_32_avx2(__m256i *input, + __m256i *output, + const int size, + const int bit) { + if (bit > 0) { + int i; + for (i = 0; i < size; i++) { + output[i] = av1_round_shift_32_avx2(input[i], bit); + } + } else { + int i; + for (i = 0; i < size; i++) { + output[i] = _mm256_slli_epi32(input[i], -bit); + } + } +} + +static void transpose_8x8_avx2(const __m256i *in, __m256i *out) { __m256i u0, u1, u2, u3, u4, u5, u6, u7; __m256i x0, x1; - u0 = _mm256_unpacklo_epi32(in[0], in[4]); - u1 = _mm256_unpackhi_epi32(in[0], in[4]); + u0 = _mm256_unpacklo_epi32(in[0], in[1]); + u1 = _mm256_unpackhi_epi32(in[0], in[1]); - u2 = _mm256_unpacklo_epi32(in[8], in[12]); - u3 = _mm256_unpackhi_epi32(in[8], in[12]); + u2 = _mm256_unpacklo_epi32(in[2], in[3]); + u3 = _mm256_unpackhi_epi32(in[2], in[3]); - u4 = _mm256_unpacklo_epi32(in[16], in[20]); - u5 = _mm256_unpackhi_epi32(in[16], in[20]); + u4 = _mm256_unpacklo_epi32(in[4], in[5]); + u5 = _mm256_unpackhi_epi32(in[4], in[5]); - u6 = _mm256_unpacklo_epi32(in[24], in[28]); - u7 = _mm256_unpackhi_epi32(in[24], in[28]); + u6 = _mm256_unpacklo_epi32(in[6], in[7]); + u7 = _mm256_unpackhi_epi32(in[6], in[7]); x0 = _mm256_unpacklo_epi64(u0, u2); x1 = _mm256_unpacklo_epi64(u4, u6); out[0] = _mm256_permute2f128_si256(x0, x1, 0x20); - out[16] = _mm256_permute2f128_si256(x0, x1, 0x31); + out[4] = _mm256_permute2f128_si256(x0, x1, 0x31); x0 = _mm256_unpackhi_epi64(u0, u2); x1 = _mm256_unpackhi_epi64(u4, u6); - out[4] = _mm256_permute2f128_si256(x0, x1, 0x20); - out[20] = _mm256_permute2f128_si256(x0, x1, 0x31); + out[1] = _mm256_permute2f128_si256(x0, x1, 0x20); + out[5] = _mm256_permute2f128_si256(x0, x1, 0x31); x0 = _mm256_unpacklo_epi64(u1, u3); x1 = _mm256_unpacklo_epi64(u5, u7); - out[8] = _mm256_permute2f128_si256(x0, x1, 0x20); - out[24] = _mm256_permute2f128_si256(x0, x1, 0x31); + out[2] = _mm256_permute2f128_si256(x0, x1, 0x20); + out[6] = _mm256_permute2f128_si256(x0, x1, 0x31); x0 = _mm256_unpackhi_epi64(u1, u3); x1 = _mm256_unpackhi_epi64(u5, u7); - out[12] = _mm256_permute2f128_si256(x0, x1, 0x20); - out[28] = _mm256_permute2f128_si256(x0, x1, 0x31); -} - -static void transpose_32x32_16x16(const __m256i *in, __m256i *out) { - transpose_32x32_8x8(&in[0], &out[0]); - transpose_32x32_8x8(&in[1], &out[32]); - transpose_32x32_8x8(&in[32], &out[1]); - transpose_32x32_8x8(&in[33], &out[33]); -} - -static void transpose_32x32(const __m256i *in, __m256i *out) { - transpose_32x32_16x16(&in[0], &out[0]); - transpose_32x32_16x16(&in[2], &out[64]); - transpose_32x32_16x16(&in[64], &out[2]); - transpose_32x32_16x16(&in[66], &out[66]); + out[3] = _mm256_permute2f128_si256(x0, x1, 0x20); + out[7] = _mm256_permute2f128_si256(x0, x1, 0x31); } -static void load_buffer_32x32(const int32_t *coeff, __m256i *in) { +static void load_buffer_32x32(const int32_t *coeff, __m256i *in, + int input_stiride, int size) { int i; - for (i = 0; i < 128; ++i) { - in[i] = _mm256_loadu_si256((const __m256i *)coeff); - coeff += 8; + for (i = 0; i < size; ++i) { + in[i] = _mm256_loadu_si256((const __m256i *)(coeff + i * input_stiride)); } } -static __m256i highbd_clamp_epi32(__m256i x, int bd) { - const __m256i zero = _mm256_setzero_si256(); - const __m256i one = _mm256_set1_epi16(1); - const __m256i max = _mm256_sub_epi16(_mm256_slli_epi16(one, bd), one); - __m256i clamped, mask; - - mask = _mm256_cmpgt_epi16(x, max); - clamped = _mm256_andnot_si256(mask, x); - mask = _mm256_and_si256(mask, max); - clamped = _mm256_or_si256(mask, clamped); - mask = _mm256_cmpgt_epi16(clamped, zero); - clamped = _mm256_and_si256(clamped, mask); - - return clamped; -} - -static void write_buffer_32x32(__m256i *in, uint16_t *output, int stride, - int fliplr, int flipud, int shift, int bd) { - __m256i u0, u1, x0, x1, x2, x3, v0, v1, v2, v3; - const __m256i zero = _mm256_setzero_si256(); - int i = 0; - (void)fliplr; - (void)flipud; - - __m256i round = _mm256_set1_epi32((1 << shift) >> 1); - - while (i < 128) { - u0 = _mm256_loadu_si256((const __m256i *)output); - u1 = _mm256_loadu_si256((const __m256i *)(output + 16)); - - x0 = _mm256_unpacklo_epi16(u0, zero); - x1 = _mm256_unpackhi_epi16(u0, zero); - x2 = _mm256_unpacklo_epi16(u1, zero); - x3 = _mm256_unpackhi_epi16(u1, zero); - - v0 = _mm256_permute2f128_si256(in[i], in[i + 1], 0x20); - v1 = _mm256_permute2f128_si256(in[i], in[i + 1], 0x31); - v2 = _mm256_permute2f128_si256(in[i + 2], in[i + 3], 0x20); - v3 = _mm256_permute2f128_si256(in[i + 2], in[i + 3], 0x31); - - v0 = _mm256_add_epi32(v0, round); - v1 = _mm256_add_epi32(v1, round); - v2 = _mm256_add_epi32(v2, round); - v3 = _mm256_add_epi32(v3, round); - - v0 = _mm256_sra_epi32(v0, _mm_cvtsi32_si128(shift)); - v1 = _mm256_sra_epi32(v1, _mm_cvtsi32_si128(shift)); - v2 = _mm256_sra_epi32(v2, _mm_cvtsi32_si128(shift)); - v3 = _mm256_sra_epi32(v3, _mm_cvtsi32_si128(shift)); - - v0 = _mm256_add_epi32(v0, x0); - v1 = _mm256_add_epi32(v1, x1); - v2 = _mm256_add_epi32(v2, x2); - v3 = _mm256_add_epi32(v3, x3); - - v0 = _mm256_packus_epi32(v0, v1); - v2 = _mm256_packus_epi32(v2, v3); - - v0 = highbd_clamp_epi32(v0, bd); - v2 = highbd_clamp_epi32(v2, bd); - - _mm256_storeu_si256((__m256i *)output, v0); - _mm256_storeu_si256((__m256i *)(output + 16), v2); - output += stride; - i += 4; - } +static INLINE __m256i half_btf_0_avx2(const __m256i *w0, const __m256i *n0, + const __m256i *rounding, int bit) { + __m256i x; + x = _mm256_mullo_epi32(*w0, *n0); + x = _mm256_add_epi32(x, *rounding); + x = _mm256_srai_epi32(x, bit); + return x; } static INLINE __m256i half_btf_avx2(const __m256i *w0, const __m256i *n0, @@ -200,18 +197,549 @@ static void addsub_shift_avx2(const __m256i in0, const __m256i in1, __m256i a0 = _mm256_add_epi32(in0_w_offset, in1); __m256i a1 = _mm256_sub_epi32(in0_w_offset, in1); + a0 = _mm256_sra_epi32(a0, _mm_cvtsi32_si128(shift)); + a1 = _mm256_sra_epi32(a1, _mm_cvtsi32_si128(shift)); + a0 = _mm256_max_epi32(a0, *clamp_lo); a0 = _mm256_min_epi32(a0, *clamp_hi); a1 = _mm256_max_epi32(a1, *clamp_lo); a1 = _mm256_min_epi32(a1, *clamp_hi); - a0 = _mm256_sra_epi32(a0, _mm_cvtsi32_si128(shift)); - a1 = _mm256_sra_epi32(a1, _mm_cvtsi32_si128(shift)); - *out0 = a0; *out1 = a1; } +static INLINE void idct32_stage4_avx2( + __m256i *bf1, const __m256i *cospim8, const __m256i *cospi56, + const __m256i *cospi8, const __m256i *cospim56, const __m256i *cospim40, + const __m256i *cospi24, const __m256i *cospi40, const __m256i *cospim24, + const __m256i *rounding, int bit) { + __m256i temp1, temp2; + temp1 = half_btf_avx2(cospim8, &bf1[17], cospi56, &bf1[30], rounding, bit); + bf1[30] = half_btf_avx2(cospi56, &bf1[17], cospi8, &bf1[30], rounding, bit); + bf1[17] = temp1; + + temp2 = half_btf_avx2(cospim56, &bf1[18], cospim8, &bf1[29], rounding, bit); + bf1[29] = half_btf_avx2(cospim8, &bf1[18], cospi56, &bf1[29], rounding, bit); + bf1[18] = temp2; + + temp1 = half_btf_avx2(cospim40, &bf1[21], cospi24, &bf1[26], rounding, bit); + bf1[26] = half_btf_avx2(cospi24, &bf1[21], cospi40, &bf1[26], rounding, bit); + bf1[21] = temp1; + + temp2 = half_btf_avx2(cospim24, &bf1[22], cospim40, &bf1[25], rounding, bit); + bf1[25] = half_btf_avx2(cospim40, &bf1[22], cospi24, &bf1[25], rounding, bit); + bf1[22] = temp2; +} + +static INLINE void idct32_stage5_avx2( + __m256i *bf1, const __m256i *cospim16, const __m256i *cospi48, + const __m256i *cospi16, const __m256i *cospim48, const __m256i *clamp_lo, + const __m256i *clamp_hi, const __m256i *rounding, int bit) { + __m256i temp1, temp2; + temp1 = half_btf_avx2(cospim16, &bf1[9], cospi48, &bf1[14], rounding, bit); + bf1[14] = half_btf_avx2(cospi48, &bf1[9], cospi16, &bf1[14], rounding, bit); + bf1[9] = temp1; + + temp2 = half_btf_avx2(cospim48, &bf1[10], cospim16, &bf1[13], rounding, bit); + bf1[13] = half_btf_avx2(cospim16, &bf1[10], cospi48, &bf1[13], rounding, bit); + bf1[10] = temp2; + + addsub_avx2(bf1[16], bf1[19], bf1 + 16, bf1 + 19, clamp_lo, clamp_hi); + addsub_avx2(bf1[17], bf1[18], bf1 + 17, bf1 + 18, clamp_lo, clamp_hi); + addsub_avx2(bf1[23], bf1[20], bf1 + 23, bf1 + 20, clamp_lo, clamp_hi); + addsub_avx2(bf1[22], bf1[21], bf1 + 22, bf1 + 21, clamp_lo, clamp_hi); + addsub_avx2(bf1[24], bf1[27], bf1 + 24, bf1 + 27, clamp_lo, clamp_hi); + addsub_avx2(bf1[25], bf1[26], bf1 + 25, bf1 + 26, clamp_lo, clamp_hi); + addsub_avx2(bf1[31], bf1[28], bf1 + 31, bf1 + 28, clamp_lo, clamp_hi); + addsub_avx2(bf1[30], bf1[29], bf1 + 30, bf1 + 29, clamp_lo, clamp_hi); +} + +static INLINE void idct32_stage6_avx2( + __m256i *bf1, const __m256i *cospim32, const __m256i *cospi32, + const __m256i *cospim16, const __m256i *cospi48, const __m256i *cospi16, + const __m256i *cospim48, const __m256i *clamp_lo, const __m256i *clamp_hi, + const __m256i *rounding, int bit) { + __m256i temp1, temp2; + temp1 = half_btf_avx2(cospim32, &bf1[5], cospi32, &bf1[6], rounding, bit); + bf1[6] = half_btf_avx2(cospi32, &bf1[5], cospi32, &bf1[6], rounding, bit); + bf1[5] = temp1; + + addsub_avx2(bf1[8], bf1[11], bf1 + 8, bf1 + 11, clamp_lo, clamp_hi); + addsub_avx2(bf1[9], bf1[10], bf1 + 9, bf1 + 10, clamp_lo, clamp_hi); + addsub_avx2(bf1[15], bf1[12], bf1 + 15, bf1 + 12, clamp_lo, clamp_hi); + addsub_avx2(bf1[14], bf1[13], bf1 + 14, bf1 + 13, clamp_lo, clamp_hi); + + temp1 = half_btf_avx2(cospim16, &bf1[18], cospi48, &bf1[29], rounding, bit); + bf1[29] = half_btf_avx2(cospi48, &bf1[18], cospi16, &bf1[29], rounding, bit); + bf1[18] = temp1; + temp2 = half_btf_avx2(cospim16, &bf1[19], cospi48, &bf1[28], rounding, bit); + bf1[28] = half_btf_avx2(cospi48, &bf1[19], cospi16, &bf1[28], rounding, bit); + bf1[19] = temp2; + temp1 = half_btf_avx2(cospim48, &bf1[20], cospim16, &bf1[27], rounding, bit); + bf1[27] = half_btf_avx2(cospim16, &bf1[20], cospi48, &bf1[27], rounding, bit); + bf1[20] = temp1; + temp2 = half_btf_avx2(cospim48, &bf1[21], cospim16, &bf1[26], rounding, bit); + bf1[26] = half_btf_avx2(cospim16, &bf1[21], cospi48, &bf1[26], rounding, bit); + bf1[21] = temp2; +} + +static INLINE void idct32_stage7_avx2(__m256i *bf1, const __m256i *cospim32, + const __m256i *cospi32, + const __m256i *clamp_lo, + const __m256i *clamp_hi, + const __m256i *rounding, int bit) { + __m256i temp1, temp2; + addsub_avx2(bf1[0], bf1[7], bf1 + 0, bf1 + 7, clamp_lo, clamp_hi); + addsub_avx2(bf1[1], bf1[6], bf1 + 1, bf1 + 6, clamp_lo, clamp_hi); + addsub_avx2(bf1[2], bf1[5], bf1 + 2, bf1 + 5, clamp_lo, clamp_hi); + addsub_avx2(bf1[3], bf1[4], bf1 + 3, bf1 + 4, clamp_lo, clamp_hi); + + temp1 = half_btf_avx2(cospim32, &bf1[10], cospi32, &bf1[13], rounding, bit); + bf1[13] = half_btf_avx2(cospi32, &bf1[10], cospi32, &bf1[13], rounding, bit); + bf1[10] = temp1; + temp2 = half_btf_avx2(cospim32, &bf1[11], cospi32, &bf1[12], rounding, bit); + bf1[12] = half_btf_avx2(cospi32, &bf1[11], cospi32, &bf1[12], rounding, bit); + bf1[11] = temp2; + + addsub_avx2(bf1[16], bf1[23], bf1 + 16, bf1 + 23, clamp_lo, clamp_hi); + addsub_avx2(bf1[17], bf1[22], bf1 + 17, bf1 + 22, clamp_lo, clamp_hi); + addsub_avx2(bf1[18], bf1[21], bf1 + 18, bf1 + 21, clamp_lo, clamp_hi); + addsub_avx2(bf1[19], bf1[20], bf1 + 19, bf1 + 20, clamp_lo, clamp_hi); + addsub_avx2(bf1[31], bf1[24], bf1 + 31, bf1 + 24, clamp_lo, clamp_hi); + addsub_avx2(bf1[30], bf1[25], bf1 + 30, bf1 + 25, clamp_lo, clamp_hi); + addsub_avx2(bf1[29], bf1[26], bf1 + 29, bf1 + 26, clamp_lo, clamp_hi); + addsub_avx2(bf1[28], bf1[27], bf1 + 28, bf1 + 27, clamp_lo, clamp_hi); +} + +static INLINE void idct32_stage8_avx2(__m256i *bf1, const __m256i *cospim32, + const __m256i *cospi32, + const __m256i *clamp_lo, + const __m256i *clamp_hi, + const __m256i *rounding, int bit) { + __m256i temp1, temp2; + addsub_avx2(bf1[0], bf1[15], bf1 + 0, bf1 + 15, clamp_lo, clamp_hi); + addsub_avx2(bf1[1], bf1[14], bf1 + 1, bf1 + 14, clamp_lo, clamp_hi); + addsub_avx2(bf1[2], bf1[13], bf1 + 2, bf1 + 13, clamp_lo, clamp_hi); + addsub_avx2(bf1[3], bf1[12], bf1 + 3, bf1 + 12, clamp_lo, clamp_hi); + addsub_avx2(bf1[4], bf1[11], bf1 + 4, bf1 + 11, clamp_lo, clamp_hi); + addsub_avx2(bf1[5], bf1[10], bf1 + 5, bf1 + 10, clamp_lo, clamp_hi); + addsub_avx2(bf1[6], bf1[9], bf1 + 6, bf1 + 9, clamp_lo, clamp_hi); + addsub_avx2(bf1[7], bf1[8], bf1 + 7, bf1 + 8, clamp_lo, clamp_hi); + + temp1 = half_btf_avx2(cospim32, &bf1[20], cospi32, &bf1[27], rounding, bit); + bf1[27] = half_btf_avx2(cospi32, &bf1[20], cospi32, &bf1[27], rounding, bit); + bf1[20] = temp1; + temp2 = half_btf_avx2(cospim32, &bf1[21], cospi32, &bf1[26], rounding, bit); + bf1[26] = half_btf_avx2(cospi32, &bf1[21], cospi32, &bf1[26], rounding, bit); + bf1[21] = temp2; + temp1 = half_btf_avx2(cospim32, &bf1[22], cospi32, &bf1[25], rounding, bit); + bf1[25] = half_btf_avx2(cospi32, &bf1[22], cospi32, &bf1[25], rounding, bit); + bf1[22] = temp1; + temp2 = half_btf_avx2(cospim32, &bf1[23], cospi32, &bf1[24], rounding, bit); + bf1[24] = half_btf_avx2(cospi32, &bf1[23], cospi32, &bf1[24], rounding, bit); + bf1[23] = temp2; +} + +static INLINE void idct32_stage9_avx2(__m256i *bf1, __m256i *out, + const int do_cols, const int bd, + const int out_shift, + const int log_range) { + if (do_cols) { + addsub_no_clamp_avx2(bf1[0], bf1[31], out + 0, out + 31); + addsub_no_clamp_avx2(bf1[1], bf1[30], out + 1, out + 30); + addsub_no_clamp_avx2(bf1[2], bf1[29], out + 2, out + 29); + addsub_no_clamp_avx2(bf1[3], bf1[28], out + 3, out + 28); + addsub_no_clamp_avx2(bf1[4], bf1[27], out + 4, out + 27); + addsub_no_clamp_avx2(bf1[5], bf1[26], out + 5, out + 26); + addsub_no_clamp_avx2(bf1[6], bf1[25], out + 6, out + 25); + addsub_no_clamp_avx2(bf1[7], bf1[24], out + 7, out + 24); + addsub_no_clamp_avx2(bf1[8], bf1[23], out + 8, out + 23); + addsub_no_clamp_avx2(bf1[9], bf1[22], out + 9, out + 22); + addsub_no_clamp_avx2(bf1[10], bf1[21], out + 10, out + 21); + addsub_no_clamp_avx2(bf1[11], bf1[20], out + 11, out + 20); + addsub_no_clamp_avx2(bf1[12], bf1[19], out + 12, out + 19); + addsub_no_clamp_avx2(bf1[13], bf1[18], out + 13, out + 18); + addsub_no_clamp_avx2(bf1[14], bf1[17], out + 14, out + 17); + addsub_no_clamp_avx2(bf1[15], bf1[16], out + 15, out + 16); + } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m256i clamp_lo_out = _mm256_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m256i clamp_hi_out = _mm256_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); + + addsub_shift_avx2(bf1[0], bf1[31], out + 0, out + 31, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[1], bf1[30], out + 1, out + 30, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[2], bf1[29], out + 2, out + 29, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[3], bf1[28], out + 3, out + 28, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[4], bf1[27], out + 4, out + 27, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[5], bf1[26], out + 5, out + 26, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[6], bf1[25], out + 6, out + 25, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[7], bf1[24], out + 7, out + 24, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[8], bf1[23], out + 8, out + 23, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[9], bf1[22], out + 9, out + 22, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[10], bf1[21], out + 10, out + 21, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[11], bf1[20], out + 11, out + 20, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[12], bf1[19], out + 12, out + 19, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[13], bf1[18], out + 13, out + 18, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[14], bf1[17], out + 14, out + 17, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf1[15], bf1[16], out + 15, out + 16, &clamp_lo_out, + &clamp_hi_out, out_shift); + } +} + +static void idct32_low1_avx2(__m256i *in, __m256i *out, int bit, int do_cols, + int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m256i cospi32 = _mm256_set1_epi32(cospi[32]); + const __m256i rounding = _mm256_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m256i clamp_lo = _mm256_set1_epi32(-(1 << (log_range - 1))); + const __m256i clamp_hi = _mm256_set1_epi32((1 << (log_range - 1)) - 1); + __m256i x; + // stage 0 + // stage 1 + // stage 2 + // stage 3 + // stage 4 + // stage 5 + x = _mm256_mullo_epi32(in[0], cospi32); + x = _mm256_add_epi32(x, rounding); + x = _mm256_srai_epi32(x, bit); + + // stage 6 + // stage 7 + // stage 8 + // stage 9 + if (do_cols) { + x = _mm256_max_epi32(x, clamp_lo); + x = _mm256_min_epi32(x, clamp_hi); + } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m256i clamp_lo_out = _mm256_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m256i clamp_hi_out = _mm256_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); + __m256i offset = _mm256_set1_epi32((1 << out_shift) >> 1); + x = _mm256_add_epi32(offset, x); + x = _mm256_sra_epi32(x, _mm_cvtsi32_si128(out_shift)); + x = _mm256_max_epi32(x, clamp_lo_out); + x = _mm256_min_epi32(x, clamp_hi_out); + } + + out[0] = x; + out[1] = x; + out[2] = x; + out[3] = x; + out[4] = x; + out[5] = x; + out[6] = x; + out[7] = x; + out[8] = x; + out[9] = x; + out[10] = x; + out[11] = x; + out[12] = x; + out[13] = x; + out[14] = x; + out[15] = x; + out[16] = x; + out[17] = x; + out[18] = x; + out[19] = x; + out[20] = x; + out[21] = x; + out[22] = x; + out[23] = x; + out[24] = x; + out[25] = x; + out[26] = x; + out[27] = x; + out[28] = x; + out[29] = x; + out[30] = x; + out[31] = x; +} + +static void idct32_low8_avx2(__m256i *in, __m256i *out, int bit, int do_cols, + int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m256i cospi62 = _mm256_set1_epi32(cospi[62]); + const __m256i cospi14 = _mm256_set1_epi32(cospi[14]); + const __m256i cospi54 = _mm256_set1_epi32(cospi[54]); + const __m256i cospi6 = _mm256_set1_epi32(cospi[6]); + const __m256i cospi10 = _mm256_set1_epi32(cospi[10]); + const __m256i cospi2 = _mm256_set1_epi32(cospi[2]); + const __m256i cospim58 = _mm256_set1_epi32(-cospi[58]); + const __m256i cospim50 = _mm256_set1_epi32(-cospi[50]); + const __m256i cospi60 = _mm256_set1_epi32(cospi[60]); + const __m256i cospi12 = _mm256_set1_epi32(cospi[12]); + const __m256i cospi4 = _mm256_set1_epi32(cospi[4]); + const __m256i cospim52 = _mm256_set1_epi32(-cospi[52]); + const __m256i cospi56 = _mm256_set1_epi32(cospi[56]); + const __m256i cospi24 = _mm256_set1_epi32(cospi[24]); + const __m256i cospi40 = _mm256_set1_epi32(cospi[40]); + const __m256i cospi8 = _mm256_set1_epi32(cospi[8]); + const __m256i cospim40 = _mm256_set1_epi32(-cospi[40]); + const __m256i cospim8 = _mm256_set1_epi32(-cospi[8]); + const __m256i cospim56 = _mm256_set1_epi32(-cospi[56]); + const __m256i cospim24 = _mm256_set1_epi32(-cospi[24]); + const __m256i cospi32 = _mm256_set1_epi32(cospi[32]); + const __m256i cospim32 = _mm256_set1_epi32(-cospi[32]); + const __m256i cospi48 = _mm256_set1_epi32(cospi[48]); + const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]); + const __m256i cospi16 = _mm256_set1_epi32(cospi[16]); + const __m256i cospim16 = _mm256_set1_epi32(-cospi[16]); + const __m256i rounding = _mm256_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m256i clamp_lo = _mm256_set1_epi32(-(1 << (log_range - 1))); + const __m256i clamp_hi = _mm256_set1_epi32((1 << (log_range - 1)) - 1); + __m256i bf1[32]; + + { + // stage 0 + // stage 1 + bf1[0] = in[0]; + bf1[4] = in[4]; + bf1[8] = in[2]; + bf1[12] = in[6]; + bf1[16] = in[1]; + bf1[20] = in[5]; + bf1[24] = in[3]; + bf1[28] = in[7]; + + // stage 2 + bf1[31] = half_btf_0_avx2(&cospi2, &bf1[16], &rounding, bit); + bf1[16] = half_btf_0_avx2(&cospi62, &bf1[16], &rounding, bit); + bf1[19] = half_btf_0_avx2(&cospim50, &bf1[28], &rounding, bit); + bf1[28] = half_btf_0_avx2(&cospi14, &bf1[28], &rounding, bit); + bf1[27] = half_btf_0_avx2(&cospi10, &bf1[20], &rounding, bit); + bf1[20] = half_btf_0_avx2(&cospi54, &bf1[20], &rounding, bit); + bf1[23] = half_btf_0_avx2(&cospim58, &bf1[24], &rounding, bit); + bf1[24] = half_btf_0_avx2(&cospi6, &bf1[24], &rounding, bit); + + // stage 3 + bf1[15] = half_btf_0_avx2(&cospi4, &bf1[8], &rounding, bit); + bf1[8] = half_btf_0_avx2(&cospi60, &bf1[8], &rounding, bit); + + bf1[11] = half_btf_0_avx2(&cospim52, &bf1[12], &rounding, bit); + bf1[12] = half_btf_0_avx2(&cospi12, &bf1[12], &rounding, bit); + bf1[17] = bf1[16]; + bf1[18] = bf1[19]; + bf1[21] = bf1[20]; + bf1[22] = bf1[23]; + bf1[25] = bf1[24]; + bf1[26] = bf1[27]; + bf1[29] = bf1[28]; + bf1[30] = bf1[31]; + + // stage 4 + bf1[7] = half_btf_0_avx2(&cospi8, &bf1[4], &rounding, bit); + bf1[4] = half_btf_0_avx2(&cospi56, &bf1[4], &rounding, bit); + + bf1[9] = bf1[8]; + bf1[10] = bf1[11]; + bf1[13] = bf1[12]; + bf1[14] = bf1[15]; + + idct32_stage4_avx2(bf1, &cospim8, &cospi56, &cospi8, &cospim56, &cospim40, + &cospi24, &cospi40, &cospim24, &rounding, bit); + + // stage 5 + bf1[0] = half_btf_0_avx2(&cospi32, &bf1[0], &rounding, bit); + bf1[1] = bf1[0]; + bf1[5] = bf1[4]; + bf1[6] = bf1[7]; + + idct32_stage5_avx2(bf1, &cospim16, &cospi48, &cospi16, &cospim48, &clamp_lo, + &clamp_hi, &rounding, bit); + + // stage 6 + bf1[3] = bf1[0]; + bf1[2] = bf1[1]; + + idct32_stage6_avx2(bf1, &cospim32, &cospi32, &cospim16, &cospi48, &cospi16, + &cospim48, &clamp_lo, &clamp_hi, &rounding, bit); + + // stage 7 + idct32_stage7_avx2(bf1, &cospim32, &cospi32, &clamp_lo, &clamp_hi, + &rounding, bit); + + // stage 8 + idct32_stage8_avx2(bf1, &cospim32, &cospi32, &clamp_lo, &clamp_hi, + &rounding, bit); + + // stage 9 + idct32_stage9_avx2(bf1, out, do_cols, bd, out_shift, log_range); + } +} + +static void idct32_low16_avx2(__m256i *in, __m256i *out, int bit, int do_cols, + int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m256i cospi62 = _mm256_set1_epi32(cospi[62]); + const __m256i cospi30 = _mm256_set1_epi32(cospi[30]); + const __m256i cospi46 = _mm256_set1_epi32(cospi[46]); + const __m256i cospi14 = _mm256_set1_epi32(cospi[14]); + const __m256i cospi54 = _mm256_set1_epi32(cospi[54]); + const __m256i cospi22 = _mm256_set1_epi32(cospi[22]); + const __m256i cospi38 = _mm256_set1_epi32(cospi[38]); + const __m256i cospi6 = _mm256_set1_epi32(cospi[6]); + const __m256i cospi26 = _mm256_set1_epi32(cospi[26]); + const __m256i cospi10 = _mm256_set1_epi32(cospi[10]); + const __m256i cospi18 = _mm256_set1_epi32(cospi[18]); + const __m256i cospi2 = _mm256_set1_epi32(cospi[2]); + const __m256i cospim58 = _mm256_set1_epi32(-cospi[58]); + const __m256i cospim42 = _mm256_set1_epi32(-cospi[42]); + const __m256i cospim50 = _mm256_set1_epi32(-cospi[50]); + const __m256i cospim34 = _mm256_set1_epi32(-cospi[34]); + const __m256i cospi60 = _mm256_set1_epi32(cospi[60]); + const __m256i cospi28 = _mm256_set1_epi32(cospi[28]); + const __m256i cospi44 = _mm256_set1_epi32(cospi[44]); + const __m256i cospi12 = _mm256_set1_epi32(cospi[12]); + const __m256i cospi20 = _mm256_set1_epi32(cospi[20]); + const __m256i cospi4 = _mm256_set1_epi32(cospi[4]); + const __m256i cospim52 = _mm256_set1_epi32(-cospi[52]); + const __m256i cospim36 = _mm256_set1_epi32(-cospi[36]); + const __m256i cospi56 = _mm256_set1_epi32(cospi[56]); + const __m256i cospi24 = _mm256_set1_epi32(cospi[24]); + const __m256i cospi40 = _mm256_set1_epi32(cospi[40]); + const __m256i cospi8 = _mm256_set1_epi32(cospi[8]); + const __m256i cospim40 = _mm256_set1_epi32(-cospi[40]); + const __m256i cospim8 = _mm256_set1_epi32(-cospi[8]); + const __m256i cospim56 = _mm256_set1_epi32(-cospi[56]); + const __m256i cospim24 = _mm256_set1_epi32(-cospi[24]); + const __m256i cospi32 = _mm256_set1_epi32(cospi[32]); + const __m256i cospim32 = _mm256_set1_epi32(-cospi[32]); + const __m256i cospi48 = _mm256_set1_epi32(cospi[48]); + const __m256i cospim48 = _mm256_set1_epi32(-cospi[48]); + const __m256i cospi16 = _mm256_set1_epi32(cospi[16]); + const __m256i cospim16 = _mm256_set1_epi32(-cospi[16]); + const __m256i rounding = _mm256_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m256i clamp_lo = _mm256_set1_epi32(-(1 << (log_range - 1))); + const __m256i clamp_hi = _mm256_set1_epi32((1 << (log_range - 1)) - 1); + __m256i bf1[32]; + + { + // stage 0 + // stage 1 + bf1[0] = in[0]; + bf1[2] = in[8]; + bf1[4] = in[4]; + bf1[6] = in[12]; + bf1[8] = in[2]; + bf1[10] = in[10]; + bf1[12] = in[6]; + bf1[14] = in[14]; + bf1[16] = in[1]; + bf1[18] = in[9]; + bf1[20] = in[5]; + bf1[22] = in[13]; + bf1[24] = in[3]; + bf1[26] = in[11]; + bf1[28] = in[7]; + bf1[30] = in[15]; + + // stage 2 + bf1[31] = half_btf_0_avx2(&cospi2, &bf1[16], &rounding, bit); + bf1[16] = half_btf_0_avx2(&cospi62, &bf1[16], &rounding, bit); + bf1[17] = half_btf_0_avx2(&cospim34, &bf1[30], &rounding, bit); + bf1[30] = half_btf_0_avx2(&cospi30, &bf1[30], &rounding, bit); + bf1[29] = half_btf_0_avx2(&cospi18, &bf1[18], &rounding, bit); + bf1[18] = half_btf_0_avx2(&cospi46, &bf1[18], &rounding, bit); + bf1[19] = half_btf_0_avx2(&cospim50, &bf1[28], &rounding, bit); + bf1[28] = half_btf_0_avx2(&cospi14, &bf1[28], &rounding, bit); + bf1[27] = half_btf_0_avx2(&cospi10, &bf1[20], &rounding, bit); + bf1[20] = half_btf_0_avx2(&cospi54, &bf1[20], &rounding, bit); + bf1[21] = half_btf_0_avx2(&cospim42, &bf1[26], &rounding, bit); + bf1[26] = half_btf_0_avx2(&cospi22, &bf1[26], &rounding, bit); + bf1[25] = half_btf_0_avx2(&cospi26, &bf1[22], &rounding, bit); + bf1[22] = half_btf_0_avx2(&cospi38, &bf1[22], &rounding, bit); + bf1[23] = half_btf_0_avx2(&cospim58, &bf1[24], &rounding, bit); + bf1[24] = half_btf_0_avx2(&cospi6, &bf1[24], &rounding, bit); + + // stage 3 + bf1[15] = half_btf_0_avx2(&cospi4, &bf1[8], &rounding, bit); + bf1[8] = half_btf_0_avx2(&cospi60, &bf1[8], &rounding, bit); + bf1[9] = half_btf_0_avx2(&cospim36, &bf1[14], &rounding, bit); + bf1[14] = half_btf_0_avx2(&cospi28, &bf1[14], &rounding, bit); + bf1[13] = half_btf_0_avx2(&cospi20, &bf1[10], &rounding, bit); + bf1[10] = half_btf_0_avx2(&cospi44, &bf1[10], &rounding, bit); + bf1[11] = half_btf_0_avx2(&cospim52, &bf1[12], &rounding, bit); + bf1[12] = half_btf_0_avx2(&cospi12, &bf1[12], &rounding, bit); + + addsub_avx2(bf1[16], bf1[17], bf1 + 16, bf1 + 17, &clamp_lo, &clamp_hi); + addsub_avx2(bf1[19], bf1[18], bf1 + 19, bf1 + 18, &clamp_lo, &clamp_hi); + addsub_avx2(bf1[20], bf1[21], bf1 + 20, bf1 + 21, &clamp_lo, &clamp_hi); + addsub_avx2(bf1[23], bf1[22], bf1 + 23, bf1 + 22, &clamp_lo, &clamp_hi); + addsub_avx2(bf1[24], bf1[25], bf1 + 24, bf1 + 25, &clamp_lo, &clamp_hi); + addsub_avx2(bf1[27], bf1[26], bf1 + 27, bf1 + 26, &clamp_lo, &clamp_hi); + addsub_avx2(bf1[28], bf1[29], bf1 + 28, bf1 + 29, &clamp_lo, &clamp_hi); + addsub_avx2(bf1[31], bf1[30], bf1 + 31, bf1 + 30, &clamp_lo, &clamp_hi); + + // stage 4 + bf1[7] = half_btf_0_avx2(&cospi8, &bf1[4], &rounding, bit); + bf1[4] = half_btf_0_avx2(&cospi56, &bf1[4], &rounding, bit); + bf1[5] = half_btf_0_avx2(&cospim40, &bf1[6], &rounding, bit); + bf1[6] = half_btf_0_avx2(&cospi24, &bf1[6], &rounding, bit); + + addsub_avx2(bf1[8], bf1[9], bf1 + 8, bf1 + 9, &clamp_lo, &clamp_hi); + addsub_avx2(bf1[11], bf1[10], bf1 + 11, bf1 + 10, &clamp_lo, &clamp_hi); + addsub_avx2(bf1[12], bf1[13], bf1 + 12, bf1 + 13, &clamp_lo, &clamp_hi); + addsub_avx2(bf1[15], bf1[14], bf1 + 15, bf1 + 14, &clamp_lo, &clamp_hi); + + idct32_stage4_avx2(bf1, &cospim8, &cospi56, &cospi8, &cospim56, &cospim40, + &cospi24, &cospi40, &cospim24, &rounding, bit); + + // stage 5 + bf1[0] = half_btf_0_avx2(&cospi32, &bf1[0], &rounding, bit); + bf1[1] = bf1[0]; + bf1[3] = half_btf_0_avx2(&cospi16, &bf1[2], &rounding, bit); + bf1[2] = half_btf_0_avx2(&cospi48, &bf1[2], &rounding, bit); + + addsub_avx2(bf1[4], bf1[5], bf1 + 4, bf1 + 5, &clamp_lo, &clamp_hi); + addsub_avx2(bf1[7], bf1[6], bf1 + 7, bf1 + 6, &clamp_lo, &clamp_hi); + + idct32_stage5_avx2(bf1, &cospim16, &cospi48, &cospi16, &cospim48, &clamp_lo, + &clamp_hi, &rounding, bit); + + // stage 6 + addsub_avx2(bf1[0], bf1[3], bf1 + 0, bf1 + 3, &clamp_lo, &clamp_hi); + addsub_avx2(bf1[1], bf1[2], bf1 + 1, bf1 + 2, &clamp_lo, &clamp_hi); + + idct32_stage6_avx2(bf1, &cospim32, &cospi32, &cospim16, &cospi48, &cospi16, + &cospim48, &clamp_lo, &clamp_hi, &rounding, bit); + + // stage 7 + idct32_stage7_avx2(bf1, &cospim32, &cospi32, &clamp_lo, &clamp_hi, + &rounding, bit); + + // stage 8 + idct32_stage8_avx2(bf1, &cospim32, &cospi32, &clamp_lo, &clamp_hi, + &rounding, bit); + + // stage 9 + idct32_stage9_avx2(bf1, out, do_cols, bd, out_shift, log_range); + } +} + static void idct32_avx2(__m256i *in, __m256i *out, int bit, int do_cols, int bd, int out_shift) { const int32_t *cospi = cospi_arr(bit); @@ -270,43 +798,42 @@ static void idct32_avx2(__m256i *in, __m256i *out, int bit, int do_cols, int bd, const __m256i clamp_lo = _mm256_set1_epi32(-(1 << (log_range - 1))); const __m256i clamp_hi = _mm256_set1_epi32((1 << (log_range - 1)) - 1); __m256i bf1[32], bf0[32]; - int col; - for (col = 0; col < 4; ++col) { + { // stage 0 // stage 1 - bf1[0] = in[0 * 4 + col]; - bf1[1] = in[16 * 4 + col]; - bf1[2] = in[8 * 4 + col]; - bf1[3] = in[24 * 4 + col]; - bf1[4] = in[4 * 4 + col]; - bf1[5] = in[20 * 4 + col]; - bf1[6] = in[12 * 4 + col]; - bf1[7] = in[28 * 4 + col]; - bf1[8] = in[2 * 4 + col]; - bf1[9] = in[18 * 4 + col]; - bf1[10] = in[10 * 4 + col]; - bf1[11] = in[26 * 4 + col]; - bf1[12] = in[6 * 4 + col]; - bf1[13] = in[22 * 4 + col]; - bf1[14] = in[14 * 4 + col]; - bf1[15] = in[30 * 4 + col]; - bf1[16] = in[1 * 4 + col]; - bf1[17] = in[17 * 4 + col]; - bf1[18] = in[9 * 4 + col]; - bf1[19] = in[25 * 4 + col]; - bf1[20] = in[5 * 4 + col]; - bf1[21] = in[21 * 4 + col]; - bf1[22] = in[13 * 4 + col]; - bf1[23] = in[29 * 4 + col]; - bf1[24] = in[3 * 4 + col]; - bf1[25] = in[19 * 4 + col]; - bf1[26] = in[11 * 4 + col]; - bf1[27] = in[27 * 4 + col]; - bf1[28] = in[7 * 4 + col]; - bf1[29] = in[23 * 4 + col]; - bf1[30] = in[15 * 4 + col]; - bf1[31] = in[31 * 4 + col]; + bf1[0] = in[0]; + bf1[1] = in[16]; + bf1[2] = in[8]; + bf1[3] = in[24]; + bf1[4] = in[4]; + bf1[5] = in[20]; + bf1[6] = in[12]; + bf1[7] = in[28]; + bf1[8] = in[2]; + bf1[9] = in[18]; + bf1[10] = in[10]; + bf1[11] = in[26]; + bf1[12] = in[6]; + bf1[13] = in[22]; + bf1[14] = in[14]; + bf1[15] = in[30]; + bf1[16] = in[1]; + bf1[17] = in[17]; + bf1[18] = in[9]; + bf1[19] = in[25]; + bf1[20] = in[5]; + bf1[21] = in[21]; + bf1[22] = in[13]; + bf1[23] = in[29]; + bf1[24] = in[3]; + bf1[25] = in[19]; + bf1[26] = in[11]; + bf1[27] = in[27]; + bf1[28] = in[7]; + bf1[29] = in[23]; + bf1[30] = in[15]; + bf1[31] = in[31]; // stage 2 bf0[0] = bf1[0]; @@ -568,91 +1095,255 @@ static void idct32_avx2(__m256i *in, __m256i *out, int bit, int do_cols, int bd, // stage 9 if (do_cols) { - addsub_no_clamp_avx2(bf0[0], bf0[31], out + 0 * 4 + col, - out + 31 * 4 + col); - addsub_no_clamp_avx2(bf0[1], bf0[30], out + 1 * 4 + col, - out + 30 * 4 + col); - addsub_no_clamp_avx2(bf0[2], bf0[29], out + 2 * 4 + col, - out + 29 * 4 + col); - addsub_no_clamp_avx2(bf0[3], bf0[28], out + 3 * 4 + col, - out + 28 * 4 + col); - addsub_no_clamp_avx2(bf0[4], bf0[27], out + 4 * 4 + col, - out + 27 * 4 + col); - addsub_no_clamp_avx2(bf0[5], bf0[26], out + 5 * 4 + col, - out + 26 * 4 + col); - addsub_no_clamp_avx2(bf0[6], bf0[25], out + 6 * 4 + col, - out + 25 * 4 + col); - addsub_no_clamp_avx2(bf0[7], bf0[24], out + 7 * 4 + col, - out + 24 * 4 + col); - addsub_no_clamp_avx2(bf0[8], bf0[23], out + 8 * 4 + col, - out + 23 * 4 + col); - addsub_no_clamp_avx2(bf0[9], bf0[22], out + 9 * 4 + col, - out + 22 * 4 + col); - addsub_no_clamp_avx2(bf0[10], bf0[21], out + 10 * 4 + col, - out + 21 * 4 + col); - addsub_no_clamp_avx2(bf0[11], bf0[20], out + 11 * 4 + col, - out + 20 * 4 + col); - addsub_no_clamp_avx2(bf0[12], bf0[19], out + 12 * 4 + col, - out + 19 * 4 + col); - addsub_no_clamp_avx2(bf0[13], bf0[18], out + 13 * 4 + col, - out + 18 * 4 + col); - addsub_no_clamp_avx2(bf0[14], bf0[17], out + 14 * 4 + col, - out + 17 * 4 + col); - addsub_no_clamp_avx2(bf0[15], bf0[16], out + 15 * 4 + col, - out + 16 * 4 + col); + addsub_no_clamp_avx2(bf0[0], bf0[31], out + 0, out + 31); + addsub_no_clamp_avx2(bf0[1], bf0[30], out + 1, out + 30); + addsub_no_clamp_avx2(bf0[2], bf0[29], out + 2, out + 29); + addsub_no_clamp_avx2(bf0[3], bf0[28], out + 3, out + 28); + addsub_no_clamp_avx2(bf0[4], bf0[27], out + 4, out + 27); + addsub_no_clamp_avx2(bf0[5], bf0[26], out + 5, out + 26); + addsub_no_clamp_avx2(bf0[6], bf0[25], out + 6, out + 25); + addsub_no_clamp_avx2(bf0[7], bf0[24], out + 7, out + 24); + addsub_no_clamp_avx2(bf0[8], bf0[23], out + 8, out + 23); + addsub_no_clamp_avx2(bf0[9], bf0[22], out + 9, out + 22); + addsub_no_clamp_avx2(bf0[10], bf0[21], out + 10, out + 21); + addsub_no_clamp_avx2(bf0[11], bf0[20], out + 11, out + 20); + addsub_no_clamp_avx2(bf0[12], bf0[19], out + 12, out + 19); + addsub_no_clamp_avx2(bf0[13], bf0[18], out + 13, out + 18); + addsub_no_clamp_avx2(bf0[14], bf0[17], out + 14, out + 17); + addsub_no_clamp_avx2(bf0[15], bf0[16], out + 15, out + 16); } else { - addsub_shift_avx2(bf0[0], bf0[31], out + 0 * 4 + col, out + 31 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[1], bf0[30], out + 1 * 4 + col, out + 30 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[2], bf0[29], out + 2 * 4 + col, out + 29 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[3], bf0[28], out + 3 * 4 + col, out + 28 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[4], bf0[27], out + 4 * 4 + col, out + 27 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[5], bf0[26], out + 5 * 4 + col, out + 26 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[6], bf0[25], out + 6 * 4 + col, out + 25 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[7], bf0[24], out + 7 * 4 + col, out + 24 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[8], bf0[23], out + 8 * 4 + col, out + 23 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[9], bf0[22], out + 9 * 4 + col, out + 22 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[10], bf0[21], out + 10 * 4 + col, - out + 21 * 4 + col, &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[11], bf0[20], out + 11 * 4 + col, - out + 20 * 4 + col, &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[12], bf0[19], out + 12 * 4 + col, - out + 19 * 4 + col, &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[13], bf0[18], out + 13 * 4 + col, - out + 18 * 4 + col, &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[14], bf0[17], out + 14 * 4 + col, - out + 17 * 4 + col, &clamp_lo, &clamp_hi, out_shift); - addsub_shift_avx2(bf0[15], bf0[16], out + 15 * 4 + col, - out + 16 * 4 + col, &clamp_lo, &clamp_hi, out_shift); + const int log_range_out = AOMMAX(16, bd + 6); + const __m256i clamp_lo_out = _mm256_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m256i clamp_hi_out = _mm256_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); + + addsub_shift_avx2(bf0[0], bf0[31], out + 0, out + 31, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[1], bf0[30], out + 1, out + 30, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[2], bf0[29], out + 2, out + 29, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[3], bf0[28], out + 3, out + 28, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[4], bf0[27], out + 4, out + 27, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[5], bf0[26], out + 5, out + 26, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[6], bf0[25], out + 6, out + 25, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[7], bf0[24], out + 7, out + 24, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[8], bf0[23], out + 8, out + 23, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[9], bf0[22], out + 9, out + 22, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[10], bf0[21], out + 10, out + 21, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[11], bf0[20], out + 11, out + 20, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[12], bf0[19], out + 12, out + 19, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[13], bf0[18], out + 13, out + 18, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[14], bf0[17], out + 14, out + 17, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_avx2(bf0[15], bf0[16], out + 15, out + 16, &clamp_lo_out, + &clamp_hi_out, out_shift); } } } -void av1_inv_txfm2d_add_32x32_avx2(const int32_t *coeff, uint16_t *output, - int stride, TX_TYPE tx_type, int bd) { - __m256i in[128], out[128]; - const int8_t *shift = inv_txfm_shift_ls[TX_32X32]; - const int txw_idx = get_txw_idx(TX_32X32); - const int txh_idx = get_txh_idx(TX_32X32); +typedef void (*transform_1d_avx2)(__m256i *in, __m256i *out, int bit, + int do_cols, int bd, int out_shift); + +static const transform_1d_avx2 + highbd_txfm_all_1d_zeros_w8_arr[TX_SIZES][ITX_TYPES_1D][4] = { + { + { NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL }, + }, + { { NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL } }, + { + { NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL }, + }, + { { idct32_low1_avx2, idct32_low8_avx2, idct32_low16_avx2, idct32_avx2 }, + { NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL } }, + + { { NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL } } + }; + +static void highbd_inv_txfm2d_add_no_identity_avx2(const int32_t *input, + uint16_t *output, int stride, + TX_TYPE tx_type, + TX_SIZE tx_size, int eob, + const int bd) { + __m256i buf1[64 * 2]; + int eobx, eoby; + get_eobx_eoby_scan_default(&eobx, &eoby, tx_size, eob); + const int8_t *shift = inv_txfm_shift_ls[tx_size]; + const int txw_idx = get_txw_idx(tx_size); + const int txh_idx = get_txh_idx(tx_size); + const int txfm_size_col = tx_size_wide[tx_size]; + const int txfm_size_row = tx_size_high[tx_size]; + const int buf_size_w_div8 = txfm_size_col >> 3; + const int buf_size_nonzero_w_div8 = (eobx + 8) >> 3; + const int buf_size_nonzero_h_div8 = (eoby + 8) >> 3; + const int input_stride = AOMMIN(32, txfm_size_col); + + const int fun_idx_x = lowbd_txfm_all_1d_zeros_idx[eobx]; + const int fun_idx_y = lowbd_txfm_all_1d_zeros_idx[eoby]; + const transform_1d_avx2 row_txfm = + highbd_txfm_all_1d_zeros_w8_arr[txw_idx][hitx_1d_tab[tx_type]][fun_idx_x]; + const transform_1d_avx2 col_txfm = + highbd_txfm_all_1d_zeros_w8_arr[txh_idx][vitx_1d_tab[tx_type]][fun_idx_y]; + + assert(col_txfm != NULL); + assert(row_txfm != NULL); + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + // 1st stage: column transform + for (int i = 0; i < buf_size_nonzero_h_div8; i++) { + __m256i buf0[32]; + const int32_t *input_row = input + i * input_stride * 8; + for (int j = 0; j < buf_size_nonzero_w_div8; ++j) { + __m256i *buf0_cur = buf0 + j * 8; + load_buffer_32x32(input_row + j * 8, buf0_cur, input_stride, 8); + + transpose_8x8_avx2(&buf0_cur[0], &buf0_cur[0]); + } + + row_txfm(buf0, buf0, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, -shift[0]); + + __m256i *_buf1 = buf1 + i * 8; + for (int j = 0; j < buf_size_w_div8; ++j) { + transpose_8x8_avx2(&buf0[j * 8], &_buf1[j * txfm_size_row]); + } + } + // 2nd stage: column transform + for (int i = 0; i < buf_size_w_div8; i++) { + col_txfm(buf1 + i * txfm_size_row, buf1 + i * txfm_size_row, + inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0); + + av1_round_shift_array_32_avx2(buf1 + i * txfm_size_row, + buf1 + i * txfm_size_row, txfm_size_row, + -shift[1]); + } + + // write to buffer + { + for (int i = 0; i < (txfm_size_col >> 4); i++) { + highbd_write_buffer_16xn_avx2(buf1 + i * txfm_size_row * 2, + output + 16 * i, stride, ud_flip, + txfm_size_row, bd); + } + } +} + +void av1_highbd_inv_txfm2d_add_universe_avx2(const int32_t *input, + uint8_t *output, int stride, + TX_TYPE tx_type, TX_SIZE tx_size, + int eob, const int bd) { switch (tx_type) { case DCT_DCT: - load_buffer_32x32(coeff, in); - transpose_32x32(in, out); - idct32_avx2(out, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, -shift[0]); - transpose_32x32(in, out); - idct32_avx2(out, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0); - write_buffer_32x32(in, output, stride, 0, 0, -shift[1], bd); + highbd_inv_txfm2d_add_no_identity_avx2(input, CONVERT_TO_SHORTPTR(output), + stride, tx_type, tx_size, eob, bd); break; + default: assert(0); break; + } +} + +void av1_highbd_inv_txfm_add_32x32_avx2(const tran_low_t *input, uint8_t *dest, + int stride, + const TxfmParam *txfm_param) { + const int bd = txfm_param->bd; + const TX_TYPE tx_type = txfm_param->tx_type; + const int32_t *src = cast_to_int32(input); + switch (tx_type) { + case DCT_DCT: + av1_highbd_inv_txfm2d_add_universe_avx2(input, dest, stride, tx_type, + txfm_param->tx_size, + txfm_param->eob, bd); + break; + // Assembly version doesn't support IDTX, so use C version for it. + case IDTX: + av1_inv_txfm2d_add_32x32_c(src, CONVERT_TO_SHORTPTR(dest), stride, + tx_type, bd); + break; + default: assert(0); } } + +void av1_highbd_inv_txfm_add_avx2(const tran_low_t *input, uint8_t *dest, + int stride, const TxfmParam *txfm_param) { + assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]); + const TX_SIZE tx_size = txfm_param->tx_size; + switch (tx_size) { + case TX_32X32: + av1_highbd_inv_txfm_add_32x32_avx2(input, dest, stride, txfm_param); + break; + case TX_16X16: + av1_highbd_inv_txfm_add_16x16_sse4_1(input, dest, stride, txfm_param); + break; + case TX_8X8: + av1_highbd_inv_txfm_add_8x8_sse4_1(input, dest, stride, txfm_param); + break; + case TX_4X8: + av1_highbd_inv_txfm_add_4x8(input, dest, stride, txfm_param); + break; + case TX_8X4: + av1_highbd_inv_txfm_add_8x4(input, dest, stride, txfm_param); + break; + case TX_8X16: + av1_highbd_inv_txfm_add_8x16_sse4_1(input, dest, stride, txfm_param); + break; + case TX_16X8: + av1_highbd_inv_txfm_add_16x8_sse4_1(input, dest, stride, txfm_param); + break; + case TX_16X32: + av1_highbd_inv_txfm_add_16x32(input, dest, stride, txfm_param); + break; + case TX_32X16: + av1_highbd_inv_txfm_add_32x16(input, dest, stride, txfm_param); + break; + case TX_32X64: + av1_highbd_inv_txfm_add_32x64(input, dest, stride, txfm_param); + break; + case TX_64X32: + av1_highbd_inv_txfm_add_64x32(input, dest, stride, txfm_param); + break; + case TX_4X4: + av1_highbd_inv_txfm_add_4x4_sse4_1(input, dest, stride, txfm_param); + break; + case TX_16X4: + av1_highbd_inv_txfm_add_16x4(input, dest, stride, txfm_param); + break; + case TX_4X16: + av1_highbd_inv_txfm_add_4x16(input, dest, stride, txfm_param); + break; + case TX_8X32: + av1_highbd_inv_txfm_add_8x32(input, dest, stride, txfm_param); + break; + case TX_32X8: + av1_highbd_inv_txfm_add_32x8(input, dest, stride, txfm_param); + break; + case TX_64X64: + case TX_16X64: + case TX_64X16: + av1_highbd_inv_txfm2d_add_universe_sse4_1( + input, dest, stride, txfm_param->tx_type, txfm_param->tx_size, + txfm_param->eob, txfm_param->bd); + break; + default: assert(0 && "Invalid transform size"); break; + } +} diff --git a/third_party/aom/av1/common/x86/highbd_inv_txfm_sse4.c b/third_party/aom/av1/common/x86/highbd_inv_txfm_sse4.c index 801a4133b..e29e0baf5 100644 --- a/third_party/aom/av1/common/x86/highbd_inv_txfm_sse4.c +++ b/third_party/aom/av1/common/x86/highbd_inv_txfm_sse4.c @@ -15,8 +15,60 @@ #include "config/av1_rtcd.h" #include "av1/common/av1_inv_txfm1d_cfg.h" +#include "av1/common/idct.h" +#include "av1/common/x86/av1_inv_txfm_ssse3.h" +#include "av1/common/x86/av1_txfm_sse4.h" #include "av1/common/x86/highbd_txfm_utility_sse4.h" +static INLINE __m128i highbd_clamp_epi16(__m128i u, int bd) { + const __m128i zero = _mm_setzero_si128(); + const __m128i one = _mm_set1_epi16(1); + const __m128i max = _mm_sub_epi16(_mm_slli_epi16(one, bd), one); + __m128i clamped, mask; + + mask = _mm_cmpgt_epi16(u, max); + clamped = _mm_andnot_si128(mask, u); + mask = _mm_and_si128(mask, max); + clamped = _mm_or_si128(mask, clamped); + mask = _mm_cmpgt_epi16(clamped, zero); + clamped = _mm_and_si128(clamped, mask); + + return clamped; +} + +static INLINE __m128i highbd_get_recon_8x8_sse4_1(const __m128i pred, + __m128i res0, __m128i res1, + const int bd) { + __m128i x0 = _mm_cvtepi16_epi32(pred); + __m128i x1 = _mm_cvtepi16_epi32(_mm_srli_si128(pred, 8)); + + x0 = _mm_add_epi32(res0, x0); + x1 = _mm_add_epi32(res1, x1); + x0 = _mm_packus_epi32(x0, x1); + x0 = highbd_clamp_epi16(x0, bd); + return x0; +} + +static INLINE void highbd_write_buffer_8xn_sse4_1(__m128i *in, uint16_t *output, + int stride, int flipud, + int height, const int bd) { + int j = flipud ? (height - 1) : 0; + const int step = flipud ? -1 : 1; + for (int i = 0; i < height; ++i, j += step) { + __m128i v = _mm_loadu_si128((__m128i const *)(output + i * stride)); + __m128i u = highbd_get_recon_8x8_sse4_1(v, in[j], in[j + height], bd); + + _mm_storeu_si128((__m128i *)(output + i * stride), u); + } +} + +static INLINE void load_buffer_32bit_input(const int32_t *in, int stride, + __m128i *out, int out_size) { + for (int i = 0; i < out_size; ++i) { + out[i] = _mm_loadu_si128((const __m128i *)(in + i * stride)); + } +} + static INLINE void load_buffer_4x4(const int32_t *coeff, __m128i *in) { in[0] = _mm_load_si128((const __m128i *)(coeff + 0)); in[1] = _mm_load_si128((const __m128i *)(coeff + 4)); @@ -57,18 +109,231 @@ static void addsub_shift_sse4_1(const __m128i in0, const __m128i in1, __m128i a0 = _mm_add_epi32(in0_w_offset, in1); __m128i a1 = _mm_sub_epi32(in0_w_offset, in1); + a0 = _mm_sra_epi32(a0, _mm_cvtsi32_si128(shift)); + a1 = _mm_sra_epi32(a1, _mm_cvtsi32_si128(shift)); + a0 = _mm_max_epi32(a0, *clamp_lo); a0 = _mm_min_epi32(a0, *clamp_hi); a1 = _mm_max_epi32(a1, *clamp_lo); a1 = _mm_min_epi32(a1, *clamp_hi); - a0 = _mm_sra_epi32(a0, _mm_cvtsi32_si128(shift)); - a1 = _mm_sra_epi32(a1, _mm_cvtsi32_si128(shift)); - *out0 = a0; *out1 = a1; } +static INLINE void idct32_stage4_sse4_1( + __m128i *bf1, const __m128i *cospim8, const __m128i *cospi56, + const __m128i *cospi8, const __m128i *cospim56, const __m128i *cospim40, + const __m128i *cospi24, const __m128i *cospi40, const __m128i *cospim24, + const __m128i *rounding, int bit) { + __m128i temp1, temp2; + temp1 = half_btf_sse4_1(cospim8, &bf1[17], cospi56, &bf1[30], rounding, bit); + bf1[30] = half_btf_sse4_1(cospi56, &bf1[17], cospi8, &bf1[30], rounding, bit); + bf1[17] = temp1; + + temp2 = half_btf_sse4_1(cospim56, &bf1[18], cospim8, &bf1[29], rounding, bit); + bf1[29] = + half_btf_sse4_1(cospim8, &bf1[18], cospi56, &bf1[29], rounding, bit); + bf1[18] = temp2; + + temp1 = half_btf_sse4_1(cospim40, &bf1[21], cospi24, &bf1[26], rounding, bit); + bf1[26] = + half_btf_sse4_1(cospi24, &bf1[21], cospi40, &bf1[26], rounding, bit); + bf1[21] = temp1; + + temp2 = + half_btf_sse4_1(cospim24, &bf1[22], cospim40, &bf1[25], rounding, bit); + bf1[25] = + half_btf_sse4_1(cospim40, &bf1[22], cospi24, &bf1[25], rounding, bit); + bf1[22] = temp2; +} + +static INLINE void idct32_stage5_sse4_1( + __m128i *bf1, const __m128i *cospim16, const __m128i *cospi48, + const __m128i *cospi16, const __m128i *cospim48, const __m128i *clamp_lo, + const __m128i *clamp_hi, const __m128i *rounding, int bit) { + __m128i temp1, temp2; + temp1 = half_btf_sse4_1(cospim16, &bf1[9], cospi48, &bf1[14], rounding, bit); + bf1[14] = half_btf_sse4_1(cospi48, &bf1[9], cospi16, &bf1[14], rounding, bit); + bf1[9] = temp1; + + temp2 = + half_btf_sse4_1(cospim48, &bf1[10], cospim16, &bf1[13], rounding, bit); + bf1[13] = + half_btf_sse4_1(cospim16, &bf1[10], cospi48, &bf1[13], rounding, bit); + bf1[10] = temp2; + + addsub_sse4_1(bf1[16], bf1[19], bf1 + 16, bf1 + 19, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[17], bf1[18], bf1 + 17, bf1 + 18, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[23], bf1[20], bf1 + 23, bf1 + 20, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[22], bf1[21], bf1 + 22, bf1 + 21, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[24], bf1[27], bf1 + 24, bf1 + 27, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[25], bf1[26], bf1 + 25, bf1 + 26, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[31], bf1[28], bf1 + 31, bf1 + 28, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[30], bf1[29], bf1 + 30, bf1 + 29, clamp_lo, clamp_hi); +} + +static INLINE void idct32_stage6_sse4_1( + __m128i *bf1, const __m128i *cospim32, const __m128i *cospi32, + const __m128i *cospim16, const __m128i *cospi48, const __m128i *cospi16, + const __m128i *cospim48, const __m128i *clamp_lo, const __m128i *clamp_hi, + const __m128i *rounding, int bit) { + __m128i temp1, temp2; + temp1 = half_btf_sse4_1(cospim32, &bf1[5], cospi32, &bf1[6], rounding, bit); + bf1[6] = half_btf_sse4_1(cospi32, &bf1[5], cospi32, &bf1[6], rounding, bit); + bf1[5] = temp1; + + addsub_sse4_1(bf1[8], bf1[11], bf1 + 8, bf1 + 11, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[9], bf1[10], bf1 + 9, bf1 + 10, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[15], bf1[12], bf1 + 15, bf1 + 12, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[14], bf1[13], bf1 + 14, bf1 + 13, clamp_lo, clamp_hi); + + temp1 = half_btf_sse4_1(cospim16, &bf1[18], cospi48, &bf1[29], rounding, bit); + bf1[29] = + half_btf_sse4_1(cospi48, &bf1[18], cospi16, &bf1[29], rounding, bit); + bf1[18] = temp1; + temp2 = half_btf_sse4_1(cospim16, &bf1[19], cospi48, &bf1[28], rounding, bit); + bf1[28] = + half_btf_sse4_1(cospi48, &bf1[19], cospi16, &bf1[28], rounding, bit); + bf1[19] = temp2; + temp1 = + half_btf_sse4_1(cospim48, &bf1[20], cospim16, &bf1[27], rounding, bit); + bf1[27] = + half_btf_sse4_1(cospim16, &bf1[20], cospi48, &bf1[27], rounding, bit); + bf1[20] = temp1; + temp2 = + half_btf_sse4_1(cospim48, &bf1[21], cospim16, &bf1[26], rounding, bit); + bf1[26] = + half_btf_sse4_1(cospim16, &bf1[21], cospi48, &bf1[26], rounding, bit); + bf1[21] = temp2; +} + +static INLINE void idct32_stage7_sse4_1(__m128i *bf1, const __m128i *cospim32, + const __m128i *cospi32, + const __m128i *clamp_lo, + const __m128i *clamp_hi, + const __m128i *rounding, int bit) { + __m128i temp1, temp2; + addsub_sse4_1(bf1[0], bf1[7], bf1 + 0, bf1 + 7, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[1], bf1[6], bf1 + 1, bf1 + 6, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[2], bf1[5], bf1 + 2, bf1 + 5, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[3], bf1[4], bf1 + 3, bf1 + 4, clamp_lo, clamp_hi); + + temp1 = half_btf_sse4_1(cospim32, &bf1[10], cospi32, &bf1[13], rounding, bit); + bf1[13] = + half_btf_sse4_1(cospi32, &bf1[10], cospi32, &bf1[13], rounding, bit); + bf1[10] = temp1; + temp2 = half_btf_sse4_1(cospim32, &bf1[11], cospi32, &bf1[12], rounding, bit); + bf1[12] = + half_btf_sse4_1(cospi32, &bf1[11], cospi32, &bf1[12], rounding, bit); + bf1[11] = temp2; + + addsub_sse4_1(bf1[16], bf1[23], bf1 + 16, bf1 + 23, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[17], bf1[22], bf1 + 17, bf1 + 22, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[18], bf1[21], bf1 + 18, bf1 + 21, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[19], bf1[20], bf1 + 19, bf1 + 20, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[31], bf1[24], bf1 + 31, bf1 + 24, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[30], bf1[25], bf1 + 30, bf1 + 25, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[29], bf1[26], bf1 + 29, bf1 + 26, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[28], bf1[27], bf1 + 28, bf1 + 27, clamp_lo, clamp_hi); +} + +static INLINE void idct32_stage8_sse4_1(__m128i *bf1, const __m128i *cospim32, + const __m128i *cospi32, + const __m128i *clamp_lo, + const __m128i *clamp_hi, + const __m128i *rounding, int bit) { + __m128i temp1, temp2; + addsub_sse4_1(bf1[0], bf1[15], bf1 + 0, bf1 + 15, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[1], bf1[14], bf1 + 1, bf1 + 14, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[2], bf1[13], bf1 + 2, bf1 + 13, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[3], bf1[12], bf1 + 3, bf1 + 12, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[4], bf1[11], bf1 + 4, bf1 + 11, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[5], bf1[10], bf1 + 5, bf1 + 10, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[6], bf1[9], bf1 + 6, bf1 + 9, clamp_lo, clamp_hi); + addsub_sse4_1(bf1[7], bf1[8], bf1 + 7, bf1 + 8, clamp_lo, clamp_hi); + + temp1 = half_btf_sse4_1(cospim32, &bf1[20], cospi32, &bf1[27], rounding, bit); + bf1[27] = + half_btf_sse4_1(cospi32, &bf1[20], cospi32, &bf1[27], rounding, bit); + bf1[20] = temp1; + temp2 = half_btf_sse4_1(cospim32, &bf1[21], cospi32, &bf1[26], rounding, bit); + bf1[26] = + half_btf_sse4_1(cospi32, &bf1[21], cospi32, &bf1[26], rounding, bit); + bf1[21] = temp2; + temp1 = half_btf_sse4_1(cospim32, &bf1[22], cospi32, &bf1[25], rounding, bit); + bf1[25] = + half_btf_sse4_1(cospi32, &bf1[22], cospi32, &bf1[25], rounding, bit); + bf1[22] = temp1; + temp2 = half_btf_sse4_1(cospim32, &bf1[23], cospi32, &bf1[24], rounding, bit); + bf1[24] = + half_btf_sse4_1(cospi32, &bf1[23], cospi32, &bf1[24], rounding, bit); + bf1[23] = temp2; +} + +static INLINE void idct32_stage9_sse4_1(__m128i *bf1, __m128i *out, + const int do_cols, const int bd, + const int out_shift, + const int log_range) { + if (do_cols) { + addsub_no_clamp_sse4_1(bf1[0], bf1[31], out + 0, out + 31); + addsub_no_clamp_sse4_1(bf1[1], bf1[30], out + 1, out + 30); + addsub_no_clamp_sse4_1(bf1[2], bf1[29], out + 2, out + 29); + addsub_no_clamp_sse4_1(bf1[3], bf1[28], out + 3, out + 28); + addsub_no_clamp_sse4_1(bf1[4], bf1[27], out + 4, out + 27); + addsub_no_clamp_sse4_1(bf1[5], bf1[26], out + 5, out + 26); + addsub_no_clamp_sse4_1(bf1[6], bf1[25], out + 6, out + 25); + addsub_no_clamp_sse4_1(bf1[7], bf1[24], out + 7, out + 24); + addsub_no_clamp_sse4_1(bf1[8], bf1[23], out + 8, out + 23); + addsub_no_clamp_sse4_1(bf1[9], bf1[22], out + 9, out + 22); + addsub_no_clamp_sse4_1(bf1[10], bf1[21], out + 10, out + 21); + addsub_no_clamp_sse4_1(bf1[11], bf1[20], out + 11, out + 20); + addsub_no_clamp_sse4_1(bf1[12], bf1[19], out + 12, out + 19); + addsub_no_clamp_sse4_1(bf1[13], bf1[18], out + 13, out + 18); + addsub_no_clamp_sse4_1(bf1[14], bf1[17], out + 14, out + 17); + addsub_no_clamp_sse4_1(bf1[15], bf1[16], out + 15, out + 16); + } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); + + addsub_shift_sse4_1(bf1[0], bf1[31], out + 0, out + 31, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[1], bf1[30], out + 1, out + 30, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[2], bf1[29], out + 2, out + 29, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[3], bf1[28], out + 3, out + 28, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[4], bf1[27], out + 4, out + 27, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[5], bf1[26], out + 5, out + 26, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[6], bf1[25], out + 6, out + 25, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[7], bf1[24], out + 7, out + 24, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[8], bf1[23], out + 8, out + 23, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[9], bf1[22], out + 9, out + 22, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[10], bf1[21], out + 10, out + 21, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[11], bf1[20], out + 11, out + 20, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[12], bf1[19], out + 12, out + 19, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[13], bf1[18], out + 13, out + 18, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[14], bf1[17], out + 14, out + 17, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf1[15], bf1[16], out + 15, out + 16, &clamp_lo_out, + &clamp_hi_out, out_shift); + } +} + static void neg_shift_sse4_1(const __m128i in0, const __m128i in1, __m128i *out0, __m128i *out1, const __m128i *clamp_lo, const __m128i *clamp_hi, @@ -77,14 +342,14 @@ static void neg_shift_sse4_1(const __m128i in0, const __m128i in1, __m128i a0 = _mm_add_epi32(offset, in0); __m128i a1 = _mm_sub_epi32(offset, in1); + a0 = _mm_sra_epi32(a0, _mm_cvtsi32_si128(shift)); + a1 = _mm_sra_epi32(a1, _mm_cvtsi32_si128(shift)); + a0 = _mm_max_epi32(a0, *clamp_lo); a0 = _mm_min_epi32(a0, *clamp_hi); a1 = _mm_max_epi32(a1, *clamp_lo); a1 = _mm_min_epi32(a1, *clamp_hi); - a0 = _mm_sra_epi32(a0, _mm_cvtsi32_si128(shift)); - a1 = _mm_sra_epi32(a1, _mm_cvtsi32_si128(shift)); - *out0 = a0; *out1 = a1; } @@ -96,9 +361,6 @@ static void idct4x4_sse4_1(__m128i *in, int bit, int do_cols, int bd) { const __m128i cospi16 = _mm_set1_epi32(cospi[16]); const __m128i cospim16 = _mm_set1_epi32(-cospi[16]); const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); - const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); - const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); - const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); __m128i u0, u1, u2, u3; __m128i v0, v1, v2, v3, x, y; @@ -135,11 +397,19 @@ static void idct4x4_sse4_1(__m128i *in, int bit, int do_cols, int bd) { v3 = _mm_add_epi32(v3, rnding); v3 = _mm_srai_epi32(v3, bit); - addsub_sse4_1(v0, v3, in + 0, in + 3, &clamp_lo, &clamp_hi); - addsub_sse4_1(v1, v2, in + 1, in + 2, &clamp_lo, &clamp_hi); + if (do_cols) { + addsub_no_clamp_sse4_1(v0, v3, in + 0, in + 3); + addsub_no_clamp_sse4_1(v1, v2, in + 1, in + 2); + } else { + const int log_range = AOMMAX(16, bd + 6); + const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); + const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); + addsub_sse4_1(v0, v3, in + 0, in + 3, &clamp_lo, &clamp_hi); + addsub_sse4_1(v1, v2, in + 1, in + 2, &clamp_lo, &clamp_hi); + } } -static void iadst4x4_sse4_1(__m128i *in, int bit) { +static void iadst4x4_sse4_1(__m128i *in, int bit, int do_cols, int bd) { const int32_t *sinpi = sinpi_arr(bit); const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); const __m128i sinpi1 = _mm_set1_epi32((int)sinpi[1]); @@ -197,6 +467,21 @@ static void iadst4x4_sse4_1(__m128i *in, int bit) { u3 = _mm_add_epi32(u3, rnding); u3 = _mm_srai_epi32(u3, bit); + if (!do_cols) { + const int log_range = AOMMAX(16, bd + 6); + const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); + const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); + + u0 = _mm_max_epi32(u0, clamp_lo); + u0 = _mm_min_epi32(u0, clamp_hi); + u1 = _mm_max_epi32(u1, clamp_lo); + u1 = _mm_min_epi32(u1, clamp_hi); + u2 = _mm_max_epi32(u2, clamp_lo); + u2 = _mm_min_epi32(u2, clamp_hi); + u3 = _mm_max_epi32(u3, clamp_lo); + u3 = _mm_min_epi32(u3, clamp_hi); + } + in[0] = u0; in[1] = u1; in[2] = u2; @@ -217,22 +502,6 @@ static INLINE void round_shift_4x4(__m128i *in, int shift) { in[3] = _mm_srai_epi32(in[3], shift); } -static INLINE __m128i highbd_clamp_epi16(__m128i u, int bd) { - const __m128i zero = _mm_setzero_si128(); - const __m128i one = _mm_set1_epi16(1); - const __m128i max = _mm_sub_epi16(_mm_slli_epi16(one, bd), one); - __m128i clamped, mask; - - mask = _mm_cmpgt_epi16(u, max); - clamped = _mm_andnot_si128(mask, u); - mask = _mm_and_si128(mask, max); - clamped = _mm_or_si128(mask, clamped); - mask = _mm_cmpgt_epi16(clamped, zero); - clamped = _mm_and_si128(clamped, mask); - - return clamped; -} - static void write_buffer_4x4(__m128i *in, uint16_t *output, int stride, int fliplr, int flipud, int shift, int bd) { const __m128i zero = _mm_setzero_si128(); @@ -304,49 +573,49 @@ void av1_inv_txfm2d_add_4x4_sse4_1(const int32_t *coeff, uint16_t *output, case ADST_DCT: load_buffer_4x4(coeff, in); idct4x4_sse4_1(in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd); - iadst4x4_sse4_1(in, inv_cos_bit_col[txw_idx][txh_idx]); + iadst4x4_sse4_1(in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd); write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd); break; case DCT_ADST: load_buffer_4x4(coeff, in); - iadst4x4_sse4_1(in, inv_cos_bit_row[txw_idx][txh_idx]); + iadst4x4_sse4_1(in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd); idct4x4_sse4_1(in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd); write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd); break; case ADST_ADST: load_buffer_4x4(coeff, in); - iadst4x4_sse4_1(in, inv_cos_bit_row[txw_idx][txh_idx]); - iadst4x4_sse4_1(in, inv_cos_bit_col[txw_idx][txh_idx]); + iadst4x4_sse4_1(in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd); + iadst4x4_sse4_1(in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd); write_buffer_4x4(in, output, stride, 0, 0, -shift[1], bd); break; case FLIPADST_DCT: load_buffer_4x4(coeff, in); idct4x4_sse4_1(in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd); - iadst4x4_sse4_1(in, inv_cos_bit_col[txw_idx][txh_idx]); + iadst4x4_sse4_1(in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd); write_buffer_4x4(in, output, stride, 0, 1, -shift[1], bd); break; case DCT_FLIPADST: load_buffer_4x4(coeff, in); - iadst4x4_sse4_1(in, inv_cos_bit_row[txw_idx][txh_idx]); + iadst4x4_sse4_1(in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd); idct4x4_sse4_1(in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd); write_buffer_4x4(in, output, stride, 1, 0, -shift[1], bd); break; case FLIPADST_FLIPADST: load_buffer_4x4(coeff, in); - iadst4x4_sse4_1(in, inv_cos_bit_row[txw_idx][txh_idx]); - iadst4x4_sse4_1(in, inv_cos_bit_col[txw_idx][txh_idx]); + iadst4x4_sse4_1(in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd); + iadst4x4_sse4_1(in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd); write_buffer_4x4(in, output, stride, 1, 1, -shift[1], bd); break; case ADST_FLIPADST: load_buffer_4x4(coeff, in); - iadst4x4_sse4_1(in, inv_cos_bit_row[txw_idx][txh_idx]); - iadst4x4_sse4_1(in, inv_cos_bit_col[txw_idx][txh_idx]); + iadst4x4_sse4_1(in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd); + iadst4x4_sse4_1(in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd); write_buffer_4x4(in, output, stride, 1, 0, -shift[1], bd); break; case FLIPADST_ADST: load_buffer_4x4(coeff, in); - iadst4x4_sse4_1(in, inv_cos_bit_row[txw_idx][txh_idx]); - iadst4x4_sse4_1(in, inv_cos_bit_col[txw_idx][txh_idx]); + iadst4x4_sse4_1(in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd); + iadst4x4_sse4_1(in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd); write_buffer_4x4(in, output, stride, 0, 1, -shift[1], bd); break; default: assert(0); @@ -482,14 +751,19 @@ static void idct8x8_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, addsub_no_clamp_sse4_1(u2, u5, out + 2 * 2 + col, out + 5 * 2 + col); addsub_no_clamp_sse4_1(u3, u4, out + 3 * 2 + col, out + 4 * 2 + col); } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); addsub_shift_sse4_1(u0, u7, out + 0 * 2 + col, out + 7 * 2 + col, - &clamp_lo, &clamp_hi, out_shift); + &clamp_lo_out, &clamp_hi_out, out_shift); addsub_shift_sse4_1(u1, u6, out + 1 * 2 + col, out + 6 * 2 + col, - &clamp_lo, &clamp_hi, out_shift); + &clamp_lo_out, &clamp_hi_out, out_shift); addsub_shift_sse4_1(u2, u5, out + 2 * 2 + col, out + 5 * 2 + col, - &clamp_lo, &clamp_hi, out_shift); + &clamp_lo_out, &clamp_hi_out, out_shift); addsub_shift_sse4_1(u3, u4, out + 3 * 2 + col, out + 4 * 2 + col, - &clamp_lo, &clamp_hi, out_shift); + &clamp_lo_out, &clamp_hi_out, out_shift); } } } @@ -651,14 +925,18 @@ static void iadst8x8_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, out[12] = u[5]; out[14] = _mm_sub_epi32(kZero, u[1]); } else { - neg_shift_sse4_1(u[0], u[4], out + 0, out + 2, &clamp_lo, &clamp_hi, - out_shift); - neg_shift_sse4_1(u[6], u[2], out + 4, out + 6, &clamp_lo, &clamp_hi, - out_shift); - neg_shift_sse4_1(u[3], u[7], out + 8, out + 10, &clamp_lo, &clamp_hi, + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(-(1 << (log_range_out - 1))); + const __m128i clamp_hi_out = _mm_set1_epi32((1 << (log_range_out - 1)) - 1); + + neg_shift_sse4_1(u[0], u[4], out + 0, out + 2, &clamp_lo_out, &clamp_hi_out, out_shift); - neg_shift_sse4_1(u[5], u[1], out + 12, out + 14, &clamp_lo, &clamp_hi, + neg_shift_sse4_1(u[6], u[2], out + 4, out + 6, &clamp_lo_out, &clamp_hi_out, out_shift); + neg_shift_sse4_1(u[3], u[7], out + 8, out + 10, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(u[5], u[1], out + 12, out + 14, &clamp_lo_out, + &clamp_hi_out, out_shift); } // Odd 8 points: 1, 3, ..., 15 @@ -796,14 +1074,18 @@ static void iadst8x8_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, out[13] = u[5]; out[15] = _mm_sub_epi32(kZero, u[1]); } else { - neg_shift_sse4_1(u[0], u[4], out + 1, out + 3, &clamp_lo, &clamp_hi, - out_shift); - neg_shift_sse4_1(u[6], u[2], out + 5, out + 7, &clamp_lo, &clamp_hi, - out_shift); - neg_shift_sse4_1(u[3], u[7], out + 9, out + 11, &clamp_lo, &clamp_hi, + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(-(1 << (log_range_out - 1))); + const __m128i clamp_hi_out = _mm_set1_epi32((1 << (log_range_out - 1)) - 1); + + neg_shift_sse4_1(u[0], u[4], out + 1, out + 3, &clamp_lo_out, &clamp_hi_out, out_shift); - neg_shift_sse4_1(u[5], u[1], out + 13, out + 15, &clamp_lo, &clamp_hi, + neg_shift_sse4_1(u[6], u[2], out + 5, out + 7, &clamp_lo_out, &clamp_hi_out, out_shift); + neg_shift_sse4_1(u[3], u[7], out + 9, out + 11, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(u[5], u[1], out + 13, out + 15, &clamp_lo_out, + &clamp_hi_out, out_shift); } } @@ -976,64 +1258,1141 @@ void av1_inv_txfm2d_add_8x8_sse4_1(const int32_t *coeff, uint16_t *output, } } -// 16x16 -static void load_buffer_16x16(const int32_t *coeff, __m128i *in) { - int i; - for (i = 0; i < 64; ++i) { - in[i] = _mm_load_si128((const __m128i *)(coeff + (i << 2))); +static void idct8x8_low1_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, + int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + __m128i x; + + // stage 0 + // stage 1 + // stage 2 + // stage 3 + x = _mm_mullo_epi32(in[0], cospi32); + x = _mm_add_epi32(x, rnding); + x = _mm_srai_epi32(x, bit); + + // stage 4 + // stage 5 + if (!do_cols) { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); + + __m128i offset = _mm_set1_epi32((1 << out_shift) >> 1); + x = _mm_add_epi32(x, offset); + x = _mm_sra_epi32(x, _mm_cvtsi32_si128(out_shift)); + x = _mm_max_epi32(x, clamp_lo_out); + x = _mm_min_epi32(x, clamp_hi_out); } + + out[0] = x; + out[1] = x; + out[2] = x; + out[3] = x; + out[4] = x; + out[5] = x; + out[6] = x; + out[7] = x; } -static void assign_8x8_input_from_16x16(const __m128i *in, __m128i *in8x8, - int col) { - int i; - for (i = 0; i < 16; i += 2) { - in8x8[i] = in[col]; - in8x8[i + 1] = in[col + 1]; - col += 4; +static void idct8x8_new_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, + int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m128i cospi56 = _mm_set1_epi32(cospi[56]); + const __m128i cospim8 = _mm_set1_epi32(-cospi[8]); + const __m128i cospi24 = _mm_set1_epi32(cospi[24]); + const __m128i cospim40 = _mm_set1_epi32(-cospi[40]); + const __m128i cospi40 = _mm_set1_epi32(cospi[40]); + const __m128i cospi8 = _mm_set1_epi32(cospi[8]); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospim16 = _mm_set1_epi32(-cospi[16]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); + const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); + __m128i u0, u1, u2, u3, u4, u5, u6, u7; + __m128i v0, v1, v2, v3, v4, v5, v6, v7; + __m128i x, y; + + // stage 0 + // stage 1 + // stage 2 + u0 = in[0]; + u1 = in[4]; + u2 = in[2]; + u3 = in[6]; + + x = _mm_mullo_epi32(in[1], cospi56); + y = _mm_mullo_epi32(in[7], cospim8); + u4 = _mm_add_epi32(x, y); + u4 = _mm_add_epi32(u4, rnding); + u4 = _mm_srai_epi32(u4, bit); + + x = _mm_mullo_epi32(in[1], cospi8); + y = _mm_mullo_epi32(in[7], cospi56); + u7 = _mm_add_epi32(x, y); + u7 = _mm_add_epi32(u7, rnding); + u7 = _mm_srai_epi32(u7, bit); + + x = _mm_mullo_epi32(in[5], cospi24); + y = _mm_mullo_epi32(in[3], cospim40); + u5 = _mm_add_epi32(x, y); + u5 = _mm_add_epi32(u5, rnding); + u5 = _mm_srai_epi32(u5, bit); + + x = _mm_mullo_epi32(in[5], cospi40); + y = _mm_mullo_epi32(in[3], cospi24); + u6 = _mm_add_epi32(x, y); + u6 = _mm_add_epi32(u6, rnding); + u6 = _mm_srai_epi32(u6, bit); + + // stage 3 + x = _mm_mullo_epi32(u0, cospi32); + y = _mm_mullo_epi32(u1, cospi32); + v0 = _mm_add_epi32(x, y); + v0 = _mm_add_epi32(v0, rnding); + v0 = _mm_srai_epi32(v0, bit); + + v1 = _mm_sub_epi32(x, y); + v1 = _mm_add_epi32(v1, rnding); + v1 = _mm_srai_epi32(v1, bit); + + x = _mm_mullo_epi32(u2, cospi48); + y = _mm_mullo_epi32(u3, cospim16); + v2 = _mm_add_epi32(x, y); + v2 = _mm_add_epi32(v2, rnding); + v2 = _mm_srai_epi32(v2, bit); + + x = _mm_mullo_epi32(u2, cospi16); + y = _mm_mullo_epi32(u3, cospi48); + v3 = _mm_add_epi32(x, y); + v3 = _mm_add_epi32(v3, rnding); + v3 = _mm_srai_epi32(v3, bit); + + addsub_sse4_1(u4, u5, &v4, &v5, &clamp_lo, &clamp_hi); + addsub_sse4_1(u7, u6, &v7, &v6, &clamp_lo, &clamp_hi); + + // stage 4 + addsub_sse4_1(v0, v3, &u0, &u3, &clamp_lo, &clamp_hi); + addsub_sse4_1(v1, v2, &u1, &u2, &clamp_lo, &clamp_hi); + u4 = v4; + u7 = v7; + + x = _mm_mullo_epi32(v5, cospi32); + y = _mm_mullo_epi32(v6, cospi32); + u6 = _mm_add_epi32(y, x); + u6 = _mm_add_epi32(u6, rnding); + u6 = _mm_srai_epi32(u6, bit); + + u5 = _mm_sub_epi32(y, x); + u5 = _mm_add_epi32(u5, rnding); + u5 = _mm_srai_epi32(u5, bit); + + // stage 5 + if (do_cols) { + addsub_no_clamp_sse4_1(u0, u7, out + 0, out + 7); + addsub_no_clamp_sse4_1(u1, u6, out + 1, out + 6); + addsub_no_clamp_sse4_1(u2, u5, out + 2, out + 5); + addsub_no_clamp_sse4_1(u3, u4, out + 3, out + 4); + } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); + addsub_shift_sse4_1(u0, u7, out + 0, out + 7, &clamp_lo_out, &clamp_hi_out, + out_shift); + addsub_shift_sse4_1(u1, u6, out + 1, out + 6, &clamp_lo_out, &clamp_hi_out, + out_shift); + addsub_shift_sse4_1(u2, u5, out + 2, out + 5, &clamp_lo_out, &clamp_hi_out, + out_shift); + addsub_shift_sse4_1(u3, u4, out + 3, out + 4, &clamp_lo_out, &clamp_hi_out, + out_shift); } } -static void swap_addr(uint16_t **output1, uint16_t **output2) { - uint16_t *tmp; - tmp = *output1; - *output1 = *output2; - *output2 = tmp; +static void iadst8x8_low1_sse4_1(__m128i *in, __m128i *out, int bit, + int do_cols, int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m128i cospi4 = _mm_set1_epi32(cospi[4]); + const __m128i cospi60 = _mm_set1_epi32(cospi[60]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + const __m128i kZero = _mm_setzero_si128(); + __m128i u[8], x; + + // stage 0 + // stage 1 + // stage 2 + + x = _mm_mullo_epi32(in[0], cospi60); + u[0] = _mm_add_epi32(x, rnding); + u[0] = _mm_srai_epi32(u[0], bit); + + x = _mm_mullo_epi32(in[0], cospi4); + u[1] = _mm_sub_epi32(kZero, x); + u[1] = _mm_add_epi32(u[1], rnding); + u[1] = _mm_srai_epi32(u[1], bit); + + // stage 3 + // stage 4 + __m128i temp1, temp2; + temp1 = _mm_mullo_epi32(u[0], cospi16); + x = _mm_mullo_epi32(u[1], cospi48); + temp1 = _mm_add_epi32(temp1, x); + temp1 = _mm_add_epi32(temp1, rnding); + temp1 = _mm_srai_epi32(temp1, bit); + u[4] = temp1; + + temp2 = _mm_mullo_epi32(u[0], cospi48); + x = _mm_mullo_epi32(u[1], cospi16); + u[5] = _mm_sub_epi32(temp2, x); + u[5] = _mm_add_epi32(u[5], rnding); + u[5] = _mm_srai_epi32(u[5], bit); + + // stage 5 + // stage 6 + temp1 = _mm_mullo_epi32(u[0], cospi32); + x = _mm_mullo_epi32(u[1], cospi32); + u[2] = _mm_add_epi32(temp1, x); + u[2] = _mm_add_epi32(u[2], rnding); + u[2] = _mm_srai_epi32(u[2], bit); + + u[3] = _mm_sub_epi32(temp1, x); + u[3] = _mm_add_epi32(u[3], rnding); + u[3] = _mm_srai_epi32(u[3], bit); + + temp1 = _mm_mullo_epi32(u[4], cospi32); + x = _mm_mullo_epi32(u[5], cospi32); + u[6] = _mm_add_epi32(temp1, x); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + u[7] = _mm_sub_epi32(temp1, x); + u[7] = _mm_add_epi32(u[7], rnding); + u[7] = _mm_srai_epi32(u[7], bit); + + // stage 7 + if (do_cols) { + out[0] = u[0]; + out[1] = _mm_sub_epi32(kZero, u[4]); + out[2] = u[6]; + out[3] = _mm_sub_epi32(kZero, u[2]); + out[4] = u[3]; + out[5] = _mm_sub_epi32(kZero, u[7]); + out[6] = u[5]; + out[7] = _mm_sub_epi32(kZero, u[1]); + } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(-(1 << (log_range_out - 1))); + const __m128i clamp_hi_out = _mm_set1_epi32((1 << (log_range_out - 1)) - 1); + + neg_shift_sse4_1(u[0], u[4], out + 0, out + 1, &clamp_lo_out, &clamp_hi_out, + out_shift); + neg_shift_sse4_1(u[6], u[2], out + 2, out + 3, &clamp_lo_out, &clamp_hi_out, + out_shift); + neg_shift_sse4_1(u[3], u[7], out + 4, out + 5, &clamp_lo_out, &clamp_hi_out, + out_shift); + neg_shift_sse4_1(u[5], u[1], out + 6, out + 7, &clamp_lo_out, &clamp_hi_out, + out_shift); + } } -static void write_buffer_16x16(__m128i *in, uint16_t *output, int stride, - int fliplr, int flipud, int shift, int bd) { - __m128i in8x8[16]; - uint16_t *leftUp = &output[0]; - uint16_t *rightUp = &output[8]; - uint16_t *leftDown = &output[8 * stride]; - uint16_t *rightDown = &output[8 * stride + 8]; +static void iadst8x8_new_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, + int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m128i cospi4 = _mm_set1_epi32(cospi[4]); + const __m128i cospi60 = _mm_set1_epi32(cospi[60]); + const __m128i cospi20 = _mm_set1_epi32(cospi[20]); + const __m128i cospi44 = _mm_set1_epi32(cospi[44]); + const __m128i cospi36 = _mm_set1_epi32(cospi[36]); + const __m128i cospi28 = _mm_set1_epi32(cospi[28]); + const __m128i cospi52 = _mm_set1_epi32(cospi[52]); + const __m128i cospi12 = _mm_set1_epi32(cospi[12]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospim48 = _mm_set1_epi32(-cospi[48]); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + const __m128i kZero = _mm_setzero_si128(); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); + const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); + __m128i u[8], v[8], x; + + // stage 0 + // stage 1 + // stage 2 - if (fliplr) { - swap_addr(&leftUp, &rightUp); - swap_addr(&leftDown, &rightDown); + u[0] = _mm_mullo_epi32(in[7], cospi4); + x = _mm_mullo_epi32(in[0], cospi60); + u[0] = _mm_add_epi32(u[0], x); + u[0] = _mm_add_epi32(u[0], rnding); + u[0] = _mm_srai_epi32(u[0], bit); + + u[1] = _mm_mullo_epi32(in[7], cospi60); + x = _mm_mullo_epi32(in[0], cospi4); + u[1] = _mm_sub_epi32(u[1], x); + u[1] = _mm_add_epi32(u[1], rnding); + u[1] = _mm_srai_epi32(u[1], bit); + + // (2) + u[2] = _mm_mullo_epi32(in[5], cospi20); + x = _mm_mullo_epi32(in[2], cospi44); + u[2] = _mm_add_epi32(u[2], x); + u[2] = _mm_add_epi32(u[2], rnding); + u[2] = _mm_srai_epi32(u[2], bit); + + u[3] = _mm_mullo_epi32(in[5], cospi44); + x = _mm_mullo_epi32(in[2], cospi20); + u[3] = _mm_sub_epi32(u[3], x); + u[3] = _mm_add_epi32(u[3], rnding); + u[3] = _mm_srai_epi32(u[3], bit); + + // (3) + u[4] = _mm_mullo_epi32(in[3], cospi36); + x = _mm_mullo_epi32(in[4], cospi28); + u[4] = _mm_add_epi32(u[4], x); + u[4] = _mm_add_epi32(u[4], rnding); + u[4] = _mm_srai_epi32(u[4], bit); + + u[5] = _mm_mullo_epi32(in[3], cospi28); + x = _mm_mullo_epi32(in[4], cospi36); + u[5] = _mm_sub_epi32(u[5], x); + u[5] = _mm_add_epi32(u[5], rnding); + u[5] = _mm_srai_epi32(u[5], bit); + + // (4) + u[6] = _mm_mullo_epi32(in[1], cospi52); + x = _mm_mullo_epi32(in[6], cospi12); + u[6] = _mm_add_epi32(u[6], x); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + u[7] = _mm_mullo_epi32(in[1], cospi12); + x = _mm_mullo_epi32(in[6], cospi52); + u[7] = _mm_sub_epi32(u[7], x); + u[7] = _mm_add_epi32(u[7], rnding); + u[7] = _mm_srai_epi32(u[7], bit); + + // stage 3 + addsub_sse4_1(u[0], u[4], &v[0], &v[4], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[1], u[5], &v[1], &v[5], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[2], u[6], &v[2], &v[6], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[3], u[7], &v[3], &v[7], &clamp_lo, &clamp_hi); + + // stage 4 + u[0] = v[0]; + u[1] = v[1]; + u[2] = v[2]; + u[3] = v[3]; + + u[4] = _mm_mullo_epi32(v[4], cospi16); + x = _mm_mullo_epi32(v[5], cospi48); + u[4] = _mm_add_epi32(u[4], x); + u[4] = _mm_add_epi32(u[4], rnding); + u[4] = _mm_srai_epi32(u[4], bit); + + u[5] = _mm_mullo_epi32(v[4], cospi48); + x = _mm_mullo_epi32(v[5], cospi16); + u[5] = _mm_sub_epi32(u[5], x); + u[5] = _mm_add_epi32(u[5], rnding); + u[5] = _mm_srai_epi32(u[5], bit); + + u[6] = _mm_mullo_epi32(v[6], cospim48); + x = _mm_mullo_epi32(v[7], cospi16); + u[6] = _mm_add_epi32(u[6], x); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + u[7] = _mm_mullo_epi32(v[6], cospi16); + x = _mm_mullo_epi32(v[7], cospim48); + u[7] = _mm_sub_epi32(u[7], x); + u[7] = _mm_add_epi32(u[7], rnding); + u[7] = _mm_srai_epi32(u[7], bit); + + // stage 5 + addsub_sse4_1(u[0], u[2], &v[0], &v[2], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[1], u[3], &v[1], &v[3], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[4], u[6], &v[4], &v[6], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[5], u[7], &v[5], &v[7], &clamp_lo, &clamp_hi); + + // stage 6 + u[0] = v[0]; + u[1] = v[1]; + u[4] = v[4]; + u[5] = v[5]; + + v[0] = _mm_mullo_epi32(v[2], cospi32); + x = _mm_mullo_epi32(v[3], cospi32); + u[2] = _mm_add_epi32(v[0], x); + u[2] = _mm_add_epi32(u[2], rnding); + u[2] = _mm_srai_epi32(u[2], bit); + + u[3] = _mm_sub_epi32(v[0], x); + u[3] = _mm_add_epi32(u[3], rnding); + u[3] = _mm_srai_epi32(u[3], bit); + + v[0] = _mm_mullo_epi32(v[6], cospi32); + x = _mm_mullo_epi32(v[7], cospi32); + u[6] = _mm_add_epi32(v[0], x); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + u[7] = _mm_sub_epi32(v[0], x); + u[7] = _mm_add_epi32(u[7], rnding); + u[7] = _mm_srai_epi32(u[7], bit); + + // stage 7 + if (do_cols) { + out[0] = u[0]; + out[1] = _mm_sub_epi32(kZero, u[4]); + out[2] = u[6]; + out[3] = _mm_sub_epi32(kZero, u[2]); + out[4] = u[3]; + out[5] = _mm_sub_epi32(kZero, u[7]); + out[6] = u[5]; + out[7] = _mm_sub_epi32(kZero, u[1]); + } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(-(1 << (log_range_out - 1))); + const __m128i clamp_hi_out = _mm_set1_epi32((1 << (log_range_out - 1)) - 1); + + neg_shift_sse4_1(u[0], u[4], out + 0, out + 1, &clamp_lo_out, &clamp_hi_out, + out_shift); + neg_shift_sse4_1(u[6], u[2], out + 2, out + 3, &clamp_lo_out, &clamp_hi_out, + out_shift); + neg_shift_sse4_1(u[3], u[7], out + 4, out + 5, &clamp_lo_out, &clamp_hi_out, + out_shift); + neg_shift_sse4_1(u[5], u[1], out + 6, out + 7, &clamp_lo_out, &clamp_hi_out, + out_shift); } +} - if (flipud) { - swap_addr(&leftUp, &leftDown); - swap_addr(&rightUp, &rightDown); +static void idct16x16_low1_sse4_1(__m128i *in, __m128i *out, int bit, + int do_cols, int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); + const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); + + { + // stage 0 + // stage 1 + // stage 2 + // stage 3 + // stage 4 + in[0] = _mm_mullo_epi32(in[0], cospi32); + in[0] = _mm_add_epi32(in[0], rnding); + in[0] = _mm_srai_epi32(in[0], bit); + + // stage 5 + // stage 6 + // stage 7 + if (do_cols) { + in[0] = _mm_max_epi32(in[0], clamp_lo); + in[0] = _mm_min_epi32(in[0], clamp_hi); + } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); + __m128i offset = _mm_set1_epi32((1 << out_shift) >> 1); + in[0] = _mm_add_epi32(in[0], offset); + in[0] = _mm_sra_epi32(in[0], _mm_cvtsi32_si128(out_shift)); + in[0] = _mm_max_epi32(in[0], clamp_lo_out); + in[0] = _mm_min_epi32(in[0], clamp_hi_out); + } + + out[0] = in[0]; + out[1] = in[0]; + out[2] = in[0]; + out[3] = in[0]; + out[4] = in[0]; + out[5] = in[0]; + out[6] = in[0]; + out[7] = in[0]; + out[8] = in[0]; + out[9] = in[0]; + out[10] = in[0]; + out[11] = in[0]; + out[12] = in[0]; + out[13] = in[0]; + out[14] = in[0]; + out[15] = in[0]; } +} - // Left-up quarter - assign_8x8_input_from_16x16(in, in8x8, 0); - write_buffer_8x8(in8x8, leftUp, stride, fliplr, flipud, shift, bd); +static void idct16x16_low8_sse4_1(__m128i *in, __m128i *out, int bit, + int do_cols, int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m128i cospi60 = _mm_set1_epi32(cospi[60]); + const __m128i cospi28 = _mm_set1_epi32(cospi[28]); + const __m128i cospi44 = _mm_set1_epi32(cospi[44]); + const __m128i cospi20 = _mm_set1_epi32(cospi[20]); + const __m128i cospi12 = _mm_set1_epi32(cospi[12]); + const __m128i cospi4 = _mm_set1_epi32(cospi[4]); + const __m128i cospi56 = _mm_set1_epi32(cospi[56]); + const __m128i cospi24 = _mm_set1_epi32(cospi[24]); + const __m128i cospim40 = _mm_set1_epi32(-cospi[40]); + const __m128i cospi8 = _mm_set1_epi32(cospi[8]); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i cospim16 = _mm_set1_epi32(-cospi[16]); + const __m128i cospim48 = _mm_set1_epi32(-cospi[48]); + const __m128i cospim36 = _mm_set1_epi32(-cospi[36]); + const __m128i cospim52 = _mm_set1_epi32(-cospi[52]); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); + const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); + __m128i u[16], x, y; + + { + // stage 0 + // stage 1 + u[0] = in[0]; + u[2] = in[4]; + u[4] = in[2]; + u[6] = in[6]; + u[8] = in[1]; + u[10] = in[5]; + u[12] = in[3]; + u[14] = in[7]; + + // stage 2 + u[15] = half_btf_0_sse4_1(&cospi4, &u[8], &rnding, bit); + u[8] = half_btf_0_sse4_1(&cospi60, &u[8], &rnding, bit); + + u[9] = half_btf_0_sse4_1(&cospim36, &u[14], &rnding, bit); + u[14] = half_btf_0_sse4_1(&cospi28, &u[14], &rnding, bit); + + u[13] = half_btf_0_sse4_1(&cospi20, &u[10], &rnding, bit); + u[10] = half_btf_0_sse4_1(&cospi44, &u[10], &rnding, bit); + + u[11] = half_btf_0_sse4_1(&cospim52, &u[12], &rnding, bit); + u[12] = half_btf_0_sse4_1(&cospi12, &u[12], &rnding, bit); + + // stage 3 + u[7] = half_btf_0_sse4_1(&cospi8, &u[4], &rnding, bit); + u[4] = half_btf_0_sse4_1(&cospi56, &u[4], &rnding, bit); + u[5] = half_btf_0_sse4_1(&cospim40, &u[6], &rnding, bit); + u[6] = half_btf_0_sse4_1(&cospi24, &u[6], &rnding, bit); + + addsub_sse4_1(u[8], u[9], &u[8], &u[9], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[11], u[10], &u[11], &u[10], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[12], u[13], &u[12], &u[13], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[15], u[14], &u[15], &u[14], &clamp_lo, &clamp_hi); + + // stage 4 + x = _mm_mullo_epi32(u[0], cospi32); + u[0] = _mm_add_epi32(x, rnding); + u[0] = _mm_srai_epi32(u[0], bit); + u[1] = u[0]; - // Right-up quarter - assign_8x8_input_from_16x16(in, in8x8, 2); - write_buffer_8x8(in8x8, rightUp, stride, fliplr, flipud, shift, bd); + u[3] = half_btf_0_sse4_1(&cospi16, &u[2], &rnding, bit); + u[2] = half_btf_0_sse4_1(&cospi48, &u[2], &rnding, bit); - // Left-down quarter - assign_8x8_input_from_16x16(in, in8x8, 32); - write_buffer_8x8(in8x8, leftDown, stride, fliplr, flipud, shift, bd); + addsub_sse4_1(u[4], u[5], &u[4], &u[5], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[7], u[6], &u[7], &u[6], &clamp_lo, &clamp_hi); - // Right-down quarter - assign_8x8_input_from_16x16(in, in8x8, 34); - write_buffer_8x8(in8x8, rightDown, stride, fliplr, flipud, shift, bd); + x = half_btf_sse4_1(&cospim16, &u[9], &cospi48, &u[14], &rnding, bit); + u[14] = half_btf_sse4_1(&cospi48, &u[9], &cospi16, &u[14], &rnding, bit); + u[9] = x; + y = half_btf_sse4_1(&cospim48, &u[10], &cospim16, &u[13], &rnding, bit); + u[13] = half_btf_sse4_1(&cospim16, &u[10], &cospi48, &u[13], &rnding, bit); + u[10] = y; + + // stage 5 + addsub_sse4_1(u[0], u[3], &u[0], &u[3], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[1], u[2], &u[1], &u[2], &clamp_lo, &clamp_hi); + + x = _mm_mullo_epi32(u[5], cospi32); + y = _mm_mullo_epi32(u[6], cospi32); + u[5] = _mm_sub_epi32(y, x); + u[5] = _mm_add_epi32(u[5], rnding); + u[5] = _mm_srai_epi32(u[5], bit); + + u[6] = _mm_add_epi32(y, x); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + addsub_sse4_1(u[8], u[11], &u[8], &u[11], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[9], u[10], &u[9], &u[10], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[15], u[12], &u[15], &u[12], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[14], u[13], &u[14], &u[13], &clamp_lo, &clamp_hi); + + // stage 6 + addsub_sse4_1(u[0], u[7], &u[0], &u[7], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[1], u[6], &u[1], &u[6], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[2], u[5], &u[2], &u[5], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[3], u[4], &u[3], &u[4], &clamp_lo, &clamp_hi); + + x = _mm_mullo_epi32(u[10], cospi32); + y = _mm_mullo_epi32(u[13], cospi32); + u[10] = _mm_sub_epi32(y, x); + u[10] = _mm_add_epi32(u[10], rnding); + u[10] = _mm_srai_epi32(u[10], bit); + + u[13] = _mm_add_epi32(x, y); + u[13] = _mm_add_epi32(u[13], rnding); + u[13] = _mm_srai_epi32(u[13], bit); + + x = _mm_mullo_epi32(u[11], cospi32); + y = _mm_mullo_epi32(u[12], cospi32); + u[11] = _mm_sub_epi32(y, x); + u[11] = _mm_add_epi32(u[11], rnding); + u[11] = _mm_srai_epi32(u[11], bit); + + u[12] = _mm_add_epi32(x, y); + u[12] = _mm_add_epi32(u[12], rnding); + u[12] = _mm_srai_epi32(u[12], bit); + // stage 7 + if (do_cols) { + addsub_no_clamp_sse4_1(u[0], u[15], out + 0, out + 15); + addsub_no_clamp_sse4_1(u[1], u[14], out + 1, out + 14); + addsub_no_clamp_sse4_1(u[2], u[13], out + 2, out + 13); + addsub_no_clamp_sse4_1(u[3], u[12], out + 3, out + 12); + addsub_no_clamp_sse4_1(u[4], u[11], out + 4, out + 11); + addsub_no_clamp_sse4_1(u[5], u[10], out + 5, out + 10); + addsub_no_clamp_sse4_1(u[6], u[9], out + 6, out + 9); + addsub_no_clamp_sse4_1(u[7], u[8], out + 7, out + 8); + } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); + + addsub_shift_sse4_1(u[0], u[15], out + 0, out + 15, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(u[1], u[14], out + 1, out + 14, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(u[2], u[13], out + 2, out + 13, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(u[3], u[12], out + 3, out + 12, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(u[4], u[11], out + 4, out + 11, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(u[5], u[10], out + 5, out + 10, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(u[6], u[9], out + 6, out + 9, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(u[7], u[8], out + 7, out + 8, &clamp_lo_out, + &clamp_hi_out, out_shift); + } + } +} + +static void iadst16x16_low1_sse4_1(__m128i *in, __m128i *out, int bit, + int do_cols, int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m128i cospi2 = _mm_set1_epi32(cospi[2]); + const __m128i cospi62 = _mm_set1_epi32(cospi[62]); + const __m128i cospi8 = _mm_set1_epi32(cospi[8]); + const __m128i cospi56 = _mm_set1_epi32(cospi[56]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + const __m128i zero = _mm_setzero_si128(); + __m128i v[16], x, y, temp1, temp2; + + // Calculate the column 0, 1, 2, 3 + { + // stage 0 + // stage 1 + // stage 2 + x = _mm_mullo_epi32(in[0], cospi62); + v[0] = _mm_add_epi32(x, rnding); + v[0] = _mm_srai_epi32(v[0], bit); + + x = _mm_mullo_epi32(in[0], cospi2); + v[1] = _mm_sub_epi32(zero, x); + v[1] = _mm_add_epi32(v[1], rnding); + v[1] = _mm_srai_epi32(v[1], bit); + + // stage 3 + v[8] = v[0]; + v[9] = v[1]; + + // stage 4 + temp1 = _mm_mullo_epi32(v[8], cospi8); + x = _mm_mullo_epi32(v[9], cospi56); + temp1 = _mm_add_epi32(temp1, x); + temp1 = _mm_add_epi32(temp1, rnding); + temp1 = _mm_srai_epi32(temp1, bit); + + temp2 = _mm_mullo_epi32(v[8], cospi56); + x = _mm_mullo_epi32(v[9], cospi8); + temp2 = _mm_sub_epi32(temp2, x); + temp2 = _mm_add_epi32(temp2, rnding); + temp2 = _mm_srai_epi32(temp2, bit); + v[8] = temp1; + v[9] = temp2; + + // stage 5 + v[4] = v[0]; + v[5] = v[1]; + v[12] = v[8]; + v[13] = v[9]; + + // stage 6 + temp1 = _mm_mullo_epi32(v[4], cospi16); + x = _mm_mullo_epi32(v[5], cospi48); + temp1 = _mm_add_epi32(temp1, x); + temp1 = _mm_add_epi32(temp1, rnding); + temp1 = _mm_srai_epi32(temp1, bit); + + temp2 = _mm_mullo_epi32(v[4], cospi48); + x = _mm_mullo_epi32(v[5], cospi16); + temp2 = _mm_sub_epi32(temp2, x); + temp2 = _mm_add_epi32(temp2, rnding); + temp2 = _mm_srai_epi32(temp2, bit); + v[4] = temp1; + v[5] = temp2; + + temp1 = _mm_mullo_epi32(v[12], cospi16); + x = _mm_mullo_epi32(v[13], cospi48); + temp1 = _mm_add_epi32(temp1, x); + temp1 = _mm_add_epi32(temp1, rnding); + temp1 = _mm_srai_epi32(temp1, bit); + + temp2 = _mm_mullo_epi32(v[12], cospi48); + x = _mm_mullo_epi32(v[13], cospi16); + temp2 = _mm_sub_epi32(temp2, x); + temp2 = _mm_add_epi32(temp2, rnding); + temp2 = _mm_srai_epi32(temp2, bit); + v[12] = temp1; + v[13] = temp2; + + // stage 7 + v[2] = v[0]; + v[3] = v[1]; + v[6] = v[4]; + v[7] = v[5]; + v[10] = v[8]; + v[11] = v[9]; + v[14] = v[12]; + v[15] = v[13]; + + // stage 8 + y = _mm_mullo_epi32(v[2], cospi32); + x = _mm_mullo_epi32(v[3], cospi32); + v[2] = _mm_add_epi32(y, x); + v[2] = _mm_add_epi32(v[2], rnding); + v[2] = _mm_srai_epi32(v[2], bit); + + v[3] = _mm_sub_epi32(y, x); + v[3] = _mm_add_epi32(v[3], rnding); + v[3] = _mm_srai_epi32(v[3], bit); + + y = _mm_mullo_epi32(v[6], cospi32); + x = _mm_mullo_epi32(v[7], cospi32); + v[6] = _mm_add_epi32(y, x); + v[6] = _mm_add_epi32(v[6], rnding); + v[6] = _mm_srai_epi32(v[6], bit); + + v[7] = _mm_sub_epi32(y, x); + v[7] = _mm_add_epi32(v[7], rnding); + v[7] = _mm_srai_epi32(v[7], bit); + + y = _mm_mullo_epi32(v[10], cospi32); + x = _mm_mullo_epi32(v[11], cospi32); + v[10] = _mm_add_epi32(y, x); + v[10] = _mm_add_epi32(v[10], rnding); + v[10] = _mm_srai_epi32(v[10], bit); + + v[11] = _mm_sub_epi32(y, x); + v[11] = _mm_add_epi32(v[11], rnding); + v[11] = _mm_srai_epi32(v[11], bit); + + y = _mm_mullo_epi32(v[14], cospi32); + x = _mm_mullo_epi32(v[15], cospi32); + v[14] = _mm_add_epi32(y, x); + v[14] = _mm_add_epi32(v[14], rnding); + v[14] = _mm_srai_epi32(v[14], bit); + + v[15] = _mm_sub_epi32(y, x); + v[15] = _mm_add_epi32(v[15], rnding); + v[15] = _mm_srai_epi32(v[15], bit); + + // stage 9 + if (do_cols) { + out[0] = v[0]; + out[1] = _mm_sub_epi32(_mm_setzero_si128(), v[8]); + out[2] = v[12]; + out[3] = _mm_sub_epi32(_mm_setzero_si128(), v[4]); + out[4] = v[6]; + out[5] = _mm_sub_epi32(_mm_setzero_si128(), v[14]); + out[6] = v[10]; + out[7] = _mm_sub_epi32(_mm_setzero_si128(), v[2]); + out[8] = v[3]; + out[9] = _mm_sub_epi32(_mm_setzero_si128(), v[11]); + out[10] = v[15]; + out[11] = _mm_sub_epi32(_mm_setzero_si128(), v[7]); + out[12] = v[5]; + out[13] = _mm_sub_epi32(_mm_setzero_si128(), v[13]); + out[14] = v[9]; + out[15] = _mm_sub_epi32(_mm_setzero_si128(), v[1]); + } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(-(1 << (log_range_out - 1))); + const __m128i clamp_hi_out = + _mm_set1_epi32((1 << (log_range_out - 1)) - 1); + + neg_shift_sse4_1(v[0], v[8], out + 0, out + 1, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(v[12], v[4], out + 2, out + 3, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(v[6], v[14], out + 4, out + 5, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(v[10], v[2], out + 6, out + 7, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(v[3], v[11], out + 8, out + 9, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(v[15], v[7], out + 10, out + 11, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(v[5], v[13], out + 12, out + 13, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(v[9], v[1], out + 14, out + 15, &clamp_lo_out, + &clamp_hi_out, out_shift); + } + } +} + +static void iadst16x16_low8_sse4_1(__m128i *in, __m128i *out, int bit, + int do_cols, int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m128i cospi2 = _mm_set1_epi32(cospi[2]); + const __m128i cospi62 = _mm_set1_epi32(cospi[62]); + const __m128i cospi10 = _mm_set1_epi32(cospi[10]); + const __m128i cospi54 = _mm_set1_epi32(cospi[54]); + const __m128i cospi18 = _mm_set1_epi32(cospi[18]); + const __m128i cospi46 = _mm_set1_epi32(cospi[46]); + const __m128i cospi26 = _mm_set1_epi32(cospi[26]); + const __m128i cospi38 = _mm_set1_epi32(cospi[38]); + const __m128i cospi34 = _mm_set1_epi32(cospi[34]); + const __m128i cospi30 = _mm_set1_epi32(cospi[30]); + const __m128i cospi42 = _mm_set1_epi32(cospi[42]); + const __m128i cospi22 = _mm_set1_epi32(cospi[22]); + const __m128i cospi50 = _mm_set1_epi32(cospi[50]); + const __m128i cospi14 = _mm_set1_epi32(cospi[14]); + const __m128i cospi58 = _mm_set1_epi32(cospi[58]); + const __m128i cospi6 = _mm_set1_epi32(cospi[6]); + const __m128i cospi8 = _mm_set1_epi32(cospi[8]); + const __m128i cospi56 = _mm_set1_epi32(cospi[56]); + const __m128i cospi40 = _mm_set1_epi32(cospi[40]); + const __m128i cospi24 = _mm_set1_epi32(cospi[24]); + const __m128i cospim56 = _mm_set1_epi32(-cospi[56]); + const __m128i cospim24 = _mm_set1_epi32(-cospi[24]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i cospim48 = _mm_set1_epi32(-cospi[48]); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); + const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); + __m128i u[16], x, y; + + // Calculate the column 0, 1, 2, 3 + { + // stage 0 + // stage 1 + // stage 2 + __m128i zero = _mm_setzero_si128(); + x = _mm_mullo_epi32(in[0], cospi62); + u[0] = _mm_add_epi32(x, rnding); + u[0] = _mm_srai_epi32(u[0], bit); + + x = _mm_mullo_epi32(in[0], cospi2); + u[1] = _mm_sub_epi32(zero, x); + u[1] = _mm_add_epi32(u[1], rnding); + u[1] = _mm_srai_epi32(u[1], bit); + + x = _mm_mullo_epi32(in[2], cospi54); + u[2] = _mm_add_epi32(x, rnding); + u[2] = _mm_srai_epi32(u[2], bit); + + x = _mm_mullo_epi32(in[2], cospi10); + u[3] = _mm_sub_epi32(zero, x); + u[3] = _mm_add_epi32(u[3], rnding); + u[3] = _mm_srai_epi32(u[3], bit); + + x = _mm_mullo_epi32(in[4], cospi46); + u[4] = _mm_add_epi32(x, rnding); + u[4] = _mm_srai_epi32(u[4], bit); + + x = _mm_mullo_epi32(in[4], cospi18); + u[5] = _mm_sub_epi32(zero, x); + u[5] = _mm_add_epi32(u[5], rnding); + u[5] = _mm_srai_epi32(u[5], bit); + + x = _mm_mullo_epi32(in[6], cospi38); + u[6] = _mm_add_epi32(x, rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + x = _mm_mullo_epi32(in[6], cospi26); + u[7] = _mm_sub_epi32(zero, x); + u[7] = _mm_add_epi32(u[7], rnding); + u[7] = _mm_srai_epi32(u[7], bit); + + u[8] = _mm_mullo_epi32(in[7], cospi34); + u[8] = _mm_add_epi32(u[8], rnding); + u[8] = _mm_srai_epi32(u[8], bit); + + u[9] = _mm_mullo_epi32(in[7], cospi30); + u[9] = _mm_add_epi32(u[9], rnding); + u[9] = _mm_srai_epi32(u[9], bit); + + u[10] = _mm_mullo_epi32(in[5], cospi42); + u[10] = _mm_add_epi32(u[10], rnding); + u[10] = _mm_srai_epi32(u[10], bit); + + u[11] = _mm_mullo_epi32(in[5], cospi22); + u[11] = _mm_add_epi32(u[11], rnding); + u[11] = _mm_srai_epi32(u[11], bit); + + u[12] = _mm_mullo_epi32(in[3], cospi50); + u[12] = _mm_add_epi32(u[12], rnding); + u[12] = _mm_srai_epi32(u[12], bit); + + u[13] = _mm_mullo_epi32(in[3], cospi14); + u[13] = _mm_add_epi32(u[13], rnding); + u[13] = _mm_srai_epi32(u[13], bit); + + u[14] = _mm_mullo_epi32(in[1], cospi58); + u[14] = _mm_add_epi32(u[14], rnding); + u[14] = _mm_srai_epi32(u[14], bit); + + u[15] = _mm_mullo_epi32(in[1], cospi6); + u[15] = _mm_add_epi32(u[15], rnding); + u[15] = _mm_srai_epi32(u[15], bit); + + // stage 3 + addsub_sse4_1(u[0], u[8], &u[0], &u[8], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[1], u[9], &u[1], &u[9], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[2], u[10], &u[2], &u[10], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[3], u[11], &u[3], &u[11], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[4], u[12], &u[4], &u[12], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[5], u[13], &u[5], &u[13], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[6], u[14], &u[6], &u[14], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[7], u[15], &u[7], &u[15], &clamp_lo, &clamp_hi); + + // stage 4 + y = _mm_mullo_epi32(u[8], cospi56); + x = _mm_mullo_epi32(u[9], cospi56); + u[8] = _mm_mullo_epi32(u[8], cospi8); + u[8] = _mm_add_epi32(u[8], x); + u[8] = _mm_add_epi32(u[8], rnding); + u[8] = _mm_srai_epi32(u[8], bit); + + x = _mm_mullo_epi32(u[9], cospi8); + u[9] = _mm_sub_epi32(y, x); + u[9] = _mm_add_epi32(u[9], rnding); + u[9] = _mm_srai_epi32(u[9], bit); + + x = _mm_mullo_epi32(u[11], cospi24); + y = _mm_mullo_epi32(u[10], cospi24); + u[10] = _mm_mullo_epi32(u[10], cospi40); + u[10] = _mm_add_epi32(u[10], x); + u[10] = _mm_add_epi32(u[10], rnding); + u[10] = _mm_srai_epi32(u[10], bit); + + x = _mm_mullo_epi32(u[11], cospi40); + u[11] = _mm_sub_epi32(y, x); + u[11] = _mm_add_epi32(u[11], rnding); + u[11] = _mm_srai_epi32(u[11], bit); + + x = _mm_mullo_epi32(u[13], cospi8); + y = _mm_mullo_epi32(u[12], cospi8); + u[12] = _mm_mullo_epi32(u[12], cospim56); + u[12] = _mm_add_epi32(u[12], x); + u[12] = _mm_add_epi32(u[12], rnding); + u[12] = _mm_srai_epi32(u[12], bit); + + x = _mm_mullo_epi32(u[13], cospim56); + u[13] = _mm_sub_epi32(y, x); + u[13] = _mm_add_epi32(u[13], rnding); + u[13] = _mm_srai_epi32(u[13], bit); + + x = _mm_mullo_epi32(u[15], cospi40); + y = _mm_mullo_epi32(u[14], cospi40); + u[14] = _mm_mullo_epi32(u[14], cospim24); + u[14] = _mm_add_epi32(u[14], x); + u[14] = _mm_add_epi32(u[14], rnding); + u[14] = _mm_srai_epi32(u[14], bit); + + x = _mm_mullo_epi32(u[15], cospim24); + u[15] = _mm_sub_epi32(y, x); + u[15] = _mm_add_epi32(u[15], rnding); + u[15] = _mm_srai_epi32(u[15], bit); + + // stage 5 + addsub_sse4_1(u[0], u[4], &u[0], &u[4], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[1], u[5], &u[1], &u[5], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[2], u[6], &u[2], &u[6], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[3], u[7], &u[3], &u[7], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[8], u[12], &u[8], &u[12], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[9], u[13], &u[9], &u[13], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[10], u[14], &u[10], &u[14], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[11], u[15], &u[11], &u[15], &clamp_lo, &clamp_hi); + + // stage 6 + x = _mm_mullo_epi32(u[5], cospi48); + y = _mm_mullo_epi32(u[4], cospi48); + u[4] = _mm_mullo_epi32(u[4], cospi16); + u[4] = _mm_add_epi32(u[4], x); + u[4] = _mm_add_epi32(u[4], rnding); + u[4] = _mm_srai_epi32(u[4], bit); + + x = _mm_mullo_epi32(u[5], cospi16); + u[5] = _mm_sub_epi32(y, x); + u[5] = _mm_add_epi32(u[5], rnding); + u[5] = _mm_srai_epi32(u[5], bit); + + x = _mm_mullo_epi32(u[7], cospi16); + y = _mm_mullo_epi32(u[6], cospi16); + u[6] = _mm_mullo_epi32(u[6], cospim48); + u[6] = _mm_add_epi32(u[6], x); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + x = _mm_mullo_epi32(u[7], cospim48); + u[7] = _mm_sub_epi32(y, x); + u[7] = _mm_add_epi32(u[7], rnding); + u[7] = _mm_srai_epi32(u[7], bit); + + x = _mm_mullo_epi32(u[13], cospi48); + y = _mm_mullo_epi32(u[12], cospi48); + u[12] = _mm_mullo_epi32(u[12], cospi16); + u[12] = _mm_add_epi32(u[12], x); + u[12] = _mm_add_epi32(u[12], rnding); + u[12] = _mm_srai_epi32(u[12], bit); + + x = _mm_mullo_epi32(u[13], cospi16); + u[13] = _mm_sub_epi32(y, x); + u[13] = _mm_add_epi32(u[13], rnding); + u[13] = _mm_srai_epi32(u[13], bit); + + x = _mm_mullo_epi32(u[15], cospi16); + y = _mm_mullo_epi32(u[14], cospi16); + u[14] = _mm_mullo_epi32(u[14], cospim48); + u[14] = _mm_add_epi32(u[14], x); + u[14] = _mm_add_epi32(u[14], rnding); + u[14] = _mm_srai_epi32(u[14], bit); + + x = _mm_mullo_epi32(u[15], cospim48); + u[15] = _mm_sub_epi32(y, x); + u[15] = _mm_add_epi32(u[15], rnding); + u[15] = _mm_srai_epi32(u[15], bit); + + // stage 7 + addsub_sse4_1(u[0], u[2], &u[0], &u[2], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[1], u[3], &u[1], &u[3], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[4], u[6], &u[4], &u[6], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[5], u[7], &u[5], &u[7], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[8], u[10], &u[8], &u[10], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[9], u[11], &u[9], &u[11], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[12], u[14], &u[12], &u[14], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[13], u[15], &u[13], &u[15], &clamp_lo, &clamp_hi); + + // stage 8 + y = _mm_mullo_epi32(u[2], cospi32); + x = _mm_mullo_epi32(u[3], cospi32); + u[2] = _mm_add_epi32(y, x); + u[2] = _mm_add_epi32(u[2], rnding); + u[2] = _mm_srai_epi32(u[2], bit); + + u[3] = _mm_sub_epi32(y, x); + u[3] = _mm_add_epi32(u[3], rnding); + u[3] = _mm_srai_epi32(u[3], bit); + y = _mm_mullo_epi32(u[6], cospi32); + x = _mm_mullo_epi32(u[7], cospi32); + u[6] = _mm_add_epi32(y, x); + u[6] = _mm_add_epi32(u[6], rnding); + u[6] = _mm_srai_epi32(u[6], bit); + + u[7] = _mm_sub_epi32(y, x); + u[7] = _mm_add_epi32(u[7], rnding); + u[7] = _mm_srai_epi32(u[7], bit); + + y = _mm_mullo_epi32(u[10], cospi32); + x = _mm_mullo_epi32(u[11], cospi32); + u[10] = _mm_add_epi32(y, x); + u[10] = _mm_add_epi32(u[10], rnding); + u[10] = _mm_srai_epi32(u[10], bit); + + u[11] = _mm_sub_epi32(y, x); + u[11] = _mm_add_epi32(u[11], rnding); + u[11] = _mm_srai_epi32(u[11], bit); + + y = _mm_mullo_epi32(u[14], cospi32); + x = _mm_mullo_epi32(u[15], cospi32); + u[14] = _mm_add_epi32(y, x); + u[14] = _mm_add_epi32(u[14], rnding); + u[14] = _mm_srai_epi32(u[14], bit); + + u[15] = _mm_sub_epi32(y, x); + u[15] = _mm_add_epi32(u[15], rnding); + u[15] = _mm_srai_epi32(u[15], bit); + + // stage 9 + if (do_cols) { + out[0] = u[0]; + out[1] = _mm_sub_epi32(_mm_setzero_si128(), u[8]); + out[2] = u[12]; + out[3] = _mm_sub_epi32(_mm_setzero_si128(), u[4]); + out[4] = u[6]; + out[5] = _mm_sub_epi32(_mm_setzero_si128(), u[14]); + out[6] = u[10]; + out[7] = _mm_sub_epi32(_mm_setzero_si128(), u[2]); + out[8] = u[3]; + out[9] = _mm_sub_epi32(_mm_setzero_si128(), u[11]); + out[10] = u[15]; + out[11] = _mm_sub_epi32(_mm_setzero_si128(), u[7]); + out[12] = u[5]; + out[13] = _mm_sub_epi32(_mm_setzero_si128(), u[13]); + out[14] = u[9]; + out[15] = _mm_sub_epi32(_mm_setzero_si128(), u[1]); + } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(-(1 << (log_range_out - 1))); + const __m128i clamp_hi_out = + _mm_set1_epi32((1 << (log_range_out - 1)) - 1); + + neg_shift_sse4_1(u[0], u[8], out + 0, out + 1, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(u[12], u[4], out + 2, out + 3, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(u[6], u[14], out + 4, out + 5, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(u[10], u[2], out + 6, out + 7, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(u[3], u[11], out + 8, out + 9, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(u[15], u[7], out + 10, out + 11, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(u[5], u[13], out + 12, out + 13, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(u[9], u[1], out + 14, out + 15, &clamp_lo_out, + &clamp_hi_out, out_shift); + } + } } static void idct16x16_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, @@ -1067,27 +2426,26 @@ static void idct16x16_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); __m128i u[16], v[16], x, y; - int col; - for (col = 0; col < 4; ++col) { + { // stage 0 // stage 1 - u[0] = in[0 * 4 + col]; - u[1] = in[8 * 4 + col]; - u[2] = in[4 * 4 + col]; - u[3] = in[12 * 4 + col]; - u[4] = in[2 * 4 + col]; - u[5] = in[10 * 4 + col]; - u[6] = in[6 * 4 + col]; - u[7] = in[14 * 4 + col]; - u[8] = in[1 * 4 + col]; - u[9] = in[9 * 4 + col]; - u[10] = in[5 * 4 + col]; - u[11] = in[13 * 4 + col]; - u[12] = in[3 * 4 + col]; - u[13] = in[11 * 4 + col]; - u[14] = in[7 * 4 + col]; - u[15] = in[15 * 4 + col]; + u[0] = in[0]; + u[1] = in[8]; + u[2] = in[4]; + u[3] = in[12]; + u[4] = in[2]; + u[5] = in[10]; + u[6] = in[6]; + u[7] = in[14]; + u[8] = in[1]; + u[9] = in[9]; + u[10] = in[5]; + u[11] = in[13]; + u[12] = in[3]; + u[13] = in[11]; + u[14] = in[7]; + u[15] = in[15]; // stage 2 v[0] = u[0]; @@ -1200,37 +2558,37 @@ static void idct16x16_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, // stage 7 if (do_cols) { - addsub_no_clamp_sse4_1(v[0], v[15], out + 0 * 4 + col, - out + 15 * 4 + col); - addsub_no_clamp_sse4_1(v[1], v[14], out + 1 * 4 + col, - out + 14 * 4 + col); - addsub_no_clamp_sse4_1(v[2], v[13], out + 2 * 4 + col, - out + 13 * 4 + col); - addsub_no_clamp_sse4_1(v[3], v[12], out + 3 * 4 + col, - out + 12 * 4 + col); - addsub_no_clamp_sse4_1(v[4], v[11], out + 4 * 4 + col, - out + 11 * 4 + col); - addsub_no_clamp_sse4_1(v[5], v[10], out + 5 * 4 + col, - out + 10 * 4 + col); - addsub_no_clamp_sse4_1(v[6], v[9], out + 6 * 4 + col, out + 9 * 4 + col); - addsub_no_clamp_sse4_1(v[7], v[8], out + 7 * 4 + col, out + 8 * 4 + col); + addsub_no_clamp_sse4_1(v[0], v[15], out + 0, out + 15); + addsub_no_clamp_sse4_1(v[1], v[14], out + 1, out + 14); + addsub_no_clamp_sse4_1(v[2], v[13], out + 2, out + 13); + addsub_no_clamp_sse4_1(v[3], v[12], out + 3, out + 12); + addsub_no_clamp_sse4_1(v[4], v[11], out + 4, out + 11); + addsub_no_clamp_sse4_1(v[5], v[10], out + 5, out + 10); + addsub_no_clamp_sse4_1(v[6], v[9], out + 6, out + 9); + addsub_no_clamp_sse4_1(v[7], v[8], out + 7, out + 8); } else { - addsub_shift_sse4_1(v[0], v[15], out + 0 * 4 + col, out + 15 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_sse4_1(v[1], v[14], out + 1 * 4 + col, out + 14 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_sse4_1(v[2], v[13], out + 2 * 4 + col, out + 13 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_sse4_1(v[3], v[12], out + 3 * 4 + col, out + 12 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_sse4_1(v[4], v[11], out + 4 * 4 + col, out + 11 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_sse4_1(v[5], v[10], out + 5 * 4 + col, out + 10 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_sse4_1(v[6], v[9], out + 6 * 4 + col, out + 9 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); - addsub_shift_sse4_1(v[7], v[8], out + 7 * 4 + col, out + 8 * 4 + col, - &clamp_lo, &clamp_hi, out_shift); + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); + + addsub_shift_sse4_1(v[0], v[15], out + 0, out + 15, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(v[1], v[14], out + 1, out + 14, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(v[2], v[13], out + 2, out + 13, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(v[3], v[12], out + 3, out + 12, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(v[4], v[11], out + 4, out + 11, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(v[5], v[10], out + 5, out + 10, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(v[6], v[9], out + 6, out + 9, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(v[7], v[8], out + 7, out + 8, &clamp_lo_out, + &clamp_hi_out, out_shift); } } } @@ -1269,106 +2627,104 @@ static void iadst16x16_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); __m128i u[16], v[16], x, y; - const int col_num = 4; - int col; // Calculate the column 0, 1, 2, 3 - for (col = 0; col < col_num; ++col) { + { // stage 0 // stage 1 // stage 2 - v[0] = _mm_mullo_epi32(in[15 * col_num + col], cospi2); - x = _mm_mullo_epi32(in[0 * col_num + col], cospi62); + v[0] = _mm_mullo_epi32(in[15], cospi2); + x = _mm_mullo_epi32(in[0], cospi62); v[0] = _mm_add_epi32(v[0], x); v[0] = _mm_add_epi32(v[0], rnding); v[0] = _mm_srai_epi32(v[0], bit); - v[1] = _mm_mullo_epi32(in[15 * col_num + col], cospi62); - x = _mm_mullo_epi32(in[0 * col_num + col], cospi2); + v[1] = _mm_mullo_epi32(in[15], cospi62); + x = _mm_mullo_epi32(in[0], cospi2); v[1] = _mm_sub_epi32(v[1], x); v[1] = _mm_add_epi32(v[1], rnding); v[1] = _mm_srai_epi32(v[1], bit); - v[2] = _mm_mullo_epi32(in[13 * col_num + col], cospi10); - x = _mm_mullo_epi32(in[2 * col_num + col], cospi54); + v[2] = _mm_mullo_epi32(in[13], cospi10); + x = _mm_mullo_epi32(in[2], cospi54); v[2] = _mm_add_epi32(v[2], x); v[2] = _mm_add_epi32(v[2], rnding); v[2] = _mm_srai_epi32(v[2], bit); - v[3] = _mm_mullo_epi32(in[13 * col_num + col], cospi54); - x = _mm_mullo_epi32(in[2 * col_num + col], cospi10); + v[3] = _mm_mullo_epi32(in[13], cospi54); + x = _mm_mullo_epi32(in[2], cospi10); v[3] = _mm_sub_epi32(v[3], x); v[3] = _mm_add_epi32(v[3], rnding); v[3] = _mm_srai_epi32(v[3], bit); - v[4] = _mm_mullo_epi32(in[11 * col_num + col], cospi18); - x = _mm_mullo_epi32(in[4 * col_num + col], cospi46); + v[4] = _mm_mullo_epi32(in[11], cospi18); + x = _mm_mullo_epi32(in[4], cospi46); v[4] = _mm_add_epi32(v[4], x); v[4] = _mm_add_epi32(v[4], rnding); v[4] = _mm_srai_epi32(v[4], bit); - v[5] = _mm_mullo_epi32(in[11 * col_num + col], cospi46); - x = _mm_mullo_epi32(in[4 * col_num + col], cospi18); + v[5] = _mm_mullo_epi32(in[11], cospi46); + x = _mm_mullo_epi32(in[4], cospi18); v[5] = _mm_sub_epi32(v[5], x); v[5] = _mm_add_epi32(v[5], rnding); v[5] = _mm_srai_epi32(v[5], bit); - v[6] = _mm_mullo_epi32(in[9 * col_num + col], cospi26); - x = _mm_mullo_epi32(in[6 * col_num + col], cospi38); + v[6] = _mm_mullo_epi32(in[9], cospi26); + x = _mm_mullo_epi32(in[6], cospi38); v[6] = _mm_add_epi32(v[6], x); v[6] = _mm_add_epi32(v[6], rnding); v[6] = _mm_srai_epi32(v[6], bit); - v[7] = _mm_mullo_epi32(in[9 * col_num + col], cospi38); - x = _mm_mullo_epi32(in[6 * col_num + col], cospi26); + v[7] = _mm_mullo_epi32(in[9], cospi38); + x = _mm_mullo_epi32(in[6], cospi26); v[7] = _mm_sub_epi32(v[7], x); v[7] = _mm_add_epi32(v[7], rnding); v[7] = _mm_srai_epi32(v[7], bit); - v[8] = _mm_mullo_epi32(in[7 * col_num + col], cospi34); - x = _mm_mullo_epi32(in[8 * col_num + col], cospi30); + v[8] = _mm_mullo_epi32(in[7], cospi34); + x = _mm_mullo_epi32(in[8], cospi30); v[8] = _mm_add_epi32(v[8], x); v[8] = _mm_add_epi32(v[8], rnding); v[8] = _mm_srai_epi32(v[8], bit); - v[9] = _mm_mullo_epi32(in[7 * col_num + col], cospi30); - x = _mm_mullo_epi32(in[8 * col_num + col], cospi34); + v[9] = _mm_mullo_epi32(in[7], cospi30); + x = _mm_mullo_epi32(in[8], cospi34); v[9] = _mm_sub_epi32(v[9], x); v[9] = _mm_add_epi32(v[9], rnding); v[9] = _mm_srai_epi32(v[9], bit); - v[10] = _mm_mullo_epi32(in[5 * col_num + col], cospi42); - x = _mm_mullo_epi32(in[10 * col_num + col], cospi22); + v[10] = _mm_mullo_epi32(in[5], cospi42); + x = _mm_mullo_epi32(in[10], cospi22); v[10] = _mm_add_epi32(v[10], x); v[10] = _mm_add_epi32(v[10], rnding); v[10] = _mm_srai_epi32(v[10], bit); - v[11] = _mm_mullo_epi32(in[5 * col_num + col], cospi22); - x = _mm_mullo_epi32(in[10 * col_num + col], cospi42); + v[11] = _mm_mullo_epi32(in[5], cospi22); + x = _mm_mullo_epi32(in[10], cospi42); v[11] = _mm_sub_epi32(v[11], x); v[11] = _mm_add_epi32(v[11], rnding); v[11] = _mm_srai_epi32(v[11], bit); - v[12] = _mm_mullo_epi32(in[3 * col_num + col], cospi50); - x = _mm_mullo_epi32(in[12 * col_num + col], cospi14); + v[12] = _mm_mullo_epi32(in[3], cospi50); + x = _mm_mullo_epi32(in[12], cospi14); v[12] = _mm_add_epi32(v[12], x); v[12] = _mm_add_epi32(v[12], rnding); v[12] = _mm_srai_epi32(v[12], bit); - v[13] = _mm_mullo_epi32(in[3 * col_num + col], cospi14); - x = _mm_mullo_epi32(in[12 * col_num + col], cospi50); + v[13] = _mm_mullo_epi32(in[3], cospi14); + x = _mm_mullo_epi32(in[12], cospi50); v[13] = _mm_sub_epi32(v[13], x); v[13] = _mm_add_epi32(v[13], rnding); v[13] = _mm_srai_epi32(v[13], bit); - v[14] = _mm_mullo_epi32(in[1 * col_num + col], cospi58); - x = _mm_mullo_epi32(in[14 * col_num + col], cospi6); + v[14] = _mm_mullo_epi32(in[1], cospi58); + x = _mm_mullo_epi32(in[14], cospi6); v[14] = _mm_add_epi32(v[14], x); v[14] = _mm_add_epi32(v[14], rnding); v[14] = _mm_srai_epi32(v[14], bit); - v[15] = _mm_mullo_epi32(in[1 * col_num + col], cospi6); - x = _mm_mullo_epi32(in[14 * col_num + col], cospi58); + v[15] = _mm_mullo_epi32(in[1], cospi6); + x = _mm_mullo_epi32(in[14], cospi58); v[15] = _mm_sub_epi32(v[15], x); v[15] = _mm_add_epi32(v[15], rnding); v[15] = _mm_srai_epi32(v[15], bit); @@ -1575,268 +2931,835 @@ static void iadst16x16_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, // stage 9 if (do_cols) { - out[0 * col_num + col] = v[0]; - out[1 * col_num + col] = _mm_sub_epi32(_mm_setzero_si128(), v[8]); - out[2 * col_num + col] = v[12]; - out[3 * col_num + col] = _mm_sub_epi32(_mm_setzero_si128(), v[4]); - out[4 * col_num + col] = v[6]; - out[5 * col_num + col] = _mm_sub_epi32(_mm_setzero_si128(), v[14]); - out[6 * col_num + col] = v[10]; - out[7 * col_num + col] = _mm_sub_epi32(_mm_setzero_si128(), v[2]); - out[8 * col_num + col] = v[3]; - out[9 * col_num + col] = _mm_sub_epi32(_mm_setzero_si128(), v[11]); - out[10 * col_num + col] = v[15]; - out[11 * col_num + col] = _mm_sub_epi32(_mm_setzero_si128(), v[7]); - out[12 * col_num + col] = v[5]; - out[13 * col_num + col] = _mm_sub_epi32(_mm_setzero_si128(), v[13]); - out[14 * col_num + col] = v[9]; - out[15 * col_num + col] = _mm_sub_epi32(_mm_setzero_si128(), v[1]); + out[0] = v[0]; + out[1] = _mm_sub_epi32(_mm_setzero_si128(), v[8]); + out[2] = v[12]; + out[3] = _mm_sub_epi32(_mm_setzero_si128(), v[4]); + out[4] = v[6]; + out[5] = _mm_sub_epi32(_mm_setzero_si128(), v[14]); + out[6] = v[10]; + out[7] = _mm_sub_epi32(_mm_setzero_si128(), v[2]); + out[8] = v[3]; + out[9] = _mm_sub_epi32(_mm_setzero_si128(), v[11]); + out[10] = v[15]; + out[11] = _mm_sub_epi32(_mm_setzero_si128(), v[7]); + out[12] = v[5]; + out[13] = _mm_sub_epi32(_mm_setzero_si128(), v[13]); + out[14] = v[9]; + out[15] = _mm_sub_epi32(_mm_setzero_si128(), v[1]); } else { - neg_shift_sse4_1(v[0], v[8], out + 0 * col_num + col, - out + 1 * col_num + col, &clamp_lo, &clamp_hi, - out_shift); - neg_shift_sse4_1(v[12], v[4], out + 2 * col_num + col, - out + 3 * col_num + col, &clamp_lo, &clamp_hi, - out_shift); - neg_shift_sse4_1(v[6], v[14], out + 4 * col_num + col, - out + 5 * col_num + col, &clamp_lo, &clamp_hi, - out_shift); - neg_shift_sse4_1(v[10], v[2], out + 6 * col_num + col, - out + 7 * col_num + col, &clamp_lo, &clamp_hi, - out_shift); - neg_shift_sse4_1(v[3], v[11], out + 8 * col_num + col, - out + 9 * col_num + col, &clamp_lo, &clamp_hi, - out_shift); - neg_shift_sse4_1(v[15], v[7], out + 10 * col_num + col, - out + 11 * col_num + col, &clamp_lo, &clamp_hi, - out_shift); - neg_shift_sse4_1(v[5], v[13], out + 12 * col_num + col, - out + 13 * col_num + col, &clamp_lo, &clamp_hi, - out_shift); - neg_shift_sse4_1(v[9], v[1], out + 14 * col_num + col, - out + 15 * col_num + col, &clamp_lo, &clamp_hi, - out_shift); + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(-(1 << (log_range_out - 1))); + const __m128i clamp_hi_out = + _mm_set1_epi32((1 << (log_range_out - 1)) - 1); + + neg_shift_sse4_1(v[0], v[8], out + 0, out + 1, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(v[12], v[4], out + 2, out + 3, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(v[6], v[14], out + 4, out + 5, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(v[10], v[2], out + 6, out + 7, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(v[3], v[11], out + 8, out + 9, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(v[15], v[7], out + 10, out + 11, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(v[5], v[13], out + 12, out + 13, &clamp_lo_out, + &clamp_hi_out, out_shift); + neg_shift_sse4_1(v[9], v[1], out + 14, out + 15, &clamp_lo_out, + &clamp_hi_out, out_shift); } } } -void av1_inv_txfm2d_add_16x16_sse4_1(const int32_t *coeff, uint16_t *output, - int stride, TX_TYPE tx_type, int bd) { - __m128i in[64], out[64]; - const int8_t *shift = inv_txfm_shift_ls[TX_16X16]; - const int txw_idx = get_txw_idx(TX_16X16); - const int txh_idx = get_txh_idx(TX_16X16); - - switch (tx_type) { - case DCT_DCT: - load_buffer_16x16(coeff, in); - transpose_16x16(in, out); - idct16x16_sse4_1(out, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, - -shift[0]); - transpose_16x16(in, out); - idct16x16_sse4_1(out, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0); - write_buffer_16x16(in, output, stride, 0, 0, -shift[1], bd); - break; - case DCT_ADST: - load_buffer_16x16(coeff, in); - transpose_16x16(in, out); - iadst16x16_sse4_1(out, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, - -shift[0]); - transpose_16x16(in, out); - idct16x16_sse4_1(out, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0); - write_buffer_16x16(in, output, stride, 0, 0, -shift[1], bd); - break; - case ADST_DCT: - load_buffer_16x16(coeff, in); - transpose_16x16(in, out); - idct16x16_sse4_1(out, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, - -shift[0]); - transpose_16x16(in, out); - iadst16x16_sse4_1(out, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0); - write_buffer_16x16(in, output, stride, 0, 0, -shift[1], bd); - break; - case ADST_ADST: - load_buffer_16x16(coeff, in); - transpose_16x16(in, out); - iadst16x16_sse4_1(out, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, - -shift[0]); - transpose_16x16(in, out); - iadst16x16_sse4_1(out, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0); - write_buffer_16x16(in, output, stride, 0, 0, -shift[1], bd); - break; - case FLIPADST_DCT: - load_buffer_16x16(coeff, in); - transpose_16x16(in, out); - idct16x16_sse4_1(out, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, - -shift[0]); - transpose_16x16(in, out); - iadst16x16_sse4_1(out, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0); - write_buffer_16x16(in, output, stride, 0, 1, -shift[1], bd); - break; - case DCT_FLIPADST: - load_buffer_16x16(coeff, in); - transpose_16x16(in, out); - iadst16x16_sse4_1(out, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, - -shift[0]); - transpose_16x16(in, out); - idct16x16_sse4_1(out, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0); - write_buffer_16x16(in, output, stride, 1, 0, -shift[1], bd); - break; - case ADST_FLIPADST: - load_buffer_16x16(coeff, in); - transpose_16x16(in, out); - iadst16x16_sse4_1(out, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, - -shift[0]); - transpose_16x16(in, out); - iadst16x16_sse4_1(out, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0); - write_buffer_16x16(in, output, stride, 1, 0, -shift[1], bd); - break; - case FLIPADST_FLIPADST: - load_buffer_16x16(coeff, in); - transpose_16x16(in, out); - iadst16x16_sse4_1(out, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, - -shift[0]); - transpose_16x16(in, out); - iadst16x16_sse4_1(out, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0); - write_buffer_16x16(in, output, stride, 1, 1, -shift[1], bd); - break; - case FLIPADST_ADST: - load_buffer_16x16(coeff, in); - transpose_16x16(in, out); - iadst16x16_sse4_1(out, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, - -shift[0]); - transpose_16x16(in, out); - iadst16x16_sse4_1(out, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0); - write_buffer_16x16(in, output, stride, 0, 1, -shift[1], bd); - break; - default: assert(0); +static INLINE void idct64_stage8_sse4_1( + __m128i *u, const __m128i *cospim32, const __m128i *cospi32, + const __m128i *cospim16, const __m128i *cospi48, const __m128i *cospi16, + const __m128i *cospim48, const __m128i *clamp_lo, const __m128i *clamp_hi, + const __m128i *rnding, int bit) { + int i; + __m128i temp1, temp2, temp3, temp4; + temp1 = half_btf_sse4_1(cospim32, &u[10], cospi32, &u[13], rnding, bit); + u[13] = half_btf_sse4_1(cospi32, &u[10], cospi32, &u[13], rnding, bit); + u[10] = temp1; + temp2 = half_btf_sse4_1(cospim32, &u[11], cospi32, &u[12], rnding, bit); + u[12] = half_btf_sse4_1(cospi32, &u[11], cospi32, &u[12], rnding, bit); + u[11] = temp2; + + for (i = 16; i < 20; ++i) { + addsub_sse4_1(u[i], u[i ^ 7], &u[i], &u[i ^ 7], clamp_lo, clamp_hi); + addsub_sse4_1(u[i ^ 15], u[i ^ 8], &u[i ^ 15], &u[i ^ 8], clamp_lo, + clamp_hi); } + + temp1 = half_btf_sse4_1(cospim16, &u[36], cospi48, &u[59], rnding, bit); + temp2 = half_btf_sse4_1(cospim16, &u[37], cospi48, &u[58], rnding, bit); + temp3 = half_btf_sse4_1(cospim16, &u[38], cospi48, &u[57], rnding, bit); + temp4 = half_btf_sse4_1(cospim16, &u[39], cospi48, &u[56], rnding, bit); + u[56] = half_btf_sse4_1(cospi48, &u[39], cospi16, &u[56], rnding, bit); + u[57] = half_btf_sse4_1(cospi48, &u[38], cospi16, &u[57], rnding, bit); + u[58] = half_btf_sse4_1(cospi48, &u[37], cospi16, &u[58], rnding, bit); + u[59] = half_btf_sse4_1(cospi48, &u[36], cospi16, &u[59], rnding, bit); + u[36] = temp1; + u[37] = temp2; + u[38] = temp3; + u[39] = temp4; + + temp1 = half_btf_sse4_1(cospim48, &u[40], cospim16, &u[55], rnding, bit); + temp2 = half_btf_sse4_1(cospim48, &u[41], cospim16, &u[54], rnding, bit); + temp3 = half_btf_sse4_1(cospim48, &u[42], cospim16, &u[53], rnding, bit); + temp4 = half_btf_sse4_1(cospim48, &u[43], cospim16, &u[52], rnding, bit); + u[52] = half_btf_sse4_1(cospim16, &u[43], cospi48, &u[52], rnding, bit); + u[53] = half_btf_sse4_1(cospim16, &u[42], cospi48, &u[53], rnding, bit); + u[54] = half_btf_sse4_1(cospim16, &u[41], cospi48, &u[54], rnding, bit); + u[55] = half_btf_sse4_1(cospim16, &u[40], cospi48, &u[55], rnding, bit); + u[40] = temp1; + u[41] = temp2; + u[42] = temp3; + u[43] = temp4; } -static void load_buffer_64x64_lower_32x32(const int32_t *coeff, __m128i *in) { - int i, j; +static INLINE void idct64_stage9_sse4_1(__m128i *u, const __m128i *cospim32, + const __m128i *cospi32, + const __m128i *clamp_lo, + const __m128i *clamp_hi, + const __m128i *rnding, int bit) { + int i; + __m128i temp1, temp2, temp3, temp4; + for (i = 0; i < 8; ++i) { + addsub_sse4_1(u[i], u[15 - i], &u[i], &u[15 - i], clamp_lo, clamp_hi); + } - __m128i zero = _mm_setzero_si128(); + temp1 = half_btf_sse4_1(cospim32, &u[20], cospi32, &u[27], rnding, bit); + temp2 = half_btf_sse4_1(cospim32, &u[21], cospi32, &u[26], rnding, bit); + temp3 = half_btf_sse4_1(cospim32, &u[22], cospi32, &u[25], rnding, bit); + temp4 = half_btf_sse4_1(cospim32, &u[23], cospi32, &u[24], rnding, bit); + u[24] = half_btf_sse4_1(cospi32, &u[23], cospi32, &u[24], rnding, bit); + u[25] = half_btf_sse4_1(cospi32, &u[22], cospi32, &u[25], rnding, bit); + u[26] = half_btf_sse4_1(cospi32, &u[21], cospi32, &u[26], rnding, bit); + u[27] = half_btf_sse4_1(cospi32, &u[20], cospi32, &u[27], rnding, bit); + u[20] = temp1; + u[21] = temp2; + u[22] = temp3; + u[23] = temp4; + for (i = 32; i < 40; i++) { + addsub_sse4_1(u[i], u[i ^ 15], &u[i], &u[i ^ 15], clamp_lo, clamp_hi); + } - for (i = 0; i < 32; ++i) { - for (j = 0; j < 8; ++j) { - in[16 * i + j] = - _mm_loadu_si128((const __m128i *)(coeff + 32 * i + 4 * j)); - in[16 * i + j + 8] = zero; - } + for (i = 48; i < 56; i++) { + addsub_sse4_1(u[i ^ 15], u[i], &u[i ^ 15], &u[i], clamp_lo, clamp_hi); + } +} + +static INLINE void idct64_stage10_sse4_1(__m128i *u, const __m128i *cospim32, + const __m128i *cospi32, + const __m128i *clamp_lo, + const __m128i *clamp_hi, + const __m128i *rnding, int bit) { + __m128i temp1, temp2, temp3, temp4; + for (int i = 0; i < 16; i++) { + addsub_sse4_1(u[i], u[31 - i], &u[i], &u[31 - i], clamp_lo, clamp_hi); } - for (i = 0; i < 512; ++i) in[512 + i] = zero; + temp1 = half_btf_sse4_1(cospim32, &u[40], cospi32, &u[55], rnding, bit); + temp2 = half_btf_sse4_1(cospim32, &u[41], cospi32, &u[54], rnding, bit); + temp3 = half_btf_sse4_1(cospim32, &u[42], cospi32, &u[53], rnding, bit); + temp4 = half_btf_sse4_1(cospim32, &u[43], cospi32, &u[52], rnding, bit); + u[52] = half_btf_sse4_1(cospi32, &u[43], cospi32, &u[52], rnding, bit); + u[53] = half_btf_sse4_1(cospi32, &u[42], cospi32, &u[53], rnding, bit); + u[54] = half_btf_sse4_1(cospi32, &u[41], cospi32, &u[54], rnding, bit); + u[55] = half_btf_sse4_1(cospi32, &u[40], cospi32, &u[55], rnding, bit); + u[40] = temp1; + u[41] = temp2; + u[42] = temp3; + u[43] = temp4; + + temp1 = half_btf_sse4_1(cospim32, &u[44], cospi32, &u[51], rnding, bit); + temp2 = half_btf_sse4_1(cospim32, &u[45], cospi32, &u[50], rnding, bit); + temp3 = half_btf_sse4_1(cospim32, &u[46], cospi32, &u[49], rnding, bit); + temp4 = half_btf_sse4_1(cospim32, &u[47], cospi32, &u[48], rnding, bit); + u[48] = half_btf_sse4_1(cospi32, &u[47], cospi32, &u[48], rnding, bit); + u[49] = half_btf_sse4_1(cospi32, &u[46], cospi32, &u[49], rnding, bit); + u[50] = half_btf_sse4_1(cospi32, &u[45], cospi32, &u[50], rnding, bit); + u[51] = half_btf_sse4_1(cospi32, &u[44], cospi32, &u[51], rnding, bit); + u[44] = temp1; + u[45] = temp2; + u[46] = temp3; + u[47] = temp4; } -static void transpose_64x64(__m128i *in, __m128i *out, int do_cols) { - int i, j; - for (i = 0; i < (do_cols ? 16 : 8); ++i) { - for (j = 0; j < 8; ++j) { - TRANSPOSE_4X4(in[(4 * i + 0) * 16 + j], in[(4 * i + 1) * 16 + j], - in[(4 * i + 2) * 16 + j], in[(4 * i + 3) * 16 + j], - out[(4 * j + 0) * 16 + i], out[(4 * j + 1) * 16 + i], - out[(4 * j + 2) * 16 + i], out[(4 * j + 3) * 16 + i]); +static INLINE void idct64_stage11_sse4_1(__m128i *u, __m128i *out, int do_cols, + int bd, int out_shift, + const int log_range) { + if (do_cols) { + for (int i = 0; i < 32; i++) { + addsub_no_clamp_sse4_1(u[i], u[63 - i], &out[(i)], &out[(63 - i)]); + } + } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); + + for (int i = 0; i < 32; i++) { + addsub_shift_sse4_1(u[i], u[63 - i], &out[(i)], &out[(63 - i)], + &clamp_lo_out, &clamp_hi_out, out_shift); } } } -static void assign_16x16_input_from_32x32(const __m128i *in, __m128i *in16x16, - int col) { - int i; - for (i = 0; i < 16 * 16 / 4; i += 4) { - in16x16[i] = in[col]; - in16x16[i + 1] = in[col + 1]; - in16x16[i + 2] = in[col + 2]; - in16x16[i + 3] = in[col + 3]; - col += 8; +static void idct64x64_low1_sse4_1(__m128i *in, __m128i *out, int bit, + int do_cols, int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); + const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); + + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + + { + __m128i x; + + // stage 1 + // stage 2 + // stage 3 + // stage 4 + // stage 5 + // stage 6 + x = half_btf_0_sse4_1(&cospi32, &in[0], &rnding, bit); + + // stage 8 + // stage 9 + // stage 10 + // stage 11 + if (do_cols) { + x = _mm_max_epi32(x, clamp_lo); + x = _mm_min_epi32(x, clamp_hi); + } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); + + __m128i offset = _mm_set1_epi32((1 << out_shift) >> 1); + x = _mm_add_epi32(x, offset); + x = _mm_sra_epi32(x, _mm_cvtsi32_si128(out_shift)); + + x = _mm_max_epi32(x, clamp_lo_out); + x = _mm_min_epi32(x, clamp_hi_out); + } + + out[0] = x; + out[63] = x; + out[1] = x; + out[62] = x; + out[2] = x; + out[61] = x; + out[3] = x; + out[60] = x; + out[4] = x; + out[59] = x; + out[5] = x; + out[58] = x; + out[6] = x; + out[57] = x; + out[7] = x; + out[56] = x; + out[8] = x; + out[55] = x; + out[9] = x; + out[54] = x; + out[10] = x; + out[53] = x; + out[11] = x; + out[52] = x; + out[12] = x; + out[51] = x; + out[13] = x; + out[50] = x; + out[14] = x; + out[49] = x; + out[15] = x; + out[48] = x; + out[16] = x; + out[47] = x; + out[17] = x; + out[46] = x; + out[18] = x; + out[45] = x; + out[19] = x; + out[44] = x; + out[20] = x; + out[43] = x; + out[21] = x; + out[42] = x; + out[22] = x; + out[41] = x; + out[23] = x; + out[40] = x; + out[24] = x; + out[39] = x; + out[25] = x; + out[38] = x; + out[26] = x; + out[37] = x; + out[27] = x; + out[36] = x; + out[28] = x; + out[35] = x; + out[29] = x; + out[34] = x; + out[30] = x; + out[33] = x; + out[31] = x; + out[32] = x; } } -static void write_buffer_32x32(__m128i *in, uint16_t *output, int stride, - int fliplr, int flipud, int shift, int bd) { - __m128i in16x16[16 * 16 / 4]; - uint16_t *leftUp = &output[0]; - uint16_t *rightUp = &output[16]; - uint16_t *leftDown = &output[16 * stride]; - uint16_t *rightDown = &output[16 * stride + 16]; +static void idct64x64_low8_sse4_1(__m128i *in, __m128i *out, int bit, + int do_cols, int bd, int out_shift) { + int i, j; + const int32_t *cospi = cospi_arr(bit); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); + const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); - if (fliplr) { - swap_addr(&leftUp, &rightUp); - swap_addr(&leftDown, &rightDown); - } + const __m128i cospi1 = _mm_set1_epi32(cospi[1]); + const __m128i cospi2 = _mm_set1_epi32(cospi[2]); + const __m128i cospi3 = _mm_set1_epi32(cospi[3]); + const __m128i cospi4 = _mm_set1_epi32(cospi[4]); + const __m128i cospi6 = _mm_set1_epi32(cospi[6]); + const __m128i cospi8 = _mm_set1_epi32(cospi[8]); + const __m128i cospi12 = _mm_set1_epi32(cospi[12]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i cospi20 = _mm_set1_epi32(cospi[20]); + const __m128i cospi24 = _mm_set1_epi32(cospi[24]); + const __m128i cospi28 = _mm_set1_epi32(cospi[28]); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i cospi40 = _mm_set1_epi32(cospi[40]); + const __m128i cospi44 = _mm_set1_epi32(cospi[44]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospi56 = _mm_set1_epi32(cospi[56]); + const __m128i cospi60 = _mm_set1_epi32(cospi[60]); + const __m128i cospim4 = _mm_set1_epi32(-cospi[4]); + const __m128i cospim8 = _mm_set1_epi32(-cospi[8]); + const __m128i cospim12 = _mm_set1_epi32(-cospi[12]); + const __m128i cospim16 = _mm_set1_epi32(-cospi[16]); + const __m128i cospim20 = _mm_set1_epi32(-cospi[20]); + const __m128i cospim24 = _mm_set1_epi32(-cospi[24]); + const __m128i cospim28 = _mm_set1_epi32(-cospi[28]); + const __m128i cospim32 = _mm_set1_epi32(-cospi[32]); + const __m128i cospim36 = _mm_set1_epi32(-cospi[36]); + const __m128i cospim40 = _mm_set1_epi32(-cospi[40]); + const __m128i cospim48 = _mm_set1_epi32(-cospi[48]); + const __m128i cospim52 = _mm_set1_epi32(-cospi[52]); + const __m128i cospim56 = _mm_set1_epi32(-cospi[56]); + const __m128i cospi63 = _mm_set1_epi32(cospi[63]); + const __m128i cospim57 = _mm_set1_epi32(-cospi[57]); + const __m128i cospi7 = _mm_set1_epi32(cospi[7]); + const __m128i cospi5 = _mm_set1_epi32(cospi[5]); + const __m128i cospi59 = _mm_set1_epi32(cospi[59]); + const __m128i cospim61 = _mm_set1_epi32(-cospi[61]); + const __m128i cospim58 = _mm_set1_epi32(-cospi[58]); + const __m128i cospi62 = _mm_set1_epi32(cospi[62]); - if (flipud) { - swap_addr(&leftUp, &leftDown); - swap_addr(&rightUp, &rightDown); - } + { + __m128i u[64]; + + // stage 1 + u[0] = in[0]; + u[8] = in[4]; + u[16] = in[2]; + u[24] = in[6]; + u[32] = in[1]; + u[40] = in[5]; + u[48] = in[3]; + u[56] = in[7]; + + // stage 2 + u[63] = half_btf_0_sse4_1(&cospi1, &u[32], &rnding, bit); + u[32] = half_btf_0_sse4_1(&cospi63, &u[32], &rnding, bit); + u[39] = half_btf_0_sse4_1(&cospim57, &u[56], &rnding, bit); + u[56] = half_btf_0_sse4_1(&cospi7, &u[56], &rnding, bit); + u[55] = half_btf_0_sse4_1(&cospi5, &u[40], &rnding, bit); + u[40] = half_btf_0_sse4_1(&cospi59, &u[40], &rnding, bit); + u[47] = half_btf_0_sse4_1(&cospim61, &u[48], &rnding, bit); + u[48] = half_btf_0_sse4_1(&cospi3, &u[48], &rnding, bit); - // Left-up quarter - assign_16x16_input_from_32x32(in, in16x16, 0); - write_buffer_16x16(in16x16, leftUp, stride, fliplr, flipud, shift, bd); + // stage 3 + u[31] = half_btf_0_sse4_1(&cospi2, &u[16], &rnding, bit); + u[16] = half_btf_0_sse4_1(&cospi62, &u[16], &rnding, bit); + u[23] = half_btf_0_sse4_1(&cospim58, &u[24], &rnding, bit); + u[24] = half_btf_0_sse4_1(&cospi6, &u[24], &rnding, bit); + u[33] = u[32]; + u[38] = u[39]; + u[41] = u[40]; + u[46] = u[47]; + u[49] = u[48]; + u[54] = u[55]; + u[57] = u[56]; + u[62] = u[63]; - // Right-up quarter - assign_16x16_input_from_32x32(in, in16x16, 32 / 2 / 4); - write_buffer_16x16(in16x16, rightUp, stride, fliplr, flipud, shift, bd); + // stage 4 + __m128i temp1, temp2; + u[15] = half_btf_0_sse4_1(&cospi4, &u[8], &rnding, bit); + u[8] = half_btf_0_sse4_1(&cospi60, &u[8], &rnding, bit); + u[17] = u[16]; + u[22] = u[23]; + u[25] = u[24]; + u[30] = u[31]; + + temp1 = half_btf_sse4_1(&cospim4, &u[33], &cospi60, &u[62], &rnding, bit); + u[62] = half_btf_sse4_1(&cospi60, &u[33], &cospi4, &u[62], &rnding, bit); + u[33] = temp1; + + temp2 = half_btf_sse4_1(&cospim36, &u[38], &cospi28, &u[57], &rnding, bit); + u[38] = half_btf_sse4_1(&cospim28, &u[38], &cospim36, &u[57], &rnding, bit); + u[57] = temp2; + + temp1 = half_btf_sse4_1(&cospim20, &u[41], &cospi44, &u[54], &rnding, bit); + u[54] = half_btf_sse4_1(&cospi44, &u[41], &cospi20, &u[54], &rnding, bit); + u[41] = temp1; + + temp2 = half_btf_sse4_1(&cospim12, &u[46], &cospim52, &u[49], &rnding, bit); + u[49] = half_btf_sse4_1(&cospim52, &u[46], &cospi12, &u[49], &rnding, bit); + u[46] = temp2; - // Left-down quarter - assign_16x16_input_from_32x32(in, in16x16, 32 * 32 / 2 / 4); - write_buffer_16x16(in16x16, leftDown, stride, fliplr, flipud, shift, bd); + // stage 5 + u[9] = u[8]; + u[14] = u[15]; + + temp1 = half_btf_sse4_1(&cospim8, &u[17], &cospi56, &u[30], &rnding, bit); + u[30] = half_btf_sse4_1(&cospi56, &u[17], &cospi8, &u[30], &rnding, bit); + u[17] = temp1; + + temp2 = half_btf_sse4_1(&cospim24, &u[22], &cospim40, &u[25], &rnding, bit); + u[25] = half_btf_sse4_1(&cospim40, &u[22], &cospi24, &u[25], &rnding, bit); + u[22] = temp2; + + u[35] = u[32]; + u[34] = u[33]; + u[36] = u[39]; + u[37] = u[38]; + u[43] = u[40]; + u[42] = u[41]; + u[44] = u[47]; + u[45] = u[46]; + u[51] = u[48]; + u[50] = u[49]; + u[52] = u[55]; + u[53] = u[54]; + u[59] = u[56]; + u[58] = u[57]; + u[60] = u[63]; + u[61] = u[62]; - // Right-down quarter - assign_16x16_input_from_32x32(in, in16x16, 32 * 32 / 2 / 4 + 32 / 2 / 4); - write_buffer_16x16(in16x16, rightDown, stride, fliplr, flipud, shift, bd); -} + // stage 6 + temp1 = half_btf_0_sse4_1(&cospi32, &u[0], &rnding, bit); + u[1] = half_btf_0_sse4_1(&cospi32, &u[0], &rnding, bit); + u[0] = temp1; + + temp2 = half_btf_sse4_1(&cospim16, &u[9], &cospi48, &u[14], &rnding, bit); + u[14] = half_btf_sse4_1(&cospi48, &u[9], &cospi16, &u[14], &rnding, bit); + u[9] = temp2; + u[19] = u[16]; + u[18] = u[17]; + u[20] = u[23]; + u[21] = u[22]; + u[27] = u[24]; + u[26] = u[25]; + u[28] = u[31]; + u[29] = u[30]; + + temp1 = half_btf_sse4_1(&cospim8, &u[34], &cospi56, &u[61], &rnding, bit); + u[61] = half_btf_sse4_1(&cospi56, &u[34], &cospi8, &u[61], &rnding, bit); + u[34] = temp1; + temp2 = half_btf_sse4_1(&cospim8, &u[35], &cospi56, &u[60], &rnding, bit); + u[60] = half_btf_sse4_1(&cospi56, &u[35], &cospi8, &u[60], &rnding, bit); + u[35] = temp2; + temp1 = half_btf_sse4_1(&cospim56, &u[36], &cospim8, &u[59], &rnding, bit); + u[59] = half_btf_sse4_1(&cospim8, &u[36], &cospi56, &u[59], &rnding, bit); + u[36] = temp1; + temp2 = half_btf_sse4_1(&cospim56, &u[37], &cospim8, &u[58], &rnding, bit); + u[58] = half_btf_sse4_1(&cospim8, &u[37], &cospi56, &u[58], &rnding, bit); + u[37] = temp2; + temp1 = half_btf_sse4_1(&cospim40, &u[42], &cospi24, &u[53], &rnding, bit); + u[53] = half_btf_sse4_1(&cospi24, &u[42], &cospi40, &u[53], &rnding, bit); + u[42] = temp1; + temp2 = half_btf_sse4_1(&cospim40, &u[43], &cospi24, &u[52], &rnding, bit); + u[52] = half_btf_sse4_1(&cospi24, &u[43], &cospi40, &u[52], &rnding, bit); + u[43] = temp2; + temp1 = half_btf_sse4_1(&cospim24, &u[44], &cospim40, &u[51], &rnding, bit); + u[51] = half_btf_sse4_1(&cospim40, &u[44], &cospi24, &u[51], &rnding, bit); + u[44] = temp1; + temp2 = half_btf_sse4_1(&cospim24, &u[45], &cospim40, &u[50], &rnding, bit); + u[50] = half_btf_sse4_1(&cospim40, &u[45], &cospi24, &u[50], &rnding, bit); + u[45] = temp2; -static void assign_32x32_input_from_64x64(const __m128i *in, __m128i *in32x32, - int col) { - int i; - for (i = 0; i < 32 * 32 / 4; i += 8) { - in32x32[i] = in[col]; - in32x32[i + 1] = in[col + 1]; - in32x32[i + 2] = in[col + 2]; - in32x32[i + 3] = in[col + 3]; - in32x32[i + 4] = in[col + 4]; - in32x32[i + 5] = in[col + 5]; - in32x32[i + 6] = in[col + 6]; - in32x32[i + 7] = in[col + 7]; - col += 16; + // stage 7 + u[3] = u[0]; + u[2] = u[1]; + u[11] = u[8]; + u[10] = u[9]; + u[12] = u[15]; + u[13] = u[14]; + + temp1 = half_btf_sse4_1(&cospim16, &u[18], &cospi48, &u[29], &rnding, bit); + u[29] = half_btf_sse4_1(&cospi48, &u[18], &cospi16, &u[29], &rnding, bit); + u[18] = temp1; + temp2 = half_btf_sse4_1(&cospim16, &u[19], &cospi48, &u[28], &rnding, bit); + u[28] = half_btf_sse4_1(&cospi48, &u[19], &cospi16, &u[28], &rnding, bit); + u[19] = temp2; + temp1 = half_btf_sse4_1(&cospim48, &u[20], &cospim16, &u[27], &rnding, bit); + u[27] = half_btf_sse4_1(&cospim16, &u[20], &cospi48, &u[27], &rnding, bit); + u[20] = temp1; + temp2 = half_btf_sse4_1(&cospim48, &u[21], &cospim16, &u[26], &rnding, bit); + u[26] = half_btf_sse4_1(&cospim16, &u[21], &cospi48, &u[26], &rnding, bit); + u[21] = temp2; + for (i = 32; i < 64; i += 16) { + for (j = i; j < i + 4; j++) { + addsub_sse4_1(u[j], u[j ^ 7], &u[j], &u[j ^ 7], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[j ^ 15], u[j ^ 8], &u[j ^ 15], &u[j ^ 8], &clamp_lo, + &clamp_hi); + } + } + + // stage 8 + u[7] = u[0]; + u[6] = u[1]; + u[5] = u[2]; + u[4] = u[3]; + u[9] = u[9]; + + idct64_stage8_sse4_1(u, &cospim32, &cospi32, &cospim16, &cospi48, &cospi16, + &cospim48, &clamp_lo, &clamp_hi, &rnding, bit); + + // stage 9 + idct64_stage9_sse4_1(u, &cospim32, &cospi32, &clamp_lo, &clamp_hi, &rnding, + bit); + + // stage 10 + idct64_stage10_sse4_1(u, &cospim32, &cospi32, &clamp_lo, &clamp_hi, &rnding, + bit); + + // stage 11 + idct64_stage11_sse4_1(u, out, do_cols, bd, out_shift, log_range); } } -static void write_buffer_64x64(__m128i *in, uint16_t *output, int stride, - int fliplr, int flipud, int shift, int bd) { - __m128i in32x32[32 * 32 / 4]; - uint16_t *leftUp = &output[0]; - uint16_t *rightUp = &output[32]; - uint16_t *leftDown = &output[32 * stride]; - uint16_t *rightDown = &output[32 * stride + 32]; +static void idct64x64_low16_sse4_1(__m128i *in, __m128i *out, int bit, + int do_cols, int bd, int out_shift) { + int i, j; + const int32_t *cospi = cospi_arr(bit); + const __m128i rnding = _mm_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); + const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); - if (fliplr) { - swap_addr(&leftUp, &rightUp); - swap_addr(&leftDown, &rightDown); - } + const __m128i cospi1 = _mm_set1_epi32(cospi[1]); + const __m128i cospi2 = _mm_set1_epi32(cospi[2]); + const __m128i cospi3 = _mm_set1_epi32(cospi[3]); + const __m128i cospi4 = _mm_set1_epi32(cospi[4]); + const __m128i cospi5 = _mm_set1_epi32(cospi[5]); + const __m128i cospi6 = _mm_set1_epi32(cospi[6]); + const __m128i cospi7 = _mm_set1_epi32(cospi[7]); + const __m128i cospi8 = _mm_set1_epi32(cospi[8]); + const __m128i cospi9 = _mm_set1_epi32(cospi[9]); + const __m128i cospi10 = _mm_set1_epi32(cospi[10]); + const __m128i cospi11 = _mm_set1_epi32(cospi[11]); + const __m128i cospi12 = _mm_set1_epi32(cospi[12]); + const __m128i cospi13 = _mm_set1_epi32(cospi[13]); + const __m128i cospi14 = _mm_set1_epi32(cospi[14]); + const __m128i cospi15 = _mm_set1_epi32(cospi[15]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i cospi20 = _mm_set1_epi32(cospi[20]); + const __m128i cospi24 = _mm_set1_epi32(cospi[24]); + const __m128i cospi28 = _mm_set1_epi32(cospi[28]); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i cospi36 = _mm_set1_epi32(cospi[36]); + const __m128i cospi40 = _mm_set1_epi32(cospi[40]); + const __m128i cospi44 = _mm_set1_epi32(cospi[44]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospi51 = _mm_set1_epi32(cospi[51]); + const __m128i cospi52 = _mm_set1_epi32(cospi[52]); + const __m128i cospi54 = _mm_set1_epi32(cospi[54]); + const __m128i cospi55 = _mm_set1_epi32(cospi[55]); + const __m128i cospi56 = _mm_set1_epi32(cospi[56]); + const __m128i cospi59 = _mm_set1_epi32(cospi[59]); + const __m128i cospi60 = _mm_set1_epi32(cospi[60]); + const __m128i cospi62 = _mm_set1_epi32(cospi[62]); + const __m128i cospi63 = _mm_set1_epi32(cospi[63]); - if (flipud) { - swap_addr(&leftUp, &leftDown); - swap_addr(&rightUp, &rightDown); - } + const __m128i cospim4 = _mm_set1_epi32(-cospi[4]); + const __m128i cospim8 = _mm_set1_epi32(-cospi[8]); + const __m128i cospim12 = _mm_set1_epi32(-cospi[12]); + const __m128i cospim16 = _mm_set1_epi32(-cospi[16]); + const __m128i cospim20 = _mm_set1_epi32(-cospi[20]); + const __m128i cospim24 = _mm_set1_epi32(-cospi[24]); + const __m128i cospim28 = _mm_set1_epi32(-cospi[28]); + const __m128i cospim32 = _mm_set1_epi32(-cospi[32]); + const __m128i cospim36 = _mm_set1_epi32(-cospi[36]); + const __m128i cospim40 = _mm_set1_epi32(-cospi[40]); + const __m128i cospim44 = _mm_set1_epi32(-cospi[44]); + const __m128i cospim48 = _mm_set1_epi32(-cospi[48]); + const __m128i cospim49 = _mm_set1_epi32(-cospi[49]); + const __m128i cospim50 = _mm_set1_epi32(-cospi[50]); + const __m128i cospim52 = _mm_set1_epi32(-cospi[52]); + const __m128i cospim53 = _mm_set1_epi32(-cospi[53]); + const __m128i cospim56 = _mm_set1_epi32(-cospi[56]); + const __m128i cospim57 = _mm_set1_epi32(-cospi[57]); + const __m128i cospim58 = _mm_set1_epi32(-cospi[58]); + const __m128i cospim60 = _mm_set1_epi32(-cospi[60]); + const __m128i cospim61 = _mm_set1_epi32(-cospi[61]); + + { + __m128i u[64]; + __m128i tmp1, tmp2, tmp3, tmp4; + // stage 1 + u[0] = in[0]; + u[32] = in[1]; + u[36] = in[9]; + u[40] = in[5]; + u[44] = in[13]; + u[48] = in[3]; + u[52] = in[11]; + u[56] = in[7]; + u[60] = in[15]; + u[16] = in[2]; + u[20] = in[10]; + u[24] = in[6]; + u[28] = in[14]; + u[4] = in[8]; + u[8] = in[4]; + u[12] = in[12]; + + // stage 2 + u[63] = half_btf_0_sse4_1(&cospi1, &u[32], &rnding, bit); + u[32] = half_btf_0_sse4_1(&cospi63, &u[32], &rnding, bit); + u[35] = half_btf_0_sse4_1(&cospim49, &u[60], &rnding, bit); + u[60] = half_btf_0_sse4_1(&cospi15, &u[60], &rnding, bit); + u[59] = half_btf_0_sse4_1(&cospi9, &u[36], &rnding, bit); + u[36] = half_btf_0_sse4_1(&cospi55, &u[36], &rnding, bit); + u[39] = half_btf_0_sse4_1(&cospim57, &u[56], &rnding, bit); + u[56] = half_btf_0_sse4_1(&cospi7, &u[56], &rnding, bit); + u[55] = half_btf_0_sse4_1(&cospi5, &u[40], &rnding, bit); + u[40] = half_btf_0_sse4_1(&cospi59, &u[40], &rnding, bit); + u[43] = half_btf_0_sse4_1(&cospim53, &u[52], &rnding, bit); + u[52] = half_btf_0_sse4_1(&cospi11, &u[52], &rnding, bit); + u[47] = half_btf_0_sse4_1(&cospim61, &u[48], &rnding, bit); + u[48] = half_btf_0_sse4_1(&cospi3, &u[48], &rnding, bit); + u[51] = half_btf_0_sse4_1(&cospi13, &u[44], &rnding, bit); + u[44] = half_btf_0_sse4_1(&cospi51, &u[44], &rnding, bit); - // Left-up quarter - assign_32x32_input_from_64x64(in, in32x32, 0); - write_buffer_32x32(in32x32, leftUp, stride, fliplr, flipud, shift, bd); + // stage 3 + u[31] = half_btf_0_sse4_1(&cospi2, &u[16], &rnding, bit); + u[16] = half_btf_0_sse4_1(&cospi62, &u[16], &rnding, bit); + u[19] = half_btf_0_sse4_1(&cospim50, &u[28], &rnding, bit); + u[28] = half_btf_0_sse4_1(&cospi14, &u[28], &rnding, bit); + u[27] = half_btf_0_sse4_1(&cospi10, &u[20], &rnding, bit); + u[20] = half_btf_0_sse4_1(&cospi54, &u[20], &rnding, bit); + u[23] = half_btf_0_sse4_1(&cospim58, &u[24], &rnding, bit); + u[24] = half_btf_0_sse4_1(&cospi6, &u[24], &rnding, bit); + u[33] = u[32]; + u[34] = u[35]; + u[37] = u[36]; + u[38] = u[39]; + u[41] = u[40]; + u[42] = u[43]; + u[45] = u[44]; + u[46] = u[47]; + u[49] = u[48]; + u[50] = u[51]; + u[53] = u[52]; + u[54] = u[55]; + u[57] = u[56]; + u[58] = u[59]; + u[61] = u[60]; + u[62] = u[63]; - // Right-up quarter - assign_32x32_input_from_64x64(in, in32x32, 64 / 2 / 4); - write_buffer_32x32(in32x32, rightUp, stride, fliplr, flipud, shift, bd); + // stage 4 + u[15] = half_btf_0_sse4_1(&cospi4, &u[8], &rnding, bit); + u[8] = half_btf_0_sse4_1(&cospi60, &u[8], &rnding, bit); + u[11] = half_btf_0_sse4_1(&cospim52, &u[12], &rnding, bit); + u[12] = half_btf_0_sse4_1(&cospi12, &u[12], &rnding, bit); + + u[17] = u[16]; + u[18] = u[19]; + u[21] = u[20]; + u[22] = u[23]; + u[25] = u[24]; + u[26] = u[27]; + u[29] = u[28]; + u[30] = u[31]; + + tmp1 = half_btf_sse4_1(&cospim4, &u[33], &cospi60, &u[62], &rnding, bit); + tmp2 = half_btf_sse4_1(&cospim60, &u[34], &cospim4, &u[61], &rnding, bit); + tmp3 = half_btf_sse4_1(&cospim36, &u[37], &cospi28, &u[58], &rnding, bit); + tmp4 = half_btf_sse4_1(&cospim28, &u[38], &cospim36, &u[57], &rnding, bit); + u[57] = half_btf_sse4_1(&cospim36, &u[38], &cospi28, &u[57], &rnding, bit); + u[58] = half_btf_sse4_1(&cospi28, &u[37], &cospi36, &u[58], &rnding, bit); + u[61] = half_btf_sse4_1(&cospim4, &u[34], &cospi60, &u[61], &rnding, bit); + u[62] = half_btf_sse4_1(&cospi60, &u[33], &cospi4, &u[62], &rnding, bit); + u[33] = tmp1; + u[34] = tmp2; + u[37] = tmp3; + u[38] = tmp4; + + tmp1 = half_btf_sse4_1(&cospim20, &u[41], &cospi44, &u[54], &rnding, bit); + tmp2 = half_btf_sse4_1(&cospim44, &u[42], &cospim20, &u[53], &rnding, bit); + tmp3 = half_btf_sse4_1(&cospim52, &u[45], &cospi12, &u[50], &rnding, bit); + tmp4 = half_btf_sse4_1(&cospim12, &u[46], &cospim52, &u[49], &rnding, bit); + u[49] = half_btf_sse4_1(&cospim52, &u[46], &cospi12, &u[49], &rnding, bit); + u[50] = half_btf_sse4_1(&cospi12, &u[45], &cospi52, &u[50], &rnding, bit); + u[53] = half_btf_sse4_1(&cospim20, &u[42], &cospi44, &u[53], &rnding, bit); + u[54] = half_btf_sse4_1(&cospi44, &u[41], &cospi20, &u[54], &rnding, bit); + u[41] = tmp1; + u[42] = tmp2; + u[45] = tmp3; + u[46] = tmp4; - // Left-down quarter - assign_32x32_input_from_64x64(in, in32x32, 64 * 64 / 2 / 4); - write_buffer_32x32(in32x32, leftDown, stride, fliplr, flipud, shift, bd); + // stage 5 + u[7] = half_btf_0_sse4_1(&cospi8, &u[4], &rnding, bit); + u[4] = half_btf_0_sse4_1(&cospi56, &u[4], &rnding, bit); + + u[9] = u[8]; + u[10] = u[11]; + u[13] = u[12]; + u[14] = u[15]; + + tmp1 = half_btf_sse4_1(&cospim8, &u[17], &cospi56, &u[30], &rnding, bit); + tmp2 = half_btf_sse4_1(&cospim56, &u[18], &cospim8, &u[29], &rnding, bit); + tmp3 = half_btf_sse4_1(&cospim40, &u[21], &cospi24, &u[26], &rnding, bit); + tmp4 = half_btf_sse4_1(&cospim24, &u[22], &cospim40, &u[25], &rnding, bit); + u[25] = half_btf_sse4_1(&cospim40, &u[22], &cospi24, &u[25], &rnding, bit); + u[26] = half_btf_sse4_1(&cospi24, &u[21], &cospi40, &u[26], &rnding, bit); + u[29] = half_btf_sse4_1(&cospim8, &u[18], &cospi56, &u[29], &rnding, bit); + u[30] = half_btf_sse4_1(&cospi56, &u[17], &cospi8, &u[30], &rnding, bit); + u[17] = tmp1; + u[18] = tmp2; + u[21] = tmp3; + u[22] = tmp4; - // Right-down quarter - assign_32x32_input_from_64x64(in, in32x32, 64 * 64 / 2 / 4 + 64 / 2 / 4); - write_buffer_32x32(in32x32, rightDown, stride, fliplr, flipud, shift, bd); + for (i = 32; i < 64; i += 8) { + addsub_sse4_1(u[i + 0], u[i + 3], &u[i + 0], &u[i + 3], &clamp_lo, + &clamp_hi); + addsub_sse4_1(u[i + 1], u[i + 2], &u[i + 1], &u[i + 2], &clamp_lo, + &clamp_hi); + + addsub_sse4_1(u[i + 7], u[i + 4], &u[i + 7], &u[i + 4], &clamp_lo, + &clamp_hi); + addsub_sse4_1(u[i + 6], u[i + 5], &u[i + 6], &u[i + 5], &clamp_lo, + &clamp_hi); + } + + // stage 6 + tmp1 = half_btf_0_sse4_1(&cospi32, &u[0], &rnding, bit); + u[1] = half_btf_0_sse4_1(&cospi32, &u[0], &rnding, bit); + u[0] = tmp1; + u[5] = u[4]; + u[6] = u[7]; + + tmp1 = half_btf_sse4_1(&cospim16, &u[9], &cospi48, &u[14], &rnding, bit); + u[14] = half_btf_sse4_1(&cospi48, &u[9], &cospi16, &u[14], &rnding, bit); + u[9] = tmp1; + tmp2 = half_btf_sse4_1(&cospim48, &u[10], &cospim16, &u[13], &rnding, bit); + u[13] = half_btf_sse4_1(&cospim16, &u[10], &cospi48, &u[13], &rnding, bit); + u[10] = tmp2; + + for (i = 16; i < 32; i += 8) { + addsub_sse4_1(u[i + 0], u[i + 3], &u[i + 0], &u[i + 3], &clamp_lo, + &clamp_hi); + addsub_sse4_1(u[i + 1], u[i + 2], &u[i + 1], &u[i + 2], &clamp_lo, + &clamp_hi); + + addsub_sse4_1(u[i + 7], u[i + 4], &u[i + 7], &u[i + 4], &clamp_lo, + &clamp_hi); + addsub_sse4_1(u[i + 6], u[i + 5], &u[i + 6], &u[i + 5], &clamp_lo, + &clamp_hi); + } + + tmp1 = half_btf_sse4_1(&cospim8, &u[34], &cospi56, &u[61], &rnding, bit); + tmp2 = half_btf_sse4_1(&cospim8, &u[35], &cospi56, &u[60], &rnding, bit); + tmp3 = half_btf_sse4_1(&cospim56, &u[36], &cospim8, &u[59], &rnding, bit); + tmp4 = half_btf_sse4_1(&cospim56, &u[37], &cospim8, &u[58], &rnding, bit); + u[58] = half_btf_sse4_1(&cospim8, &u[37], &cospi56, &u[58], &rnding, bit); + u[59] = half_btf_sse4_1(&cospim8, &u[36], &cospi56, &u[59], &rnding, bit); + u[60] = half_btf_sse4_1(&cospi56, &u[35], &cospi8, &u[60], &rnding, bit); + u[61] = half_btf_sse4_1(&cospi56, &u[34], &cospi8, &u[61], &rnding, bit); + u[34] = tmp1; + u[35] = tmp2; + u[36] = tmp3; + u[37] = tmp4; + + tmp1 = half_btf_sse4_1(&cospim40, &u[42], &cospi24, &u[53], &rnding, bit); + tmp2 = half_btf_sse4_1(&cospim40, &u[43], &cospi24, &u[52], &rnding, bit); + tmp3 = half_btf_sse4_1(&cospim24, &u[44], &cospim40, &u[51], &rnding, bit); + tmp4 = half_btf_sse4_1(&cospim24, &u[45], &cospim40, &u[50], &rnding, bit); + u[50] = half_btf_sse4_1(&cospim40, &u[45], &cospi24, &u[50], &rnding, bit); + u[51] = half_btf_sse4_1(&cospim40, &u[44], &cospi24, &u[51], &rnding, bit); + u[52] = half_btf_sse4_1(&cospi24, &u[43], &cospi40, &u[52], &rnding, bit); + u[53] = half_btf_sse4_1(&cospi24, &u[42], &cospi40, &u[53], &rnding, bit); + u[42] = tmp1; + u[43] = tmp2; + u[44] = tmp3; + u[45] = tmp4; + + // stage 7 + u[3] = u[0]; + u[2] = u[1]; + tmp1 = half_btf_sse4_1(&cospim32, &u[5], &cospi32, &u[6], &rnding, bit); + u[6] = half_btf_sse4_1(&cospi32, &u[5], &cospi32, &u[6], &rnding, bit); + u[5] = tmp1; + addsub_sse4_1(u[8], u[11], &u[8], &u[11], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[9], u[10], &u[9], &u[10], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[15], u[12], &u[15], &u[12], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[14], u[13], &u[14], &u[13], &clamp_lo, &clamp_hi); + + tmp1 = half_btf_sse4_1(&cospim16, &u[18], &cospi48, &u[29], &rnding, bit); + tmp2 = half_btf_sse4_1(&cospim16, &u[19], &cospi48, &u[28], &rnding, bit); + tmp3 = half_btf_sse4_1(&cospim48, &u[20], &cospim16, &u[27], &rnding, bit); + tmp4 = half_btf_sse4_1(&cospim48, &u[21], &cospim16, &u[26], &rnding, bit); + u[26] = half_btf_sse4_1(&cospim16, &u[21], &cospi48, &u[26], &rnding, bit); + u[27] = half_btf_sse4_1(&cospim16, &u[20], &cospi48, &u[27], &rnding, bit); + u[28] = half_btf_sse4_1(&cospi48, &u[19], &cospi16, &u[28], &rnding, bit); + u[29] = half_btf_sse4_1(&cospi48, &u[18], &cospi16, &u[29], &rnding, bit); + u[18] = tmp1; + u[19] = tmp2; + u[20] = tmp3; + u[21] = tmp4; + + for (i = 32; i < 64; i += 16) { + for (j = i; j < i + 4; j++) { + addsub_sse4_1(u[j], u[j ^ 7], &u[j], &u[j ^ 7], &clamp_lo, &clamp_hi); + addsub_sse4_1(u[j ^ 15], u[j ^ 8], &u[j ^ 15], &u[j ^ 8], &clamp_lo, + &clamp_hi); + } + } + + // stage 8 + for (i = 0; i < 4; ++i) { + addsub_sse4_1(u[i], u[7 - i], &u[i], &u[7 - i], &clamp_lo, &clamp_hi); + } + + idct64_stage8_sse4_1(u, &cospim32, &cospi32, &cospim16, &cospi48, &cospi16, + &cospim48, &clamp_lo, &clamp_hi, &rnding, bit); + + // stage 9 + idct64_stage9_sse4_1(u, &cospim32, &cospi32, &clamp_lo, &clamp_hi, &rnding, + bit); + + // stage 10 + idct64_stage10_sse4_1(u, &cospim32, &cospi32, &clamp_lo, &clamp_hi, &rnding, + bit); + + // stage 11 + idct64_stage11_sse4_1(u, out, do_cols, bd, out_shift, log_range); + } } static void idct64x64_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, @@ -1847,7 +3770,6 @@ static void idct64x64_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); - int col; const __m128i cospi1 = _mm_set1_epi32(cospi[1]); const __m128i cospi2 = _mm_set1_epi32(cospi[2]); @@ -1929,46 +3851,46 @@ static void idct64x64_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, const __m128i cospim60 = _mm_set1_epi32(-cospi[60]); const __m128i cospim61 = _mm_set1_epi32(-cospi[61]); - for (col = 0; col < (do_cols ? 64 / 4 : 32 / 4); ++col) { + { __m128i u[64], v[64]; // stage 1 - u[32] = in[1 * 16 + col]; - u[34] = in[17 * 16 + col]; - u[36] = in[9 * 16 + col]; - u[38] = in[25 * 16 + col]; - u[40] = in[5 * 16 + col]; - u[42] = in[21 * 16 + col]; - u[44] = in[13 * 16 + col]; - u[46] = in[29 * 16 + col]; - u[48] = in[3 * 16 + col]; - u[50] = in[19 * 16 + col]; - u[52] = in[11 * 16 + col]; - u[54] = in[27 * 16 + col]; - u[56] = in[7 * 16 + col]; - u[58] = in[23 * 16 + col]; - u[60] = in[15 * 16 + col]; - u[62] = in[31 * 16 + col]; - - v[16] = in[2 * 16 + col]; - v[18] = in[18 * 16 + col]; - v[20] = in[10 * 16 + col]; - v[22] = in[26 * 16 + col]; - v[24] = in[6 * 16 + col]; - v[26] = in[22 * 16 + col]; - v[28] = in[14 * 16 + col]; - v[30] = in[30 * 16 + col]; - - u[8] = in[4 * 16 + col]; - u[10] = in[20 * 16 + col]; - u[12] = in[12 * 16 + col]; - u[14] = in[28 * 16 + col]; - - v[4] = in[8 * 16 + col]; - v[6] = in[24 * 16 + col]; - - u[0] = in[0 * 16 + col]; - u[2] = in[16 * 16 + col]; + u[32] = in[1]; + u[34] = in[17]; + u[36] = in[9]; + u[38] = in[25]; + u[40] = in[5]; + u[42] = in[21]; + u[44] = in[13]; + u[46] = in[29]; + u[48] = in[3]; + u[50] = in[19]; + u[52] = in[11]; + u[54] = in[27]; + u[56] = in[7]; + u[58] = in[23]; + u[60] = in[15]; + u[62] = in[31]; + + v[16] = in[2]; + v[18] = in[18]; + v[20] = in[10]; + v[22] = in[26]; + v[24] = in[6]; + v[26] = in[22]; + v[28] = in[14]; + v[30] = in[30]; + + u[8] = in[4]; + u[10] = in[20]; + u[12] = in[12]; + u[14] = in[28]; + + v[4] = in[8]; + v[6] = in[24]; + + u[0] = in[0]; + u[2] = in[16]; // stage 2 v[32] = half_btf_0_sse4_1(&cospi63, &u[32], &rnding, bit); @@ -2301,39 +4223,1126 @@ static void idct64x64_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, // stage 11 if (do_cols) { for (i = 0; i < 32; i++) { - addsub_no_clamp_sse4_1(v[i], v[63 - i], &out[16 * (i) + col], - &out[16 * (63 - i) + col]); + addsub_no_clamp_sse4_1(v[i], v[63 - i], &out[(i)], &out[(63 - i)]); } } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); + for (i = 0; i < 32; i++) { - addsub_shift_sse4_1(v[i], v[63 - i], &out[16 * (i) + col], - &out[16 * (63 - i) + col], &clamp_lo, &clamp_hi, - out_shift); + addsub_shift_sse4_1(v[i], v[63 - i], &out[(i)], &out[(63 - i)], + &clamp_lo_out, &clamp_hi_out, out_shift); } } } } -void av1_inv_txfm2d_add_64x64_sse4_1(const int32_t *coeff, uint16_t *output, - int stride, TX_TYPE tx_type, int bd) { - __m128i in[64 * 64 / 4], out[64 * 64 / 4]; - const int8_t *shift = inv_txfm_shift_ls[TX_64X64]; - const int txw_idx = tx_size_wide_log2[TX_64X64] - tx_size_wide_log2[0]; - const int txh_idx = tx_size_high_log2[TX_64X64] - tx_size_high_log2[0]; +static void idct32x32_low1_sse4_1(__m128i *in, __m128i *out, int bit, + int do_cols, int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i rounding = _mm_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); + const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); + __m128i bf1; + + // stage 0 + // stage 1 + bf1 = in[0]; + + // stage 2 + // stage 3 + // stage 4 + // stage 5 + bf1 = half_btf_0_sse4_1(&cospi32, &bf1, &rounding, bit); + + // stage 6 + // stage 7 + // stage 8 + // stage 9 + if (do_cols) { + bf1 = _mm_max_epi32(bf1, clamp_lo); + bf1 = _mm_min_epi32(bf1, clamp_hi); + } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); + + __m128i offset = _mm_set1_epi32((1 << out_shift) >> 1); + bf1 = _mm_add_epi32(bf1, offset); + bf1 = _mm_sra_epi32(bf1, _mm_cvtsi32_si128(out_shift)); + bf1 = _mm_max_epi32(bf1, clamp_lo_out); + bf1 = _mm_min_epi32(bf1, clamp_hi_out); + } + out[0] = bf1; + out[1] = bf1; + out[2] = bf1; + out[3] = bf1; + out[4] = bf1; + out[5] = bf1; + out[6] = bf1; + out[7] = bf1; + out[8] = bf1; + out[9] = bf1; + out[10] = bf1; + out[11] = bf1; + out[12] = bf1; + out[13] = bf1; + out[14] = bf1; + out[15] = bf1; + out[16] = bf1; + out[17] = bf1; + out[18] = bf1; + out[19] = bf1; + out[20] = bf1; + out[21] = bf1; + out[22] = bf1; + out[23] = bf1; + out[24] = bf1; + out[25] = bf1; + out[26] = bf1; + out[27] = bf1; + out[28] = bf1; + out[29] = bf1; + out[30] = bf1; + out[31] = bf1; +} + +static void idct32x32_low8_sse4_1(__m128i *in, __m128i *out, int bit, + int do_cols, int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m128i cospi62 = _mm_set1_epi32(cospi[62]); + const __m128i cospi14 = _mm_set1_epi32(cospi[14]); + const __m128i cospi54 = _mm_set1_epi32(cospi[54]); + const __m128i cospi6 = _mm_set1_epi32(cospi[6]); + const __m128i cospi10 = _mm_set1_epi32(cospi[10]); + const __m128i cospi2 = _mm_set1_epi32(cospi[2]); + const __m128i cospim58 = _mm_set1_epi32(-cospi[58]); + const __m128i cospim50 = _mm_set1_epi32(-cospi[50]); + const __m128i cospi60 = _mm_set1_epi32(cospi[60]); + const __m128i cospi12 = _mm_set1_epi32(cospi[12]); + const __m128i cospi4 = _mm_set1_epi32(cospi[4]); + const __m128i cospim52 = _mm_set1_epi32(-cospi[52]); + const __m128i cospi56 = _mm_set1_epi32(cospi[56]); + const __m128i cospi24 = _mm_set1_epi32(cospi[24]); + const __m128i cospi40 = _mm_set1_epi32(cospi[40]); + const __m128i cospi8 = _mm_set1_epi32(cospi[8]); + const __m128i cospim40 = _mm_set1_epi32(-cospi[40]); + const __m128i cospim8 = _mm_set1_epi32(-cospi[8]); + const __m128i cospim56 = _mm_set1_epi32(-cospi[56]); + const __m128i cospim24 = _mm_set1_epi32(-cospi[24]); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i cospim32 = _mm_set1_epi32(-cospi[32]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospim48 = _mm_set1_epi32(-cospi[48]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i cospim16 = _mm_set1_epi32(-cospi[16]); + const __m128i rounding = _mm_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); + const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); + __m128i bf1[32]; + + // stage 0 + // stage 1 + bf1[0] = in[0]; + bf1[4] = in[4]; + bf1[8] = in[2]; + bf1[12] = in[6]; + bf1[16] = in[1]; + bf1[20] = in[5]; + bf1[24] = in[3]; + bf1[28] = in[7]; + + // stage 2 + bf1[31] = half_btf_0_sse4_1(&cospi2, &bf1[16], &rounding, bit); + bf1[16] = half_btf_0_sse4_1(&cospi62, &bf1[16], &rounding, bit); + bf1[19] = half_btf_0_sse4_1(&cospim50, &bf1[28], &rounding, bit); + bf1[28] = half_btf_0_sse4_1(&cospi14, &bf1[28], &rounding, bit); + bf1[27] = half_btf_0_sse4_1(&cospi10, &bf1[20], &rounding, bit); + bf1[20] = half_btf_0_sse4_1(&cospi54, &bf1[20], &rounding, bit); + bf1[23] = half_btf_0_sse4_1(&cospim58, &bf1[24], &rounding, bit); + bf1[24] = half_btf_0_sse4_1(&cospi6, &bf1[24], &rounding, bit); + + // stage 3 + bf1[15] = half_btf_0_sse4_1(&cospi4, &bf1[8], &rounding, bit); + bf1[8] = half_btf_0_sse4_1(&cospi60, &bf1[8], &rounding, bit); + + bf1[11] = half_btf_0_sse4_1(&cospim52, &bf1[12], &rounding, bit); + bf1[12] = half_btf_0_sse4_1(&cospi12, &bf1[12], &rounding, bit); + bf1[17] = bf1[16]; + bf1[18] = bf1[19]; + bf1[21] = bf1[20]; + bf1[22] = bf1[23]; + bf1[25] = bf1[24]; + bf1[26] = bf1[27]; + bf1[29] = bf1[28]; + bf1[30] = bf1[31]; + + // stage 4 : + bf1[7] = half_btf_0_sse4_1(&cospi8, &bf1[4], &rounding, bit); + bf1[4] = half_btf_0_sse4_1(&cospi56, &bf1[4], &rounding, bit); + + bf1[9] = bf1[8]; + bf1[10] = bf1[11]; + bf1[13] = bf1[12]; + bf1[14] = bf1[15]; + + idct32_stage4_sse4_1(bf1, &cospim8, &cospi56, &cospi8, &cospim56, &cospim40, + &cospi24, &cospi40, &cospim24, &rounding, bit); + + // stage 5 + bf1[0] = half_btf_0_sse4_1(&cospi32, &bf1[0], &rounding, bit); + bf1[1] = bf1[0]; + bf1[5] = bf1[4]; + bf1[6] = bf1[7]; + + idct32_stage5_sse4_1(bf1, &cospim16, &cospi48, &cospi16, &cospim48, &clamp_lo, + &clamp_hi, &rounding, bit); + + // stage 6 + bf1[3] = bf1[0]; + bf1[2] = bf1[1]; + + idct32_stage6_sse4_1(bf1, &cospim32, &cospi32, &cospim16, &cospi48, &cospi16, + &cospim48, &clamp_lo, &clamp_hi, &rounding, bit); + + // stage 7 + idct32_stage7_sse4_1(bf1, &cospim32, &cospi32, &clamp_lo, &clamp_hi, + &rounding, bit); + + // stage 8 + idct32_stage8_sse4_1(bf1, &cospim32, &cospi32, &clamp_lo, &clamp_hi, + &rounding, bit); + + // stage 9 + idct32_stage9_sse4_1(bf1, out, do_cols, bd, out_shift, log_range); +} + +static void idct32x32_low16_sse4_1(__m128i *in, __m128i *out, int bit, + int do_cols, int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m128i cospi62 = _mm_set1_epi32(cospi[62]); + const __m128i cospi30 = _mm_set1_epi32(cospi[30]); + const __m128i cospi46 = _mm_set1_epi32(cospi[46]); + const __m128i cospi14 = _mm_set1_epi32(cospi[14]); + const __m128i cospi54 = _mm_set1_epi32(cospi[54]); + const __m128i cospi22 = _mm_set1_epi32(cospi[22]); + const __m128i cospi38 = _mm_set1_epi32(cospi[38]); + const __m128i cospi6 = _mm_set1_epi32(cospi[6]); + const __m128i cospi26 = _mm_set1_epi32(cospi[26]); + const __m128i cospi10 = _mm_set1_epi32(cospi[10]); + const __m128i cospi18 = _mm_set1_epi32(cospi[18]); + const __m128i cospi2 = _mm_set1_epi32(cospi[2]); + const __m128i cospim58 = _mm_set1_epi32(-cospi[58]); + const __m128i cospim42 = _mm_set1_epi32(-cospi[42]); + const __m128i cospim50 = _mm_set1_epi32(-cospi[50]); + const __m128i cospim34 = _mm_set1_epi32(-cospi[34]); + const __m128i cospi60 = _mm_set1_epi32(cospi[60]); + const __m128i cospi28 = _mm_set1_epi32(cospi[28]); + const __m128i cospi44 = _mm_set1_epi32(cospi[44]); + const __m128i cospi12 = _mm_set1_epi32(cospi[12]); + const __m128i cospi20 = _mm_set1_epi32(cospi[20]); + const __m128i cospi4 = _mm_set1_epi32(cospi[4]); + const __m128i cospim52 = _mm_set1_epi32(-cospi[52]); + const __m128i cospim36 = _mm_set1_epi32(-cospi[36]); + const __m128i cospi56 = _mm_set1_epi32(cospi[56]); + const __m128i cospi24 = _mm_set1_epi32(cospi[24]); + const __m128i cospi40 = _mm_set1_epi32(cospi[40]); + const __m128i cospi8 = _mm_set1_epi32(cospi[8]); + const __m128i cospim40 = _mm_set1_epi32(-cospi[40]); + const __m128i cospim8 = _mm_set1_epi32(-cospi[8]); + const __m128i cospim56 = _mm_set1_epi32(-cospi[56]); + const __m128i cospim24 = _mm_set1_epi32(-cospi[24]); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i cospim32 = _mm_set1_epi32(-cospi[32]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospim48 = _mm_set1_epi32(-cospi[48]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i cospim16 = _mm_set1_epi32(-cospi[16]); + const __m128i rounding = _mm_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); + const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); + __m128i bf1[32]; + + // stage 0 + // stage 1 + + bf1[0] = in[0]; + bf1[2] = in[8]; + bf1[4] = in[4]; + bf1[6] = in[12]; + bf1[8] = in[2]; + bf1[10] = in[10]; + bf1[12] = in[6]; + bf1[14] = in[14]; + bf1[16] = in[1]; + bf1[18] = in[9]; + bf1[20] = in[5]; + bf1[22] = in[13]; + bf1[24] = in[3]; + bf1[26] = in[11]; + bf1[28] = in[7]; + bf1[30] = in[15]; + + // stage 2 + bf1[31] = half_btf_0_sse4_1(&cospi2, &bf1[16], &rounding, bit); + bf1[16] = half_btf_0_sse4_1(&cospi62, &bf1[16], &rounding, bit); + bf1[17] = half_btf_0_sse4_1(&cospim34, &bf1[30], &rounding, bit); + bf1[30] = half_btf_0_sse4_1(&cospi30, &bf1[30], &rounding, bit); + bf1[29] = half_btf_0_sse4_1(&cospi18, &bf1[18], &rounding, bit); + bf1[18] = half_btf_0_sse4_1(&cospi46, &bf1[18], &rounding, bit); + bf1[19] = half_btf_0_sse4_1(&cospim50, &bf1[28], &rounding, bit); + bf1[28] = half_btf_0_sse4_1(&cospi14, &bf1[28], &rounding, bit); + bf1[27] = half_btf_0_sse4_1(&cospi10, &bf1[20], &rounding, bit); + bf1[20] = half_btf_0_sse4_1(&cospi54, &bf1[20], &rounding, bit); + bf1[21] = half_btf_0_sse4_1(&cospim42, &bf1[26], &rounding, bit); + bf1[26] = half_btf_0_sse4_1(&cospi22, &bf1[26], &rounding, bit); + bf1[25] = half_btf_0_sse4_1(&cospi26, &bf1[22], &rounding, bit); + bf1[22] = half_btf_0_sse4_1(&cospi38, &bf1[22], &rounding, bit); + bf1[23] = half_btf_0_sse4_1(&cospim58, &bf1[24], &rounding, bit); + bf1[24] = half_btf_0_sse4_1(&cospi6, &bf1[24], &rounding, bit); + + // stage 3 + bf1[15] = half_btf_0_sse4_1(&cospi4, &bf1[8], &rounding, bit); + bf1[8] = half_btf_0_sse4_1(&cospi60, &bf1[8], &rounding, bit); + bf1[9] = half_btf_0_sse4_1(&cospim36, &bf1[14], &rounding, bit); + bf1[14] = half_btf_0_sse4_1(&cospi28, &bf1[14], &rounding, bit); + bf1[13] = half_btf_0_sse4_1(&cospi20, &bf1[10], &rounding, bit); + bf1[10] = half_btf_0_sse4_1(&cospi44, &bf1[10], &rounding, bit); + bf1[11] = half_btf_0_sse4_1(&cospim52, &bf1[12], &rounding, bit); + bf1[12] = half_btf_0_sse4_1(&cospi12, &bf1[12], &rounding, bit); + + addsub_sse4_1(bf1[16], bf1[17], bf1 + 16, bf1 + 17, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[19], bf1[18], bf1 + 19, bf1 + 18, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[20], bf1[21], bf1 + 20, bf1 + 21, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[23], bf1[22], bf1 + 23, bf1 + 22, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[24], bf1[25], bf1 + 24, bf1 + 25, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[27], bf1[26], bf1 + 27, bf1 + 26, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[28], bf1[29], bf1 + 28, bf1 + 29, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[31], bf1[30], bf1 + 31, bf1 + 30, &clamp_lo, &clamp_hi); + // stage 4 + bf1[7] = half_btf_0_sse4_1(&cospi8, &bf1[4], &rounding, bit); + bf1[4] = half_btf_0_sse4_1(&cospi56, &bf1[4], &rounding, bit); + bf1[5] = half_btf_0_sse4_1(&cospim40, &bf1[6], &rounding, bit); + bf1[6] = half_btf_0_sse4_1(&cospi24, &bf1[6], &rounding, bit); + + addsub_sse4_1(bf1[8], bf1[9], bf1 + 8, bf1 + 9, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[11], bf1[10], bf1 + 11, bf1 + 10, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[12], bf1[13], bf1 + 12, bf1 + 13, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[15], bf1[14], bf1 + 15, bf1 + 14, &clamp_lo, &clamp_hi); + + idct32_stage4_sse4_1(bf1, &cospim8, &cospi56, &cospi8, &cospim56, &cospim40, + &cospi24, &cospi40, &cospim24, &rounding, bit); + + // stage 5 + bf1[0] = half_btf_0_sse4_1(&cospi32, &bf1[0], &rounding, bit); + bf1[1] = bf1[0]; + bf1[3] = half_btf_0_sse4_1(&cospi16, &bf1[2], &rounding, bit); + bf1[2] = half_btf_0_sse4_1(&cospi48, &bf1[2], &rounding, bit); + + addsub_sse4_1(bf1[4], bf1[5], bf1 + 4, bf1 + 5, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[7], bf1[6], bf1 + 7, bf1 + 6, &clamp_lo, &clamp_hi); + + idct32_stage5_sse4_1(bf1, &cospim16, &cospi48, &cospi16, &cospim48, &clamp_lo, + &clamp_hi, &rounding, bit); + // stage 6 + addsub_sse4_1(bf1[0], bf1[3], bf1 + 0, bf1 + 3, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[1], bf1[2], bf1 + 1, bf1 + 2, &clamp_lo, &clamp_hi); + + idct32_stage6_sse4_1(bf1, &cospim32, &cospi32, &cospim16, &cospi48, &cospi16, + &cospim48, &clamp_lo, &clamp_hi, &rounding, bit); + + // stage 7 + idct32_stage7_sse4_1(bf1, &cospim32, &cospi32, &clamp_lo, &clamp_hi, + &rounding, bit); + + // stage 8 + idct32_stage8_sse4_1(bf1, &cospim32, &cospi32, &clamp_lo, &clamp_hi, + &rounding, bit); + + // stage 9 + idct32_stage9_sse4_1(bf1, out, do_cols, bd, out_shift, log_range); +} + +static void idct32x32_sse4_1(__m128i *in, __m128i *out, int bit, int do_cols, + int bd, int out_shift) { + const int32_t *cospi = cospi_arr(bit); + const __m128i cospi62 = _mm_set1_epi32(cospi[62]); + const __m128i cospi30 = _mm_set1_epi32(cospi[30]); + const __m128i cospi46 = _mm_set1_epi32(cospi[46]); + const __m128i cospi14 = _mm_set1_epi32(cospi[14]); + const __m128i cospi54 = _mm_set1_epi32(cospi[54]); + const __m128i cospi22 = _mm_set1_epi32(cospi[22]); + const __m128i cospi38 = _mm_set1_epi32(cospi[38]); + const __m128i cospi6 = _mm_set1_epi32(cospi[6]); + const __m128i cospi58 = _mm_set1_epi32(cospi[58]); + const __m128i cospi26 = _mm_set1_epi32(cospi[26]); + const __m128i cospi42 = _mm_set1_epi32(cospi[42]); + const __m128i cospi10 = _mm_set1_epi32(cospi[10]); + const __m128i cospi50 = _mm_set1_epi32(cospi[50]); + const __m128i cospi18 = _mm_set1_epi32(cospi[18]); + const __m128i cospi34 = _mm_set1_epi32(cospi[34]); + const __m128i cospi2 = _mm_set1_epi32(cospi[2]); + const __m128i cospim58 = _mm_set1_epi32(-cospi[58]); + const __m128i cospim26 = _mm_set1_epi32(-cospi[26]); + const __m128i cospim42 = _mm_set1_epi32(-cospi[42]); + const __m128i cospim10 = _mm_set1_epi32(-cospi[10]); + const __m128i cospim50 = _mm_set1_epi32(-cospi[50]); + const __m128i cospim18 = _mm_set1_epi32(-cospi[18]); + const __m128i cospim34 = _mm_set1_epi32(-cospi[34]); + const __m128i cospim2 = _mm_set1_epi32(-cospi[2]); + const __m128i cospi60 = _mm_set1_epi32(cospi[60]); + const __m128i cospi28 = _mm_set1_epi32(cospi[28]); + const __m128i cospi44 = _mm_set1_epi32(cospi[44]); + const __m128i cospi12 = _mm_set1_epi32(cospi[12]); + const __m128i cospi52 = _mm_set1_epi32(cospi[52]); + const __m128i cospi20 = _mm_set1_epi32(cospi[20]); + const __m128i cospi36 = _mm_set1_epi32(cospi[36]); + const __m128i cospi4 = _mm_set1_epi32(cospi[4]); + const __m128i cospim52 = _mm_set1_epi32(-cospi[52]); + const __m128i cospim20 = _mm_set1_epi32(-cospi[20]); + const __m128i cospim36 = _mm_set1_epi32(-cospi[36]); + const __m128i cospim4 = _mm_set1_epi32(-cospi[4]); + const __m128i cospi56 = _mm_set1_epi32(cospi[56]); + const __m128i cospi24 = _mm_set1_epi32(cospi[24]); + const __m128i cospi40 = _mm_set1_epi32(cospi[40]); + const __m128i cospi8 = _mm_set1_epi32(cospi[8]); + const __m128i cospim40 = _mm_set1_epi32(-cospi[40]); + const __m128i cospim8 = _mm_set1_epi32(-cospi[8]); + const __m128i cospim56 = _mm_set1_epi32(-cospi[56]); + const __m128i cospim24 = _mm_set1_epi32(-cospi[24]); + const __m128i cospi32 = _mm_set1_epi32(cospi[32]); + const __m128i cospim32 = _mm_set1_epi32(-cospi[32]); + const __m128i cospi48 = _mm_set1_epi32(cospi[48]); + const __m128i cospim48 = _mm_set1_epi32(-cospi[48]); + const __m128i cospi16 = _mm_set1_epi32(cospi[16]); + const __m128i cospim16 = _mm_set1_epi32(-cospi[16]); + const __m128i rounding = _mm_set1_epi32(1 << (bit - 1)); + const int log_range = AOMMAX(16, bd + (do_cols ? 6 : 8)); + const __m128i clamp_lo = _mm_set1_epi32(-(1 << (log_range - 1))); + const __m128i clamp_hi = _mm_set1_epi32((1 << (log_range - 1)) - 1); + __m128i bf1[32], bf0[32]; + + // stage 0 + // stage 1 + bf1[0] = in[0]; + bf1[1] = in[16]; + bf1[2] = in[8]; + bf1[3] = in[24]; + bf1[4] = in[4]; + bf1[5] = in[20]; + bf1[6] = in[12]; + bf1[7] = in[28]; + bf1[8] = in[2]; + bf1[9] = in[18]; + bf1[10] = in[10]; + bf1[11] = in[26]; + bf1[12] = in[6]; + bf1[13] = in[22]; + bf1[14] = in[14]; + bf1[15] = in[30]; + bf1[16] = in[1]; + bf1[17] = in[17]; + bf1[18] = in[9]; + bf1[19] = in[25]; + bf1[20] = in[5]; + bf1[21] = in[21]; + bf1[22] = in[13]; + bf1[23] = in[29]; + bf1[24] = in[3]; + bf1[25] = in[19]; + bf1[26] = in[11]; + bf1[27] = in[27]; + bf1[28] = in[7]; + bf1[29] = in[23]; + bf1[30] = in[15]; + bf1[31] = in[31]; + + // stage 2 + bf0[0] = bf1[0]; + bf0[1] = bf1[1]; + bf0[2] = bf1[2]; + bf0[3] = bf1[3]; + bf0[4] = bf1[4]; + bf0[5] = bf1[5]; + bf0[6] = bf1[6]; + bf0[7] = bf1[7]; + bf0[8] = bf1[8]; + bf0[9] = bf1[9]; + bf0[10] = bf1[10]; + bf0[11] = bf1[11]; + bf0[12] = bf1[12]; + bf0[13] = bf1[13]; + bf0[14] = bf1[14]; + bf0[15] = bf1[15]; + bf0[16] = + half_btf_sse4_1(&cospi62, &bf1[16], &cospim2, &bf1[31], &rounding, bit); + bf0[17] = + half_btf_sse4_1(&cospi30, &bf1[17], &cospim34, &bf1[30], &rounding, bit); + bf0[18] = + half_btf_sse4_1(&cospi46, &bf1[18], &cospim18, &bf1[29], &rounding, bit); + bf0[19] = + half_btf_sse4_1(&cospi14, &bf1[19], &cospim50, &bf1[28], &rounding, bit); + bf0[20] = + half_btf_sse4_1(&cospi54, &bf1[20], &cospim10, &bf1[27], &rounding, bit); + bf0[21] = + half_btf_sse4_1(&cospi22, &bf1[21], &cospim42, &bf1[26], &rounding, bit); + bf0[22] = + half_btf_sse4_1(&cospi38, &bf1[22], &cospim26, &bf1[25], &rounding, bit); + bf0[23] = + half_btf_sse4_1(&cospi6, &bf1[23], &cospim58, &bf1[24], &rounding, bit); + bf0[24] = + half_btf_sse4_1(&cospi58, &bf1[23], &cospi6, &bf1[24], &rounding, bit); + bf0[25] = + half_btf_sse4_1(&cospi26, &bf1[22], &cospi38, &bf1[25], &rounding, bit); + bf0[26] = + half_btf_sse4_1(&cospi42, &bf1[21], &cospi22, &bf1[26], &rounding, bit); + bf0[27] = + half_btf_sse4_1(&cospi10, &bf1[20], &cospi54, &bf1[27], &rounding, bit); + bf0[28] = + half_btf_sse4_1(&cospi50, &bf1[19], &cospi14, &bf1[28], &rounding, bit); + bf0[29] = + half_btf_sse4_1(&cospi18, &bf1[18], &cospi46, &bf1[29], &rounding, bit); + bf0[30] = + half_btf_sse4_1(&cospi34, &bf1[17], &cospi30, &bf1[30], &rounding, bit); + bf0[31] = + half_btf_sse4_1(&cospi2, &bf1[16], &cospi62, &bf1[31], &rounding, bit); + + // stage 3 + bf1[0] = bf0[0]; + bf1[1] = bf0[1]; + bf1[2] = bf0[2]; + bf1[3] = bf0[3]; + bf1[4] = bf0[4]; + bf1[5] = bf0[5]; + bf1[6] = bf0[6]; + bf1[7] = bf0[7]; + bf1[8] = + half_btf_sse4_1(&cospi60, &bf0[8], &cospim4, &bf0[15], &rounding, bit); + bf1[9] = + half_btf_sse4_1(&cospi28, &bf0[9], &cospim36, &bf0[14], &rounding, bit); + bf1[10] = + half_btf_sse4_1(&cospi44, &bf0[10], &cospim20, &bf0[13], &rounding, bit); + bf1[11] = + half_btf_sse4_1(&cospi12, &bf0[11], &cospim52, &bf0[12], &rounding, bit); + bf1[12] = + half_btf_sse4_1(&cospi52, &bf0[11], &cospi12, &bf0[12], &rounding, bit); + bf1[13] = + half_btf_sse4_1(&cospi20, &bf0[10], &cospi44, &bf0[13], &rounding, bit); + bf1[14] = + half_btf_sse4_1(&cospi36, &bf0[9], &cospi28, &bf0[14], &rounding, bit); + bf1[15] = + half_btf_sse4_1(&cospi4, &bf0[8], &cospi60, &bf0[15], &rounding, bit); + + addsub_sse4_1(bf0[16], bf0[17], bf1 + 16, bf1 + 17, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[19], bf0[18], bf1 + 19, bf1 + 18, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[20], bf0[21], bf1 + 20, bf1 + 21, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[23], bf0[22], bf1 + 23, bf1 + 22, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[24], bf0[25], bf1 + 24, bf1 + 25, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[27], bf0[26], bf1 + 27, bf1 + 26, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[28], bf0[29], bf1 + 28, bf1 + 29, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[31], bf0[30], bf1 + 31, bf1 + 30, &clamp_lo, &clamp_hi); + + // stage 4 + bf0[0] = bf1[0]; + bf0[1] = bf1[1]; + bf0[2] = bf1[2]; + bf0[3] = bf1[3]; + bf0[4] = + half_btf_sse4_1(&cospi56, &bf1[4], &cospim8, &bf1[7], &rounding, bit); + bf0[5] = + half_btf_sse4_1(&cospi24, &bf1[5], &cospim40, &bf1[6], &rounding, bit); + bf0[6] = + half_btf_sse4_1(&cospi40, &bf1[5], &cospi24, &bf1[6], &rounding, bit); + bf0[7] = half_btf_sse4_1(&cospi8, &bf1[4], &cospi56, &bf1[7], &rounding, bit); + + addsub_sse4_1(bf1[8], bf1[9], bf0 + 8, bf0 + 9, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[11], bf1[10], bf0 + 11, bf0 + 10, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[12], bf1[13], bf0 + 12, bf0 + 13, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[15], bf1[14], bf0 + 15, bf0 + 14, &clamp_lo, &clamp_hi); + + bf0[16] = bf1[16]; + bf0[17] = + half_btf_sse4_1(&cospim8, &bf1[17], &cospi56, &bf1[30], &rounding, bit); + bf0[18] = + half_btf_sse4_1(&cospim56, &bf1[18], &cospim8, &bf1[29], &rounding, bit); + bf0[19] = bf1[19]; + bf0[20] = bf1[20]; + bf0[21] = + half_btf_sse4_1(&cospim40, &bf1[21], &cospi24, &bf1[26], &rounding, bit); + bf0[22] = + half_btf_sse4_1(&cospim24, &bf1[22], &cospim40, &bf1[25], &rounding, bit); + bf0[23] = bf1[23]; + bf0[24] = bf1[24]; + bf0[25] = + half_btf_sse4_1(&cospim40, &bf1[22], &cospi24, &bf1[25], &rounding, bit); + bf0[26] = + half_btf_sse4_1(&cospi24, &bf1[21], &cospi40, &bf1[26], &rounding, bit); + bf0[27] = bf1[27]; + bf0[28] = bf1[28]; + bf0[29] = + half_btf_sse4_1(&cospim8, &bf1[18], &cospi56, &bf1[29], &rounding, bit); + bf0[30] = + half_btf_sse4_1(&cospi56, &bf1[17], &cospi8, &bf1[30], &rounding, bit); + bf0[31] = bf1[31]; + + // stage 5 + bf1[0] = + half_btf_sse4_1(&cospi32, &bf0[0], &cospi32, &bf0[1], &rounding, bit); + bf1[1] = + half_btf_sse4_1(&cospi32, &bf0[0], &cospim32, &bf0[1], &rounding, bit); + bf1[2] = + half_btf_sse4_1(&cospi48, &bf0[2], &cospim16, &bf0[3], &rounding, bit); + bf1[3] = + half_btf_sse4_1(&cospi16, &bf0[2], &cospi48, &bf0[3], &rounding, bit); + addsub_sse4_1(bf0[4], bf0[5], bf1 + 4, bf1 + 5, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[7], bf0[6], bf1 + 7, bf1 + 6, &clamp_lo, &clamp_hi); + bf1[8] = bf0[8]; + bf1[9] = + half_btf_sse4_1(&cospim16, &bf0[9], &cospi48, &bf0[14], &rounding, bit); + bf1[10] = + half_btf_sse4_1(&cospim48, &bf0[10], &cospim16, &bf0[13], &rounding, bit); + bf1[11] = bf0[11]; + bf1[12] = bf0[12]; + bf1[13] = + half_btf_sse4_1(&cospim16, &bf0[10], &cospi48, &bf0[13], &rounding, bit); + bf1[14] = + half_btf_sse4_1(&cospi48, &bf0[9], &cospi16, &bf0[14], &rounding, bit); + bf1[15] = bf0[15]; + addsub_sse4_1(bf0[16], bf0[19], bf1 + 16, bf1 + 19, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[17], bf0[18], bf1 + 17, bf1 + 18, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[23], bf0[20], bf1 + 23, bf1 + 20, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[22], bf0[21], bf1 + 22, bf1 + 21, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[24], bf0[27], bf1 + 24, bf1 + 27, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[25], bf0[26], bf1 + 25, bf1 + 26, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[31], bf0[28], bf1 + 31, bf1 + 28, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[30], bf0[29], bf1 + 30, bf1 + 29, &clamp_lo, &clamp_hi); + + // stage 6 + addsub_sse4_1(bf1[0], bf1[3], bf0 + 0, bf0 + 3, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[1], bf1[2], bf0 + 1, bf0 + 2, &clamp_lo, &clamp_hi); + bf0[4] = bf1[4]; + bf0[5] = + half_btf_sse4_1(&cospim32, &bf1[5], &cospi32, &bf1[6], &rounding, bit); + bf0[6] = + half_btf_sse4_1(&cospi32, &bf1[5], &cospi32, &bf1[6], &rounding, bit); + bf0[7] = bf1[7]; + addsub_sse4_1(bf1[8], bf1[11], bf0 + 8, bf0 + 11, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[9], bf1[10], bf0 + 9, bf0 + 10, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[15], bf1[12], bf0 + 15, bf0 + 12, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[14], bf1[13], bf0 + 14, bf0 + 13, &clamp_lo, &clamp_hi); + bf0[16] = bf1[16]; + bf0[17] = bf1[17]; + bf0[18] = + half_btf_sse4_1(&cospim16, &bf1[18], &cospi48, &bf1[29], &rounding, bit); + bf0[19] = + half_btf_sse4_1(&cospim16, &bf1[19], &cospi48, &bf1[28], &rounding, bit); + bf0[20] = + half_btf_sse4_1(&cospim48, &bf1[20], &cospim16, &bf1[27], &rounding, bit); + bf0[21] = + half_btf_sse4_1(&cospim48, &bf1[21], &cospim16, &bf1[26], &rounding, bit); + bf0[22] = bf1[22]; + bf0[23] = bf1[23]; + bf0[24] = bf1[24]; + bf0[25] = bf1[25]; + bf0[26] = + half_btf_sse4_1(&cospim16, &bf1[21], &cospi48, &bf1[26], &rounding, bit); + bf0[27] = + half_btf_sse4_1(&cospim16, &bf1[20], &cospi48, &bf1[27], &rounding, bit); + bf0[28] = + half_btf_sse4_1(&cospi48, &bf1[19], &cospi16, &bf1[28], &rounding, bit); + bf0[29] = + half_btf_sse4_1(&cospi48, &bf1[18], &cospi16, &bf1[29], &rounding, bit); + bf0[30] = bf1[30]; + bf0[31] = bf1[31]; + + // stage 7 + addsub_sse4_1(bf0[0], bf0[7], bf1 + 0, bf1 + 7, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[1], bf0[6], bf1 + 1, bf1 + 6, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[2], bf0[5], bf1 + 2, bf1 + 5, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[3], bf0[4], bf1 + 3, bf1 + 4, &clamp_lo, &clamp_hi); + bf1[8] = bf0[8]; + bf1[9] = bf0[9]; + bf1[10] = + half_btf_sse4_1(&cospim32, &bf0[10], &cospi32, &bf0[13], &rounding, bit); + bf1[11] = + half_btf_sse4_1(&cospim32, &bf0[11], &cospi32, &bf0[12], &rounding, bit); + bf1[12] = + half_btf_sse4_1(&cospi32, &bf0[11], &cospi32, &bf0[12], &rounding, bit); + bf1[13] = + half_btf_sse4_1(&cospi32, &bf0[10], &cospi32, &bf0[13], &rounding, bit); + bf1[14] = bf0[14]; + bf1[15] = bf0[15]; + addsub_sse4_1(bf0[16], bf0[23], bf1 + 16, bf1 + 23, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[17], bf0[22], bf1 + 17, bf1 + 22, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[18], bf0[21], bf1 + 18, bf1 + 21, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[19], bf0[20], bf1 + 19, bf1 + 20, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[31], bf0[24], bf1 + 31, bf1 + 24, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[30], bf0[25], bf1 + 30, bf1 + 25, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[29], bf0[26], bf1 + 29, bf1 + 26, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf0[28], bf0[27], bf1 + 28, bf1 + 27, &clamp_lo, &clamp_hi); + + // stage 8 + addsub_sse4_1(bf1[0], bf1[15], bf0 + 0, bf0 + 15, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[1], bf1[14], bf0 + 1, bf0 + 14, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[2], bf1[13], bf0 + 2, bf0 + 13, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[3], bf1[12], bf0 + 3, bf0 + 12, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[4], bf1[11], bf0 + 4, bf0 + 11, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[5], bf1[10], bf0 + 5, bf0 + 10, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[6], bf1[9], bf0 + 6, bf0 + 9, &clamp_lo, &clamp_hi); + addsub_sse4_1(bf1[7], bf1[8], bf0 + 7, bf0 + 8, &clamp_lo, &clamp_hi); + bf0[16] = bf1[16]; + bf0[17] = bf1[17]; + bf0[18] = bf1[18]; + bf0[19] = bf1[19]; + bf0[20] = + half_btf_sse4_1(&cospim32, &bf1[20], &cospi32, &bf1[27], &rounding, bit); + bf0[21] = + half_btf_sse4_1(&cospim32, &bf1[21], &cospi32, &bf1[26], &rounding, bit); + bf0[22] = + half_btf_sse4_1(&cospim32, &bf1[22], &cospi32, &bf1[25], &rounding, bit); + bf0[23] = + half_btf_sse4_1(&cospim32, &bf1[23], &cospi32, &bf1[24], &rounding, bit); + bf0[24] = + half_btf_sse4_1(&cospi32, &bf1[23], &cospi32, &bf1[24], &rounding, bit); + bf0[25] = + half_btf_sse4_1(&cospi32, &bf1[22], &cospi32, &bf1[25], &rounding, bit); + bf0[26] = + half_btf_sse4_1(&cospi32, &bf1[21], &cospi32, &bf1[26], &rounding, bit); + bf0[27] = + half_btf_sse4_1(&cospi32, &bf1[20], &cospi32, &bf1[27], &rounding, bit); + bf0[28] = bf1[28]; + bf0[29] = bf1[29]; + bf0[30] = bf1[30]; + bf0[31] = bf1[31]; + + // stage 9 + if (do_cols) { + addsub_no_clamp_sse4_1(bf0[0], bf0[31], out + 0, out + 31); + addsub_no_clamp_sse4_1(bf0[1], bf0[30], out + 1, out + 30); + addsub_no_clamp_sse4_1(bf0[2], bf0[29], out + 2, out + 29); + addsub_no_clamp_sse4_1(bf0[3], bf0[28], out + 3, out + 28); + addsub_no_clamp_sse4_1(bf0[4], bf0[27], out + 4, out + 27); + addsub_no_clamp_sse4_1(bf0[5], bf0[26], out + 5, out + 26); + addsub_no_clamp_sse4_1(bf0[6], bf0[25], out + 6, out + 25); + addsub_no_clamp_sse4_1(bf0[7], bf0[24], out + 7, out + 24); + addsub_no_clamp_sse4_1(bf0[8], bf0[23], out + 8, out + 23); + addsub_no_clamp_sse4_1(bf0[9], bf0[22], out + 9, out + 22); + addsub_no_clamp_sse4_1(bf0[10], bf0[21], out + 10, out + 21); + addsub_no_clamp_sse4_1(bf0[11], bf0[20], out + 11, out + 20); + addsub_no_clamp_sse4_1(bf0[12], bf0[19], out + 12, out + 19); + addsub_no_clamp_sse4_1(bf0[13], bf0[18], out + 13, out + 18); + addsub_no_clamp_sse4_1(bf0[14], bf0[17], out + 14, out + 17); + addsub_no_clamp_sse4_1(bf0[15], bf0[16], out + 15, out + 16); + } else { + const int log_range_out = AOMMAX(16, bd + 6); + const __m128i clamp_lo_out = _mm_set1_epi32(AOMMAX( + -(1 << (log_range_out - 1)), -(1 << (log_range - 1 - out_shift)))); + const __m128i clamp_hi_out = _mm_set1_epi32(AOMMIN( + (1 << (log_range_out - 1)) - 1, (1 << (log_range - 1 - out_shift)))); + + addsub_shift_sse4_1(bf0[0], bf0[31], out + 0, out + 31, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[1], bf0[30], out + 1, out + 30, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[2], bf0[29], out + 2, out + 29, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[3], bf0[28], out + 3, out + 28, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[4], bf0[27], out + 4, out + 27, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[5], bf0[26], out + 5, out + 26, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[6], bf0[25], out + 6, out + 25, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[7], bf0[24], out + 7, out + 24, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[8], bf0[23], out + 8, out + 23, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[9], bf0[22], out + 9, out + 22, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[10], bf0[21], out + 10, out + 21, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[11], bf0[20], out + 11, out + 20, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[12], bf0[19], out + 12, out + 19, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[13], bf0[18], out + 13, out + 18, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[14], bf0[17], out + 14, out + 17, &clamp_lo_out, + &clamp_hi_out, out_shift); + addsub_shift_sse4_1(bf0[15], bf0[16], out + 15, out + 16, &clamp_lo_out, + &clamp_hi_out, out_shift); + } +} + +void av1_highbd_inv_txfm_add_8x8_sse4_1(const tran_low_t *input, uint8_t *dest, + int stride, + const TxfmParam *txfm_param) { + int bd = txfm_param->bd; + const TX_TYPE tx_type = txfm_param->tx_type; + const int32_t *src = cast_to_int32(input); + switch (tx_type) { + // Assembly version doesn't support some transform types, so use C version + // for those. + case V_DCT: + case H_DCT: + case V_ADST: + case H_ADST: + case V_FLIPADST: + case H_FLIPADST: + case IDTX: + av1_inv_txfm2d_add_8x8_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type, + bd); + break; + default: + av1_inv_txfm2d_add_8x8_sse4_1(src, CONVERT_TO_SHORTPTR(dest), stride, + tx_type, bd); + break; + } +} + +void av1_highbd_inv_txfm_add_16x8_sse4_1(const tran_low_t *input, uint8_t *dest, + int stride, + const TxfmParam *txfm_param) { + int bd = txfm_param->bd; + const TX_TYPE tx_type = txfm_param->tx_type; + const int32_t *src = cast_to_int32(input); + switch (tx_type) { + // Assembly version doesn't support some transform types, so use C version + // for those. + case V_DCT: + case H_DCT: + case V_ADST: + case H_ADST: + case V_FLIPADST: + case H_FLIPADST: + case IDTX: + av1_inv_txfm2d_add_16x8_c(src, CONVERT_TO_SHORTPTR(dest), stride, + txfm_param->tx_type, txfm_param->bd); + break; + default: + av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type, + txfm_param->tx_size, + txfm_param->eob, bd); + break; + } +} + +void av1_highbd_inv_txfm_add_8x16_sse4_1(const tran_low_t *input, uint8_t *dest, + int stride, + const TxfmParam *txfm_param) { + int bd = txfm_param->bd; + const TX_TYPE tx_type = txfm_param->tx_type; + const int32_t *src = cast_to_int32(input); + switch (tx_type) { + // Assembly version doesn't support some transform types, so use C version + // for those. + case V_DCT: + case H_DCT: + case V_ADST: + case H_ADST: + case V_FLIPADST: + case H_FLIPADST: + case IDTX: + av1_inv_txfm2d_add_8x16_c(src, CONVERT_TO_SHORTPTR(dest), stride, + txfm_param->tx_type, txfm_param->bd); + break; + default: + av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type, + txfm_param->tx_size, + txfm_param->eob, bd); + break; + } +} + +void av1_highbd_inv_txfm_add_16x16_sse4_1(const tran_low_t *input, + uint8_t *dest, int stride, + const TxfmParam *txfm_param) { + int bd = txfm_param->bd; + const TX_TYPE tx_type = txfm_param->tx_type; + const int32_t *src = cast_to_int32(input); + switch (tx_type) { + // Assembly version doesn't support some transform types, so use C version + // for those. + case V_DCT: + case H_DCT: + case V_ADST: + case H_ADST: + case V_FLIPADST: + case H_FLIPADST: + case IDTX: + av1_inv_txfm2d_add_16x16_c(src, CONVERT_TO_SHORTPTR(dest), stride, + tx_type, bd); + break; + default: + av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type, + txfm_param->tx_size, + txfm_param->eob, bd); + break; + } +} + +void av1_highbd_inv_txfm_add_32x32_sse4_1(const tran_low_t *input, + uint8_t *dest, int stride, + const TxfmParam *txfm_param) { + int bd = txfm_param->bd; + const TX_TYPE tx_type = txfm_param->tx_type; + const int32_t *src = cast_to_int32(input); switch (tx_type) { case DCT_DCT: - load_buffer_64x64_lower_32x32(coeff, in); - transpose_64x64(in, out, 0); - idct64x64_sse4_1(out, in, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, - -shift[0]); - transpose_64x64(in, out, 1); - idct64x64_sse4_1(out, in, inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0); - write_buffer_64x64(in, output, stride, 0, 0, -shift[1], bd); + av1_highbd_inv_txfm2d_add_universe_sse4_1(input, dest, stride, tx_type, + txfm_param->tx_size, + txfm_param->eob, bd); + break; + // Assembly version doesn't support IDTX, so use C version for it. + case IDTX: + av1_inv_txfm2d_add_32x32_c(src, CONVERT_TO_SHORTPTR(dest), stride, + tx_type, bd); break; + default: assert(0); + } +} +void av1_highbd_inv_txfm_add_4x4_sse4_1(const tran_low_t *input, uint8_t *dest, + int stride, + const TxfmParam *txfm_param) { + assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]); + int eob = txfm_param->eob; + int bd = txfm_param->bd; + int lossless = txfm_param->lossless; + const int32_t *src = cast_to_int32(input); + const TX_TYPE tx_type = txfm_param->tx_type; + if (lossless) { + assert(tx_type == DCT_DCT); + av1_highbd_iwht4x4_add(input, dest, stride, eob, bd); + return; + } + switch (tx_type) { + // Assembly version doesn't support some transform types, so use C version + // for those. + case V_DCT: + case H_DCT: + case V_ADST: + case H_ADST: + case V_FLIPADST: + case H_FLIPADST: + case IDTX: + av1_inv_txfm2d_add_4x4_c(src, CONVERT_TO_SHORTPTR(dest), stride, tx_type, + bd); + break; default: - av1_inv_txfm2d_add_64x64_c(coeff, output, stride, tx_type, bd); + av1_inv_txfm2d_add_4x4_sse4_1(src, CONVERT_TO_SHORTPTR(dest), stride, + tx_type, bd); + break; + } +} + +static const transform_1d_sse4_1 + highbd_txfm_all_1d_zeros_w8_arr[TX_SIZES][ITX_TYPES_1D][4] = { + { + { NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL }, + }, + { { idct8x8_low1_sse4_1, idct8x8_new_sse4_1, NULL, NULL }, + { iadst8x8_low1_sse4_1, iadst8x8_new_sse4_1, NULL, NULL }, + { NULL, NULL, NULL, NULL } }, + { + { idct16x16_low1_sse4_1, idct16x16_low8_sse4_1, idct16x16_sse4_1, + NULL }, + { iadst16x16_low1_sse4_1, iadst16x16_low8_sse4_1, iadst16x16_sse4_1, + NULL }, + { NULL, NULL, NULL, NULL }, + }, + { { idct32x32_low1_sse4_1, idct32x32_low8_sse4_1, idct32x32_low16_sse4_1, + idct32x32_sse4_1 }, + { NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL } }, + { { idct64x64_low1_sse4_1, idct64x64_low8_sse4_1, idct64x64_low16_sse4_1, + idct64x64_sse4_1 }, + { NULL, NULL, NULL, NULL }, + { NULL, NULL, NULL, NULL } } + }; + +static void highbd_inv_txfm2d_add_no_identity_sse41(const int32_t *input, + uint16_t *output, + int stride, TX_TYPE tx_type, + TX_SIZE tx_size, int eob, + const int bd) { + __m128i buf1[64 * 16]; + int eobx, eoby; + get_eobx_eoby_scan_default(&eobx, &eoby, tx_size, eob); + const int8_t *shift = inv_txfm_shift_ls[tx_size]; + const int txw_idx = get_txw_idx(tx_size); + const int txh_idx = get_txh_idx(tx_size); + const int txfm_size_col = tx_size_wide[tx_size]; + const int txfm_size_row = tx_size_high[tx_size]; + const int buf_size_w_div8 = txfm_size_col >> 2; + const int buf_size_nonzero_w_div8 = (eobx + 8) >> 3; + const int buf_size_nonzero_h_div8 = (eoby + 8) >> 3; + const int input_stride = AOMMIN(32, txfm_size_col); + const int rect_type = get_rect_tx_log_ratio(txfm_size_col, txfm_size_row); + + const int fun_idx_x = lowbd_txfm_all_1d_zeros_idx[eobx]; + const int fun_idx_y = lowbd_txfm_all_1d_zeros_idx[eoby]; + const transform_1d_sse4_1 row_txfm = + highbd_txfm_all_1d_zeros_w8_arr[txw_idx][hitx_1d_tab[tx_type]][fun_idx_x]; + const transform_1d_sse4_1 col_txfm = + highbd_txfm_all_1d_zeros_w8_arr[txh_idx][vitx_1d_tab[tx_type]][fun_idx_y]; + + assert(col_txfm != NULL); + assert(row_txfm != NULL); + int ud_flip, lr_flip; + get_flip_cfg(tx_type, &ud_flip, &lr_flip); + + // 1st stage: column transform + for (int i = 0; i < buf_size_nonzero_h_div8 << 1; i++) { + __m128i buf0[64]; + const int32_t *input_row = input + i * input_stride * 4; + for (int j = 0; j < buf_size_nonzero_w_div8 << 1; ++j) { + __m128i *buf0_cur = buf0 + j * 4; + load_buffer_32bit_input(input_row + j * 4, input_stride, buf0_cur, 4); + + TRANSPOSE_4X4(buf0_cur[0], buf0_cur[1], buf0_cur[2], buf0_cur[3], + buf0_cur[0], buf0_cur[1], buf0_cur[2], buf0_cur[3]); + } + if (rect_type == 1 || rect_type == -1) { + av1_round_shift_rect_array_32_sse4_1( + buf0, buf0, buf_size_nonzero_w_div8 << 3, 0, NewInvSqrt2); + } + row_txfm(buf0, buf0, inv_cos_bit_row[txw_idx][txh_idx], 0, bd, -shift[0]); + + __m128i *_buf1 = buf1 + i * 4; + if (lr_flip) { + for (int j = 0; j < buf_size_w_div8; ++j) { + TRANSPOSE_4X4(buf0[4 * j + 3], buf0[4 * j + 2], buf0[4 * j + 1], + buf0[4 * j], + _buf1[txfm_size_row * (buf_size_w_div8 - 1 - j) + 0], + _buf1[txfm_size_row * (buf_size_w_div8 - 1 - j) + 1], + _buf1[txfm_size_row * (buf_size_w_div8 - 1 - j) + 2], + _buf1[txfm_size_row * (buf_size_w_div8 - 1 - j) + 3]); + } + } else { + for (int j = 0; j < buf_size_w_div8; ++j) { + TRANSPOSE_4X4( + buf0[j * 4 + 0], buf0[j * 4 + 1], buf0[j * 4 + 2], buf0[j * 4 + 3], + _buf1[j * txfm_size_row + 0], _buf1[j * txfm_size_row + 1], + _buf1[j * txfm_size_row + 2], _buf1[j * txfm_size_row + 3]); + } + } + } + // 2nd stage: column transform + for (int i = 0; i < buf_size_w_div8; i++) { + col_txfm(buf1 + i * txfm_size_row, buf1 + i * txfm_size_row, + inv_cos_bit_col[txw_idx][txh_idx], 1, bd, 0); + + av1_round_shift_array_32_sse4_1(buf1 + i * txfm_size_row, + buf1 + i * txfm_size_row, txfm_size_row, + -shift[1]); + } + + // write to buffer + { + for (int i = 0; i < (txfm_size_col >> 3); i++) { + highbd_write_buffer_8xn_sse4_1(buf1 + i * txfm_size_row * 2, + output + 8 * i, stride, ud_flip, + txfm_size_row, bd); + } + } +} + +void av1_highbd_inv_txfm2d_add_universe_sse4_1(const int32_t *input, + uint8_t *output, int stride, + TX_TYPE tx_type, TX_SIZE tx_size, + int eob, const int bd) { + switch (tx_type) { + case DCT_DCT: + case ADST_DCT: + case DCT_ADST: + case ADST_ADST: + case FLIPADST_DCT: + case DCT_FLIPADST: + case FLIPADST_FLIPADST: + case ADST_FLIPADST: + case FLIPADST_ADST: + highbd_inv_txfm2d_add_no_identity_sse41( + input, CONVERT_TO_SHORTPTR(output), stride, tx_type, tx_size, eob, + bd); + break; + default: assert(0); break; + } +} + +void av1_highbd_inv_txfm_add_sse4_1(const tran_low_t *input, uint8_t *dest, + int stride, const TxfmParam *txfm_param) { + assert(av1_ext_tx_used[txfm_param->tx_set_type][txfm_param->tx_type]); + const TX_SIZE tx_size = txfm_param->tx_size; + switch (tx_size) { + case TX_32X32: + av1_highbd_inv_txfm_add_32x32_sse4_1(input, dest, stride, txfm_param); + break; + case TX_16X16: + av1_highbd_inv_txfm_add_16x16_sse4_1(input, dest, stride, txfm_param); + break; + case TX_8X8: + av1_highbd_inv_txfm_add_8x8_sse4_1(input, dest, stride, txfm_param); + break; + case TX_4X8: + av1_highbd_inv_txfm_add_4x8(input, dest, stride, txfm_param); + break; + case TX_8X4: + av1_highbd_inv_txfm_add_8x4(input, dest, stride, txfm_param); + break; + case TX_8X16: + av1_highbd_inv_txfm_add_8x16_sse4_1(input, dest, stride, txfm_param); + break; + case TX_16X8: + av1_highbd_inv_txfm_add_16x8_sse4_1(input, dest, stride, txfm_param); + break; + case TX_16X32: + av1_highbd_inv_txfm_add_16x32(input, dest, stride, txfm_param); + break; + case TX_32X16: + av1_highbd_inv_txfm_add_32x16(input, dest, stride, txfm_param); + break; + case TX_32X64: + av1_highbd_inv_txfm_add_32x64(input, dest, stride, txfm_param); + break; + case TX_64X32: + av1_highbd_inv_txfm_add_64x32(input, dest, stride, txfm_param); + break; + case TX_4X4: + av1_highbd_inv_txfm_add_4x4_sse4_1(input, dest, stride, txfm_param); + break; + case TX_16X4: + av1_highbd_inv_txfm_add_16x4(input, dest, stride, txfm_param); + break; + case TX_4X16: + av1_highbd_inv_txfm_add_4x16(input, dest, stride, txfm_param); + break; + case TX_8X32: + av1_highbd_inv_txfm_add_8x32(input, dest, stride, txfm_param); + break; + case TX_32X8: + av1_highbd_inv_txfm_add_32x8(input, dest, stride, txfm_param); + break; + case TX_64X64: + case TX_16X64: + case TX_64X16: + av1_highbd_inv_txfm2d_add_universe_sse4_1( + input, dest, stride, txfm_param->tx_type, txfm_param->tx_size, + txfm_param->eob, txfm_param->bd); break; + default: assert(0 && "Invalid transform size"); break; } } diff --git a/third_party/aom/av1/common/x86/highbd_jnt_convolve_avx2.c b/third_party/aom/av1/common/x86/highbd_jnt_convolve_avx2.c index 608bd88a4..e298cf653 100644 --- a/third_party/aom/av1/common/x86/highbd_jnt_convolve_avx2.c +++ b/third_party/aom/av1/common/x86/highbd_jnt_convolve_avx2.c @@ -14,7 +14,6 @@ #include "config/aom_dsp_rtcd.h" -#include "aom_dsp/aom_convolve.h" #include "aom_dsp/x86/convolve_avx2.h" #include "aom_dsp/x86/convolve_common_intrin.h" #include "aom_dsp/x86/convolve_sse4_1.h" diff --git a/third_party/aom/av1/common/x86/highbd_txfm_utility_sse4.h b/third_party/aom/av1/common/x86/highbd_txfm_utility_sse4.h index b29bd1d79..6f24e5948 100644 --- a/third_party/aom/av1/common/x86/highbd_txfm_utility_sse4.h +++ b/third_party/aom/av1/common/x86/highbd_txfm_utility_sse4.h @@ -9,8 +9,8 @@ * PATENTS file, you can obtain it at www.aomedia.org/license/patent. */ -#ifndef _HIGHBD_TXFM_UTILITY_SSE4_H -#define _HIGHBD_TXFM_UTILITY_SSE4_H +#ifndef AOM_AV1_COMMON_X86_HIGHBD_TXFM_UTILITY_SSE4_H_ +#define AOM_AV1_COMMON_X86_HIGHBD_TXFM_UTILITY_SSE4_H_ #include <smmintrin.h> /* SSE4.1 */ @@ -75,6 +75,17 @@ static INLINE void transpose_16x16(const __m128i *in, __m128i *out) { out[63]); } +static INLINE void transpose_32x32(const __m128i *input, __m128i *output) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < 8; i++) { + TRANSPOSE_4X4(input[i * 32 + j + 0], input[i * 32 + j + 8], + input[i * 32 + j + 16], input[i * 32 + j + 24], + output[j * 32 + i + 0], output[j * 32 + i + 8], + output[j * 32 + i + 16], output[j * 32 + i + 24]); + } + } +} + // Note: // rounding = 1 << (bit - 1) static INLINE __m128i half_btf_sse4_1(const __m128i *w0, const __m128i *n0, @@ -100,4 +111,15 @@ static INLINE __m128i half_btf_0_sse4_1(const __m128i *w0, const __m128i *n0, return x; } -#endif // _HIGHBD_TXFM_UTILITY_SSE4_H +typedef void (*transform_1d_sse4_1)(__m128i *in, __m128i *out, int bit, + int do_cols, int bd, int out_shift); + +typedef void (*fwd_transform_1d_sse4_1)(__m128i *in, __m128i *out, int bit, + const int num_cols); + +void av1_highbd_inv_txfm2d_add_universe_sse4_1(const int32_t *input, + uint8_t *output, int stride, + TX_TYPE tx_type, TX_SIZE tx_size, + int eob, const int bd); + +#endif // AOM_AV1_COMMON_X86_HIGHBD_TXFM_UTILITY_SSE4_H_ diff --git a/third_party/aom/av1/common/x86/highbd_warp_plane_sse4.c b/third_party/aom/av1/common/x86/highbd_warp_plane_sse4.c index a08beaafd..4bcab0564 100644 --- a/third_party/aom/av1/common/x86/highbd_warp_plane_sse4.c +++ b/third_party/aom/av1/common/x86/highbd_warp_plane_sse4.c @@ -19,10 +19,21 @@ static const uint8_t warp_highbd_arrange_bytes[16] = { 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 }; -static INLINE void horizontal_filter(__m128i src, __m128i src2, __m128i *tmp, - int sx, int alpha, int k, - const int offset_bits_horiz, - const int reduce_bits_horiz) { +static const uint8_t highbd_shuffle_alpha0_mask0[16] = { + 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3 +}; +static const uint8_t highbd_shuffle_alpha0_mask1[16] = { + 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7 +}; +static const uint8_t highbd_shuffle_alpha0_mask2[16] = { + 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11 +}; +static const uint8_t highbd_shuffle_alpha0_mask3[16] = { + 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15 +}; + +static INLINE void highbd_prepare_horizontal_filter_coeff(int alpha, int sx, + __m128i *coeff) { // Filter even-index pixels const __m128i tmp_0 = _mm_loadu_si128( (__m128i *)(warped_filter + ((sx + 0 * alpha) >> WARPEDDIFF_PREC_BITS))); @@ -43,27 +54,13 @@ static INLINE void horizontal_filter(__m128i src, __m128i src2, __m128i *tmp, const __m128i tmp_14 = _mm_unpackhi_epi32(tmp_4, tmp_6); // coeffs 0 1 0 1 0 1 0 1 for pixels 0, 2, 4, 6 - const __m128i coeff_0 = _mm_unpacklo_epi64(tmp_8, tmp_10); + coeff[0] = _mm_unpacklo_epi64(tmp_8, tmp_10); // coeffs 2 3 2 3 2 3 2 3 for pixels 0, 2, 4, 6 - const __m128i coeff_2 = _mm_unpackhi_epi64(tmp_8, tmp_10); + coeff[2] = _mm_unpackhi_epi64(tmp_8, tmp_10); // coeffs 4 5 4 5 4 5 4 5 for pixels 0, 2, 4, 6 - const __m128i coeff_4 = _mm_unpacklo_epi64(tmp_12, tmp_14); + coeff[4] = _mm_unpacklo_epi64(tmp_12, tmp_14); // coeffs 6 7 6 7 6 7 6 7 for pixels 0, 2, 4, 6 - const __m128i coeff_6 = _mm_unpackhi_epi64(tmp_12, tmp_14); - - const __m128i round_const = _mm_set1_epi32((1 << offset_bits_horiz) + - ((1 << reduce_bits_horiz) >> 1)); - - // Calculate filtered results - const __m128i res_0 = _mm_madd_epi16(src, coeff_0); - const __m128i res_2 = _mm_madd_epi16(_mm_alignr_epi8(src2, src, 4), coeff_2); - const __m128i res_4 = _mm_madd_epi16(_mm_alignr_epi8(src2, src, 8), coeff_4); - const __m128i res_6 = _mm_madd_epi16(_mm_alignr_epi8(src2, src, 12), coeff_6); - - __m128i res_even = - _mm_add_epi32(_mm_add_epi32(res_0, res_4), _mm_add_epi32(res_2, res_6)); - res_even = _mm_sra_epi32(_mm_add_epi32(res_even, round_const), - _mm_cvtsi32_si128(reduce_bits_horiz)); + coeff[6] = _mm_unpackhi_epi64(tmp_12, tmp_14); // Filter odd-index pixels const __m128i tmp_1 = _mm_loadu_si128( @@ -80,15 +77,63 @@ static INLINE void horizontal_filter(__m128i src, __m128i src2, __m128i *tmp, const __m128i tmp_13 = _mm_unpackhi_epi32(tmp_1, tmp_3); const __m128i tmp_15 = _mm_unpackhi_epi32(tmp_5, tmp_7); - const __m128i coeff_1 = _mm_unpacklo_epi64(tmp_9, tmp_11); - const __m128i coeff_3 = _mm_unpackhi_epi64(tmp_9, tmp_11); - const __m128i coeff_5 = _mm_unpacklo_epi64(tmp_13, tmp_15); - const __m128i coeff_7 = _mm_unpackhi_epi64(tmp_13, tmp_15); + coeff[1] = _mm_unpacklo_epi64(tmp_9, tmp_11); + coeff[3] = _mm_unpackhi_epi64(tmp_9, tmp_11); + coeff[5] = _mm_unpacklo_epi64(tmp_13, tmp_15); + coeff[7] = _mm_unpackhi_epi64(tmp_13, tmp_15); +} + +static INLINE void highbd_prepare_horizontal_filter_coeff_alpha0( + int sx, __m128i *coeff) { + // Filter coeff + const __m128i tmp_0 = _mm_loadu_si128( + (__m128i *)(warped_filter + (sx >> WARPEDDIFF_PREC_BITS))); + + coeff[0] = _mm_shuffle_epi8( + tmp_0, _mm_loadu_si128((__m128i *)highbd_shuffle_alpha0_mask0)); + coeff[2] = _mm_shuffle_epi8( + tmp_0, _mm_loadu_si128((__m128i *)highbd_shuffle_alpha0_mask1)); + coeff[4] = _mm_shuffle_epi8( + tmp_0, _mm_loadu_si128((__m128i *)highbd_shuffle_alpha0_mask2)); + coeff[6] = _mm_shuffle_epi8( + tmp_0, _mm_loadu_si128((__m128i *)highbd_shuffle_alpha0_mask3)); + + coeff[1] = coeff[0]; + coeff[3] = coeff[2]; + coeff[5] = coeff[4]; + coeff[7] = coeff[6]; +} + +static INLINE void highbd_filter_src_pixels( + const __m128i *src, const __m128i *src2, __m128i *tmp, __m128i *coeff, + const int offset_bits_horiz, const int reduce_bits_horiz, int k) { + const __m128i src_1 = *src; + const __m128i src2_1 = *src2; - const __m128i res_1 = _mm_madd_epi16(_mm_alignr_epi8(src2, src, 2), coeff_1); - const __m128i res_3 = _mm_madd_epi16(_mm_alignr_epi8(src2, src, 6), coeff_3); - const __m128i res_5 = _mm_madd_epi16(_mm_alignr_epi8(src2, src, 10), coeff_5); - const __m128i res_7 = _mm_madd_epi16(_mm_alignr_epi8(src2, src, 14), coeff_7); + const __m128i round_const = _mm_set1_epi32((1 << offset_bits_horiz) + + ((1 << reduce_bits_horiz) >> 1)); + + const __m128i res_0 = _mm_madd_epi16(src_1, coeff[0]); + const __m128i res_2 = + _mm_madd_epi16(_mm_alignr_epi8(src2_1, src_1, 4), coeff[2]); + const __m128i res_4 = + _mm_madd_epi16(_mm_alignr_epi8(src2_1, src_1, 8), coeff[4]); + const __m128i res_6 = + _mm_madd_epi16(_mm_alignr_epi8(src2_1, src_1, 12), coeff[6]); + + __m128i res_even = + _mm_add_epi32(_mm_add_epi32(res_0, res_4), _mm_add_epi32(res_2, res_6)); + res_even = _mm_sra_epi32(_mm_add_epi32(res_even, round_const), + _mm_cvtsi32_si128(reduce_bits_horiz)); + + const __m128i res_1 = + _mm_madd_epi16(_mm_alignr_epi8(src2_1, src_1, 2), coeff[1]); + const __m128i res_3 = + _mm_madd_epi16(_mm_alignr_epi8(src2_1, src_1, 6), coeff[3]); + const __m128i res_5 = + _mm_madd_epi16(_mm_alignr_epi8(src2_1, src_1, 10), coeff[5]); + const __m128i res_7 = + _mm_madd_epi16(_mm_alignr_epi8(src2_1, src_1, 14), coeff[7]); __m128i res_odd = _mm_add_epi32(_mm_add_epi32(res_1, res_5), _mm_add_epi32(res_3, res_7)); @@ -101,6 +146,145 @@ static INLINE void horizontal_filter(__m128i src, __m128i src2, __m128i *tmp, tmp[k + 7] = _mm_packs_epi32(res_even, res_odd); } +static INLINE void highbd_horiz_filter(const __m128i *src, const __m128i *src2, + __m128i *tmp, int sx, int alpha, int k, + const int offset_bits_horiz, + const int reduce_bits_horiz) { + __m128i coeff[8]; + highbd_prepare_horizontal_filter_coeff(alpha, sx, coeff); + highbd_filter_src_pixels(src, src2, tmp, coeff, offset_bits_horiz, + reduce_bits_horiz, k); +} + +static INLINE void highbd_warp_horizontal_filter_alpha0_beta0( + const uint16_t *ref, __m128i *tmp, int stride, int32_t ix4, int32_t iy4, + int32_t sx4, int alpha, int beta, int p_height, int height, int i, + const int offset_bits_horiz, const int reduce_bits_horiz) { + (void)beta; + (void)alpha; + int k; + + __m128i coeff[8]; + highbd_prepare_horizontal_filter_coeff_alpha0(sx4, coeff); + + for (k = -7; k < AOMMIN(8, p_height - i); ++k) { + int iy = iy4 + k; + if (iy < 0) + iy = 0; + else if (iy > height - 1) + iy = height - 1; + + // Load source pixels + const __m128i src = + _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7)); + const __m128i src2 = + _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 + 1)); + highbd_filter_src_pixels(&src, &src2, tmp, coeff, offset_bits_horiz, + reduce_bits_horiz, k); + } +} + +static INLINE void highbd_warp_horizontal_filter_alpha0( + const uint16_t *ref, __m128i *tmp, int stride, int32_t ix4, int32_t iy4, + int32_t sx4, int alpha, int beta, int p_height, int height, int i, + const int offset_bits_horiz, const int reduce_bits_horiz) { + (void)alpha; + int k; + for (k = -7; k < AOMMIN(8, p_height - i); ++k) { + int iy = iy4 + k; + if (iy < 0) + iy = 0; + else if (iy > height - 1) + iy = height - 1; + int sx = sx4 + beta * (k + 4); + + // Load source pixels + const __m128i src = + _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7)); + const __m128i src2 = + _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 + 1)); + + __m128i coeff[8]; + highbd_prepare_horizontal_filter_coeff_alpha0(sx, coeff); + highbd_filter_src_pixels(&src, &src2, tmp, coeff, offset_bits_horiz, + reduce_bits_horiz, k); + } +} + +static INLINE void highbd_warp_horizontal_filter_beta0( + const uint16_t *ref, __m128i *tmp, int stride, int32_t ix4, int32_t iy4, + int32_t sx4, int alpha, int beta, int p_height, int height, int i, + const int offset_bits_horiz, const int reduce_bits_horiz) { + (void)beta; + int k; + __m128i coeff[8]; + highbd_prepare_horizontal_filter_coeff(alpha, sx4, coeff); + + for (k = -7; k < AOMMIN(8, p_height - i); ++k) { + int iy = iy4 + k; + if (iy < 0) + iy = 0; + else if (iy > height - 1) + iy = height - 1; + + // Load source pixels + const __m128i src = + _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7)); + const __m128i src2 = + _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 + 1)); + highbd_filter_src_pixels(&src, &src2, tmp, coeff, offset_bits_horiz, + reduce_bits_horiz, k); + } +} + +static INLINE void highbd_warp_horizontal_filter( + const uint16_t *ref, __m128i *tmp, int stride, int32_t ix4, int32_t iy4, + int32_t sx4, int alpha, int beta, int p_height, int height, int i, + const int offset_bits_horiz, const int reduce_bits_horiz) { + int k; + for (k = -7; k < AOMMIN(8, p_height - i); ++k) { + int iy = iy4 + k; + if (iy < 0) + iy = 0; + else if (iy > height - 1) + iy = height - 1; + int sx = sx4 + beta * (k + 4); + + // Load source pixels + const __m128i src = + _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7)); + const __m128i src2 = + _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 + 1)); + + highbd_horiz_filter(&src, &src2, tmp, sx, alpha, k, offset_bits_horiz, + reduce_bits_horiz); + } +} + +static INLINE void highbd_prepare_warp_horizontal_filter( + const uint16_t *ref, __m128i *tmp, int stride, int32_t ix4, int32_t iy4, + int32_t sx4, int alpha, int beta, int p_height, int height, int i, + const int offset_bits_horiz, const int reduce_bits_horiz) { + if (alpha == 0 && beta == 0) + highbd_warp_horizontal_filter_alpha0_beta0( + ref, tmp, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i, + offset_bits_horiz, reduce_bits_horiz); + + else if (alpha == 0 && beta != 0) + highbd_warp_horizontal_filter_alpha0(ref, tmp, stride, ix4, iy4, sx4, alpha, + beta, p_height, height, i, + offset_bits_horiz, reduce_bits_horiz); + + else if (alpha != 0 && beta == 0) + highbd_warp_horizontal_filter_beta0(ref, tmp, stride, ix4, iy4, sx4, alpha, + beta, p_height, height, i, + offset_bits_horiz, reduce_bits_horiz); + else + highbd_warp_horizontal_filter(ref, tmp, stride, ix4, iy4, sx4, alpha, beta, + p_height, height, i, offset_bits_horiz, + reduce_bits_horiz); +} + void av1_highbd_warp_affine_sse4_1(const int32_t *mat, const uint16_t *ref, int width, int height, int stride, uint16_t *pred, int p_col, int p_row, @@ -247,27 +431,13 @@ void av1_highbd_warp_affine_sse4_1(const int32_t *mat, const uint16_t *ref, const __m128i src_padded = _mm_unpacklo_epi8(src_lo, src_hi); const __m128i src2_padded = _mm_unpackhi_epi8(src_lo, src_hi); - horizontal_filter(src_padded, src2_padded, tmp, sx, alpha, k, - offset_bits_horiz, reduce_bits_horiz); + highbd_horiz_filter(&src_padded, &src2_padded, tmp, sx, alpha, k, + offset_bits_horiz, reduce_bits_horiz); } } else { - for (k = -7; k < AOMMIN(8, p_height - i); ++k) { - int iy = iy4 + k; - if (iy < 0) - iy = 0; - else if (iy > height - 1) - iy = height - 1; - int sx = sx4 + beta * (k + 4); - - // Load source pixels - const __m128i src = - _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7)); - const __m128i src2 = - _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 + 1)); - - horizontal_filter(src, src2, tmp, sx, alpha, k, offset_bits_horiz, - reduce_bits_horiz); - } + highbd_prepare_warp_horizontal_filter( + ref, tmp, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i, + offset_bits_horiz, reduce_bits_horiz); } // Vertical filter diff --git a/third_party/aom/av1/common/x86/jnt_convolve_avx2.c b/third_party/aom/av1/common/x86/jnt_convolve_avx2.c index d1ea26290..9f2e2b457 100644 --- a/third_party/aom/av1/common/x86/jnt_convolve_avx2.c +++ b/third_party/aom/av1/common/x86/jnt_convolve_avx2.c @@ -13,7 +13,6 @@ #include "config/aom_dsp_rtcd.h" -#include "aom_dsp/aom_convolve.h" #include "aom_dsp/x86/convolve_avx2.h" #include "aom_dsp/x86/convolve_common_intrin.h" #include "aom_dsp/x86/convolve_sse4_1.h" @@ -21,6 +20,21 @@ #include "aom_dsp/aom_filter.h" #include "av1/common/convolve.h" +static INLINE __m256i unpack_weights_avx2(ConvolveParams *conv_params) { + const int w0 = conv_params->fwd_offset; + const int w1 = conv_params->bck_offset; + const __m256i wt0 = _mm256_set1_epi16(w0); + const __m256i wt1 = _mm256_set1_epi16(w1); + const __m256i wt = _mm256_unpacklo_epi16(wt0, wt1); + return wt; +} + +static INLINE __m256i load_line2_avx2(const void *a, const void *b) { + return _mm256_permute2x128_si256( + _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)a)), + _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)b)), 0x20); +} + void av1_jnt_convolve_x_avx2(const uint8_t *src, int src_stride, uint8_t *dst0, int dst_stride0, int w, int h, const InterpFilterParams *filter_params_x, @@ -34,11 +48,7 @@ void av1_jnt_convolve_x_avx2(const uint8_t *src, int src_stride, uint8_t *dst0, const int fo_horiz = filter_params_x->taps / 2 - 1; const uint8_t *const src_ptr = src - fo_horiz; const int bits = FILTER_BITS - conv_params->round_1; - const int w0 = conv_params->fwd_offset; - const int w1 = conv_params->bck_offset; - const __m256i wt0 = _mm256_set1_epi16(w0); - const __m256i wt1 = _mm256_set1_epi16(w1); - const __m256i wt = _mm256_unpacklo_epi16(wt0, wt1); + const __m256i wt = unpack_weights_avx2(conv_params); const int do_average = conv_params->do_average; const int use_jnt_comp_avg = conv_params->use_jnt_comp_avg; const int offset_0 = @@ -68,13 +78,11 @@ void av1_jnt_convolve_x_avx2(const uint8_t *src, int src_stride, uint8_t *dst0, (void)subpel_y_q4; for (i = 0; i < h; i += 2) { + const uint8_t *src_data = src_ptr + i * src_stride; + CONV_BUF_TYPE *dst_data = dst + i * dst_stride; for (j = 0; j < w; j += 8) { - const __m256i data = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(&src_ptr[i * src_stride + j]))), - _mm256_castsi128_si256(_mm_loadu_si128( - (__m128i *)(&src_ptr[i * src_stride + j + src_stride]))), - 0x20); + const __m256i data = + load_line2_avx2(&src_data[j], &src_data[j + src_stride]); __m256i res = convolve_lowbd_x(data, coeffs, filt); @@ -86,13 +94,8 @@ void av1_jnt_convolve_x_avx2(const uint8_t *src, int src_stride, uint8_t *dst0, // Accumulate values into the destination buffer if (do_average) { - const __m256i data_ref_0 = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j]))), - _mm256_castsi128_si256(_mm_loadu_si128( - (__m128i *)(&dst[i * dst_stride + j + dst_stride]))), - 0x20); - + const __m256i data_ref_0 = + load_line2_avx2(&dst_data[j], &dst_data[j + dst_stride]); const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_unsigned, &wt, use_jnt_comp_avg); @@ -141,11 +144,7 @@ void av1_jnt_convolve_y_avx2(const uint8_t *src, int src_stride, uint8_t *dst0, const __m256i round_const = _mm256_set1_epi32((1 << conv_params->round_1) >> 1); const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1); - const int w0 = conv_params->fwd_offset; - const int w1 = conv_params->bck_offset; - const __m256i wt0 = _mm256_set1_epi16(w0); - const __m256i wt1 = _mm256_set1_epi16(w1); - const __m256i wt = _mm256_unpacklo_epi16(wt0, wt1); + const __m256i wt = unpack_weights_avx2(conv_params); const int do_average = conv_params->do_average; const int use_jnt_comp_avg = conv_params->use_jnt_comp_avg; const int offset_0 = @@ -172,72 +171,35 @@ void av1_jnt_convolve_y_avx2(const uint8_t *src, int src_stride, uint8_t *dst0, for (j = 0; j < w; j += 16) { const uint8_t *data = &src_ptr[j]; __m256i src6; - // Load lines a and b. Line a to lower 128, line b to upper 128 - const __m256i src_01a = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 0 * src_stride))), - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 1 * src_stride))), - 0x20); - - const __m256i src_12a = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 1 * src_stride))), - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 2 * src_stride))), - 0x20); - - const __m256i src_23a = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 2 * src_stride))), - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 3 * src_stride))), - 0x20); - - const __m256i src_34a = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 3 * src_stride))), - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 4 * src_stride))), - 0x20); - - const __m256i src_45a = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 4 * src_stride))), - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 5 * src_stride))), - 0x20); - - src6 = _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 6 * src_stride))); - const __m256i src_56a = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 5 * src_stride))), - src6, 0x20); - - s[0] = _mm256_unpacklo_epi8(src_01a, src_12a); - s[1] = _mm256_unpacklo_epi8(src_23a, src_34a); - s[2] = _mm256_unpacklo_epi8(src_45a, src_56a); - - s[4] = _mm256_unpackhi_epi8(src_01a, src_12a); - s[5] = _mm256_unpackhi_epi8(src_23a, src_34a); - s[6] = _mm256_unpackhi_epi8(src_45a, src_56a); + { + __m256i src_ab[7]; + __m256i src_a[7]; + src_a[0] = _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data)); + for (int kk = 0; kk < 6; ++kk) { + data += src_stride; + src_a[kk + 1] = + _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data)); + src_ab[kk] = _mm256_permute2x128_si256(src_a[kk], src_a[kk + 1], 0x20); + } + src6 = src_a[6]; + s[0] = _mm256_unpacklo_epi8(src_ab[0], src_ab[1]); + s[1] = _mm256_unpacklo_epi8(src_ab[2], src_ab[3]); + s[2] = _mm256_unpacklo_epi8(src_ab[4], src_ab[5]); + s[4] = _mm256_unpackhi_epi8(src_ab[0], src_ab[1]); + s[5] = _mm256_unpackhi_epi8(src_ab[2], src_ab[3]); + s[6] = _mm256_unpackhi_epi8(src_ab[4], src_ab[5]); + } for (i = 0; i < h; i += 2) { - data = &src_ptr[i * src_stride + j]; - const __m256i src_67a = _mm256_permute2x128_si256( - src6, - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 7 * src_stride))), - 0x20); + data = &src_ptr[(i + 7) * src_stride + j]; + const __m256i src7 = + _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data)); + const __m256i src_67a = _mm256_permute2x128_si256(src6, src7, 0x20); src6 = _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 8 * src_stride))); - const __m256i src_78a = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(data + 7 * src_stride))), - src6, 0x20); + _mm_loadu_si128((__m128i *)(data + src_stride))); + const __m256i src_78a = _mm256_permute2x128_si256(src7, src6, 0x20); s[3] = _mm256_unpacklo_epi8(src_67a, src_78a); s[7] = _mm256_unpackhi_epi8(src_67a, src_78a); @@ -266,13 +228,8 @@ void av1_jnt_convolve_y_avx2(const uint8_t *src, int src_stride, uint8_t *dst0, if (w - j < 16) { if (do_average) { - const __m256i data_ref_0 = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j]))), - _mm256_castsi128_si256(_mm_loadu_si128( - (__m128i *)(&dst[i * dst_stride + j + dst_stride]))), - 0x20); - + const __m256i data_ref_0 = load_line2_avx2( + &dst[i * dst_stride + j], &dst[i * dst_stride + j + dst_stride]); const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_lo_unsigned, &wt, use_jnt_comp_avg); @@ -325,19 +282,12 @@ void av1_jnt_convolve_y_avx2(const uint8_t *src, int src_stride, uint8_t *dst0, _mm256_add_epi16(res_hi_round, offset_const_2); if (do_average) { - const __m256i data_ref_0_lo = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j]))), - _mm256_castsi128_si256(_mm_loadu_si128( - (__m128i *)(&dst[i * dst_stride + j + dst_stride]))), - 0x20); - - const __m256i data_ref_0_hi = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j + 8]))), - _mm256_castsi128_si256(_mm_loadu_si128( - (__m128i *)(&dst[i * dst_stride + j + 8 + dst_stride]))), - 0x20); + const __m256i data_ref_0_lo = load_line2_avx2( + &dst[i * dst_stride + j], &dst[i * dst_stride + j + dst_stride]); + + const __m256i data_ref_0_hi = + load_line2_avx2(&dst[i * dst_stride + j + 8], + &dst[i * dst_stride + j + 8 + dst_stride]); const __m256i comp_avg_res_lo = comp_avg(&data_ref_0_lo, &res_lo_unsigned, &wt, use_jnt_comp_avg); @@ -404,11 +354,7 @@ void av1_jnt_convolve_2d_avx2(const uint8_t *src, int src_stride, uint8_t *dst0, const int fo_vert = filter_params_y->taps / 2 - 1; const int fo_horiz = filter_params_x->taps / 2 - 1; const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz; - const int w0 = conv_params->fwd_offset; - const int w1 = conv_params->bck_offset; - const __m256i wt0 = _mm256_set1_epi16(w0); - const __m256i wt1 = _mm256_set1_epi16(w1); - const __m256i wt = _mm256_unpacklo_epi16(wt0, wt1); + const __m256i wt = unpack_weights_avx2(conv_params); const int do_average = conv_params->do_average; const int use_jnt_comp_avg = conv_params->use_jnt_comp_avg; const int offset_0 = @@ -442,15 +388,14 @@ void av1_jnt_convolve_2d_avx2(const uint8_t *src, int src_stride, uint8_t *dst0, for (j = 0; j < w; j += 8) { /* Horizontal filter */ { + const uint8_t *src_h = src_ptr + j; for (i = 0; i < im_h; i += 2) { - __m256i data = _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)&src_ptr[(i * src_stride) + j])); + __m256i data = + _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)src_h)); if (i + 1 < im_h) data = _mm256_inserti128_si256( - data, - _mm_loadu_si128( - (__m128i *)&src_ptr[(i * src_stride) + j + src_stride]), - 1); + data, _mm_loadu_si128((__m128i *)(src_h + src_stride)), 1); + src_h += (src_stride << 1); __m256i res = convolve_lowbd_x(data, coeffs_x, filt); res = _mm256_sra_epi16(_mm256_add_epi16(res, round_const_h), @@ -500,13 +445,9 @@ void av1_jnt_convolve_2d_avx2(const uint8_t *src, int src_stride, uint8_t *dst0, const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const); if (do_average) { - const __m256i data_ref_0 = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j]))), - _mm256_castsi128_si256(_mm_loadu_si128( - (__m128i *)(&dst[i * dst_stride + j + dst_stride]))), - 0x20); - + const __m256i data_ref_0 = + load_line2_avx2(&dst[i * dst_stride + j], + &dst[i * dst_stride + j + dst_stride]); const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_unsigned, &wt, use_jnt_comp_avg); @@ -534,12 +475,9 @@ void av1_jnt_convolve_2d_avx2(const uint8_t *src, int src_stride, uint8_t *dst0, const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const); if (do_average) { - const __m256i data_ref_0 = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j]))), - _mm256_castsi128_si256(_mm_loadu_si128( - (__m128i *)(&dst[i * dst_stride + j + dst_stride]))), - 0x20); + const __m256i data_ref_0 = + load_line2_avx2(&dst[i * dst_stride + j], + &dst[i * dst_stride + j + dst_stride]); const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_unsigned, &wt, use_jnt_comp_avg); @@ -598,11 +536,7 @@ void av1_jnt_convolve_2d_copy_avx2(const uint8_t *src, int src_stride, const __m128i left_shift = _mm_cvtsi32_si128(bits); const int do_average = conv_params->do_average; const int use_jnt_comp_avg = conv_params->use_jnt_comp_avg; - const int w0 = conv_params->fwd_offset; - const int w1 = conv_params->bck_offset; - const __m256i wt0 = _mm256_set1_epi16(w0); - const __m256i wt1 = _mm256_set1_epi16(w1); - const __m256i wt = _mm256_unpacklo_epi16(wt0, wt1); + const __m256i wt = unpack_weights_avx2(conv_params); const __m256i zero = _mm256_setzero_si256(); const int offset_0 = @@ -663,13 +597,8 @@ void av1_jnt_convolve_2d_copy_avx2(const uint8_t *src, int src_stride, // Accumulate values into the destination buffer if (do_average) { - const __m256i data_ref_0 = _mm256_permute2x128_si256( - _mm256_castsi128_si256( - _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j]))), - _mm256_castsi128_si256(_mm_loadu_si128( - (__m128i *)(&dst[i * dst_stride + j + dst_stride]))), - 0x20); - + const __m256i data_ref_0 = load_line2_avx2( + &dst[i * dst_stride + j], &dst[i * dst_stride + j + dst_stride]); const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_unsigned, &wt, use_jnt_comp_avg); diff --git a/third_party/aom/av1/common/x86/reconinter_avx2.c b/third_party/aom/av1/common/x86/reconinter_avx2.c index ffbb31849..f645e0454 100644 --- a/third_party/aom/av1/common/x86/reconinter_avx2.c +++ b/third_party/aom/av1/common/x86/reconinter_avx2.c @@ -16,8 +16,504 @@ #include "aom/aom_integer.h" #include "aom_dsp/blend.h" #include "aom_dsp/x86/synonyms.h" +#include "aom_dsp/x86/synonyms_avx2.h" #include "av1/common/blockd.h" +static INLINE __m256i calc_mask_avx2(const __m256i mask_base, const __m256i s0, + const __m256i s1) { + const __m256i diff = _mm256_abs_epi16(_mm256_sub_epi16(s0, s1)); + return _mm256_abs_epi16( + _mm256_add_epi16(mask_base, _mm256_srli_epi16(diff, 4))); + // clamp(diff, 0, 64) can be skiped for diff is always in the range ( 38, 54) +} +void av1_build_compound_diffwtd_mask_avx2(uint8_t *mask, + DIFFWTD_MASK_TYPE mask_type, + const uint8_t *src0, int stride0, + const uint8_t *src1, int stride1, + int h, int w) { + const int mb = (mask_type == DIFFWTD_38_INV) ? AOM_BLEND_A64_MAX_ALPHA : 0; + const __m256i y_mask_base = _mm256_set1_epi16(38 - mb); + int i = 0; + if (4 == w) { + do { + const __m128i s0A = xx_loadl_32(src0); + const __m128i s0B = xx_loadl_32(src0 + stride0); + const __m128i s0C = xx_loadl_32(src0 + stride0 * 2); + const __m128i s0D = xx_loadl_32(src0 + stride0 * 3); + const __m128i s0AB = _mm_unpacklo_epi32(s0A, s0B); + const __m128i s0CD = _mm_unpacklo_epi32(s0C, s0D); + const __m128i s0ABCD = _mm_unpacklo_epi64(s0AB, s0CD); + const __m256i s0ABCD_w = _mm256_cvtepu8_epi16(s0ABCD); + + const __m128i s1A = xx_loadl_32(src1); + const __m128i s1B = xx_loadl_32(src1 + stride1); + const __m128i s1C = xx_loadl_32(src1 + stride1 * 2); + const __m128i s1D = xx_loadl_32(src1 + stride1 * 3); + const __m128i s1AB = _mm_unpacklo_epi32(s1A, s1B); + const __m128i s1CD = _mm_unpacklo_epi32(s1C, s1D); + const __m128i s1ABCD = _mm_unpacklo_epi64(s1AB, s1CD); + const __m256i s1ABCD_w = _mm256_cvtepu8_epi16(s1ABCD); + const __m256i m16 = calc_mask_avx2(y_mask_base, s0ABCD_w, s1ABCD_w); + const __m256i m8 = _mm256_packus_epi16(m16, _mm256_setzero_si256()); + const __m128i x_m8 = + _mm256_castsi256_si128(_mm256_permute4x64_epi64(m8, 0xd8)); + xx_storeu_128(mask, x_m8); + src0 += (stride0 << 2); + src1 += (stride1 << 2); + mask += 16; + i += 4; + } while (i < h); + } else if (8 == w) { + do { + const __m128i s0A = xx_loadl_64(src0); + const __m128i s0B = xx_loadl_64(src0 + stride0); + const __m128i s0C = xx_loadl_64(src0 + stride0 * 2); + const __m128i s0D = xx_loadl_64(src0 + stride0 * 3); + const __m256i s0AC_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(s0A, s0C)); + const __m256i s0BD_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(s0B, s0D)); + const __m128i s1A = xx_loadl_64(src1); + const __m128i s1B = xx_loadl_64(src1 + stride1); + const __m128i s1C = xx_loadl_64(src1 + stride1 * 2); + const __m128i s1D = xx_loadl_64(src1 + stride1 * 3); + const __m256i s1AB_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(s1A, s1C)); + const __m256i s1CD_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(s1B, s1D)); + const __m256i m16AC = calc_mask_avx2(y_mask_base, s0AC_w, s1AB_w); + const __m256i m16BD = calc_mask_avx2(y_mask_base, s0BD_w, s1CD_w); + const __m256i m8 = _mm256_packus_epi16(m16AC, m16BD); + yy_storeu_256(mask, m8); + src0 += stride0 << 2; + src1 += stride1 << 2; + mask += 32; + i += 4; + } while (i < h); + } else if (16 == w) { + do { + const __m128i s0A = xx_load_128(src0); + const __m128i s0B = xx_load_128(src0 + stride0); + const __m128i s1A = xx_load_128(src1); + const __m128i s1B = xx_load_128(src1 + stride1); + const __m256i s0AL = _mm256_cvtepu8_epi16(s0A); + const __m256i s0BL = _mm256_cvtepu8_epi16(s0B); + const __m256i s1AL = _mm256_cvtepu8_epi16(s1A); + const __m256i s1BL = _mm256_cvtepu8_epi16(s1B); + + const __m256i m16AL = calc_mask_avx2(y_mask_base, s0AL, s1AL); + const __m256i m16BL = calc_mask_avx2(y_mask_base, s0BL, s1BL); + + const __m256i m8 = + _mm256_permute4x64_epi64(_mm256_packus_epi16(m16AL, m16BL), 0xd8); + yy_storeu_256(mask, m8); + src0 += stride0 << 1; + src1 += stride1 << 1; + mask += 32; + i += 2; + } while (i < h); + } else { + do { + int j = 0; + do { + const __m256i s0 = yy_loadu_256(src0 + j); + const __m256i s1 = yy_loadu_256(src1 + j); + const __m256i s0L = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(s0)); + const __m256i s1L = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(s1)); + const __m256i s0H = + _mm256_cvtepu8_epi16(_mm256_extracti128_si256(s0, 1)); + const __m256i s1H = + _mm256_cvtepu8_epi16(_mm256_extracti128_si256(s1, 1)); + const __m256i m16L = calc_mask_avx2(y_mask_base, s0L, s1L); + const __m256i m16H = calc_mask_avx2(y_mask_base, s0H, s1H); + const __m256i m8 = + _mm256_permute4x64_epi64(_mm256_packus_epi16(m16L, m16H), 0xd8); + yy_storeu_256(mask + j, m8); + j += 32; + } while (j < w); + src0 += stride0; + src1 += stride1; + mask += w; + i += 1; + } while (i < h); + } +} + +static INLINE __m256i calc_mask_d16_avx2(const __m256i *data_src0, + const __m256i *data_src1, + const __m256i *round_const, + const __m256i *mask_base_16, + const __m256i *clip_diff, int round) { + const __m256i diffa = _mm256_subs_epu16(*data_src0, *data_src1); + const __m256i diffb = _mm256_subs_epu16(*data_src1, *data_src0); + const __m256i diff = _mm256_max_epu16(diffa, diffb); + const __m256i diff_round = + _mm256_srli_epi16(_mm256_adds_epu16(diff, *round_const), round); + const __m256i diff_factor = _mm256_srli_epi16(diff_round, DIFF_FACTOR_LOG2); + const __m256i diff_mask = _mm256_adds_epi16(diff_factor, *mask_base_16); + const __m256i diff_clamp = _mm256_min_epi16(diff_mask, *clip_diff); + return diff_clamp; +} + +static INLINE __m256i calc_mask_d16_inv_avx2(const __m256i *data_src0, + const __m256i *data_src1, + const __m256i *round_const, + const __m256i *mask_base_16, + const __m256i *clip_diff, + int round) { + const __m256i diffa = _mm256_subs_epu16(*data_src0, *data_src1); + const __m256i diffb = _mm256_subs_epu16(*data_src1, *data_src0); + const __m256i diff = _mm256_max_epu16(diffa, diffb); + const __m256i diff_round = + _mm256_srli_epi16(_mm256_adds_epu16(diff, *round_const), round); + const __m256i diff_factor = _mm256_srli_epi16(diff_round, DIFF_FACTOR_LOG2); + const __m256i diff_mask = _mm256_adds_epi16(diff_factor, *mask_base_16); + const __m256i diff_clamp = _mm256_min_epi16(diff_mask, *clip_diff); + const __m256i diff_const_16 = _mm256_sub_epi16(*clip_diff, diff_clamp); + return diff_const_16; +} + +static INLINE void build_compound_diffwtd_mask_d16_avx2( + uint8_t *mask, const CONV_BUF_TYPE *src0, int src0_stride, + const CONV_BUF_TYPE *src1, int src1_stride, int h, int w, int shift) { + const int mask_base = 38; + const __m256i _r = _mm256_set1_epi16((1 << shift) >> 1); + const __m256i y38 = _mm256_set1_epi16(mask_base); + const __m256i y64 = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA); + int i = 0; + if (w == 4) { + do { + const __m128i s0A = xx_loadl_64(src0); + const __m128i s0B = xx_loadl_64(src0 + src0_stride); + const __m128i s0C = xx_loadl_64(src0 + src0_stride * 2); + const __m128i s0D = xx_loadl_64(src0 + src0_stride * 3); + const __m128i s1A = xx_loadl_64(src1); + const __m128i s1B = xx_loadl_64(src1 + src1_stride); + const __m128i s1C = xx_loadl_64(src1 + src1_stride * 2); + const __m128i s1D = xx_loadl_64(src1 + src1_stride * 3); + const __m256i s0 = yy_set_m128i(_mm_unpacklo_epi64(s0C, s0D), + _mm_unpacklo_epi64(s0A, s0B)); + const __m256i s1 = yy_set_m128i(_mm_unpacklo_epi64(s1C, s1D), + _mm_unpacklo_epi64(s1A, s1B)); + const __m256i m16 = calc_mask_d16_avx2(&s0, &s1, &_r, &y38, &y64, shift); + const __m256i m8 = _mm256_packus_epi16(m16, _mm256_setzero_si256()); + xx_storeu_128(mask, + _mm256_castsi256_si128(_mm256_permute4x64_epi64(m8, 0xd8))); + src0 += src0_stride << 2; + src1 += src1_stride << 2; + mask += 16; + i += 4; + } while (i < h); + } else if (w == 8) { + do { + const __m256i s0AB = yy_loadu2_128(src0 + src0_stride, src0); + const __m256i s0CD = + yy_loadu2_128(src0 + src0_stride * 3, src0 + src0_stride * 2); + const __m256i s1AB = yy_loadu2_128(src1 + src1_stride, src1); + const __m256i s1CD = + yy_loadu2_128(src1 + src1_stride * 3, src1 + src1_stride * 2); + const __m256i m16AB = + calc_mask_d16_avx2(&s0AB, &s1AB, &_r, &y38, &y64, shift); + const __m256i m16CD = + calc_mask_d16_avx2(&s0CD, &s1CD, &_r, &y38, &y64, shift); + const __m256i m8 = _mm256_packus_epi16(m16AB, m16CD); + yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8)); + src0 += src0_stride << 2; + src1 += src1_stride << 2; + mask += 32; + i += 4; + } while (i < h); + } else if (w == 16) { + do { + const __m256i s0A = yy_loadu_256(src0); + const __m256i s0B = yy_loadu_256(src0 + src0_stride); + const __m256i s1A = yy_loadu_256(src1); + const __m256i s1B = yy_loadu_256(src1 + src1_stride); + const __m256i m16A = + calc_mask_d16_avx2(&s0A, &s1A, &_r, &y38, &y64, shift); + const __m256i m16B = + calc_mask_d16_avx2(&s0B, &s1B, &_r, &y38, &y64, shift); + const __m256i m8 = _mm256_packus_epi16(m16A, m16B); + yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8)); + src0 += src0_stride << 1; + src1 += src1_stride << 1; + mask += 32; + i += 2; + } while (i < h); + } else if (w == 32) { + do { + const __m256i s0A = yy_loadu_256(src0); + const __m256i s0B = yy_loadu_256(src0 + 16); + const __m256i s1A = yy_loadu_256(src1); + const __m256i s1B = yy_loadu_256(src1 + 16); + const __m256i m16A = + calc_mask_d16_avx2(&s0A, &s1A, &_r, &y38, &y64, shift); + const __m256i m16B = + calc_mask_d16_avx2(&s0B, &s1B, &_r, &y38, &y64, shift); + const __m256i m8 = _mm256_packus_epi16(m16A, m16B); + yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8)); + src0 += src0_stride; + src1 += src1_stride; + mask += 32; + i += 1; + } while (i < h); + } else if (w == 64) { + do { + const __m256i s0A = yy_loadu_256(src0); + const __m256i s0B = yy_loadu_256(src0 + 16); + const __m256i s0C = yy_loadu_256(src0 + 32); + const __m256i s0D = yy_loadu_256(src0 + 48); + const __m256i s1A = yy_loadu_256(src1); + const __m256i s1B = yy_loadu_256(src1 + 16); + const __m256i s1C = yy_loadu_256(src1 + 32); + const __m256i s1D = yy_loadu_256(src1 + 48); + const __m256i m16A = + calc_mask_d16_avx2(&s0A, &s1A, &_r, &y38, &y64, shift); + const __m256i m16B = + calc_mask_d16_avx2(&s0B, &s1B, &_r, &y38, &y64, shift); + const __m256i m16C = + calc_mask_d16_avx2(&s0C, &s1C, &_r, &y38, &y64, shift); + const __m256i m16D = + calc_mask_d16_avx2(&s0D, &s1D, &_r, &y38, &y64, shift); + const __m256i m8AB = _mm256_packus_epi16(m16A, m16B); + const __m256i m8CD = _mm256_packus_epi16(m16C, m16D); + yy_storeu_256(mask, _mm256_permute4x64_epi64(m8AB, 0xd8)); + yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8CD, 0xd8)); + src0 += src0_stride; + src1 += src1_stride; + mask += 64; + i += 1; + } while (i < h); + } else { + do { + const __m256i s0A = yy_loadu_256(src0); + const __m256i s0B = yy_loadu_256(src0 + 16); + const __m256i s0C = yy_loadu_256(src0 + 32); + const __m256i s0D = yy_loadu_256(src0 + 48); + const __m256i s0E = yy_loadu_256(src0 + 64); + const __m256i s0F = yy_loadu_256(src0 + 80); + const __m256i s0G = yy_loadu_256(src0 + 96); + const __m256i s0H = yy_loadu_256(src0 + 112); + const __m256i s1A = yy_loadu_256(src1); + const __m256i s1B = yy_loadu_256(src1 + 16); + const __m256i s1C = yy_loadu_256(src1 + 32); + const __m256i s1D = yy_loadu_256(src1 + 48); + const __m256i s1E = yy_loadu_256(src1 + 64); + const __m256i s1F = yy_loadu_256(src1 + 80); + const __m256i s1G = yy_loadu_256(src1 + 96); + const __m256i s1H = yy_loadu_256(src1 + 112); + const __m256i m16A = + calc_mask_d16_avx2(&s0A, &s1A, &_r, &y38, &y64, shift); + const __m256i m16B = + calc_mask_d16_avx2(&s0B, &s1B, &_r, &y38, &y64, shift); + const __m256i m16C = + calc_mask_d16_avx2(&s0C, &s1C, &_r, &y38, &y64, shift); + const __m256i m16D = + calc_mask_d16_avx2(&s0D, &s1D, &_r, &y38, &y64, shift); + const __m256i m16E = + calc_mask_d16_avx2(&s0E, &s1E, &_r, &y38, &y64, shift); + const __m256i m16F = + calc_mask_d16_avx2(&s0F, &s1F, &_r, &y38, &y64, shift); + const __m256i m16G = + calc_mask_d16_avx2(&s0G, &s1G, &_r, &y38, &y64, shift); + const __m256i m16H = + calc_mask_d16_avx2(&s0H, &s1H, &_r, &y38, &y64, shift); + const __m256i m8AB = _mm256_packus_epi16(m16A, m16B); + const __m256i m8CD = _mm256_packus_epi16(m16C, m16D); + const __m256i m8EF = _mm256_packus_epi16(m16E, m16F); + const __m256i m8GH = _mm256_packus_epi16(m16G, m16H); + yy_storeu_256(mask, _mm256_permute4x64_epi64(m8AB, 0xd8)); + yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8CD, 0xd8)); + yy_storeu_256(mask + 64, _mm256_permute4x64_epi64(m8EF, 0xd8)); + yy_storeu_256(mask + 96, _mm256_permute4x64_epi64(m8GH, 0xd8)); + src0 += src0_stride; + src1 += src1_stride; + mask += 128; + i += 1; + } while (i < h); + } +} + +static INLINE void build_compound_diffwtd_mask_d16_inv_avx2( + uint8_t *mask, const CONV_BUF_TYPE *src0, int src0_stride, + const CONV_BUF_TYPE *src1, int src1_stride, int h, int w, int shift) { + const int mask_base = 38; + const __m256i _r = _mm256_set1_epi16((1 << shift) >> 1); + const __m256i y38 = _mm256_set1_epi16(mask_base); + const __m256i y64 = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA); + int i = 0; + if (w == 4) { + do { + const __m128i s0A = xx_loadl_64(src0); + const __m128i s0B = xx_loadl_64(src0 + src0_stride); + const __m128i s0C = xx_loadl_64(src0 + src0_stride * 2); + const __m128i s0D = xx_loadl_64(src0 + src0_stride * 3); + const __m128i s1A = xx_loadl_64(src1); + const __m128i s1B = xx_loadl_64(src1 + src1_stride); + const __m128i s1C = xx_loadl_64(src1 + src1_stride * 2); + const __m128i s1D = xx_loadl_64(src1 + src1_stride * 3); + const __m256i s0 = yy_set_m128i(_mm_unpacklo_epi64(s0C, s0D), + _mm_unpacklo_epi64(s0A, s0B)); + const __m256i s1 = yy_set_m128i(_mm_unpacklo_epi64(s1C, s1D), + _mm_unpacklo_epi64(s1A, s1B)); + const __m256i m16 = + calc_mask_d16_inv_avx2(&s0, &s1, &_r, &y38, &y64, shift); + const __m256i m8 = _mm256_packus_epi16(m16, _mm256_setzero_si256()); + xx_storeu_128(mask, + _mm256_castsi256_si128(_mm256_permute4x64_epi64(m8, 0xd8))); + src0 += src0_stride << 2; + src1 += src1_stride << 2; + mask += 16; + i += 4; + } while (i < h); + } else if (w == 8) { + do { + const __m256i s0AB = yy_loadu2_128(src0 + src0_stride, src0); + const __m256i s0CD = + yy_loadu2_128(src0 + src0_stride * 3, src0 + src0_stride * 2); + const __m256i s1AB = yy_loadu2_128(src1 + src1_stride, src1); + const __m256i s1CD = + yy_loadu2_128(src1 + src1_stride * 3, src1 + src1_stride * 2); + const __m256i m16AB = + calc_mask_d16_inv_avx2(&s0AB, &s1AB, &_r, &y38, &y64, shift); + const __m256i m16CD = + calc_mask_d16_inv_avx2(&s0CD, &s1CD, &_r, &y38, &y64, shift); + const __m256i m8 = _mm256_packus_epi16(m16AB, m16CD); + yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8)); + src0 += src0_stride << 2; + src1 += src1_stride << 2; + mask += 32; + i += 4; + } while (i < h); + } else if (w == 16) { + do { + const __m256i s0A = yy_loadu_256(src0); + const __m256i s0B = yy_loadu_256(src0 + src0_stride); + const __m256i s1A = yy_loadu_256(src1); + const __m256i s1B = yy_loadu_256(src1 + src1_stride); + const __m256i m16A = + calc_mask_d16_inv_avx2(&s0A, &s1A, &_r, &y38, &y64, shift); + const __m256i m16B = + calc_mask_d16_inv_avx2(&s0B, &s1B, &_r, &y38, &y64, shift); + const __m256i m8 = _mm256_packus_epi16(m16A, m16B); + yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8)); + src0 += src0_stride << 1; + src1 += src1_stride << 1; + mask += 32; + i += 2; + } while (i < h); + } else if (w == 32) { + do { + const __m256i s0A = yy_loadu_256(src0); + const __m256i s0B = yy_loadu_256(src0 + 16); + const __m256i s1A = yy_loadu_256(src1); + const __m256i s1B = yy_loadu_256(src1 + 16); + const __m256i m16A = + calc_mask_d16_inv_avx2(&s0A, &s1A, &_r, &y38, &y64, shift); + const __m256i m16B = + calc_mask_d16_inv_avx2(&s0B, &s1B, &_r, &y38, &y64, shift); + const __m256i m8 = _mm256_packus_epi16(m16A, m16B); + yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8)); + src0 += src0_stride; + src1 += src1_stride; + mask += 32; + i += 1; + } while (i < h); + } else if (w == 64) { + do { + const __m256i s0A = yy_loadu_256(src0); + const __m256i s0B = yy_loadu_256(src0 + 16); + const __m256i s0C = yy_loadu_256(src0 + 32); + const __m256i s0D = yy_loadu_256(src0 + 48); + const __m256i s1A = yy_loadu_256(src1); + const __m256i s1B = yy_loadu_256(src1 + 16); + const __m256i s1C = yy_loadu_256(src1 + 32); + const __m256i s1D = yy_loadu_256(src1 + 48); + const __m256i m16A = + calc_mask_d16_inv_avx2(&s0A, &s1A, &_r, &y38, &y64, shift); + const __m256i m16B = + calc_mask_d16_inv_avx2(&s0B, &s1B, &_r, &y38, &y64, shift); + const __m256i m16C = + calc_mask_d16_inv_avx2(&s0C, &s1C, &_r, &y38, &y64, shift); + const __m256i m16D = + calc_mask_d16_inv_avx2(&s0D, &s1D, &_r, &y38, &y64, shift); + const __m256i m8AB = _mm256_packus_epi16(m16A, m16B); + const __m256i m8CD = _mm256_packus_epi16(m16C, m16D); + yy_storeu_256(mask, _mm256_permute4x64_epi64(m8AB, 0xd8)); + yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8CD, 0xd8)); + src0 += src0_stride; + src1 += src1_stride; + mask += 64; + i += 1; + } while (i < h); + } else { + do { + const __m256i s0A = yy_loadu_256(src0); + const __m256i s0B = yy_loadu_256(src0 + 16); + const __m256i s0C = yy_loadu_256(src0 + 32); + const __m256i s0D = yy_loadu_256(src0 + 48); + const __m256i s0E = yy_loadu_256(src0 + 64); + const __m256i s0F = yy_loadu_256(src0 + 80); + const __m256i s0G = yy_loadu_256(src0 + 96); + const __m256i s0H = yy_loadu_256(src0 + 112); + const __m256i s1A = yy_loadu_256(src1); + const __m256i s1B = yy_loadu_256(src1 + 16); + const __m256i s1C = yy_loadu_256(src1 + 32); + const __m256i s1D = yy_loadu_256(src1 + 48); + const __m256i s1E = yy_loadu_256(src1 + 64); + const __m256i s1F = yy_loadu_256(src1 + 80); + const __m256i s1G = yy_loadu_256(src1 + 96); + const __m256i s1H = yy_loadu_256(src1 + 112); + const __m256i m16A = + calc_mask_d16_inv_avx2(&s0A, &s1A, &_r, &y38, &y64, shift); + const __m256i m16B = + calc_mask_d16_inv_avx2(&s0B, &s1B, &_r, &y38, &y64, shift); + const __m256i m16C = + calc_mask_d16_inv_avx2(&s0C, &s1C, &_r, &y38, &y64, shift); + const __m256i m16D = + calc_mask_d16_inv_avx2(&s0D, &s1D, &_r, &y38, &y64, shift); + const __m256i m16E = + calc_mask_d16_inv_avx2(&s0E, &s1E, &_r, &y38, &y64, shift); + const __m256i m16F = + calc_mask_d16_inv_avx2(&s0F, &s1F, &_r, &y38, &y64, shift); + const __m256i m16G = + calc_mask_d16_inv_avx2(&s0G, &s1G, &_r, &y38, &y64, shift); + const __m256i m16H = + calc_mask_d16_inv_avx2(&s0H, &s1H, &_r, &y38, &y64, shift); + const __m256i m8AB = _mm256_packus_epi16(m16A, m16B); + const __m256i m8CD = _mm256_packus_epi16(m16C, m16D); + const __m256i m8EF = _mm256_packus_epi16(m16E, m16F); + const __m256i m8GH = _mm256_packus_epi16(m16G, m16H); + yy_storeu_256(mask, _mm256_permute4x64_epi64(m8AB, 0xd8)); + yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8CD, 0xd8)); + yy_storeu_256(mask + 64, _mm256_permute4x64_epi64(m8EF, 0xd8)); + yy_storeu_256(mask + 96, _mm256_permute4x64_epi64(m8GH, 0xd8)); + src0 += src0_stride; + src1 += src1_stride; + mask += 128; + i += 1; + } while (i < h); + } +} + +void av1_build_compound_diffwtd_mask_d16_avx2( + uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const CONV_BUF_TYPE *src0, + int src0_stride, const CONV_BUF_TYPE *src1, int src1_stride, int h, int w, + ConvolveParams *conv_params, int bd) { + const int shift = + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1 + (bd - 8); + // When rounding constant is added, there is a possibility of overflow. + // However that much precision is not required. Code should very well work for + // other values of DIFF_FACTOR_LOG2 and AOM_BLEND_A64_MAX_ALPHA as well. But + // there is a possibility of corner case bugs. + assert(DIFF_FACTOR_LOG2 == 4); + assert(AOM_BLEND_A64_MAX_ALPHA == 64); + + if (mask_type == DIFFWTD_38) { + build_compound_diffwtd_mask_d16_avx2(mask, src0, src0_stride, src1, + src1_stride, h, w, shift); + } else { + build_compound_diffwtd_mask_d16_inv_avx2(mask, src0, src0_stride, src1, + src1_stride, h, w, shift); + } +} + void av1_build_compound_diffwtd_mask_highbd_avx2( uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const uint8_t *src0, int src0_stride, const uint8_t *src1, int src1_stride, int h, int w, diff --git a/third_party/aom/av1/common/x86/selfguided_avx2.c b/third_party/aom/av1/common/x86/selfguided_avx2.c index 375def62e..0aaf1f454 100644 --- a/third_party/aom/av1/common/x86/selfguided_avx2.c +++ b/third_party/aom/av1/common/x86/selfguided_avx2.c @@ -546,17 +546,18 @@ static void final_filter_fast(int32_t *dst, int dst_stride, const int32_t *A, } } -void av1_selfguided_restoration_avx2(const uint8_t *dgd8, int width, int height, - int dgd_stride, int32_t *flt0, - int32_t *flt1, int flt_stride, - int sgr_params_idx, int bit_depth, - int highbd) { +int av1_selfguided_restoration_avx2(const uint8_t *dgd8, int width, int height, + int dgd_stride, int32_t *flt0, + int32_t *flt1, int flt_stride, + int sgr_params_idx, int bit_depth, + int highbd) { // The ALIGN_POWER_OF_TWO macro here ensures that column 1 of Atl, Btl, // Ctl and Dtl is 32-byte aligned. const int buf_elts = ALIGN_POWER_OF_TWO(RESTORATION_PROC_UNIT_PELS, 3); - DECLARE_ALIGNED(32, int32_t, - buf[4 * ALIGN_POWER_OF_TWO(RESTORATION_PROC_UNIT_PELS, 3)]); + int32_t *buf = aom_memalign( + 32, 4 * sizeof(*buf) * ALIGN_POWER_OF_TWO(RESTORATION_PROC_UNIT_PELS, 3)); + if (!buf) return -1; const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ; const int height_ext = height + 2 * SGRPROJ_BORDER_VERT; @@ -625,6 +626,8 @@ void av1_selfguided_restoration_avx2(const uint8_t *dgd8, int width, int height, final_filter(flt1, flt_stride, A, B, buf_stride, dgd8, dgd_stride, width, height, highbd); } + aom_free(buf); + return 0; } void apply_selfguided_restoration_avx2(const uint8_t *dat8, int width, @@ -635,8 +638,10 @@ void apply_selfguided_restoration_avx2(const uint8_t *dat8, int width, int32_t *flt0 = tmpbuf; int32_t *flt1 = flt0 + RESTORATION_UNITPELS_MAX; assert(width * height <= RESTORATION_UNITPELS_MAX); - av1_selfguided_restoration_avx2(dat8, width, height, stride, flt0, flt1, - width, eps, bit_depth, highbd); + const int ret = av1_selfguided_restoration_avx2( + dat8, width, height, stride, flt0, flt1, width, eps, bit_depth, highbd); + (void)ret; + assert(!ret); const sgr_params_type *const params = &sgr_params[eps]; int xq[2]; decode_xq(xqd, xq, params); diff --git a/third_party/aom/av1/common/x86/selfguided_sse4.c b/third_party/aom/av1/common/x86/selfguided_sse4.c index c64150b9d..ea3f6d942 100644 --- a/third_party/aom/av1/common/x86/selfguided_sse4.c +++ b/third_party/aom/av1/common/x86/selfguided_sse4.c @@ -499,13 +499,15 @@ static void final_filter_fast(int32_t *dst, int dst_stride, const int32_t *A, } } -void av1_selfguided_restoration_sse4_1(const uint8_t *dgd8, int width, - int height, int dgd_stride, - int32_t *flt0, int32_t *flt1, - int flt_stride, int sgr_params_idx, - int bit_depth, int highbd) { - DECLARE_ALIGNED(16, int32_t, buf[4 * RESTORATION_PROC_UNIT_PELS]); - memset(buf, 0, sizeof(buf)); +int av1_selfguided_restoration_sse4_1(const uint8_t *dgd8, int width, + int height, int dgd_stride, int32_t *flt0, + int32_t *flt1, int flt_stride, + int sgr_params_idx, int bit_depth, + int highbd) { + int32_t *buf = (int32_t *)aom_memalign( + 16, 4 * sizeof(*buf) * RESTORATION_PROC_UNIT_PELS); + if (!buf) return -1; + memset(buf, 0, 4 * sizeof(*buf) * RESTORATION_PROC_UNIT_PELS); const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ; const int height_ext = height + 2 * SGRPROJ_BORDER_VERT; @@ -574,6 +576,8 @@ void av1_selfguided_restoration_sse4_1(const uint8_t *dgd8, int width, final_filter(flt1, flt_stride, A, B, buf_stride, dgd8, dgd_stride, width, height, highbd); } + aom_free(buf); + return 0; } void apply_selfguided_restoration_sse4_1(const uint8_t *dat8, int width, @@ -584,8 +588,10 @@ void apply_selfguided_restoration_sse4_1(const uint8_t *dat8, int width, int32_t *flt0 = tmpbuf; int32_t *flt1 = flt0 + RESTORATION_UNITPELS_MAX; assert(width * height <= RESTORATION_UNITPELS_MAX); - av1_selfguided_restoration_sse4_1(dat8, width, height, stride, flt0, flt1, - width, eps, bit_depth, highbd); + const int ret = av1_selfguided_restoration_sse4_1( + dat8, width, height, stride, flt0, flt1, width, eps, bit_depth, highbd); + (void)ret; + assert(!ret); const sgr_params_type *const params = &sgr_params[eps]; int xq[2]; decode_xq(xqd, xq, params); diff --git a/third_party/aom/av1/common/x86/warp_plane_sse4.c b/third_party/aom/av1/common/x86/warp_plane_sse4.c index efc542cbf..b810cea2e 100644 --- a/third_party/aom/av1/common/x86/warp_plane_sse4.c +++ b/third_party/aom/av1/common/x86/warp_plane_sse4.c @@ -203,15 +203,72 @@ static const uint8_t even_mask[16] = { 0, 2, 2, 4, 4, 6, 6, 8, static const uint8_t odd_mask[16] = { 1, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15, 0 }; -static INLINE void horizontal_filter(__m128i src, __m128i *tmp, int sx, - int alpha, int k, +static const uint8_t shuffle_alpha0_mask01[16] = { 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1 }; + +static const uint8_t shuffle_alpha0_mask23[16] = { 2, 3, 2, 3, 2, 3, 2, 3, + 2, 3, 2, 3, 2, 3, 2, 3 }; + +static const uint8_t shuffle_alpha0_mask45[16] = { 4, 5, 4, 5, 4, 5, 4, 5, + 4, 5, 4, 5, 4, 5, 4, 5 }; + +static const uint8_t shuffle_alpha0_mask67[16] = { 6, 7, 6, 7, 6, 7, 6, 7, + 6, 7, 6, 7, 6, 7, 6, 7 }; + +static const uint8_t shuffle_gamma0_mask0[16] = { 0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 0, 1, 2, 3 }; +static const uint8_t shuffle_gamma0_mask1[16] = { 4, 5, 6, 7, 4, 5, 6, 7, + 4, 5, 6, 7, 4, 5, 6, 7 }; +static const uint8_t shuffle_gamma0_mask2[16] = { 8, 9, 10, 11, 8, 9, 10, 11, + 8, 9, 10, 11, 8, 9, 10, 11 }; +static const uint8_t shuffle_gamma0_mask3[16] = { + 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15 +}; + +static INLINE void filter_src_pixels(__m128i src, __m128i *tmp, __m128i *coeff, const int offset_bits_horiz, - const int reduce_bits_horiz) { + const int reduce_bits_horiz, int k) { const __m128i src_even = _mm_shuffle_epi8(src, _mm_loadu_si128((__m128i *)even_mask)); const __m128i src_odd = _mm_shuffle_epi8(src, _mm_loadu_si128((__m128i *)odd_mask)); + // The pixel order we need for 'src' is: + // 0 2 2 4 4 6 6 8 1 3 3 5 5 7 7 9 + const __m128i src_02 = _mm_unpacklo_epi64(src_even, src_odd); + const __m128i res_02 = _mm_maddubs_epi16(src_02, coeff[0]); + // 4 6 6 8 8 10 10 12 5 7 7 9 9 11 11 13 + const __m128i src_46 = _mm_unpacklo_epi64(_mm_srli_si128(src_even, 4), + _mm_srli_si128(src_odd, 4)); + const __m128i res_46 = _mm_maddubs_epi16(src_46, coeff[1]); + // 1 3 3 5 5 7 7 9 2 4 4 6 6 8 8 10 + const __m128i src_13 = + _mm_unpacklo_epi64(src_odd, _mm_srli_si128(src_even, 2)); + const __m128i res_13 = _mm_maddubs_epi16(src_13, coeff[2]); + // 5 7 7 9 9 11 11 13 6 8 8 10 10 12 12 14 + const __m128i src_57 = _mm_unpacklo_epi64(_mm_srli_si128(src_odd, 4), + _mm_srli_si128(src_even, 6)); + const __m128i res_57 = _mm_maddubs_epi16(src_57, coeff[3]); + + const __m128i round_const = _mm_set1_epi16((1 << offset_bits_horiz) + + ((1 << reduce_bits_horiz) >> 1)); + // Note: The values res_02 + res_46 and res_13 + res_57 both + // fit into int16s at this point, but their sum may be too wide to fit + // into an int16. However, once we also add round_const, the sum of + // all of these fits into a uint16. + // + // The wrapping behaviour of _mm_add_* is used here to make sure we + // get the correct result despite converting between different + // (implicit) types. + const __m128i res_even = _mm_add_epi16(res_02, res_46); + const __m128i res_odd = _mm_add_epi16(res_13, res_57); + const __m128i res = + _mm_add_epi16(_mm_add_epi16(res_even, res_odd), round_const); + tmp[k + 7] = _mm_srl_epi16(res, _mm_cvtsi32_si128(reduce_bits_horiz)); +} + +static INLINE void prepare_horizontal_filter_coeff(int alpha, int sx, + __m128i *coeff) { // Filter even-index pixels const __m128i tmp_0 = _mm_loadl_epi64( (__m128i *)&filter_8bit[(sx + 0 * alpha) >> WARPEDDIFF_PREC_BITS]); @@ -249,47 +306,504 @@ static INLINE void horizontal_filter(__m128i src, __m128i *tmp, int sx, const __m128i tmp_15 = _mm_unpackhi_epi32(tmp_9, tmp_11); // Coeffs 0 2 for pixels 0 2 4 6 1 3 5 7 - const __m128i coeff_02 = _mm_unpacklo_epi64(tmp_12, tmp_14); + coeff[0] = _mm_unpacklo_epi64(tmp_12, tmp_14); // Coeffs 4 6 for pixels 0 2 4 6 1 3 5 7 - const __m128i coeff_46 = _mm_unpackhi_epi64(tmp_12, tmp_14); + coeff[1] = _mm_unpackhi_epi64(tmp_12, tmp_14); // Coeffs 1 3 for pixels 0 2 4 6 1 3 5 7 - const __m128i coeff_13 = _mm_unpacklo_epi64(tmp_13, tmp_15); + coeff[2] = _mm_unpacklo_epi64(tmp_13, tmp_15); // Coeffs 5 7 for pixels 0 2 4 6 1 3 5 7 - const __m128i coeff_57 = _mm_unpackhi_epi64(tmp_13, tmp_15); + coeff[3] = _mm_unpackhi_epi64(tmp_13, tmp_15); +} - // The pixel order we need for 'src' is: - // 0 2 2 4 4 6 6 8 1 3 3 5 5 7 7 9 - const __m128i src_02 = _mm_unpacklo_epi64(src_even, src_odd); - const __m128i res_02 = _mm_maddubs_epi16(src_02, coeff_02); - // 4 6 6 8 8 10 10 12 5 7 7 9 9 11 11 13 - const __m128i src_46 = _mm_unpacklo_epi64(_mm_srli_si128(src_even, 4), - _mm_srli_si128(src_odd, 4)); - const __m128i res_46 = _mm_maddubs_epi16(src_46, coeff_46); - // 1 3 3 5 5 7 7 9 2 4 4 6 6 8 8 10 - const __m128i src_13 = - _mm_unpacklo_epi64(src_odd, _mm_srli_si128(src_even, 2)); - const __m128i res_13 = _mm_maddubs_epi16(src_13, coeff_13); - // 5 7 7 9 9 11 11 13 6 8 8 10 10 12 12 14 - const __m128i src_57 = _mm_unpacklo_epi64(_mm_srli_si128(src_odd, 4), - _mm_srli_si128(src_even, 6)); - const __m128i res_57 = _mm_maddubs_epi16(src_57, coeff_57); +static INLINE void prepare_horizontal_filter_coeff_alpha0(int sx, + __m128i *coeff) { + // Filter even-index pixels + const __m128i tmp_0 = + _mm_loadl_epi64((__m128i *)&filter_8bit[sx >> WARPEDDIFF_PREC_BITS]); - const __m128i round_const = _mm_set1_epi16((1 << offset_bits_horiz) + - ((1 << reduce_bits_horiz) >> 1)); + // Coeffs 0 2 for pixels 0 2 4 6 1 3 5 7 + coeff[0] = _mm_shuffle_epi8( + tmp_0, _mm_loadu_si128((__m128i *)shuffle_alpha0_mask01)); + // Coeffs 4 6 for pixels 0 2 4 6 1 3 5 7 + coeff[1] = _mm_shuffle_epi8( + tmp_0, _mm_loadu_si128((__m128i *)shuffle_alpha0_mask23)); + // Coeffs 1 3 for pixels 0 2 4 6 1 3 5 7 + coeff[2] = _mm_shuffle_epi8( + tmp_0, _mm_loadu_si128((__m128i *)shuffle_alpha0_mask45)); + // Coeffs 5 7 for pixels 0 2 4 6 1 3 5 7 + coeff[3] = _mm_shuffle_epi8( + tmp_0, _mm_loadu_si128((__m128i *)shuffle_alpha0_mask67)); +} - // Note: The values res_02 + res_46 and res_13 + res_57 both - // fit into int16s at this point, but their sum may be too wide to fit - // into an int16. However, once we also add round_const, the sum of - // all of these fits into a uint16. - // - // The wrapping behaviour of _mm_add_* is used here to make sure we - // get the correct result despite converting between different - // (implicit) types. - const __m128i res_even = _mm_add_epi16(res_02, res_46); - const __m128i res_odd = _mm_add_epi16(res_13, res_57); - const __m128i res = - _mm_add_epi16(_mm_add_epi16(res_even, res_odd), round_const); - tmp[k + 7] = _mm_srl_epi16(res, _mm_cvtsi32_si128(reduce_bits_horiz)); +static INLINE void horizontal_filter(__m128i src, __m128i *tmp, int sx, + int alpha, int k, + const int offset_bits_horiz, + const int reduce_bits_horiz) { + __m128i coeff[4]; + prepare_horizontal_filter_coeff(alpha, sx, coeff); + filter_src_pixels(src, tmp, coeff, offset_bits_horiz, reduce_bits_horiz, k); +} + +static INLINE void warp_horizontal_filter(const uint8_t *ref, __m128i *tmp, + int stride, int32_t ix4, int32_t iy4, + int32_t sx4, int alpha, int beta, + int p_height, int height, int i, + const int offset_bits_horiz, + const int reduce_bits_horiz) { + int k; + for (k = -7; k < AOMMIN(8, p_height - i); ++k) { + int iy = iy4 + k; + if (iy < 0) + iy = 0; + else if (iy > height - 1) + iy = height - 1; + int sx = sx4 + beta * (k + 4); + + // Load source pixels + const __m128i src = + _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7)); + horizontal_filter(src, tmp, sx, alpha, k, offset_bits_horiz, + reduce_bits_horiz); + } +} + +static INLINE void warp_horizontal_filter_alpha0( + const uint8_t *ref, __m128i *tmp, int stride, int32_t ix4, int32_t iy4, + int32_t sx4, int alpha, int beta, int p_height, int height, int i, + const int offset_bits_horiz, const int reduce_bits_horiz) { + (void)alpha; + int k; + for (k = -7; k < AOMMIN(8, p_height - i); ++k) { + int iy = iy4 + k; + if (iy < 0) + iy = 0; + else if (iy > height - 1) + iy = height - 1; + int sx = sx4 + beta * (k + 4); + + // Load source pixels + const __m128i src = + _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7)); + + __m128i coeff[4]; + prepare_horizontal_filter_coeff_alpha0(sx, coeff); + filter_src_pixels(src, tmp, coeff, offset_bits_horiz, reduce_bits_horiz, k); + } +} + +static INLINE void warp_horizontal_filter_beta0( + const uint8_t *ref, __m128i *tmp, int stride, int32_t ix4, int32_t iy4, + int32_t sx4, int alpha, int beta, int p_height, int height, int i, + const int offset_bits_horiz, const int reduce_bits_horiz) { + (void)beta; + int k; + __m128i coeff[4]; + prepare_horizontal_filter_coeff(alpha, sx4, coeff); + + for (k = -7; k < AOMMIN(8, p_height - i); ++k) { + int iy = iy4 + k; + if (iy < 0) + iy = 0; + else if (iy > height - 1) + iy = height - 1; + + // Load source pixels + const __m128i src = + _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7)); + filter_src_pixels(src, tmp, coeff, offset_bits_horiz, reduce_bits_horiz, k); + } +} + +static INLINE void warp_horizontal_filter_alpha0_beta0( + const uint8_t *ref, __m128i *tmp, int stride, int32_t ix4, int32_t iy4, + int32_t sx4, int alpha, int beta, int p_height, int height, int i, + const int offset_bits_horiz, const int reduce_bits_horiz) { + (void)beta; + (void)alpha; + int k; + + __m128i coeff[4]; + prepare_horizontal_filter_coeff_alpha0(sx4, coeff); + + for (k = -7; k < AOMMIN(8, p_height - i); ++k) { + int iy = iy4 + k; + if (iy < 0) + iy = 0; + else if (iy > height - 1) + iy = height - 1; + + // Load source pixels + const __m128i src = + _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7)); + filter_src_pixels(src, tmp, coeff, offset_bits_horiz, reduce_bits_horiz, k); + } +} + +static INLINE void unpack_weights_and_set_round_const( + ConvolveParams *conv_params, const int round_bits, const int offset_bits, + __m128i *res_sub_const, __m128i *round_bits_const, __m128i *wt) { + *res_sub_const = + _mm_set1_epi16(-(1 << (offset_bits - conv_params->round_1)) - + (1 << (offset_bits - conv_params->round_1 - 1))); + *round_bits_const = _mm_set1_epi16(((1 << round_bits) >> 1)); + + const int w0 = conv_params->fwd_offset; + const int w1 = conv_params->bck_offset; + const __m128i wt0 = _mm_set1_epi16(w0); + const __m128i wt1 = _mm_set1_epi16(w1); + *wt = _mm_unpacklo_epi16(wt0, wt1); +} + +static INLINE void prepare_vertical_filter_coeffs(int gamma, int sy, + __m128i *coeffs) { + const __m128i tmp_0 = _mm_loadu_si128( + (__m128i *)(warped_filter + ((sy + 0 * gamma) >> WARPEDDIFF_PREC_BITS))); + const __m128i tmp_2 = _mm_loadu_si128( + (__m128i *)(warped_filter + ((sy + 2 * gamma) >> WARPEDDIFF_PREC_BITS))); + const __m128i tmp_4 = _mm_loadu_si128( + (__m128i *)(warped_filter + ((sy + 4 * gamma) >> WARPEDDIFF_PREC_BITS))); + const __m128i tmp_6 = _mm_loadu_si128( + (__m128i *)(warped_filter + ((sy + 6 * gamma) >> WARPEDDIFF_PREC_BITS))); + + const __m128i tmp_8 = _mm_unpacklo_epi32(tmp_0, tmp_2); + const __m128i tmp_10 = _mm_unpacklo_epi32(tmp_4, tmp_6); + const __m128i tmp_12 = _mm_unpackhi_epi32(tmp_0, tmp_2); + const __m128i tmp_14 = _mm_unpackhi_epi32(tmp_4, tmp_6); + + // even coeffs + coeffs[0] = _mm_unpacklo_epi64(tmp_8, tmp_10); + coeffs[1] = _mm_unpackhi_epi64(tmp_8, tmp_10); + coeffs[2] = _mm_unpacklo_epi64(tmp_12, tmp_14); + coeffs[3] = _mm_unpackhi_epi64(tmp_12, tmp_14); + + const __m128i tmp_1 = _mm_loadu_si128( + (__m128i *)(warped_filter + ((sy + 1 * gamma) >> WARPEDDIFF_PREC_BITS))); + const __m128i tmp_3 = _mm_loadu_si128( + (__m128i *)(warped_filter + ((sy + 3 * gamma) >> WARPEDDIFF_PREC_BITS))); + const __m128i tmp_5 = _mm_loadu_si128( + (__m128i *)(warped_filter + ((sy + 5 * gamma) >> WARPEDDIFF_PREC_BITS))); + const __m128i tmp_7 = _mm_loadu_si128( + (__m128i *)(warped_filter + ((sy + 7 * gamma) >> WARPEDDIFF_PREC_BITS))); + + const __m128i tmp_9 = _mm_unpacklo_epi32(tmp_1, tmp_3); + const __m128i tmp_11 = _mm_unpacklo_epi32(tmp_5, tmp_7); + const __m128i tmp_13 = _mm_unpackhi_epi32(tmp_1, tmp_3); + const __m128i tmp_15 = _mm_unpackhi_epi32(tmp_5, tmp_7); + + // odd coeffs + coeffs[4] = _mm_unpacklo_epi64(tmp_9, tmp_11); + coeffs[5] = _mm_unpackhi_epi64(tmp_9, tmp_11); + coeffs[6] = _mm_unpacklo_epi64(tmp_13, tmp_15); + coeffs[7] = _mm_unpackhi_epi64(tmp_13, tmp_15); +} + +static INLINE void prepare_vertical_filter_coeffs_gamma0(int sy, + __m128i *coeffs) { + const __m128i tmp_0 = _mm_loadu_si128( + (__m128i *)(warped_filter + (sy >> WARPEDDIFF_PREC_BITS))); + + // even coeffs + coeffs[0] = + _mm_shuffle_epi8(tmp_0, _mm_loadu_si128((__m128i *)shuffle_gamma0_mask0)); + coeffs[1] = + _mm_shuffle_epi8(tmp_0, _mm_loadu_si128((__m128i *)shuffle_gamma0_mask1)); + coeffs[2] = + _mm_shuffle_epi8(tmp_0, _mm_loadu_si128((__m128i *)shuffle_gamma0_mask2)); + coeffs[3] = + _mm_shuffle_epi8(tmp_0, _mm_loadu_si128((__m128i *)shuffle_gamma0_mask3)); + + // odd coeffs + coeffs[4] = coeffs[0]; + coeffs[5] = coeffs[1]; + coeffs[6] = coeffs[2]; + coeffs[7] = coeffs[3]; +} + +static INLINE void filter_src_pixels_vertical(__m128i *tmp, __m128i *coeffs, + __m128i *res_lo, __m128i *res_hi, + int k) { + // Load from tmp and rearrange pairs of consecutive rows into the + // column order 0 0 2 2 4 4 6 6; 1 1 3 3 5 5 7 7 + const __m128i *src = tmp + (k + 4); + const __m128i src_0 = _mm_unpacklo_epi16(src[0], src[1]); + const __m128i src_2 = _mm_unpacklo_epi16(src[2], src[3]); + const __m128i src_4 = _mm_unpacklo_epi16(src[4], src[5]); + const __m128i src_6 = _mm_unpacklo_epi16(src[6], src[7]); + + const __m128i res_0 = _mm_madd_epi16(src_0, coeffs[0]); + const __m128i res_2 = _mm_madd_epi16(src_2, coeffs[1]); + const __m128i res_4 = _mm_madd_epi16(src_4, coeffs[2]); + const __m128i res_6 = _mm_madd_epi16(src_6, coeffs[3]); + + const __m128i res_even = + _mm_add_epi32(_mm_add_epi32(res_0, res_2), _mm_add_epi32(res_4, res_6)); + + // Filter odd-index pixels + const __m128i src_1 = _mm_unpackhi_epi16(src[0], src[1]); + const __m128i src_3 = _mm_unpackhi_epi16(src[2], src[3]); + const __m128i src_5 = _mm_unpackhi_epi16(src[4], src[5]); + const __m128i src_7 = _mm_unpackhi_epi16(src[6], src[7]); + + const __m128i res_1 = _mm_madd_epi16(src_1, coeffs[4]); + const __m128i res_3 = _mm_madd_epi16(src_3, coeffs[5]); + const __m128i res_5 = _mm_madd_epi16(src_5, coeffs[6]); + const __m128i res_7 = _mm_madd_epi16(src_7, coeffs[7]); + + const __m128i res_odd = + _mm_add_epi32(_mm_add_epi32(res_1, res_3), _mm_add_epi32(res_5, res_7)); + + // Rearrange pixels back into the order 0 ... 7 + *res_lo = _mm_unpacklo_epi32(res_even, res_odd); + *res_hi = _mm_unpackhi_epi32(res_even, res_odd); +} + +static INLINE void store_vertical_filter_output( + __m128i *res_lo, __m128i *res_hi, const __m128i *res_add_const, + const __m128i *wt, const __m128i *res_sub_const, __m128i *round_bits_const, + uint8_t *pred, ConvolveParams *conv_params, int i, int j, int k, + const int reduce_bits_vert, int p_stride, int p_width, + const int round_bits) { + __m128i res_lo_1 = *res_lo; + __m128i res_hi_1 = *res_hi; + + if (conv_params->is_compound) { + __m128i *const p = + (__m128i *)&conv_params->dst[(i + k + 4) * conv_params->dst_stride + j]; + res_lo_1 = _mm_srai_epi32(_mm_add_epi32(res_lo_1, *res_add_const), + reduce_bits_vert); + const __m128i temp_lo_16 = _mm_packus_epi32(res_lo_1, res_lo_1); + __m128i res_lo_16; + if (conv_params->do_average) { + __m128i *const dst8 = (__m128i *)&pred[(i + k + 4) * p_stride + j]; + const __m128i p_16 = _mm_loadl_epi64(p); + + if (conv_params->use_jnt_comp_avg) { + const __m128i p_16_lo = _mm_unpacklo_epi16(p_16, temp_lo_16); + const __m128i wt_res_lo = _mm_madd_epi16(p_16_lo, *wt); + const __m128i shifted_32 = + _mm_srai_epi32(wt_res_lo, DIST_PRECISION_BITS); + res_lo_16 = _mm_packus_epi32(shifted_32, shifted_32); + } else { + res_lo_16 = _mm_srai_epi16(_mm_add_epi16(p_16, temp_lo_16), 1); + } + + res_lo_16 = _mm_add_epi16(res_lo_16, *res_sub_const); + + res_lo_16 = _mm_srai_epi16(_mm_add_epi16(res_lo_16, *round_bits_const), + round_bits); + __m128i res_8_lo = _mm_packus_epi16(res_lo_16, res_lo_16); + *(uint32_t *)dst8 = _mm_cvtsi128_si32(res_8_lo); + } else { + _mm_storel_epi64(p, temp_lo_16); + } + if (p_width > 4) { + __m128i *const p4 = + (__m128i *)&conv_params + ->dst[(i + k + 4) * conv_params->dst_stride + j + 4]; + res_hi_1 = _mm_srai_epi32(_mm_add_epi32(res_hi_1, *res_add_const), + reduce_bits_vert); + const __m128i temp_hi_16 = _mm_packus_epi32(res_hi_1, res_hi_1); + __m128i res_hi_16; + + if (conv_params->do_average) { + __m128i *const dst8_4 = + (__m128i *)&pred[(i + k + 4) * p_stride + j + 4]; + const __m128i p4_16 = _mm_loadl_epi64(p4); + + if (conv_params->use_jnt_comp_avg) { + const __m128i p_16_hi = _mm_unpacklo_epi16(p4_16, temp_hi_16); + const __m128i wt_res_hi = _mm_madd_epi16(p_16_hi, *wt); + const __m128i shifted_32 = + _mm_srai_epi32(wt_res_hi, DIST_PRECISION_BITS); + res_hi_16 = _mm_packus_epi32(shifted_32, shifted_32); + } else { + res_hi_16 = _mm_srai_epi16(_mm_add_epi16(p4_16, temp_hi_16), 1); + } + res_hi_16 = _mm_add_epi16(res_hi_16, *res_sub_const); + + res_hi_16 = _mm_srai_epi16(_mm_add_epi16(res_hi_16, *round_bits_const), + round_bits); + __m128i res_8_hi = _mm_packus_epi16(res_hi_16, res_hi_16); + *(uint32_t *)dst8_4 = _mm_cvtsi128_si32(res_8_hi); + + } else { + _mm_storel_epi64(p4, temp_hi_16); + } + } + } else { + const __m128i res_lo_round = _mm_srai_epi32( + _mm_add_epi32(res_lo_1, *res_add_const), reduce_bits_vert); + const __m128i res_hi_round = _mm_srai_epi32( + _mm_add_epi32(res_hi_1, *res_add_const), reduce_bits_vert); + + const __m128i res_16bit = _mm_packs_epi32(res_lo_round, res_hi_round); + __m128i res_8bit = _mm_packus_epi16(res_16bit, res_16bit); + + // Store, blending with 'pred' if needed + __m128i *const p = (__m128i *)&pred[(i + k + 4) * p_stride + j]; + + // Note: If we're outputting a 4x4 block, we need to be very careful + // to only output 4 pixels at this point, to avoid encode/decode + // mismatches when encoding with multiple threads. + if (p_width == 4) { + *(uint32_t *)p = _mm_cvtsi128_si32(res_8bit); + } else { + _mm_storel_epi64(p, res_8bit); + } + } +} + +static INLINE void warp_vertical_filter( + uint8_t *pred, __m128i *tmp, ConvolveParams *conv_params, int16_t gamma, + int16_t delta, int p_height, int p_stride, int p_width, int i, int j, + int sy4, const int reduce_bits_vert, const __m128i *res_add_const, + const int round_bits, const int offset_bits) { + int k; + __m128i res_sub_const, round_bits_const, wt; + unpack_weights_and_set_round_const(conv_params, round_bits, offset_bits, + &res_sub_const, &round_bits_const, &wt); + // Vertical filter + for (k = -4; k < AOMMIN(4, p_height - i - 4); ++k) { + int sy = sy4 + delta * (k + 4); + + __m128i coeffs[8]; + prepare_vertical_filter_coeffs(gamma, sy, coeffs); + + __m128i res_lo; + __m128i res_hi; + filter_src_pixels_vertical(tmp, coeffs, &res_lo, &res_hi, k); + + store_vertical_filter_output(&res_lo, &res_hi, res_add_const, &wt, + &res_sub_const, &round_bits_const, pred, + conv_params, i, j, k, reduce_bits_vert, + p_stride, p_width, round_bits); + } +} + +static INLINE void warp_vertical_filter_gamma0( + uint8_t *pred, __m128i *tmp, ConvolveParams *conv_params, int16_t gamma, + int16_t delta, int p_height, int p_stride, int p_width, int i, int j, + int sy4, const int reduce_bits_vert, const __m128i *res_add_const, + const int round_bits, const int offset_bits) { + int k; + (void)gamma; + __m128i res_sub_const, round_bits_const, wt; + unpack_weights_and_set_round_const(conv_params, round_bits, offset_bits, + &res_sub_const, &round_bits_const, &wt); + // Vertical filter + for (k = -4; k < AOMMIN(4, p_height - i - 4); ++k) { + int sy = sy4 + delta * (k + 4); + + __m128i coeffs[8]; + prepare_vertical_filter_coeffs_gamma0(sy, coeffs); + + __m128i res_lo; + __m128i res_hi; + filter_src_pixels_vertical(tmp, coeffs, &res_lo, &res_hi, k); + + store_vertical_filter_output(&res_lo, &res_hi, res_add_const, &wt, + &res_sub_const, &round_bits_const, pred, + conv_params, i, j, k, reduce_bits_vert, + p_stride, p_width, round_bits); + } +} + +static INLINE void warp_vertical_filter_delta0( + uint8_t *pred, __m128i *tmp, ConvolveParams *conv_params, int16_t gamma, + int16_t delta, int p_height, int p_stride, int p_width, int i, int j, + int sy4, const int reduce_bits_vert, const __m128i *res_add_const, + const int round_bits, const int offset_bits) { + (void)delta; + int k; + __m128i res_sub_const, round_bits_const, wt; + unpack_weights_and_set_round_const(conv_params, round_bits, offset_bits, + &res_sub_const, &round_bits_const, &wt); + + __m128i coeffs[8]; + prepare_vertical_filter_coeffs(gamma, sy4, coeffs); + // Vertical filter + for (k = -4; k < AOMMIN(4, p_height - i - 4); ++k) { + __m128i res_lo; + __m128i res_hi; + filter_src_pixels_vertical(tmp, coeffs, &res_lo, &res_hi, k); + + store_vertical_filter_output(&res_lo, &res_hi, res_add_const, &wt, + &res_sub_const, &round_bits_const, pred, + conv_params, i, j, k, reduce_bits_vert, + p_stride, p_width, round_bits); + } +} + +static INLINE void warp_vertical_filter_gamma0_delta0( + uint8_t *pred, __m128i *tmp, ConvolveParams *conv_params, int16_t gamma, + int16_t delta, int p_height, int p_stride, int p_width, int i, int j, + int sy4, const int reduce_bits_vert, const __m128i *res_add_const, + const int round_bits, const int offset_bits) { + (void)delta; + (void)gamma; + int k; + __m128i res_sub_const, round_bits_const, wt; + unpack_weights_and_set_round_const(conv_params, round_bits, offset_bits, + &res_sub_const, &round_bits_const, &wt); + + __m128i coeffs[8]; + prepare_vertical_filter_coeffs_gamma0(sy4, coeffs); + // Vertical filter + for (k = -4; k < AOMMIN(4, p_height - i - 4); ++k) { + __m128i res_lo; + __m128i res_hi; + filter_src_pixels_vertical(tmp, coeffs, &res_lo, &res_hi, k); + + store_vertical_filter_output(&res_lo, &res_hi, res_add_const, &wt, + &res_sub_const, &round_bits_const, pred, + conv_params, i, j, k, reduce_bits_vert, + p_stride, p_width, round_bits); + } +} + +static INLINE void prepare_warp_vertical_filter( + uint8_t *pred, __m128i *tmp, ConvolveParams *conv_params, int16_t gamma, + int16_t delta, int p_height, int p_stride, int p_width, int i, int j, + int sy4, const int reduce_bits_vert, const __m128i *res_add_const, + const int round_bits, const int offset_bits) { + if (gamma == 0 && delta == 0) + warp_vertical_filter_gamma0_delta0( + pred, tmp, conv_params, gamma, delta, p_height, p_stride, p_width, i, j, + sy4, reduce_bits_vert, res_add_const, round_bits, offset_bits); + else if (gamma == 0 && delta != 0) + warp_vertical_filter_gamma0(pred, tmp, conv_params, gamma, delta, p_height, + p_stride, p_width, i, j, sy4, reduce_bits_vert, + res_add_const, round_bits, offset_bits); + else if (gamma != 0 && delta == 0) + warp_vertical_filter_delta0(pred, tmp, conv_params, gamma, delta, p_height, + p_stride, p_width, i, j, sy4, reduce_bits_vert, + res_add_const, round_bits, offset_bits); + else + warp_vertical_filter(pred, tmp, conv_params, gamma, delta, p_height, + p_stride, p_width, i, j, sy4, reduce_bits_vert, + res_add_const, round_bits, offset_bits); +} + +static INLINE void prepare_warp_horizontal_filter( + const uint8_t *ref, __m128i *tmp, int stride, int32_t ix4, int32_t iy4, + int32_t sx4, int alpha, int beta, int p_height, int height, int i, + const int offset_bits_horiz, const int reduce_bits_horiz) { + if (alpha == 0 && beta == 0) + warp_horizontal_filter_alpha0_beta0(ref, tmp, stride, ix4, iy4, sx4, alpha, + beta, p_height, height, i, + offset_bits_horiz, reduce_bits_horiz); + else if (alpha == 0 && beta != 0) + warp_horizontal_filter_alpha0(ref, tmp, stride, ix4, iy4, sx4, alpha, beta, + p_height, height, i, offset_bits_horiz, + reduce_bits_horiz); + else if (alpha != 0 && beta == 0) + warp_horizontal_filter_beta0(ref, tmp, stride, ix4, iy4, sx4, alpha, beta, + p_height, height, i, offset_bits_horiz, + reduce_bits_horiz); + else + warp_horizontal_filter(ref, tmp, stride, ix4, iy4, sx4, alpha, beta, + p_height, height, i, offset_bits_horiz, + reduce_bits_horiz); } void av1_warp_affine_sse4_1(const int32_t *mat, const uint8_t *ref, int width, @@ -309,24 +823,12 @@ void av1_warp_affine_sse4_1(const int32_t *mat, const uint8_t *ref, int width, assert(IMPLIES(conv_params->is_compound, conv_params->dst != NULL)); const int offset_bits_vert = bd + 2 * FILTER_BITS - reduce_bits_horiz; - const __m128i reduce_bits_vert_shift = _mm_cvtsi32_si128(reduce_bits_vert); const __m128i reduce_bits_vert_const = _mm_set1_epi32(((1 << reduce_bits_vert) >> 1)); const __m128i res_add_const = _mm_set1_epi32(1 << offset_bits_vert); const int round_bits = 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1; const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0; - const __m128i res_sub_const = - _mm_set1_epi16(-(1 << (offset_bits - conv_params->round_1)) - - (1 << (offset_bits - conv_params->round_1 - 1))); - __m128i round_bits_shift = _mm_cvtsi32_si128(round_bits); - __m128i round_bits_const = _mm_set1_epi16(((1 << round_bits) >> 1)); - - const int w0 = conv_params->fwd_offset; - const int w1 = conv_params->bck_offset; - const __m128i wt0 = _mm_set1_epi16(w0); - const __m128i wt1 = _mm_set1_epi16(w1); - const __m128i wt = _mm_unpacklo_epi16(wt0, wt1); assert(IMPLIES(conv_params->do_average, conv_params->is_compound)); /* Note: For this code to work, the left/right frame borders need to be @@ -340,6 +842,13 @@ void av1_warp_affine_sse4_1(const int32_t *mat, const uint8_t *ref, int width, assert(ref[i * stride + width + j] == ref[i * stride + (width - 1)]); } }*/ + __m128i res_add_const_1; + if (conv_params->is_compound == 1) { + res_add_const_1 = _mm_add_epi32(reduce_bits_vert_const, res_add_const); + } else { + res_add_const_1 = _mm_set1_epi32(-(1 << (bd + reduce_bits_vert - 1)) + + ((1 << reduce_bits_vert) >> 1)); + } for (i = 0; i < p_height; i += 8) { for (j = 0; j < p_width; j += 8) { @@ -419,203 +928,15 @@ void av1_warp_affine_sse4_1(const int32_t *mat, const uint8_t *ref, int width, reduce_bits_horiz); } } else { - for (k = -7; k < AOMMIN(8, p_height - i); ++k) { - int iy = iy4 + k; - if (iy < 0) - iy = 0; - else if (iy > height - 1) - iy = height - 1; - int sx = sx4 + beta * (k + 4); - - // Load source pixels - const __m128i src = - _mm_loadu_si128((__m128i *)(ref + iy * stride + ix4 - 7)); - horizontal_filter(src, tmp, sx, alpha, k, offset_bits_horiz, - reduce_bits_horiz); - } + prepare_warp_horizontal_filter(ref, tmp, stride, ix4, iy4, sx4, alpha, + beta, p_height, height, i, + offset_bits_horiz, reduce_bits_horiz); } // Vertical filter - for (k = -4; k < AOMMIN(4, p_height - i - 4); ++k) { - int sy = sy4 + delta * (k + 4); - - // Load from tmp and rearrange pairs of consecutive rows into the - // column order 0 0 2 2 4 4 6 6; 1 1 3 3 5 5 7 7 - const __m128i *src = tmp + (k + 4); - const __m128i src_0 = _mm_unpacklo_epi16(src[0], src[1]); - const __m128i src_2 = _mm_unpacklo_epi16(src[2], src[3]); - const __m128i src_4 = _mm_unpacklo_epi16(src[4], src[5]); - const __m128i src_6 = _mm_unpacklo_epi16(src[6], src[7]); - - // Filter even-index pixels - const __m128i tmp_0 = _mm_loadu_si128( - (__m128i *)(warped_filter + - ((sy + 0 * gamma) >> WARPEDDIFF_PREC_BITS))); - const __m128i tmp_2 = _mm_loadu_si128( - (__m128i *)(warped_filter + - ((sy + 2 * gamma) >> WARPEDDIFF_PREC_BITS))); - const __m128i tmp_4 = _mm_loadu_si128( - (__m128i *)(warped_filter + - ((sy + 4 * gamma) >> WARPEDDIFF_PREC_BITS))); - const __m128i tmp_6 = _mm_loadu_si128( - (__m128i *)(warped_filter + - ((sy + 6 * gamma) >> WARPEDDIFF_PREC_BITS))); - - const __m128i tmp_8 = _mm_unpacklo_epi32(tmp_0, tmp_2); - const __m128i tmp_10 = _mm_unpacklo_epi32(tmp_4, tmp_6); - const __m128i tmp_12 = _mm_unpackhi_epi32(tmp_0, tmp_2); - const __m128i tmp_14 = _mm_unpackhi_epi32(tmp_4, tmp_6); - - const __m128i coeff_0 = _mm_unpacklo_epi64(tmp_8, tmp_10); - const __m128i coeff_2 = _mm_unpackhi_epi64(tmp_8, tmp_10); - const __m128i coeff_4 = _mm_unpacklo_epi64(tmp_12, tmp_14); - const __m128i coeff_6 = _mm_unpackhi_epi64(tmp_12, tmp_14); - - const __m128i res_0 = _mm_madd_epi16(src_0, coeff_0); - const __m128i res_2 = _mm_madd_epi16(src_2, coeff_2); - const __m128i res_4 = _mm_madd_epi16(src_4, coeff_4); - const __m128i res_6 = _mm_madd_epi16(src_6, coeff_6); - - const __m128i res_even = _mm_add_epi32(_mm_add_epi32(res_0, res_2), - _mm_add_epi32(res_4, res_6)); - - // Filter odd-index pixels - const __m128i src_1 = _mm_unpackhi_epi16(src[0], src[1]); - const __m128i src_3 = _mm_unpackhi_epi16(src[2], src[3]); - const __m128i src_5 = _mm_unpackhi_epi16(src[4], src[5]); - const __m128i src_7 = _mm_unpackhi_epi16(src[6], src[7]); - - const __m128i tmp_1 = _mm_loadu_si128( - (__m128i *)(warped_filter + - ((sy + 1 * gamma) >> WARPEDDIFF_PREC_BITS))); - const __m128i tmp_3 = _mm_loadu_si128( - (__m128i *)(warped_filter + - ((sy + 3 * gamma) >> WARPEDDIFF_PREC_BITS))); - const __m128i tmp_5 = _mm_loadu_si128( - (__m128i *)(warped_filter + - ((sy + 5 * gamma) >> WARPEDDIFF_PREC_BITS))); - const __m128i tmp_7 = _mm_loadu_si128( - (__m128i *)(warped_filter + - ((sy + 7 * gamma) >> WARPEDDIFF_PREC_BITS))); - - const __m128i tmp_9 = _mm_unpacklo_epi32(tmp_1, tmp_3); - const __m128i tmp_11 = _mm_unpacklo_epi32(tmp_5, tmp_7); - const __m128i tmp_13 = _mm_unpackhi_epi32(tmp_1, tmp_3); - const __m128i tmp_15 = _mm_unpackhi_epi32(tmp_5, tmp_7); - - const __m128i coeff_1 = _mm_unpacklo_epi64(tmp_9, tmp_11); - const __m128i coeff_3 = _mm_unpackhi_epi64(tmp_9, tmp_11); - const __m128i coeff_5 = _mm_unpacklo_epi64(tmp_13, tmp_15); - const __m128i coeff_7 = _mm_unpackhi_epi64(tmp_13, tmp_15); - - const __m128i res_1 = _mm_madd_epi16(src_1, coeff_1); - const __m128i res_3 = _mm_madd_epi16(src_3, coeff_3); - const __m128i res_5 = _mm_madd_epi16(src_5, coeff_5); - const __m128i res_7 = _mm_madd_epi16(src_7, coeff_7); - - const __m128i res_odd = _mm_add_epi32(_mm_add_epi32(res_1, res_3), - _mm_add_epi32(res_5, res_7)); - - // Rearrange pixels back into the order 0 ... 7 - __m128i res_lo = _mm_unpacklo_epi32(res_even, res_odd); - __m128i res_hi = _mm_unpackhi_epi32(res_even, res_odd); - - if (conv_params->is_compound) { - __m128i *const p = - (__m128i *)&conv_params - ->dst[(i + k + 4) * conv_params->dst_stride + j]; - res_lo = _mm_add_epi32(res_lo, res_add_const); - res_lo = _mm_sra_epi32(_mm_add_epi32(res_lo, reduce_bits_vert_const), - reduce_bits_vert_shift); - const __m128i temp_lo_16 = _mm_packus_epi32(res_lo, res_lo); - __m128i res_lo_16; - if (conv_params->do_average) { - __m128i *const dst8 = (__m128i *)&pred[(i + k + 4) * p_stride + j]; - const __m128i p_16 = _mm_loadl_epi64(p); - - if (conv_params->use_jnt_comp_avg) { - const __m128i p_16_lo = _mm_unpacklo_epi16(p_16, temp_lo_16); - const __m128i wt_res_lo = _mm_madd_epi16(p_16_lo, wt); - const __m128i shifted_32 = - _mm_srai_epi32(wt_res_lo, DIST_PRECISION_BITS); - res_lo_16 = _mm_packus_epi32(shifted_32, shifted_32); - } else { - res_lo_16 = _mm_srai_epi16(_mm_add_epi16(p_16, temp_lo_16), 1); - } - - res_lo_16 = _mm_add_epi16(res_lo_16, res_sub_const); - - res_lo_16 = _mm_sra_epi16( - _mm_add_epi16(res_lo_16, round_bits_const), round_bits_shift); - __m128i res_8_lo = _mm_packus_epi16(res_lo_16, res_lo_16); - *(uint32_t *)dst8 = _mm_cvtsi128_si32(res_8_lo); - } else { - _mm_storel_epi64(p, temp_lo_16); - } - if (p_width > 4) { - __m128i *const p4 = - (__m128i *)&conv_params - ->dst[(i + k + 4) * conv_params->dst_stride + j + 4]; - - res_hi = _mm_add_epi32(res_hi, res_add_const); - res_hi = - _mm_sra_epi32(_mm_add_epi32(res_hi, reduce_bits_vert_const), - reduce_bits_vert_shift); - const __m128i temp_hi_16 = _mm_packus_epi32(res_hi, res_hi); - __m128i res_hi_16; - - if (conv_params->do_average) { - __m128i *const dst8_4 = - (__m128i *)&pred[(i + k + 4) * p_stride + j + 4]; - const __m128i p4_16 = _mm_loadl_epi64(p4); - - if (conv_params->use_jnt_comp_avg) { - const __m128i p_16_hi = _mm_unpacklo_epi16(p4_16, temp_hi_16); - const __m128i wt_res_hi = _mm_madd_epi16(p_16_hi, wt); - const __m128i shifted_32 = - _mm_srai_epi32(wt_res_hi, DIST_PRECISION_BITS); - res_hi_16 = _mm_packus_epi32(shifted_32, shifted_32); - } else { - res_hi_16 = _mm_srai_epi16(_mm_add_epi16(p4_16, temp_hi_16), 1); - } - res_hi_16 = _mm_add_epi16(res_hi_16, res_sub_const); - - res_hi_16 = _mm_sra_epi16( - _mm_add_epi16(res_hi_16, round_bits_const), round_bits_shift); - __m128i res_8_hi = _mm_packus_epi16(res_hi_16, res_hi_16); - *(uint32_t *)dst8_4 = _mm_cvtsi128_si32(res_8_hi); - - } else { - _mm_storel_epi64(p4, temp_hi_16); - } - } - } else { - // Round and pack into 8 bits - const __m128i round_const = - _mm_set1_epi32(-(1 << (bd + reduce_bits_vert - 1)) + - ((1 << reduce_bits_vert) >> 1)); - - const __m128i res_lo_round = _mm_srai_epi32( - _mm_add_epi32(res_lo, round_const), reduce_bits_vert); - const __m128i res_hi_round = _mm_srai_epi32( - _mm_add_epi32(res_hi, round_const), reduce_bits_vert); - - const __m128i res_16bit = _mm_packs_epi32(res_lo_round, res_hi_round); - __m128i res_8bit = _mm_packus_epi16(res_16bit, res_16bit); - - // Store, blending with 'pred' if needed - __m128i *const p = (__m128i *)&pred[(i + k + 4) * p_stride + j]; - - // Note: If we're outputting a 4x4 block, we need to be very careful - // to only output 4 pixels at this point, to avoid encode/decode - // mismatches when encoding with multiple threads. - if (p_width == 4) { - *(uint32_t *)p = _mm_cvtsi128_si32(res_8bit); - } else { - _mm_storel_epi64(p, res_8bit); - } - } - } + prepare_warp_vertical_filter( + pred, tmp, conv_params, gamma, delta, p_height, p_stride, p_width, i, + j, sy4, reduce_bits_vert, &res_add_const_1, round_bits, offset_bits); } } } diff --git a/third_party/aom/av1/common/x86/wiener_convolve_avx2.c b/third_party/aom/av1/common/x86/wiener_convolve_avx2.c index e1449fd21..87a6e1239 100644 --- a/third_party/aom/av1/common/x86/wiener_convolve_avx2.c +++ b/third_party/aom/av1/common/x86/wiener_convolve_avx2.c @@ -39,7 +39,8 @@ void av1_wiener_convolve_add_src_avx2(const uint8_t *src, ptrdiff_t src_stride, DECLARE_ALIGNED(32, uint16_t, temp[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]); - int intermediate_height = h + SUBPEL_TAPS - 1; + int intermediate_height = h + SUBPEL_TAPS - 2; + memset(temp + (intermediate_height * MAX_SB_SIZE), 0, MAX_SB_SIZE); const int center_tap = ((SUBPEL_TAPS - 1) / 2); const uint8_t *const src_ptr = src - center_tap * src_stride - center_tap; diff --git a/third_party/aom/av1/common/x86/wiener_convolve_sse2.c b/third_party/aom/av1/common/x86/wiener_convolve_sse2.c index 3083d224b..f9d00b733 100644 --- a/third_party/aom/av1/common/x86/wiener_convolve_sse2.c +++ b/third_party/aom/av1/common/x86/wiener_convolve_sse2.c @@ -32,7 +32,8 @@ void av1_wiener_convolve_add_src_sse2(const uint8_t *src, ptrdiff_t src_stride, DECLARE_ALIGNED(16, uint16_t, temp[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]); - int intermediate_height = h + SUBPEL_TAPS - 1; + int intermediate_height = h + SUBPEL_TAPS - 2; + memset(temp + (intermediate_height * MAX_SB_SIZE), 0, MAX_SB_SIZE); int i, j; const int center_tap = ((SUBPEL_TAPS - 1) / 2); const uint8_t *const src_ptr = src - center_tap * src_stride - center_tap; |