summaryrefslogtreecommitdiffstats
path: root/third_party/aom/av1/common/txb_common.h
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/aom/av1/common/txb_common.h')
-rw-r--r--third_party/aom/av1/common/txb_common.h130
1 files changed, 36 insertions, 94 deletions
diff --git a/third_party/aom/av1/common/txb_common.h b/third_party/aom/av1/common/txb_common.h
index bea162d70..5620a70a9 100644
--- a/third_party/aom/av1/common/txb_common.h
+++ b/third_party/aom/av1/common/txb_common.h
@@ -24,6 +24,10 @@ typedef struct txb_ctx {
int dc_sign_ctx;
} TXB_CTX;
+static INLINE TX_SIZE get_txsize_context(TX_SIZE tx_size) {
+ return txsize_sqr_up_map[tx_size];
+}
+
#define BASE_CONTEXT_POSITION_NUM 12
static int base_ref_offset[BASE_CONTEXT_POSITION_NUM][2] = {
/* clang-format off*/
@@ -33,14 +37,14 @@ static int base_ref_offset[BASE_CONTEXT_POSITION_NUM][2] = {
};
static INLINE int get_level_count(const tran_low_t *tcoeffs, int stride,
- int row, int col, int level,
+ int height, int row, int col, int level,
int (*nb_offset)[2], int nb_num) {
int count = 0;
for (int idx = 0; idx < nb_num; ++idx) {
const int ref_row = row + nb_offset[idx][0];
const int ref_col = col + nb_offset[idx][1];
const int pos = ref_row * stride + ref_col;
- if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
continue;
tran_low_t abs_coeff = abs(tcoeffs[pos]);
count += abs_coeff > level;
@@ -49,14 +53,15 @@ static INLINE int get_level_count(const tran_low_t *tcoeffs, int stride,
}
static INLINE void get_mag(int *mag, const tran_low_t *tcoeffs, int stride,
- int row, int col, int (*nb_offset)[2], int nb_num) {
+ int height, int row, int col, int (*nb_offset)[2],
+ int nb_num) {
mag[0] = 0;
mag[1] = 0;
for (int idx = 0; idx < nb_num; ++idx) {
const int ref_row = row + nb_offset[idx][0];
const int ref_col = col + nb_offset[idx][1];
const int pos = ref_row * stride + ref_col;
- if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
continue;
tran_low_t abs_coeff = abs(tcoeffs[pos]);
if (nb_offset[idx][0] >= 0 && nb_offset[idx][1] >= 0) {
@@ -70,15 +75,16 @@ static INLINE void get_mag(int *mag, const tran_low_t *tcoeffs, int stride,
}
}
static INLINE int get_level_count_mag(int *mag, const tran_low_t *tcoeffs,
- int stride, int row, int col, int level,
- int (*nb_offset)[2], int nb_num) {
+ int stride, int height, int row, int col,
+ int level, int (*nb_offset)[2],
+ int nb_num) {
int count = 0;
*mag = 0;
for (int idx = 0; idx < nb_num; ++idx) {
const int ref_row = row + nb_offset[idx][0];
const int ref_col = col + nb_offset[idx][1];
const int pos = ref_row * stride + ref_col;
- if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
continue;
tran_low_t abs_coeff = abs(tcoeffs[pos]);
count += abs_coeff > level;
@@ -111,15 +117,16 @@ static INLINE int get_base_ctx_from_count_mag(int row, int col, int count,
static INLINE int get_base_ctx(const tran_low_t *tcoeffs,
int c, // raster order
- const int bwl, const int level) {
+ const int bwl, const int height,
+ const int level) {
const int stride = 1 << bwl;
const int row = c >> bwl;
const int col = c - (row << bwl);
const int level_minus_1 = level - 1;
int mag;
- int count =
- get_level_count_mag(&mag, tcoeffs, stride, row, col, level_minus_1,
- base_ref_offset, BASE_CONTEXT_POSITION_NUM);
+ int count = get_level_count_mag(&mag, tcoeffs, stride, height, row, col,
+ level_minus_1, base_ref_offset,
+ BASE_CONTEXT_POSITION_NUM);
int ctx_idx = get_base_ctx_from_count_mag(row, col, count, mag, level);
return ctx_idx;
}
@@ -169,15 +176,15 @@ static INLINE int get_br_ctx_from_count_mag(int row, int col, int count,
static INLINE int get_br_ctx(const tran_low_t *tcoeffs,
const int c, // raster order
- const int bwl) {
+ const int bwl, const int height) {
const int stride = 1 << bwl;
const int row = c >> bwl;
const int col = c - (row << bwl);
const int level_minus_1 = NUM_BASE_LEVELS;
int mag;
- const int count =
- get_level_count_mag(&mag, tcoeffs, stride, row, col, level_minus_1,
- br_ref_offset, BR_CONTEXT_POSITION_NUM);
+ const int count = get_level_count_mag(&mag, tcoeffs, stride, height, row, col,
+ level_minus_1, br_ref_offset,
+ BR_CONTEXT_POSITION_NUM);
const int ctx = get_br_ctx_from_count_mag(row, col, count, mag);
return ctx;
}
@@ -188,79 +195,15 @@ static int sig_ref_offset[SIG_REF_OFFSET_NUM][2] = {
{ -1, 1 }, { 0, -2 }, { 0, -1 }, { 1, -2 }, { 1, -1 },
};
-static INLINE int get_nz_map_ctx(const tran_low_t *tcoeffs,
- const uint8_t *txb_mask,
- const int coeff_idx, // raster order
- const int bwl) {
- const int row = coeff_idx >> bwl;
- const int col = coeff_idx - (row << bwl);
- int ctx = 0;
- int idx;
- int stride = 1 << bwl;
-
- if (row == 0 && col == 0) return 0;
-
- if (row == 0 && col == 1) return 1 + (tcoeffs[0] != 0);
-
- if (row == 1 && col == 0) return 3 + (tcoeffs[0] != 0);
-
- if (row == 1 && col == 1) {
- int pos;
- ctx = (tcoeffs[0] != 0);
-
- if (txb_mask[1]) ctx += (tcoeffs[1] != 0);
- pos = 1 << bwl;
- if (txb_mask[pos]) ctx += (tcoeffs[pos] != 0);
-
- ctx = (ctx + 1) >> 1;
-
- assert(5 + ctx <= 7);
-
- return 5 + ctx;
- }
-
- for (idx = 0; idx < SIG_REF_OFFSET_NUM; ++idx) {
- int ref_row = row + sig_ref_offset[idx][0];
- int ref_col = col + sig_ref_offset[idx][1];
- int pos;
-
- if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
- continue;
-
- pos = (ref_row << bwl) + ref_col;
-
- if (txb_mask[pos]) ctx += (tcoeffs[pos] != 0);
- }
-
- if (row == 0) {
- ctx = (ctx + 1) >> 1;
-
- assert(ctx < 3);
- return 8 + ctx;
- }
-
- if (col == 0) {
- ctx = (ctx + 1) >> 1;
-
- assert(ctx < 3);
- return 11 + ctx;
- }
-
- ctx >>= 1;
-
- assert(14 + ctx < 20);
-
- return 14 + ctx;
-}
-
-static INLINE int get_nz_count(const tran_low_t *tcoeffs, int stride, int row,
- int col, const int16_t *iscan) {
+static INLINE int get_nz_count(const tran_low_t *tcoeffs, int stride,
+ int height, int row, int col,
+ const int16_t *iscan) {
int count = 0;
const int pos = row * stride + col;
for (int idx = 0; idx < SIG_REF_OFFSET_NUM; ++idx) {
const int ref_row = row + sig_ref_offset[idx][0];
const int ref_col = col + sig_ref_offset[idx][1];
- if (ref_row < 0 || ref_col < 0 || ref_row >= stride || ref_col >= stride)
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
continue;
const int nb_pos = ref_row * stride + ref_col;
if (iscan[nb_pos] < iscan[pos]) count += (tcoeffs[nb_pos] != 0);
@@ -320,26 +263,25 @@ static INLINE int get_nz_map_ctx_from_count(int count,
return 14 + ctx;
}
-// TODO(angiebird): merge this function with get_nz_map_ctx() after proper
-// testing
-static INLINE int get_nz_map_ctx2(const tran_low_t *tcoeffs,
- const int coeff_idx, // raster order
- const int bwl, const int16_t *iscan) {
+static INLINE int get_nz_map_ctx(const tran_low_t *tcoeffs,
+ const int coeff_idx, // raster order
+ const int bwl, const int height,
+ const int16_t *iscan) {
int stride = 1 << bwl;
const int row = coeff_idx >> bwl;
const int col = coeff_idx - (row << bwl);
- int count = get_nz_count(tcoeffs, stride, row, col, iscan);
+ int count = get_nz_count(tcoeffs, stride, height, row, col, iscan);
return get_nz_map_ctx_from_count(count, tcoeffs, coeff_idx, bwl, iscan);
}
static INLINE int get_eob_ctx(const tran_low_t *tcoeffs,
const int coeff_idx, // raster order
- const int bwl) {
+ const TX_SIZE txs_ctx) {
(void)tcoeffs;
- if (bwl == 2) return av1_coeff_band_4x4[coeff_idx];
- if (bwl == 3) return av1_coeff_band_8x8[coeff_idx];
- if (bwl == 4) return av1_coeff_band_16x16[coeff_idx];
- if (bwl == 5) return av1_coeff_band_32x32[coeff_idx];
+ if (txs_ctx == TX_4X4) return av1_coeff_band_4x4[coeff_idx];
+ if (txs_ctx == TX_8X8) return av1_coeff_band_8x8[coeff_idx];
+ if (txs_ctx == TX_16X16) return av1_coeff_band_16x16[coeff_idx];
+ if (txs_ctx == TX_32X32) return av1_coeff_band_32x32[coeff_idx];
assert(0);
return 0;