diff --git a/.gitignore b/.gitignore
index ea24784..e33f790 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,18 +1,18 @@
-.idea/
-*.iml
-
-*.class
-
-# Mobile Tools for Java (J2ME)
-.mtj.tmp/
+# generated and compiled
+bin/
+gen/
-# Package Files #
-*.war
-*.ear
+# Gradle Build
+.gradle/
+*/build/
+build/
-Thumbs.db
-
-target/
+# setting
+.idea/
+local.properties
+.classpath
+.project
-# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
-hs_err_pid*
+#Android Studio
+*.iml
+*.iws
diff --git a/.mvn/wrapper/maven-wrapper.properties b/.mvn/wrapper/maven-wrapper.properties
index 56bb016..b5943ce 100644
--- a/.mvn/wrapper/maven-wrapper.properties
+++ b/.mvn/wrapper/maven-wrapper.properties
@@ -1 +1 @@
-distributionUrl=https://repo1.maven.org/maven2/org/apache/maven/apache-maven/3.5.0/apache-maven-3.5.0-bin.zip
\ No newline at end of file
+#distributionUrl=https://repo1.maven.org/maven2/org/apache/maven/apache-maven/3.5.0/apache-maven-3.5.0-bin.zip
\ No newline at end of file
diff --git a/build.gradle b/build.gradle
new file mode 100644
index 0000000..091f560
--- /dev/null
+++ b/build.gradle
@@ -0,0 +1,79 @@
+// Top-level build file where you can add configuration options common to all sub-projects/modules.
+
+buildscript {
+ repositories {
+ google()
+ jcenter()
+ maven {
+ url "https://jitpack.io"
+ }
+ }
+
+ dependencies {
+ classpath 'com.android.tools.build:gradle:3.5.0-beta01'
+
+ // NOTE: Do not place your application dependencies here; they belong
+ // in the individual module build.gradle files
+ }
+}
+
+allprojects {
+ repositories {
+ google()
+ jcenter()
+ maven {
+ url "https://jitpack.io"
+ }
+ }
+}
+
+apply plugin: 'com.android.library'
+
+android {
+ compileSdkVersion 28
+
+ defaultConfig {
+ minSdkVersion 24
+ targetSdkVersion 28
+ versionCode 1
+ versionName "1.0"
+ }
+
+ buildTypes {
+ release {
+ minifyEnabled false
+ proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
+ }
+ }
+
+ lintOptions {
+ abortOnError false
+ }
+
+ buildTypes {
+ debug {
+ testCoverageEnabled false
+ }
+ }
+
+ sourceSets {
+ main {
+ java {
+ // Merge source sets instead of adding rushcore as submodule so that the test coverage report works
+ srcDirs = ['src/main/java']
+ }
+ }
+ }
+ compileOptions {
+ sourceCompatibility JavaVersion.VERSION_1_8
+ targetCompatibility JavaVersion.VERSION_1_8
+ }
+}
+
+dependencies {
+ testImplementation ('junit:junit:4.12') { exclude module: 'hamcrest-core' }
+
+ implementation 'com.google.code.gson:gson:2.8.5'
+ implementation 'org.testng:testng:6.9.6'
+ implementation 'org.assertj:assertj-core:3.12.2'
+}
\ No newline at end of file
diff --git a/demo-tic-tac-toe/build.gradle b/demo-tic-tac-toe/build.gradle
new file mode 100644
index 0000000..80f302e
--- /dev/null
+++ b/demo-tic-tac-toe/build.gradle
@@ -0,0 +1,61 @@
+apply plugin: 'com.android.application'
+
+android {
+ compileSdkVersion 28
+
+ defaultConfig {
+ minSdkVersion 24
+ targetSdkVersion 28
+ versionCode 1
+ versionName "1.0"
+ }
+
+ buildTypes {
+ release {
+ minifyEnabled false
+ proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
+ }
+ }
+
+ lintOptions {
+ abortOnError false
+ }
+
+ buildTypes {
+ debug {
+ testCoverageEnabled false
+ }
+ }
+
+ sourceSets {
+ main {
+ java {
+ // Merge source sets instead of adding rushcore as submodule so that the test coverage report works
+ srcDirs = ['src/main/java']
+ }
+ }
+ }
+ compileOptions {
+ sourceCompatibility JavaVersion.VERSION_1_8
+ targetCompatibility JavaVersion.VERSION_1_8
+ }
+}
+
+dependencies {
+// implementation 'androidx.appcompat:appcompat:1.0.2'
+ testImplementation ('junit:junit:4.12') { exclude module: 'hamcrest-core' }
+ modules {
+ module("org.hamcrest:hamcrest-core") {
+ replacedBy("junit:junit", "Vous ")
+ }
+ }
+
+ implementation ('org.slf4j:slf4j-simple:1.8.0-beta4') { exclude module: 'junit' }
+
+ implementation 'org.testng:testng:6.9.6'
+ implementation 'org.assertj:assertj-core:3.12.2'
+ testImplementation project(path: ':')
+
+// implementation project(path: ':java-reinforcement-learning')
+ implementation project(path: ':')
+}
diff --git a/demo-tic-tac-toe/src/main/AndroidManifest.xml b/demo-tic-tac-toe/src/main/AndroidManifest.xml
new file mode 100644
index 0000000..e77dbbd
--- /dev/null
+++ b/demo-tic-tac-toe/src/main/AndroidManifest.xml
@@ -0,0 +1,2 @@
+
+
\ No newline at end of file
diff --git a/proguard-rules.pro b/proguard-rules.pro
new file mode 100644
index 0000000..92de838
--- /dev/null
+++ b/proguard-rules.pro
@@ -0,0 +1,18 @@
+# Add project specific ProGuard rules here.
+# By default, the flags in this file are appended to flags specified
+# in /Users/Stuart/Development/sdk/tools/proguard/proguard-android.txt
+# You can edit the include path and order by changing the proguardFiles
+# directive in build.gradle.
+#
+# For more details, see
+# http://developer.android.com/guide/developing/tools/proguard.html
+
+# Add any project specific keep options here:
+
+# If your project uses WebView with JS, uncomment the following
+# and specify the fully qualified class name to the JavaScript interface
+# class:
+#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
+# public *;
+#}
+
diff --git a/settings.gradle b/settings.gradle
new file mode 100644
index 0000000..2a773c9
--- /dev/null
+++ b/settings.gradle
@@ -0,0 +1,3 @@
+//include ':java-reinforcement-learning'
+include ':demo-tic-tac-toe'
+project(':demo-tic-tac-toe').projectDir = new File('demo-tic-tac-toe')
\ No newline at end of file
diff --git a/src/main/AndroidManifest.xml b/src/main/AndroidManifest.xml
new file mode 100644
index 0000000..f71a439
--- /dev/null
+++ b/src/main/AndroidManifest.xml
@@ -0,0 +1,2 @@
+
+
\ No newline at end of file
diff --git a/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java
index 7de7f9b..92a5a55 100644
--- a/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java
+++ b/src/main/java/com/github/chen0040/rl/actionselection/AbstractActionSelectionStrategy.java
@@ -8,66 +8,64 @@
import java.util.Map;
import java.util.Set;
-
/**
* Created by xschen on 9/27/2015 0027.
*/
public abstract class AbstractActionSelectionStrategy implements ActionSelectionStrategy {
- private String prototype;
- protected Map attributes = new HashMap();
-
- public String getPrototype(){
- return prototype;
- }
+ private String prototype;
+ protected Map attributes = new HashMap<>();
- public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
- return new IndexValue();
- }
+ public String getPrototype() {
+ return prototype;
+ }
- public IndexValue selectAction(int stateId, UtilityModel model, Set actionsAtState) {
- return new IndexValue();
- }
+ public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
+ return new IndexValue();
+ }
- public AbstractActionSelectionStrategy(){
- prototype = this.getClass().getCanonicalName();
- }
+ public IndexValue selectAction(int stateId, UtilityModel model, Set actionsAtState) {
+ return new IndexValue();
+ }
+ public AbstractActionSelectionStrategy() {
+ prototype = this.getClass().getCanonicalName();
+ }
- public AbstractActionSelectionStrategy(HashMap attributes){
- this.attributes = attributes;
- if(attributes.containsKey("prototype")){
- this.prototype = attributes.get("prototype");
- }
- }
+ public AbstractActionSelectionStrategy(HashMap attributes) {
+ this.attributes = attributes;
+ if (attributes.containsKey("prototype")) {
+ this.prototype = attributes.get("prototype");
+ }
+ }
- public Map getAttributes(){
- return attributes;
- }
+ public Map getAttributes() {
+ return attributes;
+ }
- @Override
- public boolean equals(Object obj) {
- ActionSelectionStrategy rhs = (ActionSelectionStrategy)obj;
- if(!prototype.equalsIgnoreCase(rhs.getPrototype())) return false;
- for(Map.Entry entry : rhs.getAttributes().entrySet()) {
- if(!attributes.containsKey(entry.getKey())) {
- return false;
- }
- if(!attributes.get(entry.getKey()).equals(entry.getValue())){
- return false;
- }
- }
- for(Map.Entry entry : attributes.entrySet()) {
- if(!rhs.getAttributes().containsKey(entry.getKey())) {
- return false;
- }
- if(!rhs.getAttributes().get(entry.getKey()).equals(entry.getValue())){
- return false;
- }
- }
- return true;
- }
+ @Override
+ public boolean equals(Object obj) {
+ ActionSelectionStrategy rhs = (ActionSelectionStrategy) obj;
+ if (!prototype.equalsIgnoreCase(rhs.getPrototype())) return false;
+ for (Map.Entry entry : rhs.getAttributes().entrySet()) {
+ if (!attributes.containsKey(entry.getKey())) {
+ return false;
+ }
+ if (!attributes.get(entry.getKey()).equals(entry.getValue())) {
+ return false;
+ }
+ }
+ for (Map.Entry entry : attributes.entrySet()) {
+ if (!rhs.getAttributes().containsKey(entry.getKey())) {
+ return false;
+ }
+ if (!rhs.getAttributes().get(entry.getKey()).equals(entry.getValue())) {
+ return false;
+ }
+ }
+ return true;
+ }
- @Override
- public abstract Object clone();
+ @Override
+ public abstract Object clone();
}
diff --git a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategy.java
index 51b6824..ff92269 100644
--- a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategy.java
+++ b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategy.java
@@ -9,13 +9,15 @@
import java.util.Map;
import java.util.Set;
-
/**
* Created by xschen on 9/27/2015 0027.
*/
public interface ActionSelectionStrategy extends Serializable, Cloneable {
- IndexValue selectAction(int stateId, QModel model, Set actionsAtState);
- IndexValue selectAction(int stateId, UtilityModel model, Set actionsAtState);
- String getPrototype();
- Map getAttributes();
+ IndexValue selectAction(int stateId, QModel model, Set actionsAtState);
+
+ IndexValue selectAction(int stateId, UtilityModel model, Set actionsAtState);
+
+ String getPrototype();
+
+ Map getAttributes();
}
diff --git a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java
index ce92be0..6159678 100644
--- a/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java
+++ b/src/main/java/com/github/chen0040/rl/actionselection/ActionSelectionStrategyFactory.java
@@ -3,57 +3,55 @@
import java.util.HashMap;
import java.util.Map;
-
/**
* Created by xschen on 9/27/2015 0027.
*/
public class ActionSelectionStrategyFactory {
- public static ActionSelectionStrategy deserialize(String conf){
- String[] comps = conf.split(";");
-
- HashMap attributes = new HashMap();
- for(int i=0; i < comps.length; ++i){
- String comp = comps[i];
- String[] field = comp.split("=");
- if(field.length < 2) continue;
- String fieldname = field[0].trim();
- String fieldvalue = field[1].trim();
-
- attributes.put(fieldname, fieldvalue);
- }
- if(attributes.isEmpty()){
- attributes.put("prototype", conf);
- }
-
- String prototype = attributes.get("prototype");
- if(prototype.equals(GreedyActionSelectionStrategy.class.getCanonicalName())){
- return new GreedyActionSelectionStrategy();
- } else if(prototype.equals(SoftMaxActionSelectionStrategy.class.getCanonicalName())){
- return new SoftMaxActionSelectionStrategy();
- } else if(prototype.equals(EpsilonGreedyActionSelectionStrategy.class.getCanonicalName())){
- return new EpsilonGreedyActionSelectionStrategy(attributes);
- } else if(prototype.equals(GibbsSoftMaxActionSelectionStrategy.class.getCanonicalName())){
- return new GibbsSoftMaxActionSelectionStrategy();
- }
-
- return null;
- }
-
- public static String serialize(ActionSelectionStrategy strategy){
- Map attributes = strategy.getAttributes();
- attributes.put("prototype", strategy.getPrototype());
-
- StringBuilder sb = new StringBuilder();
- boolean first = true;
- for(Map.Entry entry : attributes.entrySet()){
- if(first){
- first = false;
- }
- else{
- sb.append(";");
- }
- sb.append(entry.getKey()+"="+entry.getValue());
- }
- return sb.toString();
- }
+ public static ActionSelectionStrategy deserialize(String conf) {
+ String[] comps = conf.split(";");
+
+ HashMap attributes = new HashMap<>();
+ for (String comp : comps) {
+ String[] field = comp.split("=");
+ if (field.length < 2) continue;
+ String fieldname = field[0].trim();
+ String fieldvalue = field[1].trim();
+
+ attributes.put(fieldname, fieldvalue);
+ }
+ if (attributes.isEmpty()) {
+ attributes.put("prototype", conf);
+ }
+
+ String prototype = attributes.get("prototype");
+ if (prototype.equals(GreedyActionSelectionStrategy.class.getCanonicalName())) {
+ return new GreedyActionSelectionStrategy();
+ } else if (prototype.equals(SoftMaxActionSelectionStrategy.class.getCanonicalName())) {
+ return new SoftMaxActionSelectionStrategy();
+ } else if (prototype
+ .equals(EpsilonGreedyActionSelectionStrategy.class.getCanonicalName())) {
+ return new EpsilonGreedyActionSelectionStrategy(attributes);
+ } else if (prototype.equals(GibbsSoftMaxActionSelectionStrategy.class.getCanonicalName())) {
+ return new GibbsSoftMaxActionSelectionStrategy();
+ }
+
+ return null;
+ }
+
+ public static String serialize(ActionSelectionStrategy strategy) {
+ Map attributes = strategy.getAttributes();
+ attributes.put("prototype", strategy.getPrototype());
+
+ StringBuilder sb = new StringBuilder();
+ boolean first = true;
+ for (Map.Entry entry : attributes.entrySet()) {
+ if (first) {
+ first = false;
+ } else {
+ sb.append(";");
+ }
+ sb.append(entry.getKey() + "=" + entry.getValue());
+ }
+ return sb.toString();
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java
index 5f7db9a..3d2e4e7 100644
--- a/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java
+++ b/src/main/java/com/github/chen0040/rl/actionselection/EpsilonGreedyActionSelectionStrategy.java
@@ -5,75 +5,74 @@
import java.util.*;
-
/**
* Created by xschen on 9/27/2015 0027.
*/
public class EpsilonGreedyActionSelectionStrategy extends AbstractActionSelectionStrategy {
- public static final String EPSILON = "epsilon";
- private Random random = new Random();
+ public static final String EPSILON = "epsilon";
+ private Random random = new Random();
- @Override
- public Object clone(){
- EpsilonGreedyActionSelectionStrategy clone = new EpsilonGreedyActionSelectionStrategy();
- clone.copy(this);
- return clone;
- }
+ @Override
+ public Object clone() {
+ EpsilonGreedyActionSelectionStrategy clone = new EpsilonGreedyActionSelectionStrategy();
+ clone.copy(this);
+ return clone;
+ }
- public void copy(EpsilonGreedyActionSelectionStrategy rhs){
- random = rhs.random;
- for(Map.Entry entry : rhs.attributes.entrySet()){
- attributes.put(entry.getKey(), entry.getValue());
- }
- }
+ public void copy(EpsilonGreedyActionSelectionStrategy rhs) {
+ random = rhs.random;
+ for (Map.Entry entry : rhs.attributes.entrySet()) {
+ attributes.put(entry.getKey(), entry.getValue());
+ }
+ }
- @Override
- public boolean equals(Object obj){
- if(obj != null && obj instanceof EpsilonGreedyActionSelectionStrategy){
- EpsilonGreedyActionSelectionStrategy rhs = (EpsilonGreedyActionSelectionStrategy)obj;
- if(epsilon() != rhs.epsilon()) return false;
- // if(!random.equals(rhs.random)) return false;
- return true;
- }
- return false;
- }
+ @Override
+ public boolean equals(Object obj) {
+ if (obj != null && obj instanceof EpsilonGreedyActionSelectionStrategy) {
+ EpsilonGreedyActionSelectionStrategy rhs = (EpsilonGreedyActionSelectionStrategy) obj;
+ if (epsilon() != rhs.epsilon()) return false;
+ // if(!random.equals(rhs.random)) return false;
+ return true;
+ }
+ return false;
+ }
- private double epsilon(){
- return Double.parseDouble(attributes.get(EPSILON));
- }
+ private double epsilon() {
+ return Double.parseDouble(attributes.get(EPSILON));
+ }
- public EpsilonGreedyActionSelectionStrategy(){
- epsilon(0.1);
- }
+ public EpsilonGreedyActionSelectionStrategy() {
+ epsilon(0.1);
+ }
- public EpsilonGreedyActionSelectionStrategy(HashMap attributes){
- super(attributes);
- }
+ public EpsilonGreedyActionSelectionStrategy(HashMap attributes) {
+ super(attributes);
+ }
- private void epsilon(double value){
- attributes.put(EPSILON, "" + value);
- }
+ private void epsilon(double value) {
+ attributes.put(EPSILON, "" + value);
+ }
- public EpsilonGreedyActionSelectionStrategy(Random random){
- this.random = random;
- epsilon(0.1);
- }
+ public EpsilonGreedyActionSelectionStrategy(Random random) {
+ this.random = random;
+ epsilon(0.1);
+ }
- @Override
- public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
- if(random.nextDouble() < 1- epsilon()){
- return model.actionWithMaxQAtState(stateId, actionsAtState);
- }else{
- int actionId;
- if(actionsAtState != null && !actionsAtState.isEmpty()) {
- List actions = new ArrayList<>(actionsAtState);
- actionId = actions.get(random.nextInt(actions.size()));
- } else {
- actionId = random.nextInt(model.getActionCount());
- }
+ @Override
+ public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
+ if (random.nextDouble() < 1 - epsilon()) {
+ return model.actionWithMaxQAtState(stateId, actionsAtState);
+ } else {
+ int actionId;
+ if (actionsAtState != null && !actionsAtState.isEmpty()) {
+ List actions = new ArrayList<>(actionsAtState);
+ actionId = actions.get(random.nextInt(actions.size()));
+ } else {
+ actionId = random.nextInt(model.getActionCount());
+ }
- double Q = model.getQ(stateId, actionId);
- return new IndexValue(actionId, Q);
- }
- }
+ double Q = model.getQ(stateId, actionId);
+ return new IndexValue(actionId, Q);
+ }
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java
index 8b2d8d2..12f1573 100644
--- a/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java
+++ b/src/main/java/com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.java
@@ -8,64 +8,62 @@
import java.util.Random;
import java.util.Set;
-
/**
* Created by xschen on 9/28/2015 0028.
*/
public class GibbsSoftMaxActionSelectionStrategy extends AbstractActionSelectionStrategy {
- private Random random = null;
- public GibbsSoftMaxActionSelectionStrategy(){
- random = new Random();
- }
+ private Random random = null;
+
+ public GibbsSoftMaxActionSelectionStrategy() {
+ random = new Random();
+ }
- public GibbsSoftMaxActionSelectionStrategy(Random random){
- this.random = random;
- }
+ public GibbsSoftMaxActionSelectionStrategy(Random random) {
+ this.random = random;
+ }
- @Override
- public Object clone() {
- GibbsSoftMaxActionSelectionStrategy clone = new GibbsSoftMaxActionSelectionStrategy();
- return clone;
- }
+ @Override
+ public Object clone() {
+ GibbsSoftMaxActionSelectionStrategy clone = new GibbsSoftMaxActionSelectionStrategy();
+ return clone;
+ }
- @Override
- public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
- List actions = new ArrayList();
- if(actionsAtState == null){
- for(int i=0; i < model.getActionCount(); ++i){
- actions.add(i);
- }
- }else{
- for(Integer actionId : actionsAtState){
- actions.add(actionId);
- }
- }
+ @Override
+ public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
+ List actions = new ArrayList<>();
+ if (actionsAtState == null) {
+ for (int i = 0; i < model.getActionCount(); ++i) {
+ actions.add(i);
+ }
+ } else {
+ actions.addAll(actionsAtState);
+ }
- double sum = 0;
- List plist = new ArrayList();
- for(int i=0; i < actions.size(); ++i){
- int actionId = actions.get(i);
- double p = Math.exp(model.getQ(stateId, actionId));
- sum += p;
- plist.add(sum);
- }
+ double sum = 0;
+ List plist = new ArrayList<>();
+ for (int i = 0; i < actions.size(); ++i) {
+ int actionId = actions.get(i);
+ double p = Math.exp(model.getQ(stateId, actionId));
+ sum += p;
+ plist.add(sum);
+ }
- IndexValue iv = new IndexValue();
- iv.setIndex(-1);
- iv.setValue(Double.NEGATIVE_INFINITY);
+ IndexValue iv = new IndexValue();
+ iv.setIndex(-1);
+ iv.setValue(Double.NEGATIVE_INFINITY);
- double r = sum * random.nextDouble();
- for(int i=0; i < actions.size(); ++i){
+ double r = sum * random.nextDouble();
+ for (int i = 0; i < actions.size(); ++i) {
- if(plist.get(i) >= r){
- int actionId = actions.get(i);
- iv.setValue(model.getQ(stateId, actionId));
- iv.setIndex(actionId);
- break;
- }
- }
+ if (plist.get(i) >= r) {
+ int actionId = actions.get(i);
+ iv.setValue(model.getQ(stateId, actionId));
+ iv.setIndex(actionId);
+ break;
+ }
+ }
- return iv;
- }
+ return iv;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java
index 6b0f350..8d6d7f3 100644
--- a/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java
+++ b/src/main/java/com/github/chen0040/rl/actionselection/GreedyActionSelectionStrategy.java
@@ -5,24 +5,23 @@
import java.util.Set;
-
/**
* Created by xschen on 9/27/2015 0027.
*/
public class GreedyActionSelectionStrategy extends AbstractActionSelectionStrategy {
- @Override
- public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
- return model.actionWithMaxQAtState(stateId, actionsAtState);
- }
+ @Override
+ public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
+ return model.actionWithMaxQAtState(stateId, actionsAtState);
+ }
- @Override
- public Object clone(){
- GreedyActionSelectionStrategy clone = new GreedyActionSelectionStrategy();
- return clone;
- }
+ @Override
+ public Object clone() {
+ GreedyActionSelectionStrategy clone = new GreedyActionSelectionStrategy();
+ return clone;
+ }
- @Override
- public boolean equals(Object obj){
- return obj != null && obj instanceof GreedyActionSelectionStrategy;
- }
+ @Override
+ public boolean equals(Object obj) {
+ return obj != null && obj instanceof GreedyActionSelectionStrategy;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java b/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java
index f9735b9..51ef128 100644
--- a/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java
+++ b/src/main/java/com/github/chen0040/rl/actionselection/SoftMaxActionSelectionStrategy.java
@@ -6,34 +6,33 @@
import java.util.Random;
import java.util.Set;
-
/**
* Created by xschen on 9/27/2015 0027.
*/
public class SoftMaxActionSelectionStrategy extends AbstractActionSelectionStrategy {
- private Random random = new Random();
+ private Random random = new Random();
- @Override
- public Object clone(){
- SoftMaxActionSelectionStrategy clone = new SoftMaxActionSelectionStrategy(random);
- return clone;
- }
+ @Override
+ public Object clone() {
+ SoftMaxActionSelectionStrategy clone = new SoftMaxActionSelectionStrategy(random);
+ return clone;
+ }
- @Override
- public boolean equals(Object obj){
- return obj != null && obj instanceof SoftMaxActionSelectionStrategy;
- }
+ @Override
+ public boolean equals(Object obj) {
+ return obj != null && obj instanceof SoftMaxActionSelectionStrategy;
+ }
- public SoftMaxActionSelectionStrategy(){
+ public SoftMaxActionSelectionStrategy() {
- }
+ }
- public SoftMaxActionSelectionStrategy(Random random){
- this.random = random;
- }
+ public SoftMaxActionSelectionStrategy(Random random) {
+ this.random = random;
+ }
- @Override
- public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
- return model.actionWithSoftMaxQAtState(stateId, actionsAtState, random);
- }
+ @Override
+ public IndexValue selectAction(int stateId, QModel model, Set actionsAtState) {
+ return model.actionWithSoftMaxQAtState(stateId, actionsAtState, random);
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java
index 6e34874..f262afe 100644
--- a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java
+++ b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgent.java
@@ -7,95 +7,91 @@
import java.util.Set;
import java.util.function.Function;
-
/**
* Created by chen0469 on 9/28/2015 0028.
*/
public class ActorCriticAgent implements Serializable {
- private ActorCriticLearner learner;
- private int currentState;
- private int prevState;
- private int prevAction;
-
- public void enableEligibilityTrace(double lambda){
- ActorCriticLambdaLearner acll = new ActorCriticLambdaLearner(learner);
- acll.setLambda(lambda);
- learner = acll;
- }
-
- public void start(int stateId){
- currentState = stateId;
- prevAction = -1;
- prevState = -1;
- }
-
- public ActorCriticLearner getLearner(){
- return learner;
- }
-
- public void setLearner(ActorCriticLearner learner){
- this.learner = learner;
- }
-
- public ActorCriticAgent(int stateCount, int actionCount){
- learner = new ActorCriticLearner(stateCount, actionCount);
- }
-
- public ActorCriticAgent(){
-
- }
-
- public ActorCriticAgent(ActorCriticLearner learner){
- this.learner = learner;
- }
-
- public ActorCriticAgent makeCopy(){
- ActorCriticAgent clone = new ActorCriticAgent();
- clone.copy(this);
- return clone;
- }
-
- public void copy(ActorCriticAgent rhs){
- learner = (ActorCriticLearner)rhs.learner.makeCopy();
- prevAction = rhs.prevAction;
- prevState = rhs.prevState;
- currentState = rhs.currentState;
- }
-
- @Override
- public boolean equals(Object obj){
- if(obj != null && obj instanceof ActorCriticAgent){
- ActorCriticAgent rhs = (ActorCriticAgent)obj;
- return learner.equals(rhs.learner) && prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState;
-
- }
- return false;
- }
-
- public int selectAction(Set actionsAtState){
- return learner.selectAction(currentState, actionsAtState);
- }
-
- public int selectAction(){
- return learner.selectAction(currentState);
- }
-
- public void update(int actionTaken, int newState, double immediateReward, final Vec V){
- update(actionTaken, newState, null, immediateReward, V);
- }
-
- public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward, final Vec V){
-
- learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward, new Function() {
- public Double apply(Integer stateId) {
- return V.get(stateId);
- }
- });
-
- prevAction = actionTaken;
- prevState = currentState;
-
- currentState = newState;
- }
+ private ActorCriticLearner learner;
+ private int currentState;
+ private int prevState;
+ private int prevAction;
+
+ public void enableEligibilityTrace(double lambda) {
+ ActorCriticLambdaLearner acll = new ActorCriticLambdaLearner(learner);
+ acll.setLambda(lambda);
+ learner = acll;
+ }
+
+ public void start(int stateId) {
+ currentState = stateId;
+ prevAction = -1;
+ prevState = -1;
+ }
+
+ public ActorCriticLearner getLearner() {
+ return learner;
+ }
+
+ public void setLearner(ActorCriticLearner learner) {
+ this.learner = learner;
+ }
+
+ public ActorCriticAgent(int stateCount, int actionCount) {
+ learner = new ActorCriticLearner(stateCount, actionCount);
+ }
+
+ public ActorCriticAgent() {
+
+ }
+
+ public ActorCriticAgent(ActorCriticLearner learner) {
+ this.learner = learner;
+ }
+
+ public ActorCriticAgent makeCopy() {
+ ActorCriticAgent clone = new ActorCriticAgent();
+ clone.copy(this);
+ return clone;
+ }
+
+ public void copy(ActorCriticAgent rhs) {
+ learner = (ActorCriticLearner) rhs.learner.makeCopy();
+ prevAction = rhs.prevAction;
+ prevState = rhs.prevState;
+ currentState = rhs.currentState;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj != null && obj instanceof ActorCriticAgent) {
+ ActorCriticAgent rhs = (ActorCriticAgent) obj;
+ return learner
+ .equals(rhs.learner) && prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState;
+
+ }
+ return false;
+ }
+
+ public int selectAction(Set actionsAtState) {
+ return learner.selectAction(currentState, actionsAtState);
+ }
+
+ public int selectAction() {
+ return learner.selectAction(currentState);
+ }
+
+ public void update(int actionTaken, int newState, double immediateReward, final Vec V) {
+ update(actionTaken, newState, null, immediateReward, V);
+ }
+
+ public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward, final Vec V) {
+
+ learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward, V::get);
+
+ prevAction = actionTaken;
+ prevState = currentState;
+
+ currentState = newState;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLambdaLearner.java b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLambdaLearner.java
index d68f978..1ac041d 100644
--- a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLambdaLearner.java
+++ b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLambdaLearner.java
@@ -1,129 +1,123 @@
package com.github.chen0040.rl.learning.actorcritic;
-
-import com.alibaba.fastjson.JSON;
-import com.alibaba.fastjson.serializer.SerializerFeature;
+//import com.alibaba.fastjson.JSON;
+//import com.alibaba.fastjson.serializer.SerializerFeature;
import com.github.chen0040.rl.models.EligibilityTraceUpdateMode;
import com.github.chen0040.rl.utils.Matrix;
import java.util.Set;
import java.util.function.Function;
-
/**
* Created by chen0469 on 9/28/2015 0028.
*/
public class ActorCriticLambdaLearner extends ActorCriticLearner {
- private Matrix e;
- private double lambda = 0.9;
- private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
-
- public ActorCriticLambdaLearner(){
- super();
- }
-
- public ActorCriticLambdaLearner(int stateCount, int actionCount){
- super(stateCount, actionCount);
- e = new Matrix(stateCount, actionCount);
- }
-
-
-
- public ActorCriticLambdaLearner(ActorCriticLearner learner){
- copy(learner);
- e = new Matrix(P.getStateCount(), P.getActionCount());
- }
-
- public ActorCriticLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double lambda, double initialP){
- super(stateCount, actionCount, alpha, gamma, initialP);
- this.lambda = lambda;
- e = new Matrix(stateCount, actionCount);
- }
-
- public EligibilityTraceUpdateMode getTraceUpdateMode() {
- return traceUpdateMode;
- }
-
- public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) {
- this.traceUpdateMode = traceUpdateMode;
- }
-
- public double getLambda(){
- return lambda;
- }
-
- public void setLambda(double lambda){
- this.lambda = lambda;
- }
-
- public ActorCriticLambdaLearner makeCopy(){
- ActorCriticLambdaLearner clone = new ActorCriticLambdaLearner();
- clone.copy(this);
- return clone;
- }
-
- @Override
- public void copy(ActorCriticLearner rhs){
- super.copy(rhs);
-
- ActorCriticLambdaLearner rhs2 = (ActorCriticLambdaLearner)rhs;
- e = rhs2.e.makeCopy();
- lambda = rhs2.lambda;
- traceUpdateMode = rhs2.traceUpdateMode;
- }
-
- @Override
- public boolean equals(Object obj){
- if(!super.equals(obj)){
- return false;
- }
-
- if(obj instanceof ActorCriticLambdaLearner){
- ActorCriticLambdaLearner rhs = (ActorCriticLambdaLearner)obj;
- return e.equals(rhs.e) && lambda == rhs.lambda && traceUpdateMode == rhs.traceUpdateMode;
- }
-
- return false;
- }
-
- public Matrix getEligibility(){
- return e;
- }
-
- public void setEligibility(Matrix e){
- this.e = e;
- }
-
- @Override
- public void update(int currentStateId, int currentActionId, int newStateId, Set actionsAtNewState, double immediateReward, Function V){
-
- double td_error = immediateReward + V.apply(newStateId) - V.apply(currentStateId);
-
- int stateCount = P.getStateCount();
- int actionCount = P.getActionCount();
-
- double gamma = P.getGamma();
-
- e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1);
-
-
- for(int stateId = 0; stateId < stateCount; ++stateId){
- for(int actionId = 0; actionId < actionCount; ++actionId){
-
- double oldP = P.getQ(stateId, actionId);
- double alpha = P.getAlpha(currentStateId, currentActionId);
- double newP = oldP + alpha * td_error * e.get(stateId, actionId);
-
- P.setQ(stateId, actionId, newP);
-
- if (actionId != currentActionId) {
- e.set(currentStateId, actionId, 0);
- } else {
- e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda);
- }
- }
- }
- }
-
+ private Matrix e;
+ private double lambda = 0.9;
+ private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
+
+ public ActorCriticLambdaLearner() {
+ super();
+ }
+
+ public ActorCriticLambdaLearner(int stateCount, int actionCount) {
+ super(stateCount, actionCount);
+ e = new Matrix(stateCount, actionCount);
+ }
+
+ public ActorCriticLambdaLearner(ActorCriticLearner learner) {
+ copy(learner);
+ e = new Matrix(P.getStateCount(), P.getActionCount());
+ }
+
+ public ActorCriticLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double lambda, double initialP) {
+ super(stateCount, actionCount, alpha, gamma, initialP);
+ this.lambda = lambda;
+ e = new Matrix(stateCount, actionCount);
+ }
+
+ public EligibilityTraceUpdateMode getTraceUpdateMode() {
+ return traceUpdateMode;
+ }
+
+ public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) {
+ this.traceUpdateMode = traceUpdateMode;
+ }
+
+ public double getLambda() {
+ return lambda;
+ }
+
+ public void setLambda(double lambda) {
+ this.lambda = lambda;
+ }
+
+ public ActorCriticLambdaLearner makeCopy() {
+ ActorCriticLambdaLearner clone = new ActorCriticLambdaLearner();
+ clone.copy(this);
+ return clone;
+ }
+
+ @Override
+ public void copy(ActorCriticLearner rhs) {
+ super.copy(rhs);
+
+ ActorCriticLambdaLearner rhs2 = (ActorCriticLambdaLearner) rhs;
+ e = rhs2.e.makeCopy();
+ lambda = rhs2.lambda;
+ traceUpdateMode = rhs2.traceUpdateMode;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (!super.equals(obj)) {
+ return false;
+ }
+
+ if (obj instanceof ActorCriticLambdaLearner) {
+ ActorCriticLambdaLearner rhs = (ActorCriticLambdaLearner) obj;
+ return e.equals(rhs.e) && lambda == rhs.lambda && traceUpdateMode == rhs.traceUpdateMode;
+ }
+
+ return false;
+ }
+
+ public Matrix getEligibility() {
+ return e;
+ }
+
+ public void setEligibility(Matrix e) {
+ this.e = e;
+ }
+
+ @Override
+ public void update(int currentStateId, int currentActionId, int newStateId, Set actionsAtNewState, double immediateReward, Function V) {
+
+ double td_error = immediateReward + V.apply(newStateId) - V.apply(currentStateId);
+
+ int stateCount = P.getStateCount();
+ int actionCount = P.getActionCount();
+
+ double gamma = P.getGamma();
+
+ e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1);
+
+ for (int stateId = 0; stateId < stateCount; ++stateId) {
+ for (int actionId = 0; actionId < actionCount; ++actionId) {
+
+ double oldP = P.getQ(stateId, actionId);
+ double alpha = P.getAlpha(currentStateId, currentActionId);
+ double newP = oldP + alpha * td_error * e.get(stateId, actionId);
+
+ P.setQ(stateId, actionId, newP);
+
+ if (actionId != currentActionId) {
+ e.set(currentStateId, actionId, 0);
+ } else {
+ e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda);
+ }
+ }
+ }
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearner.java b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearner.java
index d106c94..2a25094 100644
--- a/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearner.java
+++ b/src/main/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearner.java
@@ -1,8 +1,7 @@
package com.github.chen0040.rl.learning.actorcritic;
-
-import com.alibaba.fastjson.JSON;
-import com.alibaba.fastjson.serializer.SerializerFeature;
+//import com.alibaba.fastjson.JSON;
+//import com.alibaba.fastjson.serializer.SerializerFeature;
import com.github.chen0040.rl.actionselection.AbstractActionSelectionStrategy;
import com.github.chen0040.rl.actionselection.ActionSelectionStrategy;
import com.github.chen0040.rl.actionselection.ActionSelectionStrategyFactory;
@@ -10,100 +9,102 @@
import com.github.chen0040.rl.models.QModel;
import com.github.chen0040.rl.utils.IndexValue;
import com.github.chen0040.rl.utils.Vec;
+import com.google.gson.Gson;
import java.io.Serializable;
import java.util.Random;
import java.util.Set;
import java.util.function.Function;
-
/**
* Created by chen0469 on 9/28/2015 0028.
*/
-public class ActorCriticLearner implements Serializable{
- protected QModel P;
- protected ActionSelectionStrategy actionSelectionStrategy;
-
- public String toJson() {
- return JSON.toJSONString(this, SerializerFeature.BrowserCompatible);
- }
-
- public static ActorCriticLearner fromJson(String json){
- return JSON.parseObject(json, ActorCriticLearner.class);
- }
-
- public Object makeCopy(){
- ActorCriticLearner clone = new ActorCriticLearner();
- clone.copy(this);
- return clone;
- }
-
- public void copy(ActorCriticLearner rhs){
- P = rhs.P.makeCopy();
- actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy)rhs.actionSelectionStrategy).clone();
- }
-
- @Override
- public boolean equals(Object obj){
- if(obj != null && obj instanceof ActorCriticLearner){
- ActorCriticLearner rhs = (ActorCriticLearner)obj;
- return P.equals(rhs.P) && getActionSelection().equals(rhs.getActionSelection());
- }
- return false;
- }
-
- public ActorCriticLearner(){
-
- }
-
- public ActorCriticLearner(int stateCount, int actionCount){
- this(stateCount, actionCount, 1, 0.7, 0.01);
- }
-
- public int selectAction(int stateId, Set actionsAtState){
- IndexValue iv = actionSelectionStrategy.selectAction(stateId, P, actionsAtState);
- return iv.getIndex();
- }
-
- public int selectAction(int stateId){
- return selectAction(stateId, null);
- }
-
- public ActorCriticLearner(int stateCount, int actionCount, double beta, double gamma, double initialP){
- P = new QModel(stateCount, actionCount, initialP);
- P.setAlpha(beta);
- P.setGamma(gamma);
-
- actionSelectionStrategy = new GibbsSoftMaxActionSelectionStrategy();
- }
-
- public void update(int currentStateId, int currentActionId, int newStateId, double immediateReward, Function V){
- update(currentStateId, currentActionId, newStateId, null, immediateReward, V);
- }
-
- public void update(int currentStateId, int currentActionId, int newStateId,Set actionsAtNewState, double immediateReward, Function V){
- double td_error = immediateReward + V.apply(newStateId) - V.apply(currentStateId);
-
- double oldP = P.getQ(currentStateId, currentActionId);
- double beta = P.getAlpha(currentStateId, currentActionId);
- double newP = oldP + beta * td_error;
- P.setQ(currentStateId, currentActionId, newP);
- }
-
- public String getActionSelection() {
- return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy);
- }
-
- public void setActionSelection(String conf) {
- this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf);
- }
-
-
- public QModel getP() {
- return P;
- }
-
- public void setP(QModel p) {
- P = p;
- }
+public class ActorCriticLearner implements Serializable {
+ protected QModel P;
+ protected ActionSelectionStrategy actionSelectionStrategy;
+
+ public String toJson() {
+ return new Gson().toJson(this);
+// return JSON.toJSONString(this, SerializerFeature.BrowserCompatible);
+ }
+
+ public static ActorCriticLearner fromJson(String json) {
+ return new Gson().fromJson(json, ActorCriticLearner.class);
+// return JSON.parseObject(json, ActorCriticLearner.class);
+ }
+
+ public Object makeCopy() {
+ ActorCriticLearner clone = new ActorCriticLearner();
+ clone.copy(this);
+ return clone;
+ }
+
+ public void copy(ActorCriticLearner rhs) {
+ P = rhs.P.makeCopy();
+ actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy)
+ .clone();
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj != null && obj instanceof ActorCriticLearner) {
+ ActorCriticLearner rhs = (ActorCriticLearner) obj;
+ return P.equals(rhs.P) && getActionSelection().equals(rhs.getActionSelection());
+ }
+ return false;
+ }
+
+ public ActorCriticLearner() {
+
+ }
+
+ public ActorCriticLearner(int stateCount, int actionCount) {
+ this(stateCount, actionCount, 1, 0.7, 0.01);
+ }
+
+ public int selectAction(int stateId, Set actionsAtState) {
+ IndexValue iv = actionSelectionStrategy.selectAction(stateId, P, actionsAtState);
+ return iv.getIndex();
+ }
+
+ public int selectAction(int stateId) {
+ return selectAction(stateId, null);
+ }
+
+ public ActorCriticLearner(int stateCount, int actionCount, double beta, double gamma, double initialP) {
+ P = new QModel(stateCount, actionCount, initialP);
+ P.setAlpha(beta);
+ P.setGamma(gamma);
+
+ actionSelectionStrategy = new GibbsSoftMaxActionSelectionStrategy();
+ }
+
+ public void update(int currentStateId, int currentActionId, int newStateId, double immediateReward, Function V) {
+ update(currentStateId, currentActionId, newStateId, null, immediateReward, V);
+ }
+
+ public void update(int currentStateId, int currentActionId, int newStateId, Set actionsAtNewState, double immediateReward, Function V) {
+ double td_error = immediateReward + V.apply(newStateId) - V.apply(currentStateId);
+
+ double oldP = P.getQ(currentStateId, currentActionId);
+ double beta = P.getAlpha(currentStateId, currentActionId);
+ double newP = oldP + beta * td_error;
+ P.setQ(currentStateId, currentActionId, newP);
+ }
+
+ public String getActionSelection() {
+ return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy);
+ }
+
+ public void setActionSelection(String conf) {
+ this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf);
+ }
+
+ public QModel getP() {
+ return P;
+ }
+
+ public void setP(QModel p) {
+ P = p;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java
index afdb314..8b1a84e 100644
--- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java
+++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QAgent.java
@@ -6,107 +6,109 @@
import java.util.Random;
import java.util.Set;
-
/**
* Created by xschen on 9/27/2015 0027.
*/
-public class QAgent implements Serializable{
- private QLearner learner;
- private int currentState;
- private int prevState;
-
- /** action taken at prevState */
- private int prevAction;
-
- public int getCurrentState(){
- return currentState;
- }
-
- public int getPrevState(){
- return prevState;
- }
-
- public int getPrevAction(){
- return prevAction;
- }
-
- public void start(int currentState){
- this.currentState = currentState;
- this.prevAction = -1;
- this.prevState = -1;
- }
-
- public IndexValue selectAction(){
- return learner.selectAction(currentState);
- }
-
- public IndexValue selectAction(Set actionsAtState){
- return learner.selectAction(currentState, actionsAtState);
- }
-
- public void update(int actionTaken, int newState, double immediateReward){
- update(actionTaken, newState, null, immediateReward);
- }
-
- public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward){
-
- learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward);
-
- prevState = currentState;
- prevAction = actionTaken;
-
- currentState = newState;
- }
-
- public void enableEligibilityTrace(double lambda){
- QLambdaLearner acll = new QLambdaLearner(learner);
- acll.setLambda(lambda);
- learner = acll;
- }
-
- public QLearner getLearner(){
- return learner;
- }
-
- public void setLearner(QLearner learner){
- this.learner = learner;
- }
-
- public QAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ){
- learner = new QLearner(stateCount, actionCount, alpha, gamma, initialQ);
- }
-
- public QAgent(QLearner learner){
- this.learner = learner;
- }
-
- public QAgent(int stateCount, int actionCount){
- learner = new QLearner(stateCount, actionCount);
- }
-
- public QAgent(){
-
- }
-
- public QAgent makeCopy(){
- QAgent clone = new QAgent();
- clone.copy(this);
- return clone;
- }
-
- public void copy(QAgent rhs){
- learner.copy(rhs.learner);
- prevAction = rhs.prevAction;
- prevState = rhs.prevState;
- currentState = rhs.currentState;
- }
-
- @Override
- public boolean equals(Object obj){
- if(obj != null && obj instanceof QAgent){
- QAgent rhs = (QAgent)obj;
- return prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState && learner.equals(rhs.learner);
- }
- return false;
- }
+public class QAgent implements Serializable {
+ private QLearner learner;
+ private int currentState;
+ private int prevState;
+
+ /**
+ * action taken at prevState
+ */
+ private int prevAction;
+
+ public int getCurrentState() {
+ return currentState;
+ }
+
+ public int getPrevState() {
+ return prevState;
+ }
+
+ public int getPrevAction() {
+ return prevAction;
+ }
+
+ public void start(int currentState) {
+ this.currentState = currentState;
+ this.prevAction = -1;
+ this.prevState = -1;
+ }
+
+ public IndexValue selectAction() {
+ return learner.selectAction(currentState);
+ }
+
+ public IndexValue selectAction(Set actionsAtState) {
+ return learner.selectAction(currentState, actionsAtState);
+ }
+
+ public void update(int actionTaken, int newState, double immediateReward) {
+ update(actionTaken, newState, null, immediateReward);
+ }
+
+ public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward) {
+
+ learner.update(currentState, actionTaken, newState, actionsAtNewState, immediateReward);
+
+ prevState = currentState;
+ prevAction = actionTaken;
+
+ currentState = newState;
+ }
+
+ public void enableEligibilityTrace(double lambda) {
+ QLambdaLearner acll = new QLambdaLearner(learner);
+ acll.setLambda(lambda);
+ learner = acll;
+ }
+
+ public QLearner getLearner() {
+ return learner;
+ }
+
+ public void setLearner(QLearner learner) {
+ this.learner = learner;
+ }
+
+ public QAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ) {
+ learner = new QLearner(stateCount, actionCount, alpha, gamma, initialQ);
+ }
+
+ public QAgent(QLearner learner) {
+ this.learner = learner;
+ }
+
+ public QAgent(int stateCount, int actionCount) {
+ learner = new QLearner(stateCount, actionCount);
+ }
+
+ public QAgent() {
+
+ }
+
+ public QAgent makeCopy() {
+ QAgent clone = new QAgent();
+ clone.copy(this);
+ return clone;
+ }
+
+ public void copy(QAgent rhs) {
+ learner.copy(rhs.learner);
+ prevAction = rhs.prevAction;
+ prevState = rhs.prevState;
+ currentState = rhs.currentState;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj != null && obj instanceof QAgent) {
+ QAgent rhs = (QAgent) obj;
+ return prevAction == rhs.prevAction && prevState == rhs.prevState && currentState == rhs.currentState && learner
+ .equals(rhs.learner);
+ }
+ return false;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java
index 875ef3a..df75cd2 100644
--- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java
+++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLambdaLearner.java
@@ -1,135 +1,129 @@
package com.github.chen0040.rl.learning.qlearn;
-
import com.github.chen0040.rl.models.EligibilityTraceUpdateMode;
import com.github.chen0040.rl.utils.Matrix;
import java.util.Set;
-
/**
* Created by xschen on 9/28/2015 0028.
*/
public class QLambdaLearner extends QLearner {
- private double lambda = 0.9;
- private Matrix e;
- private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
-
- public EligibilityTraceUpdateMode getTraceUpdateMode() {
- return traceUpdateMode;
- }
-
- public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) {
- this.traceUpdateMode = traceUpdateMode;
- }
-
- public double getLambda(){
- return lambda;
- }
-
- public void setLambda(double lambda){
- this.lambda = lambda;
- }
-
- public QLambdaLearner makeCopy(){
- QLambdaLearner clone = new QLambdaLearner();
- clone.copy(this);
- return clone;
- }
-
- @Override
- public void copy(QLearner rhs){
- super.copy(rhs);
-
- QLambdaLearner rhs2 = (QLambdaLearner)rhs;
- lambda = rhs2.lambda;
- e = rhs2.e.makeCopy();
- traceUpdateMode = rhs2.traceUpdateMode;
- }
-
- public QLambdaLearner(QLearner learner){
- copy(learner);
- e = new Matrix(model.getStateCount(), model.getActionCount());
- }
-
- @Override
- public boolean equals(Object obj){
- if(!super.equals(obj)){
- return false;
- }
-
- if(obj instanceof QLambdaLearner){
- QLambdaLearner rhs = (QLambdaLearner)obj;
- return rhs.lambda == lambda && e.equals(rhs.e) && traceUpdateMode == rhs.traceUpdateMode;
- }
-
- return false;
- }
-
- public QLambdaLearner(){
- super();
- }
-
- public QLambdaLearner(int stateCount, int actionCount){
- super(stateCount, actionCount);
- e = new Matrix(stateCount, actionCount);
- }
-
- public QLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ){
- super(stateCount, actionCount, alpha, gamma, initialQ);
- e = new Matrix(stateCount, actionCount);
- }
-
- public Matrix getEligibility()
- {
- return e;
- }
-
- public void setEligibility(Matrix e){
- this.e = e;
- }
-
- @Override
- public void update(int currentStateId, int currentActionId, int nextStateId, Set actionsAtNextStateId, double immediateReward)
- {
- // old_value is $Q_t(s_t, a_t)$
- double oldQ = model.getQ(currentStateId, currentActionId);
-
- // learning_rate;
- double alpha = model.getAlpha(currentStateId, currentActionId);
-
- // discount_rate;
- double gamma = model.getGamma();
-
- // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$
- double maxQ = maxQAtState(nextStateId, actionsAtNextStateId);
-
- double td_error = immediateReward + gamma * maxQ - oldQ;
-
- int stateCount = model.getStateCount();
- int actionCount = model.getActionCount();
-
- e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1);
-
-
- for(int stateId = 0; stateId < stateCount; ++stateId){
- for(int actionId = 0; actionId < actionCount; ++actionId){
- oldQ = model.getQ(stateId, actionId);
- double newQ = oldQ + alpha * td_error * e.get(stateId, actionId);
-
- // new_value is $Q_{t+1}(s_t, a_t)$
- model.setQ(currentStateId, currentActionId, newQ);
-
- if (actionId != currentActionId) {
- e.set(currentStateId, actionId, 0);
- } else {
- e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda);
- }
- }
- }
-
-
-
- }
+ private double lambda = 0.9;
+ private Matrix e;
+ private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
+
+ public EligibilityTraceUpdateMode getTraceUpdateMode() {
+ return traceUpdateMode;
+ }
+
+ public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) {
+ this.traceUpdateMode = traceUpdateMode;
+ }
+
+ public double getLambda() {
+ return lambda;
+ }
+
+ public void setLambda(double lambda) {
+ this.lambda = lambda;
+ }
+
+ public QLambdaLearner makeCopy() {
+ QLambdaLearner clone = new QLambdaLearner();
+ clone.copy(this);
+ return clone;
+ }
+
+ @Override
+ public void copy(QLearner rhs) {
+ super.copy(rhs);
+
+ QLambdaLearner rhs2 = (QLambdaLearner) rhs;
+ lambda = rhs2.lambda;
+ e = rhs2.e.makeCopy();
+ traceUpdateMode = rhs2.traceUpdateMode;
+ }
+
+ public QLambdaLearner(QLearner learner) {
+ copy(learner);
+ e = new Matrix(model.getStateCount(), model.getActionCount());
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (!super.equals(obj)) {
+ return false;
+ }
+
+ if (obj instanceof QLambdaLearner) {
+ QLambdaLearner rhs = (QLambdaLearner) obj;
+ return rhs.lambda == lambda && e
+ .equals(rhs.e) && traceUpdateMode == rhs.traceUpdateMode;
+ }
+
+ return false;
+ }
+
+ public QLambdaLearner() {
+ super();
+ }
+
+ public QLambdaLearner(int stateCount, int actionCount) {
+ super(stateCount, actionCount);
+ e = new Matrix(stateCount, actionCount);
+ }
+
+ public QLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ) {
+ super(stateCount, actionCount, alpha, gamma, initialQ);
+ e = new Matrix(stateCount, actionCount);
+ }
+
+ public Matrix getEligibility() {
+ return e;
+ }
+
+ public void setEligibility(Matrix e) {
+ this.e = e;
+ }
+
+ @Override
+ public void update(int currentStateId, int currentActionId, int nextStateId, Set actionsAtNextStateId, double immediateReward) {
+ // old_value is $Q_t(s_t, a_t)$
+ double oldQ = model.getQ(currentStateId, currentActionId);
+
+ // learning_rate;
+ double alpha = model.getAlpha(currentStateId, currentActionId);
+
+ // discount_rate;
+ double gamma = model.getGamma();
+
+ // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$
+ double maxQ = maxQAtState(nextStateId, actionsAtNextStateId);
+
+ double td_error = immediateReward + gamma * maxQ - oldQ;
+
+ int stateCount = model.getStateCount();
+ int actionCount = model.getActionCount();
+
+ e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1);
+
+ for (int stateId = 0; stateId < stateCount; ++stateId) {
+ for (int actionId = 0; actionId < actionCount; ++actionId) {
+ oldQ = model.getQ(stateId, actionId);
+ double newQ = oldQ + alpha * td_error * e.get(stateId, actionId);
+
+ // new_value is $Q_{t+1}(s_t, a_t)$
+ model.setQ(currentStateId, currentActionId, newQ);
+
+ if (actionId != currentActionId) {
+ e.set(currentStateId, actionId, 0);
+ } else {
+ e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda);
+ }
+ }
+ }
+
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java
index 865abc5..a087957 100644
--- a/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java
+++ b/src/main/java/com/github/chen0040/rl/learning/qlearn/QLearner.java
@@ -1,142 +1,137 @@
package com.github.chen0040.rl.learning.qlearn;
-
-import com.alibaba.fastjson.JSON;
-import com.alibaba.fastjson.annotation.JSONField;
-import com.alibaba.fastjson.serializer.SerializerFeature;
+//import com.alibaba.fastjson.JSON;
+//import com.alibaba.fastjson.annotation.JSONField;
+//import com.alibaba.fastjson.serializer.SerializerFeature;
import com.github.chen0040.rl.actionselection.AbstractActionSelectionStrategy;
import com.github.chen0040.rl.actionselection.ActionSelectionStrategy;
import com.github.chen0040.rl.actionselection.ActionSelectionStrategyFactory;
import com.github.chen0040.rl.actionselection.EpsilonGreedyActionSelectionStrategy;
import com.github.chen0040.rl.models.QModel;
import com.github.chen0040.rl.utils.IndexValue;
+import com.google.gson.Gson;
import java.io.Serializable;
import java.util.Random;
import java.util.Set;
-
/**
- * Created by xschen on 9/27/2015 0027.
- * Implement temporal-difference learning Q-Learning, which is an off-policy TD control algorithm
- * Q is known as the quality of state-action combination, note that it is different from utility of a state
+ * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Q-Learning, which is
+ * an off-policy TD control algorithm Q is known as the quality of state-action combination, note
+ * that it is different from utility of a state
*/
-public class QLearner implements Serializable,Cloneable {
- protected QModel model;
-
- private ActionSelectionStrategy actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
-
- public QLearner makeCopy(){
- QLearner clone = new QLearner();
- clone.copy(this);
- return clone;
- }
-
- public String toJson() {
- return JSON.toJSONString(this, SerializerFeature.BrowserCompatible);
- }
-
- public static QLearner fromJson(String json){
- return JSON.parseObject(json, QLearner.class);
- }
-
- public void copy(QLearner rhs){
- model = rhs.model.makeCopy();
- actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone();
- }
-
- @Override
- public boolean equals(Object obj){
- if(obj !=null && obj instanceof QLearner){
- QLearner rhs = (QLearner)obj;
- if(!model.equals(rhs.model)) return false;
- return actionSelectionStrategy.equals(rhs.actionSelectionStrategy);
- }
- return false;
- }
-
- public QModel getModel() {
- return model;
- }
-
- public void setModel(QModel model) {
- this.model = model;
- }
-
-
- public String getActionSelection() {
- return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy);
- }
-
- public void setActionSelection(String conf) {
- this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf);
- }
-
- public QLearner(){
-
- }
-
- public QLearner(int stateCount, int actionCount){
- this(stateCount, actionCount, 0.1, 0.7, 0.1);
- }
-
- public QLearner(QModel model, ActionSelectionStrategy actionSelectionStrategy){
- this.model = model;
- this.actionSelectionStrategy = actionSelectionStrategy;
- }
-
- public QLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ)
- {
- model = new QModel(stateCount, actionCount, initialQ);
- model.setAlpha(alpha);
- model.setGamma(gamma);
- actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
- }
-
-
- protected double maxQAtState(int stateId, Set actionsAtState){
- IndexValue iv = model.actionWithMaxQAtState(stateId, actionsAtState);
- double maxQ = iv.getValue();
- return maxQ;
- }
-
- public IndexValue selectAction(int stateId, Set actionsAtState){
- return actionSelectionStrategy.selectAction(stateId, model, actionsAtState);
- }
-
- public IndexValue selectAction(int stateId){
- return selectAction(stateId, null);
- }
-
-
- public void update(int stateId, int actionId, int nextStateId, double immediateReward){
- update(stateId, actionId, nextStateId, null, immediateReward);
- }
-
- public void update(int stateId, int actionId, int nextStateId, Set actionsAtNextStateId, double immediateReward)
- {
- // old_value is $Q_t(s_t, a_t)$
- double oldQ = model.getQ(stateId, actionId);
-
- // learning_rate;
- double alpha = model.getAlpha(stateId, actionId);
-
- // discount_rate;
- double gamma = model.getGamma();
-
- // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$
- double maxQ = maxQAtState(nextStateId, actionsAtNextStateId);
-
- // learned_value = immediate_reward + gamma * estimate_of_optimal_future_value
- // old_value = oldQ
- // temporal_difference = learned_value - old_value
- // new_value = old_value + learning_rate * temporal_difference
- double newQ = oldQ + alpha * (immediateReward + gamma * maxQ - oldQ);
-
- // new_value is $Q_{t+1}(s_t, a_t)$
- model.setQ(stateId, actionId, newQ);
- }
-
-
+public class QLearner implements Serializable, Cloneable {
+ protected QModel model;
+
+ private ActionSelectionStrategy actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
+
+ public QLearner makeCopy() {
+ QLearner clone = new QLearner();
+ clone.copy(this);
+ return clone;
+ }
+
+ public String toJson() {
+ return new Gson().toJson(this);
+// return JSON.toJSONString(this, SerializerFeature.BrowserCompatible);
+ }
+
+ public static QLearner fromJson(String json) {
+ return new Gson().fromJson(json, QLearner.class);
+// return JSON.parseObject(json, QLearner.class);
+ }
+
+ public void copy(QLearner rhs) {
+ model = rhs.model.makeCopy();
+ actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy)
+ .clone();
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj != null && obj instanceof QLearner) {
+ QLearner rhs = (QLearner) obj;
+ if (!model.equals(rhs.model)) return false;
+ return actionSelectionStrategy.equals(rhs.actionSelectionStrategy);
+ }
+ return false;
+ }
+
+ public QModel getModel() {
+ return model;
+ }
+
+ public void setModel(QModel model) {
+ this.model = model;
+ }
+
+ public String getActionSelection() {
+ return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy);
+ }
+
+ public void setActionSelection(String conf) {
+ this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf);
+ }
+
+ public QLearner() {
+
+ }
+
+ public QLearner(int stateCount, int actionCount) {
+ this(stateCount, actionCount, 0.1, 0.7, 0.1);
+ }
+
+ public QLearner(QModel model, ActionSelectionStrategy actionSelectionStrategy) {
+ this.model = model;
+ this.actionSelectionStrategy = actionSelectionStrategy;
+ }
+
+ public QLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ) {
+ model = new QModel(stateCount, actionCount, initialQ);
+ model.setAlpha(alpha);
+ model.setGamma(gamma);
+ actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
+ }
+
+ protected double maxQAtState(int stateId, Set actionsAtState) {
+ IndexValue iv = model.actionWithMaxQAtState(stateId, actionsAtState);
+ double maxQ = iv.getValue();
+ return maxQ;
+ }
+
+ public IndexValue selectAction(int stateId, Set actionsAtState) {
+ return actionSelectionStrategy.selectAction(stateId, model, actionsAtState);
+ }
+
+ public IndexValue selectAction(int stateId) {
+ return selectAction(stateId, null);
+ }
+
+ public void update(int stateId, int actionId, int nextStateId, double immediateReward) {
+ update(stateId, actionId, nextStateId, null, immediateReward);
+ }
+
+ public void update(int stateId, int actionId, int nextStateId, Set actionsAtNextStateId, double immediateReward) {
+ // old_value is $Q_t(s_t, a_t)$
+ double oldQ = model.getQ(stateId, actionId);
+
+ // learning_rate;
+ double alpha = model.getAlpha(stateId, actionId);
+
+ // discount_rate;
+ double gamma = model.getGamma();
+
+ // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$
+ double maxQ = maxQAtState(nextStateId, actionsAtNextStateId);
+
+ // learned_value = immediate_reward + gamma * estimate_of_optimal_future_value
+ // old_value = oldQ
+ // temporal_difference = learned_value - old_value
+ // new_value = old_value + learning_rate * temporal_difference
+ double newQ = oldQ + alpha * (immediateReward + gamma * maxQ - oldQ);
+
+ // new_value is $Q_{t+1}(s_t, a_t)$
+ model.setQ(stateId, actionId, newQ);
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java b/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java
index f26f20a..533273c 100644
--- a/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java
+++ b/src/main/java/com/github/chen0040/rl/learning/rlearn/RAgent.java
@@ -6,96 +6,93 @@
import java.util.Random;
import java.util.Set;
-
/**
* Created by xschen on 9/27/2015 0027.
*/
-public class RAgent implements Serializable{
- private RLearner learner;
- private int currentState;
- private int currentAction;
- private double currentValue;
-
- public int getCurrentState(){
- return currentState;
- }
-
- public int getCurrentAction(){
- return currentAction;
- }
-
- public void start(int currentState){
- this.currentState = currentState;
- }
-
- public RAgent makeCopy(){
- RAgent clone = new RAgent();
- clone.copy(this);
- return clone;
- }
-
- public void copy(RAgent rhs){
- currentState = rhs.currentState;
- currentAction = rhs.currentAction;
- learner.copy(rhs.learner);
- }
-
- @Override
- public boolean equals(Object obj){
- if(obj != null && obj instanceof RAgent){
- RAgent rhs = (RAgent)obj;
- if(!learner.equals(rhs.learner)) return false;
- if(currentAction != rhs.currentAction) return false;
- return currentState == rhs.currentState;
- }
- return false;
- }
-
- public IndexValue selectAction(){
- return selectAction(null);
- }
-
- public IndexValue selectAction(Set actionsAtState){
-
- if(currentAction==-1){
- IndexValue iv = learner.selectAction(currentState, actionsAtState);
- currentAction = iv.getIndex();
- currentValue = iv.getValue();
- }
- return new IndexValue(currentAction, currentValue);
- }
-
- public void update(int newState, double immediateReward){
- update(newState, null, immediateReward);
- }
-
- public void update(int newState, Set actionsAtState, double immediateReward){
- if(currentAction != -1) {
- learner.update(currentState, currentAction, newState, actionsAtState, immediateReward);
- currentState = newState;
- currentAction = -1;
- }
- }
-
- public RAgent(){
-
- }
-
-
-
- public RLearner getLearner(){
- return learner;
- }
-
- public void setLearner(RLearner learner){
- this.learner = learner;
- }
-
- public RAgent(int stateCount, int actionCount, double alpha, double beta, double rho, double initialQ){
- learner = new RLearner(stateCount, actionCount, alpha, beta, rho, initialQ);
- }
-
- public RAgent(int stateCount, int actionCount){
- learner = new RLearner(stateCount, actionCount);
- }
+public class RAgent implements Serializable {
+ private RLearner learner;
+ private int currentState;
+ private int currentAction;
+ private double currentValue;
+
+ public int getCurrentState() {
+ return currentState;
+ }
+
+ public int getCurrentAction() {
+ return currentAction;
+ }
+
+ public void start(int currentState) {
+ this.currentState = currentState;
+ }
+
+ public RAgent makeCopy() {
+ RAgent clone = new RAgent();
+ clone.copy(this);
+ return clone;
+ }
+
+ public void copy(RAgent rhs) {
+ currentState = rhs.currentState;
+ currentAction = rhs.currentAction;
+ learner.copy(rhs.learner);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj != null && obj instanceof RAgent) {
+ RAgent rhs = (RAgent) obj;
+ if (!learner.equals(rhs.learner)) return false;
+ if (currentAction != rhs.currentAction) return false;
+ return currentState == rhs.currentState;
+ }
+ return false;
+ }
+
+ public IndexValue selectAction() {
+ return selectAction(null);
+ }
+
+ public IndexValue selectAction(Set actionsAtState) {
+
+ if (currentAction == -1) {
+ IndexValue iv = learner.selectAction(currentState, actionsAtState);
+ currentAction = iv.getIndex();
+ currentValue = iv.getValue();
+ }
+ return new IndexValue(currentAction, currentValue);
+ }
+
+ public void update(int newState, double immediateReward) {
+ update(newState, null, immediateReward);
+ }
+
+ public void update(int newState, Set actionsAtState, double immediateReward) {
+ if (currentAction != -1) {
+ learner.update(currentState, currentAction, newState, actionsAtState, immediateReward);
+ currentState = newState;
+ currentAction = -1;
+ }
+ }
+
+ public RAgent() {
+
+ }
+
+ public RLearner getLearner() {
+ return learner;
+ }
+
+ public void setLearner(RLearner learner) {
+ this.learner = learner;
+ }
+
+ public RAgent(int stateCount, int actionCount, double alpha, double beta, double rho, double initialQ) {
+ learner = new RLearner(stateCount, actionCount, alpha, beta, rho, initialQ);
+ }
+
+ public RAgent(int stateCount, int actionCount) {
+ learner = new RLearner(stateCount, actionCount);
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java b/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java
index 910d53f..d80edbe 100644
--- a/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java
+++ b/src/main/java/com/github/chen0040/rl/learning/rlearn/RLearner.java
@@ -1,141 +1,141 @@
package com.github.chen0040.rl.learning.rlearn;
-
-import com.alibaba.fastjson.JSON;
-import com.alibaba.fastjson.serializer.SerializerFeature;
+//import com.alibaba.fastjson.JSON;
+//import com.alibaba.fastjson.serializer.SerializerFeature;
import com.github.chen0040.rl.actionselection.AbstractActionSelectionStrategy;
import com.github.chen0040.rl.actionselection.ActionSelectionStrategy;
import com.github.chen0040.rl.actionselection.ActionSelectionStrategyFactory;
import com.github.chen0040.rl.actionselection.EpsilonGreedyActionSelectionStrategy;
import com.github.chen0040.rl.models.QModel;
import com.github.chen0040.rl.utils.IndexValue;
-import lombok.Getter;
+import com.google.gson.Gson;
+
+//import lombok.Getter;
import java.io.Serializable;
import java.util.Set;
-
/**
* Created by xschen on 9/27/2015 0027.
*/
-public class RLearner implements Serializable, Cloneable{
-
- private QModel model;
- private ActionSelectionStrategy actionSelectionStrategy;
- private double rho;
- private double beta;
-
- public String toJson() {
- return JSON.toJSONString(this, SerializerFeature.BrowserCompatible);
- }
+public class RLearner implements Serializable, Cloneable {
+
+ private QModel model;
+ private ActionSelectionStrategy actionSelectionStrategy;
+ private double rho;
+ private double beta;
- public static RLearner fromJson(String json){
- return JSON.parseObject(json, RLearner.class);
- }
+ public String toJson() {
+ return new Gson().toJson(this);
+// return JSON.toJSONString(this, SerializerFeature.BrowserCompatible);
+ }
+
+ public static RLearner fromJson(String json) {
+ return new Gson().fromJson(json, RLearner.class);
+// return JSON.parseObject(json, RLearner.class);
+ }
- public RLearner makeCopy(){
- RLearner clone = new RLearner();
- clone.copy(this);
- return clone;
- }
-
- public void copy(RLearner rhs){
- model = rhs.model.makeCopy();
- actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy)rhs.actionSelectionStrategy).clone();
- rho = rhs.rho;
- beta = rhs.beta;
- }
-
- @Override
- public boolean equals(Object obj){
- if(obj != null && obj instanceof RLearner){
- RLearner rhs = (RLearner)obj;
- if(!model.equals(rhs.model)) return false;
- if(!actionSelectionStrategy.equals(rhs.actionSelectionStrategy)) return false;
- if(rho != rhs.rho) return false;
- return beta == rhs.beta;
- }
- return false;
- }
-
- public RLearner(){
-
- }
-
- public double getRho() {
- return rho;
- }
-
- public void setRho(double rho) {
- this.rho = rho;
- }
-
- public double getBeta() {
- return beta;
- }
-
- public void setBeta(double beta) {
- this.beta = beta;
- }
-
- public QModel getModel(){
- return model;
-
- }
-
- public void setModel(QModel model){
- this.model = model;
- }
-
- public String getActionSelection(){
- return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy);
- }
-
- public void setActionSelection(String conf){
- this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf);
- }
-
- public RLearner(int stateCount, int actionCount){
- this(stateCount, actionCount, 0.1, 0.1, 0.7, 0.1);
- }
-
- public RLearner(int state_count, int action_count, double alpha, double beta, double rho, double initial_Q)
- {
- model = new QModel(state_count, action_count, initial_Q);
- model.setAlpha(alpha);
-
- this.rho = rho;
- this.beta = beta;
-
- actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
- }
-
- private double maxQAtState(int stateId, Set actionsAtState){
- IndexValue iv = model.actionWithMaxQAtState(stateId, actionsAtState);
- double maxQ = iv.getValue();
- return maxQ;
- }
-
- public void update(int currentState, int actionTaken, int newState, Set actionsAtNextStateId, double immediate_reward)
- {
- double oldQ = model.getQ(currentState, actionTaken);
+ public RLearner makeCopy() {
+ RLearner clone = new RLearner();
+ clone.copy(this);
+ return clone;
+ }
+
+ public void copy(RLearner rhs) {
+ model = rhs.model.makeCopy();
+ actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy)
+ .clone();
+ rho = rhs.rho;
+ beta = rhs.beta;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj != null && obj instanceof RLearner) {
+ RLearner rhs = (RLearner) obj;
+ if (!model.equals(rhs.model)) return false;
+ if (!actionSelectionStrategy.equals(rhs.actionSelectionStrategy)) return false;
+ if (rho != rhs.rho) return false;
+ return beta == rhs.beta;
+ }
+ return false;
+ }
+
+ public RLearner() {
+
+ }
+
+ public double getRho() {
+ return rho;
+ }
+
+ public void setRho(double rho) {
+ this.rho = rho;
+ }
+
+ public double getBeta() {
+ return beta;
+ }
+
+ public void setBeta(double beta) {
+ this.beta = beta;
+ }
+
+ public QModel getModel() {
+ return model;
+
+ }
+
+ public void setModel(QModel model) {
+ this.model = model;
+ }
+
+ public String getActionSelection() {
+ return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy);
+ }
+
+ public void setActionSelection(String conf) {
+ this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf);
+ }
+
+ public RLearner(int stateCount, int actionCount) {
+ this(stateCount, actionCount, 0.1, 0.1, 0.7, 0.1);
+ }
+
+ public RLearner(int state_count, int action_count, double alpha, double beta, double rho, double initial_Q) {
+ model = new QModel(state_count, action_count, initial_Q);
+ model.setAlpha(alpha);
+
+ this.rho = rho;
+ this.beta = beta;
+
+ actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
+ }
+
+ private double maxQAtState(int stateId, Set actionsAtState) {
+ IndexValue iv = model.actionWithMaxQAtState(stateId, actionsAtState);
+ double maxQ = iv.getValue();
+ return maxQ;
+ }
+
+ public void update(int currentState, int actionTaken, int newState, Set actionsAtNextStateId, double immediate_reward) {
+ double oldQ = model.getQ(currentState, actionTaken);
+
+ double alpha = model.getAlpha(currentState, actionTaken); // learning rate;
+
+ double maxQ = maxQAtState(newState, actionsAtNextStateId);
- double alpha = model.getAlpha(currentState, actionTaken); // learning rate;
-
- double maxQ = maxQAtState(newState, actionsAtNextStateId);
+ double newQ = oldQ + alpha * (immediate_reward - rho + maxQ - oldQ);
+
+ double maxQAtCurrentState = maxQAtState(currentState, null);
+ if (newQ == maxQAtCurrentState) {
+ rho = rho + beta * (immediate_reward - rho + maxQ - maxQAtCurrentState);
+ }
- double newQ = oldQ + alpha * (immediate_reward - rho + maxQ - oldQ);
-
- double maxQAtCurrentState = maxQAtState(currentState, null);
- if (newQ == maxQAtCurrentState)
- {
- rho = rho + beta * (immediate_reward - rho + maxQ - maxQAtCurrentState);
- }
+ model.setQ(currentState, actionTaken, newQ);
+ }
- model.setQ(currentState, actionTaken, newQ);
- }
-
- public IndexValue selectAction(int stateId, Set actionsAtState){
- return actionSelectionStrategy.selectAction(stateId, model, actionsAtState);
- }
+ public IndexValue selectAction(int stateId, Set actionsAtState) {
+ return actionSelectionStrategy.selectAction(stateId, model, actionsAtState);
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java
index c4c8f27..2cfa8ae 100644
--- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java
+++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaAgent.java
@@ -6,125 +6,122 @@
import java.util.Random;
import java.util.Set;
-
/**
- * Created by xschen on 9/27/2015 0027.
- * Implement temporal-difference learning Sarsa, which is an on-policy TD control algorithm
+ * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Sarsa, which is an
+ * on-policy TD control algorithm
*/
-public class SarsaAgent implements Serializable{
- private SarsaLearner learner;
- private int currentState;
- private int currentAction;
- private double currentValue;
- private int prevState;
- private int prevAction;
-
- public int getCurrentState(){
- return currentState;
- }
-
- public int getCurrentAction(){
- return currentAction;
- }
-
- public int getPrevState() { return prevState; }
-
- public int getPrevAction() { return prevAction; }
-
- public void start(int currentState){
- this.currentState = currentState;
- this.prevState = -1;
- this.prevAction = -1;
- }
-
- public IndexValue selectAction(){
- return selectAction(null);
- }
-
- public IndexValue selectAction(Set actionsAtState){
- if(currentAction == -1){
- IndexValue iv = learner.selectAction(currentState, actionsAtState);
- currentAction = iv.getIndex();
- currentValue = iv.getValue();
- }
-
- return new IndexValue(currentAction, currentValue);
- }
-
- public void update(int actionTaken, int newState, double immediateReward){
- update(actionTaken, newState, null, immediateReward);
- }
-
- public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward){
-
- IndexValue iv = learner.selectAction(currentState, actionsAtNewState);
- int futureAction = iv.getIndex();
-
- learner.update(currentState, actionTaken, newState, futureAction, immediateReward);
-
- prevState = this.currentState;
- this.prevAction = actionTaken;
-
- currentAction = futureAction;
- currentState = newState;
- }
-
-
-
- public SarsaLearner getLearner(){
- return learner;
- }
-
- public void setLearner(SarsaLearner learner){
- this.learner = learner;
- }
-
- public SarsaAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ){
- learner = new SarsaLearner(stateCount, actionCount, alpha, gamma, initialQ);
- }
-
- public SarsaAgent(int stateCount, int actionCount){
- learner = new SarsaLearner(stateCount, actionCount);
- }
-
- public SarsaAgent(SarsaLearner learner){
- this.learner = learner;
- }
-
- public SarsaAgent(){
-
- }
-
- public void enableEligibilityTrace(double lambda){
- SarsaLambdaLearner acll = new SarsaLambdaLearner(learner);
- acll.setLambda(lambda);
- learner = acll;
- }
-
- public SarsaAgent makeCopy(){
- SarsaAgent clone = new SarsaAgent();
- clone.copy(this);
- return clone;
- }
-
- public void copy(SarsaAgent rhs){
- learner.copy(rhs.learner);
- currentAction = rhs.currentAction;
- currentState = rhs.currentState;
- prevAction = rhs.prevAction;
- prevState = rhs.prevState;
- }
-
- @Override
- public boolean equals(Object obj){
- if(obj != null && obj instanceof SarsaAgent){
- SarsaAgent rhs = (SarsaAgent)obj;
- return prevAction == rhs.prevAction
- && prevState == rhs.prevState
- && currentAction == rhs.currentAction
- && currentState == rhs.currentState
- && learner.equals(rhs.learner);
- }
- return false;
- }
+public class SarsaAgent implements Serializable {
+ private SarsaLearner learner;
+ private int currentState;
+ private int currentAction;
+ private double currentValue;
+ private int prevState;
+ private int prevAction;
+
+ public int getCurrentState() {
+ return currentState;
+ }
+
+ public int getCurrentAction() {
+ return currentAction;
+ }
+
+ public int getPrevState() { return prevState; }
+
+ public int getPrevAction() { return prevAction; }
+
+ public void start(int currentState) {
+ this.currentState = currentState;
+ this.prevState = -1;
+ this.prevAction = -1;
+ }
+
+ public IndexValue selectAction() {
+ return selectAction(null);
+ }
+
+ public IndexValue selectAction(Set actionsAtState) {
+ if (currentAction == -1) {
+ IndexValue iv = learner.selectAction(currentState, actionsAtState);
+ currentAction = iv.getIndex();
+ currentValue = iv.getValue();
+ }
+
+ return new IndexValue(currentAction, currentValue);
+ }
+
+ public void update(int actionTaken, int newState, double immediateReward) {
+ update(actionTaken, newState, null, immediateReward);
+ }
+
+ public void update(int actionTaken, int newState, Set actionsAtNewState, double immediateReward) {
+
+ IndexValue iv = learner.selectAction(currentState, actionsAtNewState);
+ int futureAction = iv.getIndex();
+
+ learner.update(currentState, actionTaken, newState, futureAction, immediateReward);
+
+ prevState = this.currentState;
+ this.prevAction = actionTaken;
+
+ currentAction = futureAction;
+ currentState = newState;
+ }
+
+ public SarsaLearner getLearner() {
+ return learner;
+ }
+
+ public void setLearner(SarsaLearner learner) {
+ this.learner = learner;
+ }
+
+ public SarsaAgent(int stateCount, int actionCount, double alpha, double gamma, double initialQ) {
+ learner = new SarsaLearner(stateCount, actionCount, alpha, gamma, initialQ);
+ }
+
+ public SarsaAgent(int stateCount, int actionCount) {
+ learner = new SarsaLearner(stateCount, actionCount);
+ }
+
+ public SarsaAgent(SarsaLearner learner) {
+ this.learner = learner;
+ }
+
+ public SarsaAgent() {
+
+ }
+
+ public void enableEligibilityTrace(double lambda) {
+ SarsaLambdaLearner acll = new SarsaLambdaLearner(learner);
+ acll.setLambda(lambda);
+ learner = acll;
+ }
+
+ public SarsaAgent makeCopy() {
+ SarsaAgent clone = new SarsaAgent();
+ clone.copy(this);
+ return clone;
+ }
+
+ public void copy(SarsaAgent rhs) {
+ learner.copy(rhs.learner);
+ currentAction = rhs.currentAction;
+ currentState = rhs.currentState;
+ prevAction = rhs.prevAction;
+ prevState = rhs.prevState;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj != null && obj instanceof SarsaAgent) {
+ SarsaAgent rhs = (SarsaAgent) obj;
+ return prevAction == rhs.prevAction
+ && prevState == rhs.prevState
+ && currentAction == rhs.currentAction
+ && currentState == rhs.currentState
+ && learner.equals(rhs.learner);
+ }
+ return false;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java
index e51543e..0a94fe4 100644
--- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java
+++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLambdaLearner.java
@@ -1,130 +1,127 @@
package com.github.chen0040.rl.learning.sarsa;
-
import com.github.chen0040.rl.models.EligibilityTraceUpdateMode;
import com.github.chen0040.rl.utils.Matrix;
-
/**
* Created by xschen on 9/28/2015 0028.
*/
public class SarsaLambdaLearner extends SarsaLearner {
- private double lambda = 0.9;
- private Matrix e;
- private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
-
- public EligibilityTraceUpdateMode getTraceUpdateMode() {
- return traceUpdateMode;
- }
-
- public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) {
- this.traceUpdateMode = traceUpdateMode;
- }
-
- public double getLambda(){
- return lambda;
- }
-
- public void setLambda(double lambda){
- this.lambda = lambda;
- }
-
- @Override
- public Object clone(){
- SarsaLambdaLearner clone = new SarsaLambdaLearner();
- clone.copy(this);
- return clone;
- }
-
- @Override
- public void copy(SarsaLearner rhs){
- super.copy(rhs);
-
- SarsaLambdaLearner rhs2 = (SarsaLambdaLearner)rhs;
- lambda = rhs2.lambda;
- e = rhs2.e.makeCopy();
- traceUpdateMode = rhs2.traceUpdateMode;
- }
-
- @Override
- public boolean equals(Object obj){
- if(!super.equals(obj)){
- return false;
- }
-
- if(obj instanceof SarsaLambdaLearner){
- SarsaLambdaLearner rhs = (SarsaLambdaLearner)obj;
- return rhs.lambda == lambda && e.equals(rhs.e) && traceUpdateMode == rhs.traceUpdateMode;
- }
-
- return false;
- }
-
- public SarsaLambdaLearner(){
- super();
- }
-
- public SarsaLambdaLearner(int stateCount, int actionCount){
- super(stateCount, actionCount);
- e = new Matrix(stateCount, actionCount);
- }
-
- public SarsaLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ){
- super(stateCount, actionCount, alpha, gamma, initialQ);
- e = new Matrix(stateCount, actionCount);
- }
-
- public SarsaLambdaLearner(SarsaLearner learner){
- copy(learner);
- e = new Matrix(model.getStateCount(), model.getActionCount());
- }
-
- public Matrix getEligibility()
- {
- return e;
- }
-
- public void setEligibility(Matrix e){
- this.e = e;
- }
-
- @Override
- public void update(int currentStateId, int currentActionId, int nextStateId, int nextActionId, double immediateReward)
- {
- // old_value is $Q_t(s_t, a_t)$
- double oldQ = model.getQ(currentStateId, currentActionId);
-
- // learning_rate;
- double alpha = model.getAlpha(currentStateId, currentActionId);
-
- // discount_rate;
- double gamma = model.getGamma();
-
- // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$
- double nextQ = model.getQ(nextStateId, nextActionId);
-
- double td_error = immediateReward + gamma * nextQ - oldQ;
-
- int stateCount = model.getStateCount();
- int actionCount = model.getActionCount();
-
- e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1);
-
- for(int stateId = 0; stateId < stateCount; ++stateId){
- for(int actionId = 0; actionId < actionCount; ++actionId){
- oldQ = model.getQ(stateId, actionId);
-
- double newQ = oldQ + alpha * td_error * e.get(stateId, actionId);
-
- model.setQ(stateId, actionId, newQ);
-
- if (actionId != currentActionId) {
- e.set(currentStateId, actionId, 0);
- } else {
- e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda);
- }
- }
- }
- }
+ private double lambda = 0.9;
+ private Matrix e;
+ private EligibilityTraceUpdateMode traceUpdateMode = EligibilityTraceUpdateMode.ReplaceTrace;
+
+ public EligibilityTraceUpdateMode getTraceUpdateMode() {
+ return traceUpdateMode;
+ }
+
+ public void setTraceUpdateMode(EligibilityTraceUpdateMode traceUpdateMode) {
+ this.traceUpdateMode = traceUpdateMode;
+ }
+
+ public double getLambda() {
+ return lambda;
+ }
+
+ public void setLambda(double lambda) {
+ this.lambda = lambda;
+ }
+
+ @Override
+ public Object clone() {
+ SarsaLambdaLearner clone = new SarsaLambdaLearner();
+ clone.copy(this);
+ return clone;
+ }
+
+ @Override
+ public void copy(SarsaLearner rhs) {
+ super.copy(rhs);
+
+ SarsaLambdaLearner rhs2 = (SarsaLambdaLearner) rhs;
+ lambda = rhs2.lambda;
+ e = rhs2.e.makeCopy();
+ traceUpdateMode = rhs2.traceUpdateMode;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (!super.equals(obj)) {
+ return false;
+ }
+
+ if (obj instanceof SarsaLambdaLearner) {
+ SarsaLambdaLearner rhs = (SarsaLambdaLearner) obj;
+ return rhs.lambda == lambda && e
+ .equals(rhs.e) && traceUpdateMode == rhs.traceUpdateMode;
+ }
+
+ return false;
+ }
+
+ public SarsaLambdaLearner() {
+ super();
+ }
+
+ public SarsaLambdaLearner(int stateCount, int actionCount) {
+ super(stateCount, actionCount);
+ e = new Matrix(stateCount, actionCount);
+ }
+
+ public SarsaLambdaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ) {
+ super(stateCount, actionCount, alpha, gamma, initialQ);
+ e = new Matrix(stateCount, actionCount);
+ }
+
+ public SarsaLambdaLearner(SarsaLearner learner) {
+ copy(learner);
+ e = new Matrix(model.getStateCount(), model.getActionCount());
+ }
+
+ public Matrix getEligibility() {
+ return e;
+ }
+
+ public void setEligibility(Matrix e) {
+ this.e = e;
+ }
+
+ @Override
+ public void update(int currentStateId, int currentActionId, int nextStateId, int nextActionId, double immediateReward) {
+ // old_value is $Q_t(s_t, a_t)$
+ double oldQ = model.getQ(currentStateId, currentActionId);
+
+ // learning_rate;
+ double alpha = model.getAlpha(currentStateId, currentActionId);
+
+ // discount_rate;
+ double gamma = model.getGamma();
+
+ // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$
+ double nextQ = model.getQ(nextStateId, nextActionId);
+
+ double td_error = immediateReward + gamma * nextQ - oldQ;
+
+ int stateCount = model.getStateCount();
+ int actionCount = model.getActionCount();
+
+ e.set(currentStateId, currentActionId, e.get(currentStateId, currentActionId) + 1);
+
+ for (int stateId = 0; stateId < stateCount; ++stateId) {
+ for (int actionId = 0; actionId < actionCount; ++actionId) {
+ oldQ = model.getQ(stateId, actionId);
+
+ double newQ = oldQ + alpha * td_error * e.get(stateId, actionId);
+
+ model.setQ(stateId, actionId, newQ);
+
+ if (actionId != currentActionId) {
+ e.set(currentStateId, actionId, 0);
+ } else {
+ e.set(stateId, actionId, e.get(stateId, actionId) * gamma * lambda);
+ }
+ }
+ }
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java
index 7fef780..fd04dcb 100644
--- a/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java
+++ b/src/main/java/com/github/chen0040/rl/learning/sarsa/SarsaLearner.java
@@ -1,160 +1,157 @@
package com.github.chen0040.rl.learning.sarsa;
-
-import com.alibaba.fastjson.JSON;
-import com.alibaba.fastjson.serializer.SerializerFeature;
+//import com.alibaba.fastjson.JSON;
+//import com.alibaba.fastjson.serializer.SerializerFeature;
import com.github.chen0040.rl.actionselection.AbstractActionSelectionStrategy;
import com.github.chen0040.rl.actionselection.ActionSelectionStrategy;
import com.github.chen0040.rl.actionselection.ActionSelectionStrategyFactory;
import com.github.chen0040.rl.actionselection.EpsilonGreedyActionSelectionStrategy;
import com.github.chen0040.rl.models.QModel;
import com.github.chen0040.rl.utils.IndexValue;
+import com.google.gson.Gson;
import java.io.Serializable;
import java.util.Random;
import java.util.Set;
-
/**
- * Created by xschen on 9/27/2015 0027.
- * Implement temporal-difference learning Q-Learning, which is an off-policy TD control algorithm
- * Q is known as the quality of state-action combination, note that it is different from utility of a state
+ * Created by xschen on 9/27/2015 0027. Implement temporal-difference learning Q-Learning, which is
+ * an off-policy TD control algorithm Q is known as the quality of state-action combination, note
+ * that it is different from utility of a state
*/
-public class SarsaLearner implements Serializable,Cloneable {
- protected QModel model;
- private ActionSelectionStrategy actionSelectionStrategy;
-
- public String toJson() {
- return JSON.toJSONString(this, SerializerFeature.BrowserCompatible);
- }
-
- public static SarsaLearner fromJson(String json){
- return JSON.parseObject(json, SarsaLearner.class);
- }
-
- public SarsaLearner makeCopy(){
- SarsaLearner clone = new SarsaLearner();
- clone.copy(this);
- return clone;
- }
+public class SarsaLearner implements Serializable, Cloneable {
+ protected QModel model;
+ private ActionSelectionStrategy actionSelectionStrategy;
- public void copy(SarsaLearner rhs){
- model = rhs.model.makeCopy();
- actionSelectionStrategy = (ActionSelectionStrategy)((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy).clone();
- }
+ public String toJson() {
+ return new Gson().toJson(this);
+// return JSON.toJSONString(this, SerializerFeature.BrowserCompatible);
+ }
- @Override
- public boolean equals(Object obj){
- if(obj !=null && obj instanceof SarsaLearner){
- SarsaLearner rhs = (SarsaLearner)obj;
- if(!model.equals(rhs.model)) return false;
- return actionSelectionStrategy.equals(rhs.actionSelectionStrategy);
- }
- return false;
- }
+ public static SarsaLearner fromJson(String json) {
+ return new Gson().fromJson(json, SarsaLearner.class);
+// return JSON.parseObject(json, SarsaLearner.class);
+ }
- public QModel getModel() {
- return model;
- }
+ public SarsaLearner makeCopy() {
+ SarsaLearner clone = new SarsaLearner();
+ clone.copy(this);
+ return clone;
+ }
- public void setModel(QModel model) {
- this.model = model;
- }
+ public void copy(SarsaLearner rhs) {
+ model = rhs.model.makeCopy();
+ actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) rhs.actionSelectionStrategy)
+ .clone();
+ }
- public String getActionSelection() {
- return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy);
- }
+ @Override
+ public boolean equals(Object obj) {
+ if (obj != null && obj instanceof SarsaLearner) {
+ SarsaLearner rhs = (SarsaLearner) obj;
+ if (!model.equals(rhs.model)) return false;
+ return actionSelectionStrategy.equals(rhs.actionSelectionStrategy);
+ }
+ return false;
+ }
- public void setActionSelection(String conf) {
- this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf);
- }
+ public QModel getModel() {
+ return model;
+ }
- public SarsaLearner(){
+ public void setModel(QModel model) {
+ this.model = model;
+ }
- }
+ public String getActionSelection() {
+ return ActionSelectionStrategyFactory.serialize(actionSelectionStrategy);
+ }
- public SarsaLearner(int stateCount, int actionCount){
- this(stateCount, actionCount, 0.1, 0.7, 0.1);
- }
+ public void setActionSelection(String conf) {
+ this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(conf);
+ }
- public SarsaLearner(QModel model, ActionSelectionStrategy actionSelectionStrategy){
- this.model = model;
- this.actionSelectionStrategy = actionSelectionStrategy;
- }
+ public SarsaLearner() {
- public SarsaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ)
- {
- model = new QModel(stateCount, actionCount, initialQ);
- model.setAlpha(alpha);
- model.setGamma(gamma);
- actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
- }
+ }
- public static void main(String[] args){
- int stateCount = 100;
- int actionCount = 10;
+ public SarsaLearner(int stateCount, int actionCount) {
+ this(stateCount, actionCount, 0.1, 0.7, 0.1);
+ }
- SarsaLearner learner = new SarsaLearner(stateCount, actionCount);
+ public SarsaLearner(QModel model, ActionSelectionStrategy actionSelectionStrategy) {
+ this.model = model;
+ this.actionSelectionStrategy = actionSelectionStrategy;
+ }
- double reward = 0; // reward gained by transiting from prevState to currentState
- Random random = new Random();
- int currentStateId = random.nextInt(stateCount);
- int currentActionId = learner.selectAction(currentStateId).getIndex();
+ public SarsaLearner(int stateCount, int actionCount, double alpha, double gamma, double initialQ) {
+ model = new QModel(stateCount, actionCount, initialQ);
+ model.setAlpha(alpha);
+ model.setGamma(gamma);
+ actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
+ }
- for(int time=0; time < 1000; ++time){
+ public static void main(String[] args) {
+ int stateCount = 100;
+ int actionCount = 10;
- System.out.println("Controller does action-"+currentActionId);
+ SarsaLearner learner = new SarsaLearner(stateCount, actionCount);
- int newStateId = random.nextInt(actionCount);
- reward = random.nextDouble();
+ double reward = 0; // reward gained by transiting from prevState to currentState
+ Random random = new Random();
+ int currentStateId = random.nextInt(stateCount);
+ int currentActionId = learner.selectAction(currentStateId).getIndex();
- System.out.println("Now the new state is " + newStateId);
- System.out.println("Controller receives Reward = " + reward);
+ for (int time = 0; time < 1000; ++time) {
- int futureActionId = learner.selectAction(newStateId).getIndex();
+ System.out.println("Controller does action-" + currentActionId);
- System.out.println("Controller is expected to do action-"+futureActionId);
+ int newStateId = random.nextInt(actionCount);
+ reward = random.nextDouble();
- learner.update(currentStateId, currentActionId, newStateId, futureActionId, reward);
+ System.out.println("Now the new state is " + newStateId);
+ System.out.println("Controller receives Reward = " + reward);
- currentStateId = newStateId;
- currentActionId = futureActionId;
- }
- }
+ int futureActionId = learner.selectAction(newStateId).getIndex();
+ System.out.println("Controller is expected to do action-" + futureActionId);
- public IndexValue selectAction(int stateId, Set actionsAtState){
- return actionSelectionStrategy.selectAction(stateId, model, actionsAtState);
- }
+ learner.update(currentStateId, currentActionId, newStateId, futureActionId, reward);
- public IndexValue selectAction(int stateId){
- return selectAction(stateId, null);
- }
+ currentStateId = newStateId;
+ currentActionId = futureActionId;
+ }
+ }
- public void update(int stateId, int actionId, int nextStateId, int nextActionId, double immediateReward)
- {
- // old_value is $Q_t(s_t, a_t)$
- double oldQ = model.getQ(stateId, actionId);
+ public IndexValue selectAction(int stateId, Set actionsAtState) {
+ return actionSelectionStrategy.selectAction(stateId, model, actionsAtState);
+ }
- // learning_rate;
- double alpha = model.getAlpha(stateId, actionId);
+ public IndexValue selectAction(int stateId) {
+ return selectAction(stateId, null);
+ }
- // discount_rate;
- double gamma = model.getGamma();
+ public void update(int stateId, int actionId, int nextStateId, int nextActionId, double immediateReward) {
+ // old_value is $Q_t(s_t, a_t)$
+ double oldQ = model.getQ(stateId, actionId);
- // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$
- double nextQ = model.getQ(nextStateId, nextActionId);
+ // learning_rate;
+ double alpha = model.getAlpha(stateId, actionId);
- // learned_value = immediate_reward + gamma * estimate_of_optimal_future_value
- // old_value = oldQ
- // temporal_difference = learned_value - old_value
- // new_value = old_value + learning_rate * temporal_difference
- double newQ = oldQ + alpha * (immediateReward + gamma * nextQ - oldQ);
+ // discount_rate;
+ double gamma = model.getGamma();
- // new_value is $Q_{t+1}(s_t, a_t)$
- model.setQ(stateId, actionId, newQ);
- }
+ // estimate_of_optimal_future_value is $max_a Q_t(s_{t+1}, a)$
+ double nextQ = model.getQ(nextStateId, nextActionId);
+ // learned_value = immediate_reward + gamma * estimate_of_optimal_future_value
+ // old_value = oldQ
+ // temporal_difference = learned_value - old_value
+ // new_value = old_value + learning_rate * temporal_difference
+ double newQ = oldQ + alpha * (immediateReward + gamma * nextQ - oldQ);
+ // new_value is $Q_{t+1}(s_t, a_t)$
+ model.setQ(stateId, actionId, newQ);
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java b/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java
index e25380f..dd6dc71 100644
--- a/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java
+++ b/src/main/java/com/github/chen0040/rl/models/EligibilityTraceUpdateMode.java
@@ -4,6 +4,6 @@
* Created by xschen on 9/28/2015 0028.
*/
public enum EligibilityTraceUpdateMode {
- ReplaceTrace,
- AccumulateTrace
+ ReplaceTrace,
+ AccumulateTrace
}
diff --git a/src/main/java/com/github/chen0040/rl/models/QModel.java b/src/main/java/com/github/chen0040/rl/models/QModel.java
index 2d314a1..e4a61ce 100644
--- a/src/main/java/com/github/chen0040/rl/models/QModel.java
+++ b/src/main/java/com/github/chen0040/rl/models/QModel.java
@@ -1,158 +1,166 @@
package com.github.chen0040.rl.models;
-
import com.github.chen0040.rl.utils.IndexValue;
import com.github.chen0040.rl.utils.Matrix;
import com.github.chen0040.rl.utils.Vec;
-import lombok.Getter;
-import lombok.Setter;
-import java.util.*;
+//import lombok.Getter;
+//import lombok.Setter;
+import java.util.*;
/**
- * @author xschen
- * 9/27/2015 0027.
- * Q is known as the quality of state-action combination, note that it is different from utility of a state
+ * @author xschen 9/27/2015 0027. Q is known as the quality of state-action combination, note that
+ * it is different from utility of a state
*/
-@Getter
-@Setter
+//@Getter
+//@Setter
public class QModel {
- /**
- * Q value for (state_id, action_id) pair
- * Q is known as the quality of state-action combination, note that it is different from utility of a state
- */
- private Matrix Q;
- /**
- * $\alpha[s, a]$ value for learning rate: alpha(state_id, action_id)
- */
- private Matrix alphaMatrix;
-
- /**
- * discount factor
- */
- private double gamma = 0.7;
-
- private int stateCount;
- private int actionCount;
-
- public QModel(int stateCount, int actionCount, double initialQ){
- this.stateCount = stateCount;
- this.actionCount = actionCount;
- Q = new Matrix(stateCount,actionCount);
- alphaMatrix = new Matrix(stateCount, actionCount);
- Q.setAll(initialQ);
- alphaMatrix.setAll(0.1);
- }
-
- public QModel(int stateCount, int actionCount){
- this(stateCount, actionCount, 0.1);
- }
-
- public QModel(){
-
- }
-
- @Override
- public boolean equals(Object rhs){
- if(rhs != null && rhs instanceof QModel){
- QModel rhs2 = (QModel)rhs;
-
-
- if(gamma != rhs2.gamma) return false;
-
-
- if(stateCount != rhs2.stateCount || actionCount != rhs2.actionCount) return false;
-
- if((Q!=null && rhs2.Q==null) || (Q==null && rhs2.Q !=null)) return false;
- if((alphaMatrix !=null && rhs2.alphaMatrix ==null) || (alphaMatrix ==null && rhs2.alphaMatrix !=null)) return false;
-
- return !((Q != null && !Q.equals(rhs2.Q)) || (alphaMatrix != null && !alphaMatrix.equals(rhs2.alphaMatrix)));
-
- }
- return false;
- }
-
- public QModel makeCopy(){
- QModel clone = new QModel();
- clone.copy(this);
- return clone;
- }
-
- public void copy(QModel rhs){
- gamma = rhs.gamma;
- stateCount = rhs.stateCount;
- actionCount = rhs.actionCount;
- Q = rhs.Q==null ? null : rhs.Q.makeCopy();
- alphaMatrix = rhs.alphaMatrix == null ? null : rhs.alphaMatrix.makeCopy();
- }
-
-
- public double getQ(int stateId, int actionId){
- return Q.get(stateId, actionId);
- }
-
-
- public void setQ(int stateId, int actionId, double Qij){
- Q.set(stateId, actionId, Qij);
- }
-
-
- public double getAlpha(int stateId, int actionId){
- return alphaMatrix.get(stateId, actionId);
- }
-
-
- public void setAlpha(double defaultAlpha) {
- this.alphaMatrix.setAll(defaultAlpha);
- }
-
-
- public IndexValue actionWithMaxQAtState(int stateId, Set actionsAtState){
- Vec rowVector = Q.rowAt(stateId);
- return rowVector.indexWithMaxValue(actionsAtState);
- }
-
- private void reset(double initialQ){
- Q.setAll(initialQ);
- }
-
-
- public IndexValue actionWithSoftMaxQAtState(int stateId,Set actionsAtState, Random random) {
- Vec rowVector = Q.rowAt(stateId);
- double sum = 0;
-
- if(actionsAtState==null){
- actionsAtState = new HashSet<>();
- for(int i=0; i < actionCount; ++i){
- actionsAtState.add(i);
- }
- }
-
- List actions = new ArrayList<>();
- for(Integer actionId : actionsAtState){
- actions.add(actionId);
- }
-
- double[] acc = new double[actions.size()];
- for(int i=0; i < actions.size(); ++i){
- sum += rowVector.get(actions.get(i));
- acc[i] = sum;
- }
-
-
- double r = random.nextDouble() * sum;
-
- IndexValue result = new IndexValue();
- for(int i=0; i < actions.size(); ++i){
- if(acc[i] >= r){
- int actionId = actions.get(i);
- result.setIndex(actionId);
- result.setValue(rowVector.get(actionId));
- break;
- }
- }
-
- return result;
- }
+ /**
+ * Q value for (state_id, action_id) pair Q is known as the quality of state-action combination,
+ * note that it is different from utility of a state
+ */
+ private Matrix Q;
+ /**
+ * $\alpha[s, a]$ value for learning rate: alpha(state_id, action_id)
+ */
+ private Matrix alphaMatrix;
+
+ /**
+ * discount factor
+ */
+ private double gamma = 0.7;
+
+ private int stateCount;
+ private int actionCount;
+
+ public QModel(int stateCount, int actionCount, double initialQ) {
+ this.stateCount = stateCount;
+ this.actionCount = actionCount;
+ Q = new Matrix(stateCount, actionCount);
+ alphaMatrix = new Matrix(stateCount, actionCount);
+ Q.setAll(initialQ);
+ alphaMatrix.setAll(0.1);
+ }
+
+ public QModel(int stateCount, int actionCount) {
+ this(stateCount, actionCount, 0.1);
+ }
+
+ public QModel() {
+
+ }
+
+ @Override
+ public boolean equals(Object rhs) {
+ if (rhs != null && rhs instanceof QModel) {
+ QModel rhs2 = (QModel) rhs;
+
+ if (gamma != rhs2.gamma) return false;
+
+ if (stateCount != rhs2.stateCount || actionCount != rhs2.actionCount) return false;
+
+ if ((Q != null && rhs2.Q == null) || (Q == null && rhs2.Q != null)) return false;
+ if ((alphaMatrix != null && rhs2.alphaMatrix == null) || (alphaMatrix == null && rhs2.alphaMatrix != null))
+ return false;
+
+ return !((Q != null && !Q.equals(rhs2.Q)) || (alphaMatrix != null && !alphaMatrix
+ .equals(rhs2.alphaMatrix)));
+
+ }
+ return false;
+ }
+
+ public QModel makeCopy() {
+ QModel clone = new QModel();
+ clone.copy(this);
+ return clone;
+ }
+
+ public void copy(QModel rhs) {
+ gamma = rhs.gamma;
+ stateCount = rhs.stateCount;
+ actionCount = rhs.actionCount;
+ Q = rhs.Q == null ? null : rhs.Q.makeCopy();
+ alphaMatrix = rhs.alphaMatrix == null ? null : rhs.alphaMatrix.makeCopy();
+ }
+
+ public double getQ(int stateId, int actionId) {
+ return Q.get(stateId, actionId);
+ }
+
+ public void setQ(int stateId, int actionId, double Qij) {
+ Q.set(stateId, actionId, Qij);
+ }
+
+ public double getAlpha(int stateId, int actionId) {
+ return alphaMatrix.get(stateId, actionId);
+ }
+
+ public void setAlpha(double defaultAlpha) {
+ this.alphaMatrix.setAll(defaultAlpha);
+ }
+
+ public IndexValue actionWithMaxQAtState(int stateId, Set actionsAtState) {
+ Vec rowVector = Q.rowAt(stateId);
+ return rowVector.indexWithMaxValue(actionsAtState);
+ }
+
+ private void reset(double initialQ) {
+ Q.setAll(initialQ);
+ }
+
+ public IndexValue actionWithSoftMaxQAtState(int stateId, Set actionsAtState, Random random) {
+ Vec rowVector = Q.rowAt(stateId);
+ double sum = 0;
+
+ if (actionsAtState == null) {
+ actionsAtState = new HashSet<>();
+ for (int i = 0; i < actionCount; ++i) {
+ actionsAtState.add(i);
+ }
+ }
+
+ List actions = new ArrayList<>(actionsAtState);
+
+ double[] acc = new double[actions.size()];
+ for (int i = 0; i < actions.size(); ++i) {
+ sum += rowVector.get(actions.get(i));
+ acc[i] = sum;
+ }
+
+ double r = random.nextDouble() * sum;
+
+ IndexValue result = new IndexValue();
+ for (int i = 0; i < actions.size(); ++i) {
+ if (acc[i] >= r) {
+ int actionId = actions.get(i);
+ result.setIndex(actionId);
+ result.setValue(rowVector.get(actionId));
+ break;
+ }
+ }
+
+ return result;
+ }
+
+ public int getActionCount() {
+ return this.actionCount;
+ }
+
+ public int getStateCount() {
+ return stateCount;
+ }
+
+ public double getGamma() {
+ return this.gamma;
+ }
+
+ public void setGamma(double gamma) {
+ this.gamma = gamma;
+ }
+
+ public Matrix getAlphaMatrix() {
+ return this.alphaMatrix;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/models/UtilityModel.java b/src/main/java/com/github/chen0040/rl/models/UtilityModel.java
index cff1859..e8e2eb4 100644
--- a/src/main/java/com/github/chen0040/rl/models/UtilityModel.java
+++ b/src/main/java/com/github/chen0040/rl/models/UtilityModel.java
@@ -1,91 +1,91 @@
package com.github.chen0040.rl.models;
import com.github.chen0040.rl.utils.Vec;
-import lombok.Getter;
-import lombok.Setter;
-import java.io.Serializable;
+//import lombok.Getter;
+//import lombok.Setter;
+import java.io.Serializable;
/**
- * @author xschen
- * 9/27/2015 0027.
- * Utility value of a state $U(s)$ is the expected long term reward of state $s$ given the sequence of reward and the optimal policy
- * Utility value $U(s)$ at state $s$ can be obtained by the Bellman equation
- * Bellman Equtation states that $U(s) = R(s) + \gamma * max_a \sum_{s'} T(s,a,s')U(s')$
- * where s' is the possible transitioned state given that action $a$ is applied at state $s$
- * where $T(s,a,s')$ is the transition probability of $s \rightarrow s'$ given that action $a$ is applied at state $s$
- * where $\sum_{s'} T(s,a,s')U(s')$ is the expected long term reward given that action $a$ is applied at state $s$
- * where $max_a \sum_{s'} T(s,a,s')U(s')$ is the maximum expected long term reward given that the chosen optimal action $a$ is applied at state $s$
+ * @author xschen 9/27/2015 0027. Utility value of a state $U(s)$ is the expected long term reward
+ * of state $s$ given the sequence of reward and the optimal policy Utility value $U(s)$ at state
+ * $s$ can be obtained by the Bellman equation Bellman Equtation states that $U(s) = R(s) + \gamma *
+ * max_a \sum_{s'} T(s,a,s')U(s')$ where s' is the possible transitioned state given that action $a$
+ * is applied at state $s$ where $T(s,a,s')$ is the transition probability of $s \rightarrow s'$
+ * given that action $a$ is applied at state $s$ where $\sum_{s'} T(s,a,s')U(s')$ is the expected
+ * long term reward given that action $a$ is applied at state $s$ where $max_a \sum_{s'}
+ * T(s,a,s')U(s')$ is the maximum expected long term reward given that the chosen optimal action $a$
+ * is applied at state $s$
*/
-@Getter
-@Setter
+//@Getter
+//@Setter
public class UtilityModel implements Serializable {
- private Vec U;
- private int stateCount;
- private int actionCount;
-
- public void setU(Vec U){
- this.U = U;
- }
-
- public Vec getU() {
- return U;
- }
-
- public double getU(int stateId){
- return U.get(stateId);
- }
-
- public int getStateCount() {
- return stateCount;
- }
-
- public int getActionCount() {
- return actionCount;
- }
-
- public UtilityModel(int stateCount, int actionCount, double initialU){
- this.stateCount = stateCount;
- this.actionCount = actionCount;
- U = new Vec(stateCount);
- U.setAll(initialU);
- }
-
- public UtilityModel(int stateCount, int actionCount){
- this(stateCount, actionCount, 0.1);
- }
-
- public UtilityModel(){
-
- }
-
- public void copy(UtilityModel rhs){
- U = rhs.U==null ? null : rhs.U.makeCopy();
- actionCount = rhs.actionCount;
- stateCount = rhs.stateCount;
- }
-
- public UtilityModel makeCopy(){
- UtilityModel clone = new UtilityModel();
- clone.copy(this);
- return clone;
- }
-
- @Override
- public boolean equals(Object rhs){
- if(rhs != null && rhs instanceof UtilityModel){
- UtilityModel rhs2 = (UtilityModel)rhs;
- if(actionCount != rhs2.actionCount || stateCount != rhs2.stateCount) return false;
-
- if((U==null && rhs2.U!=null) && (U!=null && rhs2.U ==null)) return false;
- return !(U != null && !U.equals(rhs2.U));
-
- }
- return false;
- }
-
- public void reset(double initialU){
- U.setAll(initialU);
- }
+ private Vec U;
+ private int stateCount;
+ private int actionCount;
+
+ public void setU(Vec U) {
+ this.U = U;
+ }
+
+ public Vec getU() {
+ return U;
+ }
+
+ public double getU(int stateId) {
+ return U.get(stateId);
+ }
+
+ public int getStateCount() {
+ return stateCount;
+ }
+
+ public int getActionCount() {
+ return actionCount;
+ }
+
+ public UtilityModel(int stateCount, int actionCount, double initialU) {
+ this.stateCount = stateCount;
+ this.actionCount = actionCount;
+ U = new Vec(stateCount);
+ U.setAll(initialU);
+ }
+
+ public UtilityModel(int stateCount, int actionCount) {
+ this(stateCount, actionCount, 0.1);
+ }
+
+ public UtilityModel() {
+
+ }
+
+ public void copy(UtilityModel rhs) {
+ U = rhs.U == null ? null : rhs.U.makeCopy();
+ actionCount = rhs.actionCount;
+ stateCount = rhs.stateCount;
+ }
+
+ public UtilityModel makeCopy() {
+ UtilityModel clone = new UtilityModel();
+ clone.copy(this);
+ return clone;
+ }
+
+ @Override
+ public boolean equals(Object rhs) {
+ if (rhs != null && rhs instanceof UtilityModel) {
+ UtilityModel rhs2 = (UtilityModel) rhs;
+ if (actionCount != rhs2.actionCount || stateCount != rhs2.stateCount) return false;
+
+ if ((U == null && rhs2.U != null) && (U != null && rhs2.U == null)) return false;
+ return !(U != null && !U.equals(rhs2.U));
+
+ }
+ return false;
+ }
+
+ public void reset(double initialU) {
+ U.setAll(initialU);
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java b/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java
index e840bc1..145b1e4 100644
--- a/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java
+++ b/src/main/java/com/github/chen0040/rl/utils/DoubleUtils.java
@@ -4,11 +4,11 @@
* Created by xschen on 10/11/2015 0011.
*/
public class DoubleUtils {
- public static boolean equals(double a1, double a2){
- return Math.abs(a1-a2) < 1e-10;
- }
+ public static boolean equals(double a1, double a2) {
+ return Math.abs(a1 - a2) < 1e-10;
+ }
- public static boolean isZero(double a){
- return a < 1e-20;
- }
+ public static boolean isZero(double a) {
+ return a < 1e-20;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/utils/IndexValue.java b/src/main/java/com/github/chen0040/rl/utils/IndexValue.java
index 66c2bf6..264dcc3 100644
--- a/src/main/java/com/github/chen0040/rl/utils/IndexValue.java
+++ b/src/main/java/com/github/chen0040/rl/utils/IndexValue.java
@@ -1,46 +1,59 @@
package com.github.chen0040.rl.utils;
-
-import lombok.Getter;
-import lombok.Setter;
-
+//import lombok.Getter;
+//import lombok.Setter;
/**
* Created by xschen on 6/5/2017.
*/
-@Getter
-@Setter
+//@Getter
+//@Setter
public class IndexValue {
- private int index;
- private double value;
-
- public IndexValue(){
-
- }
-
- public IndexValue(int index, double value){
- this.index = index;
- this.value = value;
- }
-
- public IndexValue makeCopy(){
- IndexValue clone = new IndexValue();
- clone.setValue(value);
- clone.setIndex(index);
- return clone;
- }
-
- @Override
- public boolean equals(Object rhs){
- if(rhs != null && rhs instanceof IndexValue){
- IndexValue rhs2 = (IndexValue)rhs;
- return index == rhs2.index && value == rhs2.value;
- }
- return false;
- }
-
- public boolean isValid(){
- return index != -1;
- }
-
+ private int index;
+ private double value;
+
+ public IndexValue() {
+
+ }
+
+ public IndexValue(int index, double value) {
+ this.index = index;
+ this.value = value;
+ }
+
+ public IndexValue makeCopy() {
+ IndexValue clone = new IndexValue();
+ clone.setValue(value);
+ clone.setIndex(index);
+ return clone;
+ }
+
+ public void setIndex(int index) {
+ this.index = index;
+ }
+
+ @Override
+ public boolean equals(Object rhs) {
+ if (rhs != null && rhs instanceof IndexValue) {
+ IndexValue rhs2 = (IndexValue) rhs;
+ return index == rhs2.index && value == rhs2.value;
+ }
+ return false;
+ }
+
+ public boolean isValid() {
+ return index != -1;
+ }
+
+ public void setValue(double v) {
+ this.value = v;
+ }
+
+ public int getIndex() {
+ return this.index;
+ }
+
+ public double getValue() {
+ return this.value;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/utils/Matrix.java b/src/main/java/com/github/chen0040/rl/utils/Matrix.java
index cd42bd5..30793b7 100644
--- a/src/main/java/com/github/chen0040/rl/utils/Matrix.java
+++ b/src/main/java/com/github/chen0040/rl/utils/Matrix.java
@@ -1,8 +1,9 @@
package com.github.chen0040.rl.utils;
-import com.alibaba.fastjson.annotation.JSONField;
-import lombok.Getter;
-import lombok.Setter;
+//import com.alibaba.fastjson.annotation.JSONField;
+
+//import lombok.Getter;
+//import lombok.Setter;
import java.io.Serializable;
import java.util.ArrayList;
@@ -10,234 +11,228 @@
import java.util.List;
import java.util.Map;
-
/**
* Created by xschen on 9/27/2015 0027.
*/
-@Getter
-@Setter
+//@Getter
+//@Setter
public class Matrix implements Serializable {
- private Map rows = new HashMap<>();
- private int rowCount;
- private int columnCount;
- private double defaultValue;
-
- public Matrix(){
-
- }
-
- public Matrix(double[][] A){
- for(int i = 0; i < A.length; ++i){
- double[] B = A[i];
- for(int j=0; j < B.length; ++j){
- set(i, j, B[j]);
- }
- }
- }
-
- public void setRow(int rowIndex, Vec rowVector){
- rowVector.setId(rowIndex);
- rows.put(rowIndex, rowVector);
- }
-
-
- public static Matrix identity(int dimension){
- Matrix m = new Matrix(dimension, dimension);
- for(int i=0; i < m.getRowCount(); ++i){
- m.set(i, i, 1);
- }
- return m;
- }
-
- @Override
- public boolean equals(Object rhs){
- if(rhs != null && rhs instanceof Matrix){
- Matrix rhs2 = (Matrix)rhs;
- if(rowCount != rhs2.rowCount || columnCount != rhs2.columnCount){
- return false;
- }
-
- if(defaultValue == rhs2.defaultValue) {
- for (Integer index : rows.keySet()) {
- if (!rhs2.rows.containsKey(index)) return false;
- if (!rows.get(index).equals(rhs2.rows.get(index))) {
- System.out.println("failed!");
- return false;
- }
- }
-
- for (Integer index : rhs2.rows.keySet()) {
- if (!rows.containsKey(index)) return false;
- if (!rhs2.rows.get(index).equals(rows.get(index))) {
- System.out.println("failed! 22");
- return false;
- }
- }
- } else {
-
- for(int i=0; i < rowCount; ++i) {
- for(int j=0; j < columnCount; ++j) {
- if(this.get(i, j) != rhs2.get(i, j)){
- return false;
- }
- }
- }
- }
-
- return true;
- }
-
- return false;
- }
-
- public Matrix makeCopy(){
- Matrix clone = new Matrix(rowCount, columnCount);
- clone.copy(this);
- return clone;
- }
-
- public void copy(Matrix rhs){
- rowCount = rhs.rowCount;
- columnCount = rhs.columnCount;
- defaultValue = rhs.defaultValue;
-
- rows.clear();
-
- for(Map.Entry entry : rhs.rows.entrySet()){
- rows.put(entry.getKey(), entry.getValue().makeCopy());
- }
- }
-
-
-
- public void set(int rowIndex, int columnIndex, double value){
- Vec row = rowAt(rowIndex);
- row.set(columnIndex, value);
- if(rowIndex >= rowCount) { rowCount = rowIndex+1; }
- if(columnIndex >= columnCount) { columnCount = columnIndex + 1; }
- }
-
-
-
- public Matrix(int rowCount, int columnCount){
- this.rowCount = rowCount;
- this.columnCount = columnCount;
- this.defaultValue = 0;
- }
-
- public Vec rowAt(int rowIndex){
- Vec row = rows.get(rowIndex);
- if(row == null){
- row = new Vec(columnCount);
- row.setAll(defaultValue);
- row.setId(rowIndex);
- rows.put(rowIndex, row);
- }
- return row;
- }
-
- public void setAll(double value){
- defaultValue = value;
- for(Vec row : rows.values()){
- row.setAll(value);
- }
- }
-
- public double get(int rowIndex, int columnIndex) {
- Vec row= rowAt(rowIndex);
- return row.get(columnIndex);
- }
-
- public List columnVectors()
- {
- Matrix A = this;
- int n = A.getColumnCount();
- int rowCount = A.getRowCount();
-
- List Acols = new ArrayList();
-
- for (int c = 0; c < n; ++c)
- {
- Vec Acol = new Vec(rowCount);
- Acol.setAll(defaultValue);
- Acol.setId(c);
-
- for (int r = 0; r < rowCount; ++r)
- {
- Acol.set(r, A.get(r, c));
- }
- Acols.add(Acol);
- }
- return Acols;
- }
-
- public Matrix multiply(Matrix rhs)
- {
- if(this.getColumnCount() != rhs.getRowCount()){
- System.err.println("A.columnCount must be equal to B.rowCount in multiplication");
- return null;
- }
-
- Vec row1;
- Vec col2;
-
- Matrix result = new Matrix(getRowCount(), rhs.getColumnCount());
- result.setAll(defaultValue);
-
- List rhsColumns = rhs.columnVectors();
-
- for (Map.Entry entry : rows.entrySet())
- {
- int r1 = entry.getKey();
- row1 = entry.getValue();
- for (int c2 = 0; c2 < rhsColumns.size(); ++c2)
- {
- col2 = rhsColumns.get(c2);
- result.set(r1, c2, row1.multiply(col2));
- }
- }
-
- return result;
- }
-
- @JSONField(serialize = false)
- public boolean isSymmetric(){
- if (getRowCount() != getColumnCount()) return false;
-
- for (Map.Entry rowEntry : rows.entrySet())
- {
- int row = rowEntry.getKey();
- Vec rowVec = rowEntry.getValue();
-
- for (Integer col : rowVec.getData().keySet())
- {
- if (row == col.intValue()) continue;
- if(DoubleUtils.equals(rowVec.get(col), this.get(col, row))){
- return false;
- }
- }
- }
-
- return true;
- }
-
- public Vec multiply(Vec rhs)
- {
- if(this.getColumnCount() != rhs.getDimension()){
- System.err.println("columnCount must be equal to the size of the vector for multiplication");
- }
-
- Vec row1;
- Vec result = new Vec(getRowCount());
- for (Map.Entry entry : rows.entrySet())
- {
- row1 = entry.getValue();
- result.set(entry.getKey(), row1.multiply(rhs));
- }
- return result;
- }
-
-
-
+ private Map rows = new HashMap<>();
+ private int rowCount;
+ private int columnCount;
+ private double defaultValue;
+
+ public Matrix() {
+
+ }
+
+ public Matrix(double[][] A) {
+ for (int i = 0; i < A.length; ++i) {
+ double[] B = A[i];
+ for (int j = 0; j < B.length; ++j) {
+ set(i, j, B[j]);
+ }
+ }
+ }
+
+ public void setRow(int rowIndex, Vec rowVector) {
+ rowVector.setId(rowIndex);
+ rows.put(rowIndex, rowVector);
+ }
+
+ public static Matrix identity(int dimension) {
+ Matrix m = new Matrix(dimension, dimension);
+ for (int i = 0; i < m.getRowCount(); ++i) {
+ m.set(i, i, 1);
+ }
+ return m;
+ }
+
+ public int getRowCount() {
+ return this.rowCount;
+ }
+
+ @Override
+ public boolean equals(Object rhs) {
+ if (rhs != null && rhs instanceof Matrix) {
+ Matrix rhs2 = (Matrix) rhs;
+ if (rowCount != rhs2.rowCount || columnCount != rhs2.columnCount) {
+ return false;
+ }
+
+ if (defaultValue == rhs2.defaultValue) {
+ for (Integer index : rows.keySet()) {
+ if (!rhs2.rows.containsKey(index)) return false;
+ if (!rows.get(index).equals(rhs2.rows.get(index))) {
+ System.out.println("failed!");
+ return false;
+ }
+ }
+
+ for (Integer index : rhs2.rows.keySet()) {
+ if (!rows.containsKey(index)) return false;
+ if (!rhs2.rows.get(index).equals(rows.get(index))) {
+ System.out.println("failed! 22");
+ return false;
+ }
+ }
+ } else {
+
+ for (int i = 0; i < rowCount; ++i) {
+ for (int j = 0; j < columnCount; ++j) {
+ if (this.get(i, j) != rhs2.get(i, j)) {
+ return false;
+ }
+ }
+ }
+ }
+
+ return true;
+ }
+
+ return false;
+ }
+
+ public Matrix makeCopy() {
+ Matrix clone = new Matrix(rowCount, columnCount);
+ clone.copy(this);
+ return clone;
+ }
+
+ public void copy(Matrix rhs) {
+ rowCount = rhs.rowCount;
+ columnCount = rhs.columnCount;
+ defaultValue = rhs.defaultValue;
+
+ rows.clear();
+
+ for (Map.Entry entry : rhs.rows.entrySet()) {
+ rows.put(entry.getKey(), entry.getValue().makeCopy());
+ }
+ }
+
+ public void set(int rowIndex, int columnIndex, double value) {
+ Vec row = rowAt(rowIndex);
+ row.set(columnIndex, value);
+ if (rowIndex >= rowCount) {
+ rowCount = rowIndex + 1;
+ }
+ if (columnIndex >= columnCount) {
+ columnCount = columnIndex + 1;
+ }
+ }
+
+ public Matrix(int rowCount, int columnCount) {
+ this.rowCount = rowCount;
+ this.columnCount = columnCount;
+ this.defaultValue = 0;
+ }
+
+ public Vec rowAt(int rowIndex) {
+ Vec row = rows.get(rowIndex);
+ if (row == null) {
+ row = new Vec(columnCount);
+ row.setAll(defaultValue);
+ row.setId(rowIndex);
+ rows.put(rowIndex, row);
+ }
+ return row;
+ }
+
+ public void setAll(double value) {
+ defaultValue = value;
+ for (Vec row : rows.values()) {
+ row.setAll(value);
+ }
+ }
+
+ public double get(int rowIndex, int columnIndex) {
+ Vec row = rowAt(rowIndex);
+ return row.get(columnIndex);
+ }
+
+ public List columnVectors() {
+ Matrix A = this;
+ int n = A.getColumnCount();
+ int rowCount = A.getRowCount();
+
+ List Acols = new ArrayList<>();
+
+ for (int c = 0; c < n; ++c) {
+ Vec Acol = new Vec(rowCount);
+ Acol.setAll(defaultValue);
+ Acol.setId(c);
+
+ for (int r = 0; r < rowCount; ++r) {
+ Acol.set(r, A.get(r, c));
+ }
+ Acols.add(Acol);
+ }
+ return Acols;
+ }
+
+ public int getColumnCount() {
+ return this.columnCount;
+ }
+
+ public Matrix multiply(Matrix rhs) {
+ if (this.getColumnCount() != rhs.getRowCount()) {
+ System.err.println("A.columnCount must be equal to B.rowCount in multiplication");
+ return null;
+ }
+
+ Vec row1;
+ Vec col2;
+
+ Matrix result = new Matrix(getRowCount(), rhs.getColumnCount());
+ result.setAll(defaultValue);
+
+ List rhsColumns = rhs.columnVectors();
+
+ for (Map.Entry entry : rows.entrySet()) {
+ int r1 = entry.getKey();
+ row1 = entry.getValue();
+ for (int c2 = 0; c2 < rhsColumns.size(); ++c2) {
+ col2 = rhsColumns.get(c2);
+ result.set(r1, c2, row1.multiply(col2));
+ }
+ }
+
+ return result;
+ }
+
+// @JSONField(serialize = false)
+ public boolean isSymmetric() {
+ if (getRowCount() != getColumnCount()) return false;
+
+ for (Map.Entry rowEntry : rows.entrySet()) {
+ int row = rowEntry.getKey();
+ Vec rowVec = rowEntry.getValue();
+
+ for (Integer col : rowVec.getData().keySet()) {
+ if (row == col) continue;
+ if (DoubleUtils.equals(rowVec.get(col), this.get(col, row))) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+ }
+
+ public Vec multiply(Vec rhs) {
+ if (this.getColumnCount() != rhs.getDimension()) {
+ System.err
+ .println("columnCount must be equal to the size of the vector for multiplication");
+ }
+
+ Vec row1;
+ Vec result = new Vec(getRowCount());
+ for (Map.Entry entry : rows.entrySet()) {
+ row1 = entry.getValue();
+ result.set(entry.getKey(), row1.multiply(rhs));
+ }
+ return result;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java b/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java
index e43c28b..2bc8631 100644
--- a/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java
+++ b/src/main/java/com/github/chen0040/rl/utils/MatrixUtils.java
@@ -2,28 +2,24 @@
import java.util.List;
-
/**
* Created by xschen on 10/11/2015 0011.
*/
public class MatrixUtils {
- /**
- * Convert a list of column vectors into a matrix
- */
- public static Matrix matrixFromColumnVectors(List R)
- {
- int n = R.size();
- int m = R.get(0).getDimension();
+ /**
+ * Convert a list of column vectors into a matrix
+ */
+ public static Matrix matrixFromColumnVectors(List R) {
+ int n = R.size();
+ int m = R.get(0).getDimension();
- Matrix T = new Matrix(m, n);
- for (int c = 0; c < n; ++c)
- {
- Vec Rcol = R.get(c);
- for (int r : Rcol.getData().keySet())
- {
- T.set(r, c, Rcol.get(r));
- }
- }
- return T;
- }
+ Matrix T = new Matrix(m, n);
+ for (int c = 0; c < n; ++c) {
+ Vec Rcol = R.get(c);
+ for (int r : Rcol.getData().keySet()) {
+ T.set(r, c, Rcol.get(r));
+ }
+ }
+ return T;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java b/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java
index b4895ea..f959a85 100644
--- a/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java
+++ b/src/main/java/com/github/chen0040/rl/utils/TupleTwo.java
@@ -1,56 +1,58 @@
package com.github.chen0040.rl.utils;
+import java.util.Objects;
+
/**
* Created by xschen on 10/11/2015 0011.
*/
public class TupleTwo {
- private T1 item1;
- private T2 item2;
-
- public TupleTwo(T1 item1, T2 item2){
- this.item1 = item1;
- this.item2 = item2;
- }
-
- public T1 getItem1() {
- return item1;
- }
-
- public void setItem1(T1 item1) {
- this.item1 = item1;
- }
-
- public T2 getItem2() {
- return item2;
- }
-
- public void setItem2(T2 item2) {
- this.item2 = item2;
- }
-
- public static TupleTwo create(U1 item1, U2 item2){
- return new TupleTwo(item1, item2);
- }
-
-
- @Override public boolean equals(Object o) {
- if (this == o)
- return true;
- if (o == null || getClass() != o.getClass())
- return false;
-
- TupleTwo, ?> tupleTwo = (TupleTwo, ?>) o;
-
- if (item1 != null ? !item1.equals(tupleTwo.item1) : tupleTwo.item1 != null)
- return false;
- return item2 != null ? item2.equals(tupleTwo.item2) : tupleTwo.item2 == null;
-
- }
-
-
- @Override public int hashCode() {
- int result = item1 != null ? item1.hashCode() : 0;
- result = 31 * result + (item2 != null ? item2.hashCode() : 0);
- return result;
- }
+ private T1 item1;
+ private T2 item2;
+
+ public TupleTwo(T1 item1, T2 item2) {
+ this.item1 = item1;
+ this.item2 = item2;
+ }
+
+ public T1 getItem1() {
+ return item1;
+ }
+
+ public void setItem1(T1 item1) {
+ this.item1 = item1;
+ }
+
+ public T2 getItem2() {
+ return item2;
+ }
+
+ public void setItem2(T2 item2) {
+ this.item2 = item2;
+ }
+
+ public static TupleTwo create(U1 item1, U2 item2) {
+ return new TupleTwo<>(item1, item2);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (o == null || getClass() != o.getClass())
+ return false;
+
+ TupleTwo, ?> tupleTwo = (TupleTwo, ?>) o;
+
+ if (!Objects.equals(item1, tupleTwo.item1))
+ return false;
+ return Objects.equals(item2, tupleTwo.item2);
+
+ }
+
+ @Override
+ public int hashCode() {
+ int result = item1 != null ? item1.hashCode() : 0;
+ result = 31 * result + (item2 != null ? item2.hashCode() : 0);
+ return result;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/utils/Vec.java b/src/main/java/com/github/chen0040/rl/utils/Vec.java
index 4699d0e..763005a 100644
--- a/src/main/java/com/github/chen0040/rl/utils/Vec.java
+++ b/src/main/java/com/github/chen0040/rl/utils/Vec.java
@@ -1,7 +1,7 @@
package com.github.chen0040.rl.utils;
-import lombok.Getter;
-import lombok.Setter;
+//import lombok.Getter;
+//import lombok.Setter;
import java.io.Serializable;
import java.util.HashMap;
@@ -9,341 +9,329 @@
import java.util.Map;
import java.util.Set;
-
/**
* Created by xschen on 9/27/2015 0027.
*/
-@Getter
-@Setter
+//@Getter
+//@Setter
public class Vec implements Serializable {
- private Map data = new HashMap();
- private int dimension;
- private double defaultValue;
- private int id = -1;
-
- public Vec(){
-
- }
-
- public Vec(double[] v){
- for(int i=0; i < v.length; ++i){
- set(i, v[i]);
- }
- }
-
- public Vec(int dimension){
- this.dimension = dimension;
- defaultValue = 0;
- }
-
- public Vec(int dimension, Map data){
- this.dimension = dimension;
- defaultValue = 0;
-
- for(Map.Entry entry : data.entrySet()){
- set(entry.getKey(), entry.getValue());
- }
- }
-
- public Vec makeCopy(){
- Vec clone = new Vec(dimension);
- clone.copy(this);
- return clone;
- }
-
- public void copy(Vec rhs){
- defaultValue = rhs.defaultValue;
- dimension = rhs.dimension;
- id = rhs.id;
-
- data.clear();
- for(Map.Entry entry : rhs.data.entrySet()){
- data.put(entry.getKey(), entry.getValue());
- }
- }
-
- public void set(int i, double value){
- if(value == defaultValue) return;
-
- data.put(i, value);
- if(i >= dimension){
- dimension = i+1;
- }
- }
-
-
- public double get(int i){
- return data.getOrDefault(i, defaultValue);
- }
-
- @Override
- public boolean equals(Object rhs){
- if(rhs != null && rhs instanceof Vec){
- Vec rhs2 = (Vec)rhs;
- if(dimension != rhs2.dimension){
- return false;
- }
-
- if(data.size() != rhs2.data.size()){
- return false;
- }
-
- for(Integer index : data.keySet()){
- if(!rhs2.data.containsKey(index)) return false;
- if(!DoubleUtils.equals(data.get(index), rhs2.data.get(index))){
- return false;
- }
- }
-
- if(defaultValue != rhs2.defaultValue){
- for(int i=0; i < dimension; ++i){
- if(data.containsKey(i)){
- return false;
- }
- }
- }
-
- return true;
- }
-
- return false;
- }
-
- public void setAll(double value){
- defaultValue = value;
- for(Integer index : data.keySet()){
- data.put(index, defaultValue);
- }
- }
-
- public IndexValue indexWithMaxValue(Set indices){
- if(indices == null){
- return indexWithMaxValue();
- }else{
- IndexValue iv = new IndexValue();
- iv.setIndex(-1);
- iv.setValue(Double.NEGATIVE_INFINITY);
- for(Integer index : indices){
- double value = data.getOrDefault(index, Double.NEGATIVE_INFINITY);
- if(value > iv.getValue()){
- iv.setIndex(index);
- iv.setValue(value);
- }
- }
- return iv;
- }
- }
-
- public IndexValue indexWithMaxValue(){
- IndexValue iv = new IndexValue();
- iv.setIndex(-1);
- iv.setValue(Double.NEGATIVE_INFINITY);
-
-
- for(Map.Entry entry : data.entrySet()){
- if(entry.getKey() >= dimension) continue;
-
- double value = entry.getValue();
- if(value > iv.getValue()){
- iv.setValue(value);
- iv.setIndex(entry.getKey());
- }
- }
-
- if(!iv.isValid()){
- iv.setValue(defaultValue);
- } else{
- if(iv.getValue() < defaultValue){
- for(int i=0; i < dimension; ++i){
- if(!data.containsKey(i)){
- iv.setValue(defaultValue);
- iv.setIndex(i);
- break;
- }
- }
- }
- }
-
- return iv;
- }
-
-
-
- public Vec projectOrthogonal(Iterable vlist) {
- Vec b = this;
- for(Vec v : vlist)
- {
- b = b.minus(b.projectAlong(v));
- }
-
- return b;
- }
-
- public Vec projectOrthogonal(List vlist, Map alpha) {
- Vec b = this;
- for(int i = 0; i < vlist.size(); ++i)
- {
- Vec v = vlist.get(i);
- double norm_a = v.multiply(v);
-
- if (DoubleUtils.isZero(norm_a)) {
- return new Vec(dimension);
- }
- double sigma = multiply(v) / norm_a;
- Vec v_parallel = v.multiply(sigma);
-
- alpha.put(i, sigma);
-
- b = b.minus(v_parallel);
- }
-
- return b;
- }
-
- public Vec projectAlong(Vec rhs)
- {
- double norm_a = rhs.multiply(rhs);
-
- if (DoubleUtils.isZero(norm_a)) {
- return new Vec(dimension);
- }
- double sigma = multiply(rhs) / norm_a;
- return rhs.multiply(sigma);
- }
-
- public Vec multiply(double rhs){
- Vec clone = (Vec)this.makeCopy();
- for(Integer i : data.keySet()){
- clone.data.put(i, rhs * data.get(i));
- }
- return clone;
- }
-
- public double multiply(Vec rhs)
- {
- double productSum = 0;
- if(defaultValue == 0) {
- for (Map.Entry entry : data.entrySet()) {
- productSum += entry.getValue() * rhs.get(entry.getKey());
- }
- } else {
- for(int i=0; i < dimension; ++i){
- productSum += get(i) * rhs.get(i);
- }
- }
-
- return productSum;
- }
-
- public Vec pow(double scalar)
- {
- Vec result = new Vec(dimension);
- for (Map.Entry entry : data.entrySet())
- {
- result.data.put(entry.getKey(), Math.pow(entry.getValue(), scalar));
- }
- return result;
- }
-
- public Vec add(Vec rhs)
- {
- Vec result = new Vec(dimension);
- int index;
- for (Map.Entry entry : data.entrySet()) {
- index = entry.getKey();
- result.data.put(index, entry.getValue() + rhs.data.get(index));
- }
- for(Map.Entry entry : rhs.data.entrySet()){
- index = entry.getKey();
- if(result.data.containsKey(index)) continue;
- result.data.put(index, entry.getValue() + data.get(index));
- }
-
- return result;
- }
-
- public Vec minus(Vec rhs)
- {
- Vec result = new Vec(dimension);
- int index;
- for (Map.Entry entry : data.entrySet()) {
- index = entry.getKey();
- result.data.put(index, entry.getValue() - rhs.data.get(index));
- }
- for(Map.Entry entry : rhs.data.entrySet()){
- index = entry.getKey();
- if(result.data.containsKey(index)) continue;
- result.data.put(index, data.get(index) - entry.getValue());
- }
-
- return result;
- }
-
- public double sum(){
- double sum = 0;
-
- for(Map.Entry entry : data.entrySet()){
- sum += entry.getValue();
- }
- sum += defaultValue * (dimension - data.size());
-
- return sum;
- }
-
- public boolean isZero(){
- return DoubleUtils.isZero(sum());
- }
-
- public double norm(int level)
- {
- if (level == 1)
- {
- double sum = 0;
- for (Double val : data.values())
- {
- sum += Math.abs(val);
- }
- if(!DoubleUtils.isZero(defaultValue)) {
- sum += Math.abs(defaultValue) * (dimension - data.size());
- }
- return sum;
- }
- else if (level == 2)
- {
- double sum = multiply(this);
- if(!DoubleUtils.isZero(defaultValue)){
- sum += (dimension - data.size()) * (defaultValue * defaultValue);
- }
- return Math.sqrt(sum);
- }
- else
- {
- double sum = 0;
- for (Double val : this.data.values())
- {
- sum += Math.pow(Math.abs(val), level);
- }
- if(!DoubleUtils.isZero(defaultValue)) {
- sum += Math.pow(Math.abs(defaultValue), level) * (dimension - data.size());
- }
- return Math.pow(sum, 1.0 / level);
- }
- }
-
- public Vec normalize()
- {
- double norm = norm(2); // L2 norm is the cartesian distance
- if (DoubleUtils.isZero(norm))
- {
- return new Vec(dimension);
- }
- Vec clone = new Vec(dimension);
- clone.setAll(defaultValue / norm);
-
- for (Integer k : data.keySet())
- {
- clone.data.put(k, data.get(k) / norm);
- }
- return clone;
- }
+ private Map data = new HashMap<>();
+ private int dimension;
+ private double defaultValue;
+ private int id = -1;
+
+ public Vec() {
+
+ }
+
+ public Vec(double[] v) {
+ for (int i = 0; i < v.length; ++i) {
+ set(i, v[i]);
+ }
+ }
+
+ public Vec(int dimension) {
+ this.dimension = dimension;
+ defaultValue = 0;
+ }
+
+ public Vec(int dimension, Map data) {
+ this.dimension = dimension;
+ defaultValue = 0;
+
+ for (Map.Entry entry : data.entrySet()) {
+ set(entry.getKey(), entry.getValue());
+ }
+ }
+
+ public Vec makeCopy() {
+ Vec clone = new Vec(dimension);
+ clone.copy(this);
+ return clone;
+ }
+
+ public void copy(Vec rhs) {
+ defaultValue = rhs.defaultValue;
+ dimension = rhs.dimension;
+ id = rhs.id;
+
+ data.clear();
+ for (Map.Entry entry : rhs.data.entrySet()) {
+ data.put(entry.getKey(), entry.getValue());
+ }
+ }
+
+ public void set(int i, double value) {
+ if (value == defaultValue) return;
+
+ data.put(i, value);
+ if (i >= dimension) {
+ dimension = i + 1;
+ }
+ }
+
+ public double get(int i) {
+ return data.getOrDefault(i, defaultValue);
+ }
+
+ @Override
+ public boolean equals(Object rhs) {
+ if (rhs != null && rhs instanceof Vec) {
+ Vec rhs2 = (Vec) rhs;
+ if (dimension != rhs2.dimension) {
+ return false;
+ }
+
+ if (data.size() != rhs2.data.size()) {
+ return false;
+ }
+
+ for (Integer index : data.keySet()) {
+ if (!rhs2.data.containsKey(index)) return false;
+ if (!DoubleUtils.equals(data.get(index), rhs2.data.get(index))) {
+ return false;
+ }
+ }
+
+ if (defaultValue != rhs2.defaultValue) {
+ for (int i = 0; i < dimension; ++i) {
+ if (data.containsKey(i)) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+ }
+
+ return false;
+ }
+
+ public void setAll(double value) {
+ defaultValue = value;
+ for (Integer index : data.keySet()) {
+ data.put(index, defaultValue);
+ }
+ }
+
+ public IndexValue indexWithMaxValue(Set indices) {
+ if (indices == null) {
+ return indexWithMaxValue();
+ } else {
+ IndexValue iv = new IndexValue();
+ iv.setIndex(-1);
+ iv.setValue(Double.NEGATIVE_INFINITY);
+ for (Integer index : indices) {
+ double value = data.getOrDefault(index, Double.NEGATIVE_INFINITY);
+ if (value > iv.getValue()) {
+ iv.setIndex(index);
+ iv.setValue(value);
+ }
+ }
+ return iv;
+ }
+ }
+
+ public IndexValue indexWithMaxValue() {
+ IndexValue iv = new IndexValue();
+ iv.setIndex(-1);
+ iv.setValue(Double.NEGATIVE_INFINITY);
+
+ for (Map.Entry entry : data.entrySet()) {
+ if (entry.getKey() >= dimension) continue;
+
+ double value = entry.getValue();
+ if (value > iv.getValue()) {
+ iv.setValue(value);
+ iv.setIndex(entry.getKey());
+ }
+ }
+
+ if (!iv.isValid()) {
+ iv.setValue(defaultValue);
+ } else {
+ if (iv.getValue() < defaultValue) {
+ for (int i = 0; i < dimension; ++i) {
+ if (!data.containsKey(i)) {
+ iv.setValue(defaultValue);
+ iv.setIndex(i);
+ break;
+ }
+ }
+ }
+ }
+
+ return iv;
+ }
+
+ public Vec projectOrthogonal(Iterable vlist) {
+ Vec b = this;
+ for (Vec v : vlist) {
+ b = b.minus(b.projectAlong(v));
+ }
+
+ return b;
+ }
+
+ public Vec projectOrthogonal(List vlist, Map alpha) {
+ Vec b = this;
+ for (int i = 0; i < vlist.size(); ++i) {
+ Vec v = vlist.get(i);
+ double norm_a = v.multiply(v);
+
+ if (DoubleUtils.isZero(norm_a)) {
+ return new Vec(dimension);
+ }
+ double sigma = multiply(v) / norm_a;
+ Vec v_parallel = v.multiply(sigma);
+
+ alpha.put(i, sigma);
+
+ b = b.minus(v_parallel);
+ }
+
+ return b;
+ }
+
+ public Vec projectAlong(Vec rhs) {
+ double norm_a = rhs.multiply(rhs);
+
+ if (DoubleUtils.isZero(norm_a)) {
+ return new Vec(dimension);
+ }
+ double sigma = multiply(rhs) / norm_a;
+ return rhs.multiply(sigma);
+ }
+
+ public Vec multiply(double rhs) {
+ Vec clone = (Vec) this.makeCopy();
+ for (Integer i : data.keySet()) {
+ clone.data.put(i, rhs * data.get(i));
+ }
+ return clone;
+ }
+
+ public double multiply(Vec rhs) {
+ double productSum = 0;
+ if (defaultValue == 0) {
+ for (Map.Entry entry : data.entrySet()) {
+ productSum += entry.getValue() * rhs.get(entry.getKey());
+ }
+ } else {
+ for (int i = 0; i < dimension; ++i) {
+ productSum += get(i) * rhs.get(i);
+ }
+ }
+
+ return productSum;
+ }
+
+ public Vec pow(double scalar) {
+ Vec result = new Vec(dimension);
+ for (Map.Entry entry : data.entrySet()) {
+ result.data.put(entry.getKey(), Math.pow(entry.getValue(), scalar));
+ }
+ return result;
+ }
+
+ public Vec add(Vec rhs) {
+ Vec result = new Vec(dimension);
+ int index;
+ for (Map.Entry entry : data.entrySet()) {
+ index = entry.getKey();
+ result.data.put(index, entry.getValue() + rhs.data.get(index));
+ }
+ for (Map.Entry entry : rhs.data.entrySet()) {
+ index = entry.getKey();
+ if (result.data.containsKey(index)) continue;
+ result.data.put(index, entry.getValue() + data.get(index));
+ }
+
+ return result;
+ }
+
+ public Vec minus(Vec rhs) {
+ Vec result = new Vec(dimension);
+ int index;
+ for (Map.Entry entry : data.entrySet()) {
+ index = entry.getKey();
+ result.data.put(index, entry.getValue() - rhs.data.get(index));
+ }
+ for (Map.Entry entry : rhs.data.entrySet()) {
+ index = entry.getKey();
+ if (result.data.containsKey(index)) continue;
+ result.data.put(index, data.get(index) - entry.getValue());
+ }
+
+ return result;
+ }
+
+ public double sum() {
+ double sum = 0;
+
+ for (Map.Entry entry : data.entrySet()) {
+ sum += entry.getValue();
+ }
+ sum += defaultValue * (dimension - data.size());
+
+ return sum;
+ }
+
+ public boolean isZero() {
+ return DoubleUtils.isZero(sum());
+ }
+
+ public double norm(int level) {
+ if (level == 1) {
+ double sum = 0;
+ for (Double val : data.values()) {
+ sum += Math.abs(val);
+ }
+ if (!DoubleUtils.isZero(defaultValue)) {
+ sum += Math.abs(defaultValue) * (dimension - data.size());
+ }
+ return sum;
+ } else if (level == 2) {
+ double sum = multiply(this);
+ if (!DoubleUtils.isZero(defaultValue)) {
+ sum += (dimension - data.size()) * (defaultValue * defaultValue);
+ }
+ return Math.sqrt(sum);
+ } else {
+ double sum = 0;
+ for (Double val : this.data.values()) {
+ sum += Math.pow(Math.abs(val), level);
+ }
+ if (!DoubleUtils.isZero(defaultValue)) {
+ sum += Math.pow(Math.abs(defaultValue), level) * (dimension - data.size());
+ }
+ return Math.pow(sum, 1.0 / level);
+ }
+ }
+
+ public Vec normalize() {
+ double norm = norm(2); // L2 norm is the cartesian distance
+ if (DoubleUtils.isZero(norm)) {
+ return new Vec(dimension);
+ }
+ Vec clone = new Vec(dimension);
+ clone.setAll(defaultValue / norm);
+
+ for (Integer k : data.keySet()) {
+ clone.data.put(k, data.get(k) / norm);
+ }
+ return clone;
+ }
+
+ public void setId(int rowIndex) {
+ this.id = rowIndex;
+ }
+
+ public Map getData() {
+ return this.data;
+ }
+
+ public int getDimension() {
+ return this.dimension;
+ }
}
diff --git a/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java b/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java
index 2bbfbaa..da79781 100644
--- a/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java
+++ b/src/main/java/com/github/chen0040/rl/utils/VectorUtils.java
@@ -3,37 +3,30 @@
import java.util.ArrayList;
import java.util.List;
-
/**
* Created by xschen on 10/11/2015 0011.
*/
public class VectorUtils {
- public static List removeZeroVectors(Iterable vlist)
- {
- List vstarlist = new ArrayList();
- for (Vec v : vlist)
- {
- if (!v.isZero())
- {
- vstarlist.add(v);
- }
- }
-
- return vstarlist;
- }
-
- public static TupleTwo, List> normalize(Iterable vlist)
- {
- List norms = new ArrayList();
- List vstarlist = new ArrayList();
- for (Vec v : vlist)
- {
- norms.add(v.norm(2));
- vstarlist.add(v.normalize());
- }
-
- return TupleTwo.create(vstarlist, norms);
- }
-
+ public static List removeZeroVectors(Iterable vlist) {
+ List vstarlist = new ArrayList<>();
+ for (Vec v : vlist) {
+ if (!v.isZero()) {
+ vstarlist.add(v);
+ }
+ }
+
+ return vstarlist;
+ }
+
+ public static TupleTwo, List> normalize(Iterable vlist) {
+ List norms = new ArrayList<>();
+ List vstarlist = new ArrayList<>();
+ for (Vec v : vlist) {
+ norms.add(v.norm(2));
+ vstarlist.add(v.normalize());
+ }
+
+ return TupleTwo.create(vstarlist, norms);
+ }
}
diff --git a/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgentUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgentUnitTest.java
index 7b3d094..3c9d7cc 100644
--- a/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgentUnitTest.java
+++ b/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticAgentUnitTest.java
@@ -1,46 +1,45 @@
package com.github.chen0040.rl.learning.actorcritic;
-
import com.github.chen0040.rl.utils.Vec;
+
import org.testng.annotations.Test;
import java.util.Random;
import static org.testng.Assert.*;
-
/**
* Created by xschen on 6/5/2017.
*/
public class ActorCriticAgentUnitTest {
- @Test
- public void test_learn(){
- int stateCount = 100;
- int actionCount = 10;
+ @Test
+ public void test_learn() {
+ int stateCount = 100;
+ int actionCount = 10;
- ActorCriticAgent agent = new ActorCriticAgent(stateCount, actionCount);
- Vec stateValues = new Vec(stateCount);
+ ActorCriticAgent agent = new ActorCriticAgent(stateCount, actionCount);
+ Vec stateValues = new Vec(stateCount);
- Random random = new Random();
- agent.start(random.nextInt(stateCount));
- for(int time=0; time < 1000; ++time){
+ Random random = new Random();
+ agent.start(random.nextInt(stateCount));
+ for (int time = 0; time < 1000; ++time) {
- int actionId = agent.selectAction();
- System.out.println("Agent does action-"+actionId);
+ int actionId = agent.selectAction();
+ System.out.println("Agent does action-" + actionId);
- int newStateId = random.nextInt(actionCount);
- double reward = random.nextDouble();
+ int newStateId = random.nextInt(actionCount);
+ double reward = random.nextDouble();
- System.out.println("Now the new state is "+newStateId);
- System.out.println("Agent receives Reward = "+reward);
+ System.out.println("Now the new state is " + newStateId);
+ System.out.println("Agent receives Reward = " + reward);
- System.out.println("World state values changed ...");
- for(int stateId = 0; stateId < stateCount; ++stateId){
- stateValues.set(stateId, random.nextDouble());
- }
+ System.out.println("World state values changed ...");
+ for (int stateId = 0; stateId < stateCount; ++stateId) {
+ stateValues.set(stateId, random.nextDouble());
+ }
- agent.update(actionId, newStateId, reward, stateValues);
- }
- }
+ agent.update(actionId, newStateId, reward, stateValues);
+ }
+ }
}
diff --git a/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearnerUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearnerUnitTest.java
index d6bafac..3c0bdf8 100644
--- a/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearnerUnitTest.java
+++ b/src/test/java/com/github/chen0040/rl/learning/actorcritic/ActorCriticLearnerUnitTest.java
@@ -1,6 +1,7 @@
package com.github.chen0040.rl.learning.actorcritic;
import com.github.chen0040.rl.utils.Vec;
+
import org.testng.annotations.Test;
import java.util.Random;
@@ -12,39 +13,39 @@
*/
public class ActorCriticLearnerUnitTest {
- @Test
- public void test_learn(){
- int stateCount = 100;
- int actionCount = 10;
+ @Test
+ public void test_learn() {
+ int stateCount = 100;
+ int actionCount = 10;
- ActorCriticLearner learner = new ActorCriticLearner(stateCount, actionCount);
- final Vec stateValues = new Vec(stateCount);
+ ActorCriticLearner learner = new ActorCriticLearner(stateCount, actionCount);
+ final Vec stateValues = new Vec(stateCount);
- Random random = new Random();
- int currentStateId = random.nextInt(stateCount);
- for(int time=0; time < 1000; ++time){
+ Random random = new Random();
+ int currentStateId = random.nextInt(stateCount);
+ for (int time = 0; time < 1000; ++time) {
- int actionId = learner.selectAction(currentStateId);
- System.out.println("Agent does action-"+actionId);
+ int actionId = learner.selectAction(currentStateId);
+ System.out.println("Agent does action-" + actionId);
- int newStateId = random.nextInt(actionCount);
- double reward = random.nextDouble();
+ int newStateId = random.nextInt(actionCount);
+ double reward = random.nextDouble();
- System.out.println("Now the new state is "+newStateId);
- System.out.println("Agent receives Reward = "+reward);
+ System.out.println("Now the new state is " + newStateId);
+ System.out.println("Agent receives Reward = " + reward);
- System.out.println("World state values changed ...");
- for(int stateId = 0; stateId < stateCount; ++stateId){
- stateValues.set(stateId, random.nextDouble());
- }
+ System.out.println("World state values changed ...");
+ for (int stateId = 0; stateId < stateCount; ++stateId) {
+ stateValues.set(stateId, random.nextDouble());
+ }
- learner.update(currentStateId, actionId, newStateId, reward, stateValues::get);
- }
+ learner.update(currentStateId, actionId, newStateId, reward, stateValues::get);
+ }
- ActorCriticLearner learner2 = ActorCriticLearner.fromJson(learner.toJson());
+ ActorCriticLearner learner2 = ActorCriticLearner.fromJson(learner.toJson());
- assertThat(learner2.getP()).isEqualTo(learner.getP());
- assertThat(learner2.getActionSelection()).isEqualTo(learner.getActionSelection());
- assertThat(learner2).isEqualTo(learner);
- }
+ assertThat(learner2.getP()).isEqualTo(learner.getP());
+ assertThat(learner2.getActionSelection()).isEqualTo(learner.getActionSelection());
+ assertThat(learner2).isEqualTo(learner);
+ }
}
diff --git a/src/test/java/com/github/chen0040/rl/learning/models/QModelUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/models/QModelUnitTest.java
index ee30dc4..6b5b763 100644
--- a/src/test/java/com/github/chen0040/rl/learning/models/QModelUnitTest.java
+++ b/src/test/java/com/github/chen0040/rl/learning/models/QModelUnitTest.java
@@ -1,34 +1,33 @@
package com.github.chen0040.rl.learning.models;
-import com.alibaba.fastjson.JSON;
+//import com.alibaba.fastjson.JSON;
import com.github.chen0040.rl.models.QModel;
+import com.google.gson.Gson;
+
import org.testng.annotations.Test;
import static org.assertj.core.api.Java6Assertions.assertThat;
public class QModelUnitTest {
- @Test
- public void testJsonSerialization() {
- QModel model = new QModel(100, 10);
-
- model.setQ(3, 4, 0.3);
- model.setQ(92, 2, 0.2);
-
- model.setAlpha(0.4);
- model.setGamma(0.3);
-
- String json = JSON.toJSONString(model);
- QModel model2 = JSON.parseObject(json, QModel.class);
+ @Test
+ public void testJsonSerialization() {
+ QModel model = new QModel(100, 10);
- assertThat(model).isEqualTo(model2);
- assertThat(model.getQ()).isEqualTo(model2.getQ());
- assertThat(model.getAlphaMatrix()).isEqualTo(model2.getAlphaMatrix());
- assertThat(model.getStateCount()).isEqualTo(model2.getStateCount());
- assertThat(model.getActionCount()).isEqualTo(model2.getActionCount());
- assertThat(model.getGamma()).isEqualTo(model2.getGamma());
+ model.setQ(3, 4, 0.3);
+ model.setQ(92, 2, 0.2);
+ model.setAlpha(0.4);
+ model.setGamma(0.3);
+ String json = new Gson().toJson(model); //JSON.toJSONString(model);
+ QModel model2 = new Gson().fromJson(json, QModel.class); //JSON.parseObject(json, QModel.class);
+ assertThat(model).isEqualTo(model2);
+ assertThat(model.getQ()).isEqualTo(model2.getQ());
+ assertThat(model.getAlphaMatrix()).isEqualTo(model2.getAlphaMatrix());
+ assertThat(model.getStateCount()).isEqualTo(model2.getStateCount());
+ assertThat(model.getActionCount()).isEqualTo(model2.getActionCount());
+ assertThat(model.getGamma()).isEqualTo(model2.getGamma());
- }
+ }
}
diff --git a/src/test/java/com/github/chen0040/rl/learning/qlearn/QAgentUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/qlearn/QAgentUnitTest.java
index 627d7d0..c7bed39 100644
--- a/src/test/java/com/github/chen0040/rl/learning/qlearn/QAgentUnitTest.java
+++ b/src/test/java/com/github/chen0040/rl/learning/qlearn/QAgentUnitTest.java
@@ -1,41 +1,41 @@
package com.github.chen0040.rl.learning.qlearn;
-
import com.github.chen0040.rl.actionselection.SoftMaxActionSelectionStrategy;
+
import org.testng.annotations.Test;
import java.util.Random;
import static org.testng.Assert.*;
-
/**
* Created by xschen on 6/5/2017.
*/
public class QAgentUnitTest {
- @Test
- public void test_q_learn(){
- int stateCount = 100;
- int actionCount = 10;
- QAgent agent = new QAgent(stateCount, actionCount);
+ @Test
+ public void test_q_learn() {
+ int stateCount = 100;
+ int actionCount = 10;
+ QAgent agent = new QAgent(stateCount, actionCount);
- agent.getLearner().setActionSelection(SoftMaxActionSelectionStrategy.class.getCanonicalName());
+ agent.getLearner()
+ .setActionSelection(SoftMaxActionSelectionStrategy.class.getCanonicalName());
- Random random = new Random();
- agent.start(random.nextInt(stateCount));
- for(int time=0; time < 1000; ++time){
+ Random random = new Random();
+ agent.start(random.nextInt(stateCount));
+ for (int time = 0; time < 1000; ++time) {
- int actionId = agent.selectAction().getIndex();
- System.out.println("Agent does action-"+actionId);
+ int actionId = agent.selectAction().getIndex();
+ System.out.println("Agent does action-" + actionId);
- int newStateId = random.nextInt(actionCount);
- double reward = random.nextDouble();
+ int newStateId = random.nextInt(actionCount);
+ double reward = random.nextDouble();
- System.out.println("Now the new state is "+newStateId);
- System.out.println("Agent receives Reward = "+reward);
+ System.out.println("Now the new state is " + newStateId);
+ System.out.println("Agent receives Reward = " + reward);
- agent.update(actionId, newStateId, reward);
- }
- }
+ agent.update(actionId, newStateId, reward);
+ }
+ }
}
diff --git a/src/test/java/com/github/chen0040/rl/learning/qlearn/QLearnerUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/qlearn/QLearnerUnitTest.java
index 6685cac..a0023fb 100644
--- a/src/test/java/com/github/chen0040/rl/learning/qlearn/QLearnerUnitTest.java
+++ b/src/test/java/com/github/chen0040/rl/learning/qlearn/QLearnerUnitTest.java
@@ -1,8 +1,8 @@
package com.github.chen0040.rl.learning.qlearn;
+//import com.alibaba.fastjson.JSON;
+//import com.alibaba.fastjson.serializer.SerializerFeature;
-import com.alibaba.fastjson.JSON;
-import com.alibaba.fastjson.serializer.SerializerFeature;
import org.testng.annotations.Test;
import java.util.Random;
@@ -10,59 +10,56 @@
import static org.assertj.core.api.Java6Assertions.assertThat;
import static org.testng.Assert.*;
-
/**
* Created by xschen on 6/5/2017.
*/
public class QLearnerUnitTest {
- private static final int stateCount = 100;
- private static final int actionCount = 10;
-
- @Test
- public void testJsonSerialization() {
-
- QLearner learner = new QLearner(stateCount, actionCount);
+ private static final int stateCount = 100;
+ private static final int actionCount = 10;
- run(learner);
+ @Test
+ public void testJsonSerialization() {
- String json = learner.toJson();
+ QLearner learner = new QLearner(stateCount, actionCount);
+ run(learner);
- QLearner learner2 = QLearner.fromJson(json);
+ String json = learner.toJson();
- assertThat(learner.getModel()).isEqualTo(learner2.getModel());
+ QLearner learner2 = QLearner.fromJson(json);
- assertThat(learner.getActionSelection()).isEqualTo(learner2.getActionSelection());
+ assertThat(learner.getModel()).isEqualTo(learner2.getModel());
- }
+ assertThat(learner.getActionSelection()).isEqualTo(learner2.getActionSelection());
- @Test
- public void test_q_learn(){
+ }
+ @Test
+ public void test_q_learn() {
- QLearner learner = new QLearner(stateCount, actionCount);
+ QLearner learner = new QLearner(stateCount, actionCount);
- run(learner);
+ run(learner);
- }
+ }
- private void run(QLearner learner) {
- Random random = new Random();
- int currentStateId = random.nextInt(stateCount);
- for(int time=0; time < 1000; ++time){
+ private void run(QLearner learner) {
+ Random random = new Random();
+ int currentStateId = random.nextInt(stateCount);
+ for (int time = 0; time < 1000; ++time) {
- int actionId = learner.selectAction(currentStateId).getIndex();
- System.out.println("Controller does action-"+actionId);
+ int actionId = learner.selectAction(currentStateId).getIndex();
+ System.out.println("Controller does action-" + actionId);
- int newStateId = random.nextInt(actionCount);
- double reward = random.nextDouble();
+ int newStateId = random.nextInt(actionCount);
+ double reward = random.nextDouble();
- System.out.println("Now the new state is "+newStateId);
- System.out.println("Controller receives Reward = "+reward);
+ System.out.println("Now the new state is " + newStateId);
+ System.out.println("Controller receives Reward = " + reward);
- learner.update(currentStateId, actionId, newStateId, reward);
- currentStateId = newStateId;
- }
- }
+ learner.update(currentStateId, actionId, newStateId, reward);
+ currentStateId = newStateId;
+ }
+ }
}
diff --git a/src/test/java/com/github/chen0040/rl/learning/rlearn/RAgentUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/rlearn/RAgentUnitTest.java
index 6110298..c3f668b 100644
--- a/src/test/java/com/github/chen0040/rl/learning/rlearn/RAgentUnitTest.java
+++ b/src/test/java/com/github/chen0040/rl/learning/rlearn/RAgentUnitTest.java
@@ -1,7 +1,7 @@
package com.github.chen0040.rl.learning.rlearn;
-
import com.github.chen0040.rl.utils.IndexValue;
+
import org.testng.annotations.Test;
import java.util.Random;
@@ -9,41 +9,39 @@
import static org.assertj.core.api.Java6Assertions.assertThat;
import static org.testng.Assert.*;
-
/**
* Created by xschen on 6/5/2017.
*/
public class RAgentUnitTest {
- @Test
- public void test_r_learn(){
-
- int stateCount = 100;
- int actionCount = 10;
- RAgent agent = new RAgent(stateCount, actionCount);
+ @Test
+ public void test_r_learn() {
- Random random = new Random();
- agent.start(random.nextInt(stateCount));
- for(int time=0; time < 1000; ++time){
+ int stateCount = 100;
+ int actionCount = 10;
+ RAgent agent = new RAgent(stateCount, actionCount);
- IndexValue actionValue = agent.selectAction();
- int actionId = actionValue.getIndex();
- System.out.println("Agent does action-"+actionId);
+ Random random = new Random();
+ agent.start(random.nextInt(stateCount));
+ for (int time = 0; time < 1000; ++time) {
- int newStateId = random.nextInt(actionCount);
- double reward = random.nextDouble();
+ IndexValue actionValue = agent.selectAction();
+ int actionId = actionValue.getIndex();
+ System.out.println("Agent does action-" + actionId);
- System.out.println("Now the new state is "+newStateId);
- System.out.println("Agent receives Reward = "+reward);
+ int newStateId = random.nextInt(actionCount);
+ double reward = random.nextDouble();
- agent.update(newStateId, reward);
- }
+ System.out.println("Now the new state is " + newStateId);
+ System.out.println("Agent receives Reward = " + reward);
- RLearner learner = agent.getLearner();
- RLearner learner2 = RLearner.fromJson(learner.toJson());
+ agent.update(newStateId, reward);
+ }
- assertThat(learner).isEqualTo(learner2);
+ RLearner learner = agent.getLearner();
+ RLearner learner2 = RLearner.fromJson(learner.toJson());
+ assertThat(learner).isEqualTo(learner2);
- }
+ }
}
diff --git a/src/test/java/com/github/chen0040/rl/learning/sarsa/SarsaAgentUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/sarsa/SarsaAgentUnitTest.java
index 0959b33..fa82642 100644
--- a/src/test/java/com/github/chen0040/rl/learning/sarsa/SarsaAgentUnitTest.java
+++ b/src/test/java/com/github/chen0040/rl/learning/sarsa/SarsaAgentUnitTest.java
@@ -1,6 +1,5 @@
package com.github.chen0040.rl.learning.sarsa;
-
import org.testng.annotations.Test;
import java.util.Random;
@@ -8,39 +7,38 @@
import static org.assertj.core.api.Java6Assertions.assertThat;
import static org.testng.Assert.*;
-
/**
* Created by xschen on 6/5/2017.
*/
public class SarsaAgentUnitTest {
- @Test
- public void test_sarsa(){
- int stateCount = 100;
- int actionCount = 10;
- SarsaAgent agent = new SarsaAgent(stateCount, actionCount);
+ @Test
+ public void test_sarsa() {
+ int stateCount = 100;
+ int actionCount = 10;
+ SarsaAgent agent = new SarsaAgent(stateCount, actionCount);
- double reward = 0; //immediate reward by transiting from prevState to currentState
- Random random = new Random();
- agent.start(random.nextInt(stateCount));
- int actionTaken = agent.selectAction().getIndex();
- for(int time=0; time < 1000; ++time){
+ double reward = 0; //immediate reward by transiting from prevState to currentState
+ Random random = new Random();
+ agent.start(random.nextInt(stateCount));
+ int actionTaken = agent.selectAction().getIndex();
+ for (int time = 0; time < 1000; ++time) {
- System.out.println("Agent does action-"+actionTaken);
+ System.out.println("Agent does action-" + actionTaken);
- int newStateId = random.nextInt(actionCount);
- reward = random.nextDouble();
+ int newStateId = random.nextInt(actionCount);
+ reward = random.nextDouble();
- System.out.println("Now the new state is "+newStateId);
- System.out.println("Agent receives Reward = "+reward);
+ System.out.println("Now the new state is " + newStateId);
+ System.out.println("Agent receives Reward = " + reward);
- agent.update(actionTaken, newStateId, reward);
- }
+ agent.update(actionTaken, newStateId, reward);
+ }
- SarsaLearner learner = agent.getLearner();
+ SarsaLearner learner = agent.getLearner();
- SarsaLearner learner2 = SarsaLearner.fromJson(learner.toJson());
+ SarsaLearner learner2 = SarsaLearner.fromJson(learner.toJson());
- assertThat(learner2).isEqualTo(learner);
- }
+ assertThat(learner2).isEqualTo(learner);
+ }
}
diff --git a/src/test/java/com/github/chen0040/rl/learning/utils/MatrixUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/utils/MatrixUnitTest.java
index 2e9cf52..a613388 100644
--- a/src/test/java/com/github/chen0040/rl/learning/utils/MatrixUnitTest.java
+++ b/src/test/java/com/github/chen0040/rl/learning/utils/MatrixUnitTest.java
@@ -1,8 +1,10 @@
package com.github.chen0040.rl.learning.utils;
-import com.alibaba.fastjson.JSON;
-import com.alibaba.fastjson.serializer.SerializerFeature;
+//import com.alibaba.fastjson.JSON;
+//import com.alibaba.fastjson.serializer.SerializerFeature;
import com.github.chen0040.rl.utils.Matrix;
+import com.google.gson.Gson;
+
import org.testng.annotations.Test;
import java.util.Random;
@@ -11,55 +13,55 @@
public class MatrixUnitTest {
- private static final Random random = new Random(42);
+ private static final Random random = new Random(42);
- @Test
- public void testJsonSerialization() {
- Matrix matrix = new Matrix(10, 10);
- matrix.set(0, 0, 10);
- matrix.set(4, 2, 2);
- matrix.set(3, 3, 2);
+ @Test
+ public void testJsonSerialization() {
+ Matrix matrix = new Matrix(10, 10);
+ matrix.set(0, 0, 10);
+ matrix.set(4, 2, 2);
+ matrix.set(3, 3, 2);
- assertThat(matrix.get(0, 0)).isEqualTo(10);
- assertThat(matrix.get(4, 2)).isEqualTo(2);
- assertThat(matrix.get(3, 3)).isEqualTo(2);
- assertThat(matrix.get(4, 4)).isEqualTo(0);
+ assertThat(matrix.get(0, 0)).isEqualTo(10);
+ assertThat(matrix.get(4, 2)).isEqualTo(2);
+ assertThat(matrix.get(3, 3)).isEqualTo(2);
+ assertThat(matrix.get(4, 4)).isEqualTo(0);
- assertThat(matrix.getRowCount()).isEqualTo(10);
- assertThat(matrix.getColumnCount()).isEqualTo(10);
+ assertThat(matrix.getRowCount()).isEqualTo(10);
+ assertThat(matrix.getColumnCount()).isEqualTo(10);
- String json = JSON.toJSONString(matrix, SerializerFeature.PrettyFormat);
+ String json = new Gson().toJson(matrix); //JSON.toJSONString(matrix, SerializerFeature.PrettyFormat);
- System.out.println(json);
- Matrix matrix2 = JSON.parseObject(json, Matrix.class);
- assertThat(matrix).isEqualTo(matrix2);
+ System.out.println(json);
+ Matrix matrix2 = new Gson().fromJson(json, Matrix.class); //JSON.parseObject(json, Matrix.class);
+ assertThat(matrix).isEqualTo(matrix2);
- for(int i=0; i < matrix.getRowCount(); ++i){
- for(int j=0; j < matrix.getColumnCount(); ++j) {
- assertThat(matrix.get(i, j)).isEqualTo(matrix2.get(i, j));
- }
- }
- }
+ for (int i = 0; i < matrix.getRowCount(); ++i) {
+ for (int j = 0; j < matrix.getColumnCount(); ++j) {
+ assertThat(matrix.get(i, j)).isEqualTo(matrix2.get(i, j));
+ }
+ }
+ }
- @Test
- public void testJsonSerialization_Random() {
- Matrix matrix = new Matrix(10, 10);
- for(int i=0; i < matrix.getRowCount(); ++i){
- for(int j=0; j < matrix.getColumnCount(); ++j){
- matrix.set(i, j, random.nextDouble());
- }
- }
- Matrix matrix2 = matrix.makeCopy();
- assertThat(matrix).isEqualTo(matrix2);
+ @Test
+ public void testJsonSerialization_Random() {
+ Matrix matrix = new Matrix(10, 10);
+ for (int i = 0; i < matrix.getRowCount(); ++i) {
+ for (int j = 0; j < matrix.getColumnCount(); ++j) {
+ matrix.set(i, j, random.nextDouble());
+ }
+ }
+ Matrix matrix2 = matrix.makeCopy();
+ assertThat(matrix).isEqualTo(matrix2);
- String json = JSON.toJSONString(matrix);
- Matrix matrix3 = JSON.parseObject(json, Matrix.class);
- assertThat(matrix2).isEqualTo(matrix3);
+ String json = new Gson().toJson(matrix); //JSON.toJSONString(matrix);
+ Matrix matrix3 = new Gson().fromJson(json, Matrix.class); //JSON.parseObject(json, Matrix.class);
+ assertThat(matrix2).isEqualTo(matrix3);
- for(int i=0; i < matrix.getRowCount(); ++i){
- for(int j=0; j < matrix.getColumnCount(); ++j){
- assertThat(matrix2.get(i, j)).isEqualTo(matrix3.get(i, j));
- }
- }
- }
+ for (int i = 0; i < matrix.getRowCount(); ++i) {
+ for (int j = 0; j < matrix.getColumnCount(); ++j) {
+ assertThat(matrix2.get(i, j)).isEqualTo(matrix3.get(i, j));
+ }
+ }
+ }
}
diff --git a/src/test/java/com/github/chen0040/rl/learning/utils/VecUnitTest.java b/src/test/java/com/github/chen0040/rl/learning/utils/VecUnitTest.java
index daa4396..40c87bd 100644
--- a/src/test/java/com/github/chen0040/rl/learning/utils/VecUnitTest.java
+++ b/src/test/java/com/github/chen0040/rl/learning/utils/VecUnitTest.java
@@ -1,20 +1,22 @@
package com.github.chen0040.rl.learning.utils;
-import com.alibaba.fastjson.JSON;
+//import com.alibaba.fastjson.JSON;
import com.github.chen0040.rl.utils.Vec;
+import com.google.gson.Gson;
+
import org.testng.annotations.Test;
import static org.assertj.core.api.Java6Assertions.assertThat;
public class VecUnitTest {
- @Test
- public void testJsonSerialization() {
- Vec vec = new Vec(100);
- vec.set(9, 100);
- vec.set(11, 2);
- vec.set(0, 1);
- String json = JSON.toJSONString(vec);
- Vec vec2 = JSON.parseObject(json, Vec.class);
- assertThat(vec).isEqualTo(vec2);
- }
+ @Test
+ public void testJsonSerialization() {
+ Vec vec = new Vec(100);
+ vec.set(9, 100);
+ vec.set(11, 2);
+ vec.set(0, 1);
+ String json = new Gson().toJson(vec); //JSON.toJSONString(vec);
+ Vec vec2 = new Gson().fromJson(json, Vec.class); //JSON.parseObject(json, Vec.class);
+ assertThat(vec).isEqualTo(vec2);
+ }
}
diff --git a/src/test/resources/log4j.properties b/src/test/resources/log4j.properties
index ef69b72..98799c4 100644
--- a/src/test/resources/log4j.properties
+++ b/src/test/resources/log4j.properties
@@ -1,9 +1,7 @@
# Set root logger level to DEBUG and its only appender to A1.
log4j.rootLogger=DEBUG, A1
-
# A1 is set to be a ConsoleAppender.
log4j.appender.A1=org.apache.log4j.ConsoleAppender
-
# A1 uses PatternLayout.
log4j.appender.A1.layout=org.apache.log4j.PatternLayout
log4j.appender.A1.layout.ConversionPattern=%-5p %c %x - %m%n