X-Git-Url: http://source.jalview.org/gitweb/?a=blobdiff_plain;f=src%2Fjalview%2Fws%2Fdatamodel%2Falphafold%2FPAEContactMatrix.java;h=1ec856b67ac49045550cd5f0279008e17b19d4c1;hb=e134764b7eec841cb56a417250f2dd898680f985;hp=e61af44a14f2b05986ea7531d349c66fadd1f097;hpb=b5ea0bbb85bef19c50fb4341bda9e9da9ef09b13;p=jalview.git diff --git a/src/jalview/ws/datamodel/alphafold/PAEContactMatrix.java b/src/jalview/ws/datamodel/alphafold/PAEContactMatrix.java index e61af44..1ec856b 100644 --- a/src/jalview/ws/datamodel/alphafold/PAEContactMatrix.java +++ b/src/jalview/ws/datamodel/alphafold/PAEContactMatrix.java @@ -1,80 +1,197 @@ package jalview.ws.datamodel.alphafold; +import java.awt.Color; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import jalview.analysis.AverageDistanceEngine; +import jalview.bin.Console; +import jalview.datamodel.BinaryNode; import jalview.datamodel.ContactListI; import jalview.datamodel.ContactListImpl; import jalview.datamodel.ContactListProviderI; import jalview.datamodel.ContactMatrixI; -import jalview.datamodel.ContactRange; import jalview.datamodel.SequenceI; +import jalview.util.MapUtils; public class PAEContactMatrix implements ContactMatrixI { - SequenceI refSeq=null; - int maxrow=0,maxcol=0; - int[] indices1,indices2; + SequenceI refSeq = null; + + /** + * the length that refSeq is expected to be (excluding gaps, of course) + */ + int length; + + int maxrow = 0, maxcol = 0; + + int[] indices1, indices2; + float[][] elements; + float maxscore; - - @SuppressWarnings("unchecked") - public PAEContactMatrix(SequenceI _refSeq, Map pae_obj) throws Exception + + private void setRefSeq(SequenceI _refSeq) { refSeq = _refSeq; - while (refSeq.getDatasetSequence()!=null) + while (refSeq.getDatasetSequence() != null) { - refSeq=refSeq.getDatasetSequence(); + refSeq = refSeq.getDatasetSequence(); } + length = _refSeq.getEnd() - _refSeq.getStart() + 1; + } + + @SuppressWarnings("unchecked") + public PAEContactMatrix(SequenceI _refSeq, Map pae_obj) + { + setRefSeq(_refSeq); // convert the lists to primitive arrays and store - int length = _refSeq.getEnd()-_refSeq.getStart()+1; - - // assume indices are with respect to range defined by _refSeq on the dataset refSeq - Iterator rows = ((List)pae_obj.get("residue1")).iterator(); - Iterator cols = ((List)pae_obj.get("residue2")).iterator(); - Iterator scores = ((List)pae_obj.get("distance")).iterator(); - - elements=new float[length][length]; - while (scores.hasNext()) { - float escore=scores.next().floatValue(); - int row=rows.next().intValue(); - int col=cols.next().intValue(); - if (maxrow maxcol) + { + maxcol = row.length; + } + maxscore = row[0]; + for (float f : row) + { + if (maxscore < f) + { + maxscore = f; + } + } + } + maxrow = matrix.length; + elements = matrix; + + } + + /** + * parse a sane JSON representation of the pAE + * + * @param pae_obj + */ + @SuppressWarnings("unchecked") + private void parse_version_2_pAE(Map pae_obj) + { + // this is never going to be reached by the integer rounding.. or is it ? + maxscore = ((Double) MapUtils.getFirst(pae_obj, + "max_predicted_aligned_error", "max_pae")).floatValue(); + List> scoreRows = ((List>) MapUtils + .getFirst(pae_obj, "predicted_aligned_error", "pae")); + elements = new float[scoreRows.size()][scoreRows.size()]; + int row = 0, col = 0; + for (List scoreRow : scoreRows) + { + Iterator scores = scoreRow.iterator(); + while (scores.hasNext()) + { + Object d = scores.next(); + if (d instanceof Double) + elements[row][col++] = ((Double) d).longValue(); + else + elements[row][col++] = (float) ((Long) d).longValue(); + } + row++; + col = 0; + } + maxcol = length; + maxrow = length; + } + + /** + * v1 format got ditched 28th July 2022 see + * https://alphafold.ebi.ac.uk/faq#:~:text=We%20updated%20the%20PAE%20JSON%20file%20format%20on%2028th%20July%202022 + * + * @param pae_obj + */ + @SuppressWarnings("unchecked") + private void parse_version_1_pAE(Map pae_obj) + { + // assume indices are with respect to range defined by _refSeq on the + // dataset refSeq + Iterator rows = ((List) pae_obj.get("residue1")).iterator(); + Iterator cols = ((List) pae_obj.get("residue2")).iterator(); + Iterator scores = ((List) pae_obj.get("distance")) + .iterator(); + // assume square matrix + elements = new float[length][length]; + while (scores.hasNext()) + { + float escore = scores.next().floatValue(); + int row = rows.next().intValue(); + int col = cols.next().intValue(); + if (maxrow < row) { - maxrow=row; + maxrow = row; } - if (maxcol= elements.length) + { + return null; + } + + return new ContactListImpl(new ContactListProviderI() { @Override + public int getPosition() + { + return _column; + } + + @Override public int getContactHeight() { - return maxcol-1; + return maxcol - 1; } - + @Override public double getContactAt(int column) { - if (column<0 || column>=elements[_column].length) + if (column < 0 || column >= elements[_column].length) { return -1; } - // TODO Auto-generated method stub return elements[_column][column]; } }); @@ -95,7 +212,7 @@ public class PAEContactMatrix implements ContactMatrixI @Override public boolean hasReferenceSeq() { - return (refSeq!=null); + return (refSeq != null); } @Override @@ -104,4 +221,179 @@ public class PAEContactMatrix implements ContactMatrixI return refSeq; } + @Override + public String getAnnotDescr() + { + return "Predicted Alignment Error" + + ((refSeq == null) ? "" : (" for " + refSeq.getName())); + } + + @Override + public String getAnnotLabel() + { + StringBuilder label = new StringBuilder("PAE Matrix"); + // if (this.getReferenceSeq() != null) + // { + // label.append(":").append(this.getReferenceSeq().getDisplayId(false)); + // } + return label.toString(); + } + + public static final String PAEMATRIX = "PAE_MATRIX"; + + @Override + public String getType() + { + return PAEMATRIX; + } + + @Override + public int getWidth() + { + return length; + } + + @Override + public int getHeight() + { + return length; + } + + List groups = null; + + @Override + public boolean hasGroups() + { + return groups != null; + } + + String newick = null; + + @Override + public String getNewick() + { + return newick; + } + + @Override + public boolean hasTree() + { + return newick != null && newick.length() > 0; + } + + boolean abs; + + double thresh; + + String treeType = null; + + public void makeGroups(float thresh, boolean abs) + { + AverageDistanceEngine clusterer = new AverageDistanceEngine(null, null, + this); + double height = clusterer.findHeight(clusterer.getTopNode()); + newick = new jalview.io.NewickFile(clusterer.getTopNode(), false, true) + .print(); + treeType = "UPGMA"; + Console.trace("Newick string\n" + newick); + + List nodegroups; + if (abs ? height > thresh : 0 < thresh && thresh < 1) + { + float cut = abs ? (float) (thresh / height) : thresh; + Console.debug("Threshold " + cut + " for height=" + height); + + nodegroups = clusterer.groupNodes(cut); + } + else + { + nodegroups = new ArrayList(); + nodegroups.add(clusterer.getTopNode()); + } + this.abs = abs; + this.thresh = thresh; + groups = new ArrayList<>(); + for (BinaryNode root : nodegroups) + { + BitSet gpset = new BitSet(); + for (BinaryNode leaf : clusterer.findLeaves(root)) + { + gpset.set((Integer) leaf.element()); + } + groups.add(gpset); + } + } + + @Override + public void updateGroups(List colGroups) + { + if (colGroups != null) + { + groups = colGroups; + } + } + + @Override + public BitSet getGroupsFor(int column) + { + for (BitSet gp : groups) + { + if (gp.get(column)) + { + return gp; + } + } + return ContactMatrixI.super.getGroupsFor(column); + } + + HashMap colorMap = new HashMap<>(); + + @Override + public Color getColourForGroup(BitSet bs) + { + if (bs == null) + { + return Color.white; + } + Color groupCol = colorMap.get(bs); + if (groupCol == null) + { + return Color.white; + } + return groupCol; + } + + @Override + public void setColorForGroup(BitSet bs, Color color) + { + colorMap.put(bs, color); + } + + public void restoreGroups(List newgroups, String treeMethod, + String tree, double thresh2) + { + treeType = treeMethod; + groups = newgroups; + thresh = thresh2; + newick = tree; + + } + + @Override + public boolean hasCutHeight() + { + return groups != null && thresh != 0; + } + + @Override + public double getCutHeight() + { + return thresh; + } + + @Override + public String getTreeMethod() + { + return treeType; + } }