Jalview Imported Sources
[jalview.git] / src / jalview / analysis / NJTree.java
1 package jalview.analysis;\r
2 \r
3 import jalview.datamodel.*;\r
4 import jalview.util.*;\r
5 import jalview.schemes.ResidueProperties;\r
6 import java.util.*;\r
7 \r
8 public class NJTree {\r
9 \r
10   Vector cluster;\r
11   SequenceI[] sequence;\r
12 \r
13   int done[];\r
14   int noseqs;\r
15   int noClus;\r
16 \r
17   float distance[][];\r
18 \r
19   int mini;\r
20   int minj;\r
21   float ri;\r
22   float rj;\r
23 \r
24   Vector groups = new Vector();\r
25   SequenceNode maxdist;\r
26   SequenceNode top;\r
27 \r
28   float maxDistValue;\r
29   float maxheight;\r
30 \r
31   int ycount;\r
32 \r
33   Vector node;\r
34 \r
35   String type;\r
36   String pwtype;\r
37 \r
38   Object found = null;\r
39   Object leaves = null;\r
40 \r
41   int start;\r
42   int end;\r
43 \r
44   public NJTree(SequenceNode node) {\r
45     top = node;\r
46     maxheight = findHeight(top);\r
47   }\r
48 \r
49   public NJTree(SequenceI[] sequence,int start, int end) {\r
50     this(sequence,"NJ","BL",start,end);\r
51   }\r
52 \r
53   public NJTree(SequenceI[] sequence,String type,String pwtype,int start, int end ) {\r
54 \r
55     this.sequence = sequence;\r
56     this.node     = new Vector();\r
57     this.type     = type;\r
58     this.pwtype   = pwtype;\r
59     this.start    = start;\r
60     this.end      = end;\r
61 \r
62     if (!(type.equals("NJ"))) {\r
63       type = "AV";\r
64     }\r
65 \r
66     if (!(pwtype.equals("PID"))) {\r
67       type = "BL";\r
68     }\r
69 \r
70     int i=0;\r
71 \r
72     done = new int[sequence.length];\r
73 \r
74 \r
75     while (i < sequence.length  && sequence[i] != null) {\r
76       done[i] = 0;\r
77       i++;\r
78     }\r
79 \r
80     noseqs = i++;\r
81 \r
82     distance = findDistances();\r
83 \r
84     makeLeaves();\r
85 \r
86     noClus = cluster.size();\r
87 \r
88     cluster();\r
89 \r
90   }\r
91 \r
92 \r
93   public void cluster() {\r
94 \r
95     while (noClus > 2) {\r
96       if (type.equals("NJ")) {\r
97         float mind = findMinNJDistance();\r
98       } else {\r
99         float mind = findMinDistance();\r
100       }\r
101 \r
102       Cluster c = joinClusters(mini,minj);\r
103 \r
104 \r
105       done[minj] = 1;\r
106 \r
107       cluster.setElementAt(null,minj);\r
108       cluster.setElementAt(c,mini);\r
109 \r
110       noClus--;\r
111     }\r
112 \r
113     boolean onefound = false;\r
114 \r
115     int one = -1;\r
116     int two = -1;\r
117 \r
118     for (int i=0; i < noseqs; i++) {\r
119       if (done[i] != 1) {\r
120         if (onefound == false) {\r
121           two = i;\r
122           onefound = true;\r
123         } else {\r
124           one = i;\r
125         }\r
126       }\r
127     }\r
128 \r
129     Cluster c = joinClusters(one,two);\r
130     top = (SequenceNode)(node.elementAt(one));\r
131 \r
132     reCount(top);\r
133     findHeight(top);\r
134     findMaxDist(top);\r
135 \r
136   }\r
137 \r
138   public Cluster joinClusters(int i, int j) {\r
139 \r
140     float dist = distance[i][j];\r
141 \r
142     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
143     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
144 \r
145     int[] value = new int[noi + noj];\r
146 \r
147     for (int ii = 0; ii < noi;ii++) {\r
148       value[ii] =  ((Cluster)cluster.elementAt(i)).value[ii];\r
149     }\r
150 \r
151     for (int ii = noi; ii < noi+ noj;ii++) {\r
152       value[ii] =  ((Cluster)cluster.elementAt(j)).value[ii-noi];\r
153     }\r
154 \r
155     Cluster c = new Cluster(value);\r
156 \r
157     ri = findr(i,j);\r
158     rj = findr(j,i);\r
159 \r
160     if (type.equals("NJ")) {\r
161       findClusterNJDistance(i,j);\r
162     } else {\r
163       findClusterDistance(i,j);\r
164     }\r
165 \r
166     SequenceNode sn = new SequenceNode();\r
167 \r
168     sn.setLeft((SequenceNode)(node.elementAt(i)));\r
169     sn.setRight((SequenceNode)(node.elementAt(j)));\r
170 \r
171     SequenceNode tmpi = (SequenceNode)(node.elementAt(i));\r
172     SequenceNode tmpj = (SequenceNode)(node.elementAt(j));\r
173 \r
174     if (type.equals("NJ")) {\r
175       findNewNJDistances(tmpi,tmpj,dist);\r
176     } else {\r
177       findNewDistances(tmpi,tmpj,dist);\r
178     }\r
179 \r
180     tmpi.setParent(sn);\r
181     tmpj.setParent(sn);\r
182 \r
183     node.setElementAt(sn,i);\r
184     return c;\r
185   }\r
186 \r
187   public void findNewNJDistances(SequenceNode tmpi, SequenceNode tmpj, float dist) {\r
188 \r
189     float ih = 0;\r
190     float jh = 0;\r
191 \r
192     SequenceNode sni = tmpi;\r
193     SequenceNode snj = tmpj;\r
194 \r
195     tmpi.dist = (dist + ri - rj)/2;\r
196     tmpj.dist = (dist - tmpi.dist);\r
197 \r
198     if (tmpi.dist < 0) {\r
199       tmpi.dist = 0;\r
200     }\r
201     if (tmpj.dist < 0) {\r
202       tmpj.dist = 0;\r
203     }\r
204   }\r
205 \r
206   public void findNewDistances(SequenceNode tmpi,SequenceNode tmpj,float dist) {\r
207 \r
208     float ih = 0;\r
209     float jh = 0;\r
210 \r
211     SequenceNode sni = tmpi;\r
212     SequenceNode snj = tmpj;\r
213 \r
214     while (sni != null) {\r
215       ih = ih + sni.dist;\r
216       sni = (SequenceNode)sni.left();\r
217     }\r
218 \r
219     while (snj != null) {\r
220       jh = jh + snj.dist;\r
221       snj = (SequenceNode)snj.left();\r
222     }\r
223 \r
224     tmpi.dist = (dist/2 - ih);\r
225     tmpj.dist = (dist/2 - jh);\r
226   }\r
227 \r
228 \r
229 \r
230   public void findClusterDistance(int i, int j) {\r
231 \r
232     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
233     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
234 \r
235     // New distances from cluster to others\r
236     float[] newdist = new float[noseqs];\r
237 \r
238     for (int l = 0; l < noseqs; l++) {\r
239       if ( l != i && l != j) {\r
240         newdist[l] = (distance[i][l] * noi + distance[j][l] * noj)/(noi + noj);\r
241       } else {\r
242         newdist[l] = 0;\r
243       }\r
244     }\r
245 \r
246     for (int ii=0; ii < noseqs;ii++) {\r
247       distance[i][ii] = newdist[ii];\r
248       distance[ii][i] = newdist[ii];\r
249     }\r
250   }\r
251 \r
252   public void findClusterNJDistance(int i, int j) {\r
253 \r
254     int noi = ((Cluster)cluster.elementAt(i)).value.length;\r
255     int noj = ((Cluster)cluster.elementAt(j)).value.length;\r
256 \r
257     // New distances from cluster to others\r
258     float[] newdist = new float[noseqs];\r
259 \r
260     for (int l = 0; l < noseqs; l++) {\r
261       if ( l != i && l != j) {\r
262         newdist[l] = (distance[i][l] + distance[j][l] - distance[i][j])/2;\r
263       } else {\r
264         newdist[l] = 0;\r
265       }\r
266     }\r
267 \r
268     for (int ii=0; ii < noseqs;ii++) {\r
269       distance[i][ii] = newdist[ii];\r
270       distance[ii][i] = newdist[ii];\r
271     }\r
272   }\r
273 \r
274   public float findr(int i, int j) {\r
275 \r
276     float tmp = 1;\r
277     for (int k=0; k < noseqs;k++) {\r
278       if (k!= i && k!= j && done[k] != 1) {\r
279         tmp = tmp + distance[i][k];\r
280       }\r
281     }\r
282 \r
283     if (noClus > 2) {\r
284       tmp = tmp/(noClus - 2);\r
285     }\r
286 \r
287     return tmp;\r
288   }\r
289 \r
290   public float findMinNJDistance() {\r
291 \r
292     float min = 100000;\r
293 \r
294     for (int i=0; i < noseqs-1; i++) {\r
295       for (int j=i+1;j < noseqs;j++) {\r
296         if (done[i] != 1 && done[j] != 1) {\r
297           float tmp = distance[i][j] - (findr(i,j) + findr(j,i));\r
298           if (tmp < min) {\r
299 \r
300             mini = i;\r
301             minj = j;\r
302 \r
303             min = tmp;\r
304 \r
305           }\r
306         }\r
307       }\r
308     }\r
309     return min;\r
310   }\r
311 \r
312   public float findMinDistance() {\r
313 \r
314     float min = 100000;\r
315 \r
316     for (int i=0; i < noseqs-1;i++) {\r
317       for (int j = i+1; j < noseqs;j++) {\r
318         if (done[i] != 1 && done[j] != 1) {\r
319           if (distance[i][j] < min) {\r
320             mini = i;\r
321             minj = j;\r
322 \r
323             min = distance[i][j];\r
324           }\r
325         }\r
326       }\r
327     }\r
328     return min;\r
329   }\r
330 \r
331   public float[][] findDistances() {\r
332 \r
333     float[][] distance = new float[noseqs][noseqs];\r
334 \r
335     if (pwtype.equals("PID")) {\r
336       for (int i = 0; i < noseqs-1; i++) {\r
337         for (int j = i; j < noseqs; j++) {\r
338           if (j==i) {\r
339             distance[i][i] = 0;\r
340           } else {\r
341             distance[i][j] = 100-Comparison.compare(sequence[i],sequence[j],start,end);\r
342             distance[j][i] = distance[i][j];\r
343           }\r
344         }\r
345       }\r
346     } else if (pwtype.equals("BL")) {\r
347       int   maxscore = 0;\r
348 \r
349       for (int i = 0; i < noseqs-1; i++) {\r
350         for (int j = i; j < noseqs; j++) {\r
351           int score = 0;\r
352           for (int k=0; k < sequence[i].getLength(); k++) {\r
353             try{\r
354               score +=\r
355                   ResidueProperties.getBLOSUM62(sequence[i].getSequence(k,\r
356                   k + 1),\r
357                                                 sequence[j].getSequence(k,\r
358                   k + 1));\r
359             }catch(Exception ex){System.out.println("err creating BLOSUM62 tree");}\r
360           }\r
361           distance[i][j] = (float)score;\r
362           if (score > maxscore) {\r
363             maxscore = score;\r
364           }\r
365         }\r
366       }\r
367       for (int i = 0; i < noseqs-1; i++) {\r
368         for (int j = i; j < noseqs; j++) {\r
369           distance[i][j] =  (float)maxscore - distance[i][j];\r
370           distance[j][i] = distance[i][j];\r
371         }\r
372       }\r
373     } else if (pwtype.equals("SW")) {\r
374       float max = -1;\r
375       for (int i = 0; i < noseqs-1; i++) {\r
376         for (int j = i; j < noseqs; j++) {\r
377           AlignSeq as = new AlignSeq(sequence[i],sequence[j],"pep");\r
378           as.calcScoreMatrix();\r
379           as.traceAlignment();\r
380           as.printAlignment();\r
381           distance[i][j] = (float)as.maxscore;\r
382           if (max < distance[i][j]) {\r
383             max = distance[i][j];\r
384           }\r
385         }\r
386       }\r
387       for (int i = 0; i < noseqs-1; i++) {\r
388         for (int j = i; j < noseqs; j++) {\r
389           distance[i][j] =  max - distance[i][j];\r
390           distance[j][i] = distance[i][j];\r
391         }\r
392       }\r
393     }\r
394 \r
395     return distance;\r
396   }\r
397 \r
398   public void makeLeaves() {\r
399     cluster = new Vector();\r
400 \r
401     for (int i=0; i < noseqs; i++) {\r
402       SequenceNode sn = new SequenceNode();\r
403 \r
404       sn.setElement(sequence[i]);\r
405       sn.setName(sequence[i].getName());\r
406       node.addElement(sn);\r
407 \r
408       int[] value = new int[1];\r
409       value[0] = i;\r
410 \r
411       Cluster c = new Cluster(value);\r
412       cluster.addElement(c);\r
413     }\r
414   }\r
415 \r
416   public Vector findLeaves(SequenceNode node, Vector leaves) {\r
417     if (node == null) {\r
418       return leaves;\r
419     }\r
420 \r
421     if (node.left() == null && node.right() == null) {\r
422       leaves.addElement(node);\r
423       return leaves;\r
424     } else {\r
425       findLeaves((SequenceNode)node.left(),leaves);\r
426       findLeaves((SequenceNode)node.right(),leaves);\r
427     }\r
428     return leaves;\r
429   }\r
430 \r
431   public Object findLeaf(SequenceNode node, int count) {\r
432     found = _findLeaf(node,count);\r
433 \r
434     return found;\r
435   }\r
436   public Object _findLeaf(SequenceNode node,int count) {\r
437     if (node == null) {\r
438       return null;\r
439     }\r
440     if (node.ycount == count) {\r
441       found = node.element();\r
442       return found;\r
443     } else {\r
444       _findLeaf((SequenceNode)node.left(),count);\r
445       _findLeaf((SequenceNode)node.right(),count);\r
446     }\r
447 \r
448     return found;\r
449   }\r
450 \r
451   public void printNode(SequenceNode node) {\r
452     if (node == null) {\r
453       return;\r
454     }\r
455     if (node.left() == null && node.right() == null) {\r
456       System.out.println("Leaf = " + ((SequenceI)node.element()).getName());\r
457       System.out.println("Dist " + ((SequenceNode)node).dist);\r
458       System.out.println("Boot " + node.getBootstrap());\r
459     } else {\r
460       System.out.println("Dist " + ((SequenceNode)node).dist);\r
461       printNode((SequenceNode)node.left());\r
462       printNode((SequenceNode)node.right());\r
463     }\r
464   }\r
465   public void findMaxDist(SequenceNode node) {\r
466     if (node == null) {\r
467       return;\r
468     }\r
469     if (node.left() == null && node.right() == null) {\r
470 \r
471       float dist = ((SequenceNode)node).dist;\r
472       if (dist > maxDistValue) {\r
473           maxdist      = (SequenceNode)node;\r
474           maxDistValue = dist;\r
475       }\r
476     } else {\r
477       findMaxDist((SequenceNode)node.left());\r
478       findMaxDist((SequenceNode)node.right());\r
479     }\r
480   }\r
481     public Vector getGroups() {\r
482         return groups;\r
483     }\r
484     public float getMaxHeight() {\r
485         return maxheight;\r
486     }\r
487   public void  groupNodes(SequenceNode node, float threshold) {\r
488     if (node == null) {\r
489       return;\r
490     }\r
491 \r
492     if (node.height/maxheight > threshold) {\r
493       groups.addElement(node);\r
494     } else {\r
495       groupNodes((SequenceNode)node.left(),threshold);\r
496       groupNodes((SequenceNode)node.right(),threshold);\r
497     }\r
498   }\r
499 \r
500   public float findHeight(SequenceNode node) {\r
501 \r
502     if (node == null) {\r
503       return maxheight;\r
504     }\r
505 \r
506     if (node.left() == null && node.right() == null) {\r
507       node.height = ((SequenceNode)node.parent()).height + node.dist;\r
508 \r
509       if (node.height > maxheight) {\r
510         return node.height;\r
511       } else {\r
512         return maxheight;\r
513       }\r
514     } else {\r
515       if (node.parent() != null) {\r
516         node.height = ((SequenceNode)node.parent()).height + node.dist;\r
517       } else {\r
518         maxheight = 0;\r
519         node.height = (float)0.0;\r
520       }\r
521 \r
522       maxheight = findHeight((SequenceNode)(node.left()));\r
523       maxheight = findHeight((SequenceNode)(node.right()));\r
524     }\r
525     return maxheight;\r
526   }\r
527   public SequenceNode reRoot() {\r
528     if (maxdist != null) {\r
529       ycount = 0;\r
530       float tmpdist = maxdist.dist;\r
531 \r
532       // New top\r
533       SequenceNode sn = new SequenceNode();\r
534       sn.setParent(null);\r
535 \r
536       // New right hand of top\r
537       SequenceNode snr = (SequenceNode)maxdist.parent();\r
538       changeDirection(snr,maxdist);\r
539       System.out.println("Printing reversed tree");\r
540       printN(snr);\r
541       snr.dist = tmpdist/2;\r
542       maxdist.dist = tmpdist/2;\r
543 \r
544       snr.setParent(sn);\r
545       maxdist.setParent(sn);\r
546 \r
547       sn.setRight(snr);\r
548       sn.setLeft(maxdist);\r
549 \r
550       top = sn;\r
551 \r
552       ycount = 0;\r
553       reCount(top);\r
554       findHeight(top);\r
555 \r
556     }\r
557     return top;\r
558   }\r
559   public static void printN(SequenceNode node) {\r
560     if (node == null) {\r
561       return;\r
562     }\r
563 \r
564     if (node.left() != null && node.right() != null) {\r
565       printN((SequenceNode)node.left());\r
566       printN((SequenceNode)node.right());\r
567     } else {\r
568       System.out.println(" name = " + ((SequenceI)node.element()).getName());\r
569     }\r
570     System.out.println(" dist = " + ((SequenceNode)node).dist + " " + ((SequenceNode)node).count + " " + ((SequenceNode)node).height);\r
571   }\r
572 \r
573     public void reCount(SequenceNode node) {\r
574         ycount = 0;\r
575         _reCount(node);\r
576     }\r
577   public void _reCount(SequenceNode node) {\r
578     if (node == null) {\r
579       return;\r
580     }\r
581 \r
582     if (node.left() != null && node.right() != null) {\r
583       _reCount((SequenceNode)node.left());\r
584       _reCount((SequenceNode)node.right());\r
585 \r
586       SequenceNode l = (SequenceNode)node.left();\r
587       SequenceNode r = (SequenceNode)node.right();\r
588 \r
589       ((SequenceNode)node).count  = l.count + r.count;\r
590       ((SequenceNode)node).ycount = (l.ycount + r.ycount)/2;\r
591 \r
592     } else {\r
593       ((SequenceNode)node).count = 1;\r
594       ((SequenceNode)node).ycount = ycount++;\r
595     }\r
596 \r
597   }\r
598     public void swapNodes(SequenceNode node) {\r
599         if (node == null) {\r
600             return;\r
601         }\r
602         SequenceNode tmp = (SequenceNode)node.left();\r
603 \r
604         node.setLeft(node.right());\r
605         node.setRight(tmp);\r
606     }\r
607   public void changeDirection(SequenceNode node, SequenceNode dir) {\r
608     if (node == null) {\r
609       return;\r
610     }\r
611     if (node.parent() != top) {\r
612       changeDirection((SequenceNode)node.parent(), node);\r
613 \r
614       SequenceNode tmp = (SequenceNode)node.parent();\r
615 \r
616       if (dir == node.left()) {\r
617         node.setParent(dir);\r
618         node.setLeft(tmp);\r
619       } else if (dir == node.right()) {\r
620         node.setParent(dir);\r
621         node.setRight(tmp);\r
622       }\r
623 \r
624     } else {\r
625       if (dir == node.left()) {\r
626         node.setParent(node.left());\r
627 \r
628         if (top.left() == node) {\r
629           node.setRight(top.right());\r
630         } else {\r
631           node.setRight(top.left());\r
632         }\r
633       } else {\r
634         node.setParent(node.right());\r
635 \r
636         if (top.left() == node) {\r
637           node.setLeft(top.right());\r
638         } else {\r
639           node.setLeft(top.left());\r
640         }\r
641       }\r
642     }\r
643   }\r
644     public void setMaxDist(SequenceNode node) {\r
645         this.maxdist = maxdist;\r
646     }\r
647     public SequenceNode getMaxDist() {\r
648         return maxdist;\r
649     }\r
650     public SequenceNode getTopNode() {\r
651         return top;\r
652     }\r
653 \r
654 }\r
655 \r
656 \r
657 \r
658 class Cluster {\r
659 \r
660   int[] value;\r
661 \r
662   public Cluster(int[] value) {\r
663     this.value = value;\r
664   }\r
665 \r
666 }\r
667 \r