X-Git-Url: http://source.jalview.org/gitweb/?a=blobdiff_plain;ds=sidebyside;f=src%2Fjalview%2Fws%2Fdatamodel%2Falphafold%2FPAEContactMatrix.java;h=397a84b2e933e79e42287950c4208c22f28c2fc3;hb=eb3e681d6e82ccdd5d312d1981dfb306e7f479f0;hp=30c77d287ac150b74eb65fdbda054cd6c8f68427;hpb=7420ce36f2b43280ef610e3743960207e4c2dbe3;p=jalview.git diff --git a/src/jalview/ws/datamodel/alphafold/PAEContactMatrix.java b/src/jalview/ws/datamodel/alphafold/PAEContactMatrix.java index 30c77d2..397a84b 100644 --- a/src/jalview/ws/datamodel/alphafold/PAEContactMatrix.java +++ b/src/jalview/ws/datamodel/alphafold/PAEContactMatrix.java @@ -1,54 +1,53 @@ package jalview.ws.datamodel.alphafold; -import java.util.ArrayList; -import java.util.BitSet; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; import java.util.Iterator; import java.util.List; import java.util.Map; -import jalview.analysis.AverageDistanceEngine; -import jalview.bin.Console; -import jalview.datamodel.BinaryNode; +import org.json.simple.JSONObject; + import jalview.datamodel.ContactListI; import jalview.datamodel.ContactListImpl; import jalview.datamodel.ContactListProviderI; import jalview.datamodel.ContactMatrixI; +import jalview.datamodel.FloatContactMatrix; +import jalview.datamodel.GroupSet; +import jalview.datamodel.SequenceDummy; import jalview.datamodel.SequenceI; +import jalview.io.FileFormatException; +import jalview.util.MapList; import jalview.util.MapUtils; +import jalview.ws.dbsources.EBIAlfaFold; -public class PAEContactMatrix implements ContactMatrixI +/** + * routines and class for holding predicted alignment error matrices as produced + * by alphafold et al. + * + * getContactList(column) returns the vector of predicted alignment errors for + * reference position given by column getElementAt(column, i) returns the + * predicted superposition error for the ith position when column is used as + * reference + * + * Many thanks to Ora Schueler Furman for noticing that earlier development + * versions did not show the PAE oriented correctly + * + * @author jprocter + * + */ +public class PAEContactMatrix extends + MappableContactMatrix implements ContactMatrixI { - SequenceI refSeq = null; - - /** - * the length that refSeq is expected to be (excluding gaps, of course) - */ - int length; - - int maxrow = 0, maxcol = 0; - - int[] indices1, indices2; - - float[][] elements; - - float maxscore; - - private void setRefSeq(SequenceI _refSeq) - { - refSeq = _refSeq; - while (refSeq.getDatasetSequence() != null) - { - refSeq = refSeq.getDatasetSequence(); - } - length = _refSeq.getEnd() - _refSeq.getStart() + 1; - } @SuppressWarnings("unchecked") public PAEContactMatrix(SequenceI _refSeq, Map pae_obj) + throws FileFormatException { setRefSeq(_refSeq); // convert the lists to primitive arrays and store - + if (!MapUtils.containsAKey(pae_obj, "predicted_aligned_error", "pae")) { parse_version_1_pAE(pae_obj); @@ -59,66 +58,104 @@ public class PAEContactMatrix implements ContactMatrixI parse_version_2_pAE(pae_obj); } } + /** * construct a sequence associated PAE matrix directly from a float array + * * @param _refSeq * @param matrix */ public PAEContactMatrix(SequenceI _refSeq, float[][] matrix) { + mappedMatrix = new FloatContactMatrix(matrix); setRefSeq(_refSeq); - maxcol=0; - for (float[] row:matrix) - { - if (row.length>maxcol) - { - maxcol=row.length; - } - maxscore=row[0]; - for (float f:row) - { - if (maxscore pae_obj) { - // this is never going to be reached by the integer rounding.. or is it ? - maxscore = ((Double) MapUtils.getFirst(pae_obj, - "max_predicted_aligned_error", "max_pae")).floatValue(); + float maxscore = -1; + // look for a maxscore element - if there is one... + try + { + // this is never going to be reached by the integer rounding.. or is it ? + maxscore = ((Double) MapUtils.getFirst(pae_obj, + "max_predicted_aligned_error", "max_pae")).floatValue(); + } catch (Throwable t) + { + // ignore if a key is not found. + } List> scoreRows = ((List>) MapUtils - .getFirst(pae_obj, "predicted_aligned_error", "pae")) - ; - elements = new float[scoreRows.size()][scoreRows.size()]; + .getFirst(pae_obj, "predicted_aligned_error", "pae")); + float[][] elements = new float[scoreRows.size()][scoreRows.size()]; int row = 0, col = 0; - for (List scoreRow:scoreRows) + for (List scoreRow : scoreRows) { Iterator scores = scoreRow.iterator(); while (scores.hasNext()) { Object d = scores.next(); if (d instanceof Double) - elements[row][col++] = ((Double) d).longValue(); + { + elements[col][row] = ((Double) d).longValue(); + } else - elements[row][col++] = (float) ((Long)d).longValue(); + { + elements[col][row] = (float) ((Long) d).longValue(); + } + + if (maxscore < elements[col][row]) + { + maxscore = elements[col][row]; + } + col++; } row++; col = 0; } - maxcol = length; - maxrow = length; + mappedMatrix = new FloatContactMatrix(elements); } /** @@ -134,13 +171,11 @@ public class PAEContactMatrix implements ContactMatrixI // dataset refSeq Iterator rows = ((List) pae_obj.get("residue1")).iterator(); Iterator cols = ((List) pae_obj.get("residue2")).iterator(); - Iterator scores = ((List) pae_obj.get("distance")) - .iterator(); - // assume square matrix - elements = new float[length][length]; - while (scores.hasNext()) + // two pass - to allocate the elements array + + int maxrow = -1, maxcol = -1; + while (rows.hasNext()) { - float escore = scores.next().floatValue(); int row = rows.next().intValue(); int col = cols.next().intValue(); if (maxrow < row) @@ -151,153 +186,97 @@ public class PAEContactMatrix implements ContactMatrixI { maxcol = col; } - elements[row - 1][col - 1] = escore; - } - - maxscore = ((Double) MapUtils.getFirst(pae_obj, - "max_predicted_aligned_error", "max_pae")).floatValue(); - } - @Override - public ContactListI getContactList(final int _column) - { - if (_column < 0 || _column >= elements.length) - { - return null; } - - return new ContactListImpl(new ContactListProviderI() + rows = ((List) pae_obj.get("residue1")).iterator(); + cols = ((List) pae_obj.get("residue2")).iterator(); + Iterator scores = ((List) pae_obj.get("distance")) + .iterator(); + float[][] elements = new float[maxcol][maxrow]; + while (scores.hasNext()) { - @Override - public int getPosition() - { - return _column; - } - - @Override - public int getContactHeight() + float escore = scores.next().floatValue(); + int row = rows.next().intValue(); + int col = cols.next().intValue(); + if (maxrow < row) { - return maxcol-1; + maxrow = row; } - - @Override - public double getContactAt(int column) + if (maxcol < col) { - if (column < 0 || column >= elements[_column].length) - { - return -1; - } - return elements[_column][column]; + maxcol = col; } - }); - } - - @Override - public float getMin() - { - return 0; - } - - @Override - public float getMax() - { - return maxscore; - } - - @Override - public boolean hasReferenceSeq() - { - return (refSeq != null); - } + elements[col - 1][row - 1] = escore; + } - @Override - public SequenceI getReferenceSeq() - { - return refSeq; + mappedMatrix = new FloatContactMatrix(elements); } @Override public String getAnnotDescr() { - return "Predicted Alignment Error for " + refSeq.getName(); + return "Predicted Alignment Error" + + ((refSeq == null) ? "" : (" for " + refSeq.getName())); } @Override public String getAnnotLabel() { - return "pAE Matrix"; + StringBuilder label = new StringBuilder("PAE Matrix"); + // if (this.getReferenceSeq() != null) + // { + // label.append(":").append(this.getReferenceSeq().getDisplayId(false)); + // } + return label.toString(); } - public static final String PAEMATRIX="PAE_MATRIX"; + public static final String PAEMATRIX = "PAE_MATRIX"; + @Override public String getType() { return PAEMATRIX; } - @Override - public int getWidth() - { - return length; - } - @Override - public int getHeight() - { - return length; - } - List groups=null; - @Override - public boolean hasGroups() - { - return groups!=null; - } - String newick=null; - public String getNewickString() - { - return newick; - } - public void makeGroups(float thresh,boolean abs) - { - AverageDistanceEngine clusterer = new AverageDistanceEngine(null, null, this); - double height = clusterer.findHeight(clusterer.getTopNode()); - newick = new jalview.io.NewickFile(clusterer.getTopNode(),false,true).print(); - - Console.trace("Newick string\n"+newick); - List nodegroups; - if (abs ? height > thresh : 0 < thresh && thresh < 1) + public static void validateContactMatrixFile(String fileName) + throws FileFormatException, IOException + { + FileInputStream infile = null; + try { - float cut = abs ? (float) (thresh / height) : thresh; - Console.debug("Threshold "+cut+" for height="+height); - - nodegroups = clusterer.groupNodes(cut); + infile = new FileInputStream(new File(fileName)); + } catch (Throwable t) + { + new IOException("Couldn't open " + fileName, t); } - else + JSONObject paeDict = null; + try { - nodegroups = new ArrayList(); - nodegroups.add(clusterer.getTopNode()); + paeDict = EBIAlfaFold.parseJSONtoPAEContactMatrix(infile); + } catch (Throwable t) + { + new FileFormatException("Couldn't parse " + fileName + + " as a JSON dict or array containing a dict"); } - groups = new ArrayList<>(); - for (BinaryNode root:nodegroups) + PAEContactMatrix matrix = new PAEContactMatrix( + new SequenceDummy("Predicted"), (Map) paeDict); + if (matrix.getWidth() <= 0) { - BitSet gpset=new BitSet(); - for (BinaryNode leaf:clusterer.findLeaves(root)) - { - gpset.set((Integer)leaf.element()); - } - groups.add(gpset); + throw new FileFormatException( + "No data in PAE matrix read from '" + fileName + "'"); } } - + @Override - public BitSet getGroupsFor(int column) + public boolean equals(Object obj) { - for (BitSet gp:groups) { - if (gp.get(column)) - { - return gp; - } - } - return ContactMatrixI.super.getGroupsFor(column); + return super.equals(obj); + } + + @Override + public int hashCode() + { + return super.hashCode(); } }