1 package jalview.ws.datamodel.alphafold;
3 import java.util.ArrayList;
4 import java.util.BitSet;
5 import java.util.Iterator;
9 import jalview.analysis.AverageDistanceEngine;
10 import jalview.bin.Console;
11 import jalview.datamodel.BinaryNode;
12 import jalview.datamodel.ContactListI;
13 import jalview.datamodel.ContactListImpl;
14 import jalview.datamodel.ContactListProviderI;
15 import jalview.datamodel.ContactMatrixI;
16 import jalview.datamodel.SequenceI;
17 import jalview.util.MapUtils;
19 public class PAEContactMatrix implements ContactMatrixI
22 SequenceI refSeq = null;
25 * the length that refSeq is expected to be (excluding gaps, of course)
29 int maxrow = 0, maxcol = 0;
31 int[] indices1, indices2;
37 private void setRefSeq(SequenceI _refSeq)
40 while (refSeq.getDatasetSequence() != null)
42 refSeq = refSeq.getDatasetSequence();
44 length = _refSeq.getEnd() - _refSeq.getStart() + 1;
46 @SuppressWarnings("unchecked")
47 public PAEContactMatrix(SequenceI _refSeq, Map<String, Object> pae_obj)
50 // convert the lists to primitive arrays and store
52 if (!MapUtils.containsAKey(pae_obj, "predicted_aligned_error", "pae"))
54 parse_version_1_pAE(pae_obj);
59 parse_version_2_pAE(pae_obj);
63 * construct a sequence associated PAE matrix directly from a float array
67 public PAEContactMatrix(SequenceI _refSeq, float[][] matrix)
71 for (float[] row:matrix)
73 if (row.length>maxcol)
91 * parse a sane JSON representation of the pAE
95 @SuppressWarnings("unchecked")
96 private void parse_version_2_pAE(Map<String, Object> pae_obj)
98 // this is never going to be reached by the integer rounding.. or is it ?
99 maxscore = ((Double) MapUtils.getFirst(pae_obj,
100 "max_predicted_aligned_error", "max_pae")).floatValue();
101 List<List<Long>> scoreRows = ((List<List<Long>>) MapUtils
102 .getFirst(pae_obj, "predicted_aligned_error", "pae"))
104 elements = new float[scoreRows.size()][scoreRows.size()];
105 int row = 0, col = 0;
106 for (List<Long> scoreRow:scoreRows)
108 Iterator<Long> scores = scoreRow.iterator();
109 while (scores.hasNext())
111 Object d = scores.next();
112 if (d instanceof Double)
113 elements[row][col++] = ((Double) d).longValue();
115 elements[row][col++] = (float) ((Long)d).longValue();
125 * v1 format got ditched 28th July 2022 see
126 * https://alphafold.ebi.ac.uk/faq#:~:text=We%20updated%20the%20PAE%20JSON%20file%20format%20on%2028th%20July%202022
130 @SuppressWarnings("unchecked")
131 private void parse_version_1_pAE(Map<String, Object> pae_obj)
133 // assume indices are with respect to range defined by _refSeq on the
135 Iterator<Long> rows = ((List<Long>) pae_obj.get("residue1")).iterator();
136 Iterator<Long> cols = ((List<Long>) pae_obj.get("residue2")).iterator();
137 Iterator<Double> scores = ((List<Double>) pae_obj.get("distance"))
139 // assume square matrix
140 elements = new float[length][length];
141 while (scores.hasNext())
143 float escore = scores.next().floatValue();
144 int row = rows.next().intValue();
145 int col = cols.next().intValue();
154 elements[row - 1][col - 1] = escore;
157 maxscore = ((Double) MapUtils.getFirst(pae_obj,
158 "max_predicted_aligned_error", "max_pae")).floatValue();
162 public ContactListI getContactList(final int _column)
164 if (_column < 0 || _column >= elements.length)
169 return new ContactListImpl(new ContactListProviderI()
172 public int getPosition()
178 public int getContactHeight()
184 public double getContactAt(int column)
186 if (column < 0 || column >= elements[_column].length)
190 return elements[_column][column];
196 public float getMin()
202 public float getMax()
208 public boolean hasReferenceSeq()
210 return (refSeq != null);
214 public SequenceI getReferenceSeq()
220 public String getAnnotDescr()
222 return "Predicted Alignment Error for " + refSeq.getName();
226 public String getAnnotLabel()
231 public static final String PAEMATRIX="PAE_MATRIX";
233 public String getType()
238 public int getWidth()
243 public int getHeight()
247 List<BitSet> groups=null;
249 public boolean hasGroups()
254 public String getNewickString()
258 public void makeGroups(float thresh,boolean abs)
260 AverageDistanceEngine clusterer = new AverageDistanceEngine(null, null, this);
261 double height = clusterer.findHeight(clusterer.getTopNode());
262 newick = new jalview.io.NewickFile(clusterer.getTopNode(),false,true).print();
264 Console.trace("Newick string\n"+newick);
266 List<BinaryNode> nodegroups;
267 if (abs ? height > thresh : 0 < thresh && thresh < 1)
269 float cut = abs ? (float) (thresh / height) : thresh;
270 Console.debug("Threshold "+cut+" for height="+height);
272 nodegroups = clusterer.groupNodes(cut);
276 nodegroups = new ArrayList<BinaryNode>();
277 nodegroups.add(clusterer.getTopNode());
280 groups = new ArrayList<>();
281 for (BinaryNode root:nodegroups)
283 BitSet gpset=new BitSet();
284 for (BinaryNode leaf:clusterer.findLeaves(root))
286 gpset.set((Integer)leaf.element());
293 public BitSet getGroupsFor(int column)
295 for (BitSet gp:groups) {
301 return ContactMatrixI.super.getGroupsFor(column);