root/as3/GameAI/trunk/src/NN3Test.as

リビジョン 1204, 4.0 kB (コミッタ: hael, コミット時期: 3 年 前)

--

Line 
1 package
2 {
3         import flash.display.Sprite;
4         import flash.events.Event;
5         import flash.events.MouseEvent;
6         import flash.text.TextField;
7         import flash.utils.getTimer;
8        
9         import jp.dip.hael.gameai.nn.NN3;
10         import jp.dip.hael.gameai.nn.NN3Event;
11         import jp.dip.hael.gameai.nn.actfunc.*;
12        
13         /**
14          * @private
15          */
16         public class NN3Test extends Sprite
17         {
18                 private var learningDatum:Array = [
19 //                      [[0.1, 0.1], [0.1]],
20 //                      [[0.9, 0.1], [0.9]],
21 //                      [[0.1, 0.9], [0.9]],
22 //                      [[0.9, 0.9], [0.1]]
23                         [[0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]],
24                         [[0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]],
25                         [[0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]],
26                         [[0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]],
27                         [[0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1, 0.1]],
28                         [[0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1, 0.1]],
29                         [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1, 0.1]],
30                         [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1, 0.1]],
31                         [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9, 0.1]],
32                         [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.9]]
33                 ];
34                
35                 private var start:int, time:Array = [], epoch:Array = [], error:Array = [];
36                
37                 private var textfield:TextField;
38                
39                 public function NN3Test()
40                 {
41                         textfield = new TextField();
42                         textfield.height = 500;
43                         textfield.width  = 500;
44                         addChild(textfield);
45                         stage.addEventListener(MouseEvent.CLICK, run);
46                         stage.addEventListener(MouseEvent.MOUSE_OUT, stop);
47                         addEventListener(Event.ENTER_FRAME, onEnterFrame_);
48                 }
49                
50                
51                 private function onEnterFrame_(e:Event):void
52                 {
53                         trace("hoge");
54                 }
55                
56                
57                 private function stop(e:MouseEvent):void
58                 {
59                         stage.removeEventListener(MouseEvent.CLICK, run);
60                         stage.removeEventListener(MouseEvent.MOUSE_OUT, stop);
61                        
62                         var t:int = 0, epo:int = 0, err:Number = 0;
63                         for(var i:int = 0; i < time.length; i++){
64                                 t += time[i];
65                                 epo += epoch[i];
66                                 err += error[i];
67                         }
68                         t /= i;
69                         epo /= i;
70                         err /= i;
71                         textfield.appendText("mean: " + t + "ms" + " " + epo + " " + err);
72                 }
73                
74                
75                 private function run(e:MouseEvent):void
76                 {
77                         var nn3:NN3 = new NN3();
78                         nn3.init(10, 4, 10);
79                        
80 //                      for(var i:int = 0; i < learningDatum.length; i++){
81 //                              trace(nn3.input(learningDatum[i][0]));
82 //                      }
83                        
84                         nn3.addEventListener(NN3Event.LEARNING_COMPLETE, onLearningComplete_);
85                         start = getTimer();
86                         nn3.learn(learningDatum, 3.7, 0.007, 2000, 0.0);
87
88 //                      for(var epoch:int = 0; epoch < 2000; epoch++){
89 //                              var err:Number = 0.0;
90 //                              var s:String = "";
91 //                              for(var n:int = 0; n < learningDatum.length; n++){
92 //                                      var e:Number = nn3.learn1Case(learningDatum[n], 1.7, 0.9);
93 //                                      err += e * e;
94 //                                      var out:Array = nn3.input(learningDatum[n][0]);
95 //                                      s += out[0].toString() + " ";
96 //                              }
97 //                              err /= n;
98 //                              trace(err.toString() + " " + s);
99 //                      }
100 //                      trace("finish");
101
102                 }
103                
104                 private function onLearningComplete_(e:NN3Event):void
105                 {
106                         trace("////////////////////");
107                         var t1:int = getTimer() - start;
108                         var nn3:NN3 = e.target as NN3;
109                         nn3.removeEventListener(NN3Event.LEARNING_COMPLETE, onLearningComplete_);
110                        
111                         for(var i:int = 0; i < learningDatum.length; i++){
112                                 trace(nn3.input(learningDatum[i][0]));
113                         }
114                         textfield.appendText(t1 + "ms" + " " + nn3.epoch + " " + nn3.error + "\n");
115                         time.push(t1);
116                         epoch.push(nn3.epoch);
117                         error.push(nn3.error);
118                        
119                         var t:int = getTimer();
120                         nn3.input(learningDatum[0][0]);
121                         t1 = getTimer() - t;
122                         textfield.appendText("[" + t1 + "ms]\n");
123                 }
124
125         }
126 }
Note: リポジトリブラウザについてのヘルプは TracBrowser を参照してください。