diff options
Diffstat (limited to 'media/libaom/src/test/av1_txfm_test.cc')
-rw-r--r-- | media/libaom/src/test/av1_txfm_test.cc | 371 |
1 files changed, 371 insertions, 0 deletions
diff --git a/media/libaom/src/test/av1_txfm_test.cc b/media/libaom/src/test/av1_txfm_test.cc new file mode 100644 index 000000000..d5b0ce325 --- /dev/null +++ b/media/libaom/src/test/av1_txfm_test.cc @@ -0,0 +1,371 @@ +/* + * Copyright (c) 2016, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <stdio.h> +#include "test/av1_txfm_test.h" + +namespace libaom_test { + +int get_txfm1d_size(TX_SIZE tx_size) { return tx_size_wide[tx_size]; } + +void get_txfm1d_type(TX_TYPE txfm2d_type, TYPE_TXFM *type0, TYPE_TXFM *type1) { + switch (txfm2d_type) { + case DCT_DCT: + *type0 = TYPE_DCT; + *type1 = TYPE_DCT; + break; + case ADST_DCT: + *type0 = TYPE_ADST; + *type1 = TYPE_DCT; + break; + case DCT_ADST: + *type0 = TYPE_DCT; + *type1 = TYPE_ADST; + break; + case ADST_ADST: + *type0 = TYPE_ADST; + *type1 = TYPE_ADST; + break; + case FLIPADST_DCT: + *type0 = TYPE_ADST; + *type1 = TYPE_DCT; + break; + case DCT_FLIPADST: + *type0 = TYPE_DCT; + *type1 = TYPE_ADST; + break; + case FLIPADST_FLIPADST: + *type0 = TYPE_ADST; + *type1 = TYPE_ADST; + break; + case ADST_FLIPADST: + *type0 = TYPE_ADST; + *type1 = TYPE_ADST; + break; + case FLIPADST_ADST: + *type0 = TYPE_ADST; + *type1 = TYPE_ADST; + break; + case IDTX: + *type0 = TYPE_IDTX; + *type1 = TYPE_IDTX; + break; + case H_DCT: + *type0 = TYPE_IDTX; + *type1 = TYPE_DCT; + break; + case V_DCT: + *type0 = TYPE_DCT; + *type1 = TYPE_IDTX; + break; + case H_ADST: + *type0 = TYPE_IDTX; + *type1 = TYPE_ADST; + break; + case V_ADST: + *type0 = TYPE_ADST; + *type1 = TYPE_IDTX; + break; + case H_FLIPADST: + *type0 = TYPE_IDTX; + *type1 = TYPE_ADST; + break; + case V_FLIPADST: + *type0 = TYPE_ADST; + *type1 = TYPE_IDTX; + break; + default: + *type0 = TYPE_DCT; + *type1 = TYPE_DCT; + assert(0); + break; + } +} + +double Sqrt2 = pow(2, 0.5); +double invSqrt2 = 1 / pow(2, 0.5); + +double dct_matrix(double n, double k, int size) { + return cos(M_PI * (2 * n + 1) * k / (2 * size)); +} + +void reference_dct_1d(const double *in, double *out, int size) { + for (int k = 0; k < size; ++k) { + out[k] = 0; + for (int n = 0; n < size; ++n) { + out[k] += in[n] * dct_matrix(n, k, size); + } + if (k == 0) out[k] = out[k] * invSqrt2; + } +} + +void reference_idct_1d(const double *in, double *out, int size) { + for (int k = 0; k < size; ++k) { + out[k] = 0; + for (int n = 0; n < size; ++n) { + if (n == 0) + out[k] += invSqrt2 * in[n] * dct_matrix(k, n, size); + else + out[k] += in[n] * dct_matrix(k, n, size); + } + } +} + +// TODO(any): Copied from the old 'fadst4' (same as the new 'av1_fadst4_new' +// function). Should be replaced by a proper reference function that takes +// 'double' input & output. +static void fadst4_new(const tran_low_t *input, tran_low_t *output) { + tran_high_t x0, x1, x2, x3; + tran_high_t s0, s1, s2, s3, s4, s5, s6, s7; + + x0 = input[0]; + x1 = input[1]; + x2 = input[2]; + x3 = input[3]; + + if (!(x0 | x1 | x2 | x3)) { + output[0] = output[1] = output[2] = output[3] = 0; + return; + } + + s0 = sinpi_1_9 * x0; + s1 = sinpi_4_9 * x0; + s2 = sinpi_2_9 * x1; + s3 = sinpi_1_9 * x1; + s4 = sinpi_3_9 * x2; + s5 = sinpi_4_9 * x3; + s6 = sinpi_2_9 * x3; + s7 = x0 + x1 - x3; + + x0 = s0 + s2 + s5; + x1 = sinpi_3_9 * s7; + x2 = s1 - s3 + s6; + x3 = s4; + + s0 = x0 + x3; + s1 = x1; + s2 = x2 - x3; + s3 = x2 - x0 + x3; + + // 1-D transform scaling factor is sqrt(2). + output[0] = (tran_low_t)fdct_round_shift(s0); + output[1] = (tran_low_t)fdct_round_shift(s1); + output[2] = (tran_low_t)fdct_round_shift(s2); + output[3] = (tran_low_t)fdct_round_shift(s3); +} + +void reference_adst_1d(const double *in, double *out, int size) { + if (size == 4) { // Special case. + tran_low_t int_input[4]; + for (int i = 0; i < 4; ++i) { + int_input[i] = static_cast<tran_low_t>(round(in[i])); + } + tran_low_t int_output[4]; + fadst4_new(int_input, int_output); + for (int i = 0; i < 4; ++i) { + out[i] = int_output[i]; + } + return; + } + + for (int k = 0; k < size; ++k) { + out[k] = 0; + for (int n = 0; n < size; ++n) { + out[k] += in[n] * sin(M_PI * (2 * n + 1) * (2 * k + 1) / (4 * size)); + } + } +} + +void reference_idtx_1d(const double *in, double *out, int size) { + double scale = 0; + if (size == 4) + scale = Sqrt2; + else if (size == 8) + scale = 2; + else if (size == 16) + scale = 2 * Sqrt2; + else if (size == 32) + scale = 4; + else if (size == 64) + scale = 4 * Sqrt2; + for (int k = 0; k < size; ++k) { + out[k] = in[k] * scale; + } +} + +void reference_hybrid_1d(double *in, double *out, int size, int type) { + if (type == TYPE_DCT) + reference_dct_1d(in, out, size); + else if (type == TYPE_ADST) + reference_adst_1d(in, out, size); + else + reference_idtx_1d(in, out, size); +} + +double get_amplification_factor(TX_TYPE tx_type, TX_SIZE tx_size) { + TXFM_2D_FLIP_CFG fwd_txfm_flip_cfg; + av1_get_fwd_txfm_cfg(tx_type, tx_size, &fwd_txfm_flip_cfg); + const int tx_width = tx_size_wide[fwd_txfm_flip_cfg.tx_size]; + const int tx_height = tx_size_high[fwd_txfm_flip_cfg.tx_size]; + const int8_t *shift = fwd_txfm_flip_cfg.shift; + const int amplify_bit = shift[0] + shift[1] + shift[2]; + double amplify_factor = + amplify_bit >= 0 ? (1 << amplify_bit) : (1.0 / (1 << -amplify_bit)); + + // For rectangular transforms, we need to multiply by an extra factor. + const int rect_type = get_rect_tx_log_ratio(tx_width, tx_height); + if (abs(rect_type) == 1) { + amplify_factor *= pow(2, 0.5); + } + return amplify_factor; +} + +void reference_hybrid_2d(double *in, double *out, TX_TYPE tx_type, + TX_SIZE tx_size) { + // Get transform type and size of each dimension. + TYPE_TXFM type0; + TYPE_TXFM type1; + get_txfm1d_type(tx_type, &type0, &type1); + const int tx_width = tx_size_wide[tx_size]; + const int tx_height = tx_size_high[tx_size]; + + double *const temp_in = new double[AOMMAX(tx_width, tx_height)]; + double *const temp_out = new double[AOMMAX(tx_width, tx_height)]; + double *const out_interm = new double[tx_width * tx_height]; + const int stride = tx_width; + + // Transform columns. + for (int c = 0; c < tx_width; ++c) { + for (int r = 0; r < tx_height; ++r) { + temp_in[r] = in[r * stride + c]; + } + reference_hybrid_1d(temp_in, temp_out, tx_height, type0); + for (int r = 0; r < tx_height; ++r) { + out_interm[r * stride + c] = temp_out[r]; + } + } + + // Transform rows. + for (int r = 0; r < tx_height; ++r) { + reference_hybrid_1d(out_interm + r * stride, out + r * stride, tx_width, + type1); + } + + delete[] temp_in; + delete[] temp_out; + delete[] out_interm; + + // These transforms use an approximate 2D DCT transform, by only keeping the + // top-left quarter of the coefficients, and repacking them in the first + // quarter indices. + // TODO(urvang): Refactor this code. + if (tx_width == 64 && tx_height == 64) { // tx_size == TX_64X64 + // Zero out top-right 32x32 area. + for (int row = 0; row < 32; ++row) { + memset(out + row * 64 + 32, 0, 32 * sizeof(*out)); + } + // Zero out the bottom 64x32 area. + memset(out + 32 * 64, 0, 32 * 64 * sizeof(*out)); + // Re-pack non-zero coeffs in the first 32x32 indices. + for (int row = 1; row < 32; ++row) { + memcpy(out + row * 32, out + row * 64, 32 * sizeof(*out)); + } + } else if (tx_width == 32 && tx_height == 64) { // tx_size == TX_32X64 + // Zero out the bottom 32x32 area. + memset(out + 32 * 32, 0, 32 * 32 * sizeof(*out)); + // Note: no repacking needed here. + } else if (tx_width == 64 && tx_height == 32) { // tx_size == TX_64X32 + // Zero out right 32x32 area. + for (int row = 0; row < 32; ++row) { + memset(out + row * 64 + 32, 0, 32 * sizeof(*out)); + } + // Re-pack non-zero coeffs in the first 32x32 indices. + for (int row = 1; row < 32; ++row) { + memcpy(out + row * 32, out + row * 64, 32 * sizeof(*out)); + } + } else if (tx_width == 16 && tx_height == 64) { // tx_size == TX_16X64 + // Zero out the bottom 16x32 area. + memset(out + 16 * 32, 0, 16 * 32 * sizeof(*out)); + // Note: no repacking needed here. + } else if (tx_width == 64 && tx_height == 16) { // tx_size == TX_64X16 + // Zero out right 32x16 area. + for (int row = 0; row < 16; ++row) { + memset(out + row * 64 + 32, 0, 32 * sizeof(*out)); + } + // Re-pack non-zero coeffs in the first 32x16 indices. + for (int row = 1; row < 16; ++row) { + memcpy(out + row * 32, out + row * 64, 32 * sizeof(*out)); + } + } + + // Apply appropriate scale. + const double amplify_factor = get_amplification_factor(tx_type, tx_size); + for (int c = 0; c < tx_width; ++c) { + for (int r = 0; r < tx_height; ++r) { + out[r * stride + c] *= amplify_factor; + } + } +} + +template <typename Type> +void fliplr(Type *dest, int width, int height, int stride) { + for (int r = 0; r < height; ++r) { + for (int c = 0; c < width / 2; ++c) { + const Type tmp = dest[r * stride + c]; + dest[r * stride + c] = dest[r * stride + width - 1 - c]; + dest[r * stride + width - 1 - c] = tmp; + } + } +} + +template <typename Type> +void flipud(Type *dest, int width, int height, int stride) { + for (int c = 0; c < width; ++c) { + for (int r = 0; r < height / 2; ++r) { + const Type tmp = dest[r * stride + c]; + dest[r * stride + c] = dest[(height - 1 - r) * stride + c]; + dest[(height - 1 - r) * stride + c] = tmp; + } + } +} + +template <typename Type> +void fliplrud(Type *dest, int width, int height, int stride) { + for (int r = 0; r < height / 2; ++r) { + for (int c = 0; c < width; ++c) { + const Type tmp = dest[r * stride + c]; + dest[r * stride + c] = dest[(height - 1 - r) * stride + width - 1 - c]; + dest[(height - 1 - r) * stride + width - 1 - c] = tmp; + } + } +} + +template void fliplr<double>(double *dest, int width, int height, int stride); +template void flipud<double>(double *dest, int width, int height, int stride); +template void fliplrud<double>(double *dest, int width, int height, int stride); + +int bd_arr[BD_NUM] = { 8, 10, 12 }; + +int8_t low_range_arr[BD_NUM] = { 18, 32, 32 }; +int8_t high_range_arr[BD_NUM] = { 32, 32, 32 }; + +void txfm_stage_range_check(const int8_t *stage_range, int stage_num, + int8_t cos_bit, int low_range, int high_range) { + for (int i = 0; i < stage_num; ++i) { + EXPECT_LE(stage_range[i], low_range); + ASSERT_LE(stage_range[i] + cos_bit, high_range) << "stage = " << i; + } + for (int i = 0; i < stage_num - 1; ++i) { + // make sure there is no overflow while doing half_btf() + ASSERT_LE(stage_range[i + 1] + cos_bit, high_range) << "stage = " << i; + } +} +} // namespace libaom_test |