JAL-2393 code tidy and comments, PIDDistanceModel deprecated
[jalview.git] / src / jalview / analysis / TreeBuilder.java
1 package jalview.analysis;
2
3 import jalview.api.analysis.DistanceScoreModelI;
4 import jalview.api.analysis.ScoreModelI;
5 import jalview.api.analysis.SimilarityParamsI;
6 import jalview.api.analysis.SimilarityScoreModelI;
7 import jalview.datamodel.AlignmentView;
8 import jalview.datamodel.CigarArray;
9 import jalview.datamodel.SeqCigar;
10 import jalview.datamodel.SequenceI;
11 import jalview.datamodel.SequenceNode;
12 import jalview.math.MatrixI;
13 import jalview.viewmodel.AlignmentViewport;
14
15 import java.util.BitSet;
16 import java.util.Vector;
17
18 public abstract class TreeBuilder
19 {
20   public static final String AVERAGE_DISTANCE = "AV";
21
22   public static final String NEIGHBOUR_JOINING = "NJ";
23
24   protected Vector<BitSet> clusters;
25
26   protected SequenceI[] sequences;
27
28   public AlignmentView seqData;
29
30   protected BitSet done;
31
32   protected int noseqs;
33
34   int noClus;
35
36   protected MatrixI distances;
37
38   protected int mini;
39
40   protected int minj;
41
42   protected double ri;
43
44   protected double rj;
45
46   SequenceNode maxdist;
47
48   SequenceNode top;
49
50   double maxDistValue;
51
52   double maxheight;
53
54   int ycount;
55
56   Vector<SequenceNode> node;
57
58   /**
59    * Constructor
60    * 
61    * @param av
62    * @param sm
63    * @param scoreParameters
64    */
65   public TreeBuilder(AlignmentViewport av, ScoreModelI sm,
66           SimilarityParamsI scoreParameters)
67   {
68     int start, end;
69     boolean selview = av.getSelectionGroup() != null
70             && av.getSelectionGroup().getSize() > 1;
71     AlignmentView seqStrings = av.getAlignmentView(selview);
72     if (!selview)
73     {
74       start = 0;
75       end = av.getAlignment().getWidth();
76       this.sequences = av.getAlignment().getSequencesArray();
77     }
78     else
79     {
80       start = av.getSelectionGroup().getStartRes();
81       end = av.getSelectionGroup().getEndRes() + 1;
82       this.sequences = av.getSelectionGroup().getSequencesInOrder(
83               av.getAlignment());
84     }
85
86     init(seqStrings, start, end);
87
88     computeTree(sm, scoreParameters);
89   }
90
91   public SequenceI[] getSequences()
92   {
93     return sequences;
94   }
95
96   /**
97    * DOCUMENT ME!
98    * 
99    * @param nd
100    *          DOCUMENT ME!
101    * 
102    * @return DOCUMENT ME!
103    */
104   double findHeight(SequenceNode nd)
105   {
106     if (nd == null)
107     {
108       return maxheight;
109     }
110   
111     if ((nd.left() == null) && (nd.right() == null))
112     {
113       nd.height = ((SequenceNode) nd.parent()).height + nd.dist;
114   
115       if (nd.height > maxheight)
116       {
117         return nd.height;
118       }
119       else
120       {
121         return maxheight;
122       }
123     }
124     else
125     {
126       if (nd.parent() != null)
127       {
128         nd.height = ((SequenceNode) nd.parent()).height + nd.dist;
129       }
130       else
131       {
132         maxheight = 0;
133         nd.height = (float) 0.0;
134       }
135   
136       maxheight = findHeight((SequenceNode) (nd.left()));
137       maxheight = findHeight((SequenceNode) (nd.right()));
138     }
139   
140     return maxheight;
141   }
142
143   /**
144    * DOCUMENT ME!
145    * 
146    * @param nd
147    *          DOCUMENT ME!
148    */
149   void reCount(SequenceNode nd)
150   {
151     ycount = 0;
152     // _lycount = 0;
153     // _lylimit = this.node.size();
154     _reCount(nd);
155   }
156
157   /**
158    * DOCUMENT ME!
159    * 
160    * @param nd
161    *          DOCUMENT ME!
162    */
163   void _reCount(SequenceNode nd)
164   {
165     // if (_lycount<_lylimit)
166     // {
167     // System.err.println("Warning: depth of _recount greater than number of nodes.");
168     // }
169     if (nd == null)
170     {
171       return;
172     }
173     // _lycount++;
174   
175     if ((nd.left() != null) && (nd.right() != null))
176     {
177   
178       _reCount((SequenceNode) nd.left());
179       _reCount((SequenceNode) nd.right());
180   
181       SequenceNode l = (SequenceNode) nd.left();
182       SequenceNode r = (SequenceNode) nd.right();
183   
184       nd.count = l.count + r.count;
185       nd.ycount = (l.ycount + r.ycount) / 2;
186     }
187     else
188     {
189       nd.count = 1;
190       nd.ycount = ycount++;
191     }
192     // _lycount--;
193   }
194
195   /**
196    * DOCUMENT ME!
197    * 
198    * @return DOCUMENT ME!
199    */
200   public SequenceNode getTopNode()
201   {
202     return top;
203   }
204
205   /**
206    * 
207    * @return true if tree has real distances
208    */
209   public boolean hasDistances()
210   {
211     return true;
212   }
213
214   /**
215    * 
216    * @return true if tree has real bootstrap values
217    */
218   public boolean hasBootstrap()
219   {
220     return false;
221   }
222
223   public boolean hasRootDistance()
224   {
225     return true;
226   }
227
228   /**
229    * Form clusters by grouping sub-clusters, starting from one sequence per
230    * cluster, and finishing when only two clusters remain
231    */
232   void cluster()
233   {
234     while (noClus > 2)
235     {
236       findMinDistance();
237   
238       joinClusters(mini, minj);
239   
240       noClus--;
241     }
242   
243     int rightChild = done.nextClearBit(0);
244     int leftChild = done.nextClearBit(rightChild + 1);
245   
246     joinClusters(leftChild, rightChild);
247     top = (node.elementAt(leftChild));
248   
249     reCount(top);
250     findHeight(top);
251     findMaxDist(top);
252   }
253
254   protected abstract double findMinDistance();
255
256   /**
257    * Calculates the tree using the given score model and parameters, and the
258    * configured tree type
259    * <p>
260    * If the score model computes pairwise distance scores, then these are used
261    * directly to derive the tree
262    * <p>
263    * If the score model computes similarity scores, then the range of the scores
264    * is reversed to give a distance measure, and this is used to derive the tree
265    * 
266    * @param sm
267    * @param scoreOptions
268    */
269   protected void computeTree(ScoreModelI sm, SimilarityParamsI scoreOptions)
270   {
271     if (sm instanceof DistanceScoreModelI)
272     {
273       distances = ((DistanceScoreModelI) sm).findDistances(seqData,
274               scoreOptions);
275     }
276     else if (sm instanceof SimilarityScoreModelI)
277     {
278       /*
279        * compute similarity and invert it to give a distance measure
280        * reverseRange(true) converts maximum similarity to zero distance
281        */
282       MatrixI result = ((SimilarityScoreModelI) sm).findSimilarities(
283               seqData, scoreOptions);
284       result.reverseRange(true);
285       distances = result;
286     }
287   
288     makeLeaves();
289   
290     noClus = clusters.size();
291   
292     cluster();
293   }
294
295   /**
296    * DOCUMENT ME!
297    * 
298    * @param nd
299    *          DOCUMENT ME!
300    */
301   void findMaxDist(SequenceNode nd)
302   {
303     if (nd == null)
304     {
305       return;
306     }
307   
308     if ((nd.left() == null) && (nd.right() == null))
309     {
310       double dist = nd.dist;
311   
312       if (dist > maxDistValue)
313       {
314         maxdist = nd;
315         maxDistValue = dist;
316       }
317     }
318     else
319     {
320       findMaxDist((SequenceNode) nd.left());
321       findMaxDist((SequenceNode) nd.right());
322     }
323   }
324
325   /**
326    * DOCUMENT ME!
327    * 
328    * @param i
329    *          DOCUMENT ME!
330    * @param j
331    *          DOCUMENT ME!
332    * 
333    * @return DOCUMENT ME!
334    */
335   protected double findr(int i, int j)
336   {
337     double tmp = 1;
338   
339     for (int k = 0; k < noseqs; k++)
340     {
341       if ((k != i) && (k != j) && (!done.get(k)))
342       {
343         tmp = tmp + distances.getValue(i, k);
344       }
345     }
346   
347     if (noClus > 2)
348     {
349       tmp = tmp / (noClus - 2);
350     }
351   
352     return tmp;
353   }
354
355   protected void init(AlignmentView seqView, int start, int end)
356   {
357     this.node = new Vector<SequenceNode>();
358     if (seqView != null)
359     {
360       this.seqData = seqView;
361     }
362     else
363     {
364       SeqCigar[] seqs = new SeqCigar[sequences.length];
365       for (int i = 0; i < sequences.length; i++)
366       {
367         seqs[i] = new SeqCigar(sequences[i], start, end);
368       }
369       CigarArray sdata = new CigarArray(seqs);
370       sdata.addOperation(CigarArray.M, end - start + 1);
371       this.seqData = new AlignmentView(sdata, start);
372     }
373   
374     /*
375      * count the non-null sequences
376      */
377     noseqs = 0;
378   
379     done = new BitSet();
380   
381     for (SequenceI seq : sequences)
382     {
383       if (seq != null)
384       {
385         noseqs++;
386       }
387     }
388   }
389
390   /**
391    * Merges cluster(j) to cluster(i) and recalculates cluster and node distances
392    * 
393    * @param i
394    * @param j
395    */
396   void joinClusters(final int i, final int j)
397   {
398     double dist = distances.getValue(i, j);
399   
400     ri = findr(i, j);
401     rj = findr(j, i);
402   
403     findClusterDistance(i, j);
404   
405     SequenceNode sn = new SequenceNode();
406   
407     sn.setLeft((node.elementAt(i)));
408     sn.setRight((node.elementAt(j)));
409   
410     SequenceNode tmpi = (node.elementAt(i));
411     SequenceNode tmpj = (node.elementAt(j));
412   
413     findNewDistances(tmpi, tmpj, dist);
414   
415     tmpi.setParent(sn);
416     tmpj.setParent(sn);
417   
418     node.setElementAt(sn, i);
419   
420     /*
421      * move the members of cluster(j) to cluster(i)
422      * and mark cluster j as out of the game
423      */
424     clusters.get(i).or(clusters.get(j));
425     clusters.get(j).clear();
426     done.set(j);
427   }
428
429   protected abstract void findNewDistances(SequenceNode tmpi, SequenceNode tmpj,
430           double dist);
431
432   /**
433    * Calculates and saves the distance between the combination of cluster(i) and
434    * cluster(j) and all other clusters. The form of the calculation depends on
435    * the tree clustering method being used.
436    * 
437    * @param i
438    * @param j
439    */
440   protected abstract void findClusterDistance(int i, int j);
441
442   /**
443    * Start by making a cluster for each individual sequence
444    */
445   void makeLeaves()
446   {
447     clusters = new Vector<BitSet>();
448   
449     for (int i = 0; i < noseqs; i++)
450     {
451       SequenceNode sn = new SequenceNode();
452   
453       sn.setElement(sequences[i]);
454       sn.setName(sequences[i].getName());
455       node.addElement(sn);
456       BitSet bs = new BitSet();
457       bs.set(i);
458       clusters.addElement(bs);
459     }
460   }
461
462 }