00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #include <it/hmmalgo.h>
00026
00027
00028
00029
00030
00031
00032
00033
00034 mat bcjr (mat x, imat next_state, mat next_state_pt, vec alpha_0,
00035 vec beta_end)
00036 {
00037 int m, mm, br;
00038
00039 int nb_states = imat_width (next_state);
00040
00041 int nb_branch = imat_height (next_state);
00042 int K = mat_width (x), t;
00043 double gamma;
00044
00045
00046 mat alpha = mat_new_zeros (K + 1, nb_states);
00047 mat beta = mat_new_zeros (K + 1, nb_states);
00048 mat pbranch = mat_new_zeros (nb_branch, K);
00049 vec pbranch_sum;
00050
00051
00052 vec_copy (alpha[0], alpha_0);
00053 vec_copy (beta[K], beta_end);
00054
00055
00056 for (t = 1; t <= K; t++) {
00057
00058 for (br = 0; br < nb_branch; br++) {
00059 for (mm = 0; mm < nb_states; mm++) {
00060
00061
00062 m = next_state[br][mm];
00063 if (m == DUMM_NODE)
00064 continue;
00065
00066 gamma = x[br][t - 1] * next_state_pt[br][mm];
00067 alpha[t][m] += alpha[t - 1][mm] * gamma;
00068 }
00069 }
00070 vec_normalize (alpha[t], 1);
00071 }
00072
00073
00074 for (t = K - 1; t > 0; t--) {
00075 for (br = 0; br < nb_branch; br++)
00076 for (m = 0; m < nb_states; m++) {
00077 mm = next_state[br][m];
00078 if (mm == DUMM_NODE)
00079 continue;
00080
00081
00082 gamma = x[br][t] * next_state_pt[br][m];
00083 beta[t][m] += beta[t + 1][mm] * gamma;
00084 }
00085 vec_normalize (beta[t], 1);
00086 }
00087
00088
00089 for (t = 1; t <= K; t++) {
00090 for (br = 0; br < nb_branch; br++)
00091 for (mm = 0; mm < nb_states; mm++) {
00092 m = next_state[br][mm];
00093 if (m == DUMM_NODE)
00094 continue;
00095
00096 gamma = x[br][t - 1] * next_state_pt[br][mm];
00097 pbranch[br][t - 1] += alpha[t - 1][mm] * beta[t][m] * gamma;
00098 }
00099 }
00100
00101 pbranch_sum = mat_cols_sum (pbranch);
00102
00103 for (t = 0; t < K; t++)
00104 mat_col_div_by (pbranch, t, pbranch_sum[t]);
00105
00106 mat_delete (alpha);
00107 mat_delete (beta);
00108 vec_delete (pbranch_sum);
00109
00110 return pbranch;
00111 }
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121 ivec viterbi (mat x, imat next_state, mat next_state_pt, vec alpha_0,
00122 vec beta_end)
00123 {
00124 int m, mm, br;
00125
00126 int nb_states = imat_width (next_state);
00127
00128 int nb_branch = imat_height (next_state);
00129 int K = mat_width (x), t, cur_state;
00130 double rt, logalphacur, logp;
00131
00132
00133 vec logalpha_curr = vec_new (nb_states);
00134 vec logalpha_next = vec_new (nb_states);
00135 vec tmp;
00136 double gamma;
00137 imat prev_branch = imat_new_set (DUMM_NODE, K + 1, nb_states);
00138 imat prev_state = imat_new_set (DUMM_NODE, K + 1, nb_states);
00139 ivec seq = ivec_new (K);
00140
00141
00142 vec_copy (logalpha_curr, alpha_0);
00143 vec_apply_function (logalpha_curr, IT_FUNCTION (log), NULL);
00144
00145
00146 for (t = 1; t <= K; t++) {
00147
00148 vec_set (logalpha_next, -HUGE_VAL);
00149
00150
00151 for (mm = 0; mm < nb_states; mm++)
00152 for (br = 0; br < nb_branch; br++) {
00153
00154
00155 m = next_state[br][mm];
00156 if (m == DUMM_NODE)
00157 continue;
00158
00159 rt = x[br][t - 1];
00160
00161 gamma = rt * next_state_pt[br][mm];
00162
00163
00164
00165 logalphacur = logalpha_curr[mm] + log (gamma);
00166
00167 if (logalphacur > logalpha_next[m]) {
00168 logalpha_next[m] = logalphacur;
00169 prev_branch[t][m] = br;
00170 prev_state[t][m] = mm;
00171 }
00172 }
00173
00174 tmp = logalpha_curr;
00175 logalpha_curr = logalpha_next;
00176 logalpha_next = tmp;
00177 }
00178
00179
00180
00181
00182
00183 logp = -HUGE_VAL;
00184 cur_state = -1;
00185
00186 for (m = 0; m < nb_states; m++)
00187 if (logalpha_curr[m] + log (beta_end[m]) > logp) {
00188 logp = logalpha_curr[m] + log (beta_end[m]);
00189 cur_state = m;
00190 }
00191
00192
00193 for (t = K - 1; t >= 0; t--) {
00194 seq[t] = prev_branch[t + 1][cur_state];
00195 cur_state = prev_state[t + 1][cur_state];
00196 }
00197
00198 vec_delete (logalpha_curr);
00199 vec_delete (logalpha_next);
00200 imat_delete (prev_state);
00201 imat_delete (prev_branch);
00202
00203 return seq;
00204 }
00205
00206
00207
00208
00209
00210
00211 ivec logviterbi (mat logx, imat next_state, mat next_state_logpt,
00212 vec logalpha_0, vec logbeta_end)
00213 {
00214 int m, mm, br;
00215
00216 int nb_states = imat_width (next_state);
00217
00218 int nb_branch = imat_height (next_state);
00219 int K = mat_width (logx), t, cur_state;
00220 double rt, logalphacur, logp;
00221
00222
00223 vec logalpha_curr = vec_new (nb_states);
00224 vec logalpha_next = vec_new (nb_states);
00225 vec tmp;
00226 double loggamma;
00227 imat prev_branch = imat_new_set (DUMM_NODE, K + 1, nb_states);
00228 imat prev_state = imat_new_set (DUMM_NODE, K + 1, nb_states);
00229 ivec seq = ivec_new (K);
00230
00231
00232 vec_copy (logalpha_curr, logalpha_0);
00233
00234
00235 for (t = 1; t <= K; t++) {
00236
00237 vec_set (logalpha_next, -HUGE_VAL);
00238
00239
00240 for (mm = 0; mm < nb_states; mm++)
00241 for (br = 0; br < nb_branch; br++) {
00242
00243
00244 m = next_state[br][mm];
00245 if (m == DUMM_NODE)
00246 continue;
00247
00248 rt = logx[br][t - 1];
00249
00250 loggamma = rt + next_state_logpt[br][mm];
00251
00252
00253
00254 logalphacur = logalpha_curr[mm] + loggamma;
00255
00256 if (logalphacur > logalpha_next[m]) {
00257 logalpha_next[m] = logalphacur;
00258 prev_branch[t][m] = br;
00259 prev_state[t][m] = mm;
00260 }
00261 }
00262
00263 tmp = logalpha_curr;
00264 logalpha_curr = logalpha_next;
00265 logalpha_next = tmp;
00266 }
00267
00268
00269
00270
00271
00272 logp = -HUGE_VAL;
00273 cur_state = -1;
00274
00275 for (m = 0; m < nb_states; m++)
00276 if (logalpha_curr[m] + logbeta_end[m] > logp) {
00277 logp = logalpha_curr[m] + logbeta_end[m];
00278 cur_state = m;
00279 }
00280
00281
00282 for (t = K - 1; t >= 0; t--) {
00283 seq[t] = prev_branch[t + 1][cur_state];
00284 cur_state = prev_state[t + 1][cur_state];
00285 }
00286
00287 vec_delete (logalpha_curr);
00288 vec_delete (logalpha_next);
00289 imat_delete (prev_state);
00290 imat_delete (prev_branch);
00291
00292 return seq;
00293 }
00294
00295
00296
00297
00298
00299
00300
00301
00302
00303
00304 ivec viterbi_side (mat x, imat next_state, mat next_state_pt,
00305 vec alpha_0, ivec sideinfo_pos, mat sideinfo)
00306 {
00307 int m, mm, br;
00308
00309 int nb_states = imat_width (next_state);
00310
00311 int nb_branch = imat_height (next_state);
00312 int K = mat_width (x), t, cur_state;
00313 int sideinfo_idx = 0;
00314 double rt, logalphacur, logp;
00315
00316
00317 vec logalpha_curr = vec_new (nb_states);
00318 vec logalpha_next = vec_new (nb_states);
00319 vec tmp;
00320 double gamma;
00321 imat prev_branch = imat_new_set (DUMM_NODE, K + 1, nb_states);
00322 imat prev_state = imat_new_set (DUMM_NODE, K + 1, nb_states);
00323 ivec seq = ivec_new (K);
00324
00325
00326 vec_copy (logalpha_curr, alpha_0);
00327 vec_apply_function (logalpha_curr, IT_FUNCTION (log), NULL);
00328
00329
00330 for (t = 1; t <= K; t++) {
00331 vec_set (logalpha_next, -HUGE_VAL);
00332
00333
00334 for (mm = 0; mm < nb_states; mm++)
00335 for (br = 0; br < nb_branch; br++) {
00336
00337
00338 m = next_state[br][mm];
00339 if (m == DUMM_NODE)
00340 continue;
00341
00342 rt = x[br][t - 1];
00343
00344 gamma = rt * next_state_pt[br][mm];
00345
00346
00347
00348 logalphacur = logalpha_curr[mm] + log (gamma);
00349
00350 if (logalphacur > logalpha_next[m]) {
00351 logalpha_next[m] = logalphacur;
00352 prev_branch[t][m] = br;
00353 prev_state[t][m] = mm;
00354 }
00355 }
00356
00357
00358 if (sideinfo_idx < ivec_length (sideinfo_pos))
00359 while (sideinfo_pos[sideinfo_idx] <= t) {
00360 for (mm = 0; mm < nb_states; mm++)
00361 logalpha_next[mm] += log (sideinfo[sideinfo_idx][mm]);
00362
00363 sideinfo_idx++;
00364 if (sideinfo_idx >= ivec_length (sideinfo_pos))
00365 break;
00366 }
00367
00368
00369 tmp = logalpha_curr;
00370 logalpha_curr = logalpha_next;
00371 logalpha_next = tmp;
00372 }
00373
00374
00375
00376
00377
00378 logp = -HUGE_VAL;
00379 cur_state = -1;
00380
00381
00382 for (m = 0; m < nb_states; m++)
00383 if (logalpha_curr[m] > logp) {
00384 logp = logalpha_curr[m];
00385 cur_state = m;
00386 }
00387
00388
00389 for (t = K - 1; t >= 0; t--) {
00390 seq[t] = prev_branch[t + 1][cur_state];
00391 cur_state = prev_state[t + 1][cur_state];
00392 }
00393
00394 vec_delete (logalpha_curr);
00395 vec_delete (logalpha_next);
00396 imat_delete (prev_state);
00397 imat_delete (prev_branch);
00398
00399 return seq;
00400 }
00401
00402
00403
00404
00405 mat bcjr_side (mat x, imat next_state, mat next_state_pt,
00406 vec alpha_0, ivec sideinfo_pos, mat sideinfo)
00407 {
00408 int m, mm, br;
00409
00410 int nb_states = imat_width (next_state);
00411
00412 int nb_branch = imat_height (next_state);
00413 int K = mat_width (x), t;
00414 int sideinfo_idx;
00415 double gamma;
00416
00417
00418 mat alpha = mat_new_zeros (K + 1, nb_states);
00419 mat beta = mat_new_ones (K + 1, nb_states);
00420 mat pbranch = mat_new_zeros (nb_branch, K);
00421 vec pbranch_sum;
00422
00423
00424 vec_set (beta[K], 1. / nb_states);
00425
00426
00427 vec_copy (alpha[0], alpha_0);
00428
00429
00430 sideinfo_idx = 0;
00431 for (t = 1; t <= K; t++) {
00432 for (br = 0; br < nb_branch; br++) {
00433 for (mm = 0; mm < nb_states; mm++) {
00434
00435 m = next_state[br][mm];
00436 if (m == DUMM_NODE)
00437 continue;
00438
00439 gamma = x[br][t - 1] * next_state_pt[br][mm];
00440 alpha[t][m] += alpha[t - 1][mm] * gamma;
00441 }
00442 }
00443
00444
00445 if (sideinfo_idx < ivec_length (sideinfo_pos))
00446 while (sideinfo_pos[sideinfo_idx] == t) {
00447 for (mm = 0; mm < nb_states; mm++)
00448 alpha[t][mm] *= sideinfo[sideinfo_idx][mm];
00449
00450 sideinfo_idx++;
00451 if (sideinfo_idx >= ivec_length (sideinfo_pos))
00452 break;
00453 }
00454 vec_normalize (alpha[t], 1);
00455 }
00456
00457
00458 for (sideinfo_idx = ivec_length (sideinfo_pos) - 1;
00459 sideinfo_idx >= 0; sideinfo_idx--) {
00460 if (sideinfo_pos[sideinfo_idx] != K)
00461 break;
00462
00463 for (mm = 0; mm < nb_states; mm++)
00464 beta[K][mm] *= sideinfo[sideinfo_idx][mm];
00465 }
00466
00467 for (t = K - 1; t > 0; t--) {
00468
00469 for (br = 0; br < nb_branch; br++)
00470 for (m = 0; m < nb_states; m++) {
00471 mm = next_state[br][m];
00472 if (mm == DUMM_NODE)
00473 continue;
00474
00475
00476 gamma = x[br][t] * next_state_pt[br][m];
00477 beta[t][m] += beta[t + 1][mm] * gamma;
00478 }
00479
00480
00481 if (sideinfo_idx >= 0)
00482 while (sideinfo_pos[sideinfo_idx] == t) {
00483 for (mm = 0; mm < nb_states; mm++)
00484 beta[t][mm] *= sideinfo[sideinfo_idx][mm];
00485
00486 sideinfo_idx--;
00487 if (sideinfo_idx < 0)
00488 break;
00489 }
00490
00491 vec_normalize (beta[t], 1);
00492 }
00493
00494
00495 for (t = 1; t <= K; t++) {
00496 for (br = 0; br < nb_branch; br++)
00497 for (mm = 0; mm < nb_states; mm++) {
00498 m = next_state[br][mm];
00499 if (m == DUMM_NODE)
00500 continue;
00501
00502 gamma = x[br][t - 1] * next_state_pt[br][mm];
00503 pbranch[br][t - 1] += alpha[t - 1][mm] * beta[t][m] * gamma;
00504 }
00505 }
00506
00507 pbranch_sum = mat_cols_sum (pbranch);
00508
00509 for (t = 0; t < K; t++)
00510 mat_col_div_by (pbranch, t, pbranch_sum[t]);
00511
00512 mat_delete (alpha);
00513 mat_delete (beta);
00514 vec_delete (pbranch_sum);
00515
00516 return pbranch;
00517 }