46d2844cc6e59b122791f07cae5d7c82acb1646e
[jalview.git] / src / jalview / analysis / NJTree.java
1 package jalview.analysis;\r
2 \r
3 import jalview.datamodel.*;\r
4 import jalview.util.*;\r
5 import jalview.schemes.ResidueProperties;\r
6 import java.util.*;\r
7 \r
8 import jalview.io.NewickFile;\r
9 \r
10 public class NJTree {\r
11 \r
12   Vector cluster;\r
13   SequenceI[] sequence;\r
14 \r
15   int done[];\r
16   int noseqs;\r
17   int noClus;\r
18 \r
19   float distance[][];\r
20 \r
21   int mini;\r
22   int minj;\r
23   float ri;\r
24   float rj;\r
25 \r
26   Vector groups = new Vector();\r
27   SequenceNode maxdist;\r
28   SequenceNode top;\r
29 \r
30   float maxDistValue;\r
31   float maxheight;\r
32 \r
33   int ycount;\r
34 \r
35   Vector node;\r
36 \r
37   String type;\r
38   String pwtype;\r
39 \r
40   Object found = null;\r
41   Object leaves = null;\r
42 \r
43   int start;\r
44   int end;\r
45 \r
46   public NJTree(SequenceNode node) {\r
47     top = node;\r
48     maxheight = findHeight(top);\r
49   }\r
50 \r
51   public String toString()\r
52   {\r
53     jalview.io.NewickFile fout = new jalview.io.NewickFile(getTopNode());\r
54     return fout.print(false,true); // distances only\r
55   }\r
56 \r
57   public NJTree(SequenceI[] seqs, NewickFile treefile) {\r
58     top = treefile.getTree();\r
59     maxheight = findHeight(top);\r
60     SequenceIdMatcher algnIds = new SequenceIdMatcher(seqs);\r
61 \r
62     Vector leaves = new Vector();\r
63     findLeaves(top, leaves);\r
64 \r
65     int i = 0;\r
66     int namesleft = seqs.length;\r
67 \r
68     SequenceNode j;\r
69     SequenceI nam;\r
70     String realnam;\r
71     while (i < leaves.size())\r
72     {\r
73       j = (SequenceNode) leaves.elementAt(i++);\r
74       realnam = j.getName();\r
75       nam = null;\r
76       if (namesleft>-1)\r
77         nam = algnIds.findIdMatch(realnam);\r
78       if (nam != null) {\r
79         j.setElement(nam);\r
80         namesleft--;\r
81       } else {\r
82         j.setElement(new Sequence(realnam, "THISISAPLACEHLDER"));\r
83       }\r
84     }\r
85   }\r
86 \r
87   public NJTree(SequenceI[] sequence,int start, int end) {\r
88     this(sequence,"NJ","BL",start,end);\r
89   }\r
90 \r
91   public NJTree(SequenceI[] sequence,String type,String pwtype,int start, int end ) {\r
92 \r
93     this.sequence = sequence;\r
94     this.node     = new Vector();\r
95     this.type     = type;\r
96     this.pwtype   = pwtype;\r
97     this.start    = start;\r
98     this.end      = end;\r
99 \r
100     if (!(type.equals("NJ"))) {\r
101       type = "AV";\r
102     }\r
103 \r
104     if (!(pwtype.equals("PID"))) {\r
105       type = "BL";\r
106     }\r
107 \r
108     int i=0;\r
109 \r
110     done = new int[sequence.length];\r
111 \r
112 \r
113     while (i < sequence.length  && sequence[i] != null) {\r
114       done[i] = 0;\r
115       i++;\r
116     }\r
117 \r
118     noseqs = i++;\r
119 \r
120     distance = findDistances();\r
121 \r
122     makeLeaves();\r
123 \r
124     noClus = cluster.size();\r
125 \r
126     cluster();\r
127 \r
128   }\r
129 \r
130 \r
131   public void cluster() {\r
132 \r
133     while (noClus > 2) {\r
134       if (type.equals("NJ")) {\r
135         float mind = findMinNJDistance();\r
136       } else {\r
137         float mind = findMinDistance();\r
138       }\r
139 \r
140       Cluster c = joinClusters(mini,minj);\r
141 \r
142 \r
143       done[minj] = 1;\r
144 \r
145       cluster.setElementAt(null,minj);\r
146       cluster.setElementAt(c,mini);\r
147 \r
148       noClus--;\r
149     }\r
150 \r
151     boolean onefound = false;\r
152 \r
153     int one = -1;\r
154     int two = -1;\r
155 \r
156     for (int i=0; i < noseqs; i++) {\r
157       if (done[i] != 1) {\r
158         if (onefound == false) {\r
159           two = i;\r
160           onefound = true;\r
161         } else {\r
162           one = i;\r
163         }\r
164       }\r
165     }\r
166 \r
167     Cluster c = joinClusters(one,two);\r
168     top = (SequenceNode)(node.elementAt(one));\r
169 \r
170     reCount(top);\r
171     findHeight(top);\r
172     findMaxDist(top);\r
173 \r
174   }\r
175 \r
176   public Cluster joinClusters(int i, int j) {\r
177 \r
178     float dist = distance[i][j];\r
179 \r
180     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
181     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
182 \r
183     int[] value = new int[noi + noj];\r
184 \r
185     for (int ii = 0; ii < noi;ii++) {\r
186       value[ii] =  ((Cluster)cluster.elementAt(i)).value[ii];\r
187     }\r
188 \r
189     for (int ii = noi; ii < noi+ noj;ii++) {\r
190       value[ii] =  ((Cluster)cluster.elementAt(j)).value[ii-noi];\r
191     }\r
192 \r
193     Cluster c = new Cluster(value);\r
194 \r
195     ri = findr(i,j);\r
196     rj = findr(j,i);\r
197 \r
198     if (type.equals("NJ")) {\r
199       findClusterNJDistance(i,j);\r
200     } else {\r
201       findClusterDistance(i,j);\r
202     }\r
203 \r
204     SequenceNode sn = new SequenceNode();\r
205 \r
206     sn.setLeft((SequenceNode)(node.elementAt(i)));\r
207     sn.setRight((SequenceNode)(node.elementAt(j)));\r
208 \r
209     SequenceNode tmpi = (SequenceNode)(node.elementAt(i));\r
210     SequenceNode tmpj = (SequenceNode)(node.elementAt(j));\r
211 \r
212     if (type.equals("NJ")) {\r
213       findNewNJDistances(tmpi,tmpj,dist);\r
214     } else {\r
215       findNewDistances(tmpi,tmpj,dist);\r
216     }\r
217 \r
218     tmpi.setParent(sn);\r
219     tmpj.setParent(sn);\r
220 \r
221     node.setElementAt(sn,i);\r
222     return c;\r
223   }\r
224 \r
225   public void findNewNJDistances(SequenceNode tmpi, SequenceNode tmpj, float dist) {\r
226 \r
227     float ih = 0;\r
228     float jh = 0;\r
229 \r
230     SequenceNode sni = tmpi;\r
231     SequenceNode snj = tmpj;\r
232 \r
233     tmpi.dist = (dist + ri - rj)/2;\r
234     tmpj.dist = (dist - tmpi.dist);\r
235 \r
236     if (tmpi.dist < 0) {\r
237       tmpi.dist = 0;\r
238     }\r
239     if (tmpj.dist < 0) {\r
240       tmpj.dist = 0;\r
241     }\r
242   }\r
243 \r
244   public void findNewDistances(SequenceNode tmpi,SequenceNode tmpj,float dist) {\r
245 \r
246     float ih = 0;\r
247     float jh = 0;\r
248 \r
249     SequenceNode sni = tmpi;\r
250     SequenceNode snj = tmpj;\r
251 \r
252     while (sni != null) {\r
253       ih = ih + sni.dist;\r
254       sni = (SequenceNode)sni.left();\r
255     }\r
256 \r
257     while (snj != null) {\r
258       jh = jh + snj.dist;\r
259       snj = (SequenceNode)snj.left();\r
260     }\r
261 \r
262     tmpi.dist = (dist/2 - ih);\r
263     tmpj.dist = (dist/2 - jh);\r
264   }\r
265 \r
266 \r
267 \r
268   public void findClusterDistance(int i, int j) {\r
269 \r
270     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
271     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
272 \r
273     // New distances from cluster to others\r
274     float[] newdist = new float[noseqs];\r
275 \r
276     for (int l = 0; l < noseqs; l++) {\r
277       if ( l != i && l != j) {\r
278         newdist[l] = (distance[i][l] * noi + distance[j][l] * noj)/(noi + noj);\r
279       } else {\r
280         newdist[l] = 0;\r
281       }\r
282     }\r
283 \r
284     for (int ii=0; ii < noseqs;ii++) {\r
285       distance[i][ii] = newdist[ii];\r
286       distance[ii][i] = newdist[ii];\r
287     }\r
288   }\r
289 \r
290   public void findClusterNJDistance(int i, int j) {\r
291 \r
292     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
293     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
294 \r
295     // New distances from cluster to others\r
296     float[] newdist = new float[noseqs];\r
297 \r
298     for (int l = 0; l < noseqs; l++) {\r
299       if ( l != i && l != j) {\r
300         newdist[l] = (distance[i][l] + distance[j][l] - distance[i][j])/2;\r
301       } else {\r
302         newdist[l] = 0;\r
303       }\r
304     }\r
305 \r
306     for (int ii=0; ii < noseqs;ii++) {\r
307       distance[i][ii] = newdist[ii];\r
308       distance[ii][i] = newdist[ii];\r
309     }\r
310   }\r
311 \r
312   public float findr(int i, int j) {\r
313 \r
314     float tmp = 1;\r
315     for (int k=0; k < noseqs;k++) {\r
316       if (k!= i && k!= j && done[k] != 1) {\r
317         tmp = tmp + distance[i][k];\r
318       }\r
319     }\r
320 \r
321     if (noClus > 2) {\r
322       tmp = tmp/(noClus - 2);\r
323     }\r
324 \r
325     return tmp;\r
326   }\r
327 \r
328   public float findMinNJDistance() {\r
329 \r
330     float min = 100000;\r
331 \r
332     for (int i=0; i < noseqs-1; i++) {\r
333       for (int j=i+1;j < noseqs;j++) {\r
334         if (done[i] != 1 && done[j] != 1) {\r
335           float tmp = distance[i][j] - (findr(i,j) + findr(j,i));\r
336           if (tmp < min) {\r
337 \r
338             mini = i;\r
339             minj = j;\r
340 \r
341             min = tmp;\r
342 \r
343           }\r
344         }\r
345       }\r
346     }\r
347     return min;\r
348   }\r
349 \r
350   public float findMinDistance() {\r
351 \r
352     float min = 100000;\r
353 \r
354     for (int i=0; i < noseqs-1;i++) {\r
355       for (int j = i+1; j < noseqs;j++) {\r
356         if (done[i] != 1 && done[j] != 1) {\r
357           if (distance[i][j] < min) {\r
358             mini = i;\r
359             minj = j;\r
360 \r
361             min = distance[i][j];\r
362           }\r
363         }\r
364       }\r
365     }\r
366     return min;\r
367   }\r
368 \r
369   public float[][] findDistances() {\r
370 \r
371     float[][] distance = new float[noseqs][noseqs];\r
372     if (pwtype.equals("PID")) {\r
373       for (int i = 0; i < noseqs-1; i++) {\r
374         for (int j = i; j < noseqs; j++) {\r
375           if (j==i) {\r
376             distance[i][i] = 0;\r
377           } else {\r
378             distance[i][j] = 100-Comparison.PID(sequence[i], sequence[j]);\r
379             distance[j][i] = distance[i][j];\r
380           }\r
381         }\r
382       }\r
383     } else if (pwtype.equals("BL")) {\r
384       int   maxscore = 0;\r
385 \r
386       for (int i = 0; i < noseqs-1; i++) {\r
387         for (int j = i; j < noseqs; j++) {\r
388           int score = 0;\r
389           for (int k=0; k < sequence[i].getLength(); k++) {\r
390             try{\r
391               score +=\r
392                   ResidueProperties.getBLOSUM62(sequence[i].getSequence(k,\r
393                   k + 1),\r
394                                                 sequence[j].getSequence(k,\r
395                   k + 1));\r
396             }catch(Exception ex){System.out.println("err creating BLOSUM62 tree");}\r
397           }\r
398           distance[i][j] = (float)score;\r
399           if (score > maxscore) {\r
400             maxscore = score;\r
401           }\r
402         }\r
403       }\r
404       for (int i = 0; i < noseqs-1; i++) {\r
405         for (int j = i; j < noseqs; j++) {\r
406           distance[i][j] =  (float)maxscore - distance[i][j];\r
407           distance[j][i] = distance[i][j];\r
408         }\r
409       }\r
410     } else if (pwtype.equals("SW")) {\r
411       float max = -1;\r
412       for (int i = 0; i < noseqs-1; i++) {\r
413         for (int j = i; j < noseqs; j++) {\r
414           AlignSeq as = new AlignSeq(sequence[i],sequence[j],"pep");\r
415           as.calcScoreMatrix();\r
416           as.traceAlignment();\r
417           as.printAlignment();\r
418           distance[i][j] = (float)as.maxscore;\r
419           if (max < distance[i][j]) {\r
420             max = distance[i][j];\r
421           }\r
422         }\r
423       }\r
424       for (int i = 0; i < noseqs-1; i++) {\r
425         for (int j = i; j < noseqs; j++) {\r
426           distance[i][j] =  max - distance[i][j];\r
427           distance[j][i] = distance[i][j];\r
428         }\r
429       }\r
430     }\r
431 \r
432     return distance;\r
433   }\r
434 \r
435   public void makeLeaves() {\r
436     cluster = new Vector();\r
437 \r
438     for (int i=0; i < noseqs; i++) {\r
439       SequenceNode sn = new SequenceNode();\r
440 \r
441       sn.setElement(sequence[i]);\r
442       sn.setName(sequence[i].getName());\r
443       node.addElement(sn);\r
444 \r
445       int[] value = new int[1];\r
446       value[0] = i;\r
447 \r
448       Cluster c = new Cluster(value);\r
449       cluster.addElement(c);\r
450     }\r
451   }\r
452 \r
453   public Vector findLeaves(SequenceNode node, Vector leaves) {\r
454     if (node == null) {\r
455       return leaves;\r
456     }\r
457 \r
458     if (node.left() == null && node.right() == null) {\r
459       leaves.addElement(node);\r
460       return leaves;\r
461     } else {\r
462       findLeaves((SequenceNode)node.left(),leaves);\r
463       findLeaves((SequenceNode)node.right(),leaves);\r
464     }\r
465     return leaves;\r
466   }\r
467 \r
468   public Object findLeaf(SequenceNode node, int count) {\r
469     found = _findLeaf(node,count);\r
470 \r
471     return found;\r
472   }\r
473   public Object _findLeaf(SequenceNode node,int count) {\r
474     if (node == null) {\r
475       return null;\r
476     }\r
477     if (node.ycount == count) {\r
478       found = node.element();\r
479       return found;\r
480     } else {\r
481       _findLeaf((SequenceNode)node.left(),count);\r
482       _findLeaf((SequenceNode)node.right(),count);\r
483     }\r
484 \r
485     return found;\r
486   }\r
487 \r
488   public void printNode(SequenceNode node) {\r
489     if (node == null) {\r
490       return;\r
491     }\r
492     if (node.left() == null && node.right() == null) {\r
493       System.out.println("Leaf = " + ((SequenceI)node.element()).getName());\r
494       System.out.println("Dist " + ((SequenceNode)node).dist);\r
495       System.out.println("Boot " + node.getBootstrap());\r
496     } else {\r
497       System.out.println("Dist " + ((SequenceNode)node).dist);\r
498       printNode((SequenceNode)node.left());\r
499       printNode((SequenceNode)node.right());\r
500     }\r
501   }\r
502   public void findMaxDist(SequenceNode node) {\r
503     if (node == null) {\r
504       return;\r
505     }\r
506     if (node.left() == null && node.right() == null) {\r
507 \r
508       float dist = ((SequenceNode)node).dist;\r
509       if (dist > maxDistValue) {\r
510           maxdist      = (SequenceNode)node;\r
511           maxDistValue = dist;\r
512       }\r
513     } else {\r
514       findMaxDist((SequenceNode)node.left());\r
515       findMaxDist((SequenceNode)node.right());\r
516     }\r
517   }\r
518     public Vector getGroups() {\r
519         return groups;\r
520     }\r
521     public float getMaxHeight() {\r
522         return maxheight;\r
523     }\r
524   public void  groupNodes(SequenceNode node, float threshold) {\r
525     if (node == null) {\r
526       return;\r
527     }\r
528 \r
529     if (node.height/maxheight > threshold) {\r
530       groups.addElement(node);\r
531     } else {\r
532       groupNodes((SequenceNode)node.left(),threshold);\r
533       groupNodes((SequenceNode)node.right(),threshold);\r
534     }\r
535   }\r
536 \r
537   public float findHeight(SequenceNode node) {\r
538 \r
539     if (node == null) {\r
540       return maxheight;\r
541     }\r
542 \r
543     if (node.left() == null && node.right() == null) {\r
544       node.height = ((SequenceNode)node.parent()).height + node.dist;\r
545 \r
546       if (node.height > maxheight) {\r
547         return node.height;\r
548       } else {\r
549         return maxheight;\r
550       }\r
551     } else {\r
552       if (node.parent() != null) {\r
553         node.height = ((SequenceNode)node.parent()).height + node.dist;\r
554       } else {\r
555         maxheight = 0;\r
556         node.height = (float)0.0;\r
557       }\r
558 \r
559       maxheight = findHeight((SequenceNode)(node.left()));\r
560       maxheight = findHeight((SequenceNode)(node.right()));\r
561     }\r
562     return maxheight;\r
563   }\r
564   public SequenceNode reRoot() {\r
565     if (maxdist != null) {\r
566       ycount = 0;\r
567       float tmpdist = maxdist.dist;\r
568 \r
569       // New top\r
570       SequenceNode sn = new SequenceNode();\r
571       sn.setParent(null);\r
572 \r
573       // New right hand of top\r
574       SequenceNode snr = (SequenceNode)maxdist.parent();\r
575       changeDirection(snr,maxdist);\r
576       System.out.println("Printing reversed tree");\r
577       printN(snr);\r
578       snr.dist = tmpdist/2;\r
579       maxdist.dist = tmpdist/2;\r
580 \r
581       snr.setParent(sn);\r
582       maxdist.setParent(sn);\r
583 \r
584       sn.setRight(snr);\r
585       sn.setLeft(maxdist);\r
586 \r
587       top = sn;\r
588 \r
589       ycount = 0;\r
590       reCount(top);\r
591       findHeight(top);\r
592 \r
593     }\r
594     return top;\r
595   }\r
596   public static void printN(SequenceNode node) {\r
597     if (node == null) {\r
598       return;\r
599     }\r
600 \r
601     if (node.left() != null && node.right() != null) {\r
602       printN((SequenceNode)node.left());\r
603       printN((SequenceNode)node.right());\r
604     } else {\r
605       System.out.println(" name = " + ((SequenceI)node.element()).getName());\r
606     }\r
607     System.out.println(" dist = " + ((SequenceNode)node).dist + " " + ((SequenceNode)node).count + " " + ((SequenceNode)node).height);\r
608   }\r
609 \r
610     public void reCount(SequenceNode node) {\r
611         ycount = 0;\r
612         _reCount(node);\r
613     }\r
614   public void _reCount(SequenceNode node) {\r
615     if (node == null) {\r
616       return;\r
617     }\r
618 \r
619     if (node.left() != null && node.right() != null) {\r
620       _reCount((SequenceNode)node.left());\r
621       _reCount((SequenceNode)node.right());\r
622 \r
623       SequenceNode l = (SequenceNode)node.left();\r
624       SequenceNode r = (SequenceNode)node.right();\r
625 \r
626       ((SequenceNode)node).count  = l.count + r.count;\r
627       ((SequenceNode)node).ycount = (l.ycount + r.ycount)/2;\r
628 \r
629     } else {\r
630       ((SequenceNode)node).count = 1;\r
631       ((SequenceNode)node).ycount = ycount++;\r
632     }\r
633 \r
634   }\r
635     public void swapNodes(SequenceNode node) {\r
636         if (node == null) {\r
637             return;\r
638         }\r
639         SequenceNode tmp = (SequenceNode)node.left();\r
640 \r
641         node.setLeft(node.right());\r
642         node.setRight(tmp);\r
643     }\r
644   public void changeDirection(SequenceNode node, SequenceNode dir) {\r
645     if (node == null) {\r
646       return;\r
647     }\r
648     if (node.parent() != top) {\r
649       changeDirection((SequenceNode)node.parent(), node);\r
650 \r
651       SequenceNode tmp = (SequenceNode)node.parent();\r
652 \r
653       if (dir == node.left()) {\r
654         node.setParent(dir);\r
655         node.setLeft(tmp);\r
656       } else if (dir == node.right()) {\r
657         node.setParent(dir);\r
658         node.setRight(tmp);\r
659       }\r
660 \r
661     } else {\r
662       if (dir == node.left()) {\r
663         node.setParent(node.left());\r
664 \r
665         if (top.left() == node) {\r
666           node.setRight(top.right());\r
667         } else {\r
668           node.setRight(top.left());\r
669         }\r
670       } else {\r
671         node.setParent(node.right());\r
672 \r
673         if (top.left() == node) {\r
674           node.setLeft(top.right());\r
675         } else {\r
676           node.setLeft(top.left());\r
677         }\r
678       }\r
679     }\r
680   }\r
681     public void setMaxDist(SequenceNode node) {\r
682         this.maxdist = maxdist;\r
683     }\r
684     public SequenceNode getMaxDist() {\r
685         return maxdist;\r
686     }\r
687     public SequenceNode getTopNode() {\r
688         return top;\r
689     }\r
690 \r
691 }\r
692 \r
693 \r
694 \r
695 class Cluster {\r
696 \r
697   int[] value;\r
698 \r
699   public Cluster(int[] value) {\r
700     this.value = value;\r
701   }\r
702 \r
703 }\r
704 \r