JAL-4134 JAL-3855 store/restore groups, tree and threshold used to cluster a PAE...
[jalview.git] / src / jalview / ws / datamodel / alphafold / PAEContactMatrix.java
index 30c77d2..87ccab6 100644 (file)
@@ -43,12 +43,13 @@ public class PAEContactMatrix implements ContactMatrixI
     }
     length = _refSeq.getEnd() - _refSeq.getStart() + 1;
   }
+
   @SuppressWarnings("unchecked")
   public PAEContactMatrix(SequenceI _refSeq, Map<String, Object> pae_obj)
   {
     setRefSeq(_refSeq);
     // convert the lists to primitive arrays and store
-    
+
     if (!MapUtils.containsAKey(pae_obj, "predicted_aligned_error", "pae"))
     {
       parse_version_1_pAE(pae_obj);
@@ -59,32 +60,35 @@ public class PAEContactMatrix implements ContactMatrixI
       parse_version_2_pAE(pae_obj);
     }
   }
+
   /**
    * construct a sequence associated PAE matrix directly from a float array
+   * 
    * @param _refSeq
    * @param matrix
    */
   public PAEContactMatrix(SequenceI _refSeq, float[][] matrix)
   {
     setRefSeq(_refSeq);
-    maxcol=0;
-    for (float[] row:matrix)
+    maxcol = 0;
+    for (float[] row : matrix)
     {
-      if (row.length>maxcol)
+      if (row.length > maxcol)
       {
-        maxcol=row.length;
+        maxcol = row.length;
       }
-      maxscore=row[0];
-      for (float f:row)
+      maxscore = row[0];
+      for (float f : row)
       {
-        if (maxscore<f) {
-          maxscore=f;
+        if (maxscore < f)
+        {
+          maxscore = f;
         }
       }
     }
-    maxrow=matrix.length;
+    maxrow = matrix.length;
     elements = matrix;
-    
+
   }
 
   /**
@@ -99,11 +103,10 @@ public class PAEContactMatrix implements ContactMatrixI
     maxscore = ((Double) MapUtils.getFirst(pae_obj,
             "max_predicted_aligned_error", "max_pae")).floatValue();
     List<List<Long>> scoreRows = ((List<List<Long>>) MapUtils
-            .getFirst(pae_obj, "predicted_aligned_error", "pae"))
-            ;
+            .getFirst(pae_obj, "predicted_aligned_error", "pae"));
     elements = new float[scoreRows.size()][scoreRows.size()];
     int row = 0, col = 0;
-    for (List<Long> scoreRow:scoreRows)
+    for (List<Long> scoreRow : scoreRows)
     {
       Iterator<Long> scores = scoreRow.iterator();
       while (scores.hasNext())
@@ -112,7 +115,7 @@ public class PAEContactMatrix implements ContactMatrixI
         if (d instanceof Double)
           elements[row][col++] = ((Double) d).longValue();
         else
-          elements[row][col++] = (float) ((Long)d).longValue();
+          elements[row][col++] = (float) ((Long) d).longValue();
       }
       row++;
       col = 0;
@@ -177,7 +180,7 @@ public class PAEContactMatrix implements ContactMatrixI
       @Override
       public int getContactHeight()
       {
-        return maxcol-1;
+        return maxcol - 1;
       }
 
       @Override
@@ -225,20 +228,28 @@ public class PAEContactMatrix implements ContactMatrixI
   @Override
   public String getAnnotLabel()
   {
-    return "pAE Matrix";
+    StringBuilder label = new StringBuilder("PAE Matrix");
+    if (this.getReferenceSeq() != null)
+    {
+      label.append(":").append(this.getReferenceSeq().getDisplayId(false));
+    }
+    return label.toString();
   }
 
-  public static final String PAEMATRIX="PAE_MATRIX";
+  public static final String PAEMATRIX = "PAE_MATRIX";
+
   @Override
   public String getType()
   {
     return PAEMATRIX;
   }
+
   @Override
   public int getWidth()
   {
     return length;
   }
+
   @Override
   public int getHeight()
   {
@@ -255,12 +266,15 @@ public class PAEContactMatrix implements ContactMatrixI
   {
     return newick;
   }
+  boolean abs;
+  double thresh;
+  String treeType=null;
   public void makeGroups(float thresh,boolean abs)
   {
     AverageDistanceEngine clusterer = new AverageDistanceEngine(null, null, this);
     double height = clusterer.findHeight(clusterer.getTopNode());
     newick = new jalview.io.NewickFile(clusterer.getTopNode(),false,true).print();
-
+    treeType = "UPGMA";
     Console.trace("Newick string\n"+newick);
 
     List<BinaryNode> nodegroups;
@@ -276,7 +290,8 @@ public class PAEContactMatrix implements ContactMatrixI
       nodegroups = new ArrayList<BinaryNode>();
       nodegroups.add(clusterer.getTopNode());
     }
-
+    this.abs=abs;
+    this.thresh=thresh;
     groups = new ArrayList<>();
     for (BinaryNode root:nodegroups)
     {
@@ -300,4 +315,28 @@ public class PAEContactMatrix implements ContactMatrixI
     }
     return ContactMatrixI.super.getGroupsFor(column);
   }
+
+  public void restoreGroups(List<BitSet> newgroups, String treeMethod,
+          String tree, double thresh2)
+  {
+    treeType=treeMethod;
+    groups = newgroups;
+    thresh=thresh2;
+    newick =tree;
+    
+  }
+  @Override
+  public boolean hasCutHeight() {
+    return groups!=null && thresh!=0;
+  }
+  @Override
+  public double getCutHeight()
+  {
+    return thresh;
+  }
+  @Override
+  public String getTreeMethod()
+  {
+    return treeType;
+  }
 }