added PropertyChangeListeners for alignviewport.alignment.sequences
[jalview.git] / src / jalview / analysis / NJTree.java
1 package jalview.analysis;\r
2 \r
3 import jalview.datamodel.*;\r
4 import jalview.util.*;\r
5 import jalview.schemes.ResidueProperties;\r
6 import java.util.*;\r
7 \r
8 import jalview.io.NewickFile;\r
9 \r
10 public class NJTree {\r
11 \r
12   Vector cluster;\r
13   SequenceI[] sequence;\r
14 \r
15   int done[];\r
16   int noseqs;\r
17   int noClus;\r
18 \r
19   float distance[][];\r
20 \r
21   int mini;\r
22   int minj;\r
23   float ri;\r
24   float rj;\r
25 \r
26   Vector groups = new Vector();\r
27   SequenceNode maxdist;\r
28   SequenceNode top;\r
29 \r
30   float maxDistValue;\r
31   float maxheight;\r
32 \r
33   int ycount;\r
34 \r
35   Vector node;\r
36 \r
37   String type;\r
38   String pwtype;\r
39 \r
40   Object found = null;\r
41   Object leaves = null;\r
42 \r
43   int start;\r
44   int end;\r
45 \r
46   public NJTree(SequenceNode node) {\r
47     top = node;\r
48     maxheight = findHeight(top);\r
49   }\r
50 \r
51   public String toString()\r
52   {\r
53     jalview.io.NewickFile fout = new jalview.io.NewickFile(getTopNode());\r
54     return fout.print(false,true); // distances only\r
55   }\r
56 \r
57   public NJTree(SequenceI[] seqs, NewickFile treefile) {\r
58     top = treefile.getTree();\r
59     maxheight = findHeight(top);\r
60     SequenceIdMatcher algnIds = new SequenceIdMatcher(seqs);\r
61 \r
62     Vector leaves = new Vector();\r
63     findLeaves(top, leaves);\r
64 \r
65     int i = 0;\r
66     int namesleft = seqs.length;\r
67 \r
68     SequenceNode j;\r
69     SequenceI nam;\r
70     String realnam;\r
71     while (i < leaves.size())\r
72     {\r
73       j = (SequenceNode) leaves.elementAt(i++);\r
74       realnam = j.getName();\r
75       nam = null;\r
76       if (namesleft>-1)\r
77         nam = algnIds.findIdMatch(realnam);\r
78       if (nam != null) {\r
79         j.setElement(nam);\r
80         namesleft--;\r
81       } else {\r
82         j.setElement(new Sequence(realnam, "THISISAPLACEHLDER"));\r
83         j.setPlaceholder(true);\r
84 \r
85       }\r
86     }\r
87   }\r
88 \r
89   /**\r
90    *\r
91    * used when the alignment associated to a tree has changed.\r
92    *\r
93    * @param alignment Vector\r
94    */\r
95   public void UpdatePlaceHolders(Vector alignment) {\r
96     Vector leaves = new Vector();\r
97     findLeaves(top, leaves);\r
98     int sz = leaves.size();\r
99     int i=0;\r
100     while (i<sz) {\r
101       SequenceNode leaf = (SequenceNode) leaves.elementAt(i++);\r
102       if (alignment.contains(leaf.element()))\r
103         leaf.setPlaceholder(false);\r
104       else\r
105         leaf.setPlaceholder(true);\r
106     }\r
107   }\r
108 \r
109   public NJTree(SequenceI[] sequence,int start, int end) {\r
110     this(sequence,"NJ","BL",start,end);\r
111   }\r
112 \r
113   public NJTree(SequenceI[] sequence,String type,String pwtype,int start, int end ) {\r
114 \r
115     this.sequence = sequence;\r
116     this.node     = new Vector();\r
117     this.type     = type;\r
118     this.pwtype   = pwtype;\r
119     this.start    = start;\r
120     this.end      = end;\r
121 \r
122     if (!(type.equals("NJ"))) {\r
123       type = "AV";\r
124     }\r
125 \r
126     if (!(pwtype.equals("PID"))) {\r
127       type = "BL";\r
128     }\r
129 \r
130     int i=0;\r
131 \r
132     done = new int[sequence.length];\r
133 \r
134 \r
135     while (i < sequence.length  && sequence[i] != null) {\r
136       done[i] = 0;\r
137       i++;\r
138     }\r
139 \r
140     noseqs = i++;\r
141 \r
142     distance = findDistances();\r
143 \r
144     makeLeaves();\r
145 \r
146     noClus = cluster.size();\r
147 \r
148     cluster();\r
149 \r
150   }\r
151 \r
152 \r
153   public void cluster() {\r
154 \r
155     while (noClus > 2) {\r
156       if (type.equals("NJ")) {\r
157         float mind = findMinNJDistance();\r
158       } else {\r
159         float mind = findMinDistance();\r
160       }\r
161 \r
162       Cluster c = joinClusters(mini,minj);\r
163 \r
164 \r
165       done[minj] = 1;\r
166 \r
167       cluster.setElementAt(null,minj);\r
168       cluster.setElementAt(c,mini);\r
169 \r
170       noClus--;\r
171     }\r
172 \r
173     boolean onefound = false;\r
174 \r
175     int one = -1;\r
176     int two = -1;\r
177 \r
178     for (int i=0; i < noseqs; i++) {\r
179       if (done[i] != 1) {\r
180         if (onefound == false) {\r
181           two = i;\r
182           onefound = true;\r
183         } else {\r
184           one = i;\r
185         }\r
186       }\r
187     }\r
188 \r
189     Cluster c = joinClusters(one,two);\r
190     top = (SequenceNode)(node.elementAt(one));\r
191 \r
192     reCount(top);\r
193     findHeight(top);\r
194     findMaxDist(top);\r
195 \r
196   }\r
197 \r
198   public Cluster joinClusters(int i, int j) {\r
199 \r
200     float dist = distance[i][j];\r
201 \r
202     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
203     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
204 \r
205     int[] value = new int[noi + noj];\r
206 \r
207     for (int ii = 0; ii < noi;ii++) {\r
208       value[ii] =  ((Cluster)cluster.elementAt(i)).value[ii];\r
209     }\r
210 \r
211     for (int ii = noi; ii < noi+ noj;ii++) {\r
212       value[ii] =  ((Cluster)cluster.elementAt(j)).value[ii-noi];\r
213     }\r
214 \r
215     Cluster c = new Cluster(value);\r
216 \r
217     ri = findr(i,j);\r
218     rj = findr(j,i);\r
219 \r
220     if (type.equals("NJ")) {\r
221       findClusterNJDistance(i,j);\r
222     } else {\r
223       findClusterDistance(i,j);\r
224     }\r
225 \r
226     SequenceNode sn = new SequenceNode();\r
227 \r
228     sn.setLeft((SequenceNode)(node.elementAt(i)));\r
229     sn.setRight((SequenceNode)(node.elementAt(j)));\r
230 \r
231     SequenceNode tmpi = (SequenceNode)(node.elementAt(i));\r
232     SequenceNode tmpj = (SequenceNode)(node.elementAt(j));\r
233 \r
234     if (type.equals("NJ")) {\r
235       findNewNJDistances(tmpi,tmpj,dist);\r
236     } else {\r
237       findNewDistances(tmpi,tmpj,dist);\r
238     }\r
239 \r
240     tmpi.setParent(sn);\r
241     tmpj.setParent(sn);\r
242 \r
243     node.setElementAt(sn,i);\r
244     return c;\r
245   }\r
246 \r
247   public void findNewNJDistances(SequenceNode tmpi, SequenceNode tmpj, float dist) {\r
248 \r
249     float ih = 0;\r
250     float jh = 0;\r
251 \r
252     SequenceNode sni = tmpi;\r
253     SequenceNode snj = tmpj;\r
254 \r
255     tmpi.dist = (dist + ri - rj)/2;\r
256     tmpj.dist = (dist - tmpi.dist);\r
257 \r
258     if (tmpi.dist < 0) {\r
259       tmpi.dist = 0;\r
260     }\r
261     if (tmpj.dist < 0) {\r
262       tmpj.dist = 0;\r
263     }\r
264   }\r
265 \r
266   public void findNewDistances(SequenceNode tmpi,SequenceNode tmpj,float dist) {\r
267 \r
268     float ih = 0;\r
269     float jh = 0;\r
270 \r
271     SequenceNode sni = tmpi;\r
272     SequenceNode snj = tmpj;\r
273 \r
274     while (sni != null) {\r
275       ih = ih + sni.dist;\r
276       sni = (SequenceNode)sni.left();\r
277     }\r
278 \r
279     while (snj != null) {\r
280       jh = jh + snj.dist;\r
281       snj = (SequenceNode)snj.left();\r
282     }\r
283 \r
284     tmpi.dist = (dist/2 - ih);\r
285     tmpj.dist = (dist/2 - jh);\r
286   }\r
287 \r
288 \r
289 \r
290   public void findClusterDistance(int i, int j) {\r
291 \r
292     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
293     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
294 \r
295     // New distances from cluster to others\r
296     float[] newdist = new float[noseqs];\r
297 \r
298     for (int l = 0; l < noseqs; l++) {\r
299       if ( l != i && l != j) {\r
300         newdist[l] = (distance[i][l] * noi + distance[j][l] * noj)/(noi + noj);\r
301       } else {\r
302         newdist[l] = 0;\r
303       }\r
304     }\r
305 \r
306     for (int ii=0; ii < noseqs;ii++) {\r
307       distance[i][ii] = newdist[ii];\r
308       distance[ii][i] = newdist[ii];\r
309     }\r
310   }\r
311 \r
312   public void findClusterNJDistance(int i, int j) {\r
313 \r
314     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
315     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
316 \r
317     // New distances from cluster to others\r
318     float[] newdist = new float[noseqs];\r
319 \r
320     for (int l = 0; l < noseqs; l++) {\r
321       if ( l != i && l != j) {\r
322         newdist[l] = (distance[i][l] + distance[j][l] - distance[i][j])/2;\r
323       } else {\r
324         newdist[l] = 0;\r
325       }\r
326     }\r
327 \r
328     for (int ii=0; ii < noseqs;ii++) {\r
329       distance[i][ii] = newdist[ii];\r
330       distance[ii][i] = newdist[ii];\r
331     }\r
332   }\r
333 \r
334   public float findr(int i, int j) {\r
335 \r
336     float tmp = 1;\r
337     for (int k=0; k < noseqs;k++) {\r
338       if (k!= i && k!= j && done[k] != 1) {\r
339         tmp = tmp + distance[i][k];\r
340       }\r
341     }\r
342 \r
343     if (noClus > 2) {\r
344       tmp = tmp/(noClus - 2);\r
345     }\r
346 \r
347     return tmp;\r
348   }\r
349 \r
350   public float findMinNJDistance() {\r
351 \r
352     float min = 100000;\r
353 \r
354     for (int i=0; i < noseqs-1; i++) {\r
355       for (int j=i+1;j < noseqs;j++) {\r
356         if (done[i] != 1 && done[j] != 1) {\r
357           float tmp = distance[i][j] - (findr(i,j) + findr(j,i));\r
358           if (tmp < min) {\r
359 \r
360             mini = i;\r
361             minj = j;\r
362 \r
363             min = tmp;\r
364 \r
365           }\r
366         }\r
367       }\r
368     }\r
369     return min;\r
370   }\r
371 \r
372   public float findMinDistance() {\r
373 \r
374     float min = 100000;\r
375 \r
376     for (int i=0; i < noseqs-1;i++) {\r
377       for (int j = i+1; j < noseqs;j++) {\r
378         if (done[i] != 1 && done[j] != 1) {\r
379           if (distance[i][j] < min) {\r
380             mini = i;\r
381             minj = j;\r
382 \r
383             min = distance[i][j];\r
384           }\r
385         }\r
386       }\r
387     }\r
388     return min;\r
389   }\r
390 \r
391   public float[][] findDistances() {\r
392 \r
393     float[][] distance = new float[noseqs][noseqs];\r
394     if (pwtype.equals("PID")) {\r
395       for (int i = 0; i < noseqs-1; i++) {\r
396         for (int j = i; j < noseqs; j++) {\r
397           if (j==i) {\r
398             distance[i][i] = 0;\r
399           } else {\r
400             distance[i][j] = 100-Comparison.PID(sequence[i], sequence[j]);\r
401             distance[j][i] = distance[i][j];\r
402           }\r
403         }\r
404       }\r
405     } else if (pwtype.equals("BL")) {\r
406       int   maxscore = 0;\r
407 \r
408       for (int i = 0; i < noseqs-1; i++) {\r
409         for (int j = i; j < noseqs; j++) {\r
410           int score = 0;\r
411           for (int k=0; k < sequence[i].getLength(); k++) {\r
412             try{\r
413               score +=\r
414                   ResidueProperties.getBLOSUM62(sequence[i].getSequence(k,\r
415                   k + 1),\r
416                                                 sequence[j].getSequence(k,\r
417                   k + 1));\r
418             }catch(Exception ex){System.out.println("err creating BLOSUM62 tree");}\r
419           }\r
420           distance[i][j] = (float)score;\r
421           if (score > maxscore) {\r
422             maxscore = score;\r
423           }\r
424         }\r
425       }\r
426       for (int i = 0; i < noseqs-1; i++) {\r
427         for (int j = i; j < noseqs; j++) {\r
428           distance[i][j] =  (float)maxscore - distance[i][j];\r
429           distance[j][i] = distance[i][j];\r
430         }\r
431       }\r
432     } else if (pwtype.equals("SW")) {\r
433       float max = -1;\r
434       for (int i = 0; i < noseqs-1; i++) {\r
435         for (int j = i; j < noseqs; j++) {\r
436           AlignSeq as = new AlignSeq(sequence[i],sequence[j],"pep");\r
437           as.calcScoreMatrix();\r
438           as.traceAlignment();\r
439           as.printAlignment();\r
440           distance[i][j] = (float)as.maxscore;\r
441           if (max < distance[i][j]) {\r
442             max = distance[i][j];\r
443           }\r
444         }\r
445       }\r
446       for (int i = 0; i < noseqs-1; i++) {\r
447         for (int j = i; j < noseqs; j++) {\r
448           distance[i][j] =  max - distance[i][j];\r
449           distance[j][i] = distance[i][j];\r
450         }\r
451       }\r
452     }\r
453 \r
454     return distance;\r
455   }\r
456 \r
457   public void makeLeaves() {\r
458     cluster = new Vector();\r
459 \r
460     for (int i=0; i < noseqs; i++) {\r
461       SequenceNode sn = new SequenceNode();\r
462 \r
463       sn.setElement(sequence[i]);\r
464       sn.setName(sequence[i].getName());\r
465       node.addElement(sn);\r
466 \r
467       int[] value = new int[1];\r
468       value[0] = i;\r
469 \r
470       Cluster c = new Cluster(value);\r
471       cluster.addElement(c);\r
472     }\r
473   }\r
474 \r
475   public Vector findLeaves(SequenceNode node, Vector leaves) {\r
476     if (node == null) {\r
477       return leaves;\r
478     }\r
479 \r
480     if (node.left() == null && node.right() == null) {\r
481       leaves.addElement(node);\r
482       return leaves;\r
483     } else {\r
484       findLeaves((SequenceNode)node.left(),leaves);\r
485       findLeaves((SequenceNode)node.right(),leaves);\r
486     }\r
487     return leaves;\r
488   }\r
489 \r
490   public Object findLeaf(SequenceNode node, int count) {\r
491     found = _findLeaf(node,count);\r
492 \r
493     return found;\r
494   }\r
495   public Object _findLeaf(SequenceNode node,int count) {\r
496     if (node == null) {\r
497       return null;\r
498     }\r
499     if (node.ycount == count) {\r
500       found = node.element();\r
501       return found;\r
502     } else {\r
503       _findLeaf((SequenceNode)node.left(),count);\r
504       _findLeaf((SequenceNode)node.right(),count);\r
505     }\r
506 \r
507     return found;\r
508   }\r
509 \r
510   public void printNode(SequenceNode node) {\r
511     if (node == null) {\r
512       return;\r
513     }\r
514     if (node.left() == null && node.right() == null) {\r
515       System.out.println("Leaf = " + ((SequenceI)node.element()).getName());\r
516       System.out.println("Dist " + ((SequenceNode)node).dist);\r
517       System.out.println("Boot " + node.getBootstrap());\r
518     } else {\r
519       System.out.println("Dist " + ((SequenceNode)node).dist);\r
520       printNode((SequenceNode)node.left());\r
521       printNode((SequenceNode)node.right());\r
522     }\r
523   }\r
524   public void findMaxDist(SequenceNode node) {\r
525     if (node == null) {\r
526       return;\r
527     }\r
528     if (node.left() == null && node.right() == null) {\r
529 \r
530       float dist = ((SequenceNode)node).dist;\r
531       if (dist > maxDistValue) {\r
532           maxdist      = (SequenceNode)node;\r
533           maxDistValue = dist;\r
534       }\r
535     } else {\r
536       findMaxDist((SequenceNode)node.left());\r
537       findMaxDist((SequenceNode)node.right());\r
538     }\r
539   }\r
540     public Vector getGroups() {\r
541         return groups;\r
542     }\r
543     public float getMaxHeight() {\r
544         return maxheight;\r
545     }\r
546   public void  groupNodes(SequenceNode node, float threshold) {\r
547     if (node == null) {\r
548       return;\r
549     }\r
550 \r
551     if (node.height/maxheight > threshold) {\r
552       groups.addElement(node);\r
553     } else {\r
554       groupNodes((SequenceNode)node.left(),threshold);\r
555       groupNodes((SequenceNode)node.right(),threshold);\r
556     }\r
557   }\r
558 \r
559   public float findHeight(SequenceNode node) {\r
560 \r
561     if (node == null) {\r
562       return maxheight;\r
563     }\r
564 \r
565     if (node.left() == null && node.right() == null) {\r
566       node.height = ((SequenceNode)node.parent()).height + node.dist;\r
567 \r
568       if (node.height > maxheight) {\r
569         return node.height;\r
570       } else {\r
571         return maxheight;\r
572       }\r
573     } else {\r
574       if (node.parent() != null) {\r
575         node.height = ((SequenceNode)node.parent()).height + node.dist;\r
576       } else {\r
577         maxheight = 0;\r
578         node.height = (float)0.0;\r
579       }\r
580 \r
581       maxheight = findHeight((SequenceNode)(node.left()));\r
582       maxheight = findHeight((SequenceNode)(node.right()));\r
583     }\r
584     return maxheight;\r
585   }\r
586   public SequenceNode reRoot() {\r
587     if (maxdist != null) {\r
588       ycount = 0;\r
589       float tmpdist = maxdist.dist;\r
590 \r
591       // New top\r
592       SequenceNode sn = new SequenceNode();\r
593       sn.setParent(null);\r
594 \r
595       // New right hand of top\r
596       SequenceNode snr = (SequenceNode)maxdist.parent();\r
597       changeDirection(snr,maxdist);\r
598       System.out.println("Printing reversed tree");\r
599       printN(snr);\r
600       snr.dist = tmpdist/2;\r
601       maxdist.dist = tmpdist/2;\r
602 \r
603       snr.setParent(sn);\r
604       maxdist.setParent(sn);\r
605 \r
606       sn.setRight(snr);\r
607       sn.setLeft(maxdist);\r
608 \r
609       top = sn;\r
610 \r
611       ycount = 0;\r
612       reCount(top);\r
613       findHeight(top);\r
614 \r
615     }\r
616     return top;\r
617   }\r
618   public static void printN(SequenceNode node) {\r
619     if (node == null) {\r
620       return;\r
621     }\r
622 \r
623     if (node.left() != null && node.right() != null) {\r
624       printN((SequenceNode)node.left());\r
625       printN((SequenceNode)node.right());\r
626     } else {\r
627       System.out.println(" name = " + ((SequenceI)node.element()).getName());\r
628     }\r
629     System.out.println(" dist = " + ((SequenceNode)node).dist + " " + ((SequenceNode)node).count + " " + ((SequenceNode)node).height);\r
630   }\r
631 \r
632     public void reCount(SequenceNode node) {\r
633         ycount = 0;\r
634         _reCount(node);\r
635     }\r
636   public void _reCount(SequenceNode node) {\r
637     if (node == null) {\r
638       return;\r
639     }\r
640 \r
641     if (node.left() != null && node.right() != null) {\r
642       _reCount((SequenceNode)node.left());\r
643       _reCount((SequenceNode)node.right());\r
644 \r
645       SequenceNode l = (SequenceNode)node.left();\r
646       SequenceNode r = (SequenceNode)node.right();\r
647 \r
648       ((SequenceNode)node).count  = l.count + r.count;\r
649       ((SequenceNode)node).ycount = (l.ycount + r.ycount)/2;\r
650 \r
651     } else {\r
652       ((SequenceNode)node).count = 1;\r
653       ((SequenceNode)node).ycount = ycount++;\r
654     }\r
655 \r
656   }\r
657     public void swapNodes(SequenceNode node) {\r
658         if (node == null) {\r
659             return;\r
660         }\r
661         SequenceNode tmp = (SequenceNode)node.left();\r
662 \r
663         node.setLeft(node.right());\r
664         node.setRight(tmp);\r
665     }\r
666   public void changeDirection(SequenceNode node, SequenceNode dir) {\r
667     if (node == null) {\r
668       return;\r
669     }\r
670     if (node.parent() != top) {\r
671       changeDirection((SequenceNode)node.parent(), node);\r
672 \r
673       SequenceNode tmp = (SequenceNode)node.parent();\r
674 \r
675       if (dir == node.left()) {\r
676         node.setParent(dir);\r
677         node.setLeft(tmp);\r
678       } else if (dir == node.right()) {\r
679         node.setParent(dir);\r
680         node.setRight(tmp);\r
681       }\r
682 \r
683     } else {\r
684       if (dir == node.left()) {\r
685         node.setParent(node.left());\r
686 \r
687         if (top.left() == node) {\r
688           node.setRight(top.right());\r
689         } else {\r
690           node.setRight(top.left());\r
691         }\r
692       } else {\r
693         node.setParent(node.right());\r
694 \r
695         if (top.left() == node) {\r
696           node.setLeft(top.right());\r
697         } else {\r
698           node.setLeft(top.left());\r
699         }\r
700       }\r
701     }\r
702   }\r
703     public void setMaxDist(SequenceNode node) {\r
704         this.maxdist = maxdist;\r
705     }\r
706     public SequenceNode getMaxDist() {\r
707         return maxdist;\r
708     }\r
709     public SequenceNode getTopNode() {\r
710         return top;\r
711     }\r
712 \r
713 }\r
714 \r
715 \r
716 \r
717 class Cluster {\r
718 \r
719   int[] value;\r
720 \r
721   public Cluster(int[] value) {\r
722     this.value = value;\r
723   }\r
724 \r
725 }\r
726 \r