package jalview.ws.datamodel.alphafold; import java.awt.Color; import java.io.BufferedInputStream; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.util.ArrayList; import java.util.BitSet; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import org.json.simple.JSONObject; import jalview.analysis.AverageDistanceEngine; import jalview.bin.Console; import jalview.datamodel.Annotation; import jalview.datamodel.BinaryNode; import jalview.datamodel.ContactListI; import jalview.datamodel.ContactListImpl; import jalview.datamodel.ContactListProviderI; import jalview.datamodel.ContactMatrixI; import jalview.datamodel.Mapping; import jalview.datamodel.SequenceDummy; import jalview.datamodel.SequenceI; import jalview.io.DataSourceType; import jalview.io.FileFormatException; import jalview.io.FileParse; import jalview.util.MapList; import jalview.util.MapUtils; import jalview.ws.dbsources.EBIAlfaFold; public class PAEContactMatrix extends MappableContactMatrix implements ContactMatrixI { int maxrow = 0, maxcol = 0; float[][] elements; float maxscore; @SuppressWarnings("unchecked") public PAEContactMatrix(SequenceI _refSeq, Map pae_obj) throws FileFormatException { setRefSeq(_refSeq); // convert the lists to primitive arrays and store if (!MapUtils.containsAKey(pae_obj, "predicted_aligned_error", "pae")) { parse_version_1_pAE(pae_obj); return; } else { parse_version_2_pAE(pae_obj); } } /** * construct a sequence associated PAE matrix directly from a float array * * @param _refSeq * @param matrix */ public PAEContactMatrix(SequenceI _refSeq, float[][] matrix) { setRefSeq(_refSeq); maxcol = 0; for (float[] row : matrix) { if (row.length > maxcol) { maxcol = row.length; } maxscore = row[0]; for (float f : row) { if (maxscore < f) { maxscore = f; } } } maxrow = matrix.length; elements = matrix; } /** * new matrix with specific mapping to a reference sequence * @param newRefSeq * @param newFromMapList * @param elements2 */ public PAEContactMatrix(SequenceI newRefSeq, MapList newFromMapList, float[][] elements2) { this(newRefSeq,elements2); toSeq = newFromMapList; } /** * parse a sane JSON representation of the pAE * * @param pae_obj */ @SuppressWarnings("unchecked") private void parse_version_2_pAE(Map pae_obj) { maxscore = -1; // look for a maxscore element - if there is one... try { // this is never going to be reached by the integer rounding.. or is it ? maxscore = ((Double) MapUtils.getFirst(pae_obj, "max_predicted_aligned_error", "max_pae")).floatValue(); } catch (Throwable t) { // ignore if a key is not found. } List> scoreRows = ((List>) MapUtils .getFirst(pae_obj, "predicted_aligned_error", "pae")); elements = new float[scoreRows.size()][scoreRows.size()]; int row = 0, col = 0; for (List scoreRow : scoreRows) { Iterator scores = scoreRow.iterator(); while (scores.hasNext()) { Object d = scores.next(); if (d instanceof Double) { elements[row][col++] = ((Double) d).longValue(); } else { elements[row][col++] = (float) ((Long) d).longValue(); } if (maxscore < elements[row][col - 1]) { maxscore = elements[row][col - 1]; } } row++; col = 0; } maxcol = length; maxrow = length; } /** * v1 format got ditched 28th July 2022 see * https://alphafold.ebi.ac.uk/faq#:~:text=We%20updated%20the%20PAE%20JSON%20file%20format%20on%2028th%20July%202022 * * @param pae_obj */ @SuppressWarnings("unchecked") private void parse_version_1_pAE(Map pae_obj) { // assume indices are with respect to range defined by _refSeq on the // dataset refSeq Iterator rows = ((List) pae_obj.get("residue1")).iterator(); Iterator cols = ((List) pae_obj.get("residue2")).iterator(); // two pass - to allocate the elements array while (rows.hasNext()) { int row = rows.next().intValue(); int col = cols.next().intValue(); if (maxrow < row) { maxrow = row; } if (maxcol < col) { maxcol = col; } } rows = ((List) pae_obj.get("residue1")).iterator(); cols = ((List) pae_obj.get("residue2")).iterator(); Iterator scores = ((List) pae_obj.get("distance")) .iterator(); elements = new float[maxrow][maxcol]; while (scores.hasNext()) { float escore = scores.next().floatValue(); int row = rows.next().intValue(); int col = cols.next().intValue(); if (maxrow < row) { maxrow = row; } if (maxcol < col) { maxcol = col; } elements[row - 1][col - 1] = escore; } maxscore = ((Double) MapUtils.getFirst(pae_obj, "max_predicted_aligned_error", "max_pae")).floatValue(); } @Override public ContactListI getContactList(final int column) { // final int _column; // if (toSeq != null) // { // int[] word = toSeq.locateInTo(column, column); // if (word == null) // { // return null; // } // _column = word[0]; // } // else // { // _column = column; // } if (column < 0 || column >= elements.length) { return null; } return new ContactListImpl(new ContactListProviderI() { @Override public int getPosition() { return column; } @Override public int getContactHeight() { return maxcol - 1; } @Override public double getContactAt(int mcolumn) { if (mcolumn < 0 || mcolumn >= elements[column].length) { return -1; } return elements[column][mcolumn]; } }); } @Override protected double getElementAt(int _column, int i) { return elements[_column][i]; } @Override public float getMin() { return 0; } @Override public float getMax() { return maxscore; } @Override public String getAnnotDescr() { return "Predicted Alignment Error"+((refSeq==null) ? "" : (" for " + refSeq.getName())); } @Override public String getAnnotLabel() { StringBuilder label = new StringBuilder("PAE Matrix"); //if (this.getReferenceSeq() != null) //{ // label.append(":").append(this.getReferenceSeq().getDisplayId(false)); //} return label.toString(); } public static final String PAEMATRIX = "PAE_MATRIX"; @Override public String getType() { return PAEMATRIX; } @Override public int getWidth() { return length; } @Override public int getHeight() { return length; } List groups=null; @Override public boolean hasGroups() { return groups!=null; } String newick=null; @Override public String getNewick() { return newick; } @Override public boolean hasTree() { return newick!=null && newick.length()>0; } boolean abs; double thresh; String treeType=null; public void makeGroups(float thresh,boolean abs) { AverageDistanceEngine clusterer = new AverageDistanceEngine(null, null, this); double height = clusterer.findHeight(clusterer.getTopNode()); newick = new jalview.io.NewickFile(clusterer.getTopNode(),false,true).print(); treeType = "UPGMA"; Console.trace("Newick string\n"+newick); List nodegroups; if (abs ? height > thresh : 0 < thresh && thresh < 1) { float cut = abs ? (float) (thresh / height) : thresh; Console.debug("Threshold "+cut+" for height="+height); nodegroups = clusterer.groupNodes(cut); } else { nodegroups = new ArrayList(); nodegroups.add(clusterer.getTopNode()); } this.abs=abs; this.thresh=thresh; groups = new ArrayList<>(); for (BinaryNode root:nodegroups) { BitSet gpset=new BitSet(); for (BinaryNode leaf:clusterer.findLeaves(root)) { gpset.set((Integer)leaf.element()); } groups.add(gpset); } } @Override public void updateGroups(List colGroups) { if (colGroups!=null) { groups=colGroups; } } @Override public BitSet getGroupsFor(int column) { if (groups != null) { for (BitSet gp : groups) { if (gp.get(column)) { return gp; } } } return ContactMatrixI.super.getGroupsFor(column); } HashMap colorMap = new HashMap<>(); @Override public Color getColourForGroup(BitSet bs) { if (bs==null) { return Color.white; } Color groupCol=colorMap.get(bs); if (groupCol==null) { return Color.white; } return groupCol; } @Override public void setColorForGroup(BitSet bs,Color color) { colorMap.put(bs,color); } public void restoreGroups(List newgroups, String treeMethod, String tree, double thresh2) { treeType=treeMethod; groups = newgroups; thresh=thresh2; newick =tree; } @Override public boolean hasCutHeight() { return groups!=null && thresh!=0; } @Override public double getCutHeight() { return thresh; } @Override public String getTreeMethod() { return treeType; } public static void validateContactMatrixFile(String fileName) throws FileFormatException,IOException { FileInputStream infile=null; try { infile = new FileInputStream(new File(fileName)); } catch (Throwable t) { new IOException("Couldn't open "+fileName,t); } JSONObject paeDict=null; try { paeDict = EBIAlfaFold.parseJSONtoPAEContactMatrix(infile); } catch (Throwable t) { new FileFormatException("Couldn't parse "+fileName+" as a JSON dict or array containing a dict"); } PAEContactMatrix matrix = new PAEContactMatrix(new SequenceDummy("Predicted"), (Map)paeDict); if (matrix.getWidth()<=0) { throw new FileFormatException("No data in PAE matrix read from '"+fileName+"'"); } } @Override protected PAEContactMatrix newMappableContactMatrix( SequenceI newRefSeq, MapList newFromMapList) { return new PAEContactMatrix(newRefSeq, newFromMapList, elements); } }