JAL-2379 SparseMatrix alternative to Matrix, both MatrixI
authorgmungoc <g.m.carstairs@dundee.ac.uk>
Tue, 17 Jan 2017 13:20:37 +0000 (13:20 +0000)
committergmungoc <g.m.carstairs@dundee.ac.uk>
Tue, 17 Jan 2017 13:20:37 +0000 (13:20 +0000)
src/jalview/math/Matrix.java
src/jalview/math/MatrixI.java [new file with mode: 0644]
src/jalview/math/SparseMatrix.java [new file with mode: 0644]
test/jalview/math/MatrixTest.java
test/jalview/math/SparseMatrixTest.java [new file with mode: 0644]

index 28b9d67..0f93eb5 100755 (executable)
@@ -31,39 +31,52 @@ import java.io.PrintStream;
  * @author $author$
  * @version $Revision$
  */
-public class Matrix
+public class Matrix implements MatrixI
 {
-  /**
-   * SMJSPUBLIC
+  /*
+   * the [row][column] values in the matrix
    */
-  public double[][] value;
+  private double[][] value;
 
-  /** DOCUMENT ME!! */
-  public int rows;
+  /*
+   * the number of rows
+   */
+  protected int rows;
 
-  /** DOCUMENT ME!! */
-  public int cols;
+  /*
+   * the number of columns
+   */
+  protected int cols;
 
   /** DOCUMENT ME!! */
-  public double[] d; // Diagonal
+  protected double[] d; // Diagonal
 
   /** DOCUMENT ME!! */
-  public double[] e; // off diagonal
+  protected double[] e; // off diagonal
 
   /**
    * maximum number of iterations for tqli
    * 
    */
-  int maxIter = 45; // fudge - add 15 iterations, just in case
+  private static final int maxIter = 45; // fudge - add 15 iterations, just in
+                                         // case
 
   /**
+   * Default constructor
+   */
+  public Matrix()
+  {
+
+  }
+  
+  /**
    * Creates a new Matrix object. For example
    * 
    * <pre>
-   *   new Matrix(new double[][] {{2, 3}, {4, 5}, 2, 2)
+   *   new Matrix(new double[][] {{2, 3, 4}, {5, 6, 7})
    * constructs
-   *   (2 3)
-   *   (4 5)
+   *   (2 3 4)
+   *   (5 6 7)
    * </pre>
    * 
    * Note that ragged arrays (with not all rows, or columns, of the same
@@ -72,22 +85,24 @@ public class Matrix
    * 
    * @param values
    *          the matrix values in row-major order
-   * @param rows
-   * @param cols
    */
-  public Matrix(double[][] values, int rows, int cols)
+  public Matrix(double[][] values)
   {
-    this.rows = rows;
-    this.cols = cols;
+    this.rows = values.length;
+    if (rows > 0)
+    {
+      this.cols = values[0].length;
+    }
     this.value = values;
   }
 
   /**
-   * Returns a new matrix which is the transposes of this one
+   * Returns a new matrix which is the transpose of this one
    * 
    * @return DOCUMENT ME!
    */
-  public Matrix transpose()
+  @Override
+  public MatrixI transpose()
   {
     double[][] out = new double[cols][rows];
 
@@ -99,7 +114,7 @@ public class Matrix
       }
     }
 
-    return new Matrix(out, cols, rows);
+    return new Matrix(out);
   }
 
   /**
@@ -107,14 +122,16 @@ public class Matrix
    * 
    * @param ps
    *          DOCUMENT ME!
+   * @param format
    */
-  public void print(PrintStream ps)
+  @Override
+  public void print(PrintStream ps, String format)
   {
     for (int i = 0; i < rows; i++)
     {
       for (int j = 0; j < cols; j++)
       {
-        Format.print(ps, "%8.2f", value[i][j]);
+        Format.print(ps, format, getValue(i, j));
       }
 
       ps.println();
@@ -133,29 +150,32 @@ public class Matrix
    *           if the number of columns in the pre-multiplier is not equal to
    *           the number of rows in the multiplicand (this)
    */
-  public Matrix preMultiply(Matrix in)
+  @Override
+  public MatrixI preMultiply(MatrixI in)
   {
-    if (in.cols != this.rows)
+    if (in.width() != rows)
     {
       throw new IllegalArgumentException("Can't pre-multiply " + this.rows
-              + " rows by " + in.cols + " columns");
+              + " rows by " + in.width() + " columns");
     }
-    double[][] tmp = new double[in.rows][this.cols];
+    double[][] tmp = new double[in.height()][this.cols];
 
-    for (int i = 0; i < in.rows; i++)
+    for (int i = 0; i < in.height(); i++)
     {
       for (int j = 0; j < this.cols; j++)
       {
-        tmp[i][j] = 0.0;
-
-        for (int k = 0; k < in.cols; k++)
+        /*
+         * result[i][j] is the vector product of 
+         * in.row[i] and this.column[j]
+         */
+        for (int k = 0; k < in.width(); k++)
         {
-          tmp[i][j] += (in.value[i][k] * this.value[k][j]);
+          tmp[i][j] += (in.getValue(i, k) * this.value[k][j]);
         }
       }
     }
 
-    return new Matrix(tmp, in.rows, this.cols);
+    return new Matrix(tmp);
   }
 
   /**
@@ -196,12 +216,13 @@ public class Matrix
    *           number of columns in the multiplicand (this)
    * @see #preMultiply(Matrix)
    */
-  public Matrix postMultiply(Matrix in)
+  @Override
+  public MatrixI postMultiply(MatrixI in)
   {
-    if (in.rows != this.cols)
+    if (in.height() != this.cols)
     {
       throw new IllegalArgumentException("Can't post-multiply " + this.cols
-              + " columns by " + in.rows + " rows");
+              + " columns by " + in.height() + " rows");
     }
     return in.preMultiply(this);
   }
@@ -211,29 +232,26 @@ public class Matrix
    * 
    * @return
    */
