summaryrefslogtreecommitdiffstats
path: root/third_party/aom/av1/common/txb_common.h
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/aom/av1/common/txb_common.h')
-rw-r--r--third_party/aom/av1/common/txb_common.h351
1 files changed, 284 insertions, 67 deletions
diff --git a/third_party/aom/av1/common/txb_common.h b/third_party/aom/av1/common/txb_common.h
index 5620a70a9..3bf8f8c61 100644
--- a/third_party/aom/av1/common/txb_common.h
+++ b/third_party/aom/av1/common/txb_common.h
@@ -11,6 +11,10 @@
#ifndef AV1_COMMON_TXB_COMMON_H_
#define AV1_COMMON_TXB_COMMON_H_
+
+#define REDUCE_CONTEXT_DEPENDENCY 0
+#define MIN_SCAN_IDX_REDUCE_CONTEXT_DEPENDENCY 0
+
extern const int16_t av1_coeff_band_4x4[16];
extern const int16_t av1_coeff_band_8x8[64];
@@ -28,7 +32,6 @@ static INLINE TX_SIZE get_txsize_context(TX_SIZE tx_size) {
return txsize_sqr_up_map[tx_size];
}
-#define BASE_CONTEXT_POSITION_NUM 12
static int base_ref_offset[BASE_CONTEXT_POSITION_NUM][2] = {
/* clang-format off*/
{ -2, 0 }, { -1, -1 }, { -1, 0 }, { -1, 1 }, { 0, -2 }, { 0, -1 }, { 0, 1 },
@@ -36,23 +39,24 @@ static int base_ref_offset[BASE_CONTEXT_POSITION_NUM][2] = {
/* clang-format on*/
};
-static INLINE int get_level_count(const tran_low_t *tcoeffs, int stride,
+static INLINE int get_level_count(const tran_low_t *tcoeffs, int bwl,
int height, int row, int col, int level,
int (*nb_offset)[2], int nb_num) {
int count = 0;
for (int idx = 0; idx < nb_num; ++idx) {
const int ref_row = row + nb_offset[idx][0];
const int ref_col = col + nb_offset[idx][1];
- const int pos = ref_row * stride + ref_col;
- if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height ||
+ ref_col >= (1 << bwl))
continue;
+ const int pos = (ref_row << bwl) + ref_col;
tran_low_t abs_coeff = abs(tcoeffs[pos]);
count += abs_coeff > level;
}
return count;
}
-static INLINE void get_mag(int *mag, const tran_low_t *tcoeffs, int stride,
+static INLINE void get_mag(int *mag, const tran_low_t *tcoeffs, int bwl,
int height, int row, int col, int (*nb_offset)[2],
int nb_num) {
mag[0] = 0;
@@ -60,9 +64,10 @@ static INLINE void get_mag(int *mag, const tran_low_t *tcoeffs, int stride,
for (int idx = 0; idx < nb_num; ++idx) {
const int ref_row = row + nb_offset[idx][0];
const int ref_col = col + nb_offset[idx][1];
- const int pos = ref_row * stride + ref_col;
- if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height ||
+ ref_col >= (1 << bwl))
continue;
+ const int pos = (ref_row << bwl) + ref_col;
tran_low_t abs_coeff = abs(tcoeffs[pos]);
if (nb_offset[idx][0] >= 0 && nb_offset[idx][1] >= 0) {
if (abs_coeff > mag[0]) {
@@ -74,18 +79,50 @@ static INLINE void get_mag(int *mag, const tran_low_t *tcoeffs, int stride,
}
}
}
+
+static INLINE void get_base_count_mag(int *mag, int *count,
+ const tran_low_t *tcoeffs, int bwl,
+ int height, int row, int col) {
+ mag[0] = 0;
+ mag[1] = 0;
+ for (int i = 0; i < NUM_BASE_LEVELS; ++i) count[i] = 0;
+ for (int idx = 0; idx < BASE_CONTEXT_POSITION_NUM; ++idx) {
+ const int ref_row = row + base_ref_offset[idx][0];
+ const int ref_col = col + base_ref_offset[idx][1];
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height ||
+ ref_col >= (1 << bwl))
+ continue;
+ const int pos = (ref_row << bwl) + ref_col;
+ tran_low_t abs_coeff = abs(tcoeffs[pos]);
+ // count
+ for (int i = 0; i < NUM_BASE_LEVELS; ++i) {
+ count[i] += abs_coeff > i;
+ }
+ // mag
+ if (base_ref_offset[idx][0] >= 0 && base_ref_offset[idx][1] >= 0) {
+ if (abs_coeff > mag[0]) {
+ mag[0] = abs_coeff;
+ mag[1] = 1;
+ } else if (abs_coeff == mag[0]) {
+ ++mag[1];
+ }
+ }
+ }
+}
+
static INLINE int get_level_count_mag(int *mag, const tran_low_t *tcoeffs,
- int stride, int height, int row, int col,
+ int bwl, int height, int row, int col,
int level, int (*nb_offset)[2],
int nb_num) {
+ const int stride = 1 << bwl;
int count = 0;
*mag = 0;
for (int idx = 0; idx < nb_num; ++idx) {
const int ref_row = row + nb_offset[idx][0];
const int ref_col = col + nb_offset[idx][1];
- const int pos = ref_row * stride + ref_col;
if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
continue;
+ const int pos = (ref_row << bwl) + ref_col;
tran_low_t abs_coeff = abs(tcoeffs[pos]);
count += abs_coeff > level;
if (nb_offset[idx][0] >= 0 && nb_offset[idx][1] >= 0)
@@ -95,19 +132,21 @@ static INLINE int get_level_count_mag(int *mag, const tran_low_t *tcoeffs,
}
static INLINE int get_base_ctx_from_count_mag(int row, int col, int count,
- int mag, int level) {
+ int sig_mag) {
const int ctx = (count + 1) >> 1;
- const int sig_mag = mag > level;
int ctx_idx = -1;
if (row == 0 && col == 0) {
ctx_idx = (ctx << 1) + sig_mag;
- assert(ctx_idx < 8);
+ // TODO(angiebird): turn this on once the optimization is finalized
+ // assert(ctx_idx < 8);
} else if (row == 0) {
ctx_idx = 8 + (ctx << 1) + sig_mag;
- assert(ctx_idx < 18);
+ // TODO(angiebird): turn this on once the optimization is finalized
+ // assert(ctx_idx < 18);
} else if (col == 0) {
ctx_idx = 8 + 10 + (ctx << 1) + sig_mag;
- assert(ctx_idx < 28);
+ // TODO(angiebird): turn this on once the optimization is finalized
+ // assert(ctx_idx < 28);
} else {
ctx_idx = 8 + 10 + 10 + (ctx << 1) + sig_mag;
assert(ctx_idx < COEFF_BASE_CONTEXTS);
@@ -119,15 +158,14 @@ static INLINE int get_base_ctx(const tran_low_t *tcoeffs,
int c, // raster order
const int bwl, const int height,
const int level) {
- const int stride = 1 << bwl;
const int row = c >> bwl;
const int col = c - (row << bwl);
const int level_minus_1 = level - 1;
int mag;
- int count = get_level_count_mag(&mag, tcoeffs, stride, height, row, col,
- level_minus_1, base_ref_offset,
- BASE_CONTEXT_POSITION_NUM);
- int ctx_idx = get_base_ctx_from_count_mag(row, col, count, mag, level);
+ int count =
+ get_level_count_mag(&mag, tcoeffs, bwl, height, row, col, level_minus_1,
+ base_ref_offset, BASE_CONTEXT_POSITION_NUM);
+ int ctx_idx = get_base_ctx_from_count_mag(row, col, count, mag > level);
return ctx_idx;
}
@@ -139,13 +177,52 @@ static int br_ref_offset[BR_CONTEXT_POSITION_NUM][2] = {
/* clang-format on*/
};
-static int br_level_map[9] = {
+static const int br_level_map[9] = {
0, 0, 1, 1, 2, 2, 3, 3, 3,
};
+static const int coeff_to_br_index[COEFF_BASE_RANGE] = {
+ 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
+};
+
+static const int br_index_to_coeff[BASE_RANGE_SETS] = {
+ 0, 2, 6,
+};
+
+static const int br_extra_bits[BASE_RANGE_SETS] = {
+ 1, 2, 3,
+};
+
#define BR_MAG_OFFSET 1
// TODO(angiebird): optimize this function by using a table to map from
// count/mag to ctx
+
+static INLINE int get_br_count_mag(int *mag, const tran_low_t *tcoeffs, int bwl,
+ int height, int row, int col, int level) {
+ mag[0] = 0;
+ mag[1] = 0;
+ int count = 0;
+ for (int idx = 0; idx < BR_CONTEXT_POSITION_NUM; ++idx) {
+ const int ref_row = row + br_ref_offset[idx][0];
+ const int ref_col = col + br_ref_offset[idx][1];
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height ||
+ ref_col >= (1 << bwl))
+ continue;
+ const int pos = (ref_row << bwl) + ref_col;
+ tran_low_t abs_coeff = abs(tcoeffs[pos]);
+ count += abs_coeff > level;
+ if (br_ref_offset[idx][0] >= 0 && br_ref_offset[idx][1] >= 0) {
+ if (abs_coeff > mag[0]) {
+ mag[0] = abs_coeff;
+ mag[1] = 1;
+ } else if (abs_coeff == mag[0]) {
+ ++mag[1];
+ }
+ }
+ }
+ return count;
+}
+
static INLINE int get_br_ctx_from_count_mag(int row, int col, int count,
int mag) {
int offset = 0;
@@ -153,7 +230,7 @@ static INLINE int get_br_ctx_from_count_mag(int row, int col, int count,
offset = 0;
else if (mag <= 3)
offset = 1;
- else if (mag <= 6)
+ else if (mag <= 5)
offset = 2;
else
offset = 3;
@@ -177,111 +254,171 @@ static INLINE int get_br_ctx_from_count_mag(int row, int col, int count,
static INLINE int get_br_ctx(const tran_low_t *tcoeffs,
const int c, // raster order
const int bwl, const int height) {
- const int stride = 1 << bwl;
const int row = c >> bwl;
const int col = c - (row << bwl);
const int level_minus_1 = NUM_BASE_LEVELS;
int mag;
- const int count = get_level_count_mag(&mag, tcoeffs, stride, height, row, col,
- level_minus_1, br_ref_offset,
- BR_CONTEXT_POSITION_NUM);
+ const int count =
+ get_level_count_mag(&mag, tcoeffs, bwl, height, row, col, level_minus_1,
+ br_ref_offset, BR_CONTEXT_POSITION_NUM);
const int ctx = get_br_ctx_from_count_mag(row, col, count, mag);
return ctx;
}
-#define SIG_REF_OFFSET_NUM 11
+#define SIG_REF_OFFSET_NUM 7
static int sig_ref_offset[SIG_REF_OFFSET_NUM][2] = {
- { -2, -1 }, { -2, 0 }, { -2, 1 }, { -1, -2 }, { -1, -1 }, { -1, 0 },
- { -1, 1 }, { 0, -2 }, { 0, -1 }, { 1, -2 }, { 1, -1 },
+ { -2, -1 }, { -2, 0 }, { -1, -2 }, { -1, -1 },
+ { -1, 0 }, { 0, -2 }, { 0, -1 },
};
-static INLINE int get_nz_count(const tran_low_t *tcoeffs, int stride,
- int height, int row, int col,
- const int16_t *iscan) {
+#if REDUCE_CONTEXT_DEPENDENCY
+static INLINE int get_nz_count(const tran_low_t *tcoeffs, int bwl, int height,
+ int row, int col, int prev_row, int prev_col) {
int count = 0;
- const int pos = row * stride + col;
for (int idx = 0; idx < SIG_REF_OFFSET_NUM; ++idx) {
const int ref_row = row + sig_ref_offset[idx][0];
const int ref_col = col + sig_ref_offset[idx][1];
- if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height ||
+ ref_col >= (1 << bwl) || (prev_row == ref_row && prev_col == ref_col))
+ continue;
+ const int nb_pos = (ref_row << bwl) + ref_col;
+ count += (tcoeffs[nb_pos] != 0);
+ }
+ return count;
+}
+#else
+static INLINE int get_nz_count(const tran_low_t *tcoeffs, int bwl, int height,
+ int row, int col) {
+ int count = 0;
+ for (int idx = 0; idx < SIG_REF_OFFSET_NUM; ++idx) {
+ const int ref_row = row + sig_ref_offset[idx][0];
+ const int ref_col = col + sig_ref_offset[idx][1];
+ if (ref_row < 0 || ref_col < 0 || ref_row >= height ||
+ ref_col >= (1 << bwl))
continue;
- const int nb_pos = ref_row * stride + ref_col;
- if (iscan[nb_pos] < iscan[pos]) count += (tcoeffs[nb_pos] != 0);
+ const int nb_pos = (ref_row << bwl) + ref_col;
+ count += (tcoeffs[nb_pos] != 0);
}
return count;
}
+#endif
+
+static INLINE TX_CLASS get_tx_class(TX_TYPE tx_type) {
+ switch (tx_type) {
+#if CONFIG_EXT_TX
+ case V_DCT:
+ case V_ADST:
+ case V_FLIPADST: return TX_CLASS_VERT;
+ case H_DCT:
+ case H_ADST:
+ case H_FLIPADST: return TX_CLASS_HORIZ;
+#endif
+ default: return TX_CLASS_2D;
+ }
+}
// TODO(angiebird): optimize this function by generate a table that maps from
// count to ctx
static INLINE int get_nz_map_ctx_from_count(int count,
- const tran_low_t *tcoeffs,
int coeff_idx, // raster order
- int bwl, const int16_t *iscan) {
+ int bwl, TX_TYPE tx_type) {
+ (void)tx_type;
const int row = coeff_idx >> bwl;
const int col = coeff_idx - (row << bwl);
int ctx = 0;
+#if CONFIG_EXT_TX
+ int tx_class = get_tx_class(tx_type);
+ int offset;
+ if (tx_class == TX_CLASS_2D)
+ offset = 0;
+ else if (tx_class == TX_CLASS_VERT)
+ offset = SIG_COEF_CONTEXTS_2D;
+ else
+ offset = SIG_COEF_CONTEXTS_2D + SIG_COEF_CONTEXTS_1D;
+#else
+ int offset = 0;
+#endif
- if (row == 0 && col == 0) return 0;
+ if (row == 0 && col == 0) return offset + 0;
- if (row == 0 && col == 1) return 1 + (tcoeffs[0] != 0);
+ if (row == 0 && col == 1) return offset + 1 + count;
- if (row == 1 && col == 0) return 3 + (tcoeffs[0] != 0);
+ if (row == 1 && col == 0) return offset + 3 + count;
if (row == 1 && col == 1) {
- int pos;
- ctx = (tcoeffs[0] != 0);
-
- if (iscan[1] < iscan[coeff_idx]) ctx += (tcoeffs[1] != 0);
- pos = 1 << bwl;
- if (iscan[pos] < iscan[coeff_idx]) ctx += (tcoeffs[pos] != 0);
-
- ctx = (ctx + 1) >> 1;
+ ctx = (count + 1) >> 1;
assert(5 + ctx <= 7);
- return 5 + ctx;
+ return offset + 5 + ctx;
}
if (row == 0) {
ctx = (count + 1) >> 1;
- assert(ctx < 3);
- return 8 + ctx;
+ assert(ctx < 2);
+ return offset + 8 + ctx;
}
if (col == 0) {
ctx = (count + 1) >> 1;
- assert(ctx < 3);
- return 11 + ctx;
+ assert(ctx < 2);
+ return offset + 10 + ctx;
}
ctx = count >> 1;
- assert(14 + ctx < 20);
+ assert(12 + ctx < 16);
- return 14 + ctx;
+ return offset + 12 + ctx;
}
-static INLINE int get_nz_map_ctx(const tran_low_t *tcoeffs,
- const int coeff_idx, // raster order
- const int bwl, const int height,
- const int16_t *iscan) {
- int stride = 1 << bwl;
+static INLINE int get_nz_map_ctx(const tran_low_t *tcoeffs, const int scan_idx,
+ const int16_t *scan, const int bwl,
+ const int height, TX_TYPE tx_type) {
+ const int coeff_idx = scan[scan_idx];
const int row = coeff_idx >> bwl;
const int col = coeff_idx - (row << bwl);
- int count = get_nz_count(tcoeffs, stride, height, row, col, iscan);
- return get_nz_map_ctx_from_count(count, tcoeffs, coeff_idx, bwl, iscan);
+#if REDUCE_CONTEXT_DEPENDENCY
+ int prev_coeff_idx;
+ int prev_row;
+ int prev_col;
+ if (scan_idx > MIN_SCAN_IDX_REDUCE_CONTEXT_DEPENDENCY) {
+ prev_coeff_idx = scan[scan_idx - 1]; // raster order
+ prev_row = prev_coeff_idx >> bwl;
+ prev_col = prev_coeff_idx - (prev_row << bwl);
+ } else {
+ prev_coeff_idx = -1;
+ prev_row = -1;
+ prev_col = -1;
+ }
+ int count = get_nz_count(tcoeffs, bwl, height, row, col, prev_row, prev_col);
+#else
+ int count = get_nz_count(tcoeffs, bwl, height, row, col);
+#endif
+ return get_nz_map_ctx_from_count(count, coeff_idx, bwl, tx_type);
}
static INLINE int get_eob_ctx(const tran_low_t *tcoeffs,
const int coeff_idx, // raster order
- const TX_SIZE txs_ctx) {
+ const TX_SIZE txs_ctx, TX_TYPE tx_type) {
(void)tcoeffs;
- if (txs_ctx == TX_4X4) return av1_coeff_band_4x4[coeff_idx];
- if (txs_ctx == TX_8X8) return av1_coeff_band_8x8[coeff_idx];
- if (txs_ctx == TX_16X16) return av1_coeff_band_16x16[coeff_idx];
- if (txs_ctx == TX_32X32) return av1_coeff_band_32x32[coeff_idx];
+ int offset = 0;
+#if CONFIG_CTX1D
+ TX_CLASS tx_class = get_tx_class(tx_type);
+ if (tx_class == TX_CLASS_VERT)
+ offset = EOB_COEF_CONTEXTS_2D;
+ else if (tx_class == TX_CLASS_HORIZ)
+ offset = EOB_COEF_CONTEXTS_2D + EOB_COEF_CONTEXTS_1D;
+#else
+ (void)tx_type;
+#endif
+
+ if (txs_ctx == TX_4X4) return offset + av1_coeff_band_4x4[coeff_idx];
+ if (txs_ctx == TX_8X8) return offset + av1_coeff_band_8x8[coeff_idx];
+ if (txs_ctx == TX_16X16) return offset + av1_coeff_band_16x16[coeff_idx];
+ if (txs_ctx == TX_32X32) return offset + av1_coeff_band_32x32[coeff_idx];
assert(0);
return 0;
@@ -369,6 +506,86 @@ static INLINE void get_txb_ctx(BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
}
}
+#if LV_MAP_PROB
+void av1_init_txb_probs(FRAME_CONTEXT *fc);
+#endif // LV_MAP_PROB
+
void av1_adapt_txb_probs(AV1_COMMON *cm, unsigned int count_sat,
unsigned int update_factor);
+
+void av1_init_lv_map(AV1_COMMON *cm);
+
+#if CONFIG_CTX1D
+static INLINE void get_eob_vert(int16_t *eob_ls, const tran_low_t *tcoeff,
+ int w, int h) {
+ for (int c = 0; c < w; ++c) {
+ eob_ls[c] = 0;
+ for (int r = h - 1; r >= 0; --r) {
+ int coeff_idx = r * w + c;
+ if (tcoeff[coeff_idx] != 0) {
+ eob_ls[c] = r + 1;
+ break;
+ }
+ }
+ }
+}
+
+static INLINE void get_eob_horiz(int16_t *eob_ls, const tran_low_t *tcoeff,
+ int w, int h) {
+ for (int r = 0; r < h; ++r) {
+ eob_ls[r] = 0;
+ for (int c = w - 1; c >= 0; --c) {
+ int coeff_idx = r * w + c;
+ if (tcoeff[coeff_idx] != 0) {
+ eob_ls[r] = c + 1;
+ break;
+ }
+ }
+ }
+}
+
+static INLINE int get_empty_line_ctx(int line_idx, int16_t *eob_ls) {
+ if (line_idx > 0) {
+ int prev_eob = eob_ls[line_idx - 1];
+ if (prev_eob == 0) {
+ return 1;
+ } else if (prev_eob < 3) {
+ return 2;
+ } else if (prev_eob < 6) {
+ return 3;
+ } else {
+ return 4;
+ }
+ } else {
+ return 0;
+ }
+}
+
+#define MAX_POS_CTX 8
+static int pos_ctx[MAX_HVTX_SIZE] = {
+ 0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5,
+ 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
+};
+static INLINE int get_hv_eob_ctx(int line_idx, int pos, int16_t *eob_ls) {
+ if (line_idx > 0) {
+ int prev_eob = eob_ls[line_idx - 1];
+ int diff = pos + 1 - prev_eob;
+ int abs_diff = abs(diff);
+ int ctx_idx = pos_ctx[abs_diff];
+ assert(ctx_idx < MAX_POS_CTX);
+ if (diff < 0) {
+ ctx_idx += MAX_POS_CTX;
+ assert(ctx_idx >= MAX_POS_CTX);
+ assert(ctx_idx < 2 * MAX_POS_CTX);
+ }
+ return ctx_idx;
+ } else {
+ int ctx_idx = MAX_POS_CTX + MAX_POS_CTX + pos_ctx[pos];
+ assert(ctx_idx < HV_EOB_CONTEXTS);
+ assert(HV_EOB_CONTEXTS == MAX_POS_CTX * 3);
+ return ctx_idx;
+ }
+}
+#endif // CONFIG_CTX1D
+
#endif // AV1_COMMON_TXB_COMMON_H_