@@ -3,8 +3,17 @@ use std::collections::HashSet;
33use std:: fmt;
44use std:: str:: FromStr ;
55use std:: sync:: Arc ;
6+ use thiserror:: Error ;
67use tokio_postgres:: types:: { FromSql , ToSql , Type } ;
78
9+ /// Errors that can occur during schema operations.
10+ #[ derive( Debug , Error ) ]
11+ pub enum SchemaError {
12+ /// Columns were received during replication that do not exist in the stored table schema.
13+ #[ error( "received columns during replication that are not in the stored table schema: {0:?}" ) ]
14+ UnknownReplicatedColumns ( Vec < String > ) ,
15+ }
16+
817/// An object identifier in Postgres.
918type Oid = u32 ;
1019
@@ -226,8 +235,35 @@ impl ReplicationMask {
226235 ///
227236 /// The mask is constructed by checking which column names from the schema are present
228237 /// in the provided set of replicated column names.
229- pub fn build ( schema : & TableSchema , replicated_column_names : & HashSet < String > ) -> Self {
230- let mask = schema
238+ ///
239+ /// # Errors
240+ ///
241+ /// Returns [`SchemaError::UnknownReplicatedColumns`] if any column in
242+ /// `replicated_column_names` does not exist in the table schema.
243+ ///
244+ /// The column validation occurs because we have to make sure that the stored table schema is always
245+ /// up to date, if not, it's a critical problem.
246+ pub fn build (
247+ table_schema : & TableSchema ,
248+ replicated_column_names : & HashSet < String > ,
249+ ) -> Result < Self , SchemaError > {
250+ let schema_column_names: HashSet < & str > = table_schema
251+ . column_schemas
252+ . iter ( )
253+ . map ( |column_schema| column_schema. name . as_str ( ) )
254+ . collect ( ) ;
255+
256+ let unknown_columns: Vec < String > = replicated_column_names
257+ . iter ( )
258+ . filter ( |name| !schema_column_names. contains ( name. as_str ( ) ) )
259+ . cloned ( )
260+ . collect ( ) ;
261+
262+ if !unknown_columns. is_empty ( ) {
263+ return Err ( SchemaError :: UnknownReplicatedColumns ( unknown_columns) ) ;
264+ }
265+
266+ let mask = table_schema
231267 . column_schemas
232268 . iter ( )
233269 . map ( |cs| {
@@ -239,7 +275,7 @@ impl ReplicationMask {
239275 } )
240276 . collect ( ) ;
241277
242- Self ( Arc :: new ( mask) )
278+ Ok ( Self ( Arc :: new ( mask) ) )
243279 }
244280
245281 /// Returns the underlying mask as a slice.
@@ -272,15 +308,16 @@ pub struct ReplicatedTableSchema {
272308
273309impl ReplicatedTableSchema {
274310 /// Creates a [`ReplicatedTableSchema`] from a schema and a pre-computed mask.
275- pub fn from_mask ( schema : Arc < TableSchema > , mask : ReplicationMask ) -> Self {
311+ pub fn from_mask ( table_schema : Arc < TableSchema > , replication_mask : ReplicationMask ) -> Self {
276312 debug_assert_eq ! (
277- schema . column_schemas. len( ) ,
278- mask . len( ) ,
313+ table_schema . column_schemas. len( ) ,
314+ replication_mask . len( ) ,
279315 "mask length must match column count"
280316 ) ;
317+
281318 Self {
282- table_schema : schema ,
283- replication_mask : mask ,
319+ table_schema,
320+ replication_mask,
284321 }
285322 }
286323
@@ -323,3 +360,91 @@ impl ReplicatedTableSchema {
323360 . filter_map ( |( cs, & m) | if m == 1 { Some ( cs) } else { None } )
324361 }
325362}
363+
364+ #[ cfg( test) ]
365+ mod tests {
366+ use super :: * ;
367+
368+ fn create_test_table_schema ( ) -> TableSchema {
369+ TableSchema :: new (
370+ TableId :: new ( 123 ) ,
371+ TableName :: new ( "public" . to_string ( ) , "test_table" . to_string ( ) ) ,
372+ vec ! [
373+ ColumnSchema :: new( "id" . to_string( ) , Type :: INT4 , -1 , 1 , Some ( 1 ) , false ) ,
374+ ColumnSchema :: new( "name" . to_string( ) , Type :: TEXT , -1 , 2 , None , true ) ,
375+ ColumnSchema :: new( "age" . to_string( ) , Type :: INT4 , -1 , 3 , None , true ) ,
376+ ] ,
377+ )
378+ }
379+
380+ #[ test]
381+ fn test_replication_mask_build_all_columns_replicated ( ) {
382+ let schema = create_test_table_schema ( ) ;
383+ let replicated_columns: HashSet < String > = [ "id" , "name" , "age" ]
384+ . into_iter ( )
385+ . map ( String :: from)
386+ . collect ( ) ;
387+
388+ let mask = ReplicationMask :: build ( & schema, & replicated_columns) . unwrap ( ) ;
389+
390+ assert_eq ! ( mask. as_slice( ) , & [ 1 , 1 , 1 ] ) ;
391+ }
392+
393+ #[ test]
394+ fn test_replication_mask_build_partial_columns_replicated ( ) {
395+ let schema = create_test_table_schema ( ) ;
396+ let replicated_columns: HashSet < String > =
397+ [ "id" , "age" ] . into_iter ( ) . map ( String :: from) . collect ( ) ;
398+
399+ let mask = ReplicationMask :: build ( & schema, & replicated_columns) . unwrap ( ) ;
400+
401+ assert_eq ! ( mask. as_slice( ) , & [ 1 , 0 , 1 ] ) ;
402+ }
403+
404+ #[ test]
405+ fn test_replication_mask_build_no_columns_replicated ( ) {
406+ let schema = create_test_table_schema ( ) ;
407+ let replicated_columns: HashSet < String > = HashSet :: new ( ) ;
408+
409+ let mask = ReplicationMask :: build ( & schema, & replicated_columns) . unwrap ( ) ;
410+
411+ assert_eq ! ( mask. as_slice( ) , & [ 0 , 0 , 0 ] ) ;
412+ }
413+
414+ #[ test]
415+ fn test_replication_mask_build_unknown_column_error ( ) {
416+ let schema = create_test_table_schema ( ) ;
417+ let replicated_columns: HashSet < String > = [ "id" , "unknown_column" ]
418+ . into_iter ( )
419+ . map ( String :: from)
420+ . collect ( ) ;
421+
422+ let result = ReplicationMask :: build ( & schema, & replicated_columns) ;
423+
424+ assert ! ( result. is_err( ) ) ;
425+ let err = result. unwrap_err ( ) ;
426+ match err {
427+ SchemaError :: UnknownReplicatedColumns ( columns) => {
428+ assert_eq ! ( columns, vec![ "unknown_column" . to_string( ) ] ) ;
429+ }
430+ }
431+ }
432+
433+ #[ test]
434+ fn test_replication_mask_build_multiple_unknown_columns_error ( ) {
435+ let schema = create_test_table_schema ( ) ;
436+ let replicated_columns: HashSet < String > =
437+ [ "id" , "foo" , "bar" ] . into_iter ( ) . map ( String :: from) . collect ( ) ;
438+
439+ let result = ReplicationMask :: build ( & schema, & replicated_columns) ;
440+
441+ assert ! ( result. is_err( ) ) ;
442+ let err = result. unwrap_err ( ) ;
443+ match err {
444+ SchemaError :: UnknownReplicatedColumns ( mut columns) => {
445+ columns. sort ( ) ;
446+ assert_eq ! ( columns, vec![ "bar" . to_string( ) , "foo" . to_string( ) ] ) ;
447+ }
448+ }
449+ }
450+ }
0 commit comments