-  public Matrix copy()
+  @Override
+  public MatrixI copy()
   {
     double[][] newmat = new double[rows][cols];
 
     for (int i = 0; i < rows; i++)
     {
       System.arraycopy(value[i], 0, newmat[i], 0, value[i].length);
-      // for (int j = 0; j < cols; j++)
-      // {
-      // newmat[i][j] = value[i][j];
-      // }
     }
 
-    return new Matrix(newmat, rows, cols);
+    return new Matrix(newmat);
   }
 
   /**
    * DOCUMENT ME!
    */
+  @Override
   public void tred()
   {
     int n = rows;
-    int l;
     int k;
     int j;
     int i;
@@ -249,7 +267,7 @@ public class Matrix
 
     for (i = n; i >= 2; i--)
     {
-      l = i - 1;
+      final int l = i - 1;
       h = 0.0;
       scale = 0.0;
 
@@ -257,22 +275,28 @@ public class Matrix
       {
         for (k = 1; k <= l; k++)
         {
-          scale += Math.abs(value[i - 1][k - 1]);
+          // double v = Math.abs(value[i - 1][k - 1]);
+          double v = Math.abs(getValue(i - 1, k - 1));
+          scale += v;
         }
 
         if (scale == 0.0)
         {
-          e[i - 1] = value[i - 1][l - 1];
+          // e[i - 1] = value[i - 1][l - 1];
+          e[i - 1] = getValue(i - 1, l - 1);
         }
         else
         {
           for (k = 1; k <= l; k++)
           {
-            value[i - 1][k - 1] /= scale;
-            h += (value[i - 1][k - 1] * value[i - 1][k - 1]);
+            // value[i - 1][k - 1] /= scale;
+            // h += (value[i - 1][k - 1] * value[i - 1][k - 1]);
+            double v = divideValue(i - 1, k - 1, scale);
+            h += v * v;
           }
 
-          f = value[i - 1][l - 1];
+          // f = value[i - 1][l - 1];
+          f = getValue(i - 1, l - 1);
 
           if (f > 0)
           {
@@ -285,46 +309,60 @@ public class Matrix
 
           e[i - 1] = scale * g;
           h -= (f * g);
-          value[i - 1][l - 1] = f - g;
+          // value[i - 1][l - 1] = f - g;
+          setValue(i - 1, l - 1, f - g);
+          // System.out.println(String.format("%d %d %f %f %f %f %f %.5e", i,
+          // l,
+          // scale, f, g, h, getValue(i - 1, l - 1), checksum()));
           f = 0.0;
 
           for (j = 1; j <= l; j++)
           {
-            value[j - 1][i - 1] = value[i - 1][j - 1] / h;
+            // value[j - 1][i - 1] = value[i - 1][j - 1] / h;
+            double val = getValue(i - 1, j - 1) / h;
+            setValue(j - 1, i - 1, val);
             g = 0.0;
 
             for (k = 1; k <= j; k++)
             {
-              g += (value[j - 1][k - 1] * value[i - 1][k - 1]);
+              // g += (value[j - 1][k - 1] * value[i - 1][k - 1]);
+              g += (getValue(j - 1, k - 1) * getValue(i - 1, k - 1));
             }
 
             for (k = j + 1; k <= l; k++)
             {
-              g += (value[k - 1][j - 1] * value[i - 1][k - 1]);
+              // g += (value[k - 1][j - 1] * value[i - 1][k - 1]);
+              g += (getValue(k - 1, j - 1) * getValue(i - 1, k - 1));
             }
 
             e[j - 1] = g / h;
-            f += (e[j - 1] * value[i - 1][j - 1]);
+            // f += (e[j - 1] * value[i - 1][j - 1]);
+            f += (e[j - 1] * getValue(i - 1, j - 1));
           }
 
           hh = f / (h + h);
 
           for (j = 1; j <= l; j++)
           {
-            f = value[i - 1][j - 1];
+            // f = value[i - 1][j - 1];
+            f = getValue(i - 1, j - 1);
             g = e[j - 1] - (hh * f);
             e[j - 1] = g;
 
             for (k = 1; k <= j; k++)
             {
-              value[j - 1][k - 1] -= ((f * e[k - 1]) + (g * value[i - 1][k - 1]));
+              // value[j - 1][k - 1] -= ((f * e[k - 1]) + (g * value[i - 1][k -
+              // 1]));
+              double val = (f * e[k - 1]) + (g * getValue(i - 1, k - 1));
+              addValue(j - 1, k - 1, -val);
             }
           }
         }
       }
       else
       {
-        e[i - 1] = value[i - 1][l - 1];
+        // e[i - 1] = value[i - 1][l - 1];
+        e[i - 1] = getValue(i - 1, l - 1);
       }
 
       d[i - 1] = h;
@@ -335,7 +373,7 @@ public class Matrix
 
     for (i = 1; i <= n; i++)
     {
-      l = i - 1;
+      final int l = i - 1;
 
       if (d[i - 1] != 0.0)
       {
@@ -345,30 +383,72 @@ public class Matrix
 
           for (k = 1; k <= l; k++)
           {
-            g += (value[i - 1][k - 1] * value[k - 1][j - 1]);
+            // g += (value[i - 1][k - 1] * value[k - 1][j - 1]);
+            g += (getValue(i - 1, k - 1) * getValue(k - 1, j - 1));
           }
 
           for (k = 1; k <= l; k++)
           {
-            value[k - 1][j - 1] -= (g * value[k - 1][i - 1]);
+            // value[k - 1][j - 1] -= (g * value[k - 1][i - 1]);
+            addValue(k - 1, j - 1, -(g * getValue(k - 1, i - 1)));
           }
         }
       }
 
-      d[i - 1] = value[i - 1][i - 1];
-      value[i - 1][i - 1] = 1.0;
+      // d[i - 1] = value[i - 1][i - 1];
+      // value[i - 1][i - 1] = 1.0;
+      d[i - 1] = getValue(i - 1, i - 1);
+      setValue(i - 1, i - 1, 1.0);
 
       for (j = 1; j <= l; j++)
       {
-        value[j - 1][i - 1] = 0.0;
-        value[i - 1][j - 1] = 0.0;
+        // value[j - 1][i - 1] = 0.0;
+        // value[i - 1][j - 1] = 0.0;
+        setValue(j - 1, i - 1, 0.0);
+        setValue(i - 1, j - 1, 0.0);
       }
     }
   }
 
   /**
+   * Adds f to the value at [i, j] and returns the new value
+   * 
+   * @param i
+   * @param j
+   * @param f
+   */
+  protected double addValue(int i, int j, double f)
+  {
+    double v = value[i][j] + f;
+    value[i][j] = v;
+    return v;
+  }
+
+  /**
+   * Divides the value at [i, j] by divisor and returns the new value. If d is
+   * zero, returns the unchanged value.
+   * 
+   * @param i
+   * @param j
+   * @param divisor
+   * @return
+   */
+  protected double divideValue(int i, int j, double divisor)
+  {
+    if (divisor == 0d)
+    {
+      return getValue(i, j);
+    }
+    double v = value[i][j];
+    v = v / divisor;
+    value[i][j] = v;
+    return v;
+  }
+
+  /**
    * DOCUMENT ME!
    */
+  @Override
   public void tqli() throws Exception
   {
     int n = rows;
@@ -381,7 +461,6 @@ public class Matrix
     double s;
     double r;
     double p;
-    ;
 
     double g;
     double f;
@@ -464,9 +543,12 @@ public class Matrix
 
             for (k = 1; k <= n; k++)
             {
-              f = value[k - 1][i];
-              value[k - 1][i] = (s * value[k - 1][i - 1]) + (c * f);
-              value[k - 1][i - 1] = (c * value[k - 1][i - 1]) - (s * f);
+              // f = value[k - 1][i];
+              // value[k - 1][i] = (s * value[k - 1][i - 1]) + (c * f);
+              // value[k - 1][i - 1] = (c * value[k - 1][i - 1]) - (s * f);
+              f = getValue(k - 1, i);
+              setValue(k - 1, i, (s * getValue(k - 1, i - 1)) + (c * f));
+              setValue(k - 1, i - 1, (c * getValue(k - 1, i - 1)) - (s * f));
             }
           }
 
@@ -478,6 +560,17 @@ public class Matrix
     }
   }
 
