8af10b0f248c0d1a2c152dbc89105de6f1e88cdb
[jalview.git] / test / jalview / math / MatrixTest.java
1 package jalview.math;
2
3 import static org.testng.Assert.assertEquals;
4 import static org.testng.Assert.assertFalse;
5 import static org.testng.Assert.assertNotEquals;
6 import static org.testng.Assert.assertNotSame;
7 import static org.testng.Assert.assertNull;
8 import static org.testng.Assert.assertTrue;
9 import static org.testng.Assert.fail;
10
11 import java.util.Arrays;
12 import java.util.Random;
13
14 import org.testng.annotations.Test;
15 import org.testng.internal.junit.ArrayAsserts;
16
17 public class MatrixTest
18 {
19   final static double DELTA = 0.000001d;
20
21   @Test(groups = "Timing")
22   public void testPreMultiply_timing()
23   {
24     int rows = 50; // increase to stress test timing
25     int cols = 100;
26     double[][] d1 = new double[rows][cols];
27     double[][] d2 = new double[cols][rows];
28     Matrix m1 = new Matrix(d1);
29     Matrix m2 = new Matrix(d2);
30     long start = System.currentTimeMillis();
31     m1.preMultiply(m2);
32     long elapsed = System.currentTimeMillis() - start;
33     System.out.println(rows + "x" + cols
34             + " multiplications of double took " + elapsed + "ms");
35   }
36
37   @Test(groups = "Functional")
38   public void testPreMultiply()
39   {
40     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 } }); // 1x3
41     Matrix m2 = new Matrix(new double[][] { { 5 }, { 6 }, { 7 } }); // 3x1
42
43     /*
44      * 1x3 times 3x1 is 1x1
45      * 2x5 + 3x6 + 4*7 =  56
46      */
47     MatrixI m3 = m2.preMultiply(m1);
48     assertEquals(m3.height(), 1);
49     assertEquals(m3.width(), 1);
50     assertEquals(m3.getValue(0, 0), 56d);
51
52     /*
53      * 3x1 times 1x3 is 3x3
54      */
55     m3 = m1.preMultiply(m2);
56     assertEquals(m3.height(), 3);
57     assertEquals(m3.width(), 3);
58     assertEquals(Arrays.toString(m3.getRow(0)), "[10.0, 15.0, 20.0]");
59     assertEquals(Arrays.toString(m3.getRow(1)), "[12.0, 18.0, 24.0]");
60     assertEquals(Arrays.toString(m3.getRow(2)), "[14.0, 21.0, 28.0]");
61   }
62
63   @Test(
64     groups = "Functional",
65     expectedExceptions = { IllegalArgumentException.class })
66   public void testPreMultiply_tooManyColumns()
67   {
68     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
69
70     /*
71      * 2x3 times 2x3 invalid operation - 
72      * multiplier has more columns than multiplicand has rows
73      */
74     m1.preMultiply(m1);
75     fail("Expected exception");
76   }
77
78   @Test(
79     groups = "Functional",
80     expectedExceptions = { IllegalArgumentException.class })
81   public void testPreMultiply_tooFewColumns()
82   {
83     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
84
85     /*
86      * 3x2 times 3x2 invalid operation - 
87      * multiplier has more columns than multiplicand has row
88      */
89     m1.preMultiply(m1);
90     fail("Expected exception");
91   }
92   
93   
94   private boolean matrixEquals(Matrix m1, Matrix m2) {
95     if (m1.width() != m2.width() || m1.height() != m2.height())
96     {
97       return false;
98     }
99     for (int i = 0; i < m1.height(); i++)
100     {
101       if (!Arrays.equals(m1.getRow(i), m2.getRow(i)))
102       {
103         return false;
104       }
105     }
106     return true;
107   }
108
109   @Test(groups = "Functional")
110   public void testPostMultiply()
111   {
112     /*
113      * Square matrices
114      * (2 3) . (10   100)
115      * (4 5)   (1000 10000)
116      * =
117      * (3020 30200)
118      * (5040 50400)
119      */
120     MatrixI m1 = new Matrix(new double[][] { { 2, 3 }, { 4, 5 } });
121     MatrixI m2 = new Matrix(new double[][] { { 10, 100 }, { 1000, 10000 } });
122     MatrixI m3 = m1.postMultiply(m2);
123     assertEquals(Arrays.toString(m3.getRow(0)), "[3020.0, 30200.0]");
124     assertEquals(Arrays.toString(m3.getRow(1)), "[5040.0, 50400.0]");
125
126     /*
127      * also check m2.preMultiply(m1) - should be same as m1.postMultiply(m2) 
128      */
129     m3 = m2.preMultiply(m1);
130     assertEquals(Arrays.toString(m3.getRow(0)), "[3020.0, 30200.0]");
131     assertEquals(Arrays.toString(m3.getRow(1)), "[5040.0, 50400.0]");
132
133     /*
134      * m1 has more rows than columns
135      * (2).(10 100 1000) = (20 200 2000)
136      * (3)                 (30 300 3000)
137      */
138     m1 = new Matrix(new double[][] { { 2 }, { 3 } });
139     m2 = new Matrix(new double[][] { { 10, 100, 1000 } });
140     m3 = m1.postMultiply(m2);
141     assertEquals(m3.height(), 2);
142     assertEquals(m3.width(), 3);
143     assertEquals(Arrays.toString(m3.getRow(0)), "[20.0, 200.0, 2000.0]");
144     assertEquals(Arrays.toString(m3.getRow(1)), "[30.0, 300.0, 3000.0]");
145     m3 = m2.preMultiply(m1);
146     assertEquals(m3.height(), 2);
147     assertEquals(m3.width(), 3);
148     assertEquals(Arrays.toString(m3.getRow(0)), "[20.0, 200.0, 2000.0]");
149     assertEquals(Arrays.toString(m3.getRow(1)), "[30.0, 300.0, 3000.0]");
150
151     /*
152      * m1 has more columns than rows
153      * (2 3 4) . (5 4) = (56 25)
154      *           (6 3) 
155      *           (7 2)
156      * [0, 0] = 2*5 + 3*6 + 4*7 = 56
157      * [0, 1] = 2*4 + 3*3 + 4*2 = 25  
158      */
159     m1 = new Matrix(new double[][] { { 2, 3, 4 } });
160     m2 = new Matrix(new double[][] { { 5, 4 }, { 6, 3 }, { 7, 2 } });
161     m3 = m1.postMultiply(m2);
162     assertEquals(m3.height(), 1);
163     assertEquals(m3.width(), 2);
164     assertEquals(m3.getRow(0)[0], 56d);
165     assertEquals(m3.getRow(0)[1], 25d);
166
167     /*
168      * and check premultiply equivalent
169      */
170     m3 = m2.preMultiply(m1);
171     assertEquals(m3.height(), 1);
172     assertEquals(m3.width(), 2);
173     assertEquals(m3.getRow(0)[0], 56d);
174     assertEquals(m3.getRow(0)[1], 25d);
175   }
176
177   @Test(groups = "Functional")
178   public void testCopy()
179   {
180     Random r = new Random();
181     int rows = 5;
182     int cols = 11;
183     double[][] in = new double[rows][cols];
184
185     for (int i = 0; i < rows; i++)
186     {
187       for (int j = 0; j < cols; j++)
188       {
189         in[i][j] = r.nextDouble();
190       }
191     }
192     Matrix m1 = new Matrix(in);
193
194     Matrix m2 = (Matrix) m1.copy();
195     assertNotSame(m1, m2);
196     assertTrue(matrixEquals(m1, m2));
197     assertNull(m2.d);
198     assertNull(m2.e);
199
200     /*
201      * now add d and e vectors and recopy
202      */
203     m1.d = Arrays.copyOf(in[2], in[2].length);
204     m1.e = Arrays.copyOf(in[4], in[4].length);
205     m2 = (Matrix) m1.copy();
206     assertNotSame(m2.d, m1.d);
207     assertNotSame(m2.e, m1.e);
208     assertEquals(m2.d, m1.d);
209     assertEquals(m2.e, m1.e);
210   }
211
212   /**
213    * main method extracted from Matrix
214    * 
215    * @param args
216    */
217   public static void main(String[] args) throws Exception
218   {
219     int n = Integer.parseInt(args[0]);
220     double[][] in = new double[n][n];
221   
222     for (int i = 0; i < n; i++)
223     {
224       for (int j = 0; j < n; j++)
225       {
226         in[i][j] = Math.random();
227       }
228     }
229   
230     Matrix origmat = new Matrix(in);
231   
232     // System.out.println(" --- Original matrix ---- ");
233     // / origmat.print(System.out);
234     // System.out.println();
235     // System.out.println(" --- transpose matrix ---- ");
236     MatrixI trans = origmat.transpose();
237   
238     // trans.print(System.out);
239     // System.out.println();
240     // System.out.println(" --- OrigT * Orig ---- ");
241     MatrixI symm = trans.postMultiply(origmat);
242   
243     // symm.print(System.out);
244     // System.out.println();
245     // Copy the symmetric matrix for later
246     // Matrix origsymm = symm.copy();
247   
248     // This produces the tridiagonal transformation matrix
249     // long tstart = System.currentTimeMillis();
250     symm.tred();
251   
252     // long tend = System.currentTimeMillis();
253   
254     // System.out.println("Time take for tred = " + (tend-tstart) + "ms");
255     // System.out.println(" ---Tridiag transform matrix ---");
256     // symm.print(System.out);
257     // System.out.println();
258     // System.out.println(" --- D vector ---");
259     // symm.printD(System.out);
260     // System.out.println();
261     // System.out.println(" --- E vector ---");
262     // symm.printE(System.out);
263     // System.out.println();
264     // Now produce the diagonalization matrix
265     // tstart = System.currentTimeMillis();
266     symm.tqli();
267     // tend = System.currentTimeMillis();
268   
269     // System.out.println("Time take for tqli = " + (tend-tstart) + " ms");
270     // System.out.println(" --- New diagonalization matrix ---");
271     // symm.print(System.out);
272     // System.out.println();
273     // System.out.println(" --- D vector ---");
274     // symm.printD(System.out);
275     // System.out.println();
276     // System.out.println(" --- E vector ---");
277     // symm.printE(System.out);
278     // System.out.println();
279     // System.out.println(" --- First eigenvector --- ");
280     // double[] eigenv = symm.getColumn(0);
281     // for (int i=0; i < eigenv.length;i++) {
282     // Format.print(System.out,"%15.4f",eigenv[i]);
283     // }
284     // System.out.println();
285     // double[] neigenv = origsymm.vectorPostMultiply(eigenv);
286     // for (int i=0; i < neigenv.length;i++) {
287     // Format.print(System.out,"%15.4f",neigenv[i]/symm.d[0]);
288     // }
289     // System.out.println();
290   }
291
292   @Test(groups = "Timing")
293   public void testSign()
294   {
295     assertEquals(Matrix.sign(-1, -2), -1d);
296     assertEquals(Matrix.sign(-1, 2), 1d);
297     assertEquals(Matrix.sign(-1, 0), 1d);
298     assertEquals(Matrix.sign(1, -2), -1d);
299     assertEquals(Matrix.sign(1, 2), 1d);
300     assertEquals(Matrix.sign(1, 0), 1d);
301   }
302
303   /**
304    * Helper method to make values for a sparse, pseudo-random symmetric matrix
305    * 
306    * @param rows
307    * @param cols
308    * @param occupancy
309    *          one in 'occupancy' entries will be non-zero
310    * @return
311    */
312   public double[][] getSparseValues(int rows, int cols, int occupancy)
313   {
314     Random r = new Random(1729);
315
316     /*
317      * generate whole number values between -12 and +12
318      * (to mimic score matrices used in Jalview)
319      */
320     double[][] d = new double[rows][cols];
321     int m = 0;
322     for (int i = 0; i < rows; i++)
323     {
324       if (++m % occupancy == 0)
325       {
326         d[i][i] = r.nextInt() % 13; // diagonal
327       }
328       for (int j = 0; j < i; j++)
329       {
330         if (++m % occupancy == 0)
331         {
332           d[i][j] = r.nextInt() % 13;
333           d[j][i] = d[i][j];
334         }
335       }
336     }
337     return d;
338   
339   }
340
341   /**
342    * Verify that the results of method tred() are the same if the calculation is
343    * redone
344    */
345   @Test(groups = "Functional")
346   public void testTred_reproducible()
347   {
348     /*
349      * make a pseudo-random symmetric matrix as required for tred/tqli
350      */
351     int rows = 10;
352     int cols = rows;
353     double[][] d = getSparseValues(rows, cols, 3);
354   
355     /*
356      * make a copy of the values so m1, m2 are not
357      * sharing arrays!
358      */
359     double[][] d1 = new double[rows][cols];
360     for (int row = 0; row < rows; row++)
361     {
362       for (int col = 0; col < cols; col++)
363       {
364         d1[row][col] = d[row][col];
365       }
366     }
367     Matrix m1 = new Matrix(d);
368     Matrix m2 = new Matrix(d1);
369     assertMatricesMatch(m1, m2); // sanity check
370     m1.tred();
371     m2.tred();
372     assertMatricesMatch(m1, m2);
373   }
374
375   public static void assertMatricesMatch(MatrixI m1, MatrixI m2)
376   {
377     if (m1.height() != m2.height())
378     {
379       fail("height mismatch");
380     }
381     if (m1.width() != m2.width())
382     {
383       fail("width mismatch");
384     }
385     for (int row = 0; row < m1.height(); row++)
386     {
387       for (int col = 0; col < m1.width(); col++)
388       {
389         double v2 = m2.getValue(row, col);
390         double v1 = m1.getValue(row, col);
391         if (Math.abs(v1 - v2) > DELTA)
392         {
393           fail(String.format("At [%d, %d] %f != %f", row, col, v1, v2));
394         }
395       }
396     }
397     ArrayAsserts.assertArrayEquals("D vector", m1.getD(), m2.getD(),
398             0.00001d);
399     ArrayAsserts.assertArrayEquals("E vector", m1.getE(), m2.getE(),
400             0.00001d);
401   }
402
403   @Test(groups = "Functional")
404   public void testFindMinMax()
405   {
406     /*
407      * empty matrix case
408      */
409     Matrix m = new Matrix(new double[][] { {} });
410     assertNull(m.findMinMax());
411
412     /*
413      * normal case
414      */
415     double[][] vals = new double[2][];
416     vals[0] = new double[] {7d, 1d, -2.3d};
417     vals[1] = new double[] {-12d, 94.3d, -102.34d};
418     m = new Matrix(vals);
419     double[] minMax = m.findMinMax();
420     assertEquals(minMax[0], -102.34d);
421     assertEquals(minMax[1], 94.3d);
422   }
423
424   @Test(groups = { "Functional", "Timing" })
425   public void testFindMinMax_timing()
426   {
427     Random r = new Random();
428     int size = 1000; // increase to stress test timing
429     double[][] vals = new double[size][size];
430     double max = -Double.MAX_VALUE;
431     double min = Double.MAX_VALUE;
432     for (int i = 0; i < size; i++)
433     {
434       vals[i] = new double[size];
435       for (int j = 0; j < size; j++)
436       {
437         // use nextLong rather than nextDouble to include negative values
438         double d = r.nextLong();
439         if (d > max)
440         {
441           max = d;
442         }
443         if (d < min)
444         {
445           min = d;
446         }
447         vals[i][j] = d;
448       }
449     }
450     Matrix m = new Matrix(vals);
451     long now = System.currentTimeMillis();
452     double[] minMax = m.findMinMax();
453     System.out.println(String.format("findMinMax for %d x %d took %dms",
454             size, size, (System.currentTimeMillis() - now)));
455     assertEquals(minMax[0], min);
456     assertEquals(minMax[1], max);
457   }
458
459   /**
460    * Test range reversal with maximum value becoming zero
461    */
462   @Test(groups = "Functional")
463   public void testReverseRange_maxToZero()
464   {
465     Matrix m1 = new Matrix(
466             new double[][] { { 2, 3.5, 4 }, { -3.4, 4, 15 } });
467
468     /*
469      * subtract all from max: range -3.4 to 15 becomes 18.4 to 0
470      */
471     m1.reverseRange(true);
472     assertEquals(m1.getValue(0, 0), 13d, DELTA);
473     assertEquals(m1.getValue(0, 1), 11.5d, DELTA);
474     assertEquals(m1.getValue(0, 2), 11d, DELTA);
475     assertEquals(m1.getValue(1, 0), 18.4d, DELTA);
476     assertEquals(m1.getValue(1, 1), 11d, DELTA);
477     assertEquals(m1.getValue(1, 2), 0d, DELTA);
478
479     /*
480      * repeat operation - range is now 0 to 18.4
481      */
482     m1.reverseRange(true);
483     assertEquals(m1.getValue(0, 0), 5.4d, DELTA);
484     assertEquals(m1.getValue(0, 1), 6.9d, DELTA);
485     assertEquals(m1.getValue(0, 2), 7.4d, DELTA);
486     assertEquals(m1.getValue(1, 0), 0d, DELTA);
487     assertEquals(m1.getValue(1, 1), 7.4d, DELTA);
488     assertEquals(m1.getValue(1, 2), 18.4d, DELTA);
489   }
490
491   /**
492    * Test range reversal with minimum and maximum values swapped
493    */
494   @Test(groups = "Functional")
495   public void testReverseRange_swapMinMax()
496   {
497     Matrix m1 = new Matrix(
498             new double[][] { { 2, 3.5, 4 }, { -3.4, 4, 15 } });
499   
500     /*
501      * swap all values in min-max range
502      * = subtract from (min + max = 11.6) 
503      * range -3.4 to 15 becomes 18.4 to -3.4
504      */
505     m1.reverseRange(false);
506     assertEquals(m1.getValue(0, 0), 9.6d, DELTA);
507     assertEquals(m1.getValue(0, 1), 8.1d, DELTA);
508     assertEquals(m1.getValue(0, 2), 7.6d, DELTA);
509     assertEquals(m1.getValue(1, 0), 15d, DELTA);
510     assertEquals(m1.getValue(1, 1), 7.6d, DELTA);
511     assertEquals(m1.getValue(1, 2), -3.4d, DELTA);
512   
513     /*
514      * repeat operation - original values restored
515      */
516     m1.reverseRange(false);
517     assertEquals(m1.getValue(0, 0), 2d, DELTA);
518     assertEquals(m1.getValue(0, 1), 3.5d, DELTA);
519     assertEquals(m1.getValue(0, 2), 4d, DELTA);
520     assertEquals(m1.getValue(1, 0), -3.4d, DELTA);
521     assertEquals(m1.getValue(1, 1), 4d, DELTA);
522     assertEquals(m1.getValue(1, 2), 15d, DELTA);
523   }
524
525   @Test(groups = "Functional")
526   public void testMultiply()
527   {
528     Matrix m = new Matrix(new double[][] { { 2, 3.5, 4 }, { -3.4, 4, 15 } });
529     m.multiply(2d);
530     assertEquals(m.getValue(0, 0), 4d, DELTA);
531     assertEquals(m.getValue(0, 1), 7d, DELTA);
532     assertEquals(m.getValue(0, 2), 8d, DELTA);
533     assertEquals(m.getValue(1, 0), -6.8d, DELTA);
534     assertEquals(m.getValue(1, 1), 8d, DELTA);
535     assertEquals(m.getValue(1, 2), 30d, DELTA);
536   }
537
538   @Test(groups = "Functional")
539   public void testConstructor()
540   {
541     double[][] values = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
542     Matrix m = new Matrix(values);
543     assertEquals(m.getValue(0, 0), 1d, DELTA);
544
545     /*
546      * verify the matrix has a copy of the original array
547      */
548     assertNotSame(values[0], m.getRow(0));
549     values[0][0] = -1d;
550     assertEquals(m.getValue(0, 0), 1d, DELTA); // unchanged
551   }
552
553   @Test(groups = "Functional")
554   public void testEquals_hashCode()
555   {
556     double[][] values = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
557     Matrix m1 = new Matrix(values);
558     double[][] values2 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
559     Matrix m2 = new Matrix(values2);
560
561     assertTrue(m1.equals(m1));
562     assertTrue(m1.equals(m2));
563     assertTrue(m2.equals(m1));
564     // equal objects should have same hashCode
565     assertEquals(m1.hashCode(), m2.hashCode());
566
567     double[][] values3 = new double[][] { { 1, 2, 3 }, { 4, 5, 7 } };
568     m2 = new Matrix(values3);
569     assertFalse(m1.equals(m2));
570     assertFalse(m2.equals(m1));
571     assertNotEquals(m1.hashCode(), m2.hashCode());
572
573     // same hashCode doesn't always mean equal
574     values2 = new double[][] { { 1, 2, 3 }, { 4, 6, 5 } };
575     m2 = new Matrix(values2);
576     assertFalse(m2.equals(m1));
577     assertEquals(m1.hashCode(), m2.hashCode());
578
579     assertFalse(m1.equals(null));
580     assertFalse(m1.equals("foo"));
581   }
582 }