Skip to content

Commit b0c7df4

Browse files
committed
Streamlit OLTP implement writeback
1 parent 259cab0 commit b0c7df4

File tree

1 file changed

+127
-67
lines changed

1 file changed

+127
-67
lines changed

streamlit/views/oltp_database_connect.py

Lines changed: 127 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111
st.header("OLTP Database", divider=True)
1212
st.subheader("Connect a table")
1313
st.write(
14-
"This app connects to a [Databricks Lakebase](https://docs.databricks.com/aws/en/oltp/) OLTP database instance. "
15-
"Provide the instance name, database, schema, and table."
14+
"This app connects to a [Databricks Lakebase](https://docs.databricks.com/aws/en/oltp/) OLTP database instance for reads and writes, e.g., of App state. "
15+
"Provide the instance name, database, schema, and state table."
1616
)
1717

1818

1919
w = WorkspaceClient()
2020

21+
session_id = str(uuid.uuid4())
22+
if "session_id" not in st.session_state:
23+
st.session_state["session_id"] = session_id
24+
2125

2226
def generate_token(instance_name: str) -> str:
2327
cred = w.database.generate_database_credential(
@@ -38,7 +42,7 @@ def connect(cls, conninfo: str = "", **kwargs):
3842

3943

4044
@st.cache_resource
41-
def build_pool(*, instance_name: str, host: str, user: str, database: str) -> ConnectionPool:
45+
def build_pool(instance_name: str, host: str, user: str, database: str) -> ConnectionPool:
4246
conninfo = f"host={host} dbname={database} user={user}"
4347
return ConnectionPool(
4448
conninfo=conninfo,
@@ -50,12 +54,30 @@ def build_pool(*, instance_name: str, host: str, user: str, database: str) -> Co
5054
)
5155

5256

57+
def upsert_app_state(pool, session_id: str, state: dict):
58+
with pool.connection() as conn:
59+
with conn.cursor() as cur:
60+
for key, value in state.items():
61+
cur.execute(f"""
62+
INSERT INTO app_state (session_id, key, value, updated_at)
63+
VALUES ('{session_id}', '{key}', '{value}', CURRENT_TIMESTAMP)
64+
ON CONFLICT (session_id, key) DO UPDATE
65+
SET value = EXCLUDED.value,
66+
updated_at = CURRENT_TIMESTAMP
67+
""")
68+
conn.commit()
69+
70+
5371
def query_df(pool: ConnectionPool, sql: str) -> pd.DataFrame:
5472
with pool.connection() as conn:
5573
with conn.cursor() as cur:
5674
cur.execute(sql)
75+
if not cur.description:
76+
return pd.DataFrame()
77+
5778
cols = [d.name for d in cur.description]
5879
rows = cur.fetchall()
80+
5981
return pd.DataFrame(rows, columns=cols)
6082

6183

@@ -64,9 +86,9 @@ def query_df(pool: ConnectionPool, sql: str) -> pd.DataFrame:
6486
with tab_try:
6587
instance_names = [i.name for i in w.database.list_database_instances()]
6688
instance_name = st.selectbox("Database instance:", instance_names)
67-
database = st.text_input("Database:", placeholder="customer_database")
68-
table = st.text_input("Table in a database schema:", placeholder="customer_core.customers_oltp")
69-
limit = st.text_input("Limit:", value=10)
89+
database = st.text_input("Database:", value="databricks_postgres")
90+
schema = st.text_input("Schema:", value="public")
91+
table = st.text_input("Table:", value="app_state")
7092

7193
user = w.current_user.me().user_name
7294
host = ""
@@ -77,75 +99,113 @@ def query_df(pool: ConnectionPool, sql: str) -> pd.DataFrame:
7799
if not all([instance_name, host, database, table]):
78100
st.error("Please provide instance, database, and schema-table.")
79101
else:
80-
pool = build_pool(instance_name=instance_name, host=host, user=user, database=database)
81-
sql = f"SELECT * FROM {table} LIMIT {int(limit)};"
82-
df = query_df(pool, sql)
83-
st.dataframe(df, use_container_width=True)
102+
pool = build_pool(instance_name, host, user, database)
103+
104+
create_table_sql = f"""
105+
CREATE TABLE IF NOT EXISTS {schema}.{table} (
106+
session_id TEXT,
107+
key TEXT,
108+
value TEXT,
109+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
110+
PRIMARY KEY (session_id, key)
111+
)
112+
"""
113+
with pool.connection() as conn:
114+
with conn.cursor() as cur:
115+
cur.execute(create_table_sql)
116+
conn.commit()
117+
118+
state = {"feedback_message": "true"}
119+
upsert_app_state(pool, session_id, state)
120+
121+
df = query_df(pool, f"SELECT * FROM {schema}.{table} WHERE session_id='{session_id}'")
122+
st.dataframe(df)
84123

85124
with tab_code:
86125
st.code(
87126
'''
88-
import uuid
89-
import streamlit as st
90-
import pandas as pd
91-
92-
from databricks.sdk import WorkspaceClient
93-
94-
import psycopg
95-
from psycopg_pool import ConnectionPool
96-
97-
98-
w = WorkspaceClient()
99-
100-
101-
def generate_token(instance_name: str) -> str:
102-
cred = w.database.generate_database_credential(
103-
request_id=str(uuid.uuid4()), instance_names=[instance_name]
104-
)
105-
106-
return cred.token
127+
import uuid
128+
import streamlit as st
129+
import pandas as pd
107130
108-
109-
class RotatingTokenConnection(psycopg.Connection):
110-
@classmethod
111-
def connect(cls, conninfo: str = "", **kwargs):
112-
instance_name = kwargs.pop("_instance_name")
113-
kwargs["password"] = generate_token(instance_name)
114-
kwargs.setdefault("sslmode", "require")
115-
return super().connect(conninfo, **kwargs)
131+
from databricks.sdk import WorkspaceClient
132+
import psycopg
133+
from psycopg_pool import ConnectionPool
116134
117135
118-
@st.cache_resource
119-
def build_pool(instance_name: str, host: str, user: str, database: str) -> ConnectionPool:
120-
return ConnectionPool(
121-
conninfo=f"host={host} dbname={database} user={user}",
122-
connection_class=RotatingTokenConnection,
123-
kwargs={"_instance_name": instance_name},
124-
min_size=1,
125-
max_size=10,
126-
open=True,
127-
)
128-
129-
130-
def query_df(pool: ConnectionPool, sql: str) -> pd.DataFrame:
131-
with pool.connection() as conn:
132-
with conn.cursor() as cur:
133-
cur.execute(sql)
134-
cols = [d.name for d in cur.description]
135-
rows = cur.fetchall()
136+
w = WorkspaceClient()
136137
137-
return pd.DataFrame(rows, columns=cols)
138-
139-
140-
instance_name = "dbase_instance"
141-
database = "customer_database"
142-
table = "customer_core.customers_oltp"
143-
user = w.current_user.me().user_name
144-
host = w.database.get_database_instance(name=instance_name).read_write_dns
138+
139+
class RotatingTokenConnection(psycopg.Connection):
140+
@classmethod
141+
def connect(cls, conninfo: str = "", **kwargs):
142+
kwargs["password"] = w.database.generate_database_credential(
143+
request_id=str(uuid.uuid4()),
144+
instance_names=[kwargs.pop("_instance_name")]
145+
).token
146+
kwargs.setdefault("sslmode", "require")
147+
return super().connect(conninfo, **kwargs)
148+
149+
150+
@st.cache_resource
151+
def build_pool(instance_name: str, host: str, user: str, database: str) -> ConnectionPool:
152+
return ConnectionPool(
153+
conninfo=f"host={host} dbname={database} user={user}",
154+
connection_class=RotatingTokenConnection,
155+
kwargs={"_instance_name": instance_name},
156+
min_size=1,
157+
max_size=5,
158+
open=True,
159+
)
160+
161+
162+
def query_df(pool: ConnectionPool, sql: str) -> pd.DataFrame:
163+
with pool.connection() as conn:
164+
with conn.cursor() as cur:
165+
cur.execute(sql)
166+
if cur.description is None:
167+
return pd.DataFrame()
168+
cols = [d.name for d in cur.description]
169+
rows = cur.fetchall()
170+
return pd.DataFrame(rows, columns=cols)
171+
172+
173+
session_id = str(uuid.uuid4())
174+
if "session_id" not in st.session_state:
175+
st.session_state["session_id"] = session_id
176+
177+
178+
instance_name = "dbase_instance"
179+
database = "databricks_postgres"
180+
schema = "public"
181+
table = "app_state"
182+
user = w.current_user.me().user_name
183+
host = w.database.get_database_instance(name=instance_name).read_write_dns
145184
146-
pool = build_pool(instance_name, host, user, database)
147-
df = query_df(pool, f'SELECT * FROM {table} LIMIT 100')
148-
st.dataframe(df)
185+
pool = build_pool(instance_name, host, user, database)
186+
187+
with pool.connection() as conn:
188+
with conn.cursor() as cur:
189+
cur.execute(f"""
190+
CREATE TABLE IF NOT EXISTS {schema}.{table} (
191+
session_id TEXT,
192+
key TEXT,
193+
value TEXT,
194+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
195+
PRIMARY KEY (session_id, key)
196+
)
197+
""")
198+
199+
cur.execute(f"""
200+
INSERT INTO app_state (session_id, key, value, updated_at)
201+
VALUES ('{session_id}', 'feedback_message', 'true', CURRENT_TIMESTAMP)
202+
ON CONFLICT (session_id, key) DO UPDATE
203+
SET value = EXCLUDED.value,
204+
updated_at = CURRENT_TIMESTAMP
205+
""")
206+
207+
df = query_df(pool, f"SELECT * FROM {schema}.{table} WHERE session_id = '{session_id}'")
208+
st.dataframe(df)
149209
''',
150210
language="python",
151211
)
@@ -167,7 +227,7 @@ def query_df(pool: ConnectionPool, sql: str) -> pd.DataFrame:
167227
'''
168228
GRANT CONNECT ON DATABASE databricks_postgres TO "099f0306-9e29-4a87-84c0-3046e4bcea02";
169229
GRANT USAGE ON SCHEMA public TO "099f0306-9e29-4a87-84c0-3046e4bcea02";
170-
GRANT SELECT, INSERT, UPDATE, DELETE ON TABLE quotes_history TO "099f0306-9e29-4a87-84c0-3046e4bcea02";
230+
GRANT SELECT, INSERT, UPDATE, DELETE ON TABLE app_state TO "099f0306-9e29-4a87-84c0-3046e4bcea02";
171231
''',
172232
language="sql",
173233
)

0 commit comments

Comments
 (0)