44#include "EST_cutils.h"
53 return impurity.value();
54 else if (question.ask(d))
55 return left->predict(d);
57 return right->predict(d);
64 else if (question.ask(d))
65 return left->predict_node(d);
67 return right->predict_node(d);
74 if ((left == 0) && (right == 0))
76 else if (get_impurity().type() != wnim_class)
82void WNode::prune(
void)
90 if (left != 0) left->prune();
91 if (right != 0) right->prune();
95 if ((left->pure() == TRUE) && ((right->pure() == TRUE)) &&
96 (left->get_impurity().value() == right->get_impurity().value()))
98 delete left; left = 0;
99 delete right; right = 0;
105void WNode::held_out_prune()
116 wgn_score_question(question,get_data());
117 if (question.get_score() < get_impurity().measure())
119 wgn_find_split(question,get_data(),
122 left->held_out_prune();
123 right->held_out_prune();
127 delete left; left = 0;
128 delete right; right = 0;
133void WNode::print_out(ostream &s,
int margin)
138 for (i=0;i<margin;i++) s <<
" ";
145 left->print_out(s,margin+1);
146 right->print_out(s,margin+1);
151ostream & operator <<(ostream &s,
WNode &n)
160void WDataSet::ignore_non_numbers()
165 for (i=0; i<dlength; i++)
167 if ((p_type[i] == wndt_binary) ||
168 (p_type[i] == wndt_float))
179void WDataSet::load_description(
const EST_String &fname, LISP ignores)
186 description = car(vload(fname,1));
187 dlength = siod_llength(description);
193 if (wgn_predictee_name ==
"")
198 for (i=0,d=description; d != NIL; d=cdr(d),i++)
200 p_name[i] = get_c_string(car(car(d)));
201 tname = get_c_string(car(cdr(car(d))));
203 if ((wgn_predictee_name !=
"") && (wgn_predictee_name == p_name[i]))
205 if ((wgn_count_field_name !=
"") &&
206 (wgn_count_field_name == p_name[i]))
208 if ((tname ==
"count") || (i == wgn_count_field))
211 p_type[i] = wndt_ignore;
215 else if ((tname ==
"ignore") || (siod_member_str(p_name[i],ignores)))
217 p_type[i] = wndt_ignore;
219 if (i == wgn_predictee)
220 wagon_error(
EST_String(
"predictee \"")+p_name[i]+
221 "\" can't be ignored \n");
223 else if (siod_llength(car(d)) > 2)
225 LISP rest = cdr(car(d));
227 siod_list_to_strlist(rest,sl);
228 p_type[i] = wgn_discretes.def(sl);
229 if (streq(get_c_string(car(rest)),
"_other_"))
230 wgn_discretes[p_type[i]].def_val(
"_other_");
232 else if (tname ==
"binary")
233 p_type[i] = wndt_binary;
234 else if (tname ==
"cluster")
235 p_type[i] = wndt_cluster;
236 else if (tname ==
"vector")
237 p_type[i] = wndt_vector;
238 else if (tname ==
"trajectory")
239 p_type[i] = wndt_trajectory;
240 else if (tname ==
"ols")
241 p_type[i] = wndt_ols;
242 else if (tname ==
"matrix")
243 p_type[i] = wndt_matrix;
244 else if (tname ==
"float")
245 p_type[i] = wndt_float;
248 wagon_error(
EST_String(
"Unknown type \"")+tname+
249 "\" for field number "+itoString(i)+
250 "/"+p_name[i]+
" in description file \""+fname+
"\"");
254 if (wgn_predictee == -1)
256 wagon_error(
EST_String(
"predictee field \"")+wgn_predictee_name+
257 "\" not found in description ");
261const int WQuestion::ask(
const WVector &w)
const
267 if (w.get_flt_val(feature_pos) == operand1.
Float())
272 if (w.get_int_val(feature_pos) == 1)
276 case wnop_greaterthan:
277 if (w.get_flt_val(feature_pos) > operand1.
Float())
282 if (w.get_flt_val(feature_pos) < operand1.
Float())
287 if (w.get_int_val(feature_pos) == operand1.
Int())
292 if (ilist_member(operandl,w.get_int_val(feature_pos)))
297 wagon_error(
"Unknown test operator");
303ostream& operator<<(ostream& s,
const WQuestion &q)
306 static EST_Regex needquotes(
".*[()'\";., \t\n\r].*");
308 s <<
"(" << wgn_dataset.feat_name(q.get_fp());
312 s <<
" = " << q.get_operand1().
string();
316 case wnop_greaterthan:
317 s <<
" > " << q.get_operand1().
Float();
320 s <<
" < " << q.get_operand1().
Float();
323 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
324 name(q.get_operand1().
Int());
327 s << quote_string(name,
"\"",
"\\",1);
332 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
333 name(q.get_operand1().
Int());
334 s <<
" matches " << quote_string(name,
"\"",
"\\",1);
338 for (
int l=0; l < q.get_operandl().length(); l++)
340 name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
341 name(q.get_operandl().
nth(l));
342 if (name.matches(needquotes))
343 s << quote_string(name,
"\"",
"\\",1);
367 cerr <<
"WImpurity: no value currently set\n";
370 else if (t==wnim_class)
372 else if (t==wnim_cluster)
374 else if (t==wnim_ols)
376 else if (t==wnim_vector)
378 else if (t==wnim_trajectory)
384double WImpurity::samples(
void)
388 else if (t==wnim_class)
390 else if (t==wnim_cluster)
391 return members.length();
392 else if (t==wnim_ols)
393 return members.length();
394 else if (t==wnim_vector)
395 return members.length();
396 else if (t==wnim_trajectory)
397 return members.length();
407 a.
reset(); trajectory=0; l=0; width=0;
409 for (i=0; i < ds.
n(); i++)
413 else if (wgn_count_field == -1)
414 cumulate((*(ds(i)))[wgn_predictee],1);
416 cumulate((*(ds(i)))[wgn_predictee],
417 (*(ds(i)))[wgn_count_field]);
421float WImpurity::measure(
void)
425 else if (t == wnim_vector)
426 return vector_impurity();
427 else if (t == wnim_trajectory)
428 return trajectory_impurity();
429 else if (t == wnim_matrix)
431 else if (t == wnim_class)
433 else if (t == wnim_cluster)
434 return cluster_impurity();
435 else if (t == wnim_ols)
436 return ols_impurity();
439 cerr <<
"WImpurity: can't measure unset object" << endl;
444float WImpurity::vector_impurity()
459 if (wgn_VertexFeats.
a(0,j) > 0.0)
462 for (pp=members.head(), countpp=member_counts.head(); pp != 0; pp=pp->next(), countpp=countpp->next())
464 i = members.
item(pp);
467 b.cumulate(wgn_VertexTrack.
a(i,j), member_counts.
item(countpp)) ;
477 float x, lshift, rshift, ushift;
482 if (wgn_VertexFeats.
a(0,j) > 0.0)
485 for (pp=members.head(), countpp=member_counts.head(); pp != 0;
486 pp=pp->next(), countpp=countpp->next())
488 i = members.
item(pp);
490 c[j].cumulate(wgn_VertexTrack.
a(i,j),member_counts.
item(countpp));
497 for (pp=members.head(), countpp=member_counts.head(); pp != 0;
498 pp=pp->next(), countpp=countpp->next())
501 float bshift, qshift;
503 i = members.
item(pp);
505 lshift = 0; ushift = 0; rshift = 0;
507 for (q=-20; q<=20; q++)
510 for (j=67+q; j<147+q; j++)
512 x = c[j].
mean() - wgn_VertexTrack(i,j);
514 if ((bshift > 0) && (qshift > bshift))
517 if ((bshift == 0) || (qshift < bshift))
536 if (wgn_VertexFeats.
a(0,j) > 0.0)
538 for (pp=members.head(); pp != 0; pp=pp->next())
539 cs[j][j] += wgn_VertexTrack.
a(members.
item(pp),j);
545 if (wgn_VertexFeats.
a(0,j) > 0.0)
547 for (pp=members.head(); pp != 0; pp=pp->next())
549 mmm = members.
item(pp);
550 cs[i][j] += (wgn_VertexTrack.
a(mmm,i)-cs[j][j].
mean())*
551 (wgn_VertexTrack.
a(mmm,j)-cs[j][j].
mean());
558 if (wgn_VertexFeats.
a(0,j) > 0.0)
559 a += cs[i][j].stddev();
570 for (pp=members.head(); pp != 0; pp=pp->next())
572 x = members.
item(pp);
574 for (qq=pp->next(); qq != 0; qq=qq->next())
576 y = members.
item(qq);
578 if (wgn_VertexFeats.
a(0,j) > 0.0)
580 d = wgn_VertexTrack(x,j)-wgn_VertexTrack(y,j);
590 return a.
mean() * count;
593WImpurity::~WImpurity()
600 delete [] trajectory[j];
601 delete [] trajectory;
608float WImpurity::trajectory_impurity()
620 double n, m, m1, m2, w;
633 for (pp=members.head(); pp != 0; pp=pp->next())
635 i = members.
item(pp);
636 for (q=0; q<wgn_UnitTrack.
a(i,1); q++)
638 ni = (int)wgn_UnitTrack.
a(i,0)+q;
639 if (wgn_VertexTrack.
a(ni,0) == -1.0)
646 if (q==wgn_UnitTrack.
a(i,1))
652 l2ss += wgn_UnitTrack.
a(i,1) - (q+1) - 1;
653 lss += wgn_UnitTrack.
a(i,1);
654 if (wgn_UnitTrack.
a(i,1) > l)
655 l = (
int)wgn_UnitTrack.
a(i,1);
660 l = ((int)lss.
mean() < 7) ? 7 : (
int)lss.
mean();
668 for (pp=members.head(); pp != 0; pp=pp->next())
670 i = members.
item(pp);
671 m = (float)wgn_UnitTrack.
a(i,1)/(float)l;
672 s = (int)wgn_UnitTrack.
a(i,0);
673 for (ti=0,n=0.0; ti<l; ti++,n+=m)
678 if (wgn_VertexFeats.
a(0,j) > 0.0)
679 trajectory[ti][j] += wgn_VertexTrack.
a(s+ni,j);
686 for (ti=0; ti<l; ti++)
689 if (wgn_VertexFeats.
a(0,j) > 0.0)
690 stdss += trajectory[ti][j].stddev();
694 score = stdss.
mean() * members.length();
698 l1 = (l1ss.
mean() < 10.0) ? 10 : (
int)l1ss.
mean();
699 l2 = (l2ss.
mean() < 10.0) ? 10 : (
int)l2ss.
mean();
707 for (pp=members.head(); pp != 0; pp=pp->next())
709 i = members.
item(pp);
711 s = (int)wgn_UnitTrack.
a(i,0);
712 for (q=0; q<wgn_UnitTrack.
a(i,1); q++)
713 if (wgn_VertexTrack.
a(s+q,0) == -1.0)
718 s2l = (int)wgn_UnitTrack.
a(i,1) - (s1l + 2);
719 m1 = (float)(s1l)/(float)l1;
720 m2 = (float)(s2l)/(float)l2;
722 for (ti=0,n=0.0; s1l > 0 && ti<l1; ti++,n+=m1)
724 ni = s + (((int)n < s1l) ? (int)n : s1l - 1);
726 if (wgn_VertexFeats.
a(0,j) > 0.0)
727 trajectory[ti][j] += wgn_VertexTrack.
a(ni,j);
731 if (wgn_VertexFeats.
a(0,j) > 0.0)
732 trajectory[ti][j] += -1;
735 for (ti++,n=0.0; s2l > 0 && ti<l-1; ti++,n+=m2)
737 ni = s + (((int)n < s2l) ? (int)n : s2l - 1);
739 if (wgn_VertexFeats.
a(0,j) > 0.0)
740 trajectory[ti][j] += wgn_VertexTrack.
a(ni,j);
743 if (wgn_VertexFeats.
a(0,j) > 0.0)
744 trajectory[ti][j] += -2;
751 for (w=0.0,ti=0; ti<l1; ti++,w+=m)
753 if (wgn_VertexFeats.
a(0,j) > 0.0)
754 stdss += trajectory[ti][j].stddev() * w;
756 for (w=1.0,ti++; ti<l-1; ti++,w-=m)
758 if (wgn_VertexFeats.
a(0,j) > 0.0)
759 stdss += trajectory[ti][j].stddev() * w;
762 score = stdss.
mean() * members.length();
778 w = wgn_dataset.width();
780 X.resize(members.length(),w);
781 Y.
resize(members.length(),1);
782 feat_names.
append(
"Intercept");
785 for (p=0,pp=members.head(); pp; p++,pp=pp->next())
787 n = members.
item(pp);
795 X.a_no_check(p,0) = 1;
796 for (m=1,xm=1; m < w; m++)
798 if (wgn_dataset.ftype(m) == wndt_float)
802 feat_names.
append(wgn_dataset.feat_name(m));
804 X.a_no_check(p,xm) = (*wv)[m];
817float WImpurity::ols_impurity()
831 part_to_ols_data(X,Y,included,feat_names,members,*data);
840 if (!robust_ols(X,Y,included,coeffsl))
845 ols_apply(X,coeffsl,pred);
846 ols_test(Y,pred,cor,rmse);
849 printf(
"Impurity OLS X(%d,%d) Y(%d,%d) %f, %f, %f\n",
853 if (fabs(coeffsl[0]) > 10000)
859 return (1-best_score) *members.length();
862float WImpurity::cluster_impurity()
873 for (pp=members.head(); pp != 0; pp=pp->next())
875 i = members.
item(pp);
876 for (q=pp->next(); q != 0; q=q->next())
879 dist = (j < i ? wgn_DistMatrix.
a_no_check(i,j) :
893float WImpurity::cluster_distance(
int i)
897 float dist = cluster_member_mean(i);
898 float mdist = dist-a.
mean();
907int WImpurity::in_cluster(
int i)
911 float dist = cluster_member_mean(i);
914 for (pp=members.head(); pp != 0; pp=pp->next())
916 if (dist < cluster_member_mean(members.
item(pp)))
922float WImpurity::cluster_ranking(
int i)
925 float dist = cluster_distance(i);
929 for (pp=members.head(); pp != 0; pp=pp->next())
931 if (dist >= cluster_distance(members.
item(pp)))
938float WImpurity::cluster_member_mean(
int i)
946 for (sum=0.0,n=0,q=members.head(); q != 0; q=q->next())
951 dist = (j < i ? wgn_DistMatrix(i,j) : wgn_DistMatrix(j,i));
957 return ( n == 0 ? 0.0 : sum/n );
960void WImpurity::cumulate(
const float pv,
double count)
964 if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
969 else if (wgn_dataset.ftype(wgn_predictee) == wndt_ols)
974 else if (wgn_dataset.ftype(wgn_predictee) == wndt_vector)
980 member_counts.
append((
float)count);
982 else if (wgn_dataset.ftype(wgn_predictee) == wndt_trajectory)
987 else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
990 p.init(&wgn_discretes[wgn_dataset.ftype(wgn_predictee)]);
992 p.cumulate((
int)pv,count);
994 else if (wgn_dataset.ftype(wgn_predictee) == wndt_binary)
997 a.cumulate((
int)pv,count);
999 else if (wgn_dataset.ftype(wgn_predictee) == wndt_float)
1002 a.cumulate(pv,count);
1006 wagon_error(
"WImpurity: cannot cumulate EST_Val type");
1010ostream & operator <<(ostream &s,
WImpurity &imp)
1015 if (imp.t == wnim_float)
1016 s <<
"(" << imp.a.
stddev() <<
" " << imp.a.
mean() <<
")";
1017 else if (imp.t == wnim_vector)
1021 imp.vector_impurity();
1022 if (wgn_vertex_output ==
"mean")
1027 for (p=imp.members.head(), countp=imp.member_counts.head(); p != 0; p=p->next(), countp=countp->next())
1030 b.cumulate(wgn_VertexTrack.
a(imp.members.
item(p),j), imp.member_counts.
item(countp));
1033 s <<
"(" << b.
mean() <<
" ";
1034 if (isfinite(b.
stddev()))
1037 s <<
"0.001" <<
")";
1046 double best = WGN_HUGE_VAL;
1056 for (p=imp.members.head(); p != 0; p=p->next())
1058 cs[j] += wgn_VertexTrack.
a(imp.members.
item(p),j);
1062 for (p=imp.members.head(); p != 0; p=p->next())
1065 if (wgn_VertexFeats.
a(0,j) > 0.0)
1067 d = (wgn_VertexTrack.
a(imp.members.
item(p),j)-cs[j].
mean())
1075 bestp = imp.members.
item(p);
1082 s << wgn_VertexTrack.
a(bestp,j);
1085 if (isfinite(cs[j].stddev()))
1086 s << cs[j].stddev();
1097 s << imp.a.
mean() <<
")";
1099 else if (imp.t == wnim_trajectory)
1102 imp.trajectory_impurity();
1103 for (i=0; i<imp.l; i++)
1108 s <<
"(" << imp.trajectory[i][j].
mean() <<
" "
1109 << imp.trajectory[i][j].
stddev() <<
" " <<
")";
1115 s << imp.a.
mean() <<
")";
1117 else if (imp.t == wnim_cluster)
1121 for (p=imp.members.head(); p != 0; p=p->next())
1124 s <<
"(" << imp.members.
item(p) <<
" " <<
1125 imp.cluster_member_mean(imp.members.
item(p)) <<
")";
1131 s << imp.a.
mean() <<
")";
1133 else if (imp.t == wnim_ols)
1146 part_to_ols_data(X,Y,included,feat_names,imp.members,*(imp.data));
1147 if (!robust_ols(X,Y,included,coeffsl))
1149 printf(
"no robust ols\n");
1154 ols_apply(X,coeffsl,pred);
1155 ols_test(Y,pred,cor,rmse);
1156 for (i=0; i<coeffsl.
num_rows(); i++)
1159 s << feat_names.
nth(i);
1167 s <<
") " << cor <<
")";
1169 else if (imp.t == wnim_class)
1179 s <<
"(" << name <<
" " << prob <<
") ";
1184 s <<
"([WImpurity unset])";
EST_Litem * item_next(EST_Litem *idx) const
Used for iterating through members of the distribution.
EST_Litem * item_start() const
Used for iterating through members of the distribution.
void item_prob(EST_Litem *idx, EST_String &s, double &prob) const
During iteration returns name and probability given index.
const EST_String & most_probable(double *prob=NULL) const
Return the most probable member of the distribution.
double samples(void) const
Total number of example found.
double entropy(void) const
int item_end(EST_Litem *idx) const
Used for iterating through members of the distribution.
int matches(const char *e, int pos=0) const
Exactly match this string?
double stddev(void) const
standard deviation of currently cummulated values
double variance(void) const
variance of currently cummulated values
double mean(void) const
mean of currently cummulated values
void reset(void)
reset internal values
double samples(void)
number of samples in set
T & item(const EST_Litem *p)
T & nth(int n)
return the Nth value
void append(const T &item)
add item onto end of list
int num_columns() const
return number of columns
INLINE const T & a_no_check(int row, int col) const
const access with no bounds check, care recommend
int num_rows() const
return number of rows
void resize(int rows, int cols, int set=1)
resize matrix
void resize(int n, int set=1)
resize vector
void resize(int n, int set=1)
INLINE int n() const
number of items in vector.
INLINE const T & a_no_check(int n) const
read-only const access operator: without bounds checking
float & a(int i, int c=0)
int num_channels() const
return number of channels in track
const EST_String & string(void) const
const int Int(void) const
const float Float(void) const