JAL-4134 JAL-3855 store/restore groups, tree and threshold used to cluster a PAE...
[jalview.git] / src / jalview / ws / datamodel / alphafold / PAEContactMatrix.java
index d1a2e9d..87ccab6 100644 (file)
@@ -1,9 +1,14 @@
 package jalview.ws.datamodel.alphafold;
 
+import java.util.ArrayList;
+import java.util.BitSet;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 
+import jalview.analysis.AverageDistanceEngine;
+import jalview.bin.Console;
+import jalview.datamodel.BinaryNode;
 import jalview.datamodel.ContactListI;
 import jalview.datamodel.ContactListImpl;
 import jalview.datamodel.ContactListProviderI;
@@ -223,9 +228,11 @@ public class PAEContactMatrix implements ContactMatrixI
   @Override
   public String getAnnotLabel()
   {
-    StringBuilder label = new StringBuilder("pAE Matrix");
+    StringBuilder label = new StringBuilder("PAE Matrix");
     if (this.getReferenceSeq() != null)
+    {
       label.append(":").append(this.getReferenceSeq().getDisplayId(false));
+    }
     return label.toString();
   }
 
@@ -248,4 +255,88 @@ public class PAEContactMatrix implements ContactMatrixI
   {
     return length;
   }
+  List<BitSet> groups=null;
+  @Override
+  public boolean hasGroups()
+  {
+    return groups!=null;
+  }
+  String newick=null;
+  public String getNewickString()
+  {
+    return newick;
+  }
+  boolean abs;
+  double thresh;
+  String treeType=null;
+  public void makeGroups(float thresh,boolean abs)
+  {
+    AverageDistanceEngine clusterer = new AverageDistanceEngine(null, null, this);
+    double height = clusterer.findHeight(clusterer.getTopNode());
+    newick = new jalview.io.NewickFile(clusterer.getTopNode(),false,true).print();
+    treeType = "UPGMA";
+    Console.trace("Newick string\n"+newick);
+
+    List<BinaryNode> nodegroups;
+    if (abs ? height > thresh : 0 < thresh && thresh < 1)
+    {
+      float cut = abs ? (float) (thresh / height) : thresh;
+      Console.debug("Threshold "+cut+" for height="+height);
+
+      nodegroups = clusterer.groupNodes(cut);
+    }
+    else
+    {
+      nodegroups = new ArrayList<BinaryNode>();
+      nodegroups.add(clusterer.getTopNode());
+    }
+    this.abs=abs;
+    this.thresh=thresh;
+    groups = new ArrayList<>();
+    for (BinaryNode root:nodegroups)
+    {
+      BitSet gpset=new BitSet();
+      for (BinaryNode leaf:clusterer.findLeaves(root))
+      {
+        gpset.set((Integer)leaf.element());
+      }
+      groups.add(gpset);
+    }
+  }
+  
+  @Override
+  public BitSet getGroupsFor(int column)
+  {
+    for (BitSet gp:groups) {
+      if (gp.get(column))
+      {
+        return gp;
+      }
+    }
+    return ContactMatrixI.super.getGroupsFor(column);
+  }
+
+  public void restoreGroups(List<BitSet> newgroups, String treeMethod,
+          String tree, double thresh2)
+  {
+    treeType=treeMethod;
+    groups = newgroups;
+    thresh=thresh2;
+    newick =tree;
+    
+  }
+  @Override
+  public boolean hasCutHeight() {
+    return groups!=null && thresh!=0;
+  }
+  @Override
+  public double getCutHeight()
+  {
+    return thresh;
+  }
+  @Override
+  public String getTreeMethod()
+  {
+    return treeType;
+  }
 }