fixed so that alignment percentage ID is used for distance calculation
[jalview.git] / src / jalview / analysis / NJTree.java
1 package jalview.analysis;\r
2 \r
3 import jalview.datamodel.*;\r
4 import jalview.util.*;\r
5 import jalview.schemes.ResidueProperties;\r
6 import java.util.*;\r
7 \r
8 public class NJTree {\r
9 \r
10   Vector cluster;\r
11   SequenceI[] sequence;\r
12 \r
13   int done[];\r
14   int noseqs;\r
15   int noClus;\r
16 \r
17   float distance[][];\r
18 \r
19   int mini;\r
20   int minj;\r
21   float ri;\r
22   float rj;\r
23 \r
24   Vector groups = new Vector();\r
25   SequenceNode maxdist;\r
26   SequenceNode top;\r
27 \r
28   float maxDistValue;\r
29   float maxheight;\r
30 \r
31   int ycount;\r
32 \r
33   Vector node;\r
34 \r
35   String type;\r
36   String pwtype;\r
37 \r
38   Object found = null;\r
39   Object leaves = null;\r
40 \r
41   int start;\r
42   int end;\r
43 \r
44   public NJTree(SequenceNode node) {\r
45     top = node;\r
46     maxheight = findHeight(top);\r
47   }\r
48 \r
49   public NJTree(SequenceI[] sequence,int start, int end) {\r
50     this(sequence,"NJ","BL",start,end);\r
51   }\r
52 \r
53   public NJTree(SequenceI[] sequence,String type,String pwtype,int start, int end ) {\r
54 \r
55     this.sequence = sequence;\r
56     this.node     = new Vector();\r
57     this.type     = type;\r
58     this.pwtype   = pwtype;\r
59     this.start    = start;\r
60     this.end      = end;\r
61 \r
62     if (!(type.equals("NJ"))) {\r
63       type = "AV";\r
64     }\r
65 \r
66     if (!(pwtype.equals("PID"))) {\r
67       type = "BL";\r
68     }\r
69 \r
70     int i=0;\r
71 \r
72     done = new int[sequence.length];\r
73 \r
74 \r
75     while (i < sequence.length  && sequence[i] != null) {\r
76       done[i] = 0;\r
77       i++;\r
78     }\r
79 \r
80     noseqs = i++;\r
81 \r
82     distance = findDistances();\r
83 \r
84     makeLeaves();\r
85 \r
86     noClus = cluster.size();\r
87 \r
88     cluster();\r
89 \r
90   }\r
91 \r
92 \r
93   public void cluster() {\r
94 \r
95     while (noClus > 2) {\r
96       if (type.equals("NJ")) {\r
97         float mind = findMinNJDistance();\r
98       } else {\r
99         float mind = findMinDistance();\r
100       }\r
101 \r
102       Cluster c = joinClusters(mini,minj);\r
103 \r
104 \r
105       done[minj] = 1;\r
106 \r
107       cluster.setElementAt(null,minj);\r
108       cluster.setElementAt(c,mini);\r
109 \r
110       noClus--;\r
111     }\r
112 \r
113     boolean onefound = false;\r
114 \r
115     int one = -1;\r
116     int two = -1;\r
117 \r
118     for (int i=0; i < noseqs; i++) {\r
119       if (done[i] != 1) {\r
120         if (onefound == false) {\r
121           two = i;\r
122           onefound = true;\r
123         } else {\r
124           one = i;\r
125         }\r
126       }\r
127     }\r
128 \r
129     Cluster c = joinClusters(one,two);\r
130     top = (SequenceNode)(node.elementAt(one));\r
131 \r
132     reCount(top);\r
133     findHeight(top);\r
134     findMaxDist(top);\r
135 \r
136   }\r
137 \r
138   public Cluster joinClusters(int i, int j) {\r
139 \r
140     float dist = distance[i][j];\r
141 \r
142     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
143     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
144 \r
145     int[] value = new int[noi + noj];\r
146 \r
147     for (int ii = 0; ii < noi;ii++) {\r
148       value[ii] =  ((Cluster)cluster.elementAt(i)).value[ii];\r
149     }\r
150 \r
151     for (int ii = noi; ii < noi+ noj;ii++) {\r
152       value[ii] =  ((Cluster)cluster.elementAt(j)).value[ii-noi];\r
153     }\r
154 \r
155     Cluster c = new Cluster(value);\r
156 \r
157     ri = findr(i,j);\r
158     rj = findr(j,i);\r
159 \r
160     if (type.equals("NJ")) {\r
161       findClusterNJDistance(i,j);\r
162     } else {\r
163       findClusterDistance(i,j);\r
164     }\r
165 \r
166     SequenceNode sn = new SequenceNode();\r
167 \r
168     sn.setLeft((SequenceNode)(node.elementAt(i)));\r
169     sn.setRight((SequenceNode)(node.elementAt(j)));\r
170 \r
171     SequenceNode tmpi = (SequenceNode)(node.elementAt(i));\r
172     SequenceNode tmpj = (SequenceNode)(node.elementAt(j));\r
173 \r
174     if (type.equals("NJ")) {\r
175       findNewNJDistances(tmpi,tmpj,dist);\r
176     } else {\r
177       findNewDistances(tmpi,tmpj,dist);\r
178     }\r
179 \r
180     tmpi.setParent(sn);\r
181     tmpj.setParent(sn);\r
182 \r
183     node.setElementAt(sn,i);\r
184     return c;\r
185   }\r
186 \r
187   public void findNewNJDistances(SequenceNode tmpi, SequenceNode tmpj, float dist) {\r
188 \r
189     float ih = 0;\r
190     float jh = 0;\r
191 \r
192     SequenceNode sni = tmpi;\r
193     SequenceNode snj = tmpj;\r
194 \r
195     tmpi.dist = (dist + ri - rj)/2;\r
196     tmpj.dist = (dist - tmpi.dist);\r
197 \r
198     if (tmpi.dist < 0) {\r
199       tmpi.dist = 0;\r
200     }\r
201     if (tmpj.dist < 0) {\r
202       tmpj.dist = 0;\r
203     }\r
204   }\r
205 \r
206   public void findNewDistances(SequenceNode tmpi,SequenceNode tmpj,float dist) {\r
207 \r
208     float ih = 0;\r
209     float jh = 0;\r
210 \r
211     SequenceNode sni = tmpi;\r
212     SequenceNode snj = tmpj;\r
213 \r
214     while (sni != null) {\r
215       ih = ih + sni.dist;\r
216       sni = (SequenceNode)sni.left();\r
217     }\r
218 \r
219     while (snj != null) {\r
220       jh = jh + snj.dist;\r
221       snj = (SequenceNode)snj.left();\r
222     }\r
223 \r
224     tmpi.dist = (dist/2 - ih);\r
225     tmpj.dist = (dist/2 - jh);\r
226   }\r
227 \r
228 \r
229 \r
230   public void findClusterDistance(int i, int j) {\r
231 \r
232     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
233     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
234 \r
235     // New distances from cluster to others\r
236     float[] newdist = new float[noseqs];\r
237 \r
238     for (int l = 0; l < noseqs; l++) {\r
239       if ( l != i && l != j) {\r
240         newdist[l] = (distance[i][l] * noi + distance[j][l] * noj)/(noi + noj);\r
241       } else {\r
242         newdist[l] = 0;\r
243       }\r
244     }\r
245 \r
246     for (int ii=0; ii < noseqs;ii++) {\r
247       distance[i][ii] = newdist[ii];\r
248       distance[ii][i] = newdist[ii];\r
249     }\r
250   }\r
251 \r
252   public void findClusterNJDistance(int i, int j) {\r
253 \r
254     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
255     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
256 \r
257     // New distances from cluster to others\r
258     float[] newdist = new float[noseqs];\r
259 \r
260     for (int l = 0; l < noseqs; l++) {\r
261       if ( l != i && l != j) {\r
262         newdist[l] = (distance[i][l] + distance[j][l] - distance[i][j])/2;\r
263       } else {\r
264         newdist[l] = 0;\r
265       }\r
266     }\r
267 \r
268     for (int ii=0; ii < noseqs;ii++) {\r
269       distance[i][ii] = newdist[ii];\r
270       distance[ii][i] = newdist[ii];\r
271     }\r
272   }\r
273 \r
274   public float findr(int i, int j) {\r
275 \r
276     float tmp = 1;\r
277     for (int k=0; k < noseqs;k++) {\r
278       if (k!= i && k!= j && done[k] != 1) {\r
279         tmp = tmp + distance[i][k];\r
280       }\r
281     }\r
282 \r
283     if (noClus > 2) {\r
284       tmp = tmp/(noClus - 2);\r
285     }\r
286 \r
287     return tmp;\r
288   }\r
289 \r
290   public float findMinNJDistance() {\r
291 \r
292     float min = 100000;\r
293 \r
294     for (int i=0; i < noseqs-1; i++) {\r
295       for (int j=i+1;j < noseqs;j++) {\r
296         if (done[i] != 1 && done[j] != 1) {\r
297           float tmp = distance[i][j] - (findr(i,j) + findr(j,i));\r
298           if (tmp < min) {\r
299 \r
300             mini = i;\r
301             minj = j;\r
302 \r
303             min = tmp;\r
304 \r
305           }\r
306         }\r
307       }\r
308     }\r
309     return min;\r
310   }\r
311 \r
312   public float findMinDistance() {\r
313 \r
314     float min = 100000;\r
315 \r
316     for (int i=0; i < noseqs-1;i++) {\r
317       for (int j = i+1; j < noseqs;j++) {\r
318         if (done[i] != 1 && done[j] != 1) {\r
319           if (distance[i][j] < min) {\r
320             mini = i;\r
321             minj = j;\r
322 \r
323             min = distance[i][j];\r
324           }\r
325         }\r
326       }\r
327     }\r
328     return min;\r
329   }\r
330 \r
331   public float[][] findDistances() {\r
332 \r
333     float[][] distance = new float[noseqs][noseqs];\r
334     if (pwtype.equals("PID")) {\r
335       for (int i = 0; i < noseqs-1; i++) {\r
336         for (int j = i; j < noseqs; j++) {\r
337           if (j==i) {\r
338             distance[i][i] = 0;\r
339           } else {\r
340             distance[i][j] = 100-Comparison.PID(sequence[i], sequence[j]);\r
341             distance[j][i] = distance[i][j];\r
342           }\r
343         }\r
344       }\r
345     } else if (pwtype.equals("BL")) {\r
346       int   maxscore = 0;\r
347 \r
348       for (int i = 0; i < noseqs-1; i++) {\r
349         for (int j = i; j < noseqs; j++) {\r
350           int score = 0;\r
351           for (int k=0; k < sequence[i].getLength(); k++) {\r
352             try{\r
353               score +=\r
354                   ResidueProperties.getBLOSUM62(sequence[i].getSequence(k,\r
355                   k + 1),\r
356                                                 sequence[j].getSequence(k,\r
357                   k + 1));\r
358             }catch(Exception ex){System.out.println("err creating BLOSUM62 tree");}\r
359           }\r
360           distance[i][j] = (float)score;\r
361           if (score > maxscore) {\r
362             maxscore = score;\r
363           }\r
364         }\r
365       }\r
366       for (int i = 0; i < noseqs-1; i++) {\r
367         for (int j = i; j < noseqs; j++) {\r
368           distance[i][j] =  (float)maxscore - distance[i][j];\r
369           distance[j][i] = distance[i][j];\r
370         }\r
371       }\r
372     } else if (pwtype.equals("SW")) {\r
373       float max = -1;\r
374       for (int i = 0; i < noseqs-1; i++) {\r
375         for (int j = i; j < noseqs; j++) {\r
376           AlignSeq as = new AlignSeq(sequence[i],sequence[j],"pep");\r
377           as.calcScoreMatrix();\r
378           as.traceAlignment();\r
379           as.printAlignment();\r
380           distance[i][j] = (float)as.maxscore;\r
381           if (max < distance[i][j]) {\r
382             max = distance[i][j];\r
383           }\r
384         }\r
385       }\r
386       for (int i = 0; i < noseqs-1; i++) {\r
387         for (int j = i; j < noseqs; j++) {\r
388           distance[i][j] =  max - distance[i][j];\r
389           distance[j][i] = distance[i][j];\r
390         }\r
391       }\r
392     }\r
393 \r
394     return distance;\r
395   }\r
396 \r
397   public void makeLeaves() {\r
398     cluster = new Vector();\r
399 \r
400     for (int i=0; i < noseqs; i++) {\r
401       SequenceNode sn = new SequenceNode();\r
402 \r
403       sn.setElement(sequence[i]);\r
404       sn.setName(sequence[i].getName());\r
405       node.addElement(sn);\r
406 \r
407       int[] value = new int[1];\r
408       value[0] = i;\r
409 \r
410       Cluster c = new Cluster(value);\r
411       cluster.addElement(c);\r
412     }\r
413   }\r
414 \r
415   public Vector findLeaves(SequenceNode node, Vector leaves) {\r
416     if (node == null) {\r
417       return leaves;\r
418     }\r
419 \r
420     if (node.left() == null && node.right() == null) {\r
421       leaves.addElement(node);\r
422       return leaves;\r
423     } else {\r
424       findLeaves((SequenceNode)node.left(),leaves);\r
425       findLeaves((SequenceNode)node.right(),leaves);\r
426     }\r
427     return leaves;\r
428   }\r
429 \r
430   public Object findLeaf(SequenceNode node, int count) {\r
431     found = _findLeaf(node,count);\r
432 \r
433     return found;\r
434   }\r
435   public Object _findLeaf(SequenceNode node,int count) {\r
436     if (node == null) {\r
437       return null;\r
438     }\r
439     if (node.ycount == count) {\r
440       found = node.element();\r
441       return found;\r
442     } else {\r
443       _findLeaf((SequenceNode)node.left(),count);\r
444       _findLeaf((SequenceNode)node.right(),count);\r
445     }\r
446 \r
447     return found;\r
448   }\r
449 \r
450   public void printNode(SequenceNode node) {\r
451     if (node == null) {\r
452       return;\r
453     }\r
454     if (node.left() == null && node.right() == null) {\r
455       System.out.println("Leaf = " + ((SequenceI)node.element()).getName());\r
456       System.out.println("Dist " + ((SequenceNode)node).dist);\r
457       System.out.println("Boot " + node.getBootstrap());\r
458     } else {\r
459       System.out.println("Dist " + ((SequenceNode)node).dist);\r
460       printNode((SequenceNode)node.left());\r
461       printNode((SequenceNode)node.right());\r
462     }\r
463   }\r
464   public void findMaxDist(SequenceNode node) {\r
465     if (node == null) {\r
466       return;\r
467     }\r
468     if (node.left() == null && node.right() == null) {\r
469 \r
470       float dist = ((SequenceNode)node).dist;\r
471       if (dist > maxDistValue) {\r
472           maxdist      = (SequenceNode)node;\r
473           maxDistValue = dist;\r
474       }\r
475     } else {\r
476       findMaxDist((SequenceNode)node.left());\r
477       findMaxDist((SequenceNode)node.right());\r
478     }\r
479   }\r
480     public Vector getGroups() {\r
481         return groups;\r
482     }\r
483     public float getMaxHeight() {\r
484         return maxheight;\r
485     }\r
486   public void  groupNodes(SequenceNode node, float threshold) {\r
487     if (node == null) {\r
488       return;\r
489     }\r
490 \r
491     if (node.height/maxheight > threshold) {\r
492       groups.addElement(node);\r
493     } else {\r
494       groupNodes((SequenceNode)node.left(),threshold);\r
495       groupNodes((SequenceNode)node.right(),threshold);\r
496     }\r
497   }\r
498 \r
499   public float findHeight(SequenceNode node) {\r
500 \r
501     if (node == null) {\r
502       return maxheight;\r
503     }\r
504 \r
505     if (node.left() == null && node.right() == null) {\r
506       node.height = ((SequenceNode)node.parent()).height + node.dist;\r
507 \r
508       if (node.height > maxheight) {\r
509         return node.height;\r
510       } else {\r
511         return maxheight;\r
512       }\r
513     } else {\r
514       if (node.parent() != null) {\r
515         node.height = ((SequenceNode)node.parent()).height + node.dist;\r
516       } else {\r
517         maxheight = 0;\r
518         node.height = (float)0.0;\r
519       }\r
520 \r
521       maxheight = findHeight((SequenceNode)(node.left()));\r
522       maxheight = findHeight((SequenceNode)(node.right()));\r
523     }\r
524     return maxheight;\r
525   }\r
526   public SequenceNode reRoot() {\r
527     if (maxdist != null) {\r
528       ycount = 0;\r
529       float tmpdist = maxdist.dist;\r
530 \r
531       // New top\r
532       SequenceNode sn = new SequenceNode();\r
533       sn.setParent(null);\r
534 \r
535       // New right hand of top\r
536       SequenceNode snr = (SequenceNode)maxdist.parent();\r
537       changeDirection(snr,maxdist);\r
538       System.out.println("Printing reversed tree");\r
539       printN(snr);\r
540       snr.dist = tmpdist/2;\r
541       maxdist.dist = tmpdist/2;\r
542 \r
543       snr.setParent(sn);\r
544       maxdist.setParent(sn);\r
545 \r
546       sn.setRight(snr);\r
547       sn.setLeft(maxdist);\r
548 \r
549       top = sn;\r
550 \r
551       ycount = 0;\r
552       reCount(top);\r
553       findHeight(top);\r
554 \r
555     }\r
556     return top;\r
557   }\r
558   public static void printN(SequenceNode node) {\r
559     if (node == null) {\r
560       return;\r
561     }\r
562 \r
563     if (node.left() != null && node.right() != null) {\r
564       printN((SequenceNode)node.left());\r
565       printN((SequenceNode)node.right());\r
566     } else {\r
567       System.out.println(" name = " + ((SequenceI)node.element()).getName());\r
568     }\r
569     System.out.println(" dist = " + ((SequenceNode)node).dist + " " + ((SequenceNode)node).count + " " + ((SequenceNode)node).height);\r
570   }\r
571 \r
572     public void reCount(SequenceNode node) {\r
573         ycount = 0;\r
574         _reCount(node);\r
575     }\r
576   public void _reCount(SequenceNode node) {\r
577     if (node == null) {\r
578       return;\r
579     }\r
580 \r
581     if (node.left() != null && node.right() != null) {\r
582       _reCount((SequenceNode)node.left());\r
583       _reCount((SequenceNode)node.right());\r
584 \r
585       SequenceNode l = (SequenceNode)node.left();\r
586       SequenceNode r = (SequenceNode)node.right();\r
587 \r
588       ((SequenceNode)node).count  = l.count + r.count;\r
589       ((SequenceNode)node).ycount = (l.ycount + r.ycount)/2;\r
590 \r
591     } else {\r
592       ((SequenceNode)node).count = 1;\r
593       ((SequenceNode)node).ycount = ycount++;\r
594     }\r
595 \r
596   }\r
597     public void swapNodes(SequenceNode node) {\r
598         if (node == null) {\r
599             return;\r
600         }\r
601         SequenceNode tmp = (SequenceNode)node.left();\r
602 \r
603         node.setLeft(node.right());\r
604         node.setRight(tmp);\r
605     }\r
606   public void changeDirection(SequenceNode node, SequenceNode dir) {\r
607     if (node == null) {\r
608       return;\r
609     }\r
610     if (node.parent() != top) {\r
611       changeDirection((SequenceNode)node.parent(), node);\r
612 \r
613       SequenceNode tmp = (SequenceNode)node.parent();\r
614 \r
615       if (dir == node.left()) {\r
616         node.setParent(dir);\r
617         node.setLeft(tmp);\r
618       } else if (dir == node.right()) {\r
619         node.setParent(dir);\r
620         node.setRight(tmp);\r
621       }\r
622 \r
623     } else {\r
624       if (dir == node.left()) {\r
625         node.setParent(node.left());\r
626 \r
627         if (top.left() == node) {\r
628           node.setRight(top.right());\r
629         } else {\r
630           node.setRight(top.left());\r
631         }\r
632       } else {\r
633         node.setParent(node.right());\r
634 \r
635         if (top.left() == node) {\r
636           node.setLeft(top.right());\r
637         } else {\r
638           node.setLeft(top.left());\r
639         }\r
640       }\r
641     }\r
642   }\r
643     public void setMaxDist(SequenceNode node) {\r
644         this.maxdist = maxdist;\r
645     }\r
646     public SequenceNode getMaxDist() {\r
647         return maxdist;\r
648     }\r
649     public SequenceNode getTopNode() {\r
650         return top;\r
651     }\r
652 \r
653 }\r
654 \r
655 \r
656 \r
657 class Cluster {\r
658 \r
659   int[] value;\r
660 \r
661   public Cluster(int[] value) {\r
662     this.value = value;\r
663   }\r
664 \r
665 }\r
666 \r