JAL-3205 faster calculation with symmetric score matrix
authorgmungoc <g.m.carstairs@dundee.ac.uk>
Sun, 3 Mar 2019 07:25:25 +0000 (07:25 +0000)
committergmungoc <g.m.carstairs@dundee.ac.uk>
Sun, 3 Mar 2019 07:25:25 +0000 (07:25 +0000)
src/jalview/analysis/scoremodels/PIDModel.java
src/jalview/analysis/scoremodels/ScoreMatrix.java
src/jalview/math/Matrix.java
test/jalview/analysis/scoremodels/PIDModelTest.java
test/jalview/analysis/scoremodels/ScoreMatrixTest.java
test/jalview/math/MatrixTest.java

index c1e8b42..ddfe5e4 100644 (file)
@@ -152,15 +152,17 @@ public class PIDModel extends SimilarityScoreModel
   protected MatrixI findSimilarities(String[] seqs,
           SimilarityParamsI options)
   {
-    // TODO reuse code in ScoreMatrix instead somehow
-    double[][] values = new double[seqs.length][];
+    /*
+     * calculation is symmetric so just compute lower diagonal
+     */
+    double[][] values = new double[seqs.length][seqs.length];
     for (int row = 0; row < seqs.length; row++)
     {
-      values[row] = new double[seqs.length];
-      for (int col = 0; col < seqs.length; col++)
+      for (int col = row; col < seqs.length; col++)
       {
         double total = computePID(seqs[row], seqs[col], options);
         values[row][col] = total;
+        values[col][row] = total;
       }
     }
     return new Matrix(values);
index 6cdfacb..b206339 100644 (file)
@@ -99,6 +99,8 @@ public class ScoreMatrix extends SimilarityScoreModel
 
   private float maxValue;
 
+  private boolean symmetric;
+
   /**
    * Constructor given a name, symbol alphabet, and matrix of scores for pairs
    * of symbols. The matrix should be square and of the same size as the
@@ -156,6 +158,8 @@ public class ScoreMatrix extends SimilarityScoreModel
 
     findMinMax();
 
+    symmetric = checkSymmetry();
+
     /*
      * crude heuristic for now...
      */
@@ -163,6 +167,27 @@ public class ScoreMatrix extends SimilarityScoreModel
   }
 
   /**
+   * Answers true if the matrix is symmetric, else false. Usually, substitution
+   * matrices are symmetric, which allows calculations to be short cut.
+   * 
+   * @return
+   */
+  private boolean checkSymmetry()
+  {
+    for (int i = 0; i < matrix.length; i++)
+    {
+      for (int j = i; j < matrix.length; j++)
+      {
+        if (matrix[i][j] != matrix[j][i])
+        {
+          return false;
+        }
+      }
+    }
+    return true;
+  }
+
+  /**
    * Record the minimum and maximum score values
    */
   protected void findMinMax()
@@ -457,14 +482,17 @@ public class ScoreMatrix extends SimilarityScoreModel
   protected MatrixI findSimilarities(String[] seqs,
           SimilarityParamsI params)
   {
-    double[][] values = new double[seqs.length][];
+    double[][] values = new double[seqs.length][seqs.length];
     for (int row = 0; row < seqs.length; row++)
     {
-      values[row] = new double[seqs.length];
-      for (int col = 0; col < seqs.length; col++)
+      for (int col = symmetric ? row : 0; col < seqs.length; col++)
       {
         double total = computeSimilarity(seqs[row], seqs[col], params);
         values[row][col] = total;
+        if (symmetric)
+        {
+          values[col][row] = total;
+        }
       }
     }
     return new Matrix(values);
@@ -592,4 +620,9 @@ public class ScoreMatrix extends SimilarityScoreModel
   {
     return this;
   }
+
+  public boolean isSymmetric()
+  {
+    return symmetric;
+  }
 }
index 77862c8..1e8f39d 100755 (executable)
@@ -984,4 +984,48 @@ public class Matrix implements MatrixI
   {
     e = v;
   }
+
+  @Override
+  public int hashCode()
+  {
+    return (int) getTotal();
+  }
+
+  public double getTotal()
+  {
+    double d = 0d;
+    for (int i = 0; i < this.height(); i++)
+    {
+      for (int j = 0; j < this.width(); j++)
+      {
+        d += value[i][j];
+      }
+    }
+    return d;
+  }
+
+  @Override
+  public boolean equals(Object obj)
+  {
+    if (!(obj instanceof MatrixI))
+    {
+      return false;
+    }
+    MatrixI m2 = (MatrixI) obj;
+    if (this.height() != m2.height() || this.width() != m2.width())
+    {
+      return false;
+    }
+    for (int i = 0; i < this.height(); i++)
+    {
+      for (int j = 0; j < this.width(); j++)
+      {
+        if (this.getValue(i, j) != m2.getValue(i, j))
+        {
+          return false;
+        }
+      }
+    }
+    return true;
+  }
 }
