3 import static org.testng.Assert.assertEquals;
4 import static org.testng.Assert.fail;
6 import java.util.Arrays;
8 import org.testng.annotations.Test;
10 public class MatrixTest
12 @Test(groups = "Timing")
13 public void testPreMultiply_timing()
17 double[][] d1 = new double[rows][cols];
18 double[][] d2 = new double[cols][rows];
19 Matrix m1 = new Matrix(d1, rows, cols);
20 Matrix m2 = new Matrix(d2, cols, rows);
21 long start = System.currentTimeMillis();
23 long elapsed = System.currentTimeMillis() - start;
24 System.out.println(rows + "x" + cols
25 + " multiplications of double took " + elapsed + "ms");
28 @Test(groups = "Functional")
29 public void testPreMultiply()
31 Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 } }, 1, 3); // 1x3
32 Matrix m2 = new Matrix(new double[][] { { 5 }, { 6 }, { 7 } }, 3, 1); // 3x1
35 * 1x3 times 3x1 is 1x1
36 * 2x5 + 3x6 + 4*7 = 56
38 Matrix m3 = m2.preMultiply(m1);
39 assertEquals(m3.rows, 1);
40 assertEquals(m3.cols, 1);
41 assertEquals(m3.value[0][0], 56d);
44 * 3x1 times 1x3 is 3x3
46 m3 = m1.preMultiply(m2);
47 assertEquals(m3.rows, 3);
48 assertEquals(m3.cols, 3);
49 assertEquals(Arrays.toString(m3.value[0]), "[10.0, 15.0, 20.0]");
50 assertEquals(Arrays.toString(m3.value[1]), "[12.0, 18.0, 24.0]");
51 assertEquals(Arrays.toString(m3.value[2]), "[14.0, 21.0, 28.0]");
55 groups = "Functional",
56 expectedExceptions = { IllegalArgumentException.class })
57 public void testPreMultiply_tooManyColumns()
59 Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }, 2,
63 * 2x3 times 2x3 invalid operation -
64 * multiplier has more columns than multiplicand has rows
67 fail("Expected exception");
71 groups = "Functional",
72 expectedExceptions = { IllegalArgumentException.class })
73 public void testPreMultiply_tooFewColumns()
75 Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }, 2,
79 * 3x2 times 3x2 invalid operation -
80 * multiplier has more columns than multiplicand has row
83 fail("Expected exception");
87 private boolean matrixEquals(Matrix m1, Matrix m2) {
88 return Arrays.deepEquals(m1.value, m2.value);
91 @Test(groups = "Functional")
92 public void testPostMultiply()
102 Matrix m1 = new Matrix(new double[][] { { 2, 3 }, { 4, 5 } }, 2, 2);
103 Matrix m2 = new Matrix(new double[][] { { 10, 100 }, { 1000, 10000 } },
105 Matrix m3 = m1.postMultiply(m2);
106 assertEquals(Arrays.toString(m3.value[0]), "[3020.0, 30200.0]");
107 assertEquals(Arrays.toString(m3.value[1]), "[5040.0, 50400.0]");
110 * also check m2.preMultiply(m1) - should be same as m1.postMultiply(m2)
112 m3 = m2.preMultiply(m1);
113 assertEquals(Arrays.toString(m3.value[0]), "[3020.0, 30200.0]");
114 assertEquals(Arrays.toString(m3.value[1]), "[5040.0, 50400.0]");
117 * m1 has more rows than columns
118 * (2).(10 100 1000) = (20 200 2000)
121 m1 = new Matrix(new double[][] { { 2 }, { 3 } }, 2, 1);
122 m2 = new Matrix(new double[][] { { 10, 100, 1000 } }, 1, 3);
123 m3 = m1.postMultiply(m2);
124 assertEquals(m3.rows, 2);
125 assertEquals(m3.cols, 3);
126 assertEquals(Arrays.toString(m3.value[0]), "[20.0, 200.0, 2000.0]");
127 assertEquals(Arrays.toString(m3.value[1]), "[30.0, 300.0, 3000.0]");
128 m3 = m2.preMultiply(m1);
129 assertEquals(m3.rows, 2);
130 assertEquals(m3.cols, 3);
131 assertEquals(Arrays.toString(m3.value[0]), "[20.0, 200.0, 2000.0]");
132 assertEquals(Arrays.toString(m3.value[1]), "[30.0, 300.0, 3000.0]");
135 * m1 has more columns than rows
136 * (2 3 4) . (5 4) = (56 25)
139 * [0, 0] = 2*5 + 3*6 + 4*7 = 56
140 * [0, 1] = 2*4 + 3*3 + 4*2 = 25
142 m1 = new Matrix(new double[][] { { 2, 3, 4 } }, 1, 3);
143 m2 = new Matrix(new double[][] { { 5, 4 }, { 6, 3 }, { 7, 2 } }, 3, 2);
144 m3 = m1.postMultiply(m2);
145 assertEquals(m3.rows, 1);
146 assertEquals(m3.cols, 2);
147 assertEquals(m3.value[0][0], 56d);
148 assertEquals(m3.value[0][1], 25d);
151 * and check premultiply equivalent
153 m3 = m2.preMultiply(m1);
154 assertEquals(m3.rows, 1);
155 assertEquals(m3.cols, 2);
156 assertEquals(m3.value[0][0], 56d);
157 assertEquals(m3.value[0][1], 25d);
161 * main method extracted from Matrix
165 public static void main(String[] args) throws Exception
167 int n = Integer.parseInt(args[0]);
168 double[][] in = new double[n][n];
170 for (int i = 0; i < n; i++)
172 for (int j = 0; j < n; j++)
174 in[i][j] = Math.random();
178 Matrix origmat = new Matrix(in, n, n);
180 // System.out.println(" --- Original matrix ---- ");
181 // / origmat.print(System.out);
182 // System.out.println();
183 // System.out.println(" --- transpose matrix ---- ");
184 Matrix trans = origmat.transpose();
186 // trans.print(System.out);
187 // System.out.println();
188 // System.out.println(" --- OrigT * Orig ---- ");
189 Matrix symm = trans.postMultiply(origmat);
191 // symm.print(System.out);
192 // System.out.println();
193 // Copy the symmetric matrix for later
194 // Matrix origsymm = symm.copy();
196 // This produces the tridiagonal transformation matrix
197 // long tstart = System.currentTimeMillis();
200 // long tend = System.currentTimeMillis();
202 // System.out.println("Time take for tred = " + (tend-tstart) + "ms");
203 // System.out.println(" ---Tridiag transform matrix ---");
204 // symm.print(System.out);
205 // System.out.println();
206 // System.out.println(" --- D vector ---");
207 // symm.printD(System.out);
208 // System.out.println();
209 // System.out.println(" --- E vector ---");
210 // symm.printE(System.out);
211 // System.out.println();
212 // Now produce the diagonalization matrix
213 // tstart = System.currentTimeMillis();
215 // tend = System.currentTimeMillis();
217 // System.out.println("Time take for tqli = " + (tend-tstart) + " ms");
218 // System.out.println(" --- New diagonalization matrix ---");
219 // symm.print(System.out);
220 // System.out.println();
221 // System.out.println(" --- D vector ---");
222 // symm.printD(System.out);
223 // System.out.println();
224 // System.out.println(" --- E vector ---");
225 // symm.printE(System.out);
226 // System.out.println();
227 // System.out.println(" --- First eigenvector --- ");
228 // double[] eigenv = symm.getColumn(0);
229 // for (int i=0; i < eigenv.length;i++) {
230 // Format.print(System.out,"%15.4f",eigenv[i]);
232 // System.out.println();
233 // double[] neigenv = origsymm.vectorPostMultiply(eigenv);
234 // for (int i=0; i < neigenv.length;i++) {
235 // Format.print(System.out,"%15.4f",neigenv[i]/symm.d[0]);
237 // System.out.println();