Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 299 additions & 11 deletions parser/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -2476,14 +2476,38 @@ func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error)

stmt.Definition = &ast.TableDefinition{}

// Parse column definitions
// Parse column definitions and table constraints
for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF {
colDef, err := p.parseColumnDefinition()
if err != nil {
p.skipToEndOfStatement()
return stmt, nil
upperLit := strings.ToUpper(p.curTok.Literal)

// Check for table-level constraints
if upperLit == "CONSTRAINT" {
constraint, err := p.parseNamedTableConstraint()
if err != nil {
p.skipToEndOfStatement()
return stmt, nil
}
if constraint != nil {
stmt.Definition.TableConstraints = append(stmt.Definition.TableConstraints, constraint)
}
} else if upperLit == "PRIMARY" || upperLit == "UNIQUE" || upperLit == "FOREIGN" || upperLit == "CHECK" {
constraint, err := p.parseUnnamedTableConstraint()
if err != nil {
p.skipToEndOfStatement()
return stmt, nil
}
if constraint != nil {
stmt.Definition.TableConstraints = append(stmt.Definition.TableConstraints, constraint)
}
} else {
// Parse column definition
colDef, err := p.parseColumnDefinition()
if err != nil {
p.skipToEndOfStatement()
return stmt, nil
}
stmt.Definition.ColumnDefinitions = append(stmt.Definition.ColumnDefinitions, colDef)
}
stmt.Definition.ColumnDefinitions = append(stmt.Definition.ColumnDefinitions, colDef)

if p.curTok.Type == TokenComma {
p.nextToken()
Expand Down Expand Up @@ -2685,6 +2709,265 @@ func (p *Parser) parseColumnDefinition() (*ast.ColumnDefinition, error) {
return col, nil
}

// parseNamedTableConstraint parses a CONSTRAINT name ... table constraint
func (p *Parser) parseNamedTableConstraint() (ast.TableConstraint, error) {
// Consume CONSTRAINT
p.nextToken()

// Parse constraint name
constraintName := p.parseIdentifier()

// Now parse the actual constraint type
upperLit := strings.ToUpper(p.curTok.Literal)

if upperLit == "PRIMARY" {
constraint, err := p.parsePrimaryKeyConstraint()
if err != nil {
return nil, err
}
constraint.ConstraintIdentifier = constraintName
return constraint, nil
} else if upperLit == "UNIQUE" {
constraint, err := p.parseUniqueConstraint()
if err != nil {
return nil, err
}
constraint.ConstraintIdentifier = constraintName
return constraint, nil
} else if upperLit == "FOREIGN" {
constraint, err := p.parseForeignKeyConstraint()
if err != nil {
return nil, err
}
constraint.ConstraintIdentifier = constraintName
return constraint, nil
} else if upperLit == "CHECK" {
constraint, err := p.parseCheckConstraint()
if err != nil {
return nil, err
}
constraint.ConstraintIdentifier = constraintName
return constraint, nil
}

return nil, nil
}

// parseUnnamedTableConstraint parses an unnamed table constraint (PRIMARY KEY, UNIQUE, FOREIGN KEY, CHECK)
func (p *Parser) parseUnnamedTableConstraint() (ast.TableConstraint, error) {
upperLit := strings.ToUpper(p.curTok.Literal)

if upperLit == "PRIMARY" {
return p.parsePrimaryKeyConstraint()
} else if upperLit == "UNIQUE" {
return p.parseUniqueConstraint()
} else if upperLit == "FOREIGN" {
return p.parseForeignKeyConstraint()
} else if upperLit == "CHECK" {
return p.parseCheckConstraint()
}

return nil, nil
}

// parsePrimaryKeyConstraint parses PRIMARY KEY CLUSTERED/NONCLUSTERED (columns)
func (p *Parser) parsePrimaryKeyConstraint() (*ast.UniqueConstraintDefinition, error) {
// Consume PRIMARY
p.nextToken()
if p.curTok.Type == TokenKey {
p.nextToken() // consume KEY
}

constraint := &ast.UniqueConstraintDefinition{
IsPrimaryKey: true,
}

// Parse optional CLUSTERED/NONCLUSTERED
if strings.ToUpper(p.curTok.Literal) == "CLUSTERED" {
constraint.Clustered = true
constraint.IndexType = &ast.IndexType{IndexTypeKind: "Clustered"}
p.nextToken()
} else if strings.ToUpper(p.curTok.Literal) == "NONCLUSTERED" {
constraint.Clustered = false
constraint.IndexType = &ast.IndexType{IndexTypeKind: "NonClustered"}
p.nextToken()
}

// Parse column list
if p.curTok.Type == TokenLParen {
p.nextToken() // consume (
for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF {
col := p.parseColumnWithSortOrder()
constraint.Columns = append(constraint.Columns, col)

if p.curTok.Type == TokenComma {
p.nextToken()
} else {
break
}
}
if p.curTok.Type == TokenRParen {
p.nextToken() // consume )
}
}

return constraint, nil
}

// parseUniqueConstraint parses UNIQUE CLUSTERED/NONCLUSTERED (columns)
func (p *Parser) parseUniqueConstraint() (*ast.UniqueConstraintDefinition, error) {
// Consume UNIQUE
p.nextToken()

constraint := &ast.UniqueConstraintDefinition{
IsPrimaryKey: false,
}

// Parse optional CLUSTERED/NONCLUSTERED
if strings.ToUpper(p.curTok.Literal) == "CLUSTERED" {
constraint.Clustered = true
constraint.IndexType = &ast.IndexType{IndexTypeKind: "Clustered"}
p.nextToken()
} else if strings.ToUpper(p.curTok.Literal) == "NONCLUSTERED" {
constraint.Clustered = false
constraint.IndexType = &ast.IndexType{IndexTypeKind: "NonClustered"}
p.nextToken()
}

// Parse column list
if p.curTok.Type == TokenLParen {
p.nextToken() // consume (
for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF {
col := p.parseColumnWithSortOrder()
constraint.Columns = append(constraint.Columns, col)

if p.curTok.Type == TokenComma {
p.nextToken()
} else {
break
}
}
if p.curTok.Type == TokenRParen {
p.nextToken() // consume )
}
}

return constraint, nil
}

// parseForeignKeyConstraint parses FOREIGN KEY (columns) REFERENCES table (columns)
func (p *Parser) parseForeignKeyConstraint() (*ast.ForeignKeyConstraintDefinition, error) {
// Consume FOREIGN
p.nextToken()
if p.curTok.Type == TokenKey {
p.nextToken() // consume KEY
}

constraint := &ast.ForeignKeyConstraintDefinition{}

// Parse column list
if p.curTok.Type == TokenLParen {
p.nextToken() // consume (
for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF {
ident := p.parseIdentifier()
constraint.Columns = append(constraint.Columns, ident)

if p.curTok.Type == TokenComma {
p.nextToken()
} else {
break
}
}
if p.curTok.Type == TokenRParen {
p.nextToken() // consume )
}
}

// Parse REFERENCES
if strings.ToUpper(p.curTok.Literal) == "REFERENCES" {
p.nextToken() // consume REFERENCES

// Parse reference table name
refTable, err := p.parseSchemaObjectName()
if err != nil {
return nil, err
}
constraint.ReferenceTableName = refTable

// Parse referenced column list
if p.curTok.Type == TokenLParen {
p.nextToken() // consume (
for p.curTok.Type != TokenRParen && p.curTok.Type != TokenEOF {
ident := p.parseIdentifier()
constraint.ReferencedColumns = append(constraint.ReferencedColumns, ident)

if p.curTok.Type == TokenComma {
p.nextToken()
} else {
break
}
}
if p.curTok.Type == TokenRParen {
p.nextToken() // consume )
}
}
}

return constraint, nil
}

// parseCheckConstraint parses CHECK (expression)
func (p *Parser) parseCheckConstraint() (*ast.CheckConstraintDefinition, error) {
// Consume CHECK
p.nextToken()

constraint := &ast.CheckConstraintDefinition{}

// Parse condition
if p.curTok.Type == TokenLParen {
p.nextToken() // consume (
cond, err := p.parseBooleanExpression()
if err != nil {
return nil, err
}
constraint.CheckCondition = cond
if p.curTok.Type == TokenRParen {
p.nextToken() // consume )
}
}

return constraint, nil
}

// parseColumnWithSortOrder parses a column name with optional ASC/DESC sort order
func (p *Parser) parseColumnWithSortOrder() *ast.ColumnWithSortOrder {
col := &ast.ColumnWithSortOrder{
SortOrder: ast.SortOrderNotSpecified,
}

// Parse column name
ident := p.parseIdentifier()
col.Column = &ast.ColumnReferenceExpression{
ColumnType: "Regular",
MultiPartIdentifier: &ast.MultiPartIdentifier{
Count: 1,
Identifiers: []*ast.Identifier{ident},
},
}

// Parse optional ASC/DESC
upperLit := strings.ToUpper(p.curTok.Literal)
if upperLit == "ASC" {
col.SortOrder = ast.SortOrderAscending
p.nextToken()
} else if upperLit == "DESC" {
col.SortOrder = ast.SortOrderDescending
p.nextToken()
}

return col
}

func (p *Parser) parseGrantStatement() (*ast.GrantStatement, error) {
// Consume GRANT
p.nextToken()
Expand Down Expand Up @@ -3427,14 +3710,19 @@ func foreignKeyConstraintToJSON(c *ast.ForeignKeyConstraintDefinition) jsonNode
for i, col := range c.ReferencedColumns {
cols[i] = identifierToJSON(col)
}
node["ReferencedColumns"] = cols
node["ReferencedTableColumns"] = cols
}
if c.DeleteAction != "" {
node["DeleteAction"] = c.DeleteAction
// Always include DeleteAction and UpdateAction with default value
deleteAction := c.DeleteAction
if deleteAction == "" {
deleteAction = "NotSpecified"
}
if c.UpdateAction != "" {
node["UpdateAction"] = c.UpdateAction
node["DeleteAction"] = deleteAction
updateAction := c.UpdateAction
if updateAction == "" {
updateAction = "NotSpecified"
}
node["UpdateAction"] = updateAction
return node
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"todo": true}
{}
2 changes: 1 addition & 1 deletion parser/testdata/TSqlParserTestScript3/metadata.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"todo": true}
{}
Loading