Merge branch 'releases/Release_2_11_3_Branch'
[jalview.git] / src / jalview / analysis / AverageDistanceEngine.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21 package jalview.analysis;
22
23 import java.util.ArrayList;
24 import java.util.BitSet;
25 import java.util.List;
26 import java.util.Vector;
27
28 import jalview.datamodel.AlignmentAnnotation;
29 import jalview.datamodel.BinaryNode;
30 import jalview.datamodel.ContactListI;
31 import jalview.datamodel.ContactMatrixI;
32 import jalview.math.Matrix;
33 import jalview.viewmodel.AlignmentViewport;
34
35 /**
36  * This class implements distance calculations used in constructing a Average
37  * Distance tree (also known as UPGMA)
38  */
39 public class AverageDistanceEngine extends TreeEngine
40 {
41   ContactMatrixI cm;
42
43   AlignmentViewport av;
44
45   AlignmentAnnotation aa;
46
47   // 0 - normalised dot product
48   // 1 - L1 - ie (abs(v_1-v_2)/dim(v))
49   // L1 is more rational - since can reason about value of difference,
50   // normalised dot product might give cleaner clusters, but more difficult to
51   // understand.
52
53   int mode = 1;
54
55   /**
56    * compute cosine distance matrix for a given contact matrix and create a
57    * UPGMA tree
58    * 
59    * @param cm
60    * @param cosineOrDifference
61    *          false - dot product : true - L1
62    */
63   public AverageDistanceEngine(AlignmentViewport av, AlignmentAnnotation aa,
64           ContactMatrixI cm, boolean cosineOrDifference)
65   {
66     this.av = av;
67     this.aa = aa;
68     this.cm = cm;
69     mode = (cosineOrDifference) ? 1 : 0;
70     calculate(cm);
71
72   }
73
74   public void calculate(ContactMatrixI cm)
75   {
76     this.cm = cm;
77     node = new Vector<BinaryNode>();
78     clusters = new Vector<BitSet>();
79     distances = new Matrix(new double[cm.getWidth()][cm.getWidth()]);
80     noseqs = cm.getWidth();
81     done = new BitSet();
82     double moduli[] = new double[cm.getWidth()];
83     double max;
84     if (mode == 0)
85     {
86       max = 1;
87     }
88     else
89     {
90       max = cm.getMax() * cm.getMax();
91     }
92
93     for (int i = 0; i < cm.getWidth(); i++)
94     {
95       // init the tree engine node for this column
96       BinaryNode cnode = new BinaryNode();
97       cnode.setElement(Integer.valueOf(i));
98       cnode.setName("c" + i);
99       node.addElement(cnode);
100       BitSet bs = new BitSet();
101       bs.set(i);
102       clusters.addElement(bs);
103
104       // compute distance matrix element
105       ContactListI ith = cm.getContactList(i);
106       distances.setValue(i, i, 0);
107       if (ith == null)
108       {
109         continue;
110       }
111       for (int j = 0; j < i; j++)
112       {
113         ContactListI jth = cm.getContactList(j);
114         if (jth == null)
115         {
116           break;
117         }
118         double prd = 0;
119         for (int indx = 0; indx < cm.getHeight(); indx++)
120         {
121           if (mode == 0)
122           {
123             if (j == 0)
124             {
125               moduli[i] += ith.getContactAt(indx) * ith.getContactAt(indx);
126             }
127             prd += ith.getContactAt(indx) * jth.getContactAt(indx);
128           }
129           else
130           {
131             prd += Math
132                     .abs(ith.getContactAt(indx) - jth.getContactAt(indx));
133           }
134         }
135         if (mode == 0)
136         {
137           if (j == 0)
138           {
139             moduli[i] = Math.sqrt(moduli[i]);
140           }
141           prd = (moduli[i] != 0 && moduli[j] != 0)
142                   ? prd / (moduli[i] * moduli[j])
143                   : 0;
144           prd = 1 - prd;
145         }
146         else
147         {
148           prd /= cm.getHeight();
149         }
150         distances.setValue(i, j, prd);
151         distances.setValue(j, i, prd);
152       }
153     }
154
155     noClus = clusters.size();
156     cluster();
157   }
158
159   /**
160    * Calculates and saves the distance between the combination of cluster(i) and
161    * cluster(j) and all other clusters. An average of the distances from
162    * cluster(i) and cluster(j) is calculated, weighted by the sizes of each
163    * cluster.
164    * 
165    * @param i
166    * @param j
167    */
168   @Override
169   protected void findClusterDistance(int i, int j)
170   {
171     int noi = clusters.elementAt(i).cardinality();
172     int noj = clusters.elementAt(j).cardinality();
173
174     // New distances from cluster i to others
175     double[] newdist = new double[noseqs];
176
177     for (int l = 0; l < noseqs; l++)
178     {
179       if ((l != i) && (l != j))
180       {
181         newdist[l] = ((distances.getValue(i, l) * noi)
182                 + (distances.getValue(j, l) * noj)) / (noi + noj);
183       }
184       else
185       {
186         newdist[l] = 0;
187       }
188     }
189
190     for (int ii = 0; ii < noseqs; ii++)
191     {
192       distances.setValue(i, ii, newdist[ii]);
193       distances.setValue(ii, i, newdist[ii]);
194     }
195   }
196
197   /**
198    * {@inheritDoc}
199    */
200   @Override
201   protected double findMinDistance()
202   {
203     double min = Double.MAX_VALUE;
204
205     for (int i = 0; i < (noseqs - 1); i++)
206     {
207       for (int j = i + 1; j < noseqs; j++)
208       {
209         if (!done.get(i) && !done.get(j))
210         {
211           if (distances.getValue(i, j) < min)
212           {
213             mini = i;
214             minj = j;
215
216             min = distances.getValue(i, j);
217           }
218         }
219       }
220     }
221     return min;
222   }
223
224   /**
225    * {@inheritDoc}
226    */
227   @Override
228   protected void findNewDistances(BinaryNode nodei, BinaryNode nodej,
229           double dist)
230   {
231     double ih = 0;
232     double jh = 0;
233
234     BinaryNode sni = nodei;
235     BinaryNode snj = nodej;
236
237     while (sni != null)
238     {
239       ih = ih + sni.dist;
240       sni = (BinaryNode) sni.left();
241     }
242
243     while (snj != null)
244     {
245       jh = jh + snj.dist;
246       snj = (BinaryNode) snj.left();
247     }
248
249     nodei.dist = ((dist / 2) - ih);
250     nodej.dist = ((dist / 2) - jh);
251   }
252
253   /***
254    * not the right place - OH WELL!
255    */
256
257   /**
258    * Makes a list of groups, where each group is represented by a node whose
259    * height (distance from the root node), as a fraction of the height of the
260    * whole tree, is greater than the given threshold. This corresponds to
261    * selecting the nodes immediately to the right of a vertical line
262    * partitioning the tree (if the tree is drawn with root to the left). Each
263    * such node represents a group that contains all of the sequences linked to
264    * the child leaf nodes.
265    * 
266    * @param threshold
267    * @see #getGroups()
268    */
269   public List<BinaryNode> groupNodes(float threshold)
270   {
271     List<BinaryNode> groups = new ArrayList<BinaryNode>();
272     _groupNodes(groups, getTopNode(), threshold);
273     return groups;
274   }
275
276   protected void _groupNodes(List<BinaryNode> groups, BinaryNode nd,
277           float threshold)
278   {
279     if (nd == null)
280     {
281       return;
282     }
283
284     if ((nd.height / maxheight) > threshold)
285     {
286       groups.add(nd);
287     }
288     else
289     {
290       _groupNodes(groups, nd.left(), threshold);
291       _groupNodes(groups, nd.right(), threshold);
292     }
293   }
294
295   /**
296    * DOCUMENT ME!
297    * 
298    * @param nd
299    *          DOCUMENT ME!
300    * 
301    * @return DOCUMENT ME!
302    */
303   public double findHeight(BinaryNode nd)
304   {
305     if (nd == null)
306     {
307       return maxheight;
308     }
309
310     if ((nd.left() == null) && (nd.right() == null))
311     {
312       nd.height = ((BinaryNode) nd.parent()).height + nd.dist;
313
314       if (nd.height > maxheight)
315       {
316         return nd.height;
317       }
318       else
319       {
320         return maxheight;
321       }
322     }
323     else
324     {
325       if (nd.parent() != null)
326       {
327         nd.height = ((BinaryNode) nd.parent()).height + nd.dist;
328       }
329       else
330       {
331         maxheight = 0;
332         nd.height = (float) 0.0;
333       }
334
335       maxheight = findHeight((BinaryNode) (nd.left()));
336       maxheight = findHeight((BinaryNode) (nd.right()));
337     }
338
339     return maxheight;
340   }
341
342   /**
343    * Search for leaf nodes below (or at) the given node
344    * 
345    * @param top2
346    *          root node to search from
347    * 
348    * @return
349    */
350   public Vector<BinaryNode> findLeaves(BinaryNode top2)
351   {
352     Vector<BinaryNode> leaves = new Vector<BinaryNode>();
353     findLeaves(top2, leaves);
354     return leaves;
355   }
356
357   /**
358    * Search for leaf nodes.
359    * 
360    * @param nd
361    *          root node to search from
362    * @param leaves
363    *          Vector of leaves to add leaf node objects too.
364    * 
365    * @return Vector of leaf nodes on binary tree
366    */
367   Vector<BinaryNode> findLeaves(BinaryNode nd, Vector<BinaryNode> leaves)
368   {
369     if (nd == null)
370     {
371       return leaves;
372     }
373
374     if ((nd.left() == null) && (nd.right() == null)) // Interior node
375     // detection
376     {
377       leaves.addElement(nd);
378
379       return leaves;
380     }
381     else
382     {
383       /*
384        * TODO: Identify internal nodes... if (node.isSequenceLabel()) {
385        * leaves.addElement(node); }
386        */
387       findLeaves(nd.left(), leaves);
388       findLeaves(nd.right(), leaves);
389     }
390
391     return leaves;
392   }
393
394 }