@@ -128,85 +128,74 @@ public open class PostgreSqlFormatter(
128128 return expr
129129 }
130130
131- protected open fun visitBulkInsert (expr : BulkInsertExpression ): BulkInsertExpression {
132- generateMultipleInsertSQL(expr.table.name.quoted, expr.assignments)
133-
134- generateOnConflictSQL(expr.conflictTarget, expr.updateAssignments)
135-
136- return expr
137- }
138-
139131 protected open fun visitInsertOrUpdate (expr : InsertOrUpdateExpression ): InsertOrUpdateExpression {
140- generateMultipleInsertSQL(expr.table.name.quoted, listOf (expr.assignments))
141-
142- generateOnConflictSQL(expr.conflictTarget, expr.updateAssignments)
132+ writeKeyword(" insert into " )
133+ visitTable(expr.table.copy(tableAlias = null ))
134+ writeColumnNames(expr.assignments.map { it.column })
135+ writeKeyword(" values " )
136+ writeValues(expr.assignments)
137+
138+ if (expr.conflictColumns.isNotEmpty()) {
139+ writeKeyword(" on conflict " )
140+ writeColumnNames(expr.conflictColumns)
141+
142+ if (expr.updateAssignments.isNotEmpty()) {
143+ writeKeyword(" do update set " )
144+ visitColumnAssignments(expr.updateAssignments)
145+ } else {
146+ writeKeyword(" do nothing " )
147+ }
148+ }
143149
144150 return expr
145151 }
146152
147- private fun generateMultipleInsertSQL (
148- quotedTableName : String ,
149- assignmentsList : List <List <ColumnAssignmentExpression <* >>>
150- ) {
151- if (assignmentsList.isEmpty()) {
152- throw IllegalStateException (" The insert expression has no values to insert" )
153- }
154-
153+ protected open fun visitBulkInsert (expr : BulkInsertExpression ): BulkInsertExpression {
155154 writeKeyword(" insert into " )
156-
157- write(" $quotedTableName (" )
158- assignmentsList.first().forEachIndexed { i, assignment ->
159- if (i > 0 ) write(" , " )
160- checkColumnName(assignment.column.name)
161- write(assignment.column.name.quoted)
155+ visitTable(expr.table.copy(tableAlias = null ))
156+ writeColumnNames(expr.assignments[0 ].map { it.column })
157+ writeKeyword(" values " )
158+
159+ for ((i, assignments) in expr.assignments.withIndex()) {
160+ if (i > 0 ) {
161+ removeLastBlank()
162+ write(" , " )
163+ }
164+ writeValues(assignments)
162165 }
163166
164- writeKeyword(" )" )
165- writeKeyword(" values " )
167+ if (expr.conflictColumns.isNotEmpty()) {
168+ writeKeyword(" on conflict " )
169+ writeColumnNames(expr.conflictColumns)
166170
167- assignmentsList.forEachIndexed { i, assignments ->
168- if (i > 0 ) write(" , " )
169- writeKeyword(" ( " )
170- visitExpressionList(assignments.map { it.expression as ArgumentExpression })
171- writeKeyword(" )" )
171+ if (expr.updateAssignments.isNotEmpty()) {
172+ writeKeyword(" do update set " )
173+ visitColumnAssignments(expr.updateAssignments)
174+ } else {
175+ writeKeyword(" do nothing " )
176+ }
172177 }
173178
174- removeLastBlank()
179+ return expr
175180 }
176181
177- private fun generateOnConflictSQL (
178- conflictTarget : List <ColumnExpression <* >>,
179- updateAssignments : List <ColumnAssignmentExpression <* >>
180- ) {
181- if (conflictTarget.isEmpty()) {
182- // We are just performing an Insert operation, so any conflict will interrupt the query with an error
183- return
184- }
182+ private fun writeColumnNames (columns : List <ColumnExpression <* >>) {
183+ write(" (" )
185184
186- writeKeyword(" on conflict (" )
187- conflictTarget.forEachIndexed { i, column ->
185+ for ((i, column) in columns.withIndex()) {
188186 if (i > 0 ) write(" , " )
189187 checkColumnName(column.name)
190188 write(column.name.quoted)
191189 }
192190
193- writeKeyword(" ) do " )
194-
195- if (updateAssignments.isNotEmpty()) {
196- writeKeyword(" update set " )
197- updateAssignments.forEachIndexed { i, assignment ->
198- if (i > 0 ) {
199- removeLastBlank()
200- write(" , " )
201- }
202- checkColumnName(assignment.column.name)
203- write(" ${assignment.column.name.quoted} " )
204- write(" = " )
205- visit(assignment.expression)
206- }
207- } else {
208- writeKeyword(" nothing" )
209- }
191+ write(" ) " )
192+ }
193+
194+ private fun writeValues (assignments : List <ColumnAssignmentExpression <* >>) {
195+ write(" (" )
196+ visitExpressionList(assignments.map { it.expression as ArgumentExpression })
197+ removeLastBlank()
198+ write(" ) " )
210199 }
211200}
212201
@@ -260,23 +249,64 @@ public open class PostgreSqlExpressionVisitor : SqlExpressionVisitor() {
260249 protected open fun visitInsertOrUpdate (expr : InsertOrUpdateExpression ): InsertOrUpdateExpression {
261250 val table = visitTable(expr.table)
262251 val assignments = visitColumnAssignments(expr.assignments)
263- val conflictTarget = visitExpressionList(expr.conflictTarget)
252+ val conflictColumns = visitExpressionList(expr.conflictColumns)
253+ val updateAssignments = visitColumnAssignments(expr.updateAssignments)
254+
255+ @Suppress(" ComplexCondition" )
256+ if (table == = expr.table
257+ && assignments == = expr.assignments
258+ && conflictColumns == = expr.conflictColumns
259+ && updateAssignments == = expr.updateAssignments
260+ ) {
261+ return expr
262+ } else {
263+ return expr.copy(
264+ table = table,
265+ assignments = assignments,
266+ conflictColumns = conflictColumns,
267+ updateAssignments = updateAssignments
268+ )
269+ }
270+ }
271+
272+ protected open fun visitBulkInsert (expr : BulkInsertExpression ): BulkInsertExpression {
273+ val table = expr.table
274+ val assignments = visitBulkInsertAssignments(expr.assignments)
275+ val conflictColumns = visitExpressionList(expr.conflictColumns)
264276 val updateAssignments = visitColumnAssignments(expr.updateAssignments)
265277
266278 @Suppress(" ComplexCondition" )
267279 if (table == = expr.table
268280 && assignments == = expr.assignments
269- && conflictTarget == = expr.conflictTarget
281+ && conflictColumns == = expr.conflictColumns
270282 && updateAssignments == = expr.updateAssignments
271283 ) {
272284 return expr
273285 } else {
274286 return expr.copy(
275287 table = table,
276288 assignments = assignments,
277- conflictTarget = conflictTarget ,
289+ conflictColumns = conflictColumns ,
278290 updateAssignments = updateAssignments
279291 )
280292 }
281293 }
294+
295+ protected open fun visitBulkInsertAssignments (
296+ assignments : List <List <ColumnAssignmentExpression <* >>>
297+ ): List <List <ColumnAssignmentExpression <* >>> {
298+ val result = ArrayList <List <ColumnAssignmentExpression <* >>>()
299+ var changed = false
300+
301+ for (row in assignments) {
302+ val visited = visitColumnAssignments(row)
303+ result + = visited
304+
305+ if (visited != = row) {
306+ changed = true
307+ }
308+ }
309+
310+ return if (changed) result else assignments
311+ }
282312}
0 commit comments