+  @Override
+  public double getValue(int i, int j)
+  {
+    return value[i][j];
+  }
+
+  public void setValue(int i, int j, double val)
+  {
+    value[i][j] = val;
+  }
+
   /**
    * DOCUMENT ME!
    */
@@ -730,16 +823,14 @@ public class Matrix
   }
 
   /**
-   * DOCUMENT ME!
+   * Answers the first argument with the sign of the second argument
    * 
    * @param a
-   *          DOCUMENT ME!
    * @param b
-   *          DOCUMENT ME!
    * 
-   * @return DOCUMENT ME!
+   * @return
    */
-  public double sign(double a, double b)
+  static double sign(double a, double b)
   {
     if (b < 0)
     {
@@ -775,12 +866,14 @@ public class Matrix
    * 
    * @param ps
    *          DOCUMENT ME!
+   * @param format
    */
-  public void printD(PrintStream ps)
+  @Override
+  public void printD(PrintStream ps, String format)
   {
     for (int j = 0; j < rows; j++)
     {
-      Format.print(ps, "%15.4e", d[j]);
+      Format.print(ps, format, d[j]);
     }
   }
 
@@ -789,12 +882,45 @@ public class Matrix
    * 
    * @param ps
    *          DOCUMENT ME!
+   * @param format TODO
    */
-  public void printE(PrintStream ps)
+  @Override
+  public void printE(PrintStream ps, String format)
   {
     for (int j = 0; j < rows; j++)
     {
-      Format.print(ps, "%15.4e", e[j]);
+      Format.print(ps, format, e[j]);
     }
   }
+
+  @Override
+  public double[] getD()
+  {
+    return d;
+  }
+
+  @Override
+  public double[] getE()
+  {
+    return e;
+  }
+  
+  @Override
+  public int height() {
+    return rows;
+  }
+
+  @Override
+  public int width()
+  {
+    return cols;
+  }
+
+  @Override
+  public double[] getRow(int i)
+  {
+    double[] row = new double[cols];
+    System.arraycopy(value[i], 0, row, 0, cols);
+    return row;
+  }
 }
diff --git a/src/jalview/math/MatrixI.java b/src/jalview/math/MatrixI.java
new file mode 100644 (file)
index 0000000..d74a98b
--- /dev/null
@@ -0,0 +1,59 @@
+package jalview.math;
+
+import java.io.PrintStream;
+
+public interface MatrixI
+{
+  /**
+   * Answers the number of columns
+   * 
+   * @return
+   */
+  int width();
+
+  /**
+   * Answers the number of rows
+   * 
+   * @return
+   */
+  int height();
+
+  /**
+   * Answers the value at row i, column j
+   * 
+   * @param i
+   * @param j
+   * @return
+   */
+  double getValue(int i, int j);
+
+  /**
+   * Answers a copy of the values in the i'th row
+   * 
+   * @return
+   */
+  double[] getRow(int i);
+  
+  MatrixI copy();
+
+  MatrixI transpose();
+
+  MatrixI preMultiply(MatrixI m);
+
+  MatrixI postMultiply(MatrixI m);
+
+  double[] getD();
+
+  double[] getE();
+
+  void print(PrintStream ps, String format);
+
+  void printD(PrintStream ps, String format);
+
+  void printE(PrintStream ps, String format);
+
+  void tqli() throws Exception;
+
+  void tred();
+
+}
diff --git a/src/jalview/math/SparseMatrix.java b/src/jalview/math/SparseMatrix.java
new file mode 100644 (file)
index 0000000..4a9d427
--- /dev/null
@@ -0,0 +1,218 @@
+package jalview.math;
+
+import jalview.ext.android.SparseDoubleArray;
+
+/**
+ * A variant of Matrix intended for use for sparse (mostly zero) matrices. This
+ * class uses a SparseDoubleArray to hold each row of the matrix. The sparse
+ * array only stores non-zero values. This gives a smaller memory footprint, and
+ * fewer matrix calculation operations, for mostly zero matrices.
+ * 
+ * @author gmcarstairs
+ */
+public class SparseMatrix extends Matrix
+{
+  /*
+   * we choose columns for the sparse arrays as this allows
+   * optimisation of the preMultiply() method used in PCA.run()
+   */
+  SparseDoubleArray[] sparseColumns;
+
+  /**
+   * Constructor given data in [row][column] order
+   * 
+   * @param v
+   */
+  public SparseMatrix(double[][] v)
+  {
+    rows = v.length;
+    if (rows > 0) {
+      cols = v[0].length;
+    }
+    sparseColumns = new SparseDoubleArray[cols];
+
+    /*
+     * transpose v[row][col] into [col][row] order
+     */
+    for (int col = 0; col < cols; col++)
+    {
+      SparseDoubleArray sparseColumn = new SparseDoubleArray();
+      sparseColumns[col] = sparseColumn;
+      for (int row = 0; row < rows; row++)
+      {
+        double value = v[row][col];
+        if (value != 0d)
+        {
+          sparseColumn.put(row, value);
+        }
+      }
+    }
+  }
+
+  /**
+   * Answers the value at row i, column j
+   */
+  @Override
+  public double getValue(int i, int j)
+  {
+    return sparseColumns[j].get(i);
+  }
+
+  /**
+   * Sets the value at row i, column j to val
+   */
+  @Override
+  public void setValue(int i, int j, double val)
+  {
+    if (val == 0d)
+    {
+      sparseColumns[j].delete(i);
+    }
+    else
+    {
+      sparseColumns[j].put(i, val);
+    }
+  }
+
+  @Override
+  public double[] getColumn(int i)
+  {
+    double[] col = new double[height()];
+
+    SparseDoubleArray vals = sparseColumns[i];
+    for (int nonZero = 0; nonZero < vals.size(); nonZero++)
+    {
+      col[vals.keyAt(nonZero)] = vals.valueAt(nonZero);
+    }
+    return col;
+  }
+
+  @Override
+  public MatrixI copy()
+  {
+    double[][] vals = new double[height()][width()];
+    for (int i = 0; i < height(); i++)
+    {
+      vals[i] = getRow(i);
+    }
+    return new SparseMatrix(vals);
+  }
+
+  @Override
+  public MatrixI transpose()
+  {
+    double[][] out = new double[cols][rows];
+
+    /*
+     * for each column...
+     */
+    for (int i = 0; i < cols; i++)
+    {
+      /*
+       * put non-zero values into the corresponding row
+       * of the transposed matrix
+       */
+      SparseDoubleArray vals = sparseColumns[i];
+      for (int nonZero = 0; nonZero < vals.size(); nonZero++)
+      {
+        out[i][vals.keyAt(nonZero)] = vals.valueAt(nonZero);
+      }
+    }
+
+    return new SparseMatrix(out);
+  }
+
+  /**
+   * Answers a new matrix which is the product in.this. If the product contains
+   * less than 20% non-zero values, it is returned as a SparseMatrix, else as a
+   * Matrix.
+   * <p>
+   * This method is optimised for the sparse arrays which store column values
+   * for a SparseMatrix. Note that postMultiply is not so optimised. That would
+   * require redundantly also storing sparse arrays for the rows, which has not
+   * been done. Currently only preMultiply is used in Jalview.
+   */
+  @Override
+  public MatrixI preMultiply(MatrixI in)
+  {
+    if (in.width() != rows)
+    {
+      throw new IllegalArgumentException("Can't pre-multiply " + this.rows
+              + " rows by " + in.width() + " columns");
+    }
+    double[][] tmp = new double[in.height()][this.cols];
+
+    long count = 0L;
+    for (int i = 0; i < in.height(); i++)
+    {
+      for (int j = 0; j < this.cols; j++)
+      {
+        /*
+         * result[i][j] is the vector product of 
+         * in.row[i] and this.column[j]
+         * we only need to use non-zero values from the column
+         */
+        SparseDoubleArray vals = sparseColumns[j];
+        boolean added = false;
+        for (int nonZero = 0; nonZero < vals.size(); nonZero++)
+        {
+          int myRow = vals.keyAt(nonZero);
+          double myValue = vals.valueAt(nonZero);
+          tmp[i][j] += (in.getValue(i, myRow) * myValue);
+          added = true;
+        }
+        if (added && tmp[i][j] != 0d)
+        {
+          count++; // non-zero entry in product
+        }
+      }
+    }
+
+    /*
+     * heuristic rule - if product is more than 80% zero
+     * then construct a SparseMatrix, else a Matrix
+     */
+    if (count * 5 < in.height() * cols)
+    {
+      return new SparseMatrix(tmp);
+    }
+    else
+    {
+      return new Matrix(tmp);
+    }
+  }
+
+  @Override
+  protected double divideValue(int i, int j, double divisor)
+  {
+    if (divisor == 0d)
+    {
+      return getValue(i, j);
+    }
+    double v = sparseColumns[j].multiply(i, 1 / divisor);
+    return v;
+  }
+
+  @Override
+  protected double addValue(int i, int j, double addend)
+  {
+    double v = sparseColumns[j].add(i, addend);
+    return v;
+  }
+
+  /**
+   * Returns the fraction of the whole matrix size that is actually modelled in
+   * sparse arrays (normally, the non-zero values)
+   * 
+   * @return
+   */
+  public float getFillRatio()
+  {
+    long count = 0L;
+    for (SparseDoubleArray col : sparseColumns)
+    {
+      count += col.size();
+    }
+    return count / (float) (height() * width());
+  }
+}
index 795b2fa..1500dc6 100644 (file)
@@ -16,8 +16,8 @@ public class MatrixTest
     int cols = 1000;
     double[][] d1 = new double[rows][cols];
     double[][] d2 = new double[cols][rows];
-    Matrix m1 = new Matrix(d1, rows, cols);
-    Matrix m2 = new Matrix(d2, cols, rows);
+    Matrix m1 = new Matrix(d1);
+    Matrix m2 = new Matrix(d2);
     long start = System.currentTimeMillis();
     m1.preMultiply(m2);
     long elapsed = System.currentTimeMillis() - start;
@@ -28,27 +28,27 @@ public class MatrixTest
   @Test(groups = "Functional")
   public void testPreMultiply()
   {
-    Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 } }, 1, 3); // 1x3
-    Matrix m2 = new Matrix(new double[][] { { 5 }, { 6 }, { 7 } }, 3, 1); // 3x1
+    MatrixI m1 = new Matrix(new double[][] { { 2, 3, 4 } }); // 1x3
+    MatrixI m2 = new Matrix(new double[][] { { 5 }, { 6 }, { 7 } }); // 3x1
 
     /*
      * 1x3 times 3x1 is 1x1
      * 2x5 + 3x6 + 4*7 =  56
      */
