JAL-3878 Add jpred operation and worker to the services.
[jalview.git] / src / jalview / ws2 / operations / JPredWorker.java
1 package jalview.ws2.operations;
2
3 import static java.lang.String.format;
4 import java.io.IOException;
5 import java.util.List;
6 import java.util.Map;
7 import java.util.function.Consumer;
8
9 import jalview.analysis.SeqsetUtils;
10 import jalview.analysis.SeqsetUtils.SequenceInfo;
11 import jalview.api.AlignViewportI;
12 import jalview.bin.Cache;
13 import jalview.commands.RemoveGapsCommand;
14 import jalview.datamodel.Alignment;
15 import jalview.datamodel.AlignmentAnnotation;
16 import jalview.datamodel.AlignmentI;
17 import jalview.datamodel.AlignmentView;
18 import jalview.datamodel.HiddenColumns;
19 import jalview.datamodel.SeqCigar;
20 import jalview.datamodel.SequenceI;
21 import jalview.gui.AlignFrame;
22 import jalview.io.JnetAnnotationMaker;
23 import jalview.util.MessageManager;
24 import jalview.ws.params.ArgumentI;
25 import jalview.ws2.WSJob;
26 import jalview.ws2.WSJobStatus;
27 import jalview.ws2.operations.AlignmentWorker.AlignmentJob;
28
29 public class JPredWorker extends AbstractPollableWorker
30 {
31
32   private class InputFormatParameter implements ArgumentI
33   {
34     String value = "";
35
36     @Override
37     public String getName()
38     {
39       return "format";
40     }
41
42     @Override
43     public String getValue()
44     {
45       return value;
46     }
47
48     @Override
49     public void setValue(String selectedItem)
50     {
51       value = selectedItem;
52     }
53   }
54
55   private static class JobInput
56   {
57     List<SequenceI> msf;
58
59     int[] delMap;
60
61     Map<String, SequenceInfo> sequenceInfo;
62   }
63
64   public class JPredJob extends WSJob
65   {
66     List<SequenceI> msf;
67
68     int[] delMap;
69
70     Map<String, SequenceInfo> sequenceInfo;
71
72     private JPredJob()
73     {
74       super(operation.service.getProviderName(), operation.getName(),
75           operation.getHostName());
76     }
77
78     private void setInput(JobInput input)
79     {
80       msf = input.msf;
81       delMap = input.delMap;
82       sequenceInfo = input.sequenceInfo;
83     }
84   }
85
86   public class PredictionResult
87   {
88     AlignmentI alignment;
89
90     HiddenColumns hiddenCols;
91
92     int firstSeq;
93
94     public AlignmentI getAlignment()
95     {
96       return alignment;
97     }
98
99     public HiddenColumns getHiddenCols()
100     {
101       return hiddenCols;
102     }
103   }
104
105   private JPredOperation operation;
106
107   private Consumer<PredictionResult> resultConsumer;
108
109   private AlignmentView view;
110
111   private WSJobList<JPredJob> jobs = new WSJobList<>();
112
113   private JPredJob job;
114
115   private char gapChar;
116
117   AlignmentI currentView;
118
119   public JPredWorker(JPredOperation operation, AlignmentView alignView,
120       AlignViewportI viewport)
121   {
122     this.operation = operation;
123     this.view = alignView;
124     this.gapChar = viewport.getGapCharacter();
125     this.currentView = viewport.getAlignment();
126   }
127
128   @Override
129   public Operation getOperation()
130   {
131     return operation;
132   }
133
134   @Override
135   public WSJobList<? extends WSJob> getJobs()
136   {
137     return jobs;
138   }
139
140   public void setResultConsumer(Consumer<PredictionResult> consumer)
141   {
142     this.resultConsumer = consumer;
143   }
144
145   @Override
146   public void start() throws IOException
147   {
148     var input = prepareInputData(view, true);
149     job = new JPredJob();
150     job.setInput(input);
151     jobs.add(job);
152     listeners.fireJobCreated(job);
153
154     var formatArg = new InputFormatParameter();
155     formatArg.setValue(input.msf.size() > 1 ? "fasta" : "seq");
156     List<ArgumentI> args = List.of(formatArg);
157     int exceptionCount = MAX_RETRY;
158     String jobId = null;
159     do
160     {
161       try
162       {
163         jobId = operation.getWebService().submit(job.msf, args);
164       } catch (IOException e)
165       {
166         Cache.log.warn(format("%s failed to submit sequences to the server %s.",
167             operation.getName(), operation.getHostName()), e);
168         exceptionCount--;
169       }
170     } while (jobId == null && exceptionCount > 0);
171     if (jobId != null)
172     {
173       job.setJobId(jobId);
174       job.setStatus(WSJobStatus.SUBMITTED);
175       listeners.fireWorkerStarted();
176     }
177     else
178     {
179       job.setStatus(WSJobStatus.SERVER_ERROR);
180       listeners.fireWorkerNotStarted();
181     }
182   }
183
184   private static JobInput prepareInputData(AlignmentView view, boolean viewOnly)
185   {
186     SeqCigar[] msf = view.getSequences();
187     SequenceI seq = msf[0].getSeq('-');
188     int[] delMap = null;
189     if (viewOnly)
190       delMap = view.getVisibleContigMapFor(seq.gapMap());
191     SequenceI[] aln = new SequenceI[msf.length];
192     for (int i = 0; i < msf.length; i++)
193       aln[i] = msf[i].getSeq('-');
194     var sequenceInfo = msf.length > 1 ? SeqsetUtils.uniquify(aln, true)
195         : Map.of("Sequence", SeqsetUtils.SeqCharacterHash(seq));
196     if (viewOnly)
197     {
198       // Remove hidden regions from sequence objects.
199       String seqs[] = view.getSequenceStrings('-');
200       for (int i = 0; i < msf.length; i++)
201         aln[i].setSequence(seqs[i]);
202       seq.setSequence(seqs[0]);
203     }
204     var input = new JobInput();
205     input.msf = List.of(aln);
206     input.delMap = delMap;
207     input.sequenceInfo = sequenceInfo;
208     return input;
209   }
210
211   @Override
212   public void done()
213   {
214     listeners.fireWorkerCompleting();
215     PredictionResult result = null;
216     try
217     {
218       result = (job.msf.size() > 1)
219           ? prepareMultipleSequenceResult(job)
220           : prepareSingleSequenceResult(job);
221     } catch (Exception e)
222     {
223       Cache.log.error("Couldn't retrieve results for job.", e);
224       job.setStatus(WSJobStatus.SERVER_ERROR);
225     }
226     if (result != null)
227     {
228       for (var annot : result.alignment.getAlignmentAnnotation())
229       {
230         if (annot.sequenceRef != null)
231         {
232           replaceAnnotationOnAlignmentWith(annot, annot.label,
233               getClass().getName(), annot.sequenceRef);
234         }
235       }
236     }
237     resultConsumer.accept(result);
238     listeners.fireWorkerCompleted();
239   }
240
241   private PredictionResult prepareMultipleSequenceResult(JPredJob job)
242       throws Exception
243   {
244     AlignmentI alignment;
245     HiddenColumns hiddenCols = null;
246     var prediction = operation.predictionSupplier.getPrediction(job);
247     if (job.delMap != null)
248     {
249       Object[] alandcolsel = view.getAlignmentAndHiddenColumns(gapChar);
250       alignment = new Alignment((SequenceI[]) alandcolsel[0]);
251       hiddenCols = (HiddenColumns) alandcolsel[1];
252     }
253     else
254     {
255       alignment = operation.predictionSupplier.getAlignment(job);
256       var seqs = new SequenceI[alignment.getHeight()];
257       for (int i = 0; i < alignment.getHeight(); i++)
258       {
259         seqs[i] = alignment.getSequenceAt(i);
260       }
261       SeqsetUtils.deuniquify(job.sequenceInfo, seqs);
262     }
263     int firstSeq = 0;
264     alignment.setDataset(currentView.getDataset());
265     JnetAnnotationMaker.add_annotation(prediction, alignment, firstSeq, false,
266         job.delMap);
267     var result = new PredictionResult();
268     result.alignment = alignment;
269     result.hiddenCols = hiddenCols;
270     result.firstSeq = firstSeq;
271     return result;
272   }
273
274   static final int msaIndex = 0;
275
276   private PredictionResult prepareSingleSequenceResult(JPredJob job)
277       throws Exception
278   {
279     var prediction = operation.predictionSupplier.getPrediction(job);
280     AlignmentI alignment = new Alignment(prediction.getSeqsAsArray());
281     HiddenColumns hiddenCols = null;
282     int firstSeq = prediction.getQuerySeqPosition();
283     if (job.delMap != null)
284     {
285       Object[] alanndcolsel = view.getAlignmentAndHiddenColumns(gapChar);
286       SequenceI[] seqs = (SequenceI[]) alanndcolsel[0];
287       new RemoveGapsCommand(MessageManager.getString("label.remove_gaps"),
288           new SequenceI[]
289           { seqs[msaIndex] }, currentView);
290       SequenceI profileSeq = alignment.getSequenceAt(firstSeq);
291       profileSeq.setSequence(seqs[msaIndex].getSequenceAsString());
292     }
293     SeqsetUtils.SeqCharacterUnhash(alignment.getSequenceAt(firstSeq),
294         job.sequenceInfo.get("Sequence"));
295     alignment.setDataset(currentView.getDataset());
296     JnetAnnotationMaker.add_annotation(prediction, alignment, firstSeq, true,
297         job.delMap);
298     SequenceI profileSeq = alignment.getSequenceAt(0);
299     if (job.delMap != null)
300     {
301       hiddenCols = alignment.propagateInsertions(profileSeq, view);
302     }
303     var result = new PredictionResult();
304     result.alignment = alignment;
305     result.hiddenCols = hiddenCols;
306     result.firstSeq = firstSeq;
307     return result;
308   }
309
310   private static void replaceAnnotationOnAlignmentWith(
311       AlignmentAnnotation newAnnot, String typeName, String calcId,
312       SequenceI aSeq)
313   {
314     SequenceI dsseq = aSeq.getDatasetSequence();
315     while (dsseq.getDatasetSequence() != null)
316     {
317       dsseq = dsseq.getDatasetSequence();
318     }
319     // look for same annotation on dataset and lift this one over
320     List<AlignmentAnnotation> dsan = dsseq.getAlignmentAnnotations(calcId,
321         typeName);
322     if (dsan != null && dsan.size() > 0)
323     {
324       for (AlignmentAnnotation dssan : dsan)
325       {
326         dsseq.removeAlignmentAnnotation(dssan);
327       }
328     }
329     AlignmentAnnotation dssan = new AlignmentAnnotation(newAnnot);
330     dsseq.addAlignmentAnnotation(dssan);
331     dssan.adjustForAlignment();
332   }
333
334 }