1 package jalview.ws.datamodel.alphafold;
4 import java.util.ArrayList;
5 import java.util.BitSet;
6 import java.util.HashMap;
7 import java.util.Iterator;
11 import jalview.analysis.AverageDistanceEngine;
12 import jalview.bin.Console;
13 import jalview.datamodel.BinaryNode;
14 import jalview.datamodel.ContactListI;
15 import jalview.datamodel.ContactListImpl;
16 import jalview.datamodel.ContactListProviderI;
17 import jalview.datamodel.ContactMatrixI;
18 import jalview.datamodel.SequenceI;
19 import jalview.util.MapUtils;
21 public class PAEContactMatrix implements ContactMatrixI
24 SequenceI refSeq = null;
27 * the length that refSeq is expected to be (excluding gaps, of course)
31 int maxrow = 0, maxcol = 0;
33 int[] indices1, indices2;
39 private void setRefSeq(SequenceI _refSeq)
42 while (refSeq.getDatasetSequence() != null)
44 refSeq = refSeq.getDatasetSequence();
46 length = _refSeq.getEnd() - _refSeq.getStart() + 1;
49 @SuppressWarnings("unchecked")
50 public PAEContactMatrix(SequenceI _refSeq, Map<String, Object> pae_obj)
53 // convert the lists to primitive arrays and store
55 if (!MapUtils.containsAKey(pae_obj, "predicted_aligned_error", "pae"))
57 parse_version_1_pAE(pae_obj);
62 parse_version_2_pAE(pae_obj);
67 * construct a sequence associated PAE matrix directly from a float array
72 public PAEContactMatrix(SequenceI _refSeq, float[][] matrix)
76 for (float[] row : matrix)
78 if (row.length > maxcol)
91 maxrow = matrix.length;
97 * parse a sane JSON representation of the pAE
101 @SuppressWarnings("unchecked")
102 private void parse_version_2_pAE(Map<String, Object> pae_obj)
104 // this is never going to be reached by the integer rounding.. or is it ?
105 maxscore = ((Double) MapUtils.getFirst(pae_obj,
106 "max_predicted_aligned_error", "max_pae")).floatValue();
107 List<List<Long>> scoreRows = ((List<List<Long>>) MapUtils
108 .getFirst(pae_obj, "predicted_aligned_error", "pae"));
109 elements = new float[scoreRows.size()][scoreRows.size()];
110 int row = 0, col = 0;
111 for (List<Long> scoreRow : scoreRows)
113 Iterator<Long> scores = scoreRow.iterator();
114 while (scores.hasNext())
116 Object d = scores.next();
117 if (d instanceof Double)
118 elements[row][col++] = ((Double) d).longValue();
120 elements[row][col++] = (float) ((Long) d).longValue();
130 * v1 format got ditched 28th July 2022 see
131 * https://alphafold.ebi.ac.uk/faq#:~:text=We%20updated%20the%20PAE%20JSON%20file%20format%20on%2028th%20July%202022
135 @SuppressWarnings("unchecked")
136 private void parse_version_1_pAE(Map<String, Object> pae_obj)
138 // assume indices are with respect to range defined by _refSeq on the
140 Iterator<Long> rows = ((List<Long>) pae_obj.get("residue1")).iterator();
141 Iterator<Long> cols = ((List<Long>) pae_obj.get("residue2")).iterator();
142 Iterator<Double> scores = ((List<Double>) pae_obj.get("distance"))
144 // assume square matrix
145 elements = new float[length][length];
146 while (scores.hasNext())
148 float escore = scores.next().floatValue();
149 int row = rows.next().intValue();
150 int col = cols.next().intValue();
159 elements[row - 1][col - 1] = escore;
162 maxscore = ((Double) MapUtils.getFirst(pae_obj,
163 "max_predicted_aligned_error", "max_pae")).floatValue();
167 public ContactListI getContactList(final int _column)
169 if (_column < 0 || _column >= elements.length)
174 return new ContactListImpl(new ContactListProviderI()
177 public int getPosition()
183 public int getContactHeight()
189 public double getContactAt(int column)
191 if (column < 0 || column >= elements[_column].length)
195 return elements[_column][column];
201 public float getMin()
207 public float getMax()
213 public boolean hasReferenceSeq()
215 return (refSeq != null);
219 public SequenceI getReferenceSeq()
225 public String getAnnotDescr()
227 return "Predicted Alignment Error"
228 + ((refSeq == null) ? "" : (" for " + refSeq.getName()));
232 public String getAnnotLabel()
234 StringBuilder label = new StringBuilder("PAE Matrix");
235 // if (this.getReferenceSeq() != null)
237 // label.append(":").append(this.getReferenceSeq().getDisplayId(false));
239 return label.toString();
242 public static final String PAEMATRIX = "PAE_MATRIX";
245 public String getType()
251 public int getWidth()
257 public int getHeight()
262 List<BitSet> groups = null;
265 public boolean hasGroups()
267 return groups != null;
270 String newick = null;
273 public String getNewick()
279 public boolean hasTree()
281 return newick != null && newick.length() > 0;
288 String treeType = null;
290 public void makeGroups(float thresh, boolean abs)
292 AverageDistanceEngine clusterer = new AverageDistanceEngine(null, null,
294 double height = clusterer.findHeight(clusterer.getTopNode());
295 newick = new jalview.io.NewickFile(clusterer.getTopNode(), false, true)
298 Console.trace("Newick string\n" + newick);
300 List<BinaryNode> nodegroups;
301 if (abs ? height > thresh : 0 < thresh && thresh < 1)
303 float cut = abs ? (float) (thresh / height) : thresh;
304 Console.debug("Threshold " + cut + " for height=" + height);
306 nodegroups = clusterer.groupNodes(cut);
310 nodegroups = new ArrayList<BinaryNode>();
311 nodegroups.add(clusterer.getTopNode());
314 this.thresh = thresh;
315 groups = new ArrayList<>();
316 for (BinaryNode root : nodegroups)
318 BitSet gpset = new BitSet();
319 for (BinaryNode leaf : clusterer.findLeaves(root))
321 gpset.set((Integer) leaf.element());
328 public void updateGroups(List<BitSet> colGroups)
330 if (colGroups != null)
337 public BitSet getGroupsFor(int column)
339 for (BitSet gp : groups)
346 return ContactMatrixI.super.getGroupsFor(column);
349 HashMap<BitSet, Color> colorMap = new HashMap<>();
352 public Color getColourForGroup(BitSet bs)
358 Color groupCol = colorMap.get(bs);
359 if (groupCol == null)
367 public void setColorForGroup(BitSet bs, Color color)
369 colorMap.put(bs, color);
372 public void restoreGroups(List<BitSet> newgroups, String treeMethod,
373 String tree, double thresh2)
375 treeType = treeMethod;
383 public boolean hasCutHeight()
385 return groups != null && thresh != 0;
389 public double getCutHeight()
395 public String getTreeMethod()