refact: support table, orders by, and conitionals

This commit is contained in:
Yehoshua Sandler 2025-05-11 19:38:51 -05:00
parent 0e6cd88d64
commit 7ec428e374
2 changed files with 278 additions and 16 deletions

View File

@ -83,7 +83,6 @@ func Parse(sql string) (Query, error) {
Type: SELECT, Type: SELECT,
} }
parseSelectStatement(&p) parseSelectStatement(&p)
fmt.Printf("IsWildcard %v", p.query.(*Select).IsWildcard)
default: default:
return p.query, fmt.Errorf("No process defined for determined queryType: %s", queryType) return p.query, fmt.Errorf("No process defined for determined queryType: %s", queryType)
} }
@ -103,7 +102,7 @@ func findQueryType(p *parser) (string, error) {
return firstCommand.Value, nil return firstCommand.Value, nil
} }
func parseSelectStatement(p *parser) { func parseSelectStatement(p *parser) error {
selectQuery := p.query.(*Select) selectQuery := p.query.(*Select)
foundWildcard, _ := p.findToken(func(t Token) bool { foundWildcard, _ := p.findToken(func(t Token) bool {
@ -113,9 +112,29 @@ func parseSelectStatement(p *parser) {
selectQuery.IsWildcard = foundWildcard != nil selectQuery.IsWildcard = foundWildcard != nil
if !selectQuery.IsWildcard { if !selectQuery.IsWildcard {
parseSelectColumns(p) err := parseSelectColumns(p)
if err != nil {
fmt.Println(err.Error())
return err
}
} }
tableErr := parseSelectTable(p)
if tableErr != nil {
return tableErr
}
conditionalsErr := parseSelectConditionals(p)
if conditionalsErr != nil {
return conditionalsErr
}
ordersByErr := parseOrderBys(p)
if ordersByErr != nil {
return ordersByErr
}
return nil
} }
func parseSelectColumns(p *parser) error { func parseSelectColumns(p *parser) error {
@ -133,39 +152,37 @@ func parseSelectColumns(p *parser) error {
return fmt.Errorf("Could not find range between SELECT and FROM") return fmt.Errorf("Could not find range between SELECT and FROM")
} }
tokensFromSelectToFrom := p.tokens[selectCommandIndex:fromKeywordIndex]
lookBehindBuffer := [10]Token{} lookBehindBuffer := [10]Token{}
var workingColumn Column var workingColumn Column
columns := make([]Column, 0) columns := make([]Column, 0)
for _, token := range tokensFromSelectToFrom { startRange := selectCommandIndex + 1
endRange := fromKeywordIndex - 1
for i := startRange; i <= endRange; i++ {
token := p.tokens[i]
if token.Type == sqllexer.FUNCTION { if token.Type == sqllexer.FUNCTION {
unshiftBuffer(&lookBehindBuffer, token) unshiftBuffer(&lookBehindBuffer, token)
workingColumn.AggregateFunction = AggregateFunctionTypeByName(token.Value) workingColumn.AggregateFunction = AggregateFunctionTypeByName(token.Value)
continue continue
} else if token.Type == sqllexer.PUNCTUATION { } else if token.Type == sqllexer.PUNCTUATION && token.Value == "," {
if token.Value == "," {
columns = append(columns, workingColumn) columns = append(columns, workingColumn)
workingColumn = Column{} workingColumn = Column{}
}
continue continue
} else if token.Type == sqllexer.IDENT { } else if token.Type == sqllexer.IDENT {
unshiftBuffer(&lookBehindBuffer, token) unshiftBuffer(&lookBehindBuffer, token)
if lookBehindBuffer[0].Type == sqllexer.ALIAS_INDICATOR { if lookBehindBuffer[1].Type == sqllexer.ALIAS_INDICATOR {
workingColumn.Alias = token.Value workingColumn.Alias = token.Value
} else { } else {
workingColumn.Name = token.Value workingColumn.Name = token.Value
} }
continue continue
} else if token.Type == sqllexer.ALIAS_INDICATOR { } else if token.Type == sqllexer.ALIAS_INDICATOR {
unshiftBuffer(&lookBehindBuffer, token) unshiftBuffer(&lookBehindBuffer, token)
continue continue
} else if token.Type == sqllexer.SPACE { } else if i == endRange {
continue
} else {
if workingColumn.Name != "" { if workingColumn.Name != "" {
columns = append(columns, workingColumn) columns = append(columns, workingColumn)
workingColumn = Column{} workingColumn = Column{}
@ -177,3 +194,126 @@ func parseSelectColumns(p *parser) error {
return nil return nil
} }
func parseSelectTable(p *parser) error {
selectQuery := p.query.(*Select)
_, fromKeywordIndex := p.findToken(func(t Token) bool {
return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "FROM"
})
if fromKeywordIndex < 0 {
return fmt.Errorf("Could not FROM keyword to look for table name")
}
var foundTable Table
for i := fromKeywordIndex + 1; i < len(p.tokens); i++ {
t := &p.tokens[i]
if foundTable.Name == "" && t.Type == sqllexer.IDENT {
foundTable.Name = p.tokens[i].Value
continue
} else if t.Type == sqllexer.IDENT {
foundTable.Alias = p.tokens[i].Value
break
} else if t.Type == sqllexer.SPACE || t.Type == sqllexer.ALIAS_INDICATOR {
continue
} else {
break
}
}
if foundTable.Name == "" {
return fmt.Errorf("Could not find table name")
}
selectQuery.Table = foundTable
return nil
}
func parseSelectConditionals(p *parser) error {
selectQuery := p.query.(*Select)
_, whereKeywordIndex := p.findToken(func(t Token) bool {
return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "WHERE"
})
if whereKeywordIndex < 0 {
return nil // fmt.Errorf("Could not find WHERE to look for conditionals")
}
var workingConditional Conditional
for i := whereKeywordIndex + 1; i < len(p.tokens); i++ {
t := &p.tokens[i]
if t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) != "AND" && strings.ToUpper(t.Value) != "OR" && strings.ToUpper(t.Value) != "NOT" {
break
}
if t.Type == sqllexer.IDENT {
workingConditional.Key = t.Value
} else if t.Type == sqllexer.OPERATOR {
workingConditional.Operator = t.Value
} else if t.Type == sqllexer.BOOLEAN || t.Type == sqllexer.NULL || t.Type == sqllexer.STRING || t.Type == sqllexer.NUMBER {
workingConditional.Value = t.Value
} else if t.Type == sqllexer.KEYWORD {
if strings.ToUpper(t.Value) == "AND" || strings.ToUpper(t.Value) == "OR" {
workingConditional.Extension = strings.ToUpper(t.Value)
}
}
if workingConditional.Key != "" && workingConditional.Operator != "" && workingConditional.Value != "" {
selectQuery.Conditionals = append(selectQuery.Conditionals, workingConditional)
workingConditional = Conditional{}
}
}
return nil
}
func parseOrderBys(p *parser) error {
selectQuery := p.query.(*Select)
_, byKeywordIndex := p.findToken(func(t Token) bool {
return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "BY"
})
if byKeywordIndex < 0 {
return nil
}
var orderBys []OrderBy
var workingOrderBy OrderBy
for i := byKeywordIndex + 1; i < len(p.tokens); i++ {
t := &p.tokens[i]
if t.Type == sqllexer.SPACE {
continue
} else if t.Type == sqllexer.IDENT && workingOrderBy.Key == "" {
workingOrderBy.Key = t.Value
continue
} else if t.Type == sqllexer.IDENT && workingOrderBy.Key != "" {
orderBys = append(orderBys, workingOrderBy)
workingOrderBy.Key = t.Value
continue
} else if t.Type == sqllexer.KEYWORD {
if t.Value == "DESC" {
workingOrderBy.IsDescend = true
} else if t.Value != "ASC" {
break
}
} else if t.Type == sqllexer.PUNCTUATION {
orderBys = append(orderBys, workingOrderBy)
workingOrderBy = OrderBy{}
continue
}
}
if workingOrderBy.Key != "" {
orderBys = append(orderBys, workingOrderBy)
}
selectQuery.OrderBys = orderBys
return nil
}

View File

@ -27,6 +27,74 @@ func TestParseSelectStatement_StateMachine(t *testing.T) {
}, },
}, },
}, },
{
input: "SELECT CustomerName, City FROM Customers",
expected: &Select{
Type: SELECT,
Table: Table{Name: "Customers"},
IsWildcard: false,
Columns: []Column{
{
Name: "CustomerName",
},
{
Name: "City",
},
},
},
},
{
input: "SELECT CustomerName AS Customer, City AS town FROM Customers AS People",
expected: &Select{
Type: SELECT,
Table: Table{Name: "Customers", Alias: "People"},
IsWildcard: false,
Columns: []Column{
{
Name: "CustomerName",
Alias: "Customer",
},
{
Name: "City",
Alias: "town",
},
},
},
},
{
input: "SELECT * FROM Orders ORDER BY StreetNumber, CountryCode;",
expected: &Select{
Type: SELECT,
Table: Table{Name: "Orders"},
IsWildcard: true,
OrderBys: []OrderBy{
{
Key: "StreetNumber",
},
{
Key: "CountryCode",
},
},
},
},
{
input: "SELECT * FROM ZipCodes ORDER BY Code ASC, StateName DESC",
expected: &Select{
Type: SELECT,
Table: Table{Name: "ZipCodes"},
IsWildcard: true,
OrderBys: []OrderBy{
{
Key: "Code",
IsDescend: false,
},
{
Key: "StateName",
IsDescend: true,
},
},
},
},
} }
for _, sql := range testSqlStatements { for _, sql := range testSqlStatements {
@ -48,6 +116,60 @@ func TestParseSelectStatement_StateMachine(t *testing.T) {
if answerAsSelect.IsWildcard != expected.IsWildcard { if answerAsSelect.IsWildcard != expected.IsWildcard {
t.Errorf("got %#v for Select.IsWildcard, expected %#v", answerAsSelect.IsWildcard, expected.IsWildcard) t.Errorf("got %#v for Select.IsWildcard, expected %#v", answerAsSelect.IsWildcard, expected.IsWildcard)
} }
if answerAsSelect.Table.Name != expected.Table.Name {
t.Errorf("got %s for Select.Table.Name, expected %s", answerAsSelect.Table.Name, expected.Table.Name)
}
if answerAsSelect.Table.Alias != expected.Table.Alias {
t.Errorf("got %s for Select.Table.Alias, expected %s", answerAsSelect.Table.Alias, expected.Table.Alias)
}
if len(answerAsSelect.Columns) != len(expected.Columns) {
t.Errorf("got %d number of columns for Select.Columns, expected %d", len(answerAsSelect.Columns), len(expected.Columns))
} else {
for i, expectedColumn := range expected.Columns {
if expectedColumn.Name != answerAsSelect.Columns[i].Name {
t.Errorf("got %s for Select.Column[%d].Name, expected %s", answerAsSelect.Columns[i].Name, i, expectedColumn.Name)
}
if expectedColumn.Alias != answerAsSelect.Columns[i].Alias {
t.Errorf("got %s for Select.Column[%d].Alias, expected %s", answerAsSelect.Columns[i].Alias, i, expectedColumn.Alias)
}
if expectedColumn.AggregateFunction != answerAsSelect.Columns[i].AggregateFunction {
t.Errorf("got %d for Select.Column[%d].AggregateFunction, expected %d", answerAsSelect.Columns[i].AggregateFunction, i, expectedColumn.AggregateFunction)
}
}
}
if len(answerAsSelect.Conditionals) != len(expected.Conditionals) {
t.Errorf("got %d number of conditionals for Select.Conditionals, expected %d", len(answerAsSelect.Conditionals), len(expected.Conditionals))
} else {
for i, expectedCondition := range expected.Conditionals {
if expectedCondition.Key != answerAsSelect.Conditionals[i].Key {
t.Errorf("got %s for Select.Conditionals[%d].Key, expected %s", answerAsSelect.Conditionals[i].Key, i, expectedCondition.Key)
}
if expectedCondition.Operator != answerAsSelect.Conditionals[i].Operator {
t.Errorf("got %s for Select.Conditionals[%d].Operator, expected %s", answerAsSelect.Conditionals[i].Operator, i, expectedCondition.Operator)
}
if expectedCondition.Value != answerAsSelect.Conditionals[i].Value {
t.Errorf("got %s for Select.Conditionals[%d].Value, expected %s", answerAsSelect.Conditionals[i].Value, i, expectedCondition.Value)
}
if expectedCondition.Extension != answerAsSelect.Conditionals[i].Extension {
t.Errorf("got %s for Select.Conditionals[%d].Extension, expected %s", answerAsSelect.Conditionals[i].Extension, i, expectedCondition.Extension)
}
}
}
if len(answerAsSelect.OrderBys) != len(expected.OrderBys) {
t.Errorf("got %d number of orderBys for Select.OrderBys, expected %d", len(answerAsSelect.OrderBys), len(expected.OrderBys))
} else {
for i, expectedOrderBy := range expected.OrderBys {
if expectedOrderBy.Key != answerAsSelect.OrderBys[i].Key {
t.Errorf("got %s for Select.OrderBys[%d].Key, expected %s", answerAsSelect.OrderBys[i].Key, i, expectedOrderBy.Key)
}
if expectedOrderBy.IsDescend != answerAsSelect.OrderBys[i].IsDescend {
t.Errorf("got %#v for Select.OrderBys[%d].IsDescend, expected %#v", answerAsSelect.OrderBys[i].IsDescend, i, expectedOrderBy.IsDescend)
}
}
}
}) })
} }