diff options
Diffstat (limited to 'third_party/aom/av1/encoder/encodetxb.c')
-rw-r--r-- | third_party/aom/av1/encoder/encodetxb.c | 1149 |
1 files changed, 1124 insertions, 25 deletions
diff --git a/third_party/aom/av1/encoder/encodetxb.c b/third_party/aom/av1/encoder/encodetxb.c index 3f71a4472..731642064 100644 --- a/third_party/aom/av1/encoder/encodetxb.c +++ b/third_party/aom/av1/encoder/encodetxb.c @@ -21,6 +21,8 @@ #include "av1/encoder/subexp.h" #include "av1/encoder/tokenize.h" +#define TEST_OPTIMIZE_TXB 0 + void av1_alloc_txb_buf(AV1_COMP *cpi) { #if 0 AV1_COMMON *cm = &cpi->common; @@ -159,7 +161,7 @@ void av1_write_coeffs_txb(const AV1_COMMON *const cm, MACROBLOCKD *xd, } // level is above 1. - ctx = get_level_ctx(tcoeff, scan[c], bwl); + ctx = get_br_ctx(tcoeff, scan[c], bwl); for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) { if (level == (idx + 1 + NUM_BASE_LEVELS)) { aom_write(w, 1, cm->fc->coeff_lps[tx_size][plane_type][ctx]); @@ -251,6 +253,32 @@ static INLINE void get_base_ctx_set(const tran_low_t *tcoeffs, return; } +static INLINE int get_br_cost(tran_low_t abs_qc, int ctx, + const aom_prob *coeff_lps) { + const tran_low_t min_level = 1 + NUM_BASE_LEVELS; + const tran_low_t max_level = 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE; + if (abs_qc >= min_level) { + const int cost0 = av1_cost_bit(coeff_lps[ctx], 0); + const int cost1 = av1_cost_bit(coeff_lps[ctx], 1); + if (abs_qc >= max_level) + return COEFF_BASE_RANGE * cost0; + else + return (abs_qc - min_level) * cost0 + cost1; + } else { + return 0; + } +} + +static INLINE int get_base_cost(tran_low_t abs_qc, int ctx, + aom_prob (*coeff_base)[COEFF_BASE_CONTEXTS], + int base_idx) { + const int level = base_idx + 1; + if (abs_qc < level) + return 0; + else + return av1_cost_bit(coeff_base[base_idx][ctx], abs_qc == level); +} + int av1_cost_coeffs_txb(const AV1_COMP *const cpi, MACROBLOCK *x, int plane, int block, TXB_CTX *txb_ctx) { const AV1_COMMON *const cm = &cpi->common; @@ -331,7 +359,7 @@ int av1_cost_coeffs_txb(const AV1_COMP *const cpi, MACROBLOCK *x, int plane, int idx; int ctx; - ctx = get_level_ctx(qcoeff, scan[c], bwl); + ctx = get_br_ctx(qcoeff, scan[c], bwl); for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) { if (level == (idx + 1 + NUM_BASE_LEVELS)) { @@ -373,12 +401,1085 @@ int av1_cost_coeffs_txb(const AV1_COMP *const cpi, MACROBLOCK *x, int plane, return cost; } -typedef struct TxbParams { - const AV1_COMP *cpi; - ThreadData *td; - int rate; -} TxbParams; +static INLINE int has_base(tran_low_t qc, int base_idx) { + const int level = base_idx + 1; + return abs(qc) >= level; +} + +static void gen_base_count_mag_arr(int (*base_count_arr)[MAX_TX_SQUARE], + int (*base_mag_arr)[2], + const tran_low_t *qcoeff, int stride, + int eob, const int16_t *scan) { + for (int c = 0; c < eob; ++c) { + const int coeff_idx = scan[c]; // raster order + if (!has_base(qcoeff[coeff_idx], 0)) continue; + const int row = coeff_idx / stride; + const int col = coeff_idx % stride; + int *mag = base_mag_arr[coeff_idx]; + get_mag(mag, qcoeff, stride, row, col, base_ref_offset, + BASE_CONTEXT_POSITION_NUM); + for (int i = 0; i < NUM_BASE_LEVELS; ++i) { + if (!has_base(qcoeff[coeff_idx], i)) continue; + int *count = base_count_arr[i] + coeff_idx; + *count = get_level_count(qcoeff, stride, row, col, i, base_ref_offset, + BASE_CONTEXT_POSITION_NUM); + } + } +} + +static void gen_nz_count_arr(int(*nz_count_arr), const tran_low_t *qcoeff, + int stride, int eob, + const SCAN_ORDER *scan_order) { + const int16_t *scan = scan_order->scan; + const int16_t *iscan = scan_order->iscan; + for (int c = 0; c < eob; ++c) { + const int coeff_idx = scan[c]; // raster order + const int row = coeff_idx / stride; + const int col = coeff_idx % stride; + nz_count_arr[coeff_idx] = get_nz_count(qcoeff, stride, row, col, iscan); + } +} + +static void gen_nz_ctx_arr(int (*nz_ctx_arr)[2], int(*nz_count_arr), + const tran_low_t *qcoeff, int bwl, int eob, + const SCAN_ORDER *scan_order) { + const int16_t *scan = scan_order->scan; + const int16_t *iscan = scan_order->iscan; + for (int c = 0; c < eob; ++c) { + const int coeff_idx = scan[c]; // raster order + const int count = nz_count_arr[coeff_idx]; + nz_ctx_arr[coeff_idx][0] = + get_nz_map_ctx_from_count(count, qcoeff, coeff_idx, bwl, iscan); + } +} + +static void gen_base_ctx_arr(int (*base_ctx_arr)[MAX_TX_SQUARE][2], + int (*base_count_arr)[MAX_TX_SQUARE], + int (*base_mag_arr)[2], const tran_low_t *qcoeff, + int stride, int eob, const int16_t *scan) { + (void)qcoeff; + for (int i = 0; i < NUM_BASE_LEVELS; ++i) { + for (int c = 0; c < eob; ++c) { + const int coeff_idx = scan[c]; // raster order + if (!has_base(qcoeff[coeff_idx], i)) continue; + const int row = coeff_idx / stride; + const int col = coeff_idx % stride; + const int count = base_count_arr[i][coeff_idx]; + const int *mag = base_mag_arr[coeff_idx]; + const int level = i + 1; + base_ctx_arr[i][coeff_idx][0] = + get_base_ctx_from_count_mag(row, col, count, mag[0], level); + } + } +} + +static INLINE int has_br(tran_low_t qc) { + return abs(qc) >= 1 + NUM_BASE_LEVELS; +} + +static void gen_br_count_mag_arr(int *br_count_arr, int (*br_mag_arr)[2], + const tran_low_t *qcoeff, int stride, int eob, + const int16_t *scan) { + for (int c = 0; c < eob; ++c) { + const int coeff_idx = scan[c]; // raster order + if (!has_br(qcoeff[coeff_idx])) continue; + const int row = coeff_idx / stride; + const int col = coeff_idx % stride; + int *count = br_count_arr + coeff_idx; + int *mag = br_mag_arr[coeff_idx]; + *count = get_level_count(qcoeff, stride, row, col, NUM_BASE_LEVELS, + br_ref_offset, BR_CONTEXT_POSITION_NUM); + get_mag(mag, qcoeff, stride, row, col, br_ref_offset, + BR_CONTEXT_POSITION_NUM); + } +} + +static void gen_br_ctx_arr(int (*br_ctx_arr)[2], const int *br_count_arr, + int (*br_mag_arr)[2], const tran_low_t *qcoeff, + int stride, int eob, const int16_t *scan) { + (void)qcoeff; + for (int c = 0; c < eob; ++c) { + const int coeff_idx = scan[c]; // raster order + if (!has_br(qcoeff[coeff_idx])) continue; + const int row = coeff_idx / stride; + const int col = coeff_idx % stride; + const int count = br_count_arr[coeff_idx]; + const int *mag = br_mag_arr[coeff_idx]; + br_ctx_arr[coeff_idx][0] = + get_br_ctx_from_count_mag(row, col, count, mag[0]); + } +} + +static INLINE int get_sign_bit_cost(tran_low_t qc, int coeff_idx, + const aom_prob *dc_sign_prob, + int dc_sign_ctx) { + const int sign = (qc < 0) ? 1 : 0; + // sign bit cost + if (coeff_idx == 0) { + return av1_cost_bit(dc_sign_prob[dc_sign_ctx], sign); + } else { + return av1_cost_bit(128, sign); + } +} +static INLINE int get_golomb_cost(int abs_qc) { + if (abs_qc >= 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) { + // residual cost + int r = abs_qc - COEFF_BASE_RANGE - NUM_BASE_LEVELS; + int ri = r; + int length = 0; + + while (ri) { + ri >>= 1; + ++length; + } + + return av1_cost_literal(2 * length - 1); + } else { + return 0; + } +} + +// TODO(angiebird): add static once this function is called +void gen_txb_cache(TxbCache *txb_cache, TxbInfo *txb_info) { + const int16_t *scan = txb_info->scan_order->scan; + gen_nz_count_arr(txb_cache->nz_count_arr, txb_info->qcoeff, txb_info->stride, + txb_info->eob, txb_info->scan_order); + gen_nz_ctx_arr(txb_cache->nz_ctx_arr, txb_cache->nz_count_arr, + txb_info->qcoeff, txb_info->bwl, txb_info->eob, + txb_info->scan_order); + gen_base_count_mag_arr(txb_cache->base_count_arr, txb_cache->base_mag_arr, + txb_info->qcoeff, txb_info->stride, txb_info->eob, + scan); + gen_base_ctx_arr(txb_cache->base_ctx_arr, txb_cache->base_count_arr, + txb_cache->base_mag_arr, txb_info->qcoeff, txb_info->stride, + txb_info->eob, scan); + gen_br_count_mag_arr(txb_cache->br_count_arr, txb_cache->br_mag_arr, + txb_info->qcoeff, txb_info->stride, txb_info->eob, scan); + gen_br_ctx_arr(txb_cache->br_ctx_arr, txb_cache->br_count_arr, + txb_cache->br_mag_arr, txb_info->qcoeff, txb_info->stride, + txb_info->eob, scan); +} + +static INLINE aom_prob get_level_prob(int level, int coeff_idx, + const TxbCache *txb_cache, + const TxbProbs *txb_probs) { + if (level == 0) { + const int ctx = txb_cache->nz_ctx_arr[coeff_idx][0]; + return txb_probs->nz_map[ctx]; + } else if (level >= 1 && level < 1 + NUM_BASE_LEVELS) { + const int idx = level - 1; + const int ctx = txb_cache->base_ctx_arr[idx][coeff_idx][0]; + return txb_probs->coeff_base[idx][ctx]; + } else if (level >= 1 + NUM_BASE_LEVELS && + level < 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) { + const int ctx = txb_cache->br_ctx_arr[coeff_idx][0]; + return txb_probs->coeff_lps[ctx]; + } else if (level >= 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) { + printf("get_level_prob does not support golomb\n"); + assert(0); + return 0; + } else { + assert(0); + return 0; + } +} + +static INLINE tran_low_t get_lower_coeff(tran_low_t qc) { + if (qc == 0) { + return 0; + } + return qc > 0 ? qc - 1 : qc + 1; +} + +static INLINE void update_mag_arr(int *mag_arr, int abs_qc) { + if (mag_arr[0] == abs_qc) { + mag_arr[1] -= 1; + assert(mag_arr[1] >= 0); + } +} + +static INLINE int get_mag_from_mag_arr(const int *mag_arr) { + int mag; + if (mag_arr[1] > 0) { + mag = mag_arr[0]; + } else if (mag_arr[0] > 0) { + mag = mag_arr[0] - 1; + } else { + // no neighbor + assert(mag_arr[0] == 0 && mag_arr[1] == 0); + mag = 0; + } + return mag; +} + +static int neighbor_level_down_update(int *new_count, int *new_mag, int count, + const int *mag, int coeff_idx, + tran_low_t abs_nb_coeff, int nb_coeff_idx, + int level, const TxbInfo *txb_info) { + *new_count = count; + *new_mag = get_mag_from_mag_arr(mag); + + int update = 0; + // check if br_count changes + if (abs_nb_coeff == level) { + update = 1; + *new_count -= 1; + assert(*new_count >= 0); + } + const int row = coeff_idx >> txb_info->bwl; + const int col = coeff_idx - (row << txb_info->bwl); + const int nb_row = nb_coeff_idx >> txb_info->bwl; + const int nb_col = nb_coeff_idx - (nb_row << txb_info->bwl); + + // check if mag changes + if (nb_row >= row && nb_col >= col) { + if (abs_nb_coeff == mag[0]) { + assert(mag[1] > 0); + if (mag[1] == 1) { + // the nb is the only qc with max mag + *new_mag -= 1; + assert(*new_mag >= 0); + update = 1; + } + } + } + return update; +} + +static int try_neighbor_level_down_br(int coeff_idx, int nb_coeff_idx, + const TxbCache *txb_cache, + const TxbProbs *txb_probs, + const TxbInfo *txb_info) { + const tran_low_t qc = txb_info->qcoeff[coeff_idx]; + const tran_low_t abs_qc = abs(qc); + const int level = NUM_BASE_LEVELS + 1; + if (abs_qc < level) return 0; + + const tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx]; + const tran_low_t abs_nb_coeff = abs(nb_coeff); + const int count = txb_cache->br_count_arr[coeff_idx]; + const int *mag = txb_cache->br_mag_arr[coeff_idx]; + int new_count; + int new_mag; + const int update = + neighbor_level_down_update(&new_count, &new_mag, count, mag, coeff_idx, + abs_nb_coeff, nb_coeff_idx, level, txb_info); + if (update) { + const int row = coeff_idx >> txb_info->bwl; + const int col = coeff_idx - (row << txb_info->bwl); + const int ctx = txb_cache->br_ctx_arr[coeff_idx][0]; + const int org_cost = get_br_cost(abs_qc, ctx, txb_probs->coeff_lps); + + const int new_ctx = get_br_ctx_from_count_mag(row, col, new_count, new_mag); + const int new_cost = get_br_cost(abs_qc, new_ctx, txb_probs->coeff_lps); + const int cost_diff = -org_cost + new_cost; + return cost_diff; + } else { + return 0; + } +} + +static int try_neighbor_level_down_base(int coeff_idx, int nb_coeff_idx, + const TxbCache *txb_cache, + const TxbProbs *txb_probs, + const TxbInfo *txb_info) { + const tran_low_t qc = txb_info->qcoeff[coeff_idx]; + const tran_low_t abs_qc = abs(qc); + + int cost_diff = 0; + for (int base_idx = 0; base_idx < NUM_BASE_LEVELS; ++base_idx) { + const int level = base_idx + 1; + if (abs_qc < level) continue; + + const tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx]; + const tran_low_t abs_nb_coeff = abs(nb_coeff); + + const int count = txb_cache->base_count_arr[base_idx][coeff_idx]; + const int *mag = txb_cache->base_mag_arr[coeff_idx]; + int new_count; + int new_mag; + const int update = + neighbor_level_down_update(&new_count, &new_mag, count, mag, coeff_idx, + abs_nb_coeff, nb_coeff_idx, level, txb_info); + if (update) { + const int row = coeff_idx >> txb_info->bwl; + const int col = coeff_idx - (row << txb_info->bwl); + const int ctx = txb_cache->base_ctx_arr[base_idx][coeff_idx][0]; + const int org_cost = + get_base_cost(abs_qc, ctx, txb_probs->coeff_base, base_idx); + + const int new_ctx = + get_base_ctx_from_count_mag(row, col, new_count, new_mag, level); + const int new_cost = + get_base_cost(abs_qc, new_ctx, txb_probs->coeff_base, base_idx); + cost_diff += -org_cost + new_cost; + } + } + return cost_diff; +} + +static int try_neighbor_level_down_nz(int coeff_idx, int nb_coeff_idx, + const TxbCache *txb_cache, + const TxbProbs *txb_probs, + TxbInfo *txb_info) { + // assume eob doesn't change + const tran_low_t qc = txb_info->qcoeff[coeff_idx]; + const tran_low_t abs_qc = abs(qc); + const tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx]; + const tran_low_t abs_nb_coeff = abs(nb_coeff); + if (abs_nb_coeff != 1) return 0; + const int16_t *iscan = txb_info->scan_order->iscan; + const int scan_idx = iscan[coeff_idx]; + if (scan_idx == txb_info->seg_eob) return 0; + const int nb_scan_idx = iscan[nb_coeff_idx]; + if (nb_scan_idx < scan_idx) { + const int count = txb_cache->nz_count_arr[coeff_idx]; + assert(count > 0); + txb_info->qcoeff[nb_coeff_idx] = get_lower_coeff(nb_coeff); + const int new_ctx = get_nz_map_ctx_from_count( + count - 1, txb_info->qcoeff, coeff_idx, txb_info->bwl, iscan); + txb_info->qcoeff[nb_coeff_idx] = nb_coeff; + const int ctx = txb_cache->nz_ctx_arr[coeff_idx][0]; + const int is_nz = abs_qc > 0; + const int org_cost = av1_cost_bit(txb_probs->nz_map[ctx], is_nz); + const int new_cost = av1_cost_bit(txb_probs->nz_map[new_ctx], is_nz); + const int cost_diff = new_cost - org_cost; + return cost_diff; + } else { + return 0; + } +} + +static int try_self_level_down(tran_low_t *low_coeff, int coeff_idx, + const TxbCache *txb_cache, + const TxbProbs *txb_probs, TxbInfo *txb_info) { + const tran_low_t qc = txb_info->qcoeff[coeff_idx]; + if (qc == 0) { + *low_coeff = 0; + return 0; + } + const tran_low_t abs_qc = abs(qc); + *low_coeff = get_lower_coeff(qc); + int cost_diff; + if (*low_coeff == 0) { + const int scan_idx = txb_info->scan_order->iscan[coeff_idx]; + const aom_prob level_prob = + get_level_prob(abs_qc, coeff_idx, txb_cache, txb_probs); + const aom_prob low_level_prob = + get_level_prob(abs(*low_coeff), coeff_idx, txb_cache, txb_probs); + if (scan_idx < txb_info->seg_eob) { + // When level-0, we code the binary of abs_qc > level + // but when level-k k > 0 we code the binary of abs_qc == level + // That's why wee need this special treatment for level-0 map + // TODO(angiebird): make leve-0 consistent to other levels + cost_diff = -av1_cost_bit(level_prob, 1) + + av1_cost_bit(low_level_prob, 0) - + av1_cost_bit(low_level_prob, 1); + } else { + cost_diff = -av1_cost_bit(level_prob, 1); + } + + if (scan_idx < txb_info->seg_eob) { + const int eob_ctx = + get_eob_ctx(txb_info->qcoeff, coeff_idx, txb_info->bwl); + cost_diff -= av1_cost_bit(txb_probs->eob_flag[eob_ctx], + scan_idx == (txb_info->eob - 1)); + } + + const int sign_cost = get_sign_bit_cost( + qc, coeff_idx, txb_probs->dc_sign_prob, txb_info->txb_ctx->dc_sign_ctx); + cost_diff -= sign_cost; + } else if (abs_qc < 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) { + const aom_prob level_prob = + get_level_prob(abs_qc, coeff_idx, txb_cache, txb_probs); + const aom_prob low_level_prob = + get_level_prob(abs(*low_coeff), coeff_idx, txb_cache, txb_probs); + cost_diff = -av1_cost_bit(level_prob, 1) + av1_cost_bit(low_level_prob, 1) - + av1_cost_bit(low_level_prob, 0); + } else if (abs_qc == 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) { + const aom_prob low_level_prob = + get_level_prob(abs(*low_coeff), coeff_idx, txb_cache, txb_probs); + cost_diff = -get_golomb_cost(abs_qc) + av1_cost_bit(low_level_prob, 1) - + av1_cost_bit(low_level_prob, 0); + } else { + assert(abs_qc > 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE); + const tran_low_t abs_low_coeff = abs(*low_coeff); + cost_diff = -get_golomb_cost(abs_qc) + get_golomb_cost(abs_low_coeff); + } + return cost_diff; +} + +#define COST_MAP_SIZE 5 +#define COST_MAP_OFFSET 2 + +static INLINE int check_nz_neighbor(tran_low_t qc) { return abs(qc) == 1; } + +static INLINE int check_base_neighbor(tran_low_t qc) { + return abs(qc) <= 1 + NUM_BASE_LEVELS; +} + +static INLINE int check_br_neighbor(tran_low_t qc) { + return abs(qc) > BR_MAG_OFFSET; +} + +// TODO(angiebird): add static to this function once it's called +int try_level_down(int coeff_idx, const TxbCache *txb_cache, + const TxbProbs *txb_probs, TxbInfo *txb_info, + int (*cost_map)[COST_MAP_SIZE]) { + if (cost_map) { + for (int i = 0; i < COST_MAP_SIZE; ++i) av1_zero(cost_map[i]); + } + + tran_low_t qc = txb_info->qcoeff[coeff_idx]; + tran_low_t low_coeff; + if (qc == 0) return 0; + int accu_cost_diff = 0; + + const int16_t *iscan = txb_info->scan_order->iscan; + const int eob = txb_info->eob; + const int scan_idx = iscan[coeff_idx]; + if (scan_idx < eob) { + const int cost_diff = try_self_level_down(&low_coeff, coeff_idx, txb_cache, + txb_probs, txb_info); + if (cost_map) + cost_map[0 + COST_MAP_OFFSET][0 + COST_MAP_OFFSET] = cost_diff; + accu_cost_diff += cost_diff; + } + + const int row = coeff_idx >> txb_info->bwl; + const int col = coeff_idx - (row << txb_info->bwl); + if (check_nz_neighbor(qc)) { + for (int i = 0; i < SIG_REF_OFFSET_NUM; ++i) { + const int nb_row = row - sig_ref_offset[i][0]; + const int nb_col = col - sig_ref_offset[i][1]; + const int nb_coeff_idx = nb_row * txb_info->stride + nb_col; + const int nb_scan_idx = iscan[nb_coeff_idx]; + if (nb_scan_idx < eob && nb_row >= 0 && nb_col >= 0 && + nb_row < txb_info->stride && nb_col < txb_info->stride) { + const int cost_diff = try_neighbor_level_down_nz( + nb_coeff_idx, coeff_idx, txb_cache, txb_probs, txb_info); + if (cost_map) + cost_map[nb_row - row + COST_MAP_OFFSET] + [nb_col - col + COST_MAP_OFFSET] += cost_diff; + accu_cost_diff += cost_diff; + } + } + } + + if (check_base_neighbor(qc)) { + for (int i = 0; i < BASE_CONTEXT_POSITION_NUM; ++i) { + const int nb_row = row - base_ref_offset[i][0]; + const int nb_col = col - base_ref_offset[i][1]; + const int nb_coeff_idx = nb_row * txb_info->stride + nb_col; + const int nb_scan_idx = iscan[nb_coeff_idx]; + if (nb_scan_idx < eob && nb_row >= 0 && nb_col >= 0 && + nb_row < txb_info->stride && nb_col < txb_info->stride) { + const int cost_diff = try_neighbor_level_down_base( + nb_coeff_idx, coeff_idx, txb_cache, txb_probs, txb_info); + if (cost_map) + cost_map[nb_row - row + COST_MAP_OFFSET] + [nb_col - col + COST_MAP_OFFSET] += cost_diff; + accu_cost_diff += cost_diff; + } + } + } + + if (check_br_neighbor(qc)) { + for (int i = 0; i < BR_CONTEXT_POSITION_NUM; ++i) { + const int nb_row = row - br_ref_offset[i][0]; + const int nb_col = col - br_ref_offset[i][1]; + const int nb_coeff_idx = nb_row * txb_info->stride + nb_col; + const int nb_scan_idx = iscan[nb_coeff_idx]; + if (nb_scan_idx < eob && nb_row >= 0 && nb_col >= 0 && + nb_row < txb_info->stride && nb_col < txb_info->stride) { + const int cost_diff = try_neighbor_level_down_br( + nb_coeff_idx, coeff_idx, txb_cache, txb_probs, txb_info); + if (cost_map) + cost_map[nb_row - row + COST_MAP_OFFSET] + [nb_col - col + COST_MAP_OFFSET] += cost_diff; + accu_cost_diff += cost_diff; + } + } + } + + return accu_cost_diff; +} + +static int get_low_coeff_cost(int coeff_idx, const TxbCache *txb_cache, + const TxbProbs *txb_probs, + const TxbInfo *txb_info) { + const tran_low_t qc = txb_info->qcoeff[coeff_idx]; + const int abs_qc = abs(qc); + assert(abs_qc <= 1); + int cost = 0; + const int scan_idx = txb_info->scan_order->iscan[coeff_idx]; + if (scan_idx < txb_info->seg_eob) { + const aom_prob level_prob = + get_level_prob(0, coeff_idx, txb_cache, txb_probs); + cost += av1_cost_bit(level_prob, qc != 0); + } + + if (qc != 0) { + const int base_idx = 0; + const int ctx = txb_cache->base_ctx_arr[base_idx][coeff_idx][0]; + cost += get_base_cost(abs_qc, ctx, txb_probs->coeff_base, base_idx); + if (scan_idx < txb_info->seg_eob) { + const int eob_ctx = + get_eob_ctx(txb_info->qcoeff, coeff_idx, txb_info->bwl); + cost += av1_cost_bit(txb_probs->eob_flag[eob_ctx], + scan_idx == (txb_info->eob - 1)); + } + cost += get_sign_bit_cost(qc, coeff_idx, txb_probs->dc_sign_prob, + txb_info->txb_ctx->dc_sign_ctx); + } + return cost; +} + +static INLINE void set_eob(TxbInfo *txb_info, int eob) { + txb_info->eob = eob; + txb_info->seg_eob = AOMMIN(eob, tx_size_2d[txb_info->tx_size] - 1); +} + +// TODO(angiebird): add static to this function once it's called +int try_change_eob(int *new_eob, int coeff_idx, const TxbCache *txb_cache, + const TxbProbs *txb_probs, TxbInfo *txb_info) { + assert(txb_info->eob > 0); + const tran_low_t qc = txb_info->qcoeff[coeff_idx]; + const int abs_qc = abs(qc); + if (abs_qc != 1) { + *new_eob = -1; + return 0; + } + const int16_t *iscan = txb_info->scan_order->iscan; + const int16_t *scan = txb_info->scan_order->scan; + const int scan_idx = iscan[coeff_idx]; + *new_eob = 0; + int cost_diff = 0; + cost_diff -= get_low_coeff_cost(coeff_idx, txb_cache, txb_probs, txb_info); + // int coeff_cost = + // get_coeff_cost(qc, scan_idx, txb_info, txb_probs); + // if (-cost_diff != coeff_cost) { + // printf("-cost_diff %d coeff_cost %d\n", -cost_diff, coeff_cost); + // get_low_coeff_cost(coeff_idx, txb_cache, txb_probs, txb_info); + // get_coeff_cost(qc, scan_idx, txb_info, txb_probs); + // } + for (int si = scan_idx - 1; si >= 0; --si) { + const int ci = scan[si]; + if (txb_info->qcoeff[ci] != 0) { + *new_eob = si + 1; + break; + } else { + cost_diff -= get_low_coeff_cost(ci, txb_cache, txb_probs, txb_info); + } + } + + const int org_eob = txb_info->eob; + set_eob(txb_info, *new_eob); + cost_diff += try_level_down(coeff_idx, txb_cache, txb_probs, txb_info, NULL); + set_eob(txb_info, org_eob); + + if (*new_eob > 0) { + // Note that get_eob_ctx does NOT actually account for qcoeff, so we don't + // need to lower down the qcoeff here + const int eob_ctx = + get_eob_ctx(txb_info->qcoeff, scan[*new_eob - 1], txb_info->bwl); + cost_diff -= av1_cost_bit(txb_probs->eob_flag[eob_ctx], 0); + cost_diff += av1_cost_bit(txb_probs->eob_flag[eob_ctx], 1); + } else { + const int txb_skip_ctx = txb_info->txb_ctx->txb_skip_ctx; + cost_diff -= av1_cost_bit(txb_probs->txb_skip[txb_skip_ctx], 0); + cost_diff += av1_cost_bit(txb_probs->txb_skip[txb_skip_ctx], 1); + } + return cost_diff; +} + +static INLINE tran_low_t qcoeff_to_dqcoeff(tran_low_t qc, int dqv, int shift) { + int sgn = qc < 0 ? -1 : 1; + return sgn * ((abs(qc) * dqv) >> shift); +} + +// TODO(angiebird): add static to this function it's called +void update_level_down(int coeff_idx, TxbCache *txb_cache, TxbInfo *txb_info) { + const tran_low_t qc = txb_info->qcoeff[coeff_idx]; + const int abs_qc = abs(qc); + if (qc == 0) return; + const tran_low_t low_coeff = get_lower_coeff(qc); + txb_info->qcoeff[coeff_idx] = low_coeff; + const int dqv = txb_info->dequant[coeff_idx != 0]; + txb_info->dqcoeff[coeff_idx] = + qcoeff_to_dqcoeff(low_coeff, dqv, txb_info->shift); + + const int row = coeff_idx >> txb_info->bwl; + const int col = coeff_idx - (row << txb_info->bwl); + const int eob = txb_info->eob; + const int16_t *iscan = txb_info->scan_order->iscan; + for (int i = 0; i < SIG_REF_OFFSET_NUM; ++i) { + const int nb_row = row - sig_ref_offset[i][0]; + const int nb_col = col - sig_ref_offset[i][1]; + const int nb_coeff_idx = nb_row * txb_info->stride + nb_col; + const int nb_scan_idx = iscan[nb_coeff_idx]; + if (nb_scan_idx < eob && nb_row >= 0 && nb_col >= 0 && + nb_row < txb_info->stride && nb_col < txb_info->stride) { + const int scan_idx = iscan[coeff_idx]; + if (scan_idx < nb_scan_idx) { + const int level = 1; + if (abs_qc == level) { + txb_cache->nz_count_arr[nb_coeff_idx] -= 1; + assert(txb_cache->nz_count_arr[nb_coeff_idx] >= 0); + } + const int count = txb_cache->nz_count_arr[nb_coeff_idx]; + txb_cache->nz_ctx_arr[nb_coeff_idx][0] = get_nz_map_ctx_from_count( + count, txb_info->qcoeff, nb_coeff_idx, txb_info->bwl, iscan); + // int ref_ctx = get_nz_map_ctx2(txb_info->qcoeff, nb_coeff_idx, + // txb_info->bwl, iscan); + // if (ref_ctx != txb_cache->nz_ctx_arr[nb_coeff_idx][0]) + // printf("nz ctx %d ref_ctx %d\n", + // txb_cache->nz_ctx_arr[nb_coeff_idx][0], ref_ctx); + } + } + } + + for (int i = 0; i < BASE_CONTEXT_POSITION_NUM; ++i) { + const int nb_row = row - base_ref_offset[i][0]; + const int nb_col = col - base_ref_offset[i][1]; + const int nb_coeff_idx = nb_row * txb_info->stride + nb_col; + const tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx]; + if (!has_base(nb_coeff, 0)) continue; + const int nb_scan_idx = iscan[nb_coeff_idx]; + if (nb_scan_idx < eob && nb_row >= 0 && nb_col >= 0 && + nb_row < txb_info->stride && nb_col < txb_info->stride) { + if (row >= nb_row && col >= nb_col) + update_mag_arr(txb_cache->base_mag_arr[nb_coeff_idx], abs_qc); + const int mag = + get_mag_from_mag_arr(txb_cache->base_mag_arr[nb_coeff_idx]); + for (int base_idx = 0; base_idx < NUM_BASE_LEVELS; ++base_idx) { + if (!has_base(nb_coeff, base_idx)) continue; + const int level = base_idx + 1; + if (abs_qc == level) { + txb_cache->base_count_arr[base_idx][nb_coeff_idx] -= 1; + assert(txb_cache->base_count_arr[base_idx][nb_coeff_idx] >= 0); + } + const int count = txb_cache->base_count_arr[base_idx][nb_coeff_idx]; + txb_cache->base_ctx_arr[base_idx][nb_coeff_idx][0] = + get_base_ctx_from_count_mag(nb_row, nb_col, count, mag, level); + // int ref_ctx = get_base_ctx(txb_info->qcoeff, nb_coeff_idx, + // txb_info->bwl, level); + // if (ref_ctx != txb_cache->base_ctx_arr[base_idx][nb_coeff_idx][0]) { + // printf("base ctx %d ref_ctx %d\n", + // txb_cache->base_ctx_arr[base_idx][nb_coeff_idx][0], ref_ctx); + // } + } + } + } + + for (int i = 0; i < BR_CONTEXT_POSITION_NUM; ++i) { + const int nb_row = row - br_ref_offset[i][0]; + const int nb_col = col - br_ref_offset[i][1]; + const int nb_coeff_idx = nb_row * txb_info->stride + nb_col; + const int nb_scan_idx = iscan[nb_coeff_idx]; + const tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx]; + if (!has_br(nb_coeff)) continue; + if (nb_scan_idx < eob && nb_row >= 0 && nb_col >= 0 && + nb_row < txb_info->stride && nb_col < txb_info->stride) { + const int level = 1 + NUM_BASE_LEVELS; + if (abs_qc == level) { + txb_cache->br_count_arr[nb_coeff_idx] -= 1; + assert(txb_cache->br_count_arr[nb_coeff_idx] >= 0); + } + if (row >= nb_row && col >= nb_col) + update_mag_arr(txb_cache->br_mag_arr[nb_coeff_idx], abs_qc); + const int count = txb_cache->br_count_arr[nb_coeff_idx]; + const int mag = get_mag_from_mag_arr(txb_cache->br_mag_arr[nb_coeff_idx]); + txb_cache->br_ctx_arr[nb_coeff_idx][0] = + get_br_ctx_from_count_mag(nb_row, nb_col, count, mag); + // int ref_ctx = get_level_ctx(txb_info->qcoeff, nb_coeff_idx, + // txb_info->bwl); + // if (ref_ctx != txb_cache->br_ctx_arr[nb_coeff_idx][0]) { + // printf("base ctx %d ref_ctx %d\n", + // txb_cache->br_ctx_arr[nb_coeff_idx][0], ref_ctx); + // } + } + } +} + +static int get_coeff_cost(tran_low_t qc, int scan_idx, TxbInfo *txb_info, + const TxbProbs *txb_probs) { + const TXB_CTX *txb_ctx = txb_info->txb_ctx; + const int is_nz = (qc != 0); + const tran_low_t abs_qc = abs(qc); + int cost = 0; + const int16_t *scan = txb_info->scan_order->scan; + const int16_t *iscan = txb_info->scan_order->iscan; + + if (scan_idx < txb_info->seg_eob) { + int coeff_ctx = + get_nz_map_ctx2(txb_info->qcoeff, scan[scan_idx], txb_info->bwl, iscan); + cost += av1_cost_bit(txb_probs->nz_map[coeff_ctx], is_nz); + } + + if (is_nz) { + cost += get_sign_bit_cost(qc, scan_idx, txb_probs->dc_sign_prob, + txb_ctx->dc_sign_ctx); + + int ctx_ls[NUM_BASE_LEVELS] = { 0 }; + get_base_ctx_set(txb_info->qcoeff, scan[scan_idx], txb_info->bwl, ctx_ls); + + int i; + for (i = 0; i < NUM_BASE_LEVELS; ++i) { + cost += get_base_cost(abs_qc, ctx_ls[i], txb_probs->coeff_base, i); + } + + if (abs_qc > NUM_BASE_LEVELS) { + int ctx = get_br_ctx(txb_info->qcoeff, scan[scan_idx], txb_info->bwl); + cost += get_br_cost(abs_qc, ctx, txb_probs->coeff_lps); + cost += get_golomb_cost(abs_qc); + } + + if (scan_idx < txb_info->seg_eob) { + int eob_ctx = + get_eob_ctx(txb_info->qcoeff, scan[scan_idx], txb_info->bwl); + cost += av1_cost_bit(txb_probs->eob_flag[eob_ctx], + scan_idx == (txb_info->eob - 1)); + } + } + return cost; +} + +#if TEST_OPTIMIZE_TXB +#define ALL_REF_OFFSET_NUM 17 +static int all_ref_offset[ALL_REF_OFFSET_NUM][2] = { + { 0, 0 }, { -2, -1 }, { -2, 0 }, { -2, 1 }, { -1, -2 }, { -1, -1 }, + { -1, 0 }, { -1, 1 }, { 0, -2 }, { 0, -1 }, { 1, -2 }, { 1, -1 }, + { 1, 0 }, { 2, 0 }, { 0, 1 }, { 0, 2 }, { 1, 1 }, +}; + +static int try_level_down_ref(int coeff_idx, const TxbProbs *txb_probs, + TxbInfo *txb_info, + int (*cost_map)[COST_MAP_SIZE]) { + if (cost_map) { + for (int i = 0; i < COST_MAP_SIZE; ++i) av1_zero(cost_map[i]); + } + tran_low_t qc = txb_info->qcoeff[coeff_idx]; + if (qc == 0) return 0; + int row = coeff_idx >> txb_info->bwl; + int col = coeff_idx - (row << txb_info->bwl); + int org_cost = 0; + for (int i = 0; i < ALL_REF_OFFSET_NUM; ++i) { + int nb_row = row - all_ref_offset[i][0]; + int nb_col = col - all_ref_offset[i][1]; + int nb_coeff_idx = nb_row * txb_info->stride + nb_col; + int nb_scan_idx = txb_info->scan_order->iscan[nb_coeff_idx]; + if (nb_scan_idx < txb_info->eob && nb_row >= 0 && nb_col >= 0 && + nb_row < txb_info->stride && nb_col < txb_info->stride) { + tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx]; + int cost = get_coeff_cost(nb_coeff, nb_scan_idx, txb_info, txb_probs); + if (cost_map) + cost_map[nb_row - row + COST_MAP_OFFSET] + [nb_col - col + COST_MAP_OFFSET] -= cost; + org_cost += cost; + } + } + txb_info->qcoeff[coeff_idx] = get_lower_coeff(qc); + int new_cost = 0; + for (int i = 0; i < ALL_REF_OFFSET_NUM; ++i) { + int nb_row = row - all_ref_offset[i][0]; + int nb_col = col - all_ref_offset[i][1]; + int nb_coeff_idx = nb_row * txb_info->stride + nb_col; + int nb_scan_idx = txb_info->scan_order->iscan[nb_coeff_idx]; + if (nb_scan_idx < txb_info->eob && nb_row >= 0 && nb_col >= 0 && + nb_row < txb_info->stride && nb_col < txb_info->stride) { + tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx]; + int cost = get_coeff_cost(nb_coeff, nb_scan_idx, txb_info, txb_probs); + if (cost_map) + cost_map[nb_row - row + COST_MAP_OFFSET] + [nb_col - col + COST_MAP_OFFSET] += cost; + new_cost += cost; + } + } + txb_info->qcoeff[coeff_idx] = qc; + return new_cost - org_cost; +} +static void test_level_down(int coeff_idx, const TxbCache *txb_cache, + const TxbProbs *txb_probs, TxbInfo *txb_info) { + int cost_map[COST_MAP_SIZE][COST_MAP_SIZE]; + int ref_cost_map[COST_MAP_SIZE][COST_MAP_SIZE]; + const int cost_diff = + try_level_down(coeff_idx, txb_cache, txb_probs, txb_info, cost_map); + const int cost_diff_ref = + try_level_down_ref(coeff_idx, txb_probs, txb_info, ref_cost_map); + if (cost_diff != cost_diff_ref) { + printf("qc %d cost_diff %d cost_diff_ref %d\n", txb_info->qcoeff[coeff_idx], + cost_diff, cost_diff_ref); + for (int r = 0; r < COST_MAP_SIZE; ++r) { + for (int c = 0; c < COST_MAP_SIZE; ++c) { + printf("%d:%d ", cost_map[r][c], ref_cost_map[r][c]); + } + printf("\n"); + } + } +} +#endif + +// TODO(angiebird): make this static once it's called +int get_txb_cost(TxbInfo *txb_info, const TxbProbs *txb_probs) { + int cost = 0; + int txb_skip_ctx = txb_info->txb_ctx->txb_skip_ctx; + const int16_t *scan = txb_info->scan_order->scan; + if (txb_info->eob == 0) { + cost = av1_cost_bit(txb_probs->txb_skip[txb_skip_ctx], 1); + return cost; + } + cost = av1_cost_bit(txb_probs->txb_skip[txb_skip_ctx], 0); + for (int c = 0; c < txb_info->eob; ++c) { + tran_low_t qc = txb_info->qcoeff[scan[c]]; + int coeff_cost = get_coeff_cost(qc, c, txb_info, txb_probs); + cost += coeff_cost; + } + return cost; +} + +#if TEST_OPTIMIZE_TXB +void test_try_change_eob(TxbInfo *txb_info, TxbProbs *txb_probs, + TxbCache *txb_cache) { + int eob = txb_info->eob; + const int16_t *scan = txb_info->scan_order->scan; + if (eob > 0) { + int last_si = eob - 1; + int last_ci = scan[last_si]; + int last_coeff = txb_info->qcoeff[last_ci]; + if (abs(last_coeff) == 1) { + int new_eob; + int cost_diff = + try_change_eob(&new_eob, last_ci, txb_cache, txb_probs, txb_info); + int org_eob = txb_info->eob; + int cost = get_txb_cost(txb_info, txb_probs); + + txb_info->qcoeff[last_ci] = get_lower_coeff(last_coeff); + set_eob(txb_info, new_eob); + int new_cost = get_txb_cost(txb_info, txb_probs); + set_eob(txb_info, org_eob); + txb_info->qcoeff[last_ci] = last_coeff; + + int ref_cost_diff = -cost + new_cost; + if (cost_diff != ref_cost_diff) + printf("org_eob %d new_eob %d cost_diff %d ref_cost_diff %d\n", org_eob, + new_eob, cost_diff, ref_cost_diff); + } + } +} +#endif + +static INLINE int64_t get_coeff_dist(tran_low_t tcoeff, tran_low_t dqcoeff, + int shift) { + const int64_t diff = (tcoeff - dqcoeff) * (1 << shift); + const int64_t error = diff * diff; + return error; +} + +typedef struct LevelDownStats { + int update; + tran_low_t low_qc; + tran_low_t low_dqc; + int64_t rd_diff; + int cost_diff; + int64_t dist_diff; + int new_eob; +} LevelDownStats; + +void try_level_down_facade(LevelDownStats *stats, int scan_idx, + const TxbCache *txb_cache, const TxbProbs *txb_probs, + TxbInfo *txb_info) { + const int16_t *scan = txb_info->scan_order->scan; + const int coeff_idx = scan[scan_idx]; + const tran_low_t qc = txb_info->qcoeff[coeff_idx]; + stats->new_eob = -1; + stats->update = 0; + if (qc == 0) { + return; + } + + const tran_low_t tqc = txb_info->tcoeff[coeff_idx]; + const int dqv = txb_info->dequant[coeff_idx != 0]; + + const tran_low_t dqc = qcoeff_to_dqcoeff(qc, dqv, txb_info->shift); + const int64_t dqc_dist = get_coeff_dist(tqc, dqc, txb_info->shift); + + stats->low_qc = get_lower_coeff(qc); + stats->low_dqc = qcoeff_to_dqcoeff(stats->low_qc, dqv, txb_info->shift); + const int64_t low_dqc_dist = + get_coeff_dist(tqc, stats->low_dqc, txb_info->shift); + + stats->dist_diff = -dqc_dist + low_dqc_dist; + stats->cost_diff = 0; + stats->new_eob = txb_info->eob; + if (scan_idx == txb_info->eob - 1 && abs(qc) == 1) { + stats->cost_diff = try_change_eob(&stats->new_eob, coeff_idx, txb_cache, + txb_probs, txb_info); + } else { + stats->cost_diff = + try_level_down(coeff_idx, txb_cache, txb_probs, txb_info, NULL); +#if TEST_OPTIMIZE_TXB + test_level_down(coeff_idx, txb_cache, txb_probs, txb_info); +#endif + } + stats->rd_diff = RDCOST(txb_info->rdmult, txb_info->rddiv, stats->cost_diff, + stats->dist_diff); + if (stats->rd_diff < 0) stats->update = 1; + return; +} + +static int optimize_txb(TxbInfo *txb_info, const TxbProbs *txb_probs, + TxbCache *txb_cache, int dry_run) { + int update = 0; + if (txb_info->eob == 0) return update; + int cost_diff = 0; + int64_t dist_diff = 0; + int64_t rd_diff = 0; + const int max_eob = tx_size_2d[txb_info->tx_size]; + +#if TEST_OPTIMIZE_TXB + int64_t sse; + int64_t org_dist = + av1_block_error_c(txb_info->tcoeff, txb_info->dqcoeff, max_eob, &sse) * + (1 << (2 * txb_info->shift)); + int org_cost = get_txb_cost(txb_info, txb_probs); +#endif + + tran_low_t *org_qcoeff = txb_info->qcoeff; + tran_low_t *org_dqcoeff = txb_info->dqcoeff; + + tran_low_t tmp_qcoeff[MAX_TX_SQUARE]; + tran_low_t tmp_dqcoeff[MAX_TX_SQUARE]; + const int org_eob = txb_info->eob; + if (dry_run) { + memcpy(tmp_qcoeff, org_qcoeff, sizeof(org_qcoeff[0]) * max_eob); + memcpy(tmp_dqcoeff, org_dqcoeff, sizeof(org_dqcoeff[0]) * max_eob); + txb_info->qcoeff = tmp_qcoeff; + txb_info->dqcoeff = tmp_dqcoeff; + } + + const int16_t *scan = txb_info->scan_order->scan; + + // forward optimize the nz_map + const int cur_eob = txb_info->eob; + for (int si = 0; si < cur_eob; ++si) { + const int coeff_idx = scan[si]; + tran_low_t qc = txb_info->qcoeff[coeff_idx]; + if (abs(qc) == 1) { + LevelDownStats stats; + try_level_down_facade(&stats, si, txb_cache, txb_probs, txb_info); + if (stats.update) { + update = 1; + cost_diff += stats.cost_diff; + dist_diff += stats.dist_diff; + rd_diff += stats.rd_diff; + update_level_down(coeff_idx, txb_cache, txb_info); + set_eob(txb_info, stats.new_eob); + } + } + } + + // backward optimize the level-k map + for (int si = txb_info->eob - 1; si >= 0; --si) { + LevelDownStats stats; + try_level_down_facade(&stats, si, txb_cache, txb_probs, txb_info); + const int coeff_idx = scan[si]; + if (stats.update) { +#if TEST_OPTIMIZE_TXB +// printf("si %d low_qc %d cost_diff %d dist_diff %ld rd_diff %ld eob %d new_eob +// %d\n", si, stats.low_qc, stats.cost_diff, stats.dist_diff, stats.rd_diff, +// txb_info->eob, stats.new_eob); +#endif + update = 1; + cost_diff += stats.cost_diff; + dist_diff += stats.dist_diff; + rd_diff += stats.rd_diff; + update_level_down(coeff_idx, txb_cache, txb_info); + set_eob(txb_info, stats.new_eob); + } + if (si > txb_info->eob) si = txb_info->eob; + } +#if TEST_OPTIMIZE_TXB + int64_t new_dist = + av1_block_error_c(txb_info->tcoeff, txb_info->dqcoeff, max_eob, &sse) * + (1 << (2 * txb_info->shift)); + int new_cost = get_txb_cost(txb_info, txb_probs); + int64_t ref_dist_diff = new_dist - org_dist; + int ref_cost_diff = new_cost - org_cost; + if (cost_diff != ref_cost_diff || dist_diff != ref_dist_diff) + printf( + "overall rd_diff %ld\ncost_diff %d ref_cost_diff%d\ndist_diff %ld " + "ref_dist_diff %ld\neob %d new_eob %d\n\n", + rd_diff, cost_diff, ref_cost_diff, dist_diff, ref_dist_diff, org_eob, + txb_info->eob); +#endif + if (dry_run) { + txb_info->qcoeff = org_qcoeff; + txb_info->dqcoeff = org_dqcoeff; + set_eob(txb_info, org_eob); + } + return update; +} + +// These numbers are empirically obtained. +static const int plane_rd_mult[REF_TYPES][PLANE_TYPES] = { +#if CONFIG_EC_ADAPT + { 17, 13 }, { 16, 10 }, +#else + { 20, 12 }, { 16, 12 }, +#endif +}; + +int av1_optimize_txb(const AV1_COMMON *cm, MACROBLOCK *x, int plane, int block, + TX_SIZE tx_size, TXB_CTX *txb_ctx) { + MACROBLOCKD *const xd = &x->e_mbd; + const PLANE_TYPE plane_type = get_plane_type(plane); + const TX_TYPE tx_type = get_tx_type(plane_type, xd, block, tx_size); + const MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi; + const struct macroblock_plane *p = &x->plane[plane]; + struct macroblockd_plane *pd = &xd->plane[plane]; + const int eob = p->eobs[block]; + tran_low_t *qcoeff = BLOCK_OFFSET(p->qcoeff, block); + tran_low_t *dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block); + const tran_low_t *tcoeff = BLOCK_OFFSET(p->coeff, block); + const int16_t *dequant = pd->dequant; + const int seg_eob = AOMMIN(eob, tx_size_2d[tx_size] - 1); + const aom_prob *nz_map = xd->fc->nz_map[tx_size][plane_type]; + + const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2; + const int stride = 1 << bwl; + aom_prob(*coeff_base)[COEFF_BASE_CONTEXTS] = + xd->fc->coeff_base[tx_size][plane_type]; + + const aom_prob *coeff_lps = xd->fc->coeff_lps[tx_size][plane_type]; + + const int is_inter = is_inter_block(mbmi); + const SCAN_ORDER *const scan_order = + get_scan(cm, tx_size, tx_type, is_inter_block(mbmi)); + + const TxbProbs txb_probs = { xd->fc->dc_sign[plane_type], + nz_map, + coeff_base, + coeff_lps, + xd->fc->eob_flag[tx_size][plane_type], + xd->fc->txb_skip[tx_size] }; + + const int shift = av1_get_tx_scale(tx_size); + const int64_t rdmult = + (x->rdmult * plane_rd_mult[is_inter][plane_type] + 2) >> 2; + const int64_t rddiv = x->rddiv; + + TxbInfo txb_info = { qcoeff, dqcoeff, tcoeff, dequant, shift, + tx_size, bwl, stride, eob, seg_eob, + scan_order, txb_ctx, rdmult, rddiv }; + TxbCache txb_cache; + gen_txb_cache(&txb_cache, &txb_info); + + const int update = optimize_txb(&txb_info, &txb_probs, &txb_cache, 0); + if (update) p->eobs[block] = txb_info.eob; + return txb_info.eob; +} int av1_get_txb_entropy_context(const tran_low_t *qcoeff, const SCAN_ORDER *scan_order, int eob) { const int16_t *scan = scan_order->scan; @@ -394,10 +1495,10 @@ int av1_get_txb_entropy_context(const tran_low_t *qcoeff, return cul_level; } -static void update_txb_context(int plane, int block, int blk_row, int blk_col, - BLOCK_SIZE plane_bsize, TX_SIZE tx_size, - void *arg) { - TxbParams *const args = arg; +void av1_update_txb_context_b(int plane, int block, int blk_row, int blk_col, + BLOCK_SIZE plane_bsize, TX_SIZE tx_size, + void *arg) { + struct tokenize_b_args *const args = arg; const AV1_COMP *cpi = args->cpi; const AV1_COMMON *cm = &cpi->common; ThreadData *const td = args->td; @@ -418,10 +1519,10 @@ static void update_txb_context(int plane, int block, int blk_row, int blk_col, av1_set_contexts(xd, pd, plane, tx_size, cul_level, blk_col, blk_row); } -static void update_and_record_txb_context(int plane, int block, int blk_row, - int blk_col, BLOCK_SIZE plane_bsize, - TX_SIZE tx_size, void *arg) { - TxbParams *const args = arg; +void av1_update_and_record_txb_context(int plane, int block, int blk_row, + int blk_col, BLOCK_SIZE plane_bsize, + TX_SIZE tx_size, void *arg) { + struct tokenize_b_args *const args = arg; const AV1_COMP *cpi = args->cpi; const AV1_COMMON *cm = &cpi->common; ThreadData *const td = args->td; @@ -529,7 +1630,7 @@ static void update_and_record_txb_context(int plane, int block, int blk_row, } // level is above 1. - ctx = get_level_ctx(tcoeff, scan[c], bwl); + ctx = get_br_ctx(tcoeff, scan[c], bwl); for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) { if (level == (idx + 1 + NUM_BASE_LEVELS)) { ++td->counts->coeff_lps[tx_size][plane_type][ctx][1]; @@ -568,23 +1669,23 @@ void av1_update_txb_context(const AV1_COMP *cpi, ThreadData *td, const int ctx = av1_get_skip_context(xd); const int skip_inc = !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP); - struct TxbParams arg = { cpi, td, 0 }; + struct tokenize_b_args arg = { cpi, td, NULL, 0 }; (void)rate; (void)mi_row; (void)mi_col; if (mbmi->skip) { if (!dry_run) td->counts->skip[ctx][1] += skip_inc; - reset_skip_context(xd, bsize); + av1_reset_skip_context(xd, mi_row, mi_col, bsize); return; } if (!dry_run) { td->counts->skip[ctx][0] += skip_inc; av1_foreach_transformed_block(xd, bsize, mi_row, mi_col, - update_and_record_txb_context, &arg); + av1_update_and_record_txb_context, &arg); } else if (dry_run == DRY_RUN_NORMAL) { - av1_foreach_transformed_block(xd, bsize, mi_row, mi_col, update_txb_context, - &arg); + av1_foreach_transformed_block(xd, bsize, mi_row, mi_col, + av1_update_txb_context_b, &arg); } else { printf("DRY_RUN_COSTCOEFFS is not supported yet\n"); assert(0); @@ -749,8 +1850,7 @@ int64_t av1_search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane, av1_invalid_rd_stats(&this_rd_stats); av1_xform_quant(cm, x, plane, block, blk_row, blk_col, plane_bsize, tx_size, coeff_ctx, AV1_XFORM_QUANT_FP); - if (x->plane[plane].eobs[block] && !xd->lossless[mbmi->segment_id]) - av1_optimize_b(cm, x, plane, block, tx_size, coeff_ctx); + av1_optimize_b(cm, x, plane, block, plane_bsize, tx_size, a, l); av1_dist_block(cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size, &this_rd_stats.dist, &this_rd_stats.sse, OUTPUT_HAS_PREDICTED_PIXELS); @@ -771,8 +1871,7 @@ int64_t av1_search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane, // copy the best result in the above tx_type search for loop av1_xform_quant(cm, x, plane, block, blk_row, blk_col, plane_bsize, tx_size, coeff_ctx, AV1_XFORM_QUANT_FP); - if (x->plane[plane].eobs[block] && !xd->lossless[mbmi->segment_id]) - av1_optimize_b(cm, x, plane, block, tx_size, coeff_ctx); + av1_optimize_b(cm, x, plane, block, plane_bsize, tx_size, a, l); if (!is_inter_block(mbmi)) { // intra mode needs decoded result such that the next transform block // can use it for prediction. |