JAL-1767 port save/load PCA panel to JAXB - JAL-3063
[jalview.git] / src / jalview / project / Jalview2XML.java
index 90c09d7..9f796f2 100644 (file)
  */
 package jalview.project;
 
+import static jalview.math.RotatableMatrix.Axis.X;
+import static jalview.math.RotatableMatrix.Axis.Y;
+import static jalview.math.RotatableMatrix.Axis.Z;
+
 import jalview.analysis.Conservation;
+import jalview.analysis.PCA;
+import jalview.analysis.scoremodels.ScoreModels;
+import jalview.analysis.scoremodels.SimilarityParams;
 import jalview.api.FeatureColourI;
 import jalview.api.ViewStyleI;
+import jalview.api.analysis.ScoreModelI;
+import jalview.api.analysis.SimilarityParamsI;
 import jalview.api.structures.JalviewStructureDisplayI;
 import jalview.bin.Cache;
 import jalview.datamodel.AlignedCodonFrame;
@@ -31,6 +40,7 @@ import jalview.datamodel.AlignmentAnnotation;
 import jalview.datamodel.AlignmentI;
 import jalview.datamodel.GraphLine;
 import jalview.datamodel.PDBEntry;
+import jalview.datamodel.Point;
 import jalview.datamodel.RnaViewerModel;
 import jalview.datamodel.SequenceFeature;
 import jalview.datamodel.SequenceGroup;
@@ -52,6 +62,7 @@ import jalview.gui.FeatureRenderer;
 import jalview.gui.Jalview2XML_V1;
 import jalview.gui.JvOptionPane;
 import jalview.gui.OOMWarning;
+import jalview.gui.PCAPanel;
 import jalview.gui.PaintRefresher;
 import jalview.gui.SplitFrame;
 import jalview.gui.StructureViewer;
@@ -61,6 +72,8 @@ import jalview.gui.TreePanel;
 import jalview.io.DataSourceType;
 import jalview.io.FileFormat;
 import jalview.io.NewickFile;
+import jalview.math.Matrix;
+import jalview.math.MatrixI;
 import jalview.renderer.ResidueShaderI;
 import jalview.schemes.AnnotationColourGradient;
 import jalview.schemes.ColourSchemeI;
@@ -77,6 +90,7 @@ import jalview.util.StringUtils;
 import jalview.util.jarInputStreamProvider;
 import jalview.util.matcher.Condition;
 import jalview.viewmodel.AlignmentViewport;
+import jalview.viewmodel.PCAModel;
 import jalview.viewmodel.ViewportRanges;
 import jalview.viewmodel.seqfeatures.FeatureRendererSettings;
 import jalview.viewmodel.seqfeatures.FeaturesDisplayed;
@@ -92,6 +106,8 @@ import jalview.xml.binding.jalview.Annotation;
 import jalview.xml.binding.jalview.Annotation.ThresholdLine;
 import jalview.xml.binding.jalview.AnnotationColourScheme;
 import jalview.xml.binding.jalview.AnnotationElement;
+import jalview.xml.binding.jalview.DoubleMatrix;
+import jalview.xml.binding.jalview.DoubleVector;
 import jalview.xml.binding.jalview.Feature;
 import jalview.xml.binding.jalview.Feature.OtherData;
 import jalview.xml.binding.jalview.FeatureMatcherSet.CompoundMatcher;
@@ -106,6 +122,11 @@ import jalview.xml.binding.jalview.JalviewModel.JSeq.Pdbids;
 import jalview.xml.binding.jalview.JalviewModel.JSeq.Pdbids.StructureState;
 import jalview.xml.binding.jalview.JalviewModel.JSeq.RnaViewer;
 import jalview.xml.binding.jalview.JalviewModel.JSeq.RnaViewer.SecondaryStructure;
+import jalview.xml.binding.jalview.JalviewModel.PcaViewer;
+import jalview.xml.binding.jalview.JalviewModel.PcaViewer.Axis;
+import jalview.xml.binding.jalview.JalviewModel.PcaViewer.SeqPointMax;
+import jalview.xml.binding.jalview.JalviewModel.PcaViewer.SeqPointMin;
+import jalview.xml.binding.jalview.JalviewModel.PcaViewer.SequencePoint;
 import jalview.xml.binding.jalview.JalviewModel.Tree;
 import jalview.xml.binding.jalview.JalviewModel.UserColours;
 import jalview.xml.binding.jalview.JalviewModel.Viewport;
@@ -118,6 +139,7 @@ import jalview.xml.binding.jalview.MapListType.MapListTo;
 import jalview.xml.binding.jalview.Mapping;
 import jalview.xml.binding.jalview.NoValueColour;
 import jalview.xml.binding.jalview.ObjectFactory;
