diff --git a/README.md b/README.md index 7a36186..01f0c13 100644 --- a/README.md +++ b/README.md @@ -3,13 +3,14 @@ Reimplementation of Learning Simple Algorithms From Examples paper on Julia # TODO - Revisit grid tasks data generation -- Make supervised grid tasks working +- ~~Make supervised grid tasks working~~ - ~~GRU controller~~ - ~~Model save~~ -- Visualization/demo +- ~~Visualization/demo~~ (Have some harmless bugs. Still wip) - Q-learning (work in progress) # Q-Learning -- implement make_batches function -- implement loss/train function -- debug run_episodes! and other things +- ~~implement make_batches function~~ +- ~~implement loss/train function~~ +- ~~debug run_episodes! and other things~~ +- correct the objective and training diff --git a/data.jl b/data.jl index 565ac3d..ea61015 100644 --- a/data.jl +++ b/data.jl @@ -36,7 +36,8 @@ function copy_data(seqlen) data = Any[rand(0:9) for i=1:seqlen] actions = [goldacts[:moveright] for i =1:seqlen] ygold = data - return (data, ygold, actions) # x,y,actions + actions = map(ai->(ai,WRITE), actions) + return (data, ygold, actions) end diff --git a/demo/demo.jl b/demo/demo.jl new file mode 100644 index 0000000..2432046 --- /dev/null +++ b/demo/demo.jl @@ -0,0 +1,185 @@ +using JSON,Images,ArgParse + +const IMGH = 90 +const IMGW = 60 +const RIGHT = "mr" +const LEFT = "ml" +const UP = "up" +const DOWN = "down" +const NOOP = -78 +const EOF = -2 # eof for reverse data +const OUTDIR ="outdir/" +function parsedoc(fname) + inf = JSON.parsefile(fname) + global taskname = inf["task"] + global input = inf["input"] + global actions = inf["actions"] + global output = inf["output"] + global spos = inf["startpos"] +end + +function makegrids(input,output) + name(x)=string("img/n",x,".png") + ri = Any[] + if length(input) > 1 + # input side is 2d grid + maxlen = maximum(map(y->length(y),input)) + for i = 1:length(input) + if length(input[i]) != maxlen + input[i] = [[10 for i=1:maxlen-length(input[i])]...,input[i]...] + end + end + end + for i=1:length(input) + singlerow = load(name(input[i][1])) + for j=2:length(input[i]) + if input[i][j] == EOF;break; end # reverse task case + singlerow = hcat(singlerow,load(name(input[i][j]))) + end + push!(ri,singlerow) + end + ri = vcat(ri...) + outs = [zeros(similar(load(name(input[1][1])))) for i=1:length(output)] + ro = hcat(outs...) + ro = ro .+ RGBA(1,1,1,1) + return ri,ro +end + +function makegridadd(input,output) + maxlen = ndigits(maximum(input)) + minlen = ndigits(input[2]) + ri = Any[] + push!(ri,reverse(digits(input[1]))) + push!(ri,[[10 for i=1:maxlen-minlen]...,reverse(digits(input[2]))...]) + + return makegrids(ri,digits(input[1]+input[2])) +end + +# input grid +function modifygrid(ipos,igrid,frame,counter) + println("input grid modify",ipos) + if ipos[1] != -1 + xstart = (ipos[1]-1) * IMGH + 1 + else + xstart = 1 + end + ystart = (ipos[2]-1) * IMGW + 1 + igridtmp = convert(Array{Float32},rawview(channelview(igrid))) + if ystart > size(igridtmp,3) + ystart = (ipos[2] -2) * IMGW + 1 + end + igridtmp[1:3,xstart:(xstart+IMGH-1),ystart:(ystart+IMGW-1)] += frame + # small correction +# igridtmp[2:3,xstart:(xstart+IMGH-1),ystart:(ystart+IMGW-1)] = 0 + igridtmp[igridtmp .> 255] = 255 + igrid = colorview(RGBA,igridtmp./255) + save(string(OUTDIR,"itape",counter,".png"),igrid) +end + +# output grid +function modifygrid(img::String,opos,ogrid,counter,frame) + if img == "NOOP" + if taskname == "reverse" + noopimg = ones(RGBA{N0f8}, 90, 300) + save(string(OUTDIR,"otape",counter,".png"),noopimg) + return noopimg,opos + else + noopimg = convert(Array{Float32},rawview(channelview(ogrid))) + zaa = colorview(RGBA,noopimg./255) + save(string(OUTDIR,"otape",counter,".png"),zaa) + return noopimg,opos + end + end + ystart = (opos[2]-1)*IMGW + 1 + ogridtmp = convert(Array{Float32},rawview(channelview(ogrid))) + outnum = convert(Array{Float32},rawview(channelview(load(img)))) + ogridtmp[:,:,ystart:(ystart+IMGW-1)] = outnum + result = copy(ogridtmp) + ogridtmp[1:3,:,ystart:(ystart+IMGW-1)] += frame + ogridtmp[ogridtmp .> 255] = 255 + ogrid = colorview(RGBA,ogridtmp./255) + save(string(OUTDIR,"otape",counter,".png"),ogrid) + if taskname == "add" + return result,[opos[1],opos[2]-1] + else + return result,[opos[1],opos[2]+1] + end +end + +function run(igrid,ogrid,actions,outputs,spos,frame) + name(x)=(x == NOOP ? "NOOP":string("img/n",x,".png")) + oldi = igrid + oldo = ogrid + ipos = spos + if taskname in ["copy","reverse"] + opos = (-1,1) + else + opos = (-1,length(outputs) - length(find(x->x==NOOP,outputs))) + end + counter = 10 + modifygrid(ipos,igrid,frame,counter) + for (index,action) in enumerate(actions) + # modify output tape + ogrid,opos = modifygrid(name(outputs[index]),opos,ogrid,counter,frame) + counter +=1 + igrid = oldi + ipos = move(action,ipos,Int(size(oldi,1)/IMGH),Int(size(oldi,2)/IMGW)) + modifygrid(ipos,igrid,frame,counter) + end +end + +function main() + opts = parse_commandline() + parsedoc(opts[:jsonfile]) + println("actions:",actions) + println("input:",input) + println("output:",output) + println("spos:",spos) + println("taskname:",taskname) + if taskname == "add" + igrid,ogrid = makegridadd(input,output) + else + igrid,ogrid = makegrids(input,output) + end + + frame = makeframe() + # spos -> copy,reverse (-1,1) + # addition/mul -> (-1,k) + run(igrid,ogrid,actions,output,spos,frame) +end + +function makeframe() + fr = zeros(Float32,3,IMGH,IMGW) + fr[1,1:3,:] = 255 + fr[1,IMGH-2:IMGH,:] = 255 + fr[1,:,1:3] = 255 + fr[1,:,IMGW-2:IMGW] = 255 + return fr +end + +# rn : rownumber, cn:columnnumber +function move(a,p,rn,cn) + if a == RIGHT + p = [p[1],min(p[2]+1,cn)] + elseif a == LEFT + p = [p[1],max(p[2]-1,1)] + elseif a == UP + p = [max(p[1]-1,1),p[2]] + elseif a == DOWN + p = [min(p[1]+1,rn),p[2]] + else + println(" asdada") + end + return p +end + +function parse_commandline() + s = ArgParseSettings() + @add_arg_table s begin + ("--jsonfile";default="testinput1d.json") + ("--outpath";default="outdir/copy/") + end + return parse_args(s;as_symbols=true) +end + +main() diff --git a/demo/img/n0.png b/demo/img/n0.png new file mode 100644 index 0000000..4f8c56e Binary files /dev/null and b/demo/img/n0.png differ diff --git a/demo/img/n1.png b/demo/img/n1.png new file mode 100644 index 0000000..1937446 Binary files /dev/null and b/demo/img/n1.png differ diff --git a/demo/img/n10.png b/demo/img/n10.png new file mode 100644 index 0000000..4b97e33 Binary files /dev/null and b/demo/img/n10.png differ diff --git a/demo/img/n2.png b/demo/img/n2.png new file mode 100644 index 0000000..5fc2e45 Binary files /dev/null and b/demo/img/n2.png differ diff --git a/demo/img/n3.png b/demo/img/n3.png new file mode 100644 index 0000000..73a8f64 Binary files /dev/null and b/demo/img/n3.png differ diff --git a/demo/img/n4.png b/demo/img/n4.png new file mode 100644 index 0000000..5224b61 Binary files /dev/null and b/demo/img/n4.png differ diff --git a/demo/img/n5.png b/demo/img/n5.png new file mode 100644 index 0000000..d1b8891 Binary files /dev/null and b/demo/img/n5.png differ diff --git a/demo/img/n6.png b/demo/img/n6.png new file mode 100644 index 0000000..3d8bb60 Binary files /dev/null and b/demo/img/n6.png differ diff --git a/demo/img/n7.png b/demo/img/n7.png new file mode 100644 index 0000000..671e6b4 Binary files /dev/null and b/demo/img/n7.png differ diff --git a/demo/img/n8.png b/demo/img/n8.png new file mode 100644 index 0000000..90959a1 Binary files /dev/null and b/demo/img/n8.png differ diff --git a/demo/img/n9.png b/demo/img/n9.png new file mode 100644 index 0000000..111d39f Binary files /dev/null and b/demo/img/n9.png differ diff --git a/demo/makegif.sh b/demo/makegif.sh new file mode 100644 index 0000000..59d338f --- /dev/null +++ b/demo/makegif.sh @@ -0,0 +1,26 @@ +#! /usr/bin/env bash +set -e + +WIDTH=60 +HEIGHT=90 +SEQLEN=5 +SLEEP=150 +TOTALWIDTH=`expr ${WIDTH} \* ${SEQLEN}` +TOTALHEIGHT=`expr 2 \* ${HEIGHT} + 10` +OFFSET=100 +echo $TOTALHEIGHT +DATA_TYPE=${DATA_TYPE:-reverse} +echo "Using type=${DATA_TYPE}. To change this set DATA_TYPE to 'copy' or 'reverse' or 'add'" + +INPUT_DIR=${OUTPUT_DIR:-$HOME/gifs/${DATA_TYPE}} +OUTPUT_DIR=${OUTPUT_DIR:-$HOME/gifs/${DATA_TYPE}/outs} +echo "Writing to ${OUTPUT_DIR}. To change this, set the OUTPUT_DIR environment variable." +mkdir -p $OUTPUT_DIR + +convert -delay $SLEEP -loop 0 $INPUT_DIR/itape*.png $OUTPUT_DIR/gifintape.gif +convert -delay $SLEEP -loop 0 $INPUT_DIR/otape*.png $OUTPUT_DIR/gifouttape.gif + +convert $OUTPUT_DIR/gifintape.gif -repage ${TOTALWIDTH}x${TOTALHEIGHT} -coalesce null: \( $OUTPUT_DIR/gifouttape.gif \) -geometry +0+$OFFSET -layers Composite $OUTPUT_DIR/${DATA_TYPE}_final.gif + + + diff --git a/demo/outdir/add/itape10.png b/demo/outdir/add/itape10.png new file mode 100644 index 0000000..13441ea Binary files /dev/null and b/demo/outdir/add/itape10.png differ diff --git a/demo/outdir/add/itape11.png b/demo/outdir/add/itape11.png new file mode 100644 index 0000000..37f8cab Binary files /dev/null and b/demo/outdir/add/itape11.png differ diff --git a/demo/outdir/add/itape12.png b/demo/outdir/add/itape12.png new file mode 100644 index 0000000..7c672ba Binary files /dev/null and b/demo/outdir/add/itape12.png differ diff --git a/demo/outdir/add/itape13.png b/demo/outdir/add/itape13.png new file mode 100644 index 0000000..6a38009 Binary files /dev/null and b/demo/outdir/add/itape13.png differ diff --git a/demo/outdir/add/itape14.png b/demo/outdir/add/itape14.png new file mode 100644 index 0000000..f5e0e2b Binary files /dev/null and b/demo/outdir/add/itape14.png differ diff --git a/demo/outdir/add/itape15.png b/demo/outdir/add/itape15.png new file mode 100644 index 0000000..a6fbbb5 Binary files /dev/null and b/demo/outdir/add/itape15.png differ diff --git a/demo/outdir/add/itape16.png b/demo/outdir/add/itape16.png new file mode 100644 index 0000000..93f8409 Binary files /dev/null and b/demo/outdir/add/itape16.png differ diff --git a/demo/outdir/add/itape17.png b/demo/outdir/add/itape17.png new file mode 100644 index 0000000..887cf1d Binary files /dev/null and b/demo/outdir/add/itape17.png differ diff --git a/demo/outdir/add/itape18.png b/demo/outdir/add/itape18.png new file mode 100644 index 0000000..216654b Binary files /dev/null and b/demo/outdir/add/itape18.png differ diff --git a/demo/outdir/add/itape19.png b/demo/outdir/add/itape19.png new file mode 100644 index 0000000..93f8409 Binary files /dev/null and b/demo/outdir/add/itape19.png differ diff --git a/demo/outdir/add/itape20.png b/demo/outdir/add/itape20.png new file mode 100644 index 0000000..93f8409 Binary files /dev/null and b/demo/outdir/add/itape20.png differ diff --git a/demo/outdir/add/itape21.png b/demo/outdir/add/itape21.png new file mode 100644 index 0000000..216654b Binary files /dev/null and b/demo/outdir/add/itape21.png differ diff --git a/demo/outdir/add/otape10.png b/demo/outdir/add/otape10.png new file mode 100644 index 0000000..c2e397d Binary files /dev/null and b/demo/outdir/add/otape10.png differ diff --git a/demo/outdir/add/otape11.png b/demo/outdir/add/otape11.png new file mode 100644 index 0000000..566a85d Binary files /dev/null and b/demo/outdir/add/otape11.png differ diff --git a/demo/outdir/add/otape12.png b/demo/outdir/add/otape12.png new file mode 100644 index 0000000..445e795 Binary files /dev/null and b/demo/outdir/add/otape12.png differ diff --git a/demo/outdir/add/otape13.png b/demo/outdir/add/otape13.png new file mode 100644 index 0000000..d5c7b38 Binary files /dev/null and b/demo/outdir/add/otape13.png differ diff --git a/demo/outdir/add/otape14.png b/demo/outdir/add/otape14.png new file mode 100644 index 0000000..284f724 Binary files /dev/null and b/demo/outdir/add/otape14.png differ diff --git a/demo/outdir/add/otape15.png b/demo/outdir/add/otape15.png new file mode 100644 index 0000000..3f7f7f5 Binary files /dev/null and b/demo/outdir/add/otape15.png differ diff --git a/demo/outdir/add/otape16.png b/demo/outdir/add/otape16.png new file mode 100644 index 0000000..f2cdf1b Binary files /dev/null and b/demo/outdir/add/otape16.png differ diff --git a/demo/outdir/add/otape17.png b/demo/outdir/add/otape17.png new file mode 100644 index 0000000..6362dca Binary files /dev/null and b/demo/outdir/add/otape17.png differ diff --git a/demo/outdir/add/otape18.png b/demo/outdir/add/otape18.png new file mode 100644 index 0000000..703a601 Binary files /dev/null and b/demo/outdir/add/otape18.png differ diff --git a/demo/outdir/add/otape19.png b/demo/outdir/add/otape19.png new file mode 100644 index 0000000..b274ea4 Binary files /dev/null and b/demo/outdir/add/otape19.png differ diff --git a/demo/outdir/add/otape20.png b/demo/outdir/add/otape20.png new file mode 100644 index 0000000..876d9eb Binary files /dev/null and b/demo/outdir/add/otape20.png differ diff --git a/demo/outdir/add/outs/add_final.gif b/demo/outdir/add/outs/add_final.gif new file mode 100644 index 0000000..dd08f92 Binary files /dev/null and b/demo/outdir/add/outs/add_final.gif differ diff --git a/demo/outdir/add/outs/gifintape.gif b/demo/outdir/add/outs/gifintape.gif new file mode 100644 index 0000000..879d97a Binary files /dev/null and b/demo/outdir/add/outs/gifintape.gif differ diff --git a/demo/outdir/add/outs/gifouttape.gif b/demo/outdir/add/outs/gifouttape.gif new file mode 100644 index 0000000..35e275f Binary files /dev/null and b/demo/outdir/add/outs/gifouttape.gif differ diff --git a/demo/outdir/copy/itape10.png b/demo/outdir/copy/itape10.png new file mode 100644 index 0000000..212d169 Binary files /dev/null and b/demo/outdir/copy/itape10.png differ diff --git a/demo/outdir/copy/itape11.png b/demo/outdir/copy/itape11.png new file mode 100644 index 0000000..21fcaab Binary files /dev/null and b/demo/outdir/copy/itape11.png differ diff --git a/demo/outdir/copy/itape12.png b/demo/outdir/copy/itape12.png new file mode 100644 index 0000000..0e03321 Binary files /dev/null and b/demo/outdir/copy/itape12.png differ diff --git a/demo/outdir/copy/itape13.png b/demo/outdir/copy/itape13.png new file mode 100644 index 0000000..37e79e7 Binary files /dev/null and b/demo/outdir/copy/itape13.png differ diff --git a/demo/outdir/copy/itape14.png b/demo/outdir/copy/itape14.png new file mode 100644 index 0000000..d9a20de Binary files /dev/null and b/demo/outdir/copy/itape14.png differ diff --git a/demo/outdir/copy/itape15.png b/demo/outdir/copy/itape15.png new file mode 100644 index 0000000..d9a20de Binary files /dev/null and b/demo/outdir/copy/itape15.png differ diff --git a/demo/outdir/copy/otape10.png b/demo/outdir/copy/otape10.png new file mode 100644 index 0000000..7480c26 Binary files /dev/null and b/demo/outdir/copy/otape10.png differ diff --git a/demo/outdir/copy/otape11.png b/demo/outdir/copy/otape11.png new file mode 100644 index 0000000..09fb534 Binary files /dev/null and b/demo/outdir/copy/otape11.png differ diff --git a/demo/outdir/copy/otape12.png b/demo/outdir/copy/otape12.png new file mode 100644 index 0000000..8ef95b5 Binary files /dev/null and b/demo/outdir/copy/otape12.png differ diff --git a/demo/outdir/copy/otape13.png b/demo/outdir/copy/otape13.png new file mode 100644 index 0000000..e2ae84a Binary files /dev/null and b/demo/outdir/copy/otape13.png differ diff --git a/demo/outdir/copy/otape14.png b/demo/outdir/copy/otape14.png new file mode 100644 index 0000000..d9a20de Binary files /dev/null and b/demo/outdir/copy/otape14.png differ diff --git a/demo/outdir/copy/outs/copy_final.gif b/demo/outdir/copy/outs/copy_final.gif new file mode 100644 index 0000000..9855db7 Binary files /dev/null and b/demo/outdir/copy/outs/copy_final.gif differ diff --git a/demo/outdir/copy/outs/gifintape.gif b/demo/outdir/copy/outs/gifintape.gif new file mode 100644 index 0000000..2809b4f Binary files /dev/null and b/demo/outdir/copy/outs/gifintape.gif differ diff --git a/demo/outdir/copy/outs/gifouttape.gif b/demo/outdir/copy/outs/gifouttape.gif new file mode 100644 index 0000000..ce284f2 Binary files /dev/null and b/demo/outdir/copy/outs/gifouttape.gif differ diff --git a/demo/outdir/reverse/itape10.png b/demo/outdir/reverse/itape10.png new file mode 100644 index 0000000..0335f3d Binary files /dev/null and b/demo/outdir/reverse/itape10.png differ diff --git a/demo/outdir/reverse/itape11.png b/demo/outdir/reverse/itape11.png new file mode 100644 index 0000000..1afa9e9 Binary files /dev/null and b/demo/outdir/reverse/itape11.png differ diff --git a/demo/outdir/reverse/itape12.png b/demo/outdir/reverse/itape12.png new file mode 100644 index 0000000..a16219a Binary files /dev/null and b/demo/outdir/reverse/itape12.png differ diff --git a/demo/outdir/reverse/itape13.png b/demo/outdir/reverse/itape13.png new file mode 100644 index 0000000..c360c73 Binary files /dev/null and b/demo/outdir/reverse/itape13.png differ diff --git a/demo/outdir/reverse/itape14.png b/demo/outdir/reverse/itape14.png new file mode 100644 index 0000000..cdaa7a3 Binary files /dev/null and b/demo/outdir/reverse/itape14.png differ diff --git a/demo/outdir/reverse/itape15.png b/demo/outdir/reverse/itape15.png new file mode 100644 index 0000000..cdaa7a3 Binary files /dev/null and b/demo/outdir/reverse/itape15.png differ diff --git a/demo/outdir/reverse/itape16.png b/demo/outdir/reverse/itape16.png new file mode 100644 index 0000000..c360c73 Binary files /dev/null and b/demo/outdir/reverse/itape16.png differ diff --git a/demo/outdir/reverse/itape17.png b/demo/outdir/reverse/itape17.png new file mode 100644 index 0000000..a16219a Binary files /dev/null and b/demo/outdir/reverse/itape17.png differ diff --git a/demo/outdir/reverse/itape18.png b/demo/outdir/reverse/itape18.png new file mode 100644 index 0000000..1afa9e9 Binary files /dev/null and b/demo/outdir/reverse/itape18.png differ diff --git a/demo/outdir/reverse/itape19.png b/demo/outdir/reverse/itape19.png new file mode 100644 index 0000000..0335f3d Binary files /dev/null and b/demo/outdir/reverse/itape19.png differ diff --git a/demo/outdir/reverse/itape20.png b/demo/outdir/reverse/itape20.png new file mode 100644 index 0000000..0335f3d Binary files /dev/null and b/demo/outdir/reverse/itape20.png differ diff --git a/demo/outdir/reverse/itape21.png b/demo/outdir/reverse/itape21.png new file mode 100644 index 0000000..0335f3d Binary files /dev/null and b/demo/outdir/reverse/itape21.png differ diff --git a/demo/outdir/reverse/otape10.png b/demo/outdir/reverse/otape10.png new file mode 100644 index 0000000..c2e397d Binary files /dev/null and b/demo/outdir/reverse/otape10.png differ diff --git a/demo/outdir/reverse/otape11.png b/demo/outdir/reverse/otape11.png new file mode 100644 index 0000000..c2e397d Binary files /dev/null and b/demo/outdir/reverse/otape11.png differ diff --git a/demo/outdir/reverse/otape12.png b/demo/outdir/reverse/otape12.png new file mode 100644 index 0000000..c2e397d Binary files /dev/null and b/demo/outdir/reverse/otape12.png differ diff --git a/demo/outdir/reverse/otape13.png b/demo/outdir/reverse/otape13.png new file mode 100644 index 0000000..c2e397d Binary files /dev/null and b/demo/outdir/reverse/otape13.png differ diff --git a/demo/outdir/reverse/otape14.png b/demo/outdir/reverse/otape14.png new file mode 100644 index 0000000..c2e397d Binary files /dev/null and b/demo/outdir/reverse/otape14.png differ diff --git a/demo/outdir/reverse/otape15.png b/demo/outdir/reverse/otape15.png new file mode 100644 index 0000000..c2e397d Binary files /dev/null and b/demo/outdir/reverse/otape15.png differ diff --git a/demo/outdir/reverse/otape16.png b/demo/outdir/reverse/otape16.png new file mode 100644 index 0000000..aea47bb Binary files /dev/null and b/demo/outdir/reverse/otape16.png differ diff --git a/demo/outdir/reverse/otape17.png b/demo/outdir/reverse/otape17.png new file mode 100644 index 0000000..55d8b79 Binary files /dev/null and b/demo/outdir/reverse/otape17.png differ diff --git a/demo/outdir/reverse/otape18.png b/demo/outdir/reverse/otape18.png new file mode 100644 index 0000000..6baf0d4 Binary files /dev/null and b/demo/outdir/reverse/otape18.png differ diff --git a/demo/outdir/reverse/otape19.png b/demo/outdir/reverse/otape19.png new file mode 100644 index 0000000..b470e5d Binary files /dev/null and b/demo/outdir/reverse/otape19.png differ diff --git a/demo/outdir/reverse/otape20.png b/demo/outdir/reverse/otape20.png new file mode 100644 index 0000000..90bd072 Binary files /dev/null and b/demo/outdir/reverse/otape20.png differ diff --git a/demo/outdir/reverse/outs/gifintape.gif b/demo/outdir/reverse/outs/gifintape.gif new file mode 100644 index 0000000..f7065ba Binary files /dev/null and b/demo/outdir/reverse/outs/gifintape.gif differ diff --git a/demo/outdir/reverse/outs/gifouttape.gif b/demo/outdir/reverse/outs/gifouttape.gif new file mode 100644 index 0000000..053d068 Binary files /dev/null and b/demo/outdir/reverse/outs/gifouttape.gif differ diff --git a/demo/outdir/reverse/outs/reverse_final.gif b/demo/outdir/reverse/outs/reverse_final.gif new file mode 100644 index 0000000..76969bf Binary files /dev/null and b/demo/outdir/reverse/outs/reverse_final.gif differ diff --git a/demo/testdata/add_test.json b/demo/testdata/add_test.json new file mode 100644 index 0000000..5e63ad8 --- /dev/null +++ b/demo/testdata/add_test.json @@ -0,0 +1 @@ +{"task":"add","startpos":[1,5],"actions":["down","ml","up","ml","down","ml","up","ml","down","ml","up"],"output":[-78,2,-78,6,-78,0,-78,4,-78,9,-78],"input":[93837,225]} \ No newline at end of file diff --git a/demo/testdata/copy_test.json b/demo/testdata/copy_test.json new file mode 100644 index 0000000..920fd26 --- /dev/null +++ b/demo/testdata/copy_test.json @@ -0,0 +1 @@ +{"task":"copy","startpos":[-1,1],"actions":["mr","mr","mr","mr","mr"],"output":[3,3,8,0,4],"input":[[3,3,8,0,4]]} \ No newline at end of file diff --git a/demo/testdata/reverse_test.json b/demo/testdata/reverse_test.json new file mode 100644 index 0000000..e7e628a --- /dev/null +++ b/demo/testdata/reverse_test.json @@ -0,0 +1 @@ +{"task":"reverse","startpos":[-1,1],"actions":["mr","mr","mr","mr","mr","ml","ml","ml","ml","ml","ml"],"output":[-78,-78,-78,-78,-78,-78,4,4,8,6,3],"input":[[3,6,8,4,4,-2]]} \ No newline at end of file diff --git a/env.jl b/env.jl index 199786c..6881eb3 100644 --- a/env.jl +++ b/env.jl @@ -1,81 +1,54 @@ -import Base: push! -import Base: length -import Base: empty! -import Base: pop! -import Base: getindex - -const ACTIONS = ("mr","ml","up","down", "") -# token stands for start in input, stop in output - type Game - ninstances - input_tapes - output_tapes - gold_tapes - next_actions - prev_actions + input_tape + output_tape + gold_tape task - pointers - symgold + head timestep - # terminated - # mask - - function Game(x,y,actions,task="copy") - N = length(x) # ninstances + prev_actions + is_done - # input tapes - xtapes = [] + function Game(x,y,task="copy") + # make input tpae + input_tape = nothing if in(task, ("copy","reverse")) - for xi0 in x - xi1 = convert(Array{Int64}, xi0) - xi2 = reshape(xi1, 1, length(xi1)) - push!(xtapes, xi2) - end + x = convert(Array{Int64}, x) + input_tape = reshape(x, 1, length(x)) elseif in(task, ("add","mul","radd")) - for xi in x - push!(xtapes, make_grid(xi)) - end + input_tape = make_grid(x) elseif task == "walk" - # error("walk is not implemented yet") - for xi in x - push!(xtapes, xi) - end + input_tape = x end - # output tapes - ytapes = map(i->Any[], 1:N) + # make output tape + output_tape = Int64[] - # gold tapes - gtapes = [] + # make gold tape + gold_tape = nothing if in(task, ("copy","reverse","walk")) - for yi in y - push!(gtapes, yi) - end + gold_tape = y else - for yi in y - push!(gtapes, digits(yi)) - end + gold_tape = digits(y) end - # actions - xactions = map(ai->["", ai...], actions) - yactions = map(ai->[ai..., ""], actions) + # previous action + prev_actions = [STOP_ACTION] - # pointer <=> head - pointers = init_pointers(xtapes,N,task) - symgold = map(i->get_symgold(xtapes[i],gtapes[i],actions[i],task), 1:N) + # head + head = init_head(input_tape,task) timestep = 1 - # terminated = falses(N) - new( - N,xtapes,ytapes,gtapes,yactions,xactions, - task,pointers,symgold,timestep) + new(input_tape, output_tape, gold_tape, + task, head, timestep, prev_actions, false) end end -function init_pointers(grids,ninstances,task) - map(gi->get_origin(gi,task), grids) +function init_head(grid,task) + return get_origin(grid,task) +end + +function init_head(g::Game) + return get_origin(g.input_tape,g.task) end function get_origin(grid,task) @@ -91,84 +64,46 @@ function get_origin(grid,task) error("invalid task: $task") end -# now only just for copy and reverse tasks -function move_timestep!(g::Game, actions::Array) - for k = 1:g.ninstances - action = actions[k] - move_timestep!(g, k, action) - end +function move_timestep!(g::Game, write_symbol, move_action) + write_action = write_symbol == NO_SYMBOL ? NOT_WRITE : WRITE + write_action == WRITE && write!(g, write_symbol) + move!(g, move_action) g.timestep += 1 + push!(g.prev_actions, (move_action, write_action)) + g.is_done = is_done(g) end -function move_timestep!(g::Game) - actions = map(ai->ai[g.timestep], g.next_actions) - move_timestep!(g,actions) -end - -function move_timestep!(g::Game, instance::Int64, action) - k = instance +function move!(g::Game, action) if action == "mr" - g.pointers[k][2] += 1 + g.head[2] += 1 elseif action == "ml" - g.pointers[k][2] -= 1 + g.head[2] -= 1 elseif action == "" - # do nothing + g.is_done = true elseif action == "up" - g.pointers[k][1] -= 1 + g.head[1] -= 1 elseif action == "down" - g.pointers[k][1] += 1 + g.head[1] += 1 else error("invalid action: $action") end end -function make_input(g::Game, s2i, a2i) - x1 = zeros(Float32, length(s2i), g.ninstances) - - # x1 => onehots, x11 => values, x12 => decoded (actions) - x11 = map(i->read_symbol(g.input_tapes[i],g.pointers[i]), 1:g.ninstances) - x12 = map(v->s2i[v], x11) - for k = 1:length(x12); x1[x12[k],k] = 1; end - - # x2 => onehots, x21 => values, x22 => decoded (actions) - x2 = zeros(Float32, length(a2i), g.ninstances) - x21 = map(i->g.prev_actions[i][g.timestep], 1:g.ninstances) - x22 = map(v->a2i[v], x21) - for k = 1:length(x22); x2[x22[k],k] = 1; end - - return x1,x2 -end - -function make_inputs(g::Game, s2i, a2i) - reset!(g) - inputs = [] - for k = 1:length(g.prev_actions[1]) - push!(inputs, make_input(g,s2i,a2i)) - move_timestep!(g) - end - reset!(g) - return inputs +function write!(g::Game, symbol) + (g.task in ("copy","reverse","walk")?push!:unshift!)(g.output_tape, symbol) end -function make_output(g::Game, s2i, a2i) - y10 = map(i->g.symgold[i][g.timestep], 1:g.ninstances) - y11 = map(yi->s2i[yi], y10) - - y20 = map(i->g.next_actions[i][g.timestep], 1:g.ninstances) - y21 = map(yi->a2i[yi], y20) - - return y11, y21 -end - -function make_outputs(g, s2i, a2i) - reset!(g) - outputs = [] - for k = 1:length(g.next_actions[1]) - push!(outputs, make_output(g,s2i,a2i)) - move_timestep!(g) +# FIXME: when it is done? right thing is to check the last action is or not +function is_done(g::Game) + len = length(g.output_tape) + len >= length(g.gold_tape) && return true + gold = nothing + if g.task in ("copy","reverse","walk") + gold = g.gold_tape[1:len] + else + gold = g.gold_tape[end-length(len):end] end - reset!(g) - return outputs + return !(g.output_tape == gold) end function make_grid(x) @@ -185,10 +120,79 @@ function make_grid(x) end function reset!(g::Game) + g.prev_actions = [STOP_ACTION] + g.head = init_head(g.input_tape, g.task) g.timestep = 1 - g.pointers = init_pointers(g.input_tapes,g.ninstances,g.task) + empty!(g.output_tape) + g.is_done = false +end + +function read_symbol(grid, head) + if 0 < head[1] <= size(grid,1) && 0 < head[2] <= size(grid,2) + return grid[head...] + end + return -1 +end + +function read_symbol(g::Game) + return read_symbol(g.input_tape, g.head) +end + +# transitions +type Transition + reward # +1 for true symbol output, 0 for otherwise + input_symbol # read symbol - controller input + prev_action # previous action - controller input + nsteps # remaining steps + h; c # controller states + action # action has been taken + output_symbol # symbol written to output tape + is_done # last step or not +end + +function take_action(w, b, s, steps_done; o=Dict()) + ei = get(o, :epsinit, EPS_INIT) + ef = get(o, :epsfinal, EPS_FINAL) + ed = get(o, :epsdecay, EPS_DECAY) + et = ef + (ei-ef) * exp(-steps_done/ed) + + # I think they did not use GLIE, they just unroll + if rand() > et + # @show w,s + s = ndims(s) == 1 ? reshape(s,length(s),1) : s + y = predict(w,b,s) + return mapslices(indmax, Array(y), 1) + else + return rand(1:size(w,1), size(s)...) + end +end + +function take_action(w,b,s; eps=0.05) + if rand() > eps + s = ndims(s) == 1 ? reshape(s, length(s), 1) : s + y = predict(w,b,s) + return mapslices(indmax, Array(y), 1) + else + return rand(1:size(w,1), size(s,2)) + end +end + +function get_reward(g::Game) + length(g.output_tape) > length(g.gold_tape) && return 0 + if in(g.task, ("copy","reverse","walk")) + is_desired = g.output_tape == g.gold_tape[1:length(g.output_tape)] + else + is_desired = g.output_tape == g.gold_tape[end-length(g.output_tape):end] + end + return Int(is_desired && length(g.output_tape) != 0) end +function get_remaining_steps(g::Game) + return length(g.gold_tape) - length(g.output_tape) +end + +# deprecated old input functions, let me keep them for a while + # x: input tape, y: output tape, a: actions function get_symgold(x,y,a,task) if task == "copy" @@ -207,168 +211,105 @@ function get_symgold(x,y,a,task) end end -function read_symbol(grid, pointer) - if 0 < pointer[1] <= size(grid,1) && 0 < pointer[2] <= size(grid,2) - return grid[pointer...] - end - return -1 -end - -# Environment for Reinforcement Learning -type Transition - # +1 for true symbol output, 0 for otherwise - reward - - # environment state - POMDP - input_symbol - input_action - nsteps # remaining steps +function make_data(games,s2i,a2i,actions) + all_done = false + inputs, outputs, masks = [], [], [] + + t = 1 + while !all_done + input = zeros(Cuchar, length(s2i)+length(a2i), length(games)) + mask = falses(1, length(games)) + symgold = NO_SYMBOL*ones(Int64, length(games)) + actgold = length(a2i)*ones(Int64, length(games)) + + for (i,game) in enumerate(games) + # skip if game is finished + game.is_done && continue + + input_symbol = read_symbol(game.input_tape, game.head) + input[s2i[input_symbol],i] = 1 + input[length(s2i)+a2i[game.prev_actions[end]],i] = 1 + + move_action, write_action = actions[i][game.timestep] + y = NO_SYMBOL + if write_action == WRITE + mask[1,i] = 1 + if game.task in ("copy","reverse","walk") + y = game.gold_tape[length(game.output_tape)+1] + else + y = game.gold_tape[end-length(game.output_tape)] + end + symgold[i] = s2i[y] + end - # controller state - e.g. RNN hidden/cell - h - c + actgold[i] = a2i[(move_action,write_action)] + move_timestep!(game,y,move_action) + end - # next environment state - output_symbol - output_action -end + push!(inputs, input) + push!(outputs, (symgold,actgold)) + push!(masks, mask) -type ReplayMemory - capacity - memory + all_done = true + for k = 1:length(games) + if !games[k].is_done + all_done = false + end + end - function ReplayMemory(capacity) - memory = Transition[] - new(capacity, memory) + # @show t,all_done + t+=1 end -end -function push!(obj::ReplayMemory, t) - push!(obj.memory, t) - length(obj.memory) > obj.capacity && shift!(obj.memory) + for k = 1:length(games); reset!(games[k]); end + return inputs,outputs,masks end -function length(obj::ReplayMemory) - return length(obj.memory) -end +function make_input(g::Game, s2i, a2i) + x1 = zeros(Float32, length(s2i), g.ninstances) -function empty!(obj::ReplayMemory) - empty!(obj.memory) -end + # x1 => onehots, x11 => values, x12 => decoded (actions) + x11 = map(i->read_symbol(g.input_tapes[i],g.pointers[i]), 1:g.ninstances) + x12 = map(v->s2i[v], x11) + for k = 1:length(x12); x1[x12[k],k] = 1; end -function pop!(obj::ReplayMemory) - pop!(obj.memory) -end + # x2 => onehots, x21 => values, x22 => decoded (actions) + x2 = zeros(Float32, length(a2i), g.ninstances) + x21 = map(i->g.prev_actions[i][g.timestep], 1:g.ninstances) + x22 = map(v->a2i[v], x21) + for k = 1:length(x22); x2[x22[k],k] = 1; end -function sample(obj::ReplayMemory, nsamples, nsteps) - samples = [] - indices = randperm(length(obj))[1:min(nsamples,length(obj))] - for ind in indices - push!(samples, obj.memory[ind:min(ind+nsteps-1,length(obj))]) - end - return samples + return x1,x2 end -# currently I don't have an efficient idea to run episodes parallel -# I just leave it in this way for simplicity -function run_episodes!( - g::Game, mem, w, h, c, s2i, i2s, a2i, i2a, steps_done; o=Dict()) +function make_inputs(g::Game, s2i, a2i) reset!(g) - atype = typeof(w[:wcont]) - - for k = 1:g.ninstances - input_action = g.prev_actions[k][1] # no action input - input_symbol = read_symbol(g.input_tapes[k], g.pointers[k]) - predicted = [] - - while true - # make one-hot input vector - input = zeros(length(s2i)+length(a2i), 1) - input[s2i[input_symbol]] = 1 - input[length(s2i)+a2i[input_action]] = 1 - input = convert(atype, input) - - # use controller - cout, h1, c1 = propagate(w[:wcont], w[:bcont], input, h, c) - - # predict symbol - y1pred = predict(w[:wsymb],w[:bsymb], cout) - y1pred = indmax(Array(y1pred)) - push!(predicted, i2s[y1pred]) - - # take action - action = take_action(w[:wact],w[:bact],cout,steps_done; o=o) - action = i2a[action] - - # decide reward, termination, remaining steps - reward, done, nsteps = get_reward(g, k, predicted) - - # transition - this_transition = Transition( - reward, - input_symbol, - input_action, - nsteps, - h != nothing ? Array(h) : nothing, - c != nothing ? Array(c) : nothing, - predicted[end], # output_symbol - action) - - # push to replay memory - push!(mem, this_transition) - - # move head - move_timestep!(g, k, action) - - # change controller state - h = h1; c = c1 - - # change inputs - input_symbol = predicted[end] - input_action = action - - # if done, then break - done && break - end - - # FIXME: when to do steps_done increament? - # after each episode or after each step? - steps_done += 1 + inputs = [] + for k = 1:length(g.prev_actions[1]) + push!(inputs, make_input(g,s2i,a2i)) + move_timestep!(g) end - - return steps_done + reset!(g) + return inputs end -function take_action(w, b, s, steps_done; o=Dict()) - ei = get(o, :epsinit, EPS_INIT) - ef = get(o, :epsfinal, EPS_FINAL) - ed = get(o, :epsdecay, EPS_DECAY) - et = ef + (ei-ef) * exp(-steps_done/ed) +function make_output(g::Game, s2i, a2i) + y10 = map(i->g.symgold[i][g.timestep], 1:g.ninstances) + y11 = map(yi->s2i[yi], y10) - if rand() > et - # @show w,s - s = reshape(s,length(s),1) - y = predict(w,b,s) - y = Array(y) - return indmax(y) - else - return rand(1:2) - end -end + y20 = map(i->g.next_actions[i][g.timestep], 1:g.ninstances) + y21 = map(yi->a2i[yi], y20) -function get_reward(g::Game, instance, predictions) - symgold = g.symgold[instance] - symgold = filter(si->si!=NO_SYMBOL, symgold) - predictions = filter(pi->pi!=NO_SYMBOL, predictions) + return y11, y21 +end - reward = 0 - done = false - nsteps = length(symgold) - length(predictions) - if predictions == symgold[1:length(predictions)] - reward = 1 - else - done = true +function make_outputs(g, s2i, a2i) + reset!(g) + outputs = [] + for k = 1:length(g.next_actions[1]) + push!(outputs, make_output(g,s2i,a2i)) + move_timestep!(g) end - - return reward, done, nsteps + reset!(g) + return outputs end diff --git a/model.jl b/model.jl index 884157f..0f0af92 100644 --- a/model.jl +++ b/model.jl @@ -41,46 +41,57 @@ function predict(w,b,x) return w * x .+ b end -function logprob(output, ypred) +function logprob(output, ypred, mask=nothing) nrows,ncols = size(ypred) index = output + nrows*(0:(length(output)-1)) + # FIXME: this is so dirty + if mask != nothing && length(mask) != 1 + index = index[[mask...]] + elseif mask != nothing && length(mask) == 1 && !mask[1] + return 0 + end o1 = logp(ypred,1) + # @show index o2 = o1[index] o3 = sum(o2) return o3 end -# x1,y1 => input/output for symbols -# x2,y2 => input/output for actions -# weighted loss for soft symbol/action distributions -function sloss(w,x,y,h,c; values=[]) +# loss function for supervised learning +# x: controller input, y: controller output (action+symbol) +# m: masks for loss, h/c: controller states +function sloss(w,x,y,m,h,c; values=[]) batchsize = size(x[1][1],2) atype = typeof(AutoGrad.getval(w[:wcont])) lossval1 = lossval2 = 0 - for (xi,yi) in zip(x,y) + for (i,(xi,yi,mi)) in enumerate(zip(x,y,m)) # concat previous action and symbol from input tape - input = convert(atype,vcat(xi...)) # TODO: CPU/GPU comparison + input = convert(atype,xi) # TODO: CPU/GPU comparison # use the controller cout,h,c = propagate(w[:wcont],w[:bcont],input,h,c) # make predictions - y1pred = predict(w[:wsymb],w[:bsymb],cout) - y2pred = predict(w[:wact],w[:bact],cout) + sympred = predict(w[:wsymb],w[:bsymb],cout) + actpred = predict(w[:wact],w[:bact],cout) # log probabilities - lossval1 += logprob(yi[1],y1pred) - lossval2 += logprob(yi[2],y2pred) + symgold, actgold = yi[1], yi[2] + lossval1 += logprob(symgold,sympred,mi) + lossval2 += logprob(actgold,actpred) end # combined loss - lossval = 0.5*(lossval1+lossval2) - push!(values, AutoGrad.getval(-lossval)) - return -lossval/(batchsize*length(x)) + lossval = -0.5*(lossval1+lossval2) + push!(values, AutoGrad.getval(lossval)) + push!(values, batchsize*length(x)) + + # return -lossval/(batchsize*length(x)) + return lossval end -slgradient = grad(sloss) +slgrad = grad(sloss) function initweights( atype,units,nsymbols,nactions, @@ -123,45 +134,140 @@ function initopts(w,optim) return opts end - # Reinforcement Learning stuff # xs => controller inputs (concat prev_action and read_symbol) # ys => controller symbol outputs written to output tape # as => actions taken by following behaviour policy -# targets => temporal difference learning targets -# TODO: make it sequential? -function rloss(w, targets, xs, ys, as, h, c; values=[]) +# ts => temporal difference learning targets +function rloss(w, ts, xs, ys, as, ms, h, c; values=[]) # propagate controller, same with previous cout, h, c = propagate(w[:wcont], w[:bcont], xs, h, c) # symbol prediction, same sympred = predict(w[:wsymb], w[:bsymb], cout) - # action estimation, same + # compute Q estimate qsa = predict(w[:wact], w[:bact], cout) - - # compute indices nrows, ncols = size(qsa) index = as + nrows*(0:(length(as)-1)) - - # compute estimate - qs = qsa[index] + qs = qsa[index] # divide by nsteps remaining estimate = reshape(qs, 1, length(qs)) + ts = reshape(ts, 1, length(ts)) - # hybrid loss calculation + # hybrid loss calculation, supervised (symbols), q-learning (actions) val = 0 - val += -0.5 * logprob(ys, sympred) # sl loss, output symbols - val += 0.5 * sumabs2(targets-estimate) # rl loss, actions + val -= 0.5 * logprob(ys, sympred, ms) + val += 0.5 * sumabs2(ts-estimate) - push!(values, val) - return val / size(targets,2) + push!(values, val, size(ts,2)) + return val end -rlgradient = grad(rloss) +rlgrad = grad(rloss) + +# FIXME: this is so dirty and inefficient +function make_batches(w,histories,s2i,a2i,discount,nsteps,batchsize; o=Dict()) + atype = get(o, :atype, typeof(w[:wcont])) + + samples = [] + for history in histories + for k = 1:length(history)-1 + this = history[k] + + # input formation + x = (this.input_symbol,this.prev_action) + y = this.output_symbol + a = this.action + m = y != NO_SYMBOL && !this.is_done + ph = this.h + pc = this.c + vs = this.nsteps + + # target formation + T = min(k+nsteps, length(history)) + rs = reduce(+, [0, map(hi->hi.reward, history[k+1:T])...]) + yT = history[T].output_symbol + target = rs + if yT != NO_SYMBOL && !history[T].is_done + # compute target + xT = (history[T].input_symbol, history[T].prev_action) + input = zeros(Cuchar, length(s2i)+length(a2i), 1) + input[s2i[xT[1]]] = 1 + input[length(s2i)+a2i[xT[2]]] = 1 + input = convert(atype, input) + + vT = history[T].nsteps + phT,pcT = history[T].h, history[T].c + cout, hT, cT = propagate(w[:wcont],w[:bcont],input,phT,pcT) + qsa = predict(w[:wact], w[:bact], cout) + qs = maximum(qsa) + target += vT * maximum(qs) + end + + # normalize target + target = target/vs + + sample = (target,x,y,a,m,ph,pc) + push!(samples, sample) + end + + # episode ending + length(history) >= 1 || continue + target = history[end].reward / history[end].nsteps + x = (history[end].input_symbol, history[end].prev_action) + y = history[end].output_symbol + a = history[end].action + ph = history[end].h + pc = history[end].c + m = y != NO_SYMBOL && !history[end].is_done + sample = (target,x,y,a,m,ph,pc) + push!(samples,sample) + end + + batches = [] + for k = 1:batchsize:length(samples) + from = k; to = min(from+k-1,length(samples)) + bsamples = samples[from:to] + + # make target batch + ts = mapreduce(s->s[1], vcat, bsamples) + + # make input batch + xs = falses(length(s2i)+length(a2i), to-from+1) + for j = 1:to-from+1 + xs[s2i[bsamples[j][2][1]],j] = 1 + xs[length(s2i)+a2i[bsamples[j][2][2]]] = 1 + end + + # make output batch + ys = map(si->s2i[si[3]], bsamples) + + # make action batch + as = map(si->a2i[si[4]], bsamples) + + # make mask batch + ms = map(si->si[5], bsamples) + + # make h,c batches + hs = cs = nothing + if bsamples[1][end-1] != nothing + hs = mapreduce(bi->bi[end-1], hcat, bsamples) + cs = mapreduce(bi->bi[end], hcat, bsamples) + end + + batch = (ts,xs,ys,as,ms,hs,cs) + push!(batches, batch) + end + + return batches +end # compute TD targets for objective function compute_targets(samples, w, discount, nsteps, s2i, a2i) # reward calculations + if discount < 0 + discount = 1 + end discounts = map(i->discount^i, 0:nsteps) targets = zeros(1, length(samples)) for k = 1:length(samples) @@ -190,7 +296,7 @@ function compute_targets(samples, w, discount, nsteps, s2i, a2i) end # (2.2) batch environment states - aka controller inputs - sa = map(s->(s[end].input_symbol, s[end].input_action), samples) + sa = map(s->(s[end].input_symbol, s[end].prev_action), samples) inputs = zeros(length(s2i)+length(a2i), length(samples)) for k = 1:length(sa) # symbol-action pairs inputs[s2i[sa[k][1]],k] = 1 @@ -219,34 +325,37 @@ function compute_targets(samples, w, discount, nsteps, s2i, a2i) return targets end -function make_batch( - obj::ReplayMemory, w, discount, nsteps, s2i, a2i, batchsize) - samples = sample(obj, batchsize, nsteps) - targets = compute_targets(samples, w, discount, nsteps, s2i, a2i) - atype = typeof(targets) - - # xs <-> inputs (read symbol+previous action) - onehots - xs = zeros(length(s2i)+length(a2i),length(samples)) - for (i,sample) in enumerate(samples) - xs[s2i[sample[1].input_symbol],i] = 1 - xs[length(s2i)+a2i[sample[1].input_action],i] = 1 - end - xs = convert(atype, xs) - - # ys <-> output symbols - # as <-> actions - ys = map(s->s[1].output_symbol, samples); ys = map(yi->s2i[yi],ys) - as = map(s->s[1].output_action, samples); as = map(ai->a2i[ai],as) - - h = c = nothing - if samples[1][1].h != nothing - h = mapreduce(s->s[1].h, hcat, samples) - h = convert(atype, h) - end - if samples[1][1].c != nothing - c = mapreduce(s->s[1].c, hcat, samples) - c = convert(atype, c) - end - - return targets, xs, ys, as, h, c -end +# function make_batch( +# obj::ReplayMemory, w, discount, nsteps, s2i, a2i, batchsize) +# samples = sample(obj, batchsize, nsteps) +# targets = compute_targets(samples, w, discount, nsteps, s2i, a2i) +# atype = typeof(targets) + +# # xs <-> inputs (read symbol+previous action) - onehots +# xs = zeros(length(s2i)+length(a2i),length(samples)) +# for (i,sample) in enumerate(samples) +# xs[s2i[sample[1].input_symbol],i] = 1 +# xs[length(s2i)+a2i[sample[1].input_action],i] = 1 +# end +# xs = convert(atype, xs) + +# # ys <-> output symbols +# # as <-> actions +# ys = map(s->s[1].output_symbol, samples); ys = map(yi->s2i[yi],ys) +# as = map(s->s[1].output_action, samples); as = map(ai->a2i[ai],as) + +# h = c = nothing +# if samples[1][1].h != nothing +# h = mapreduce(s->s[1].h, hcat, samples) +# h = convert(atype, h) +# end +# if samples[1][1].c != nothing +# c = mapreduce(s->s[1].c, hcat, samples) +# c = convert(atype, c) +# end +# vs = map(si->si[1].nsteps, samples) +# vs = reshape(vs, 1, length(vs)) +# vs = convert(atype, vs) + +# return targets, xs, ys, as, h, c, vs +# end diff --git a/train.jl b/train.jl index b8353d2..3832413 100644 --- a/train.jl +++ b/train.jl @@ -1,13 +1,14 @@ using Knet using ArgParse using JLD +using Combinatorics include("env.jl") include("model.jl") include("data.jl") include("vocab.jl") -const CAPACITY = 50000 +const CAPACITY = 20000 const EPS_INIT = 0.9 const EPS_FINAL = 0.005 const EPS_DECAY = 200 @@ -25,7 +26,7 @@ function main(args) ("--optim"; default="Rmsprop()") ("--units"; default=200) ("--controller"; default="feedforward"; help="feedforward or lstm") - ("--discount"; default=0.95) + ("--discount"; default=-1.; arg_type=Float64) ("--start"; default=6; arg_type=Int64) ("--end"; default=50; arg_type=Int64) ("--step"; default=4; arg_type=Int64) @@ -37,6 +38,10 @@ function main(args) ("--supervised"; action=:store_true; help="if not, q-learning") ("--capacity"; default=CAPACITY; arg_type=Int64) ("--nsteps"; default=20; arg_type=Int64) + ("--update"; default=5000; arg_type=Int64) + ("--nepisodes"; default=5000; arg_type=Int64) + ("--threshold"; default=0.98; arg_type=Float64) + ("--noreward"; default=10; arg_type=Int64) end isa(args, AbstractString) && (args=split(args)) @@ -46,170 +51,294 @@ function main(args) data_generator = get_data_generator(o[:task]) # init model, params etc. - w = wfix = opts = s2i = i2s = nothing - a2i, i2a = initvocab(ACTIONS) + w = opts = s2i = i2s = nothing + a2i, i2a = initvocab(get_actions(o[:task])) + s2i, i2s = initvocab(get_symbols(o[:task])) if o[:loadfile] == nothing - s2i, i2s = initvocab(get_symbols(o[:task])) w = initweights( o[:atype],o[:units],length(s2i),length(a2i),o[:controller],o[:dist]) - opts = initopts(w,o[:optim]) else o[:task] = load(o[:loadfile], "task") w = load(o[:loadfile], "w") + w = Dict(k=>convert(o[:atype],v) for (k,v) in w) # opts - not yet! # opts = load(o[:loadfile]) - o[:start] = load(o[:loadfile], "complexity") - s2i, i2s = initvocab(get_symbols(o[:task])) - end - mem = ReplayMemory(o[:capacity]) - if !o[:supervised] - wfix = Dict(map(k->(k,copy(w[k])), keys(w))) + # o[:start] = load(o[:loadfile], "complexity") end + opts = initopts(w,o[:optim]) # C => complexity # c => controller cell steps_done = 0 for C = o[:start]:o[:step]:o[:end] + # prepare validation data seqlen = div(C, complexities[o[:task]]) - val = map(xi->data_generator(seqlen), [1:o[:nvalid]...]) - iter = 0 - lossval = 0 + data = map(xi->data_generator(seqlen), [1:o[:nvalid]...]) + valid = []; actions = [] + for (input,output,action) in data + push!(valid, Game(input, output, o[:task])) + push!(actions, action) + end + empty!(data) + iter = 1 + lossval = 0 while true - trn = map(xi->data_generator(seqlen), 1:o[:batchsize]) - x = map(xi->xi[1], trn) - y = map(xi->xi[2], trn) - actions = map(xi->xi[3], trn) - game = Game(x,y,actions,o[:task]) - T = length(game.symgold[1]) - - h,c = initstates( - o[:atype],o[:units],o[:batchsize],o[:controller]) - - # FIXME: sl.iter != rl.iter (but how) - if o[:supervised] - inputs = make_inputs(game, s2i, a2i) - outputs = make_outputs(game, s2i, a2i) - timesteps = length(inputs) - batchsize = o[:batchsize] - batchloss = sltrain!(w,inputs,outputs,h,c,opts) - batchloss = batchloss / (batchsize * timesteps) - iter += 1 - lossval = update_lossval(lossval,batchloss,iter) - else # rl train - # run new episodes - steps_done = run_episodes!( - game, mem, w, h, c, s2i, i2s, a2i, i2a, steps_done; o=o) - - # train with batches from memory - for k = 1:o[:period] - batchsize = o[:batchsize] - batch = make_batch(mem, wfix, o[:discount], o[:nsteps], - s2i, a2i, o[:batchsize]) - batchloss = rltrain!(w,batch...,opts) - batchloss = batchloss / batchsize - lossval = update_lossval(lossval,batchloss,iter) - iter += 1 - end + # get examples for training + trn = map(xi->data_generator(seqlen), 1:o[:nepisodes]) + + # build environments + games = []; actions = [] + for (input,output,action) in trn + push!(games, Game(input, output, o[:task])) + push!(actions, [action..., STOP_ACTION]) end - # perform the validation + # train network while running episodes + this_loss = train!( + w,games,s2i,i2s,a2i,i2a,actions,opts,o[:supervised]; o=o) + lossval = update_lossval(lossval, this_loss, iter) + + lossval = update_lossval(lossval,this_loss,iter) if iter % o[:period] == 0 - println("lossval:$lossval") - acc = validate(w,s2i,i2s,a2i,i2a,val,o) - println("(iter:$iter,acc:$acc)") - if acc > 0.98 - println("$C converged in $iter iteration") - # Knet.gc(); gc(); Knet.gc() - - # save model - if o[:savefile] != nothing - save(o[:savefile], - "w", map(Array, w), - # need something like above for opts - # "opts", opts, - "task", o[:task], - "complexity", C) - end - - if !o[:supervised] - wfix = Dict(map(k->(k,copy(w[k])), keys(w))) - empty!(mem) - end + accuracy, average_reward = validate(w,valid,s2i,i2s,a2i,i2a;o=o) + println("(iter:$iter,loss:$lossval,accuracy:$accuracy,reward:$average_reward)") + if accuracy >= o[:threshold] && o[:savefile] != nothing + save(o[:savefile], + "w", Dict(k=>Array(v) for (k,v) in w), + # need something like above for opts + # "opts", opts, + "task", o[:task], + "complexity", C) + end + if accuracy >= o[:threshold] + println("$C converged in $iter iterations") break end - end # validation + end + + iter += 1 end # while true end # one complexity step end -function sltrain!(w,x,y,h,c,opts) - values = [] - gloss = slgradient(w,x,y,h,c; values=values) - update!(w, gloss, opts) - return values[1] +# train while running episodes +function train!(w,games,s2i,i2s,a2i,i2a,actions,opts,supervised; o=Dict()) + train = true + run_episodes!(w,games,s2i,i2s,a2i,i2a,train,supervised; + o=o,opts=opts,actions=actions) end -function rltrain!(w,targets,x,y,a,h,c,opts) - values = [] - gloss = rlgradient(w,targets,x,y,a,h,c; values=values) - update!(w, gloss, opts) - return values[1] +# validate games +function validate(w,games,s2i,i2s,a2i,i2a; o=Dict()) + train = supervised = false + acc, reward = run_episodes!(w,games,s2i,i2s,a2i,i2a,train,supervised; o=o) + return acc, reward end -function validate(w,s2i,i2s,a2i,i2a,data,o) - batches = map(i->data[i:i+o[:batchsize]-1], [1:o[:batchsize]:length(data)...]) +# this is skeleton for all methods +function run_episodes!(w,games,s2i,i2s,a2i,i2a,train,supervised; + o=Dict(), opts=Dict(), actions=[]) + + # init state parameters + atype = get(o, :atype, typeof(w[:wcont])) + hidden = size(w[:wcont], 1) + controller = get(o, :controller, "lstm") + + # unrolled + nsteps = get(o, :nsteps, 20) + + # supervised learning + if train && supervised + isempty(actions) && error("actions must not be empty for supervision") + + # (1) prepare input data + inputs, outputs, masks = make_data(games,s2i,a2i,actions) + + # (2) init controller states + h, c = initstates(atype, hidden, length(games), controller) + + # (3) train network + batchloss = sltrain!(w,h,c,inputs,outputs,masks,opts; o=o) + + # (4) return batchloss + return batchloss + end + + # needed by qwatkins + histories = [] + for game in games; push!(histories, []); end + + # run episodes - for both validation and q-learning ncorrect = 0 - for batch in batches - x = map(xi->xi[1], batch) - y = map(xi->xi[2], batch) - actions = map(xi->xi[3], batch) - game = Game(x,y,actions,o[:task]) - T = length(game.next_actions[1]) - - correctness = trues(length(batch)) - h,c = initstates(o[:atype],o[:units],o[:batchsize],o[:controller]) - for k = 1:T - x1,x2 = make_input(game, s2i, a2i) - y1,y2 = make_output(game, s2i, a2i) - x1 = convert(o[:atype], x1) - x2 = convert(o[:atype], x2) - cout, h, c = propagate(w[:wcont],w[:bcont],vcat(x1,x2),h,c) - y1pred = predict(w[:wsymb],w[:bsymb],cout) - y2pred = predict(w[:wact],w[:bact],cout) - - y1pred = convert(Array, y1pred) - y1pred = mapslices(indmax,y1pred,1) - y1pred = map(yi->i2s[yi], y1pred) - - y2pred = convert(Array, y2pred) - y2pred = mapslices(indmax,y2pred,1) - y2pred = map(yi->i2a[yi], y2pred) - - # check correctness - for i = 1:length(y1pred) - if y1pred[i] != game.symgold[i][k] - correctness[i] = false - end + iter = 1 + + cumulative_reward = 0 + for (i,game) in enumerate(games) + episode_reward = 0 + h, c = initstates(atype, hidden, 1, controller) + while !game.is_done + # (1) prepare input data + input_symbol = read_symbol(game) + input_action = game.prev_actions[end] + input = zeros(Cuchar, length(s2i)+length(a2i), 1) + input[s2i[input_symbol],1] = 1 + input[length(s2i)+a2i[input_action],1] = 1 + + # (2) propage controller + prev_h = prev_c = nothing + if train && h != nothing && c != nothing + prev_h = Array(h) + prev_c = Array(c) end + cout, h, c = propagate(w[:wcont],w[:bcont],convert(atype,input),h,c) - for i = 1:game.ninstances - game.prev_actions[i][k] = y2pred[i] + # (3) take action + eps = train ? 0.0 : 0.20 + next_action = first(take_action(w[:wact],w[:bact],cout; eps=eps)) + next_action = i2a[next_action] + move_action, write_action = next_action + + # (4) predict symbol + y = NO_SYMBOL + if write_action == WRITE + y = predict(w[:wsymb],w[:bsymb],cout) + # symbol = mapslices(indmax, Array(symbol), 1)[1] + y = indmax(Array(y)) + y = i2s[y] end - move_timestep!(game,y2pred) + + move_timestep!(game, y, move_action) + reward = get_reward(game) + episode_reward += reward + + # (5) add transition to history if phase is RL train + if train + remaining_steps = get_remaining_steps(game) + + # new transition + transition = Transition( + reward, + input_symbol, + input_action, + remaining_steps, + prev_h, prev_c, + next_action, + y, + game.is_done + ) + + add_history = true + if game.is_done + add_history = game.gold_tape == game.output_tape + end + + # push it do history + add_history && push!(histories[i], transition) + end + end + + cumulative_reward += episode_reward + if game.output_tape == game.gold_tape + ncorrect += 1 end - ncorrect += sum(correctness) end - return ncorrect / length(data) + + # reset games + for (i,game) in enumerate(games) + reset!(game) + end + + # q-watkins training if phase is RL train + if train && !supervised + discount = get(o, :discount, -1) + batchsize = get(o, :batchsize, length(games)) + + # (1) prepare input data + batches = make_batches( + w,histories,s2i,a2i,discount,nsteps,batchsize) + + # (2) train network + batchloss = rltrain!(w,batches,opts; o=o) + + # (3) return batchloss + return batchloss + end + + return ncorrect/length(games), cumulative_reward/length(games) +end + +function sltrain!(w,h,c,inputs,outputs,masks,opt; o=Dict()) + dw = similar(w) + for k in keys(w); dw[k] = similar(w[k]); fill!(dw[k], 0); end + + total = num_samples = 0 + maxlen = get(o, :maxlen, 50) + for i = 1:maxlen:length(inputs) + lower = i; upper = min(i+maxlen-1,length(inputs)) + x = inputs[lower:upper] + y = outputs[lower:upper] + m = masks[lower:upper] + + values = [] + gloss = slgrad(w,x,y,m,h,c; values=values) + + # track loss and training sample values + total += values[1]; num_samples += values[2] + + # accumulate gradients + for k in keys(dw); dw[k] += gloss[k]; end + + # unbox states + h = AutoGrad.getval(h) + c = AutoGrad.getval(c) + end + + # clip with num_samples + for k in keys(dw); dw[k] = dw[k]/num_samples; end + # FIXME: try a different thing in here + + # update weights + update!(w,dw,opt) + + return total/num_samples +end + +function rltrain!(w,batches,opt; o=Dict()) + dw = similar(w) + for k in keys(w); dw[k] = similar(w[k]); fill!(dw[k], 0); end + atype = get(o, :atype, AutoGrad.getval(typeof(w[:wcont]))) + + total = num_samples = 0 + mapconvert(args...) = map(i->convert(atype,i), args) + asarray(x) = typeof(x) <: Array ? x : typeof(x)[x] + asarrays(xs...) = map(x->asarray(x), xs) + for (ts,xs,ys,as,ms,hs,cs) in batches + ts,ys,as,ms = asarrays(ts,ys,as,ms) + # @show ts,ys,as,ms + ts,xs = mapconvert(ts,xs) + if hs != nothing && cs != nothing + hs,cs = mapconvert(hs,cs) + end + + values = [] + gloss = rlgrad(w,ts,xs,ys,as,ms,hs,cs; values=values) + total += values[1]; num_samples += values[2] + for k in keys(gloss); dw[k] += gloss[k]; end + end + + for k in keys(w); dw[k] = dw[k]/num_samples; end + update!(w,dw,opt) + return total/num_samples end -function update_lossval(lossval,batchloss,iter) +function update_lossval(lossval,batchloss,iter,alpha=0.01) if iter < 100 lossval = (iter-1)*lossval + batchloss lossval = lossval / iter else - lossval = 0.01 * batchloss + 0.99 * lossval + lossval = alpha * batchloss + (1-alpha) * lossval end return lossval end diff --git a/vocab.jl b/vocab.jl index 1fae676..7b99fa9 100644 --- a/vocab.jl +++ b/vocab.jl @@ -1,9 +1,15 @@ const COPY_SYMBOLS = (-1:9...) const REVERSE_SYMBOLS = (-2:9...) const WALK_SYMBOLS = (-4:9...) -const SYMBOLS = (-1:9...) -const NO_OP = -1 +const SYMBOLS = (0:9...) const NO_SYMBOL = -1 +const WRITE = "write" +const NOT_WRITE = "not-write" + +const TAPE_ACTIONS = ("mr","ml") +const GRID_ACTIONS = ("mr","ml") +const WRITE_ACTIONS = (WRITE, NOT_WRITE) +const STOP_ACTION = ("", NOT_WRITE) function get_symbols(task) if in(task,("copy","reverse","walk")) @@ -12,9 +18,24 @@ function get_symbols(task) return SYMBOLS end +function get_actions(task) + actions = nothing + if in(task,("copy","reverse")) + actions = TAPE_ACTIONS + else + actions = GRID_ACTIONS + end + retval = [] + for ma in actions + for wa in WRITE_ACTIONS + push!(retval, (ma,wa)) + end + end + return [retval..., STOP_ACTION] +end + function initvocab(symbols) symbols = collect(symbols) - sort!(symbols) s2i, i2s = Dict(), Dict() c = 1 for sym in symbols