JAL-4134 allow tree groups to be stored/recovered on contact matrix for groupwise...
[jalview.git] / src / jalview / ws / datamodel / alphafold / PAEContactMatrix.java
1 package jalview.ws.datamodel.alphafold;
2
3 import java.util.ArrayList;
4 import java.util.BitSet;
5 import java.util.Iterator;
6 import java.util.List;
7 import java.util.Map;
8
9 import jalview.analysis.AverageDistanceEngine;
10 import jalview.bin.Console;
11 import jalview.datamodel.BinaryNode;
12 import jalview.datamodel.ContactListI;
13 import jalview.datamodel.ContactListImpl;
14 import jalview.datamodel.ContactListProviderI;
15 import jalview.datamodel.ContactMatrixI;
16 import jalview.datamodel.SequenceI;
17 import jalview.util.MapUtils;
18
19 public class PAEContactMatrix implements ContactMatrixI
20 {
21
22   SequenceI refSeq = null;
23
24   /**
25    * the length that refSeq is expected to be (excluding gaps, of course)
26    */
27   int length;
28
29   int maxrow = 0, maxcol = 0;
30
31   int[] indices1, indices2;
32
33   float[][] elements;
34
35   float maxscore;
36
37   private void setRefSeq(SequenceI _refSeq)
38   {
39     refSeq = _refSeq;
40     while (refSeq.getDatasetSequence() != null)
41     {
42       refSeq = refSeq.getDatasetSequence();
43     }
44     length = _refSeq.getEnd() - _refSeq.getStart() + 1;
45   }
46   @SuppressWarnings("unchecked")
47   public PAEContactMatrix(SequenceI _refSeq, Map<String, Object> pae_obj)
48   {
49     setRefSeq(_refSeq);
50     // convert the lists to primitive arrays and store
51     
52     if (!MapUtils.containsAKey(pae_obj, "predicted_aligned_error", "pae"))
53     {
54       parse_version_1_pAE(pae_obj);
55       return;
56     }
57     else
58     {
59       parse_version_2_pAE(pae_obj);
60     }
61   }
62   /**
63    * construct a sequence associated PAE matrix directly from a float array
64    * @param _refSeq
65    * @param matrix
66    */
67   public PAEContactMatrix(SequenceI _refSeq, float[][] matrix)
68   {
69     setRefSeq(_refSeq);
70     maxcol=0;
71     for (float[] row:matrix)
72     {
73       if (row.length>maxcol)
74       {
75         maxcol=row.length;
76       }
77       maxscore=row[0];
78       for (float f:row)
79       {
80         if (maxscore<f) {
81           maxscore=f;
82         }
83       }
84     }
85     maxrow=matrix.length;
86     elements = matrix;
87     
88   }
89
90   /**
91    * parse a sane JSON representation of the pAE
92    * 
93    * @param pae_obj
94    */
95   @SuppressWarnings("unchecked")
96   private void parse_version_2_pAE(Map<String, Object> pae_obj)
97   {
98     // this is never going to be reached by the integer rounding.. or is it ?
99     maxscore = ((Double) MapUtils.getFirst(pae_obj,
100             "max_predicted_aligned_error", "max_pae")).floatValue();
101     List<List<Long>> scoreRows = ((List<List<Long>>) MapUtils
102             .getFirst(pae_obj, "predicted_aligned_error", "pae"))
103             ;
104     elements = new float[scoreRows.size()][scoreRows.size()];
105     int row = 0, col = 0;
106     for (List<Long> scoreRow:scoreRows)
107     {
108       Iterator<Long> scores = scoreRow.iterator();
109       while (scores.hasNext())
110       {
111         Object d = scores.next();
112         if (d instanceof Double)
113           elements[row][col++] = ((Double) d).longValue();
114         else
115           elements[row][col++] = (float) ((Long)d).longValue();
116       }
117       row++;
118       col = 0;
119     }
120     maxcol = length;
121     maxrow = length;
122   }
123
124   /**
125    * v1 format got ditched 28th July 2022 see
126    * https://alphafold.ebi.ac.uk/faq#:~:text=We%20updated%20the%20PAE%20JSON%20file%20format%20on%2028th%20July%202022
127    * 
128    * @param pae_obj
129    */
130   @SuppressWarnings("unchecked")
131   private void parse_version_1_pAE(Map<String, Object> pae_obj)
132   {
133     // assume indices are with respect to range defined by _refSeq on the
134     // dataset refSeq
135     Iterator<Long> rows = ((List<Long>) pae_obj.get("residue1")).iterator();
136     Iterator<Long> cols = ((List<Long>) pae_obj.get("residue2")).iterator();
137     Iterator<Double> scores = ((List<Double>) pae_obj.get("distance"))
138             .iterator();
139     // assume square matrix
140     elements = new float[length][length];
141     while (scores.hasNext())
142     {
143       float escore = scores.next().floatValue();
144       int row = rows.next().intValue();
145       int col = cols.next().intValue();
146       if (maxrow < row)
147       {
148         maxrow = row;
149       }
150       if (maxcol < col)
151       {
152         maxcol = col;
153       }
154       elements[row - 1][col - 1] = escore;
155     }
156
157     maxscore = ((Double) MapUtils.getFirst(pae_obj,
158             "max_predicted_aligned_error", "max_pae")).floatValue();
159   }
160
161   @Override
162   public ContactListI getContactList(final int _column)
163   {
164     if (_column < 0 || _column >= elements.length)
165     {
166       return null;
167     }
168
169     return new ContactListImpl(new ContactListProviderI()
170     {
171       @Override
172       public int getPosition()
173       {
174         return _column;
175       }
176
177       @Override
178       public int getContactHeight()
179       {
180         return maxcol-1;
181       }
182
183       @Override
184       public double getContactAt(int column)
185       {
186         if (column < 0 || column >= elements[_column].length)
187         {
188           return -1;
189         }
190         return elements[_column][column];
191       }
192     });
193   }
194
195   @Override
196   public float getMin()
197   {
198     return 0;
199   }
200
201   @Override
202   public float getMax()
203   {
204     return maxscore;
205   }
206
207   @Override
208   public boolean hasReferenceSeq()
209   {
210     return (refSeq != null);
211   }
212
213   @Override
214   public SequenceI getReferenceSeq()
215   {
216     return refSeq;
217   }
218
219   @Override
220   public String getAnnotDescr()
221   {
222     return "Predicted Alignment Error for " + refSeq.getName();
223   }
224
225   @Override
226   public String getAnnotLabel()
227   {
228     return "pAE Matrix";
229   }
230
231   public static final String PAEMATRIX="PAE_MATRIX";
232   @Override
233   public String getType()
234   {
235     return PAEMATRIX;
236   }
237   @Override
238   public int getWidth()
239   {
240     return length;
241   }
242   @Override
243   public int getHeight()
244   {
245     return length;
246   }
247   List<BitSet> groups=null;
248   @Override
249   public boolean hasGroups()
250   {
251     return groups!=null;
252   }
253   String newick=null;
254   public String getNewickString()
255   {
256     return newick;
257   }
258   public void makeGroups(float thresh,boolean abs)
259   {
260     AverageDistanceEngine clusterer = new AverageDistanceEngine(null, null, this);
261     double height = clusterer.findHeight(clusterer.getTopNode());
262     newick = new jalview.io.NewickFile(clusterer.getTopNode(),false,true).print();
263
264     Console.trace("Newick string\n"+newick);
265
266     List<BinaryNode> nodegroups;
267     if (abs ? height > thresh : 0 < thresh && thresh < 1)
268     {
269       float cut = abs ? (float) (thresh / height) : thresh;
270       Console.debug("Threshold "+cut+" for height="+height);
271
272       nodegroups = clusterer.groupNodes(cut);
273     }
274     else
275     {
276       nodegroups = new ArrayList<BinaryNode>();
277       nodegroups.add(clusterer.getTopNode());
278     }
279
280     groups = new ArrayList<>();
281     for (BinaryNode root:nodegroups)
282     {
283       BitSet gpset=new BitSet();
284       for (BinaryNode leaf:clusterer.findLeaves(root))
285       {
286         gpset.set((Integer)leaf.element());
287       }
288       groups.add(gpset);
289     }
290   }
291   
292   @Override
293   public BitSet getGroupsFor(int column)
294   {
295     for (BitSet gp:groups) {
296       if (gp.get(column))
297       {
298         return gp;
299       }
300     }
301     return ContactMatrixI.super.getGroupsFor(column);
302   }
303 }