@@ -11,9 +11,16 @@ import (
1111 "gorm.io/gorm"
1212)
1313
14+ // TablePrimaryKey holds the primary key column for a table
15+ type TablePrimaryKey struct {
16+ Table string
17+ PKColumn string // Empty string means no PK found, will use ctid
18+ }
19+
1420// GenerateSQL generates anonymization SQL from rules
1521// Uses PostgreSQL row_number() for deterministic anonymization
16- func GenerateSQL (rules []models.AnonRule ) string {
22+ // primaryKeys maps table names to their primary key columns for consistent ordering
23+ func GenerateSQL (rules []models.AnonRule , primaryKeys map [string ]string ) string {
1724 if len (rules ) == 0 {
1825 return ""
1926 }
@@ -27,43 +34,99 @@ func GenerateSQL(rules []models.AnonRule) string {
2734 var sqlStatements []string
2835
2936 for table , rules := range tableRules {
30- sql := generateTableUpdateSQL (table , rules )
37+ pkColumn := primaryKeys [table ] // Empty string if not found
38+ sql := generateTableUpdateSQL (table , rules , pkColumn )
3139 sqlStatements = append (sqlStatements , sql )
3240 }
3341
3442 return strings .Join (sqlStatements , "\n \n " )
3543}
3644
45+ // generatePrimaryKeyQuerySQL generates SQL to query primary keys for all tables
46+ func generatePrimaryKeyQuerySQL (tables []string ) string {
47+ if len (tables ) == 0 {
48+ return ""
49+ }
50+
51+ // Build SQL to find primary key columns for all tables
52+ // Returns: table_name | column_name (one row per table with single-column PK)
53+ quotedTables := make ([]string , len (tables ))
54+ for i , table := range tables {
55+ quotedTables [i ] = fmt .Sprintf ("'%s'" , strings .ReplaceAll (table , "'" , "''" ))
56+ }
57+
58+ sql := fmt .Sprintf (`
59+ SELECT
60+ t.tablename as table_name,
61+ a.attname as column_name
62+ FROM pg_tables t
63+ JOIN pg_class c ON c.relname = t.tablename
64+ JOIN pg_index i ON i.indrelid = c.oid AND i.indisprimary
65+ JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = ANY(i.indkey)
66+ WHERE t.schemaname = 'public'
67+ AND t.tablename IN (%s)
68+ AND array_length(i.indkey, 1) = 1 -- Only single-column primary keys
69+ ORDER BY t.tablename;
70+ ` , strings .Join (quotedTables , ", " ))
71+
72+ return sql
73+ }
74+
3775// generateTableUpdateSQL generates UPDATE statement for a single table
38- func generateTableUpdateSQL (table string , rules []models.AnonRule ) string {
76+ // pkColumn is the primary key column name (empty string means use ctid)
77+ func generateTableUpdateSQL (table string , rules []models.AnonRule , pkColumn string ) string {
3978 if len (rules ) == 0 {
4079 return ""
4180 }
4281
43- // Build SET clause with row_number replacement
82+ // Determine ordering: use primary key if available, otherwise ctid
83+ var orderBy string
84+ var orderByComment string
85+ if pkColumn != "" {
86+ orderBy = quoteIdentifier (pkColumn )
87+ orderByComment = fmt .Sprintf (" (ordered by PK: %s)" , pkColumn )
88+ } else {
89+ orderBy = "ctid"
90+ orderByComment = " (ordered by ctid - no PK found)"
91+ }
92+
93+ // Build SET clause with row_number replacement and IS DISTINCT FROM for idempotency
4494 var setClauses []string
95+ var whereConditions []string
4596 for _ , rule := range rules {
46- // Replace ${index} with row number in the template
47- // Use row_number() OVER (ORDER BY primary key or ctid for deterministic ordering)
4897 setValue := renderTemplate (rule .Template )
49- setClauses = append (setClauses , fmt .Sprintf ("%s = %s" , quoteIdentifier (rule .Column ), setValue ))
98+ columnQuoted := quoteIdentifier (rule .Column )
99+
100+ // Add SET clause
101+ setClauses = append (setClauses , fmt .Sprintf ("%s = %s" , columnQuoted , setValue ))
102+
103+ // Add condition to skip rows that already have the target value (idempotency)
104+ whereConditions = append (whereConditions , fmt .Sprintf ("%s.%s IS DISTINCT FROM %s" ,
105+ quoteIdentifier (table ), columnQuoted , setValue ))
50106 }
51107
108+ // Combine WHERE conditions with OR (update if ANY column is different)
109+ whereClause := strings .Join (whereConditions , " OR " )
110+
52111 // Use CTE with row numbers for deterministic updates
53- sql := fmt .Sprintf (`-- Anonymize table: %s
112+ sql := fmt .Sprintf (`-- Anonymize table: %s%s
54113WITH numbered_rows AS (
55- SELECT ctid, row_number() OVER (ORDER BY ctid ) as _row_num
114+ SELECT ctid, row_number() OVER (ORDER BY %s ) as _row_num
56115 FROM %s
57116)
58117UPDATE %s
59118SET %s
60119FROM numbered_rows
61- WHERE %s.ctid = numbered_rows.ctid;` ,
120+ WHERE %s.ctid = numbered_rows.ctid
121+ AND (%s);` ,
62122 table ,
123+ orderByComment ,
124+ orderBy ,
63125 quoteIdentifier (table ),
64126 quoteIdentifier (table ),
65127 strings .Join (setClauses , ",\n " ),
66128 quoteIdentifier (table ),
129+ whereClause ,
67130 )
68131
69132 return sql
@@ -129,8 +192,62 @@ func Apply(ctx context.Context, db *gorm.DB, params ApplyParams, logger zerolog.
129192 Int ("rule_count" , len (rules )).
130193 Msg ("Applying anonymization rules" )
131194
132- // Generate SQL from rules
133- sql := GenerateSQL (rules )
195+ // Extract unique table names from rules
196+ tableMap := make (map [string ]bool )
197+ for _ , rule := range rules {
198+ tableMap [rule .Table ] = true
199+ }
200+ var tables []string
201+ for table := range tableMap {
202+ tables = append (tables , table )
203+ }
204+
205+ // Query for primary keys
206+ primaryKeys := make (map [string ]string )
207+ if len (tables ) > 0 {
208+ pkQuerySQL := generatePrimaryKeyQuerySQL (tables )
209+ pkScript := fmt .Sprintf (`#!/bin/bash
210+ set -euo pipefail
211+ DATABASE_NAME="%s"
212+ PG_VERSION="%s"
213+ PG_PORT="%d"
214+ PG_BIN="/usr/lib/postgresql/${PG_VERSION}/bin"
215+
216+ sudo -u postgres ${PG_BIN}/psql -p ${PG_PORT} -d "${DATABASE_NAME}" -t -A -F'|' <<'PK_QUERY'
217+ %s
218+ PK_QUERY
219+ ` , params .DatabaseName , params .PostgresVersion , params .PostgresPort , pkQuerySQL )
220+
221+ cmd := exec .CommandContext (ctx , "bash" , "-c" , pkScript )
222+ outputBytes , err := cmd .CombinedOutput ()
223+ if err != nil {
224+ // Log warning but continue - we'll use ctid as fallback
225+ logger .Warn ().
226+ Err (err ).
227+ Str ("output" , string (outputBytes )).
228+ Msg ("Failed to query primary keys, will use ctid for ordering" )
229+ } else {
230+ // Parse output: table_name|column_name (one per line)
231+ output := strings .TrimSpace (string (outputBytes ))
232+ if output != "" {
233+ for _ , line := range strings .Split (output , "\n " ) {
234+ parts := strings .Split (line , "|" )
235+ if len (parts ) == 2 {
236+ tableName := strings .TrimSpace (parts [0 ])
237+ columnName := strings .TrimSpace (parts [1 ])
238+ primaryKeys [tableName ] = columnName
239+ logger .Debug ().
240+ Str ("table" , tableName ).
241+ Str ("pk_column" , columnName ).
242+ Msg ("Detected primary key" )
243+ }
244+ }
245+ }
246+ }
247+ }
248+
249+ // Generate SQL from rules with primary key information
250+ sql := GenerateSQL (rules , primaryKeys )
134251 if sql == "" {
135252 logger .Warn ().Msg ("Generated empty SQL from rules" )
136253 return 0 , nil
0 commit comments