JAL-2380 fix Matrix.postMultiply()
[jalview.git] / test / jalview / math / MatrixTest.java
1 package jalview.math;
2
3 import static org.testng.Assert.assertEquals;
4 import static org.testng.Assert.fail;
5
6 import java.util.Arrays;
7
8 import org.testng.annotations.Test;
9
10 public class MatrixTest
11 {
12   @Test(groups = "Timing")
13   public void testPreMultiply_timing()
14   {
15     int rows = 500;
16     int cols = 1000;
17     double[][] d1 = new double[rows][cols];
18     double[][] d2 = new double[cols][rows];
19     Matrix m1 = new Matrix(d1, rows, cols);
20     Matrix m2 = new Matrix(d2, cols, rows);
21     long start = System.currentTimeMillis();
22     m1.preMultiply(m2);
23     long elapsed = System.currentTimeMillis() - start;
24     System.out.println(rows + "x" + cols
25             + " multiplications of double took " + elapsed + "ms");
26   }
27
28   @Test(groups = "Functional")
29   public void testPreMultiply()
30   {
31     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 } }, 1, 3); // 1x3
32     Matrix m2 = new Matrix(new double[][] { { 5 }, { 6 }, { 7 } }, 3, 1); // 3x1
33
34     /*
35      * 1x3 times 3x1 is 1x1
36      * 2x5 + 3x6 + 4*7 =  56
37      */
38     Matrix m3 = m2.preMultiply(m1);
39     assertEquals(m3.rows, 1);
40     assertEquals(m3.cols, 1);
41     assertEquals(m3.value[0][0], 56d);
42
43     /*
44      * 3x1 times 1x3 is 3x3
45      */
46     m3 = m1.preMultiply(m2);
47     assertEquals(m3.rows, 3);
48     assertEquals(m3.cols, 3);
49     assertEquals(Arrays.toString(m3.value[0]), "[10.0, 15.0, 20.0]");
50     assertEquals(Arrays.toString(m3.value[1]), "[12.0, 18.0, 24.0]");
51     assertEquals(Arrays.toString(m3.value[2]), "[14.0, 21.0, 28.0]");
52   }
53
54   @Test(
55     groups = "Functional",
56     expectedExceptions = { IllegalArgumentException.class })
57   public void testPreMultiply_tooManyColumns()
58   {
59     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }, 2,
60             3); // 2x3
61
62     /*
63      * 2x3 times 2x3 invalid operation - 
64      * multiplier has more columns than multiplicand has rows
65      */
66     m1.preMultiply(m1);
67     fail("Expected exception");
68   }
69
70   @Test(
71     groups = "Functional",
72     expectedExceptions = { IllegalArgumentException.class })
73   public void testPreMultiply_tooFewColumns()
74   {
75     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }, 2,
76             3); // 2x3
77
78     /*
79      * 3x2 times 3x2 invalid operation - 
80      * multiplier has more columns than multiplicand has row
81      */
82     m1.preMultiply(m1);
83     fail("Expected exception");
84   }
85   
86   
87   private boolean matrixEquals(Matrix m1, Matrix m2) {
88     return Arrays.deepEquals(m1.value, m2.value);
89   }
90
91   @Test(groups = "Functional")
92   public void testPostMultiply()
93   {
94     /*
95      * Square matrices
96      * (2 3) . (10   100)
97      * (4 5)   (1000 10000)
98      * =
99      * (3020 30200)
100      * (5040 50400)
101      */
102     Matrix m1 = new Matrix(new double[][] { { 2, 3 }, { 4, 5 } }, 2, 2);
103     Matrix m2 = new Matrix(new double[][] { { 10, 100 }, { 1000, 10000 } },
104             2, 2);
105     Matrix m3 = m1.postMultiply(m2);
106     assertEquals(Arrays.toString(m3.value[0]), "[3020.0, 30200.0]");
107     assertEquals(Arrays.toString(m3.value[1]), "[5040.0, 50400.0]");
108
109     /*
110      * also check m2.preMultiply(m1) - should be same as m1.postMultiply(m2) 
111      */
112     m3 = m2.preMultiply(m1);
113     assertEquals(Arrays.toString(m3.value[0]), "[3020.0, 30200.0]");
114     assertEquals(Arrays.toString(m3.value[1]), "[5040.0, 50400.0]");
115
116     /*
117      * m1 has more rows than columns
118      * (2).(10 100 1000) = (20 200 2000)
119      * (3)                 (30 300 3000)
120      */
121     m1 = new Matrix(new double[][] { { 2 }, { 3 } }, 2, 1);
122     m2 = new Matrix(new double[][] { { 10, 100, 1000 } }, 1, 3);
123     m3 = m1.postMultiply(m2);
124     assertEquals(m3.rows, 2);
125     assertEquals(m3.cols, 3);
126     assertEquals(Arrays.toString(m3.value[0]), "[20.0, 200.0, 2000.0]");
127     assertEquals(Arrays.toString(m3.value[1]), "[30.0, 300.0, 3000.0]");
128     m3 = m2.preMultiply(m1);
129     assertEquals(m3.rows, 2);
130     assertEquals(m3.cols, 3);
131     assertEquals(Arrays.toString(m3.value[0]), "[20.0, 200.0, 2000.0]");
132     assertEquals(Arrays.toString(m3.value[1]), "[30.0, 300.0, 3000.0]");
133
134     /*
135      * m1 has more columns than rows
136      * (2 3 4) . (5 4) = (56 25)
137      *           (6 3) 
138      *           (7 2)
139      * [0, 0] = 2*5 + 3*6 + 4*7 = 56
140      * [0, 1] = 2*4 + 3*3 + 4*2 = 25  
141      */
142     m1 = new Matrix(new double[][] { { 2, 3, 4 } }, 1, 3);
143     m2 = new Matrix(new double[][] { { 5, 4 }, { 6, 3 }, { 7, 2 } }, 3, 2);
144     m3 = m1.postMultiply(m2);
145     assertEquals(m3.rows, 1);
146     assertEquals(m3.cols, 2);
147     assertEquals(m3.value[0][0], 56d);
148     assertEquals(m3.value[0][1], 25d);
149
150     /*
151      * and check premultiply equivalent
152      */
153     m3 = m2.preMultiply(m1);
154     assertEquals(m3.rows, 1);
155     assertEquals(m3.cols, 2);
156     assertEquals(m3.value[0][0], 56d);
157     assertEquals(m3.value[0][1], 25d);
158   }
159
160   /**
161    * main method extracted from Matrix
162    * 
163    * @param args
164    */
165   public static void main(String[] args) throws Exception
166   {
167     int n = Integer.parseInt(args[0]);
168     double[][] in = new double[n][n];
169   
170     for (int i = 0; i < n; i++)
171     {
172       for (int j = 0; j < n; j++)
173       {
174         in[i][j] = Math.random();
175       }
176     }
177   
178     Matrix origmat = new Matrix(in, n, n);
179   
180     // System.out.println(" --- Original matrix ---- ");
181     // / origmat.print(System.out);
182     // System.out.println();
183     // System.out.println(" --- transpose matrix ---- ");
184     Matrix trans = origmat.transpose();
185   
186     // trans.print(System.out);
187     // System.out.println();
188     // System.out.println(" --- OrigT * Orig ---- ");
189     Matrix symm = trans.postMultiply(origmat);
190   
191     // symm.print(System.out);
192     // System.out.println();
193     // Copy the symmetric matrix for later
194     // Matrix origsymm = symm.copy();
195   
196     // This produces the tridiagonal transformation matrix
197     // long tstart = System.currentTimeMillis();
198     symm.tred();
199   
200     // long tend = System.currentTimeMillis();
201   
202     // System.out.println("Time take for tred = " + (tend-tstart) + "ms");
203     // System.out.println(" ---Tridiag transform matrix ---");
204     // symm.print(System.out);
205     // System.out.println();
206     // System.out.println(" --- D vector ---");
207     // symm.printD(System.out);
208     // System.out.println();
209     // System.out.println(" --- E vector ---");
210     // symm.printE(System.out);
211     // System.out.println();
212     // Now produce the diagonalization matrix
213     // tstart = System.currentTimeMillis();
214     symm.tqli();
215     // tend = System.currentTimeMillis();
216   
217     // System.out.println("Time take for tqli = " + (tend-tstart) + " ms");
218     // System.out.println(" --- New diagonalization matrix ---");
219     // symm.print(System.out);
220     // System.out.println();
221     // System.out.println(" --- D vector ---");
222     // symm.printD(System.out);
223     // System.out.println();
224     // System.out.println(" --- E vector ---");
225     // symm.printE(System.out);
226     // System.out.println();
227     // System.out.println(" --- First eigenvector --- ");
228     // double[] eigenv = symm.getColumn(0);
229     // for (int i=0; i < eigenv.length;i++) {
230     // Format.print(System.out,"%15.4f",eigenv[i]);
231     // }
232     // System.out.println();
233     // double[] neigenv = origsymm.vectorPostMultiply(eigenv);
234     // for (int i=0; i < neigenv.length;i++) {
235     // Format.print(System.out,"%15.4f",neigenv[i]/symm.d[0]);
236     // }
237     // System.out.println();
238   }
239 }