ed6f861bf984bb0edc1ef73c39119162502713f6
[jalview.git] / src / jalview / analysis / AverageDistanceEngine.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21 package jalview.analysis;
22
23 import java.util.ArrayList;
24 import java.util.BitSet;
25 import java.util.List;
26 import java.util.Vector;
27
28 import jalview.datamodel.AlignmentAnnotation;
29 import jalview.datamodel.BinaryNode;
30 import jalview.datamodel.ContactListI;
31 import jalview.datamodel.ContactMatrixI;
32 import jalview.math.Matrix;
33 import jalview.viewmodel.AlignmentViewport;
34
35 /**
36  * This class implements distance calculations used in constructing a Average
37  * Distance tree (also known as UPGMA)
38  */
39 public class AverageDistanceEngine extends TreeEngine
40 {
41   ContactMatrixI cm;
42   AlignmentViewport av;
43   AlignmentAnnotation aa;
44   /**
45    * compute cosine distance matrix for a given contact matrix and create a UPGMA tree
46    * @param cm
47    */
48   public AverageDistanceEngine(AlignmentViewport av, AlignmentAnnotation aa, ContactMatrixI cm)
49   {
50     this.av =av;
51     this.aa = aa;
52     this.cm=cm;
53     calculate(cm);
54
55   }
56   public void calculate(ContactMatrixI cm)
57   {
58     this.cm = cm;
59     node = new Vector<BinaryNode>();
60     clusters = new Vector<BitSet>();
61     distances = new Matrix(new double[cm.getWidth()][cm.getWidth()]);
62     noseqs=cm.getWidth();
63     done  = new BitSet();
64     double moduli[]=new double[cm.getWidth()];
65     
66     
67     for (int i=0;i<cm.getWidth();i++)
68     {
69       // init the tree engine node for this column
70       BinaryNode cnode = new BinaryNode();
71       cnode.setElement(Integer.valueOf(i));
72       cnode.setName("c"+i);
73       node.addElement(cnode);
74       BitSet bs = new BitSet();
75       bs.set(i);
76       clusters.addElement(bs);
77
78       // compute distance matrix element
79       ContactListI ith=cm.getContactList(i);
80       
81       for (int j=0;j<i;j++)
82       {
83         distances.setValue(i,i,0);
84         ContactListI jth = cm.getContactList(j);
85         double prd=0;
86         for (int indx=0;indx<cm.getHeight();indx++)
87         {
88           if (j==0)
89           {
90             moduli[i]+=ith.getContactAt(indx)*ith.getContactAt(indx);
91           }
92           prd+=ith.getContactAt(indx)*jth.getContactAt(indx);
93         }
94         if (j==0)
95         {
96           moduli[i]=Math.sqrt(moduli[i]);
97         }
98         prd=(moduli[i]!=0 && moduli[j]!=0) ? prd/(moduli[i]*moduli[j]) : 0;
99         distances.setValue(i, j, 1-prd);
100         distances.setValue(j, i, 1-prd);
101       }
102     }
103
104     noClus = clusters.size();
105     cluster();
106   }
107   /**
108    * Calculates and saves the distance between the combination of cluster(i) and
109    * cluster(j) and all other clusters. An average of the distances from
110    * cluster(i) and cluster(j) is calculated, weighted by the sizes of each
111    * cluster.
112    * 
113    * @param i
114    * @param j
115    */
116   @Override
117   protected void findClusterDistance(int i, int j)
118   {
119     int noi = clusters.elementAt(i).cardinality();
120     int noj = clusters.elementAt(j).cardinality();
121
122     // New distances from cluster i to others
123     double[] newdist = new double[noseqs];
124
125     for (int l = 0; l < noseqs; l++)
126     {
127       if ((l != i) && (l != j))
128       {
129         newdist[l] = ((distances.getValue(i, l) * noi)
130                 + (distances.getValue(j, l) * noj)) / (noi + noj);
131       }
132       else
133       {
134         newdist[l] = 0;
135       }
136     }
137
138     for (int ii = 0; ii < noseqs; ii++)
139     {
140       distances.setValue(i, ii, newdist[ii]);
141       distances.setValue(ii, i, newdist[ii]);
142     }
143   }
144
145   /**
146    * {@inheritDoc}
147    */
148   @Override
149   protected double findMinDistance()
150   {
151     double min = Double.MAX_VALUE;
152
153     for (int i = 0; i < (noseqs - 1); i++)
154     {
155       for (int j = i + 1; j < noseqs; j++)
156       {
157         if (!done.get(i) && !done.get(j))
158         {
159           if (distances.getValue(i, j) < min)
160           {
161             mini = i;
162             minj = j;
163
164             min = distances.getValue(i, j);
165           }
166         }
167       }
168     }
169     return min;
170   }
171
172   /**
173    * {@inheritDoc}
174    */
175   @Override
176   protected void findNewDistances(BinaryNode nodei, BinaryNode nodej,
177           double dist)
178   {
179     double ih = 0;
180     double jh = 0;
181
182     BinaryNode sni = nodei;
183     BinaryNode snj = nodej;
184
185     while (sni != null)
186     {
187       ih = ih + sni.dist;
188       sni = (BinaryNode) sni.left();
189     }
190
191     while (snj != null)
192     {
193       jh = jh + snj.dist;
194       snj = (BinaryNode) snj.left();
195     }
196
197     nodei.dist = ((dist / 2) - ih);
198     nodej.dist = ((dist / 2) - jh);
199   }
200   /***
201    * not the right place - OH WELL!
202    */
203
204   /**
205    * Makes a list of groups, where each group is represented by a node whose
206    * height (distance from the root node), as a fraction of the height of the
207    * whole tree, is greater than the given threshold. This corresponds to
208    * selecting the nodes immediately to the right of a vertical line
209    * partitioning the tree (if the tree is drawn with root to the left). Each
210    * such node represents a group that contains all of the sequences linked to
211    * the child leaf nodes.
212    * 
213    * @param threshold
214    * @see #getGroups()
215    */
216   public List<BinaryNode> groupNodes(float threshold)
217   {
218     List<BinaryNode> groups = new ArrayList<BinaryNode>();
219     _groupNodes(groups, getTopNode(), threshold);
220     return groups;
221   }
222
223   protected void _groupNodes(List<BinaryNode> groups, BinaryNode nd,
224           float threshold)
225   {
226     if (nd == null)
227     {
228       return;
229     }
230
231     if ((nd.height / maxheight) > threshold)
232     {
233       groups.add(nd);
234     }
235     else
236     {
237       _groupNodes(groups,  nd.left(), threshold);
238       _groupNodes(groups,  nd.right(), threshold);
239     }
240   }
241
242   /**
243    * DOCUMENT ME!
244    * 
245    * @param nd
246    *          DOCUMENT ME!
247    * 
248    * @return DOCUMENT ME!
249    */
250   public double findHeight(BinaryNode nd)
251   {
252     if (nd == null)
253     {
254       return maxheight;
255     }
256
257     if ((nd.left() == null) && (nd.right() == null))
258     {
259       nd.height = ((BinaryNode) nd.parent()).height + nd.dist;
260
261       if (nd.height > maxheight)
262       {
263         return nd.height;
264       }
265       else
266       {
267         return maxheight;
268       }
269     }
270     else
271     {
272       if (nd.parent() != null)
273       {
274         nd.height = ((BinaryNode) nd.parent()).height + nd.dist;
275       }
276       else
277       {
278         maxheight = 0;
279         nd.height = (float) 0.0;
280       }
281
282       maxheight = findHeight((BinaryNode) (nd.left()));
283       maxheight = findHeight((BinaryNode) (nd.right()));
284     }
285
286     return maxheight;
287   }
288
289
290   /**
291    * Search for leaf nodes below (or at) the given node
292    * 
293    * @param top2
294    *          root node to search from
295    * 
296    * @return
297    */
298   public Vector<BinaryNode> findLeaves(BinaryNode top2)
299   {
300     Vector<BinaryNode> leaves = new Vector<BinaryNode>();
301     findLeaves(top2, leaves);
302     return leaves;
303   }
304
305   /**
306    * Search for leaf nodes.
307    * 
308    * @param nd
309    *          root node to search from
310    * @param leaves
311    *          Vector of leaves to add leaf node objects too.
312    * 
313    * @return Vector of leaf nodes on binary tree
314    */
315   Vector<BinaryNode> findLeaves(BinaryNode nd,
316           Vector<BinaryNode> leaves)
317   {
318     if (nd == null)
319     {
320       return leaves;
321     }
322
323     if ((nd.left() == null) && (nd.right() == null)) // Interior node
324     // detection
325     {
326       leaves.addElement(nd);
327
328       return leaves;
329     }
330     else
331     {
332       /*
333        * TODO: Identify internal nodes... if (node.isSequenceLabel()) {
334        * leaves.addElement(node); }
335        */
336       findLeaves( nd.left(), leaves);
337       findLeaves( nd.right(), leaves);
338     }
339
340     return leaves;
341   }
342
343
344 }