effef9a405fe3140b0f47f439b81f0ed5d285f6a
[jalview.git] / src / jalview / analysis / TreeBuilder.java
1 package jalview.analysis;
2
3 import jalview.api.analysis.ScoreModelI;
4 import jalview.api.analysis.SimilarityParamsI;
5 import jalview.datamodel.AlignmentView;
6 import jalview.datamodel.CigarArray;
7 import jalview.datamodel.SeqCigar;
8 import jalview.datamodel.SequenceI;
9 import jalview.datamodel.SequenceNode;
10 import jalview.math.MatrixI;
11 import jalview.viewmodel.AlignmentViewport;
12
13 import java.util.BitSet;
14 import java.util.Vector;
15
16 public abstract class TreeBuilder
17 {
18   public static final String AVERAGE_DISTANCE = "AV";
19
20   public static final String NEIGHBOUR_JOINING = "NJ";
21
22   protected Vector<BitSet> clusters;
23
24   protected SequenceI[] sequences;
25
26   public AlignmentView seqData;
27
28   protected BitSet done;
29
30   protected int noseqs;
31
32   int noClus;
33
34   protected MatrixI distances;
35
36   protected int mini;
37
38   protected int minj;
39
40   protected double ri;
41
42   protected double rj;
43
44   SequenceNode maxdist;
45
46   SequenceNode top;
47
48   double maxDistValue;
49
50   double maxheight;
51
52   int ycount;
53
54   Vector<SequenceNode> node;
55
56   private AlignmentView seqStrings;
57
58   /**
59    * Constructor
60    * 
61    * @param av
62    * @param sm
63    * @param scoreParameters
64    */
65   public TreeBuilder(AlignmentViewport av, ScoreModelI sm,
66           SimilarityParamsI scoreParameters)
67   {
68     int start, end;
69     boolean selview = av.getSelectionGroup() != null
70             && av.getSelectionGroup().getSize() > 1;
71     seqStrings = av.getAlignmentView(selview);
72     if (!selview)
73     {
74       start = 0;
75       end = av.getAlignment().getWidth();
76       this.sequences = av.getAlignment().getSequencesArray();
77     }
78     else
79     {
80       start = av.getSelectionGroup().getStartRes();
81       end = av.getSelectionGroup().getEndRes() + 1;
82       this.sequences = av.getSelectionGroup().getSequencesInOrder(
83               av.getAlignment());
84     }
85
86     init(seqStrings, start, end);
87
88     computeTree(sm, scoreParameters);
89   }
90
91   public SequenceI[] getSequences()
92   {
93     return sequences;
94   }
95
96   /**
97    * DOCUMENT ME!
98    * 
99    * @param nd
100    *          DOCUMENT ME!
101    * 
102    * @return DOCUMENT ME!
103    */
104   double findHeight(SequenceNode nd)
105   {
106     if (nd == null)
107     {
108       return maxheight;
109     }
110   
111     if ((nd.left() == null) && (nd.right() == null))
112     {
113       nd.height = ((SequenceNode) nd.parent()).height + nd.dist;
114   
115       if (nd.height > maxheight)
116       {
117         return nd.height;
118       }
119       else
120       {
121         return maxheight;
122       }
123     }
124     else
125     {
126       if (nd.parent() != null)
127       {
128         nd.height = ((SequenceNode) nd.parent()).height + nd.dist;
129       }
130       else
131       {
132         maxheight = 0;
133         nd.height = (float) 0.0;
134       }
135   
136       maxheight = findHeight((SequenceNode) (nd.left()));
137       maxheight = findHeight((SequenceNode) (nd.right()));
138     }
139   
140     return maxheight;
141   }
142
143   /**
144    * DOCUMENT ME!
145    * 
146    * @param nd
147    *          DOCUMENT ME!
148    */
149   void reCount(SequenceNode nd)
150   {
151     ycount = 0;
152     // _lycount = 0;
153     // _lylimit = this.node.size();
154     _reCount(nd);
155   }
156
157   /**
158    * DOCUMENT ME!
159    * 
160    * @param nd
161    *          DOCUMENT ME!
162    */
163   void _reCount(SequenceNode nd)
164   {
165     // if (_lycount<_lylimit)
166     // {
167     // System.err.println("Warning: depth of _recount greater than number of nodes.");
168     // }
169     if (nd == null)
170     {
171       return;
172     }
173     // _lycount++;
174   
175     if ((nd.left() != null) && (nd.right() != null))
176     {
177   
178       _reCount((SequenceNode) nd.left());
179       _reCount((SequenceNode) nd.right());
180   
181       SequenceNode l = (SequenceNode) nd.left();
182       SequenceNode r = (SequenceNode) nd.right();
183   
184       nd.count = l.count + r.count;
185       nd.ycount = (l.ycount + r.ycount) / 2;
186     }
187     else
188     {
189       nd.count = 1;
190       nd.ycount = ycount++;
191     }
192     // _lycount--;
193   }
194
195   /**
196    * DOCUMENT ME!
197    * 
198    * @return DOCUMENT ME!
199    */
200   public SequenceNode getTopNode()
201   {
202     return top;
203   }
204
205   /**
206    * 
207    * @return true if tree has real distances
208    */
209   public boolean hasDistances()
210   {
211     return true;
212   }
213
214   /**
215    * 
216    * @return true if tree has real bootstrap values
217    */
218   public boolean hasBootstrap()
219   {
220     return false;
221   }
222
223   public boolean hasRootDistance()
224   {
225     return true;
226   }
227
228   /**
229    * Form clusters by grouping sub-clusters, starting from one sequence per
230    * cluster, and finishing when only two clusters remain
231    */
232   void cluster()
233   {
234     while (noClus > 2)
235     {
236       findMinDistance();
237   
238       joinClusters(mini, minj);
239   
240       noClus--;
241     }
242   
243     int rightChild = done.nextClearBit(0);
244     int leftChild = done.nextClearBit(rightChild + 1);
245   
246     joinClusters(leftChild, rightChild);
247     top = (node.elementAt(leftChild));
248   
249     reCount(top);
250     findHeight(top);
251     findMaxDist(top);
252   }
253
254   /**
255    * Returns the minimum distance between two clusters, and also sets the
256    * indices of the clusters in fields mini and minj
257    * 
258    * @return
259    */
260   protected abstract double findMinDistance();
261
262   /**
263    * Calculates the tree using the given score model and parameters, and the
264    * configured tree type
265    * <p>
266    * If the score model computes pairwise distance scores, then these are used
267    * directly to derive the tree
268    * <p>
269    * If the score model computes similarity scores, then the range of the scores
270    * is reversed to give a distance measure, and this is used to derive the tree
271    * 
272    * @param sm
273    * @param scoreOptions
274    */
275   protected void computeTree(ScoreModelI sm, SimilarityParamsI scoreOptions)
276   {
277     distances = sm.findDistances(seqData, scoreOptions);
278   
279     makeLeaves();
280   
281     noClus = clusters.size();
282   
283     cluster();
284   }
285
286   /**
287    * Finds the node, at or below the given node, with the maximum distance, and
288    * saves the node and the distance value
289    * 
290    * @param nd
291    */
292   void findMaxDist(SequenceNode nd)
293   {
294     if (nd == null)
295     {
296       return;
297     }
298   
299     if ((nd.left() == null) && (nd.right() == null))
300     {
301       double dist = nd.dist;
302   
303       if (dist > maxDistValue)
304       {
305         maxdist = nd;
306         maxDistValue = dist;
307       }
308     }
309     else
310     {
311       findMaxDist((SequenceNode) nd.left());
312       findMaxDist((SequenceNode) nd.right());
313     }
314   }
315
316   /**
317    * Calculates and returns r, whatever that is
318    * 
319    * @param i
320    * @param j
321    * 
322    * @return
323    */
324   protected double findr(int i, int j)
325   {
326     double tmp = 1;
327   
328     for (int k = 0; k < noseqs; k++)
329     {
330       if ((k != i) && (k != j) && (!done.get(k)))
331       {
332         tmp = tmp + distances.getValue(i, k);
333       }
334     }
335   
336     if (noClus > 2)
337     {
338       tmp = tmp / (noClus - 2);
339     }
340   
341     return tmp;
342   }
343
344   protected void init(AlignmentView seqView, int start, int end)
345   {
346     this.node = new Vector<SequenceNode>();
347     if (seqView != null)
348     {
349       this.seqData = seqView;
350     }
351     else
352     {
353       SeqCigar[] seqs = new SeqCigar[sequences.length];
354       for (int i = 0; i < sequences.length; i++)
355       {
356         seqs[i] = new SeqCigar(sequences[i], start, end);
357       }
358       CigarArray sdata = new CigarArray(seqs);
359       sdata.addOperation(CigarArray.M, end - start + 1);
360       this.seqData = new AlignmentView(sdata, start);
361     }
362   
363     /*
364      * count the non-null sequences
365      */
366     noseqs = 0;
367   
368     done = new BitSet();
369   
370     for (SequenceI seq : sequences)
371     {
372       if (seq != null)
373       {
374         noseqs++;
375       }
376     }
377   }
378
379   /**
380    * Merges cluster(j) to cluster(i) and recalculates cluster and node distances
381    * 
382    * @param i
383    * @param j
384    */
385   void joinClusters(final int i, final int j)
386   {
387     double dist = distances.getValue(i, j);
388   
389     ri = findr(i, j);
390     rj = findr(j, i);
391   
392     findClusterDistance(i, j);
393   
394     SequenceNode sn = new SequenceNode();
395   
396     sn.setLeft((node.elementAt(i)));
397     sn.setRight((node.elementAt(j)));
398   
399     SequenceNode tmpi = (node.elementAt(i));
400     SequenceNode tmpj = (node.elementAt(j));
401   
402     findNewDistances(tmpi, tmpj, dist);
403   
404     tmpi.setParent(sn);
405     tmpj.setParent(sn);
406   
407     node.setElementAt(sn, i);
408   
409     /*
410      * move the members of cluster(j) to cluster(i)
411      * and mark cluster j as out of the game
412      */
413     clusters.get(i).or(clusters.get(j));
414     clusters.get(j).clear();
415     done.set(j);
416   }
417
418   /*
419    * Computes and stores new distances for nodei and nodej, given the previous
420    * distance between them
421    */
422   protected abstract void findNewDistances(SequenceNode nodei,
423           SequenceNode nodej, double previousDistance);
424
425   /**
426    * Calculates and saves the distance between the combination of cluster(i) and
427    * cluster(j) and all other clusters. The form of the calculation depends on
428    * the tree clustering method being used.
429    * 
430    * @param i
431    * @param j
432    */
433   protected abstract void findClusterDistance(int i, int j);
434
435   /**
436    * Start by making a cluster for each individual sequence
437    */
438   void makeLeaves()
439   {
440     clusters = new Vector<BitSet>();
441   
442     for (int i = 0; i < noseqs; i++)
443     {
444       SequenceNode sn = new SequenceNode();
445   
446       sn.setElement(sequences[i]);
447       sn.setName(sequences[i].getName());
448       node.addElement(sn);
449       BitSet bs = new BitSet();
450       bs.set(i);
451       clusters.addElement(bs);
452     }
453   }
454
455   public AlignmentView getOriginalData()
456   {
457     return seqStrings;
458   }
459
460 }