Add missing binaty and statis library
[jabaws.git] / binaries / src / ViennaRNA / libsvm-2.91 / svm.cpp
1 #include <math.h>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <ctype.h>
5 #include <float.h>
6 #include <string.h>
7 #include <stdarg.h>
8 #include "svm.h"
9 int libsvm_version = LIBSVM_VERSION;
10 typedef float Qfloat;
11 typedef signed char schar;
12 #ifndef min
13 template <class T> static inline T min(T x,T y) { return (x<y)?x:y; }
14 #endif
15 #ifndef max
16 template <class T> static inline T max(T x,T y) { return (x>y)?x:y; }
17 #endif
18 template <class T> static inline void swap(T& x, T& y) { T t=x; x=y; y=t; }
19 template <class S, class T> static inline void clone(T*& dst, S* src, int n)
20 {
21         dst = new T[n];
22         memcpy((void *)dst,(void *)src,sizeof(T)*n);
23 }
24 static inline double powi(double base, int times)
25 {
26         double tmp = base, ret = 1.0;
27
28         for(int t=times; t>0; t/=2)
29         {
30                 if(t%2==1) ret*=tmp;
31                 tmp = tmp * tmp;
32         }
33         return ret;
34 }
35 #define INF HUGE_VAL
36 #define TAU 1e-12
37 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
38
39 static void print_string_stdout(const char *s)
40 {
41         fputs(s,stdout);
42         fflush(stdout);
43 }
44 static void (*svm_print_string) (const char *) = &print_string_stdout;
45 #if 1
46 static void info(const char *fmt,...)
47 {
48         char buf[BUFSIZ];
49         va_list ap;
50         va_start(ap,fmt);
51         vsprintf(buf,fmt,ap);
52         va_end(ap);
53         (*svm_print_string)(buf);
54 }
55 #else
56 static void info(const char *fmt,...) {}
57 #endif
58
59 //
60 // Kernel Cache
61 //
62 // l is the number of total data items
63 // size is the cache size limit in bytes
64 //
65 class Cache
66 {
67 public:
68         Cache(int l,long int size);
69         ~Cache();
70
71         // request data [0,len)
72         // return some position p where [p,len) need to be filled
73         // (p >= len if nothing needs to be filled)
74         int get_data(const int index, Qfloat **data, int len);
75         void swap_index(int i, int j);  
76 private:
77         int l;
78         long int size;
79         struct head_t
80         {
81                 head_t *prev, *next;    // a circular list
82                 Qfloat *data;
83                 int len;                // data[0,len) is cached in this entry
84         };
85
86         head_t *head;
87         head_t lru_head;
88         void lru_delete(head_t *h);
89         void lru_insert(head_t *h);
90 };
91
92 Cache::Cache(int l_,long int size_):l(l_),size(size_)
93 {
94         head = (head_t *)calloc(l,sizeof(head_t));      // initialized to 0
95         size /= sizeof(Qfloat);
96         size -= l * sizeof(head_t) / sizeof(Qfloat);
97         size = max(size, 2 * (long int) l);     // cache must be large enough for two columns
98         lru_head.next = lru_head.prev = &lru_head;
99 }
100
101 Cache::~Cache()
102 {
103         for(head_t *h = lru_head.next; h != &lru_head; h=h->next)
104                 free(h->data);
105         free(head);
106 }
107
108 void Cache::lru_delete(head_t *h)
109 {
110         // delete from current location
111         h->prev->next = h->next;
112         h->next->prev = h->prev;
113 }
114
115 void Cache::lru_insert(head_t *h)
116 {
117         // insert to last position
118         h->next = &lru_head;
119         h->prev = lru_head.prev;
120         h->prev->next = h;
121         h->next->prev = h;
122 }
123
124 int Cache::get_data(const int index, Qfloat **data, int len)
125 {
126         head_t *h = &head[index];
127         if(h->len) lru_delete(h);
128         int more = len - h->len;
129
130         if(more > 0)
131         {
132                 // free old space
133                 while(size < more)
134                 {
135                         head_t *old = lru_head.next;
136                         lru_delete(old);
137                         free(old->data);
138                         size += old->len;
139                         old->data = 0;
140                         old->len = 0;
141                 }
142
143                 // allocate new space
144                 h->data = (Qfloat *)realloc(h->data,sizeof(Qfloat)*len);
145                 size -= more;
146                 swap(h->len,len);
147         }
148
149         lru_insert(h);
150         *data = h->data;
151         return len;
152 }
153
154 void Cache::swap_index(int i, int j)
155 {
156         if(i==j) return;
157
158         if(head[i].len) lru_delete(&head[i]);
159         if(head[j].len) lru_delete(&head[j]);
160         swap(head[i].data,head[j].data);
161         swap(head[i].len,head[j].len);
162         if(head[i].len) lru_insert(&head[i]);
163         if(head[j].len) lru_insert(&head[j]);
164
165         if(i>j) swap(i,j);
166         for(head_t *h = lru_head.next; h!=&lru_head; h=h->next)
167         {
168                 if(h->len > i)
169                 {
170                         if(h->len > j)
171                                 swap(h->data[i],h->data[j]);
172                         else
173                         {
174                                 // give up
175                                 lru_delete(h);
176                                 free(h->data);
177                                 size += h->len;
178                                 h->data = 0;
179                                 h->len = 0;
180                         }
181                 }
182         }
183 }
184
185 //
186 // Kernel evaluation
187 //
188 // the static method k_function is for doing single kernel evaluation
189 // the constructor of Kernel prepares to calculate the l*l kernel matrix
190 // the member function get_Q is for getting one column from the Q Matrix
191 //
192 class QMatrix {
193 public:
194         virtual Qfloat *get_Q(int column, int len) const = 0;
195         virtual Qfloat *get_QD() const = 0;
196         virtual void swap_index(int i, int j) const = 0;
197         virtual ~QMatrix() {}
198 };
199
200 class Kernel: public QMatrix {
201 public:
202         Kernel(int l, svm_node * const * x, const svm_parameter& param);
203         virtual ~Kernel();
204
205         static double k_function(const svm_node *x, const svm_node *y,
206                                  const svm_parameter& param);
207         virtual Qfloat *get_Q(int column, int len) const = 0;
208         virtual Qfloat *get_QD() const = 0;
209         virtual void swap_index(int i, int j) const     // no so const...
210         {
211                 swap(x[i],x[j]);
212                 if(x_square) swap(x_square[i],x_square[j]);
213         }
214 protected:
215
216         double (Kernel::*kernel_function)(int i, int j) const;
217
218 private:
219         const svm_node **x;
220         double *x_square;
221
222         // svm_parameter
223         const int kernel_type;
224         const int degree;
225         const double gamma;
226         const double coef0;
227
228         static double dot(const svm_node *px, const svm_node *py);
229         double kernel_linear(int i, int j) const
230         {
231                 return dot(x[i],x[j]);
232         }
233         double kernel_poly(int i, int j) const
234         {
235                 return powi(gamma*dot(x[i],x[j])+coef0,degree);
236         }
237         double kernel_rbf(int i, int j) const
238         {
239                 return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
240         }
241         double kernel_sigmoid(int i, int j) const
242         {
243                 return tanh(gamma*dot(x[i],x[j])+coef0);
244         }
245         double kernel_precomputed(int i, int j) const
246         {
247                 return x[i][(int)(x[j][0].value)].value;
248         }
249 };
250
251 Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param)
252 :kernel_type(param.kernel_type), degree(param.degree),
253  gamma(param.gamma), coef0(param.coef0)
254 {
255         switch(kernel_type)
256         {
257                 case LINEAR:
258                         kernel_function = &Kernel::kernel_linear;
259                         break;
260                 case POLY:
261                         kernel_function = &Kernel::kernel_poly;
262                         break;
263                 case RBF:
264                         kernel_function = &Kernel::kernel_rbf;
265                         break;
266                 case SIGMOID:
267                         kernel_function = &Kernel::kernel_sigmoid;
268                         break;
269                 case PRECOMPUTED:
270                         kernel_function = &Kernel::kernel_precomputed;
271                         break;
272         }
273
274         clone(x,x_,l);
275
276         if(kernel_type == RBF)
277         {
278                 x_square = new double[l];
279                 for(int i=0;i<l;i++)
280                         x_square[i] = dot(x[i],x[i]);
281         }
282         else
283                 x_square = 0;
284 }
285
286 Kernel::~Kernel()
287 {
288         delete[] x;
289         delete[] x_square;
290 }
291
292 double Kernel::dot(const svm_node *px, const svm_node *py)
293 {
294         double sum = 0;
295         while(px->index != -1 && py->index != -1)
296         {
297                 if(px->index == py->index)
298                 {
299                         sum += px->value * py->value;
300                         ++px;
301                         ++py;
302                 }
303                 else
304                 {
305                         if(px->index > py->index)
306                                 ++py;
307                         else
308                                 ++px;
309                 }                       
310         }
311         return sum;
312 }
313
314 double Kernel::k_function(const svm_node *x, const svm_node *y,
315                           const svm_parameter& param)
316 {
317         switch(param.kernel_type)
318         {
319                 case LINEAR:
320                         return dot(x,y);
321                 case POLY:
322                         return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
323                 case RBF:
324                 {
325                         double sum = 0;
326                         while(x->index != -1 && y->index !=-1)
327                         {
328                                 if(x->index == y->index)
329                                 {
330                                         double d = x->value - y->value;
331                                         sum += d*d;
332                                         ++x;
333                                         ++y;
334                                 }
335                                 else
336                                 {
337                                         if(x->index > y->index)
338                                         {       
339                                                 sum += y->value * y->value;
340                                                 ++y;
341                                         }
342                                         else
343                                         {
344                                                 sum += x->value * x->value;
345                                                 ++x;
346                                         }
347                                 }
348                         }
349
350                         while(x->index != -1)
351                         {
352                                 sum += x->value * x->value;
353                                 ++x;
354                         }
355
356                         while(y->index != -1)
357                         {
358                                 sum += y->value * y->value;
359                                 ++y;
360                         }
361                         
362                         return exp(-param.gamma*sum);
363                 }
364                 case SIGMOID:
365                         return tanh(param.gamma*dot(x,y)+param.coef0);
366                 case PRECOMPUTED:  //x: test (validation), y: SV
367                         return x[(int)(y->value)].value;
368                 default:
369                         return 0;  // Unreachable 
370         }
371 }
372
373 // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
374 // Solves:
375 //
376 //      min 0.5(\alpha^T Q \alpha) + p^T \alpha
377 //
378 //              y^T \alpha = \delta
379 //              y_i = +1 or -1
380 //              0 <= alpha_i <= Cp for y_i = 1
381 //              0 <= alpha_i <= Cn for y_i = -1
382 //
383 // Given:
384 //
385 //      Q, p, y, Cp, Cn, and an initial feasible point \alpha
386 //      l is the size of vectors and matrices
387 //      eps is the stopping tolerance
388 //
389 // solution will be put in \alpha, objective value will be put in obj
390 //
391 class Solver {
392 public:
393         Solver() {};
394         virtual ~Solver() {};
395
396         struct SolutionInfo {
397                 double obj;
398                 double rho;
399                 double upper_bound_p;
400                 double upper_bound_n;
401                 double r;       // for Solver_NU
402         };
403
404         void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
405                    double *alpha_, double Cp, double Cn, double eps,
406                    SolutionInfo* si, int shrinking);
407 protected:
408         int active_size;
409         schar *y;
410         double *G;              // gradient of objective function
411         enum { LOWER_BOUND, UPPER_BOUND, FREE };
412         char *alpha_status;     // LOWER_BOUND, UPPER_BOUND, FREE
413         double *alpha;
414         const QMatrix *Q;
415         const Qfloat *QD;
416         double eps;
417         double Cp,Cn;
418         double *p;
419         int *active_set;
420         double *G_bar;          // gradient, if we treat free variables as 0
421         int l;
422         bool unshrink;  // XXX
423
424         double get_C(int i)
425         {
426                 return (y[i] > 0)? Cp : Cn;
427         }
428         void update_alpha_status(int i)
429         {
430                 if(alpha[i] >= get_C(i))
431                         alpha_status[i] = UPPER_BOUND;
432                 else if(alpha[i] <= 0)
433                         alpha_status[i] = LOWER_BOUND;
434                 else alpha_status[i] = FREE;
435         }
436         bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
437         bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
438         bool is_free(int i) { return alpha_status[i] == FREE; }
439         void swap_index(int i, int j);
440         void reconstruct_gradient();
441         virtual int select_working_set(int &i, int &j);
442         virtual double calculate_rho();
443         virtual void do_shrinking();
444 private:
445         bool be_shrunk(int i, double Gmax1, double Gmax2);      
446 };
447
448 void Solver::swap_index(int i, int j)
449 {
450         Q->swap_index(i,j);
451         swap(y[i],y[j]);
452         swap(G[i],G[j]);
453         swap(alpha_status[i],alpha_status[j]);
454         swap(alpha[i],alpha[j]);
455         swap(p[i],p[j]);
456         swap(active_set[i],active_set[j]);
457         swap(G_bar[i],G_bar[j]);
458 }
459
460 void Solver::reconstruct_gradient()
461 {
462         // reconstruct inactive elements of G from G_bar and free variables
463
464         if(active_size == l) return;
465
466         int i,j;
467         int nr_free = 0;
468
469         for(j=active_size;j<l;j++)
470                 G[j] = G_bar[j] + p[j];
471
472         for(j=0;j<active_size;j++)
473                 if(is_free(j))
474                         nr_free++;
475
476         if(2*nr_free < active_size)
477                 info("\nWarning: using -h 0 may be faster\n");
478
479         if (nr_free*l > 2*active_size*(l-active_size))
480         {
481                 for(i=active_size;i<l;i++)
482                 {
483                         const Qfloat *Q_i = Q->get_Q(i,active_size);
484                         for(j=0;j<active_size;j++)
485                                 if(is_free(j))
486                                         G[i] += alpha[j] * Q_i[j];
487                 }
488         }
489         else
490         {
491                 for(i=0;i<active_size;i++)
492                         if(is_free(i))
493                         {
494                                 const Qfloat *Q_i = Q->get_Q(i,l);
495                                 double alpha_i = alpha[i];
496                                 for(j=active_size;j<l;j++)
497                                         G[j] += alpha_i * Q_i[j];
498                         }
499         }
500 }
501
502 void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
503                    double *alpha_, double Cp, double Cn, double eps,
504                    SolutionInfo* si, int shrinking)
505 {
506         this->l = l;
507         this->Q = &Q;
508         QD=Q.get_QD();
509         clone(p, p_,l);
510         clone(y, y_,l);
511         clone(alpha,alpha_,l);
512         this->Cp = Cp;
513         this->Cn = Cn;
514         this->eps = eps;
515         unshrink = false;
516
517         // initialize alpha_status
518         {
519                 alpha_status = new char[l];
520                 for(int i=0;i<l;i++)
521                         update_alpha_status(i);
522         }
523
524         // initialize active set (for shrinking)
525         {
526                 active_set = new int[l];
527                 for(int i=0;i<l;i++)
528                         active_set[i] = i;
529                 active_size = l;
530         }
531
532         // initialize gradient
533         {
534                 G = new double[l];
535                 G_bar = new double[l];
536                 int i;
537                 for(i=0;i<l;i++)
538                 {
539                         G[i] = p[i];
540                         G_bar[i] = 0;
541                 }
542                 for(i=0;i<l;i++)
543                         if(!is_lower_bound(i))
544                         {
545                                 const Qfloat *Q_i = Q.get_Q(i,l);
546                                 double alpha_i = alpha[i];
547                                 int j;
548                                 for(j=0;j<l;j++)
549                                         G[j] += alpha_i*Q_i[j];
550                                 if(is_upper_bound(i))
551                                         for(j=0;j<l;j++)
552                                                 G_bar[j] += get_C(i) * Q_i[j];
553                         }
554         }
555
556         // optimization step
557
558         int iter = 0;
559         int counter = min(l,1000)+1;
560
561         while(1)
562         {
563                 // show progress and do shrinking
564
565                 if(--counter == 0)
566                 {
567                         counter = min(l,1000);
568                         if(shrinking) do_shrinking();
569                         info(".");
570                 }
571
572                 int i,j;
573                 if(select_working_set(i,j)!=0)
574                 {
575                         // reconstruct the whole gradient
576                         reconstruct_gradient();
577                         // reset active set size and check
578                         active_size = l;
579                         info("*");
580                         if(select_working_set(i,j)!=0)
581                                 break;
582                         else
583                                 counter = 1;    // do shrinking next iteration
584                 }
585                 
586                 ++iter;
587
588                 // update alpha[i] and alpha[j], handle bounds carefully
589                 
590                 const Qfloat *Q_i = Q.get_Q(i,active_size);
591                 const Qfloat *Q_j = Q.get_Q(j,active_size);
592
593                 double C_i = get_C(i);
594                 double C_j = get_C(j);
595
596                 double old_alpha_i = alpha[i];
597                 double old_alpha_j = alpha[j];
598
599                 if(y[i]!=y[j])
600                 {
601                         double quad_coef = Q_i[i]+Q_j[j]+2*Q_i[j];
602                         if (quad_coef <= 0)
603                                 quad_coef = TAU;
604                         double delta = (-G[i]-G[j])/quad_coef;
605                         double diff = alpha[i] - alpha[j];
606                         alpha[i] += delta;
607                         alpha[j] += delta;
608                         
609                         if(diff > 0)
610                         {
611                                 if(alpha[j] < 0)
612                                 {
613                                         alpha[j] = 0;
614                                         alpha[i] = diff;
615                                 }
616                         }
617                         else
618                         {
619                                 if(alpha[i] < 0)
620                                 {
621                                         alpha[i] = 0;
622                                         alpha[j] = -diff;
623                                 }
624                         }
625                         if(diff > C_i - C_j)
626                         {
627                                 if(alpha[i] > C_i)
628                                 {
629                                         alpha[i] = C_i;
630                                         alpha[j] = C_i - diff;
631                                 }
632                         }
633                         else
634                         {
635                                 if(alpha[j] > C_j)
636                                 {
637                                         alpha[j] = C_j;
638                                         alpha[i] = C_j + diff;
639                                 }
640                         }
641                 }
642                 else
643                 {
644                         double quad_coef = Q_i[i]+Q_j[j]-2*Q_i[j];
645                         if (quad_coef <= 0)
646                                 quad_coef = TAU;
647                         double delta = (G[i]-G[j])/quad_coef;
648                         double sum = alpha[i] + alpha[j];
649                         alpha[i] -= delta;
650                         alpha[j] += delta;
651
652                         if(sum > C_i)
653                         {
654                                 if(alpha[i] > C_i)
655                                 {
656                                         alpha[i] = C_i;
657                                         alpha[j] = sum - C_i;
658                                 }
659                         }
660                         else
661                         {
662                                 if(alpha[j] < 0)
663                                 {
664                                         alpha[j] = 0;
665                                         alpha[i] = sum;
666                                 }
667                         }
668                         if(sum > C_j)
669                         {
670                                 if(alpha[j] > C_j)
671                                 {
672                                         alpha[j] = C_j;
673                                         alpha[i] = sum - C_j;
674                                 }
675                         }
676                         else
677                         {
678                                 if(alpha[i] < 0)
679                                 {
680                                         alpha[i] = 0;
681                                         alpha[j] = sum;
682                                 }
683                         }
684                 }
685
686                 // update G
687
688                 double delta_alpha_i = alpha[i] - old_alpha_i;
689                 double delta_alpha_j = alpha[j] - old_alpha_j;
690                 
691                 for(int k=0;k<active_size;k++)
692                 {
693                         G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
694                 }
695
696                 // update alpha_status and G_bar
697
698                 {
699                         bool ui = is_upper_bound(i);
700                         bool uj = is_upper_bound(j);
701                         update_alpha_status(i);
702                         update_alpha_status(j);
703                         int k;
704                         if(ui != is_upper_bound(i))
705                         {
706                                 Q_i = Q.get_Q(i,l);
707                                 if(ui)
708                                         for(k=0;k<l;k++)
709                                                 G_bar[k] -= C_i * Q_i[k];
710                                 else
711                                         for(k=0;k<l;k++)
712                                                 G_bar[k] += C_i * Q_i[k];
713                         }
714
715                         if(uj != is_upper_bound(j))
716                         {
717                                 Q_j = Q.get_Q(j,l);
718                                 if(uj)
719                                         for(k=0;k<l;k++)
720                                                 G_bar[k] -= C_j * Q_j[k];
721                                 else
722                                         for(k=0;k<l;k++)
723                                                 G_bar[k] += C_j * Q_j[k];
724                         }
725                 }
726         }
727
728         // calculate rho
729
730         si->rho = calculate_rho();
731
732         // calculate objective value
733         {
734                 double v = 0;
735                 int i;
736                 for(i=0;i<l;i++)
737                         v += alpha[i] * (G[i] + p[i]);
738
739                 si->obj = v/2;
740         }
741
742         // put back the solution
743         {
744                 for(int i=0;i<l;i++)
745                         alpha_[active_set[i]] = alpha[i];
746         }
747
748         // juggle everything back
749         /*{
750                 for(int i=0;i<l;i++)
751                         while(active_set[i] != i)
752                                 swap_index(i,active_set[i]);
753                                 // or Q.swap_index(i,active_set[i]);
754         }*/
755
756         si->upper_bound_p = Cp;
757         si->upper_bound_n = Cn;
758
759         info("\noptimization finished, #iter = %d\n",iter);
760
761         delete[] p;
762         delete[] y;
763         delete[] alpha;
764         delete[] alpha_status;
765         delete[] active_set;
766         delete[] G;
767         delete[] G_bar;
768 }
769
770 // return 1 if already optimal, return 0 otherwise
771 int Solver::select_working_set(int &out_i, int &out_j)
772 {
773         // return i,j such that
774         // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
775         // j: minimizes the decrease of obj value
776         //    (if quadratic coefficeint <= 0, replace it with tau)
777         //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
778         
779         double Gmax = -INF;
780         double Gmax2 = -INF;
781         int Gmax_idx = -1;
782         int Gmin_idx = -1;
783         double obj_diff_min = INF;
784
785         for(int t=0;t<active_size;t++)
786                 if(y[t]==+1)    
787                 {
788                         if(!is_upper_bound(t))
789                                 if(-G[t] >= Gmax)
790                                 {
791                                         Gmax = -G[t];
792                                         Gmax_idx = t;
793                                 }
794                 }
795                 else
796                 {
797                         if(!is_lower_bound(t))
798                                 if(G[t] >= Gmax)
799                                 {
800                                         Gmax = G[t];
801                                         Gmax_idx = t;
802                                 }
803                 }
804
805         int i = Gmax_idx;
806         const Qfloat *Q_i = NULL;
807         if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1
808                 Q_i = Q->get_Q(i,active_size);
809
810         for(int j=0;j<active_size;j++)
811         {
812                 if(y[j]==+1)
813                 {
814                         if (!is_lower_bound(j))
815                         {
816                                 double grad_diff=Gmax+G[j];
817                                 if (G[j] >= Gmax2)
818                                         Gmax2 = G[j];
819                                 if (grad_diff > 0)
820                                 {
821                                         double obj_diff; 
822                                         double quad_coef=Q_i[i]+QD[j]-2.0*y[i]*Q_i[j];
823                                         if (quad_coef > 0)
824                                                 obj_diff = -(grad_diff*grad_diff)/quad_coef;
825                                         else
826                                                 obj_diff = -(grad_diff*grad_diff)/TAU;
827
828                                         if (obj_diff <= obj_diff_min)
829                                         {
830                                                 Gmin_idx=j;
831                                                 obj_diff_min = obj_diff;
832                                         }
833                                 }
834                         }
835                 }
836                 else
837                 {
838                         if (!is_upper_bound(j))
839                         {
840                                 double grad_diff= Gmax-G[j];
841                                 if (-G[j] >= Gmax2)
842                                         Gmax2 = -G[j];
843                                 if (grad_diff > 0)
844                                 {
845                                         double obj_diff; 
846                                         double quad_coef=Q_i[i]+QD[j]+2.0*y[i]*Q_i[j];
847                                         if (quad_coef > 0)
848                                                 obj_diff = -(grad_diff*grad_diff)/quad_coef;
849                                         else
850                                                 obj_diff = -(grad_diff*grad_diff)/TAU;
851
852                                         if (obj_diff <= obj_diff_min)
853                                         {
854                                                 Gmin_idx=j;
855                                                 obj_diff_min = obj_diff;
856                                         }
857                                 }
858                         }
859                 }
860         }
861
862         if(Gmax+Gmax2 < eps)
863                 return 1;
864
865         out_i = Gmax_idx;
866         out_j = Gmin_idx;
867         return 0;
868 }
869
870 bool Solver::be_shrunk(int i, double Gmax1, double Gmax2)
871 {
872         if(is_upper_bound(i))
873         {
874                 if(y[i]==+1)
875                         return(-G[i] > Gmax1);
876                 else
877                         return(-G[i] > Gmax2);
878         }
879         else if(is_lower_bound(i))
880         {
881                 if(y[i]==+1)
882                         return(G[i] > Gmax2);
883                 else    
884                         return(G[i] > Gmax1);
885         }
886         else
887                 return(false);
888 }
889
890 void Solver::do_shrinking()
891 {
892         int i;
893         double Gmax1 = -INF;            // max { -y_i * grad(f)_i | i in I_up(\alpha) }
894         double Gmax2 = -INF;            // max { y_i * grad(f)_i | i in I_low(\alpha) }
895
896         // find maximal violating pair first
897         for(i=0;i<active_size;i++)
898         {
899                 if(y[i]==+1)    
900                 {
901                         if(!is_upper_bound(i))  
902                         {
903                                 if(-G[i] >= Gmax1)
904                                         Gmax1 = -G[i];
905                         }
906                         if(!is_lower_bound(i))  
907                         {
908                                 if(G[i] >= Gmax2)
909                                         Gmax2 = G[i];
910                         }
911                 }
912                 else    
913                 {
914                         if(!is_upper_bound(i))  
915                         {
916                                 if(-G[i] >= Gmax2)
917                                         Gmax2 = -G[i];
918                         }
919                         if(!is_lower_bound(i))  
920                         {
921                                 if(G[i] >= Gmax1)
922                                         Gmax1 = G[i];
923                         }
924                 }
925         }
926
927         if(unshrink == false && Gmax1 + Gmax2 <= eps*10) 
928         {
929                 unshrink = true;
930                 reconstruct_gradient();
931                 active_size = l;
932                 info("*");
933         }
934
935         for(i=0;i<active_size;i++)
936                 if (be_shrunk(i, Gmax1, Gmax2))
937                 {
938                         active_size--;
939                         while (active_size > i)
940                         {
941                                 if (!be_shrunk(active_size, Gmax1, Gmax2))
942                                 {
943                                         swap_index(i,active_size);
944                                         break;
945                                 }
946                                 active_size--;
947                         }
948                 }
949 }
950
951 double Solver::calculate_rho()
952 {
953         double r;
954         int nr_free = 0;
955         double ub = INF, lb = -INF, sum_free = 0;
956         for(int i=0;i<active_size;i++)
957         {
958                 double yG = y[i]*G[i];
959
960                 if(is_upper_bound(i))
961                 {
962                         if(y[i]==-1)
963                                 ub = min(ub,yG);
964                         else
965                                 lb = max(lb,yG);
966                 }
967                 else if(is_lower_bound(i))
968                 {
969                         if(y[i]==+1)
970                                 ub = min(ub,yG);
971                         else
972                                 lb = max(lb,yG);
973                 }
974                 else
975                 {
976                         ++nr_free;
977                         sum_free += yG;
978                 }
979         }
980
981         if(nr_free>0)
982                 r = sum_free/nr_free;
983         else
984                 r = (ub+lb)/2;
985
986         return r;
987 }
988
989 //
990 // Solver for nu-svm classification and regression
991 //
992 // additional constraint: e^T \alpha = constant
993 //
994 class Solver_NU : public Solver
995 {
996 public:
997         Solver_NU() {}
998         void Solve(int l, const QMatrix& Q, const double *p, const schar *y,
999                    double *alpha, double Cp, double Cn, double eps,
1000                    SolutionInfo* si, int shrinking)
1001         {
1002                 this->si = si;
1003                 Solver::Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking);
1004         }
1005 private:
1006         SolutionInfo *si;
1007         int select_working_set(int &i, int &j);
1008         double calculate_rho();
1009         bool be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4);
1010         void do_shrinking();
1011 };
1012
1013 // return 1 if already optimal, return 0 otherwise
1014 int Solver_NU::select_working_set(int &out_i, int &out_j)
1015 {
1016         // return i,j such that y_i = y_j and
1017         // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
1018         // j: minimizes the decrease of obj value
1019         //    (if quadratic coefficeint <= 0, replace it with tau)
1020         //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
1021
1022         double Gmaxp = -INF;
1023         double Gmaxp2 = -INF;
1024         int Gmaxp_idx = -1;
1025
1026         double Gmaxn = -INF;
1027         double Gmaxn2 = -INF;
1028         int Gmaxn_idx = -1;
1029
1030         int Gmin_idx = -1;
1031         double obj_diff_min = INF;
1032
1033         for(int t=0;t<active_size;t++)
1034                 if(y[t]==+1)
1035                 {
1036                         if(!is_upper_bound(t))
1037                                 if(-G[t] >= Gmaxp)
1038                                 {
1039                                         Gmaxp = -G[t];
1040                                         Gmaxp_idx = t;
1041                                 }
1042                 }
1043                 else
1044                 {
1045                         if(!is_lower_bound(t))
1046                                 if(G[t] >= Gmaxn)
1047                                 {
1048                                         Gmaxn = G[t];
1049                                         Gmaxn_idx = t;
1050                                 }
1051                 }
1052
1053         int ip = Gmaxp_idx;
1054         int in = Gmaxn_idx;
1055         const Qfloat *Q_ip = NULL;
1056         const Qfloat *Q_in = NULL;
1057         if(ip != -1) // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1
1058                 Q_ip = Q->get_Q(ip,active_size);
1059         if(in != -1)
1060                 Q_in = Q->get_Q(in,active_size);
1061
1062         for(int j=0;j<active_size;j++)
1063         {
1064                 if(y[j]==+1)
1065                 {
1066                         if (!is_lower_bound(j)) 
1067                         {
1068                                 double grad_diff=Gmaxp+G[j];
1069                                 if (G[j] >= Gmaxp2)
1070                                         Gmaxp2 = G[j];
1071                                 if (grad_diff > 0)
1072                                 {
1073                                         double obj_diff; 
1074                                         double quad_coef = Q_ip[ip]+QD[j]-2*Q_ip[j];
1075                                         if (quad_coef > 0)
1076                                                 obj_diff = -(grad_diff*grad_diff)/quad_coef;
1077                                         else
1078                                                 obj_diff = -(grad_diff*grad_diff)/TAU;
1079
1080                                         if (obj_diff <= obj_diff_min)
1081                                         {
1082                                                 Gmin_idx=j;
1083                                                 obj_diff_min = obj_diff;
1084                                         }
1085                                 }
1086                         }
1087                 }
1088                 else
1089                 {
1090                         if (!is_upper_bound(j))
1091                         {
1092                                 double grad_diff=Gmaxn-G[j];
1093                                 if (-G[j] >= Gmaxn2)
1094                                         Gmaxn2 = -G[j];
1095                                 if (grad_diff > 0)
1096                                 {
1097                                         double obj_diff; 
1098                                         double quad_coef = Q_in[in]+QD[j]-2*Q_in[j];
1099                                         if (quad_coef > 0)
1100                                                 obj_diff = -(grad_diff*grad_diff)/quad_coef;
1101                                         else
1102                                                 obj_diff = -(grad_diff*grad_diff)/TAU;
1103
1104                                         if (obj_diff <= obj_diff_min)
1105                                         {
1106                                                 Gmin_idx=j;
1107                                                 obj_diff_min = obj_diff;
1108                                         }
1109                                 }
1110                         }
1111                 }
1112         }
1113
1114         if(max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps)
1115                 return 1;
1116
1117         if (y[Gmin_idx] == +1)
1118                 out_i = Gmaxp_idx;
1119         else
1120                 out_i = Gmaxn_idx;
1121         out_j = Gmin_idx;
1122
1123         return 0;
1124 }
1125
1126 bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)
1127 {
1128         if(is_upper_bound(i))
1129         {
1130                 if(y[i]==+1)
1131                         return(-G[i] > Gmax1);
1132                 else    
1133                         return(-G[i] > Gmax4);
1134         }
1135         else if(is_lower_bound(i))
1136         {
1137                 if(y[i]==+1)
1138                         return(G[i] > Gmax2);
1139                 else    
1140                         return(G[i] > Gmax3);
1141         }
1142         else
1143                 return(false);
1144 }
1145
1146 void Solver_NU::do_shrinking()
1147 {
1148         double Gmax1 = -INF;    // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
1149         double Gmax2 = -INF;    // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
1150         double Gmax3 = -INF;    // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
1151         double Gmax4 = -INF;    // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
1152
1153         // find maximal violating pair first
1154         int i;
1155         for(i=0;i<active_size;i++)
1156         {
1157                 if(!is_upper_bound(i))
1158                 {
1159                         if(y[i]==+1)
1160                         {
1161                                 if(-G[i] > Gmax1) Gmax1 = -G[i];
1162                         }
1163                         else    if(-G[i] > Gmax4) Gmax4 = -G[i];
1164                 }
1165                 if(!is_lower_bound(i))
1166                 {
1167                         if(y[i]==+1)
1168                         {       
1169                                 if(G[i] > Gmax2) Gmax2 = G[i];
1170                         }
1171                         else    if(G[i] > Gmax3) Gmax3 = G[i];
1172                 }
1173         }
1174
1175         if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) 
1176         {
1177                 unshrink = true;
1178                 reconstruct_gradient();
1179                 active_size = l;
1180         }
1181
1182         for(i=0;i<active_size;i++)
1183                 if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4))
1184                 {
1185                         active_size--;
1186                         while (active_size > i)
1187                         {
1188                                 if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
1189                                 {
1190                                         swap_index(i,active_size);
1191                                         break;
1192                                 }
1193                                 active_size--;
1194                         }
1195                 }
1196 }
1197
1198 double Solver_NU::calculate_rho()
1199 {
1200         int nr_free1 = 0,nr_free2 = 0;
1201         double ub1 = INF, ub2 = INF;
1202         double lb1 = -INF, lb2 = -INF;
1203         double sum_free1 = 0, sum_free2 = 0;
1204
1205         for(int i=0;i<active_size;i++)
1206         {
1207                 if(y[i]==+1)
1208                 {
1209                         if(is_upper_bound(i))
1210                                 lb1 = max(lb1,G[i]);
1211                         else if(is_lower_bound(i))
1212                                 ub1 = min(ub1,G[i]);
1213                         else
1214                         {
1215                                 ++nr_free1;
1216                                 sum_free1 += G[i];
1217                         }
1218                 }
1219                 else
1220                 {
1221                         if(is_upper_bound(i))
1222                                 lb2 = max(lb2,G[i]);
1223                         else if(is_lower_bound(i))
1224                                 ub2 = min(ub2,G[i]);
1225                         else
1226                         {
1227                                 ++nr_free2;
1228                                 sum_free2 += G[i];
1229                         }
1230                 }
1231         }
1232
1233         double r1,r2;
1234         if(nr_free1 > 0)
1235                 r1 = sum_free1/nr_free1;
1236         else
1237                 r1 = (ub1+lb1)/2;
1238         
1239         if(nr_free2 > 0)
1240                 r2 = sum_free2/nr_free2;
1241         else
1242                 r2 = (ub2+lb2)/2;
1243         
1244         si->r = (r1+r2)/2;
1245         return (r1-r2)/2;
1246 }
1247
1248 //
1249 // Q matrices for various formulations
1250 //
1251 class SVC_Q: public Kernel
1252
1253 public:
1254         SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_)
1255         :Kernel(prob.l, prob.x, param)
1256         {
1257                 clone(y,y_,prob.l);
1258                 cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1259                 QD = new Qfloat[prob.l];
1260                 for(int i=0;i<prob.l;i++)
1261                         QD[i]= (Qfloat)(this->*kernel_function)(i,i);
1262         }
1263         
1264         Qfloat *get_Q(int i, int len) const
1265         {
1266                 Qfloat *data;
1267                 int start, j;
1268                 if((start = cache->get_data(i,&data,len)) < len)
1269                 {
1270                         for(j=start;j<len;j++)
1271                                 data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));
1272                 }
1273                 return data;
1274         }
1275
1276         Qfloat *get_QD() const
1277         {
1278                 return QD;
1279         }
1280
1281         void swap_index(int i, int j) const
1282         {
1283                 cache->swap_index(i,j);
1284                 Kernel::swap_index(i,j);
1285                 swap(y[i],y[j]);
1286                 swap(QD[i],QD[j]);
1287         }
1288
1289         ~SVC_Q()
1290         {
1291                 delete[] y;
1292                 delete cache;
1293                 delete[] QD;
1294         }
1295 private:
1296         schar *y;
1297         Cache *cache;
1298         Qfloat *QD;
1299 };
1300
1301 class ONE_CLASS_Q: public Kernel
1302 {
1303 public:
1304         ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param)
1305         :Kernel(prob.l, prob.x, param)
1306         {
1307                 cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1308                 QD = new Qfloat[prob.l];
1309                 for(int i=0;i<prob.l;i++)
1310                         QD[i]= (Qfloat)(this->*kernel_function)(i,i);
1311         }
1312         
1313         Qfloat *get_Q(int i, int len) const
1314         {
1315                 Qfloat *data;
1316                 int start, j;
1317                 if((start = cache->get_data(i,&data,len)) < len)
1318                 {
1319                         for(j=start;j<len;j++)
1320                                 data[j] = (Qfloat)(this->*kernel_function)(i,j);
1321                 }
1322                 return data;
1323         }
1324
1325         Qfloat *get_QD() const
1326         {
1327                 return QD;
1328         }
1329
1330         void swap_index(int i, int j) const
1331         {
1332                 cache->swap_index(i,j);
1333                 Kernel::swap_index(i,j);
1334                 swap(QD[i],QD[j]);
1335         }
1336
1337         ~ONE_CLASS_Q()
1338         {
1339                 delete cache;
1340                 delete[] QD;
1341         }
1342 private:
1343         Cache *cache;
1344         Qfloat *QD;
1345 };
1346
1347 class SVR_Q: public Kernel
1348
1349 public:
1350         SVR_Q(const svm_problem& prob, const svm_parameter& param)
1351         :Kernel(prob.l, prob.x, param)
1352         {
1353                 l = prob.l;
1354                 cache = new Cache(l,(long int)(param.cache_size*(1<<20)));
1355                 QD = new Qfloat[2*l];
1356                 sign = new schar[2*l];
1357                 index = new int[2*l];
1358                 for(int k=0;k<l;k++)
1359                 {
1360                         sign[k] = 1;
1361                         sign[k+l] = -1;
1362                         index[k] = k;
1363                         index[k+l] = k;
1364                         QD[k]= (Qfloat)(this->*kernel_function)(k,k);
1365                         QD[k+l]=QD[k];
1366                 }
1367                 buffer[0] = new Qfloat[2*l];
1368                 buffer[1] = new Qfloat[2*l];
1369                 next_buffer = 0;
1370         }
1371
1372         void swap_index(int i, int j) const
1373         {
1374                 swap(sign[i],sign[j]);
1375                 swap(index[i],index[j]);
1376                 swap(QD[i],QD[j]);
1377         }
1378         
1379         Qfloat *get_Q(int i, int len) const
1380         {
1381                 Qfloat *data;
1382                 int j, real_i = index[i];
1383                 if(cache->get_data(real_i,&data,l) < l)
1384                 {
1385                         for(j=0;j<l;j++)
1386                                 data[j] = (Qfloat)(this->*kernel_function)(real_i,j);
1387                 }
1388
1389                 // reorder and copy
1390                 Qfloat *buf = buffer[next_buffer];
1391                 next_buffer = 1 - next_buffer;
1392                 schar si = sign[i];
1393                 for(j=0;j<len;j++)
1394                         buf[j] = (Qfloat) si * (Qfloat) sign[j] * data[index[j]];
1395                 return buf;
1396         }
1397
1398         Qfloat *get_QD() const
1399         {
1400                 return QD;
1401         }
1402
1403         ~SVR_Q()
1404         {
1405                 delete cache;
1406                 delete[] sign;
1407                 delete[] index;
1408                 delete[] buffer[0];
1409                 delete[] buffer[1];
1410                 delete[] QD;
1411         }
1412 private:
1413         int l;
1414         Cache *cache;
1415         schar *sign;
1416         int *index;
1417         mutable int next_buffer;
1418         Qfloat *buffer[2];
1419         Qfloat *QD;
1420 };
1421
1422 //
1423 // construct and solve various formulations
1424 //
1425 static void solve_c_svc(
1426         const svm_problem *prob, const svm_parameter* param,
1427         double *alpha, Solver::SolutionInfo* si, double Cp, double Cn)
1428 {
1429         int l = prob->l;
1430         double *minus_ones = new double[l];
1431         schar *y = new schar[l];
1432
1433         int i;
1434
1435         for(i=0;i<l;i++)
1436         {
1437                 alpha[i] = 0;
1438                 minus_ones[i] = -1;
1439                 if(prob->y[i] > 0) y[i] = +1; else y[i]=-1;
1440         }
1441
1442         Solver s;
1443         s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,
1444                 alpha, Cp, Cn, param->eps, si, param->shrinking);
1445
1446         double sum_alpha=0;
1447         for(i=0;i<l;i++)
1448                 sum_alpha += alpha[i];
1449
1450         if (Cp==Cn)
1451                 info("nu = %f\n", sum_alpha/(Cp*prob->l));
1452
1453         for(i=0;i<l;i++)
1454                 alpha[i] *= y[i];
1455
1456         delete[] minus_ones;
1457         delete[] y;
1458 }
1459
1460 static void solve_nu_svc(
1461         const svm_problem *prob, const svm_parameter *param,
1462         double *alpha, Solver::SolutionInfo* si)
1463 {
1464         int i;
1465         int l = prob->l;
1466         double nu = param->nu;
1467
1468         schar *y = new schar[l];
1469
1470         for(i=0;i<l;i++)
1471                 if(prob->y[i]>0)
1472                         y[i] = +1;
1473                 else
1474                         y[i] = -1;
1475
1476         double sum_pos = nu*l/2;
1477         double sum_neg = nu*l/2;
1478
1479         for(i=0;i<l;i++)
1480                 if(y[i] == +1)
1481                 {
1482                         alpha[i] = min(1.0,sum_pos);
1483                         sum_pos -= alpha[i];
1484                 }
1485                 else
1486                 {
1487                         alpha[i] = min(1.0,sum_neg);
1488                         sum_neg -= alpha[i];
1489                 }
1490
1491         double *zeros = new double[l];
1492
1493         for(i=0;i<l;i++)
1494                 zeros[i] = 0;
1495
1496         Solver_NU s;
1497         s.Solve(l, SVC_Q(*prob,*param,y), zeros, y,
1498                 alpha, 1.0, 1.0, param->eps, si,  param->shrinking);
1499         double r = si->r;
1500
1501         info("C = %f\n",1/r);
1502
1503         for(i=0;i<l;i++)
1504                 alpha[i] *= y[i]/r;
1505
1506         si->rho /= r;
1507         si->obj /= (r*r);
1508         si->upper_bound_p = 1/r;
1509         si->upper_bound_n = 1/r;
1510
1511         delete[] y;
1512         delete[] zeros;
1513 }
1514
1515 static void solve_one_class(
1516         const svm_problem *prob, const svm_parameter *param,
1517         double *alpha, Solver::SolutionInfo* si)
1518 {
1519         int l = prob->l;
1520         double *zeros = new double[l];
1521         schar *ones = new schar[l];
1522         int i;
1523
1524         int n = (int)(param->nu*prob->l);       // # of alpha's at upper bound
1525
1526         for(i=0;i<n;i++)
1527                 alpha[i] = 1;
1528         if(n<prob->l)
1529                 alpha[n] = param->nu * prob->l - n;
1530         for(i=n+1;i<l;i++)
1531                 alpha[i] = 0;
1532
1533         for(i=0;i<l;i++)
1534         {
1535                 zeros[i] = 0;
1536                 ones[i] = 1;
1537         }
1538
1539         Solver s;
1540         s.Solve(l, ONE_CLASS_Q(*prob,*param), zeros, ones,
1541                 alpha, 1.0, 1.0, param->eps, si, param->shrinking);
1542
1543         delete[] zeros;
1544         delete[] ones;
1545 }
1546
1547 static void solve_epsilon_svr(
1548         const svm_problem *prob, const svm_parameter *param,
1549         double *alpha, Solver::SolutionInfo* si)
1550 {
1551         int l = prob->l;
1552         double *alpha2 = new double[2*l];
1553         double *linear_term = new double[2*l];
1554         schar *y = new schar[2*l];
1555         int i;
1556
1557         for(i=0;i<l;i++)
1558         {
1559                 alpha2[i] = 0;
1560                 linear_term[i] = param->p - prob->y[i];
1561                 y[i] = 1;
1562
1563                 alpha2[i+l] = 0;
1564                 linear_term[i+l] = param->p + prob->y[i];
1565                 y[i+l] = -1;
1566         }
1567
1568         Solver s;
1569         s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
1570                 alpha2, param->C, param->C, param->eps, si, param->shrinking);
1571
1572         double sum_alpha = 0;
1573         for(i=0;i<l;i++)
1574         {
1575                 alpha[i] = alpha2[i] - alpha2[i+l];
1576                 sum_alpha += fabs(alpha[i]);
1577         }
1578         info("nu = %f\n",sum_alpha/(param->C*l));
1579
1580         delete[] alpha2;
1581         delete[] linear_term;
1582         delete[] y;
1583 }
1584
1585 static void solve_nu_svr(
1586         const svm_problem *prob, const svm_parameter *param,
1587         double *alpha, Solver::SolutionInfo* si)
1588 {
1589         int l = prob->l;
1590         double C = param->C;
1591         double *alpha2 = new double[2*l];
1592         double *linear_term = new double[2*l];
1593         schar *y = new schar[2*l];
1594         int i;
1595
1596         double sum = C * param->nu * l / 2;
1597         for(i=0;i<l;i++)
1598         {
1599                 alpha2[i] = alpha2[i+l] = min(sum,C);
1600                 sum -= alpha2[i];
1601
1602                 linear_term[i] = - prob->y[i];
1603                 y[i] = 1;
1604
1605                 linear_term[i+l] = prob->y[i];
1606                 y[i+l] = -1;
1607         }
1608
1609         Solver_NU s;
1610         s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
1611                 alpha2, C, C, param->eps, si, param->shrinking);
1612
1613         info("epsilon = %f\n",-si->r);
1614
1615         for(i=0;i<l;i++)
1616                 alpha[i] = alpha2[i] - alpha2[i+l];
1617
1618         delete[] alpha2;
1619         delete[] linear_term;
1620         delete[] y;
1621 }
1622
1623 //
1624 // decision_function
1625 //
1626 struct decision_function
1627 {
1628         double *alpha;
1629         double rho;     
1630 };
1631
1632 static decision_function svm_train_one(
1633         const svm_problem *prob, const svm_parameter *param,
1634         double Cp, double Cn)
1635 {
1636         double *alpha = Malloc(double,prob->l);
1637         Solver::SolutionInfo si;
1638         switch(param->svm_type)
1639         {
1640                 case C_SVC:
1641                         solve_c_svc(prob,param,alpha,&si,Cp,Cn);
1642                         break;
1643                 case NU_SVC:
1644                         solve_nu_svc(prob,param,alpha,&si);
1645                         break;
1646                 case ONE_CLASS:
1647                         solve_one_class(prob,param,alpha,&si);
1648                         break;
1649                 case EPSILON_SVR:
1650                         solve_epsilon_svr(prob,param,alpha,&si);
1651                         break;
1652                 case NU_SVR:
1653                         solve_nu_svr(prob,param,alpha,&si);
1654                         break;
1655         }
1656
1657         info("obj = %f, rho = %f\n",si.obj,si.rho);
1658
1659         // output SVs
1660
1661         int nSV = 0;
1662         int nBSV = 0;
1663         for(int i=0;i<prob->l;i++)
1664         {
1665                 if(fabs(alpha[i]) > 0)
1666                 {
1667                         ++nSV;
1668                         if(prob->y[i] > 0)
1669                         {
1670                                 if(fabs(alpha[i]) >= si.upper_bound_p)
1671                                         ++nBSV;
1672                         }
1673                         else
1674                         {
1675                                 if(fabs(alpha[i]) >= si.upper_bound_n)
1676                                         ++nBSV;
1677                         }
1678                 }
1679         }
1680
1681         info("nSV = %d, nBSV = %d\n",nSV,nBSV);
1682
1683         decision_function f;
1684         f.alpha = alpha;
1685         f.rho = si.rho;
1686         return f;
1687 }
1688
1689 //
1690 // svm_model
1691 // 
1692 struct svm_model
1693 {
1694         struct svm_parameter param;     /* parameter */
1695         int nr_class;           /* number of classes, = 2 in regression/one class svm */
1696         int l;                  /* total #SV */
1697         struct svm_node **SV;           /* SVs (SV[l]) */
1698         double **sv_coef;       /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */
1699         double *rho;            /* constants in decision functions (rho[k*(k-1)/2]) */
1700         double *probA;          /* pariwise probability information */
1701         double *probB;
1702
1703         /* for classification only */
1704
1705         int *label;             /* label of each class (label[k]) */
1706         int *nSV;               /* number of SVs for each class (nSV[k]) */
1707                                 /* nSV[0] + nSV[1] + ... + nSV[k-1] = l */
1708         /* XXX */
1709         int free_sv;            /* 1 if svm_model is created by svm_load_model*/
1710                                 /* 0 if svm_model is created by svm_train */
1711 };
1712
1713 // Platt's binary SVM Probablistic Output: an improvement from Lin et al.
1714 static void sigmoid_train(
1715         int l, const double *dec_values, const double *labels, 
1716         double& A, double& B)
1717 {
1718         double prior1=0, prior0 = 0;
1719         int i;
1720
1721         for (i=0;i<l;i++)
1722                 if (labels[i] > 0) prior1+=1;
1723                 else prior0+=1;
1724         
1725         int max_iter=100;       // Maximal number of iterations
1726         double min_step=1e-10;  // Minimal step taken in line search
1727         double sigma=1e-12;     // For numerically strict PD of Hessian
1728         double eps=1e-5;
1729         double hiTarget=(prior1+1.0)/(prior1+2.0);
1730         double loTarget=1/(prior0+2.0);
1731         double *t=Malloc(double,l);
1732         double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
1733         double newA,newB,newf,d1,d2;
1734         int iter; 
1735         
1736         // Initial Point and Initial Fun Value
1737         A=0.0; B=log((prior0+1.0)/(prior1+1.0));
1738         double fval = 0.0;
1739
1740         for (i=0;i<l;i++)
1741         {
1742                 if (labels[i]>0) t[i]=hiTarget;
1743                 else t[i]=loTarget;
1744                 fApB = dec_values[i]*A+B;
1745                 if (fApB>=0)
1746                         fval += t[i]*fApB + log(1+exp(-fApB));
1747                 else
1748                         fval += (t[i] - 1)*fApB +log(1+exp(fApB));
1749         }
1750         for (iter=0;iter<max_iter;iter++)
1751         {
1752                 // Update Gradient and Hessian (use H' = H + sigma I)
1753                 h11=sigma; // numerically ensures strict PD
1754                 h22=sigma;
1755                 h21=0.0;g1=0.0;g2=0.0;
1756                 for (i=0;i<l;i++)
1757                 {
1758                         fApB = dec_values[i]*A+B;
1759                         if (fApB >= 0)
1760                         {
1761                                 p=exp(-fApB)/(1.0+exp(-fApB));
1762                                 q=1.0/(1.0+exp(-fApB));
1763                         }
1764                         else
1765                         {
1766                                 p=1.0/(1.0+exp(fApB));
1767                                 q=exp(fApB)/(1.0+exp(fApB));
1768                         }
1769                         d2=p*q;
1770                         h11+=dec_values[i]*dec_values[i]*d2;
1771                         h22+=d2;
1772                         h21+=dec_values[i]*d2;
1773                         d1=t[i]-p;
1774                         g1+=dec_values[i]*d1;
1775                         g2+=d1;
1776                 }
1777
1778                 // Stopping Criteria
1779                 if (fabs(g1)<eps && fabs(g2)<eps)
1780                         break;
1781
1782                 // Finding Newton direction: -inv(H') * g
1783                 det=h11*h22-h21*h21;
1784                 dA=-(h22*g1 - h21 * g2) / det;
1785                 dB=-(-h21*g1+ h11 * g2) / det;
1786                 gd=g1*dA+g2*dB;
1787
1788
1789                 stepsize = 1;           // Line Search
1790                 while (stepsize >= min_step)
1791                 {
1792                         newA = A + stepsize * dA;
1793                         newB = B + stepsize * dB;
1794
1795                         // New function value
1796                         newf = 0.0;
1797                         for (i=0;i<l;i++)
1798                         {
1799                                 fApB = dec_values[i]*newA+newB;
1800                                 if (fApB >= 0)
1801                                         newf += t[i]*fApB + log(1+exp(-fApB));
1802                                 else
1803                                         newf += (t[i] - 1)*fApB +log(1+exp(fApB));
1804                         }
1805                         // Check sufficient decrease
1806                         if (newf<fval+0.0001*stepsize*gd)
1807                         {
1808                                 A=newA;B=newB;fval=newf;
1809                                 break;
1810                         }
1811                         else
1812                                 stepsize = stepsize / 2.0;
1813                 }
1814
1815                 if (stepsize < min_step)
1816                 {
1817                         info("Line search fails in two-class probability estimates\n");
1818                         break;
1819                 }
1820         }
1821
1822         if (iter>=max_iter)
1823                 info("Reaching maximal iterations in two-class probability estimates\n");
1824         free(t);
1825 }
1826
1827 static double sigmoid_predict(double decision_value, double A, double B)
1828 {
1829         double fApB = decision_value*A+B;
1830         if (fApB >= 0)
1831                 return exp(-fApB)/(1.0+exp(-fApB));
1832         else
1833                 return 1.0/(1+exp(fApB)) ;
1834 }
1835
1836 // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
1837 static void multiclass_probability(int k, double **r, double *p)
1838 {
1839         int t,j;
1840         int iter = 0, max_iter=max(100,k);
1841         double **Q=Malloc(double *,k);
1842         double *Qp=Malloc(double,k);
1843         double pQp, eps=0.005/k;
1844         
1845         for (t=0;t<k;t++)
1846         {
1847                 p[t]=1.0/k;  // Valid if k = 1
1848                 Q[t]=Malloc(double,k);
1849                 Q[t][t]=0;
1850                 for (j=0;j<t;j++)
1851                 {
1852                         Q[t][t]+=r[j][t]*r[j][t];
1853                         Q[t][j]=Q[j][t];
1854                 }
1855                 for (j=t+1;j<k;j++)
1856                 {
1857                         Q[t][t]+=r[j][t]*r[j][t];
1858                         Q[t][j]=-r[j][t]*r[t][j];
1859                 }
1860         }
1861         for (iter=0;iter<max_iter;iter++)
1862         {
1863                 // stopping condition, recalculate QP,pQP for numerical accuracy
1864                 pQp=0;
1865                 for (t=0;t<k;t++)
1866                 {
1867                         Qp[t]=0;
1868                         for (j=0;j<k;j++)
1869                                 Qp[t]+=Q[t][j]*p[j];
1870                         pQp+=p[t]*Qp[t];
1871                 }
1872                 double max_error=0;
1873                 for (t=0;t<k;t++)
1874                 {
1875                         double error=fabs(Qp[t]-pQp);
1876                         if (error>max_error)
1877                                 max_error=error;
1878                 }
1879                 if (max_error<eps) break;
1880                 
1881                 for (t=0;t<k;t++)
1882                 {
1883                         double diff=(-Qp[t]+pQp)/Q[t][t];
1884                         p[t]+=diff;
1885                         pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff);
1886                         for (j=0;j<k;j++)
1887                         {
1888                                 Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff);
1889                                 p[j]/=(1+diff);
1890                         }
1891                 }
1892         }
1893         if (iter>=max_iter)
1894                 info("Exceeds max_iter in multiclass_prob\n");
1895         for(t=0;t<k;t++) free(Q[t]);
1896         free(Q);
1897         free(Qp);
1898 }
1899
1900 // Cross-validation decision values for probability estimates
1901 static void svm_binary_svc_probability(
1902         const svm_problem *prob, const svm_parameter *param,
1903         double Cp, double Cn, double& probA, double& probB)
1904 {
1905         int i;
1906         int nr_fold = 5;
1907         int *perm = Malloc(int,prob->l);
1908         double *dec_values = Malloc(double,prob->l);
1909
1910         // random shuffle
1911         for(i=0;i<prob->l;i++) perm[i]=i;
1912         for(i=0;i<prob->l;i++)
1913         {
1914                 int j = i+rand()%(prob->l-i);
1915                 swap(perm[i],perm[j]);
1916         }
1917         for(i=0;i<nr_fold;i++)
1918         {
1919                 int begin = i*prob->l/nr_fold;
1920                 int end = (i+1)*prob->l/nr_fold;
1921                 int j,k;
1922                 struct svm_problem subprob;
1923
1924                 subprob.l = prob->l-(end-begin);
1925                 subprob.x = Malloc(struct svm_node*,subprob.l);
1926                 subprob.y = Malloc(double,subprob.l);
1927                         
1928                 k=0;
1929                 for(j=0;j<begin;j++)
1930                 {
1931                         subprob.x[k] = prob->x[perm[j]];
1932                         subprob.y[k] = prob->y[perm[j]];
1933                         ++k;
1934                 }
1935                 for(j=end;j<prob->l;j++)
1936                 {
1937                         subprob.x[k] = prob->x[perm[j]];
1938                         subprob.y[k] = prob->y[perm[j]];
1939                         ++k;
1940                 }
1941                 int p_count=0,n_count=0;
1942                 for(j=0;j<k;j++)
1943                         if(subprob.y[j]>0)
1944                                 p_count++;
1945                         else
1946                                 n_count++;
1947
1948                 if(p_count==0 && n_count==0)
1949                         for(j=begin;j<end;j++)
1950                                 dec_values[perm[j]] = 0;
1951                 else if(p_count > 0 && n_count == 0)
1952                         for(j=begin;j<end;j++)
1953                                 dec_values[perm[j]] = 1;
1954                 else if(p_count == 0 && n_count > 0)
1955                         for(j=begin;j<end;j++)
1956                                 dec_values[perm[j]] = -1;
1957                 else
1958                 {
1959                         svm_parameter subparam = *param;
1960                         subparam.probability=0;
1961                         subparam.C=1.0;
1962                         subparam.nr_weight=2;
1963                         subparam.weight_label = Malloc(int,2);
1964                         subparam.weight = Malloc(double,2);
1965                         subparam.weight_label[0]=+1;
1966                         subparam.weight_label[1]=-1;
1967                         subparam.weight[0]=Cp;
1968                         subparam.weight[1]=Cn;
1969                         struct svm_model *submodel = svm_train(&subprob,&subparam);
1970                         for(j=begin;j<end;j++)
1971                         {
1972                                 svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]])); 
1973                                 // ensure +1 -1 order; reason not using CV subroutine
1974                                 dec_values[perm[j]] *= submodel->label[0];
1975                         }               
1976                         svm_destroy_model(submodel);
1977                         svm_destroy_param(&subparam);
1978                 }
1979                 free(subprob.x);
1980                 free(subprob.y);
1981         }               
1982         sigmoid_train(prob->l,dec_values,prob->y,probA,probB);
1983         free(dec_values);
1984         free(perm);
1985 }
1986
1987 // Return parameter of a Laplace distribution 
1988 static double svm_svr_probability(
1989         const svm_problem *prob, const svm_parameter *param)
1990 {
1991         int i;
1992         int nr_fold = 5;
1993         double *ymv = Malloc(double,prob->l);
1994         double mae = 0;
1995
1996         svm_parameter newparam = *param;
1997         newparam.probability = 0;
1998         svm_cross_validation(prob,&newparam,nr_fold,ymv);
1999         for(i=0;i<prob->l;i++)
2000         {
2001                 ymv[i]=prob->y[i]-ymv[i];
2002                 mae += fabs(ymv[i]);
2003         }               
2004         mae /= prob->l;
2005         double std=sqrt(2*mae*mae);
2006         int count=0;
2007         mae=0;
2008         for(i=0;i<prob->l;i++)
2009                 if (fabs(ymv[i]) > 5*std) 
2010                         count=count+1;
2011                 else 
2012                         mae+=fabs(ymv[i]);
2013         mae /= (prob->l-count);
2014         info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae);
2015         free(ymv);
2016         return mae;
2017 }
2018
2019
2020 // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
2021 // perm, length l, must be allocated before calling this subroutine
2022 static void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)
2023 {
2024         int l = prob->l;
2025         int max_nr_class = 16;
2026         int nr_class = 0;
2027         int *label = Malloc(int,max_nr_class);
2028         int *count = Malloc(int,max_nr_class);
2029         int *data_label = Malloc(int,l);        
2030         int i;
2031
2032         for(i=0;i<l;i++)
2033         {
2034                 int this_label = (int)prob->y[i];
2035                 int j;
2036                 for(j=0;j<nr_class;j++)
2037                 {
2038                         if(this_label == label[j])
2039                         {
2040                                 ++count[j];
2041                                 break;
2042                         }
2043                 }
2044                 data_label[i] = j;
2045                 if(j == nr_class)
2046                 {
2047                         if(nr_class == max_nr_class)
2048                         {
2049                                 max_nr_class *= 2;
2050                                 label = (int *)realloc(label,max_nr_class*sizeof(int));
2051                                 count = (int *)realloc(count,max_nr_class*sizeof(int));
2052                         }
2053                         label[nr_class] = this_label;
2054                         count[nr_class] = 1;
2055                         ++nr_class;
2056                 }
2057         }
2058
2059         int *start = Malloc(int,nr_class);
2060         start[0] = 0;
2061         for(i=1;i<nr_class;i++)
2062                 start[i] = start[i-1]+count[i-1];
2063         for(i=0;i<l;i++)
2064         {
2065                 perm[start[data_label[i]]] = i;
2066                 ++start[data_label[i]];
2067         }
2068         start[0] = 0;
2069         for(i=1;i<nr_class;i++)
2070                 start[i] = start[i-1]+count[i-1];
2071
2072         *nr_class_ret = nr_class;
2073         *label_ret = label;
2074         *start_ret = start;
2075         *count_ret = count;
2076         free(data_label);
2077 }
2078
2079 //
2080 // Interface functions
2081 //
2082 svm_model *svm_train(const svm_problem *prob, const svm_parameter *param)
2083 {
2084         svm_model *model = Malloc(svm_model,1);
2085         model->param = *param;
2086         model->free_sv = 0;     // XXX
2087
2088         if(param->svm_type == ONE_CLASS ||
2089            param->svm_type == EPSILON_SVR ||
2090            param->svm_type == NU_SVR)
2091         {
2092                 // regression or one-class-svm
2093                 model->nr_class = 2;
2094                 model->label = NULL;
2095                 model->nSV = NULL;
2096                 model->probA = NULL; model->probB = NULL;
2097                 model->sv_coef = Malloc(double *,1);
2098
2099                 if(param->probability && 
2100                    (param->svm_type == EPSILON_SVR ||
2101                     param->svm_type == NU_SVR))
2102                 {
2103                         model->probA = Malloc(double,1);
2104                         model->probA[0] = svm_svr_probability(prob,param);
2105                 }
2106
2107                 decision_function f = svm_train_one(prob,param,0,0);
2108                 model->rho = Malloc(double,1);
2109                 model->rho[0] = f.rho;
2110
2111                 int nSV = 0;
2112                 int i;
2113                 for(i=0;i<prob->l;i++)
2114                         if(fabs(f.alpha[i]) > 0) ++nSV;
2115                 model->l = nSV;
2116                 model->SV = Malloc(svm_node *,nSV);
2117                 model->sv_coef[0] = Malloc(double,nSV);
2118                 int j = 0;
2119                 for(i=0;i<prob->l;i++)
2120                         if(fabs(f.alpha[i]) > 0)
2121                         {
2122                                 model->SV[j] = prob->x[i];
2123                                 model->sv_coef[0][j] = f.alpha[i];
2124                                 ++j;
2125                         }               
2126
2127                 free(f.alpha);
2128         }
2129         else
2130         {
2131                 // classification
2132                 int l = prob->l;
2133                 int nr_class;
2134                 int *label = NULL;
2135                 int *start = NULL;
2136                 int *count = NULL;
2137                 int *perm = Malloc(int,l);
2138
2139                 // group training data of the same class
2140                 svm_group_classes(prob,&nr_class,&label,&start,&count,perm);            
2141                 svm_node **x = Malloc(svm_node *,l);
2142                 int i;
2143                 for(i=0;i<l;i++)
2144                         x[i] = prob->x[perm[i]];
2145
2146                 // calculate weighted C
2147
2148                 double *weighted_C = Malloc(double, nr_class);
2149                 for(i=0;i<nr_class;i++)
2150                         weighted_C[i] = param->C;
2151                 for(i=0;i<param->nr_weight;i++)
2152                 {       
2153                         int j;
2154                         for(j=0;j<nr_class;j++)
2155                                 if(param->weight_label[i] == label[j])
2156                                         break;
2157                         if(j == nr_class)
2158                                 fprintf(stderr,"warning: class label %d specified in weight is not found\n", param->weight_label[i]);
2159                         else
2160                                 weighted_C[j] *= param->weight[i];
2161                 }
2162
2163                 // train k*(k-1)/2 models
2164                 
2165                 bool *nonzero = Malloc(bool,l);
2166                 for(i=0;i<l;i++)
2167                         nonzero[i] = false;
2168                 decision_function *f = Malloc(decision_function,nr_class*(nr_class-1)/2);
2169
2170                 double *probA=NULL,*probB=NULL;
2171                 if (param->probability)
2172                 {
2173                         probA=Malloc(double,nr_class*(nr_class-1)/2);
2174                         probB=Malloc(double,nr_class*(nr_class-1)/2);
2175                 }
2176
2177                 int p = 0;
2178                 for(i=0;i<nr_class;i++)
2179                         for(int j=i+1;j<nr_class;j++)
2180                         {
2181                                 svm_problem sub_prob;
2182                                 int si = start[i], sj = start[j];
2183                                 int ci = count[i], cj = count[j];
2184                                 sub_prob.l = ci+cj;
2185                                 sub_prob.x = Malloc(svm_node *,sub_prob.l);
2186                                 sub_prob.y = Malloc(double,sub_prob.l);
2187                                 int k;
2188                                 for(k=0;k<ci;k++)
2189                                 {
2190                                         sub_prob.x[k] = x[si+k];
2191                                         sub_prob.y[k] = +1;
2192                                 }
2193                                 for(k=0;k<cj;k++)
2194                                 {
2195                                         sub_prob.x[ci+k] = x[sj+k];
2196                                         sub_prob.y[ci+k] = -1;
2197                                 }
2198
2199                                 if(param->probability)
2200                                         svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]);
2201
2202                                 f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]);
2203                                 for(k=0;k<ci;k++)
2204                                         if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0)
2205                                                 nonzero[si+k] = true;
2206                                 for(k=0;k<cj;k++)
2207                                         if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0)
2208                                                 nonzero[sj+k] = true;
2209                                 free(sub_prob.x);
2210                                 free(sub_prob.y);
2211                                 ++p;
2212                         }
2213
2214                 // build output
2215
2216                 model->nr_class = nr_class;
2217                 
2218                 model->label = Malloc(int,nr_class);
2219                 for(i=0;i<nr_class;i++)
2220                         model->label[i] = label[i];
2221                 
2222                 model->rho = Malloc(double,nr_class*(nr_class-1)/2);
2223                 for(i=0;i<nr_class*(nr_class-1)/2;i++)
2224                         model->rho[i] = f[i].rho;
2225
2226                 if(param->probability)
2227                 {
2228                         model->probA = Malloc(double,nr_class*(nr_class-1)/2);
2229                         model->probB = Malloc(double,nr_class*(nr_class-1)/2);
2230                         for(i=0;i<nr_class*(nr_class-1)/2;i++)
2231                         {
2232                                 model->probA[i] = probA[i];
2233                                 model->probB[i] = probB[i];
2234                         }
2235                 }
2236                 else
2237                 {
2238                         model->probA=NULL;
2239                         model->probB=NULL;
2240                 }
2241
2242                 int total_sv = 0;
2243                 int *nz_count = Malloc(int,nr_class);
2244                 model->nSV = Malloc(int,nr_class);
2245                 for(i=0;i<nr_class;i++)
2246                 {
2247                         int nSV = 0;
2248                         for(int j=0;j<count[i];j++)
2249                                 if(nonzero[start[i]+j])
2250                                 {       
2251                                         ++nSV;
2252                                         ++total_sv;
2253                                 }
2254                         model->nSV[i] = nSV;
2255                         nz_count[i] = nSV;
2256                 }
2257                 
2258                 info("Total nSV = %d\n",total_sv);
2259
2260                 model->l = total_sv;
2261                 model->SV = Malloc(svm_node *,total_sv);
2262                 p = 0;
2263                 for(i=0;i<l;i++)
2264                         if(nonzero[i]) model->SV[p++] = x[i];
2265
2266                 int *nz_start = Malloc(int,nr_class);
2267                 nz_start[0] = 0;
2268                 for(i=1;i<nr_class;i++)
2269                         nz_start[i] = nz_start[i-1]+nz_count[i-1];
2270
2271                 model->sv_coef = Malloc(double *,nr_class-1);
2272                 for(i=0;i<nr_class-1;i++)
2273                         model->sv_coef[i] = Malloc(double,total_sv);
2274
2275                 p = 0;
2276                 for(i=0;i<nr_class;i++)
2277                         for(int j=i+1;j<nr_class;j++)
2278                         {
2279                                 // classifier (i,j): coefficients with
2280                                 // i are in sv_coef[j-1][nz_start[i]...],
2281                                 // j are in sv_coef[i][nz_start[j]...]
2282
2283                                 int si = start[i];
2284                                 int sj = start[j];
2285                                 int ci = count[i];
2286                                 int cj = count[j];
2287                                 
2288                                 int q = nz_start[i];
2289                                 int k;
2290                                 for(k=0;k<ci;k++)
2291                                         if(nonzero[si+k])
2292                                                 model->sv_coef[j-1][q++] = f[p].alpha[k];
2293                                 q = nz_start[j];
2294                                 for(k=0;k<cj;k++)
2295                                         if(nonzero[sj+k])
2296                                                 model->sv_coef[i][q++] = f[p].alpha[ci+k];
2297                                 ++p;
2298                         }
2299                 
2300                 free(label);
2301                 free(probA);
2302                 free(probB);
2303                 free(count);
2304                 free(perm);
2305                 free(start);
2306                 free(x);
2307                 free(weighted_C);
2308                 free(nonzero);
2309                 for(i=0;i<nr_class*(nr_class-1)/2;i++)
2310                         free(f[i].alpha);
2311                 free(f);
2312                 free(nz_count);
2313                 free(nz_start);
2314         }
2315         return model;
2316 }
2317
2318 // Stratified cross validation
2319 void svm_cross_validation(const svm_problem *prob, const svm_parameter *param, int nr_fold, double *target)
2320 {
2321         int i;
2322         int *fold_start = Malloc(int,nr_fold+1);
2323         int l = prob->l;
2324         int *perm = Malloc(int,l);
2325         int nr_class;
2326
2327         // stratified cv may not give leave-one-out rate
2328         // Each class to l folds -> some folds may have zero elements
2329         if((param->svm_type == C_SVC ||
2330             param->svm_type == NU_SVC) && nr_fold < l)
2331         {
2332                 int *start = NULL;
2333                 int *label = NULL;
2334                 int *count = NULL;
2335                 svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
2336
2337                 // random shuffle and then data grouped by fold using the array perm
2338                 int *fold_count = Malloc(int,nr_fold);
2339                 int c;
2340                 int *index = Malloc(int,l);
2341                 for(i=0;i<l;i++)
2342                         index[i]=perm[i];
2343                 for (c=0; c<nr_class; c++) 
2344                         for(i=0;i<count[c];i++)
2345                         {
2346                                 int j = i+rand()%(count[c]-i);
2347                                 swap(index[start[c]+j],index[start[c]+i]);
2348                         }
2349                 for(i=0;i<nr_fold;i++)
2350                 {
2351                         fold_count[i] = 0;
2352                         for (c=0; c<nr_class;c++)
2353                                 fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold;
2354                 }
2355                 fold_start[0]=0;
2356                 for (i=1;i<=nr_fold;i++)
2357                         fold_start[i] = fold_start[i-1]+fold_count[i-1];
2358                 for (c=0; c<nr_class;c++)
2359                         for(i=0;i<nr_fold;i++)
2360                         {
2361                                 int begin = start[c]+i*count[c]/nr_fold;
2362                                 int end = start[c]+(i+1)*count[c]/nr_fold;
2363                                 for(int j=begin;j<end;j++)
2364                                 {
2365                                         perm[fold_start[i]] = index[j];
2366                                         fold_start[i]++;
2367                                 }
2368                         }
2369                 fold_start[0]=0;
2370                 for (i=1;i<=nr_fold;i++)
2371                         fold_start[i] = fold_start[i-1]+fold_count[i-1];
2372                 free(start);    
2373                 free(label);
2374                 free(count);    
2375                 free(index);
2376                 free(fold_count);
2377         }
2378         else
2379         {
2380                 for(i=0;i<l;i++) perm[i]=i;
2381                 for(i=0;i<l;i++)
2382                 {
2383                         int j = i+rand()%(l-i);
2384                         swap(perm[i],perm[j]);
2385                 }
2386                 for(i=0;i<=nr_fold;i++)
2387                         fold_start[i]=i*l/nr_fold;
2388         }
2389
2390         for(i=0;i<nr_fold;i++)
2391         {
2392                 int begin = fold_start[i];
2393                 int end = fold_start[i+1];
2394                 int j,k;
2395                 struct svm_problem subprob;
2396
2397                 subprob.l = l-(end-begin);
2398                 subprob.x = Malloc(struct svm_node*,subprob.l);
2399                 subprob.y = Malloc(double,subprob.l);
2400                         
2401                 k=0;
2402                 for(j=0;j<begin;j++)
2403                 {
2404                         subprob.x[k] = prob->x[perm[j]];
2405                         subprob.y[k] = prob->y[perm[j]];
2406                         ++k;
2407                 }
2408                 for(j=end;j<l;j++)
2409                 {
2410                         subprob.x[k] = prob->x[perm[j]];
2411                         subprob.y[k] = prob->y[perm[j]];
2412                         ++k;
2413                 }
2414                 struct svm_model *submodel = svm_train(&subprob,param);
2415                 if(param->probability && 
2416                    (param->svm_type == C_SVC || param->svm_type == NU_SVC))
2417                 {
2418                         double *prob_estimates=Malloc(double,svm_get_nr_class(submodel));
2419                         for(j=begin;j<end;j++)
2420                                 target[perm[j]] = svm_predict_probability(submodel,prob->x[perm[j]],prob_estimates);
2421                         free(prob_estimates);                   
2422                 }
2423                 else
2424                         for(j=begin;j<end;j++)
2425                                 target[perm[j]] = svm_predict(submodel,prob->x[perm[j]]);
2426                 svm_destroy_model(submodel);
2427                 free(subprob.x);
2428                 free(subprob.y);
2429         }               
2430         free(fold_start);
2431         free(perm);     
2432 }
2433
2434
2435 int svm_get_svm_type(const svm_model *model)
2436 {
2437         return model->param.svm_type;
2438 }
2439
2440 int svm_get_nr_class(const svm_model *model)
2441 {
2442         return model->nr_class;
2443 }
2444
2445 void svm_get_labels(const svm_model *model, int* label)
2446 {
2447         if (model->label != NULL)
2448                 for(int i=0;i<model->nr_class;i++)
2449                         label[i] = model->label[i];
2450 }
2451
2452 double svm_get_svr_probability(const svm_model *model)
2453 {
2454         if ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
2455             model->probA!=NULL)
2456                 return model->probA[0];
2457         else
2458         {
2459                 fprintf(stderr,"Model doesn't contain information for SVR probability inference\n");
2460                 return 0;
2461         }
2462 }
2463
2464 double svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values)
2465 {
2466         if(model->param.svm_type == ONE_CLASS ||
2467            model->param.svm_type == EPSILON_SVR ||
2468            model->param.svm_type == NU_SVR)
2469         {
2470                 double *sv_coef = model->sv_coef[0];
2471                 double sum = 0;
2472                 for(int i=0;i<model->l;i++)
2473                         sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param);
2474                 sum -= model->rho[0];
2475                 *dec_values = sum;
2476
2477                 if(model->param.svm_type == ONE_CLASS)
2478                         return (sum>0)?1:-1;
2479                 else
2480                         return sum;
2481         }
2482         else
2483         {
2484                 int i;
2485                 int nr_class = model->nr_class;
2486                 int l = model->l;
2487                 
2488                 double *kvalue = Malloc(double,l);
2489                 for(i=0;i<l;i++)
2490                         kvalue[i] = Kernel::k_function(x,model->SV[i],model->param);
2491
2492                 int *start = Malloc(int,nr_class);
2493                 start[0] = 0;
2494                 for(i=1;i<nr_class;i++)
2495                         start[i] = start[i-1]+model->nSV[i-1];
2496
2497                 int *vote = Malloc(int,nr_class);
2498                 for(i=0;i<nr_class;i++)
2499                         vote[i] = 0;
2500
2501                 int p=0;
2502                 for(i=0;i<nr_class;i++)
2503                         for(int j=i+1;j<nr_class;j++)
2504                         {
2505                                 double sum = 0;
2506                                 int si = start[i];
2507                                 int sj = start[j];
2508                                 int ci = model->nSV[i];
2509                                 int cj = model->nSV[j];
2510                                 
2511                                 int k;
2512                                 double *coef1 = model->sv_coef[j-1];
2513                                 double *coef2 = model->sv_coef[i];
2514                                 for(k=0;k<ci;k++)
2515                                         sum += coef1[si+k] * kvalue[si+k];
2516                                 for(k=0;k<cj;k++)
2517                                         sum += coef2[sj+k] * kvalue[sj+k];
2518                                 sum -= model->rho[p];
2519                                 dec_values[p] = sum;
2520
2521                                 if(dec_values[p] > 0)
2522                                         ++vote[i];
2523                                 else
2524                                         ++vote[j];
2525                                 p++;
2526                         }
2527
2528                 int vote_max_idx = 0;
2529                 for(i=1;i<nr_class;i++)
2530                         if(vote[i] > vote[vote_max_idx])
2531                                 vote_max_idx = i;
2532
2533                 free(kvalue);
2534                 free(start);
2535                 free(vote);
2536                 return model->label[vote_max_idx];
2537         }
2538 }
2539
2540 double svm_predict(const svm_model *model, const svm_node *x)
2541 {
2542         int nr_class = model->nr_class;
2543         double *dec_values;
2544         if(model->param.svm_type == ONE_CLASS ||
2545            model->param.svm_type == EPSILON_SVR ||
2546            model->param.svm_type == NU_SVR)
2547                 dec_values = Malloc(double, 1);
2548         else 
2549                 dec_values = Malloc(double, nr_class*(nr_class-1)/2);
2550         double pred_result = svm_predict_values(model, x, dec_values);
2551         free(dec_values);
2552         return pred_result;
2553 }
2554
2555 double svm_predict_probability(
2556         const svm_model *model, const svm_node *x, double *prob_estimates)
2557 {
2558         if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
2559             model->probA!=NULL && model->probB!=NULL)
2560         {
2561                 int i;
2562                 int nr_class = model->nr_class;
2563                 double *dec_values = Malloc(double, nr_class*(nr_class-1)/2);
2564                 svm_predict_values(model, x, dec_values);
2565
2566                 double min_prob=1e-7;
2567                 double **pairwise_prob=Malloc(double *,nr_class);
2568                 for(i=0;i<nr_class;i++)
2569                         pairwise_prob[i]=Malloc(double,nr_class);
2570                 int k=0;
2571                 for(i=0;i<nr_class;i++)
2572                         for(int j=i+1;j<nr_class;j++)
2573                         {
2574                                 pairwise_prob[i][j]=min(max(sigmoid_predict(dec_values[k],model->probA[k],model->probB[k]),min_prob),1-min_prob);
2575                                 pairwise_prob[j][i]=1-pairwise_prob[i][j];
2576                                 k++;
2577                         }
2578                 multiclass_probability(nr_class,pairwise_prob,prob_estimates);
2579
2580                 int prob_max_idx = 0;
2581                 for(i=1;i<nr_class;i++)
2582                         if(prob_estimates[i] > prob_estimates[prob_max_idx])
2583                                 prob_max_idx = i;
2584                 for(i=0;i<nr_class;i++)
2585                         free(pairwise_prob[i]);
2586                 free(dec_values);
2587                 free(pairwise_prob);         
2588                 return model->label[prob_max_idx];
2589         }
2590         else 
2591                 return svm_predict(model, x);
2592 }
2593
2594 static const char *svm_type_table[] =
2595 {
2596         "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL
2597 };
2598
2599 static const char *kernel_type_table[]=
2600 {
2601         "linear","polynomial","rbf","sigmoid","precomputed",NULL
2602 };
2603
2604 int svm_save_model(const char *model_file_name, const svm_model *model)
2605 {
2606         FILE *fp = fopen(model_file_name,"w");
2607         if(fp==NULL) return -1;
2608
2609         const svm_parameter& param = model->param;
2610
2611         fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]);
2612         fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]);
2613
2614         if(param.kernel_type == POLY)
2615                 fprintf(fp,"degree %d\n", param.degree);
2616
2617         if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID)
2618                 fprintf(fp,"gamma %g\n", param.gamma);
2619
2620         if(param.kernel_type == POLY || param.kernel_type == SIGMOID)
2621                 fprintf(fp,"coef0 %g\n", param.coef0);
2622
2623         int nr_class = model->nr_class;
2624         int l = model->l;
2625         fprintf(fp, "nr_class %d\n", nr_class);
2626         fprintf(fp, "total_sv %d\n",l);
2627         
2628         {
2629                 fprintf(fp, "rho");
2630                 for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2631                         fprintf(fp," %g",model->rho[i]);
2632                 fprintf(fp, "\n");
2633         }
2634         
2635         if(model->label)
2636         {
2637                 fprintf(fp, "label");
2638                 for(int i=0;i<nr_class;i++)
2639                         fprintf(fp," %d",model->label[i]);
2640                 fprintf(fp, "\n");
2641         }
2642
2643         if(model->probA) // regression has probA only
2644         {
2645                 fprintf(fp, "probA");
2646                 for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2647                         fprintf(fp," %g",model->probA[i]);
2648                 fprintf(fp, "\n");
2649         }
2650         if(model->probB)
2651         {
2652                 fprintf(fp, "probB");
2653                 for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2654                         fprintf(fp," %g",model->probB[i]);
2655                 fprintf(fp, "\n");
2656         }
2657
2658         if(model->nSV)
2659         {
2660                 fprintf(fp, "nr_sv");
2661                 for(int i=0;i<nr_class;i++)
2662                         fprintf(fp," %d",model->nSV[i]);
2663                 fprintf(fp, "\n");
2664         }
2665
2666         fprintf(fp, "SV\n");
2667         const double * const *sv_coef = model->sv_coef;
2668         const svm_node * const *SV = model->SV;
2669
2670         for(int i=0;i<l;i++)
2671         {
2672                 for(int j=0;j<nr_class-1;j++)
2673                         fprintf(fp, "%.16g ",sv_coef[j][i]);
2674
2675                 const svm_node *p = SV[i];
2676
2677                 if(param.kernel_type == PRECOMPUTED)
2678                         fprintf(fp,"0:%d ",(int)(p->value));
2679                 else
2680                         while(p->index != -1)
2681                         {
2682                                 fprintf(fp,"%d:%.8g ",p->index,p->value);
2683                                 p++;
2684                         }
2685                 fprintf(fp, "\n");
2686         }
2687         if (ferror(fp) != 0 || fclose(fp) != 0) return -1;
2688         else return 0;
2689 }
2690
2691 static char *line = NULL;
2692 static int max_line_len;
2693
2694 static char* readline(FILE *input)
2695 {
2696         int len;
2697
2698         if(fgets(line,max_line_len,input) == NULL)
2699                 return NULL;
2700
2701         while(strrchr(line,'\n') == NULL)
2702         {
2703                 max_line_len *= 2;
2704                 line = (char *) realloc(line,max_line_len);
2705                 len = (int) strlen(line);
2706                 if(fgets(line+len,max_line_len-len,input) == NULL)
2707                         break;
2708         }
2709         return line;
2710 }
2711
2712 svm_model *svm_load_model(const char *model_file_name)
2713 {
2714         FILE *fp = fopen(model_file_name,"rb");
2715         if(fp==NULL) return NULL;
2716         
2717         // read parameters
2718
2719         svm_model *model = Malloc(svm_model,1);
2720         svm_parameter& param = model->param;
2721         model->rho = NULL;
2722         model->probA = NULL;
2723         model->probB = NULL;
2724         model->label = NULL;
2725         model->nSV = NULL;
2726
2727         char cmd[81];
2728         while(1)
2729         {
2730                 fscanf(fp,"%80s",cmd);
2731
2732                 if(strcmp(cmd,"svm_type")==0)
2733                 {
2734                         fscanf(fp,"%80s",cmd);
2735                         int i;
2736                         for(i=0;svm_type_table[i];i++)
2737                         {
2738                                 if(strcmp(svm_type_table[i],cmd)==0)
2739                                 {
2740                                         param.svm_type=i;
2741                                         break;
2742                                 }
2743                         }
2744                         if(svm_type_table[i] == NULL)
2745                         {
2746                                 fprintf(stderr,"unknown svm type.\n");
2747                                 free(model->rho);
2748                                 free(model->label);
2749                                 free(model->nSV);
2750                                 free(model);
2751                                 return NULL;
2752                         }
2753                 }
2754                 else if(strcmp(cmd,"kernel_type")==0)
2755                 {               
2756                         fscanf(fp,"%80s",cmd);
2757                         int i;
2758                         for(i=0;kernel_type_table[i];i++)
2759                         {
2760                                 if(strcmp(kernel_type_table[i],cmd)==0)
2761                                 {
2762                                         param.kernel_type=i;
2763                                         break;
2764                                 }
2765                         }
2766                         if(kernel_type_table[i] == NULL)
2767                         {
2768                                 fprintf(stderr,"unknown kernel function.\n");
2769                                 free(model->rho);
2770                                 free(model->label);
2771                                 free(model->nSV);
2772                                 free(model);
2773                                 return NULL;
2774                         }
2775                 }
2776                 else if(strcmp(cmd,"degree")==0)
2777                         fscanf(fp,"%d",&param.degree);
2778                 else if(strcmp(cmd,"gamma")==0)
2779                         fscanf(fp,"%lf",&param.gamma);
2780                 else if(strcmp(cmd,"coef0")==0)
2781                         fscanf(fp,"%lf",&param.coef0);
2782                 else if(strcmp(cmd,"nr_class")==0)
2783                         fscanf(fp,"%d",&model->nr_class);
2784                 else if(strcmp(cmd,"total_sv")==0)
2785                         fscanf(fp,"%d",&model->l);
2786                 else if(strcmp(cmd,"rho")==0)
2787                 {
2788                         int n = model->nr_class * (model->nr_class-1)/2;
2789                         model->rho = Malloc(double,n);
2790                         for(int i=0;i<n;i++)
2791                                 fscanf(fp,"%lf",&model->rho[i]);
2792                 }
2793                 else if(strcmp(cmd,"label")==0)
2794                 {
2795                         int n = model->nr_class;
2796                         model->label = Malloc(int,n);
2797                         for(int i=0;i<n;i++)
2798                                 fscanf(fp,"%d",&model->label[i]);
2799                 }
2800                 else if(strcmp(cmd,"probA")==0)
2801                 {
2802                         int n = model->nr_class * (model->nr_class-1)/2;
2803                         model->probA = Malloc(double,n);
2804                         for(int i=0;i<n;i++)
2805                                 fscanf(fp,"%lf",&model->probA[i]);
2806                 }
2807                 else if(strcmp(cmd,"probB")==0)
2808                 {
2809                         int n = model->nr_class * (model->nr_class-1)/2;
2810                         model->probB = Malloc(double,n);
2811                         for(int i=0;i<n;i++)
2812                                 fscanf(fp,"%lf",&model->probB[i]);
2813                 }
2814                 else if(strcmp(cmd,"nr_sv")==0)
2815                 {
2816                         int n = model->nr_class;
2817                         model->nSV = Malloc(int,n);
2818                         for(int i=0;i<n;i++)
2819                                 fscanf(fp,"%d",&model->nSV[i]);
2820                 }
2821                 else if(strcmp(cmd,"SV")==0)
2822                 {
2823                         while(1)
2824                         {
2825                                 int c = getc(fp);
2826                                 if(c==EOF || c=='\n') break;    
2827                         }
2828                         break;
2829                 }
2830                 else
2831                 {
2832                         fprintf(stderr,"unknown text in model file: [%s]\n",cmd);
2833                         free(model->rho);
2834                         free(model->label);
2835                         free(model->nSV);
2836                         free(model);
2837                         return NULL;
2838                 }
2839         }
2840
2841         // read sv_coef and SV
2842
2843         int elements = 0;
2844         long pos = ftell(fp);
2845
2846         max_line_len = 1024;
2847         line = Malloc(char,max_line_len);
2848         char *p,*endptr,*idx,*val;
2849
2850         while(readline(fp)!=NULL)
2851         {
2852                 p = strtok(line,":");
2853                 while(1)
2854                 {
2855                         p = strtok(NULL,":");
2856                         if(p == NULL)
2857                                 break;
2858                         ++elements;
2859                 }
2860         }
2861         elements += model->l;
2862
2863         fseek(fp,pos,SEEK_SET);
2864
2865         int m = model->nr_class - 1;
2866         int l = model->l;
2867         model->sv_coef = Malloc(double *,m);
2868         int i;
2869         for(i=0;i<m;i++)
2870                 model->sv_coef[i] = Malloc(double,l);
2871         model->SV = Malloc(svm_node*,l);
2872         svm_node *x_space = NULL;
2873         if(l>0) x_space = Malloc(svm_node,elements);
2874
2875         int j=0;
2876         for(i=0;i<l;i++)
2877         {
2878                 readline(fp);
2879                 model->SV[i] = &x_space[j];
2880
2881                 p = strtok(line, " \t");
2882                 model->sv_coef[0][i] = strtod(p,&endptr);
2883                 for(int k=1;k<m;k++)
2884                 {
2885                         p = strtok(NULL, " \t");
2886                         model->sv_coef[k][i] = strtod(p,&endptr);
2887                 }
2888
2889                 while(1)
2890                 {
2891                         idx = strtok(NULL, ":");
2892                         val = strtok(NULL, " \t");
2893
2894                         if(val == NULL)
2895                                 break;
2896                         x_space[j].index = (int) strtol(idx,&endptr,10);
2897                         x_space[j].value = strtod(val,&endptr);
2898
2899                         ++j;
2900                 }
2901                 x_space[j++].index = -1;
2902         }
2903         free(line);
2904
2905         if (ferror(fp) != 0 || fclose(fp) != 0)
2906                 return NULL;
2907
2908         model->free_sv = 1;     // XXX
2909         return model;
2910 }
2911
2912 void svm_destroy_model(svm_model* model)
2913 {
2914         if(model->free_sv && model->l > 0)
2915                 free((void *)(model->SV[0]));
2916         for(int i=0;i<model->nr_class-1;i++)
2917                 free(model->sv_coef[i]);
2918         free(model->SV);
2919         free(model->sv_coef);
2920         free(model->rho);
2921         free(model->label);
2922         free(model->probA);
2923         free(model->probB);
2924         free(model->nSV);
2925         free(model);
2926 }
2927
2928 void svm_destroy_param(svm_parameter* param)
2929 {
2930         free(param->weight_label);
2931         free(param->weight);
2932 }
2933
2934 const char *svm_check_parameter(const svm_problem *prob, const svm_parameter *param)
2935 {
2936         // svm_type
2937
2938         int svm_type = param->svm_type;
2939         if(svm_type != C_SVC &&
2940            svm_type != NU_SVC &&
2941            svm_type != ONE_CLASS &&
2942            svm_type != EPSILON_SVR &&
2943            svm_type != NU_SVR)
2944                 return "unknown svm type";
2945         
2946         // kernel_type, degree
2947         
2948         int kernel_type = param->kernel_type;
2949         if(kernel_type != LINEAR &&
2950            kernel_type != POLY &&
2951            kernel_type != RBF &&
2952            kernel_type != SIGMOID &&
2953            kernel_type != PRECOMPUTED)
2954                 return "unknown kernel type";
2955
2956         if(param->gamma < 0)
2957                 return "gamma < 0";
2958
2959         if(param->degree < 0)
2960                 return "degree of polynomial kernel < 0";
2961
2962         // cache_size,eps,C,nu,p,shrinking
2963
2964         if(param->cache_size <= 0)
2965                 return "cache_size <= 0";
2966
2967         if(param->eps <= 0)
2968                 return "eps <= 0";
2969
2970         if(svm_type == C_SVC ||
2971            svm_type == EPSILON_SVR ||
2972            svm_type == NU_SVR)
2973                 if(param->C <= 0)
2974                         return "C <= 0";
2975
2976         if(svm_type == NU_SVC ||
2977            svm_type == ONE_CLASS ||
2978            svm_type == NU_SVR)
2979                 if(param->nu <= 0 || param->nu > 1)
2980                         return "nu <= 0 or nu > 1";
2981
2982         if(svm_type == EPSILON_SVR)
2983                 if(param->p < 0)
2984                         return "p < 0";
2985
2986         if(param->shrinking != 0 &&
2987            param->shrinking != 1)
2988                 return "shrinking != 0 and shrinking != 1";
2989
2990         if(param->probability != 0 &&
2991            param->probability != 1)
2992                 return "probability != 0 and probability != 1";
2993
2994         if(param->probability == 1 &&
2995            svm_type == ONE_CLASS)
2996                 return "one-class SVM probability output not supported yet";
2997
2998
2999         // check whether nu-svc is feasible
3000         
3001         if(svm_type == NU_SVC)
3002         {
3003                 int l = prob->l;
3004                 int max_nr_class = 16;
3005                 int nr_class = 0;
3006                 int *label = Malloc(int,max_nr_class);
3007                 int *count = Malloc(int,max_nr_class);
3008
3009                 int i;
3010                 for(i=0;i<l;i++)
3011                 {
3012                         int this_label = (int)prob->y[i];
3013                         int j;
3014                         for(j=0;j<nr_class;j++)
3015                                 if(this_label == label[j])
3016                                 {
3017                                         ++count[j];
3018                                         break;
3019                                 }
3020                         if(j == nr_class)
3021                         {
3022                                 if(nr_class == max_nr_class)
3023                                 {
3024                                         max_nr_class *= 2;
3025                                         label = (int *)realloc(label,max_nr_class*sizeof(int));
3026                                         count = (int *)realloc(count,max_nr_class*sizeof(int));
3027                                 }
3028                                 label[nr_class] = this_label;
3029                                 count[nr_class] = 1;
3030                                 ++nr_class;
3031                         }
3032                 }
3033         
3034                 for(i=0;i<nr_class;i++)
3035                 {
3036                         int n1 = count[i];
3037                         for(int j=i+1;j<nr_class;j++)
3038                         {
3039                                 int n2 = count[j];
3040                                 if(param->nu*(n1+n2)/2 > min(n1,n2))
3041                                 {
3042                                         free(label);
3043                                         free(count);
3044                                         return "specified nu is infeasible";
3045                                 }
3046                         }
3047                 }
3048                 free(label);
3049                 free(count);
3050         }
3051
3052         return NULL;
3053 }
3054
3055 int svm_check_probability_model(const svm_model *model)
3056 {
3057         return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
3058                 model->probA!=NULL && model->probB!=NULL) ||
3059                 ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
3060                  model->probA!=NULL);
3061 }
3062
3063 void svm_set_print_string_function(void (*print_func)(const char *))
3064 {
3065         if(print_func == NULL)
3066                 svm_print_string = &print_string_stdout;
3067         else
3068                 svm_print_string = print_func;
3069 }