in progress
[jalview.git] / forester / java / src / org / forester / msa / MsaCompactor.java
1
2 package org.forester.msa;
3
4 import java.io.File;
5 import java.io.IOException;
6 import java.io.Writer;
7 import java.util.ArrayList;
8 import java.util.Arrays;
9 import java.util.Comparator;
10 import java.util.List;
11 import java.util.SortedSet;
12 import java.util.TreeSet;
13
14 import org.forester.msa.Msa.MSA_FORMAT;
15 import org.forester.sequence.Sequence;
16 import org.forester.util.BasicDescriptiveStatistics;
17 import org.forester.util.DescriptiveStatistics;
18 import org.forester.util.ForesterUtil;
19
20 public class MsaCompactor {
21
22     private static final boolean VERBOSE = true;
23
24     public static enum SORT_BY {
25         MAX, MEAN, MEDIAN;
26     }
27     private Msa                     _msa;
28     private final SortedSet<String> _removed_seq_ids;
29
30     private MsaCompactor( final Msa msa ) {
31         _msa = msa;
32         _removed_seq_ids = new TreeSet<String>();
33     }
34
35     final public SortedSet<String> getRemovedSeqIds() {
36         return _removed_seq_ids;
37     }
38
39     final public Msa getMsa() {
40         return _msa;
41     }
42
43     public final static MsaCompactor removeWorstOffenders( final Msa msa,
44                                                            final int worst_offenders_to_remove,
45                                                            final boolean realign ) throws IOException,
46             InterruptedException {
47         final MsaCompactor mc = new MsaCompactor( msa );
48         mc.removeWorstOffenders( worst_offenders_to_remove, 1, realign );
49         return mc;
50     }
51
52     public final static MsaCompactor reduceGapAverage( final Msa msa,
53                                                        final double max_gap_average,
54                                                        final int step,
55                                                        final boolean realign,
56                                                        final File out,
57                                                        final int minimal_effective_length ) throws IOException,
58             InterruptedException {
59         final MsaCompactor mc = new MsaCompactor( msa );
60         mc.removeViaGapAverage( max_gap_average, step, realign, out, minimal_effective_length );
61         return mc;
62     }
63
64     public final static MsaCompactor reduceLength( final Msa msa,
65                                                    final int length,
66                                                    final int step,
67                                                    final boolean realign ) throws IOException, InterruptedException {
68         final MsaCompactor mc = new MsaCompactor( msa );
69         mc.removeViaLength( length, step, realign );
70         return mc;
71     }
72
73     final private void removeGapColumns() {
74         _msa = MsaMethods.createInstance().removeGapColumns( 1, 0, _msa );
75     }
76
77     final private void removeWorstOffenders( final int to_remove, final int step, final boolean realign )
78             throws IOException, InterruptedException {
79         final DescriptiveStatistics stats[] = calcStats();
80         final List<String> to_remove_ids = new ArrayList<String>();
81         for( int j = 0; j < to_remove; ++j ) {
82             to_remove_ids.add( stats[ j ].getDescription() );
83             _removed_seq_ids.add( stats[ j ].getDescription() );
84         }
85         _msa = MsaMethods.removeSequences( _msa, to_remove_ids );
86         removeGapColumns();
87         if ( realign ) {
88             mafft();
89         }
90     }
91
92     final private void mafft() throws IOException, InterruptedException {
93         final MsaInferrer mafft = Mafft.createInstance( "/home/czmasek/bin/mafft" );
94         final List<String> opts = new ArrayList<String>();
95         // opts.add( "--maxiterate" );
96         // opts.add( "1000" );
97         // opts.add( "--localpair" );
98         opts.add( "--quiet" );
99         _msa = mafft.infer( _msa.asSequenceList(), opts );
100     }
101
102     final private void removeViaGapAverage( final double mean_gapiness,
103                                             final int step,
104                                             final boolean realign,
105                                             final File outfile,
106                                             final int minimal_effective_length ) throws IOException,
107             InterruptedException {
108         if ( step < 1 ) {
109             throw new IllegalArgumentException( "step cannot be less than 1" );
110         }
111         if ( mean_gapiness < 0 ) {
112             throw new IllegalArgumentException( "target average gap ratio cannot be less than 0" );
113         }
114         if ( VERBOSE ) {
115             System.out.println( "orig: " + msaStatsAsSB() );
116         }
117         if ( minimal_effective_length > 1 ) {
118             _msa = MsaMethods.removeSequencesByMinimalLength( _msa, minimal_effective_length );
119             if ( VERBOSE ) {
120                 System.out.println( "short seq removal: " + msaStatsAsSB() );
121             }
122         }
123         int counter = step;
124         double gr;
125         do {
126             removeWorstOffenders( step, 1, false );
127             if ( realign ) {
128                 mafft();
129             }
130             gr = MsaMethods.calcGapRatio( _msa );
131             if ( VERBOSE ) {
132                 System.out.println( counter + ": " + msaStatsAsSB() );
133             }
134             write( outfile, gr );
135             counter += step;
136         } while ( gr > mean_gapiness );
137         if ( VERBOSE ) {
138             System.out.println( "final: " + msaStatsAsSB() );
139         }
140     }
141
142     final private void write( final File outfile, final double gr ) throws IOException {
143         writeMsa( outfile + "_" + _msa.getNumberOfSequences() + "_" + _msa.getLength() + "_"
144                 + ForesterUtil.roundToInt( gr * 100 ) + ".aln" );
145     }
146
147     final private void writeMsa( final String outfile ) throws IOException {
148         final Writer w = ForesterUtil.createBufferedWriter( outfile );
149         _msa.write( w, MSA_FORMAT.FASTA );
150         w.close();
151     }
152
153     final private StringBuilder msaStatsAsSB() {
154         final StringBuilder sb = new StringBuilder();
155         sb.append( _msa.getLength() );
156         sb.append( "\t" );
157         sb.append( _msa.getNumberOfSequences() );
158         sb.append( "\t" );
159         sb.append( ForesterUtil.round( MsaMethods.calcGapRatio( _msa ), 4 ) );
160         sb.append( "\t" );
161         return sb;
162     }
163
164     final private void removeViaLength( final int length, final int step, final boolean realign ) throws IOException,
165             InterruptedException {
166         if ( step < 1 ) {
167             throw new IllegalArgumentException( "step cannot be less than 1" );
168         }
169         if ( length < 11 ) {
170             throw new IllegalArgumentException( "target length cannot be less than 1" );
171         }
172         if ( VERBOSE ) {
173             System.out.println( "orig: " + msaStatsAsSB() );
174         }
175         int counter = step;
176         while ( _msa.getLength() > length ) {
177             removeWorstOffenders( step, 1, false );
178             if ( realign ) {
179                 mafft();
180             }
181             if ( VERBOSE ) {
182                 System.out.println( counter + ": " + msaStatsAsSB() );
183             }
184             counter += step;
185         }
186     }
187
188     final private DescriptiveStatistics[] calcStats() {
189         final DescriptiveStatistics stats[] = calc();
190         sort( stats );
191         for( int i = 0; i < stats.length; ++i ) {
192             final DescriptiveStatistics s = stats[ i ];
193             //            System.out.print( s.getDescription() );
194             //            System.out.print( "\t" );
195             //            System.out.print( s.arithmeticMean() );
196             //            System.out.print( "\t(" );
197             //            System.out.print( s.arithmeticMean() );
198             //            System.out.print( ")" );
199             //            System.out.print( "\t" );
200             //            System.out.print( s.getMin() );
201             //            System.out.print( "\t" );
202             //            System.out.print( s.getMax() );
203             //            System.out.println();
204         }
205         return stats;
206     }
207
208     private final static void sort( final DescriptiveStatistics stats[] ) {
209         Arrays.sort( stats, new DescriptiveStatisticsComparator( false, SORT_BY.MAX ) );
210     }
211
212     private final DescriptiveStatistics[] calc() {
213         final double gappiness[] = calcGappiness();
214         final DescriptiveStatistics stats[] = new DescriptiveStatistics[ _msa.getNumberOfSequences() ];
215         for( int row = 0; row < _msa.getNumberOfSequences(); ++row ) {
216             stats[ row ] = new BasicDescriptiveStatistics( _msa.getIdentifier( row ) );
217             for( int col = 0; col < _msa.getLength(); ++col ) {
218                 if ( _msa.getResidueAt( row, col ) != Sequence.GAP ) {
219                     stats[ row ].addValue( gappiness[ col ] );
220                 }
221             }
222         }
223         return stats;
224     }
225
226     private final double[] calcGappiness() {
227         final double gappiness[] = new double[ _msa.getLength() ];
228         final int seqs = _msa.getNumberOfSequences();
229         for( int i = 0; i < gappiness.length; ++i ) {
230             gappiness[ i ] = ( double ) MsaMethods.calcGapSumPerColumn( _msa, i ) / seqs;
231         }
232         return gappiness;
233     }
234
235     final static class DescriptiveStatisticsComparator implements Comparator<DescriptiveStatistics> {
236
237         final private boolean _ascending;
238         final private SORT_BY _sort_by;
239
240         public DescriptiveStatisticsComparator( final boolean ascending, final SORT_BY sort_by ) {
241             _ascending = ascending;
242             _sort_by = sort_by;
243         }
244
245         @Override
246         public final int compare( final DescriptiveStatistics s0, final DescriptiveStatistics s1 ) {
247             switch ( _sort_by ) {
248                 case MAX:
249                     if ( s0.getMax() < s1.getMax() ) {
250                         return _ascending ? -1 : 1;
251                     }
252                     else if ( s0.getMax() > s1.getMax() ) {
253                         return _ascending ? 1 : -1;
254                     }
255                     return 0;
256                 case MEAN:
257                     if ( s0.arithmeticMean() < s1.arithmeticMean() ) {
258                         return _ascending ? -1 : 1;
259                     }
260                     else if ( s0.arithmeticMean() > s1.arithmeticMean() ) {
261                         return _ascending ? 1 : -1;
262                     }
263                     return 0;
264                 case MEDIAN:
265                     if ( s0.median() < s1.median() ) {
266                         return _ascending ? -1 : 1;
267                     }
268                     else if ( s0.median() > s1.median() ) {
269                         return _ascending ? 1 : -1;
270                     }
271                     return 0;
272                 default:
273                     return 0;
274             }
275         }
276     }
277 }