Merge branch 'develop' into features/JAL-4134_use_annotation_row_for_colours_and_groups
[jalview.git] / src / jalview / ws / datamodel / alphafold / PAEContactMatrix.java
1 package jalview.ws.datamodel.alphafold;
2
3 import java.awt.Color;
4 import java.io.BufferedInputStream;
5 import java.io.File;
6 import java.io.FileInputStream;
7 import java.io.IOException;
8 import java.util.ArrayList;
9 import java.util.BitSet;
10 import java.util.HashMap;
11 import java.util.Iterator;
12 import java.util.List;
13 import java.util.Map;
14 import java.util.Map.Entry;
15
16 import org.json.simple.JSONObject;
17
18 import jalview.analysis.AverageDistanceEngine;
19 import jalview.bin.Console;
20 import jalview.datamodel.Annotation;
21 import jalview.datamodel.BinaryNode;
22 import jalview.datamodel.ContactListI;
23 import jalview.datamodel.ContactListImpl;
24 import jalview.datamodel.ContactListProviderI;
25 import jalview.datamodel.ContactMatrixI;
26 import jalview.datamodel.GroupSet;
27 import jalview.datamodel.GroupSetI;
28 import jalview.datamodel.Mapping;
29 import jalview.datamodel.SequenceDummy;
30 import jalview.datamodel.SequenceI;
31 import jalview.io.DataSourceType;
32 import jalview.io.FileFormatException;
33 import jalview.io.FileParse;
34 import jalview.util.MapList;
35 import jalview.util.MapUtils;
36 import jalview.ws.dbsources.EBIAlfaFold;
37
38 public class PAEContactMatrix extends
39         MappableContactMatrix<PAEContactMatrix> implements ContactMatrixI
40 {
41
42
43   int maxrow = 0, maxcol = 0;
44
45
46   float[][] elements;
47
48   float maxscore;
49
50
51   @SuppressWarnings("unchecked")
52   public PAEContactMatrix(SequenceI _refSeq, Map<String, Object> pae_obj)
53           throws FileFormatException
54   {
55     setRefSeq(_refSeq);
56     // convert the lists to primitive arrays and store
57
58     if (!MapUtils.containsAKey(pae_obj, "predicted_aligned_error", "pae"))
59     {
60       parse_version_1_pAE(pae_obj);
61       return;
62     }
63     else
64     {
65       parse_version_2_pAE(pae_obj);
66     }
67   }
68
69   /**
70    * construct a sequence associated PAE matrix directly from a float array
71    * 
72    * @param _refSeq
73    * @param matrix
74    */
75   public PAEContactMatrix(SequenceI _refSeq, float[][] matrix)
76   {
77     setRefSeq(_refSeq);
78     maxcol = 0;
79     for (float[] row : matrix)
80     {
81       if (row.length > maxcol)
82       {
83         maxcol = row.length;
84       }
85       maxscore = row[0];
86       for (float f : row)
87       {
88         if (maxscore < f)
89         {
90           maxscore = f;
91         }
92       }
93     }
94     maxrow = matrix.length;
95     elements = matrix;
96
97   }
98
99   /**
100    * new matrix with specific mapping to a reference sequence
101    * 
102    * @param newRefSeq
103    * @param newFromMapList
104    * @param elements2
105    * @param grps2
106    */
107   public PAEContactMatrix(SequenceI newRefSeq, MapList newFromMapList,
108           float[][] elements2, GroupSet grps2)
109   {
110     this(newRefSeq, elements2);
111     toSeq = newFromMapList;
112     grps = grps2;
113   }
114
115   /**
116    * parse a sane JSON representation of the pAE
117    * 
118    * @param pae_obj
119    */
120   @SuppressWarnings("unchecked")
121   private void parse_version_2_pAE(Map<String, Object> pae_obj)
122   {
123     maxscore = -1;
124     // look for a maxscore element - if there is one...
125     try
126     {
127       // this is never going to be reached by the integer rounding.. or is it ?
128       maxscore = ((Double) MapUtils.getFirst(pae_obj,
129               "max_predicted_aligned_error", "max_pae")).floatValue();
130     } catch (Throwable t)
131     {
132       // ignore if a key is not found.
133     }
134     List<List<Long>> scoreRows = ((List<List<Long>>) MapUtils
135             .getFirst(pae_obj, "predicted_aligned_error", "pae"));
136     elements = new float[scoreRows.size()][scoreRows.size()];
137     int row = 0, col = 0;
138     for (List<Long> scoreRow : scoreRows)
139     {
140       Iterator<Long> scores = scoreRow.iterator();
141       while (scores.hasNext())
142       {
143         Object d = scores.next();
144         if (d instanceof Double)
145         {
146           elements[row][col++] = ((Double) d).longValue();
147         }
148         else
149         {
150           elements[row][col++] = (float) ((Long) d).longValue();
151         }
152
153         if (maxscore < elements[row][col - 1])
154         {
155           maxscore = elements[row][col - 1];
156         }
157       }
158       row++;
159       col = 0;
160     }
161     maxcol = length;
162     maxrow = length;
163   }
164
165   /**
166    * v1 format got ditched 28th July 2022 see
167    * https://alphafold.ebi.ac.uk/faq#:~:text=We%20updated%20the%20PAE%20JSON%20file%20format%20on%2028th%20July%202022
168    * 
169    * @param pae_obj
170    */
171   @SuppressWarnings("unchecked")
172   private void parse_version_1_pAE(Map<String, Object> pae_obj)
173   {
174     // assume indices are with respect to range defined by _refSeq on the
175     // dataset refSeq
176     Iterator<Long> rows = ((List<Long>) pae_obj.get("residue1")).iterator();
177     Iterator<Long> cols = ((List<Long>) pae_obj.get("residue2")).iterator();
178     // two pass - to allocate the elements array
179     while (rows.hasNext())
180     {
181       int row = rows.next().intValue();
182       int col = cols.next().intValue();
183       if (maxrow < row)
184       {
185         maxrow = row;
186       }
187       if (maxcol < col)
188       {
189         maxcol = col;
190       }
191
192     }
193     rows = ((List<Long>) pae_obj.get("residue1")).iterator();
194     cols = ((List<Long>) pae_obj.get("residue2")).iterator();
195     Iterator<Double> scores = ((List<Double>) pae_obj.get("distance"))
196             .iterator();
197     elements = new float[maxrow][maxcol];
198     while (scores.hasNext())
199     {
200       float escore = scores.next().floatValue();
201       int row = rows.next().intValue();
202       int col = cols.next().intValue();
203       if (maxrow < row)
204       {
205         maxrow = row;
206       }
207       if (maxcol < col)
208       {
209         maxcol = col;
210       }
211       elements[row - 1][col - 1] = escore;
212     }
213
214     maxscore = ((Double) MapUtils.getFirst(pae_obj,
215             "max_predicted_aligned_error", "max_pae")).floatValue();
216   }
217
218   @Override
219   public ContactListI getContactList(final int column)
220   {
221     if (column < 0 || column >= elements.length)
222     {
223       return null;
224     }
225
226     return new ContactListImpl(new ContactListProviderI()
227     {
228       @Override
229       public int getPosition()
230       {
231         return column;
232       }
233
234       @Override
235       public int getContactHeight()
236       {
237         return maxcol - 1;
238       }
239
240       @Override
241       public double getContactAt(int mcolumn)
242       {
243         if (mcolumn < 0 || mcolumn >= elements[column].length)
244         {
245           return -1;
246         }
247         return elements[column][mcolumn];
248       }
249     });
250   }
251
252   @Override
253   protected double getElementAt(int _column, int i)
254   {
255     return elements[_column][i];
256   }
257
258   @Override
259   public float getMin()
260   {
261     return 0;
262   }
263
264   @Override
265   public float getMax()
266   {
267     return maxscore;
268   }
269
270   @Override
271   public String getAnnotDescr()
272   {
273     return "Predicted Alignment Error"
274             + ((refSeq == null) ? "" : (" for " + refSeq.getName()));
275   }
276
277   @Override
278   public String getAnnotLabel()
279   {
280     StringBuilder label = new StringBuilder("PAE Matrix");
281     // if (this.getReferenceSeq() != null)
282     // {
283     // label.append(":").append(this.getReferenceSeq().getDisplayId(false));
284     // }
285     return label.toString();
286   }
287
288   public static final String PAEMATRIX = "PAE_MATRIX";
289
290   @Override
291   public String getType()
292   {
293     return PAEMATRIX;
294   }
295
296   @Override
297   public int getWidth()
298   {
299     return length;
300   }
301
302   @Override
303   public int getHeight()
304   {
305     return length;
306   }
307   public static void validateContactMatrixFile(String fileName)
308           throws FileFormatException, IOException
309   {
310     FileInputStream infile = null;
311     try
312     {
313       infile = new FileInputStream(new File(fileName));
314     } catch (Throwable t)
315     {
316       new IOException("Couldn't open " + fileName, t);
317     }
318     JSONObject paeDict = null;
319     try
320     {
321       paeDict = EBIAlfaFold.parseJSONtoPAEContactMatrix(infile);
322     } catch (Throwable t)
323     {
324       new FileFormatException("Couldn't parse " + fileName
325               + " as a JSON dict or array containing a dict");
326     }
327
328     PAEContactMatrix matrix = new PAEContactMatrix(
329             new SequenceDummy("Predicted"), (Map<String, Object>) paeDict);
330     if (matrix.getWidth() <= 0)
331     {
332       throw new FileFormatException(
333               "No data in PAE matrix read from '" + fileName + "'");
334     }
335   }
336   @Override
337   protected PAEContactMatrix newMappableContactMatrix(SequenceI newRefSeq,
338           MapList newFromMapList)
339   {
340     PAEContactMatrix pae = new PAEContactMatrix(newRefSeq, newFromMapList,
341             elements, new GroupSet(grps));
342     return pae;
343   }
344 }