/** * interactive explorer for the beta distribution; * i.e. bayesian learning from bernoulli evidence via conjugate prior; * i.e. will the sun come up tomorrow? curious minds wish to know * http://anyall.org/blog/2009/07/beta-conjugate-explorer/ * Brendan O'Connor */ int W = 400; int H = 400; PFont font1, font2; void setup() { size(400,400); background(255); // come with the Processing download, plucked out of $ find /Applications/Processing.app -name '*.vlw' font1 = loadFont("AmericanTypewriter-24.vlw"); font2 = loadFont("TheSans-Plain-12.vlw"); } void draw() { fill(0,20); rect(0,0, width,height); stroke(255); plotBeta(curA,curB); fill(200,100); textFont(font1); textAlign(LEFT,TOP); text(String.format("Beta(%.1f, %.1f)", curA, curB), 10, 10); //"Beta(" + curA + ", " + curB + ")", 5, 5); textFont(font2); textAlign(LEFT,TOP); text(instructions(), W-150,10); axes(); } void axes() { int tickPad = 3; colorMode(RGB); stroke(100,100,150); textFont(font2); // x-axis line(xmin_pos,ymin_pos, xmax_pos,ymin_pos); textAlign(CENTER,TOP); text("0", xmin_pos, ymin_pos+ tickPad); text("0.5", xmin_pos + xpos_width/2, ymin_pos+ tickPad); text("1", xmin_pos + xpos_width, ymin_pos+ tickPad); // y-axis line(xmin_pos,ymin_pos, xmin_pos,ymax_pos); textAlign(RIGHT,CENTER); for (float y=1; y <= ymax; y += 1) { text(String.format("%.0f",y), xmin_pos-tickPad, ymin_pos - yratio*y); } } //String instructions = //"up: learn neg item\n"+ //"down: forget neg item\n"+ //"right: learn pos item\n"+ //"left: forget pos item\n"; String instructions() { String s = ""; if (curA >= 1) s += "right: learn pos item\n"; else s += "right: increase a\n"; if (curB >= 1) s += "up: learn neg item\n"; else s += "up: increase b\n"; if (curA >= 2) s += "left: forget pos item\n"; else s += "left: decrease a\n"; if (curB >= 2) s += "down: forget neg item\n"; else s += "down: decrease b\n"; s += "r: reset"; return s; } int ymin_pos=H - 20; int ymax_pos=50; int xmin_pos=20; int xmax_pos=W - 20; int xpos_width = xmax_pos - xmin_pos; int ypos_height = ymin_pos - ymax_pos; float ymax = 5; float xratio = xpos_width; float yratio = ypos_height / ymax; float xdelta = 0.01; float curA = 2; float curB = 2; void plotBeta(float a, float b) { noFill(); beginShape(); for (float x=0; x <= 1; x += xdelta) { float xpos = xmin_pos + xratio*x; float ypos = ymin_pos - yratio*dbeta(x, a,b); if (Float.isNaN(ypos) || Float.isInfinite(ypos)) continue; // System.out.print(ypos + " "); vertex(xpos, ypos); } endShape(); // System.out.println("here"); } float dbeta(float x, float a, float b) { // return x; double logdens = (a-1) * Math.log(x) + (b-1) * Math.log(1-x) - lbeta(a, b); return (float) Math.exp(logdens); //return (float) (gamma(a+b) / gamma(a) / gamma(b) * Math.pow(x, a-1) * Math.pow(1-x,b-1)); } double lbeta(float a, float b) { assert(a>0 && b>0); return lgamma(a) + lgamma(b) - lgamma(a+b); } // numeric code hacked up from scalanlp.org in turn derived from: /** * an approximation of the log of the Gamma function * of x. Laczos Approximation * Reference: Numerical Recipes in C * http://www.library.cornell.edu/nr/cbookcpdf.html * www.cs.berkeley.edu/~milch/blog/versions/blog-0.1.3/blog/distrib */ double lgamma(double x) { double y = x; double tmp = x + 5.5; tmp -= ((x + 0.5) * Math.log(tmp)); double ser = 1.000000000190015; int j = 0; while(j < 6) { y += 1; ser += (cof[j]/y); j +=1; } return (-tmp + Math.log(2.5066282746310005*ser / x)); } double cof[] = { 76.18009172947146, -86.50532032941677, 24.01409824083091,-1.231739572450155, 0.1208650973866179e-2,-0.5395239384953e-5}; boolean isUp(KeyEvent e) { return e.getKeyCode()==KeyEvent.VK_UP || e.getKeyCode()==KeyEvent.VK_KP_UP; } boolean isDown(KeyEvent e) { return e.getKeyCode()==KeyEvent.VK_DOWN || e.getKeyCode()==KeyEvent.VK_KP_DOWN; } boolean isLeft(KeyEvent e) { return e.getKeyCode()==KeyEvent.VK_LEFT || e.getKeyCode()==KeyEvent.VK_KP_LEFT; } boolean isRight(KeyEvent e) { return e.getKeyCode()==KeyEvent.VK_RIGHT || e.getKeyCode()==KeyEvent.VK_KP_RIGHT; } void keyPressed(KeyEvent e) { if (isUp(e)) { curB = incrParam(curB); } else if (isDown(e)) { curB = decrParam(curB); } else if (isLeft(e)) { curA = decrParam(curA); } else if (isRight(e)) { curA = incrParam(curA); } else if (e.getKeyCode() == 'r' || e.getKeyCode()=='R') { curA = 2; curB = 2; } } float incrParam(float a) { if (a >= 1) a += 1; else if (a < 1) a += 0.1; return a; } float decrParam(float a) { if (a >= 2) a -= 1; else if (a <= 1 && a > 0.1) a -= 0.1; return a; }