Moved tree leafname to alignment set matching into SequenceIdMatcher class for
[jalview.git] / src / jalview / analysis / NJTree.java
1 package jalview.analysis;\r
2 \r
3 import jalview.datamodel.*;\r
4 import jalview.util.*;\r
5 import jalview.schemes.ResidueProperties;\r
6 import java.util.*;\r
7 \r
8 import jalview.io.NewickFile;\r
9 \r
10 public class NJTree {\r
11 \r
12   Vector cluster;\r
13   SequenceI[] sequence;\r
14 \r
15   int done[];\r
16   int noseqs;\r
17   int noClus;\r
18 \r
19   float distance[][];\r
20 \r
21   int mini;\r
22   int minj;\r
23   float ri;\r
24   float rj;\r
25 \r
26   Vector groups = new Vector();\r
27   SequenceNode maxdist;\r
28   SequenceNode top;\r
29 \r
30   float maxDistValue;\r
31   float maxheight;\r
32 \r
33   int ycount;\r
34 \r
35   Vector node;\r
36 \r
37   String type;\r
38   String pwtype;\r
39 \r
40   Object found = null;\r
41   Object leaves = null;\r
42 \r
43   int start;\r
44   int end;\r
45 \r
46   public NJTree(SequenceNode node) {\r
47     top = node;\r
48     maxheight = findHeight(top);\r
49 \r
50   }\r
51 \r
52   public NJTree(SequenceI[] seqs, NewickFile treefile) {\r
53     top = treefile.getTree();\r
54     maxheight = findHeight(top);\r
55     SequenceIdMatcher algnIds = new SequenceIdMatcher(seqs);\r
56 \r
57     Vector leaves = new Vector();\r
58     findLeaves(top, leaves);\r
59 \r
60     int i = 0;\r
61     int namesleft = seqs.length;\r
62 \r
63     SequenceNode j;\r
64     SequenceI nam;\r
65     String realnam;\r
66     while (i < leaves.size())\r
67     {\r
68       j = (SequenceNode) leaves.elementAt(i++);\r
69       realnam = j.getName();\r
70       nam = null;\r
71       if (namesleft>-1)\r
72         nam = algnIds.findIdMatch(realnam);\r
73       if (nam != null) {\r
74         j.setElement(nam);\r
75         namesleft--;\r
76       } else {\r
77         j.setElement(new Sequence(realnam, "THISISAPLACEHLDER"));\r
78       }\r
79     }\r
80   }\r
81 \r
82   public NJTree(SequenceI[] sequence,int start, int end) {\r
83     this(sequence,"NJ","BL",start,end);\r
84   }\r
85 \r
86   public NJTree(SequenceI[] sequence,String type,String pwtype,int start, int end ) {\r
87 \r
88     this.sequence = sequence;\r
89     this.node     = new Vector();\r
90     this.type     = type;\r
91     this.pwtype   = pwtype;\r
92     this.start    = start;\r
93     this.end      = end;\r
94 \r
95     if (!(type.equals("NJ"))) {\r
96       type = "AV";\r
97     }\r
98 \r
99     if (!(pwtype.equals("PID"))) {\r
100       type = "BL";\r
101     }\r
102 \r
103     int i=0;\r
104 \r
105     done = new int[sequence.length];\r
106 \r
107 \r
108     while (i < sequence.length  && sequence[i] != null) {\r
109       done[i] = 0;\r
110       i++;\r
111     }\r
112 \r
113     noseqs = i++;\r
114 \r
115     distance = findDistances();\r
116 \r
117     makeLeaves();\r
118 \r
119     noClus = cluster.size();\r
120 \r
121     cluster();\r
122 \r
123   }\r
124 \r
125 \r
126   public void cluster() {\r
127 \r
128     while (noClus > 2) {\r
129       if (type.equals("NJ")) {\r
130         float mind = findMinNJDistance();\r
131       } else {\r
132         float mind = findMinDistance();\r
133       }\r
134 \r
135       Cluster c = joinClusters(mini,minj);\r
136 \r
137 \r
138       done[minj] = 1;\r
139 \r
140       cluster.setElementAt(null,minj);\r
141       cluster.setElementAt(c,mini);\r
142 \r
143       noClus--;\r
144     }\r
145 \r
146     boolean onefound = false;\r
147 \r
148     int one = -1;\r
149     int two = -1;\r
150 \r
151     for (int i=0; i < noseqs; i++) {\r
152       if (done[i] != 1) {\r
153         if (onefound == false) {\r
154           two = i;\r
155           onefound = true;\r
156         } else {\r
157           one = i;\r
158         }\r
159       }\r
160     }\r
161 \r
162     Cluster c = joinClusters(one,two);\r
163     top = (SequenceNode)(node.elementAt(one));\r
164 \r
165     reCount(top);\r
166     findHeight(top);\r
167     findMaxDist(top);\r
168 \r
169   }\r
170 \r
171   public Cluster joinClusters(int i, int j) {\r
172 \r
173     float dist = distance[i][j];\r
174 \r
175     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
176     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
177 \r
178     int[] value = new int[noi + noj];\r
179 \r
180     for (int ii = 0; ii < noi;ii++) {\r
181       value[ii] =  ((Cluster)cluster.elementAt(i)).value[ii];\r
182     }\r
183 \r
184     for (int ii = noi; ii < noi+ noj;ii++) {\r
185       value[ii] =  ((Cluster)cluster.elementAt(j)).value[ii-noi];\r
186     }\r
187 \r
188     Cluster c = new Cluster(value);\r
189 \r
190     ri = findr(i,j);\r
191     rj = findr(j,i);\r
192 \r
193     if (type.equals("NJ")) {\r
194       findClusterNJDistance(i,j);\r
195     } else {\r
196       findClusterDistance(i,j);\r
197     }\r
198 \r
199     SequenceNode sn = new SequenceNode();\r
200 \r
201     sn.setLeft((SequenceNode)(node.elementAt(i)));\r
202     sn.setRight((SequenceNode)(node.elementAt(j)));\r
203 \r
204     SequenceNode tmpi = (SequenceNode)(node.elementAt(i));\r
205     SequenceNode tmpj = (SequenceNode)(node.elementAt(j));\r
206 \r
207     if (type.equals("NJ")) {\r
208       findNewNJDistances(tmpi,tmpj,dist);\r
209     } else {\r
210       findNewDistances(tmpi,tmpj,dist);\r
211     }\r
212 \r
213     tmpi.setParent(sn);\r
214     tmpj.setParent(sn);\r
215 \r
216     node.setElementAt(sn,i);\r
217     return c;\r
218   }\r
219 \r
220   public void findNewNJDistances(SequenceNode tmpi, SequenceNode tmpj, float dist) {\r
221 \r
222     float ih = 0;\r
223     float jh = 0;\r
224 \r
225     SequenceNode sni = tmpi;\r
226     SequenceNode snj = tmpj;\r
227 \r
228     tmpi.dist = (dist + ri - rj)/2;\r
229     tmpj.dist = (dist - tmpi.dist);\r
230 \r
231     if (tmpi.dist < 0) {\r
232       tmpi.dist = 0;\r
233     }\r
234     if (tmpj.dist < 0) {\r
235       tmpj.dist = 0;\r
236     }\r
237   }\r
238 \r
239   public void findNewDistances(SequenceNode tmpi,SequenceNode tmpj,float dist) {\r
240 \r
241     float ih = 0;\r
242     float jh = 0;\r
243 \r
244     SequenceNode sni = tmpi;\r
245     SequenceNode snj = tmpj;\r
246 \r
247     while (sni != null) {\r
248       ih = ih + sni.dist;\r
249       sni = (SequenceNode)sni.left();\r
250     }\r
251 \r
252     while (snj != null) {\r
253       jh = jh + snj.dist;\r
254       snj = (SequenceNode)snj.left();\r
255     }\r
256 \r
257     tmpi.dist = (dist/2 - ih);\r
258     tmpj.dist = (dist/2 - jh);\r
259   }\r
260 \r
261 \r
262 \r
263   public void findClusterDistance(int i, int j) {\r
264 \r
265     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
266     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
267 \r
268     // New distances from cluster to others\r
269     float[] newdist = new float[noseqs];\r
270 \r
271     for (int l = 0; l < noseqs; l++) {\r
272       if ( l != i && l != j) {\r
273         newdist[l] = (distance[i][l] * noi + distance[j][l] * noj)/(noi + noj);\r
274       } else {\r
275         newdist[l] = 0;\r
276       }\r
277     }\r
278 \r
279     for (int ii=0; ii < noseqs;ii++) {\r
280       distance[i][ii] = newdist[ii];\r
281       distance[ii][i] = newdist[ii];\r
282     }\r
283   }\r
284 \r
285   public void findClusterNJDistance(int i, int j) {\r
286 \r
287     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
288     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
289 \r
290     // New distances from cluster to others\r
291     float[] newdist = new float[noseqs];\r
292 \r
293     for (int l = 0; l < noseqs; l++) {\r
294       if ( l != i && l != j) {\r
295         newdist[l] = (distance[i][l] + distance[j][l] - distance[i][j])/2;\r
296       } else {\r
297         newdist[l] = 0;\r
298       }\r
299     }\r
300 \r
301     for (int ii=0; ii < noseqs;ii++) {\r
302       distance[i][ii] = newdist[ii];\r
303       distance[ii][i] = newdist[ii];\r
304     }\r
305   }\r
306 \r
307   public float findr(int i, int j) {\r
308 \r
309     float tmp = 1;\r
310     for (int k=0; k < noseqs;k++) {\r
311       if (k!= i && k!= j && done[k] != 1) {\r
312         tmp = tmp + distance[i][k];\r
313       }\r
314     }\r
315 \r
316     if (noClus > 2) {\r
317       tmp = tmp/(noClus - 2);\r
318     }\r
319 \r
320     return tmp;\r
321   }\r
322 \r
323   public float findMinNJDistance() {\r
324 \r
325     float min = 100000;\r
326 \r
327     for (int i=0; i < noseqs-1; i++) {\r
328       for (int j=i+1;j < noseqs;j++) {\r
329         if (done[i] != 1 && done[j] != 1) {\r
330           float tmp = distance[i][j] - (findr(i,j) + findr(j,i));\r
331           if (tmp < min) {\r
332 \r
333             mini = i;\r
334             minj = j;\r
335 \r
336             min = tmp;\r
337 \r
338           }\r
339         }\r
340       }\r
341     }\r
342     return min;\r
343   }\r
344 \r
345   public float findMinDistance() {\r
346 \r
347     float min = 100000;\r
348 \r
349     for (int i=0; i < noseqs-1;i++) {\r
350       for (int j = i+1; j < noseqs;j++) {\r
351         if (done[i] != 1 && done[j] != 1) {\r
352           if (distance[i][j] < min) {\r
353             mini = i;\r
354             minj = j;\r
355 \r
356             min = distance[i][j];\r
357           }\r
358         }\r
359       }\r
360     }\r
361     return min;\r
362   }\r
363 \r
364   public float[][] findDistances() {\r
365 \r
366     float[][] distance = new float[noseqs][noseqs];\r
367     if (pwtype.equals("PID")) {\r
368       for (int i = 0; i < noseqs-1; i++) {\r
369         for (int j = i; j < noseqs; j++) {\r
370           if (j==i) {\r
371             distance[i][i] = 0;\r
372           } else {\r
373             distance[i][j] = 100-Comparison.PID(sequence[i], sequence[j]);\r
374             distance[j][i] = distance[i][j];\r
375           }\r
376         }\r
377       }\r
378     } else if (pwtype.equals("BL")) {\r
379       int   maxscore = 0;\r
380 \r
381       for (int i = 0; i < noseqs-1; i++) {\r
382         for (int j = i; j < noseqs; j++) {\r
383           int score = 0;\r
384           for (int k=0; k < sequence[i].getLength(); k++) {\r
385             try{\r
386               score +=\r
387                   ResidueProperties.getBLOSUM62(sequence[i].getSequence(k,\r
388                   k + 1),\r
389                                                 sequence[j].getSequence(k,\r
390                   k + 1));\r
391             }catch(Exception ex){System.out.println("err creating BLOSUM62 tree");}\r
392           }\r
393           distance[i][j] = (float)score;\r
394           if (score > maxscore) {\r
395             maxscore = score;\r
396           }\r
397         }\r
398       }\r
399       for (int i = 0; i < noseqs-1; i++) {\r
400         for (int j = i; j < noseqs; j++) {\r
401           distance[i][j] =  (float)maxscore - distance[i][j];\r
402           distance[j][i] = distance[i][j];\r
403         }\r
404       }\r
405     } else if (pwtype.equals("SW")) {\r
406       float max = -1;\r
407       for (int i = 0; i < noseqs-1; i++) {\r
408         for (int j = i; j < noseqs; j++) {\r
409           AlignSeq as = new AlignSeq(sequence[i],sequence[j],"pep");\r
410           as.calcScoreMatrix();\r
411           as.traceAlignment();\r
412           as.printAlignment();\r
413           distance[i][j] = (float)as.maxscore;\r
414           if (max < distance[i][j]) {\r
415             max = distance[i][j];\r
416           }\r
417         }\r
418       }\r
419       for (int i = 0; i < noseqs-1; i++) {\r
420         for (int j = i; j < noseqs; j++) {\r
421           distance[i][j] =  max - distance[i][j];\r
422           distance[j][i] = distance[i][j];\r
423         }\r
424       }\r
425     }\r
426 \r
427     return distance;\r
428   }\r
429 \r
430   public void makeLeaves() {\r
431     cluster = new Vector();\r
432 \r
433     for (int i=0; i < noseqs; i++) {\r
434       SequenceNode sn = new SequenceNode();\r
435 \r
436       sn.setElement(sequence[i]);\r
437       sn.setName(sequence[i].getName());\r
438       node.addElement(sn);\r
439 \r
440       int[] value = new int[1];\r
441       value[0] = i;\r
442 \r
443       Cluster c = new Cluster(value);\r
444       cluster.addElement(c);\r
445     }\r
446   }\r
447 \r
448   public Vector findLeaves(SequenceNode node, Vector leaves) {\r
449     if (node == null) {\r
450       return leaves;\r
451     }\r
452 \r
453     if (node.left() == null && node.right() == null) {\r
454       leaves.addElement(node);\r
455       return leaves;\r
456     } else {\r
457       findLeaves((SequenceNode)node.left(),leaves);\r
458       findLeaves((SequenceNode)node.right(),leaves);\r
459     }\r
460     return leaves;\r
461   }\r
462 \r
463   public Object findLeaf(SequenceNode node, int count) {\r
464     found = _findLeaf(node,count);\r
465 \r
466     return found;\r
467   }\r
468   public Object _findLeaf(SequenceNode node,int count) {\r
469     if (node == null) {\r
470       return null;\r
471     }\r
472     if (node.ycount == count) {\r
473       found = node.element();\r
474       return found;\r
475     } else {\r
476       _findLeaf((SequenceNode)node.left(),count);\r
477       _findLeaf((SequenceNode)node.right(),count);\r
478     }\r
479 \r
480     return found;\r
481   }\r
482 \r
483   public void printNode(SequenceNode node) {\r
484     if (node == null) {\r
485       return;\r
486     }\r
487     if (node.left() == null && node.right() == null) {\r
488       System.out.println("Leaf = " + ((SequenceI)node.element()).getName());\r
489       System.out.println("Dist " + ((SequenceNode)node).dist);\r
490       System.out.println("Boot " + node.getBootstrap());\r
491     } else {\r
492       System.out.println("Dist " + ((SequenceNode)node).dist);\r
493       printNode((SequenceNode)node.left());\r
494       printNode((SequenceNode)node.right());\r
495     }\r
496   }\r
497   public void findMaxDist(SequenceNode node) {\r
498     if (node == null) {\r
499       return;\r
500     }\r
501     if (node.left() == null && node.right() == null) {\r
502 \r
503       float dist = ((SequenceNode)node).dist;\r
504       if (dist > maxDistValue) {\r
505           maxdist      = (SequenceNode)node;\r
506           maxDistValue = dist;\r
507       }\r
508     } else {\r
509       findMaxDist((SequenceNode)node.left());\r
510       findMaxDist((SequenceNode)node.right());\r
511     }\r
512   }\r
513     public Vector getGroups() {\r
514         return groups;\r
515     }\r
516     public float getMaxHeight() {\r
517         return maxheight;\r
518     }\r
519   public void  groupNodes(SequenceNode node, float threshold) {\r
520     if (node == null) {\r
521       return;\r
522     }\r
523 \r
524     if (node.height/maxheight > threshold) {\r
525       groups.addElement(node);\r
526     } else {\r
527       groupNodes((SequenceNode)node.left(),threshold);\r
528       groupNodes((SequenceNode)node.right(),threshold);\r
529     }\r
530   }\r
531 \r
532   public float findHeight(SequenceNode node) {\r
533 \r
534     if (node == null) {\r
535       return maxheight;\r
536     }\r
537 \r
538     if (node.left() == null && node.right() == null) {\r
539       node.height = ((SequenceNode)node.parent()).height + node.dist;\r
540 \r
541       if (node.height > maxheight) {\r
542         return node.height;\r
543       } else {\r
544         return maxheight;\r
545       }\r
546     } else {\r
547       if (node.parent() != null) {\r
548         node.height = ((SequenceNode)node.parent()).height + node.dist;\r
549       } else {\r
550         maxheight = 0;\r
551         node.height = (float)0.0;\r
552       }\r
553 \r
554       maxheight = findHeight((SequenceNode)(node.left()));\r
555       maxheight = findHeight((SequenceNode)(node.right()));\r
556     }\r
557     return maxheight;\r
558   }\r
559   public SequenceNode reRoot() {\r
560     if (maxdist != null) {\r
561       ycount = 0;\r
562       float tmpdist = maxdist.dist;\r
563 \r
564       // New top\r
565       SequenceNode sn = new SequenceNode();\r
566       sn.setParent(null);\r
567 \r
568       // New right hand of top\r
569       SequenceNode snr = (SequenceNode)maxdist.parent();\r
570       changeDirection(snr,maxdist);\r
571       System.out.println("Printing reversed tree");\r
572       printN(snr);\r
573       snr.dist = tmpdist/2;\r
574       maxdist.dist = tmpdist/2;\r
575 \r
576       snr.setParent(sn);\r
577       maxdist.setParent(sn);\r
578 \r
579       sn.setRight(snr);\r
580       sn.setLeft(maxdist);\r
581 \r
582       top = sn;\r
583 \r
584       ycount = 0;\r
585       reCount(top);\r
586       findHeight(top);\r
587 \r
588     }\r
589     return top;\r
590   }\r
591   public static void printN(SequenceNode node) {\r
592     if (node == null) {\r
593       return;\r
594     }\r
595 \r
596     if (node.left() != null && node.right() != null) {\r
597       printN((SequenceNode)node.left());\r
598       printN((SequenceNode)node.right());\r
599     } else {\r
600       System.out.println(" name = " + ((SequenceI)node.element()).getName());\r
601     }\r
602     System.out.println(" dist = " + ((SequenceNode)node).dist + " " + ((SequenceNode)node).count + " " + ((SequenceNode)node).height);\r
603   }\r
604 \r
605     public void reCount(SequenceNode node) {\r
606         ycount = 0;\r
607         _reCount(node);\r
608     }\r
609   public void _reCount(SequenceNode node) {\r
610     if (node == null) {\r
611       return;\r
612     }\r
613 \r
614     if (node.left() != null && node.right() != null) {\r
615       _reCount((SequenceNode)node.left());\r
616       _reCount((SequenceNode)node.right());\r
617 \r
618       SequenceNode l = (SequenceNode)node.left();\r
619       SequenceNode r = (SequenceNode)node.right();\r
620 \r
621       ((SequenceNode)node).count  = l.count + r.count;\r
622       ((SequenceNode)node).ycount = (l.ycount + r.ycount)/2;\r
623 \r
624     } else {\r
625       ((SequenceNode)node).count = 1;\r
626       ((SequenceNode)node).ycount = ycount++;\r
627     }\r
628 \r
629   }\r
630     public void swapNodes(SequenceNode node) {\r
631         if (node == null) {\r
632             return;\r
633         }\r
634         SequenceNode tmp = (SequenceNode)node.left();\r
635 \r
636         node.setLeft(node.right());\r
637         node.setRight(tmp);\r
638     }\r
639   public void changeDirection(SequenceNode node, SequenceNode dir) {\r
640     if (node == null) {\r
641       return;\r
642     }\r
643     if (node.parent() != top) {\r
644       changeDirection((SequenceNode)node.parent(), node);\r
645 \r
646       SequenceNode tmp = (SequenceNode)node.parent();\r
647 \r
648       if (dir == node.left()) {\r
649         node.setParent(dir);\r
650         node.setLeft(tmp);\r
651       } else if (dir == node.right()) {\r
652         node.setParent(dir);\r
653         node.setRight(tmp);\r
654       }\r
655 \r
656     } else {\r
657       if (dir == node.left()) {\r
658         node.setParent(node.left());\r
659 \r
660         if (top.left() == node) {\r
661           node.setRight(top.right());\r
662         } else {\r
663           node.setRight(top.left());\r
664         }\r
665       } else {\r
666         node.setParent(node.right());\r
667 \r
668         if (top.left() == node) {\r
669           node.setLeft(top.right());\r
670         } else {\r
671           node.setLeft(top.left());\r
672         }\r
673       }\r
674     }\r
675   }\r
676     public void setMaxDist(SequenceNode node) {\r
677         this.maxdist = maxdist;\r
678     }\r
679     public SequenceNode getMaxDist() {\r
680         return maxdist;\r
681     }\r
682     public SequenceNode getTopNode() {\r
683         return top;\r
684     }\r
685 \r
686 }\r
687 \r
688 \r
689 \r
690 class Cluster {\r
691 \r
692   int[] value;\r
693 \r
694   public Cluster(int[] value) {\r
695     this.value = value;\r
696   }\r
697 \r
698 }\r
699 \r