Skip to content

Commit 93c4c2a

Browse files
committed
fix: anon rules runs idempotently
1 parent fa58e8b commit 93c4c2a

File tree

2 files changed

+162
-24
lines changed

2 files changed

+162
-24
lines changed

internal/anonymize/anonymize.go

Lines changed: 129 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,16 @@ import (
1111
"gorm.io/gorm"
1212
)
1313

14+
// TablePrimaryKey holds the primary key column for a table
15+
type TablePrimaryKey struct {
16+
Table string
17+
PKColumn string // Empty string means no PK found, will use ctid
18+
}
19+
1420
// GenerateSQL generates anonymization SQL from rules
1521
// Uses PostgreSQL row_number() for deterministic anonymization
16-
func GenerateSQL(rules []models.AnonRule) string {
22+
// primaryKeys maps table names to their primary key columns for consistent ordering
23+
func GenerateSQL(rules []models.AnonRule, primaryKeys map[string]string) string {
1724
if len(rules) == 0 {
1825
return ""
1926
}
@@ -27,43 +34,99 @@ func GenerateSQL(rules []models.AnonRule) string {
2734
var sqlStatements []string
2835

2936
for table, rules := range tableRules {
30-
sql := generateTableUpdateSQL(table, rules)
37+
pkColumn := primaryKeys[table] // Empty string if not found
38+
sql := generateTableUpdateSQL(table, rules, pkColumn)
3139
sqlStatements = append(sqlStatements, sql)
3240
}
3341

3442
return strings.Join(sqlStatements, "\n\n")
3543
}
3644

45+
// generatePrimaryKeyQuerySQL generates SQL to query primary keys for all tables
46+
func generatePrimaryKeyQuerySQL(tables []string) string {
47+
if len(tables) == 0 {
48+
return ""
49+
}
50+
51+
// Build SQL to find primary key columns for all tables
52+
// Returns: table_name | column_name (one row per table with single-column PK)
53+
quotedTables := make([]string, len(tables))
54+
for i, table := range tables {
55+
quotedTables[i] = fmt.Sprintf("'%s'", strings.ReplaceAll(table, "'", "''"))
56+
}
57+
58+
sql := fmt.Sprintf(`
59+
SELECT
60+
t.tablename as table_name,
61+
a.attname as column_name
62+
FROM pg_tables t
63+
JOIN pg_class c ON c.relname = t.tablename
64+
JOIN pg_index i ON i.indrelid = c.oid AND i.indisprimary
65+
JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = ANY(i.indkey)
66+
WHERE t.schemaname = 'public'
67+
AND t.tablename IN (%s)
68+
AND array_length(i.indkey, 1) = 1 -- Only single-column primary keys
69+
ORDER BY t.tablename;
70+
`, strings.Join(quotedTables, ", "))
71+
72+
return sql
73+
}
74+
3775
// generateTableUpdateSQL generates UPDATE statement for a single table
38-
func generateTableUpdateSQL(table string, rules []models.AnonRule) string {
76+
// pkColumn is the primary key column name (empty string means use ctid)
77+
func generateTableUpdateSQL(table string, rules []models.AnonRule, pkColumn string) string {
3978
if len(rules) == 0 {
4079
return ""
4180
}
4281

43-
// Build SET clause with row_number replacement
82+
// Determine ordering: use primary key if available, otherwise ctid
83+
var orderBy string
84+
var orderByComment string
85+
if pkColumn != "" {
86+
orderBy = quoteIdentifier(pkColumn)
87+
orderByComment = fmt.Sprintf(" (ordered by PK: %s)", pkColumn)
88+
} else {
89+
orderBy = "ctid"
90+
orderByComment = " (ordered by ctid - no PK found)"
91+
}
92+
93+
// Build SET clause with row_number replacement and IS DISTINCT FROM for idempotency
4494
var setClauses []string
95+
var whereConditions []string
4596
for _, rule := range rules {
46-
// Replace ${index} with row number in the template
47-
// Use row_number() OVER (ORDER BY primary key or ctid for deterministic ordering)
4897
setValue := renderTemplate(rule.Template)
49-
setClauses = append(setClauses, fmt.Sprintf("%s = %s", quoteIdentifier(rule.Column), setValue))
98+
columnQuoted := quoteIdentifier(rule.Column)
99+
100+
// Add SET clause
101+
setClauses = append(setClauses, fmt.Sprintf("%s = %s", columnQuoted, setValue))
102+
103+
// Add condition to skip rows that already have the target value (idempotency)
104+
whereConditions = append(whereConditions, fmt.Sprintf("%s.%s IS DISTINCT FROM %s",
105+
quoteIdentifier(table), columnQuoted, setValue))
50106
}
51107

108+
// Combine WHERE conditions with OR (update if ANY column is different)
109+
whereClause := strings.Join(whereConditions, " OR ")
110+
52111
// Use CTE with row numbers for deterministic updates
53-
sql := fmt.Sprintf(`-- Anonymize table: %s
112+
sql := fmt.Sprintf(`-- Anonymize table: %s%s
54113
WITH numbered_rows AS (
55-
SELECT ctid, row_number() OVER (ORDER BY ctid) as _row_num
114+
SELECT ctid, row_number() OVER (ORDER BY %s) as _row_num
56115
FROM %s
57116
)
58117
UPDATE %s
59118
SET %s
60119
FROM numbered_rows
61-
WHERE %s.ctid = numbered_rows.ctid;`,
120+
WHERE %s.ctid = numbered_rows.ctid
121+
AND (%s);`,
62122
table,
123+
orderByComment,
124+
orderBy,
63125
quoteIdentifier(table),
64126
quoteIdentifier(table),
65127
strings.Join(setClauses, ",\n "),
66128
quoteIdentifier(table),
129+
whereClause,
67130
)
68131

69132
return sql
@@ -129,8 +192,62 @@ func Apply(ctx context.Context, db *gorm.DB, params ApplyParams, logger zerolog.
129192
Int("rule_count", len(rules)).
130193
Msg("Applying anonymization rules")
131194

132-
// Generate SQL from rules
133-
sql := GenerateSQL(rules)
195+
// Extract unique table names from rules
196+
tableMap := make(map[string]bool)
197+
for _, rule := range rules {
198+
tableMap[rule.Table] = true
199+
}
200+
var tables []string
201+
for table := range tableMap {
202+
tables = append(tables, table)
203+
}
204+
205+
// Query for primary keys
206+
primaryKeys := make(map[string]string)
207+
if len(tables) > 0 {
208+
pkQuerySQL := generatePrimaryKeyQuerySQL(tables)
209+
pkScript := fmt.Sprintf(`#!/bin/bash
210+
set -euo pipefail
211+
DATABASE_NAME="%s"
212+
PG_VERSION="%s"
213+
PG_PORT="%d"
214+
PG_BIN="/usr/lib/postgresql/${PG_VERSION}/bin"
215+
216+
sudo -u postgres ${PG_BIN}/psql -p ${PG_PORT} -d "${DATABASE_NAME}" -t -A -F'|' <<'PK_QUERY'
217+
%s
218+
PK_QUERY
219+
`, params.DatabaseName, params.PostgresVersion, params.PostgresPort, pkQuerySQL)
220+
221+
cmd := exec.CommandContext(ctx, "bash", "-c", pkScript)
222+
outputBytes, err := cmd.CombinedOutput()
223+
if err != nil {
224+
// Log warning but continue - we'll use ctid as fallback
225+
logger.Warn().
226+
Err(err).
227+
Str("output", string(outputBytes)).
228+
Msg("Failed to query primary keys, will use ctid for ordering")
229+
} else {
230+
// Parse output: table_name|column_name (one per line)
231+
output := strings.TrimSpace(string(outputBytes))
232+
if output != "" {
233+
for _, line := range strings.Split(output, "\n") {
234+
parts := strings.Split(line, "|")
235+
if len(parts) == 2 {
236+
tableName := strings.TrimSpace(parts[0])
237+
columnName := strings.TrimSpace(parts[1])
238+
primaryKeys[tableName] = columnName
239+
logger.Debug().
240+
Str("table", tableName).
241+
Str("pk_column", columnName).
242+
Msg("Detected primary key")
243+
}
244+
}
245+
}
246+
}
247+
}
248+
249+
// Generate SQL from rules with primary key information
250+
sql := GenerateSQL(rules, primaryKeys)
134251
if sql == "" {
135252
logger.Warn().Msg("Generated empty SQL from rules")
136253
return 0, nil

internal/anonymize/anonymize_test.go

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,43 +9,64 @@ import (
99

1010
func TestGenerateSQL(t *testing.T) {
1111
tests := []struct {
12-
name string
13-
rules []models.AnonRule
14-
want string
12+
name string
13+
rules []models.AnonRule
14+
primaryKeys map[string]string
15+
want string
1516
}{
1617
{
17-
name: "empty rules",
18-
rules: []models.AnonRule{},
19-
want: "",
18+
name: "empty rules",
19+
rules: []models.AnonRule{},
20+
primaryKeys: map[string]string{},
21+
want: "",
2022
},
2123
{
22-
name: "single column anonymization",
24+
name: "single column anonymization without PK",
2325
rules: []models.AnonRule{
2426
{Table: "users", Column: "email", Template: "user_${index}@example.com"},
2527
},
26-
want: "UPDATE",
28+
primaryKeys: map[string]string{},
29+
want: "UPDATE",
30+
},
31+
{
32+
name: "single column anonymization with PK",
33+
rules: []models.AnonRule{
34+
{Table: "users", Column: "email", Template: "user_${index}@example.com"},
35+
},
36+
primaryKeys: map[string]string{"users": "id"},
37+
want: "ORDER BY \"id\"",
2738
},
2839
{
2940
name: "multiple columns same table",
3041
rules: []models.AnonRule{
3142
{Table: "users", Column: "email", Template: "user_${index}@example.com"},
3243
{Table: "users", Column: "name", Template: "User ${index}"},
3344
},
34-
want: "numbered_rows._row_num",
45+
primaryKeys: map[string]string{},
46+
want: "numbered_rows._row_num",
3547
},
3648
{
37-
name: "multiple tables",
49+
name: "multiple tables with mixed PKs",
3850
rules: []models.AnonRule{
3951
{Table: "users", Column: "email", Template: "user_${index}@example.com"},
4052
{Table: "orders", Column: "reference", Template: "ORD-${index}"},
4153
},
42-
want: "-- Anonymize table:",
54+
primaryKeys: map[string]string{"users": "id"}, // Only users has PK
55+
want: "-- Anonymize table:",
56+
},
57+
{
58+
name: "idempotency - IS DISTINCT FROM clause",
59+
rules: []models.AnonRule{
60+
{Table: "users", Column: "email", Template: "user_${index}@example.com"},
61+
},
62+
primaryKeys: map[string]string{},
63+
want: "IS DISTINCT FROM",
4364
},
4465
}
4566

4667
for _, tt := range tests {
4768
t.Run(tt.name, func(t *testing.T) {
48-
got := GenerateSQL(tt.rules)
69+
got := GenerateSQL(tt.rules, tt.primaryKeys)
4970
if tt.want != "" && !strings.Contains(got, tt.want) {
5071
t.Errorf("GenerateSQL() output doesn't contain expected string\nwant substring: %v\ngot: %v", tt.want, got)
5172
}

0 commit comments

Comments
 (0)