JAL-2380 code tidy after review
[jalview.git] / test / jalview / math / MatrixTest.java
1 package jalview.math;
2
3 import static org.testng.Assert.assertEquals;
4 import static org.testng.Assert.assertTrue;
5 import static org.testng.Assert.fail;
6
7 import java.util.Arrays;
8 import java.util.Random;
9
10 import org.testng.annotations.Test;
11
12 public class MatrixTest
13 {
14   @Test(groups = "Timing")
15   public void testPreMultiply_timing()
16   {
17     int rows = 500;
18     int cols = 1000;
19     double[][] d1 = new double[rows][cols];
20     double[][] d2 = new double[cols][rows];
21     Matrix m1 = new Matrix(d1);
22     Matrix m2 = new Matrix(d2);
23     long start = System.currentTimeMillis();
24     m1.preMultiply(m2);
25     long elapsed = System.currentTimeMillis() - start;
26     System.out.println(rows + "x" + cols
27             + " multiplications of double took " + elapsed + "ms");
28   }
29
30   @Test(groups = "Functional")
31   public void testPreMultiply()
32   {
33     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 } }); // 1x3
34     Matrix m2 = new Matrix(new double[][] { { 5 }, { 6 }, { 7 } }); // 3x1
35
36     /*
37      * 1x3 times 3x1 is 1x1
38      * 2x5 + 3x6 + 4*7 =  56
39      */
40     Matrix m3 = m2.preMultiply(m1);
41     assertEquals(m3.rows, 1);
42     assertEquals(m3.cols, 1);
43     assertEquals(m3.value[0][0], 56d);
44
45     /*
46      * 3x1 times 1x3 is 3x3
47      */
48     m3 = m1.preMultiply(m2);
49     assertEquals(m3.rows, 3);
50     assertEquals(m3.cols, 3);
51     assertEquals(Arrays.toString(m3.value[0]), "[10.0, 15.0, 20.0]");
52     assertEquals(Arrays.toString(m3.value[1]), "[12.0, 18.0, 24.0]");
53     assertEquals(Arrays.toString(m3.value[2]), "[14.0, 21.0, 28.0]");
54   }
55
56   @Test(
57     groups = "Functional",
58     expectedExceptions = { IllegalArgumentException.class })
59   public void testPreMultiply_tooManyColumns()
60   {
61     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
62
63     /*
64      * 2x3 times 2x3 invalid operation - 
65      * multiplier has more columns than multiplicand has rows
66      */
67     m1.preMultiply(m1);
68     fail("Expected exception");
69   }
70
71   @Test(
72     groups = "Functional",
73     expectedExceptions = { IllegalArgumentException.class })
74   public void testPreMultiply_tooFewColumns()
75   {
76     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
77
78     /*
79      * 3x2 times 3x2 invalid operation - 
80      * multiplier has more columns than multiplicand has row
81      */
82     m1.preMultiply(m1);
83     fail("Expected exception");
84   }
85   
86   
87   private boolean matrixEquals(Matrix m1, Matrix m2) {
88     return Arrays.deepEquals(m1.value, m2.value);
89   }
90
91   @Test(groups = "Functional")
92   public void testPostMultiply()
93   {
94     /*
95      * Square matrices
96      * (2 3) . (10   100)
97      * (4 5)   (1000 10000)
98      * =
99      * (3020 30200)
100      * (5040 50400)
101      */
102     Matrix m1 = new Matrix(new double[][] { { 2, 3 }, { 4, 5 } });
103     Matrix m2 = new Matrix(new double[][] { { 10, 100 }, { 1000, 10000 } });
104     Matrix m3 = m1.postMultiply(m2);
105     assertEquals(Arrays.toString(m3.value[0]), "[3020.0, 30200.0]");
106     assertEquals(Arrays.toString(m3.value[1]), "[5040.0, 50400.0]");
107
108     /*
109      * also check m2.preMultiply(m1) - should be same as m1.postMultiply(m2) 
110      */
111     m3 = m2.preMultiply(m1);
112     assertEquals(Arrays.toString(m3.value[0]), "[3020.0, 30200.0]");
113     assertEquals(Arrays.toString(m3.value[1]), "[5040.0, 50400.0]");
114
115     /*
116      * m1 has more rows than columns
117      * (2).(10 100 1000) = (20 200 2000)
118      * (3)                 (30 300 3000)
119      */
120     m1 = new Matrix(new double[][] { { 2 }, { 3 } });
121     m2 = new Matrix(new double[][] { { 10, 100, 1000 } });
122     m3 = m1.postMultiply(m2);
123     assertEquals(m3.rows, 2);
124     assertEquals(m3.cols, 3);
125     assertEquals(Arrays.toString(m3.value[0]), "[20.0, 200.0, 2000.0]");
126     assertEquals(Arrays.toString(m3.value[1]), "[30.0, 300.0, 3000.0]");
127     m3 = m2.preMultiply(m1);
128     assertEquals(m3.rows, 2);
129     assertEquals(m3.cols, 3);
130     assertEquals(Arrays.toString(m3.value[0]), "[20.0, 200.0, 2000.0]");
131     assertEquals(Arrays.toString(m3.value[1]), "[30.0, 300.0, 3000.0]");
132
133     /*
134      * m1 has more columns than rows
135      * (2 3 4) . (5 4) = (56 25)
136      *           (6 3) 
137      *           (7 2)
138      * [0, 0] = 2*5 + 3*6 + 4*7 = 56
139      * [0, 1] = 2*4 + 3*3 + 4*2 = 25  
140      */
141     m1 = new Matrix(new double[][] { { 2, 3, 4 } });
142     m2 = new Matrix(new double[][] { { 5, 4 }, { 6, 3 }, { 7, 2 } });
143     m3 = m1.postMultiply(m2);
144     assertEquals(m3.rows, 1);
145     assertEquals(m3.cols, 2);
146     assertEquals(m3.value[0][0], 56d);
147     assertEquals(m3.value[0][1], 25d);
148
149     /*
150      * and check premultiply equivalent
151      */
152     m3 = m2.preMultiply(m1);
153     assertEquals(m3.rows, 1);
154     assertEquals(m3.cols, 2);
155     assertEquals(m3.value[0][0], 56d);
156     assertEquals(m3.value[0][1], 25d);
157   }
158
159   @Test(groups = "Functional")
160   public void testCopy()
161   {
162     Random r = new Random();
163     int rows = 5;
164     int cols = 11;
165     double[][] in = new double[rows][cols];
166
167     for (int i = 0; i < rows; i++)
168     {
169       for (int j = 0; j < cols; j++)
170       {
171         in[i][j] = r.nextDouble();
172       }
173     }
174     Matrix m1 = new Matrix(in);
175     Matrix m2 = m1.copy();
176     assertTrue(matrixEquals(m1, m2));
177   }
178
179   /**
180    * main method extracted from Matrix
181    * 
182    * @param args
183    */
184   public static void main(String[] args) throws Exception
185   {
186     int n = Integer.parseInt(args[0]);
187     double[][] in = new double[n][n];
188   
189     for (int i = 0; i < n; i++)
190     {
191       for (int j = 0; j < n; j++)
192       {
193         in[i][j] = Math.random();
194       }
195     }
196   
197     Matrix origmat = new Matrix(in);
198   
199     // System.out.println(" --- Original matrix ---- ");
200     // / origmat.print(System.out);
201     // System.out.println();
202     // System.out.println(" --- transpose matrix ---- ");
203     Matrix trans = origmat.transpose();
204   
205     // trans.print(System.out);
206     // System.out.println();
207     // System.out.println(" --- OrigT * Orig ---- ");
208     Matrix symm = trans.postMultiply(origmat);
209   
210     // symm.print(System.out);
211     // System.out.println();
212     // Copy the symmetric matrix for later
213     // Matrix origsymm = symm.copy();
214   
215     // This produces the tridiagonal transformation matrix
216     // long tstart = System.currentTimeMillis();
217     symm.tred();
218   
219     // long tend = System.currentTimeMillis();
220   
221     // System.out.println("Time take for tred = " + (tend-tstart) + "ms");
222     // System.out.println(" ---Tridiag transform matrix ---");
223     // symm.print(System.out);
224     // System.out.println();
225     // System.out.println(" --- D vector ---");
226     // symm.printD(System.out);
227     // System.out.println();
228     // System.out.println(" --- E vector ---");
229     // symm.printE(System.out);
230     // System.out.println();
231     // Now produce the diagonalization matrix
232     // tstart = System.currentTimeMillis();
233     symm.tqli();
234     // tend = System.currentTimeMillis();
235   
236     // System.out.println("Time take for tqli = " + (tend-tstart) + " ms");
237     // System.out.println(" --- New diagonalization matrix ---");
238     // symm.print(System.out);
239     // System.out.println();
240     // System.out.println(" --- D vector ---");
241     // symm.printD(System.out);
242     // System.out.println();
243     // System.out.println(" --- E vector ---");
244     // symm.printE(System.out);
245     // System.out.println();
246     // System.out.println(" --- First eigenvector --- ");
247     // double[] eigenv = symm.getColumn(0);
248     // for (int i=0; i < eigenv.length;i++) {
249     // Format.print(System.out,"%15.4f",eigenv[i]);
250     // }
251     // System.out.println();
252     // double[] neigenv = origsymm.vectorPostMultiply(eigenv);
253     // for (int i=0; i < neigenv.length;i++) {
254     // Format.print(System.out,"%15.4f",neigenv[i]/symm.d[0]);
255     // }
256     // System.out.println();
257   }
258 }