index 212f825..e8ffd2f 100644 (file)
@@ -44,6 +44,8 @@ public class PIDModelTest
     double newScore = PIDModel.computePID(s1, s2, params);
     double oldScore = Comparison.PID(s1, s2);
     assertEquals(newScore, oldScore, DELTA);
+    // and verify PIDModel calculation is symmetric
+    assertEquals(newScore, PIDModel.computePID(s2, s1, params));
 
     /*
      * same length, with gaps
@@ -54,6 +56,7 @@ public class PIDModelTest
     newScore = PIDModel.computePID(s1, s2, params);
     oldScore = Comparison.PID(s1, s2);
     assertEquals(newScore, oldScore, DELTA);
+    assertEquals(newScore, PIDModel.computePID(s2, s1, params));
 
     /*
      * s2 longer than s1, with gaps
@@ -64,6 +67,7 @@ public class PIDModelTest
     newScore = PIDModel.computePID(s1, s2, params);
     oldScore = Comparison.PID(s1, s2);
     assertEquals(newScore, oldScore, DELTA);
+    assertEquals(newScore, PIDModel.computePID(s2, s1, params));
 
     /*
      * s1 longer than s2, with gaps
@@ -74,6 +78,7 @@ public class PIDModelTest
     newScore = PIDModel.computePID(s1, s2, params);
     oldScore = Comparison.PID(s1, s2);
     assertEquals(newScore, oldScore, DELTA);
+    assertEquals(newScore, PIDModel.computePID(s2, s1, params));
 
     /*
      * same but now also with gapped columns
@@ -84,6 +89,7 @@ public class PIDModelTest
     newScore = PIDModel.computePID(s1, s2, params);
     oldScore = Comparison.PID(s1, s2);
     assertEquals(newScore, oldScore, DELTA);
+    assertEquals(newScore, PIDModel.computePID(s2, s1, params));
   }
 
   /**
@@ -102,6 +108,7 @@ public class PIDModelTest
      */
     SimilarityParamsI params = new SimilarityParams(true, true, true, true);
     assertEquals(PIDModel.computePID(s1, s2, params), 80d);
+    assertEquals(PIDModel.computePID(s2, s1, params), 80d);
 
     /*
      * match gap-char but not gap-gap
@@ -109,6 +116,7 @@ public class PIDModelTest
      */
     params = new SimilarityParams(false, true, true, true);
     assertEquals(PIDModel.computePID(s1, s2, params), 75d);
+    assertEquals(PIDModel.computePID(s2, s1, params), 75d);
 
     /*
      * include gaps but don't match them
@@ -117,6 +125,7 @@ public class PIDModelTest
      */
     params = new SimilarityParams(true, false, true, true);
     assertEquals(PIDModel.computePID(s1, s2, params), 40d);
+    assertEquals(PIDModel.computePID(s2, s1, params), 40d);
 
     /*
      * include gaps but don't match them
@@ -125,6 +134,7 @@ public class PIDModelTest
      */
     params = new SimilarityParams(false, false, true, true);
     assertEquals(PIDModel.computePID(s1, s2, params), 25d);
+    assertEquals(PIDModel.computePID(s2, s1, params), 25d);
   }
 
   /**
@@ -144,6 +154,7 @@ public class PIDModelTest
      */
     SimilarityParamsI params = new SimilarityParams(true, true, true, false);
     assertEquals(PIDModel.computePID(s1, s2, params), 500d / 6);
+    assertEquals(PIDModel.computePID(s2, s1, params), 500d / 6);
   
     /*
      * match gap-char but not gap-gap
@@ -151,6 +162,7 @@ public class PIDModelTest
      */
     params = new SimilarityParams(false, true, true, false);
     assertEquals(PIDModel.computePID(s1, s2, params), 80d);
+    assertEquals(PIDModel.computePID(s2, s1, params), 80d);
   
     /*
      * include gaps but don't match them
@@ -159,6 +171,7 @@ public class PIDModelTest
      */
     params = new SimilarityParams(true, false, true, false);
     assertEquals(PIDModel.computePID(s1, s2, params), 100d / 3);
