71050c1dc8c8bedd79bafac472ce3fd0c6fd63f7
[jalview.git] / test / jalview / math / MatrixTest.java
1 package jalview.math;
2
3 import static org.testng.Assert.assertEquals;
4 import static org.testng.Assert.assertFalse;
5 import static org.testng.Assert.assertNotSame;
6 import static org.testng.Assert.assertNull;
7 import static org.testng.Assert.assertTrue;
8 import static org.testng.Assert.fail;
9
10 import java.util.Arrays;
11 import java.util.Random;
12
13 import org.testng.annotations.Test;
14 import org.testng.internal.junit.ArrayAsserts;
15
16 public class MatrixTest
17 {
18   final static double DELTA = 0.000001d;
19
20   @Test(groups = "Timing")
21   public void testPreMultiply_timing()
22   {
23     int rows = 50; // increase to stress test timing
24     int cols = 100;
25     double[][] d1 = new double[rows][cols];
26     double[][] d2 = new double[cols][rows];
27     Matrix m1 = new Matrix(d1);
28     Matrix m2 = new Matrix(d2);
29     long start = System.currentTimeMillis();
30     m1.preMultiply(m2);
31     long elapsed = System.currentTimeMillis() - start;
32     System.out.println(rows + "x" + cols
33             + " multiplications of double took " + elapsed + "ms");
34   }
35
36   @Test(groups = "Functional")
37   public void testPreMultiply()
38   {
39     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 } }); // 1x3
40     Matrix m2 = new Matrix(new double[][] { { 5 }, { 6 }, { 7 } }); // 3x1
41
42     /*
43      * 1x3 times 3x1 is 1x1
44      * 2x5 + 3x6 + 4*7 =  56
45      */
46     MatrixI m3 = m2.preMultiply(m1);
47     assertEquals(m3.height(), 1);
48     assertEquals(m3.width(), 1);
49     assertEquals(m3.getValue(0, 0), 56d);
50
51     /*
52      * 3x1 times 1x3 is 3x3
53      */
54     m3 = m1.preMultiply(m2);
55     assertEquals(m3.height(), 3);
56     assertEquals(m3.width(), 3);
57     assertEquals(Arrays.toString(m3.getRow(0)), "[10.0, 15.0, 20.0]");
58     assertEquals(Arrays.toString(m3.getRow(1)), "[12.0, 18.0, 24.0]");
59     assertEquals(Arrays.toString(m3.getRow(2)), "[14.0, 21.0, 28.0]");
60   }
61
62   @Test(
63     groups = "Functional",
64     expectedExceptions =
65     { IllegalArgumentException.class })
66   public void testPreMultiply_tooManyColumns()
67   {
68     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
69
70     /*
71      * 2x3 times 2x3 invalid operation - 
72      * multiplier has more columns than multiplicand has rows
73      */
74     m1.preMultiply(m1);
75     fail("Expected exception");
76   }
77
78   @Test(
79     groups = "Functional",
80     expectedExceptions =
81     { IllegalArgumentException.class })
82   public void testPreMultiply_tooFewColumns()
83   {
84     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
85
86     /*
87      * 3x2 times 3x2 invalid operation - 
88      * multiplier has more columns than multiplicand has row
89      */
90     m1.preMultiply(m1);
91     fail("Expected exception");
92   }
93
94   private boolean matrixEquals(Matrix m1, Matrix m2)
95   {
96     if (m1.width() != m2.width() || m1.height() != m2.height())
97     {
98       return false;
99     }
100     for (int i = 0; i < m1.height(); i++)
101     {
102       if (!Arrays.equals(m1.getRow(i), m2.getRow(i)))
103       {
104         return false;
105       }
106     }
107     return true;
108   }
109
110   @Test(groups = "Functional")
111   public void testPostMultiply()
112   {
113     /*
114      * Square matrices
115      * (2 3) . (10   100)
116      * (4 5)   (1000 10000)
117      * =
118      * (3020 30200)
119      * (5040 50400)
120      */
121     MatrixI m1 = new Matrix(new double[][] { { 2, 3 }, { 4, 5 } });
122     MatrixI m2 = new Matrix(
123             new double[][]
124             { { 10, 100 }, { 1000, 10000 } });
125     MatrixI m3 = m1.postMultiply(m2);
126     assertEquals(Arrays.toString(m3.getRow(0)), "[3020.0, 30200.0]");
127     assertEquals(Arrays.toString(m3.getRow(1)), "[5040.0, 50400.0]");
128
129     /*
130      * also check m2.preMultiply(m1) - should be same as m1.postMultiply(m2) 
131      */
132     m3 = m2.preMultiply(m1);
133     assertEquals(Arrays.toString(m3.getRow(0)), "[3020.0, 30200.0]");
134     assertEquals(Arrays.toString(m3.getRow(1)), "[5040.0, 50400.0]");
135
136     /*
137      * m1 has more rows than columns
138      * (2).(10 100 1000) = (20 200 2000)
139      * (3)                 (30 300 3000)
140      */
141     m1 = new Matrix(new double[][] { { 2 }, { 3 } });
142     m2 = new Matrix(new double[][] { { 10, 100, 1000 } });
143     m3 = m1.postMultiply(m2);
144     assertEquals(m3.height(), 2);
145     assertEquals(m3.width(), 3);
146     assertEquals(Arrays.toString(m3.getRow(0)), "[20.0, 200.0, 2000.0]");
147     assertEquals(Arrays.toString(m3.getRow(1)), "[30.0, 300.0, 3000.0]");
148     m3 = m2.preMultiply(m1);
149     assertEquals(m3.height(), 2);
150     assertEquals(m3.width(), 3);
151     assertEquals(Arrays.toString(m3.getRow(0)), "[20.0, 200.0, 2000.0]");
152     assertEquals(Arrays.toString(m3.getRow(1)), "[30.0, 300.0, 3000.0]");
153
154     /*
155      * m1 has more columns than rows
156      * (2 3 4) . (5 4) = (56 25)
157      *           (6 3) 
158      *           (7 2)
159      * [0, 0] = 2*5 + 3*6 + 4*7 = 56
160      * [0, 1] = 2*4 + 3*3 + 4*2 = 25  
161      */
162     m1 = new Matrix(new double[][] { { 2, 3, 4 } });
163     m2 = new Matrix(new double[][] { { 5, 4 }, { 6, 3 }, { 7, 2 } });
164     m3 = m1.postMultiply(m2);
165     assertEquals(m3.height(), 1);
166     assertEquals(m3.width(), 2);
167     assertEquals(m3.getRow(0)[0], 56d);
168     assertEquals(m3.getRow(0)[1], 25d);
169
170     /*
171      * and check premultiply equivalent
172      */
173     m3 = m2.preMultiply(m1);
174     assertEquals(m3.height(), 1);
175     assertEquals(m3.width(), 2);
176     assertEquals(m3.getRow(0)[0], 56d);
177     assertEquals(m3.getRow(0)[1], 25d);
178   }
179
180   @Test(groups = "Functional")
181   public void testCopy()
182   {
183     Random r = new Random();
184     int rows = 5;
185     int cols = 11;
186     double[][] in = new double[rows][cols];
187
188     for (int i = 0; i < rows; i++)
189     {
190       for (int j = 0; j < cols; j++)
191       {
192         in[i][j] = r.nextDouble();
193       }
194     }
195     Matrix m1 = new Matrix(in);
196
197     Matrix m2 = (Matrix) m1.copy();
198     assertNotSame(m1, m2);
199     assertTrue(matrixEquals(m1, m2));
200     assertNull(m2.d);
201     assertNull(m2.e);
202
203     /*
204      * now add d and e vectors and recopy
205      */
206     m1.d = Arrays.copyOf(in[2], in[2].length);
207     m1.e = Arrays.copyOf(in[4], in[4].length);
208     m2 = (Matrix) m1.copy();
209     assertNotSame(m2.d, m1.d);
210     assertNotSame(m2.e, m1.e);
211     assertEquals(m2.d, m1.d);
212     assertEquals(m2.e, m1.e);
213   }
214
215   /**
216    * main method extracted from Matrix
217    * 
218    * @param args
219    */
220   public static void main(String[] args) throws Exception
221   {
222     int n = Integer.parseInt(args[0]);
223     double[][] in = new double[n][n];
224
225     for (int i = 0; i < n; i++)
226     {
227       for (int j = 0; j < n; j++)
228       {
229         in[i][j] = Math.random();
230       }
231     }
232
233     Matrix origmat = new Matrix(in);
234
235     // System.out.println(" --- Original matrix ---- ");
236     // / origmat.print(System.out);
237     // System.out.println();
238     // System.out.println(" --- transpose matrix ---- ");
239     MatrixI trans = origmat.transpose();
240
241     // trans.print(System.out);
242     // System.out.println();
243     // System.out.println(" --- OrigT * Orig ---- ");
244     MatrixI symm = trans.postMultiply(origmat);
245
246     // symm.print(System.out);
247     // System.out.println();
248     // Copy the symmetric matrix for later
249     // Matrix origsymm = symm.copy();
250
251     // This produces the tridiagonal transformation matrix
252     // long tstart = System.currentTimeMillis();
253     symm.tred();
254
255     // long tend = System.currentTimeMillis();
256
257     // System.out.println("Time take for tred = " + (tend-tstart) + "ms");
258     // System.out.println(" ---Tridiag transform matrix ---");
259     // symm.print(System.out);
260     // System.out.println();
261     // System.out.println(" --- D vector ---");
262     // symm.printD(System.out);
263     // System.out.println();
264     // System.out.println(" --- E vector ---");
265     // symm.printE(System.out);
266     // System.out.println();
267     // Now produce the diagonalization matrix
268     // tstart = System.currentTimeMillis();
269     symm.tqli();
270     // tend = System.currentTimeMillis();
271
272     // System.out.println("Time take for tqli = " + (tend-tstart) + " ms");
273     // System.out.println(" --- New diagonalization matrix ---");
274     // symm.print(System.out);
275     // System.out.println();
276     // System.out.println(" --- D vector ---");
277     // symm.printD(System.out);
278     // System.out.println();
279     // System.out.println(" --- E vector ---");
280     // symm.printE(System.out);
281     // System.out.println();
282     // System.out.println(" --- First eigenvector --- ");
283     // double[] eigenv = symm.getColumn(0);
284     // for (int i=0; i < eigenv.length;i++) {
285     // Format.print(System.out,"%15.4f",eigenv[i]);
286     // }
287     // System.out.println();
288     // double[] neigenv = origsymm.vectorPostMultiply(eigenv);
289     // for (int i=0; i < neigenv.length;i++) {
290     // Format.print(System.out,"%15.4f",neigenv[i]/symm.d[0]);
291     // }
292     // System.out.println();
293   }
294
295   @Test(groups = "Timing")
296   public void testSign()
297   {
298     assertEquals(Matrix.sign(-1, -2), -1d);
299     assertEquals(Matrix.sign(-1, 2), 1d);
300     assertEquals(Matrix.sign(-1, 0), 1d);
301     assertEquals(Matrix.sign(1, -2), -1d);
302     assertEquals(Matrix.sign(1, 2), 1d);
303     assertEquals(Matrix.sign(1, 0), 1d);
304   }
305
306   /**
307    * Helper method to make values for a sparse, pseudo-random symmetric matrix
308    * 
309    * @param rows
310    * @param cols
311    * @param occupancy
312    *          one in 'occupancy' entries will be non-zero
313    * @return
314    */
315   public double[][] getSparseValues(int rows, int cols, int occupancy)
316   {
317     Random r = new Random(1729);
318
319     /*
320      * generate whole number values between -12 and +12
321      * (to mimic score matrices used in Jalview)
322      */
323     double[][] d = new double[rows][cols];
324     int m = 0;
325     for (int i = 0; i < rows; i++)
326     {
327       if (++m % occupancy == 0)
328       {
329         d[i][i] = r.nextInt() % 13; // diagonal
330       }
331       for (int j = 0; j < i; j++)
332       {
333         if (++m % occupancy == 0)
334         {
335           d[i][j] = r.nextInt() % 13;
336           d[j][i] = d[i][j];
337         }
338       }
339     }
340     return d;
341
342   }
343
344   /**
345    * Verify that the results of method tred() are the same if the calculation is
346    * redone
347    */
348   @Test(groups = "Functional")
349   public void testTred_reproducible()
350   {
351     /*
352      * make a pseudo-random symmetric matrix as required for tred/tqli
353      */
354     int rows = 10;
355     int cols = rows;
356     double[][] d = getSparseValues(rows, cols, 3);
357
358     /*
359      * make a copy of the values so m1, m2 are not
360      * sharing arrays!
361      */
362     double[][] d1 = new double[rows][cols];
363     for (int row = 0; row < rows; row++)
364     {
365       for (int col = 0; col < cols; col++)
366       {
367         d1[row][col] = d[row][col];
368       }
369     }
370     Matrix m1 = new Matrix(d);
371     Matrix m2 = new Matrix(d1);
372     assertMatricesMatch(m1, m2); // sanity check
373     m1.tred();
374     m2.tred();
375     assertMatricesMatch(m1, m2);
376   }
377
378   public static void assertMatricesMatch(MatrixI m1, MatrixI m2)
379   {
380     if (m1.height() != m2.height())
381     {
382       fail("height mismatch");
383     }
384     if (m1.width() != m2.width())
385     {
386       fail("width mismatch");
387     }
388     for (int row = 0; row < m1.height(); row++)
389     {
390       for (int col = 0; col < m1.width(); col++)
391       {
392         double v2 = m2.getValue(row, col);
393         double v1 = m1.getValue(row, col);
394         if (Math.abs(v1 - v2) > DELTA)
395         {
396           fail(String.format("At [%d, %d] %f != %f", row, col, v1, v2));
397         }
398       }
399     }
400     ArrayAsserts.assertArrayEquals("D vector", m1.getD(), m2.getD(),
401             0.00001d);
402     ArrayAsserts.assertArrayEquals("E vector", m1.getE(), m2.getE(),
403             0.00001d);
404   }
405
406   @Test(groups = "Functional")
407   public void testFindMinMax()
408   {
409     /*
410      * empty matrix case
411      */
412     Matrix m = new Matrix(new double[][] { {} });
413     assertNull(m.findMinMax());
414
415     /*
416      * normal case
417      */
418     double[][] vals = new double[2][];
419     vals[0] = new double[] { 7d, 1d, -2.3d };
420     vals[1] = new double[] { -12d, 94.3d, -102.34d };
421     m = new Matrix(vals);
422     double[] minMax = m.findMinMax();
423     assertEquals(minMax[0], -102.34d);
424     assertEquals(minMax[1], 94.3d);
425   }
426
427   @Test(groups = { "Functional", "Timing" })
428   public void testFindMinMax_timing()
429   {
430     Random r = new Random();
431     int size = 1000; // increase to stress test timing
432     double[][] vals = new double[size][size];
433     double max = -Double.MAX_VALUE;
434     double min = Double.MAX_VALUE;
435     for (int i = 0; i < size; i++)
436     {
437       vals[i] = new double[size];
438       for (int j = 0; j < size; j++)
439       {
440         // use nextLong rather than nextDouble to include negative values
441         double d = r.nextLong();
442         if (d > max)
443         {
444           max = d;
445         }
446         if (d < min)
447         {
448           min = d;
449         }
450         vals[i][j] = d;
451       }
452     }
453     Matrix m = new Matrix(vals);
454     long now = System.currentTimeMillis();
455     double[] minMax = m.findMinMax();
456     System.out.println(String.format("findMinMax for %d x %d took %dms",
457             size, size, (System.currentTimeMillis() - now)));
458     assertEquals(minMax[0], min);
459     assertEquals(minMax[1], max);
460   }
461
462   /**
463    * Test range reversal with maximum value becoming zero
464    */
465   @Test(groups = "Functional")
466   public void testReverseRange_maxToZero()
467   {
468     Matrix m1 = new Matrix(
469             new double[][]
470             { { 2, 3.5, 4 }, { -3.4, 4, 15 } });
471
472     /*
473      * subtract all from max: range -3.4 to 15 becomes 18.4 to 0
474      */
475     m1.reverseRange(true);
476     assertEquals(m1.getValue(0, 0), 13d, DELTA);
477     assertEquals(m1.getValue(0, 1), 11.5d, DELTA);
478     assertEquals(m1.getValue(0, 2), 11d, DELTA);
479     assertEquals(m1.getValue(1, 0), 18.4d, DELTA);
480     assertEquals(m1.getValue(1, 1), 11d, DELTA);
481     assertEquals(m1.getValue(1, 2), 0d, DELTA);
482
483     /*
484      * repeat operation - range is now 0 to 18.4
485      */
486     m1.reverseRange(true);
487     assertEquals(m1.getValue(0, 0), 5.4d, DELTA);
488     assertEquals(m1.getValue(0, 1), 6.9d, DELTA);
489     assertEquals(m1.getValue(0, 2), 7.4d, DELTA);
490     assertEquals(m1.getValue(1, 0), 0d, DELTA);
491     assertEquals(m1.getValue(1, 1), 7.4d, DELTA);
492     assertEquals(m1.getValue(1, 2), 18.4d, DELTA);
493   }
494
495   /**
496    * Test range reversal with minimum and maximum values swapped
497    */
498   @Test(groups = "Functional")
499   public void testReverseRange_swapMinMax()
500   {
501     Matrix m1 = new Matrix(
502             new double[][]
503             { { 2, 3.5, 4 }, { -3.4, 4, 15 } });
504
505     /*
506      * swap all values in min-max range
507      * = subtract from (min + max = 11.6) 
508      * range -3.4 to 15 becomes 18.4 to -3.4
509      */
510     m1.reverseRange(false);
511     assertEquals(m1.getValue(0, 0), 9.6d, DELTA);
512     assertEquals(m1.getValue(0, 1), 8.1d, DELTA);
513     assertEquals(m1.getValue(0, 2), 7.6d, DELTA);
514     assertEquals(m1.getValue(1, 0), 15d, DELTA);
515     assertEquals(m1.getValue(1, 1), 7.6d, DELTA);
516     assertEquals(m1.getValue(1, 2), -3.4d, DELTA);
517
518     /*
519      * repeat operation - original values restored
520      */
521     m1.reverseRange(false);
522     assertEquals(m1.getValue(0, 0), 2d, DELTA);
523     assertEquals(m1.getValue(0, 1), 3.5d, DELTA);
524     assertEquals(m1.getValue(0, 2), 4d, DELTA);
525     assertEquals(m1.getValue(1, 0), -3.4d, DELTA);
526     assertEquals(m1.getValue(1, 1), 4d, DELTA);
527     assertEquals(m1.getValue(1, 2), 15d, DELTA);
528   }
529
530   @Test(groups = "Functional")
531   public void testMultiply()
532   {
533     Matrix m = new Matrix(
534             new double[][]
535             { { 2, 3.5, 4 }, { -3.4, 4, 15 } });
536     m.multiply(2d);
537     assertEquals(m.getValue(0, 0), 4d, DELTA);
538     assertEquals(m.getValue(0, 1), 7d, DELTA);
539     assertEquals(m.getValue(0, 2), 8d, DELTA);
540     assertEquals(m.getValue(1, 0), -6.8d, DELTA);
541     assertEquals(m.getValue(1, 1), 8d, DELTA);
542     assertEquals(m.getValue(1, 2), 30d, DELTA);
543   }
544
545   @Test(groups = "Functional")
546   public void testConstructor()
547   {
548     double[][] values = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
549     Matrix m = new Matrix(values);
550     assertEquals(m.getValue(0, 0), 1d, DELTA);
551
552     /*
553      * verify the matrix has a copy of the original array
554      */
555     assertNotSame(values[0], m.getRow(0));
556     values[0][0] = -1d;
557     assertEquals(m.getValue(0, 0), 1d, DELTA); // unchanged
558   }
559
560   @Test(groups = "Functional")
561   public void testEquals()
562   {
563     double[][] values = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
564     Matrix m1 = new Matrix(values);
565     double[][] values2 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
566     Matrix m2 = new Matrix(values2);
567
568     double delta = 0.0001d;
569     assertTrue(m1.equals(m1, delta));
570     assertTrue(m1.equals(m2, delta));
571     assertTrue(m2.equals(m1, delta));
572
573     double[][] values3 = new double[][] { { 1, 2, 3 }, { 4, 5, 7 } };
574     m2 = new Matrix(values3);
575     assertFalse(m1.equals(m2, delta));
576     assertFalse(m2.equals(m1, delta));
577
578     // must be same shape
579     values2 = new double[][] { { 1, 2, 3 } };
580     m2 = new Matrix(values2);
581     assertFalse(m2.equals(m1, delta));
582
583     assertFalse(m1.equals(null, delta));
584   }
585 }