+import jalview.xml.binding.jalview.PcaDataType;
 import jalview.xml.binding.jalview.Pdbentry.Property;
 import jalview.xml.binding.jalview.Sequence;
 import jalview.xml.binding.jalview.Sequence.DBRef;
@@ -1215,6 +1237,9 @@ public class Jalview2XML
               tree.setXpos(tp.getX());
               tree.setYpos(tp.getY());
               tree.setId(makeHashCode(tp, null));
+              tree.setLinkToAllViews(
+                      tp.getTreeCanvas().isApplyToAllViews());
+
               // jms.addTree(tree);
               object.getTree().add(tree);
             }
@@ -1223,6 +1248,24 @@ public class Jalview2XML
       }
     }
 
+    /*
+     * save PCA viewers
+     */
+    if (!storeDS && Desktop.desktop != null)
+    {
+      for (JInternalFrame frame : Desktop.desktop.getAllFrames())
+      {
+        if (frame instanceof PCAPanel)
+        {
+          PCAPanel panel = (PCAPanel) frame;
+          if (panel.getAlignViewport().getAlignment() == jal)
+          {
+            savePCA(panel, object);
+          }
+        }
+      }
+    }
+
     // SAVE ANNOTATIONS
     /**
      * store forward refs from an annotationRow to any groups
@@ -1630,6 +1673,198 @@ public class Jalview2XML
   }
 
   /**
+   * Writes PCA viewer attributes and computed values to an XML model object and
+   * adds it to the JalviewModel. Any exceptions are reported by logging.
+   */
+  protected void savePCA(PCAPanel panel, JalviewModel object)
+  {
+    try
+    {
+      PcaViewer viewer = new PcaViewer();
+      viewer.setHeight(panel.getHeight());
+      viewer.setWidth(panel.getWidth());
+      viewer.setXpos(panel.getX());
+      viewer.setYpos(panel.getY());
+      viewer.setTitle(panel.getTitle());
+      PCAModel pcaModel = panel.getPcaModel();
+      viewer.setScoreModelName(pcaModel.getScoreModelName());
+      viewer.setXDim(panel.getSelectedDimensionIndex(X));
+      viewer.setYDim(panel.getSelectedDimensionIndex(Y));
+      viewer.setZDim(panel.getSelectedDimensionIndex(Z));
+      viewer.setBgColour(
+              panel.getRotatableCanvas().getBackgroundColour().getRGB());
+      viewer.setScaleFactor(panel.getRotatableCanvas().getScaleFactor());
+      float[] spMin = panel.getRotatableCanvas().getSeqMin();
+      SeqPointMin spmin = new SeqPointMin();
+      spmin.setXPos(spMin[0]);
+      spmin.setYPos(spMin[1]);
+      spmin.setZPos(spMin[2]);
+      viewer.setSeqPointMin(spmin);
+      float[] spMax = panel.getRotatableCanvas().getSeqMax();
+      SeqPointMax spmax = new SeqPointMax();
+      spmax.setXPos(spMax[0]);
+      spmax.setYPos(spMax[1]);
+      spmax.setZPos(spMax[2]);
+      viewer.setSeqPointMax(spmax);
+      viewer.setShowLabels(panel.getRotatableCanvas().isShowLabels());
+      viewer.setLinkToAllViews(
+              panel.getRotatableCanvas().isApplyToAllViews());
+      SimilarityParamsI sp = pcaModel.getSimilarityParameters();
+      viewer.setIncludeGaps(sp.includeGaps());
+      viewer.setMatchGaps(sp.matchGaps());
+      viewer.setIncludeGappedColumns(sp.includeGappedColumns());
+      viewer.setDenominateByShortestLength(sp.denominateByShortestLength());
+
+      /*
+       * sequence points on display
+       */
+      for (jalview.datamodel.SequencePoint spt : pcaModel
+              .getSequencePoints())
+      {
+        SequencePoint point = new SequencePoint();
+        point.setSequenceRef(seqHash(spt.getSequence()));
+        point.setXPos(spt.coord.x);
+        point.setYPos(spt.coord.y);
+        point.setZPos(spt.coord.z);
+        viewer.getSequencePoint().add(point);
+      }
+
+      /*
+       * (end points of) axes on display
+       */
+      for (Point p : panel.getRotatableCanvas().getAxisEndPoints())
+      {
+
+        Axis axis = new Axis();
+        axis.setXPos(p.x);
+        axis.setYPos(p.y);
+        axis.setZPos(p.z);
+        viewer.getAxis().add(axis);
+      }
+
+      /*
+       * raw PCA data (note we are not restoring PCA inputs here -
+       * alignment view, score model, similarity parameters)
+       */
+      PcaDataType data = new PcaDataType();
+      viewer.setPcaData(data);
+      PCA pca = pcaModel.getPcaData();
+
+      DoubleMatrix pm = new DoubleMatrix();
+      saveDoubleMatrix(pca.getPairwiseScores(), pm);
+      data.setPairwiseMatrix(pm);
+
+      DoubleMatrix tm = new DoubleMatrix();
+      saveDoubleMatrix(pca.getTridiagonal(), tm);
+      data.setTridiagonalMatrix(tm);
+
+      DoubleMatrix eigenMatrix = new DoubleMatrix();
+      data.setEigenMatrix(eigenMatrix);
+      saveDoubleMatrix(pca.getEigenmatrix(), eigenMatrix);
+
+      object.getPcaViewer().add(viewer);
+    } catch (Throwable t)
+    {
+      Cache.log.error("Error saving PCA: " + t.getMessage());
+    }
+  }
+
+  /**
+   * Stores values from a matrix into an XML element, including (if present) the
+   * D or E vectors
+   * 
+   * @param m
+   * @param xmlMatrix
+   * @see #loadDoubleMatrix(DoubleMatrix)
+   */
+  protected void saveDoubleMatrix(MatrixI m, DoubleMatrix xmlMatrix)
+  {
+    xmlMatrix.setRows(m.height());
+    xmlMatrix.setColumns(m.width());
+    for (int i = 0; i < m.height(); i++)
+    {
+      DoubleVector row = new DoubleVector();
+      for (int j = 0; j < m.width(); j++)
+      {
+        row.getV().add(m.getValue(i, j));
+      }
+      xmlMatrix.getRow().add(row);
+    }
+    if (m.getD() != null)
+    {
+      DoubleVector dVector = new DoubleVector();
+      for (double d : m.getD())
+      {
+        dVector.getV().add(d);
+      }
+      xmlMatrix.setD(dVector);
+    }
+    if (m.getE() != null)
+    {
+      DoubleVector eVector = new DoubleVector();
+      for (double e : m.getE())
+      {
+        eVector.getV().add(e);
+      }
+      xmlMatrix.setE(eVector);
+    }
+  }
+
+  /**
+   * Loads XML matrix data into a new Matrix object, including the D and/or E
+   * vectors (if present)
+   * 
+   * @param mData
+   * @return
+   * @see Jalview2XML#saveDoubleMatrix(MatrixI, DoubleMatrix)
+   */
+  protected MatrixI loadDoubleMatrix(DoubleMatrix mData)
+  {
+    int rows = mData.getRows();
+    double[][] vals = new double[rows][];
+    List<Double> dVector;
+    double[] vec;
+
+    for (int i = 0; i < rows; i++)
+    {
+      dVector = mData.getRow().get(i).getV();
+      vals[i] = new double[dVector.size()];
+      int dvi = 0;
+      for (Double d : dVector)
+      {
+        vals[i][dvi++] = d;
+      }
+    }
+
+    MatrixI m = new Matrix(vals);
+
+    if (mData.getD() != null)
+    {
+      dVector = mData.getD().getV();
+      vec = new double[dVector.size()];
+      int dvi = 0;
+      for (Double d : dVector)
+      {
+        vec[dvi++] = d;
+      }
+      m.setD(vec);
+    }
+    if (mData.getE() != null)
+    {
+      dVector = mData.getE().getV();
+      vec = new double[dVector.size()];
+      int dvi = 0;
+      for (Double d : dVector)
+      {
+        vec[dvi++] = d;
+      }
+      m.setE(vec);
+    }
+
+    return m;
+  }
+
+  /**
    * Save any Varna viewers linked to this sequence. Writes an rnaViewer element
    * for each viewer, with
    * <ul>
@@ -3799,6 +4034,7 @@ public class Jalview2XML
     if (loadTreesAndStructures)
     {
       loadTrees(jalviewModel, view, af, av, ap);
+      loadPCAViewers(jalviewModel, ap);
       loadPDBStructures(jprovider, jseqs, af, ap);
       loadRnaViewers(jprovider, jseqs, ap);
     }
@@ -3947,6 +4183,8 @@ public class Jalview2XML
           // TODO: verify 'associate with all views' works still
           tp.getTreeCanvas().setViewport(av); // af.viewport;
           tp.getTreeCanvas().setAssociatedPanel(ap); // af.alignPanel;
+          // FIXME: should we use safeBoolean here ?
+          tp.getTreeCanvas().setApplyToAllViews(tree.isLinkToAllViews());
 
         }
         if (tp == null)
@@ -5935,6 +6173,128 @@ public class Jalview2XML
   }
 
   /**
+   * Loads any saved PCA viewers
+   * 
+   * @param jms
+   * @param ap
+   */
+  protected void loadPCAViewers(JalviewModel model, AlignmentPanel ap)
+  {
+    try
+    {
+      List<PcaViewer> pcaviewers = model.getPcaViewer();
+      for (PcaViewer viewer : pcaviewers)
+      {
+        String modelName = viewer.getScoreModelName();
+        SimilarityParamsI params = new SimilarityParams(
+                viewer.isIncludeGappedColumns(), viewer.isMatchGaps(),
+                viewer.isIncludeGaps(),
+                viewer.isDenominateByShortestLength());
+
+        /*
+         * create the panel (without computing the PCA)
+         */
+        PCAPanel panel = new PCAPanel(ap, modelName, params);
+
+        panel.setTitle(viewer.getTitle());
+        panel.setBounds(new Rectangle(viewer.getXpos(), viewer.getYpos(),
+                viewer.getWidth(), viewer.getHeight()));
+
+        boolean showLabels = viewer.isShowLabels();
+        panel.setShowLabels(showLabels);
+        panel.getRotatableCanvas().setShowLabels(showLabels);
+        panel.getRotatableCanvas()
+                .setBgColour(new Color(viewer.getBgColour()));
+        panel.getRotatableCanvas()
+                .setApplyToAllViews(viewer.isLinkToAllViews());
+
+        /*
+         * load PCA output data
+         */
+        ScoreModelI scoreModel = ScoreModels.getInstance()
+                .getScoreModel(modelName, ap);
+        PCA pca = new PCA(null, scoreModel, params);
+        PcaDataType pcaData = viewer.getPcaData();
+
+        MatrixI pairwise = loadDoubleMatrix(pcaData.getPairwiseMatrix());
+        pca.setPairwiseScores(pairwise);
+
+        MatrixI triDiag = loadDoubleMatrix(pcaData.getTridiagonalMatrix());
+        pca.setTridiagonal(triDiag);
+
+        MatrixI result = loadDoubleMatrix(pcaData.getEigenMatrix());
+        pca.setEigenmatrix(result);
+
+        panel.getPcaModel().setPCA(pca);
+
+        /*
+         * we haven't saved the input data! (JAL-2647 to do)
+         */
+        panel.setInputData(null);
+
+        /*
+         * add the sequence points for the PCA display
+         */
+        List<jalview.datamodel.SequencePoint> seqPoints = new ArrayList<>();
+        for (SequencePoint sp : viewer.getSequencePoint())
+        {
+          String seqId = sp.getSequenceRef();
+          SequenceI seq = seqRefIds.get(seqId);
+          if (seq == null)
+          {
+            throw new IllegalStateException(
+                    "Unmatched seqref for PCA: " + seqId);
+          }
+          Point pt = new Point(sp.getXPos(), sp.getYPos(), sp.getZPos());
+          jalview.datamodel.SequencePoint seqPoint = new jalview.datamodel.SequencePoint(
+                  seq, pt);
+          seqPoints.add(seqPoint);
+        }
+        panel.getRotatableCanvas().setPoints(seqPoints, seqPoints.size());
+
+        /*
+         * set min-max ranges and scale after setPoints (which recomputes them)
+         */
+        panel.getRotatableCanvas().setScaleFactor(viewer.getScaleFactor());
+        SeqPointMin spMin = viewer.getSeqPointMin();
+        float[] min = new float[] { spMin.getXPos(), spMin.getYPos(),
+            spMin.getZPos() };
+        SeqPointMax spMax = viewer.getSeqPointMax();
+        float[] max = new float[] { spMax.getXPos(), spMax.getYPos(),
+            spMax.getZPos() };
+        panel.getRotatableCanvas().setSeqMinMax(min, max);
+
+        // todo: hold points list in PCAModel only
+        panel.getPcaModel().setSequencePoints(seqPoints);
+
+        panel.setSelectedDimensionIndex(viewer.getXDim(), X);
+        panel.setSelectedDimensionIndex(viewer.getYDim(), Y);
+        panel.setSelectedDimensionIndex(viewer.getZDim(), Z);
+
+        // is this duplication needed?
+        panel.setTop(seqPoints.size() - 1);
+        panel.getPcaModel().setTop(seqPoints.size() - 1);
+
+        /*
+         * add the axes' end points for the display
+         */
+        for (int i = 0; i < 3; i++)
+        {
+          Axis axis = viewer.getAxis().get(i);
+          panel.getRotatableCanvas().getAxisEndPoints()[i] = new Point(
+                  axis.getXPos(), axis.getYPos(), axis.getZPos());
+        }
+
+        Desktop.addInternalFrame(panel, MessageManager.formatMessage(
+                "label.calc_title", "PCA", modelName), 475, 450);
+      }
+    } catch (Exception ex)
+    {
+      Cache.log.error("Error loading PCA: " + ex.toString());
+    }
+  }
+
+  /**
    * Populates an XML model of the feature colour scheme for one feature type
    * 
    * @param featureType