Skip to content

Commit 2beadb6

Browse files
committed
feat: add kixi for cljc stats
1 parent e62668e commit 2beadb6

File tree

4 files changed

+213
-9
lines changed

4 files changed

+213
-9
lines changed

deps.edn

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
{:paths ["src" "resources"]
22

33
:deps
4-
{org.apache.commons/commons-math3 {:mvn/version "3.6.1"}
4+
{kixi/stats {:mvn/version "0.5.5"}
5+
org.apache.commons/commons-math3 {:mvn/version "3.6.1"}
56
org.clojure/clojure {:mvn/version "1.11.1"}}
67

78
:aliases

src/gen/distribution/commons_math.clj

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,26 @@
5454
0 false
5555
1 true)))))
5656

57-
(defn beta-distribution [^double alpha ^double beta]
58-
(BetaDistribution. (rng) alpha beta))
57+
(defn beta-distribution
58+
([] (beta-distribution 1.0 1.0))
59+
([^double alpha ^double beta]
60+
(BetaDistribution. (rng) alpha beta)))
5961

6062
(defn gamma-distribution [^double shape ^double scale]
6163
(GammaDistribution. (rng) shape scale))
6264

63-
(defn normal-distribution [^double mean ^double sd]
64-
(NormalDistribution. (rng) mean sd))
65+
(defn normal-distribution
66+
([] (normal-distribution 0.0 1.0))
67+
([^double mean ^double sd]
68+
(NormalDistribution. (rng) mean sd)))
6569

66-
(defn uniform-distribution [^double low ^double high]
67-
(UniformRealDistribution. (rng) low high))
70+
(defn uniform-distribution
71+
([] (uniform-distribution 0.0 1.0))
72+
([^double low ^double high]
73+
(UniformRealDistribution. (rng) low high)))
6874

6975
(defn uniform-discrete-distribution [low high]
70-
(UniformIntegerDistribution. low high))
76+
(UniformIntegerDistribution. (rng) low high))
7177

