From 09adb73c57cf254a66b29b1a7a127e8d95b94c2f Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 23 Dec 2025 19:59:19 +0000 Subject: [PATCH] Add table-level constraint parsing for CREATE TABLE Add support for parsing named table constraints (CONSTRAINT name ...) including PRIMARY KEY, UNIQUE, FOREIGN KEY, and CHECK constraints. Changes: - Add parseNamedTableConstraint() for CONSTRAINT name ... syntax - Add parseUnnamedTableConstraint() for constraints without names - Add parsePrimaryKeyConstraint() with column sort order support - Add parseUniqueConstraint() with CLUSTERED/NONCLUSTERED - Add parseForeignKeyConstraint() with REFERENCES clause - Add parseCheckConstraint() for CHECK expressions - Add parseColumnWithSortOrder() for ASC/DESC column specifications - Fix foreignKeyConstraintToJSON() to output ReferencedTableColumns - Always output DeleteAction/UpdateAction with default "NotSpecified" Enables tests: - BaselinesCommon_TSqlParserTestScript3 - TSqlParserTestScript3 --- parser/marshal.go | 310 +++++++++++++++++- .../metadata.json | 2 +- .../TSqlParserTestScript3/metadata.json | 2 +- 3 files changed, 301 insertions(+), 13 deletions(-) diff --git a/parser/marshal.go b/parser/marshal.go index 5bbbe830..287c4b7e 100644 --- a/parser/marshal.go +++ b/parser/marshal.go @@ -2476,14 +2476,38 @@ func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) stmt.Definition = &ast.TableDefinition{} - // Parse column definitions + // Parse column definitions and table constraints for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { - colDef, err := p.parseColumnDefinition() - if err != nil { - p.skipToEndOfStatement() - return stmt, nil + upperLit := strings.ToUpper(p.curTok.Literal) + + // Check for table-level constraints + if upperLit == "CONSTRAINT" { + constraint, err := p.parseNamedTableConstraint() + if err != nil { + p.skipToEndOfStatement() + return stmt, nil + } + if constraint != nil { + stmt.Definition.TableConstraints = append(stmt.Definition.TableConstraints, constraint) + } + } else if upperLit == "PRIMARY" || upperLit == "UNIQUE" || upperLit == "FOREIGN" || upperLit == "CHECK" { + constraint, err := p.parseUnnamedTableConstraint() + if err != nil { + p.skipToEndOfStatement() + return stmt, nil + } + if constraint != nil { + stmt.Definition.TableConstraints = append(stmt.Definition.TableConstraints, constraint) + } + } else { + // Parse column definition + colDef, err := p.parseColumnDefinition() + if err != nil { + p.skipToEndOfStatement() + return stmt, nil + } + stmt.Definition.ColumnDefinitions = append(stmt.Definition.ColumnDefinitions, colDef) } - stmt.Definition.ColumnDefinitions = append(stmt.Definition.ColumnDefinitions, colDef) if p.curTok.Type == TokenComma { p.nextToken() @@ -2685,6 +2709,265 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) { return col, nil } +// parseNamedTableConstraint parses a CONSTRAINT name ... table constraint +func (p *Parser) parseNamedTableConstraint() (ast.TableConstraint, error) { + // Consume CONSTRAINT + p.nextToken() + + // Parse constraint name + constraintName := p.parseIdentifier() + + // Now parse the actual constraint type + upperLit := strings.ToUpper(p.curTok.Literal) + + if upperLit == "PRIMARY" { + constraint, err := p.parsePrimaryKeyConstraint() + if err != nil { + return nil, err + } + constraint.ConstraintIdentifier = constraintName + return constraint, nil + } else if upperLit == "UNIQUE" { + constraint, err := p.parseUniqueConstraint() + if err != nil { + return nil, err + } + constraint.ConstraintIdentifier = constraintName + return constraint, nil + } else if upperLit == "FOREIGN" { + constraint, err := p.parseForeignKeyConstraint() + if err != nil { + return nil, err + } + constraint.ConstraintIdentifier = constraintName + return constraint, nil + } else if upperLit == "CHECK" { + constraint, err := p.parseCheckConstraint() + if err != nil { + return nil, err + } + constraint.ConstraintIdentifier = constraintName + return constraint, nil + } + + return nil, nil +} + +// parseUnnamedTableConstraint parses an unnamed table constraint (PRIMARY KEY, UNIQUE, FOREIGN KEY, CHECK) +func (p *Parser) parseUnnamedTableConstraint() (ast.TableConstraint, error) { + upperLit := strings.ToUpper(p.curTok.Literal) + + if upperLit == "PRIMARY" { + return p.parsePrimaryKeyConstraint() + } else if upperLit == "UNIQUE" { + return p.parseUniqueConstraint() + } else if upperLit == "FOREIGN" { + return p.parseForeignKeyConstraint() + } else if upperLit == "CHECK" { + return p.parseCheckConstraint() + } + + return nil, nil +} + +// parsePrimaryKeyConstraint parses PRIMARY KEY CLUSTERED/NONCLUSTERED (columns) +func (p *Parser) parsePrimaryKeyConstraint() (*ast.UniqueConstraintDefinition, error) { + // Consume PRIMARY + p.nextToken() + if p.curTok.Type == TokenKey { + p.nextToken() // consume KEY + } + + constraint := &ast.UniqueConstraintDefinition{ + IsPrimaryKey: true, + } + + // Parse optional CLUSTERED/NONCLUSTERED + if strings.ToUpper(p.curTok.Literal) == "CLUSTERED" { + constraint.Clustered = true + constraint.IndexType = &ast.IndexType{IndexTypeKind: "Clustered"} + p.nextToken() + } else if strings.ToUpper(p.curTok.Literal) == "NONCLUSTERED" { + constraint.Clustered = false + constraint.IndexType = &ast.IndexType{IndexTypeKind: "NonClustered"} + p.nextToken() + } + + // Parse column list + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + col := p.parseColumnWithSortOrder() + constraint.Columns = append(constraint.Columns, col) + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + + return constraint, nil +} + +// parseUniqueConstraint parses UNIQUE CLUSTERED/NONCLUSTERED (columns) +func (p *Parser) parseUniqueConstraint() (*ast.UniqueConstraintDefinition, error) { + // Consume UNIQUE + p.nextToken() + + constraint := &ast.UniqueConstraintDefinition{ + IsPrimaryKey: false, + } + + // Parse optional CLUSTERED/NONCLUSTERED + if strings.ToUpper(p.curTok.Literal) == "CLUSTERED" { + constraint.Clustered = true + constraint.IndexType = &ast.IndexType{IndexTypeKind: "Clustered"} + p.nextToken() + } else if strings.ToUpper(p.curTok.Literal) == "NONCLUSTERED" { + constraint.Clustered = false + constraint.IndexType = &ast.IndexType{IndexTypeKind: "NonClustered"} + p.nextToken() + } + + // Parse column list + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + col := p.parseColumnWithSortOrder() + constraint.Columns = append(constraint.Columns, col) + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + + return constraint, nil +} + +// parseForeignKeyConstraint parses FOREIGN KEY (columns) REFERENCES table (columns) +func (p *Parser) parseForeignKeyConstraint() (*ast.ForeignKeyConstraintDefinition, error) { + // Consume FOREIGN + p.nextToken() + if p.curTok.Type == TokenKey { + p.nextToken() // consume KEY + } + + constraint := &ast.ForeignKeyConstraintDefinition{} + + // Parse column list + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + ident := p.parseIdentifier() + constraint.Columns = append(constraint.Columns, ident) + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + + // Parse REFERENCES + if strings.ToUpper(p.curTok.Literal) == "REFERENCES" { + p.nextToken() // consume REFERENCES + + // Parse reference table name + refTable, err := p.parseSchemaObjectName() + if err != nil { + return nil, err + } + constraint.ReferenceTableName = refTable + + // Parse referenced column list + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF { + ident := p.parseIdentifier() + constraint.ReferencedColumns = append(constraint.ReferencedColumns, ident) + + if p.curTok.Type == TokenComma { + p.nextToken() + } else { + break + } + } + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + } + + return constraint, nil +} + +// parseCheckConstraint parses CHECK (expression) +func (p *Parser) parseCheckConstraint() (*ast.CheckConstraintDefinition, error) { + // Consume CHECK + p.nextToken() + + constraint := &ast.CheckConstraintDefinition{} + + // Parse condition + if p.curTok.Type == TokenLParen { + p.nextToken() // consume ( + cond, err := p.parseBooleanExpression() + if err != nil { + return nil, err + } + constraint.CheckCondition = cond + if p.curTok.Type == TokenRParen { + p.nextToken() // consume ) + } + } + + return constraint, nil +} + +// parseColumnWithSortOrder parses a column name with optional ASC/DESC sort order +func (p *Parser) parseColumnWithSortOrder() *ast.ColumnWithSortOrder { + col := &ast.ColumnWithSortOrder{ + SortOrder: ast.SortOrderNotSpecified, + } + + // Parse column name + ident := p.parseIdentifier() + col.Column = &ast.ColumnReferenceExpression{ + ColumnType: "Regular", + MultiPartIdentifier: &ast.MultiPartIdentifier{ + Count: 1, + Identifiers: []*ast.Identifier{ident}, + }, + } + + // Parse optional ASC/DESC + upperLit := strings.ToUpper(p.curTok.Literal) + if upperLit == "ASC" { + col.SortOrder = ast.SortOrderAscending + p.nextToken() + } else if upperLit == "DESC" { + col.SortOrder = ast.SortOrderDescending + p.nextToken() + } + + return col +} + func (p *Parser) parseGrantStatement() (*ast.GrantStatement, error) { // Consume GRANT p.nextToken() @@ -3427,14 +3710,19 @@ func foreignKeyConstraintToJSON(c *ast.ForeignKeyConstraintDefinition) jsonNode for i, col := range c.ReferencedColumns { cols[i] = identifierToJSON(col) } - node["ReferencedColumns"] = cols + node["ReferencedTableColumns"] = cols } - if c.DeleteAction != "" { - node["DeleteAction"] = c.DeleteAction + // Always include DeleteAction and UpdateAction with default value + deleteAction := c.DeleteAction + if deleteAction == "" { + deleteAction = "NotSpecified" } - if c.UpdateAction != "" { - node["UpdateAction"] = c.UpdateAction + node["DeleteAction"] = deleteAction + updateAction := c.UpdateAction + if updateAction == "" { + updateAction = "NotSpecified" } + node["UpdateAction"] = updateAction return node } diff --git a/parser/testdata/BaselinesCommon_TSqlParserTestScript3/metadata.json b/parser/testdata/BaselinesCommon_TSqlParserTestScript3/metadata.json index ccffb5b9..9e26dfee 100644 --- a/parser/testdata/BaselinesCommon_TSqlParserTestScript3/metadata.json +++ b/parser/testdata/BaselinesCommon_TSqlParserTestScript3/metadata.json @@ -1 +1 @@ -{"todo": true} \ No newline at end of file +{} \ No newline at end of file diff --git a/parser/testdata/TSqlParserTestScript3/metadata.json b/parser/testdata/TSqlParserTestScript3/metadata.json index ccffb5b9..9e26dfee 100644 --- a/parser/testdata/TSqlParserTestScript3/metadata.json +++ b/parser/testdata/TSqlParserTestScript3/metadata.json @@ -1 +1 @@ -{"todo": true} \ No newline at end of file +{} \ No newline at end of file