Skip to content

Commit dca47b7

Browse files
committed
Improve
1 parent cb06bde commit dca47b7

File tree

7 files changed

+160
-14
lines changed

7 files changed

+160
-14
lines changed

etl-destinations/src/bigquery/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ mod tests {
837837
TableName::new("public".to_string(), "test_table".to_string()),
838838
columns,
839839
));
840-
let replication_mask = ReplicationMask::build(&table_schema, &column_names);
840+
let replication_mask = ReplicationMask::build(&table_schema, &column_names).unwrap();
841841

842842
ReplicatedTableSchema::from_mask(table_schema, replication_mask)
843843
}

etl-postgres/src/types/schema.rs

Lines changed: 133 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,17 @@ use std::collections::HashSet;
33
use std::fmt;
44
use std::str::FromStr;
55
use std::sync::Arc;
6+
use thiserror::Error;
67
use tokio_postgres::types::{FromSql, ToSql, Type};
78

9+
/// Errors that can occur during schema operations.
10+
#[derive(Debug, Error)]
11+
pub enum SchemaError {
12+
/// Columns were received during replication that do not exist in the stored table schema.
13+
#[error("received columns during replication that are not in the stored table schema: {0:?}")]
14+
UnknownReplicatedColumns(Vec<String>),
15+
}
16+
817
/// An object identifier in Postgres.
918
type Oid = u32;
1019

