8708f1dca50582cb0697e0f7e25afa4eedcca82f
[jalview.git] / src / jalview / analysis / NJTree.java
1 package jalview.analysis;\r
2 \r
3 import jalview.datamodel.*;\r
4 import jalview.util.*;\r
5 import jalview.schemes.ResidueProperties;\r
6 import java.util.*;\r
7 \r
8 import jalview.io.NewickFile;\r
9 \r
10 public class NJTree {\r
11 \r
12   Vector cluster;\r
13   SequenceI[] sequence;\r
14 \r
15   int done[];\r
16   int noseqs;\r
17   int noClus;\r
18 \r
19   float distance[][];\r
20 \r
21   int mini;\r
22   int minj;\r
23   float ri;\r
24   float rj;\r
25 \r
26   Vector groups = new Vector();\r
27   SequenceNode maxdist;\r
28   SequenceNode top;\r
29 \r
30   float maxDistValue;\r
31   float maxheight;\r
32 \r
33   int ycount;\r
34 \r
35   Vector node;\r
36 \r
37   String type;\r
38   String pwtype;\r
39 \r
40   Object found = null;\r
41   Object leaves = null;\r
42 \r
43   int start;\r
44   int end;\r
45 \r
46   public NJTree(SequenceNode node) {\r
47     top = node;\r
48     maxheight = findHeight(top);\r
49   }\r
50 \r
51   public String toString()\r
52   {\r
53     jalview.io.NewickFile fout = new jalview.io.NewickFile(getTopNode());\r
54     return fout.print(false,true); // distances only\r
55   }\r
56 \r
57   public NJTree(SequenceI[] seqs, NewickFile treefile) {\r
58     top = treefile.getTree();\r
59     maxheight = findHeight(top);\r
60     SequenceIdMatcher algnIds = new SequenceIdMatcher(seqs);\r
61 \r
62     Vector leaves = new Vector();\r
63     findLeaves(top, leaves);\r
64 \r
65     int i = 0;\r
66     int namesleft = seqs.length;\r
67 \r
68     SequenceNode j;\r
69     SequenceI nam;\r
70     String realnam;\r
71     while (i < leaves.size())\r
72     {\r
73       j = (SequenceNode) leaves.elementAt(i++);\r
74       realnam = j.getName();\r
75       nam = null;\r
76       if (namesleft>-1)\r
77         nam = algnIds.findIdMatch(realnam);\r
78       if (nam != null) {\r
79         j.setElement(nam);\r
80         namesleft--;\r
81       } else {\r
82         j.setElement(new Sequence(realnam, "THISISAPLACEHLDER"));\r
83         j.setPlaceholder(true);\r
84 \r
85       }\r
86     }\r
87   }\r
88 \r
89   public NJTree(SequenceI[] sequence,int start, int end) {\r
90     this(sequence,"NJ","BL",start,end);\r
91   }\r
92 \r
93   public NJTree(SequenceI[] sequence,String type,String pwtype,int start, int end ) {\r
94 \r
95     this.sequence = sequence;\r
96     this.node     = new Vector();\r
97     this.type     = type;\r
98     this.pwtype   = pwtype;\r
99     this.start    = start;\r
100     this.end      = end;\r
101 \r
102     if (!(type.equals("NJ"))) {\r
103       type = "AV";\r
104     }\r
105 \r
106     if (!(pwtype.equals("PID"))) {\r
107       type = "BL";\r
108     }\r
109 \r
110     int i=0;\r
111 \r
112     done = new int[sequence.length];\r
113 \r
114 \r
115     while (i < sequence.length  && sequence[i] != null) {\r
116       done[i] = 0;\r
117       i++;\r
118     }\r
119 \r
120     noseqs = i++;\r
121 \r
122     distance = findDistances();\r
123 \r
124     makeLeaves();\r
125 \r
126     noClus = cluster.size();\r
127 \r
128     cluster();\r
129 \r
130   }\r
131 \r
132 \r
133   public void cluster() {\r
134 \r
135     while (noClus > 2) {\r
136       if (type.equals("NJ")) {\r
137         float mind = findMinNJDistance();\r
138       } else {\r
139         float mind = findMinDistance();\r
140       }\r
141 \r
142       Cluster c = joinClusters(mini,minj);\r
143 \r
144 \r
145       done[minj] = 1;\r
146 \r
147       cluster.setElementAt(null,minj);\r
148       cluster.setElementAt(c,mini);\r
149 \r
150       noClus--;\r
151     }\r
152 \r
153     boolean onefound = false;\r
154 \r
155     int one = -1;\r
156     int two = -1;\r
157 \r
158     for (int i=0; i < noseqs; i++) {\r
159       if (done[i] != 1) {\r
160         if (onefound == false) {\r
161           two = i;\r
162           onefound = true;\r
163         } else {\r
164           one = i;\r
165         }\r
166       }\r
167     }\r
168 \r
169     Cluster c = joinClusters(one,two);\r
170     top = (SequenceNode)(node.elementAt(one));\r
171 \r
172     reCount(top);\r
173     findHeight(top);\r
174     findMaxDist(top);\r
175 \r
176   }\r
177 \r
178   public Cluster joinClusters(int i, int j) {\r
179 \r
180     float dist = distance[i][j];\r
181 \r
182     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
183     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
184 \r
185     int[] value = new int[noi + noj];\r
186 \r
187     for (int ii = 0; ii < noi;ii++) {\r
188       value[ii] =  ((Cluster)cluster.elementAt(i)).value[ii];\r
189     }\r
190 \r
191     for (int ii = noi; ii < noi+ noj;ii++) {\r
192       value[ii] =  ((Cluster)cluster.elementAt(j)).value[ii-noi];\r
193     }\r
194 \r
195     Cluster c = new Cluster(value);\r
196 \r
197     ri = findr(i,j);\r
198     rj = findr(j,i);\r
199 \r
200     if (type.equals("NJ")) {\r
201       findClusterNJDistance(i,j);\r
202     } else {\r
203       findClusterDistance(i,j);\r
204     }\r
205 \r
206     SequenceNode sn = new SequenceNode();\r
207 \r
208     sn.setLeft((SequenceNode)(node.elementAt(i)));\r
209     sn.setRight((SequenceNode)(node.elementAt(j)));\r
210 \r
211     SequenceNode tmpi = (SequenceNode)(node.elementAt(i));\r
212     SequenceNode tmpj = (SequenceNode)(node.elementAt(j));\r
213 \r
214     if (type.equals("NJ")) {\r
215       findNewNJDistances(tmpi,tmpj,dist);\r
216     } else {\r
217       findNewDistances(tmpi,tmpj,dist);\r
218     }\r
219 \r
220     tmpi.setParent(sn);\r
221     tmpj.setParent(sn);\r
222 \r
223     node.setElementAt(sn,i);\r
224     return c;\r
225   }\r
226 \r
227   public void findNewNJDistances(SequenceNode tmpi, SequenceNode tmpj, float dist) {\r
228 \r
229     float ih = 0;\r
230     float jh = 0;\r
231 \r
232     SequenceNode sni = tmpi;\r
233     SequenceNode snj = tmpj;\r
234 \r
235     tmpi.dist = (dist + ri - rj)/2;\r
236     tmpj.dist = (dist - tmpi.dist);\r
237 \r
238     if (tmpi.dist < 0) {\r
239       tmpi.dist = 0;\r
240     }\r
241     if (tmpj.dist < 0) {\r
242       tmpj.dist = 0;\r
243     }\r
244   }\r
245 \r
246   public void findNewDistances(SequenceNode tmpi,SequenceNode tmpj,float dist) {\r
247 \r
248     float ih = 0;\r
249     float jh = 0;\r
250 \r
251     SequenceNode sni = tmpi;\r
252     SequenceNode snj = tmpj;\r
253 \r
254     while (sni != null) {\r
255       ih = ih + sni.dist;\r
256       sni = (SequenceNode)sni.left();\r
257     }\r
258 \r
259     while (snj != null) {\r
260       jh = jh + snj.dist;\r
261       snj = (SequenceNode)snj.left();\r
262     }\r
263 \r
264     tmpi.dist = (dist/2 - ih);\r
265     tmpj.dist = (dist/2 - jh);\r
266   }\r
267 \r
268 \r
269 \r
270   public void findClusterDistance(int i, int j) {\r
271 \r
272     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
273     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
274 \r
275     // New distances from cluster to others\r
276     float[] newdist = new float[noseqs];\r
277 \r
278     for (int l = 0; l < noseqs; l++) {\r
279       if ( l != i && l != j) {\r
280         newdist[l] = (distance[i][l] * noi + distance[j][l] * noj)/(noi + noj);\r
281       } else {\r
282         newdist[l] = 0;\r
283       }\r
284     }\r
285 \r
286     for (int ii=0; ii < noseqs;ii++) {\r
287       distance[i][ii] = newdist[ii];\r
288       distance[ii][i] = newdist[ii];\r
289     }\r
290   }\r
291 \r
292   public void findClusterNJDistance(int i, int j) {\r
293 \r
294     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
295     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
296 \r
297     // New distances from cluster to others\r
298     float[] newdist = new float[noseqs];\r
299 \r
300     for (int l = 0; l < noseqs; l++) {\r
301       if ( l != i && l != j) {\r
302         newdist[l] = (distance[i][l] + distance[j][l] - distance[i][j])/2;\r
303       } else {\r
304         newdist[l] = 0;\r
305       }\r
306     }\r
307 \r
308     for (int ii=0; ii < noseqs;ii++) {\r
309       distance[i][ii] = newdist[ii];\r
310       distance[ii][i] = newdist[ii];\r
311     }\r
312   }\r
313 \r
314   public float findr(int i, int j) {\r
315 \r
316     float tmp = 1;\r
317     for (int k=0; k < noseqs;k++) {\r
318       if (k!= i && k!= j && done[k] != 1) {\r
319         tmp = tmp + distance[i][k];\r
320       }\r
321     }\r
322 \r
323     if (noClus > 2) {\r
324       tmp = tmp/(noClus - 2);\r
325     }\r
326 \r
327     return tmp;\r
328   }\r
329 \r
330   public float findMinNJDistance() {\r
331 \r
332     float min = 100000;\r
333 \r
334     for (int i=0; i < noseqs-1; i++) {\r
335       for (int j=i+1;j < noseqs;j++) {\r
336         if (done[i] != 1 && done[j] != 1) {\r
337           float tmp = distance[i][j] - (findr(i,j) + findr(j,i));\r
338           if (tmp < min) {\r
339 \r
340             mini = i;\r
341             minj = j;\r
342 \r
343             min = tmp;\r
344 \r
345           }\r
346         }\r
347       }\r
348     }\r
349     return min;\r
350   }\r
351 \r
352   public float findMinDistance() {\r
353 \r
354     float min = 100000;\r
355 \r
356     for (int i=0; i < noseqs-1;i++) {\r
357       for (int j = i+1; j < noseqs;j++) {\r
358         if (done[i] != 1 && done[j] != 1) {\r
359           if (distance[i][j] < min) {\r
360             mini = i;\r
361             minj = j;\r
362 \r
363             min = distance[i][j];\r
364           }\r
365         }\r
366       }\r
367     }\r
368     return min;\r
369   }\r
370 \r
371   public float[][] findDistances() {\r
372 \r
373     float[][] distance = new float[noseqs][noseqs];\r
374     if (pwtype.equals("PID")) {\r
375       for (int i = 0; i < noseqs-1; i++) {\r
376         for (int j = i; j < noseqs; j++) {\r
377           if (j==i) {\r
378             distance[i][i] = 0;\r
379           } else {\r
380             distance[i][j] = 100-Comparison.PID(sequence[i], sequence[j]);\r
381             distance[j][i] = distance[i][j];\r
382           }\r
383         }\r
384       }\r
385     } else if (pwtype.equals("BL")) {\r
386       int   maxscore = 0;\r
387 \r
388       for (int i = 0; i < noseqs-1; i++) {\r
389         for (int j = i; j < noseqs; j++) {\r
390           int score = 0;\r
391           for (int k=0; k < sequence[i].getLength(); k++) {\r
392             try{\r
393               score +=\r
394                   ResidueProperties.getBLOSUM62(sequence[i].getSequence(k,\r
395                   k + 1),\r
396                                                 sequence[j].getSequence(k,\r
397                   k + 1));\r
398             }catch(Exception ex){System.out.println("err creating BLOSUM62 tree");}\r
399           }\r
400           distance[i][j] = (float)score;\r
401           if (score > maxscore) {\r
402             maxscore = score;\r
403           }\r
404         }\r
405       }\r
406       for (int i = 0; i < noseqs-1; i++) {\r
407         for (int j = i; j < noseqs; j++) {\r
408           distance[i][j] =  (float)maxscore - distance[i][j];\r
409           distance[j][i] = distance[i][j];\r
410         }\r
411       }\r
412     } else if (pwtype.equals("SW")) {\r
413       float max = -1;\r
414       for (int i = 0; i < noseqs-1; i++) {\r
415         for (int j = i; j < noseqs; j++) {\r
416           AlignSeq as = new AlignSeq(sequence[i],sequence[j],"pep");\r
417           as.calcScoreMatrix();\r
418           as.traceAlignment();\r
419           as.printAlignment();\r
420           distance[i][j] = (float)as.maxscore;\r
421           if (max < distance[i][j]) {\r
422             max = distance[i][j];\r
423           }\r
424         }\r
425       }\r
426       for (int i = 0; i < noseqs-1; i++) {\r
427         for (int j = i; j < noseqs; j++) {\r
428           distance[i][j] =  max - distance[i][j];\r
429           distance[j][i] = distance[i][j];\r
430         }\r
431       }\r
432     }\r
433 \r
434     return distance;\r
435   }\r
436 \r
437   public void makeLeaves() {\r
438     cluster = new Vector();\r
439 \r
440     for (int i=0; i < noseqs; i++) {\r
441       SequenceNode sn = new SequenceNode();\r
442 \r
443       sn.setElement(sequence[i]);\r
444       sn.setName(sequence[i].getName());\r
445       node.addElement(sn);\r
446 \r
447       int[] value = new int[1];\r
448       value[0] = i;\r
449 \r
450       Cluster c = new Cluster(value);\r
451       cluster.addElement(c);\r
452     }\r
453   }\r
454 \r
455   public Vector findLeaves(SequenceNode node, Vector leaves) {\r
456     if (node == null) {\r
457       return leaves;\r
458     }\r
459 \r
460     if (node.left() == null && node.right() == null) {\r
461       leaves.addElement(node);\r
462       return leaves;\r
463     } else {\r
464       findLeaves((SequenceNode)node.left(),leaves);\r
465       findLeaves((SequenceNode)node.right(),leaves);\r
466     }\r
467     return leaves;\r
468   }\r
469 \r
470   public Object findLeaf(SequenceNode node, int count) {\r
471     found = _findLeaf(node,count);\r
472 \r
473     return found;\r
474   }\r
475   public Object _findLeaf(SequenceNode node,int count) {\r
476     if (node == null) {\r
477       return null;\r
478     }\r
479     if (node.ycount == count) {\r
480       found = node.element();\r
481       return found;\r
482     } else {\r
483       _findLeaf((SequenceNode)node.left(),count);\r
484       _findLeaf((SequenceNode)node.right(),count);\r
485     }\r
486 \r
487     return found;\r
488   }\r
489 \r
490   public void printNode(SequenceNode node) {\r
491     if (node == null) {\r
492       return;\r
493     }\r
494     if (node.left() == null && node.right() == null) {\r
495       System.out.println("Leaf = " + ((SequenceI)node.element()).getName());\r
496       System.out.println("Dist " + ((SequenceNode)node).dist);\r
497       System.out.println("Boot " + node.getBootstrap());\r
498     } else {\r
499       System.out.println("Dist " + ((SequenceNode)node).dist);\r
500       printNode((SequenceNode)node.left());\r
501       printNode((SequenceNode)node.right());\r
502     }\r
503   }\r
504   public void findMaxDist(SequenceNode node) {\r
505     if (node == null) {\r
506       return;\r
507     }\r
508     if (node.left() == null && node.right() == null) {\r
509 \r
510       float dist = ((SequenceNode)node).dist;\r
511       if (dist > maxDistValue) {\r
512           maxdist      = (SequenceNode)node;\r
513           maxDistValue = dist;\r
514       }\r
515     } else {\r
516       findMaxDist((SequenceNode)node.left());\r
517       findMaxDist((SequenceNode)node.right());\r
518     }\r
519   }\r
520     public Vector getGroups() {\r
521         return groups;\r
522     }\r
523     public float getMaxHeight() {\r
524         return maxheight;\r
525     }\r
526   public void  groupNodes(SequenceNode node, float threshold) {\r
527     if (node == null) {\r
528       return;\r
529     }\r
530 \r
531     if (node.height/maxheight > threshold) {\r
532       groups.addElement(node);\r
533     } else {\r
534       groupNodes((SequenceNode)node.left(),threshold);\r
535       groupNodes((SequenceNode)node.right(),threshold);\r
536     }\r
537   }\r
538 \r
539   public float findHeight(SequenceNode node) {\r
540 \r
541     if (node == null) {\r
542       return maxheight;\r
543     }\r
544 \r
545     if (node.left() == null && node.right() == null) {\r
546       node.height = ((SequenceNode)node.parent()).height + node.dist;\r
547 \r
548       if (node.height > maxheight) {\r
549         return node.height;\r
550       } else {\r
551         return maxheight;\r
552       }\r
553     } else {\r
554       if (node.parent() != null) {\r
555         node.height = ((SequenceNode)node.parent()).height + node.dist;\r
556       } else {\r
557         maxheight = 0;\r
558         node.height = (float)0.0;\r
559       }\r
560 \r
561       maxheight = findHeight((SequenceNode)(node.left()));\r
562       maxheight = findHeight((SequenceNode)(node.right()));\r
563     }\r
564     return maxheight;\r
565   }\r
566   public SequenceNode reRoot() {\r
567     if (maxdist != null) {\r
568       ycount = 0;\r
569       float tmpdist = maxdist.dist;\r
570 \r
571       // New top\r
572       SequenceNode sn = new SequenceNode();\r
573       sn.setParent(null);\r
574 \r
575       // New right hand of top\r
576       SequenceNode snr = (SequenceNode)maxdist.parent();\r
577       changeDirection(snr,maxdist);\r
578       System.out.println("Printing reversed tree");\r
579       printN(snr);\r
580       snr.dist = tmpdist/2;\r
581       maxdist.dist = tmpdist/2;\r
582 \r
583       snr.setParent(sn);\r
584       maxdist.setParent(sn);\r
585 \r
586       sn.setRight(snr);\r
587       sn.setLeft(maxdist);\r
588 \r
589       top = sn;\r
590 \r
591       ycount = 0;\r
592       reCount(top);\r
593       findHeight(top);\r
594 \r
595     }\r
596     return top;\r
597   }\r
598   public static void printN(SequenceNode node) {\r
599     if (node == null) {\r
600       return;\r
601     }\r
602 \r
603     if (node.left() != null && node.right() != null) {\r
604       printN((SequenceNode)node.left());\r
605       printN((SequenceNode)node.right());\r
606     } else {\r
607       System.out.println(" name = " + ((SequenceI)node.element()).getName());\r
608     }\r
609     System.out.println(" dist = " + ((SequenceNode)node).dist + " " + ((SequenceNode)node).count + " " + ((SequenceNode)node).height);\r
610   }\r
611 \r
612     public void reCount(SequenceNode node) {\r
613         ycount = 0;\r
614         _reCount(node);\r
615     }\r
616   public void _reCount(SequenceNode node) {\r
617     if (node == null) {\r
618       return;\r
619     }\r
620 \r
621     if (node.left() != null && node.right() != null) {\r
622       _reCount((SequenceNode)node.left());\r
623       _reCount((SequenceNode)node.right());\r
624 \r
625       SequenceNode l = (SequenceNode)node.left();\r
626       SequenceNode r = (SequenceNode)node.right();\r
627 \r
628       ((SequenceNode)node).count  = l.count + r.count;\r
629       ((SequenceNode)node).ycount = (l.ycount + r.ycount)/2;\r
630 \r
631     } else {\r
632       ((SequenceNode)node).count = 1;\r
633       ((SequenceNode)node).ycount = ycount++;\r
634     }\r
635 \r
636   }\r
637     public void swapNodes(SequenceNode node) {\r
638         if (node == null) {\r
639             return;\r
640         }\r
641         SequenceNode tmp = (SequenceNode)node.left();\r
642 \r
643         node.setLeft(node.right());\r
644         node.setRight(tmp);\r
645     }\r
646   public void changeDirection(SequenceNode node, SequenceNode dir) {\r
647     if (node == null) {\r
648       return;\r
649     }\r
650     if (node.parent() != top) {\r
651       changeDirection((SequenceNode)node.parent(), node);\r
652 \r
653       SequenceNode tmp = (SequenceNode)node.parent();\r
654 \r
655       if (dir == node.left()) {\r
656         node.setParent(dir);\r
657         node.setLeft(tmp);\r
658       } else if (dir == node.right()) {\r
659         node.setParent(dir);\r
660         node.setRight(tmp);\r
661       }\r
662 \r
663     } else {\r
664       if (dir == node.left()) {\r
665         node.setParent(node.left());\r
666 \r
667         if (top.left() == node) {\r
668           node.setRight(top.right());\r
669         } else {\r
670           node.setRight(top.left());\r
671         }\r
672       } else {\r
673         node.setParent(node.right());\r
674 \r
675         if (top.left() == node) {\r
676           node.setLeft(top.right());\r
677         } else {\r
678           node.setLeft(top.left());\r
679         }\r
680       }\r
681     }\r
682   }\r
683     public void setMaxDist(SequenceNode node) {\r
684         this.maxdist = maxdist;\r
685     }\r
686     public SequenceNode getMaxDist() {\r
687         return maxdist;\r
688     }\r
689     public SequenceNode getTopNode() {\r
690         return top;\r
691     }\r
692 \r
693 }\r
694 \r
695 \r
696 \r
697 class Cluster {\r
698 \r
699   int[] value;\r
700 \r
701   public Cluster(int[] value) {\r
702     this.value = value;\r
703   }\r
704 \r
705 }\r
706 \r