JAL-2795 Forester now NJ-clusters the DistanceMatrix from Jalview
[jalview.git] / src / jalview / analysis / TreeBuilder.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21 package jalview.analysis;
22
23 import jalview.api.analysis.ScoreModelI;
24 import jalview.api.analysis.SimilarityParamsI;
25 import jalview.datamodel.AlignmentView;
26 import jalview.datamodel.CigarArray;
27 import jalview.datamodel.SeqCigar;
28 import jalview.datamodel.SequenceI;
29 import jalview.datamodel.SequenceNode;
30 import jalview.math.MatrixI;
31 import jalview.viewmodel.AlignmentViewport;
32
33 import java.util.BitSet;
34 import java.util.Vector;
35
36 public abstract class TreeBuilder
37 {
38   public static final String AVERAGE_DISTANCE = "AV";
39
40   public static final String NEIGHBOUR_JOINING = "NJ";
41
42   protected Vector<BitSet> clusters;
43
44   protected SequenceI[] sequences;
45
46   public AlignmentView seqData;
47
48   protected BitSet done;
49
50   protected int noseqs;
51
52   int noClus;
53
54   protected MatrixI distances;
55
56   public MatrixI testDistances;
57
58   protected int mini;
59
60   protected int minj;
61
62   protected double ri;
63
64   protected double rj;
65
66   SequenceNode maxdist;
67
68   SequenceNode top;
69
70   double maxDistValue;
71
72   double maxHeight;
73
74   int ycount;
75
76   Vector<SequenceNode> node;
77
78   protected ScoreModelI scoreModel;
79
80   protected SimilarityParamsI scoreParams;
81
82   private AlignmentView seqStrings; // redundant? (see seqData)
83
84   /**
85    * Constructor
86    * 
87    * @param av
88    * @param sm
89    * @param scoreParameters
90    */
91   public TreeBuilder(AlignmentViewport av, ScoreModelI sm,
92           SimilarityParamsI scoreParameters)
93   {
94     int start, end;
95     boolean selview = av.getSelectionGroup() != null
96             && av.getSelectionGroup().getSize() > 1;
97     seqStrings = av.getAlignmentView(selview);
98     if (!selview)
99     {
100       start = 0;
101       end = av.getAlignment().getWidth();
102       this.sequences = av.getAlignment().getSequencesArray();
103     }
104     else
105     {
106       start = av.getSelectionGroup().getStartRes();
107       end = av.getSelectionGroup().getEndRes() + 1;
108       this.sequences = av.getSelectionGroup()
109               .getSequencesInOrder(av.getAlignment());
110     }
111
112     init(seqStrings, start, end);
113
114     computeTree(sm, scoreParameters);
115   }
116
117   public SequenceI[] getSequences()
118   {
119     return sequences;
120   }
121
122   /**
123    * DOCUMENT ME!
124    * 
125    * @param nd
126    *          DOCUMENT ME!
127    * 
128    * @return DOCUMENT ME!
129    */
130   double findHeight(SequenceNode nd)
131   {
132     if (nd == null)
133     {
134       return maxHeight;
135     }
136
137     if ((nd.left() == null) && (nd.right() == null))
138     {
139       nd.height = ((SequenceNode) nd.parent()).height + nd.dist;
140
141       if (nd.height > maxHeight)
142       {
143         return nd.height;
144       }
145       else
146       {
147         return maxHeight;
148       }
149     }
150     else
151     {
152       if (nd.parent() != null)
153       {
154         nd.height = ((SequenceNode) nd.parent()).height + nd.dist;
155       }
156       else
157       {
158         maxHeight = 0;
159         nd.height = (float) 0.0;
160       }
161
162       maxHeight = findHeight((SequenceNode) (nd.left()));
163       maxHeight = findHeight((SequenceNode) (nd.right()));
164     }
165
166     return maxHeight;
167   }
168
169   /**
170    * DOCUMENT ME!
171    * 
172    * @param nd
173    *          DOCUMENT ME!
174    */
175   void reCount(SequenceNode nd)
176   {
177     ycount = 0;
178     // _lycount = 0;
179     // _lylimit = this.node.size();
180     _reCount(nd);
181   }
182
183   /**
184    * DOCUMENT ME!
185    * 
186    * @param nd
187    *          DOCUMENT ME!
188    */
189   void _reCount(SequenceNode nd)
190   {
191     // if (_lycount<_lylimit)
192     // {
193     // System.err.println("Warning: depth of _recount greater than number of
194     // nodes.");
195     // }
196     if (nd == null)
197     {
198       return;
199     }
200     // _lycount++;
201
202     if ((nd.left() != null) && (nd.right() != null))
203     {
204
205       _reCount((SequenceNode) nd.left());
206       _reCount((SequenceNode) nd.right());
207
208       SequenceNode l = (SequenceNode) nd.left();
209       SequenceNode r = (SequenceNode) nd.right();
210
211       nd.count = l.count + r.count;
212       nd.ycount = (l.ycount + r.ycount) / 2;
213     }
214     else
215     {
216       nd.count = 1;
217       nd.ycount = ycount++;
218     }
219     // _lycount--;
220   }
221
222   /**
223    * DOCUMENT ME!
224    * 
225    * @return DOCUMENT ME!
226    */
227   public SequenceNode getTopNode()
228   {
229     return top;
230   }
231
232   /**
233    * 
234    * @return true if tree has real distances
235    */
236   public boolean hasDistances()
237   {
238     return true;
239   }
240
241   /**
242    * 
243    * @return true if tree has real bootstrap values
244    */
245   public boolean hasBootstrap()
246   {
247     return false;
248   }
249
250   public boolean hasRootDistance()
251   {
252     return true;
253   }
254
255   /**
256    * Form clusters by grouping sub-clusters, starting from one sequence per
257    * cluster, and finishing when only two clusters remain
258    */
259   void cluster()
260   {
261     while (noClus > 2)
262     {
263       findMinDistance();
264
265       joinClusters(mini, minj);
266
267       noClus--;
268     }
269
270     int rightChild = done.nextClearBit(0);
271     int leftChild = done.nextClearBit(rightChild + 1);
272
273     joinClusters(leftChild, rightChild);
274     top = (node.elementAt(leftChild));
275
276     reCount(top);
277     findHeight(top);
278     findMaxDist(top);
279   }
280
281   /**
282    * Returns the minimum distance between two clusters, and also sets the
283    * indices of the clusters in fields mini and minj
284    * 
285    * @return
286    */
287   protected abstract double findMinDistance();
288
289   /**
290    * Calculates the tree using the given score model and parameters, and the
291    * configured tree type
292    * <p>
293    * If the score model computes pairwise distance scores, then these are used
294    * directly to derive the tree
295    * <p>
296    * If the score model computes similarity scores, then the range of the scores
297    * is reversed to give a distance measure, and this is used to derive the tree
298    * 
299    * @param sm
300    * @param scoreOptions
301    */
302   protected void computeTree(ScoreModelI sm, SimilarityParamsI scoreOptions)
303   {
304
305     this.scoreModel = sm;
306     this.scoreParams = scoreOptions;
307
308     testDistances = scoreModel.findDistances(seqData, scoreParams);
309     distances = scoreModel.findDistances(seqData, scoreParams);
310     makeLeaves();
311
312     noClus = clusters.size();
313
314     cluster();
315   }
316
317   /**
318    * Finds the node, at or below the given node, with the maximum distance, and
319    * saves the node and the distance value
320    * 
321    * @param nd
322    */
323   void findMaxDist(SequenceNode nd)
324   {
325     if (nd == null)
326     {
327       return;
328     }
329
330     if ((nd.left() == null) && (nd.right() == null))
331     {
332       double dist = nd.dist;
333
334       if (dist > maxDistValue)
335       {
336         maxdist = nd;
337         maxDistValue = dist;
338       }
339     }
340     else
341     {
342       findMaxDist((SequenceNode) nd.left());
343       findMaxDist((SequenceNode) nd.right());
344     }
345   }
346
347   /**
348    * Calculates and returns r, whatever that is
349    * 
350    * @param i
351    * @param j
352    * 
353    * @return
354    */
355   protected double findr(int i, int j)
356   {
357     double tmp = 1;
358
359     for (int k = 0; k < noseqs; k++)
360     {
361       if ((k != i) && (k != j) && (!done.get(k)))
362       {
363         tmp = tmp + distances.getValue(i, k);
364       }
365     }
366
367     if (noClus > 2)
368     {
369       tmp = tmp / (noClus - 2);
370     }
371
372     return tmp;
373   }
374
375   protected void init(AlignmentView seqView, int start, int end)
376   {
377     this.node = new Vector<>();
378     if (seqView != null)
379     {
380       this.seqData = seqView;
381     }
382     else
383     {
384       SeqCigar[] seqs = new SeqCigar[sequences.length];
385       for (int i = 0; i < sequences.length; i++)
386       {
387         seqs[i] = new SeqCigar(sequences[i], start, end);
388       }
389       CigarArray sdata = new CigarArray(seqs);
390       sdata.addOperation(CigarArray.M, end - start + 1);
391       this.seqData = new AlignmentView(sdata, start);
392     }
393
394     /*
395      * count the non-null sequences
396      */
397     noseqs = 0;
398
399     done = new BitSet();
400
401     for (SequenceI seq : sequences)
402     {
403       if (seq != null)
404       {
405         noseqs++;
406       }
407     }
408   }
409
410   /**
411    * Merges cluster(j) to cluster(i) and recalculates cluster and node distances
412    * 
413    * @param i
414    * @param j
415    */
416   void joinClusters(final int i, final int j)
417   {
418     double dist = distances.getValue(i, j);
419
420     ri = findr(i, j);
421     rj = findr(j, i);
422
423     findClusterDistance(i, j);
424
425     SequenceNode sn = new SequenceNode();
426
427     sn.setLeft((node.elementAt(i)));
428     sn.setRight((node.elementAt(j)));
429
430     SequenceNode tmpi = (node.elementAt(i));
431     SequenceNode tmpj = (node.elementAt(j));
432
433     findNewDistances(tmpi, tmpj, dist);
434
435     tmpi.setParent(sn);
436     tmpj.setParent(sn);
437
438     node.setElementAt(sn, i);
439
440     /*
441      * move the members of cluster(j) to cluster(i)
442      * and mark cluster j as out of the game
443      */
444     clusters.get(i).or(clusters.get(j));
445     clusters.get(j).clear();
446     done.set(j);
447   }
448
449   /*
450    * Computes and stores new distances for nodei and nodej, given the previous
451    * distance between them
452    */
453   protected abstract void findNewDistances(SequenceNode nodei,
454           SequenceNode nodej, double previousDistance);
455
456   /**
457    * Calculates and saves the distance between the combination of cluster(i) and
458    * cluster(j) and all other clusters. The form of the calculation depends on
459    * the tree clustering method being used.
460    * 
461    * @param i
462    * @param j
463    */
464   protected abstract void findClusterDistance(int i, int j);
465
466   /**
467    * Start by making a cluster for each individual sequence
468    */
469   void makeLeaves()
470   {
471     clusters = new Vector<>();
472
473     for (int i = 0; i < noseqs; i++)
474     {
475       SequenceNode sn = new SequenceNode();
476
477       sn.setElement(sequences[i]);
478       sn.setName(sequences[i].getName());
479       node.addElement(sn);
480
481       BitSet bs = new BitSet();
482       bs.set(i);
483       clusters.addElement(bs);
484     }
485   }
486
487   public AlignmentView getOriginalData()
488   {
489     return seqStrings;
490   }
491
492   public MatrixI getDistances()
493   {
494     return distances;
495   }
496
497   public AlignmentView getSeqData()
498   {
499     return seqData;
500   }
501
502   public ScoreModelI getScoreModel()
503   {
504     return scoreModel;
505   }
506
507   public SimilarityParamsI getScoreParams()
508   {
509     return scoreParams;
510   }
511
512 }