JAL-2403 JAL-1483 push 'reverseRange' inside Matrix
[jalview.git] / src / jalview / analysis / NJTree.java
index fcf208c..db2efba 100644 (file)
@@ -21,7 +21,9 @@
 package jalview.analysis;
 
 import jalview.analysis.scoremodels.ScoreModels;
+import jalview.api.analysis.DistanceScoreModelI;
 import jalview.api.analysis.ScoreModelI;
+import jalview.api.analysis.SimilarityScoreModelI;
 import jalview.datamodel.AlignmentView;
 import jalview.datamodel.BinaryNode;
 import jalview.datamodel.CigarArray;
@@ -31,6 +33,7 @@ import jalview.datamodel.Sequence;
 import jalview.datamodel.SequenceI;
 import jalview.datamodel.SequenceNode;
 import jalview.io.NewickFile;
+import jalview.math.MatrixI;
 
 import java.util.Enumeration;
 import java.util.List;
@@ -44,6 +47,15 @@ import java.util.Vector;
  */
 public class NJTree
 {
+  /*
+   * 'methods'
+   */
+  public static final String AVERAGE_DISTANCE = "AV";
+
+  public static final String NEIGHBOUR_JOINING = "NJ";
+
+  public static final String FROM_FILE = "FromFile";
+
   Vector<Cluster> cluster;
 
   SequenceI[] sequence;
@@ -58,15 +70,15 @@ public class NJTree
 
   int noClus;
 
-  float[][] distance;
+  MatrixI distance;
 
   int mini;
 
   int minj;
 
-  float ri;
+  double ri;
 
-  float rj;
+  double rj;
 
   Vector<SequenceNode> groups = new Vector<SequenceNode>();
 
@@ -74,9 +86,9 @@ public class NJTree
 
   SequenceNode top;
 
-  float maxDistValue;
+  double maxDistValue;
 
-  float maxheight;
+  double maxheight;
 
   int ycount;
 
@@ -205,25 +217,25 @@ public class NJTree
    * 
    * @param sequence
    *          DOCUMENT ME!
-   * @param type
+   * @param treeType
    *          DOCUMENT ME!
-   * @param pwtype
+   * @param modelType
    *          DOCUMENT ME!
    * @param start
    *          DOCUMENT ME!
    * @param end
    *          DOCUMENT ME!
    */
-  public NJTree(SequenceI[] sequence, AlignmentView seqData, String type,
-          String pwtype, ScoreModelI sm, int start, int end)
+  public NJTree(SequenceI[] sqs, AlignmentView seqView, String treeType,
+          String modelType, ScoreModelI sm, int start, int end)
   {
-    this.sequence = sequence;
+    this.sequence = sqs;
     this.node = new Vector<SequenceNode>();
-    this.type = type;
-    this.pwtype = pwtype;
-    if (seqData != null)
+    this.type = treeType;
+    this.pwtype = modelType;
+    if (seqView != null)
     {
-      this.seqData = seqData;
+      this.seqData = seqView;
     }
     else
     {
@@ -237,16 +249,16 @@ public class NJTree
       this.seqData = new AlignmentView(sdata, start);
     }
     // System.err.println("Made seqData");// dbg
-    if (!(type.equals("NJ")))
+    if (!(treeType.equals(NEIGHBOUR_JOINING)))
     {
-      type = "AV";
+      treeType = AVERAGE_DISTANCE;
     }
 
-    if (sm == null && !(pwtype.equals("PID")))
+    if (sm == null && !(modelType.equals("PID")))
     {
-      if (ScoreModels.getInstance().forName(pwtype) == null)
+      if (ScoreModels.getInstance().forName(modelType) == null)
       {
-        pwtype = "BLOSUM62";
+        modelType = "BLOSUM62";
       }
     }
 
@@ -262,7 +274,21 @@ public class NJTree
 
     noseqs = i++;
 
-    distance = findDistances(sm);
+    if (sm instanceof DistanceScoreModelI)
+    {
+      distance = ((DistanceScoreModelI) sm).findDistances(seqData);
+    }
+    else if (sm instanceof SimilarityScoreModelI)
+    {
+      /*
+       * compute similarity and invert it to give a distance measure
+       */
+      MatrixI result = ((SimilarityScoreModelI) sm)
+              .findSimilarities(seqData);
+      result.reverseRange(true);
+      distance = result;
+    }
+
     // System.err.println("Made distances");// dbg
     makeLeaves();
     // System.err.println("Made leaves");// dbg
@@ -384,7 +410,7 @@ public class NJTree
   {
     while (noClus > 2)
     {
-      if (type.equals("NJ"))
+      if (type.equals(NEIGHBOUR_JOINING))
       {
         findMinNJDistance();
       }
@@ -444,7 +470,7 @@ public class NJTree
    */
   public Cluster joinClusters(int i, int j)
   {
-    float dist = distance[i][j];
+    double dist = distance.getValue(i, j);
 
     int noi = cluster.elementAt(i).value.length;
     int noj = cluster.elementAt(j).value.length;
@@ -466,7 +492,7 @@ public class NJTree
     ri = findr(i, j);
     rj = findr(j, i);
 
-    if (type.equals("NJ"))
+    if (type.equals(NEIGHBOUR_JOINING))
     {
       findClusterNJDistance(i, j);
     }
@@ -483,7 +509,7 @@ public class NJTree
     SequenceNode tmpi = (node.elementAt(i));
     SequenceNode tmpj = (node.elementAt(j));
 
-    if (type.equals("NJ"))
+    if (type.equals(NEIGHBOUR_JOINING))
     {
       findNewNJDistances(tmpi, tmpj, dist);
     }
@@ -511,7 +537,7 @@ public class NJTree
    *          DOCUMENT ME!
    */
   public void findNewNJDistances(SequenceNode tmpi, SequenceNode tmpj,
-          float dist)
+          double dist)
   {
 
     tmpi.dist = ((dist + ri) - rj) / 2;
@@ -539,10 +565,10 @@ public class NJTree
    *          DOCUMENT ME!
    */
   public void findNewDistances(SequenceNode tmpi, SequenceNode tmpj,
-          float dist)
+          double dist)
   {
-    float ih = 0;
-    float jh = 0;
+    double ih = 0;
+    double jh = 0;
 
     SequenceNode sni = tmpi;
     SequenceNode snj = tmpj;
@@ -577,13 +603,16 @@ public class NJTree
     int noj = cluster.elementAt(j).value.length;
 
     // New distances from cluster to others
-    float[] newdist = new float[noseqs];
+    double[] newdist = new double[noseqs];
 
     for (int l = 0; l < noseqs; l++)
     {
       if ((l != i) && (l != j))
       {
-        newdist[l] = ((distance[i][l] * noi) + (distance[j][l] * noj))
+        // newdist[l] = ((distance[i][l] * noi) + (distance[j][l] * noj))
+        // / (noi + noj);
+        newdist[l] = ((distance.getValue(i, l) * noi) + (distance.getValue(
+                j, l) * noj))
                 / (noi + noj);
       }
       else
@@ -594,8 +623,10 @@ public class NJTree
 
     for (int ii = 0; ii < noseqs; ii++)
     {
-      distance[i][ii] = newdist[ii];
-      distance[ii][i] = newdist[ii];
+      // distance[i][ii] = newdist[ii];
+      // distance[ii][i] = newdist[ii];
+      distance.setValue(i, ii, newdist[ii]);
+      distance.setValue(ii, i, newdist[ii]);
     }
   }
 
@@ -611,13 +642,16 @@ public class NJTree
   {
 
     // New distances from cluster to others
-    float[] newdist = new float[noseqs];
+    double[] newdist = new double[noseqs];
 
     for (int l = 0; l < noseqs; l++)
     {
       if ((l != i) && (l != j))
       {
-        newdist[l] = ((distance[i][l] + distance[j][l]) - distance[i][j]) / 2;
+        // newdist[l] = ((distance[i][l] + distance[j][l]) - distance[i][j]) /
+        // 2;
+        newdist[l] = (distance.getValue(i, l) + distance.getValue(j, l) - distance
+                .getValue(i, j)) / 2;
       }
       else
       {
@@ -627,8 +661,10 @@ public class NJTree
 
     for (int ii = 0; ii < noseqs; ii++)
     {
-      distance[i][ii] = newdist[ii];
-      distance[ii][i] = newdist[ii];
+      // distance[i][ii] = newdist[ii];
+      // distance[ii][i] = newdist[ii];
+      distance.setValue(i, ii, newdist[ii]);
+      distance.setValue(ii, i, newdist[ii]);
     }
   }
 
@@ -642,15 +678,16 @@ public class NJTree
    * 
    * @return DOCUMENT ME!
    */
-  public float findr(int i, int j)
+  public double findr(int i, int j)
   {
-    float tmp = 1;
+    double tmp = 1;
 
     for (int k = 0; k < noseqs; k++)
     {
       if ((k != i) && (k != j) && (done[k] != 1))
       {
-        tmp = tmp + distance[i][k];
+        // tmp = tmp + distance[i][k];
+        tmp = tmp + distance.getValue(i, k);
       }
     }
 
@@ -667,9 +704,9 @@ public class NJTree
    * 
    * @return DOCUMENT ME!
    */
-  public float findMinNJDistance()
+  public double findMinNJDistance()
   {
-    float min = 100000;
+    double min = Double.MAX_VALUE;
 
     for (int i = 0; i < (noseqs - 1); i++)
     {
@@ -677,7 +714,9 @@ public class NJTree
       {
         if ((done[i] != 1) && (done[j] != 1))
         {
-          float tmp = distance[i][j] - (findr(i, j) + findr(j, i));
+          // float tmp = distance[i][j] - (findr(i, j) + findr(j, i));
+          double tmp = distance.getValue(i, j)
+                  - (findr(i, j) + findr(j, i));
 
           if (tmp < min)
           {
@@ -698,9 +737,9 @@ public class NJTree
    * 
    * @return DOCUMENT ME!
    */
-  public float findMinDistance()
+  public double findMinDistance()
   {
-    float min = 100000;
+    double min = Double.MAX_VALUE;
 
     for (int i = 0; i < (noseqs - 1); i++)
     {
@@ -708,12 +747,14 @@ public class NJTree
       {
         if ((done[i] != 1) && (done[j] != 1))
         {
-          if (distance[i][j] < min)
+          // if (distance[i][j] < min)
+          if (distance.getValue(i, j) < min)
           {
             mini = i;
             minj = j;
 
-            min = distance[i][j];
+            // min = distance[i][j];
+            min = distance.getValue(i, j);
           }
         }
       }
@@ -725,24 +766,41 @@ public class NJTree
   /**
    * Calculate a distance matrix given the sequence input data and score model
    * 
-   * @return similarity matrix used to compute tree
+   * @return
    */
-  public float[][] findDistances(ScoreModelI _pwmatrix)
+  public MatrixI findDistances(ScoreModelI scoreModel)
   {
+    MatrixI result = null;
 
-    float[][] dist = new float[noseqs][noseqs];
-    if (_pwmatrix == null)
+    if (scoreModel == null)
     {
       // Resolve substitution model
-      _pwmatrix = ScoreModels.getInstance().forName(pwtype);
-      if (_pwmatrix == null)
+      scoreModel = ScoreModels.getInstance().forName(pwtype);
+      if (scoreModel == null)
       {
-        _pwmatrix = ScoreModels.getInstance().forName("BLOSUM62");
+        scoreModel = ScoreModels.getInstance().forName("BLOSUM62");
       }
     }
-    dist = _pwmatrix.findDistances(seqData);
-    return dist;
+    if (scoreModel instanceof DistanceScoreModelI)
+    {
+      result = ((DistanceScoreModelI) scoreModel).findDistances(seqData);
+    }
+    else if (scoreModel instanceof SimilarityScoreModelI)
+    {
+      /*
+       * compute similarity and invert it to give a distance measure
+       */
+      result = ((SimilarityScoreModelI) scoreModel)
+              .findSimilarities(seqData);
+      result.reverseRange(true);
+    }
+    else
+    {
+      System.err
+              .println("Unexpected type of score model, can't compute distances");
+    }
 
+    return result;
   }
 
   /**
@@ -905,7 +963,7 @@ public class NJTree
 
     if ((nd.left() == null) && (nd.right() == null))
     {
-      float dist = nd.dist;
+      double dist = nd.dist;
 
       if (dist > maxDistValue)
       {
@@ -935,7 +993,7 @@ public class NJTree
    * 
    * @return DOCUMENT ME!
    */
-  public float getMaxHeight()
+  public double getMaxHeight()
   {
     return maxheight;
   }
@@ -974,7 +1032,7 @@ public class NJTree
    * 
    * @return DOCUMENT ME!
    */
-  public float findHeight(SequenceNode nd)
+  public double findHeight(SequenceNode nd)
   {
     if (nd == null)
     {
@@ -1024,7 +1082,7 @@ public class NJTree
     {
       ycount = 0;
 
-      float tmpdist = maxdist.dist;
+      double tmpdist = maxdist.dist;
 
       // New top
       SequenceNode sn = new SequenceNode();
@@ -1315,6 +1373,7 @@ public class NJTree
  * @author $author$
  * @version $Revision$
  */
+// TODO what does this class have that int[] doesn't have already?
 class Cluster
 {
   int[] value;