Skip to content

Commit 074fffe

Browse files
committed
wip: continue db refactor
1 parent f3949dd commit 074fffe

File tree

20 files changed

+250
-98
lines changed

20 files changed

+250
-98
lines changed

Cargo.lock

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ colored = "3.0.0"
4747
futures = "0.3.30"
4848
futures-util = "0.3.31"
4949
image = "0.25.5"
50-
native_db = { git = "https://github.com/cilki/native_db" }
50+
native_db = { git = "https://github.com/cilki/native_db", features = ["tokio"] }
5151
native_model = "0.6.1"
5252
os_info = "3.8.2"
5353
pem = "3.0.4"
@@ -66,6 +66,7 @@ tempfile = "3.17.0"
6666
time = { version = "0.3.37" }
6767
tokio-rustls = "0.26.0"
6868
tokio = { version = "1.42.0", features = ["full"] }
69+
tokio-util = "0.7.15"
6970
tokio-stream = "0.1.17"
7071
tower = "0.5.2"
7172
tracing = "0.1.40"

sandpolis-database/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,6 @@ serde_cbor = { workspace = true }
1717
serde = { workspace = true }
1818
tempfile = { workspace = true }
1919
tracing = { workspace = true }
20+
tokio = { workspace = true }
21+
tokio-util = { workspace = true }
2022
validator = { workspace = true }

sandpolis-database/src/lib.rs

Lines changed: 107 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@ use chrono::{DateTime, Utc};
77
use native_db::db_type::{KeyDefinition, KeyOptions, ToKeyDefinition};
88
use native_db::transaction::RTransaction;
99
use native_db::transaction::query::SecondaryScanIterator;
10+
use native_db::watch::Event;
1011
use native_db::{Key, Models, ToInput, ToKey};
1112
use sandpolis_core::{GroupName, InstanceId};
1213
use serde::{Deserialize, Serialize};
1314
use std::collections::{BTreeMap, HashMap};
1415
use std::ops::Range;
1516
use std::path::Path;
1617
use std::sync::atomic::AtomicU64;
17-
use std::sync::{RwLock, RwLockReadGuard};
1818
use std::{marker::PhantomData, sync::Arc};
19+
use tokio::sync::{RwLock, RwLockReadGuard};
20+
use tokio_util::sync::CancellationToken;
1921
use tracing::instrument::WithSubscriber;
2022
use tracing::{debug, trace};
2123

@@ -48,9 +50,12 @@ impl DatabaseLayer {
4850
}
4951

5052
/// Load or create a new database for the given group.
51-
pub fn add_group(&mut self, name: GroupName) -> Result<Arc<native_db::Database<'static>>> {
53+
pub async fn add_group(
54+
&mut self,
55+
name: GroupName,
56+
) -> Result<Arc<native_db::Database<'static>>> {
5257
// Check for duplicates
53-
let mut inner = self.inner.write().unwrap();
58+
let mut inner = self.inner.write().await;
5459
if inner.contains_key(&Some(name.clone())) {
5560
bail!("Duplicate group");
5661
}
@@ -69,8 +74,8 @@ impl DatabaseLayer {
6974
Ok(db)
7075
}
7176