@@ -226,8 +235,35 @@ impl ReplicationMask {
226235
///
227236
/// The mask is constructed by checking which column names from the schema are present
228237
/// in the provided set of replicated column names.
229-
pub fn build(schema: &TableSchema, replicated_column_names: &HashSet<String>) -> Self {
230-
let mask = schema
238+
///
239+
/// # Errors
240+
///
241+
/// Returns [`SchemaError::UnknownReplicatedColumns`] if any column in
242+
/// `replicated_column_names` does not exist in the table schema.
243+
///
244+
/// The column validation occurs because we have to make sure that the stored table schema is always
245+
/// up to date, if not, it's a critical problem.
246+
pub fn build(
247+
table_schema: &TableSchema,
248+
replicated_column_names: &HashSet<String>,
249+
) -> Result<Self, SchemaError> {
250+
let schema_column_names: HashSet<&str> = table_schema
251+
.column_schemas
252+
.iter()
253+
.map(|column_schema| column_schema.name.as_str())
254+
.collect();
255+
256+
let unknown_columns: Vec<String> = replicated_column_names
257+
.iter()
258+
.filter(|name| !schema_column_names.contains(name.as_str()))
259+
.cloned()
260+
.collect();
261+
262+
if !unknown_columns.is_empty() {
263+
return Err(SchemaError::UnknownReplicatedColumns(unknown_columns));
264+
}
265+
266+
let mask = table_schema
231267
.column_schemas
232268
.iter()
233269
.map(|cs| {
@@ -239,7 +275,7 @@ impl ReplicationMask {
239275
})
240276
.collect();
241277

242-
Self(Arc::new(mask))
278+
Ok(Self(Arc::new(mask)))
243279
}
244280

245281
/// Returns the underlying mask as a slice.
@@ -272,15 +308,16 @@ pub struct ReplicatedTableSchema {
272308

273309
impl ReplicatedTableSchema {
274310
/// Creates a [`ReplicatedTableSchema`] from a schema and a pre-computed mask.
275-
pub fn from_mask(schema: Arc<TableSchema>, mask: ReplicationMask) -> Self {
311+
pub fn from_mask(table_schema: Arc<TableSchema>, replication_mask: ReplicationMask) -> Self {
276312
debug_assert_eq!(
277-
schema.column_schemas.len(),
278-
mask.len(),
313+
table_schema.column_schemas.len(),
314+
replication_mask.len(),
279315
"mask length must match column count"
280316
);
317+
281318
Self {
282-
table_schema: schema,
283-
replication_mask: mask,
319+
table_schema,
320+
replication_mask,
284321
}
285322
}
286323

@@ -323,3 +360,91 @@ impl ReplicatedTableSchema {
323360
.filter_map(|(cs, &m)| if m == 1 { Some(cs) } else { None })
324361
}
325362
}
363+
364+
#[cfg(test)]
365+
mod tests {
366+
use super::*;
367+
368+
fn create_test_table_schema() -> TableSchema {
369+
TableSchema::new(
370+
TableId::new(123),
371+
TableName::new("public".to_string(), "test_table".to_string()),
372+
vec![
373+
ColumnSchema::new("id".to_string(), Type::INT4, -1, 1, Some(1), false),
374+
ColumnSchema::new("name".to_string(), Type::TEXT, -1, 2, None, true),
375+
ColumnSchema::new("age".to_string(), Type::INT4, -1, 3, None, true),
376+
],
377+
)
378+
}
379+
380+
#[test]
381+
fn test_replication_mask_build_all_columns_replicated() {
382+
let schema = create_test_table_schema();
383+
let replicated_columns: HashSet<String> = ["id", "name", "age"]
384+
.into_iter()
385+
.map(String::from)
386+
.collect();
387+
388+
let mask = ReplicationMask::build(&schema, &replicated_columns).unwrap();
389+
390+
assert_eq!(mask.as_slice(), &[1, 1, 1]);
391+
}
392+
393+
#[test]
394+
fn test_replication_mask_build_partial_columns_replicated() {
395+
let schema = create_test_table_schema();
396+
let replicated_columns: HashSet<String> =
397+
["id", "age"].into_iter().map(String::from).collect();
398+
399+
let mask = ReplicationMask::build(&schema, &replicated_columns).unwrap();
400+
401+
assert_eq!(mask.as_slice(), &[1, 0, 1]);
402+
}
403+
404+
#[test]
405+
fn test_replication_mask_build_no_columns_replicated() {
406+
let schema = create_test_table_schema();
407+
let replicated_columns: HashSet<String> = HashSet::new();
408+
409+
let mask = ReplicationMask::build(&schema, &replicated_columns).unwrap();
410+
411+
assert_eq!(mask.as_slice(), &[0, 0, 0]);
412+
}
413+
414+
#[test]
415+
fn test_replication_mask_build_unknown_column_error() {
416+
let schema = create_test_table_schema();
417+
let replicated_columns: HashSet<String> = ["id", "unknown_column"]
418+
.into_iter()
419+
.map(String::from)
420+
.collect();
421+
422+
let result = ReplicationMask::build(&schema, &replicated_columns);
423+
424+
assert!(result.is_err());
425+
let err = result.unwrap_err();
426+
match err {
427+
SchemaError::UnknownReplicatedColumns(columns) => {
428+
assert_eq!(columns, vec!["unknown_column".to_string()]);
429+
}
430+
}
431+
}
432+
433+
#[test]
434+
fn test_replication_mask_build_multiple_unknown_columns_error() {
435+
let schema = create_test_table_schema();
436+
let replicated_columns: HashSet<String> =
437+
["id", "foo", "bar"].into_iter().map(String::from).collect();
438+
439+
let result = ReplicationMask::build(&schema, &replicated_columns);
440+
441+
assert!(result.is_err());
442+
let err = result.unwrap_err();
443+
match err {
444+
SchemaError::UnknownReplicatedColumns(mut columns) => {
445+
columns.sort();
446+
assert_eq!(columns, vec!["bar".to_string(), "foo".to_string()]);
447+
}
448+
}
449+
}
450+
}

etl/src/error.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,27 @@ impl From<etl_postgres::replication::slots::EtlReplicationSlotError> for EtlErro
10101010
}
10111011
}
10121012

1013+
/// Converts [`etl_postgres::types::SchemaError`] to [`EtlError`] with [`ErrorKind::InvalidState`].
1014+
impl From<etl_postgres::types::SchemaError> for EtlError {
1015+
#[track_caller]
1016+
fn from(err: etl_postgres::types::SchemaError) -> EtlError {
1017+
match err {
1018+
etl_postgres::types::SchemaError::UnknownReplicatedColumns(columns) => {
1019+
EtlError::from_components(
1020+
ErrorKind::InvalidState,
1021+
Cow::Borrowed(
1022+
"Received columns during replication that are not in the stored table schema",
1023+
),
1024+
Some(Cow::Owned(format!(
1025+
"The columns that are not in the table schema are: {columns:?}"
1026+
))),
1027+
None,
1028+
)
1029+
}
1030+
}
1031+
}
1032+
}
1033+
10131034
#[cfg(test)]
10141035
mod tests {
10151036
use super::*;

etl/src/replication/apply.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1236,7 +1236,7 @@ where
12361236
"received relation message, building replication mask"
12371237
);
12381238

1239-
let replication_mask = ReplicationMask::build(&table_schema, &replicated_columns);
1239+
let replication_mask = ReplicationMask::build(&table_schema, &replicated_columns)?;
12401240
replication_masks.set(table_id, replication_mask).await;
12411241

12421242
Ok(HandleMessageResult::no_event())

etl/src/replication/masks.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ mod tests {
7676

7777
let replicated_columns: HashSet<String> =
7878
["id".to_string(), "age".to_string()].into_iter().collect();
79-
ReplicationMask::build(&schema, &replicated_columns)
79+
ReplicationMask::build(&schema, &replicated_columns).unwrap()
8080
}
8181

8282
#[tokio::test]

etl/src/replication/table_sync.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ where
202202
.await?;
203203

204204
// Build and store the replication mask for use during CDC.
205-
let replication_mask = ReplicationMask::build(&table_schema, &replicated_column_names);
205+
let replication_mask = ReplicationMask::build(&table_schema, &replicated_column_names)?;
206206
replication_masks
207207
.set(table_id, replication_mask.clone())
208208
.await;

etl/src/test_utils/test_schema.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ pub fn build_expected_users_inserts(
348348
.collect::<HashSet<_>>();
349349
let replicated_table_schema = ReplicatedTableSchema::from_mask(
350350
Arc::new(users_table_schema.clone()),
351-
ReplicationMask::build(users_table_schema, &users_table_column_names),
351+
ReplicationMask::build(users_table_schema, &users_table_column_names).unwrap(),
352352
);
353353

354354
for (name, age) in expected_rows {
@@ -386,7 +386,7 @@ pub fn build_expected_orders_inserts(
386386
.collect::<HashSet<_>>();
387387
let replicated_table_schema = ReplicatedTableSchema::from_mask(
388388
Arc::new(orders_table_schema.clone()),
389-
ReplicationMask::build(orders_table_schema, &orders_table_column_names),
389+
ReplicationMask::build(orders_table_schema, &orders_table_column_names).unwrap(),
390390
);
391391

392392
for name in expected_rows {

0 commit comments

Comments
 (0)