Edinburgh Speech Tools  2.4-release
 All Classes Functions Variables Typedefs Enumerations Enumerator Friends Pages
EST_WFST.cc
1 /*************************************************************************/
2 /* */
3 /* Centre for Speech Technology Research */
4 /* University of Edinburgh, UK */
5 /* Copyright (c) 1997 */
6 /* All Rights Reserved. */
7 /* */
8 /* Permission is hereby granted, free of charge, to use and distribute */
9 /* this software and its documentation without restriction, including */
10 /* without limitation the rights to use, copy, modify, merge, publish, */
11 /* distribute, sublicense, and/or sell copies of this work, and to */
12 /* permit persons to whom this work is furnished to do so, subject to */
13 /* the following conditions: */
14 /* 1. The code must retain the above copyright notice, this list of */
15 /* conditions and the following disclaimer. */
16 /* 2. Any modifications must be clearly marked as such. */
17 /* 3. Original authors' names are not deleted. */
18 /* 4. The authors' names are not used to endorse or promote products */
19 /* derived from this software without specific prior written */
20 /* permission. */
21 /* */
22 /* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */
23 /* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24 /* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25 /* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */
26 /* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27 /* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28 /* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29 /* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30 /* THIS SOFTWARE. */
31 /* */
32 /*************************************************************************/
33 /* Author : Alan W Black */
34 /* Date : November 1997 */
35 /*-----------------------------------------------------------------------*/
36 /* */
37 /* A class for representing Weighted Finite State Transducers */
38 /* */
39 /* This is based on various papers by Mehryar Mohri but not forgetting */
40 /* the Kay and Kaplan stuff as well as my own Koskenniemi implementation */
41 /* and finite state machine manipulations in my earlier lives */
42 /* */
43 /*=======================================================================*/
44 
45 #include <iostream>
46 #include "EST_Pathname.h"
47 #include "EST_cutils.h"
48 #include "EST_Token.h"
49 #include "EST_FileType.h"
50 #include "EST_WFST.h"
51 
52 #include "EST_TVector.h"
53 
54 Declare_TList_T(EST_WFST_Transition *, EST_WFST_TransitionP)
55 Declare_TVector_Base_T(EST_WFST_State *, NULL, NULL, EST_WFST_StateP)
56 
57 
58 #if defined(INSTANTIATE_TEMPLATES)
59 #include "../base_class/EST_TList.cc"
60 
61 Instantiate_TList_T(EST_WFST_Transition *, EST_WFST_TransitionP)
62 
63 #include "../base_class/EST_TVector.cc"
64 
65 Instantiate_TVector_T(EST_WFST_State *, EST_WFST_StateP)
66 
67 #endif
68 
69 // Used for marking states duration traversing functions
70 int EST_WFST::traverse_tag = 0;
71 
72 EST_WFST_State::EST_WFST_State(int name)
73 {
74  p_name = name;
75  p_type = wfst_error;
76  p_tag = 0;
77 }
78 
79 EST_WFST_State::EST_WFST_State(const EST_WFST_State &state)
80 {
81  EST_Litem *p;
82 
83  p_name = state.p_name;
84  p_type = state.p_type;
85  p_tag = state.p_tag;
86  for (p=state.transitions.head(); p != 0; p=p->next())
87  transitions.append(new EST_WFST_Transition(*state.transitions(p)));
88 }
89 
90 EST_WFST_State::~EST_WFST_State()
91 {
92  EST_Litem *p;
93 
94  for (p=transitions.head(); p != 0; p=p->next())
95  delete transitions(p);
96 
97 }
98 
99 EST_WFST_Transition *EST_WFST_State::add_transition(float w,
100  int end,
101  int in,
102  int out)
103 {
104  // Add new transition
105  EST_WFST_Transition *s = new EST_WFST_Transition(w,end,in,out);
106  transitions.append(s);
107  return s;
108 }
109 
110 EST_WFST::~EST_WFST()
111 {
112  clear();
113 }
114 
116 {
117 
118  // delete up to p_num_states, rather than p_states.length() as
119  // only up to there is necessarily filled
120  for (int i=0; i < p_num_states; ++i)
121  delete p_states[i];
122  p_num_states = 0;
123  p_cumulate = 0;
124 }
125 
127 {
128  p_num_states = 0;
129  init(0);
130 }
131 
132 void EST_WFST::copy(const EST_WFST &wfst)
133 {
134  clear();
135  p_in_symbols = wfst.p_in_symbols;
136  p_out_symbols = wfst.p_out_symbols;
137  p_start_state = wfst.p_start_state;
138  current_tag = wfst.current_tag;
139  p_num_states = wfst.p_num_states;
140  p_states.resize(p_num_states);
141  for (int i=0; i < p_num_states; ++i)
142  p_states[i] = new EST_WFST_State(*wfst.state(i));
143 }
144 
145 void EST_WFST::init(int init_num_states)
146 {
147  int i;
148 
149  clear();
150 
151  p_states.resize(init_num_states);
152  for (i=0; i < p_states.length(); i++)
153  p_states[i] = 0;
154  p_num_states = init_num_states;
155 
156 }
157 
158 void EST_WFST::init(LISP in_alphabet,LISP out_alphabet)
159 {
160  EST_StrList in,out;
161  LISP iin,oout;
162 
163  in.append("__epsilon__");
164  in.append("=");
165  for (iin=in_alphabet; iin != NIL; iin=cdr(iin))
166  if ((!streq(get_c_string(car(iin)),"__epsilon__")) &&
167  (!streq(get_c_string(car(iin)),"=")))
168  in.append(get_c_string(car(iin)));
169 
170  out.append("__epsilon__");
171  out.append("=");
172  for (oout=out_alphabet; oout != NIL; oout=cdr(oout))
173  if ((!streq(get_c_string(car(oout)),"__epsilon__")) &&
174  (!streq(get_c_string(car(oout)),"=")))
175  out.append(get_c_string(car(oout)));
176 
177  p_in_symbols.init(in);
178  p_out_symbols.init(out);
179 
180 }
181 
182 int EST_WFST::transduce(int state,const EST_String &in,EST_String &out) const
183 {
184  int nstate;
185  int in_i = p_in_symbols.name(in);
186  int out_i=0;
187 
188  if (in_i == -1)
189  {
190  cerr << "WFST transduce: \"" << in << "\" not in alphabet" << endl;
191  return WFST_ERROR_STATE;
192  }
193 
194  nstate = transduce(state,in_i,out_i);
195 
196  out = p_out_symbols.name(out_i);
197 
198  return nstate;
199 }
200 
201 void EST_WFST::transduce(int state,int in,wfst_translist &out) const
202 {
203  EST_WFST_State *s = p_states(state);
204  EST_Litem *i;
205 
206  for (i=s->transitions.head(); i != 0; i=i->next())
207  {
208  if (in == s->transitions(i)->in_symbol())
209  {
210  if (cumulate())
211  s->transitions(i)->set_weight(1+s->transitions(i)->weight());
212  out.append(s->transitions(i));
213  }
214  }
215 
216  // could return if any transitions were found
217 }
218 
219 int EST_WFST::transduce(int state,int in,int &out) const
220 {
221  EST_WFST_State *s = p_states(state);
222  EST_Litem *i;
223 
224  for (i=s->transitions.head(); i != 0; i=i->next())
225  {
226  if (in == s->transitions(i)->in_symbol())
227  {
228  out = s->transitions(i)->out_symbol();
229  return s->transitions(i)->state();
230  }
231  }
232 
233  return WFST_ERROR_STATE; // no match
234 }
235 
236 int EST_WFST::transition(int state,const EST_String &inout) const
237 {
238  if (inout.contains("/"))
239  return transition(state,inout.before("/"),inout.after("/"));
240  else
241  return transition(state,inout,inout);
242 }
243 
244 int EST_WFST::transition(int state,const EST_String &in,
245  const EST_String &out) const
246 {
247  int in_i = p_in_symbols.name(in);
248  int out_i = p_out_symbols.name(out);
249 
250  if ((in_i == -1) || (out_i == -1))
251  {
252  cerr << "WFST: one of " << in << "/" << out << " not in alphabet"
253  << endl;
254  return WFST_ERROR_STATE;
255  }
256 
257  return transition(state,in_i,out_i);
258 }
259 
260 int EST_WFST::transition(int state,int in, int out) const
261 {
262  // Finds first transition (hopefully deterministic)
263  float prob;
264  return transition(state,in,out,prob);
265 }
266 
267 EST_WFST_Transition *EST_WFST::find_transition(int state,int in, int out) const
268 {
269  // Finds first transition (hopefully deterministic)
270  EST_WFST_State *s = p_states(state);
271  EST_Litem *i;
272 
273  for (i=s->transitions.head(); i != 0; i=i->next())
274  {
275  if ((in == s->transitions(i)->in_symbol()) &&
276  (out == s->transitions(i)->out_symbol()))
277  {
278  if (cumulate())
279  s->transitions(i)->set_weight(1+s->transitions(i)->weight());
280  return s->transitions(i);
281  }
282  }
283 
284  return 0; // no match
285 }
286 
287 int EST_WFST::transition(int state,int in, int out,float &prob) const
288 {
289  // Finds first transition (hopefully deterministic)
290  EST_WFST_Transition *trans = find_transition(state,in,out);
291 
292  if (trans == 0)
293  {
294  prob = 0;
295  return WFST_ERROR_STATE;
296  }
297  else
298  {
299  prob = trans->weight();
300  return trans->state();
301  }
302 }
303 
304 EST_write_status EST_WFST::save_binary(FILE *fd)
305 {
306  int i;
307  EST_Litem *j;
308  int num_transitions, type, in, out, next_state;
309  float weight;
310 
311  for (i=0; i<p_num_states; i++)
312  {
313  num_transitions = p_states[i]->num_transitions();
314  fwrite(&num_transitions,4,1,fd);
315  if (p_states[i]->type() == wfst_final)
316  type = WFST_FINAL;
317  else if (p_states[i]->type() == wfst_nonfinal)
318  type = WFST_NONFINAL;
319  else if (p_states[i]->type() == wfst_licence)
320  type = WFST_LICENCE;
321  else
322  type = WFST_ERROR;
323  fwrite(&type,4,1,fd);
324  for (j=p_states[i]->transitions.head(); j != 0; j=j->next())
325  {
326  in = p_states[i]->transitions(j)->in_symbol();
327  out = p_states[i]->transitions(j)->out_symbol();
328  next_state = p_states[i]->transitions(j)->state();
329  weight = p_states[i]->transitions(j)->weight();
330 
331  if (in == out)
332  {
333  in *= -1;
334  fwrite(&in,4,1,fd);
335  }
336  else
337  {
338  fwrite(&in,4,1,fd);
339  fwrite(&out,4,1,fd);
340  }
341  fwrite(&next_state,4,1,fd);
342  fwrite(&weight,4,1,fd);
343  }
344  }
345 
346  return write_ok;
347 }
348 
349 EST_write_status EST_WFST::save(const EST_String &filename,
350  const EST_String type)
351 {
352  FILE *ofd;
353  int i;
354  static EST_Regex needquotes(".*[()'\";., \t\n\r].*");
355  EST_Litem *j;
356 
357  if (filename == "-")
358  ofd = stdout;
359  else if ((ofd = fopen(filename,"wb")) == NULL)
360  {
361  cerr << "WFST: cannot write to file \"" << filename << "\"" << endl;
362  return misc_write_error;
363  }
364 
365  fprintf(ofd,"EST_File fst\n");
366  fprintf(ofd,"DataType %s\n",(const char *)type);
367  fprintf(ofd,"in %s\n",
368  (const char *)quote_string(EST_String("(")+
369  p_in_symbols.print_to_string(TRUE)+")",
370  "\"","\\",1));
371  fprintf(ofd,"out %s\n",
372  (const char *)quote_string(EST_String("(")+
373  p_out_symbols.print_to_string(TRUE)+")",
374  "\"","\\",1));
375  fprintf(ofd,"NumStates %d\n",p_num_states);
376  fprintf(ofd, "ByteOrder %s\n", ((EST_NATIVE_BO == bo_big) ? "10" : "01"));
377  fprintf(ofd,"EST_Header_End\n");
378 
379  if (type == "binary")
380  save_binary(ofd);
381  else
382  {
383  for (i=0; i < p_num_states; i++)
384  {
385  EST_WFST_State *s=p_states[i];
386  fprintf(ofd,"((%d ",s->name());
387  switch(s->type())
388  {
389  case wfst_final:
390  fprintf(ofd,"final ");
391  break;
392  case wfst_nonfinal:
393  fprintf(ofd,"nonfinal ");
394  break;
395  case wfst_licence:
396  fprintf(ofd,"licence ");
397  break;
398  default:
399  fprintf(ofd,"error ");
400  }
401  fprintf(ofd,"%d)\n",s->num_transitions());
402  for (j=s->transitions.head(); j != 0; j=j->next())
403  {
404  EST_String in = p_in_symbols.name(s->transitions(j)->in_symbol());
405  EST_String out=p_out_symbols.name(s->transitions(j)->out_symbol());
406  if (in.matches(needquotes))
407  fprintf(ofd," (%s ",(const char *)quote_string(in,"\"","\\",1));
408  else
409  fprintf(ofd," (%s ",(const char *)in);
410  if (out.matches(needquotes))
411  fprintf(ofd," %s ",(const char *)quote_string(out,"\"","\\",1));
412  else
413  fprintf(ofd," %s ",(const char *)out);
414  fprintf(ofd,"%d %g)\n",
415  s->transitions(j)->state(),
416  s->transitions(j)->weight());
417  }
418  fprintf(ofd,")\n");
419  }
420  }
421  if (ofd != stdout)
422  fclose(ofd);
423 
424  return write_ok;
425 }
426 
427 static float get_float(FILE *fd,int swap)
428 {
429  float f;
430  fread(&f,4,1,fd);
431  if (swap) swapfloat(&f);
432  return f;
433 }
434 
435 static int get_int(FILE *fd,int swap)
436 {
437  int i;
438  fread(&i,4,1,fd);
439  if (swap)
440  return SWAPINT(i);
441  else
442  return i;
443 }
444 
445 EST_read_status EST_WFST::load_binary(FILE *fd,
446  EST_Option &hinfo,
447  int num_states,
448  int swap)
449 {
450  EST_read_status r;
451  int i,j, s;
452  int num_trans, state_type;
453  int in_sym, out_sym, next_state;
454  float trans_cost;
455 
456  r = format_ok;
457 
458  for (i=0; i < num_states; i++)
459  {
460  num_trans = get_int(fd,swap);
461  state_type = get_int(fd,swap);
462 
463  if (state_type == WFST_FINAL)
464  s = add_state(wfst_final);
465  else if (state_type == WFST_NONFINAL)
466  s = add_state(wfst_nonfinal);
467  else if (state_type == WFST_LICENCE)
468  s = add_state(wfst_licence);
469  else if (state_type == WFST_ERROR)
470  s = add_state(wfst_error);
471  else
472  {
473  cerr << "WFST load: unknown state type \"" <<
474  state_type << "\"" << endl;
475  r = read_format_error;
476  break;
477  }
478 
479  if (s != i)
480  {
481  cerr << "WFST load: internal error: unexpected state misalignment"
482  << endl;
483  r = read_format_error;
484  break;
485  }
486 
487  for (j=0; j < num_trans; j++)
488  {
489  in_sym = get_int(fd,swap);
490  if (in_sym < 0)
491  {
492  in_sym *= -1;
493  out_sym = in_sym;
494  }
495  else
496  out_sym = get_int(fd,swap);
497  next_state = get_int(fd,swap);
498  trans_cost = get_float(fd,swap);
499 
500  p_states[i]->add_transition(trans_cost,next_state,in_sym,out_sym);
501  }
502  }
503 
504  return r;
505 }
506 
507 
508 EST_read_status EST_WFST::load(const EST_String &filename)
509 {
510  // Load a WFST from a file
511  FILE *fd;
512  EST_TokenStream ts;
513  EST_Option hinfo;
514  bool ascii;
515  EST_EstFileType t;
516  EST_read_status r;
517  int i,s;
518  int swap;
519 
520  if ((fd=fopen(filename,"r")) == NULL)
521  {
522  cerr << "WFST load: unable to open \"" << filename
523  << "\" for reading" << endl;
524  return read_error;
525  }
526  ts.open(fd,FALSE);
527  ts.set_quotes('"','\\');
528 
529  if (((r = read_est_header(ts, hinfo, ascii, t)) != format_ok) ||
530  (t != est_file_fst))
531  {
532  cerr << "WFST load: not a WFST file \"" << filename << "\"" <<endl;
533  return misc_read_error;
534  }
535 
536  // Value is a quoted quoted s-expression. Two reads is the
537  // safest way to unquote them
538  LISP inalpha =
539  read_from_string(get_c_string(read_from_string(hinfo.val("in"))));
540  LISP outalpha =
541  read_from_string(get_c_string(read_from_string(hinfo.val("out"))));
542  p_start_state = 0;
543 
544  clear();
545  init(inalpha,outalpha);
546 
547  int num_states = hinfo.ival("NumStates");
548  r = format_ok;
549 
550  if (!ascii)
551  {
552  if (!hinfo.present("ByteOrder"))
553  swap = FALSE; // ascii or not there for some reason
554  else if (((hinfo.val("ByteOrder") == "01") ? bo_little : bo_big)
555  != EST_NATIVE_BO)
556  swap = TRUE;
557  else
558  swap = FALSE;
559  r = load_binary(fd,hinfo,num_states,swap);
560  }
561  else
562  {
563  for (i=0; i < num_states; i++)
564  {
565  LISP sd = lreadf(fd);
566  if (i != get_c_int(car(car(sd))))
567  {
568  cerr << "WFST load: expected description of state " << i <<
569  " but found \"" << siod_sprint(sd) << "\"" << endl;
570  r = read_format_error;
571  break;
572  }
573  if (streq("final",get_c_string(car(cdr(car(sd))))))
574  s = add_state(wfst_final);
575  else if (streq("nonfinal",get_c_string(car(cdr(car(sd))))))
576  s = add_state(wfst_nonfinal);
577  else if (streq("licence",get_c_string(car(cdr(car(sd))))))
578  s = add_state(wfst_licence);
579  else
580  {
581  cerr << "WFST load: unknown state type \"" <<
582  siod_sprint(car(cdr(car(sd)))) << "\"" << endl;
583  r = read_format_error;
584  break;
585  }
586 
587  if (s != i)
588  {
589  cerr << "WFST load: internal error: unexpected state misalignment"
590  << endl;
591  r = read_format_error;
592  break;
593  }
594  if (load_transitions_from_lisp(s,cdr(sd)) != format_ok)
595  {
596  r = read_format_error;
597  break;
598  }
599  }
600  }
601 
602  fclose(fd);
603 
604  return r;
605 }
606 
607 EST_read_status EST_WFST::load_transitions_from_lisp(int s, LISP trans)
608 {
609  LISP t;
610 
611  for (t=trans; t != NIL; t=cdr(t))
612  {
613  float w = get_c_float(siod_nth(3,car(t)));
614  int end = get_c_int(siod_nth(2,car(t)));
615  int in = p_in_symbols.name(get_c_string(siod_nth(0,car(t))));
616  int out = p_out_symbols.name(get_c_string(siod_nth(1,car(t))));
617 
618  if ((in == -1) || (out == -1))
619  {
620  cerr << "WFST load: unknown vocabulary in state transition"
621  << endl;
622  cerr << "WFST load: " << siod_sprint(car(t)) << endl;
623  return read_format_error;
624  }
625  p_states[s]->add_transition(w,end,in,out);
626  }
627  return format_ok;
628 }
629 
630 EST_String EST_WFST::summary() const
631 {
632  int i;
633  int tt=0;
634 
635  for (i=0; i < p_num_states; i++)
636  tt += p_states(i)->transitions.length();
637 
638  return EST_String("WFST ")+itoString(p_num_states)+" states "+
639  itoString(tt)+" transitions ";
640 }
641 
642 
643 void EST_WFST::more_states(int new_max)
644 {
645  int i;
646 
647  p_states.resize(new_max);
648  for (i=p_num_states; i < new_max; i++)
649  p_states[i] = 0;
650 }
651 
652 int EST_WFST::add_state(enum wfst_state_type state_type)
653 {
654  // Add new state of given type
655  EST_WFST_State *s = new EST_WFST_State(p_num_states);
656 
657  if (p_num_states >= p_states.length())
658  {
659  // Need more space for states
660  more_states((int)((float)(p_states.length()+1)*1.5));
661  }
662 
663  p_states[p_num_states] = s;
664  s->set_type(state_type);
665  p_num_states++;
666 
667  return s->name();
668 }
669 
671 {
672  // cumulate transitions during recognition
673  EST_Litem *j;
674  int i;
675 
676  p_cumulate = 1;
677  for (i=0; i < p_num_states; i++)
678  {
679  EST_WFST_State *s=p_states[i];
680  for (j=s->transitions.head(); j !=0; j=j->next())
681  s->transitions(j)->set_weight(0);
682  }
683 }
684 
686 {
687  EST_Litem *j;
688  int i;
689  float t;
690 
691  p_cumulate = 0;
692  for (i=0; i < p_num_states; i++)
693  {
694  EST_WFST_State *s=p_states[i];
695  for (t=0,j=s->transitions.head(); j !=0; j=j->next())
696  t += s->transitions(j)->weight();
697  if (t > 0)
698  for (j=s->transitions.head(); j !=0; j=j->next())
699  s->transitions(j)->set_weight(s->transitions(j)->weight()/t);
700  }
701 }