Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
708e8e0
Test added
ratopi Feb 14, 2018
96f0500
typo : lable -> label
ratopi Feb 15, 2018
aa81925
Refactoring
ratopi Feb 15, 2018
a58fe20
Ignoring IDE-files
ratopi Feb 17, 2018
3fb2ee4
Refactoring
ratopi Feb 17, 2018
bef562c
Mavenized - Project is now a Maven-Project
ratopi Feb 17, 2018
2529e4e
Moved package "edu.hitsz.c102c" to "javacnn"
ratopi Feb 17, 2018
9a38629
refactoring : introducing getOutput()
ratopi Feb 17, 2018
b19844e
Removed console-Listener
ratopi Feb 17, 2018
7859376
Using Output/InputStream in CNNLoader
ratopi Feb 17, 2018
f9ae92c
release 0.1
ratopi Feb 17, 2018
57bdcfb
preparation for 0.2 in pom
ratopi Feb 17, 2018
b0697d4
Introducing interface 'Process' to hide implemenetation details of Co…
ratopi Feb 19, 2018
d1b6439
Refactoring ConcurenceRunner is now object
ratopi Feb 19, 2018
d81b5ce
Introducing Runner-interface
ratopi Feb 19, 2018
3ca5d5b
- TaskManager removed in ConcurenceRunner : no longer needed
ratopi Feb 19, 2018
513b368
Now possible to specify count of threads in ConcurenceRunner
ratopi Feb 19, 2018
0a731dd
Moved RunCNN and dataset to test
ratopi Feb 19, 2018
59eb0cb
release 0.2
ratopi Feb 19, 2018
e1f7804
preparation for release 0.3
ratopi Feb 19, 2018
70b8ab8
removed System.out from CNN-class
ratopi Feb 19, 2018
d0b56a1
- CNN.propagate with double[]-input introduced
ratopi Feb 20, 2018
0950f6a
Implemented the DirectRunner
ratopi Feb 20, 2018
3a2511f
release 0.3
ratopi Feb 20, 2018
9ad86d7
preparation for next release (0.4)
ratopi Feb 20, 2018
bfe2ca8
A propagation method useful for production ;-)
ratopi Feb 21, 2018
9461325
readme revised
ratopi Feb 21, 2018
e85331e
- ALPHA is now member of CNN
ratopi Feb 21, 2018
40c1fce
Log is dis-/enable-able
ratopi Feb 21, 2018
cc73be0
release 0.4
ratopi Feb 21, 2018
bc0a926
preparation for release 0.5
ratopi Feb 21, 2018
ca3f5b6
removed eclipse settings
ratopi Mar 1, 2018
5a918be
readme.md ...
ratopi Mar 1, 2018
3edc1ed
get it with maven!
ratopi Mar 8, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
- TaskManager removed in ConcurenceRunner : no longer needed
- Runner is transient in CNN : for de-serialization
- removed unused method ConcurenceRunner.isSame
- removed code in comments
- restored some comments (in english)
  • Loading branch information
ratopi committed Feb 19, 2018
commit 3ca5d5bb0dc581a939cfe1986ab2c024f21b460e
46 changes: 26 additions & 20 deletions src/main/java/RunCNN.java
Original file line number Diff line number Diff line change
@@ -1,43 +1,49 @@
import java.io.IOException;

import javacnn.cnn.CNN;
import javacnn.cnn.CNN.LayerBuilder;
import javacnn.cnn.CNNLoader;
import javacnn.cnn.Layer;
import javacnn.cnn.Layer.Size;
import javacnn.dataset.Dataset;
import javacnn.dataset.DatasetLoader;
import javacnn.util.ConcurenceRunner;

