package { import flash.display.Sprite; import flash.events.Event; import flash.events.MouseEvent; import flash.text.TextField; import flash.utils.getTimer; import jp.dip.hael.gameai.nn.NN3; import jp.dip.hael.gameai.nn.NN3Event; import jp.dip.hael.gameai.nn.actfunc.*; /** * @private */ public class NN3Test extends Sprite { private var learningDatum:Array = [ // [[0.1, 0.1], [0.1]], // [[0.9, 0.1], [0.9]], // [[0.1, 0.9], [0.9]], // [[0.9, 0.9], [0.1]] [[0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], [[0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], [[0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], [[0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], [[0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1]], [[0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1]], [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1]], [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1]], [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1]], [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9]] ]; private var start:int, time:Array = [], epoch:Array = [], error:Array = []; private var textfield:TextField; public function NN3Test() { textfield = new TextField(); textfield.height = 500; textfield.width = 500; addChild(textfield); stage.addEventListener(MouseEvent.CLICK, run); stage.addEventListener(MouseEvent.MOUSE_OUT, stop); addEventListener(Event.ENTER_FRAME, onEnterFrame_); } private function onEnterFrame_(e:Event):void { trace("hoge"); } private function stop(e:MouseEvent):void { stage.removeEventListener(MouseEvent.CLICK, run); stage.removeEventListener(MouseEvent.MOUSE_OUT, stop); var t:int = 0, epo:int = 0, err:Number = 0; for(var i:int = 0; i < time.length; i++){ t += time[i]; epo += epoch[i]; err += error[i]; } t /= i; epo /= i; err /= i; textfield.appendText("mean: " + t + "ms" + " " + epo + " " + err); } private function run(e:MouseEvent):void { var nn3:NN3 = new NN3(); nn3.init(10, 4, 10); // for(var i:int = 0; i < learningDatum.length; i++){ // trace(nn3.input(learningDatum[i][0])); // } nn3.addEventListener(NN3Event.LEARNING_COMPLETE, onLearningComplete_); start = getTimer(); nn3.learn(learningDatum, 3.7, 0.007, 2000, 0.0); // for(var epoch:int = 0; epoch < 2000; epoch++){ // var err:Number = 0.0; // var s:String = ""; // for(var n:int = 0; n < learningDatum.length; n++){ // var e:Number = nn3.learn1Case(learningDatum[n], 1.7, 0.9); // err += e * e; // var out:Array = nn3.input(learningDatum[n][0]); // s += out[0].toString() + " "; // } // err /= n; // trace(err.toString() + " " + s); // } // trace("finish"); } private function onLearningComplete_(e:NN3Event):void { trace("////////////////////"); var t1:int = getTimer() - start; var nn3:NN3 = e.target as NN3; nn3.removeEventListener(NN3Event.LEARNING_COMPLETE, onLearningComplete_); for(var i:int = 0; i < learningDatum.length; i++){ trace(nn3.input(learningDatum[i][0])); } textfield.appendText(t1 + "ms" + " " + nn3.epoch + " " + nn3.error + "\n"); time.push(t1); epoch.push(nn3.epoch); error.push(nn3.error); var t:int = getTimer(); nn3.input(learningDatum[0][0]); t1 = getTimer() - t; textfield.appendText("[" + t1 + "ms]\n"); } } }