JAL-3205 added a delta parameter to Matrix.equals()
[jalview.git] / src / jalview / math / Matrix.java
1 /*
2  * Jalview - A Sequence Alignment Editor and Viewer ($$Version-Rel$$)
3  * Copyright (C) $$Year-Rel$$ The Jalview Authors
4  * 
5  * This file is part of Jalview.
6  * 
7  * Jalview is free software: you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License 
9  * as published by the Free Software Foundation, either version 3
10  * of the License, or (at your option) any later version.
11  *  
12  * Jalview is distributed in the hope that it will be useful, but 
13  * WITHOUT ANY WARRANTY; without even the implied warranty 
14  * of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
15  * PURPOSE.  See the GNU General Public License for more details.
16  * 
17  * You should have received a copy of the GNU General Public License
18  * along with Jalview.  If not, see <http://www.gnu.org/licenses/>.
19  * The Jalview Authors are detailed in the 'AUTHORS' file.
20  */
21 package jalview.math;
22
23 import jalview.util.Format;
24 import jalview.util.MessageManager;
25
26 import java.io.PrintStream;
27 import java.util.Arrays;
28
29 /**
30  * A class to model rectangular matrices of double values and operations on them
31  */
32 public class Matrix implements MatrixI
33 {
34   /*
35    * maximum number of iterations for tqli
36    */
37   private static final int MAX_ITER = 45;
38   // fudge - add 15 iterations, just in case
39
40   /*
41    * the number of rows
42    */
43   final protected int rows;
44
45   /*
46    * the number of columns
47    */
48   final protected int cols;
49
50   /*
51    * the cell values in row-major order
52    */
53   private double[][] value;
54
55   protected double[] d; // Diagonal
56
57   protected double[] e; // off diagonal
58
59   /**
60    * Constructor given number of rows and columns
61    * 
62    * @param colCount
63    * @param rowCount
64    */
65   protected Matrix(int rowCount, int colCount)
66   {
67     rows = rowCount;
68     cols = colCount;
69   }
70
71   /**
72    * Creates a new Matrix object containing a copy of the supplied array values.
73    * For example
74    * 
75    * <pre>
76    *   new Matrix(new double[][] {{2, 3, 4}, {5, 6, 7})
77    * constructs
78    *   (2 3 4)
79    *   (5 6 7)
80    * </pre>
81    * 
82    * Note that ragged arrays (with not all rows, or columns, of the same
83    * length), are not supported by this class. They can be constructed, but
84    * results of operations on them are undefined and may throw exceptions.
85    * 
86    * @param values
87    *          the matrix values in row-major order
88    */
89   public Matrix(double[][] values)
90   {
91     this.rows = values.length;
92     this.cols = this.rows == 0 ? 0 : values[0].length;
93
94     /*
95      * make a copy of the values array, for immutability
96      */
97     this.value = new double[rows][];
98     int i = 0;
99     for (double[] row : values)
100     {
101       if (row != null)
102       {
103         value[i] = new double[row.length];
104         System.arraycopy(row, 0, value[i], 0, row.length);
105       }
106       i++;
107     }
108   }
109
110   @Override
111   public MatrixI transpose()
112   {
113     double[][] out = new double[cols][rows];
114
115     for (int i = 0; i < cols; i++)
116     {
117       for (int j = 0; j < rows; j++)
118       {
119         out[i][j] = value[j][i];
120       }
121     }
122
123     return new Matrix(out);
124   }
125
126   /**
127    * DOCUMENT ME!
128    * 
129    * @param ps
130    *          DOCUMENT ME!
131    * @param format
132    */
133   @Override
134   public void print(PrintStream ps, String format)
135   {
136     for (int i = 0; i < rows; i++)
137     {
138       for (int j = 0; j < cols; j++)
139       {
140         Format.print(ps, format, getValue(i, j));
141       }
142
143       ps.println();
144     }
145   }
146
147   @Override
148   public MatrixI preMultiply(MatrixI in)
149   {
150     if (in.width() != rows)
151     {
152       throw new IllegalArgumentException("Can't pre-multiply " + this.rows
153               + " rows by " + in.width() + " columns");
154     }
155     double[][] tmp = new double[in.height()][this.cols];
156
157     for (int i = 0; i < in.height(); i++)
158     {
159       for (int j = 0; j < this.cols; j++)
160       {
161         /*
162          * result[i][j] is the vector product of 
163          * in.row[i] and this.column[j]
164          */
165         for (int k = 0; k < in.width(); k++)
166         {
167           tmp[i][j] += (in.getValue(i, k) * this.value[k][j]);
168         }
169       }
170     }
171
172     return new Matrix(tmp);
173   }
174
175   /**
176    * 
177    * @param in
178    * 
179    * @return
180    */
181   public double[] vectorPostMultiply(double[] in)
182   {
183     double[] out = new double[in.length];
184
185     for (int i = 0; i < in.length; i++)
186     {
187       out[i] = 0.0;
188
189       for (int k = 0; k < in.length; k++)
190       {
191         out[i] += (value[i][k] * in[k]);
192       }
193     }
194
195     return out;
196   }
197
198   @Override
199   public MatrixI postMultiply(MatrixI in)
200   {
201     if (in.height() != this.cols)
202     {
203       throw new IllegalArgumentException("Can't post-multiply " + this.cols
204               + " columns by " + in.height() + " rows");
205     }
206     return in.preMultiply(this);
207   }
208
209   @Override
210   public MatrixI copy()
211   {
212     double[][] newmat = new double[rows][cols];
213
214     for (int i = 0; i < rows; i++)
215     {
216       System.arraycopy(value[i], 0, newmat[i], 0, value[i].length);
217     }
218
219     Matrix m = new Matrix(newmat);
220     if (this.d != null)
221     {
222       m.d = Arrays.copyOf(this.d, this.d.length);
223     }
224     if (this.e != null)
225     {
226       m.e = Arrays.copyOf(this.e, this.e.length);
227     }
228
229     return m;
230   }
231
232   /**
233    * DOCUMENT ME!
234    */
235   @Override
236   public void tred()
237   {
238     int n = rows;
239     int k;
240     int j;
241     int i;
242
243     double scale;
244     double hh;
245     double h;
246     double g;
247     double f;
248
249     this.d = new double[rows];
250     this.e = new double[rows];
251
252     for (i = n; i >= 2; i--)
253     {
254       final int l = i - 1;
255       h = 0.0;
256       scale = 0.0;
257
258       if (l > 1)
259       {
260         for (k = 1; k <= l; k++)
261         {
262           double v = Math.abs(getValue(i - 1, k - 1));
263           scale += v;
264         }
265
266         if (scale == 0.0)
267         {
268           e[i - 1] = getValue(i - 1, l - 1);
269         }
270         else
271         {
272           for (k = 1; k <= l; k++)
273           {
274             double v = divideValue(i - 1, k - 1, scale);
275             h += v * v;
276           }
277
278           f = getValue(i - 1, l - 1);
279
280           if (f > 0)
281           {
282             g = -1.0 * Math.sqrt(h);
283           }
284           else
285           {
286             g = Math.sqrt(h);
287           }
288
289           e[i - 1] = scale * g;
290           h -= (f * g);
291           setValue(i - 1, l - 1, f - g);
292           f = 0.0;
293
294           for (j = 1; j <= l; j++)
295           {
296             double val = getValue(i - 1, j - 1) / h;
297             setValue(j - 1, i - 1, val);
298             g = 0.0;
299
300             for (k = 1; k <= j; k++)
301             {
302               g += (getValue(j - 1, k - 1) * getValue(i - 1, k - 1));
303             }
304
305             for (k = j + 1; k <= l; k++)
306             {
307               g += (getValue(k - 1, j - 1) * getValue(i - 1, k - 1));
308             }
309
310             e[j - 1] = g / h;
311             f += (e[j - 1] * getValue(i - 1, j - 1));
312           }
313
314           hh = f / (h + h);
315
316           for (j = 1; j <= l; j++)
317           {
318             f = getValue(i - 1, j - 1);
319             g = e[j - 1] - (hh * f);
320             e[j - 1] = g;
321
322             for (k = 1; k <= j; k++)
323             {
324               double val = (f * e[k - 1]) + (g * getValue(i - 1, k - 1));
325               addValue(j - 1, k - 1, -val);
326             }
327           }
328         }
329       }
330       else
331       {
332         e[i - 1] = getValue(i - 1, l - 1);
333       }
334
335       d[i - 1] = h;
336     }
337
338     d[0] = 0.0;
339     e[0] = 0.0;
340
341     for (i = 1; i <= n; i++)
342     {
343       final int l = i - 1;
344
345       if (d[i - 1] != 0.0)
346       {
347         for (j = 1; j <= l; j++)
348         {
349           g = 0.0;
350
351           for (k = 1; k <= l; k++)
352           {
353             g += (getValue(i - 1, k - 1) * getValue(k - 1, j - 1));
354           }
355
356           for (k = 1; k <= l; k++)
357           {
358             addValue(k - 1, j - 1, -(g * getValue(k - 1, i - 1)));
359           }
360         }
361       }
362
363       d[i - 1] = getValue(i - 1, i - 1);
364       setValue(i - 1, i - 1, 1.0);
365
366       for (j = 1; j <= l; j++)
367       {
368         setValue(j - 1, i - 1, 0.0);
369         setValue(i - 1, j - 1, 0.0);
370       }
371     }
372   }
373
374   /**
375    * Adds f to the value at [i, j] and returns the new value
376    * 
377    * @param i
378    * @param j
379    * @param f
380    */
381   protected double addValue(int i, int j, double f)
382   {
383     double v = value[i][j] + f;
384     value[i][j] = v;
385     return v;
386   }
387
388   /**
389    * Divides the value at [i, j] by divisor and returns the new value. If d is
390    * zero, returns the unchanged value.
391    * 
392    * @param i
393    * @param j
394    * @param divisor
395    * @return
396    */
397   protected double divideValue(int i, int j, double divisor)
398   {
399     if (divisor == 0d)
400     {
401       return getValue(i, j);
402     }
403     double v = value[i][j];
404     v = v / divisor;
405     value[i][j] = v;
406     return v;
407   }
408
409   /**
410    * DOCUMENT ME!
411    */
412   @Override
413   public void tqli() throws Exception
414   {
415     int n = rows;
416
417     int m;
418     int l;
419     int iter;
420     int i;
421     int k;
422     double s;
423     double r;
424     double p;
425
426     double g;
427     double f;
428     double dd;
429     double c;
430     double b;
431
432     for (i = 2; i <= n; i++)
433     {
434       e[i - 2] = e[i - 1];
435     }
436
437     e[n - 1] = 0.0;
438
439     for (l = 1; l <= n; l++)
440     {
441       iter = 0;
442
443       do
444       {
445         for (m = l; m <= (n - 1); m++)
446         {
447           dd = Math.abs(d[m - 1]) + Math.abs(d[m]);
448
449           if ((Math.abs(e[m - 1]) + dd) == dd)
450           {
451             break;
452           }
453         }
454
455         if (m != l)
456         {
457           iter++;
458
459           if (iter == MAX_ITER)
460           {
461             throw new Exception(MessageManager.formatMessage(
462                     "exception.matrix_too_many_iteration", new String[]
463                     { "tqli", Integer.valueOf(MAX_ITER).toString() }));
464           }
465           else
466           {
467             // System.out.println("Iteration " + iter);
468           }
469
470           g = (d[l] - d[l - 1]) / (2.0 * e[l - 1]);
471           r = Math.sqrt((g * g) + 1.0);
472           g = d[m - 1] - d[l - 1] + (e[l - 1] / (g + sign(r, g)));
473           c = 1.0;
474           s = c;
475           p = 0.0;
476
477           for (i = m - 1; i >= l; i--)
478           {
479             f = s * e[i - 1];
480             b = c * e[i - 1];
481
482             if (Math.abs(f) >= Math.abs(g))
483             {
484               c = g / f;
485               r = Math.sqrt((c * c) + 1.0);
486               e[i] = f * r;
487               s = 1.0 / r;
488               c *= s;
489             }
490             else
491             {
492               s = f / g;
493               r = Math.sqrt((s * s) + 1.0);
494               e[i] = g * r;
495               c = 1.0 / r;
496               s *= c;
497             }
498
499             g = d[i] - p;
500             r = ((d[i - 1] - g) * s) + (2.0 * c * b);
501             p = s * r;
502             d[i] = g + p;
503             g = (c * r) - b;
504
505             for (k = 1; k <= n; k++)
506             {
507               f = getValue(k - 1, i);
508               setValue(k - 1, i, (s * getValue(k - 1, i - 1)) + (c * f));
509               setValue(k - 1, i - 1,
510                       (c * getValue(k - 1, i - 1)) - (s * f));
511             }
512           }
513
514           d[l - 1] = d[l - 1] - p;
515           e[l - 1] = g;
516           e[m - 1] = 0.0;
517         }
518       } while (m != l);
519     }
520   }
521
522   @Override
523   public double getValue(int i, int j)
524   {
525     return value[i][j];
526   }
527
528   @Override
529   public void setValue(int i, int j, double val)
530   {
531     value[i][j] = val;
532   }
533
534   /**
535    * DOCUMENT ME!
536    */
537   public void tred2()
538   {
539     int n = rows;
540     int l;
541     int k;
542     int j;
543     int i;
544
545     double scale;
546     double hh;
547     double h;
548     double g;
549     double f;
550
551     this.d = new double[rows];
552     this.e = new double[rows];
553
554     for (i = n - 1; i >= 1; i--)
555     {
556       l = i - 1;
557       h = 0.0;
558       scale = 0.0;
559
560       if (l > 0)
561       {
562         for (k = 0; k < l; k++)
563         {
564           scale += Math.abs(value[i][k]);
565         }
566
567         if (scale == 0.0)
568         {
569           e[i] = value[i][l];
570         }
571         else
572         {
573           for (k = 0; k < l; k++)
574           {
575             value[i][k] /= scale;
576             h += (value[i][k] * value[i][k]);
577           }
578
579           f = value[i][l];
580
581           if (f > 0)
582           {
583             g = -1.0 * Math.sqrt(h);
584           }
585           else
586           {
587             g = Math.sqrt(h);
588           }
589
590           e[i] = scale * g;
591           h -= (f * g);
592           value[i][l] = f - g;
593           f = 0.0;
594
595           for (j = 0; j < l; j++)
596           {
597             value[j][i] = value[i][j] / h;
598             g = 0.0;
599
600             for (k = 0; k < j; k++)
601             {
602               g += (value[j][k] * value[i][k]);
603             }
604
605             for (k = j; k < l; k++)
606             {
607               g += (value[k][j] * value[i][k]);
608             }
609
610             e[j] = g / h;
611             f += (e[j] * value[i][j]);
612           }
613
614           hh = f / (h + h);
615
616           for (j = 0; j < l; j++)
617           {
618             f = value[i][j];
619             g = e[j] - (hh * f);
620             e[j] = g;
621
622             for (k = 0; k < j; k++)
623             {
624               value[j][k] -= ((f * e[k]) + (g * value[i][k]));
625             }
626           }
627         }
628       }
629       else
630       {
631         e[i] = value[i][l];
632       }
633
634       d[i] = h;
635     }
636
637     d[0] = 0.0;
638     e[0] = 0.0;
639
640     for (i = 0; i < n; i++)
641     {
642       l = i - 1;
643
644       if (d[i] != 0.0)
645       {
646         for (j = 0; j < l; j++)
647         {
648           g = 0.0;
649
650           for (k = 0; k < l; k++)
651           {
652             g += (value[i][k] * value[k][j]);
653           }
654
655           for (k = 0; k < l; k++)
656           {
657             value[k][j] -= (g * value[k][i]);
658           }
659         }
660       }
661
662       d[i] = value[i][i];
663       value[i][i] = 1.0;
664
665       for (j = 0; j < l; j++)
666       {
667         value[j][i] = 0.0;
668         value[i][j] = 0.0;
669       }
670     }
671   }
672
673   /**
674    * DOCUMENT ME!
675    */
676   public void tqli2() throws Exception
677   {
678     int n = rows;
679
680     int m;
681     int l;
682     int iter;
683     int i;
684     int k;
685     double s;
686     double r;
687     double p;
688     ;
689
690     double g;
691     double f;
692     double dd;
693     double c;
694     double b;
695
696     for (i = 2; i <= n; i++)
697     {
698       e[i - 2] = e[i - 1];
699     }
700
701     e[n - 1] = 0.0;
702
703     for (l = 1; l <= n; l++)
704     {
705       iter = 0;
706
707       do
708       {
709         for (m = l; m <= (n - 1); m++)
710         {
711           dd = Math.abs(d[m - 1]) + Math.abs(d[m]);
712
713           if ((Math.abs(e[m - 1]) + dd) == dd)
714           {
715             break;
716           }
717         }
718
719         if (m != l)
720         {
721           iter++;
722
723           if (iter == MAX_ITER)
724           {
725             throw new Exception(MessageManager.formatMessage(
726                     "exception.matrix_too_many_iteration", new String[]
727                     { "tqli2", Integer.valueOf(MAX_ITER).toString() }));
728           }
729           else
730           {
731             // System.out.println("Iteration " + iter);
732           }
733
734           g = (d[l] - d[l - 1]) / (2.0 * e[l - 1]);
735           r = Math.sqrt((g * g) + 1.0);
736           g = d[m - 1] - d[l - 1] + (e[l - 1] / (g + sign(r, g)));
737           c = 1.0;
738           s = c;
739           p = 0.0;
740
741           for (i = m - 1; i >= l; i--)
742           {
743             f = s * e[i - 1];
744             b = c * e[i - 1];
745
746             if (Math.abs(f) >= Math.abs(g))
747             {
748               c = g / f;
749               r = Math.sqrt((c * c) + 1.0);
750               e[i] = f * r;
751               s = 1.0 / r;
752               c *= s;
753             }
754             else
755             {
756               s = f / g;
757               r = Math.sqrt((s * s) + 1.0);
758               e[i] = g * r;
759               c = 1.0 / r;
760               s *= c;
761             }
762
763             g = d[i] - p;
764             r = ((d[i - 1] - g) * s) + (2.0 * c * b);
765             p = s * r;
766             d[i] = g + p;
767             g = (c * r) - b;
768
769             for (k = 1; k <= n; k++)
770             {
771               f = value[k - 1][i];
772               value[k - 1][i] = (s * value[k - 1][i - 1]) + (c * f);
773               value[k - 1][i - 1] = (c * value[k - 1][i - 1]) - (s * f);
774             }
775           }
776
777           d[l - 1] = d[l - 1] - p;
778           e[l - 1] = g;
779           e[m - 1] = 0.0;
780         }
781       } while (m != l);
782     }
783   }
784
785   /**
786    * Answers the first argument with the sign of the second argument
787    * 
788    * @param a
789    * @param b
790    * 
791    * @return
792    */
793   static double sign(double a, double b)
794   {
795     if (b < 0)
796     {
797       return -Math.abs(a);
798     }
799     else
800     {
801       return Math.abs(a);
802     }
803   }
804
805   /**
806    * Returns an array containing the values in the specified column
807    * 
808    * @param col
809    * 
810    * @return
811    */
812   public double[] getColumn(int col)
813   {
814     double[] out = new double[rows];
815
816     for (int i = 0; i < rows; i++)
817     {
818       out[i] = value[i][col];
819     }
820
821     return out;
822   }
823
824   /**
825    * DOCUMENT ME!
826    * 
827    * @param ps
828    *          DOCUMENT ME!
829    * @param format
830    */
831   @Override
832   public void printD(PrintStream ps, String format)
833   {
834     for (int j = 0; j < rows; j++)
835     {
836       Format.print(ps, format, d[j]);
837     }
838   }
839
840   /**
841    * DOCUMENT ME!
842    * 
843    * @param ps
844    *          DOCUMENT ME!
845    * @param format
846    *          TODO
847    */
848   @Override
849   public void printE(PrintStream ps, String format)
850   {
851     for (int j = 0; j < rows; j++)
852     {
853       Format.print(ps, format, e[j]);
854     }
855   }
856
857   @Override
858   public double[] getD()
859   {
860     return d;
861   }
862
863   @Override
864   public double[] getE()
865   {
866     return e;
867   }
868
869   @Override
870   public int height()
871   {
872     return rows;
873   }
874
875   @Override
876   public int width()
877   {
878     return cols;
879   }
880
881   @Override
882   public double[] getRow(int i)
883   {
884     double[] row = new double[cols];
885     System.arraycopy(value[i], 0, row, 0, cols);
886     return row;
887   }
888
889   /**
890    * Returns a length 2 array of {minValue, maxValue} of all values in the
891    * matrix. Returns null if the matrix is null or empty.
892    * 
893    * @return
894    */
895   double[] findMinMax()
896   {
897     if (value == null)
898     {
899       return null;
900     }
901     double min = Double.MAX_VALUE;
902     double max = -Double.MAX_VALUE;
903     boolean empty = true;
904     for (double[] row : value)
905     {
906       if (row != null)
907       {
908         for (double x : row)
909         {
910           empty = false;
911           if (x > max)
912           {
913             max = x;
914           }
915           if (x < min)
916           {
917             min = x;
918           }
919         }
920       }
921     }
922     return empty ? null : new double[] { min, max };
923   }
924
925   /**
926    * {@inheritDoc}
927    */
928   @Override
929   public void reverseRange(boolean maxToZero)
930   {
931     if (value == null)
932     {
933       return;
934     }
935     double[] minMax = findMinMax();
936     if (minMax == null)
937     {
938       return; // empty matrix
939     }
940     double subtractFrom = maxToZero ? minMax[1] : minMax[0] + minMax[1];
941
942     for (double[] row : value)
943     {
944       if (row != null)
945       {
946         int j = 0;
947         for (double x : row)
948         {
949           row[j] = subtractFrom - x;
950           j++;
951         }
952       }
953     }
954   }
955
956   /**
957    * Multiplies every entry in the matrix by the given value.
958    * 
959    * @param
960    */
961   @Override
962   public void multiply(double by)
963   {
964     for (double[] row : value)
965     {
966       if (row != null)
967       {
968         for (int i = 0; i < row.length; i++)
969         {
970           row[i] *= by;
971         }
972       }
973     }
974   }
975
976   @Override
977   public void setD(double[] v)
978   {
979     d = v;
980   }
981
982   @Override
983   public void setE(double[] v)
984   {
985     e = v;
986   }
987
988   public double getTotal()
989   {
990     double d = 0d;
991     for (int i = 0; i < this.height(); i++)
992     {
993       for (int j = 0; j < this.width(); j++)
994       {
995         d += value[i][j];
996       }
997     }
998     return d;
999   }
1000
1001   /**
1002    * {@inheritDoc}
1003    */
1004   @Override
1005   public boolean equals(MatrixI m2, double delta)
1006   {
1007     if (m2 == null || this.height() != m2.height()
1008             || this.width() != m2.width())
1009     {
1010       return false;
1011     }
1012     for (int i = 0; i < this.height(); i++)
1013     {
1014       for (int j = 0; j < this.width(); j++)
1015       {
1016         double diff = this.getValue(i, j) - m2.getValue(i, j);
1017         if (Math.abs(diff) > delta)
1018         {
1019           return false;
1020         }
1021       }
1022     }
1023     return true;
1024   }
1025 }