3 import static org.testng.Assert.assertEquals;
4 import static org.testng.Assert.assertTrue;
5 import static org.testng.Assert.fail;
7 import java.util.Arrays;
8 import java.util.Random;
10 import org.testng.annotations.Test;
12 public class MatrixTest
14 @Test(groups = "Timing")
15 public void testPreMultiply_timing()
19 double[][] d1 = new double[rows][cols];
20 double[][] d2 = new double[cols][rows];
21 Matrix m1 = new Matrix(d1);
22 Matrix m2 = new Matrix(d2);
23 long start = System.currentTimeMillis();
25 long elapsed = System.currentTimeMillis() - start;
26 System.out.println(rows + "x" + cols
27 + " multiplications of double took " + elapsed + "ms");
30 @Test(groups = "Functional")
31 public void testPreMultiply()
33 Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 } }); // 1x3
34 Matrix m2 = new Matrix(new double[][] { { 5 }, { 6 }, { 7 } }); // 3x1
37 * 1x3 times 3x1 is 1x1
38 * 2x5 + 3x6 + 4*7 = 56
40 Matrix m3 = m2.preMultiply(m1);
41 assertEquals(m3.rows, 1);
42 assertEquals(m3.cols, 1);
43 assertEquals(m3.value[0][0], 56d);
46 * 3x1 times 1x3 is 3x3
48 m3 = m1.preMultiply(m2);
49 assertEquals(m3.rows, 3);
50 assertEquals(m3.cols, 3);
51 assertEquals(Arrays.toString(m3.value[0]), "[10.0, 15.0, 20.0]");
52 assertEquals(Arrays.toString(m3.value[1]), "[12.0, 18.0, 24.0]");
53 assertEquals(Arrays.toString(m3.value[2]), "[14.0, 21.0, 28.0]");
57 groups = "Functional",
58 expectedExceptions = { IllegalArgumentException.class })
59 public void testPreMultiply_tooManyColumns()
61 Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
64 * 2x3 times 2x3 invalid operation -
65 * multiplier has more columns than multiplicand has rows
68 fail("Expected exception");
72 groups = "Functional",
73 expectedExceptions = { IllegalArgumentException.class })
74 public void testPreMultiply_tooFewColumns()
76 Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
79 * 3x2 times 3x2 invalid operation -
80 * multiplier has more columns than multiplicand has row
83 fail("Expected exception");
87 private boolean matrixEquals(Matrix m1, Matrix m2) {
88 return Arrays.deepEquals(m1.value, m2.value);
91 @Test(groups = "Functional")
92 public void testPostMultiply()
102 Matrix m1 = new Matrix(new double[][] { { 2, 3 }, { 4, 5 } });
103 Matrix m2 = new Matrix(new double[][] { { 10, 100 }, { 1000, 10000 } });
104 Matrix m3 = m1.postMultiply(m2);
105 assertEquals(Arrays.toString(m3.value[0]), "[3020.0, 30200.0]");
106 assertEquals(Arrays.toString(m3.value[1]), "[5040.0, 50400.0]");
109 * also check m2.preMultiply(m1) - should be same as m1.postMultiply(m2)
111 m3 = m2.preMultiply(m1);
112 assertEquals(Arrays.toString(m3.value[0]), "[3020.0, 30200.0]");
113 assertEquals(Arrays.toString(m3.value[1]), "[5040.0, 50400.0]");
116 * m1 has more rows than columns
117 * (2).(10 100 1000) = (20 200 2000)
120 m1 = new Matrix(new double[][] { { 2 }, { 3 } });
121 m2 = new Matrix(new double[][] { { 10, 100, 1000 } });
122 m3 = m1.postMultiply(m2);
123 assertEquals(m3.rows, 2);
124 assertEquals(m3.cols, 3);
125 assertEquals(Arrays.toString(m3.value[0]), "[20.0, 200.0, 2000.0]");
126 assertEquals(Arrays.toString(m3.value[1]), "[30.0, 300.0, 3000.0]");
127 m3 = m2.preMultiply(m1);
128 assertEquals(m3.rows, 2);
129 assertEquals(m3.cols, 3);
130 assertEquals(Arrays.toString(m3.value[0]), "[20.0, 200.0, 2000.0]");
131 assertEquals(Arrays.toString(m3.value[1]), "[30.0, 300.0, 3000.0]");
134 * m1 has more columns than rows
135 * (2 3 4) . (5 4) = (56 25)
138 * [0, 0] = 2*5 + 3*6 + 4*7 = 56
139 * [0, 1] = 2*4 + 3*3 + 4*2 = 25
141 m1 = new Matrix(new double[][] { { 2, 3, 4 } });
142 m2 = new Matrix(new double[][] { { 5, 4 }, { 6, 3 }, { 7, 2 } });
143 m3 = m1.postMultiply(m2);
144 assertEquals(m3.rows, 1);
145 assertEquals(m3.cols, 2);
146 assertEquals(m3.value[0][0], 56d);
147 assertEquals(m3.value[0][1], 25d);
150 * and check premultiply equivalent
152 m3 = m2.preMultiply(m1);
153 assertEquals(m3.rows, 1);
154 assertEquals(m3.cols, 2);
155 assertEquals(m3.value[0][0], 56d);
156 assertEquals(m3.value[0][1], 25d);
159 @Test(groups = "Functional")
160 public void testCopy()
162 Random r = new Random();
165 double[][] in = new double[rows][cols];
167 for (int i = 0; i < rows; i++)
169 for (int j = 0; j < cols; j++)
171 in[i][j] = r.nextDouble();
174 Matrix m1 = new Matrix(in);
175 Matrix m2 = m1.copy();
176 assertTrue(matrixEquals(m1, m2));
180 * main method extracted from Matrix
184 public static void main(String[] args) throws Exception
186 int n = Integer.parseInt(args[0]);
187 double[][] in = new double[n][n];
189 for (int i = 0; i < n; i++)
191 for (int j = 0; j < n; j++)
193 in[i][j] = Math.random();
197 Matrix origmat = new Matrix(in);
199 // System.out.println(" --- Original matrix ---- ");
200 // / origmat.print(System.out);
201 // System.out.println();
202 // System.out.println(" --- transpose matrix ---- ");
203 Matrix trans = origmat.transpose();
205 // trans.print(System.out);
206 // System.out.println();
207 // System.out.println(" --- OrigT * Orig ---- ");
208 Matrix symm = trans.postMultiply(origmat);
210 // symm.print(System.out);
211 // System.out.println();
212 // Copy the symmetric matrix for later
213 // Matrix origsymm = symm.copy();
215 // This produces the tridiagonal transformation matrix
216 // long tstart = System.currentTimeMillis();
219 // long tend = System.currentTimeMillis();
221 // System.out.println("Time take for tred = " + (tend-tstart) + "ms");
222 // System.out.println(" ---Tridiag transform matrix ---");
223 // symm.print(System.out);
224 // System.out.println();
225 // System.out.println(" --- D vector ---");
226 // symm.printD(System.out);
227 // System.out.println();
228 // System.out.println(" --- E vector ---");
229 // symm.printE(System.out);
230 // System.out.println();
231 // Now produce the diagonalization matrix
232 // tstart = System.currentTimeMillis();
234 // tend = System.currentTimeMillis();
236 // System.out.println("Time take for tqli = " + (tend-tstart) + " ms");
237 // System.out.println(" --- New diagonalization matrix ---");
238 // symm.print(System.out);
239 // System.out.println();
240 // System.out.println(" --- D vector ---");
241 // symm.printD(System.out);
242 // System.out.println();
243 // System.out.println(" --- E vector ---");
244 // symm.printE(System.out);
245 // System.out.println();
246 // System.out.println(" --- First eigenvector --- ");
247 // double[] eigenv = symm.getColumn(0);
248 // for (int i=0; i < eigenv.length;i++) {
249 // Format.print(System.out,"%15.4f",eigenv[i]);
251 // System.out.println();
252 // double[] neigenv = origsymm.vectorPostMultiply(eigenv);
253 // for (int i=0; i < neigenv.length;i++) {
254 // Format.print(System.out,"%15.4f",neigenv[i]/symm.d[0]);
256 // System.out.println();