Next version of JABA
[jabaws.git] / binaries / src / probcons / ProbabilisticModel.h
1 /////////////////////////////////////////////////////////////////
2 // ProbabilisticModel.h
3 //
4 // Routines for (1) posterior probability computations
5 //              (2) chained anchoring
6 //              (3) maximum weight trace alignment
7 /////////////////////////////////////////////////////////////////
8
9 #ifndef PROBABILISTICMODEL_H
10 #define PROBABILISTICMODEL_H
11
12 #include <list>
13 #include <cmath>
14 #include <cstdio>
15 #include "SafeVector.h"
16 #include "ScoreType.h"
17 #include "SparseMatrix.h"
18 #include "MultiSequence.h"
19
20 using namespace std;
21
22 const int NumMatchStates = 1;                                    // note that in this version the number
23                                                                  // of match states is fixed at 1...will
24                                                                  // change in future versions
25 const int NumMatrixTypes = NumMatchStates + NumInsertStates * 2;
26
27 /////////////////////////////////////////////////////////////////
28 // ProbabilisticModel
29 //
30 // Class for storing the parameters of a probabilistic model and
31 // performing different computations based on those parameters.
32 // In particular, this class handles the computation of
33 // posterior probabilities that may be used in alignment.
34 /////////////////////////////////////////////////////////////////
35
36 class ProbabilisticModel {
37
38   float initialDistribution[NumMatrixTypes];               // holds the initial probabilities for each state
39   float transProb[NumMatrixTypes][NumMatrixTypes];         // holds all state-to-state transition probabilities
40   float matchProb[256][256];                               // emission probabilities for match states
41   float insProb[256][NumMatrixTypes];                      // emission probabilities for insert states
42
43  public:
44
45   /////////////////////////////////////////////////////////////////
46   // ProbabilisticModel::ProbabilisticModel()
47   //
48   // Constructor.  Builds a new probabilistic model using the
49   // given parameters.
50   /////////////////////////////////////////////////////////////////
51
52   ProbabilisticModel (const VF &initDistribMat, const VF &gapOpen, const VF &gapExtend,
53                       const VVF &emitPairs, const VF &emitSingle){
54
55     // build transition matrix
56     VVF transMat (NumMatrixTypes, VF (NumMatrixTypes, 0.0f));
57     transMat[0][0] = 1;
58     for (int i = 0; i < NumInsertStates; i++){
59       transMat[0][2*i+1] = gapOpen[2*i];
60       transMat[0][2*i+2] = gapOpen[2*i+1];
61       transMat[0][0] -= (gapOpen[2*i] + gapOpen[2*i+1]);
62       assert (transMat[0][0] > 0);
63       transMat[2*i+1][2*i+1] = gapExtend[2*i];
64       transMat[2*i+2][2*i+2] = gapExtend[2*i+1];
65       transMat[2*i+1][2*i+2] = 0;
66       transMat[2*i+2][2*i+1] = 0;
67       transMat[2*i+1][0] = 1 - gapExtend[2*i];
68       transMat[2*i+2][0] = 1 - gapExtend[2*i+1];
69     }
70
71     // create initial and transition probability matrices
72     for (int i = 0; i < NumMatrixTypes; i++){
73       initialDistribution[i] = LOG (initDistribMat[i]);
74       for (int j = 0; j < NumMatrixTypes; j++)
75         transProb[i][j] = LOG (transMat[i][j]);
76     }
77
78     // create insertion and match probability matrices
79     for (int i = 0; i < 256; i++){
80       for (int j = 0; j < NumMatrixTypes; j++)
81         insProb[i][j] = LOG (emitSingle[i]);
82       for (int j = 0; j < 256; j++)
83         matchProb[i][j] = LOG (emitPairs[i][j]);
84     }
85   }
86
87   /////////////////////////////////////////////////////////////////
88   // ProbabilisticModel::ComputeForwardMatrix()
89   //
90   // Computes a set of forward probability matrices for aligning
91   // seq1 and seq2.
92   //
93   // For efficiency reasons, a single-dimensional floating-point
94   // array is used here, with the following indexing scheme:
95   //
96   //    forward[i + NumMatrixTypes * (j * (seq2Length+1) + k)]
97   //    refers to the probability of aligning through j characters
98   //    of the first sequence, k characters of the second sequence,
99   //    and ending in state i.
100   /////////////////////////////////////////////////////////////////
101
102   VF *ComputeForwardMatrix (Sequence *seq1, Sequence *seq2) const {
103
104     assert (seq1);
105     assert (seq2);
106
107     const int seq1Length = seq1->GetLength();
108     const int seq2Length = seq2->GetLength();
109
110     // retrieve the points to the beginning of each sequence
111     SafeVector<char>::iterator iter1 = seq1->GetDataPtr();
112     SafeVector<char>::iterator iter2 = seq2->GetDataPtr();
113
114     // create matrix
115     VF *forwardPtr = new VF (NumMatrixTypes * (seq1Length+1) * (seq2Length+1), LOG_ZERO);
116     assert (forwardPtr);
117     VF &forward = *forwardPtr;
118
119     // initialization condition
120     forward[0 + NumMatrixTypes * (1 * (seq2Length+1) + 1)] = 
121       initialDistribution[0] + matchProb[(unsigned char) iter1[1]][(unsigned char) iter2[1]];
122    
123     for (int k = 0; k < NumInsertStates; k++){
124       forward[2*k+1 + NumMatrixTypes * (1 * (seq2Length+1) + 0)] = 
125         initialDistribution[2*k+1] + insProb[(unsigned char) iter1[1]][k];
126       forward[2*k+2 + NumMatrixTypes * (0 * (seq2Length+1) + 1)] = 
127         initialDistribution[2*k+2] + insProb[(unsigned char) iter2[1]][k]; 
128     }
129     
130     // remember offset for each index combination
131     int ij = 0;
132     int i1j = -seq2Length - 1;
133     int ij1 = -1;
134     int i1j1 = -seq2Length - 2;
135
136     ij *= NumMatrixTypes;
137     i1j *= NumMatrixTypes;
138     ij1 *= NumMatrixTypes;
139     i1j1 *= NumMatrixTypes;
140
141     // compute forward scores
142     for (int i = 0; i <= seq1Length; i++){
143       unsigned char c1 = (i == 0) ? '~' : (unsigned char) iter1[i];
144       for (int j = 0; j <= seq2Length; j++){
145         unsigned char c2 = (j == 0) ? '~' : (unsigned char) iter2[j];
146
147         if (i > 1 || j > 1){
148           if (i > 0 && j > 0){
149             forward[0 + ij] = forward[0 + i1j1] + transProb[0][0];
150             for (int k = 1; k < NumMatrixTypes; k++)
151               LOG_PLUS_EQUALS (forward[0 + ij], forward[k + i1j1] + transProb[k][0]);
152             forward[0 + ij] += matchProb[c1][c2];
153           }
154           if (i > 0){
155             for (int k = 0; k < NumInsertStates; k++)
156               forward[2*k+1 + ij] = insProb[c1][k] +
157                 LOG_ADD (forward[0 + i1j] + transProb[0][2*k+1],
158                          forward[2*k+1 + i1j] + transProb[2*k+1][2*k+1]);
159           }
160           if (j > 0){
161             for (int k = 0; k < NumInsertStates; k++)
162               forward[2*k+2 + ij] = insProb[c2][k] +
163                 LOG_ADD (forward[0 + ij1] + transProb[0][2*k+2],
164                          forward[2*k+2 + ij1] + transProb[2*k+2][2*k+2]);
165           }
166         }
167
168         ij += NumMatrixTypes;
169         i1j += NumMatrixTypes;
170         ij1 += NumMatrixTypes;
171         i1j1 += NumMatrixTypes;
172       }
173     }
174
175     return forwardPtr;
176   }
177
178   /////////////////////////////////////////////////////////////////
179   // ProbabilisticModel::ComputeBackwardMatrix()
180   //
181   // Computes a set of backward probability matrices for aligning
182   // seq1 and seq2.
183   //
184   // For efficiency reasons, a single-dimensional floating-point
185   // array is used here, with the following indexing scheme:
186   //
187   //    backward[i + NumMatrixTypes * (j * (seq2Length+1) + k)]
188   //    refers to the probability of starting in state i and
189   //    aligning from character j+1 to the end of the first
190   //    sequence and from character k+1 to the end of the second
191   //    sequence.
192   /////////////////////////////////////////////////////////////////
193
194   VF *ComputeBackwardMatrix (Sequence *seq1, Sequence *seq2) const {
195
196     assert (seq1);
197     assert (seq2);
198
199     const int seq1Length = seq1->GetLength();
200     const int seq2Length = seq2->GetLength();
201     SafeVector<char>::iterator iter1 = seq1->GetDataPtr();
202     SafeVector<char>::iterator iter2 = seq2->GetDataPtr();
203
204     // create matrix
205     VF *backwardPtr = new VF (NumMatrixTypes * (seq1Length+1) * (seq2Length+1), LOG_ZERO);
206     assert (backwardPtr);
207     VF &backward = *backwardPtr;
208
209     // initialization condition
210     for (int k = 0; k < NumMatrixTypes; k++)
211       backward[NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1) + k] = initialDistribution[k];
212
213     // remember offset for each index combination
214     int ij = (seq1Length+1) * (seq2Length+1) - 1;
215     int i1j = ij + seq2Length + 1;
216     int ij1 = ij + 1;
217     int i1j1 = ij + seq2Length + 2;
218
219     ij *= NumMatrixTypes;
220     i1j *= NumMatrixTypes;
221     ij1 *= NumMatrixTypes;
222     i1j1 *= NumMatrixTypes;
223
224     // compute backward scores
225     for (int i = seq1Length; i >= 0; i--){
226       unsigned char c1 = (i == seq1Length) ? '~' : (unsigned char) iter1[i+1];
227       for (int j = seq2Length; j >= 0; j--){
228         unsigned char c2 = (j == seq2Length) ? '~' : (unsigned char) iter2[j+1];
229
230         if (i < seq1Length && j < seq2Length){
231           const float ProbXY = backward[0 + i1j1] + matchProb[c1][c2];
232           for (int k = 0; k < NumMatrixTypes; k++)
233             LOG_PLUS_EQUALS (backward[k + ij], ProbXY + transProb[k][0]);
234         }
235         if (i < seq1Length){
236           for (int k = 0; k < NumInsertStates; k++){
237             LOG_PLUS_EQUALS (backward[0 + ij], backward[2*k+1 + i1j] + insProb[c1][k] + transProb[0][2*k+1]);
238             LOG_PLUS_EQUALS (backward[2*k+1 + ij], backward[2*k+1 + i1j] + insProb[c1][k] + transProb[2*k+1][2*k+1]);
239           }
240         }
241         if (j < seq2Length){
242           for (int k = 0; k < NumInsertStates; k++){
243             LOG_PLUS_EQUALS (backward[0 + ij], backward[2*k+2 + ij1] + insProb[c2][k] + transProb[0][2*k+2]);
244             LOG_PLUS_EQUALS (backward[2*k+2 + ij], backward[2*k+2 + ij1] + insProb[c2][k] + transProb[2*k+2][2*k+2]);
245           }
246         }
247
248         ij -= NumMatrixTypes;
249         i1j -= NumMatrixTypes;
250         ij1 -= NumMatrixTypes;
251         i1j1 -= NumMatrixTypes;
252       }
253     }
254
255     return backwardPtr;
256   }
257
258   /////////////////////////////////////////////////////////////////
259   // ProbabilisticModel::ComputeTotalProbability()
260   //
261   // Computes the total probability of an alignment given
262   // the forward and backward matrices.
263   /////////////////////////////////////////////////////////////////
264
265   float ComputeTotalProbability (int seq1Length, int seq2Length,
266                                  const VF &forward, const VF &backward) const {
267
268     // compute total probability
269     float totalForwardProb = LOG_ZERO;
270     float totalBackwardProb = LOG_ZERO;
271     for (int k = 0; k < NumMatrixTypes; k++){
272       LOG_PLUS_EQUALS (totalForwardProb,
273                        forward[k + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)] + 
274                        backward[k + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)]);
275     }
276
277     totalBackwardProb = 
278       forward[0 + NumMatrixTypes * (1 * (seq2Length+1) + 1)] +
279       backward[0 + NumMatrixTypes * (1 * (seq2Length+1) + 1)];
280
281     for (int k = 0; k < NumInsertStates; k++){
282       LOG_PLUS_EQUALS (totalBackwardProb,
283                        forward[2*k+1 + NumMatrixTypes * (1 * (seq2Length+1) + 0)] +
284                        backward[2*k+1 + NumMatrixTypes * (1 * (seq2Length+1) + 0)]);
285       LOG_PLUS_EQUALS (totalBackwardProb,
286                        forward[2*k+2 + NumMatrixTypes * (0 * (seq2Length+1) + 1)] +
287                        backward[2*k+2 + NumMatrixTypes * (0 * (seq2Length+1) + 1)]);
288     }
289
290     //    cerr << totalForwardProb << " " << totalBackwardProb << endl;
291     
292     return (totalForwardProb + totalBackwardProb) / 2;
293   }
294
295   /////////////////////////////////////////////////////////////////
296   // ProbabilisticModel::ComputePosteriorMatrix()
297   //
298   // Computes the posterior probability matrix based on
299   // the forward and backward matrices.
300   /////////////////////////////////////////////////////////////////
301
302   VF *ComputePosteriorMatrix (Sequence *seq1, Sequence *seq2,
303                               const VF &forward, const VF &backward) const {
304
305     assert (seq1);
306     assert (seq2);
307
308     const int seq1Length = seq1->GetLength();
309     const int seq2Length = seq2->GetLength();
310
311     float totalProb = ComputeTotalProbability (seq1Length, seq2Length,
312                                                forward, backward);
313
314     // compute posterior matrices
315     VF *posteriorPtr = new VF((seq1Length+1) * (seq2Length+1)); assert (posteriorPtr);
316     VF &posterior = *posteriorPtr;
317
318     int ij = 0;
319     VF::iterator ptr = posterior.begin();
320
321     for (int i = 0; i <= seq1Length; i++){
322       for (int j = 0; j <= seq2Length; j++){
323         *(ptr++) = EXP (min (LOG_ONE, forward[ij] + backward[ij] - totalProb));
324         ij += NumMatrixTypes;
325       }
326     }
327
328     posterior[0] = 0;
329
330     return posteriorPtr;
331   }
332
333   /*
334   /////////////////////////////////////////////////////////////////
335   // ProbabilisticModel::ComputeExpectedCounts()
336   //
337   // Computes the expected counts for the various transitions.
338   /////////////////////////////////////////////////////////////////
339
340   VVF *ComputeExpectedCounts () const {
341
342     assert (seq1);
343     assert (seq2);
344
345     const int seq1Length = seq1->GetLength();
346     const int seq2Length = seq2->GetLength();
347     SafeVector<char>::iterator iter1 = seq1->GetDataPtr();
348     SafeVector<char>::iterator iter2 = seq2->GetDataPtr();
349
350     // compute total probability
351     float totalProb = ComputeTotalProbability (seq1Length, seq2Length,
352                                                forward, backward);
353
354     // initialize expected counts
355     VVF *countsPtr = new VVF(NumMatrixTypes + 1, VF(NumMatrixTypes, LOG_ZERO)); assert (countsPtr);
356     VVF &counts = *countsPtr;
357
358     // remember offset for each index combination
359     int ij = 0;
360     int i1j = -seq2Length - 1;
361     int ij1 = -1;
362     int i1j1 = -seq2Length - 2;
363
364     ij *= NumMatrixTypes;
365     i1j *= NumMatrixTypes;
366     ij1 *= NumMatrixTypes;
367     i1j1 *= NumMatrixTypes;
368
369     // compute expected counts
370     for (int i = 0; i <= seq1Length; i++){
371       unsigned char c1 = (i == 0) ? '~' : (unsigned char) iter1[i];
372       for (int j = 0; j <= seq2Length; j++){
373         unsigned char c2 = (j == 0) ? '~' : (unsigned char) iter2[j];
374
375         if (i > 0 && j > 0){
376           for (int k = 0; k < NumMatrixTypes; k++)
377             LOG_PLUS_EQUALS (counts[k][0],
378                              forward[k + i1j1] + transProb[k][0] +
379                              matchProb[c1][c2] + backward[0 + ij]);
380         }
381         if (i > 0){
382           for (int k = 0; k < NumInsertStates; k++){
383             LOG_PLUS_EQUALS (counts[0][2*k+1],
384                              forward[0 + i1j] + transProb[0][2*k+1] +
385                              insProb[c1][k] + backward[2*k+1 + ij]);
386             LOG_PLUS_EQUALS (counts[2*k+1][2*k+1],
387                              forward[2*k+1 + i1j] + transProb[2*k+1][2*k+1] +
388                              insProb[c1][k] + backward[2*k+1 + ij]);
389           }
390         }
391         if (j > 0){
392           for (int k = 0; k < NumInsertStates; k++){
393             LOG_PLUS_EQUALS (counts[0][2*k+2],
394                              forward[0 + ij1] + transProb[0][2*k+2] +
395                              insProb[c2][k] + backward[2*k+2 + ij]);
396             LOG_PLUS_EQUALS (counts[2*k+2][2*k+2],
397                              forward[2*k+2 + ij1] + transProb[2*k+2][2*k+2] +
398                              insProb[c2][k] + backward[2*k+2 + ij]);
399           }
400         }
401
402         ij += NumMatrixTypes;
403         i1j += NumMatrixTypes;
404         ij1 += NumMatrixTypes;
405         i1j1 += NumMatrixTypes;
406       }
407     }
408
409     // scale all expected counts appropriately
410     for (int i = 0; i < NumMatrixTypes; i++)
411       for (int j = 0; j < NumMatrixTypes; j++)
412         counts[i][j] -= totalProb;
413
414   }
415   */
416
417   /////////////////////////////////////////////////////////////////
418   // ProbabilisticModel::ComputeNewParameters()
419   //
420   // Computes a new parameter set based on the expected counts
421   // given.
422   /////////////////////////////////////////////////////////////////
423
424   void ComputeNewParameters (Sequence *seq1, Sequence *seq2,
425                              const VF &forward, const VF &backward,
426                              VF &initDistribMat, VF &gapOpen,
427                              VF &gapExtend, VVF &emitPairs, VF &emitSingle, bool enableTrainEmissions) const {
428     
429     assert (seq1);
430     assert (seq2);
431
432     const int seq1Length = seq1->GetLength();
433     const int seq2Length = seq2->GetLength();
434     SafeVector<char>::iterator iter1 = seq1->GetDataPtr();
435     SafeVector<char>::iterator iter2 = seq2->GetDataPtr();
436
437     // compute total probability
438     float totalProb = ComputeTotalProbability (seq1Length, seq2Length,
439                                                forward, backward);
440     
441     // initialize expected counts
442     VVF transCounts (NumMatrixTypes, VF (NumMatrixTypes, LOG_ZERO));
443     VF initCounts (NumMatrixTypes, LOG_ZERO);
444     VVF pairCounts (256, VF (256, LOG_ZERO));
445     VF singleCounts (256, LOG_ZERO);
446     
447     // remember offset for each index combination
448     int ij = 0;
449     int i1j = -seq2Length - 1;
450     int ij1 = -1;
451     int i1j1 = -seq2Length - 2;
452
453     ij *= NumMatrixTypes;
454     i1j *= NumMatrixTypes;
455     ij1 *= NumMatrixTypes;
456     i1j1 *= NumMatrixTypes;
457
458     // compute initial distribution posteriors
459     initCounts[0] = LOG_ADD (forward[0 + NumMatrixTypes * (1 * (seq2Length+1) + 1)] +
460                              backward[0 + NumMatrixTypes * (1 * (seq2Length+1) + 1)],
461                              forward[0 + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)] + 
462                              backward[0 + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)]);
463     for (int k = 0; k < NumInsertStates; k++){
464       initCounts[2*k+1] = LOG_ADD (forward[2*k+1 + NumMatrixTypes * (1 * (seq2Length+1) + 0)] +
465                                    backward[2*k+1 + NumMatrixTypes * (1 * (seq2Length+1) + 0)],
466                                    forward[2*k+1 + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)] + 
467                                    backward[2*k+1 + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)]);
468       initCounts[2*k+2] = LOG_ADD (forward[2*k+2 + NumMatrixTypes * (0 * (seq2Length+1) + 1)] +
469                                    backward[2*k+2 + NumMatrixTypes * (0 * (seq2Length+1) + 1)],
470                                    forward[2*k+2 + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)] + 
471                                    backward[2*k+2 + NumMatrixTypes * ((seq1Length+1) * (seq2Length+1) - 1)]);
472     }
473
474     // compute expected counts
475     for (int i = 0; i <= seq1Length; i++){
476       unsigned char c1 = (i == 0) ? '~' : (unsigned char) toupper(iter1[i]);
477       for (int j = 0; j <= seq2Length; j++){
478         unsigned char c2 = (j == 0) ? '~' : (unsigned char) toupper(iter2[j]);
479
480         if (i > 0 && j > 0){
481           if (enableTrainEmissions && i == 1 && j == 1){
482             LOG_PLUS_EQUALS (pairCounts[c1][c2],
483                              initialDistribution[0] + matchProb[c1][c2] + backward[0 + ij]);
484             LOG_PLUS_EQUALS (pairCounts[c2][c1],
485                              initialDistribution[0] + matchProb[c2][c1] + backward[0 + ij]);
486           }
487
488           for (int k = 0; k < NumMatrixTypes; k++){
489             LOG_PLUS_EQUALS (transCounts[k][0],
490                              forward[k + i1j1] + transProb[k][0] +
491                              matchProb[c1][c2] + backward[0 + ij]);
492             if (enableTrainEmissions && i != 1 || j != 1){
493               LOG_PLUS_EQUALS (pairCounts[c1][c2],
494                                forward[k + i1j1] + transProb[k][0] +
495                                matchProb[c1][c2] + backward[0 + ij]);
496               LOG_PLUS_EQUALS (pairCounts[c2][c1],
497                                forward[k + i1j1] + transProb[k][0] +
498                                matchProb[c2][c1] + backward[0 + ij]);
499             }
500           }
501         }
502         if (i > 0){
503           for (int k = 0; k < NumInsertStates; k++){
504             LOG_PLUS_EQUALS (transCounts[0][2*k+1],
505                              forward[0 + i1j] + transProb[0][2*k+1] +
506                              insProb[c1][k] + backward[2*k+1 + ij]);
507             LOG_PLUS_EQUALS (transCounts[2*k+1][2*k+1],
508                              forward[2*k+1 + i1j] + transProb[2*k+1][2*k+1] +
509                              insProb[c1][k] + backward[2*k+1 + ij]);
510             if (enableTrainEmissions){
511               if (i == 1 && j == 0){
512                 LOG_PLUS_EQUALS (singleCounts[c1],
513                                  initialDistribution[2*k+1] + insProb[c1][k] + backward[2*k+1 + ij]);
514               }
515               else {
516                 LOG_PLUS_EQUALS (singleCounts[c1],
517                                  forward[0 + i1j] + transProb[0][2*k+1] +
518                                  insProb[c1][k] + backward[2*k+1 + ij]);
519                 LOG_PLUS_EQUALS (singleCounts[c1],
520                                  forward[2*k+1 + i1j] + transProb[2*k+1][2*k+1] +
521                                  insProb[c1][k] + backward[2*k+1 + ij]);
522               }
523             }
524           }
525         }
526         if (j > 0){
527           for (int k = 0; k < NumInsertStates; k++){
528             LOG_PLUS_EQUALS (transCounts[0][2*k+2],
529                              forward[0 + ij1] + transProb[0][2*k+2] +
530                              insProb[c2][k] + backward[2*k+2 + ij]);
531             LOG_PLUS_EQUALS (transCounts[2*k+2][2*k+2],
532                              forward[2*k+2 + ij1] + transProb[2*k+2][2*k+2] +
533                              insProb[c2][k] + backward[2*k+2 + ij]);
534             if (enableTrainEmissions){
535               if (i == 0 && j == 1){
536                 LOG_PLUS_EQUALS (singleCounts[c2],
537                                  initialDistribution[2*k+2] + insProb[c2][k] + backward[2*k+2 + ij]);
538               }
539               else {
540                 LOG_PLUS_EQUALS (singleCounts[c2],
541                                  forward[0 + ij1] + transProb[0][2*k+2] +
542                                  insProb[c2][k] + backward[2*k+2 + ij]);
543                 LOG_PLUS_EQUALS (singleCounts[c2],
544                                  forward[2*k+2 + ij1] + transProb[2*k+2][2*k+2] +
545                                  insProb[c2][k] + backward[2*k+2 + ij]);
546               }
547             }
548           }
549         }
550       
551         ij += NumMatrixTypes;
552         i1j += NumMatrixTypes;
553         ij1 += NumMatrixTypes;
554         i1j1 += NumMatrixTypes;
555       }
556     }
557
558     // scale all expected counts appropriately
559     for (int i = 0; i < NumMatrixTypes; i++){
560       initCounts[i] -= totalProb;
561       for (int j = 0; j < NumMatrixTypes; j++)
562         transCounts[i][j] -= totalProb;
563     }
564     if (enableTrainEmissions){
565       for (int i = 0; i < 256; i++){
566         for (int j = 0; j < 256; j++)
567           pairCounts[i][j] -= totalProb;
568         singleCounts[i] -= totalProb;
569       }
570     }
571
572     // compute new initial distribution
573     float totalInitDistribCounts = 0;
574     for (int i = 0; i < NumMatrixTypes; i++)
575       totalInitDistribCounts += exp (initCounts[i]); // should be 2
576     initDistribMat[0] = min (1.0f, max (0.0f, (float) exp (initCounts[0]) / totalInitDistribCounts));
577     for (int k = 0; k < NumInsertStates; k++){
578       float val = (exp (initCounts[2*k+1]) + exp (initCounts[2*k+2])) / 2;
579       initDistribMat[2*k+1] = initDistribMat[2*k+2] = min (1.0f, max (0.0f, val / totalInitDistribCounts));
580     }
581
582     // compute total counts for match state
583     float inMatchStateCounts = 0;
584     for (int i = 0; i < NumMatrixTypes; i++)
585       inMatchStateCounts += exp (transCounts[0][i]);
586     for (int i = 0; i < NumInsertStates; i++){
587
588       // compute total counts for gap state
589       float inGapStateCounts =
590         exp (transCounts[2*i+1][0]) +
591         exp (transCounts[2*i+1][2*i+1]) +
592         exp (transCounts[2*i+2][0]) +
593         exp (transCounts[2*i+2][2*i+2]);
594
595       gapOpen[2*i] = gapOpen[2*i+1] =
596         (exp (transCounts[0][2*i+1]) +
597          exp (transCounts[0][2*i+2])) /
598         (2 * inMatchStateCounts);
599
600       gapExtend[2*i] = gapExtend[2*i+1] =
601         (exp (transCounts[2*i+1][2*i+1]) +
602          exp (transCounts[2*i+2][2*i+2])) /
603         inGapStateCounts;
604     }
605
606     if (enableTrainEmissions){
607       float totalPairCounts = 0;
608       float totalSingleCounts = 0;
609       for (int i = 0; i < 256; i++){
610         for (int j = 0; j <= i; j++)
611           totalPairCounts += exp (pairCounts[j][i]);
612         totalSingleCounts += exp (singleCounts[i]);
613       }
614       
615       for (int i = 0; i < 256; i++) if (!islower ((char) i)){
616         int li = (int)((unsigned char) tolower ((char) i));
617         for (int j = 0; j <= i; j++) if (!islower ((char) j)){
618           int lj = (int)((unsigned char) tolower ((char) j));
619           emitPairs[i][j] = emitPairs[i][lj] = emitPairs[li][j] = emitPairs[li][lj] = 
620             emitPairs[j][i] = emitPairs[j][li] = emitPairs[lj][i] = emitPairs[lj][li] = exp(pairCounts[j][i]) / totalPairCounts;
621         }
622         emitSingle[i] = emitSingle[li] = exp(singleCounts[i]) / totalSingleCounts;
623       }
624     }
625   }
626     
627   /////////////////////////////////////////////////////////////////
628   // ProbabilisticModel::ComputeAlignment()
629   //
630   // Computes an alignment based on the given posterior matrix.
631   // This is done by finding the maximum summing path (or
632   // maximum weight trace) through the posterior matrix.  The
633   // final alignment is returned as a pair consisting of:
634   //    (1) a string (e.g., XXXBBXXXBBBBBBYYYYBBB) where X's and
635   //        denote insertions in one of the two sequences and
636   //        B's denote that both sequences are present (i.e.
637   //        matches).
638   //    (2) a float indicating the sum achieved
639   /////////////////////////////////////////////////////////////////
640
641   pair<SafeVector<char> *, float> ComputeAlignment (int seq1Length, int seq2Length,
642                                                     const VF &posterior) const {
643
644     float *twoRows = new float[(seq2Length+1)*2]; assert (twoRows);
645     float *oldRow = twoRows;
646     float *newRow = twoRows + seq2Length + 1;
647
648     char *tracebackMatrix = new char[(seq1Length+1)*(seq2Length+1)]; assert (tracebackMatrix);
649     char *tracebackPtr = tracebackMatrix;
650
651     VF::const_iterator posteriorPtr = posterior.begin() + seq2Length + 1;
652
653     // initialization
654     for (int i = 0; i <= seq2Length; i++){
655       oldRow[i] = 0;
656       *(tracebackPtr++) = 'L';
657     }
658
659     // fill in matrix
660     for (int i = 1; i <= seq1Length; i++){
661
662       // initialize left column
663       newRow[0] = 0;
664       posteriorPtr++;
665       *(tracebackPtr++) = 'U';
666
667       // fill in rest of row
668       for (int j = 1; j <= seq2Length; j++){
669         ChooseBestOfThree (*(posteriorPtr++) + oldRow[j-1], newRow[j-1], oldRow[j],
670                            'D', 'L', 'U', &newRow[j], tracebackPtr++);
671       }
672
673       // swap rows
674       float *temp = oldRow;
675       oldRow = newRow;
676       newRow = temp;
677     }
678
679     // store best score
680     float total = oldRow[seq2Length];
681     delete [] twoRows;
682
683     // compute traceback
684     SafeVector<char> *alignment = new SafeVector<char>; assert (alignment);
685     int r = seq1Length, c = seq2Length;
686     while (r != 0 || c != 0){
687       char ch = tracebackMatrix[r*(seq2Length+1) + c];
688       switch (ch){
689       case 'L': c--; alignment->push_back ('Y'); break;
690       case 'U': r--; alignment->push_back ('X'); break;
691       case 'D': c--; r--; alignment->push_back ('B'); break;
692       default: assert (false);
693       }
694     }
695
696     delete [] tracebackMatrix;
697
698     reverse (alignment->begin(), alignment->end());
699
700     return make_pair(alignment, total);
701   }
702
703   /////////////////////////////////////////////////////////////////
704   // ProbabilisticModel::ComputeAlignmentWithGapPenalties()
705   //
706   // Similar to ComputeAlignment() except with gap penalties.
707   /////////////////////////////////////////////////////////////////
708
709   pair<SafeVector<char> *, float> ComputeAlignmentWithGapPenalties (MultiSequence *align1,
710                                                                     MultiSequence *align2,
711                                                                     const VF &posterior, int numSeqs1,
712                                                                     int numSeqs2,
713                                                                     float gapOpenPenalty,
714                                                                     float gapContinuePenalty) const {
715     int seq1Length = align1->GetSequence(0)->GetLength();
716     int seq2Length = align2->GetSequence(0)->GetLength();
717     SafeVector<SafeVector<char>::iterator > dataPtrs1 (align1->GetNumSequences());
718     SafeVector<SafeVector<char>::iterator > dataPtrs2 (align2->GetNumSequences());
719
720     // grab character data
721     for (int i = 0; i < align1->GetNumSequences(); i++)
722       dataPtrs1[i] = align1->GetSequence(i)->GetDataPtr();
723     for (int i = 0; i < align2->GetNumSequences(); i++)
724       dataPtrs2[i] = align2->GetSequence(i)->GetDataPtr();
725
726     // the number of active sequences at any given column is defined to be the
727     // number of non-gap characters in that column; the number of gap opens at
728     // any given column is defined to be the number of gap characters in that
729     // column where the previous character in the respective sequence was not
730     // a gap
731     SafeVector<int> numActive1 (seq1Length+1), numGapOpens1 (seq1Length+1);
732     SafeVector<int> numActive2 (seq2Length+1), numGapOpens2 (seq2Length+1);
733
734     // compute number of active sequences and gap opens for each group
735     for (int i = 0; i < align1->GetNumSequences(); i++){
736       SafeVector<char>::iterator dataPtr = align1->GetSequence(i)->GetDataPtr();
737       numActive1[0] = numGapOpens1[0] = 0;
738       for (int j = 1; j <= seq1Length; j++){
739         if (dataPtr[j] != '-'){
740           numActive1[j]++;
741           numGapOpens1[j] += (j != 1 && dataPtr[j-1] != '-');
742         }
743       }
744     }
745     for (int i = 0; i < align2->GetNumSequences(); i++){
746       SafeVector<char>::iterator dataPtr = align2->GetSequence(i)->GetDataPtr();
747       numActive2[0] = numGapOpens2[0] = 0;
748       for (int j = 1; j <= seq2Length; j++){
749         if (dataPtr[j] != '-'){
750           numActive2[j]++;
751           numGapOpens2[j] += (j != 1 && dataPtr[j-1] != '-');
752         }
753       }
754     }
755
756     VVF openingPenalty1 (numSeqs1+1, VF (numSeqs2+1));
757     VF continuingPenalty1 (numSeqs1+1);
758     VVF openingPenalty2 (numSeqs1+1, VF (numSeqs2+1));
759     VF continuingPenalty2 (numSeqs2+1);
760
761     // precompute penalties
762     for (int i = 0; i <= numSeqs1; i++)
763       for (int j = 0; j <= numSeqs2; j++)
764         openingPenalty1[i][j] = i * (gapOpenPenalty * j + gapContinuePenalty * (numSeqs2 - j));
765     for (int i = 0; i <= numSeqs1; i++)
766       continuingPenalty1[i] = i * gapContinuePenalty * numSeqs2;
767     for (int i = 0; i <= numSeqs2; i++)
768       for (int j = 0; j <= numSeqs1; j++)
769         openingPenalty2[i][j] = i * (gapOpenPenalty * j + gapContinuePenalty * (numSeqs1 - j));
770     for (int i = 0; i <= numSeqs2; i++)
771       continuingPenalty2[i] = i * gapContinuePenalty * numSeqs1;
772
773     float *twoRows = new float[6*(seq2Length+1)]; assert (twoRows);
774     float *oldRowMatch = twoRows;
775     float *newRowMatch = twoRows + (seq2Length+1);
776     float *oldRowInsertX = twoRows + 2*(seq2Length+1);
777     float *newRowInsertX = twoRows + 3*(seq2Length+1);
778     float *oldRowInsertY = twoRows + 4*(seq2Length+1);
779     float *newRowInsertY = twoRows + 5*(seq2Length+1);
780
781     char *tracebackMatrix = new char[3*(seq1Length+1)*(seq2Length+1)]; assert (tracebackMatrix);
782     char *tracebackPtr = tracebackMatrix;
783
784     VF::const_iterator posteriorPtr = posterior.begin() + seq2Length + 1;
785
786     // initialization
787     for (int i = 0; i <= seq2Length; i++){
788       oldRowMatch[i] = oldRowInsertX[i] = (i == 0) ? 0 : LOG_ZERO;
789       oldRowInsertY[i] = (i == 0) ? 0 : oldRowInsertY[i-1] + continuingPenalty2[numActive2[i]];
790       *(tracebackPtr) = *(tracebackPtr+1) = *(tracebackPtr+2) = 'Y';
791       tracebackPtr += 3;
792     }
793
794     // fill in matrix
795     for (int i = 1; i <= seq1Length; i++){
796
797       // initialize left column
798       newRowMatch[0] = newRowInsertY[0] = LOG_ZERO;
799       newRowInsertX[0] = oldRowInsertX[0] + continuingPenalty1[numActive1[i]];
800       posteriorPtr++;
801       *(tracebackPtr) = *(tracebackPtr+1) = *(tracebackPtr+2) = 'X';
802       tracebackPtr += 3;
803
804       // fill in rest of row
805       for (int j = 1; j <= seq2Length; j++){
806
807         // going to MATCH state
808         ChooseBestOfThree (oldRowMatch[j-1],
809                            oldRowInsertX[j-1],
810                            oldRowInsertY[j-1],
811                            'M', 'X', 'Y', &newRowMatch[j], tracebackPtr++);
812         newRowMatch[j] += *(posteriorPtr++);
813
814         // going to INSERT X state
815         ChooseBestOfThree (oldRowMatch[j] + openingPenalty1[numActive1[i]][numGapOpens2[j]],
816                            oldRowInsertX[j] + continuingPenalty1[numActive1[i]],
817                            oldRowInsertY[j] + openingPenalty1[numActive1[i]][numGapOpens2[j]],
818                            'M', 'X', 'Y', &newRowInsertX[j], tracebackPtr++);
819
820         // going to INSERT Y state
821         ChooseBestOfThree (newRowMatch[j-1] + openingPenalty2[numActive2[j]][numGapOpens1[i]],
822                            newRowInsertX[j-1] + openingPenalty2[numActive2[j]][numGapOpens1[i]],
823                            newRowInsertY[j-1] + continuingPenalty2[numActive2[j]],
824                            'M', 'X', 'Y', &newRowInsertY[j], tracebackPtr++);
825       }
826
827       // swap rows
828       float *temp;
829       temp = oldRowMatch; oldRowMatch = newRowMatch; newRowMatch = temp;
830       temp = oldRowInsertX; oldRowInsertX = newRowInsertX; newRowInsertX = temp;
831       temp = oldRowInsertY; oldRowInsertY = newRowInsertY; newRowInsertY = temp;
832     }
833
834     // store best score
835     float total;
836     char matrix;
837     ChooseBestOfThree (oldRowMatch[seq2Length], oldRowInsertX[seq2Length], oldRowInsertY[seq2Length],
838                        'M', 'X', 'Y', &total, &matrix);
839
840     delete [] twoRows;
841
842     // compute traceback
843     SafeVector<char> *alignment = new SafeVector<char>; assert (alignment);
844     int r = seq1Length, c = seq2Length;
845     while (r != 0 || c != 0){
846
847       int offset = (matrix == 'M') ? 0 : (matrix == 'X') ? 1 : 2;
848       char ch = tracebackMatrix[(r*(seq2Length+1) + c) * 3 + offset];
849       switch (matrix){
850       case 'Y': c--; alignment->push_back ('Y'); break;
851       case 'X': r--; alignment->push_back ('X'); break;
852       case 'M': c--; r--; alignment->push_back ('B'); break;
853       default: assert (false);
854       }
855       matrix = ch;
856     }
857
858     delete [] tracebackMatrix;
859
860     reverse (alignment->begin(), alignment->end());
861
862     return make_pair(alignment, 1.0f);
863   }
864
865   /////////////////////////////////////////////////////////////////
866   // ProbabilisticModel::ComputeViterbiAlignment()
867   //
868   // Computes the highest probability pairwise alignment using the
869   // probabilistic model.  The final alignment is returned as a
870   //  pair consisting of:
871   //    (1) a string (e.g., XXXBBXXXBBBBBBYYYYBBB) where X's and
872   //        denote insertions in one of the two sequences and
873   //        B's denote that both sequences are present (i.e.
874   //        matches).
875   //    (2) a float containing the log probability of the best
876   //        alignment (not used)
877   /////////////////////////////////////////////////////////////////
878
879   pair<SafeVector<char> *, float> ComputeViterbiAlignment (Sequence *seq1, Sequence *seq2) const {
880     
881     assert (seq1);
882     assert (seq2);
883     
884     const int seq1Length = seq1->GetLength();
885     const int seq2Length = seq2->GetLength();
886     
887     // retrieve the points to the beginning of each sequence
888     SafeVector<char>::iterator iter1 = seq1->GetDataPtr();
889     SafeVector<char>::iterator iter2 = seq2->GetDataPtr();
890     
891     // create viterbi matrix
892     VF *viterbiPtr = new VF (NumMatrixTypes * (seq1Length+1) * (seq2Length+1), LOG_ZERO);
893     assert (viterbiPtr);
894     VF &viterbi = *viterbiPtr;
895
896     // create traceback matrix
897     VI *tracebackPtr = new VI (NumMatrixTypes * (seq1Length+1) * (seq2Length+1), -1);
898     assert (tracebackPtr);
899     VI &traceback = *tracebackPtr;
900
901     // initialization condition
902     for (int k = 0; k < NumMatrixTypes; k++)
903       viterbi[k] = initialDistribution[k];
904
905     // remember offset for each index combination
906     int ij = 0;
907     int i1j = -seq2Length - 1;
908     int ij1 = -1;
909     int i1j1 = -seq2Length - 2;
910
911     ij *= NumMatrixTypes;
912     i1j *= NumMatrixTypes;
913     ij1 *= NumMatrixTypes;
914     i1j1 *= NumMatrixTypes;
915
916     // compute viterbi scores
917     for (int i = 0; i <= seq1Length; i++){
918       unsigned char c1 = (i == 0) ? '~' : (unsigned char) iter1[i];
919       for (int j = 0; j <= seq2Length; j++){
920         unsigned char c2 = (j == 0) ? '~' : (unsigned char) iter2[j];
921
922         if (i > 0 && j > 0){
923           for (int k = 0; k < NumMatrixTypes; k++){
924             float newVal = viterbi[k + i1j1] + transProb[k][0] + matchProb[c1][c2];
925             if (viterbi[0 + ij] < newVal){
926               viterbi[0 + ij] = newVal;
927               traceback[0 + ij] = k;
928             }
929           }
930         }
931         if (i > 0){
932           for (int k = 0; k < NumInsertStates; k++){
933             float valFromMatch = insProb[c1][k] + viterbi[0 + i1j] + transProb[0][2*k+1];
934             float valFromIns = insProb[c1][k] + viterbi[2*k+1 + i1j] + transProb[2*k+1][2*k+1];
935             if (valFromMatch >= valFromIns){
936               viterbi[2*k+1 + ij] = valFromMatch;
937               traceback[2*k+1 + ij] = 0;
938             }
939             else {
940               viterbi[2*k+1 + ij] = valFromIns;
941               traceback[2*k+1 + ij] = 2*k+1;
942             }
943           }
944         }
945         if (j > 0){
946           for (int k = 0; k < NumInsertStates; k++){
947             float valFromMatch = insProb[c2][k] + viterbi[0 + ij1] + transProb[0][2*k+2];
948             float valFromIns = insProb[c2][k] + viterbi[2*k+2 + ij1] + transProb[2*k+2][2*k+2];
949             if (valFromMatch >= valFromIns){
950               viterbi[2*k+2 + ij] = valFromMatch;
951               traceback[2*k+2 + ij] = 0;
952             }
953             else {
954               viterbi[2*k+2 + ij] = valFromIns;
955               traceback[2*k+2 + ij] = 2*k+2;
956             }
957           }
958         }
959
960         ij += NumMatrixTypes;
961         i1j += NumMatrixTypes;
962         ij1 += NumMatrixTypes;
963         i1j1 += NumMatrixTypes;
964       }
965     }
966
967     // figure out best terminating cell
968     float bestProb = LOG_ZERO;
969     int state = -1;
970     for (int k = 0; k < NumMatrixTypes; k++){
971       float thisProb = viterbi[k + NumMatrixTypes * ((seq1Length+1)*(seq2Length+1) - 1)] + initialDistribution[k];
972       if (bestProb < thisProb){
973         bestProb = thisProb;
974         state = k;
975       }
976     }
977     assert (state != -1);
978
979     delete viterbiPtr;
980
981     // compute traceback
982     SafeVector<char> *alignment = new SafeVector<char>; assert (alignment);
983     int r = seq1Length, c = seq2Length;
984     while (r != 0 || c != 0){
985       int newState = traceback[state + NumMatrixTypes * (r * (seq2Length+1) + c)];
986       
987       if (state == 0){ c--; r--; alignment->push_back ('B'); }
988       else if (state % 2 == 1){ r--; alignment->push_back ('X'); }
989       else { c--; alignment->push_back ('Y'); }
990       
991       state = newState;
992     }
993
994     delete tracebackPtr;
995
996     reverse (alignment->begin(), alignment->end());
997     
998     return make_pair(alignment, bestProb);
999   }
1000
1001   /////////////////////////////////////////////////////////////////
1002   // ProbabilisticModel::BuildPosterior()
1003   //
1004   // Builds a posterior probability matrix needed to align a pair
1005   // of alignments.  Mathematically, the returned matrix M is
1006   // defined as follows:
1007   //    M[i,j] =     sum          sum      f(s,t,i,j)
1008   //             s in align1  t in align2
1009   // where
1010   //                  [  P(s[i'] <--> t[j'])
1011   //                  [       if s[i'] is a letter in the ith column of align1 and
1012   //                  [          t[j'] it a letter in the jth column of align2
1013   //    f(s,t,i,j) =  [
1014   //                  [  0    otherwise
1015   //
1016   /////////////////////////////////////////////////////////////////
1017
1018   VF *BuildPosterior (MultiSequence *align1, MultiSequence *align2,
1019                       const SafeVector<SafeVector<SparseMatrix *> > &sparseMatrices,
1020                       float cutoff = 0.0f) const {
1021     const int seq1Length = align1->GetSequence(0)->GetLength();
1022     const int seq2Length = align2->GetSequence(0)->GetLength();
1023
1024     VF *posteriorPtr = new VF((seq1Length+1) * (seq2Length+1), 0); assert (posteriorPtr);
1025     VF &posterior = *posteriorPtr;
1026     VF::iterator postPtr = posterior.begin();
1027
1028     // for each s in align1
1029     for (int i = 0; i < align1->GetNumSequences(); i++){
1030       int first = align1->GetSequence(i)->GetLabel();
1031       SafeVector<int> *mapping1 = align1->GetSequence(i)->GetMapping();
1032
1033       // for each t in align2
1034       for (int j = 0; j < align2->GetNumSequences(); j++){
1035         int second = align2->GetSequence(j)->GetLabel();
1036         SafeVector<int> *mapping2 = align2->GetSequence(j)->GetMapping();
1037
1038         if (first < second){
1039
1040           // get the associated sparse matrix
1041           SparseMatrix *matrix = sparseMatrices[first][second];
1042           
1043           for (int ii = 1; ii <= matrix->GetSeq1Length(); ii++){
1044             SafeVector<PIF>::iterator row = matrix->GetRowPtr(ii);
1045             int base = (*mapping1)[ii] * (seq2Length+1);
1046             int rowSize = matrix->GetRowSize(ii);
1047             
1048             // add in all relevant values
1049             for (int jj = 0; jj < rowSize; jj++)
1050               posterior[base + (*mapping2)[row[jj].first]] += row[jj].second;
1051             
1052             // subtract cutoff 
1053             for (int jj = 0; jj < matrix->GetSeq2Length(); jj++)
1054               posterior[base + (*mapping2)[jj]] -= cutoff;
1055           }
1056
1057         } else {
1058
1059           // get the associated sparse matrix
1060           SparseMatrix *matrix = sparseMatrices[second][first];
1061           
1062           for (int jj = 1; jj <= matrix->GetSeq1Length(); jj++){
1063             SafeVector<PIF>::iterator row = matrix->GetRowPtr(jj);
1064             int base = (*mapping2)[jj];
1065             int rowSize = matrix->GetRowSize(jj);
1066             
1067             // add in all relevant values
1068             for (int ii = 0; ii < rowSize; ii++)
1069               posterior[base + (*mapping1)[row[ii].first] * (seq2Length + 1)] += row[ii].second;
1070             
1071             // subtract cutoff 
1072             for (int ii = 0; ii < matrix->GetSeq2Length(); ii++)
1073               posterior[base + (*mapping1)[ii] * (seq2Length + 1)] -= cutoff;
1074           }
1075
1076         }
1077         
1078
1079         delete mapping2;
1080       }
1081
1082       delete mapping1;
1083     }
1084
1085     return posteriorPtr;
1086   }
1087 };
1088
1089 #endif