1500dc68d90a66b6a0c976beb2a629d50b67808f
[jalview.git] / test / jalview / math / MatrixTest.java
1 package jalview.math;
2
3 import static org.testng.Assert.assertEquals;
4 import static org.testng.Assert.fail;
5
6 import java.util.Arrays;
7
8 import org.testng.annotations.Test;
9
10 public class MatrixTest
11 {
12   @Test(groups = "Timing")
13   public void testPreMultiply_timing()
14   {
15     int rows = 500;
16     int cols = 1000;
17     double[][] d1 = new double[rows][cols];
18     double[][] d2 = new double[cols][rows];
19     Matrix m1 = new Matrix(d1);
20     Matrix m2 = new Matrix(d2);
21     long start = System.currentTimeMillis();
22     m1.preMultiply(m2);
23     long elapsed = System.currentTimeMillis() - start;
24     System.out.println(rows + "x" + cols
25             + " multiplications of double took " + elapsed + "ms");
26   }
27
28   @Test(groups = "Functional")
29   public void testPreMultiply()
30   {
31     MatrixI m1 = new Matrix(new double[][] { { 2, 3, 4 } }); // 1x3
32     MatrixI m2 = new Matrix(new double[][] { { 5 }, { 6 }, { 7 } }); // 3x1
33
34     /*
35      * 1x3 times 3x1 is 1x1
36      * 2x5 + 3x6 + 4*7 =  56
37      */
38     MatrixI m3 = m2.preMultiply(m1);
39     assertEquals(m3.height(), 1);
40     assertEquals(m3.width(), 1);
41     assertEquals(m3.getValue(0, 0), 56d);
42
43     /*
44      * 3x1 times 1x3 is 3x3
45      */
46     m3 = m1.preMultiply(m2);
47     assertEquals(m3.height(), 3);
48     assertEquals(m3.width(), 3);
49     assertEquals(Arrays.toString(m3.getRow(0)), "[10.0, 15.0, 20.0]");
50     assertEquals(Arrays.toString(m3.getRow(1)), "[12.0, 18.0, 24.0]");
51     assertEquals(Arrays.toString(m3.getRow(2)), "[14.0, 21.0, 28.0]");
52   }
53
54   @Test(
55     groups = "Functional",
56     expectedExceptions = { IllegalArgumentException.class })
57   public void testPreMultiply_tooManyColumns()
58   {
59     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
60
61     /*
62      * 2x3 times 2x3 invalid operation - 
63      * multiplier has more columns than multiplicand has rows
64      */
65     m1.preMultiply(m1);
66     fail("Expected exception");
67   }
68
69   @Test(
70     groups = "Functional",
71     expectedExceptions = { IllegalArgumentException.class })
72   public void testPreMultiply_tooFewColumns()
73   {
74     Matrix m1 = new Matrix(new double[][] { { 2, 3, 4 }, { 3, 4, 5 } }); // 2x3
75
76     /*
77      * 3x2 times 3x2 invalid operation - 
78      * multiplier has more columns than multiplicand has row
79      */
80     m1.preMultiply(m1);
81     fail("Expected exception");
82   }
83   
84   
85   private boolean matrixEquals(Matrix m1, Matrix m2) {
86     if (m1.width() != m2.width() || m1.height() != m2.height())
87     {
88       return false;
89     }
90     for (int i = 0; i < m1.height(); i++)
91     {
92       if (!Arrays.equals(m1.getRow(i), m2.getRow(i)))
93       {
94         return false;
95       }
96     }
97     return true;
98   }
99
100   @Test(groups = "Functional")
101   public void testPostMultiply()
102   {
103     /*
104      * Square matrices
105      * (2 3) . (10   100)
106      * (4 5)   (1000 10000)
107      * =
108      * (3020 30200)
109      * (5040 50400)
110      */
111     MatrixI m1 = new Matrix(new double[][] { { 2, 3 }, { 4, 5 } });
112     MatrixI m2 = new Matrix(new double[][] { { 10, 100 }, { 1000, 10000 } });
113     MatrixI m3 = m1.postMultiply(m2);
114     assertEquals(Arrays.toString(m3.getRow(0)), "[3020.0, 30200.0]");
115     assertEquals(Arrays.toString(m3.getRow(1)), "[5040.0, 50400.0]");
116
117     /*
118      * also check m2.preMultiply(m1) - should be same as m1.postMultiply(m2) 
119      */
120     m3 = m2.preMultiply(m1);
121     assertEquals(Arrays.toString(m3.getRow(0)), "[3020.0, 30200.0]");
122     assertEquals(Arrays.toString(m3.getRow(1)), "[5040.0, 50400.0]");
123
124     /*
125      * m1 has more rows than columns
126      * (2).(10 100 1000) = (20 200 2000)
127      * (3)                 (30 300 3000)
128      */
129     m1 = new Matrix(new double[][] { { 2 }, { 3 } });
130     m2 = new Matrix(new double[][] { { 10, 100, 1000 } });
131     m3 = m1.postMultiply(m2);
132     assertEquals(m3.height(), 2);
133     assertEquals(m3.width(), 3);
134     assertEquals(Arrays.toString(m3.getRow(0)), "[20.0, 200.0, 2000.0]");
135     assertEquals(Arrays.toString(m3.getRow(1)), "[30.0, 300.0, 3000.0]");
136     m3 = m2.preMultiply(m1);
137     assertEquals(m3.height(), 2);
138     assertEquals(m3.width(), 3);
139     assertEquals(Arrays.toString(m3.getRow(0)), "[20.0, 200.0, 2000.0]");
140     assertEquals(Arrays.toString(m3.getRow(1)), "[30.0, 300.0, 3000.0]");
141
142     /*
143      * m1 has more columns than rows
144      * (2 3 4) . (5 4) = (56 25)
145      *           (6 3) 
146      *           (7 2)
147      * [0, 0] = 2*5 + 3*6 + 4*7 = 56
148      * [0, 1] = 2*4 + 3*3 + 4*2 = 25  
149      */
150     m1 = new Matrix(new double[][] { { 2, 3, 4 } });
151     m2 = new Matrix(new double[][] { { 5, 4 }, { 6, 3 }, { 7, 2 } });
152     m3 = m1.postMultiply(m2);
153     assertEquals(m3.height(), 1);
154     assertEquals(m3.width(), 2);
155     assertEquals(m3.getRow(0)[0], 56d);
156     assertEquals(m3.getRow(0)[1], 25d);
157
158     /*
159      * and check premultiply equivalent
160      */
161     m3 = m2.preMultiply(m1);
162     assertEquals(m3.height(), 1);
163     assertEquals(m3.width(), 2);
164     assertEquals(m3.getRow(0)[0], 56d);
165     assertEquals(m3.getRow(0)[1], 25d);
166   }
167
168   /**
169    * main method extracted from Matrix
170    * 
171    * @param args
172    */
173   public static void main(String[] args) throws Exception
174   {
175     int n = Integer.parseInt(args[0]);
176     double[][] in = new double[n][n];
177   
178     for (int i = 0; i < n; i++)
179     {
180       for (int j = 0; j < n; j++)
181       {
182         in[i][j] = Math.random();
183       }
184     }
185   
186     Matrix origmat = new Matrix(in);
187   
188     // System.out.println(" --- Original matrix ---- ");
189     // / origmat.print(System.out);
190     // System.out.println();
191     // System.out.println(" --- transpose matrix ---- ");
192     MatrixI trans = origmat.transpose();
193   
194     // trans.print(System.out);
195     // System.out.println();
196     // System.out.println(" --- OrigT * Orig ---- ");
197     MatrixI symm = trans.postMultiply(origmat);
198   
199     // symm.print(System.out);
200     // System.out.println();
201     // Copy the symmetric matrix for later
202     // Matrix origsymm = symm.copy();
203   
204     // This produces the tridiagonal transformation matrix
205     // long tstart = System.currentTimeMillis();
206     symm.tred();
207   
208     // long tend = System.currentTimeMillis();
209   
210     // System.out.println("Time take for tred = " + (tend-tstart) + "ms");
211     // System.out.println(" ---Tridiag transform matrix ---");
212     // symm.print(System.out);
213     // System.out.println();
214     // System.out.println(" --- D vector ---");
215     // symm.printD(System.out);
216     // System.out.println();
217     // System.out.println(" --- E vector ---");
218     // symm.printE(System.out);
219     // System.out.println();
220     // Now produce the diagonalization matrix
221     // tstart = System.currentTimeMillis();
222     symm.tqli();
223     // tend = System.currentTimeMillis();
224   
225     // System.out.println("Time take for tqli = " + (tend-tstart) + " ms");
226     // System.out.println(" --- New diagonalization matrix ---");
227     // symm.print(System.out);
228     // System.out.println();
229     // System.out.println(" --- D vector ---");
230     // symm.printD(System.out);
231     // System.out.println();
232     // System.out.println(" --- E vector ---");
233     // symm.printE(System.out);
234     // System.out.println();
235     // System.out.println(" --- First eigenvector --- ");
236     // double[] eigenv = symm.getColumn(0);
237     // for (int i=0; i < eigenv.length;i++) {
238     // Format.print(System.out,"%15.4f",eigenv[i]);
239     // }
240     // System.out.println();
241     // double[] neigenv = origsymm.vectorPostMultiply(eigenv);
242     // for (int i=0; i < neigenv.length;i++) {
243     // Format.print(System.out,"%15.4f",neigenv[i]/symm.d[0]);
244     // }
245     // System.out.println();
246   }
247
248   @Test(groups = "Timing")
249   public void testSign()
250   {
251     assertEquals(Matrix.sign(-1, -2), -1d);
252     assertEquals(Matrix.sign(-1, 2), 1d);
253     assertEquals(Matrix.sign(-1, 0), 1d);
254     assertEquals(Matrix.sign(1, -2), -1d);
255     assertEquals(Matrix.sign(1, 2), 1d);
256     assertEquals(Matrix.sign(1, 0), 1d);
257   }
258 }