d81dd44687d286542f4576034ca76307efdd2e63
[jalview.git] / src / jalview / analysis / AverageDistanceEngine.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21 package jalview.analysis;
22
23 import java.util.ArrayList;
24 import java.util.BitSet;
25 import java.util.List;
26 import java.util.Vector;
27
28 import jalview.datamodel.AlignmentAnnotation;
29 import jalview.datamodel.BinaryNode;
30 import jalview.datamodel.ContactListI;
31 import jalview.datamodel.ContactMatrixI;
32 import jalview.math.Matrix;
33 import jalview.viewmodel.AlignmentViewport;
34
35 /**
36  * This class implements distance calculations used in constructing a Average
37  * Distance tree (also known as UPGMA)
38  */
39 public class AverageDistanceEngine extends TreeEngine
40 {
41   ContactMatrixI cm;
42
43   AlignmentViewport av;
44
45   AlignmentAnnotation aa;
46
47   // 0 - normalised dot product
48   // 1 - L1 - ie (abs(v_1-v_2)/dim(v))
49   // L1 is more rational - since can reason about value of difference,
50   // normalised dot product might give cleaner clusters, but more difficult to
51   // understand.
52
53   int mode = 1;
54
55   /**
56    * compute cosine distance matrix for a given contact matrix and create a
57    * UPGMA tree
58    * @param cm
59    * @param cosineOrDifference false - dot product : true - L1
60    */
61   public AverageDistanceEngine(AlignmentViewport av, AlignmentAnnotation aa,
62           ContactMatrixI cm, boolean cosineOrDifference)
63   {
64     this.av = av;
65     this.aa = aa;
66     this.cm = cm;
67     mode = (cosineOrDifference) ? 1 :0; 
68     calculate(cm);
69
70   }
71
72
73   public void calculate(ContactMatrixI cm)
74   {
75     this.cm = cm;
76     node = new Vector<BinaryNode>();
77     clusters = new Vector<BitSet>();
78     distances = new Matrix(new double[cm.getWidth()][cm.getWidth()]);
79     noseqs = cm.getWidth();
80     done = new BitSet();
81     double moduli[] = new double[cm.getWidth()];
82     double max;
83     if (mode == 0)
84     {
85       max = 1;
86     }
87     else
88     {
89       max = cm.getMax() * cm.getMax();
90     }
91
92     for (int i = 0; i < cm.getWidth(); i++)
93     {
94       // init the tree engine node for this column
95       BinaryNode cnode = new BinaryNode();
96       cnode.setElement(Integer.valueOf(i));
97       cnode.setName("c" + i);
98       node.addElement(cnode);
99       BitSet bs = new BitSet();
100       bs.set(i);
101       clusters.addElement(bs);
102
103       // compute distance matrix element
104       ContactListI ith = cm.getContactList(i);
105       distances.setValue(i, i, 0);
106       if (ith==null)
107       {
108         continue;
109       }
110       for (int j = 0; j < i; j++)
111       {
112         ContactListI jth = cm.getContactList(j);
113         if (jth == null)
114         {
115           break;
116         }
117         double prd = 0;
118         for (int indx = 0; indx < cm.getHeight(); indx++)
119         {
120           if (mode == 0)
121           {
122             if (j == 0)
123             {
124               moduli[i] += ith.getContactAt(indx) * ith.getContactAt(indx);
125             }
126             prd += ith.getContactAt(indx) * jth.getContactAt(indx);
127           }
128           else
129           {
130             prd += Math
131                     .abs(ith.getContactAt(indx) - jth.getContactAt(indx));
132           }
133         }
134         if (mode == 0)
135         {
136           if (j == 0)
137           {
138             moduli[i] = Math.sqrt(moduli[i]);
139           }
140           prd = (moduli[i] != 0 && moduli[j] != 0)
141                   ? prd / (moduli[i] * moduli[j])
142                   : 0;
143           prd = 1 - prd;
144         }
145         else
146         {
147           prd /= cm.getHeight();
148         }
149         distances.setValue(i, j, prd);
150         distances.setValue(j, i, prd);
151       }
152     }
153
154     noClus = clusters.size();
155     cluster();
156   }
157
158   /**
159    * Calculates and saves the distance between the combination of cluster(i) and
160    * cluster(j) and all other clusters. An average of the distances from
161    * cluster(i) and cluster(j) is calculated, weighted by the sizes of each
162    * cluster.
163    * 
164    * @param i
165    * @param j
166    */
167   @Override
168   protected void findClusterDistance(int i, int j)
169   {
170     int noi = clusters.elementAt(i).cardinality();
171     int noj = clusters.elementAt(j).cardinality();
172
173     // New distances from cluster i to others
174     double[] newdist = new double[noseqs];
175
176     for (int l = 0; l < noseqs; l++)
177     {
178       if ((l != i) && (l != j))
179       {
180         newdist[l] = ((distances.getValue(i, l) * noi)
181                 + (distances.getValue(j, l) * noj)) / (noi + noj);
182       }
183       else
184       {
185         newdist[l] = 0;
186       }
187     }
188
189     for (int ii = 0; ii < noseqs; ii++)
190     {
191       distances.setValue(i, ii, newdist[ii]);
192       distances.setValue(ii, i, newdist[ii]);
193     }
194   }
195
196   /**
197    * {@inheritDoc}
198    */
199   @Override
200   protected double findMinDistance()
201   {
202     double min = Double.MAX_VALUE;
203
204     for (int i = 0; i < (noseqs - 1); i++)
205     {
206       for (int j = i + 1; j < noseqs; j++)
207       {
208         if (!done.get(i) && !done.get(j))
209         {
210           if (distances.getValue(i, j) < min)
211           {
212             mini = i;
213             minj = j;
214
215             min = distances.getValue(i, j);
216           }
217         }
218       }
219     }
220     return min;
221   }
222
223   /**
224    * {@inheritDoc}
225    */
226   @Override
227   protected void findNewDistances(BinaryNode nodei, BinaryNode nodej,
228           double dist)
229   {
230     double ih = 0;
231     double jh = 0;
232
233     BinaryNode sni = nodei;
234     BinaryNode snj = nodej;
235
236     while (sni != null)
237     {
238       ih = ih + sni.dist;
239       sni = (BinaryNode) sni.left();
240     }
241
242     while (snj != null)
243     {
244       jh = jh + snj.dist;
245       snj = (BinaryNode) snj.left();
246     }
247
248     nodei.dist = ((dist / 2) - ih);
249     nodej.dist = ((dist / 2) - jh);
250   }
251
252   /***
253    * not the right place - OH WELL!
254    */
255
256   /**
257    * Makes a list of groups, where each group is represented by a node whose
258    * height (distance from the root node), as a fraction of the height of the
259    * whole tree, is greater than the given threshold. This corresponds to
260    * selecting the nodes immediately to the right of a vertical line
261    * partitioning the tree (if the tree is drawn with root to the left). Each
262    * such node represents a group that contains all of the sequences linked to
263    * the child leaf nodes.
264    * 
265    * @param threshold
266    * @see #getGroups()
267    */
268   public List<BinaryNode> groupNodes(float threshold)
269   {
270     List<BinaryNode> groups = new ArrayList<BinaryNode>();
271     _groupNodes(groups, getTopNode(), threshold);
272     return groups;
273   }
274
275   protected void _groupNodes(List<BinaryNode> groups, BinaryNode nd,
276           float threshold)
277   {
278     if (nd == null)
279     {
280       return;
281     }
282
283     if ((nd.height / maxheight) > threshold)
284     {
285       groups.add(nd);
286     }
287     else
288     {
289       _groupNodes(groups, nd.left(), threshold);
290       _groupNodes(groups, nd.right(), threshold);
291     }
292   }
293
294   /**
295    * DOCUMENT ME!
296    * 
297    * @param nd
298    *          DOCUMENT ME!
299    * 
300    * @return DOCUMENT ME!
301    */
302   public double findHeight(BinaryNode nd)
303   {
304     if (nd == null)
305     {
306       return maxheight;
307     }
308
309     if ((nd.left() == null) && (nd.right() == null))
310     {
311       nd.height = ((BinaryNode) nd.parent()).height + nd.dist;
312
313       if (nd.height > maxheight)
314       {
315         return nd.height;
316       }
317       else
318       {
319         return maxheight;
320       }
321     }
322     else
323     {
324       if (nd.parent() != null)
325       {
326         nd.height = ((BinaryNode) nd.parent()).height + nd.dist;
327       }
328       else
329       {
330         maxheight = 0;
331         nd.height = (float) 0.0;
332       }
333
334       maxheight = findHeight((BinaryNode) (nd.left()));
335       maxheight = findHeight((BinaryNode) (nd.right()));
336     }
337
338     return maxheight;
339   }
340
341   /**
342    * Search for leaf nodes below (or at) the given node
343    * 
344    * @param top2
345    *          root node to search from
346    * 
347    * @return
348    */
349   public Vector<BinaryNode> findLeaves(BinaryNode top2)
350   {
351     Vector<BinaryNode> leaves = new Vector<BinaryNode>();
352     findLeaves(top2, leaves);
353     return leaves;
354   }
355
356   /**
357    * Search for leaf nodes.
358    * 
359    * @param nd
360    *          root node to search from
361    * @param leaves
362    *          Vector of leaves to add leaf node objects too.
363    * 
364    * @return Vector of leaf nodes on binary tree
365    */
366   Vector<BinaryNode> findLeaves(BinaryNode nd, Vector<BinaryNode> leaves)
367   {
368     if (nd == null)
369     {
370       return leaves;
371     }
372
373     if ((nd.left() == null) && (nd.right() == null)) // Interior node
374     // detection
375     {
376       leaves.addElement(nd);
377
378       return leaves;
379     }
380     else
381     {
382       /*
383        * TODO: Identify internal nodes... if (node.isSequenceLabel()) {
384        * leaves.addElement(node); }
385        */
386       findLeaves(nd.left(), leaves);
387       findLeaves(nd.right(), leaves);
388     }
389
390     return leaves;
391   }
392
393 }