JAL-2428 use BitSet for clusters and done flags
authorgmungoc <g.m.carstairs@dundee.ac.uk>
Fri, 3 Mar 2017 10:01:34 +0000 (10:01 +0000)
committergmungoc <g.m.carstairs@dundee.ac.uk>
Fri, 3 Mar 2017 10:01:34 +0000 (10:01 +0000)
src/jalview/analysis/NJTree.java

index 255d6df..50c1795 100644 (file)
@@ -36,6 +36,7 @@ import jalview.io.NewickFile;
 import jalview.math.MatrixI;
 import jalview.viewmodel.AlignmentViewport;
 
+import java.util.BitSet;
 import java.util.Enumeration;
 import java.util.List;
 import java.util.Vector;
@@ -54,7 +55,13 @@ public class NJTree
 
   public static final String FROM_FILE = "FromFile";
 
-  Vector<Cluster> cluster;
+  /*
+   * Bit j in each BitSet is set if the cluster includes the j'th sequence.
+   * Clusters are grouped as the tree is built, from an initial state
+   * where each cluster is a single sequence, until only two clusters are left.
+   * These are the children of the root of the tree.
+   */
+  Vector<BitSet> clusters;
 
   SequenceI[] sequences;
 
@@ -64,12 +71,23 @@ public class NJTree
    */
   public AlignmentView seqData = null;
 
-  int[] done;
+  /*
+   * Bit j is set when cluster j has been combined to another cluster.
+   * The last two bits left unset are the indices of the clusters which
+   * are the children of the root node.
+   */
+  BitSet done;
 
   int noseqs;
 
   int noClus;
 
+  /*
+   * Value [i, j] is the distance between cluster[i] and cluster[j].
+   * Initially these are the pairwise distances of all sequences.
+   * As the tree is built, these are updated to be the distances
+   * between the clusters as they are assembled.
+   */
   MatrixI distances;
 
   int mini;
@@ -98,8 +116,6 @@ public class NJTree
 
   String pwtype;
 
-  Object found = null;
-
   boolean hasDistances = true; // normal case for jalview trees
 
   boolean hasBootstrap = false; // normal case for jalview trees
@@ -278,7 +294,7 @@ public class NJTree
      */
     noseqs = 0;
 
-    done = new int[sequences.length];
+    done = new BitSet();
 
     for (SequenceI seq : sequences)
     {
@@ -322,7 +338,7 @@ public class NJTree
 
     makeLeaves();
 
-    noClus = cluster.size();
+    noClus = clusters.size();
 
     cluster();
   }
@@ -431,7 +447,8 @@ public class NJTree
   }
 
   /**
-   * DOCUMENT ME!
+   * Form clusters by grouping sub-clusters, starting from one sequence per
+   * cluster, and finishing when only two clusters remain
    */
   void cluster()
   {
@@ -446,39 +463,21 @@ public class NJTree
         findMinDistance();
       }
 
-      Cluster c = joinClusters(mini, minj);
+      BitSet combined = joinClusters(mini, minj);
 
-      done[minj] = 1;
+      done.set(minj);
 
-      cluster.setElementAt(null, minj);
-      cluster.setElementAt(c, mini);
+      clusters.setElementAt(null, minj);
+      clusters.setElementAt(combined, mini);
 
       noClus--;
     }
 
-    boolean onefound = false;
-
-    int one = -1;
-    int two = -1;
-
-    for (int i = 0; i < noseqs; i++)
-    {
-      if (done[i] != 1)
-      {
-        if (onefound == false)
-        {
-          two = i;
-          onefound = true;
-        }
-        else
-        {
-          one = i;
-        }
-      }
-    }
+    int rightChild = done.nextClearBit(0);
+    int leftChild = done.nextClearBit(rightChild + 1);
 
-    joinClusters(one, two);
-    top = (node.elementAt(one));
+    joinClusters(leftChild, rightChild);
+    top = (node.elementAt(leftChild));
 
     reCount(top);
     findHeight(top);
@@ -495,26 +494,13 @@ public class NJTree
    * 
    * @return DOCUMENT ME!
    */