72-
pub fn get(&self, name: Option<GroupName>) -> Result<Arc<native_db::Database<'static>>> {
73-
let inner = self.inner.read().unwrap();
77+
pub async fn get(&self, name: Option<GroupName>) -> Result<Arc<native_db::Database<'static>>> {
78+
let inner = self.inner.read().await;
7479
if let Some(db) = inner.get(&name) {
7580
return Ok(db.clone());
7681
}
@@ -81,7 +86,7 @@ impl DatabaseLayer {
8186
/// `Data` is what's stored in a database!
8287
pub trait Data
8388
where
84-
Self: ToInput + Default + Clone + PartialEq,
89+
Self: ToInput + Default + Clone + PartialEq + Send + Sync,
8590
{
8691
fn id(&self) -> DataIdentifier;
8792
fn set_id(&mut self, id: DataIdentifier);
@@ -160,6 +165,8 @@ impl ToKey for DbTimestamp {
160165
pub type DataIdentifier = u64;
161166
pub type DataExpiration = DateTime<Utc>;
162167

168+
// TODO rename Resident?
169+
163170
/// Maintains a real-time cache of persistent objects in the database of
164171
/// type `T`.
165172
#[derive(Clone)]
@@ -169,11 +176,23 @@ where
169176
{
170177
db: Arc<native_db::Database<'static>>,
171178
cache: Arc<RwLock<T>>,
179+
180+
/// Used to stop the database from sending updates
172181
watch_id: u64,
182+
183+
/// Allows the background update thread to be stopped
184+
cancel_token: CancellationToken,
173185
}
174186

175-
impl<T: Data> Watch<T> {
176-
/// Create a new `DataCache` when there's only one row in the database.
187+
impl<T: Data> Drop for Watch<T> {
188+
fn drop(&mut self) {
189+
self.cancel_token.cancel();
190+
self.db.unwatch(self.watch_id).unwrap();
191+
}
192+
}
193+
194+
impl<T: Data + 'static> Watch<T> {
195+
/// Create a new `Watch` when there's only one row in the database.
177196
pub fn singleton(db: Arc<native_db::Database<'static>>) -> Result<Self> {
178197
let r = db.r_transaction()?;
179198
let mut rows: Vec<T> = r.scan().primary()?.all()?.try_collect()?;
@@ -192,26 +211,61 @@ impl<T: Data> Watch<T> {
192211
T::default()
193212
};
194213

195-
let (channel, watch_id) = db.watch().get().primary::<T>(item.id())?;
214+
let (mut channel, watch_id) = db.watch().get().primary::<T>(item.id())?;
215+
216+
let cancel_token = CancellationToken::new();
217+
let token = cancel_token.clone();
218+
219+
let cache = Arc::new(RwLock::new(item));
220+
221+
tokio::spawn({
222+
let cache_clone = Arc::clone(&cache);
223+
async move {
224+
loop {
225+
tokio::select! {
226+
_ = token.cancelled() => {
227+
break;
228+
}
229+
event = channel.recv() => match event {
230+
Some(event) => match event {
231+
Event::Insert(data) => {}
232+
Event::Update(data) => match data.inner_new() {
233+
Ok(d) => {
234+
let mut c = cache_clone.write().await;
235+
*c = d;
236+
},
237+
Err(_) => {},
238+
}
239+
Event::Delete(data) => {}
240+
}
241+
None => {
242+
break;
243+
}
244+
}
245+
}
246+
}
247+
}
248+
});
196249

197250
Ok(Self {
198-
cache: Arc::new(RwLock::new(item)),
251+
cache,
199252
watch_id,
253+
cancel_token,
200254
db,
201255
})
202256
}
203257

204-
pub fn read(&self) -> RwLockReadGuard<'_, T> {
205-
self.cache.read().unwrap()
258+
pub async fn read(&self) -> RwLockReadGuard<'_, T> {
259+
self.cache.read().await
206260
}
207261
}
208262

209263
impl<T: Data> Watch<T> {
210-
pub fn update<F>(&self, mutator: F) -> Result<()>
264+
pub async fn update<F>(&self, mutator: F) -> Result<()>
211265
where
212266
F: Fn(&mut T) -> Result<()>,
213267
{
214-
let cache = self.cache.read().unwrap();
268+
let cache = self.cache.read().await;
215269
let mut next = cache.clone();
216270
mutator(&mut next)?;
217271

@@ -222,7 +276,8 @@ impl<T: Data> Watch<T> {
222276

223277
drop(cache);
224278

225-
self.cache.set(next).unwrap();
279+
let mut cache = self.cache.write().await;
280+
*cache = next;
226281
}
227282

228283
Ok(())
@@ -248,6 +303,7 @@ mod test_database {
248303
use native_model::{Model, native_model};
249304
use sandpolis_macros::{Data, HistoricalData};
250305
use serde::{Deserialize, Serialize};
306+
use tokio::time::{Duration, sleep};
251307

252308
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Default, Data)]
253309
#[native_model(id = 5, version = 1)]
@@ -262,8 +318,8 @@ mod test_database {
262318
pub b: String,
263319
}
264320

265-
#[test]
266-
fn test_build_database() -> Result<()> {
321+
#[tokio::test]
322+
async fn test_build_database() -> Result<()> {
267323
let models = Box::leak(Box::new(Models::new()));
268324
models.define::<TestData>().unwrap();
269325

@@ -274,17 +330,19 @@ mod test_database {
274330
},
275331
models,
276332
)?;
277-
database.add_group("default".parse()?)?;
333+
database.add_group("default".parse()?).await?;
278334

279-
let db = database.get(Some("default".parse()?))?;
335+
let db = database.get(Some("default".parse()?)).await?;
280336
let watch: Watch<TestData> = Watch::singleton(db.clone())?;
281337

282338
// Update data a bunch of times
283339
for i in 1..10 {
284-
watch.update(|data| {
285-
data.a = format!("test {i}");
286-
Ok(())
287-
})?;
340+
watch
341+
.update(|data| {
342+
data.a = format!("test {i}");
343+
Ok(())
344+
})
345+
.await?;
288346
}
289347

290348
// Database should reflect "test 9"
@@ -299,16 +357,39 @@ mod test_database {
299357
{
300358
let rw = db.rw_transaction()?;
301359
rw.upsert(TestData {
302-
_id: watch.read()._id,
360+
_id: watch.read().await._id,
303361
a: "test 10".into(),
304362
b: "".into(),
305363
})?;
306364
rw.commit()?;
307365
}
308366

309-
// Watch should reflect "test 10"
310-
assert_eq!(watch.read().a, "test 10");
367+
// Watch should reflect "test 10" after a while
368+
sleep(Duration::from_secs(1)).await;
369+
assert_eq!(watch.read().await.a, "test 10");
311370

312371
Ok(())
313372
}
314373
}
374+
375+
#[derive(Clone)]
376+
pub struct TimeResVec<T>
377+
where
378+
T: HistoricalData,
379+
{
380+
db: Arc<native_db::Database<'static>>,
381+
cache: Arc<RwLock<T>>,
382+
383+
/// Used to stop the database from sending updates
384+
watch_id: u64,
385+
386+
/// Allows the background update thread to be stopped
387+
cancel_token: CancellationToken,
388+
}
389+
390+
impl<T: HistoricalData> Drop for TimeResVec<T> {
391+
fn drop(&mut self) {
392+
self.cancel_token.cancel();
393+
self.db.unwatch(self.watch_id).unwrap();
394+
}
395+
}

sandpolis-group/src/lib.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use native_db::*;
77
use native_model::{Model, native_model};
88
use regex::Regex;
99
use sandpolis_core::{ClusterId, GroupName};
10-
use sandpolis_database::{Data, DataIdentifier, DataView, DatabaseLayer};
10+
use sandpolis_database::{Data, DataIdentifier, DatabaseLayer, Watch};
1111
use sandpolis_instance::InstanceLayer;
1212
use sandpolis_macros::Data;
1313
use serde::{Deserialize, Serialize};
@@ -24,7 +24,7 @@ pub mod messages;
2424
#[cfg(feature = "server")]
2525
pub mod server;
2626

27-
#[derive(Serialize, Deserialize, Clone, Default, Data)]
27+
#[derive(Serialize, Deserialize, Clone, Default, PartialEq, Eq, Data)]
2828
#[native_model(id = 17, version = 1)]
2929
#[native_db]
3030
pub struct GroupLayerData {
@@ -36,19 +36,19 @@ pub struct GroupLayerData {
3636

3737
#[derive(Clone)]
3838
pub struct GroupLayer {
39-
database: DatabaseLayer,
39+
data: Watch<GroupLayerData>,
4040
}
4141

4242
impl GroupLayer {
43-
pub fn new(
43+
pub async fn new(
4444
config: GroupConfig,
4545
mut database: DatabaseLayer,
4646
instance: InstanceLayer,
4747
) -> Result<Self> {
4848
// Create default group if it doesn't exist
49-
if database.get(Some("default".parse()?)).is_err() {
49+
if database.get(Some("default".parse()?)).await.is_err() {
5050
debug!("Creating default group");
51-
let db = database.get(None)?;
51+
let db = database.get(None).await?;
5252
let rw = db.rw_transaction()?;
5353
if rw
5454
.get()
@@ -80,7 +80,7 @@ impl GroupLayer {
8080
}
8181

8282
// Load all group databases
83-
let db = database.get(None)?;
83+
let db = database.get(None).await?;
8484
let r = db.r_transaction()?;
8585
for group in r.scan().primary::<GroupData>()?.all()? {
8686
let group = group?;
@@ -103,15 +103,17 @@ impl GroupLayer {
103103
}
104104
}
105105

106-
Ok(Self { database })
106+
Ok(Self {
107+
data: Watch::singleton(db)?,
108+
})
107109
}
108110
}
109111

110112
/// A group is a set of clients and agents that can interact. Each group has a
111113
/// global CA certificate that signs certificates used to connect to the server.
112114
///
113115
/// All servers have a default group called "default".
114-
#[derive(Serialize, Deserialize, Validate, Debug, Clone, Default, Data)]
116+
#[derive(Serialize, Deserialize, Validate, Debug, Clone, Default, PartialEq, Eq, Data)]
115117
#[native_model(id = 18, version = 1)]
116118
#[native_db]
117119
pub struct GroupData {
@@ -165,7 +167,7 @@ impl GroupServerCert {
165167

166168
/// A _client_ certificate (not as in "client" instance) used to authenticate
167169
/// with a server instance.
168-
#[derive(Serialize, Deserialize, Debug, Clone)]
170+
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
169171
pub struct GroupClientCert {
170172
pub ca: String,
171173
pub cert: String,

0 commit comments

Comments
 (0)