7295fade14e85e034a1a3f5436292ccb7a42fb50
[jalview.git] / test / jalview / math / SparseMatrixTest.java
1 package jalview.math;
2
3 import static org.testng.Assert.assertEquals;
4 import static org.testng.Assert.assertFalse;
5 import static org.testng.Assert.assertTrue;
6 import static org.testng.Assert.fail;
7
8 import java.util.Random;
9
10 import org.testng.annotations.Test;
11 import org.testng.internal.junit.ArrayAsserts;
12
13 public class SparseMatrixTest
14 {
15   final static double DELTA = 0.0001d;
16
17   Random r = new Random(1729);
18
19   @Test(groups = "Functional")
20   public void testConstructor()
21   {
22     MatrixI m1 = new SparseMatrix(
23             new double[][]
24             { { 2, 0, 4 }, { 0, 6, 0 } });
25     assertEquals(m1.getValue(0, 0), 2d);
26     assertEquals(m1.getValue(0, 1), 0d);
27     assertEquals(m1.getValue(0, 2), 4d);
28     assertEquals(m1.getValue(1, 0), 0d);
29     assertEquals(m1.getValue(1, 1), 6d);
30     assertEquals(m1.getValue(1, 2), 0d);
31   }
32
33   @Test(groups = "Functional")
34   public void testTranspose()
35   {
36     MatrixI m1 = new SparseMatrix(
37             new double[][]
38             { { 2, 0, 4 }, { 5, 6, 0 } });
39     MatrixI m2 = m1.transpose();
40     assertTrue(m2 instanceof SparseMatrix);
41     assertEquals(m2.height(), 3);
42     assertEquals(m2.width(), 2);
43     assertEquals(m2.getValue(0, 0), 2d);
44     assertEquals(m2.getValue(0, 1), 5d);
45     assertEquals(m2.getValue(1, 0), 0d);
46     assertEquals(m2.getValue(1, 1), 6d);
47     assertEquals(m2.getValue(2, 0), 4d);
48     assertEquals(m2.getValue(2, 1), 0d);
49   }
50
51   @Test(groups = "Functional")
52   public void testPreMultiply()
53   {
54     MatrixI m1 = new SparseMatrix(new double[][] { { 2, 3, 4 } }); // 1x3
55     MatrixI m2 = new SparseMatrix(new double[][] { { 5 }, { 6 }, { 7 } }); // 3x1
56
57     /*
58      * 1x3 times 3x1 is 1x1
59      * 2x5 + 3x6 + 4*7 =  56
60      */
61     MatrixI m3 = m2.preMultiply(m1);
62     assertFalse(m3 instanceof SparseMatrix);
63     assertEquals(m3.height(), 1);
64     assertEquals(m3.width(), 1);
65     assertEquals(m3.getValue(0, 0), 56d);
66
67     /*
68      * 3x1 times 1x3 is 3x3
69      */
70     m3 = m1.preMultiply(m2);
71     assertEquals(m3.height(), 3);
72     assertEquals(m3.width(), 3);
73     assertEquals(m3.getValue(0, 0), 10d);
74     assertEquals(m3.getValue(0, 1), 15d);
75     assertEquals(m3.getValue(0, 2), 20d);
76     assertEquals(m3.getValue(1, 0), 12d);
77     assertEquals(m3.getValue(1, 1), 18d);
78     assertEquals(m3.getValue(1, 2), 24d);
79     assertEquals(m3.getValue(2, 0), 14d);
80     assertEquals(m3.getValue(2, 1), 21d);
81     assertEquals(m3.getValue(2, 2), 28d);
82   }
83
84   @Test(
85     groups = "Functional",
86     expectedExceptions =
87     { IllegalArgumentException.class })
88   public void testPreMultiply_tooManyColumns()
89   {
90     Matrix m1 = new SparseMatrix(
91             new double[][]
92             { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
93
94     /*
95      * 2x3 times 2x3 invalid operation - 
96      * multiplier has more columns than multiplicand has rows
97      */
98     m1.preMultiply(m1);
99     fail("Expected exception");
100   }
101
102   @Test(
103     groups = "Functional",
104     expectedExceptions =
105     { IllegalArgumentException.class })
106   public void testPreMultiply_tooFewColumns()
107   {
108     Matrix m1 = new SparseMatrix(
109             new double[][]
110             { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
111
112     /*
113      * 3x2 times 3x2 invalid operation - 
114      * multiplier has more columns than multiplicand has row
115      */
116     m1.preMultiply(m1);
117     fail("Expected exception");
118   }
119
120   @Test(groups = "Functional")
121   public void testPostMultiply()
122   {
123     /*
124      * Square matrices
125      * (2 3) . (10   100)
126      * (4 5)   (1000 10000)
127      * =
128      * (3020 30200)
129      * (5040 50400)
130      */
131     MatrixI m1 = new SparseMatrix(new double[][] { { 2, 3 }, { 4, 5 } });
132     MatrixI m2 = new SparseMatrix(
133             new double[][]
134             { { 10, 100 }, { 1000, 10000 } });
135     MatrixI m3 = m1.postMultiply(m2);
136     assertEquals(m3.getValue(0, 0), 3020d);
137     assertEquals(m3.getValue(0, 1), 30200d);
138     assertEquals(m3.getValue(1, 0), 5040d);
139     assertEquals(m3.getValue(1, 1), 50400d);
140
141     /*
142      * also check m2.preMultiply(m1) - should be same as m1.postMultiply(m2) 
143      */
144     MatrixI m4 = m2.preMultiply(m1);
145     assertMatricesMatch(m3, m4, 0.00001d);
146
147     /*
148      * m1 has more rows than columns
149      * (2).(10 100 1000) = (20 200 2000)
150      * (3)                 (30 300 3000)
151      */
152     m1 = new SparseMatrix(new double[][] { { 2 }, { 3 } });
153     m2 = new SparseMatrix(new double[][] { { 10, 100, 1000 } });
154     m3 = m1.postMultiply(m2);
155     assertEquals(m3.height(), 2);
156     assertEquals(m3.width(), 3);
157     assertEquals(m3.getValue(0, 0), 20d);
158     assertEquals(m3.getValue(0, 1), 200d);
159     assertEquals(m3.getValue(0, 2), 2000d);
160     assertEquals(m3.getValue(1, 0), 30d);
161     assertEquals(m3.getValue(1, 1), 300d);
162     assertEquals(m3.getValue(1, 2), 3000d);
163
164     m4 = m2.preMultiply(m1);
165     assertMatricesMatch(m3, m4, 0.00001d);
166
167     /*
168      * m1 has more columns than rows
169      * (2 3 4) . (5 4) = (56 25)
170      *           (6 3) 
171      *           (7 2)
172      * [0, 0] = 2*5 + 3*6 + 4*7 = 56
173      * [0, 1] = 2*4 + 3*3 + 4*2 = 25  
174      */
175     m1 = new SparseMatrix(new double[][] { { 2, 3, 4 } });
176     m2 = new SparseMatrix(new double[][] { { 5, 4 }, { 6, 3 }, { 7, 2 } });
177     m3 = m1.postMultiply(m2);
178     assertEquals(m3.height(), 1);
179     assertEquals(m3.width(), 2);
180     assertEquals(m3.getValue(0, 0), 56d);
181     assertEquals(m3.getValue(0, 1), 25d);
182
183     /*
184      * and check premultiply equivalent
185      */
186     m4 = m2.preMultiply(m1);
187     assertMatricesMatch(m3, m4, 0.00001d);
188   }
189
190   @Test(groups = "Timing")
191   public void testSign()
192   {
193     assertEquals(Matrix.sign(-1, -2), -1d);
194     assertEquals(Matrix.sign(-1, 2), 1d);
195     assertEquals(Matrix.sign(-1, 0), 1d);
196     assertEquals(Matrix.sign(1, -2), -1d);
197     assertEquals(Matrix.sign(1, 2), 1d);
198     assertEquals(Matrix.sign(1, 0), 1d);
199   }
200
201   /**
202    * Verify that the results of method tred() are the same for SparseMatrix as
203    * they are for Matrix (i.e. a regression test rather than an absolute test of
204    * correctness of results)
205    */
206   @Test(groups = "Functional")
207   public void testTred_matchesMatrix()
208   {
209     /*
210      * make a pseudo-random symmetric matrix as required for tred/tqli
211      */
212     int rows = 10;
213     int cols = rows;
214     double[][] d = getSparseValues(rows, cols, 3);
215
216     /*
217      * make a copy of the values so m1, m2 are not
218      * sharing arrays!
219      */
220     double[][] d1 = new double[rows][cols];
221     for (int row = 0; row < rows; row++)
222     {
223       for (int col = 0; col < cols; col++)
224       {
225         d1[row][col] = d[row][col];
226       }
227     }
228     Matrix m1 = new Matrix(d);
229     Matrix m2 = new SparseMatrix(d1);
230     assertMatricesMatch(m1, m2, 0.00001d); // sanity check
231     m1.tred();
232     m2.tred();
233     assertMatricesMatch(m1, m2, 0.00001d);
234   }
235
236   private void assertMatricesMatch(MatrixI m1, MatrixI m2, double delta)
237   {
238     if (m1.height() != m2.height())
239     {
240       fail("height mismatch");
241     }
242     if (m1.width() != m2.width())
243     {
244       fail("width mismatch");
245     }
246     for (int row = 0; row < m1.height(); row++)
247     {
248       for (int col = 0; col < m1.width(); col++)
249       {
250         double v2 = m2.getValue(row, col);
251         double v1 = m1.getValue(row, col);
252         if (Math.abs(v1 - v2) > DELTA)
253         {
254           fail(String.format("At [%d, %d] %f != %f", row, col, v1, v2));
255         }
256       }
257     }
258     ArrayAsserts.assertArrayEquals(m1.getD(), m2.getD(), delta);
259     ArrayAsserts.assertArrayEquals(m1.getE(), m2.getE(), 0.00001d);
260   }
261
262   @Test
263   public void testGetValue()
264   {
265     double[][] d = new double[][] { { 0, 0, 1, 0, 0 }, { 2, 3, 0, 0, 0 },
266         { 4, 0, 0, 0, 5 } };
267     MatrixI m = new SparseMatrix(d);
268     for (int row = 0; row < 3; row++)
269     {
270       for (int col = 0; col < 5; col++)
271       {
272         assertEquals(m.getValue(row, col), d[row][col],
273                 String.format("At [%d, %d]", row, col));
274       }
275     }
276   }
277
278   /**
279    * Verify that the results of method tqli() are the same for SparseMatrix as
280    * they are for Matrix (i.e. a regression test rather than an absolute test of
281    * correctness of results)
282    * 
283    * @throws Exception
284    */
285   @Test(groups = "Functional")
286   public void testTqli_matchesMatrix() throws Exception
287   {
288     /*
289      * make a pseudo-random symmetric matrix as required for tred
290      */
291     int rows = 6;
292     int cols = rows;
293     double[][] d = getSparseValues(rows, cols, 3);
294
295     /*
296      * make a copy of the values so m1, m2 are not
297      * sharing arrays!
298      */
299     double[][] d1 = new double[rows][cols];
300     for (int row = 0; row < rows; row++)
301     {
302       for (int col = 0; col < cols; col++)
303       {
304         d1[row][col] = d[row][col];
305       }
306     }
307     Matrix m1 = new Matrix(d);
308     Matrix m2 = new SparseMatrix(d1);
309
310     // have to do tred() before doing tqli()
311     m1.tred();
312     m2.tred();
313     assertMatricesMatch(m1, m2, 0.00001d);
314
315     m1.tqli();
316     m2.tqli();
317     assertMatricesMatch(m1, m2, 0.00001d);
318   }
319
320   /**
321    * Helper method to make values for a sparse, pseudo-random symmetric matrix
322    * 
323    * @param rows
324    * @param cols
325    * @param occupancy
326    *          one in 'occupancy' entries will be non-zero
327    * @return
328    */
329   public double[][] getSparseValues(int rows, int cols, int occupancy)
330   {
331     /*
332      * generate whole number values between -12 and +12
333      * (to mimic score matrices used in Jalview)
334      */
335     double[][] d = new double[rows][cols];
336     int m = 0;
337     for (int i = 0; i < rows; i++)
338     {
339       if (++m % occupancy == 0)
340       {
341         d[i][i] = r.nextInt() % 13; // diagonal
342       }
343       for (int j = 0; j < i; j++)
344       {
345         if (++m % occupancy == 0)
346         {
347           d[i][j] = r.nextInt() % 13;
348           d[j][i] = d[i][j];
349         }
350       }
351     }
352     return d;
353
354   }
355
356   /**
357    * Test that verifies that the result of preMultiply is a SparseMatrix if more
358    * than 80% zeroes, else a Matrix
359    */
360   @Test(groups = "Functional")
361   public void testPreMultiply_sparseProduct()
362   {
363     MatrixI m1 = new SparseMatrix(
364             new double[][]
365             { { 1 }, { 0 }, { 0 }, { 0 }, { 0 } }); // 5x1
366     MatrixI m2 = new SparseMatrix(new double[][] { { 1, 1, 1, 1 } }); // 1x4
367
368     /*
369      * m1.m2 makes a row of 4 1's, and 4 rows of zeros
370      * 20% non-zero so not 'sparse'
371      */
372     MatrixI m3 = m2.preMultiply(m1);
373     assertFalse(m3 instanceof SparseMatrix);
374
375     /*
376      * replace a 1 with a 0 in the product:
377      * it is now > 80% zero so 'sparse'
378      */
379     m2 = new SparseMatrix(new double[][] { { 1, 1, 1, 0 } });
380     m3 = m2.preMultiply(m1);
381     assertTrue(m3 instanceof SparseMatrix);
382   }
383
384   @Test(groups = "Functional")
385   public void testFillRatio()
386   {
387     SparseMatrix m1 = new SparseMatrix(
388             new double[][]
389             { { 2, 0, 4, 1, 0 }, { 0, 6, 0, 0, 0 } });
390     assertEquals(m1.getFillRatio(), 0.4f);
391   }
392
393   /**
394    * Verify that the results of method tred() are the same if the calculation is
395    * redone
396    */
397   @Test(groups = "Functional")
398   public void testTred_reproducible()
399   {
400     /*
401      * make a pseudo-random symmetric matrix as required for tred/tqli
402      */
403     int rows = 10;
404     int cols = rows;
405     double[][] d = getSparseValues(rows, cols, 3);
406
407     /*
408      * make a copy of the values so m1, m2 are not
409      * sharing arrays!
410      */
411     double[][] d1 = new double[rows][cols];
412     for (int row = 0; row < rows; row++)
413     {
414       for (int col = 0; col < cols; col++)
415       {
416         d1[row][col] = d[row][col];
417       }
418     }
419     Matrix m1 = new SparseMatrix(d);
420     Matrix m2 = new SparseMatrix(d1);
421     assertMatricesMatch(m1, m2, 1.0e16); // sanity check
422     m1.tred();
423     m2.tred();
424     assertMatricesMatch(m1, m2, 0.00001d);
425   }
426 }