-  Cluster joinClusters(int i, int j)
+  BitSet joinClusters(int i, int j)
   {
     double dist = distances.getValue(i, j);
 
-    int noi = cluster.elementAt(i).value.length;
-    int noj = cluster.elementAt(j).value.length;
-
-    int[] value = new int[noi + noj];
-
-    for (int ii = 0; ii < noi; ii++)
-    {
-      value[ii] = cluster.elementAt(i).value[ii];
-    }
-
-    for (int ii = noi; ii < (noi + noj); ii++)
-    {
-      value[ii] = cluster.elementAt(j).value[ii - noi];
-    }
-
-    Cluster c = new Cluster(value);
+    BitSet combined = new BitSet();
+    combined.or(clusters.get(i));
+    combined.or(clusters.get(j));
 
     ri = findr(i, j);
     rj = findr(j, i);
@@ -550,7 +536,7 @@ public class NJTree
 
     node.setElementAt(sn, i);
 
-    return c;
+    return combined;
   }
 
   /**
@@ -626,8 +612,8 @@ public class NJTree
    */
   void findClusterDistance(int i, int j)
   {
-    int noi = cluster.elementAt(i).value.length;
-    int noj = cluster.elementAt(j).value.length;
+    int noi = clusters.elementAt(i).cardinality();
+    int noj = clusters.elementAt(j).cardinality();
 
     // New distances from cluster to others
     double[] newdist = new double[noseqs];
@@ -636,8 +622,6 @@ public class NJTree
     {
       if ((l != i) && (l != j))
       {
-        // newdist[l] = ((distance[i][l] * noi) + (distance[j][l] * noj))
-        // / (noi + noj);
         newdist[l] = ((distances.getValue(i, l) * noi) + (distances.getValue(
                 j, l) * noj))
                 / (noi + noj);
@@ -650,8 +634,6 @@ public class NJTree
 
     for (int ii = 0; ii < noseqs; ii++)
     {
-      // distance[i][ii] = newdist[ii];
-      // distance[ii][i] = newdist[ii];
       distances.setValue(i, ii, newdist[ii]);
       distances.setValue(ii, i, newdist[ii]);
     }
@@ -675,8 +657,6 @@ public class NJTree
     {
       if ((l != i) && (l != j))
       {
-        // newdist[l] = ((distance[i][l] + distance[j][l]) - distance[i][j]) /
-        // 2;
         newdist[l] = (distances.getValue(i, l) + distances.getValue(j, l) - distances
                 .getValue(i, j)) / 2;
       }
@@ -688,8 +668,6 @@ public class NJTree
 
     for (int ii = 0; ii < noseqs; ii++)
     {
-      // distance[i][ii] = newdist[ii];
-      // distance[ii][i] = newdist[ii];
       distances.setValue(i, ii, newdist[ii]);
       distances.setValue(ii, i, newdist[ii]);
     }
@@ -711,9 +689,8 @@ public class NJTree
 
     for (int k = 0; k < noseqs; k++)
     {
-      if ((k != i) && (k != j) && (done[k] != 1))
+      if ((k != i) && (k != j) && (!done.get(k)))
       {
-        // tmp = tmp + distance[i][k];
         tmp = tmp + distances.getValue(i, k);
       }
     }
@@ -739,9 +716,8 @@ public class NJTree
     {
       for (int j = i + 1; j < noseqs; j++)
       {
-        if ((done[i] != 1) && (done[j] != 1))
+        if (!done.get(i) && !done.get(j))
         {
-          // float tmp = distance[i][j] - (findr(i, j) + findr(j, i));
           double tmp = distances.getValue(i, j)
                   - (findr(i, j) + findr(j, i));
 
@@ -772,15 +748,13 @@ public class NJTree
     {
       for (int j = i + 1; j < noseqs; j++)
       {
-        if ((done[i] != 1) && (done[j] != 1))
+        if (!done.get(i) && !done.get(j))
         {
-          // if (distance[i][j] < min)
           if (distances.getValue(i, j) < min)
           {
             mini = i;
             minj = j;
 
-            // min = distance[i][j];
             min = distances.getValue(i, j);
           }
         }
@@ -791,11 +765,11 @@ public class NJTree
   }
 
   /**
-   * DOCUMENT ME!
+   * Start by making a cluster for each individual sequence
    */
   void makeLeaves()
   {
-    cluster = new Vector<Cluster>();
+    clusters = new Vector<BitSet>();
 
     for (int i = 0; i < noseqs; i++)
     {
@@ -804,12 +778,9 @@ public class NJTree
       sn.setElement(sequences[i]);
       sn.setName(sequences[i].getName());
       node.addElement(sn);
-
-      int[] value = new int[1];
-      value[0] = i;
-
-      Cluster c = new Cluster(value);
-      cluster.addElement(c);
+      BitSet bs = new BitSet();
+      bs.set(i);
+      clusters.addElement(bs);
     }
   }
 
@@ -1127,12 +1098,12 @@ public class NJTree
   public void reCount(SequenceNode nd)
   {
     ycount = 0;
-    _lycount = 0;
+    // _lycount = 0;
     // _lylimit = this.node.size();
     _reCount(nd);
   }
 
-  private long _lycount = 0, _lylimit = 0;
+  // private long _lycount = 0, _lylimit = 0;
 
   /**
    * DOCUMENT ME!
@@ -1150,7 +1121,7 @@ public class NJTree
     {
       return;
     }
-    _lycount++;
+    // _lycount++;
 
     if ((nd.left() != null) && (nd.right() != null))
     {
@@ -1169,7 +1140,7 @@ public class NJTree
       nd.count = 1;
       nd.ycount = ycount++;
     }
-    _lycount--;
+    // _lycount--;
   }
 
   /**
@@ -1312,26 +1283,3 @@ public class NJTree
     }
   }
 }
-
-/**
- * DOCUMENT ME!
- * 
- * @author $author$
- * @version $Revision$
- */
-// TODO what does this class have that int[] doesn't have already?
-class Cluster
-{
-  int[] value;
-
-  /**
-   * Creates a new Cluster object.
-   * 
-   * @param value
-   *          DOCUMENT ME!
-   */
-  public Cluster(int[] value)
-  {
-    this.value = value;
-  }
-}