-    Matrix m3 = m2.preMultiply(m1);
-    assertEquals(m3.rows, 1);
-    assertEquals(m3.cols, 1);
-    assertEquals(m3.value[0][0], 56d);
+    MatrixI m3 = m2.preMultiply(m1);
+    assertEquals(m3.height(), 1);
+    assertEquals(m3.width(), 1);
+    assertEquals(m3.getValue(0, 0), 56d);
 
     /*
      * 3x1 times 1x3 is 3x3
      */
     m3 = m1.preMultiply(m2);
-    assertEquals(m3.rows, 3);
-    assertEquals(m3.cols, 3);
-    assertEquals(Arrays.toString(m3.value[0]), "[10.0, 15.0, 20.0]");
-    assertEquals(Arrays.toString(m3.value[1]), "[12.0, 18.0, 24.0]");
-    assertEquals(Arrays.toString(m3.value[2]), "[14.0, 21.0, 28.0]");
+    assertEquals(m3.height(), 3);
+    assertEquals(m3.width(), 3);
+    assertEquals(Arrays.toString(m3.getRow(0)), "[10.0, 15.0, 20.0]");
+    assertEquals(Arrays.toString(m3.getRow(1)), "[12.0, 18.0, 24.0]");
+    assertEquals(Arrays.toString(m3.getRow(2)), "[14.0, 21.0, 28.0]");
   }
 
   @Test(
@@ -56,8 +56,7 @@ public class MatrixTest
     expectedExceptions = { IllegalArgumentException.class })
   public void testPreMultiply_tooManyColumns()
   {
-    Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }, 2,
-            3); // 2x3
+    Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
 
     /*
      * 2x3 times 2x3 invalid operation - 
@@ -72,8 +71,7 @@ public class MatrixTest
     expectedExceptions = { IllegalArgumentException.class })
   public void testPreMultiply_tooFewColumns()
   {
-    Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }, 2,
-            3); // 2x3
+    Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
 
     /*
      * 3x2 times 3x2 invalid operation - 
@@ -85,7 +83,18 @@ public class MatrixTest
   
   
   private boolean matrixEquals(Matrix m1, Matrix m2) {
-    return Arrays.deepEquals(m1.value, m2.value);
+    if (m1.width() != m2.width() || m1.height() != m2.height())
+    {
+      return false;
+    }
+    for (int i = 0; i < m1.height(); i++)
+    {
+      if (!Arrays.equals(m1.getRow(i), m2.getRow(i)))
+      {
+        return false;
+      }
+    }
+    return true;
   }
 
   @Test(groups = "Functional")
@@ -99,37 +108,36 @@ public class MatrixTest
      * (3020 30200)
      * (5040 50400)
      */
