0cd340e88b57dde048bb634f1e244fbce244e131
[jalview.git] / src / jalview / analysis / TreeBuilder.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21 package jalview.analysis;
22
23 import jalview.api.analysis.ScoreModelI;
24 import jalview.api.analysis.SimilarityParamsI;
25 import jalview.datamodel.AlignmentView;
26 import jalview.datamodel.CigarArray;
27 import jalview.datamodel.SeqCigar;
28 import jalview.datamodel.SequenceI;
29 import jalview.datamodel.SequenceNode;
30 import jalview.math.MatrixI;
31 import jalview.viewmodel.AlignmentViewport;
32
33 import java.util.BitSet;
34 import java.util.Vector;
35
36 public abstract class TreeBuilder
37 {
38   public static final String AVERAGE_DISTANCE = "AV";
39
40   public static final String NEIGHBOUR_JOINING = "NJ";
41
42   protected Vector<BitSet> clusters;
43
44   protected SequenceI[] sequences;
45
46   public AlignmentView seqData;
47
48   protected BitSet done;
49
50   protected int noseqs;
51
52   int noClus;
53
54   protected MatrixI distances;
55
56   protected int mini;
57
58   protected int minj;
59
60   protected double ri;
61
62   protected double rj;
63
64   SequenceNode maxdist;
65
66   SequenceNode top;
67
68   double maxDistValue;
69
70   double maxHeight;
71
72   int ycount;
73
74   Vector<SequenceNode> node;
75
76   protected ScoreModelI scoreModel;
77
78   protected SimilarityParamsI scoreParams;
79
80   private AlignmentView seqStrings; // redundant? (see seqData)
81
82   /**
83    * Constructor
84    * 
85    * @param av
86    * @param sm
87    * @param scoreParameters
88    */
89   public TreeBuilder(AlignmentViewport av, ScoreModelI sm,
90           SimilarityParamsI scoreParameters)
91   {
92     int start, end;
93     boolean selview = av.getSelectionGroup() != null
94             && av.getSelectionGroup().getSize() > 1;
95     seqStrings = av.getAlignmentView(selview);
96     if (!selview)
97     {
98       start = 0;
99       end = av.getAlignment().getWidth();
100       this.sequences = av.getAlignment().getSequencesArray();
101     }
102     else
103     {
104       start = av.getSelectionGroup().getStartRes();
105       end = av.getSelectionGroup().getEndRes() + 1;
106       this.sequences = av.getSelectionGroup()
107               .getSequencesInOrder(av.getAlignment());
108     }
109
110     init(seqStrings, start, end);
111
112     computeTree(sm, scoreParameters);
113   }
114
115   public SequenceI[] getSequences()
116   {
117     return sequences;
118   }
119
120   /**
121    * DOCUMENT ME!
122    * 
123    * @param nd
124    *          DOCUMENT ME!
125    * 
126    * @return DOCUMENT ME!
127    */
128   double findHeight(SequenceNode nd)
129   {
130     if (nd == null)
131     {
132       return maxHeight;
133     }
134
135     if ((nd.left() == null) && (nd.right() == null))
136     {
137       nd.height = ((SequenceNode) nd.parent()).height + nd.dist;
138
139       if (nd.height > maxHeight)
140       {
141         return nd.height;
142       }
143       else
144       {
145         return maxHeight;
146       }
147     }
148     else
149     {
150       if (nd.parent() != null)
151       {
152         nd.height = ((SequenceNode) nd.parent()).height + nd.dist;
153       }
154       else
155       {
156         maxHeight = 0;
157         nd.height = (float) 0.0;
158       }
159
160       maxHeight = findHeight((SequenceNode) (nd.left()));
161       maxHeight = findHeight((SequenceNode) (nd.right()));
162     }
163
164     return maxHeight;
165   }
166
167   /**
168    * DOCUMENT ME!
169    * 
170    * @param nd
171    *          DOCUMENT ME!
172    */
173   void reCount(SequenceNode nd)
174   {
175     ycount = 0;
176     // _lycount = 0;
177     // _lylimit = this.node.size();
178     _reCount(nd);
179   }
180
181   /**
182    * DOCUMENT ME!
183    * 
184    * @param nd
185    *          DOCUMENT ME!
186    */
187   void _reCount(SequenceNode nd)
188   {
189     // if (_lycount<_lylimit)
190     // {
191     // System.err.println("Warning: depth of _recount greater than number of
192     // nodes.");
193     // }
194     if (nd == null)
195     {
196       return;
197     }
198     // _lycount++;
199
200     if ((nd.left() != null) && (nd.right() != null))
201     {
202
203       _reCount((SequenceNode) nd.left());
204       _reCount((SequenceNode) nd.right());
205
206       SequenceNode l = (SequenceNode) nd.left();
207       SequenceNode r = (SequenceNode) nd.right();
208
209       nd.count = l.count + r.count;
210       nd.ycount = (l.ycount + r.ycount) / 2;
211     }
212     else
213     {
214       nd.count = 1;
215       nd.ycount = ycount++;
216     }
217     // _lycount--;
218   }
219
220   /**
221    * DOCUMENT ME!
222    * 
223    * @return DOCUMENT ME!
224    */
225   public SequenceNode getTopNode()
226   {
227     return top;
228   }
229
230   /**
231    * 
232    * @return true if tree has real distances
233    */
234   public boolean hasDistances()
235   {
236     return true;
237   }
238
239   /**
240    * 
241    * @return true if tree has real bootstrap values
242    */
243   public boolean hasBootstrap()
244   {
245     return false;
246   }
247
248   public boolean hasRootDistance()
249   {
250     return true;
251   }
252
253   /**
254    * Form clusters by grouping sub-clusters, starting from one sequence per
255    * cluster, and finishing when only two clusters remain
256    */
257   void cluster()
258   {
259     while (noClus > 2)
260     {
261       findMinDistance();
262
263       joinClusters(mini, minj);
264
265       noClus--;
266     }
267
268     int rightChild = done.nextClearBit(0);
269     int leftChild = done.nextClearBit(rightChild + 1);
270
271     joinClusters(leftChild, rightChild);
272     top = (node.elementAt(leftChild));
273
274     reCount(top);
275     findHeight(top);
276     findMaxDist(top);
277   }
278
279   /**
280    * Returns the minimum distance between two clusters, and also sets the
281    * indices of the clusters in fields mini and minj
282    * 
283    * @return
284    */
285   protected abstract double findMinDistance();
286
287   /**
288    * Calculates the tree using the given score model and parameters, and the
289    * configured tree type
290    * <p>
291    * If the score model computes pairwise distance scores, then these are used
292    * directly to derive the tree
293    * <p>
294    * If the score model computes similarity scores, then the range of the scores
295    * is reversed to give a distance measure, and this is used to derive the tree
296    * 
297    * @param sm
298    * @param scoreOptions
299    */
300   protected void computeTree(ScoreModelI sm, SimilarityParamsI scoreOptions)
301   {
302
303     this.scoreModel = sm;
304     this.scoreParams = scoreOptions;
305
306     distances = scoreModel.findDistances(seqData, scoreParams);
307
308     makeLeaves();
309
310     noClus = clusters.size();
311
312     cluster();
313   }
314
315   /**
316    * Finds the node, at or below the given node, with the maximum distance, and
317    * saves the node and the distance value
318    * 
319    * @param nd
320    */
321   void findMaxDist(SequenceNode nd)
322   {
323     if (nd == null)
324     {
325       return;
326     }
327
328     if ((nd.left() == null) && (nd.right() == null))
329     {
330       double dist = nd.dist;
331
332       if (dist > maxDistValue)
333       {
334         maxdist = nd;
335         maxDistValue = dist;
336       }
337     }
338     else
339     {
340       findMaxDist((SequenceNode) nd.left());
341       findMaxDist((SequenceNode) nd.right());
342     }
343   }
344
345   /**
346    * Calculates and returns r, whatever that is
347    * 
348    * @param i
349    * @param j
350    * 
351    * @return
352    */
353   protected double findr(int i, int j)
354   {
355     double tmp = 1;
356
357     for (int k = 0; k < noseqs; k++)
358     {
359       if ((k != i) && (k != j) && (!done.get(k)))
360       {
361         tmp = tmp + distances.getValue(i, k);
362       }
363     }
364
365     if (noClus > 2)
366     {
367       tmp = tmp / (noClus - 2);
368     }
369
370     return tmp;
371   }
372
373   protected void init(AlignmentView seqView, int start, int end)
374   {
375     this.node = new Vector<>();
376     if (seqView != null)
377     {
378       this.seqData = seqView;
379     }
380     else
381     {
382       SeqCigar[] seqs = new SeqCigar[sequences.length];
383       for (int i = 0; i < sequences.length; i++)
384       {
385         seqs[i] = new SeqCigar(sequences[i], start, end);
386       }
387       CigarArray sdata = new CigarArray(seqs);
388       sdata.addOperation(CigarArray.M, end - start + 1);
389       this.seqData = new AlignmentView(sdata, start);
390     }
391
392     /*
393      * count the non-null sequences
394      */
395     noseqs = 0;
396
397     done = new BitSet();
398
399     for (SequenceI seq : sequences)
400     {
401       if (seq != null)
402       {
403         noseqs++;
404       }
405     }
406   }
407
408   /**
409    * Merges cluster(j) to cluster(i) and recalculates cluster and node distances
410    * 
411    * @param i
412    * @param j
413    */
414   void joinClusters(final int i, final int j)
415   {
416     double dist = distances.getValue(i, j);
417
418     ri = findr(i, j);
419     rj = findr(j, i);
420
421     findClusterDistance(i, j);
422
423     SequenceNode sn = new SequenceNode();
424
425     sn.setLeft((node.elementAt(i)));
426     sn.setRight((node.elementAt(j)));
427
428     SequenceNode tmpi = (node.elementAt(i));
429     SequenceNode tmpj = (node.elementAt(j));
430
431     findNewDistances(tmpi, tmpj, dist);
432
433     tmpi.setParent(sn);
434     tmpj.setParent(sn);
435
436     node.setElementAt(sn, i);
437
438     /*
439      * move the members of cluster(j) to cluster(i)
440      * and mark cluster j as out of the game
441      */
442     clusters.get(i).or(clusters.get(j));
443     clusters.get(j).clear();
444     done.set(j);
445   }
446
447   /*
448    * Computes and stores new distances for nodei and nodej, given the previous
449    * distance between them
450    */
451   protected abstract void findNewDistances(SequenceNode nodei,
452           SequenceNode nodej, double previousDistance);
453
454   /**
455    * Calculates and saves the distance between the combination of cluster(i) and
456    * cluster(j) and all other clusters. The form of the calculation depends on
457    * the tree clustering method being used.
458    * 
459    * @param i
460    * @param j
461    */
462   protected abstract void findClusterDistance(int i, int j);
463
464   /**
465    * Start by making a cluster for each individual sequence
466    */
467   void makeLeaves()
468   {
469     clusters = new Vector<>();
470
471     for (int i = 0; i < noseqs; i++)
472     {
473       SequenceNode sn = new SequenceNode();
474
475       sn.setElement(sequences[i]);
476       sn.setName(sequences[i].getName());
477       node.addElement(sn);
478
479       BitSet bs = new BitSet();
480       bs.set(i);
481       clusters.addElement(bs);
482     }
483   }
484
485   public AlignmentView getOriginalData()
486   {
487     return seqStrings;
488   }
489
490   public MatrixI getDistances()
491   {
492     return distances;
493   }
494
495   public AlignmentView getSeqData()
496   {
497     return seqData;
498   }
499
500   public ScoreModelI getScoreModel()
501   {
502     return scoreModel;
503   }
504
505   public SimilarityParamsI getScoreParams()
506   {
507     return scoreParams;
508   }
509
510 }