Merge branch 'develop' into spike/JAL-4047/JAL-4048_columns_in_sequenceID
[jalview.git] / src / jalview / ws / datamodel / alphafold / PAEContactMatrix.java
1 package jalview.ws.datamodel.alphafold;
2
3 import java.io.File;
4 import java.io.FileInputStream;
5 import java.io.IOException;
6 import java.util.Iterator;
7 import java.util.List;
8 import java.util.Map;
9
10 import org.json.simple.JSONObject;
11
12 import jalview.datamodel.ContactListI;
13 import jalview.datamodel.ContactListImpl;
14 import jalview.datamodel.ContactListProviderI;
15 import jalview.datamodel.ContactMatrixI;
16 import jalview.datamodel.GroupSet;
17 import jalview.datamodel.SequenceDummy;
18 import jalview.datamodel.SequenceI;
19 import jalview.io.FileFormatException;
20 import jalview.util.MapList;
21 import jalview.util.MapUtils;
22 import jalview.ws.dbsources.EBIAlfaFold;
23
24 /**
25  * routines and class for holding predicted alignment error matrices as produced
26  * by alphafold et al.
27  * 
28  * getContactList(column) returns the vector of predicted alignment errors for
29  * reference position given by column getElementAt(column, i) returns the
30  * predicted superposition error for the ith position when column is used as
31  * reference
32  * 
33  * Many thanks to Ora Schueler Furman for noticing that earlier development
34  * versions did not show the PAE oriented correctly
35  *
36  * @author jprocter
37  *
38  */
39 public class PAEContactMatrix extends
40         MappableContactMatrix<PAEContactMatrix> implements ContactMatrixI
41 {
42
43
44   int maxrow = 0, maxcol = 0;
45
46
47   float[][] elements;
48
49   float maxscore;
50
51
52   @SuppressWarnings("unchecked")
53   public PAEContactMatrix(SequenceI _refSeq, Map<String, Object> pae_obj)
54           throws FileFormatException
55   {
56     setRefSeq(_refSeq);
57     // convert the lists to primitive arrays and store
58
59     if (!MapUtils.containsAKey(pae_obj, "predicted_aligned_error", "pae"))
60     {
61       parse_version_1_pAE(pae_obj);
62       return;
63     }
64     else
65     {
66       parse_version_2_pAE(pae_obj);
67     }
68   }
69
70   /**
71    * construct a sequence associated PAE matrix directly from a float array
72    * 
73    * @param _refSeq
74    * @param matrix
75    */
76   public PAEContactMatrix(SequenceI _refSeq, float[][] matrix)
77   {
78     setRefSeq(_refSeq);
79     maxcol = 0;
80     for (float[] row : matrix)
81     {
82       if (row.length > maxcol)
83       {
84         maxcol = row.length;
85       }
86       maxscore = row[0];
87       for (float f : row)
88       {
89         if (maxscore < f)
90         {
91           maxscore = f;
92         }
93       }
94     }
95     maxrow = matrix.length;
96     elements = matrix;
97
98   }
99
100   /**
101    * new matrix with specific mapping to a reference sequence
102    * 
103    * @param newRefSeq
104    * @param newFromMapList
105    * @param elements2
106    * @param grps2
107    */
108   public PAEContactMatrix(SequenceI newRefSeq, MapList newFromMapList,
109           float[][] elements2, GroupSet grps2)
110   {
111     this(newRefSeq, elements2);
112     toSeq = newFromMapList;
113     grps = grps2;
114   }
115
116   /**
117    * parse a sane JSON representation of the pAE
118    * 
119    * @param pae_obj
120    */
121   @SuppressWarnings("unchecked")
122   private void parse_version_2_pAE(Map<String, Object> pae_obj)
123   {
124     maxscore = -1;
125     // look for a maxscore element - if there is one...
126     try
127     {
128       // this is never going to be reached by the integer rounding.. or is it ?
129       maxscore = ((Double) MapUtils.getFirst(pae_obj,
130               "max_predicted_aligned_error", "max_pae")).floatValue();
131     } catch (Throwable t)
132     {
133       // ignore if a key is not found.
134     }
135     List<List<Long>> scoreRows = ((List<List<Long>>) MapUtils
136             .getFirst(pae_obj, "predicted_aligned_error", "pae"));
137     elements = new float[scoreRows.size()][scoreRows.size()];
138     int row = 0, col = 0;
139     for (List<Long> scoreRow : scoreRows)
140     {
141       Iterator<Long> scores = scoreRow.iterator();
142       while (scores.hasNext())
143       {
144         Object d = scores.next();
145         if (d instanceof Double)
146         {
147           elements[col][row] = ((Double) d).longValue();
148         }
149         else
150         {
151           elements[col][row] = (float) ((Long) d).longValue();
152         }
153
154         if (maxscore < elements[col][row])
155         {
156           maxscore = elements[col][row];
157         }
158         col++;
159       }
160       row++;
161       col = 0;
162     }
163     maxcol = length;
164     maxrow = length;
165   }
166
167   /**
168    * v1 format got ditched 28th July 2022 see
169    * https://alphafold.ebi.ac.uk/faq#:~:text=We%20updated%20the%20PAE%20JSON%20file%20format%20on%2028th%20July%202022
170    * 
171    * @param pae_obj
172    */
173   @SuppressWarnings("unchecked")
174   private void parse_version_1_pAE(Map<String, Object> pae_obj)
175   {
176     // assume indices are with respect to range defined by _refSeq on the
177     // dataset refSeq
178     Iterator<Long> rows = ((List<Long>) pae_obj.get("residue1")).iterator();
179     Iterator<Long> cols = ((List<Long>) pae_obj.get("residue2")).iterator();
180     // two pass - to allocate the elements array
181     while (rows.hasNext())
182     {
183       int row = rows.next().intValue();
184       int col = cols.next().intValue();
185       if (maxrow < row)
186       {
187         maxrow = row;
188       }
189       if (maxcol < col)
190       {
191         maxcol = col;
192       }
193
194     }
195     rows = ((List<Long>) pae_obj.get("residue1")).iterator();
196     cols = ((List<Long>) pae_obj.get("residue2")).iterator();
197     Iterator<Double> scores = ((List<Double>) pae_obj.get("distance"))
198             .iterator();
199     elements = new float[maxcol][maxrow];
200     while (scores.hasNext())
201     {
202       float escore = scores.next().floatValue();
203       int row = rows.next().intValue();
204       int col = cols.next().intValue();
205       if (maxrow < row)
206       {
207         maxrow = row;
208       }
209       if (maxcol < col)
210       {
211         maxcol = col;
212       }
213       elements[col - 1][row-1] = escore;
214     }
215
216     maxscore = ((Double) MapUtils.getFirst(pae_obj,
217             "max_predicted_aligned_error", "max_pae")).floatValue();
218   }
219
220   /**
221    * getContactList(column) @returns the vector of predicted alignment errors
222    * for reference position given by column
223    */
224   @Override
225   public ContactListI getContactList(final int column)
226   {
227     if (column < 0 || column >= elements.length)
228     {
229       return null;
230     }
231
232     return new ContactListImpl(new ContactListProviderI()
233     {
234       @Override
235       public int getPosition()
236       {
237         return column;
238       }
239
240       @Override
241       public int getContactHeight()
242       {
243         return maxcol - 1;
244       }
245
246       @Override
247       public double getContactAt(int mcolumn)
248       {
249         if (mcolumn < 0 || mcolumn >= elements[column].length)
250         {
251           return -1;
252         }
253         return elements[column][mcolumn];
254       }
255     });
256   }
257
258   /**
259    * getElementAt(column, i) @returns the predicted superposition error for the
260    * ith position when column is used as reference
261    */
262   @Override
263   protected double getElementAt(int _column, int i)
264   {
265     return elements[_column][i];
266   }
267
268   @Override
269   public float getMin()
270   {
271     return 0;
272   }
273
274   @Override
275   public float getMax()
276   {
277     return maxscore;
278   }
279
280   @Override
281   public String getAnnotDescr()
282   {
283     return "Predicted Alignment Error"
284             + ((refSeq == null) ? "" : (" for " + refSeq.getName()));
285   }
286
287   @Override
288   public String getAnnotLabel()
289   {
290     StringBuilder label = new StringBuilder("PAE Matrix");
291     // if (this.getReferenceSeq() != null)
292     // {
293     // label.append(":").append(this.getReferenceSeq().getDisplayId(false));
294     // }
295     return label.toString();
296   }
297
298   public static final String PAEMATRIX = "PAE_MATRIX";
299
300   @Override
301   public String getType()
302   {
303     return PAEMATRIX;
304   }
305
306   @Override
307   public int getWidth()
308   {
309     return maxcol;
310   }
311
312   @Override
313   public int getHeight()
314   {
315     return maxrow;
316   }
317   public static void validateContactMatrixFile(String fileName)
318           throws FileFormatException, IOException
319   {
320     FileInputStream infile = null;
321     try
322     {
323       infile = new FileInputStream(new File(fileName));
324     } catch (Throwable t)
325     {
326       new IOException("Couldn't open " + fileName, t);
327     }
328     JSONObject paeDict = null;
329     try
330     {
331       paeDict = EBIAlfaFold.parseJSONtoPAEContactMatrix(infile);
332     } catch (Throwable t)
333     {
334       new FileFormatException("Couldn't parse " + fileName
335               + " as a JSON dict or array containing a dict");
336     }
337
338     PAEContactMatrix matrix = new PAEContactMatrix(
339             new SequenceDummy("Predicted"), (Map<String, Object>) paeDict);
340     if (matrix.getWidth() <= 0)
341     {
342       throw new FileFormatException(
343               "No data in PAE matrix read from '" + fileName + "'");
344     }
345   }
346   @Override
347   protected PAEContactMatrix newMappableContactMatrix(SequenceI newRefSeq,
348           MapList newFromMapList)
349   {
350     PAEContactMatrix pae = new PAEContactMatrix(newRefSeq, newFromMapList,
351             elements, new GroupSet(grps));
352     return pae;
353   }
354 }