f28c6bc4acd381e4db3884d047f031efcf764944
[jalview.git] / src / jalview / analysis / TreeBuilder.java
1 package jalview.analysis;
2
3 import jalview.api.analysis.DistanceScoreModelI;
4 import jalview.api.analysis.ScoreModelI;
5 import jalview.api.analysis.SimilarityParamsI;
6 import jalview.api.analysis.SimilarityScoreModelI;
7 import jalview.datamodel.AlignmentView;
8 import jalview.datamodel.CigarArray;
9 import jalview.datamodel.SeqCigar;
10 import jalview.datamodel.SequenceI;
11 import jalview.datamodel.SequenceNode;
12 import jalview.math.MatrixI;
13 import jalview.viewmodel.AlignmentViewport;
14
15 import java.util.BitSet;
16 import java.util.Vector;
17
18 public abstract class TreeBuilder
19 {
20   public static final String AVERAGE_DISTANCE = "AV";
21
22   public static final String NEIGHBOUR_JOINING = "NJ";
23
24   protected Vector<BitSet> clusters;
25
26   protected SequenceI[] sequences;
27
28   public AlignmentView seqData;
29
30   protected BitSet done;
31
32   protected int noseqs;
33
34   int noClus;
35
36   protected MatrixI distances;
37
38   protected int mini;
39
40   protected int minj;
41
42   protected double ri;
43
44   protected double rj;
45
46   SequenceNode maxdist;
47
48   SequenceNode top;
49
50   double maxDistValue;
51
52   double maxheight;
53
54   int ycount;
55
56   Vector<SequenceNode> node;
57
58   /**
59    * Constructor
60    * 
61    * @param av
62    * @param sm
63    * @param scoreParameters
64    */
65   public TreeBuilder(AlignmentViewport av, ScoreModelI sm,
66           SimilarityParamsI scoreParameters)
67   {
68     int start, end;
69     boolean selview = av.getSelectionGroup() != null
70             && av.getSelectionGroup().getSize() > 1;
71     AlignmentView seqStrings = av.getAlignmentView(selview);
72     if (!selview)
73     {
74       start = 0;
75       end = av.getAlignment().getWidth();
76       this.sequences = av.getAlignment().getSequencesArray();
77     }
78     else
79     {
80       start = av.getSelectionGroup().getStartRes();
81       end = av.getSelectionGroup().getEndRes() + 1;
82       this.sequences = av.getSelectionGroup().getSequencesInOrder(
83               av.getAlignment());
84     }
85
86     init(seqStrings, start, end);
87
88     computeTree(sm, scoreParameters);
89   }
90
91   public SequenceI[] getSequences()
92   {
93     return sequences;
94   }
95
96   /**
97    * DOCUMENT ME!
98    * 
99    * @param nd
100    *          DOCUMENT ME!
101    * 
102    * @return DOCUMENT ME!
103    */
104   double findHeight(SequenceNode nd)
105   {
106     if (nd == null)
107     {
108       return maxheight;
109     }
110   
111     if ((nd.left() == null) && (nd.right() == null))
112     {
113       nd.height = ((SequenceNode) nd.parent()).height + nd.dist;
114   
115       if (nd.height > maxheight)
116       {
117         return nd.height;
118       }
119       else
120       {
121         return maxheight;
122       }
123     }
124     else
125     {
126       if (nd.parent() != null)
127       {
128         nd.height = ((SequenceNode) nd.parent()).height + nd.dist;
129       }
130       else
131       {
132         maxheight = 0;
133         nd.height = (float) 0.0;
134       }
135   
136       maxheight = findHeight((SequenceNode) (nd.left()));
137       maxheight = findHeight((SequenceNode) (nd.right()));
138     }
139   
140     return maxheight;
141   }
142
143   /**
144    * DOCUMENT ME!
145    * 
146    * @param nd
147    *          DOCUMENT ME!
148    */
149   void reCount(SequenceNode nd)
150   {
151     ycount = 0;
152     // _lycount = 0;
153     // _lylimit = this.node.size();
154     _reCount(nd);
155   }
156
157   /**
158    * DOCUMENT ME!
159    * 
160    * @param nd
161    *          DOCUMENT ME!
162    */
163   void _reCount(SequenceNode nd)
164   {
165     // if (_lycount<_lylimit)
166     // {
167     // System.err.println("Warning: depth of _recount greater than number of nodes.");
168     // }
169     if (nd == null)
170     {
171       return;
172     }
173     // _lycount++;
174   
175     if ((nd.left() != null) && (nd.right() != null))
176     {
177   
178       _reCount((SequenceNode) nd.left());
179       _reCount((SequenceNode) nd.right());
180   
181       SequenceNode l = (SequenceNode) nd.left();
182       SequenceNode r = (SequenceNode) nd.right();
183   
184       nd.count = l.count + r.count;
185       nd.ycount = (l.ycount + r.ycount) / 2;
186     }
187     else
188     {
189       nd.count = 1;
190       nd.ycount = ycount++;
191     }
192     // _lycount--;
193   }
194
195   /**
196    * DOCUMENT ME!
197    * 
198    * @return DOCUMENT ME!
199    */
200   public SequenceNode getTopNode()
201   {
202     return top;
203   }
204
205   /**
206    * 
207    * @return true if tree has real distances
208    */
209   public boolean hasDistances()
210   {
211     return true;
212   }
213
214   /**
215    * 
216    * @return true if tree has real bootstrap values
217    */
218   public boolean hasBootstrap()
219   {
220     return false;
221   }
222
223   public boolean hasRootDistance()
224   {
225     return true;
226   }
227
228   /**
229    * Form clusters by grouping sub-clusters, starting from one sequence per
230    * cluster, and finishing when only two clusters remain
231    */
232   void cluster()
233   {
234     while (noClus > 2)
235     {
236       findMinDistance();
237   
238       joinClusters(mini, minj);
239   
240       noClus--;
241     }
242   
243     int rightChild = done.nextClearBit(0);
244     int leftChild = done.nextClearBit(rightChild + 1);
245   
246     joinClusters(leftChild, rightChild);
247     top = (node.elementAt(leftChild));
248   
249     reCount(top);
250     findHeight(top);
251     findMaxDist(top);
252   }
253
254   protected abstract double findMinDistance();
255
256   /**
257    * Calculates the tree using the given score model and parameters, and the
258    * configured tree type
259    * <p>
260    * If the score model computes pairwise distance scores, then these are used
261    * directly to derive the tree
262    * <p>
263    * If the score model computes similarity scores, then the range of the scores
264    * is reversed to give a distance measure, and this is used to derive the tree
265    * 
266    * @param sm
267    * @param scoreOptions
268    */
269   protected void computeTree(ScoreModelI sm, SimilarityParamsI scoreOptions)
270   {
271     if (sm instanceof DistanceScoreModelI)
272     {
273       distances = ((DistanceScoreModelI) sm).findDistances(seqData,
274               scoreOptions);
275     }
276     else if (sm instanceof SimilarityScoreModelI)
277     {
278       /*
279        * compute similarity and invert it to give a distance measure
280        */
281       MatrixI result = ((SimilarityScoreModelI) sm).findSimilarities(
282               seqData, scoreOptions);
283       result.reverseRange(true);
284       distances = result;
285     }
286   
287     makeLeaves();
288   
289     noClus = clusters.size();
290   
291     cluster();
292   }
293
294   /**
295    * DOCUMENT ME!
296    * 
297    * @param nd
298    *          DOCUMENT ME!
299    */
300   void findMaxDist(SequenceNode nd)
301   {
302     if (nd == null)
303     {
304       return;
305     }
306   
307     if ((nd.left() == null) && (nd.right() == null))
308     {
309       double dist = nd.dist;
310   
311       if (dist > maxDistValue)
312       {
313         maxdist = nd;
314         maxDistValue = dist;
315       }
316     }
317     else
318     {
319       findMaxDist((SequenceNode) nd.left());
320       findMaxDist((SequenceNode) nd.right());
321     }
322   }
323
324   /**
325    * DOCUMENT ME!
326    * 
327    * @param i
328    *          DOCUMENT ME!
329    * @param j
330    *          DOCUMENT ME!
331    * 
332    * @return DOCUMENT ME!
333    */
334   protected double findr(int i, int j)
335   {
336     double tmp = 1;
337   
338     for (int k = 0; k < noseqs; k++)
339     {
340       if ((k != i) && (k != j) && (!done.get(k)))
341       {
342         tmp = tmp + distances.getValue(i, k);
343       }
344     }
345   
346     if (noClus > 2)
347     {
348       tmp = tmp / (noClus - 2);
349     }
350   
351     return tmp;
352   }
353
354   protected void init(AlignmentView seqView, int start, int end)
355   {
356     this.node = new Vector<SequenceNode>();
357     if (seqView != null)
358     {
359       this.seqData = seqView;
360     }
361     else
362     {
363       SeqCigar[] seqs = new SeqCigar[sequences.length];
364       for (int i = 0; i < sequences.length; i++)
365       {
366         seqs[i] = new SeqCigar(sequences[i], start, end);
367       }
368       CigarArray sdata = new CigarArray(seqs);
369       sdata.addOperation(CigarArray.M, end - start + 1);
370       this.seqData = new AlignmentView(sdata, start);
371     }
372   
373     /*
374      * count the non-null sequences
375      */
376     noseqs = 0;
377   
378     done = new BitSet();
379   
380     for (SequenceI seq : sequences)
381     {
382       if (seq != null)
383       {
384         noseqs++;
385       }
386     }
387   }
388
389   /**
390    * Merges cluster(j) to cluster(i) and recalculates cluster and node distances
391    * 
392    * @param i
393    * @param j
394    */
395   void joinClusters(final int i, final int j)
396   {
397     double dist = distances.getValue(i, j);
398   
399     ri = findr(i, j);
400     rj = findr(j, i);
401   
402     findClusterDistance(i, j);
403   
404     SequenceNode sn = new SequenceNode();
405   
406     sn.setLeft((node.elementAt(i)));
407     sn.setRight((node.elementAt(j)));
408   
409     SequenceNode tmpi = (node.elementAt(i));
410     SequenceNode tmpj = (node.elementAt(j));
411   
412     findNewDistances(tmpi, tmpj, dist);
413   
414     tmpi.setParent(sn);
415     tmpj.setParent(sn);
416   
417     node.setElementAt(sn, i);
418   
419     /*
420      * move the members of cluster(j) to cluster(i)
421      * and mark cluster j as out of the game
422      */
423     clusters.get(i).or(clusters.get(j));
424     clusters.get(j).clear();
425     done.set(j);
426   }
427
428   protected abstract void findNewDistances(SequenceNode tmpi, SequenceNode tmpj,
429           double dist);
430
431   /**
432    * Calculates and saves the distance between the combination of cluster(i) and
433    * cluster(j) and all other clusters. The form of the calculation depends on
434    * the tree clustering method being used.
435    * 
436    * @param i
437    * @param j
438    */
439   protected abstract void findClusterDistance(int i, int j);
440
441   /**
442    * Start by making a cluster for each individual sequence
443    */
444   void makeLeaves()
445   {
446     clusters = new Vector<BitSet>();
447   
448     for (int i = 0; i < noseqs; i++)
449     {
450       SequenceNode sn = new SequenceNode();
451   
452       sn.setElement(sequences[i]);
453       sn.setName(sequences[i].getName());
454       node.addElement(sn);
455       BitSet bs = new BitSet();
456       bs.set(i);
457       clusters.addElement(bs);
458     }
459   }
460
461 }