Edinburgh Speech Tools 2.4-release
EST_SCFG_inout.cc
1/*************************************************************************/
2/* */
3/* Centre for Speech Technology Research */
4/* University of Edinburgh, UK */
5/* Copyright (c) 1997 */
6/* All Rights Reserved. */
7/* */
8/* Permission is hereby granted, free of charge, to use and distribute */
9/* this software and its documentation without restriction, including */
10/* without limitation the rights to use, copy, modify, merge, publish, */
11/* distribute, sublicense, and/or sell copies of this work, and to */
12/* permit persons to whom this work is furnished to do so, subject to */
13/* the following conditions: */
14/* 1. The code must retain the above copyright notice, this list of */
15/* conditions and the following disclaimer. */
16/* 2. Any modifications must be clearly marked as such. */
17/* 3. Original authors' names are not deleted. */
18/* 4. The authors' names are not used to endorse or promote products */
19/* derived from this software without specific prior written */
20/* permission. */
21/* */
22/* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */
23/* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24/* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25/* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */
26/* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27/* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28/* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29/* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30/* THIS SOFTWARE. */
31/* */
32/*************************************************************************/
33/* Author : Alan W Black */
34/* Date : October 1997 */
35/*-----------------------------------------------------------------------*/
36/* */
37/* Implementation of an inside-outside reestimation procedure for */
38/* building a stochastic CFG seeded with a bracket corpus. */
39/* Based on "Inside-Outside Reestimation from partially bracketed */
40/* corpora", F Pereira and Y. Schabes. pp 128-135, 30th ACL, Newark, */
41/* Delaware 1992. */
42/* */
43/* This should really be done in the log domain. Addition in the log */
44/* domain can be done with a formula in Huang, Ariki and Jack */
45/* (log(a)-log(b)) */
46/* log(a+b) = log(1 + e ) + log(b) */
47/* */
48/*=======================================================================*/
49#include <cstdlib>
50#include "EST_SCFG_Chart.h"
51#include "EST_simplestats.h"
52#include "EST_math.h"
53#include "EST_TVector.h"
54
55static const EST_bracketed_string def_val_s;
56static EST_bracketed_string error_return_s;
59
60
61#if defined(INSTANTIATE_TEMPLATES)
62#include "../base_class/EST_TVector.cc"
63
65#endif
66
67void set_corpus(EST_Bcorpus &b, LISP examples)
68{
69 LISP e;
70 int i;
71
72 b.resize(siod_llength(examples));
73
74 for (i=0,e=examples; e != NIL; e=cdr(e),i++)
75 b.a_no_check(i).set_bracketed_string(car(e));
76}
77
78void EST_bracketed_string::init()
79{
80 bs = NIL;
81 gc_protect(&bs);
82 symbols = 0;
83 valid_spans = 0;
84 p_length = 0;
85}
86
87EST_bracketed_string::EST_bracketed_string()
88{
89 init();
90}
91
92EST_bracketed_string::EST_bracketed_string(LISP string)
93{
94 init();
95
96 set_bracketed_string(string);
97}
98
99EST_bracketed_string::~EST_bracketed_string()
100{
101 int i;
102 bs=NIL;
103 gc_unprotect(&bs);
104 delete [] symbols;
105 for (i=0; i < p_length; i++)
106 delete [] valid_spans[i];
107 delete [] valid_spans;
108}
109
110void EST_bracketed_string::set_bracketed_string(LISP string)
111{
112
113 bs=NIL;
114 delete [] symbols;
115
116 p_length = find_num_nodes(string);
117 symbols = new LISP[p_length];
118
119 set_leaf_indices(string,0,symbols);
120
121 bs = string;
122
123 int i,j;
124 valid_spans = new int*[length()];
125 for (i=0; i < length(); i++)
126 {
127 valid_spans[i] = new int[length()+1];
128 for (j=i+1; j <= length(); j++)
129 valid_spans[i][j] = 0;
130 }
131
132 // fill in valid table
133 if (p_length > 0)
134 find_valid(0,bs);
135
136}
137
138int EST_bracketed_string::find_num_nodes(LISP string)
139{
140 // This wont could nil as an atom
141 if (string == NIL)
142 return 0;
143 else if (CONSP(string))
144 return find_num_nodes(car(string))+
145 find_num_nodes(cdr(string));
146 else
147 return 1;
148}
149
150int EST_bracketed_string::set_leaf_indices(LISP string,int i,LISP *syms)
151{
152 if (string == NIL)
153 return i;
154 else if (!CONSP(car(string)))
155 {
156 syms[i] = string;
157 return set_leaf_indices(cdr(string),i+1,syms);
158 }
159 else // car is a tree
160 {
161 return set_leaf_indices(cdr(string),
162 set_leaf_indices(car(string),i,syms),
163 syms);
164 }
165}
166
167void EST_bracketed_string::find_valid(int s,LISP t) const
168{
169 LISP l;
170 int c;
171
172 if (consp(t))
173 {
174 for (c=s,l=t; l != NIL; l=cdr(l))
175 {
176 c += num_leafs(car(l));
177 valid_spans[s][c] = 1;
178 }
179 find_valid(s,car(t));
180 find_valid(s+num_leafs(car(t)),cdr(t));
181 }
182}
183
184int EST_bracketed_string::num_leafs(LISP t) const
185{
186 if (t == NIL)
187 return 0;
188 else if (!consp(t))
189 return 1;
190 else
191 return num_leafs(car(t)) + num_leafs(cdr(t));
192}
193
194EST_SCFG_traintest::EST_SCFG_traintest(void) : EST_SCFG()
195{
196 inside = 0;
197 outside = 0;
198 n.resize(0);
199 d.resize(0);
200}
201
202EST_SCFG_traintest::~EST_SCFG_traintest(void)
203{
204
205}
206
208{
209 set_corpus(corpus,vload(filename,1));
210}
211
212// From the formula in the paper
213double EST_SCFG_traintest::f_I_cal(int c, int p, int i, int k)
214{
215 // Find Inside probability
216 double res;
217
218 if (i == k-1)
219 {
220 res = prob_U(p,terminal(corpus.a_no_check(c).symbol_at(i)));
221// printf("prob_U p %s (%d) %d m %s (%d) res %g\n",
222// (const char *)nonterminal(p),p,
223// i,
224// (const char *)corpus.a_no_check(c).symbol_at(i),
225// terminal(corpus.a_no_check(c).symbol_at(i)),
226// res);
227 }
228 else if (corpus.a_no_check(c).valid(i,k) == TRUE)
229 {
230 int j;
231 double s=0;
232 int q,r;
233
234 for (q = 0; q < num_nonterminals(); q++)
235 for (r = 0; r < num_nonterminals(); r++)
236 {
237 double pBpqr = prob_B(p,q,r);
238 if (pBpqr > 0)
239 for (j=i+1; j < k; j++)
240 {
241 double in = f_I(c,q,i,j);
242 if (in > 0)
243 s += pBpqr * in * f_I(c,r,j,k);
244 }
245 }
246 res = s;
247 }
248 else
249 res = 0.0;
250
251 inside[p][i][k] = res;
252
253// printf("f_I p %s i %d k %d res %g\n",
254// (const char *)nonterminal(p),i,k,res);
255
256 return res;
257}
258
259double EST_SCFG_traintest::f_O_cal(int c, int p, int i, int k)
260{
261 // Find Outside probability
262 double res;
263
264 if ((i == 0) && (k == corpus.a_no_check(c).length()))
265 {
266 if (p == distinguished_symbol()) // distinguished non-terminal
267 res = 1.0;
268 else
269 res = 0.0;
270 }
271 else if (corpus.a_no_check(c).valid(i,k) == TRUE)
272 {
273 double s1=0.0;
274 double s2,s3;
275 double pBqrp,pBqpr;
276 int j;
277 int q,r;
278
279 for (q = 0; q < num_nonterminals(); q++)
280 for (r = 0; r < num_nonterminals(); r++)
281 {
282 pBqrp = prob_B(q,r,p);
283 s2 = s3 = 0.0;
284 if (pBqrp > 0)
285 {
286 for (j=0;j < i; j++)
287 {
288 double out = f_O(c,q,j,k);
289 if (out > 0)
290 s2 += out * f_I(c,r,j,i);
291 }
292 s2 *= pBqrp;
293 }
294 pBqpr = prob_B(q,p,r);
295 if (pBqpr > 0)
296 {
297 for (j=k+1;j <= corpus.a_no_check(c).length(); j++)
298 {
299 double out = f_O(c,q,i,j);
300 if (out > 0)
301 s3 += out * f_I(c,r,k,j);
302 }
303 s3 *= pBqpr;
304 }
305 s1 += s2 + s3;
306 }
307 res = s1;
308 }
309 else // not a valid bracketing
310 res = 0.0;
311
312 outside[p][i][k] = res;
313
314 return res;
315}
316
317void EST_SCFG_traintest::reestimate_rule_prob_B(int c, int ri, int p, int q, int r)
318{
319 // Re-estimate probability for binary rules
320 int i,j,k;
321 double n2=0;
322
323 double pBpqr = prob_B(p,q,r);
324
325 if (pBpqr > 0)
326 {
327 for (i=0; i <= corpus.a_no_check(c).length()-2; i++)
328 for (j=i+1; j <= corpus.a_no_check(c).length()-1; j++)
329 {
330 double d1 = f_I(c,q,i,j);
331 if (d1 == 0) continue;
332 for (k=j+1; k <= corpus.a_no_check(c).length(); k++)
333 {
334 double d2 = f_I(c,r,j,k);
335 if (d2 == 0) continue;
336 double d3 = f_O(c,p,i,k);
337 if (d3 == 0) continue;
338 n2 += d1 * d2 * d3;
339 }
340 }
341 n2 *= pBpqr;
342 }
343 // f_P(c) is probably redundant
344 double fp = f_P(c);
345 double n1,d1;
346 n1 = n2 / fp;
347 if (fp == 0) n1=0;
348
349 d1 = f_P(c,p) / fp;
350 if (fp == 0) d1=0;
351 // printf("n1 %f d1 %f n2 %f fp %f\n",n1,d1,n2,fp);
352 n[ri] += n1;
353 d[ri] += d1;
354
355}
356
357void EST_SCFG_traintest::reestimate_rule_prob_U(int c,int ri, int p, int m)
358{
359 // Re-estimate probability for unary rules
360 int i;
361
362// printf("reestimate_rule_prob_U: %f p %s m %s\n",
363// prob_U(ip,im),
364// (const char *)p,
365// (const char *)m);
366
367 double n2=0;
368
369 for (i=1; i < corpus.a_no_check(c).length(); i++)
370 if (m == terminal(corpus.a_no_check(c).symbol_at(i-1)))
371 n2 += prob_U(p,m) * f_O(c,p,i-1,i);
372
373 double fP = f_P(c);
374 if (fP != 0)
375 {
376 n[ri] += n2 / fP;
377 d[ri] += f_P(c,p) / fP;
378 }
379}
380
381double EST_SCFG_traintest::f_P(int c)
382{
383 return f_I(c,distinguished_symbol(),0,corpus.a_no_check(c).length());
384}
385
386double EST_SCFG_traintest::f_P(int c,int p)
387{
388 int i,j;
389 double db=0;
390
391 for (i=0; i < corpus.a_no_check(c).length(); i++)
392 for (j=i+1; j <= corpus.a_no_check(c).length(); j++)
393 {
394 double d1 = f_O(c,p,i,j);
395 if (d1 == 0) continue;
396 db += f_I(c,p,i,j)*d1;
397 }
398
399 return db;
400}
401
402void EST_SCFG_traintest::reestimate_grammar_probs(int passes,
403 int startpass,
404 int checkpoint,
405 int spread,
406 const EST_String &outfile)
407{
408 // Iterate over the corpus cummulating factors for each rules
409 // This reduces the space requirements and recalculations of
410 // values for each sentences.
411 // Repeat training passes to number specified
412 int pass = 0;
413 double zero=0;
414 double se;
415 int ri,c;
416
417 n.resize(rules.length());
418 d.resize(rules.length());
419
420 for (pass = startpass; pass < passes; pass++)
421 {
422 EST_Litem *r;
423 double mC, lPc;
424
425 d.fill(zero);
426 n.fill(zero);
428
429 for (mC=0.0,lPc=0.0,c=0; c < corpus.length(); c++)
430 {
431 // For skipping some sentences to speed up convergence
432 if ((spread > 0) && (((c+(pass*spread))%100) >= spread))
433 continue;
434 printf(" %d",c); fflush(stdout);
435 if (corpus.a_no_check(c).length() == 0) continue;
436 init_io_cache(c,num_nonterminals());
437 for (ri=0,r=rules.head(); r != 0; r=r->next(),ri++)
438 {
439 if (rules(r).type() == est_scfg_binary_rule)
440 reestimate_rule_prob_B(c,ri,
441 rules(r).mother(),
442 rules(r).daughter1(),
443 rules(r).daughter2());
444 else
445 reestimate_rule_prob_U(c,
446 ri,
447 rules(r).mother(),
448 rules(r).daughter1());
449 }
450 lPc += safe_log(f_P(c));
451 mC += corpus.a_no_check(c).length();
452 clear_io_cache(c);
453 }
454 printf("\n");
455
456 for (se=0.0,ri=0,r=rules.head(); r != 0; r=r->next(),ri++)
457 {
458 double n_prob = n[ri]/d[ri];
459 if (d[ri] == 0)
460 n_prob = 0;
461 se += (n_prob-rules(r).prob())*(n_prob-rules(r).prob());
462 rules(r).set_prob(n_prob);
463 }
464 printf("pass %d cross entropy %g RMSE %f %f %d\n",
465 pass,-(lPc/mC),sqrt(se/rules.length()),
466 se,rules.length());
467
468 if (checkpoint != -1)
469 {
470 if ((pass % checkpoint) == checkpoint-1)
471 {
472 char cp[20];
473 sprintf(cp,".%03d",pass);
474 save(outfile+cp);
475 user_gc(NIL); // just to keep things neat
476 }
477 }
478
479 }
480}
481
483 int startpass,
484 int checkpoint,
485 int spread,
486 const EST_String &outfile)
487{
488 // Train a Stochastic CFG using the inside outside algorithm
489
490 reestimate_grammar_probs(passes, startpass, checkpoint,
491 spread, outfile);
492}
493
494void EST_SCFG_traintest::init_io_cache(int c,int nt)
495{
496 // Build an array to cache the in/out values
497 int i,j,k;
498 int mc = corpus.a_no_check(c).length()+1;
499
500 inside = new double**[nt];
501 outside = new double**[nt];
502 for (i=0; i < nt; i++)
503 {
504 inside[i] = new double*[mc];
505 outside[i] = new double*[mc];
506 for (j=0; j < mc; j++)
507 {
508 inside[i][j] = new double[mc];
509 outside[i][j] = new double[mc];
510 for (k=0; k < mc; k++)
511 {
512 inside[i][j][k] = -1;
513 outside[i][j][k] = -1;
514 }
515 }
516 }
517}
518
519void EST_SCFG_traintest::clear_io_cache(int c)
520{
521 int mc = corpus.a_no_check(c).length()+1;
522 int i,j;
523
524 if (inside == 0)
525 return;
526
527 for (i=0; i < num_nonterminals(); i++)
528 {
529 for (j=0; j < mc; j++)
530 {
531 delete [] inside[i][j];
532 delete [] outside[i][j];
533 }
534 delete [] inside[i];
535 delete [] outside[i];
536 }
537
538 delete [] inside;
539 delete [] outside;
540
541 inside = 0;
542 outside = 0;
543}
544
545double EST_SCFG_traintest::cross_entropy()
546{
547 double lPc=0,mC=0;
548 int c;
549
550 for (c=0; c < corpus.length(); c++)
551 {
552 lPc += log(f_P(c));
553 mC += corpus.a_no_check(c).length();
554 }
555
556 return -(lPc/mC);
557}
558
560{
561 // Test corpus against current grammar.
562 double mC,lPc;
563 int c,i;
564 int failed=0;
565 double fP;
566
567 // Lets try simply finding the cross entropy
568 n.resize(rules.length());
569 d.resize(rules.length());
570 for (i=0; i < rules.length(); i++)
571 d[i] = n[i] = 0.0;
572
573 for (mC=0.0,lPc=0.0,c=0; c < corpus.length(); c++)
574 {
575 if (corpus.length() > 50)
576 {
577 printf(" %d",c);
578 fflush(stdout);
579 }
580 init_io_cache(c,num_nonterminals());
581 fP = f_P(c);
582 if (fP == 0)
583 failed++;
584 else
585 {
586 lPc += safe_log(fP);
587 mC += corpus.a_no_check(c).length();
588 }
589 clear_io_cache(c);
590 }
591 if (corpus.length() > 50)
592 printf("\n");
593
594 cout << "cross entropy " << -(lPc/mC) << " (" << failed << " failed out of " <<
595 corpus.length() << " sentences )" << endl;
596
597}
598
void train_inout(int passes, int startpass, int checkpoint, int spread, const EST_String &outfile)
void load_corpus(const EST_String &filename)
int num_nonterminals() const
Number of nonterminals.
Definition: EST_SCFG.h:222
double prob_B(int p, int q, int r) const
The rule probability of given binary rule.
Definition: EST_SCFG.h:226
void set_rule_prob_cache()
(re-)set rule probability caches
Definition: EST_SCFG.cc:256
SCFGRuleList rules
The rules themselves.
Definition: EST_SCFG.h:207
EST_write_status save(const EST_String &filename)
Save current grammar to named file.
Definition: EST_SCFG.cc:204
EST_String terminal(int m) const
Convert terminal index to string form.
Definition: EST_SCFG.h:216
double prob_U(int p, int m) const
The rule probability of given unary rule.
Definition: EST_SCFG.h:228
void resize(int n, int set=1)
resize vector
void resize(int n, int set=1)
Definition: EST_TVector.cc:196
INLINE int length() const
number of items in vector.
Definition: EST_TVector.h:252
void fill(const T &v)
Fill entire array will value <parameter>v</parameter>.
Definition: EST_TVector.cc:105
INLINE const T & a_no_check(int n) const
read-only const access operator: without bounds checking
Definition: EST_TVector.h:257
const EST_String symbol_at(int i) const
The nth symbol in the string.
Definition: EST_SCFG.h:82
int valid(int i, int k) const
If a bracketing from i to k is valid in string.
Definition: EST_SCFG.h:85