-    Matrix m1 = new Matrix(new double[][] { { 2, 3 }, { 4, 5 } }, 2, 2);
-    Matrix m2 = new Matrix(new double[][] { { 10, 100 }, { 1000, 10000 } },
-            2, 2);
-    Matrix m3 = m1.postMultiply(m2);
-    assertEquals(Arrays.toString(m3.value[0]), "[3020.0, 30200.0]");
-    assertEquals(Arrays.toString(m3.value[1]), "[5040.0, 50400.0]");
+    MatrixI m1 = new Matrix(new double[][] { { 2, 3 }, { 4, 5 } });
+    MatrixI m2 = new Matrix(new double[][] { { 10, 100 }, { 1000, 10000 } });
+    MatrixI m3 = m1.postMultiply(m2);
+    assertEquals(Arrays.toString(m3.getRow(0)), "[3020.0, 30200.0]");
+    assertEquals(Arrays.toString(m3.getRow(1)), "[5040.0, 50400.0]");
 
     /*
      * also check m2.preMultiply(m1) - should be same as m1.postMultiply(m2) 
      */
     m3 = m2.preMultiply(m1);
-    assertEquals(Arrays.toString(m3.value[0]), "[3020.0, 30200.0]");
-    assertEquals(Arrays.toString(m3.value[1]), "[5040.0, 50400.0]");
+    assertEquals(Arrays.toString(m3.getRow(0)), "[3020.0, 30200.0]");
+    assertEquals(Arrays.toString(m3.getRow(1)), "[5040.0, 50400.0]");
 
     /*
      * m1 has more rows than columns
      * (2).(10 100 1000) = (20 200 2000)
      * (3)                 (30 300 3000)
      */
-    m1 = new Matrix(new double[][] { { 2 }, { 3 } }, 2, 1);
-    m2 = new Matrix(new double[][] { { 10, 100, 1000 } }, 1, 3);
+    m1 = new Matrix(new double[][] { { 2 }, { 3 } });
+    m2 = new Matrix(new double[][] { { 10, 100, 1000 } });
     m3 = m1.postMultiply(m2);
