JAL-4134 catch NPE - though in principle this should never happen…
[jalview.git] / src / jalview / analysis / AverageDistanceEngine.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21 package jalview.analysis;
22
23 import java.util.ArrayList;
24 import java.util.BitSet;
25 import java.util.List;
26 import java.util.Vector;
27
28 import jalview.datamodel.AlignmentAnnotation;
29 import jalview.datamodel.BinaryNode;
30 import jalview.datamodel.ContactListI;
31 import jalview.datamodel.ContactMatrixI;
32 import jalview.math.Matrix;
33 import jalview.viewmodel.AlignmentViewport;
34
35 /**
36  * This class implements distance calculations used in constructing a Average
37  * Distance tree (also known as UPGMA)
38  */
39 public class AverageDistanceEngine extends TreeEngine
40 {
41   ContactMatrixI cm;
42
43   AlignmentViewport av;
44
45   AlignmentAnnotation aa;
46
47   /**
48    * compute cosine distance matrix for a given contact matrix and create a
49    * UPGMA tree
50    * 
51    * @param cm
52    */
53   public AverageDistanceEngine(AlignmentViewport av, AlignmentAnnotation aa,
54           ContactMatrixI cm)
55   {
56     this.av = av;
57     this.aa = aa;
58     this.cm = cm;
59     calculate(cm);
60
61   }
62
63   // 0 - normalised dot product
64   // 1 - L1 - ie (abs(v_1-v_2)/dim(v))
65   // L1 is more rational - since can reason about value of difference,
66   // normalised dot product might give cleaner clusters, but more difficult to
67   // understand.
68
69   int mode = 1;
70
71   public void calculate(ContactMatrixI cm)
72   {
73     this.cm = cm;
74     node = new Vector<BinaryNode>();
75     clusters = new Vector<BitSet>();
76     distances = new Matrix(new double[cm.getWidth()][cm.getWidth()]);
77     noseqs = cm.getWidth();
78     done = new BitSet();
79     double moduli[] = new double[cm.getWidth()];
80     double max;
81     if (mode == 0)
82     {
83       max = 1;
84     }
85     else
86     {
87       max = cm.getMax() * cm.getMax();
88     }
89
90     for (int i = 0; i < cm.getWidth(); i++)
91     {
92       // init the tree engine node for this column
93       BinaryNode cnode = new BinaryNode();
94       cnode.setElement(Integer.valueOf(i));
95       cnode.setName("c" + i);
96       node.addElement(cnode);
97       BitSet bs = new BitSet();
98       bs.set(i);
99       clusters.addElement(bs);
100
101       // compute distance matrix element
102       ContactListI ith = cm.getContactList(i);
103       distances.setValue(i, i, 0);
104       if (ith==null)
105       {
106         continue;
107       }
108       for (int j = 0; j < i; j++)
109       {
110         ContactListI jth = cm.getContactList(j);
111         if (jth == null)
112         {
113           break;
114         }
115         double prd = 0;
116         for (int indx = 0; indx < cm.getHeight(); indx++)
117         {
118           if (mode == 0)
119           {
120             if (j == 0)
121             {
122               moduli[i] += ith.getContactAt(indx) * ith.getContactAt(indx);
123             }
124             prd += ith.getContactAt(indx) * jth.getContactAt(indx);
125           }
126           else
127           {
128             prd += Math
129                     .abs(ith.getContactAt(indx) - jth.getContactAt(indx));
130           }
131         }
132         if (mode == 0)
133         {
134           if (j == 0)
135           {
136             moduli[i] = Math.sqrt(moduli[i]);
137           }
138           prd = (moduli[i] != 0 && moduli[j] != 0)
139                   ? prd / (moduli[i] * moduli[j])
140                   : 0;
141           prd = 1 - prd;
142         }
143         else
144         {
145           prd /= cm.getHeight();
146         }
147         distances.setValue(i, j, prd);
148         distances.setValue(j, i, prd);
149       }
150     }
151
152     noClus = clusters.size();
153     cluster();
154   }
155
156   /**
157    * Calculates and saves the distance between the combination of cluster(i) and
158    * cluster(j) and all other clusters. An average of the distances from
159    * cluster(i) and cluster(j) is calculated, weighted by the sizes of each
160    * cluster.
161    * 
162    * @param i
163    * @param j
164    */
165   @Override
166   protected void findClusterDistance(int i, int j)
167   {
168     int noi = clusters.elementAt(i).cardinality();
169     int noj = clusters.elementAt(j).cardinality();
170
171     // New distances from cluster i to others
172     double[] newdist = new double[noseqs];
173
174     for (int l = 0; l < noseqs; l++)
175     {
176       if ((l != i) && (l != j))
177       {
178         newdist[l] = ((distances.getValue(i, l) * noi)
179                 + (distances.getValue(j, l) * noj)) / (noi + noj);
180       }
181       else
182       {
183         newdist[l] = 0;
184       }
185     }
186
187     for (int ii = 0; ii < noseqs; ii++)
188     {
189       distances.setValue(i, ii, newdist[ii]);
190       distances.setValue(ii, i, newdist[ii]);
191     }
192   }
193
194   /**
195    * {@inheritDoc}
196    */
197   @Override
198   protected double findMinDistance()
199   {
200     double min = Double.MAX_VALUE;
201
202     for (int i = 0; i < (noseqs - 1); i++)
203     {
204       for (int j = i + 1; j < noseqs; j++)
205       {
206         if (!done.get(i) && !done.get(j))
207         {
208           if (distances.getValue(i, j) < min)
209           {
210             mini = i;
211             minj = j;
212
213             min = distances.getValue(i, j);
214           }
215         }
216       }
217     }
218     return min;
219   }
220
221   /**
222    * {@inheritDoc}
223    */
224   @Override
225   protected void findNewDistances(BinaryNode nodei, BinaryNode nodej,
226           double dist)
227   {
228     double ih = 0;
229     double jh = 0;
230
231     BinaryNode sni = nodei;
232     BinaryNode snj = nodej;
233
234     while (sni != null)
235     {
236       ih = ih + sni.dist;
237       sni = (BinaryNode) sni.left();
238     }
239
240     while (snj != null)
241     {
242       jh = jh + snj.dist;
243       snj = (BinaryNode) snj.left();
244     }
245
246     nodei.dist = ((dist / 2) - ih);
247     nodej.dist = ((dist / 2) - jh);
248   }
249
250   /***
251    * not the right place - OH WELL!
252    */
253
254   /**
255    * Makes a list of groups, where each group is represented by a node whose
256    * height (distance from the root node), as a fraction of the height of the
257    * whole tree, is greater than the given threshold. This corresponds to
258    * selecting the nodes immediately to the right of a vertical line
259    * partitioning the tree (if the tree is drawn with root to the left). Each
260    * such node represents a group that contains all of the sequences linked to
261    * the child leaf nodes.
262    * 
263    * @param threshold
264    * @see #getGroups()
265    */
266   public List<BinaryNode> groupNodes(float threshold)
267   {
268     List<BinaryNode> groups = new ArrayList<BinaryNode>();
269     _groupNodes(groups, getTopNode(), threshold);
270     return groups;
271   }
272
273   protected void _groupNodes(List<BinaryNode> groups, BinaryNode nd,
274           float threshold)
275   {
276     if (nd == null)
277     {
278       return;
279     }
280
281     if ((nd.height / maxheight) > threshold)
282     {
283       groups.add(nd);
284     }
285     else
286     {
287       _groupNodes(groups, nd.left(), threshold);
288       _groupNodes(groups, nd.right(), threshold);
289     }
290   }
291
292   /**
293    * DOCUMENT ME!
294    * 
295    * @param nd
296    *          DOCUMENT ME!
297    * 
298    * @return DOCUMENT ME!
299    */
300   public double findHeight(BinaryNode nd)
301   {
302     if (nd == null)
303     {
304       return maxheight;
305     }
306
307     if ((nd.left() == null) && (nd.right() == null))
308     {
309       nd.height = ((BinaryNode) nd.parent()).height + nd.dist;
310
311       if (nd.height > maxheight)
312       {
313         return nd.height;
314       }
315       else
316       {
317         return maxheight;
318       }
319     }
320     else
321     {
322       if (nd.parent() != null)
323       {
324         nd.height = ((BinaryNode) nd.parent()).height + nd.dist;
325       }
326       else
327       {
328         maxheight = 0;
329         nd.height = (float) 0.0;
330       }
331
332       maxheight = findHeight((BinaryNode) (nd.left()));
333       maxheight = findHeight((BinaryNode) (nd.right()));
334     }
335
336     return maxheight;
337   }
338
339   /**
340    * Search for leaf nodes below (or at) the given node
341    * 
342    * @param top2
343    *          root node to search from
344    * 
345    * @return
346    */
347   public Vector<BinaryNode> findLeaves(BinaryNode top2)
348   {
349     Vector<BinaryNode> leaves = new Vector<BinaryNode>();
350     findLeaves(top2, leaves);
351     return leaves;
352   }
353
354   /**
355    * Search for leaf nodes.
356    * 
357    * @param nd
358    *          root node to search from
359    * @param leaves
360    *          Vector of leaves to add leaf node objects too.
361    * 
362    * @return Vector of leaf nodes on binary tree
363    */
364   Vector<BinaryNode> findLeaves(BinaryNode nd, Vector<BinaryNode> leaves)
365   {
366     if (nd == null)
367     {
368       return leaves;
369     }
370
371     if ((nd.left() == null) && (nd.right() == null)) // Interior node
372     // detection
373     {
374       leaves.addElement(nd);
375
376       return leaves;
377     }
378     else
379     {
380       /*
381        * TODO: Identify internal nodes... if (node.isSequenceLabel()) {
382        * leaves.addElement(node); }
383        */
384       findLeaves(nd.left(), leaves);
385       findLeaves(nd.right(), leaves);
386     }
387
388     return leaves;
389   }
390
391 }