JAL-4134 - refactor tree calculation code to work with binaryNode base type.
[jalview.git] / src / jalview / analysis / TreeBuilder.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21 package jalview.analysis;
22
23 import jalview.api.analysis.ScoreModelI;
24 import jalview.api.analysis.SimilarityParamsI;
25 import jalview.datamodel.AlignmentView;
26 import jalview.datamodel.BinaryNode;
27 import jalview.datamodel.CigarArray;
28 import jalview.datamodel.SeqCigar;
29 import jalview.datamodel.SequenceI;
30 import jalview.datamodel.SequenceNode;
31 import jalview.math.MatrixI;
32 import jalview.viewmodel.AlignmentViewport;
33
34 import java.util.BitSet;
35 import java.util.Vector;
36
37 public abstract class TreeBuilder
38 {
39   public static final String AVERAGE_DISTANCE = "AV";
40
41   public static final String NEIGHBOUR_JOINING = "NJ";
42
43   protected Vector<BitSet> clusters;
44
45   protected SequenceI[] sequences;
46
47   public AlignmentView seqData;
48
49   protected BitSet done;
50
51   protected int noseqs;
52
53   int noClus;
54
55   protected MatrixI distances;
56
57   protected int mini;
58
59   protected int minj;
60
61   protected double ri;
62
63   protected double rj;
64
65   BinaryNode maxdist;
66
67   BinaryNode top;
68
69   double maxDistValue;
70
71   double maxheight;
72
73   int ycount;
74
75   Vector<SequenceNode> node;
76
77   private AlignmentView seqStrings;
78
79   /**
80    * Constructor
81    * 
82    * @param av
83    * @param sm
84    * @param scoreParameters
85    */
86   public TreeBuilder(AlignmentViewport av, ScoreModelI sm,
87           SimilarityParamsI scoreParameters)
88   {
89     int start, end;
90     boolean selview = av.getSelectionGroup() != null
91             && av.getSelectionGroup().getSize() > 1;
92     seqStrings = av.getAlignmentView(selview);
93     if (!selview)
94     {
95       start = 0;
96       end = av.getAlignment().getWidth();
97       this.sequences = av.getAlignment().getSequencesArray();
98     }
99     else
100     {
101       start = av.getSelectionGroup().getStartRes();
102       end = av.getSelectionGroup().getEndRes() + 1;
103       this.sequences = av.getSelectionGroup()
104               .getSequencesInOrder(av.getAlignment());
105     }
106
107     init(seqStrings, start, end);
108
109     computeTree(sm, scoreParameters);
110   }
111
112   public SequenceI[] getSequences()
113   {
114     return sequences;
115   }
116
117   /**
118    * DOCUMENT ME!
119    * 
120    * @param nd
121    *          DOCUMENT ME!
122    * 
123    * @return DOCUMENT ME!
124    */
125   double findHeight(BinaryNode nd)
126   {
127     if (nd == null)
128     {
129       return maxheight;
130     }
131
132     if ((nd.left() == null) && (nd.right() == null))
133     {
134       nd.height = ((BinaryNode) nd.parent()).height + nd.dist;
135
136       if (nd.height > maxheight)
137       {
138         return nd.height;
139       }
140       else
141       {
142         return maxheight;
143       }
144     }
145     else
146     {
147       if (nd.parent() != null)
148       {
149         nd.height = ((BinaryNode) nd.parent()).height + nd.dist;
150       }
151       else
152       {
153         maxheight = 0;
154         nd.height = (float) 0.0;
155       }
156
157       maxheight = findHeight((BinaryNode) (nd.left()));
158       maxheight = findHeight((BinaryNode) (nd.right()));
159     }
160
161     return maxheight;
162   }
163
164   /**
165    * DOCUMENT ME!
166    * 
167    * @param nd
168    *          DOCUMENT ME!
169    */
170   void reCount(BinaryNode nd)
171   {
172     ycount = 0;
173     // _lycount = 0;
174     // _lylimit = this.node.size();
175     _reCount(nd);
176   }
177
178   /**
179    * DOCUMENT ME!
180    * 
181    * @param nd
182    *          DOCUMENT ME!
183    */
184   void _reCount(BinaryNode nd)
185   {
186     // if (_lycount<_lylimit)
187     // {
188     // System.err.println("Warning: depth of _recount greater than number of
189     // nodes.");
190     // }
191     if (nd == null)
192     {
193       return;
194     }
195     // _lycount++;
196
197     if ((nd.left() != null) && (nd.right() != null))
198     {
199
200       _reCount(nd.left());
201       _reCount((BinaryNode) nd.right());
202
203       BinaryNode l = nd.left();
204       BinaryNode r = nd.right();
205
206       nd.count = l.count + r.count;
207       nd.ycount = (l.ycount + r.ycount) / 2;
208     }
209     else
210     {
211       nd.count = 1;
212       nd.ycount = ycount++;
213     }
214     // _lycount--;
215   }
216
217   /**
218    * DOCUMENT ME!
219    * 
220    * @return DOCUMENT ME!
221    */
222   public BinaryNode getTopNode()
223   {
224     return top;
225   }
226
227   /**
228    * 
229    * @return true if tree has real distances
230    */
231   public boolean hasDistances()
232   {
233     return true;
234   }
235
236   /**
237    * 
238    * @return true if tree has real bootstrap values
239    */
240   public boolean hasBootstrap()
241   {
242     return false;
243   }
244
245   public boolean hasRootDistance()
246   {
247     return true;
248   }
249
250   /**
251    * Form clusters by grouping sub-clusters, starting from one sequence per
252    * cluster, and finishing when only two clusters remain
253    */
254   void cluster()
255   {
256     while (noClus > 2)
257     {
258       findMinDistance();
259
260       joinClusters(mini, minj);
261
262       noClus--;
263     }
264
265     int rightChild = done.nextClearBit(0);
266     int leftChild = done.nextClearBit(rightChild + 1);
267
268     joinClusters(leftChild, rightChild);
269     top = (node.elementAt(leftChild));
270
271     reCount(top);
272     findHeight(top);
273     findMaxDist(top);
274   }
275
276   /**
277    * Returns the minimum distance between two clusters, and also sets the
278    * indices of the clusters in fields mini and minj
279    * 
280    * @return
281    */
282   protected abstract double findMinDistance();
283
284   /**
285    * Calculates the tree using the given score model and parameters, and the
286    * configured tree type
287    * <p>
288    * If the score model computes pairwise distance scores, then these are used
289    * directly to derive the tree
290    * <p>
291    * If the score model computes similarity scores, then the range of the scores
292    * is reversed to give a distance measure, and this is used to derive the tree
293    * 
294    * @param sm
295    * @param scoreOptions
296    */
297   protected void computeTree(ScoreModelI sm, SimilarityParamsI scoreOptions)
298   {
299     distances = sm.findDistances(seqData, scoreOptions);
300
301     makeLeaves();
302
303     noClus = clusters.size();
304
305     cluster();
306   }
307
308   /**
309    * Finds the node, at or below the given node, with the maximum distance, and
310    * saves the node and the distance value
311    * 
312    * @param nd
313    */
314   void findMaxDist(BinaryNode nd)
315   {
316     if (nd == null)
317     {
318       return;
319     }
320
321     if ((nd.left() == null) && (nd.right() == null))
322     {
323       double dist = nd.dist;
324
325       if (dist > maxDistValue)
326       {
327         maxdist = nd;
328         maxDistValue = dist;
329       }
330     }
331     else
332     {
333       findMaxDist((BinaryNode) nd.left());
334       findMaxDist((BinaryNode) nd.right());
335     }
336   }
337
338   /**
339    * Calculates and returns r, whatever that is
340    * 
341    * @param i
342    * @param j
343    * 
344    * @return
345    */
346   protected double findr(int i, int j)
347   {
348     double tmp = 1;
349
350     for (int k = 0; k < noseqs; k++)
351     {
352       if ((k != i) && (k != j) && (!done.get(k)))
353       {
354         tmp = tmp + distances.getValue(i, k);
355       }
356     }
357
358     if (noClus > 2)
359     {
360       tmp = tmp / (noClus - 2);
361     }
362
363     return tmp;
364   }
365
366   protected void init(AlignmentView seqView, int start, int end)
367   {
368     this.node = new Vector<SequenceNode>();
369     if (seqView != null)
370     {
371       this.seqData = seqView;
372     }
373     else
374     {
375       SeqCigar[] seqs = new SeqCigar[sequences.length];
376       for (int i = 0; i < sequences.length; i++)
377       {
378         seqs[i] = new SeqCigar(sequences[i], start, end);
379       }
380       CigarArray sdata = new CigarArray(seqs);
381       sdata.addOperation(CigarArray.M, end - start + 1);
382       this.seqData = new AlignmentView(sdata, start);
383     }
384
385     /*
386      * count the non-null sequences
387      */
388     noseqs = 0;
389
390     done = new BitSet();
391
392     for (SequenceI seq : sequences)
393     {
394       if (seq != null)
395       {
396         noseqs++;
397       }
398     }
399   }
400
401   /**
402    * Merges cluster(j) to cluster(i) and recalculates cluster and node distances
403    * 
404    * @param i
405    * @param j
406    */
407   void joinClusters(final int i, final int j)
408   {
409     double dist = distances.getValue(i, j);
410
411     ri = findr(i, j);
412     rj = findr(j, i);
413
414     findClusterDistance(i, j);
415
416     SequenceNode sn = new SequenceNode();
417
418     sn.setLeft((node.elementAt(i)));
419     sn.setRight((node.elementAt(j)));
420
421     BinaryNode tmpi = (node.elementAt(i));
422     BinaryNode tmpj = (node.elementAt(j));
423
424     findNewDistances(tmpi, tmpj, dist);
425
426     tmpi.setParent(sn);
427     tmpj.setParent(sn);
428
429     node.setElementAt(sn, i);
430
431     /*
432      * move the members of cluster(j) to cluster(i)
433      * and mark cluster j as out of the game
434      */
435     clusters.get(i).or(clusters.get(j));
436     clusters.get(j).clear();
437     done.set(j);
438   }
439
440   /*
441    * Computes and stores new distances for nodei and nodej, given the previous
442    * distance between them
443    */
444   protected abstract void findNewDistances(BinaryNode nodei,
445           BinaryNode nodej, double previousDistance);
446
447   /**
448    * Calculates and saves the distance between the combination of cluster(i) and
449    * cluster(j) and all other clusters. The form of the calculation depends on
450    * the tree clustering method being used.
451    * 
452    * @param i
453    * @param j
454    */
455   protected abstract void findClusterDistance(int i, int j);
456
457   /**
458    * Start by making a cluster for each individual sequence
459    */
460   void makeLeaves()
461   {
462     clusters = new Vector<BitSet>();
463
464     for (int i = 0; i < noseqs; i++)
465     {
466       SequenceNode sn = new SequenceNode();
467
468       sn.setElement(sequences[i]);
469       sn.setName(sequences[i].getName());
470       node.addElement(sn);
471       BitSet bs = new BitSet();
472       bs.set(i);
473       clusters.addElement(bs);
474     }
475   }
476
477   public AlignmentView getOriginalData()
478   {
479     return seqStrings;
480   }
481
482 }