7278
(defn categorical-distribution [probabilities]
7379
(let [n (count probabilities)
@@ -78,7 +84,6 @@
7884
;; ## Primitive generative functions
7985

8086
(def bernoulli
81-
"Generative function... TODO flesh out."
8287
(d/->GenerativeFn bernoulli-distribution))
8388

8489
(def beta

src/gen/distribution/kixi.clj

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
(ns gen.distribution.kixi
2+
(:require [gen.distribution :as d]
3+
[gen.distribution.math.log-likelihood :as ll]
4+
[kixi.stats.distribution :as k])
5+
(:import (kixi.stats.distribution Bernoulli Cauchy
6+
Exponential Beta
7+
Gamma Normal Uniform)))
8+
9+
;; ## Kixi.stats protocol implementations
10+
;;
11+
;; NOTE: If we want to seed the PRNG here, the right way to do it is to create
12+
;; wrapper types that hold a kixi distribution instance and an RNG. Then,
13+
;; instead of `draw`, we can call `sample-1` with the distribution and RNG.
14+
15+
16+
(extend-type Bernoulli
17+
d/Sample
18+
(sample [this] (k/draw this))
19+
20+
d/LogPDF
21+
(logpdf [this v]
22+
(ll/bernoulli (.-p this) v)))
23+
24+
(extend-type Beta
25+
d/Sample
26+
(sample [this] (k/draw this))
27+
28+
d/LogPDF
29+
(logpdf [this v]
30+
(ll/beta (.-alpha this)
31+
(.-beta this)
32+
v)))
33+
34+
(extend-type Cauchy
35+
d/Sample
36+
(sample [this] (k/draw this))
37+
38+
d/LogPDF
39+
(logpdf [this v]
40+
(ll/cauchy (.-location this)
41+
(.-scale this)
42+
v)))
43+
44+
(extend-type Exponential
45+
d/Sample
46+
(sample [this] (k/draw this))
47+
48+
d/LogPDF
49+
(logpdf [this v]
50+
(ll/exponential (.-rate this) v)))
51+
52+
(extend-type Uniform
53+
d/Sample
54+
(sample [this] (k/draw this))
55+
56+
d/LogPDF
57+
(logpdf [this v]
58+
(let [min (.-a this)
59+
max (.-b this)]
60+
(ll/uniform min max v))))
61+
62+
(extend-type Gamma
63+
d/Sample
64+
(sample [this] (k/draw this))
65+
66+
d/LogPDF
67+
(logpdf [this v]
68+
(ll/gamma (.-shape this)
69+
(.-scale this)
70+
v)))
71+
72+
(extend-type Normal
73+
d/Sample
74+
(sample [this] (k/draw this))
75+
76+
d/LogPDF
77+
(logpdf [this v]
78+
(ll/gaussian (.-mu this)
79+
(.-sd this)
80+
v)))
81+
82+
;; ## Primitive probability distributions
83+
84+
(defn bernoulli-distribution
85+
([] (bernoulli-distribution 0.5))
86+
([p] (k/bernoulli {:p p})))
87+
88+
(defn beta-distribution
89+
([] (beta-distribution 1.0 1.0))
90+
([alpha beta]
91+
(k/beta {:alpha alpha :beta beta})))
92+
93+
(defn cauchy-distribution [location scale]
94+
(k/cauchy {:location location :scale scale}))
95+
96+
(defn exponential-distribution [rate]
97+
(k/exponential {:rate rate}))
98+
99+
(defn uniform-distribution
100+
([] (uniform-distribution 0.0 1.0))
101+
([lo hi]
102+
(k/uniform {:a lo :b hi})))
103+
104+
(defn normal-distribution
105+
([] (normal-distribution 0.0 1.0))
106+
([mu sigma]
107+
(k/normal {:mu mu :sd sigma})))
108+
109+
(defn gamma-distribution [shape scale]
110+
(k/gamma {:shape shape :scale scale}))
111+
112+
;; ## Primitive generative functions
113+
114+
(def bernoulli
115+
(d/->GenerativeFn bernoulli-distribution))
116+
117+
(def beta
118+
(d/->GenerativeFn beta-distribution))
119+
120+
(def cauchy
121+
(d/->GenerativeFn cauchy-distribution))
122+
123+
(def exponential
124+
(d/->GenerativeFn exponential-distribution))
125+
126+
(def uniform
127+
(d/->GenerativeFn uniform-distribution))
128+
129+
(def normal
130+
(d/->GenerativeFn normal-distribution))
131+
132+
(def gamma
133+
(d/->GenerativeFn gamma-distribution))
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
(ns gen.distribution.kixi-test
2+
(:require [clojure.math :as math]
3+
[clojure.test :refer [deftest is]]
4+
[gen]
5+
[gen.choice-map]
6+
[gen.diff :as diff]
7+
[gen.distribution.kixi :as d]
8+
[gen.generative-function :as gf]
9+
[gen.trace :as trace]))
10+
11+
(deftest bernoulli-call-no-args
12+
(is (boolean? (d/bernoulli))))
13+
14+
(deftest bernoulli-call-args
15+
(is (boolean? (d/bernoulli 0.5))))
16+
17+
(deftest bernoulli-gf
18+
(is (= d/bernoulli (trace/gf (gf/simulate d/bernoulli [])))))
19+
20+
(deftest bernoulli-args
21+
(is (= [0.5] (trace/args (gf/simulate d/bernoulli [0.5])))))
22+
23+
(deftest bernoulli-retval
24+
(is (boolean? (trace/retval (gf/simulate d/bernoulli [0.5])))))
25+
26+
(deftest bernoulli-choices-noargs
27+
(trace/choices (gf/simulate d/bernoulli [])))
28+
29+
(deftest bernoulli-update-weight
30+
(is (= 1.0
31+
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
32+
(:trace)
33+
(trace/update #gen/choice true)
34+
(:weight)
35+
(math/exp))))
36+
(is (= (/ 0.7 0.3)
37+
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
38+
(:trace)
39+
(trace/update #gen/choice false)
40+
(:weight)
41+
(math/exp)))))
42+
43+
(deftest bernoulli-update-discard
44+
(is (nil?
45+
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
46+
(:trace)
47+
(trace/update nil)
48+
(:discard))))
49+
(is (= #gen/choice true
50+
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
51+
(:trace)
52+
(trace/update #gen/choice false)
53+
(:discard)))))
54+
55+
(deftest bernoulli-update-change
56+
(is (= diff/unknown-change
57+
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
58+
(:trace)
59+
(trace/update nil)
60+
(:change))))
61+
(is (= diff/unknown-change
62+
(-> (gf/generate d/bernoulli [0.3] #gen/choice true)
63+
(:trace)
64+
(trace/update #gen/choice false)
65+
(:change)))))

0 commit comments

Comments
 (0)