Diffstat (limited to 'inputmethods/dasher/PPMLanguageModel.cpp') (more/less context) (ignore whitespace changes)
-rw-r--r-- | inputmethods/dasher/PPMLanguageModel.cpp | 309 |
1 files changed, 309 insertions, 0 deletions
diff --git a/inputmethods/dasher/PPMLanguageModel.cpp b/inputmethods/dasher/PPMLanguageModel.cpp new file mode 100644 index 0000000..b725a2b --- a/dev/null +++ b/inputmethods/dasher/PPMLanguageModel.cpp | |||
@@ -0,0 +1,309 @@ | |||
1 | // PPMLanguageModel.h | ||
2 | // | ||
3 | ///////////////////////////////////////////////////////////////////////////// | ||
4 | // | ||
5 | // Copyright (c) 1999-2002 David Ward | ||
6 | // | ||
7 | ///////////////////////////////////////////////////////////////////////////// | ||
8 | |||
9 | #include <math.h> | ||
10 | #include "PPMLanguageModel.h" | ||
11 | |||
12 | using namespace Dasher; | ||
13 | using namespace std; | ||
14 | |||
15 | // static TCHAR debug[256]; | ||
16 | typedef unsigned long ulong; | ||
17 | |||
18 | //////////////////////////////////////////////////////////////////////// | ||
19 | /// PPMnode definitions | ||
20 | //////////////////////////////////////////////////////////////////////// | ||
21 | |||
22 | CPPMLanguageModel::CPPMnode *CPPMLanguageModel::CPPMnode::find_symbol(int sym) | ||
23 | // see if symbol is a child of node | ||
24 | { | ||
25 | // printf("finding symbol %d at node %d\n",sym,node->id); | ||
26 | CPPMnode *found=child; | ||
27 | while (found) { | ||
28 | if (found->symbol==sym) | ||
29 | return found; | ||
30 | found=found->next; | ||
31 | } | ||
32 | return 0; | ||
33 | } | ||
34 | |||
35 | |||
36 | CPPMLanguageModel::CPPMnode * CPPMLanguageModel::CPPMnode::add_symbol_to_node(int sym,int *update) | ||
37 | { | ||
38 | CPPMnode *born,*search; | ||
39 | search=find_symbol(sym); | ||
40 | if (!search) { | ||
41 | born = new CPPMnode(sym); | ||
42 | born->next=child; | ||
43 | child=born; | ||
44 | // node->count=1; | ||
45 | return born; | ||
46 | } else { | ||
47 | if (*update) { // perform update exclusions | ||
48 | search->count++; | ||
49 | *update=0; | ||
50 | } | ||
51 | return search; | ||
52 | } | ||
53 | |||
54 | } | ||
55 | |||
56 | |||
57 | ///////////////////////////////////////////////////////////////////// | ||
58 | // CPPMLanguageModel defs | ||
59 | ///////////////////////////////////////////////////////////////////// | ||
60 | |||
61 | CPPMLanguageModel::CPPMLanguageModel(CAlphabet *_alphabet,int _normalization) | ||
62 | : CLanguageModel(_alphabet,_normalization) | ||
63 | { | ||
64 | root=new CPPMnode(-1); | ||
65 | m_rootcontext=new CPPMContext(root,0); | ||
66 | } | ||
67 | |||
68 | |||
69 | CPPMLanguageModel::~CPPMLanguageModel() | ||
70 | { | ||
71 | delete root; | ||
72 | } | ||
73 | |||
74 | |||
75 | bool CPPMLanguageModel::GetProbs(CContext *context,vector<unsigned int> &probs,double addprob) | ||
76 | // get the probability distribution at the context | ||
77 | { | ||
78 | // seems like we have to have this hack for VC++ | ||
79 | CPPMContext *ppmcontext=static_cast<CPPMContext *> (context); | ||
80 | |||
81 | |||
82 | int modelchars=GetNumberModelChars(); | ||
83 | int norm=CLanguageModel::normalization(); | ||
84 | probs.resize(modelchars); | ||
85 | CPPMnode *temp,*s; | ||
86 | int loop,total; | ||
87 | int sym; | ||
88 | ulong spent=0; | ||
89 | ulong size_of_slice; | ||
90 | bool *exclusions=new bool [modelchars]; | ||
91 | ulong uniform=modelchars; | ||
92 | ulong tospend=norm-uniform; | ||
93 | temp=ppmcontext->head; | ||
94 | for (loop=0; loop <modelchars; loop++) { /* set up the exclusions array */ | ||
95 | probs[loop]=0; | ||
96 | exclusions[loop]=0; | ||
97 | } | ||
98 | while (temp!=0) { | ||
99 | //Usprintf(debug,TEXT("tospend %u\n"),tospend); | ||
100 | //DebugOutput(TEXT("round\n")); | ||
101 | total=0; | ||
102 | s=temp->child; | ||
103 | while (s) { | ||
104 | sym=s->symbol; | ||
105 | if (!exclusions[s->symbol]) | ||
106 | total=total+s->count; | ||
107 | s=s->next; | ||
108 | } | ||
109 | if (total) { | ||
110 | //Usprintf(debug,TEXT"escape %u\n"),tospend* | ||
111 | size_of_slice=tospend; | ||
112 | s=temp->child; | ||
113 | while (s) { | ||
114 | if (!exclusions[s->symbol]) { | ||
115 | exclusions[s->symbol]=1; | ||
116 | ulong p=size_of_slice*(2*s->count-1)/2/ulong(total); | ||
117 | probs[s->symbol]+=p; | ||
118 | tospend-=p; | ||
119 | } | ||
120 | // Usprintf(debug,TEXT("sym %u counts %d p %u tospend %u \n"),sym,s->count,p,tospend); | ||
121 | // DebugOutput(debug); | ||
122 | s=s->next; | ||
123 | } | ||
124 | } | ||
125 | temp = temp->vine; | ||
126 | } | ||
127 | //Usprintf(debug,TEXT("Norm %u tospend %u\n"),Norm,tospend); | ||
128 | //DebugOutput(debug); | ||
129 | |||
130 | size_of_slice=tospend; | ||
131 | int symbolsleft=0; | ||
132 | for (sym=1;sym<modelchars;sym++) | ||
133 | if (!probs[sym]) | ||
134 | symbolsleft++; | ||
135 | for (sym=1;sym<modelchars;sym++) | ||
136 | if (!probs[sym]) { | ||
137 | ulong p=size_of_slice/symbolsleft; | ||
138 | probs[sym]+=p; | ||
139 | tospend-=p; | ||
140 | } | ||
141 | |||
142 | // distribute what's left evenly | ||
143 | tospend+=uniform; | ||
144 | for (sym=1;sym<modelchars;sym++) { | ||
145 | ulong p=tospend/(modelchars-sym); | ||
146 | probs[sym]+=p; | ||
147 | tospend-=p; | ||
148 | } | ||
149 | //Usprintf(debug,TEXT("finaltospend %u\n"),tospend); | ||
150 | //DebugOutput(debug); | ||
151 | |||
152 | // free(exclusions); // !!! | ||
153 | // !!! NB by IAM: p577 Stroustrup 3rd Edition: "Allocating an object using new and deleting it using free() is asking for trouble" | ||
154 | delete[] exclusions; | ||
155 | return true; | ||
156 | } | ||
157 | |||
158 | |||
159 | void CPPMLanguageModel::AddSymbol(CPPMLanguageModel::CPPMContext &context,int symbol) | ||
160 | // add symbol to the context | ||
161 | // creates new nodes, updates counts | ||
162 | // and leaves 'context' at the new context | ||
163 | { | ||
164 | // sanity check | ||
165 | if (symbol==0 || symbol>=GetNumberModelChars()) | ||
166 | return; | ||
167 | |||
168 | CPPMnode *vineptr,*temp; | ||
169 | int updatecnt=1; | ||
170 | |||
171 | temp=context.head->vine; | ||
172 | context.head=context.head->add_symbol_to_node(symbol,&updatecnt); | ||
173 | vineptr=context.head; | ||
174 | context.order++; | ||
175 | |||
176 | while (temp!=0) { | ||
177 | vineptr->vine=temp->add_symbol_to_node(symbol,&updatecnt); | ||
178 | vineptr=vineptr->vine; | ||
179 | temp=temp->vine; | ||
180 | } | ||
181 | vineptr->vine=root; | ||
182 | if (context.order>MAX_ORDER){ | ||
183 | context.head=context.head->vine; | ||
184 | context.order--; | ||
185 | } | ||
186 | } | ||
187 | |||
188 | |||
189 | // update context with symbol 'Symbol' | ||
190 | void CPPMLanguageModel::EnterSymbol(CContext* Context, modelchar Symbol) | ||
191 | { | ||
192 | CPPMLanguageModel::CPPMContext& context = * static_cast<CPPMContext *> (Context); | ||
193 | |||
194 | CPPMnode *find; | ||
195 | CPPMnode *temp=context.head; | ||
196 | |||
197 | while (context.head) { | ||
198 | find =context.head->find_symbol(Symbol); | ||
199 | if (find) { | ||
200 | context.order++; | ||
201 | context.head=find; | ||
202 | //Usprintf(debug,TEXT("found context %x order %d\n"),head,order); | ||
203 | //DebugOutput(debug); | ||
204 | return; | ||
205 | } | ||
206 | context.order--; | ||
207 | context.head=context.head->vine; | ||
208 | } | ||
209 | |||
210 | if (context.head==0) { | ||
211 | context.head=root; | ||
212 | context.order=0; | ||
213 | } | ||
214 | |||
215 | } | ||
216 | |||
217 | |||
218 | void CPPMLanguageModel::LearnSymbol(CContext* Context, modelchar Symbol) | ||
219 | { | ||
220 | CPPMLanguageModel::CPPMContext& context = * static_cast<CPPMContext *> (Context); | ||
221 | AddSymbol(context, Symbol); | ||
222 | } | ||
223 | |||
224 | |||
225 | void CPPMLanguageModel::dumpSymbol(int symbol) | ||
226 | { | ||
227 | if ((symbol <= 32) || (symbol >= 127)) | ||
228 | printf( "<%d>", symbol ); | ||
229 | else | ||
230 | printf( "%c", symbol ); | ||
231 | } | ||
232 | |||
233 | |||
234 | void CPPMLanguageModel::dumpString( char *str, int pos, int len ) | ||
235 | // Dump the string STR starting at position POS | ||
236 | { | ||
237 | char cc; | ||
238 | int p; | ||
239 | for (p = pos; p<pos+len; p++) { | ||
240 | cc = str [p]; | ||
241 | if ((cc <= 31) || (cc >= 127)) | ||
242 | printf( "<%d>", cc ); | ||
243 | else | ||
244 | printf( "%c", cc ); | ||
245 | } | ||
246 | } | ||
247 | |||
248 | |||
249 | void CPPMLanguageModel::dumpTrie( CPPMLanguageModel::CPPMnode *t, int d ) | ||
250 | // diagnostic display of the PPM trie from node t and deeper | ||
251 | { | ||
252 | //TODO | ||
253 | /* | ||
254 | dchar debug[256]; | ||
255 | int sym; | ||
256 | CPPMnode *s; | ||
257 | Usprintf( debug,TEXT("%5d %7x "), d, t ); | ||
258 | //TODO: Uncomment this when headers sort out | ||
259 | //DebugOutput(debug); | ||
260 | if (t < 0) // pointer to input | ||
261 | printf( " <" ); | ||
262 | else { | ||
263 | Usprintf(debug,TEXT( " %3d %5d %7x %7x %7x <"), t->symbol,t->count, t->vine, t->child, t->next ); | ||
264 | //TODO: Uncomment this when headers sort out | ||
265 | //DebugOutput(debug); | ||
266 | } | ||
267 | |||
268 | dumpString( dumpTrieStr, 0, d ); | ||
269 | Usprintf( debug,TEXT(">\n") ); | ||
270 | //TODO: Uncomment this when headers sort out | ||
271 | //DebugOutput(debug); | ||
272 | if (t != 0) { | ||
273 | s = t->child; | ||
274 | while (s != 0) { | ||
275 | sym =s->symbol; | ||
276 | |||
277 | dumpTrieStr [d] = sym; | ||
278 | dumpTrie( s, d+1 ); | ||
279 | s = s->next; | ||
280 | } | ||
281 | } | ||
282 | */ | ||
283 | } | ||
284 | |||
285 | |||
286 | void CPPMLanguageModel::dump() | ||
287 | // diagnostic display of the whole PPM trie | ||
288 | { | ||
289 | // TODO: | ||
290 | /* | ||
291 | dchar debug[256]; | ||
292 | Usprintf(debug,TEXT( "Dump of Trie : \n" )); | ||
293 | //TODO: Uncomment this when headers sort out | ||
294 | //DebugOutput(debug); | ||
295 | Usprintf(debug,TEXT( "---------------\n" )); | ||
296 | //TODO: Uncomment this when headers sort out | ||
297 | //DebugOutput(debug); | ||
298 | Usprintf( debug,TEXT( "depth node symbol count vine child next context\n") ); | ||
299 | //TODO: Uncomment this when headers sort out | ||
300 | //DebugOutput(debug); | ||
301 | dumpTrie( root, 0 ); | ||
302 | Usprintf( debug,TEXT( "---------------\n" )); | ||
303 | //TODO: Uncomment this when headers sort out | ||
304 | //DebugOutput(debug); | ||
305 | Usprintf(debug,TEXT( "\n" )); | ||
306 | //TODO: Uncomment this when headers sort out | ||
307 | //DebugOutput(debug); | ||
308 | */ | ||
309 | } | ||