1 package jalview.ws.datamodel.alphafold;
3 import java.util.ArrayList;
4 import java.util.BitSet;
5 import java.util.Iterator;
9 import jalview.analysis.AverageDistanceEngine;
10 import jalview.bin.Console;
11 import jalview.datamodel.BinaryNode;
12 import jalview.datamodel.ContactListI;
13 import jalview.datamodel.ContactListImpl;
14 import jalview.datamodel.ContactListProviderI;
15 import jalview.datamodel.ContactMatrixI;
16 import jalview.datamodel.SequenceI;
17 import jalview.util.MapUtils;
19 public class PAEContactMatrix implements ContactMatrixI
22 SequenceI refSeq = null;
25 * the length that refSeq is expected to be (excluding gaps, of course)
29 int maxrow = 0, maxcol = 0;
31 int[] indices1, indices2;
37 private void setRefSeq(SequenceI _refSeq)
40 while (refSeq.getDatasetSequence() != null)
42 refSeq = refSeq.getDatasetSequence();
44 length = _refSeq.getEnd() - _refSeq.getStart() + 1;
47 @SuppressWarnings("unchecked")
48 public PAEContactMatrix(SequenceI _refSeq, Map<String, Object> pae_obj)
51 // convert the lists to primitive arrays and store
53 if (!MapUtils.containsAKey(pae_obj, "predicted_aligned_error", "pae"))
55 parse_version_1_pAE(pae_obj);
60 parse_version_2_pAE(pae_obj);
65 * construct a sequence associated PAE matrix directly from a float array
70 public PAEContactMatrix(SequenceI _refSeq, float[][] matrix)
74 for (float[] row : matrix)
76 if (row.length > maxcol)
89 maxrow = matrix.length;
95 * parse a sane JSON representation of the pAE
99 @SuppressWarnings("unchecked")
100 private void parse_version_2_pAE(Map<String, Object> pae_obj)
102 // this is never going to be reached by the integer rounding.. or is it ?
103 maxscore = ((Double) MapUtils.getFirst(pae_obj,
104 "max_predicted_aligned_error", "max_pae")).floatValue();
105 List<List<Long>> scoreRows = ((List<List<Long>>) MapUtils
106 .getFirst(pae_obj, "predicted_aligned_error", "pae"));
107 elements = new float[scoreRows.size()][scoreRows.size()];
108 int row = 0, col = 0;
109 for (List<Long> scoreRow : scoreRows)
111 Iterator<Long> scores = scoreRow.iterator();
112 while (scores.hasNext())
114 Object d = scores.next();
115 if (d instanceof Double)
116 elements[row][col++] = ((Double) d).longValue();
118 elements[row][col++] = (float) ((Long) d).longValue();
128 * v1 format got ditched 28th July 2022 see
129 * https://alphafold.ebi.ac.uk/faq#:~:text=We%20updated%20the%20PAE%20JSON%20file%20format%20on%2028th%20July%202022
133 @SuppressWarnings("unchecked")
134 private void parse_version_1_pAE(Map<String, Object> pae_obj)
136 // assume indices are with respect to range defined by _refSeq on the
138 Iterator<Long> rows = ((List<Long>) pae_obj.get("residue1")).iterator();
139 Iterator<Long> cols = ((List<Long>) pae_obj.get("residue2")).iterator();
140 Iterator<Double> scores = ((List<Double>) pae_obj.get("distance"))
142 // assume square matrix
143 elements = new float[length][length];
144 while (scores.hasNext())
146 float escore = scores.next().floatValue();
147 int row = rows.next().intValue();
148 int col = cols.next().intValue();
157 elements[row - 1][col - 1] = escore;
160 maxscore = ((Double) MapUtils.getFirst(pae_obj,
161 "max_predicted_aligned_error", "max_pae")).floatValue();
165 public ContactListI getContactList(final int _column)
167 if (_column < 0 || _column >= elements.length)
172 return new ContactListImpl(new ContactListProviderI()
175 public int getPosition()
181 public int getContactHeight()
187 public double getContactAt(int column)
189 if (column < 0 || column >= elements[_column].length)
193 return elements[_column][column];
199 public float getMin()
205 public float getMax()
211 public boolean hasReferenceSeq()
213 return (refSeq != null);
217 public SequenceI getReferenceSeq()
223 public String getAnnotDescr()
225 return "Predicted Alignment Error"+((refSeq==null) ? "" : (" for " + refSeq.getName()));
229 public String getAnnotLabel()
231 StringBuilder label = new StringBuilder("PAE Matrix");
232 //if (this.getReferenceSeq() != null)
234 // label.append(":").append(this.getReferenceSeq().getDisplayId(false));
236 return label.toString();
239 public static final String PAEMATRIX = "PAE_MATRIX";
242 public String getType()
248 public int getWidth()
254 public int getHeight()
258 List<BitSet> groups=null;
260 public boolean hasGroups()
266 public String getNewick()
272 String treeType=null;
273 public void makeGroups(float thresh,boolean abs)
275 AverageDistanceEngine clusterer = new AverageDistanceEngine(null, null, this);
276 double height = clusterer.findHeight(clusterer.getTopNode());
277 newick = new jalview.io.NewickFile(clusterer.getTopNode(),false,true).print();
279 Console.trace("Newick string\n"+newick);
281 List<BinaryNode> nodegroups;
282 if (abs ? height > thresh : 0 < thresh && thresh < 1)
284 float cut = abs ? (float) (thresh / height) : thresh;
285 Console.debug("Threshold "+cut+" for height="+height);
287 nodegroups = clusterer.groupNodes(cut);
291 nodegroups = new ArrayList<BinaryNode>();
292 nodegroups.add(clusterer.getTopNode());
296 groups = new ArrayList<>();
297 for (BinaryNode root:nodegroups)
299 BitSet gpset=new BitSet();
300 for (BinaryNode leaf:clusterer.findLeaves(root))
302 gpset.set((Integer)leaf.element());
309 public BitSet getGroupsFor(int column)
311 for (BitSet gp:groups) {
317 return ContactMatrixI.super.getGroupsFor(column);
320 public void restoreGroups(List<BitSet> newgroups, String treeMethod,
321 String tree, double thresh2)
330 public boolean hasCutHeight() {
331 return groups!=null && thresh!=0;
334 public double getCutHeight()
339 public String getTreeMethod()