Skip to content

Commit 5011a00

Browse files
refactor postgresql formatter
1 parent e47a398 commit 5011a00

File tree

1 file changed

+93
-63
lines changed

1 file changed

+93
-63
lines changed

ktorm-support-postgresql/src/main/kotlin/org/ktorm/support/postgresql/PostgreSqlDialect.kt

Lines changed: 93 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -128,85 +128,74 @@ public open class PostgreSqlFormatter(
128128
return expr
129129
}
130130

131-
protected open fun visitBulkInsert(expr: BulkInsertExpression): BulkInsertExpression {
132-
generateMultipleInsertSQL(expr.table.name.quoted, expr.assignments)
133-
134-
generateOnConflictSQL(expr.conflictTarget, expr.updateAssignments)
135-
136-
return expr
137-
}
138-
139131
protected open fun visitInsertOrUpdate(expr: InsertOrUpdateExpression): InsertOrUpdateExpression {
140-
generateMultipleInsertSQL(expr.table.name.quoted, listOf(expr.assignments))
141-
142-
generateOnConflictSQL(expr.conflictTarget, expr.updateAssignments)
132+
writeKeyword("insert into ")
133+
visitTable(expr.table.copy(tableAlias = null))
134+
writeColumnNames(expr.assignments.map { it.column })
135+
writeKeyword("values ")
136+
writeValues(expr.assignments)
137+
138+
if (expr.conflictColumns.isNotEmpty()) {
139+
writeKeyword("on conflict ")
140+
writeColumnNames(expr.conflictColumns)
141+
142+
if (expr.updateAssignments.isNotEmpty()) {
143+
writeKeyword("do update set ")
144+
visitColumnAssignments(expr.updateAssignments)
145+
} else {
146+
writeKeyword("do nothing ")
147+
}
148+
}
143149

144150
return expr
145151
}
146152