public class RunCNN {

public static void main(String[] args) throws IOException {
public static void main(String[] args) throws IOException, ClassNotFoundException {

final ConcurenceRunner concurenceRunner = new ConcurenceRunner();

final LayerBuilder builder = new LayerBuilder();
try {

builder.addLayer(Layer.buildInputLayer(new Size(28, 28)));
builder.addLayer(Layer.buildConvLayer(6, new Size(5, 5)));
builder.addLayer(Layer.buildSampLayer(new Size(2, 2)));
builder.addLayer(Layer.buildConvLayer(12, new Size(5, 5)));
builder.addLayer(Layer.buildSampLayer(new Size(2, 2)));
builder.addLayer(Layer.buildOutputLayer(10));
final CNN.LayerBuilder builder = new CNN.LayerBuilder();

final CNN cnn = new CNN(builder, 50, concurenceRunner);
builder.addLayer(Layer.buildInputLayer(new Layer.Size(28, 28)));
builder.addLayer(Layer.buildConvLayer(6, new Layer.Size(5, 5)));
builder.addLayer(Layer.buildSampLayer(new Layer.Size(2, 2)));
builder.addLayer(Layer.buildConvLayer(12, new Layer.Size(5, 5)));
builder.addLayer(Layer.buildSampLayer(new Layer.Size(2, 2)));
builder.addLayer(Layer.buildOutputLayer(10));

final String fileName = "dataset/train.format";
final Dataset dataset = DatasetLoader.load(fileName, ",", 784);
cnn.train(dataset, 5);
final CNN cnn = new CNN(builder, 50, concurenceRunner);

CNNLoader.saveModel("model.cnn", cnn);
dataset.clear();
final String fileName = "dataset/train.format";
final Dataset dataset = DatasetLoader.load(fileName, ",", 784);
cnn.train(dataset, 5);

// CNN cnn = CNNLoader.loadModel(modelName);
final Dataset testset = DatasetLoader.load("dataset/test.format", ",", -1);
cnn.predict(testset, "dataset/test.predict");
CNNLoader.saveModel("model.cnn", cnn);
dataset.clear();

concurenceRunner.stop();
/*
final CNN cnn = CNNLoader.loadModel("model.cnn");
cnn.setRunner(concurenceRunner);
*/

final Dataset testset = DatasetLoader.load("dataset/test.format", ",", -1);
cnn.predict(testset, "dataset/test.predict");

} finally {
concurenceRunner.shutdown();
}
}

}
86 changes: 32 additions & 54 deletions src/main/java/javacnn/cnn/CNN.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,22 @@

public class CNN implements Serializable {

private static final long serialVersionUID = 337920299147929932L;
private static final long serialVersionUID = 1L;

private static final double LAMBDA = 0;

private static double ALPHA = 0.85;


private final Runner runner;

private final List<Layer> layers;

private final int layerNum;

private final int batchSize;

private final Util.Operator divide_batchSize;

private final Util.Operator multiply_alpha;

private final Util.Operator multiply_lambda;

private transient Runner runner;


public CNN(LayerBuilder layerBuilder, final int batchSize, final Runner runner) {

Expand Down Expand Up @@ -81,14 +76,22 @@ public double process(double value) {
};
}

public void setRunner(final Runner runner) {
this.runner = runner;
}

private Runner getRunner() {
if (runner == null) throw new NullPointerException("'runner' is null. Set runner before start training or test!");
return runner;
}

public void train(final Dataset trainset, final int iterationCount) {
for (int iteration = 0; iteration < iterationCount; iteration++) {

int epochsNum = trainset.size() / batchSize;

if (trainset.size() % batchSize != 0) {
epochsNum++;
epochsNum++; // Extract once, round up
}

Log.info("");
Expand All @@ -111,6 +114,7 @@ public void train(final Dataset trainset, final int iterationCount) {
Layer.prepareForNewRecord();
}

// After finishing a batch update weight
updateParas();

if (epoch % 50 == 0) {
Expand All @@ -124,7 +128,7 @@ public void train(final Dataset trainset, final int iterationCount) {
final double precision = ((double) right) / count;

if (iteration % 10 == 1 && precision > 0.96) {
ALPHA = 0.001 + ALPHA * 0.9;
ALPHA = 0.001 + ALPHA * 0.9; // Adjust the quasi-learning rate dynamically
Log.info("Set alpha = " + ALPHA);
}

Expand Down Expand Up @@ -196,22 +200,10 @@ public void predict(Dataset testset, String fileName) {
Log.info("end predict");
}

private boolean isSame(double[] output, double[] target) {
boolean r = true;
for (int i = 0; i < output.length; i++)
if (Math.abs(output[i] - target[i]) > 0.5) {
r = false;
break;
}

return r;
}

private boolean train(Dataset.Record record) {
forward(record);
boolean result = backPropagation(record);
return result;
// System.exit(0);

return backPropagation(record);
}

private boolean backPropagation(Dataset.Record record) {
Expand All @@ -228,15 +220,15 @@ private void updateParas() {
case conv:
case output:
updateKernels(layer, lastLayer);
updateBias(layer, lastLayer);
updateBias(layer);
break;
default:
break;
}
}
}

private void updateBias(final Layer layer, Layer lastLayer) {
private void updateBias(final Layer layer) {
final double[][][][] errors = layer.getErrors();
int mapNum = layer.getOutMapNum();

Expand All @@ -254,7 +246,7 @@ public void process(int start, int end) {
}
};

runner.startProcess(mapNum, processor);
getRunner().startProcess(mapNum, processor);
}

private void updateKernels(final Layer layer, final Layer lastLayer) {
Expand Down Expand Up @@ -286,7 +278,7 @@ public void process(int start, int end) {
}
};

runner.startProcess(mapNum, process);
getRunner().startProcess(mapNum, process);
}

private void setHiddenLayerErrors() {
Expand Down Expand Up @@ -329,7 +321,7 @@ public void process(int start, int end) {

};

runner.startProcess(mapNum, process);
getRunner().startProcess(mapNum, process);
}

private void setConvErrors(final Layer layer, final Layer nextLayer) {
Expand All @@ -349,29 +341,13 @@ public void process(int start, int end) {
}
};

runner.startProcess(mapNum, process);
getRunner().startProcess(mapNum, process);
}

private boolean setOutLayerErrors(final Dataset.Record record) {

Layer outputLayer = layers.get(layerNum - 1);
int mapNum = outputLayer.getOutMapNum();
// double[] target =
// record.getDoubleEncodeTarget(mapNum);
// double[] outmaps = new double[mapNum];
// for (int m = 0; m < mapNum; m++) {
// double[][] outmap = outputLayer.getMap(m);
// double output = outmap[0][0];
// outmaps[m] = output;
// double errors = output * (1 - output) *
// (target[m] - output);
// outputLayer.setError(m, 0, 0, errors);
// }
// // ��ȷ
// if (isSame(outmaps, target))
// return true;
// return false;

final Layer outputLayer = layers.get(layerNum - 1);
final int mapNum = outputLayer.getOutMapNum();
final double[] target = new double[mapNum];
final double[] outmaps = new double[mapNum];

Expand All @@ -384,17 +360,19 @@ private boolean setOutLayerErrors(final Dataset.Record record) {

target[label] = 1;

// Log.i(record.getLable() + "outmaps:" +
// Util.fomart(outmaps)
// + Arrays.toString(target));

for (int m = 0; m < mapNum; m++) {
outputLayer.setError(m, 0, 0, outmaps[m] * (1 - outmaps[m]) * (target[m] - outmaps[m]));
}

return label == Util.getMaxIndex(outmaps);
}

/**
* Propagate given Record through the network.
* Returns the result.
* @param record A Record
* @return The result of the network
*/
public double[] propagate(final Dataset.Record record) {
forward(record);

Expand Down Expand Up @@ -484,7 +462,7 @@ public double process(double value) {
}
};

runner.startProcess(mapNum, process);
getRunner().startProcess(mapNum, process);
}

private void setSampOutput(final Layer layer, final Layer lastLayer) {
Expand All @@ -502,7 +480,7 @@ public void process(int start, int end) {
}
};

runner.startProcess(lastMapNum, process);
getRunner().startProcess(lastMapNum, process);
}

private void setup(final int batchSize) {
Expand Down
64 changes: 23 additions & 41 deletions src/main/java/javacnn/util/ConcurenceRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,59 +24,41 @@ public ConcurenceRunner() {
exec = Executors.newFixedThreadPool(cpuNum);
}

public void stop() {
public void shutdown() {
exec.shutdown();
}

@Override
public void startProcess(final int mapNum, final Process process) {
new TaskManager(mapNum).start(process);
}

private void run(Runnable task) {
exec.execute(task);
}


private class TaskManager {
private int workLength;
final int runCpu = cpuNum < mapNum ? cpuNum : 1;

private TaskManager(int workLength) {
this.workLength = workLength;
}

private void start(final Process processor) {
int runCpu = cpuNum < workLength ? cpuNum : 1;

// Fragment length rounded up
final CountDownLatch gate = new CountDownLatch(runCpu);
// Fragment length rounded up
final CountDownLatch gate = new CountDownLatch(runCpu);

final int fregLength = (workLength + runCpu - 1) / runCpu;
final int fregLength = (mapNum + runCpu - 1) / runCpu;

for (int cpu = 0; cpu < runCpu; cpu++) {
final int start = cpu * fregLength;
for (int cpu = 0; cpu < runCpu; cpu++) {
final int start = cpu * fregLength;

final int tmp = (cpu + 1) * fregLength;
final int end = tmp <= workLength ? tmp : workLength;
final int tmp = (cpu + 1) * fregLength;
final int end = tmp <= mapNum ? tmp : mapNum;

final Runnable task = new Runnable() {
@Override
public void run() {
processor.process(start, end);
gate.countDown();
}
};
final Runnable task = new Runnable() {
@Override
public void run() {
process.process(start, end);
gate.countDown();
}
};

ConcurenceRunner.this.run(task);
}
try {// Wait for all threads to finish running
gate.await();
} catch (InterruptedException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
exec.execute(task);
}
try {// Wait for all threads to finish running
gate.await();
} catch (InterruptedException e) {
e.printStackTrace();
throw new RuntimeException(e);
}

}

}