3 import static org.testng.Assert.assertEquals;
4 import static org.testng.Assert.fail;
6 import java.util.Arrays;
8 import org.testng.annotations.Test;
10 public class MatrixTest
12 @Test(groups = "Timing")
13 public void testPreMultiply_timing()
17 double[][] d1 = new double[rows][cols];
18 double[][] d2 = new double[cols][rows];
19 Matrix m1 = new Matrix(d1);
20 Matrix m2 = new Matrix(d2);
21 long start = System.currentTimeMillis();
23 long elapsed = System.currentTimeMillis() - start;
24 System.out.println(rows + "x" + cols
25 + " multiplications of double took " + elapsed + "ms");
28 @Test(groups = "Functional")
29 public void testPreMultiply()
31 MatrixI m1 = new Matrix(new double[][] { { 2, 3, 4 } }); // 1x3
32 MatrixI m2 = new Matrix(new double[][] { { 5 }, { 6 }, { 7 } }); // 3x1
35 * 1x3 times 3x1 is 1x1
36 * 2x5 + 3x6 + 4*7 = 56
38 MatrixI m3 = m2.preMultiply(m1);
39 assertEquals(m3.height(), 1);
40 assertEquals(m3.width(), 1);
41 assertEquals(m3.getValue(0, 0), 56d);
44 * 3x1 times 1x3 is 3x3
46 m3 = m1.preMultiply(m2);
47 assertEquals(m3.height(), 3);
48 assertEquals(m3.width(), 3);
49 assertEquals(Arrays.toString(m3.getRow(0)), "[10.0, 15.0, 20.0]");
50 assertEquals(Arrays.toString(m3.getRow(1)), "[12.0, 18.0, 24.0]");
51 assertEquals(Arrays.toString(m3.getRow(2)), "[14.0, 21.0, 28.0]");
55 groups = "Functional",
56 expectedExceptions = { IllegalArgumentException.class })
57 public void testPreMultiply_tooManyColumns()
59 Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
62 * 2x3 times 2x3 invalid operation -
63 * multiplier has more columns than multiplicand has rows
66 fail("Expected exception");
70 groups = "Functional",
71 expectedExceptions = { IllegalArgumentException.class })
72 public void testPreMultiply_tooFewColumns()
74 Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
77 * 3x2 times 3x2 invalid operation -
78 * multiplier has more columns than multiplicand has row
81 fail("Expected exception");
85 private boolean matrixEquals(Matrix m1, Matrix m2) {
86 if (m1.width() != m2.width() || m1.height() != m2.height())
90 for (int i = 0; i < m1.height(); i++)
92 if (!Arrays.equals(m1.getRow(i), m2.getRow(i)))
100 @Test(groups = "Functional")
101 public void testPostMultiply()
111 MatrixI m1 = new Matrix(new double[][] { { 2, 3 }, { 4, 5 } });
112 MatrixI m2 = new Matrix(new double[][] { { 10, 100 }, { 1000, 10000 } });
113 MatrixI m3 = m1.postMultiply(m2);
114 assertEquals(Arrays.toString(m3.getRow(0)), "[3020.0, 30200.0]");
115 assertEquals(Arrays.toString(m3.getRow(1)), "[5040.0, 50400.0]");
118 * also check m2.preMultiply(m1) - should be same as m1.postMultiply(m2)
120 m3 = m2.preMultiply(m1);
121 assertEquals(Arrays.toString(m3.getRow(0)), "[3020.0, 30200.0]");
122 assertEquals(Arrays.toString(m3.getRow(1)), "[5040.0, 50400.0]");
125 * m1 has more rows than columns
126 * (2).(10 100 1000) = (20 200 2000)
129 m1 = new Matrix(new double[][] { { 2 }, { 3 } });
130 m2 = new Matrix(new double[][] { { 10, 100, 1000 } });
131 m3 = m1.postMultiply(m2);
132 assertEquals(m3.height(), 2);
133 assertEquals(m3.width(), 3);
134 assertEquals(Arrays.toString(m3.getRow(0)), "[20.0, 200.0, 2000.0]");
135 assertEquals(Arrays.toString(m3.getRow(1)), "[30.0, 300.0, 3000.0]");
136 m3 = m2.preMultiply(m1);
137 assertEquals(m3.height(), 2);
138 assertEquals(m3.width(), 3);
139 assertEquals(Arrays.toString(m3.getRow(0)), "[20.0, 200.0, 2000.0]");
140 assertEquals(Arrays.toString(m3.getRow(1)), "[30.0, 300.0, 3000.0]");
143 * m1 has more columns than rows
144 * (2 3 4) . (5 4) = (56 25)
147 * [0, 0] = 2*5 + 3*6 + 4*7 = 56
148 * [0, 1] = 2*4 + 3*3 + 4*2 = 25
150 m1 = new Matrix(new double[][] { { 2, 3, 4 } });
151 m2 = new Matrix(new double[][] { { 5, 4 }, { 6, 3 }, { 7, 2 } });
152 m3 = m1.postMultiply(m2);
153 assertEquals(m3.height(), 1);
154 assertEquals(m3.width(), 2);
155 assertEquals(m3.getRow(0)[0], 56d);
156 assertEquals(m3.getRow(0)[1], 25d);
159 * and check premultiply equivalent
161 m3 = m2.preMultiply(m1);
162 assertEquals(m3.height(), 1);
163 assertEquals(m3.width(), 2);
164 assertEquals(m3.getRow(0)[0], 56d);
165 assertEquals(m3.getRow(0)[1], 25d);
169 * main method extracted from Matrix
173 public static void main(String[] args) throws Exception
175 int n = Integer.parseInt(args[0]);
176 double[][] in = new double[n][n];
178 for (int i = 0; i < n; i++)
180 for (int j = 0; j < n; j++)
182 in[i][j] = Math.random();
186 Matrix origmat = new Matrix(in);
188 // System.out.println(" --- Original matrix ---- ");
189 // / origmat.print(System.out);
190 // System.out.println();
191 // System.out.println(" --- transpose matrix ---- ");
192 MatrixI trans = origmat.transpose();
194 // trans.print(System.out);
195 // System.out.println();
196 // System.out.println(" --- OrigT * Orig ---- ");
197 MatrixI symm = trans.postMultiply(origmat);
199 // symm.print(System.out);
200 // System.out.println();
201 // Copy the symmetric matrix for later
202 // Matrix origsymm = symm.copy();
204 // This produces the tridiagonal transformation matrix
205 // long tstart = System.currentTimeMillis();
208 // long tend = System.currentTimeMillis();
210 // System.out.println("Time take for tred = " + (tend-tstart) + "ms");
211 // System.out.println(" ---Tridiag transform matrix ---");
212 // symm.print(System.out);
213 // System.out.println();
214 // System.out.println(" --- D vector ---");
215 // symm.printD(System.out);
216 // System.out.println();
217 // System.out.println(" --- E vector ---");
218 // symm.printE(System.out);
219 // System.out.println();
220 // Now produce the diagonalization matrix
221 // tstart = System.currentTimeMillis();
223 // tend = System.currentTimeMillis();
225 // System.out.println("Time take for tqli = " + (tend-tstart) + " ms");
226 // System.out.println(" --- New diagonalization matrix ---");
227 // symm.print(System.out);
228 // System.out.println();
229 // System.out.println(" --- D vector ---");
230 // symm.printD(System.out);
231 // System.out.println();
232 // System.out.println(" --- E vector ---");
233 // symm.printE(System.out);
234 // System.out.println();
235 // System.out.println(" --- First eigenvector --- ");
236 // double[] eigenv = symm.getColumn(0);
237 // for (int i=0; i < eigenv.length;i++) {
238 // Format.print(System.out,"%15.4f",eigenv[i]);
240 // System.out.println();
241 // double[] neigenv = origsymm.vectorPostMultiply(eigenv);
242 // for (int i=0; i < neigenv.length;i++) {
243 // Format.print(System.out,"%15.4f",neigenv[i]/symm.d[0]);
245 // System.out.println();
248 @Test(groups = "Timing")
249 public void testSign()
251 assertEquals(Matrix.sign(-1, -2), -1d);
252 assertEquals(Matrix.sign(-1, 2), 1d);
253 assertEquals(Matrix.sign(-1, 0), 1d);
254 assertEquals(Matrix.sign(1, -2), -1d);
255 assertEquals(Matrix.sign(1, 2), 1d);
256 assertEquals(Matrix.sign(1, 0), 1d);