147-
private fun generateMultipleInsertSQL(
148-
quotedTableName: String,
149-
assignmentsList: List<List<ColumnAssignmentExpression<*>>>
150-
) {
151-
if (assignmentsList.isEmpty()) {
152-
throw IllegalStateException("The insert expression has no values to insert")
153-
}
154-
153+
protected open fun visitBulkInsert(expr: BulkInsertExpression): BulkInsertExpression {
155154
writeKeyword("insert into ")
156-
157-
write("$quotedTableName (")
158-
assignmentsList.first().forEachIndexed { i, assignment ->
159-
if (i > 0) write(", ")
160-
checkColumnName(assignment.column.name)
161-
write(assignment.column.name.quoted)
155+
visitTable(expr.table.copy(tableAlias = null))
156+
writeColumnNames(expr.assignments[0].map { it.column })
157+
writeKeyword("values ")
158+
159+
for ((i, assignments) in expr.assignments.withIndex()) {
160+
if (i > 0) {
161+
removeLastBlank()
162+
write(", ")
163+
}
164+
writeValues(assignments)
162165
}
163166

164-
writeKeyword(")")
165-
writeKeyword(" values ")
167+
if (expr.conflictColumns.isNotEmpty()) {
168+
writeKeyword("on conflict ")
169+
writeColumnNames(expr.conflictColumns)
166170

167-
assignmentsList.forEachIndexed { i, assignments ->
168-
if (i > 0) write(", ")
169-
writeKeyword("( ")
170-
visitExpressionList(assignments.map { it.expression as ArgumentExpression })
171-
writeKeyword(")")
171+
if (expr.updateAssignments.isNotEmpty()) {
172+
writeKeyword("do update set ")
173+
visitColumnAssignments(expr.updateAssignments)
174+
} else {
175+
writeKeyword("do nothing ")
176+
}
172177
}
173178

174-
removeLastBlank()
179+
return expr
175180
}
176181

177-
private fun generateOnConflictSQL(
178-
conflictTarget: List<ColumnExpression<*>>,
179-
updateAssignments: List<ColumnAssignmentExpression<*>>
180-
) {
181-
if (conflictTarget.isEmpty()) {
182-
// We are just performing an Insert operation, so any conflict will interrupt the query with an error
183-
return
184-
}
182+
private fun writeColumnNames(columns: List<ColumnExpression<*>>) {
183+
write("(")
185184

186-
writeKeyword(" on conflict (")
187-
conflictTarget.forEachIndexed { i, column ->
185+
for ((i, column) in columns.withIndex()) {
188186
if (i > 0) write(", ")
189187
checkColumnName(column.name)
190188
write(column.name.quoted)
191189
}
192190

193-
writeKeyword(") do ")
194-
195-
if (updateAssignments.isNotEmpty()) {
196-
writeKeyword("update set ")
197-
updateAssignments.forEachIndexed { i, assignment ->
198-
if (i > 0) {
199-
removeLastBlank()
200-
write(", ")
201-
}
202-
checkColumnName(assignment.column.name)
203-
write("${assignment.column.name.quoted} ")
204-
write("= ")
205-
visit(assignment.expression)
206-
}
207-
} else {
208-
writeKeyword("nothing")
209-
}
191+
write(") ")
192+
}
193+
194+
private fun writeValues(assignments: List<ColumnAssignmentExpression<*>>) {
195+
write("(")
196+
visitExpressionList(assignments.map { it.expression as ArgumentExpression })
197+
removeLastBlank()
198+
write(") ")
210199
}
211200
}
212201

@@ -260,23 +249,64 @@ public open class PostgreSqlExpressionVisitor : SqlExpressionVisitor() {
260249
protected open fun visitInsertOrUpdate(expr: InsertOrUpdateExpression): InsertOrUpdateExpression {
261250
val table = visitTable(expr.table)
262251
val assignments = visitColumnAssignments(expr.assignments)
263-
val conflictTarget = visitExpressionList(expr.conflictTarget)
252+
val conflictColumns = visitExpressionList(expr.conflictColumns)
253+
val updateAssignments = visitColumnAssignments(expr.updateAssignments)
254+
255+
@Suppress("ComplexCondition")
256+
if (table === expr.table
257+
&& assignments === expr.assignments
258+
&& conflictColumns === expr.conflictColumns
259+
&& updateAssignments === expr.updateAssignments
260+
) {
261+
return expr
262+
} else {
263+
return expr.copy(
264+
table = table,
265+
assignments = assignments,
266+
conflictColumns = conflictColumns,
267+
updateAssignments = updateAssignments
268+
)
269+
}
270+
}
271+
272+
protected open fun visitBulkInsert(expr: BulkInsertExpression): BulkInsertExpression {
273+
val table = expr.table
274+
val assignments = visitBulkInsertAssignments(expr.assignments)
275+
val conflictColumns = visitExpressionList(expr.conflictColumns)
264276
val updateAssignments = visitColumnAssignments(expr.updateAssignments)
265277

266278
@Suppress("ComplexCondition")
267279
if (table === expr.table
268280
&& assignments === expr.assignments
269-
&& conflictTarget === expr.conflictTarget
281+
&& conflictColumns === expr.conflictColumns
270282
&& updateAssignments === expr.updateAssignments
271283
) {
272284
return expr
273285
} else {
274286
return expr.copy(
275287
table = table,
276288
assignments = assignments,
277-
conflictTarget = conflictTarget,
289+
conflictColumns = conflictColumns,
278290
updateAssignments = updateAssignments
279291
)
280292
}
281293
}
294+
295+
protected open fun visitBulkInsertAssignments(
296+
assignments: List<List<ColumnAssignmentExpression<*>>>
297+
): List<List<ColumnAssignmentExpression<*>>> {
298+
val result = ArrayList<List<ColumnAssignmentExpression<*>>>()
299+
var changed = false
300+
301+
for (row in assignments) {
302+
val visited = visitColumnAssignments(row)
303+
result += visited
304+
305+
if (visited !== row) {
306+
changed = true
307+
}
308+
}
309+
310+
return if (changed) result else assignments
311+
}
282312
}

0 commit comments

Comments
 (0)