JAL-4134 group columns using a tree cut - seems to work after more egregiousness
[jalview.git] / src / jalview / analysis / AverageDistanceEngine.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21 package jalview.analysis;
22
23 import java.util.ArrayList;
24 import java.util.BitSet;
25 import java.util.List;
26 import java.util.Vector;
27
28 import jalview.datamodel.AlignmentAnnotation;
29 import jalview.datamodel.BinaryNode;
30 import jalview.datamodel.ContactListI;
31 import jalview.datamodel.ContactMatrixI;
32 import jalview.math.Matrix;
33 import jalview.viewmodel.AlignmentViewport;
34
35 /**
36  * This class implements distance calculations used in constructing a Average
37  * Distance tree (also known as UPGMA)
38  */
39 public class AverageDistanceEngine extends TreeEngine
40 {
41   ContactMatrixI cm;
42   AlignmentViewport av;
43   AlignmentAnnotation aa;
44   /**
45    * compute cosine distance matrix for a given contact matrix and create a UPGMA tree
46    * @param cm
47    */
48   public AverageDistanceEngine(AlignmentViewport av, AlignmentAnnotation aa, ContactMatrixI cm)
49   {
50     this.av =av;
51     this.aa = aa;
52     this.cm=cm;
53     node = new Vector<BinaryNode>();
54     clusters = new Vector<BitSet>();
55     distances = new Matrix(new double[cm.getWidth()][cm.getWidth()]);
56     noseqs=cm.getWidth();
57     done  = new BitSet();
58     for (int i=0;i<cm.getWidth();i++)
59     {
60       // init the tree engine node for this column
61       BinaryNode cnode = new BinaryNode();
62       cnode.setElement(Integer.valueOf(i));
63       cnode.setName("c"+i);
64       node.addElement(cnode);
65       BitSet bs = new BitSet();
66       bs.set(i);
67       clusters.addElement(bs);
68
69       // compute distance matrix element
70       ContactListI ith=cm.getContactList(i);
71       for (int j=0;j<i;j++)
72       {
73         ContactListI jth = cm.getContactList(j);
74         double prd=0;
75         for (int indx=0;indx<cm.getHeight();indx++)
76         {
77           prd+=ith.getContactAt(indx)*jth.getContactAt(indx);
78         }
79         distances.setValue(i, j, prd);
80         distances.setValue(j, i, prd);
81       }
82     }
83
84     noClus = clusters.size();
85     cluster();
86     
87   }
88   /**
89    * Calculates and saves the distance between the combination of cluster(i) and
90    * cluster(j) and all other clusters. An average of the distances from
91    * cluster(i) and cluster(j) is calculated, weighted by the sizes of each
92    * cluster.
93    * 
94    * @param i
95    * @param j
96    */
97   @Override
98   protected void findClusterDistance(int i, int j)
99   {
100     int noi = clusters.elementAt(i).cardinality();
101     int noj = clusters.elementAt(j).cardinality();
102
103     // New distances from cluster i to others
104     double[] newdist = new double[noseqs];
105
106     for (int l = 0; l < noseqs; l++)
107     {
108       if ((l != i) && (l != j))
109       {
110         newdist[l] = ((distances.getValue(i, l) * noi)
111                 + (distances.getValue(j, l) * noj)) / (noi + noj);
112       }
113       else
114       {
115         newdist[l] = 0;
116       }
117     }
118
119     for (int ii = 0; ii < noseqs; ii++)
120     {
121       distances.setValue(i, ii, newdist[ii]);
122       distances.setValue(ii, i, newdist[ii]);
123     }
124   }
125
126   /**
127    * {@inheritDoc}
128    */
129   @Override
130   protected double findMinDistance()
131   {
132     double min = Double.MAX_VALUE;
133
134     for (int i = 0; i < (noseqs - 1); i++)
135     {
136       for (int j = i + 1; j < noseqs; j++)
137       {
138         if (!done.get(i) && !done.get(j))
139         {
140           if (distances.getValue(i, j) < min)
141           {
142             mini = i;
143             minj = j;
144
145             min = distances.getValue(i, j);
146           }
147         }
148       }
149     }
150     return min;
151   }
152
153   /**
154    * {@inheritDoc}
155    */
156   @Override
157   protected void findNewDistances(BinaryNode nodei, BinaryNode nodej,
158           double dist)
159   {
160     double ih = 0;
161     double jh = 0;
162
163     BinaryNode sni = nodei;
164     BinaryNode snj = nodej;
165
166     while (sni != null)
167     {
168       ih = ih + sni.dist;
169       sni = (BinaryNode) sni.left();
170     }
171
172     while (snj != null)
173     {
174       jh = jh + snj.dist;
175       snj = (BinaryNode) snj.left();
176     }
177
178     nodei.dist = ((dist / 2) - ih);
179     nodej.dist = ((dist / 2) - jh);
180   }
181   /***
182    * not the right place - OH WELL!
183    */
184
185   /**
186    * Makes a list of groups, where each group is represented by a node whose
187    * height (distance from the root node), as a fraction of the height of the
188    * whole tree, is greater than the given threshold. This corresponds to
189    * selecting the nodes immediately to the right of a vertical line
190    * partitioning the tree (if the tree is drawn with root to the left). Each
191    * such node represents a group that contains all of the sequences linked to
192    * the child leaf nodes.
193    * 
194    * @param threshold
195    * @see #getGroups()
196    */
197   public List<BinaryNode> groupNodes(float threshold)
198   {
199     List<BinaryNode> groups = new ArrayList<BinaryNode>();
200     _groupNodes(groups, getTopNode(), threshold);
201     return groups;
202   }
203
204   protected void _groupNodes(List<BinaryNode> groups, BinaryNode nd,
205           float threshold)
206   {
207     if (nd == null)
208     {
209       return;
210     }
211
212     if ((nd.height / maxheight) > threshold)
213     {
214       groups.add(nd);
215     }
216     else
217     {
218       _groupNodes(groups,  nd.left(), threshold);
219       _groupNodes(groups,  nd.right(), threshold);
220     }
221   }
222
223   /**
224    * DOCUMENT ME!
225    * 
226    * @param nd
227    *          DOCUMENT ME!
228    * 
229    * @return DOCUMENT ME!
230    */
231   public double findHeight(BinaryNode nd)
232   {
233     if (nd == null)
234     {
235       return maxheight;
236     }
237
238     if ((nd.left() == null) && (nd.right() == null))
239     {
240       nd.height = ((BinaryNode) nd.parent()).height + nd.dist;
241
242       if (nd.height > maxheight)
243       {
244         return nd.height;
245       }
246       else
247       {
248         return maxheight;
249       }
250     }
251     else
252     {
253       if (nd.parent() != null)
254       {
255         nd.height = ((BinaryNode) nd.parent()).height + nd.dist;
256       }
257       else
258       {
259         maxheight = 0;
260         nd.height = (float) 0.0;
261       }
262
263       maxheight = findHeight((BinaryNode) (nd.left()));
264       maxheight = findHeight((BinaryNode) (nd.right()));
265     }
266
267     return maxheight;
268   }
269
270
271   /**
272    * Search for leaf nodes below (or at) the given node
273    * 
274    * @param top2
275    *          root node to search from
276    * 
277    * @return
278    */
279   public Vector<BinaryNode> findLeaves(BinaryNode top2)
280   {
281     Vector<BinaryNode> leaves = new Vector<BinaryNode>();
282     findLeaves(top2, leaves);
283     return leaves;
284   }
285
286   /**
287    * Search for leaf nodes.
288    * 
289    * @param nd
290    *          root node to search from
291    * @param leaves
292    *          Vector of leaves to add leaf node objects too.
293    * 
294    * @return Vector of leaf nodes on binary tree
295    */
296   Vector<BinaryNode> findLeaves(BinaryNode nd,
297           Vector<BinaryNode> leaves)
298   {
299     if (nd == null)
300     {
301       return leaves;
302     }
303
304     if ((nd.left() == null) && (nd.right() == null)) // Interior node
305     // detection
306     {
307       leaves.addElement(nd);
308
309       return leaves;
310     }
311     else
312     {
313       /*
314        * TODO: Identify internal nodes... if (node.isSequenceLabel()) {
315        * leaves.addElement(node); }
316        */
317       findLeaves( nd.left(), leaves);
318       findLeaves( nd.right(), leaves);
319     }
320
321     return leaves;
322   }
323
324
325 }