tidied up system.out messages and moved many to stderr.
[jalview.git] / src / jalview / analysis / NJTree.java
1 package jalview.analysis;\r
2 \r
3 import jalview.datamodel.*;\r
4 import jalview.util.*;\r
5 import jalview.schemes.ResidueProperties;\r
6 import java.util.*;\r
7 \r
8 import jalview.io.NewickFile;\r
9 \r
10 public class NJTree {\r
11 \r
12   Vector cluster;\r
13   SequenceI[] sequence;\r
14 \r
15   int done[];\r
16   int noseqs;\r
17   int noClus;\r
18 \r
19   float distance[][];\r
20 \r
21   int mini;\r
22   int minj;\r
23   float ri;\r
24   float rj;\r
25 \r
26   Vector groups = new Vector();\r
27   SequenceNode maxdist;\r
28   SequenceNode top;\r
29 \r
30   float maxDistValue;\r
31   float maxheight;\r
32 \r
33   int ycount;\r
34 \r
35   Vector node;\r
36 \r
37   String type;\r
38   String pwtype;\r
39 \r
40   Object found = null;\r
41   Object leaves = null;\r
42 \r
43   int start;\r
44   int end;\r
45 \r
46   public NJTree(SequenceNode node) {\r
47     top = node;\r
48     maxheight = findHeight(top);\r
49   }\r
50 \r
51   public String toString()\r
52   {\r
53     jalview.io.NewickFile fout = new jalview.io.NewickFile(getTopNode());\r
54     return fout.print(false,true); // distances only\r
55   }\r
56 \r
57   public NJTree(SequenceI[] seqs, NewickFile treefile) {\r
58     top = treefile.getTree();\r
59     maxheight = findHeight(top);\r
60     SequenceIdMatcher algnIds = new SequenceIdMatcher(seqs);\r
61 \r
62     Vector leaves = new Vector();\r
63     findLeaves(top, leaves);\r
64 \r
65     int i = 0;\r
66     int namesleft = seqs.length;\r
67 \r
68     SequenceNode j;\r
69     SequenceI nam;\r
70     String realnam;\r
71     while (i < leaves.size())\r
72     {\r
73       j = (SequenceNode) leaves.elementAt(i++);\r
74       realnam = j.getName();\r
75       nam = null;\r
76       if (namesleft>-1)\r
77         nam = algnIds.findIdMatch(realnam);\r
78       if (nam != null) {\r
79         j.setElement(nam);\r
80         namesleft--;\r
81       } else {\r
82         j.setElement(new Sequence(realnam, "THISISAPLACEHLDER"));\r
83         j.setPlaceholder(true);\r
84 \r
85       }\r
86     }\r
87   }\r
88 \r
89   /**\r
90    *\r
91    * used when the alignment associated to a tree has changed.\r
92    *\r
93    * @param alignment Vector\r
94    */\r
95   public void UpdatePlaceHolders(Vector alignment) {\r
96     Vector leaves = new Vector();\r
97     findLeaves(top, leaves);\r
98     int sz = leaves.size();\r
99     SequenceIdMatcher seqmatcher=null;\r
100     int i=0;\r
101     while (i<sz) {\r
102       SequenceNode leaf = (SequenceNode) leaves.elementAt(i++);\r
103       if (alignment.contains(leaf.element()))\r
104         leaf.setPlaceholder(false);\r
105       else {\r
106         if (seqmatcher==null) {\r
107           // Only create this the first time we need it\r
108           SequenceI[] seqs = new SequenceI[alignment.size()];\r
109           for (int j=0; j<seqs.length; j++)\r
110             seqs[j] = (SequenceI) alignment.elementAt(j);\r
111           seqmatcher = new SequenceIdMatcher(seqs);\r
112         }\r
113         SequenceI nam = seqmatcher.findIdMatch(leaf.getName());\r
114         if (nam!=null) {\r
115           leaf.setPlaceholder(false);\r
116           leaf.setElement(nam);\r
117         } else {\r
118           leaf.setPlaceholder(true);\r
119         }\r
120       }\r
121     }\r
122   }\r
123 \r
124   public NJTree(SequenceI[] sequence,int start, int end) {\r
125     this(sequence,"NJ","BL",start,end);\r
126   }\r
127 \r
128   public NJTree(SequenceI[] sequence,String type,String pwtype,int start, int end ) {\r
129 \r
130     this.sequence = sequence;\r
131     this.node     = new Vector();\r
132     this.type     = type;\r
133     this.pwtype   = pwtype;\r
134     this.start    = start;\r
135     this.end      = end;\r
136 \r
137     if (!(type.equals("NJ"))) {\r
138       type = "AV";\r
139     }\r
140 \r
141     if (!(pwtype.equals("PID"))) {\r
142       type = "BL";\r
143     }\r
144 \r
145     int i=0;\r
146 \r
147     done = new int[sequence.length];\r
148 \r
149 \r
150     while (i < sequence.length  && sequence[i] != null) {\r
151       done[i] = 0;\r
152       i++;\r
153     }\r
154 \r
155     noseqs = i++;\r
156 \r
157     distance = findDistances();\r
158 \r
159     makeLeaves();\r
160 \r
161     noClus = cluster.size();\r
162 \r
163     cluster();\r
164 \r
165   }\r
166 \r
167 \r
168   public void cluster() {\r
169 \r
170     while (noClus > 2) {\r
171       if (type.equals("NJ")) {\r
172         float mind = findMinNJDistance();\r
173       } else {\r
174         float mind = findMinDistance();\r
175       }\r
176 \r
177       Cluster c = joinClusters(mini,minj);\r
178 \r
179 \r
180       done[minj] = 1;\r
181 \r
182       cluster.setElementAt(null,minj);\r
183       cluster.setElementAt(c,mini);\r
184 \r
185       noClus--;\r
186     }\r
187 \r
188     boolean onefound = false;\r
189 \r
190     int one = -1;\r
191     int two = -1;\r
192 \r
193     for (int i=0; i < noseqs; i++) {\r
194       if (done[i] != 1) {\r
195         if (onefound == false) {\r
196           two = i;\r
197           onefound = true;\r
198         } else {\r
199           one = i;\r
200         }\r
201       }\r
202     }\r
203 \r
204     Cluster c = joinClusters(one,two);\r
205     top = (SequenceNode)(node.elementAt(one));\r
206 \r
207     reCount(top);\r
208     findHeight(top);\r
209     findMaxDist(top);\r
210 \r
211   }\r
212 \r
213   public Cluster joinClusters(int i, int j) {\r
214 \r
215     float dist = distance[i][j];\r
216 \r
217     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
218     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
219 \r
220     int[] value = new int[noi + noj];\r
221 \r
222     for (int ii = 0; ii < noi;ii++) {\r
223       value[ii] =  ((Cluster)cluster.elementAt(i)).value[ii];\r
224     }\r
225 \r
226     for (int ii = noi; ii < noi+ noj;ii++) {\r
227       value[ii] =  ((Cluster)cluster.elementAt(j)).value[ii-noi];\r
228     }\r
229 \r
230     Cluster c = new Cluster(value);\r
231 \r
232     ri = findr(i,j);\r
233     rj = findr(j,i);\r
234 \r
235     if (type.equals("NJ")) {\r
236       findClusterNJDistance(i,j);\r
237     } else {\r
238       findClusterDistance(i,j);\r
239     }\r
240 \r
241     SequenceNode sn = new SequenceNode();\r
242 \r
243     sn.setLeft((SequenceNode)(node.elementAt(i)));\r
244     sn.setRight((SequenceNode)(node.elementAt(j)));\r
245 \r
246     SequenceNode tmpi = (SequenceNode)(node.elementAt(i));\r
247     SequenceNode tmpj = (SequenceNode)(node.elementAt(j));\r
248 \r
249     if (type.equals("NJ")) {\r
250       findNewNJDistances(tmpi,tmpj,dist);\r
251     } else {\r
252       findNewDistances(tmpi,tmpj,dist);\r
253     }\r
254 \r
255     tmpi.setParent(sn);\r
256     tmpj.setParent(sn);\r
257 \r
258     node.setElementAt(sn,i);\r
259     return c;\r
260   }\r
261 \r
262   public void findNewNJDistances(SequenceNode tmpi, SequenceNode tmpj, float dist) {\r
263 \r
264     float ih = 0;\r
265     float jh = 0;\r
266 \r
267     SequenceNode sni = tmpi;\r
268     SequenceNode snj = tmpj;\r
269 \r
270     tmpi.dist = (dist + ri - rj)/2;\r
271     tmpj.dist = (dist - tmpi.dist);\r
272 \r
273     if (tmpi.dist < 0) {\r
274       tmpi.dist = 0;\r
275     }\r
276     if (tmpj.dist < 0) {\r
277       tmpj.dist = 0;\r
278     }\r
279   }\r
280 \r
281   public void findNewDistances(SequenceNode tmpi,SequenceNode tmpj,float dist) {\r
282 \r
283     float ih = 0;\r
284     float jh = 0;\r
285 \r
286     SequenceNode sni = tmpi;\r
287     SequenceNode snj = tmpj;\r
288 \r
289     while (sni != null) {\r
290       ih = ih + sni.dist;\r
291       sni = (SequenceNode)sni.left();\r
292     }\r
293 \r
294     while (snj != null) {\r
295       jh = jh + snj.dist;\r
296       snj = (SequenceNode)snj.left();\r
297     }\r
298 \r
299     tmpi.dist = (dist/2 - ih);\r
300     tmpj.dist = (dist/2 - jh);\r
301   }\r
302 \r
303 \r
304 \r
305   public void findClusterDistance(int i, int j) {\r
306 \r
307     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
308     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
309 \r
310     // New distances from cluster to others\r
311     float[] newdist = new float[noseqs];\r
312 \r
313     for (int l = 0; l < noseqs; l++) {\r
314       if ( l != i && l != j) {\r
315         newdist[l] = (distance[i][l] * noi + distance[j][l] * noj)/(noi + noj);\r
316       } else {\r
317         newdist[l] = 0;\r
318       }\r
319     }\r
320 \r
321     for (int ii=0; ii < noseqs;ii++) {\r
322       distance[i][ii] = newdist[ii];\r
323       distance[ii][i] = newdist[ii];\r
324     }\r
325   }\r
326 \r
327   public void findClusterNJDistance(int i, int j) {\r
328 \r
329     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
330     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
331 \r
332     // New distances from cluster to others\r
333     float[] newdist = new float[noseqs];\r
334 \r
335     for (int l = 0; l < noseqs; l++) {\r
336       if ( l != i && l != j) {\r
337         newdist[l] = (distance[i][l] + distance[j][l] - distance[i][j])/2;\r
338       } else {\r
339         newdist[l] = 0;\r
340       }\r
341     }\r
342 \r
343     for (int ii=0; ii < noseqs;ii++) {\r
344       distance[i][ii] = newdist[ii];\r
345       distance[ii][i] = newdist[ii];\r
346     }\r
347   }\r
348 \r
349   public float findr(int i, int j) {\r
350 \r
351     float tmp = 1;\r
352     for (int k=0; k < noseqs;k++) {\r
353       if (k!= i && k!= j && done[k] != 1) {\r
354         tmp = tmp + distance[i][k];\r
355       }\r
356     }\r
357 \r
358     if (noClus > 2) {\r
359       tmp = tmp/(noClus - 2);\r
360     }\r
361 \r
362     return tmp;\r
363   }\r
364 \r
365   public float findMinNJDistance() {\r
366 \r
367     float min = 100000;\r
368 \r
369     for (int i=0; i < noseqs-1; i++) {\r
370       for (int j=i+1;j < noseqs;j++) {\r
371         if (done[i] != 1 && done[j] != 1) {\r
372           float tmp = distance[i][j] - (findr(i,j) + findr(j,i));\r
373           if (tmp < min) {\r
374 \r
375             mini = i;\r
376             minj = j;\r
377 \r
378             min = tmp;\r
379 \r
380           }\r
381         }\r
382       }\r
383     }\r
384     return min;\r
385   }\r
386 \r
387   public float findMinDistance() {\r
388 \r
389     float min = 100000;\r
390 \r
391     for (int i=0; i < noseqs-1;i++) {\r
392       for (int j = i+1; j < noseqs;j++) {\r
393         if (done[i] != 1 && done[j] != 1) {\r
394           if (distance[i][j] < min) {\r
395             mini = i;\r
396             minj = j;\r
397 \r
398             min = distance[i][j];\r
399           }\r
400         }\r
401       }\r
402     }\r
403     return min;\r
404   }\r
405 \r
406   public float[][] findDistances() {\r
407 \r
408     float[][] distance = new float[noseqs][noseqs];\r
409     if (pwtype.equals("PID")) {\r
410       for (int i = 0; i < noseqs-1; i++) {\r
411         for (int j = i; j < noseqs; j++) {\r
412           if (j==i) {\r
413             distance[i][i] = 0;\r
414           } else {\r
415             distance[i][j] = 100-Comparison.PID(sequence[i], sequence[j]);\r
416             distance[j][i] = distance[i][j];\r
417           }\r
418         }\r
419       }\r
420     } else if (pwtype.equals("BL")) {\r
421       int   maxscore = 0;\r
422 \r
423       for (int i = 0; i < noseqs-1; i++) {\r
424         for (int j = i; j < noseqs; j++) {\r
425           int score = 0;\r
426           for (int k=0; k < sequence[i].getLength(); k++) {\r
427             try{\r
428               score +=\r
429                   ResidueProperties.getBLOSUM62(sequence[i].getSequence(k,\r
430                   k + 1),\r
431                                                 sequence[j].getSequence(k,\r
432                   k + 1));\r
433             }catch(Exception ex){System.err.println("err creating BLOSUM62 tree");ex.printStackTrace();}\r
434           }\r
435           distance[i][j] = (float)score;\r
436           if (score > maxscore) {\r
437             maxscore = score;\r
438           }\r
439         }\r
440       }\r
441       for (int i = 0; i < noseqs-1; i++) {\r
442         for (int j = i; j < noseqs; j++) {\r
443           distance[i][j] =  (float)maxscore - distance[i][j];\r
444           distance[j][i] = distance[i][j];\r
445         }\r
446       }\r
447     } else if (pwtype.equals("SW")) {\r
448       float max = -1;\r
449       for (int i = 0; i < noseqs-1; i++) {\r
450         for (int j = i; j < noseqs; j++) {\r
451           AlignSeq as = new AlignSeq(sequence[i],sequence[j],"pep");\r
452           as.calcScoreMatrix();\r
453           as.traceAlignment();\r
454           as.printAlignment();\r
455           distance[i][j] = (float)as.maxscore;\r
456           if (max < distance[i][j]) {\r
457             max = distance[i][j];\r
458           }\r
459         }\r
460       }\r
461       for (int i = 0; i < noseqs-1; i++) {\r
462         for (int j = i; j < noseqs; j++) {\r
463           distance[i][j] =  max - distance[i][j];\r
464           distance[j][i] = distance[i][j];\r
465         }\r
466       }\r
467     }\r
468 \r
469     return distance;\r
470   }\r
471 \r
472   public void makeLeaves() {\r
473     cluster = new Vector();\r
474 \r
475     for (int i=0; i < noseqs; i++) {\r
476       SequenceNode sn = new SequenceNode();\r
477 \r
478       sn.setElement(sequence[i]);\r
479       sn.setName(sequence[i].getName());\r
480       node.addElement(sn);\r
481 \r
482       int[] value = new int[1];\r
483       value[0] = i;\r
484 \r
485       Cluster c = new Cluster(value);\r
486       cluster.addElement(c);\r
487     }\r
488   }\r
489 \r
490   public Vector findLeaves(SequenceNode node, Vector leaves) {\r
491     if (node == null) {\r
492       return leaves;\r
493     }\r
494 \r
495     if (node.left() == null && node.right() == null) {\r
496       leaves.addElement(node);\r
497       return leaves;\r
498     } else {\r
499       findLeaves((SequenceNode)node.left(),leaves);\r
500       findLeaves((SequenceNode)node.right(),leaves);\r
501     }\r
502     return leaves;\r
503   }\r
504 \r
505   public Object findLeaf(SequenceNode node, int count) {\r
506     found = _findLeaf(node,count);\r
507 \r
508     return found;\r
509   }\r
510   public Object _findLeaf(SequenceNode node,int count) {\r
511     if (node == null) {\r
512       return null;\r
513     }\r
514     if (node.ycount == count) {\r
515       found = node.element();\r
516       return found;\r
517     } else {\r
518       _findLeaf((SequenceNode)node.left(),count);\r
519       _findLeaf((SequenceNode)node.right(),count);\r
520     }\r
521 \r
522     return found;\r
523   }\r
524 \r
525   /**\r
526    * printNode is mainly for debugging purposes.\r
527    *\r
528    * @param node SequenceNode\r
529    */\r
530   public void printNode(SequenceNode node) {\r
531     if (node == null) {\r
532       return;\r
533     }\r
534     if (node.left() == null && node.right() == null) {\r
535       System.out.println("Leaf = " + ((SequenceI)node.element()).getName());\r
536       System.out.println("Dist " + ((SequenceNode)node).dist);\r
537       System.out.println("Boot " + node.getBootstrap());\r
538     } else {\r
539       System.out.println("Dist " + ((SequenceNode)node).dist);\r
540       printNode((SequenceNode)node.left());\r
541       printNode((SequenceNode)node.right());\r
542     }\r
543   }\r
544   public void findMaxDist(SequenceNode node) {\r
545     if (node == null) {\r
546       return;\r
547     }\r
548     if (node.left() == null && node.right() == null) {\r
549 \r
550       float dist = ((SequenceNode)node).dist;\r
551       if (dist > maxDistValue) {\r
552           maxdist      = (SequenceNode)node;\r
553           maxDistValue = dist;\r
554       }\r
555     } else {\r
556       findMaxDist((SequenceNode)node.left());\r
557       findMaxDist((SequenceNode)node.right());\r
558     }\r
559   }\r
560     public Vector getGroups() {\r
561         return groups;\r
562     }\r
563     public float getMaxHeight() {\r
564         return maxheight;\r
565     }\r
566   public void  groupNodes(SequenceNode node, float threshold) {\r
567     if (node == null) {\r
568       return;\r
569     }\r
570 \r
571     if (node.height/maxheight > threshold) {\r
572       groups.addElement(node);\r
573     } else {\r
574       groupNodes((SequenceNode)node.left(),threshold);\r
575       groupNodes((SequenceNode)node.right(),threshold);\r
576     }\r
577   }\r
578 \r
579   public float findHeight(SequenceNode node) {\r
580 \r
581     if (node == null) {\r
582       return maxheight;\r
583     }\r
584 \r
585     if (node.left() == null && node.right() == null) {\r
586       node.height = ((SequenceNode)node.parent()).height + node.dist;\r
587 \r
588       if (node.height > maxheight) {\r
589         return node.height;\r
590       } else {\r
591         return maxheight;\r
592       }\r
593     } else {\r
594       if (node.parent() != null) {\r
595         node.height = ((SequenceNode)node.parent()).height + node.dist;\r
596       } else {\r
597         maxheight = 0;\r
598         node.height = (float)0.0;\r
599       }\r
600 \r
601       maxheight = findHeight((SequenceNode)(node.left()));\r
602       maxheight = findHeight((SequenceNode)(node.right()));\r
603     }\r
604     return maxheight;\r
605   }\r
606   public SequenceNode reRoot() {\r
607     if (maxdist != null) {\r
608       ycount = 0;\r
609       float tmpdist = maxdist.dist;\r
610 \r
611       // New top\r
612       SequenceNode sn = new SequenceNode();\r
613       sn.setParent(null);\r
614 \r
615       // New right hand of top\r
616       SequenceNode snr = (SequenceNode)maxdist.parent();\r
617       changeDirection(snr,maxdist);\r
618       System.out.println("Printing reversed tree");\r
619       printN(snr);\r
620       snr.dist = tmpdist/2;\r
621       maxdist.dist = tmpdist/2;\r
622 \r
623       snr.setParent(sn);\r
624       maxdist.setParent(sn);\r
625 \r
626       sn.setRight(snr);\r
627       sn.setLeft(maxdist);\r
628 \r
629       top = sn;\r
630 \r
631       ycount = 0;\r
632       reCount(top);\r
633       findHeight(top);\r
634 \r
635     }\r
636     return top;\r
637   }\r
638   public static void printN(SequenceNode node) {\r
639     if (node == null) {\r
640       return;\r
641     }\r
642 \r
643     if (node.left() != null && node.right() != null) {\r
644       printN((SequenceNode)node.left());\r
645       printN((SequenceNode)node.right());\r
646     } else {\r
647       System.out.println(" name = " + ((SequenceI)node.element()).getName());\r
648     }\r
649     System.out.println(" dist = " + ((SequenceNode)node).dist + " " + ((SequenceNode)node).count + " " + ((SequenceNode)node).height);\r
650   }\r
651 \r
652     public void reCount(SequenceNode node) {\r
653         ycount = 0;\r
654         _reCount(node);\r
655     }\r
656   public void _reCount(SequenceNode node) {\r
657     if (node == null) {\r
658       return;\r
659     }\r
660 \r
661     if (node.left() != null && node.right() != null) {\r
662       _reCount((SequenceNode)node.left());\r
663       _reCount((SequenceNode)node.right());\r
664 \r
665       SequenceNode l = (SequenceNode)node.left();\r
666       SequenceNode r = (SequenceNode)node.right();\r
667 \r
668       ((SequenceNode)node).count  = l.count + r.count;\r
669       ((SequenceNode)node).ycount = (l.ycount + r.ycount)/2;\r
670 \r
671     } else {\r
672       ((SequenceNode)node).count = 1;\r
673       ((SequenceNode)node).ycount = ycount++;\r
674     }\r
675 \r
676   }\r
677     public void swapNodes(SequenceNode node) {\r
678         if (node == null) {\r
679             return;\r
680         }\r
681         SequenceNode tmp = (SequenceNode)node.left();\r
682 \r
683         node.setLeft(node.right());\r
684         node.setRight(tmp);\r
685     }\r
686   public void changeDirection(SequenceNode node, SequenceNode dir) {\r
687     if (node == null) {\r
688       return;\r
689     }\r
690     if (node.parent() != top) {\r
691       changeDirection((SequenceNode)node.parent(), node);\r
692 \r
693       SequenceNode tmp = (SequenceNode)node.parent();\r
694 \r
695       if (dir == node.left()) {\r
696         node.setParent(dir);\r
697         node.setLeft(tmp);\r
698       } else if (dir == node.right()) {\r
699         node.setParent(dir);\r
700         node.setRight(tmp);\r
701       }\r
702 \r
703     } else {\r
704       if (dir == node.left()) {\r
705         node.setParent(node.left());\r
706 \r
707         if (top.left() == node) {\r
708           node.setRight(top.right());\r
709         } else {\r
710           node.setRight(top.left());\r
711         }\r
712       } else {\r
713         node.setParent(node.right());\r
714 \r
715         if (top.left() == node) {\r
716           node.setLeft(top.right());\r
717         } else {\r
718           node.setLeft(top.left());\r
719         }\r
720       }\r
721     }\r
722   }\r
723     public void setMaxDist(SequenceNode node) {\r
724         this.maxdist = maxdist;\r
725     }\r
726     public SequenceNode getMaxDist() {\r
727         return maxdist;\r
728     }\r
729     public SequenceNode getTopNode() {\r
730         return top;\r
731     }\r
732 \r
733 }\r
734 \r
735 \r
736 \r
737 class Cluster {\r
738 \r
739   int[] value;\r
740 \r
741   public Cluster(int[] value) {\r
742     this.value = value;\r
743   }\r
744 \r
745 }\r
746 \r