diff options
Diffstat (limited to 'third_party/aom/test/fft_test.cc')
-rw-r--r-- | third_party/aom/test/fft_test.cc | 263 |
1 files changed, 263 insertions, 0 deletions
diff --git a/third_party/aom/test/fft_test.cc b/third_party/aom/test/fft_test.cc new file mode 100644 index 000000000..56187cdbb --- /dev/null +++ b/third_party/aom/test/fft_test.cc @@ -0,0 +1,263 @@ +#include <math.h> + +#include <algorithm> +#include <complex> +#include <vector> + +#include "aom_dsp/fft_common.h" +#include "aom_mem/aom_mem.h" +#if ARCH_X86 || ARCH_X86_64 +#include "aom_ports/x86.h" +#endif +#include "av1/common/common.h" +#include "config/aom_dsp_rtcd.h" +#include "test/acm_random.h" +#include "third_party/googletest/src/googletest/include/gtest/gtest.h" + +namespace { + +typedef void (*tform_fun_t)(const float *input, float *temp, float *output); + +// Simple 1D FFT implementation +template <typename InputType> +void fft(const InputType *data, std::complex<float> *result, int n) { + if (n == 1) { + result[0] = data[0]; + return; + } + std::vector<InputType> temp(n); + for (int k = 0; k < n / 2; ++k) { + temp[k] = data[2 * k]; + temp[n / 2 + k] = data[2 * k + 1]; + } + fft(&temp[0], result, n / 2); + fft(&temp[n / 2], result + n / 2, n / 2); + for (int k = 0; k < n / 2; ++k) { + std::complex<float> w = std::complex<float>((float)cos(2. * PI * k / n), + (float)-sin(2. * PI * k / n)); + std::complex<float> a = result[k]; + std::complex<float> b = result[n / 2 + k]; + result[k] = a + w * b; + result[n / 2 + k] = a - w * b; + } +} + +void transpose(std::vector<std::complex<float> > *data, int n) { + for (int y = 0; y < n; ++y) { + for (int x = y + 1; x < n; ++x) { + std::swap((*data)[y * n + x], (*data)[x * n + y]); + } + } +} + +// Simple 2D FFT implementation +template <class InputType> +std::vector<std::complex<float> > fft2d(const InputType *input, int n) { + std::vector<std::complex<float> > rowfft(n * n); + std::vector<std::complex<float> > result(n * n); + for (int y = 0; y < n; ++y) { + fft(input + y * n, &rowfft[y * n], n); + } + transpose(&rowfft, n); + for (int y = 0; y < n; ++y) { + fft(&rowfft[y * n], &result[y * n], n); + } + transpose(&result, n); + return result; +} + +struct FFTTestArg { + int n; + void (*fft)(const float *input, float *temp, float *output); + int flag; + FFTTestArg(int n_in, tform_fun_t fft_in, int flag_in) + : n(n_in), fft(fft_in), flag(flag_in) {} +}; + +std::ostream &operator<<(std::ostream &os, const FFTTestArg &test_arg) { + return os << "fft_arg { n:" << test_arg.n << " fft:" << test_arg.fft + << " flag:" << test_arg.flag << "}"; +} + +class FFT2DTest : public ::testing::TestWithParam<FFTTestArg> { + protected: + void SetUp() { + int n = GetParam().n; + input_ = (float *)aom_memalign(32, sizeof(*input_) * n * n); + temp_ = (float *)aom_memalign(32, sizeof(*temp_) * n * n); + output_ = (float *)aom_memalign(32, sizeof(*output_) * n * n * 2); + memset(input_, 0, sizeof(*input_) * n * n); + memset(temp_, 0, sizeof(*temp_) * n * n); + memset(output_, 0, sizeof(*output_) * n * n * 2); +#if ARCH_X86 || ARCH_X86_64 + disabled_ = GetParam().flag != 0 && !(x86_simd_caps() & GetParam().flag); +#else + disabled_ = GetParam().flag != 0; +#endif + } + void TearDown() { + aom_free(input_); + aom_free(temp_); + aom_free(output_); + } + int disabled_; + float *input_; + float *temp_; + float *output_; +}; + +TEST_P(FFT2DTest, Correct) { + if (disabled_) return; + + int n = GetParam().n; + for (int i = 0; i < n * n; ++i) { + input_[i] = 1; + std::vector<std::complex<float> > expected = fft2d<float>(&input_[0], n); + GetParam().fft(&input_[0], &temp_[0], &output_[0]); + for (int y = 0; y < n; ++y) { + for (int x = 0; x < (n / 2) + 1; ++x) { + EXPECT_NEAR(expected[y * n + x].real(), output_[2 * (y * n + x)], 1e-5); + EXPECT_NEAR(expected[y * n + x].imag(), output_[2 * (y * n + x) + 1], + 1e-5); + } + } + input_[i] = 0; + } +} + +TEST_P(FFT2DTest, Benchmark) { + if (disabled_) return; + + int n = GetParam().n; + float sum = 0; + for (int i = 0; i < 1000 * (64 - n); ++i) { + input_[i % (n * n)] = 1; + GetParam().fft(&input_[0], &temp_[0], &output_[0]); + sum += output_[0]; + input_[i % (n * n)] = 0; + } +} + +INSTANTIATE_TEST_CASE_P( + FFT2DTestC, FFT2DTest, + ::testing::Values(FFTTestArg(2, aom_fft2x2_float_c, 0), + FFTTestArg(4, aom_fft4x4_float_c, 0), + FFTTestArg(8, aom_fft8x8_float_c, 0), + FFTTestArg(16, aom_fft16x16_float_c, 0), + FFTTestArg(32, aom_fft32x32_float_c, 0))); +#if ARCH_X86 || ARCH_X86_64 +INSTANTIATE_TEST_CASE_P( + FFT2DTestSSE2, FFT2DTest, + ::testing::Values(FFTTestArg(4, aom_fft4x4_float_sse2, HAS_SSE2), + FFTTestArg(8, aom_fft8x8_float_sse2, HAS_SSE2), + FFTTestArg(16, aom_fft16x16_float_sse2, HAS_SSE2), + FFTTestArg(32, aom_fft32x32_float_sse2, HAS_SSE2))); + +INSTANTIATE_TEST_CASE_P( + FFT2DTestAVX2, FFT2DTest, + ::testing::Values(FFTTestArg(8, aom_fft8x8_float_avx2, HAS_AVX2), + FFTTestArg(16, aom_fft16x16_float_avx2, HAS_AVX2), + FFTTestArg(32, aom_fft32x32_float_avx2, HAS_AVX2))); +#endif + +struct IFFTTestArg { + int n; + tform_fun_t ifft; + int flag; + IFFTTestArg(int n_in, tform_fun_t ifft_in, int flag_in) + : n(n_in), ifft(ifft_in), flag(flag_in) {} +}; + +std::ostream &operator<<(std::ostream &os, const IFFTTestArg &test_arg) { + return os << "ifft_arg { n:" << test_arg.n << " fft:" << test_arg.ifft + << " flag:" << test_arg.flag << "}"; +} + +class IFFT2DTest : public ::testing::TestWithParam<IFFTTestArg> { + protected: + void SetUp() { + int n = GetParam().n; + input_ = (float *)aom_memalign(32, sizeof(*input_) * n * n * 2); + temp_ = (float *)aom_memalign(32, sizeof(*temp_) * n * n * 2); + output_ = (float *)aom_memalign(32, sizeof(*output_) * n * n); + memset(input_, 0, sizeof(*input_) * n * n * 2); + memset(temp_, 0, sizeof(*temp_) * n * n * 2); + memset(output_, 0, sizeof(*output_) * n * n); +#if ARCH_X86 || ARCH_X86_64 + disabled_ = GetParam().flag != 0 && !(x86_simd_caps() & GetParam().flag); +#else + disabled_ = GetParam().flag != 0; +#endif + } + void TearDown() { + aom_free(input_); + aom_free(temp_); + aom_free(output_); + } + int disabled_; + float *input_; + float *temp_; + float *output_; +}; + +TEST_P(IFFT2DTest, Correctness) { + if (disabled_) return; + int n = GetParam().n; + ASSERT_GE(n, 2); + std::vector<float> expected(n * n); + std::vector<float> actual(n * n); + // Do forward transform then invert to make sure we get back expected + for (int y = 0; y < n; ++y) { + for (int x = 0; x < n; ++x) { + expected[y * n + x] = 1; + std::vector<std::complex<float> > input_c = fft2d(&expected[0], n); + for (int i = 0; i < n * n; ++i) { + input_[2 * i + 0] = input_c[i].real(); + input_[2 * i + 1] = input_c[i].imag(); + } + GetParam().ifft(&input_[0], &temp_[0], &output_[0]); + + for (int yy = 0; yy < n; ++yy) { + for (int xx = 0; xx < n; ++xx) { + EXPECT_NEAR(expected[yy * n + xx], output_[yy * n + xx] / (n * n), + 1e-5); + } + } + expected[y * n + x] = 0; + } + } +}; + +TEST_P(IFFT2DTest, Benchmark) { + if (disabled_) return; + int n = GetParam().n; + float sum = 0; + for (int i = 0; i < 1000 * (64 - n); ++i) { + input_[i % (n * n)] = 1; + GetParam().ifft(&input_[0], &temp_[0], &output_[0]); + sum += output_[0]; + input_[i % (n * n)] = 0; + } +} +INSTANTIATE_TEST_CASE_P( + IFFT2DTestC, IFFT2DTest, + ::testing::Values(IFFTTestArg(2, aom_ifft2x2_float_c, 0), + IFFTTestArg(4, aom_ifft4x4_float_c, 0), + IFFTTestArg(8, aom_ifft8x8_float_c, 0), + IFFTTestArg(16, aom_ifft16x16_float_c, 0), + IFFTTestArg(32, aom_ifft32x32_float_c, 0))); +#if ARCH_X86 || ARCH_X86_64 +INSTANTIATE_TEST_CASE_P( + IFFT2DTestSSE2, IFFT2DTest, + ::testing::Values(IFFTTestArg(4, aom_ifft4x4_float_sse2, HAS_SSE2), + IFFTTestArg(8, aom_ifft8x8_float_sse2, HAS_SSE2), + IFFTTestArg(16, aom_ifft16x16_float_sse2, HAS_SSE2), + IFFTTestArg(32, aom_ifft32x32_float_sse2, HAS_SSE2))); + +INSTANTIATE_TEST_CASE_P( + IFFT2DTestAVX2, IFFT2DTest, + ::testing::Values(IFFTTestArg(8, aom_ifft8x8_float_avx2, HAS_AVX2), + IFFTTestArg(16, aom_ifft16x16_float_avx2, HAS_AVX2), + IFFTTestArg(32, aom_ifft32x32_float_avx2, HAS_AVX2))); +#endif +} // namespace |