50#include "EST_cmd_line.h"
52enum wn_strategy_type {wn_decision_list, wn_decision_tree};
54static wn_strategy_type wagon_type = wn_decision_tree;
56static int wagon_main(
int argc,
char **argv);
107int main(
int argc,
char **argv)
110 wagon_main(argc,argv);
116static int set_Vertex_Feats(
EST_Track &wgn_VertexFeats,
123 wgn_VertexFeats.
a(0,i) = 0.0;
134 const EST_String ws = (
const char *)token.whitespace();
138 wgn_VertexFeats.
a(0,i) = 1.0;
140 }
else if ((ws ==
",") || (ws ==
""))
142 s = atoi(token.string());
143 wgn_VertexFeats.
a(0,s) = 1.0;
144 }
else if (ws ==
"-")
149 e = atoi(token.string());
150 for (i=s; i<=e && i<wgn_VertexFeats.
num_channels(); i++)
151 wgn_VertexFeats.
a(0,i) = 1.0;
154 printf(
"wagon: track_feats invalid: %s at position %d\n",
155 (
const char *)wagon_track_features,
164static int wagon_main(
int argc,
char **argv)
170 ostream *wgn_coutput = 0;
171 float stepwise_limit = 0;
172 int feats_start=0, feats_end=0;
178 "Summary: CART building program\n"+
179 "-desc <ifile> Field description file\n"+
180 "-data <ifile> Datafile, one vector per line\n"+
181 "-stop <int> {50} Minimum number of examples for leaf nodes\n"+
182 "-test <ifile> Datafile to test tree on\n"+
183 "-frs <float> {10} Float range split, number of partitions to\n"+
184 " split a float feature range into\n"+
185 "-dlist Build a decision list (rather than tree)\n"+
186 "-dtree Build a decision tree (rather than list) default\n"+
187 "-output <ofile> \n"+
188 "-o <ofile> File to save output tree in\n"+
189 "-distmatrix <ifile>\n"+
190 " A distance matrix for clustering\n"+
192 " track for vertex indices\n"+
193 "-track_start <int>\n"+
194 " start channel vertex indices\n"+
195 "-track_end <int>\n"+
196 " end (inclusive) channel for vertex indices\n"+
197 "-track_feats <string>\n"+
198 " Track features to use, comma separated list\n"+
199 " with feature numbers and/or ranges, 0 start\n"+
200 "-unittrack <ifile>\n"+
201 " track for unit start and length in vertex track\n"+
202 "-quiet No questions printed during building\n"+
203 "-verbose Lost of information printing during build\n"+
204 "-predictee <string>\n"+
205 " name of field to predict (default is first field)\n"+
206 "-ignore <string>\n"+
207 " Filename or bracket list of fields to ignore\n"+
208 "-count_field <string>\n"+
209 " Name of field containing count weight for samples\n"+
210 "-stepwise Incrementally find best features\n"+
211 "-swlimit <float> {0.0}\n"+
212 " Percentage necessary improvement for stepwise,\n"+
213 " may be negative.\n"+
214 "-swopt <string> Parameter to optimize for stepwise, for \n"+
215 " classification options are correct or entropy\n"+
216 " for regression options are rmse or correlation\n"+
217 " correct and correlation are the defaults\n"+
218 "-balance <float> For derived stop size, if dataset at node, divided\n"+
219 " by balance is greater than stop it is used as stop\n"+
220 " if balance is 0 (default) always use stop as is.\n"+
221 "-cos Use mean cosine distance rather than Gaussian (TBD).\n"+
222 "-dof <float> Randomly dropout feats in training (prob).\n"+
223 "-dos <float> Randomly dropout samples in training (prob).\n"+
224 "-vertex_output <string> Output <mean> or <best> of cluster\n"+
225 "-held_out <int> Percent to hold out for pruning\n"+
226 "-max_questions <int> Maximum number of questions in tree\n"+
227 "-heap <int> {210000}\n"+
228 " Set size of Lisp heap, should not normally need\n"+
229 " to be changed from its default, only with *very*\n"+
230 " large description files (> 1M)\n"+
231 "-omp_nthreads <int> {1}\n"+
232 " Set number of OMP threads to run wagon in\n"+
233 " tree building; this overrides $OMP_NUM_THREADS\n"+
234 " (ignored if not supported)\n"+
235 "-noprune No (same class) pruning required\n",
239 wgn_held_out = al.
ival(
"-held_out");
241 wgn_dropout_feats = al.
fval(
"-dof");
243 wgn_dropout_samples = al.
fval(
"-dos");
245 wgn_cos = al.
ival(
"-cos");
247 wgn_balance = al.
fval(
"-balance");
250 cerr << argv[0] <<
": missing description and/or datafile" << endl;
251 cerr <<
"use -h for description of arguments" << endl;
260 wgn_min_cluster_size = atoi(al.
val(
"-stop"));
261 if (al.
present(
"-max_questions"))
262 wgn_max_questions = atoi(al.
val(
"-max_questions"));
266 wgn_predictee_name = al.
val(
"-predictee");
267 if (al.
present(
"-count_field"))
268 wgn_count_field_name = al.
val(
"-count_field");
270 stepwise_limit = al.
fval(
"-swlimit");
272 wgn_float_range_split = atof(al.
val(
"-frs"));
274 wgn_opt_param = al.
val(
"-swopt");
275 if (al.
present(
"-vertex_output"))
276 wgn_vertex_output = al.
val(
"-vertex_output");
280 wgn_oname = al.
val(
"-o");
282 wgn_oname = al.
val(
"-output");
283 wgn_coutput =
new ofstream(wgn_oname);
286 cerr <<
"Wagon: can't open file \"" << wgn_oname <<
287 "\" for output " << endl;
295 if (wgn_DistMatrix.
load(al.
val(
"-distmatrix")) != 0)
297 cerr <<
"Wagon: failed to load Distance Matrix from \"" <<
298 al.
val(
"-distmatrix") <<
"\"\n" << endl;
303 wagon_type = wn_decision_list;
309 siod_init(al.
ival(
"-heap"));
315 ignores = read_from_string(ig);
317 ignores = vload(ig,1);
320 wgn_load_datadescription(al.
val(
"-desc"),ignores);
321 wgn_load_dataset(wgn_dataset,al.
val(
"-data"));
322 if (al.
present(
"-distmatrix") &&
323 (wgn_DistMatrix.
num_rows() < wgn_dataset.length()))
325 cerr <<
"wagon: distance matrix is smaller than number of training elements\n";
330 wgn_VertexTrack.
load(al.
val(
"-track"));
333 wgn_VertexFeats.
a(0,i) = 1.0;
336 if (al.
present(
"-track_start"))
338 feats_start = al.
ival(
"-track_start");
339 if ((feats_start < 0) ||
342 printf(
"wagon: track_start invalid: %d out of %d channels\n",
347 for (i=0; i<feats_start; i++)
348 wgn_VertexFeats.
a(0,i) = 0.0;
354 feats_end = al.
ival(
"-track_end");
355 if ((feats_end < feats_start) ||
358 printf(
"wagon: track_end invalid: %d between start %d out of %d channels\n",
364 for (i=feats_end+1; i<wgn_VertexTrack.
num_channels(); i++)
365 wgn_VertexFeats.
a(0,i) = 0.0;
367 if (al.
present(
"-track_feats"))
369 EST_String wagon_track_features = (
const char *)al.
val(
"-track_feats");
370 set_Vertex_Feats(wgn_VertexFeats,wagon_track_features);
382 wgn_UnitTrack.
load(al.
val(
"-unittrack"));
386 if (al.
present (
"-omp_nthreads"))
388 omp_set_num_threads(atoi(al.
val(
"-omp_nthreads")));
390 omp_set_num_threads(1);
393 if (al.
present (
"-omp_nthreads"))
395 printf(
"wagon: -omp_nthreads ignored: not supported in this build.\n");
400 wgn_load_dataset(wgn_test_dataset,al.
val(
"-test"));
404 tree = wagon_stepwise(stepwise_limit);
405 else if (wagon_type == wn_decision_tree)
406 tree = wgn_build_tree(score);
407 else if (wagon_type == wn_decision_list)
409 tree = wgn_build_dlist(score,wgn_coutput);
412 cerr <<
"Wagon: unknown operation, not tree or list" << endl;
418 *wgn_coutput << *tree;
419 summary_results(*tree,wgn_coutput);
422 if (wgn_coutput != &cout)
EST_read_status load(const EST_String &filename)
Load from file (ascii or binary as defined in file)
float fval(const EST_String &rkey, int m=1) const
const EST_String & sval(const EST_String &rkey, int m=1) const
int ival(const EST_String &rkey, int m=1) const
const V & val(const K &rkey, bool m=0) const
return value according to key (const)
const int present(const K &rkey) const
Returns true if key is present.
int num_rows() const
return number of rows
int filepos(void) const
current file position in \Ref{EST_TokenStream}
void set_SingleCharSymbols(const EST_String &sc)
set which characters are to be treated as single character symbols
void set_PrePunctuationSymbols(const EST_String &ps)
set which characters are to be treated as (post) punctuation
int open_string(const EST_String &newbuffer)
open a \Ref{EST_TokenStream} for string rather than a file
void set_PunctuationSymbols(const EST_String &ps)
set which characters are to be treated as (post) punctuation
void set_WhiteSpaceChars(const EST_String &ws)
set which characters are to be treated as whitespace
EST_TokenStream & get(EST_Token &t)
get next token in stream
EST_read_status load(const EST_String name, float ishift=0.0, float startt=0.0)
float & a(int i, int c=0)
int num_channels() const
return number of channels in track
void resize(int num_frames, int num_channels, bool preserve=1)