JAL-2428 pass AlignmentView to TreeModel so Show Input Data works
[jalview.git] / src / jalview / analysis / TreeBuilder.java
1 package jalview.analysis;
2
3 import jalview.api.analysis.DistanceScoreModelI;
4 import jalview.api.analysis.ScoreModelI;
5 import jalview.api.analysis.SimilarityParamsI;
6 import jalview.api.analysis.SimilarityScoreModelI;
7 import jalview.datamodel.AlignmentView;
8 import jalview.datamodel.CigarArray;
9 import jalview.datamodel.SeqCigar;
10 import jalview.datamodel.SequenceI;
11 import jalview.datamodel.SequenceNode;
12 import jalview.math.MatrixI;
13 import jalview.viewmodel.AlignmentViewport;
14
15 import java.util.BitSet;
16 import java.util.Vector;
17
18 public abstract class TreeBuilder
19 {
20   public static final String AVERAGE_DISTANCE = "AV";
21
22   public static final String NEIGHBOUR_JOINING = "NJ";
23
24   protected Vector<BitSet> clusters;
25
26   protected SequenceI[] sequences;
27
28   public AlignmentView seqData;
29
30   protected BitSet done;
31
32   protected int noseqs;
33
34   int noClus;
35
36   protected MatrixI distances;
37
38   protected int mini;
39
40   protected int minj;
41
42   protected double ri;
43
44   protected double rj;
45
46   SequenceNode maxdist;
47
48   SequenceNode top;
49
50   double maxDistValue;
51
52   double maxheight;
53
54   int ycount;
55
56   Vector<SequenceNode> node;
57
58   private AlignmentView seqStrings;
59
60   /**
61    * Constructor
62    * 
63    * @param av
64    * @param sm
65    * @param scoreParameters
66    */
67   public TreeBuilder(AlignmentViewport av, ScoreModelI sm,
68           SimilarityParamsI scoreParameters)
69   {
70     int start, end;
71     boolean selview = av.getSelectionGroup() != null
72             && av.getSelectionGroup().getSize() > 1;
73     seqStrings = av.getAlignmentView(selview);
74     if (!selview)
75     {
76       start = 0;
77       end = av.getAlignment().getWidth();
78       this.sequences = av.getAlignment().getSequencesArray();
79     }
80     else
81     {
82       start = av.getSelectionGroup().getStartRes();
83       end = av.getSelectionGroup().getEndRes() + 1;
84       this.sequences = av.getSelectionGroup().getSequencesInOrder(
85               av.getAlignment());
86     }
87
88     init(seqStrings, start, end);
89
90     computeTree(sm, scoreParameters);
91   }
92
93   public SequenceI[] getSequences()
94   {
95     return sequences;
96   }
97
98   /**
99    * DOCUMENT ME!
100    * 
101    * @param nd
102    *          DOCUMENT ME!
103    * 
104    * @return DOCUMENT ME!
105    */
106   double findHeight(SequenceNode nd)
107   {
108     if (nd == null)
109     {
110       return maxheight;
111     }
112   
113     if ((nd.left() == null) && (nd.right() == null))
114     {
115       nd.height = ((SequenceNode) nd.parent()).height + nd.dist;
116   
117       if (nd.height > maxheight)
118       {
119         return nd.height;
120       }
121       else
122       {
123         return maxheight;
124       }
125     }
126     else
127     {
128       if (nd.parent() != null)
129       {
130         nd.height = ((SequenceNode) nd.parent()).height + nd.dist;
131       }
132       else
133       {
134         maxheight = 0;
135         nd.height = (float) 0.0;
136       }
137   
138       maxheight = findHeight((SequenceNode) (nd.left()));
139       maxheight = findHeight((SequenceNode) (nd.right()));
140     }
141   
142     return maxheight;
143   }
144
145   /**
146    * DOCUMENT ME!
147    * 
148    * @param nd
149    *          DOCUMENT ME!
150    */
151   void reCount(SequenceNode nd)
152   {
153     ycount = 0;
154     // _lycount = 0;
155     // _lylimit = this.node.size();
156     _reCount(nd);
157   }
158
159   /**
160    * DOCUMENT ME!
161    * 
162    * @param nd
163    *          DOCUMENT ME!
164    */
165   void _reCount(SequenceNode nd)
166   {
167     // if (_lycount<_lylimit)
168     // {
169     // System.err.println("Warning: depth of _recount greater than number of nodes.");
170     // }
171     if (nd == null)
172     {
173       return;
174     }
175     // _lycount++;
176   
177     if ((nd.left() != null) && (nd.right() != null))
178     {
179   
180       _reCount((SequenceNode) nd.left());
181       _reCount((SequenceNode) nd.right());
182   
183       SequenceNode l = (SequenceNode) nd.left();
184       SequenceNode r = (SequenceNode) nd.right();
185   
186       nd.count = l.count + r.count;
187       nd.ycount = (l.ycount + r.ycount) / 2;
188     }
189     else
190     {
191       nd.count = 1;
192       nd.ycount = ycount++;
193     }
194     // _lycount--;
195   }
196
197   /**
198    * DOCUMENT ME!
199    * 
200    * @return DOCUMENT ME!
201    */
202   public SequenceNode getTopNode()
203   {
204     return top;
205   }
206
207   /**
208    * 
209    * @return true if tree has real distances
210    */
211   public boolean hasDistances()
212   {
213     return true;
214   }
215
216   /**
217    * 
218    * @return true if tree has real bootstrap values
219    */
220   public boolean hasBootstrap()
221   {
222     return false;
223   }
224
225   public boolean hasRootDistance()
226   {
227     return true;
228   }
229
230   /**
231    * Form clusters by grouping sub-clusters, starting from one sequence per
232    * cluster, and finishing when only two clusters remain
233    */
234   void cluster()
235   {
236     while (noClus > 2)
237     {
238       findMinDistance();
239   
240       joinClusters(mini, minj);
241   
242       noClus--;
243     }
244   
245     int rightChild = done.nextClearBit(0);
246     int leftChild = done.nextClearBit(rightChild + 1);
247   
248     joinClusters(leftChild, rightChild);
249     top = (node.elementAt(leftChild));
250   
251     reCount(top);
252     findHeight(top);
253     findMaxDist(top);
254   }
255
256   protected abstract double findMinDistance();
257
258   /**
259    * Calculates the tree using the given score model and parameters, and the
260    * configured tree type
261    * <p>
262    * If the score model computes pairwise distance scores, then these are used
263    * directly to derive the tree
264    * <p>
265    * If the score model computes similarity scores, then the range of the scores
266    * is reversed to give a distance measure, and this is used to derive the tree
267    * 
268    * @param sm
269    * @param scoreOptions
270    */
271   protected void computeTree(ScoreModelI sm, SimilarityParamsI scoreOptions)
272   {
273     if (sm instanceof DistanceScoreModelI)
274     {
275       distances = ((DistanceScoreModelI) sm).findDistances(seqData,
276               scoreOptions);
277     }
278     else if (sm instanceof SimilarityScoreModelI)
279     {
280       /*
281        * compute similarity and invert it to give a distance measure
282        * reverseRange(true) converts maximum similarity to zero distance
283        */
284       MatrixI result = ((SimilarityScoreModelI) sm).findSimilarities(
285               seqData, scoreOptions);
286       result.reverseRange(true);
287       distances = result;
288     }
289   
290     makeLeaves();
291   
292     noClus = clusters.size();
293   
294     cluster();
295   }
296
297   /**
298    * DOCUMENT ME!
299    * 
300    * @param nd
301    *          DOCUMENT ME!
302    */
303   void findMaxDist(SequenceNode nd)
304   {
305     if (nd == null)
306     {
307       return;
308     }
309   
310     if ((nd.left() == null) && (nd.right() == null))
311     {
312       double dist = nd.dist;
313   
314       if (dist > maxDistValue)
315       {
316         maxdist = nd;
317         maxDistValue = dist;
318       }
319     }
320     else
321     {
322       findMaxDist((SequenceNode) nd.left());
323       findMaxDist((SequenceNode) nd.right());
324     }
325   }
326
327   /**
328    * DOCUMENT ME!
329    * 
330    * @param i
331    *          DOCUMENT ME!
332    * @param j
333    *          DOCUMENT ME!
334    * 
335    * @return DOCUMENT ME!
336    */
337   protected double findr(int i, int j)
338   {
339     double tmp = 1;
340   
341     for (int k = 0; k < noseqs; k++)
342     {
343       if ((k != i) && (k != j) && (!done.get(k)))
344       {
345         tmp = tmp + distances.getValue(i, k);
346       }
347     }
348   
349     if (noClus > 2)
350     {
351       tmp = tmp / (noClus - 2);
352     }
353   
354     return tmp;
355   }
356
357   protected void init(AlignmentView seqView, int start, int end)
358   {
359     this.node = new Vector<SequenceNode>();
360     if (seqView != null)
361     {
362       this.seqData = seqView;
363     }
364     else
365     {
366       SeqCigar[] seqs = new SeqCigar[sequences.length];
367       for (int i = 0; i < sequences.length; i++)
368       {
369         seqs[i] = new SeqCigar(sequences[i], start, end);
370       }
371       CigarArray sdata = new CigarArray(seqs);
372       sdata.addOperation(CigarArray.M, end - start + 1);
373       this.seqData = new AlignmentView(sdata, start);
374     }
375   
376     /*
377      * count the non-null sequences
378      */
379     noseqs = 0;
380   
381     done = new BitSet();
382   
383     for (SequenceI seq : sequences)
384     {
385       if (seq != null)
386       {
387         noseqs++;
388       }
389     }
390   }
391
392   /**
393    * Merges cluster(j) to cluster(i) and recalculates cluster and node distances
394    * 
395    * @param i
396    * @param j
397    */
398   void joinClusters(final int i, final int j)
399   {
400     double dist = distances.getValue(i, j);
401   
402     ri = findr(i, j);
403     rj = findr(j, i);
404   
405     findClusterDistance(i, j);
406   
407     SequenceNode sn = new SequenceNode();
408   
409     sn.setLeft((node.elementAt(i)));
410     sn.setRight((node.elementAt(j)));
411   
412     SequenceNode tmpi = (node.elementAt(i));
413     SequenceNode tmpj = (node.elementAt(j));
414   
415     findNewDistances(tmpi, tmpj, dist);
416   
417     tmpi.setParent(sn);
418     tmpj.setParent(sn);
419   
420     node.setElementAt(sn, i);
421   
422     /*
423      * move the members of cluster(j) to cluster(i)
424      * and mark cluster j as out of the game
425      */
426     clusters.get(i).or(clusters.get(j));
427     clusters.get(j).clear();
428     done.set(j);
429   }
430
431   protected abstract void findNewDistances(SequenceNode tmpi, SequenceNode tmpj,
432           double dist);
433
434   /**
435    * Calculates and saves the distance between the combination of cluster(i) and
436    * cluster(j) and all other clusters. The form of the calculation depends on
437    * the tree clustering method being used.
438    * 
439    * @param i
440    * @param j
441    */
442   protected abstract void findClusterDistance(int i, int j);
443
444   /**
445    * Start by making a cluster for each individual sequence
446    */
447   void makeLeaves()
448   {
449     clusters = new Vector<BitSet>();
450   
451     for (int i = 0; i < noseqs; i++)
452     {
453       SequenceNode sn = new SequenceNode();
454   
455       sn.setElement(sequences[i]);
456       sn.setName(sequences[i].getName());
457       node.addElement(sn);
458       BitSet bs = new BitSet();
459       bs.set(i);
460       clusters.addElement(bs);
461     }
462   }
463
464   public AlignmentView getOriginalData()
465   {
466     return seqStrings;
467   }
468
469 }