|
1 | | -use outbound_redis::*; |
2 | | -use owning_ref::RwLockReadGuardRef; |
3 | | -use redis::Commands; |
4 | | -use std::{ |
5 | | - collections::HashMap, |
6 | | - sync::{Arc, Mutex, RwLock}, |
7 | | -}; |
| 1 | +use std::{collections::HashMap, sync::Arc}; |
8 | 2 |
|
9 | | -pub use outbound_redis::add_to_linker; |
10 | | -use spin_engine::{ |
11 | | - host_component::{HostComponent, HostComponentsStateHandle}, |
12 | | - RuntimeContext, |
13 | | -}; |
14 | | -use wit_bindgen_wasmtime::{async_trait, wasmtime::Linker}; |
| 3 | +use anyhow::Result; |
| 4 | +use redis::{aio::Connection, AsyncCommands}; |
| 5 | +use spin_core::{HostComponent, Linker}; |
| 6 | +use tokio::sync::{Mutex, RwLock}; |
| 7 | +use wit_bindgen_wasmtime::async_trait; |
15 | 8 |
|
16 | 9 | wit_bindgen_wasmtime::export!({paths: ["../../wit/ephemeral/outbound-redis.wit"], async: *}); |
| 10 | +use outbound_redis::Error; |
17 | 11 |
|
18 | | -/// A simple implementation to support outbound Redis commands. |
| 12 | +#[derive(Clone, Default)] |
19 | 13 | pub struct OutboundRedis { |
20 | | - pub connections: Arc<RwLock<HashMap<String, Mutex<redis::Connection>>>>, |
| 14 | + connections: Arc<RwLock<HashMap<String, Arc<Mutex<Connection>>>>>, |
21 | 15 | } |
22 | 16 |
|
23 | 17 | impl HostComponent for OutboundRedis { |
24 | | - type State = Self; |
| 18 | + type Data = Self; |
25 | 19 |
|
26 | 20 | fn add_to_linker<T: Send>( |
27 | | - linker: &mut Linker<RuntimeContext<T>>, |
28 | | - state_handle: HostComponentsStateHandle<Self::State>, |
| 21 | + linker: &mut Linker<T>, |
| 22 | + get: impl Fn(&mut spin_core::Data<T>) -> &mut Self::Data + Send + Sync + Copy + 'static, |
29 | 23 | ) -> anyhow::Result<()> { |
30 | | - add_to_linker(linker, move |ctx| state_handle.get_mut(ctx)) |
| 24 | + crate::outbound_redis::add_to_linker(linker, get) |
31 | 25 | } |
32 | 26 |
|
33 | | - fn build_state(&self, component: &spin_manifest::CoreComponent) -> anyhow::Result<Self::State> { |
34 | | - let mut conn_map = HashMap::new(); |
35 | | - if let Some(address) = component.wasm.environment.get("REDIS_ADDRESS") { |
36 | | - let client = redis::Client::open(address.to_string())?; |
37 | | - let conn = client.get_connection()?; |
38 | | - conn_map.insert(address.to_owned(), Mutex::new(conn)); |
39 | | - } |
40 | | - Ok(Self { |
41 | | - connections: Arc::new(RwLock::new(conn_map)), |
42 | | - }) |
| 27 | + fn build_data(&self) -> Self::Data { |
| 28 | + self.clone() |
43 | 29 | } |
44 | 30 | } |
45 | 31 |
|
46 | | -// TODO: use spawn_blocking or async client methods (redis::aio) |
47 | 32 | #[async_trait] |
48 | 33 | impl outbound_redis::OutboundRedis for OutboundRedis { |
49 | 34 | async fn publish(&mut self, address: &str, channel: &str, payload: &[u8]) -> Result<(), Error> { |
50 | | - let conn_map = self.get_reused_conn_map(address)?; |
51 | | - let mut conn = conn_map |
52 | | - .get(address) |
53 | | - .unwrap() |
54 | | - .lock() |
55 | | - .map_err(|_| Error::Error)?; |
56 | | - conn.publish(channel, payload).map_err(|_| Error::Error)?; |
| 35 | + let conn = self.get_conn(address).await.map_err(log_error)?; |
| 36 | + conn.lock() |
| 37 | + .await |
| 38 | + .publish(channel, payload) |
| 39 | + .await |
| 40 | + .map_err(log_error)?; |
57 | 41 | Ok(()) |
58 | 42 | } |
59 | 43 |
|
60 | 44 | async fn get(&mut self, address: &str, key: &str) -> Result<Vec<u8>, Error> { |
61 | | - let conn_map = self.get_reused_conn_map(address)?; |
62 | | - let mut conn = conn_map |
63 | | - .get(address) |
64 | | - .unwrap() |
65 | | - .lock() |
66 | | - .map_err(|_| Error::Error)?; |
67 | | - let value = conn.get(key).map_err(|_| Error::Error)?; |
| 45 | + let conn = self.get_conn(address).await.map_err(log_error)?; |
| 46 | + let value = conn.lock().await.get(key).await.map_err(log_error)?; |
68 | 47 | Ok(value) |
69 | 48 | } |
70 | 49 |
|
71 | 50 | async fn set(&mut self, address: &str, key: &str, value: &[u8]) -> Result<(), Error> { |
72 | | - let conn_map = self.get_reused_conn_map(address)?; |
73 | | - let mut conn = conn_map |
74 | | - .get(address) |
75 | | - .unwrap() |
76 | | - .lock() |
77 | | - .map_err(|_| Error::Error)?; |
78 | | - conn.set(key, value).map_err(|_| Error::Error)?; |
| 51 | + let conn = self.get_conn(address).await.map_err(log_error)?; |
| 52 | + conn.lock().await.set(key, value).await.map_err(log_error)?; |
79 | 53 | Ok(()) |
80 | 54 | } |
81 | 55 |
|
82 | 56 | async fn incr(&mut self, address: &str, key: &str) -> Result<i64, Error> { |
83 | | - let conn_map = self.get_reused_conn_map(address)?; |
84 | | - let mut conn = conn_map |
85 | | - .get(address) |
86 | | - .unwrap() |
87 | | - .lock() |
88 | | - .map_err(|_| Error::Error)?; |
89 | | - let value = conn.incr(key, 1).map_err(|_| Error::Error)?; |
| 57 | + let conn = self.get_conn(address).await.map_err(log_error)?; |
| 58 | + let value = conn.lock().await.incr(key, 1).await.map_err(log_error)?; |
90 | 59 | Ok(value) |
91 | 60 | } |
92 | 61 | } |
93 | 62 |
|
94 | 63 | impl OutboundRedis { |
95 | | - fn get_reused_conn_map<'ret, 'me: 'ret, 'c>( |
96 | | - &'me mut self, |
97 | | - address: &'c str, |
98 | | - ) -> Result<RwLockReadGuardRef<'ret, HashMap<String, Mutex<redis::Connection>>>, Error> { |
99 | | - let conn_map = self.connections.read().map_err(|_| Error::Error)?; |
100 | | - if conn_map.get(address).is_some() { |
101 | | - tracing::debug!("Reuse connection: {:?}", address); |
102 | | - return Ok(RwLockReadGuardRef::new(conn_map)); |
103 | | - } |
104 | | - // Get rid of our read lock |
105 | | - drop(conn_map); |
106 | | - |
107 | | - let mut conn_map = self.connections.write().map_err(|_| Error::Error)?; |
108 | | - let client = redis::Client::open(address).map_err(|_| Error::Error)?; |
109 | | - let conn = client.get_connection().map_err(|_| Error::Error)?; |
110 | | - tracing::debug!("Build new connection: {:?}", address); |
111 | | - conn_map.insert(address.to_string(), Mutex::new(conn)); |
112 | | - // Get rid of our write lock |
113 | | - drop(conn_map); |
114 | | - |
115 | | - let conn_map = self.connections.read().map_err(|_| Error::Error)?; |
116 | | - Ok(RwLockReadGuardRef::new(conn_map)) |
| 64 | + async fn get_conn(&self, address: &str) -> Result<Arc<Mutex<Connection>>> { |
| 65 | + let conn_map = self.connections.read().await; |
| 66 | + let conn = if let Some(conn) = conn_map.get(address) { |
| 67 | + conn.clone() |
| 68 | + } else { |
| 69 | + let conn = redis::Client::open(address)?.get_async_connection().await?; |
| 70 | + let conn = Arc::new(Mutex::new(conn)); |
| 71 | + self.connections |
| 72 | + .write() |
| 73 | + .await |
| 74 | + .insert(address.to_string(), conn.clone()); |
| 75 | + conn |
| 76 | + }; |
| 77 | + Ok(conn) |
117 | 78 | } |
118 | 79 | } |
| 80 | + |
| 81 | +fn log_error(err: impl std::fmt::Debug) -> Error { |
| 82 | + tracing::warn!("Outbound Redis error: {err:?}"); |
| 83 | + Error::Error |
| 84 | +} |
0 commit comments