+    assertEquals(PIDModel.computePID(s2, s1, params), 100d / 3);
   
     /*
      * include gaps but don't match them
@@ -167,6 +180,7 @@ public class PIDModelTest
      */
     params = new SimilarityParams(false, false, true, false);
     assertEquals(PIDModel.computePID(s1, s2, params), 20d);
+    assertEquals(PIDModel.computePID(s2, s1, params), 20d);
 
     /*
      * no tests for matchGaps=true, includeGaps=false
index 1a5d43c..669c452 100644 (file)
@@ -22,6 +22,8 @@ import java.util.Arrays;
 
 import org.testng.annotations.Test;
 
+import junit.extensions.PA;
+
 public class ScoreMatrixTest
 {
   @Test(groups = "Functional")
@@ -33,6 +35,7 @@ public class ScoreMatrixTest
     scores[1] = new float[] { -4f, 5f, 6f };
     scores[2] = new float[] { 7f, 8f, 9f };
     ScoreMatrix sm = new ScoreMatrix("Test", "ABC".toCharArray(), scores);
+    assertFalse(sm.isSymmetric());
     assertEquals(sm.getSize(), 3);
     assertArrayEquals(scores, sm.getMatrix());
     assertEquals(sm.getPairwiseScore('A', 'a'), 1f);
@@ -585,4 +588,33 @@ public class ScoreMatrixTest
             + "</table>";
     assertEquals(html, expected);
   }
+
+  @Test(groups = "Functional")
+  public void testIsSymmetric()
+  {
+    float[][] scores = new float[2][];
+    scores[0] = new float[] { 1f, -2f };
+    scores[1] = new float[] { -2f, 1f };
+    ScoreMatrix sm = new ScoreMatrix("Test", "AB".toCharArray(), scores);
+    assertTrue(sm.isSymmetric());
+
+    scores[1] = new float[] { 2f, 1f };
+    sm = new ScoreMatrix("Test", "AB".toCharArray(), scores);
+    assertFalse(sm.isSymmetric());
+
+    /*
+     * verify that forcing an asymmetric matrix to use
+     * symmetric calculation gives a different (wrong) result
+     */
+    SimilarityParamsI params = new SimilarityParams(true, true, true,
+            false);
+    String[] seqs = new String[] { "AAABBBAA", "AABBABBA" };
+    MatrixI res1 = sm.findSimilarities(seqs, params);
+    MatrixI res2 = sm.findSimilarities(seqs, params);
+    assertTrue(res1.equals(res2));
+    PA.setValue(sm, "symmetric", true);
+    assertTrue(sm.isSymmetric()); // it's not true!
+    res2 = sm.findSimilarities(seqs, params);
+    assertFalse(res1.equals(res2));
+  }
 }
index 2cde593..8af10b0 100644 (file)
@@ -1,6 +1,8 @@
 package jalview.math;
 
 import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertFalse;
+import static org.testng.Assert.assertNotEquals;
 import static org.testng.Assert.assertNotSame;
 import static org.testng.Assert.assertNull;
 import static org.testng.Assert.assertTrue;
@@ -547,4 +549,34 @@ public class MatrixTest
     values[0][0] = -1d;
     assertEquals(m.getValue(0, 0), 1d, DELTA); // unchanged
   }
+
+  @Test(groups = "Functional")
+  public void testEquals_hashCode()
+  {
+    double[][] values = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+    Matrix m1 = new Matrix(values);
+    double[][] values2 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
+    Matrix m2 = new Matrix(values2);
+
+    assertTrue(m1.equals(m1));
+    assertTrue(m1.equals(m2));
+    assertTrue(m2.equals(m1));
+    // equal objects should have same hashCode
+    assertEquals(m1.hashCode(), m2.hashCode());
+
+    double[][] values3 = new double[][] { { 1, 2, 3 }, { 4, 5, 7 } };
+    m2 = new Matrix(values3);
+    assertFalse(m1.equals(m2));
+    assertFalse(m2.equals(m1));
+    assertNotEquals(m1.hashCode(), m2.hashCode());
+
+    // same hashCode doesn't always mean equal
+    values2 = new double[][] { { 1, 2, 3 }, { 4, 6, 5 } };
+    m2 = new Matrix(values2);
+    assertFalse(m2.equals(m1));
+    assertEquals(m1.hashCode(), m2.hashCode());
+
+    assertFalse(m1.equals(null));
+    assertFalse(m1.equals("foo"));
+  }
 }