JAL-3746 apply copyright to tests
[jalview.git] / test / jalview / math / SparseMatrixTest.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21 package jalview.math;
22
23 import static org.testng.Assert.assertEquals;
24 import static org.testng.Assert.assertFalse;
25 import static org.testng.Assert.assertTrue;
26 import static org.testng.Assert.fail;
27
28 import java.util.Random;
29
30 import org.testng.annotations.Test;
31 import org.testng.internal.junit.ArrayAsserts;
32
33 public class SparseMatrixTest
34 {
35   final static double DELTA = 0.0001d;
36
37   Random r = new Random(1729);
38
39   @Test(groups = "Functional")
40   public void testConstructor()
41   {
42     MatrixI m1 = new SparseMatrix(
43             new double[][]
44             { { 2, 0, 4 }, { 0, 6, 0 } });
45     assertEquals(m1.getValue(0, 0), 2d);
46     assertEquals(m1.getValue(0, 1), 0d);
47     assertEquals(m1.getValue(0, 2), 4d);
48     assertEquals(m1.getValue(1, 0), 0d);
49     assertEquals(m1.getValue(1, 1), 6d);
50     assertEquals(m1.getValue(1, 2), 0d);
51   }
52
53   @Test(groups = "Functional")
54   public void testTranspose()
55   {
56     MatrixI m1 = new SparseMatrix(
57             new double[][]
58             { { 2, 0, 4 }, { 5, 6, 0 } });
59     MatrixI m2 = m1.transpose();
60     assertTrue(m2 instanceof SparseMatrix);
61     assertEquals(m2.height(), 3);
62     assertEquals(m2.width(), 2);
63     assertEquals(m2.getValue(0, 0), 2d);
64     assertEquals(m2.getValue(0, 1), 5d);
65     assertEquals(m2.getValue(1, 0), 0d);
66     assertEquals(m2.getValue(1, 1), 6d);
67     assertEquals(m2.getValue(2, 0), 4d);
68     assertEquals(m2.getValue(2, 1), 0d);
69   }
70
71   @Test(groups = "Functional")
72   public void testPreMultiply()
73   {
74     MatrixI m1 = new SparseMatrix(new double[][] { { 2, 3, 4 } }); // 1x3
75     MatrixI m2 = new SparseMatrix(new double[][] { { 5 }, { 6 }, { 7 } }); // 3x1
76
77     /*
78      * 1x3 times 3x1 is 1x1
79      * 2x5 + 3x6 + 4*7 =  56
80      */
81     MatrixI m3 = m2.preMultiply(m1);
82     assertFalse(m3 instanceof SparseMatrix);
83     assertEquals(m3.height(), 1);
84     assertEquals(m3.width(), 1);
85     assertEquals(m3.getValue(0, 0), 56d);
86
87     /*
88      * 3x1 times 1x3 is 3x3
89      */
90     m3 = m1.preMultiply(m2);
91     assertEquals(m3.height(), 3);
92     assertEquals(m3.width(), 3);
93     assertEquals(m3.getValue(0, 0), 10d);
94     assertEquals(m3.getValue(0, 1), 15d);
95     assertEquals(m3.getValue(0, 2), 20d);
96     assertEquals(m3.getValue(1, 0), 12d);
97     assertEquals(m3.getValue(1, 1), 18d);
98     assertEquals(m3.getValue(1, 2), 24d);
99     assertEquals(m3.getValue(2, 0), 14d);
100     assertEquals(m3.getValue(2, 1), 21d);
101     assertEquals(m3.getValue(2, 2), 28d);
102   }
103
104   @Test(
105     groups = "Functional",
106     expectedExceptions =
107     { IllegalArgumentException.class })
108   public void testPreMultiply_tooManyColumns()
109   {
110     Matrix m1 = new SparseMatrix(
111             new double[][]
112             { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
113
114     /*
115      * 2x3 times 2x3 invalid operation - 
116      * multiplier has more columns than multiplicand has rows
117      */
118     m1.preMultiply(m1);
119     fail("Expected exception");
120   }
121
122   @Test(
123     groups = "Functional",
124     expectedExceptions =
125     { IllegalArgumentException.class })
126   public void testPreMultiply_tooFewColumns()
127   {
128     Matrix m1 = new SparseMatrix(
129             new double[][]
130             { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
131
132     /*
133      * 3x2 times 3x2 invalid operation - 
134      * multiplier has more columns than multiplicand has row
135      */
136     m1.preMultiply(m1);
137     fail("Expected exception");
138   }
139
140   @Test(groups = "Functional")
141   public void testPostMultiply()
142   {
143     /*
144      * Square matrices
145      * (2 3) . (10   100)
146      * (4 5)   (1000 10000)
147      * =
148      * (3020 30200)
149      * (5040 50400)
150      */
151     MatrixI m1 = new SparseMatrix(new double[][] { { 2, 3 }, { 4, 5 } });
152     MatrixI m2 = new SparseMatrix(
153             new double[][]
154             { { 10, 100 }, { 1000, 10000 } });
155     MatrixI m3 = m1.postMultiply(m2);
156     assertEquals(m3.getValue(0, 0), 3020d);
157     assertEquals(m3.getValue(0, 1), 30200d);
158     assertEquals(m3.getValue(1, 0), 5040d);
159     assertEquals(m3.getValue(1, 1), 50400d);
160
161     /*
162      * also check m2.preMultiply(m1) - should be same as m1.postMultiply(m2) 
163      */
164     MatrixI m4 = m2.preMultiply(m1);
165     assertMatricesMatch(m3, m4, 0.00001d);
166
167     /*
168      * m1 has more rows than columns
169      * (2).(10 100 1000) = (20 200 2000)
170      * (3)                 (30 300 3000)
171      */
172     m1 = new SparseMatrix(new double[][] { { 2 }, { 3 } });
173     m2 = new SparseMatrix(new double[][] { { 10, 100, 1000 } });
174     m3 = m1.postMultiply(m2);
175     assertEquals(m3.height(), 2);
176     assertEquals(m3.width(), 3);
177     assertEquals(m3.getValue(0, 0), 20d);
178     assertEquals(m3.getValue(0, 1), 200d);
179     assertEquals(m3.getValue(0, 2), 2000d);
180     assertEquals(m3.getValue(1, 0), 30d);
181     assertEquals(m3.getValue(1, 1), 300d);
182     assertEquals(m3.getValue(1, 2), 3000d);
183
184     m4 = m2.preMultiply(m1);
185     assertMatricesMatch(m3, m4, 0.00001d);
186
187     /*
188      * m1 has more columns than rows
189      * (2 3 4) . (5 4) = (56 25)
190      *           (6 3) 
191      *           (7 2)
192      * [0, 0] = 2*5 + 3*6 + 4*7 = 56
193      * [0, 1] = 2*4 + 3*3 + 4*2 = 25  
194      */
195     m1 = new SparseMatrix(new double[][] { { 2, 3, 4 } });
196     m2 = new SparseMatrix(new double[][] { { 5, 4 }, { 6, 3 }, { 7, 2 } });
197     m3 = m1.postMultiply(m2);
198     assertEquals(m3.height(), 1);
199     assertEquals(m3.width(), 2);
200     assertEquals(m3.getValue(0, 0), 56d);
201     assertEquals(m3.getValue(0, 1), 25d);
202
203     /*
204      * and check premultiply equivalent
205      */
206     m4 = m2.preMultiply(m1);
207     assertMatricesMatch(m3, m4, 0.00001d);
208   }
209
210   @Test(groups = "Timing")
211   public void testSign()
212   {
213     assertEquals(Matrix.sign(-1, -2), -1d);
214     assertEquals(Matrix.sign(-1, 2), 1d);
215     assertEquals(Matrix.sign(-1, 0), 1d);
216     assertEquals(Matrix.sign(1, -2), -1d);
217     assertEquals(Matrix.sign(1, 2), 1d);
218     assertEquals(Matrix.sign(1, 0), 1d);
219   }
220
221   /**
222    * Verify that the results of method tred() are the same for SparseMatrix as
223    * they are for Matrix (i.e. a regression test rather than an absolute test of
224    * correctness of results)
225    */
226   @Test(groups = "Functional")
227   public void testTred_matchesMatrix()
228   {
229     /*
230      * make a pseudo-random symmetric matrix as required for tred/tqli
231      */
232     int rows = 10;
233     int cols = rows;
234     double[][] d = getSparseValues(rows, cols, 3);
235
236     /*
237      * make a copy of the values so m1, m2 are not
238      * sharing arrays!
239      */
240     double[][] d1 = new double[rows][cols];
241     for (int row = 0; row < rows; row++)
242     {
243       for (int col = 0; col < cols; col++)
244       {
245         d1[row][col] = d[row][col];
246       }
247     }
248     Matrix m1 = new Matrix(d);
249     Matrix m2 = new SparseMatrix(d1);
250     assertMatricesMatch(m1, m2, 0.00001d); // sanity check
251     m1.tred();
252     m2.tred();
253     assertMatricesMatch(m1, m2, 0.00001d);
254   }
255
256   private void assertMatricesMatch(MatrixI m1, MatrixI m2, double delta)
257   {
258     if (m1.height() != m2.height())
259     {
260       fail("height mismatch");
261     }
262     if (m1.width() != m2.width())
263     {
264       fail("width mismatch");
265     }
266     for (int row = 0; row < m1.height(); row++)
267     {
268       for (int col = 0; col < m1.width(); col++)
269       {
270         double v2 = m2.getValue(row, col);
271         double v1 = m1.getValue(row, col);
272         if (Math.abs(v1 - v2) > DELTA)
273         {
274           fail(String.format("At [%d, %d] %f != %f", row, col, v1, v2));
275         }
276       }
277     }
278     ArrayAsserts.assertArrayEquals(m1.getD(), m2.getD(), delta);
279     ArrayAsserts.assertArrayEquals(m1.getE(), m2.getE(), 0.00001d);
280   }
281
282   @Test
283   public void testGetValue()
284   {
285     double[][] d = new double[][] { { 0, 0, 1, 0, 0 }, { 2, 3, 0, 0, 0 },
286         { 4, 0, 0, 0, 5 } };
287     MatrixI m = new SparseMatrix(d);
288     for (int row = 0; row < 3; row++)
289     {
290       for (int col = 0; col < 5; col++)
291       {
292         assertEquals(m.getValue(row, col), d[row][col],
293                 String.format("At [%d, %d]", row, col));
294       }
295     }
296   }
297
298   /**
299    * Verify that the results of method tqli() are the same for SparseMatrix as
300    * they are for Matrix (i.e. a regression test rather than an absolute test of
301    * correctness of results)
302    * 
303    * @throws Exception
304    */
305   @Test(groups = "Functional")
306   public void testTqli_matchesMatrix() throws Exception
307   {
308     /*
309      * make a pseudo-random symmetric matrix as required for tred
310      */
311     int rows = 6;
312     int cols = rows;
313     double[][] d = getSparseValues(rows, cols, 3);
314
315     /*
316      * make a copy of the values so m1, m2 are not
317      * sharing arrays!
318      */
319     double[][] d1 = new double[rows][cols];
320     for (int row = 0; row < rows; row++)
321     {
322       for (int col = 0; col < cols; col++)
323       {
324         d1[row][col] = d[row][col];
325       }
326     }
327     Matrix m1 = new Matrix(d);
328     Matrix m2 = new SparseMatrix(d1);
329
330     // have to do tred() before doing tqli()
331     m1.tred();
332     m2.tred();
333     assertMatricesMatch(m1, m2, 0.00001d);
334
335     m1.tqli();
336     m2.tqli();
337     assertMatricesMatch(m1, m2, 0.00001d);
338   }
339
340   /**
341    * Helper method to make values for a sparse, pseudo-random symmetric matrix
342    * 
343    * @param rows
344    * @param cols
345    * @param occupancy
346    *          one in 'occupancy' entries will be non-zero
347    * @return
348    */
349   public double[][] getSparseValues(int rows, int cols, int occupancy)
350   {
351     /*
352      * generate whole number values between -12 and +12
353      * (to mimic score matrices used in Jalview)
354      */
355     double[][] d = new double[rows][cols];
356     int m = 0;
357     for (int i = 0; i < rows; i++)
358     {
359       if (++m % occupancy == 0)
360       {
361         d[i][i] = r.nextInt() % 13; // diagonal
362       }
363       for (int j = 0; j < i; j++)
364       {
365         if (++m % occupancy == 0)
366         {
367           d[i][j] = r.nextInt() % 13;
368           d[j][i] = d[i][j];
369         }
370       }
371     }
372     return d;
373
374   }
375
376   /**
377    * Test that verifies that the result of preMultiply is a SparseMatrix if more
378    * than 80% zeroes, else a Matrix
379    */
380   @Test(groups = "Functional")
381   public void testPreMultiply_sparseProduct()
382   {
383     MatrixI m1 = new SparseMatrix(
384             new double[][]
385             { { 1 }, { 0 }, { 0 }, { 0 }, { 0 } }); // 5x1
386     MatrixI m2 = new SparseMatrix(new double[][] { { 1, 1, 1, 1 } }); // 1x4
387
388     /*
389      * m1.m2 makes a row of 4 1's, and 4 rows of zeros
390      * 20% non-zero so not 'sparse'
391      */
392     MatrixI m3 = m2.preMultiply(m1);
393     assertFalse(m3 instanceof SparseMatrix);
394
395     /*
396      * replace a 1 with a 0 in the product:
397      * it is now > 80% zero so 'sparse'
398      */
399     m2 = new SparseMatrix(new double[][] { { 1, 1, 1, 0 } });
400     m3 = m2.preMultiply(m1);
401     assertTrue(m3 instanceof SparseMatrix);
402   }
403
404   @Test(groups = "Functional")
405   public void testFillRatio()
406   {
407     SparseMatrix m1 = new SparseMatrix(
408             new double[][]
409             { { 2, 0, 4, 1, 0 }, { 0, 6, 0, 0, 0 } });
410     assertEquals(m1.getFillRatio(), 0.4f);
411   }
412
413   /**
414    * Verify that the results of method tred() are the same if the calculation is
415    * redone
416    */
417   @Test(groups = "Functional")
418   public void testTred_reproducible()
419   {
420     /*
421      * make a pseudo-random symmetric matrix as required for tred/tqli
422      */
423     int rows = 10;
424     int cols = rows;
425     double[][] d = getSparseValues(rows, cols, 3);
426
427     /*
428      * make a copy of the values so m1, m2 are not
429      * sharing arrays!
430      */
431     double[][] d1 = new double[rows][cols];
432     for (int row = 0; row < rows; row++)
433     {
434       for (int col = 0; col < cols; col++)
435       {
436         d1[row][col] = d[row][col];
437       }
438     }
439     Matrix m1 = new SparseMatrix(d);
440     Matrix m2 = new SparseMatrix(d1);
441     assertMatricesMatch(m1, m2, 1.0e16); // sanity check
442     m1.tred();
443     m2.tred();
444     assertMatricesMatch(m1, m2, 0.00001d);
445   }
446 }