JAL-3746 apply copyright to tests
[jalview.git] / test / jalview / math / MatrixTest.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21 package jalview.math;
22
23 import static org.testng.Assert.assertEquals;
24 import static org.testng.Assert.assertFalse;
25 import static org.testng.Assert.assertNotSame;
26 import static org.testng.Assert.assertNull;
27 import static org.testng.Assert.assertTrue;
28 import static org.testng.Assert.fail;
29
30 import java.util.Arrays;
31 import java.util.Random;
32
33 import org.testng.annotations.Test;
34 import org.testng.internal.junit.ArrayAsserts;
35
36 public class MatrixTest
37 {
38   final static double DELTA = 0.000001d;
39
40   @Test(groups = "Timing")
41   public void testPreMultiply_timing()
42   {
43     int rows = 50; // increase to stress test timing
44     int cols = 100;
45     double[][] d1 = new double[rows][cols];
46     double[][] d2 = new double[cols][rows];
47     Matrix m1 = new Matrix(d1);
48     Matrix m2 = new Matrix(d2);
49     long start = System.currentTimeMillis();
50     m1.preMultiply(m2);
51     long elapsed = System.currentTimeMillis() - start;
52     System.out.println(rows + "x" + cols
53             + " multiplications of double took " + elapsed + "ms");
54   }
55
56   @Test(groups = "Functional")
57   public void testPreMultiply()
58   {
59     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 } }); // 1x3
60     Matrix m2 = new Matrix(new double[][] { { 5 }, { 6 }, { 7 } }); // 3x1
61
62     /*
63      * 1x3 times 3x1 is 1x1
64      * 2x5 + 3x6 + 4*7 =  56
65      */
66     MatrixI m3 = m2.preMultiply(m1);
67     assertEquals(m3.height(), 1);
68     assertEquals(m3.width(), 1);
69     assertEquals(m3.getValue(0, 0), 56d);
70
71     /*
72      * 3x1 times 1x3 is 3x3
73      */
74     m3 = m1.preMultiply(m2);
75     assertEquals(m3.height(), 3);
76     assertEquals(m3.width(), 3);
77     assertEquals(Arrays.toString(m3.getRow(0)), "[10.0, 15.0, 20.0]");
78     assertEquals(Arrays.toString(m3.getRow(1)), "[12.0, 18.0, 24.0]");
79     assertEquals(Arrays.toString(m3.getRow(2)), "[14.0, 21.0, 28.0]");
80   }
81
82   @Test(
83     groups = "Functional",
84     expectedExceptions =
85     { IllegalArgumentException.class })
86   public void testPreMultiply_tooManyColumns()
87   {
88     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
89
90     /*
91      * 2x3 times 2x3 invalid operation - 
92      * multiplier has more columns than multiplicand has rows
93      */
94     m1.preMultiply(m1);
95     fail("Expected exception");
96   }
97
98   @Test(
99     groups = "Functional",
100     expectedExceptions =
101     { IllegalArgumentException.class })
102   public void testPreMultiply_tooFewColumns()
103   {
104     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
105
106     /*
107      * 3x2 times 3x2 invalid operation - 
108      * multiplier has more columns than multiplicand has row
109      */
110     m1.preMultiply(m1);
111     fail("Expected exception");
112   }
113
114   private boolean matrixEquals(Matrix m1, Matrix m2)
115   {
116     if (m1.width() != m2.width() || m1.height() != m2.height())
117     {
118       return false;
119     }
120     for (int i = 0; i < m1.height(); i++)
121     {
122       if (!Arrays.equals(m1.getRow(i), m2.getRow(i)))
123       {
124         return false;
125       }
126     }
127     return true;
128   }
129
130   @Test(groups = "Functional")
131   public void testPostMultiply()
132   {
133     /*
134      * Square matrices
135      * (2 3) . (10   100)
136      * (4 5)   (1000 10000)
137      * =
138      * (3020 30200)
139      * (5040 50400)
140      */
141     MatrixI m1 = new Matrix(new double[][] { { 2, 3 }, { 4, 5 } });
142     MatrixI m2 = new Matrix(
143             new double[][]
144             { { 10, 100 }, { 1000, 10000 } });
145     MatrixI m3 = m1.postMultiply(m2);
146     assertEquals(Arrays.toString(m3.getRow(0)), "[3020.0, 30200.0]");
147     assertEquals(Arrays.toString(m3.getRow(1)), "[5040.0, 50400.0]");
148
149     /*
150      * also check m2.preMultiply(m1) - should be same as m1.postMultiply(m2) 
151      */
152     m3 = m2.preMultiply(m1);
153     assertEquals(Arrays.toString(m3.getRow(0)), "[3020.0, 30200.0]");
154     assertEquals(Arrays.toString(m3.getRow(1)), "[5040.0, 50400.0]");
155
156     /*
157      * m1 has more rows than columns
158      * (2).(10 100 1000) = (20 200 2000)
159      * (3)                 (30 300 3000)
160      */
161     m1 = new Matrix(new double[][] { { 2 }, { 3 } });
162     m2 = new Matrix(new double[][] { { 10, 100, 1000 } });
163     m3 = m1.postMultiply(m2);
164     assertEquals(m3.height(), 2);
165     assertEquals(m3.width(), 3);
166     assertEquals(Arrays.toString(m3.getRow(0)), "[20.0, 200.0, 2000.0]");
167     assertEquals(Arrays.toString(m3.getRow(1)), "[30.0, 300.0, 3000.0]");
168     m3 = m2.preMultiply(m1);
169     assertEquals(m3.height(), 2);
170     assertEquals(m3.width(), 3);
171     assertEquals(Arrays.toString(m3.getRow(0)), "[20.0, 200.0, 2000.0]");
172     assertEquals(Arrays.toString(m3.getRow(1)), "[30.0, 300.0, 3000.0]");
173
174     /*
175      * m1 has more columns than rows
176      * (2 3 4) . (5 4) = (56 25)
177      *           (6 3) 
178      *           (7 2)
179      * [0, 0] = 2*5 + 3*6 + 4*7 = 56
180      * [0, 1] = 2*4 + 3*3 + 4*2 = 25  
181      */
182     m1 = new Matrix(new double[][] { { 2, 3, 4 } });
183     m2 = new Matrix(new double[][] { { 5, 4 }, { 6, 3 }, { 7, 2 } });
184     m3 = m1.postMultiply(m2);
185     assertEquals(m3.height(), 1);
186     assertEquals(m3.width(), 2);
187     assertEquals(m3.getRow(0)[0], 56d);
188     assertEquals(m3.getRow(0)[1], 25d);
189
190     /*
191      * and check premultiply equivalent
192      */
193     m3 = m2.preMultiply(m1);
194     assertEquals(m3.height(), 1);
195     assertEquals(m3.width(), 2);
196     assertEquals(m3.getRow(0)[0], 56d);
197     assertEquals(m3.getRow(0)[1], 25d);
198   }
199
200   @Test(groups = "Functional")
201   public void testCopy()
202   {
203     Random r = new Random();
204     int rows = 5;
205     int cols = 11;
206     double[][] in = new double[rows][cols];
207
208     for (int i = 0; i < rows; i++)
209     {
210       for (int j = 0; j < cols; j++)
211       {
212         in[i][j] = r.nextDouble();
213       }
214     }
215     Matrix m1 = new Matrix(in);
216
217     Matrix m2 = (Matrix) m1.copy();
218     assertNotSame(m1, m2);
219     assertTrue(matrixEquals(m1, m2));
220     assertNull(m2.d);
221     assertNull(m2.e);
222
223     /*
224      * now add d and e vectors and recopy
225      */
226     m1.d = Arrays.copyOf(in[2], in[2].length);
227     m1.e = Arrays.copyOf(in[4], in[4].length);
228     m2 = (Matrix) m1.copy();
229     assertNotSame(m2.d, m1.d);
230     assertNotSame(m2.e, m1.e);
231     assertEquals(m2.d, m1.d);
232     assertEquals(m2.e, m1.e);
233   }
234
235   /**
236    * main method extracted from Matrix
237    * 
238    * @param args
239    */
240   public static void main(String[] args) throws Exception
241   {
242     int n = Integer.parseInt(args[0]);
243     double[][] in = new double[n][n];
244
245     for (int i = 0; i < n; i++)
246     {
247       for (int j = 0; j < n; j++)
248       {
249         in[i][j] = Math.random();
250       }
251     }
252
253     Matrix origmat = new Matrix(in);
254
255     // System.out.println(" --- Original matrix ---- ");
256     // / origmat.print(System.out);
257     // System.out.println();
258     // System.out.println(" --- transpose matrix ---- ");
259     MatrixI trans = origmat.transpose();
260
261     // trans.print(System.out);
262     // System.out.println();
263     // System.out.println(" --- OrigT * Orig ---- ");
264     MatrixI symm = trans.postMultiply(origmat);
265
266     // symm.print(System.out);
267     // System.out.println();
268     // Copy the symmetric matrix for later
269     // Matrix origsymm = symm.copy();
270
271     // This produces the tridiagonal transformation matrix
272     // long tstart = System.currentTimeMillis();
273     symm.tred();
274
275     // long tend = System.currentTimeMillis();
276
277     // System.out.println("Time take for tred = " + (tend-tstart) + "ms");
278     // System.out.println(" ---Tridiag transform matrix ---");
279     // symm.print(System.out);
280     // System.out.println();
281     // System.out.println(" --- D vector ---");
282     // symm.printD(System.out);
283     // System.out.println();
284     // System.out.println(" --- E vector ---");
285     // symm.printE(System.out);
286     // System.out.println();
287     // Now produce the diagonalization matrix
288     // tstart = System.currentTimeMillis();
289     symm.tqli();
290     // tend = System.currentTimeMillis();
291
292     // System.out.println("Time take for tqli = " + (tend-tstart) + " ms");
293     // System.out.println(" --- New diagonalization matrix ---");
294     // symm.print(System.out);
295     // System.out.println();
296     // System.out.println(" --- D vector ---");
297     // symm.printD(System.out);
298     // System.out.println();
299     // System.out.println(" --- E vector ---");
300     // symm.printE(System.out);
301     // System.out.println();
302     // System.out.println(" --- First eigenvector --- ");
303     // double[] eigenv = symm.getColumn(0);
304     // for (int i=0; i < eigenv.length;i++) {
305     // Format.print(System.out,"%15.4f",eigenv[i]);
306     // }
307     // System.out.println();
308     // double[] neigenv = origsymm.vectorPostMultiply(eigenv);
309     // for (int i=0; i < neigenv.length;i++) {
310     // Format.print(System.out,"%15.4f",neigenv[i]/symm.d[0]);
311     // }
312     // System.out.println();
313   }
314
315   @Test(groups = "Timing")
316   public void testSign()
317   {
318     assertEquals(Matrix.sign(-1, -2), -1d);
319     assertEquals(Matrix.sign(-1, 2), 1d);
320     assertEquals(Matrix.sign(-1, 0), 1d);
321     assertEquals(Matrix.sign(1, -2), -1d);
322     assertEquals(Matrix.sign(1, 2), 1d);
323     assertEquals(Matrix.sign(1, 0), 1d);
324   }
325
326   /**
327    * Helper method to make values for a sparse, pseudo-random symmetric matrix
328    * 
329    * @param rows
330    * @param cols
331    * @param occupancy
332    *          one in 'occupancy' entries will be non-zero
333    * @return
334    */
335   public double[][] getSparseValues(int rows, int cols, int occupancy)
336   {
337     Random r = new Random(1729);
338
339     /*
340      * generate whole number values between -12 and +12
341      * (to mimic score matrices used in Jalview)
342      */
343     double[][] d = new double[rows][cols];
344     int m = 0;
345     for (int i = 0; i < rows; i++)
346     {
347       if (++m % occupancy == 0)
348       {
349         d[i][i] = r.nextInt() % 13; // diagonal
350       }
351       for (int j = 0; j < i; j++)
352       {
353         if (++m % occupancy == 0)
354         {
355           d[i][j] = r.nextInt() % 13;
356           d[j][i] = d[i][j];
357         }
358       }
359     }
360     return d;
361
362   }
363
364   /**
365    * Verify that the results of method tred() are the same if the calculation is
366    * redone
367    */
368   @Test(groups = "Functional")
369   public void testTred_reproducible()
370   {
371     /*
372      * make a pseudo-random symmetric matrix as required for tred/tqli
373      */
374     int rows = 10;
375     int cols = rows;
376     double[][] d = getSparseValues(rows, cols, 3);
377
378     /*
379      * make a copy of the values so m1, m2 are not
380      * sharing arrays!
381      */
382     double[][] d1 = new double[rows][cols];
383     for (int row = 0; row < rows; row++)
384     {
385       for (int col = 0; col < cols; col++)
386       {
387         d1[row][col] = d[row][col];
388       }
389     }
390     Matrix m1 = new Matrix(d);
391     Matrix m2 = new Matrix(d1);
392     assertMatricesMatch(m1, m2); // sanity check
393     m1.tred();
394     m2.tred();
395     assertMatricesMatch(m1, m2);
396   }
397
398   public static void assertMatricesMatch(MatrixI m1, MatrixI m2)
399   {
400     if (m1.height() != m2.height())
401     {
402       fail("height mismatch");
403     }
404     if (m1.width() != m2.width())
405     {
406       fail("width mismatch");
407     }
408     for (int row = 0; row < m1.height(); row++)
409     {
410       for (int col = 0; col < m1.width(); col++)
411       {
412         double v2 = m2.getValue(row, col);
413         double v1 = m1.getValue(row, col);
414         if (Math.abs(v1 - v2) > DELTA)
415         {
416           fail(String.format("At [%d, %d] %f != %f", row, col, v1, v2));
417         }
418       }
419     }
420     ArrayAsserts.assertArrayEquals("D vector", m1.getD(), m2.getD(),
421             0.00001d);
422     ArrayAsserts.assertArrayEquals("E vector", m1.getE(), m2.getE(),
423             0.00001d);
424   }
425
426   @Test(groups = "Functional")
427   public void testFindMinMax()
428   {
429     /*
430      * empty matrix case
431      */
432     Matrix m = new Matrix(new double[][] { {} });
433     assertNull(m.findMinMax());
434
435     /*
436      * normal case
437      */
438     double[][] vals = new double[2][];
439     vals[0] = new double[] { 7d, 1d, -2.3d };
440     vals[1] = new double[] { -12d, 94.3d, -102.34d };
441     m = new Matrix(vals);
442     double[] minMax = m.findMinMax();
443     assertEquals(minMax[0], -102.34d);
444     assertEquals(minMax[1], 94.3d);
445   }
446
447   @Test(groups = { "Functional", "Timing" })
448   public void testFindMinMax_timing()
449   {
450     Random r = new Random();
451     int size = 1000; // increase to stress test timing
452     double[][] vals = new double[size][size];
453     double max = -Double.MAX_VALUE;
454     double min = Double.MAX_VALUE;
455     for (int i = 0; i < size; i++)
456     {
457       vals[i] = new double[size];
458       for (int j = 0; j < size; j++)
459       {
460         // use nextLong rather than nextDouble to include negative values
461         double d = r.nextLong();
462         if (d > max)
463         {
464           max = d;
465         }
466         if (d < min)
467         {
468           min = d;
469         }
470         vals[i][j] = d;
471       }
472     }
473     Matrix m = new Matrix(vals);
474     long now = System.currentTimeMillis();
475     double[] minMax = m.findMinMax();
476     System.out.println(String.format("findMinMax for %d x %d took %dms",
477             size, size, (System.currentTimeMillis() - now)));
478     assertEquals(minMax[0], min);
479     assertEquals(minMax[1], max);
480   }
481
482   /**
483    * Test range reversal with maximum value becoming zero
484    */
485   @Test(groups = "Functional")
486   public void testReverseRange_maxToZero()
487   {
488     Matrix m1 = new Matrix(
489             new double[][]
490             { { 2, 3.5, 4 }, { -3.4, 4, 15 } });
491
492     /*
493      * subtract all from max: range -3.4 to 15 becomes 18.4 to 0
494      */
495     m1.reverseRange(true);
496     assertEquals(m1.getValue(0, 0), 13d, DELTA);
497     assertEquals(m1.getValue(0, 1), 11.5d, DELTA);
498     assertEquals(m1.getValue(0, 2), 11d, DELTA);
499     assertEquals(m1.getValue(1, 0), 18.4d, DELTA);
500     assertEquals(m1.getValue(1, 1), 11d, DELTA);
501     assertEquals(m1.getValue(1, 2), 0d, DELTA);
502
503     /*
504      * repeat operation - range is now 0 to 18.4
505      */
506     m1.reverseRange(true);
507     assertEquals(m1.getValue(0, 0), 5.4d, DELTA);
508     assertEquals(m1.getValue(0, 1), 6.9d, DELTA);
509     assertEquals(m1.getValue(0, 2), 7.4d, DELTA);
510     assertEquals(m1.getValue(1, 0), 0d, DELTA);
511     assertEquals(m1.getValue(1, 1), 7.4d, DELTA);
512     assertEquals(m1.getValue(1, 2), 18.4d, DELTA);
513   }
514
515   /**
516    * Test range reversal with minimum and maximum values swapped
517    */
518   @Test(groups = "Functional")
519   public void testReverseRange_swapMinMax()
520   {
521     Matrix m1 = new Matrix(
522             new double[][]
523             { { 2, 3.5, 4 }, { -3.4, 4, 15 } });
524
525     /*
526      * swap all values in min-max range
527      * = subtract from (min + max = 11.6) 
528      * range -3.4 to 15 becomes 18.4 to -3.4
529      */
530     m1.reverseRange(false);
531     assertEquals(m1.getValue(0, 0), 9.6d, DELTA);
532     assertEquals(m1.getValue(0, 1), 8.1d, DELTA);
533     assertEquals(m1.getValue(0, 2), 7.6d, DELTA);
534     assertEquals(m1.getValue(1, 0), 15d, DELTA);
535     assertEquals(m1.getValue(1, 1), 7.6d, DELTA);
536     assertEquals(m1.getValue(1, 2), -3.4d, DELTA);
537
538     /*
539      * repeat operation - original values restored
540      */
541     m1.reverseRange(false);
542     assertEquals(m1.getValue(0, 0), 2d, DELTA);
543     assertEquals(m1.getValue(0, 1), 3.5d, DELTA);
544     assertEquals(m1.getValue(0, 2), 4d, DELTA);
545     assertEquals(m1.getValue(1, 0), -3.4d, DELTA);
546     assertEquals(m1.getValue(1, 1), 4d, DELTA);
547     assertEquals(m1.getValue(1, 2), 15d, DELTA);
548   }
549
550   @Test(groups = "Functional")
551   public void testMultiply()
552   {
553     Matrix m = new Matrix(
554             new double[][]
555             { { 2, 3.5, 4 }, { -3.4, 4, 15 } });
556     m.multiply(2d);
557     assertEquals(m.getValue(0, 0), 4d, DELTA);
558     assertEquals(m.getValue(0, 1), 7d, DELTA);
559     assertEquals(m.getValue(0, 2), 8d, DELTA);
560     assertEquals(m.getValue(1, 0), -6.8d, DELTA);
561     assertEquals(m.getValue(1, 1), 8d, DELTA);
562     assertEquals(m.getValue(1, 2), 30d, DELTA);
563   }
564
565   @Test(groups = "Functional")
566   public void testConstructor()
567   {
568     double[][] values = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
569     Matrix m = new Matrix(values);
570     assertEquals(m.getValue(0, 0), 1d, DELTA);
571
572     /*
573      * verify the matrix has a copy of the original array
574      */
575     assertNotSame(values[0], m.getRow(0));
576     values[0][0] = -1d;
577     assertEquals(m.getValue(0, 0), 1d, DELTA); // unchanged
578   }
579
580   @Test(groups = "Functional")
581   public void testEquals()
582   {
583     double[][] values = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
584     Matrix m1 = new Matrix(values);
585     double[][] values2 = new double[][] { { 1, 2, 3 }, { 4, 5, 6 } };
586     Matrix m2 = new Matrix(values2);
587
588     double delta = 0.0001d;
589     assertTrue(m1.equals(m1, delta));
590     assertTrue(m1.equals(m2, delta));
591     assertTrue(m2.equals(m1, delta));
592
593     double[][] values3 = new double[][] { { 1, 2, 3 }, { 4, 5, 7 } };
594     m2 = new Matrix(values3);
595     assertFalse(m1.equals(m2, delta));
596     assertFalse(m2.equals(m1, delta));
597
598     // must be same shape
599     values2 = new double[][] { { 1, 2, 3 } };
600     m2 = new Matrix(values2);
601     assertFalse(m2.equals(m1, delta));
602
603     assertFalse(m1.equals(null, delta));
604   }
605 }