|
| 1 | +(ns gen.distribution |
| 2 | + "Collection of protocols and functions for working with primitive |
| 3 | + distributions." |
| 4 | + (:require [gen.dynamic.choice-map :as cm] |
| 5 | + [gen.generative-function :as gf] |
| 6 | + [gen.dynamic.trace :as trace]) |
| 7 | + #?(:clj |
| 8 | + (:import (clojure.lang IFn)))) |
| 9 | + |
| 10 | +;; ## Protocols |
| 11 | +;; |
| 12 | +;; Any distribution that can implement [[logpdf]] and [[sample]] can implement |
| 13 | +;; Gen.clj's generative function interface. These protocols are the way in to the |
| 14 | + |
| 15 | +(defprotocol LogPDF |
| 16 | + (logpdf [this v] |
| 17 | + "Returns the log-likelihood of observing the value `v` given the |
| 18 | + distribution `this`.")) |
| 19 | + |
| 20 | +(defprotocol Sample |
| 21 | + (sample [this] |
| 22 | + "Returns a single value sampled from the distribution `this`.")) |
| 23 | + |
| 24 | +(defn distribution? |
| 25 | + "Returns true if `t` implements [[LogPDF]] and [[Sample]], false otherwise." |
| 26 | + [t] |
| 27 | + (and (satisfies? LogPDF t) |
| 28 | + (satisfies? Sample t))) |
| 29 | + |
| 30 | +;; ## Primitive Generative Functions |
| 31 | + |
| 32 | +;; The [[gen.distribution/GenerativeFn]] type wraps a constructor `ctor` (a |
| 33 | +;; function of args `xs` that returns a statistical distribution) into an object |
| 34 | +;; that |
| 35 | +;; |
| 36 | +;; - acts as a function from `ctor`'s arguments to a single sample |
| 37 | +;; - implements the generative function interface defined |
| 38 | +;; in [[gen.generative-function]]. |
| 39 | +;; |
| 40 | +;; This type provides support for all primitive distributions. |
| 41 | + |
| 42 | +(defrecord GenerativeFn [ctor] |
| 43 | + gf/Simulate |
| 44 | + (simulate [this args] |
| 45 | + (let [dist (apply ctor args) |
| 46 | + val (sample dist) |
| 47 | + score (logpdf dist val)] |
| 48 | + (trace/->PrimitiveTrace this args val score))) |
| 49 | + |
| 50 | + gf/Generate |
| 51 | + (generate [gf args] |
| 52 | + {:weight 0.0 |
| 53 | + :trace (gf/simulate gf args)}) |
| 54 | + |
| 55 | + (generate [gf args constraint] |
| 56 | + (assert (cm/choice? constraint)) |
| 57 | + (let [dist (apply ctor args) |
| 58 | + val (cm/unwrap constraint) |
| 59 | + weight (logpdf dist val)] |
| 60 | + {:weight weight |
| 61 | + :trace (trace/->PrimitiveTrace gf args val weight)})) |
| 62 | + |
| 63 | + #?@(:clj |
| 64 | + [IFn |
| 65 | + (invoke [_] |
| 66 | + (sample (ctor))) |
| 67 | + (invoke [_ a] |
| 68 | + (sample (ctor a))) |
| 69 | + (invoke [_ a b] |
| 70 | + (sample (ctor a b))) |
| 71 | + (invoke [_ a b c] |
| 72 | + (sample (ctor a b c))) |
| 73 | + (invoke [_ a b c d] |
| 74 | + (sample (ctor a b c d))) |
| 75 | + (invoke [_ a b c d e] |
| 76 | + (sample (ctor a b c d e))) |
| 77 | + (invoke [_ a b c d e f] |
| 78 | + (sample (ctor a b c d e f))) |
| 79 | + (invoke [_ a b c d e f g] |
| 80 | + (sample (ctor a b c d e f g))) |
| 81 | + (invoke [_ a b c d e f g h] |
| 82 | + (sample (ctor a b c d e f g h))) |
| 83 | + (invoke [_ a b c d e f g h i] |
| 84 | + (sample (ctor a b c d e f g h i))) |
| 85 | + (invoke [_ a b c d e f g h i j] |
| 86 | + (sample (ctor a b c d e f g h i j))) |
| 87 | + (invoke [_ a b c d e f g h i j k] |
| 88 | + (sample (ctor a b c d e f g h i j k))) |
| 89 | + (invoke [_ a b c d e f g h i j k l] |
| 90 | + (sample (ctor a b c d e f g h i j k l))) |
| 91 | + (invoke [_ a b c d e f g h i j k l m] |
| 92 | + (sample (ctor a b c d e f g h i j k l m))) |
| 93 | + (invoke [_ a b c d e f g h i j k l m n] |
| 94 | + (sample (ctor a b c d e f g h i j k l m n))) |
| 95 | + (invoke [_ a b c d e f g h i j k l m n o] |
| 96 | + (sample (ctor a b c d e f g h i j k l m n o))) |
| 97 | + (invoke [_ a b c d e f g h i j k l m n o p] |
| 98 | + (sample (ctor a b c d e f g h i j k l m n o p))) |
| 99 | + (invoke [_ a b c d e f g h i j k l m n o p q] |
| 100 | + (sample (ctor a b c d e f g h i j k l m n o p q))) |
| 101 | + (invoke [_ a b c d e f g h i j k l m n o p q r] |
| 102 | + (sample (ctor a b c d e f g h i j k l m n o p q r))) |
| 103 | + (invoke [_ a b c d e f g h i j k l m n o p q r s] |
| 104 | + (sample (ctor a b c d e f g h i j k l m n o p q r s))) |
| 105 | + (invoke [_ a b c d e f g h i j k l m n o p q r s t] |
| 106 | + (sample (ctor a b c d e f g h i j k l m n o p q r s t))) |
| 107 | + (invoke [_ a b c d e f g h i j k l m n o p q r s t rest] |
| 108 | + (sample (apply ctor a b c d e f g h i j k l m n o p q r s t rest))) |
| 109 | + (applyTo [_ xs] |
| 110 | + (sample (apply ctor xs)))] |
| 111 | + |
| 112 | + :cljs |
| 113 | + [IFn |
| 114 | + (-invoke [_] |
| 115 | + (sample (ctor))) |
| 116 | + (-invoke [_ a] |
| 117 | + (sample (ctor a))) |
| 118 | + (-invoke [_ a b] |
| 119 | + (sample (ctor a b))) |
| 120 | + (-invoke [_ a b c] |
| 121 | + (sample (ctor a b c))) |
| 122 | + (-invoke [_ a b c d] |
| 123 | + (sample (ctor a b c d))) |
| 124 | + (-invoke [_ a b c d e] |
| 125 | + (sample (ctor a b c d e))) |
| 126 | + (-invoke [_ a b c d e f] |
| 127 | + (sample (ctor a b c d e f))) |
| 128 | + (-invoke [_ a b c d e f g] |
| 129 | + (sample (ctor a b c d e f g))) |
| 130 | + (-invoke [_ a b c d e f g h] |
| 131 | + (sample (ctor a b c d e f g h))) |
| 132 | + (-invoke [_ a b c d e f g h i] |
| 133 | + (sample (ctor a b c d e f g h i))) |
| 134 | + (-invoke [_ a b c d e f g h i j] |
| 135 | + (sample (ctor a b c d e f g h i j))) |
| 136 | + (-invoke [_ a b c d e f g h i j k] |
| 137 | + (sample (ctor a b c d e f g h i j k))) |
| 138 | + (-invoke [_ a b c d e f g h i j k l] |
| 139 | + (sample (ctor a b c d e f g h i j k l))) |
| 140 | + (-invoke [_ a b c d e f g h i j k l m] |
| 141 | + (sample (ctor a b c d e f g h i j k l m))) |
| 142 | + (-invoke [_ a b c d e f g h i j k l m n] |
| 143 | + (sample (ctor a b c d e f g h i j k l m n))) |
| 144 | + (-invoke [_ a b c d e f g h i j k l m n o] |
| 145 | + (sample (ctor a b c d e f g h i j k l m n o))) |
| 146 | + (-invoke [_ a b c d e f g h i j k l m n o p] |
| 147 | + (sample (ctor a b c d e f g h i j k l m n o p))) |
| 148 | + (-invoke [_ a b c d e f g h i j k l m n o p q] |
| 149 | + (sample (ctor a b c d e f g h i j k l m n o p q))) |
| 150 | + (-invoke [_ a b c d e f g h i j k l m n o p q r] |
| 151 | + (sample (ctor a b c d e f g h i j k l m n o p q r))) |
| 152 | + (-invoke [_ a b c d e f g h i j k l m n o p q r s] |
| 153 | + (sample (ctor a b c d e f g h i j k l m n o p q r s))) |
| 154 | + (-invoke [_ a b c d e f g h i j k l m n o p q r s t] |
| 155 | + (sample (ctor a b c d e f g h i j k l m n o p q r s t))) |
| 156 | + (-invoke [_ a b c d e f g h i j k l m n o p q r s t rest] |
| 157 | + (sample (apply ctor a b c d e f g h i j k l m n o p q r s t rest)))])) |
| 158 | + |
| 159 | +;; ## Combinators |
| 160 | +;; |
| 161 | +;; The [[Encoded]] type creates a new distribution from a base distribution |
| 162 | +;; `dist`. This new distribution transforms values on the way in to `logpdf` |
| 163 | +;; using an `encode` function, and decodes sampled values via `decode`. |
| 164 | +;; |
| 165 | +;; This is useful for building distributions like categorical distributions that |
| 166 | +;; might produce and score arbitrary Clojure values, but lean on some existing |
| 167 | +;; numeric base implementation. |
| 168 | + |
| 169 | +(defrecord Encoded [dist encode decode] |
| 170 | + LogPDF |
| 171 | + (logpdf [_ v] |
| 172 | + (logpdf dist (encode v))) |
| 173 | + |
| 174 | + Sample |
| 175 | + (sample [_] |
| 176 | + (decode (sample dist)))) |
| 177 | + |
| 178 | +(defn encoded |
| 179 | + "Given a distribution-producing function `ctor`, returns a constructor for a new |
| 180 | + distribution that |
| 181 | +
|
| 182 | + - encodes each value `v` into `(encode v)` before passage to [[logpdf]] |
| 183 | + - decodes each value `v` sampled from the base distribution into `(decode |
| 184 | + v)`" |
| 185 | + [ctor encode decode] |
| 186 | + (comp #(->Encoded % encode decode) ctor)) |
0 commit comments