00001
00020 #include "mfcpch.h"
00021 #include <stdlib.h>
00022 #include "statistc.h"
00023 #include "memry.h"
00024 #include "statistc.h"
00025 #include "lmedsq.h"
00026
00027 #define EXTERN
00028
00031 EXTERN INT_VAR (lms_line_trials, 12, "Number of linew fits to do");
00034
00035 #define SEED1 0x1234
00036 #define SEED2 0x5678
00037 #define SEED3 0x9abc
00038 #define LMS_MAX_FAILURES 3
00039
00040 #ifndef __UNIX__
00041 UINT32 nrand48(
00042 UINT16 *seeds
00043 ) {
00044 static UINT32 seed = 0;
00045
00046 if (seed == 0) {
00047 seed = seeds[0] ^ (seeds[1] << 8) ^ (seeds[2] << 16);
00048 srand(seed);
00049 }
00050
00051 return rand () | (rand () << 16);
00052 }
00053 #endif
00054
00060 LMS::LMS (
00061 INT32 size
00062 ):samplesize (size) {
00063 samplecount = 0;
00064 a = 0;
00065 m = 0.0f;
00066 c = 0.0f;
00067 samples = (FCOORD *) alloc_mem (size * sizeof (FCOORD));
00068 errors = (float *) alloc_mem (size * sizeof (float));
00069 line_error = 0.0f;
00070 fitted = FALSE;
00071 }
00072
00073
00079 LMS::~LMS (
00080 ) {
00081 free_mem(samples);
00082 free_mem(errors);
00083 }
00084
00085
00091 void LMS::clear() {
00092 samplecount = 0;
00093 fitted = FALSE;
00094 }
00095
00096
00102 void LMS::add(
00103 FCOORD sample
00104 ) {
00105 if (samplecount < samplesize)
00106
00107 samples[samplecount++] = sample;
00108 fitted = FALSE;
00109 }
00110
00111
00117 void LMS::fit(
00118 float &out_m,
00119 float &out_c) {
00120 INT32 index;
00121 INT32 trials;
00122 float test_m, test_c;
00123 float test_error;
00124
00125 switch (samplecount) {
00126 case 0:
00127 m = 0.0f;
00128 c = 0.0f;
00129 line_error = 0.0f;
00130 break;
00131
00132 case 1:
00133 m = 0.0f;
00134 c = samples[0].y ();
00135 line_error = 0.0f;
00136 break;
00137
00138 case 2:
00139 if (samples[0].x () != samples[1].x ()) {
00140 m = (samples[1].y () - samples[0].y ())
00141 / (samples[1].x () - samples[0].x ());
00142 c = samples[0].y () - m * samples[0].x ();
00143 }
00144 else {
00145 m = 0.0f;
00146 c = (samples[0].y () + samples[1].y ()) / 2;
00147 }
00148 line_error = 0.0f;
00149 break;
00150
00151 default:
00152 pick_line(m, c);
00153 compute_errors(m, c);
00154 index = choose_nth_item (samplecount / 2, errors, samplecount);
00155 line_error = errors[index];
00156 for (trials = 1; trials < lms_line_trials; trials++) {
00157
00158 pick_line(test_m, test_c);
00159 compute_errors(test_m, test_c);
00160 index = choose_nth_item (samplecount / 2, errors, samplecount);
00161 test_error = errors[index];
00162 if (test_error < line_error) {
00163
00164 line_error = test_error;
00165 m = test_m;
00166 c = test_c;
00167 }
00168 }
00169 }
00170 fitted = TRUE;
00171 out_m = m;
00172 out_c = c;
00173 a = 0;
00174 }
00175
00176
00182 void LMS::fit_quadratic(
00183 float outlier_threshold,
00184 double &out_a,
00185 float &out_b,
00186 float &out_c) {
00187 INT32 trials;
00188 double test_a;
00189 float test_b, test_c;
00190 float test_error;
00191
00192 if (samplecount < 3) {
00193 out_a = 0;
00194 fit(out_b, out_c);
00195 return;
00196 }
00197 pick_quadratic(a, m, c);
00198 line_error = compute_quadratic_errors (outlier_threshold, a, m, c);
00199 for (trials = 1; trials < lms_line_trials * 2; trials++) {
00200 pick_quadratic(test_a, test_b, test_c);
00201 test_error = compute_quadratic_errors (outlier_threshold,
00202 test_a, test_b, test_c);
00203 if (test_error < line_error) {
00204 line_error = test_error;
00205 a = test_a;
00206 m = test_b;
00207 c = test_c;
00208 }
00209 }
00210 fitted = TRUE;
00211 out_a = a;
00212 out_b = m;
00213 out_c = c;
00214 }
00215
00216
00223 void LMS::constrained_fit(
00224 float fixed_m,
00225 float &out_c) {
00226 INT32 index;
00227 INT32 trials;
00228 float test_c;
00229 static UINT16 seeds[3] = { SEED1, SEED2, SEED3 };
00230
00231 float test_error;
00232
00233 m = fixed_m;
00234 switch (samplecount) {
00235 case 0:
00236 c = 0.0f;
00237 line_error = 0.0f;
00238 break;
00239
00240 case 1:
00241
00242 c = samples[0].y () - m * samples[0].x ();
00243 line_error = 0.0f;
00244 break;
00245
00246 case 2:
00247 c = (samples[0].y () + samples[1].y ()
00248 - m * (samples[0].x () + samples[1].x ())) / 2;
00249 line_error = m * samples[0].x () + c - samples[0].y ();
00250 line_error *= line_error;
00251 break;
00252
00253 default:
00254 index = (INT32) nrand48 (seeds) % samplecount;
00255
00256 c = samples[index].y () - m * samples[index].x ();
00257 compute_errors(m, c);
00258 index = choose_nth_item (samplecount / 2, errors, samplecount);
00259 line_error = errors[index];
00260 for (trials = 1; trials < lms_line_trials; trials++) {
00261 index = (INT32) nrand48 (seeds) % samplecount;
00262 test_c = samples[index].y () - m * samples[index].x ();
00263
00264 compute_errors(m, test_c);
00265 index = choose_nth_item (samplecount / 2, errors, samplecount);
00266 test_error = errors[index];
00267 if (test_error < line_error) {
00268
00269 line_error = test_error;
00270 c = test_c;
00271 }
00272 }
00273 }
00274 fitted = TRUE;
00275 out_c = c;
00276 a = 0;
00277 }
00278
00279
00285 void LMS::pick_line(
00286 float &line_m,
00287 float &line_c) {
00288 INT16 trial_count;
00289 static UINT16 seeds[3] = { SEED1, SEED2, SEED3 };
00290
00291 INT32 index1;
00292 INT32 index2;
00293
00294 trial_count = 0;
00295 do {
00296 index1 = (INT32) nrand48 (seeds) % samplecount;
00297 index2 = (INT32) nrand48 (seeds) % samplecount;
00298 line_m = samples[index2].x () - samples[index1].x ();
00299 trial_count++;
00300 }
00301 while (line_m == 0 && trial_count < LMS_MAX_FAILURES);
00302 if (line_m == 0) {
00303 line_c = (samples[index2].y () + samples[index1].y ()) / 2;
00304 }
00305 else {
00306 line_m = (samples[index2].y () - samples[index1].y ()) / line_m;
00307 line_c = samples[index1].y () - samples[index1].x () * line_m;
00308 }
00309 }
00310
00311
00317 void LMS::pick_quadratic(
00318 double &line_a,
00319 float &line_m,
00320 float &line_c) {
00321 INT16 trial_count;
00322 static UINT16 seeds[3] = { SEED1, SEED2, SEED3 };
00323
00324 INT32 index1;
00325 INT32 index2;
00326 INT32 index3;
00327 FCOORD x1x2;
00328 FCOORD x1x3;
00329 FCOORD x3x2;
00330 double bottom;
00331
00332 trial_count = 0;
00333 do {
00334 if (trial_count >= LMS_MAX_FAILURES - 1) {
00335 index1 = 0;
00336 index2 = samplecount / 2;
00337 index3 = samplecount - 1;
00338 }
00339 else {
00340 index1 = (INT32) nrand48 (seeds) % samplecount;
00341 index2 = (INT32) nrand48 (seeds) % samplecount;
00342 index3 = (INT32) nrand48 (seeds) % samplecount;
00343 }
00344 x1x2 = samples[index2] - samples[index1];
00345 x1x3 = samples[index3] - samples[index1];
00346 x3x2 = samples[index2] - samples[index3];
00347 bottom = x1x2.x () * x1x3.x () * x3x2.x ();
00348 trial_count++;
00349 }
00350 while (bottom == 0 && trial_count < LMS_MAX_FAILURES);
00351 if (bottom == 0) {
00352 line_a = 0;
00353 pick_line(line_m, line_c);
00354 }
00355 else {
00356 line_a = x1x3 * x1x2 / bottom;
00357 line_m = x1x2.y () - line_a * x1x2.x ()
00358 * (samples[index2].x () + samples[index1].x ());
00359 line_m /= x1x2.x ();
00360 line_c = samples[index1].y () - samples[index1].x ()
00361 * (samples[index1].x () * line_a + line_m);
00362 }
00363 }
00364
00365
00371 void LMS::compute_errors(
00372 float line_m,
00373 float line_c) {
00374 INT32 index;
00375
00376 for (index = 0; index < samplecount; index++) {
00377 errors[index] =
00378 line_m * samples[index].x () + line_c - samples[index].y ();
00379 errors[index] *= errors[index];
00380 }
00381 }
00382
00383
00389 float LMS::compute_quadratic_errors(
00390 float outlier_threshold,
00391 double line_a,
00392 float line_m,
00393 float line_c) {
00394 INT32 outlier_count;
00395 INT32 index;
00396 INT32 error_count;
00397 double total_error;
00398
00399 total_error = 0;
00400 outlier_count = 0;
00401 error_count = 0;
00402 for (index = 0; index < samplecount; index++) {
00403 errors[error_count] = line_c + samples[index].x ()
00404 * (line_m + samples[index].x () * line_a) - samples[index].y ();
00405 errors[error_count] *= errors[error_count];
00406 if (errors[error_count] > outlier_threshold) {
00407 outlier_count++;
00408 errors[samplecount - outlier_count] = errors[error_count];
00409 }
00410 else {
00411 total_error += errors[error_count++];
00412 }
00413 }
00414 if (outlier_count * 3 < error_count)
00415 return total_error / error_count;
00416 else {
00417 index = choose_nth_item (outlier_count / 2,
00418 errors + samplecount - outlier_count,
00419 outlier_count);
00420
00421 return errors[samplecount - outlier_count + index];
00422 }
00423 }
00424
00425
00431 #ifndef GRAPHICS_DISABLED
00432 void LMS::plot(
00433 WINDOW win,
00434 COLOUR colour
00435 ) {
00436 if (fitted) {
00437 line_color_index(win, colour);
00438 move2d (win, samples[0].x (),
00439 c + samples[0].x () * (m + samples[0].x () * a));
00440 draw2d (win, samples[samplecount - 1].x (),
00441 c + samples[samplecount - 1].x () * (m +
00442 samples[samplecount -
00443 1].x () * a));
00444 }
00445 }
00446 #endif