/* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */ /* ==================================================================== * Copyright (c) 2013 Carnegie Mellon University. All rights * reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in * the documentation and/or other materials provided with the * distribution. * * * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * ==================================================================== * */ /* * kws_search.c -- Search object for key phrase spotting. */ #include #include #include #include #include #include #include #include #include "pocketsphinx_internal.h" #include "kws_search.h" /** Access macros */ #define hmm_is_active(hmm) ((hmm)->frame > 0) #define kws_nth_hmm(keyword,n) (&((keyword)->hmms[n])) static ps_lattice_t * kws_search_lattice(ps_search_t * search) { return NULL; } static int kws_search_prob(ps_search_t * search) { return 0; } static void kws_seg_free(ps_seg_t *seg) { kws_seg_t *itor = (kws_seg_t *)seg; ckd_free(itor); } static void kws_seg_fill(kws_seg_t *itor) { kws_detection_t* detection = (kws_detection_t*)gnode_ptr(itor->detection); itor->base.word = detection->keyphrase; itor->base.sf = detection->sf; itor->base.ef = detection->ef; itor->base.prob = detection->prob; itor->base.ascr = detection->ascr; itor->base.lscr = 0; } static ps_seg_t * kws_seg_next(ps_seg_t *seg) { kws_seg_t *itor = (kws_seg_t *)seg; itor->detection = gnode_next(itor->detection); if (!itor->detection) { kws_seg_free(seg); return NULL; } kws_seg_fill(itor); return seg; } static ps_segfuncs_t kws_segfuncs = { /* seg_next */ kws_seg_next, /* seg_free */ kws_seg_free }; static ps_seg_t * kws_search_seg_iter(ps_search_t * search, int32 * out_score) { kws_search_t *kwss = (kws_search_t *)search; kws_seg_t *itor; if (!kwss->detections->detect_list) return NULL; if (out_score) *out_score = 0; itor = (kws_seg_t *)ckd_calloc(1, sizeof(*itor)); itor->base.vt = &kws_segfuncs; itor->base.search = search; itor->base.lwf = 1.0; itor->detection = kwss->detections->detect_list; kws_seg_fill(itor); return (ps_seg_t *)itor; } static ps_searchfuncs_t kws_funcs = { /* name: */ "kws", /* start: */ kws_search_start, /* step: */ kws_search_step, /* finish: */ kws_search_finish, /* reinit: */ kws_search_reinit, /* free: */ kws_search_free, /* lattice: */ kws_search_lattice, /* hyp: */ kws_search_hyp, /* prob: */ kws_search_prob, /* seg_iter: */ kws_search_seg_iter, }; /* Scans the dictionary and check if all words are present. */ static int kws_search_check_dict(kws_search_t * kwss) { dict_t *dict; char **wrdptr; char *tmp_keyphrase; int32 nwrds, wid; int keyword_iter, i; uint8 success; success = TRUE; dict = ps_search_dict(kwss); for (keyword_iter = 0; keyword_iter < kwss->n_keyphrases; keyword_iter++) { tmp_keyphrase = (char *) ckd_salloc(kwss->keyphrases[keyword_iter].word); nwrds = str2words(tmp_keyphrase, NULL, 0); wrdptr = (char **) ckd_calloc(nwrds, sizeof(*wrdptr)); str2words(tmp_keyphrase, wrdptr, nwrds); for (i = 0; i < nwrds; i++) { wid = dict_wordid(dict, wrdptr[i]); if (wid == BAD_S3WID) { E_ERROR("The word '%s' is missing in the dictionary\n", wrdptr[i]); success = FALSE; break; } } ckd_free(wrdptr); ckd_free(tmp_keyphrase); } return success; } /* Activate senones for scoring */ static void kws_search_sen_active(kws_search_t * kwss) { int i, keyword_iter; acmod_clear_active(ps_search_acmod(kwss)); /* active phone loop hmms */ for (i = 0; i < kwss->n_pl; i++) acmod_activate_hmm(ps_search_acmod(kwss), &kwss->pl_hmms[i]); /* activate hmms in active nodes */ for (keyword_iter = 0; keyword_iter < kwss->n_keyphrases; keyword_iter++) { kws_keyword_t *keyword = &kwss->keyphrases[keyword_iter]; for (i = 0; i < keyword->n_hmms; i++) { if (hmm_is_active(kws_nth_hmm(keyword, i))) acmod_activate_hmm(ps_search_acmod(kwss), kws_nth_hmm(keyword, i)); } } } /* * Evaluate all the active HMMs. * (Executed once per frame.) */ static void kws_search_hmm_eval(kws_search_t * kwss, int16 const *senscr) { int32 i, keyword_iter; int32 bestscore = WORST_SCORE; hmm_context_set_senscore(kwss->hmmctx, senscr); /* evaluate hmms from phone loop */ for (i = 0; i < kwss->n_pl; ++i) { hmm_t *hmm = &kwss->pl_hmms[i]; int32 score; score = hmm_vit_eval(hmm); if (score BETTER_THAN bestscore) bestscore = score; } /* evaluate hmms for active nodes */ for (keyword_iter = 0; keyword_iter < kwss->n_keyphrases; keyword_iter++) { kws_keyword_t *keyword = &kwss->keyphrases[keyword_iter]; for (i = 0; i < keyword->n_hmms; i++) { hmm_t *hmm = kws_nth_hmm(keyword, i); if (hmm_is_active(hmm)) { int32 score; score = hmm_vit_eval(hmm); if (score BETTER_THAN bestscore) bestscore = score; } } } kwss->bestscore = bestscore; } /* * (Beam) prune the just evaluated HMMs, determine which ones remain * active. Executed once per frame. */ static void kws_search_hmm_prune(kws_search_t * kwss) { int32 thresh, i, keyword_iter; thresh = kwss->bestscore + kwss->beam; for (keyword_iter = 0; keyword_iter < kwss->n_keyphrases; keyword_iter++) { kws_keyword_t *keyword = &kwss->keyphrases[keyword_iter]; for (i = 0; i < keyword->n_hmms; i++) { hmm_t *hmm = kws_nth_hmm(keyword, i); if (hmm_is_active(hmm) && hmm_bestscore(hmm) < thresh) hmm_clear(hmm); } } } /** * Do phone transitions */ static void kws_search_trans(kws_search_t * kwss) { hmm_t *pl_best_hmm = NULL; int32 best_out_score = WORST_SCORE; int i, keyword_iter; uint8 to_clear; /* select best hmm in phone-loop to be a predecessor */ for (i = 0; i < kwss->n_pl; i++) if (hmm_out_score(&kwss->pl_hmms[i]) BETTER_THAN best_out_score) { best_out_score = hmm_out_score(&kwss->pl_hmms[i]); pl_best_hmm = &kwss->pl_hmms[i]; } /* out probs are not ready yet */ if (!pl_best_hmm) return; /* Check whether keyword wasn't spotted yet */ to_clear = FALSE; for (keyword_iter = 0; keyword_iter < kwss->n_keyphrases; keyword_iter++) { kws_keyword_t *keyword; hmm_t *last_hmm; keyword = &kwss->keyphrases[keyword_iter]; last_hmm = kws_nth_hmm(keyword, keyword->n_hmms - 1); if (hmm_is_active(last_hmm) && hmm_out_score(pl_best_hmm) BETTER_THAN WORST_SCORE) { if (hmm_out_score(last_hmm) - hmm_out_score(pl_best_hmm) >= keyword->threshold) { int32 prob = hmm_out_score(last_hmm) - hmm_out_score(pl_best_hmm); kws_detections_add(kwss->detections, keyword->word, hmm_out_history(last_hmm), kwss->frame, prob, hmm_out_score(last_hmm)); to_clear = TRUE; } /* keyword is spotted */ } /* last hmm of keyword is active */ } /* keywords loop */ if (to_clear) { for (keyword_iter = 0; keyword_iter < kwss->n_keyphrases; keyword_iter++) { kws_keyword_t* keyword = &kwss->keyphrases[keyword_iter]; for (i = 0; i < keyword->n_hmms; i++) { hmm_clear(kws_nth_hmm(keyword, i)); } } } /* clear all keywords because something was spotted */ /* Make transition for all phone loop hmms */ for (i = 0; i < kwss->n_pl; i++) { if (hmm_out_score(pl_best_hmm) + kwss->plp BETTER_THAN hmm_in_score(&kwss->pl_hmms[i])) { hmm_enter(&kwss->pl_hmms[i], hmm_out_score(pl_best_hmm) + kwss->plp, hmm_out_history(pl_best_hmm), kwss->frame + 1); } } /* Activate new keyword nodes, enter their hmms */ for (keyword_iter = 0; keyword_iter < kwss->n_keyphrases; keyword_iter++) { kws_keyword_t *keyword = &kwss->keyphrases[keyword_iter]; for (i = keyword->n_hmms - 1; i > 0; i--) { hmm_t *pred_hmm = kws_nth_hmm(keyword, i - 1); hmm_t *hmm = kws_nth_hmm(keyword, i); if (hmm_is_active(pred_hmm)) { if (!hmm_is_active(hmm) || hmm_out_score(pred_hmm) BETTER_THAN hmm_in_score(hmm)) hmm_enter(hmm, hmm_out_score(pred_hmm), hmm_out_history(pred_hmm), kwss->frame + 1); } } /* Enter keyword start node from phone loop */ if (hmm_out_score(pl_best_hmm) BETTER_THAN hmm_in_score(kws_nth_hmm(keyword, 0))) hmm_enter(kws_nth_hmm(keyword, 0), hmm_out_score(pl_best_hmm), kwss->frame, kwss->frame + 1); } /* keywords loop */ } static int kws_search_read_list(kws_search_t *kwss, const char* keyfile) { FILE *list_file; lineiter_t *li; int i; if ((list_file = fopen(keyfile, "r")) == NULL) { E_ERROR_SYSTEM("Failed to open keyword file '%s'", keyfile); return -1; } /* count keyphrases amount */ kwss->n_keyphrases = 0; for (li = lineiter_start(list_file); li; li = lineiter_next(li)) if (li->len > 0) kwss->n_keyphrases++; kwss->keyphrases = (kws_keyword_t *)ckd_calloc(kwss->n_keyphrases, sizeof(*kwss->keyphrases)); fseek(list_file, 0L, SEEK_SET); /* read keyphrases */ for (li = lineiter_start(list_file), i=0; li; li = lineiter_next(li), i++) { size_t last_ptr = li->len - 1; kwss->keyphrases[i].threshold = kwss->def_threshold; while (li->buf[last_ptr] == '\n') last_ptr--; if (li->buf[last_ptr] == '/') { size_t digit_len, start; char digit[16]; start = last_ptr - 1; while (li->buf[start] != '/' && start > 0) start--; digit_len = last_ptr - start; memcpy(digit, &li->buf[start+1], digit_len); kwss->keyphrases[i].threshold = (int32) logmath_log(kwss->base.acmod->lmath, atof_c(digit)) >> SENSCR_SHIFT; li->buf[start-1] = '\0'; } li->buf[last_ptr + 1] = '\0'; kwss->keyphrases[i].word = ckd_salloc(li->buf); } fclose(list_file); return 0; } ps_search_t * kws_search_init(const char *keyphrase, const char *keyfile, cmd_ln_t * config, acmod_t * acmod, dict_t * dict, dict2pid_t * d2p) { kws_search_t *kwss = (kws_search_t *) ckd_calloc(1, sizeof(*kwss)); ps_search_init(ps_search_base(kwss), &kws_funcs, config, acmod, dict, d2p); kwss->detections = (kws_detections_t *)ckd_calloc(1, sizeof(*kwss->detections)); kwss->beam = (int32) logmath_log(acmod->lmath, cmd_ln_float64_r(config, "-beam")) >> SENSCR_SHIFT; kwss->plp = (int32) logmath_log(acmod->lmath, cmd_ln_float32_r(config, "-kws_plp")) >> SENSCR_SHIFT; kwss->def_threshold = (int32) logmath_log(acmod->lmath, cmd_ln_float64_r(config, "-kws_threshold")) >> SENSCR_SHIFT; E_INFO("KWS(beam: %d, plp: %d, default threshold %d)\n", kwss->beam, kwss->plp, kwss->def_threshold); if (keyfile) { if (kws_search_read_list(kwss, keyfile) < 0) { E_ERROR("Failed to create kws search\n"); kws_search_free(ps_search_base(kwss)); return NULL; } } else { kwss->n_keyphrases = 1; kwss->keyphrases = (kws_keyword_t *)ckd_calloc(kwss->n_keyphrases, sizeof(*kwss->keyphrases)); kwss->keyphrases[0].threshold = kwss->def_threshold; kwss->keyphrases[0].word = ckd_salloc(keyphrase); } /* Check if all words are in dictionary */ if (!kws_search_check_dict(kwss)) { kws_search_free(ps_search_base(kwss)); return NULL; } /* Reinit for provided keyword */ if (kws_search_reinit(ps_search_base(kwss), ps_search_dict(kwss), ps_search_dict2pid(kwss)) < 0) { ps_search_free(ps_search_base(kwss)); return NULL; } return ps_search_base(kwss); } void kws_search_free(ps_search_t * search) { int i; kws_search_t *kwss; kwss = (kws_search_t *) search; ps_search_deinit(search); hmm_context_free(kwss->hmmctx); kws_detections_reset(kwss->detections); ckd_free(kwss->pl_hmms); for (i = 0; i < kwss->n_keyphrases; i++) { ckd_free(kwss->keyphrases[i].hmms); ckd_free(kwss->keyphrases[i].word); } ckd_free(kwss->keyphrases); ckd_free(kwss); } int kws_search_reinit(ps_search_t * search, dict_t * dict, dict2pid_t * d2p) { char **wrdptr; char *tmp_keyphrase; int32 wid, pronlen; int32 n_hmms, n_wrds; int32 ssid, tmatid; int i, j, p, keyword_iter; kws_search_t *kwss = (kws_search_t *) search; bin_mdef_t *mdef = search->acmod->mdef; int32 silcipid = bin_mdef_silphone(mdef); /* Free old dict2pid, dict */ ps_search_base_reinit(search, dict, d2p); /* Initialize HMM context. */ if (kwss->hmmctx) hmm_context_free(kwss->hmmctx); kwss->hmmctx = hmm_context_init(bin_mdef_n_emit_state(search->acmod->mdef), search->acmod->tmat->tp, NULL, search->acmod->mdef->sseq); if (kwss->hmmctx == NULL) return -1; /* Initialize phone loop HMMs. */ if (kwss->pl_hmms) { for (i = 0; i < kwss->n_pl; ++i) hmm_deinit((hmm_t *) & kwss->pl_hmms[i]); ckd_free(kwss->pl_hmms); } kwss->n_pl = bin_mdef_n_ciphone(search->acmod->mdef); kwss->pl_hmms = (hmm_t *) ckd_calloc(kwss->n_pl, sizeof(*kwss->pl_hmms)); for (i = 0; i < kwss->n_pl; ++i) { hmm_init(kwss->hmmctx, (hmm_t *) & kwss->pl_hmms[i], FALSE, bin_mdef_pid2ssid(search->acmod->mdef, i), bin_mdef_pid2tmatid(search->acmod->mdef, i)); } for (keyword_iter = 0; keyword_iter < kwss->n_keyphrases; keyword_iter++) { kws_keyword_t *keyword = &kwss->keyphrases[keyword_iter]; /* Initialize keyphrase HMMs */ tmp_keyphrase = (char *) ckd_salloc(keyword->word); n_wrds = str2words(tmp_keyphrase, NULL, 0); wrdptr = (char **) ckd_calloc(n_wrds, sizeof(*wrdptr)); str2words(tmp_keyphrase, wrdptr, n_wrds); /* count amount of hmms */ n_hmms = 0; for (i = 0; i < n_wrds; i++) { wid = dict_wordid(dict, wrdptr[i]); pronlen = dict_pronlen(dict, wid); n_hmms += pronlen; } /* allocate node array */ if (keyword->hmms) ckd_free(keyword->hmms); keyword->hmms = (hmm_t *) ckd_calloc(n_hmms, sizeof(hmm_t)); keyword->n_hmms = n_hmms; /* fill node array */ j = 0; for (i = 0; i < n_wrds; i++) { wid = dict_wordid(dict, wrdptr[i]); pronlen = dict_pronlen(dict, wid); for (p = 0; p < pronlen; p++) { int32 ci = dict_pron(dict, wid, p); if (p == 0) { /* first phone of word */ int32 rc = pronlen > 1 ? dict_pron(dict, wid, 1) : silcipid; ssid = dict2pid_ldiph_lc(d2p, ci, rc, silcipid); } else if (p == pronlen - 1) { /* last phone of the word */ int32 lc = dict_pron(dict, wid, p - 1); xwdssid_t *rssid = dict2pid_rssid(d2p, ci, lc); int j = rssid->cimap[silcipid]; ssid = rssid->ssid[j]; } else { /* word internal phone */ ssid = dict2pid_internal(d2p, wid, p); } tmatid = bin_mdef_pid2tmatid(mdef, ci); hmm_init(kwss->hmmctx, &keyword->hmms[j], FALSE, ssid, tmatid); j++; } } ckd_free(wrdptr); ckd_free(tmp_keyphrase); } return 0; } int kws_search_start(ps_search_t * search) { int i; kws_search_t *kwss = (kws_search_t *) search; kwss->frame = 0; kwss->bestscore = 0; kws_detections_reset(kwss->detections); /* Reset and enter all phone-loop HMMs. */ for (i = 0; i < kwss->n_pl; ++i) { hmm_t *hmm = (hmm_t *) & kwss->pl_hmms[i]; hmm_clear(hmm); hmm_enter(hmm, 0, -1, 0); } return 0; } int kws_search_step(ps_search_t * search, int frame_idx) { int16 const *senscr; kws_search_t *kwss = (kws_search_t *) search; acmod_t *acmod = search->acmod; /* Activate senones */ if (!acmod->compallsen) kws_search_sen_active(kwss); /* Calculate senone scores for current frame. */ senscr = acmod_score(acmod, &frame_idx); /* Evaluate hmms in phone loop and in active keyword nodes */ kws_search_hmm_eval(kwss, senscr); /* Prune hmms with low prob */ kws_search_hmm_prune(kwss); /* Do hmms transitions */ kws_search_trans(kwss); ++kwss->frame; return 0; } int kws_search_finish(ps_search_t * search) { /* Nothing here */ return 0; } char const * kws_search_hyp(ps_search_t * search, int32 * out_score, int32 * out_is_final) { kws_search_t *kwss = (kws_search_t *) search; if (out_score) *out_score = 0; if (search->hyp_str) ckd_free(search->hyp_str); kws_detections_hyp_str(kwss->detections, &search->hyp_str); return search->hyp_str; } char * kws_search_get_keywords(ps_search_t * search) { int i, c, len; kws_search_t *kwss; char* line; kwss = (kws_search_t *) search; len = 0; for (i = 0; i < kwss->n_keyphrases; i++) len += strlen(kwss->keyphrases[i].word); len += kwss->n_keyphrases; c = 0; line = (char *)ckd_calloc(len, sizeof(*line)); for (i = 0; i < kwss->n_keyphrases; i++) { char *keyword_str = kwss->keyphrases[i].word; memcpy(&line[c], keyword_str, strlen(keyword_str)); c += strlen(keyword_str); line[c++] = '\n'; } line[--c] = '\0'; return line; }