5381880e05d8b31c1894a1cbb1fe23084457a1b0
[jalview.git] / src / jalview / analysis / TreeBuilder.java
1 package jalview.analysis;
2
3 import jalview.api.analysis.DistanceScoreModelI;
4 import jalview.api.analysis.ScoreModelI;
5 import jalview.api.analysis.SimilarityParamsI;
6 import jalview.api.analysis.SimilarityScoreModelI;
7 import jalview.datamodel.AlignmentView;
8 import jalview.datamodel.CigarArray;
9 import jalview.datamodel.SeqCigar;
10 import jalview.datamodel.SequenceI;
11 import jalview.datamodel.SequenceNode;
12 import jalview.math.MatrixI;
13 import jalview.viewmodel.AlignmentViewport;
14
15 import java.util.BitSet;
16 import java.util.Vector;
17
18 public abstract class TreeBuilder
19 {
20   public static final String AVERAGE_DISTANCE = "AV";
21
22   public static final String NEIGHBOUR_JOINING = "NJ";
23
24   protected Vector<BitSet> clusters;
25
26   protected SequenceI[] sequences;
27
28   public AlignmentView seqData;
29
30   protected BitSet done;
31
32   protected int noseqs;
33
34   int noClus;
35
36   protected MatrixI distances;
37
38   protected int mini;
39
40   protected int minj;
41
42   protected double ri;
43
44   protected double rj;
45
46   SequenceNode maxdist;
47
48   SequenceNode top;
49
50   double maxDistValue;
51
52   double maxheight;
53
54   int ycount;
55
56   Vector<SequenceNode> node;
57
58   private AlignmentView seqStrings;
59
60   /**
61    * Constructor
62    * 
63    * @param av
64    * @param sm
65    * @param scoreParameters
66    */
67   public TreeBuilder(AlignmentViewport av, ScoreModelI sm,
68           SimilarityParamsI scoreParameters)
69   {
70     int start, end;
71     boolean selview = av.getSelectionGroup() != null
72             && av.getSelectionGroup().getSize() > 1;
73     seqStrings = av.getAlignmentView(selview);
74     if (!selview)
75     {
76       start = 0;
77       end = av.getAlignment().getWidth();
78       this.sequences = av.getAlignment().getSequencesArray();
79     }
80     else
81     {
82       start = av.getSelectionGroup().getStartRes();
83       end = av.getSelectionGroup().getEndRes() + 1;
84       this.sequences = av.getSelectionGroup().getSequencesInOrder(
85               av.getAlignment());
86     }
87
88     init(seqStrings, start, end);
89
90     computeTree(sm, scoreParameters);
91   }
92
93   public SequenceI[] getSequences()
94   {
95     return sequences;
96   }
97
98   /**
99    * DOCUMENT ME!
100    * 
101    * @param nd
102    *          DOCUMENT ME!
103    * 
104    * @return DOCUMENT ME!
105    */
106   double findHeight(SequenceNode nd)
107   {
108     if (nd == null)
109     {
110       return maxheight;
111     }
112   
113     if ((nd.left() == null) && (nd.right() == null))
114     {
115       nd.height = ((SequenceNode) nd.parent()).height + nd.dist;
116   
117       if (nd.height > maxheight)
118       {
119         return nd.height;
120       }
121       else
122       {
123         return maxheight;
124       }
125     }
126     else
127     {
128       if (nd.parent() != null)
129       {
130         nd.height = ((SequenceNode) nd.parent()).height + nd.dist;
131       }
132       else
133       {
134         maxheight = 0;
135         nd.height = (float) 0.0;
136       }
137   
138       maxheight = findHeight((SequenceNode) (nd.left()));
139       maxheight = findHeight((SequenceNode) (nd.right()));
140     }
141   
142     return maxheight;
143   }
144
145   /**
146    * DOCUMENT ME!
147    * 
148    * @param nd
149    *          DOCUMENT ME!
150    */
151   void reCount(SequenceNode nd)
152   {
153     ycount = 0;
154     // _lycount = 0;
155     // _lylimit = this.node.size();
156     _reCount(nd);
157   }
158
159   /**
160    * DOCUMENT ME!
161    * 
162    * @param nd
163    *          DOCUMENT ME!
164    */
165   void _reCount(SequenceNode nd)
166   {
167     // if (_lycount<_lylimit)
168     // {
169     // System.err.println("Warning: depth of _recount greater than number of nodes.");
170     // }
171     if (nd == null)
172     {
173       return;
174     }
175     // _lycount++;
176   
177     if ((nd.left() != null) && (nd.right() != null))
178     {
179   
180       _reCount((SequenceNode) nd.left());
181       _reCount((SequenceNode) nd.right());
182   
183       SequenceNode l = (SequenceNode) nd.left();
184       SequenceNode r = (SequenceNode) nd.right();
185   
186       nd.count = l.count + r.count;
187       nd.ycount = (l.ycount + r.ycount) / 2;
188     }
189     else
190     {
191       nd.count = 1;
192       nd.ycount = ycount++;
193     }
194     // _lycount--;
195   }
196
197   /**
198    * DOCUMENT ME!
199    * 
200    * @return DOCUMENT ME!
201    */
202   public SequenceNode getTopNode()
203   {
204     return top;
205   }
206
207   /**
208    * 
209    * @return true if tree has real distances
210    */
211   public boolean hasDistances()
212   {
213     return true;
214   }
215
216   /**
217    * 
218    * @return true if tree has real bootstrap values
219    */
220   public boolean hasBootstrap()
221   {
222     return false;
223   }
224
225   public boolean hasRootDistance()
226   {
227     return true;
228   }
229
230   /**
231    * Form clusters by grouping sub-clusters, starting from one sequence per
232    * cluster, and finishing when only two clusters remain
233    */
234   void cluster()
235   {
236     while (noClus > 2)
237     {
238       findMinDistance();
239   
240       joinClusters(mini, minj);
241   
242       noClus--;
243     }
244   
245     int rightChild = done.nextClearBit(0);
246     int leftChild = done.nextClearBit(rightChild + 1);
247   
248     joinClusters(leftChild, rightChild);
249     top = (node.elementAt(leftChild));
250   
251     reCount(top);
252     findHeight(top);
253     findMaxDist(top);
254   }
255
256   /**
257    * Returns the minimum distance between two clusters, and also sets the
258    * indices of the clusters in fields mini and minj
259    * 
260    * @return
261    */
262   protected abstract double findMinDistance();
263
264   /**
265    * Calculates the tree using the given score model and parameters, and the
266    * configured tree type
267    * <p>
268    * If the score model computes pairwise distance scores, then these are used
269    * directly to derive the tree
270    * <p>
271    * If the score model computes similarity scores, then the range of the scores
272    * is reversed to give a distance measure, and this is used to derive the tree
273    * 
274    * @param sm
275    * @param scoreOptions
276    */
277   protected void computeTree(ScoreModelI sm, SimilarityParamsI scoreOptions)
278   {
279     if (sm instanceof DistanceScoreModelI)
280     {
281       distances = ((DistanceScoreModelI) sm).findDistances(seqData,
282               scoreOptions);
283     }
284     else if (sm instanceof SimilarityScoreModelI)
285     {
286       /*
287        * compute similarity and invert it to give a distance measure
288        * reverseRange(true) converts maximum similarity to zero distance
289        */
290       MatrixI result = ((SimilarityScoreModelI) sm).findSimilarities(
291               seqData, scoreOptions);
292       result.reverseRange(true);
293       distances = result;
294     }
295   
296     makeLeaves();
297   
298     noClus = clusters.size();
299   
300     cluster();
301   }
302
303   /**
304    * Finds the node, at or below the given node, with the maximum distance, and
305    * saves the node and the distance value
306    * 
307    * @param nd
308    */
309   void findMaxDist(SequenceNode nd)
310   {
311     if (nd == null)
312     {
313       return;
314     }
315   
316     if ((nd.left() == null) && (nd.right() == null))
317     {
318       double dist = nd.dist;
319   
320       if (dist > maxDistValue)
321       {
322         maxdist = nd;
323         maxDistValue = dist;
324       }
325     }
326     else
327     {
328       findMaxDist((SequenceNode) nd.left());
329       findMaxDist((SequenceNode) nd.right());
330     }
331   }
332
333   /**
334    * Calculates and returns r, whatever that is
335    * 
336    * @param i
337    * @param j
338    * 
339    * @return
340    */
341   protected double findr(int i, int j)
342   {
343     double tmp = 1;
344   
345     for (int k = 0; k < noseqs; k++)
346     {
347       if ((k != i) && (k != j) && (!done.get(k)))
348       {
349         tmp = tmp + distances.getValue(i, k);
350       }
351     }
352   
353     if (noClus > 2)
354     {
355       tmp = tmp / (noClus - 2);
356     }
357   
358     return tmp;
359   }
360
361   protected void init(AlignmentView seqView, int start, int end)
362   {
363     this.node = new Vector<SequenceNode>();
364     if (seqView != null)
365     {
366       this.seqData = seqView;
367     }
368     else
369     {
370       SeqCigar[] seqs = new SeqCigar[sequences.length];
371       for (int i = 0; i < sequences.length; i++)
372       {
373         seqs[i] = new SeqCigar(sequences[i], start, end);
374       }
375       CigarArray sdata = new CigarArray(seqs);
376       sdata.addOperation(CigarArray.M, end - start + 1);
377       this.seqData = new AlignmentView(sdata, start);
378     }
379   
380     /*
381      * count the non-null sequences
382      */
383     noseqs = 0;
384   
385     done = new BitSet();
386   
387     for (SequenceI seq : sequences)
388     {
389       if (seq != null)
390       {
391         noseqs++;
392       }
393     }
394   }
395
396   /**
397    * Merges cluster(j) to cluster(i) and recalculates cluster and node distances
398    * 
399    * @param i
400    * @param j
401    */
402   void joinClusters(final int i, final int j)
403   {
404     double dist = distances.getValue(i, j);
405   
406     ri = findr(i, j);
407     rj = findr(j, i);
408   
409     findClusterDistance(i, j);
410   
411     SequenceNode sn = new SequenceNode();
412   
413     sn.setLeft((node.elementAt(i)));
414     sn.setRight((node.elementAt(j)));
415   
416     SequenceNode tmpi = (node.elementAt(i));
417     SequenceNode tmpj = (node.elementAt(j));
418   
419     findNewDistances(tmpi, tmpj, dist);
420   
421     tmpi.setParent(sn);
422     tmpj.setParent(sn);
423   
424     node.setElementAt(sn, i);
425   
426     /*
427      * move the members of cluster(j) to cluster(i)
428      * and mark cluster j as out of the game
429      */
430     clusters.get(i).or(clusters.get(j));
431     clusters.get(j).clear();
432     done.set(j);
433   }
434
435   /*
436    * Computes and stores new distances for nodei and nodej, given the previous
437    * distance between them
438    */
439   protected abstract void findNewDistances(SequenceNode nodei,
440           SequenceNode nodej, double previousDistance);
441
442   /**
443    * Calculates and saves the distance between the combination of cluster(i) and
444    * cluster(j) and all other clusters. The form of the calculation depends on
445    * the tree clustering method being used.
446    * 
447    * @param i
448    * @param j
449    */
450   protected abstract void findClusterDistance(int i, int j);
451
452   /**
453    * Start by making a cluster for each individual sequence
454    */
455   void makeLeaves()
456   {
457     clusters = new Vector<BitSet>();
458   
459     for (int i = 0; i < noseqs; i++)
460     {
461       SequenceNode sn = new SequenceNode();
462   
463       sn.setElement(sequences[i]);
464       sn.setName(sequences[i].getName());
465       node.addElement(sn);
466       BitSet bs = new BitSet();
467       bs.set(i);
468       clusters.addElement(bs);
469     }
470   }
471
472   public AlignmentView getOriginalData()
473   {
474     return seqStrings;
475   }
476
477 }