various loadTree bits for loading and saving trees.
[jalview.git] / src / jalview / analysis / NJTree.java
1 package jalview.analysis;\r
2 \r
3 import jalview.datamodel.*;\r
4 import jalview.util.*;\r
5 import jalview.schemes.ResidueProperties;\r
6 import java.util.*;\r
7 \r
8 import jalview.io.NewickFile;\r
9 \r
10 public class NJTree {\r
11 \r
12   Vector cluster;\r
13   SequenceI[] sequence;\r
14 \r
15   int done[];\r
16   int noseqs;\r
17   int noClus;\r
18 \r
19   float distance[][];\r
20 \r
21   int mini;\r
22   int minj;\r
23   float ri;\r
24   float rj;\r
25 \r
26   Vector groups = new Vector();\r
27   SequenceNode maxdist;\r
28   SequenceNode top;\r
29 \r
30   float maxDistValue;\r
31   float maxheight;\r
32 \r
33   int ycount;\r
34 \r
35   Vector node;\r
36 \r
37   String type;\r
38   String pwtype;\r
39 \r
40   Object found = null;\r
41   Object leaves = null;\r
42 \r
43   int start;\r
44   int end;\r
45 \r
46   public NJTree(SequenceNode node) {\r
47     top = node;\r
48     maxheight = findHeight(top);\r
49 \r
50   }\r
51   // Private SequenceID class to do fuzzy .equals() method for Hashtable.\r
52 \r
53   private class SeqIdname {\r
54     String id;\r
55 \r
56     SeqIdname(String s) {\r
57       id = new String(s);\r
58     }\r
59     public int hashCode() {\r
60       return (id.substring(0,4).hashCode());\r
61     }\r
62     public boolean equals(Object s) {\r
63       if (s instanceof SeqIdname) {\r
64         return this.equals((SeqIdname) s);\r
65       } else {\r
66         if (s instanceof String) {\r
67           return this.equals((String) s);\r
68         }\r
69       }\r
70       return false;\r
71     }\r
72 \r
73 \r
74     public boolean equals(SeqIdname s) {\r
75       if (id.startsWith(s.id) || s.id.startsWith(id)) {\r
76         return true;\r
77       }\r
78       return false;\r
79     }\r
80 \r
81     public boolean equals(String s) {\r
82       if (id.startsWith(s) || s.startsWith(id)) {\r
83         return true;\r
84       }\r
85       return false;\r
86     }\r
87   }\r
88 \r
89   public NJTree(SequenceI[] seqs, NewickFile treefile) {\r
90     top = treefile.getTree();\r
91     maxheight = findHeight(top);\r
92     Hashtable names = new Hashtable();\r
93     for (int i = 0; i < seqs.length; i++)\r
94     {\r
95       names.put(new SeqIdname(seqs[i].getDisplayId()), seqs[i]);\r
96     }\r
97     Vector leaves = new Vector();\r
98     findLeaves(top, leaves);\r
99     int i = 0;\r
100     int namesleft = seqs.length;\r
101     SequenceNode j;\r
102     SeqIdname nam;\r
103     while (i < leaves.size())\r
104     {\r
105       j = (SequenceNode) leaves.elementAt(i++);\r
106       nam = new SeqIdname(j.getName());\r
107       if ((namesleft>-1)\r
108           && names.containsKey(nam))\r
109       {\r
110         j.setElement(names.get(nam));\r
111         namesleft--;\r
112       } else {\r
113         j.setElement(new Sequence(nam.id, "THISISAPLACEHLDER"));\r
114       }\r
115     }\r
116   }\r
117 \r
118   public NJTree(SequenceI[] sequence,int start, int end) {\r
119     this(sequence,"NJ","BL",start,end);\r
120   }\r
121 \r
122   public NJTree(SequenceI[] sequence,String type,String pwtype,int start, int end ) {\r
123 \r
124     this.sequence = sequence;\r
125     this.node     = new Vector();\r
126     this.type     = type;\r
127     this.pwtype   = pwtype;\r
128     this.start    = start;\r
129     this.end      = end;\r
130 \r
131     if (!(type.equals("NJ"))) {\r
132       type = "AV";\r
133     }\r
134 \r
135     if (!(pwtype.equals("PID"))) {\r
136       type = "BL";\r
137     }\r
138 \r
139     int i=0;\r
140 \r
141     done = new int[sequence.length];\r
142 \r
143 \r
144     while (i < sequence.length  && sequence[i] != null) {\r
145       done[i] = 0;\r
146       i++;\r
147     }\r
148 \r
149     noseqs = i++;\r
150 \r
151     distance = findDistances();\r
152 \r
153     makeLeaves();\r
154 \r
155     noClus = cluster.size();\r
156 \r
157     cluster();\r
158 \r
159   }\r
160 \r
161 \r
162   public void cluster() {\r
163 \r
164     while (noClus > 2) {\r
165       if (type.equals("NJ")) {\r
166         float mind = findMinNJDistance();\r
167       } else {\r
168         float mind = findMinDistance();\r
169       }\r
170 \r
171       Cluster c = joinClusters(mini,minj);\r
172 \r
173 \r
174       done[minj] = 1;\r
175 \r
176       cluster.setElementAt(null,minj);\r
177       cluster.setElementAt(c,mini);\r
178 \r
179       noClus--;\r
180     }\r
181 \r
182     boolean onefound = false;\r
183 \r
184     int one = -1;\r
185     int two = -1;\r
186 \r
187     for (int i=0; i < noseqs; i++) {\r
188       if (done[i] != 1) {\r
189         if (onefound == false) {\r
190           two = i;\r
191           onefound = true;\r
192         } else {\r
193           one = i;\r
194         }\r
195       }\r
196     }\r
197 \r
198     Cluster c = joinClusters(one,two);\r
199     top = (SequenceNode)(node.elementAt(one));\r
200 \r
201     reCount(top);\r
202     findHeight(top);\r
203     findMaxDist(top);\r
204 \r
205   }\r
206 \r
207   public Cluster joinClusters(int i, int j) {\r
208 \r
209     float dist = distance[i][j];\r
210 \r
211     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
212     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
213 \r
214     int[] value = new int[noi + noj];\r
215 \r
216     for (int ii = 0; ii < noi;ii++) {\r
217       value[ii] =  ((Cluster)cluster.elementAt(i)).value[ii];\r
218     }\r
219 \r
220     for (int ii = noi; ii < noi+ noj;ii++) {\r
221       value[ii] =  ((Cluster)cluster.elementAt(j)).value[ii-noi];\r
222     }\r
223 \r
224     Cluster c = new Cluster(value);\r
225 \r
226     ri = findr(i,j);\r
227     rj = findr(j,i);\r
228 \r
229     if (type.equals("NJ")) {\r
230       findClusterNJDistance(i,j);\r
231     } else {\r
232       findClusterDistance(i,j);\r
233     }\r
234 \r
235     SequenceNode sn = new SequenceNode();\r
236 \r
237     sn.setLeft((SequenceNode)(node.elementAt(i)));\r
238     sn.setRight((SequenceNode)(node.elementAt(j)));\r
239 \r
240     SequenceNode tmpi = (SequenceNode)(node.elementAt(i));\r
241     SequenceNode tmpj = (SequenceNode)(node.elementAt(j));\r
242 \r
243     if (type.equals("NJ")) {\r
244       findNewNJDistances(tmpi,tmpj,dist);\r
245     } else {\r
246       findNewDistances(tmpi,tmpj,dist);\r
247     }\r
248 \r
249     tmpi.setParent(sn);\r
250     tmpj.setParent(sn);\r
251 \r
252     node.setElementAt(sn,i);\r
253     return c;\r
254   }\r
255 \r
256   public void findNewNJDistances(SequenceNode tmpi, SequenceNode tmpj, float dist) {\r
257 \r
258     float ih = 0;\r
259     float jh = 0;\r
260 \r
261     SequenceNode sni = tmpi;\r
262     SequenceNode snj = tmpj;\r
263 \r
264     tmpi.dist = (dist + ri - rj)/2;\r
265     tmpj.dist = (dist - tmpi.dist);\r
266 \r
267     if (tmpi.dist < 0) {\r
268       tmpi.dist = 0;\r
269     }\r
270     if (tmpj.dist < 0) {\r
271       tmpj.dist = 0;\r
272     }\r
273   }\r
274 \r
275   public void findNewDistances(SequenceNode tmpi,SequenceNode tmpj,float dist) {\r
276 \r
277     float ih = 0;\r
278     float jh = 0;\r
279 \r
280     SequenceNode sni = tmpi;\r
281     SequenceNode snj = tmpj;\r
282 \r
283     while (sni != null) {\r
284       ih = ih + sni.dist;\r
285       sni = (SequenceNode)sni.left();\r
286     }\r
287 \r
288     while (snj != null) {\r
289       jh = jh + snj.dist;\r
290       snj = (SequenceNode)snj.left();\r
291     }\r
292 \r
293     tmpi.dist = (dist/2 - ih);\r
294     tmpj.dist = (dist/2 - jh);\r
295   }\r
296 \r
297 \r
298 \r
299   public void findClusterDistance(int i, int j) {\r
300 \r
301     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
302     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
303 \r
304     // New distances from cluster to others\r
305     float[] newdist = new float[noseqs];\r
306 \r
307     for (int l = 0; l < noseqs; l++) {\r
308       if ( l != i && l != j) {\r
309         newdist[l] = (distance[i][l] * noi + distance[j][l] * noj)/(noi + noj);\r
310       } else {\r
311         newdist[l] = 0;\r
312       }\r
313     }\r
314 \r
315     for (int ii=0; ii < noseqs;ii++) {\r
316       distance[i][ii] = newdist[ii];\r
317       distance[ii][i] = newdist[ii];\r
318     }\r
319   }\r
320 \r
321   public void findClusterNJDistance(int i, int j) {\r
322 \r
323     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
324     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
325 \r
326     // New distances from cluster to others\r
327     float[] newdist = new float[noseqs];\r
328 \r
329     for (int l = 0; l < noseqs; l++) {\r
330       if ( l != i && l != j) {\r
331         newdist[l] = (distance[i][l] + distance[j][l] - distance[i][j])/2;\r
332       } else {\r
333         newdist[l] = 0;\r
334       }\r
335     }\r
336 \r
337     for (int ii=0; ii < noseqs;ii++) {\r
338       distance[i][ii] = newdist[ii];\r
339       distance[ii][i] = newdist[ii];\r
340     }\r
341   }\r
342 \r
343   public float findr(int i, int j) {\r
344 \r
345     float tmp = 1;\r
346     for (int k=0; k < noseqs;k++) {\r
347       if (k!= i && k!= j && done[k] != 1) {\r
348         tmp = tmp + distance[i][k];\r
349       }\r
350     }\r
351 \r
352     if (noClus > 2) {\r
353       tmp = tmp/(noClus - 2);\r
354     }\r
355 \r
356     return tmp;\r
357   }\r
358 \r
359   public float findMinNJDistance() {\r
360 \r
361     float min = 100000;\r
362 \r
363     for (int i=0; i < noseqs-1; i++) {\r
364       for (int j=i+1;j < noseqs;j++) {\r
365         if (done[i] != 1 && done[j] != 1) {\r
366           float tmp = distance[i][j] - (findr(i,j) + findr(j,i));\r
367           if (tmp < min) {\r
368 \r
369             mini = i;\r
370             minj = j;\r
371 \r
372             min = tmp;\r
373 \r
374           }\r
375         }\r
376       }\r
377     }\r
378     return min;\r
379   }\r
380 \r
381   public float findMinDistance() {\r
382 \r
383     float min = 100000;\r
384 \r
385     for (int i=0; i < noseqs-1;i++) {\r
386       for (int j = i+1; j < noseqs;j++) {\r
387         if (done[i] != 1 && done[j] != 1) {\r
388           if (distance[i][j] < min) {\r
389             mini = i;\r
390             minj = j;\r
391 \r
392             min = distance[i][j];\r
393           }\r
394         }\r
395       }\r
396     }\r
397     return min;\r
398   }\r
399 \r
400   public float[][] findDistances() {\r
401 \r
402     float[][] distance = new float[noseqs][noseqs];\r
403     if (pwtype.equals("PID")) {\r
404       for (int i = 0; i < noseqs-1; i++) {\r
405         for (int j = i; j < noseqs; j++) {\r
406           if (j==i) {\r
407             distance[i][i] = 0;\r
408           } else {\r
409             distance[i][j] = 100-Comparison.PID(sequence[i], sequence[j]);\r
410             distance[j][i] = distance[i][j];\r
411           }\r
412         }\r
413       }\r
414     } else if (pwtype.equals("BL")) {\r
415       int   maxscore = 0;\r
416 \r
417       for (int i = 0; i < noseqs-1; i++) {\r
418         for (int j = i; j < noseqs; j++) {\r
419           int score = 0;\r
420           for (int k=0; k < sequence[i].getLength(); k++) {\r
421             try{\r
422               score +=\r
423                   ResidueProperties.getBLOSUM62(sequence[i].getSequence(k,\r
424                   k + 1),\r
425                                                 sequence[j].getSequence(k,\r
426                   k + 1));\r
427             }catch(Exception ex){System.out.println("err creating BLOSUM62 tree");}\r
428           }\r
429           distance[i][j] = (float)score;\r
430           if (score > maxscore) {\r
431             maxscore = score;\r
432           }\r
433         }\r
434       }\r
435       for (int i = 0; i < noseqs-1; i++) {\r
436         for (int j = i; j < noseqs; j++) {\r
437           distance[i][j] =  (float)maxscore - distance[i][j];\r
438           distance[j][i] = distance[i][j];\r
439         }\r
440       }\r
441     } else if (pwtype.equals("SW")) {\r
442       float max = -1;\r
443       for (int i = 0; i < noseqs-1; i++) {\r
444         for (int j = i; j < noseqs; j++) {\r
445           AlignSeq as = new AlignSeq(sequence[i],sequence[j],"pep");\r
446           as.calcScoreMatrix();\r
447           as.traceAlignment();\r
448           as.printAlignment();\r
449           distance[i][j] = (float)as.maxscore;\r
450           if (max < distance[i][j]) {\r
451             max = distance[i][j];\r
452           }\r
453         }\r
454       }\r
455       for (int i = 0; i < noseqs-1; i++) {\r
456         for (int j = i; j < noseqs; j++) {\r
457           distance[i][j] =  max - distance[i][j];\r
458           distance[j][i] = distance[i][j];\r
459         }\r
460       }\r
461     }\r
462 \r
463     return distance;\r
464   }\r
465 \r
466   public void makeLeaves() {\r
467     cluster = new Vector();\r
468 \r
469     for (int i=0; i < noseqs; i++) {\r
470       SequenceNode sn = new SequenceNode();\r
471 \r
472       sn.setElement(sequence[i]);\r
473       sn.setName(sequence[i].getName());\r
474       node.addElement(sn);\r
475 \r
476       int[] value = new int[1];\r
477       value[0] = i;\r
478 \r
479       Cluster c = new Cluster(value);\r
480       cluster.addElement(c);\r
481     }\r
482   }\r
483 \r
484   public Vector findLeaves(SequenceNode node, Vector leaves) {\r
485     if (node == null) {\r
486       return leaves;\r
487     }\r
488 \r
489     if (node.left() == null && node.right() == null) {\r
490       leaves.addElement(node);\r
491       return leaves;\r
492     } else {\r
493       findLeaves((SequenceNode)node.left(),leaves);\r
494       findLeaves((SequenceNode)node.right(),leaves);\r
495     }\r
496     return leaves;\r
497   }\r
498 \r
499   public Object findLeaf(SequenceNode node, int count) {\r
500     found = _findLeaf(node,count);\r
501 \r
502     return found;\r
503   }\r
504   public Object _findLeaf(SequenceNode node,int count) {\r
505     if (node == null) {\r
506       return null;\r
507     }\r
508     if (node.ycount == count) {\r
509       found = node.element();\r
510       return found;\r
511     } else {\r
512       _findLeaf((SequenceNode)node.left(),count);\r
513       _findLeaf((SequenceNode)node.right(),count);\r
514     }\r
515 \r
516     return found;\r
517   }\r
518 \r
519   public void printNode(SequenceNode node) {\r
520     if (node == null) {\r
521       return;\r
522     }\r
523     if (node.left() == null && node.right() == null) {\r
524       System.out.println("Leaf = " + ((SequenceI)node.element()).getName());\r
525       System.out.println("Dist " + ((SequenceNode)node).dist);\r
526       System.out.println("Boot " + node.getBootstrap());\r
527     } else {\r
528       System.out.println("Dist " + ((SequenceNode)node).dist);\r
529       printNode((SequenceNode)node.left());\r
530       printNode((SequenceNode)node.right());\r
531     }\r
532   }\r
533   public void findMaxDist(SequenceNode node) {\r
534     if (node == null) {\r
535       return;\r
536     }\r
537     if (node.left() == null && node.right() == null) {\r
538 \r
539       float dist = ((SequenceNode)node).dist;\r
540       if (dist > maxDistValue) {\r
541           maxdist      = (SequenceNode)node;\r
542           maxDistValue = dist;\r
543       }\r
544     } else {\r
545       findMaxDist((SequenceNode)node.left());\r
546       findMaxDist((SequenceNode)node.right());\r
547     }\r
548   }\r
549     public Vector getGroups() {\r
550         return groups;\r
551     }\r
552     public float getMaxHeight() {\r
553         return maxheight;\r
554     }\r
555   public void  groupNodes(SequenceNode node, float threshold) {\r
556     if (node == null) {\r
557       return;\r
558     }\r
559 \r
560     if (node.height/maxheight > threshold) {\r
561       groups.addElement(node);\r
562     } else {\r
563       groupNodes((SequenceNode)node.left(),threshold);\r
564       groupNodes((SequenceNode)node.right(),threshold);\r
565     }\r
566   }\r
567 \r
568   public float findHeight(SequenceNode node) {\r
569 \r
570     if (node == null) {\r
571       return maxheight;\r
572     }\r
573 \r
574     if (node.left() == null && node.right() == null) {\r
575       node.height = ((SequenceNode)node.parent()).height + node.dist;\r
576 \r
577       if (node.height > maxheight) {\r
578         return node.height;\r
579       } else {\r
580         return maxheight;\r
581       }\r
582     } else {\r
583       if (node.parent() != null) {\r
584         node.height = ((SequenceNode)node.parent()).height + node.dist;\r
585       } else {\r
586         maxheight = 0;\r
587         node.height = (float)0.0;\r
588       }\r
589 \r
590       maxheight = findHeight((SequenceNode)(node.left()));\r
591       maxheight = findHeight((SequenceNode)(node.right()));\r
592     }\r
593     return maxheight;\r
594   }\r
595   public SequenceNode reRoot() {\r
596     if (maxdist != null) {\r
597       ycount = 0;\r
598       float tmpdist = maxdist.dist;\r
599 \r
600       // New top\r
601       SequenceNode sn = new SequenceNode();\r
602       sn.setParent(null);\r
603 \r
604       // New right hand of top\r
605       SequenceNode snr = (SequenceNode)maxdist.parent();\r
606       changeDirection(snr,maxdist);\r
607       System.out.println("Printing reversed tree");\r
608       printN(snr);\r
609       snr.dist = tmpdist/2;\r
610       maxdist.dist = tmpdist/2;\r
611 \r
612       snr.setParent(sn);\r
613       maxdist.setParent(sn);\r
614 \r
615       sn.setRight(snr);\r
616       sn.setLeft(maxdist);\r
617 \r
618       top = sn;\r
619 \r
620       ycount = 0;\r
621       reCount(top);\r
622       findHeight(top);\r
623 \r
624     }\r
625     return top;\r
626   }\r
627   public static void printN(SequenceNode node) {\r
628     if (node == null) {\r
629       return;\r
630     }\r
631 \r
632     if (node.left() != null && node.right() != null) {\r
633       printN((SequenceNode)node.left());\r
634       printN((SequenceNode)node.right());\r
635     } else {\r
636       System.out.println(" name = " + ((SequenceI)node.element()).getName());\r
637     }\r
638     System.out.println(" dist = " + ((SequenceNode)node).dist + " " + ((SequenceNode)node).count + " " + ((SequenceNode)node).height);\r
639   }\r
640 \r
641     public void reCount(SequenceNode node) {\r
642         ycount = 0;\r
643         _reCount(node);\r
644     }\r
645   public void _reCount(SequenceNode node) {\r
646     if (node == null) {\r
647       return;\r
648     }\r
649 \r
650     if (node.left() != null && node.right() != null) {\r
651       _reCount((SequenceNode)node.left());\r
652       _reCount((SequenceNode)node.right());\r
653 \r
654       SequenceNode l = (SequenceNode)node.left();\r
655       SequenceNode r = (SequenceNode)node.right();\r
656 \r
657       ((SequenceNode)node).count  = l.count + r.count;\r
658       ((SequenceNode)node).ycount = (l.ycount + r.ycount)/2;\r
659 \r
660     } else {\r
661       ((SequenceNode)node).count = 1;\r
662       ((SequenceNode)node).ycount = ycount++;\r
663     }\r
664 \r
665   }\r
666     public void swapNodes(SequenceNode node) {\r
667         if (node == null) {\r
668             return;\r
669         }\r
670         SequenceNode tmp = (SequenceNode)node.left();\r
671 \r
672         node.setLeft(node.right());\r
673         node.setRight(tmp);\r
674     }\r
675   public void changeDirection(SequenceNode node, SequenceNode dir) {\r
676     if (node == null) {\r
677       return;\r
678     }\r
679     if (node.parent() != top) {\r
680       changeDirection((SequenceNode)node.parent(), node);\r
681 \r
682       SequenceNode tmp = (SequenceNode)node.parent();\r
683 \r
684       if (dir == node.left()) {\r
685         node.setParent(dir);\r
686         node.setLeft(tmp);\r
687       } else if (dir == node.right()) {\r
688         node.setParent(dir);\r
689         node.setRight(tmp);\r
690       }\r
691 \r
692     } else {\r
693       if (dir == node.left()) {\r
694         node.setParent(node.left());\r
695 \r
696         if (top.left() == node) {\r
697           node.setRight(top.right());\r
698         } else {\r
699           node.setRight(top.left());\r
700         }\r
701       } else {\r
702         node.setParent(node.right());\r
703 \r
704         if (top.left() == node) {\r
705           node.setLeft(top.right());\r
706         } else {\r
707           node.setLeft(top.left());\r
708         }\r
709       }\r
710     }\r
711   }\r
712     public void setMaxDist(SequenceNode node) {\r
713         this.maxdist = maxdist;\r
714     }\r
715     public SequenceNode getMaxDist() {\r
716         return maxdist;\r
717     }\r
718     public SequenceNode getTopNode() {\r
719         return top;\r
720     }\r
721 \r
722 }\r
723 \r
724 \r
725 \r
726 class Cluster {\r
727 \r
728   int[] value;\r
729 \r
730   public Cluster(int[] value) {\r
731     this.value = value;\r
732   }\r
733 \r
734 }\r
735 \r