inprogress
[jalview.git] / forester / java / src / org / forester / msa / MsaCompactor.java
1
2 package org.forester.msa;
3
4 import java.io.File;
5 import java.io.IOException;
6 import java.io.Writer;
7 import java.math.RoundingMode;
8 import java.text.DecimalFormat;
9 import java.text.DecimalFormatSymbols;
10 import java.text.NumberFormat;
11 import java.util.ArrayList;
12 import java.util.Arrays;
13 import java.util.Comparator;
14 import java.util.List;
15 import java.util.SortedSet;
16 import java.util.TreeSet;
17
18 import org.forester.msa.Msa.MSA_FORMAT;
19 import org.forester.sequence.Sequence;
20 import org.forester.util.BasicDescriptiveStatistics;
21 import org.forester.util.DescriptiveStatistics;
22 import org.forester.util.ForesterUtil;
23
24 public class MsaCompactor {
25
26     private static final boolean VERBOSE = true;
27
28     public static enum SORT_BY {
29         MAX, MEAN, MEDIAN;
30     }
31     private Msa                     _msa;
32     private final SortedSet<String> _removed_seq_ids;
33
34     private MsaCompactor( final Msa msa ) {
35         _msa = msa;
36         _removed_seq_ids = new TreeSet<String>();
37     }
38
39     final public SortedSet<String> getRemovedSeqIds() {
40         return _removed_seq_ids;
41     }
42
43     final public Msa getMsa() {
44         return _msa;
45     }
46
47     public final static MsaCompactor removeWorstOffenders( final Msa msa,
48                                                            final int worst_offenders_to_remove,
49                                                            final boolean realign ) throws IOException,
50             InterruptedException {
51         final MsaCompactor mc = new MsaCompactor( msa );
52         mc.removeWorstOffenders( worst_offenders_to_remove, 1, realign );
53         return mc;
54     }
55
56     public final static MsaCompactor reduceGapAverage( final Msa msa,
57                                                        final double max_gap_average,
58                                                        final int step,
59                                                        final boolean realign,
60                                                        final File out,
61                                                        final int minimal_effective_length ) throws IOException,
62             InterruptedException {
63         final MsaCompactor mc = new MsaCompactor( msa );
64         mc.removeViaGapAverage( max_gap_average, step, realign, out, minimal_effective_length );
65         return mc;
66     }
67
68     public final static MsaCompactor reduceLength( final Msa msa,
69                                                    final int length,
70                                                    final int step,
71                                                    final boolean realign ) throws IOException, InterruptedException {
72         final MsaCompactor mc = new MsaCompactor( msa );
73         mc.removeViaLength( length, step, realign );
74         return mc;
75     }
76
77     final private void removeGapColumns() {
78         _msa = MsaMethods.createInstance().removeGapColumns( 1, 0, _msa );
79     }
80
81     final private void removeWorstOffenders( final int to_remove, final int step, final boolean realign )
82             throws IOException, InterruptedException {
83         final DescriptiveStatistics stats[] = calcStats();
84         final List<String> to_remove_ids = new ArrayList<String>();
85         for( int j = 0; j < to_remove; ++j ) {
86             to_remove_ids.add( stats[ j ].getDescription() );
87             _removed_seq_ids.add( stats[ j ].getDescription() );
88         }
89         _msa = MsaMethods.removeSequences( _msa, to_remove_ids );
90         removeGapColumns();
91         if ( realign ) {
92             mafft();
93         }
94     }
95
96     final private void mafft() throws IOException, InterruptedException {
97         final MsaInferrer mafft = Mafft.createInstance( "mafft" );
98         final List<String> opts = new ArrayList<String>();
99         // opts.add( "--maxiterate" );
100         // opts.add( "1000" );
101         // opts.add( "--localpair" );
102         opts.add( "--quiet" );
103         _msa = mafft.infer( _msa.asSequenceList(), opts );
104     }
105
106     final private void removeViaGapAverage( final double mean_gapiness,
107                                             final int step,
108                                             final boolean realign,
109                                             final File outfile,
110                                             final int minimal_effective_length ) throws IOException,
111             InterruptedException {
112         if ( step < 1 ) {
113             throw new IllegalArgumentException( "step cannot be less than 1" );
114         }
115         if ( mean_gapiness < 0 ) {
116             throw new IllegalArgumentException( "target average gap ratio cannot be less than 0" );
117         }
118         if ( VERBOSE ) {
119             System.out.println( "orig: " + msaStatsAsSB() );
120         }
121         if ( minimal_effective_length > 1 ) {
122             _msa = MsaMethods.removeSequencesByMinimalLength( _msa, minimal_effective_length );
123             if ( VERBOSE ) {
124                 System.out.println( "short seq removal: " + msaStatsAsSB() );
125             }
126         }
127         int counter = step;
128         double gr;
129         do {
130             removeWorstOffenders( step, 1, false );
131             if ( realign ) {
132                 mafft();
133             }
134             gr = MsaMethods.calcGapRatio( _msa );
135             if ( VERBOSE ) {
136                 System.out.println( counter + ": " + msaStatsAsSB() );
137             }
138             write( outfile, gr );
139             counter += step;
140         } while ( gr > mean_gapiness );
141         if ( VERBOSE ) {
142             System.out.println( "final: " + msaStatsAsSB() );
143         }
144     }
145
146     final private void write( final File outfile, final double gr ) throws IOException {
147         writeMsa( outfile + "_" + _msa.getNumberOfSequences() + "_" + _msa.getLength() + "_"
148                 + ForesterUtil.roundToInt( gr * 100 ) + ".fasta" );
149     }
150
151     final private void writeMsa( final String outfile ) throws IOException {
152         final Writer w = ForesterUtil.createBufferedWriter( outfile );
153         _msa.write( w, MSA_FORMAT.FASTA );
154         w.close();
155     }
156
157     final private StringBuilder msaStatsAsSB() {
158         final StringBuilder sb = new StringBuilder();
159         sb.append( _msa.getLength() );
160         sb.append( "\t" );
161         sb.append( _msa.getNumberOfSequences() );
162         sb.append( "\t" );
163         sb.append( ForesterUtil.round( MsaMethods.calcGapRatio( _msa ), 4 ) );
164         sb.append( "\t" );
165         return sb;
166     }
167
168     final private void removeViaLength( final int length, final int step, final boolean realign ) throws IOException,
169             InterruptedException {
170         if ( step < 1 ) {
171             throw new IllegalArgumentException( "step cannot be less than 1" );
172         }
173         if ( length < 11 ) {
174             throw new IllegalArgumentException( "target length cannot be less than 1" );
175         }
176         if ( VERBOSE ) {
177             System.out.println( "orig: " + msaStatsAsSB() );
178         }
179         int counter = step;
180         while ( _msa.getLength() > length ) {
181             removeWorstOffenders( step, 1, false );
182             if ( realign ) {
183                 mafft();
184             }
185             if ( VERBOSE ) {
186                 System.out.println( counter + ": " + msaStatsAsSB() );
187             }
188             counter += step;
189         }
190     }
191
192     final private DescriptiveStatistics[] calcStats() {
193         final DecimalFormatSymbols dfs = new DecimalFormatSymbols();
194         dfs.setDecimalSeparator( '.' );
195         final NumberFormat f = new DecimalFormat( "#.####", dfs );
196         f.setRoundingMode( RoundingMode.HALF_UP );
197         final DescriptiveStatistics stats[] = calcGapContribtions();
198         Arrays.sort( stats, new DescriptiveStatisticsComparator( false, SORT_BY.MEAN ) );
199         for( final DescriptiveStatistics stat : stats ) {
200             final StringBuilder sb = new StringBuilder();
201             sb.append( stat.getDescription() );
202             sb.append( "\t" );
203             sb.append( f.format( stat.arithmeticMean() ) );
204             sb.append( "\t" );
205             sb.append( f.format( stat.median() ) );
206             sb.append( "\t" );
207             sb.append( f.format( stat.getMin() ) );
208             sb.append( "\t" );
209             sb.append( f.format( stat.getMax() ) );
210             sb.append( "\t" );
211             System.out.println( sb );
212         }
213         return stats;
214     }
215
216     private final DescriptiveStatistics[] calcGapContribtions() {
217         final double gappiness[] = calcGappiness();
218         final DescriptiveStatistics stats[] = new DescriptiveStatistics[ _msa.getNumberOfSequences() ];
219         for( int row = 0; row < _msa.getNumberOfSequences(); ++row ) {
220             stats[ row ] = new BasicDescriptiveStatistics( _msa.getIdentifier( row ) );
221             for( int col = 0; col < _msa.getLength(); ++col ) {
222                 if ( _msa.getResidueAt( row, col ) != Sequence.GAP ) {
223                     stats[ row ].addValue( gappiness[ col ] );
224                 }
225             }
226         }
227         return stats;
228     }
229
230     private final double[] calcGappiness() {
231         final int l = _msa.getLength();
232         final double gappiness[] = new double[ l ];
233         final int seqs = _msa.getNumberOfSequences();
234         for( int i = 0; i < l; ++i ) {
235             gappiness[ i ] = ( double ) MsaMethods.calcGapSumPerColumn( _msa, i ) / seqs;
236         }
237         return gappiness;
238     }
239
240     final static class DescriptiveStatisticsComparator implements Comparator<DescriptiveStatistics> {
241
242         final private boolean _ascending;
243         final private SORT_BY _sort_by;
244
245         public DescriptiveStatisticsComparator( final boolean ascending, final SORT_BY sort_by ) {
246             _ascending = ascending;
247             _sort_by = sort_by;
248         }
249
250         @Override
251         public final int compare( final DescriptiveStatistics s0, final DescriptiveStatistics s1 ) {
252             switch ( _sort_by ) {
253                 case MAX:
254                     if ( s0.getMax() < s1.getMax() ) {
255                         return _ascending ? -1 : 1;
256                     }
257                     else if ( s0.getMax() > s1.getMax() ) {
258                         return _ascending ? 1 : -1;
259                     }
260                     return 0;
261                 case MEAN:
262                     if ( s0.arithmeticMean() < s1.arithmeticMean() ) {
263                         return _ascending ? -1 : 1;
264                     }
265                     else if ( s0.arithmeticMean() > s1.arithmeticMean() ) {
266                         return _ascending ? 1 : -1;
267                     }
268                     return 0;
269                 case MEDIAN:
270                     if ( s0.median() < s1.median() ) {
271                         return _ascending ? -1 : 1;
272                     }
273                     else if ( s0.median() > s1.median() ) {
274                         return _ascending ? 1 : -1;
275                     }
276                     return 0;
277                 default:
278                     return 0;
279             }
280         }
281     }
282 }