-    assertEquals(m3.rows, 2);
-    assertEquals(m3.cols, 3);
-    assertEquals(Arrays.toString(m3.value[0]), "[20.0, 200.0, 2000.0]");
-    assertEquals(Arrays.toString(m3.value[1]), "[30.0, 300.0, 3000.0]");
+    assertEquals(m3.height(), 2);
+    assertEquals(m3.width(), 3);
+    assertEquals(Arrays.toString(m3.getRow(0)), "[20.0, 200.0, 2000.0]");
+    assertEquals(Arrays.toString(m3.getRow(1)), "[30.0, 300.0, 3000.0]");
     m3 = m2.preMultiply(m1);
-    assertEquals(m3.rows, 2);
-    assertEquals(m3.cols, 3);
-    assertEquals(Arrays.toString(m3.value[0]), "[20.0, 200.0, 2000.0]");
-    assertEquals(Arrays.toString(m3.value[1]), "[30.0, 300.0, 3000.0]");
+    assertEquals(m3.height(), 2);
+    assertEquals(m3.width(), 3);
+    assertEquals(Arrays.toString(m3.getRow(0)), "[20.0, 200.0, 2000.0]");
+    assertEquals(Arrays.toString(m3.getRow(1)), "[30.0, 300.0, 3000.0]");
 
     /*
      * m1 has more columns than rows
@@ -139,22 +147,22 @@ public class MatrixTest
      * [0, 0] = 2*5 + 3*6 + 4*7 = 56
      * [0, 1] = 2*4 + 3*3 + 4*2 = 25  
      */
-    m1 = new Matrix(new double[][] { { 2, 3, 4 } }, 1, 3);
-    m2 = new Matrix(new double[][] { { 5, 4 }, { 6, 3 }, { 7, 2 } }, 3, 2);
+    m1 = new Matrix(new double[][] { { 2, 3, 4 } });
+    m2 = new Matrix(new double[][] { { 5, 4 }, { 6, 3 }, { 7, 2 } });
     m3 = m1.postMultiply(m2);
-    assertEquals(m3.rows, 1);
-    assertEquals(m3.cols, 2);
-    assertEquals(m3.value[0][0], 56d);
-    assertEquals(m3.value[0][1], 25d);
+    assertEquals(m3.height(), 1);
+    assertEquals(m3.width(), 2);
+    assertEquals(m3.getRow(0)[0], 56d);
+    assertEquals(m3.getRow(0)[1], 25d);
 
     /*
      * and check premultiply equivalent
      */
     m3 = m2.preMultiply(m1);
