new mafft v 6.857 with extensions
[jabaws.git] / binaries / src / mafft / extensions / mxscarna_src / AlifoldMEA.cpp
1 #include "AlifoldMEA.h"
2
3 namespace MXSCARNA{
4
5 const int AlifoldMEA::TURN = 3;
6
7 void
8 AlifoldMEA::
9 Run()
10 {
11     makeProfileBPPMatrix(alignment);
12     Initialization();
13     DP();
14     TraceBack();
15 }
16
17 void
18 AlifoldMEA::
19 makeProfileBPPMatrix(const MultiSequence *Sequences)
20 {
21     int length = Sequences->GetSequence(0)->GetLength();
22
23     Trimat<float> *consBppMat = new Trimat<float>(length + 1);
24     fill(consBppMat->begin(), consBppMat->end(), 0);
25
26     for(int i = 1; i <= length; i++) 
27         for (int j = i; j <= length; j++) 
28             bppMat.ref(i, j) = 0;
29
30
31     int number = Sequences->GetNumSequences();
32     for(int seqNum = 0; seqNum < number; seqNum++) {
33         SafeVector<int> *tmpMap = Sequences->GetSequence(seqNum)->GetMappingNumber();
34         int label = Sequences->GetSequence(seqNum)->GetLabel();
35         BPPMatrix *tmpBppMatrix = BPPMatrices[label];
36         
37         for(int i = 1; i <= length ; i++) {
38             int originI = tmpMap->at(i);
39             for(int j = i; j <= length; j++) {
40                 int originJ = tmpMap->at(j);
41                 if(originI != 0 && originJ != 0) {
42                     float tmpProb = tmpBppMatrix->GetProb(originI, originJ);
43                     bppMat.ref(i, j) += tmpProb;
44                 }
45             }
46         }
47     }
48
49         /* compute the mean of base pairing probability  */
50     for(int i = 1; i <= length; i++) {
51         for(int j = i; j <= length; j++) {
52             bppMat.ref(i,j) = bppMat.ref(i,j)/(float)number;
53         }
54     }
55
56     for (int i = 1; i <= length; i++) {
57         float sum = 0;
58         for (int j = i; j <= length; j++) {
59             sum += bppMat.ref(i,j);
60         }
61         Qi[i] = 1 - sum;
62     }
63
64     for (int i = 1; i <= length; i++) {
65         float sum = 0;
66         for (int j = i; j >= 1; j--) {
67             sum += bppMat.ref(j, i);
68         }
69         Qj[i] = 1 - sum;
70     }
71 }
72
73 void
74 AlifoldMEA::
75 Initialization()
76 {
77     int length = alignment->GetSequence(0)->GetLength();
78
79     for (int i = 1; i <= length; i++) {
80         for (int j = i; j <= length; j++) {
81             M.ref(i,j) = 0;
82             traceI.ref(i,j) = 0;
83             traceJ.ref(i,j) = 0;
84         }
85     }
86
87     for (int i = 1; i <= length; i++) {
88         M.ref(i,i)   = Qi[i]; 
89         traceI.ref(i,i) = 0;
90         traceJ.ref(i,i) = 0;
91     }
92
93     for (int i = 1; i <= length - 1; i++) {
94         M.ref(i, i+1) =  Qi[i+1];
95         traceI.ref(i,i + 1) = 0;
96         traceJ.ref(i,i + 1) = 0;
97     }
98
99     for (int i = 0; i <= length; i++) {
100         ssCons[i] = '.';
101     }
102 }
103
104 void
105 AlifoldMEA::
106 DP()
107 {
108     float g    = BasePairConst; // see scarna.hpp
109     int length = alignment->GetSequence(0)->GetLength();
110     
111     for (int i = length - 1; i >= 1; i--) {
112         for (int j = i + TURN + 1; j <= length; j++) {
113             float qi       = Qi[i];
114             float qj       = Qj[j];
115             float p        = bppMat.ref(i,j);
116
117             
118             float maxScore = qi + M.ref(i+1, j);
119             int tmpI = i+1;
120             int tmpJ = j;
121             
122             float tmpScore = qj + M.ref(i, j-1);
123             if (tmpScore > maxScore) {
124                 maxScore = tmpScore;
125                 tmpI     = i;
126                 tmpJ     = j - 1;
127             }
128             
129             tmpScore = g*2*p + M.ref(i+1, j-1);
130             if (tmpScore > maxScore) {
131                 maxScore = tmpScore;
132                 tmpI     = i + 1;
133                 tmpJ     = j - 1;
134             }
135             
136             for (int k = i + 1; k < j - 1; k++) {
137                 tmpScore = M.ref(i,k) + M.ref(k+1,j);
138                 if (tmpScore > maxScore) {
139                     maxScore = tmpScore;
140                     tmpI = i;
141                     tmpJ = j;
142                 }
143             }
144             M.ref(i,j)       = maxScore;
145             traceI.ref(i, j) = tmpI;
146             traceJ.ref(i, j) = tmpJ;
147         }
148     }
149 }
150
151 void
152 AlifoldMEA::
153 TraceBack()
154 {
155
156     int length = alignment->GetSequence(0)->GetLength();
157     SafeVector<int> stackI((length + 1)*(length+1));
158     SafeVector<int> stackJ((length + 1)*(length+1));
159     int pt = 0;
160
161     stackI[pt] = traceI.ref(1, length);
162     stackJ[pt] = traceJ.ref(1, length);
163     ++pt;
164     
165     while(pt != 0) {
166         --pt;
167         int tmpI = stackI[pt];
168         int tmpJ = stackJ[pt];
169         int nextI = traceI.ref(tmpI, tmpJ);
170         int nextJ = traceJ.ref(tmpI, tmpJ);
171
172         if (tmpI < tmpJ) {
173             if (tmpI + 1  == nextI && tmpJ == nextJ) {
174                 stackI[pt] = nextI;
175                 stackJ[pt] = nextJ;
176                 ++pt;
177             }
178             else if (tmpI == nextI && tmpJ - 1 == nextJ) {
179                 stackI[pt] = nextI;
180                 stackJ[pt] = nextJ;
181                 ++pt;
182             }
183             else if (tmpI + 1 == nextI && tmpJ - 1== nextJ) {
184                 stackI[pt] = nextI;
185                 stackJ[pt] = nextJ;
186                 ++pt;
187                 ssCons[tmpI] = '(';
188                 ssCons[tmpJ] = ')';
189             }
190             else if (tmpI == nextI && tmpJ == nextJ) {
191                 float maxScore = IMPOSSIBLE;
192                 int maxK = 0;
193
194                 for (int k = tmpI + 1; k < tmpJ - 1; k++) {
195                     float tmpScore = M.ref(tmpI,k) + M.ref(k+1,tmpJ);
196                     if (tmpScore > maxScore) {
197                         maxScore = tmpScore;
198                         maxK = k;
199                     }
200                 }
201                 stackI[pt] = traceI.ref(tmpI, maxK);
202                 stackJ[pt] = traceJ.ref(tmpI, maxK);
203                 ++pt;
204                 stackI[pt] = traceI.ref(maxK+1, tmpJ);
205                 stackJ[pt] = traceJ.ref(maxK+1, tmpJ);
206                 ++pt;
207             }
208         }
209     }
210 }
211 }