-    assertEquals(m3.rows, 1);
-    assertEquals(m3.cols, 2);
-    assertEquals(m3.value[0][0], 56d);
-    assertEquals(m3.value[0][1], 25d);
+    assertEquals(m3.height(), 1);
+    assertEquals(m3.width(), 2);
+    assertEquals(m3.getRow(0)[0], 56d);
+    assertEquals(m3.getRow(0)[1], 25d);
   }
 
   /**
@@ -175,18 +183,18 @@ public class MatrixTest
       }
     }
   
-    Matrix origmat = new Matrix(in, n, n);
+    Matrix origmat = new Matrix(in);
   
     // System.out.println(" --- Original matrix ---- ");
     // / origmat.print(System.out);
     // System.out.println();
     // System.out.println(" --- transpose matrix ---- ");
-    Matrix trans = origmat.transpose();
+    MatrixI trans = origmat.transpose();
   
     // trans.print(System.out);
     // System.out.println();
     // System.out.println(" --- OrigT * Orig ---- ");
-    Matrix symm = trans.postMultiply(origmat);
+    MatrixI symm = trans.postMultiply(origmat);
   
     // symm.print(System.out);
     // System.out.println();
@@ -236,4 +244,15 @@ public class MatrixTest
     // }
     // System.out.println();
   }
+
+  @Test(groups = "Timing")
+  public void testSign()
+  {
+    assertEquals(Matrix.sign(-1, -2), -1d);
+    assertEquals(Matrix.sign(-1, 2), 1d);
+    assertEquals(Matrix.sign(-1, 0), 1d);
+    assertEquals(Matrix.sign(1, -2), -1d);
+    assertEquals(Matrix.sign(1, 2), 1d);
+    assertEquals(Matrix.sign(1, 0), 1d);
+  }
 }
diff --git a/test/jalview/math/SparseMatrixTest.java b/test/jalview/math/SparseMatrixTest.java
new file mode 100644 (file)
index 0000000..607d415
--- /dev/null
@@ -0,0 +1,380 @@
+package jalview.math;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertFalse;
+import static org.testng.Assert.assertTrue;
+import static org.testng.Assert.fail;
+
+import java.util.Random;
+
+import org.testng.annotations.Test;
+import org.testng.internal.junit.ArrayAsserts;
+
+public class SparseMatrixTest
+{
+  final static double DELTA = 0.0001d;
+
+  Random r = new Random(1729);
+
+  @Test(groups = "Functional")
+  public void testConstructor()
+  {
+    MatrixI m1 = new SparseMatrix(
+            new double[][] { { 2, 0, 4 }, { 0, 6, 0 } });
+    assertEquals(m1.getValue(0, 0), 2d);
+    assertEquals(m1.getValue(0, 1), 0d);
+    assertEquals(m1.getValue(0, 2), 4d);
+    assertEquals(m1.getValue(1, 0), 0d);
+    assertEquals(m1.getValue(1, 1), 6d);
+    assertEquals(m1.getValue(1, 2), 0d);
+  }
+
+  @Test(groups = "Functional")
+  public void testTranspose()
+  {
+    MatrixI m1 = new SparseMatrix(
+            new double[][] { { 2, 0, 4 }, { 5, 6, 0 } });
+    MatrixI m2 = m1.transpose();
+    assertTrue(m2 instanceof SparseMatrix);
+    assertEquals(m2.height(), 3);
+    assertEquals(m2.width(), 2);
+    assertEquals(m2.getValue(0, 0), 2d);
+    assertEquals(m2.getValue(0, 1), 5d);
+    assertEquals(m2.getValue(1, 0), 0d);
+    assertEquals(m2.getValue(1, 1), 6d);
+    assertEquals(m2.getValue(2, 0), 4d);
+    assertEquals(m2.getValue(2, 1), 0d);
+  }
+  @Test(groups = "Functional")
+  public void testPreMultiply()
+  {
+    MatrixI m1 = new SparseMatrix(new double[][] { { 2, 3, 4 } }); // 1x3
+    MatrixI m2 = new SparseMatrix(new double[][] { { 5 }, { 6 }, { 7 } }); // 3x1
+
+    /*
+     * 1x3 times 3x1 is 1x1
+     * 2x5 + 3x6 + 4*7 =  56
+     */
+    MatrixI m3 = m2.preMultiply(m1);
+    assertFalse(m3 instanceof SparseMatrix);
+    assertEquals(m3.height(), 1);
+    assertEquals(m3.width(), 1);
+    assertEquals(m3.getValue(0, 0), 56d);
+
+    /*
+     * 3x1 times 1x3 is 3x3
+     */
+    m3 = m1.preMultiply(m2);
+    assertEquals(m3.height(), 3);
+    assertEquals(m3.width(), 3);
+    assertEquals(m3.getValue(0, 0), 10d);
+    assertEquals(m3.getValue(0, 1), 15d);
+    assertEquals(m3.getValue(0, 2), 20d);
+    assertEquals(m3.getValue(1, 0), 12d);
+    assertEquals(m3.getValue(1, 1), 18d);
+    assertEquals(m3.getValue(1, 2), 24d);
+    assertEquals(m3.getValue(2, 0), 14d);
+    assertEquals(m3.getValue(2, 1), 21d);
+    assertEquals(m3.getValue(2, 2), 28d);
+  }
+
+  @Test(
+    groups = "Functional",
+    expectedExceptions = { IllegalArgumentException.class })
+  public void testPreMultiply_tooManyColumns()
+  {
+    Matrix m1 = new SparseMatrix(
+            new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
+
+    /*
+     * 2x3 times 2x3 invalid operation - 
+     * multiplier has more columns than multiplicand has rows
+     */
+    m1.preMultiply(m1);
+    fail("Expected exception");
+  }
+
+  @Test(
+    groups = "Functional",
+    expectedExceptions = { IllegalArgumentException.class })
+  public void testPreMultiply_tooFewColumns()
+  {
+    Matrix m1 = new SparseMatrix(
+            new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
+
+    /*
+     * 3x2 times 3x2 invalid operation - 
+     * multiplier has more columns than multiplicand has row
+     */
+    m1.preMultiply(m1);
+    fail("Expected exception");
+  }
+  
+  @Test(groups = "Functional")
+  public void testPostMultiply()
+  {
+    /*
+     * Square matrices
+     * (2 3) . (10   100)
+     * (4 5)   (1000 10000)
+     * =
+     * (3020 30200)
+     * (5040 50400)
+     */
+    MatrixI m1 = new SparseMatrix(new double[][] { { 2, 3 }, { 4, 5 } });
+    MatrixI m2 = new SparseMatrix(new double[][] { { 10, 100 },
+        { 1000, 10000 } });
+    MatrixI m3 = m1.postMultiply(m2);
+    assertEquals(m3.getValue(0, 0), 3020d);
+    assertEquals(m3.getValue(0, 1), 30200d);
+    assertEquals(m3.getValue(1, 0), 5040d);
+    assertEquals(m3.getValue(1, 1), 50400d);
+
+    /*
+     * also check m2.preMultiply(m1) - should be same as m1.postMultiply(m2) 
+     */
+    MatrixI m4 = m2.preMultiply(m1);
+    assertMatricesMatch(m3, m4);
+
+    /*
+     * m1 has more rows than columns
+     * (2).(10 100 1000) = (20 200 2000)
+     * (3)                 (30 300 3000)
+     */
+    m1 = new SparseMatrix(new double[][] { { 2 }, { 3 } });
+    m2 = new SparseMatrix(new double[][] { { 10, 100, 1000 } });
+    m3 = m1.postMultiply(m2);
+    assertEquals(m3.height(), 2);
+    assertEquals(m3.width(), 3);
+    assertEquals(m3.getValue(0, 0), 20d);
+    assertEquals(m3.getValue(0, 1), 200d);
+    assertEquals(m3.getValue(0, 2), 2000d);
+    assertEquals(m3.getValue(1, 0), 30d);
+    assertEquals(m3.getValue(1, 1), 300d);
+    assertEquals(m3.getValue(1, 2), 3000d);
+
+    m4 = m2.preMultiply(m1);
+    assertMatricesMatch(m3, m4);
+
+    /*
+     * m1 has more columns than rows
+     * (2 3 4) . (5 4) = (56 25)
+     *           (6 3) 
+     *           (7 2)
+     * [0, 0] = 2*5 + 3*6 + 4*7 = 56
+     * [0, 1] = 2*4 + 3*3 + 4*2 = 25  
+     */
+    m1 = new SparseMatrix(new double[][] { { 2, 3, 4 } });
+    m2 = new SparseMatrix(new double[][] { { 5, 4 }, { 6, 3 }, { 7, 2 } });
+    m3 = m1.postMultiply(m2);
+    assertEquals(m3.height(), 1);
+    assertEquals(m3.width(), 2);
+    assertEquals(m3.getValue(0, 0), 56d);
+    assertEquals(m3.getValue(0, 1), 25d);
+
+    /*
+     * and check premultiply equivalent
+     */
+    m4 = m2.preMultiply(m1);
+    assertMatricesMatch(m3, m4);
+  }
+
+  @Test(groups = "Timing")
+  public void testSign()
+  {
+    assertEquals(Matrix.sign(-1, -2), -1d);
+    assertEquals(Matrix.sign(-1, 2), 1d);
+    assertEquals(Matrix.sign(-1, 0), 1d);
+    assertEquals(Matrix.sign(1, -2), -1d);
+    assertEquals(Matrix.sign(1, 2), 1d);
+    assertEquals(Matrix.sign(1, 0), 1d);
+  }
+
+  /**
+   * Verify that the results of method tred() are the same for SparseMatrix as
+   * they are for Matrix (i.e. a regression test rather than an absolute test of
+   * correctness of results)
+   */
+  @Test(groups = "Functional")
+  public void testTred_matchesMatrix()
+  {
+    /*
+     * make a pseudo-random symmetric matrix as required for tred/tqli
+     * note: test fails for matrices larger than 6x6 due to double value
+     * rounding only (random values result in very small values)
+     */
+    int rows = 6;
+    int cols = rows;
+    double[][] d = getSparseValues(rows, cols, 3);
+
+    /*
+     * make a copy of the values so m1, m2 are not
+     * sharing arrays!
+     */
+    double[][] d1 = new double[rows][cols];
+    for (int row = 0; row < rows; row++)
+    {
+      for (int col = 0; col < cols; col++)
+      {
+        d1[row][col] = d[row][col];
+      }
+    }
+    Matrix m1 = new Matrix(d);
+    Matrix m2 = new SparseMatrix(d1);
+    assertMatricesMatch(m1, m2); // sanity check
+    m1.tred();
+    m2.tred();
+    assertMatricesMatch(m1, m2);
+  }
+
+  private void assertMatricesMatch(MatrixI m1, MatrixI m2)
+  {
+    if (m1.height() != m2.height())
+    {
+      fail("height mismatch");
+    }
+    if (m1.width() != m2.width())
+    {
+      fail("width mismatch");
+    }
+    for (int row = 0; row < m1.height(); row++)
+    {
+      for (int col = 0; col < m1.width(); col++)
+      {
+        double v2 = m2.getValue(row, col);
+        double v1 = m1.getValue(row, col);
+        if (Math.abs(v1 - v2) > DELTA)
+        {
+          fail(String.format("At [%d, %d] %f != %f", row, col, v1, v2));
+        }
+      }
+    }
+    ArrayAsserts.assertArrayEquals(m1.getD(), m2.getD(), 0.00001d);
+    ArrayAsserts.assertArrayEquals(m1.getE(), m2.getE(), 0.00001d);
+  }
+
+  @Test
+  public void testGetValue()
+  {
+    double[][] d = new double[][] { { 0, 0, 1, 0, 0 }, { 2, 3, 0, 0, 0 },
+        { 4, 0, 0, 0, 5 } };
+    MatrixI m = new SparseMatrix(d);
+    for (int row = 0; row < 3; row++)
+    {
+      for (int col = 0; col < 5; col++)
+      {
+        assertEquals(m.getValue(row, col), d[row][col],
+                String.format("At [%d, %d]", row, col));
+      }
+    }
+  }
+
+  /**
+   * Verify that the results of method tqli() are the same for SparseMatrix as
+   * they are for Matrix (i.e. a regression test rather than an absolute test of
+   * correctness of results)
+   * 
+   * @throws Exception
+   */
+  @Test(groups = "Functional")
+  public void testTqli_matchesMatrix() throws Exception
+  {
+    /*
+     * make a pseudo-random symmetric matrix as required for tred
+     */
+    int rows = 6;
+    int cols = rows;
+    double[][] d = getSparseValues(rows, cols, 3);
+  
+    /*
+     * make a copy of the values so m1, m2 are not
+     * sharing arrays!
+     */
+    double[][] d1 = new double[rows][cols];
+    for (int row = 0; row < rows; row++)
+    {
+      for (int col = 0; col < cols; col++)
+      {
+        d1[row][col] = d[row][col];
+      }
+    }
+    Matrix m1 = new Matrix(d);
+    Matrix m2 = new SparseMatrix(d1);
+
+    // have to do tred() before doing tqli()
+    m1.tred();
+    m2.tred();
+    assertMatricesMatch(m1, m2);
+
+    m1.tqli();
+    m2.tqli();
+    assertMatricesMatch(m1, m2);
+  }
+
+  /**
+   * Helper method to make values for a sparse, pseudo-random symmetric matrix
+   * 
+   * @param rows
+   * @param cols
+   * @param fraction
+   *          one n fraction entries will be non-zero
+   * @return
+   */
+  public double[][] getSparseValues(int rows, int cols, int fraction)
+  {
+    double[][] d = new double[rows][cols];
+    int m = 0;
+    for (int i = 0; i < rows; i++)
+    {
+      if (++m % fraction == 0)
+      {
+        d[i][i] = r.nextDouble(); // diagonal
+      }
+      for (int j = 0; j < i; j++)
+      {
+        if (++m % fraction == 0)
+        {
+          d[i][j] = r.nextDouble();
+          d[j][i] = d[i][j];
+        }
+      }
+    }
+    return d;
+
+  }
+
+  /**
+   * Test that verifies that the result of preMultiply is a SparseMatrix if more
+   * than 80% zeroes, else a Matrix
+   */
+  @Test(groups = "Functional")
+  public void testPreMultiply_sparseProduct()
+  {
+    MatrixI m1 = new SparseMatrix(new double[][] { { 1 }, { 0 }, { 0 },
+        { 0 }, { 0 } }); // 5x1
+    MatrixI m2 = new SparseMatrix(new double[][] { { 1, 1, 1, 1 } }); // 1x4
+  
+    /*
+     * m1.m2 makes a row of 4 1's, and 4 rows of zeros
+     * 20% non-zero so not 'sparse'
+     */
+    MatrixI m3 = m2.preMultiply(m1);
+    assertFalse(m3 instanceof SparseMatrix);
+
+    /*
+     * replace a 1 with a 0 in the product:
+     * it is now > 80% zero so 'sparse'
+     */
+    m2 = new SparseMatrix(new double[][] { { 1, 1, 1, 0 } });
+    m3 = m2.preMultiply(m1);
+    assertTrue(m3 instanceof SparseMatrix);
+  }
+
+  @Test(groups = "Functional")
+  public void testFillRatio()
+  {
+    SparseMatrix m1 = new SparseMatrix(new double[][] { { 2, 0, 4, 1, 0 },
+    { 0, 6, 0, 0, 0 } });
+    assertEquals(m1.getFillRatio(), 0.4f);
+  }
+}
\ No newline at end of file