diff --git a/q/parse.go b/q/parse.go index ac4799e..fefb947 100644 --- a/q/parse.go +++ b/q/parse.go @@ -36,6 +36,27 @@ func (p *parser) findToken(condition func(Token) bool) (*Token, int) { return nil, -1 } +type FoundToken struct { + Token Token + Index int +} + +// Returns all tokens that match the condition, along with their indices +func (p *parser) findAllTokens(condition func(Token) bool) []FoundToken { + matches := make([]FoundToken, 0) + + for i, token := range p.tokens { + if condition(token) { + matches = append(matches, FoundToken{ + Token: token, + Index: i, + }) + } + } + + return matches +} + func filter[T any](items []T, fn func(item T) bool) []T { filteredItems := []T{} for _, value := range items { @@ -134,6 +155,8 @@ func parseSelectStatement(p *parser) error { return ordersByErr } + parseJoins(p) + return nil } @@ -317,3 +340,32 @@ func parseOrderBys(p *parser) error { return nil } + +func parseJoins(p *parser) error { + selectQuery := p.query.(*Select) + + foundJoinKeywords := p.findAllTokens(func(t Token) bool { + return t.Type == sqllexer.COMMAND && strings.ToUpper(t.Value) == "JOIN" + }) + + if len(foundJoinKeywords) <= 0 { + return nil + } + + var joinTokenRanges []Token + + for i := 0; i < len(foundJoinKeywords); i++ { + startRangeIndex := foundJoinKeywords[i].Index + var endRangeIndex int + + if i == (len(foundJoinKeywords) - 1) { + endRangeIndex = len(p.tokens) - 1 + } else { + endRangeIndex = foundJoinKeywords[i+1].Index + } + + joinTokenRanges = append(joinTokenRanges, p.tokens[startRangeIndex:endRangeIndex]...) + } + + return nil +} diff --git a/q/parse_test.go b/q/parse_test.go index 951c061..4643fa7 100644 --- a/q/parse_test.go +++ b/q/parse_test.go @@ -95,6 +95,88 @@ func TestParseSelectStatement_StateMachine(t *testing.T) { }, }, }, + { + input: "SELECT id, streetNumber AS streetNum, streetName, city, state FROM Addresses WHERE state = 'AL' AND zip > 9000 OR zip <= 12000 ORDER BY zip DESC, streetNumber", + expected: &Select{ + Type: SELECT, + Table: Table{Name: "Addresses"}, + IsWildcard: false, + Columns: []Column{ + { + Name: "id", + }, + { + Name: "streetNumber", + Alias: "streetNum", + }, + { + Name: "streetName", + }, + { + Name: "city", + }, + { + Name: "state", + }, + }, + Conditionals: []Conditional{ + { + Key: "state", + Operator: "=", + Value: "'AL'", + }, + { + Key: "zip", + Operator: ">", + Value: "9000", + Extension: "AND", + }, + { + Key: "zip", + Operator: "<=", + Value: "12000", + Extension: "OR", + }, + }, + OrderBys: []OrderBy{ + { + Key: "zip", + IsDescend: true, + }, + { + Key: "streetNumber", + IsDescend: false, + }, + }, + }, + }, + { + input: "SELECT ProductID, ProductName, CategoryName FROM Products INNER JOIN Categories ON Products.CategoryID = Categories.CategoryID; ", + expected: &Select{ + Type: SELECT, + Table: Table{Name: "Products"}, + Columns: []Column{ + {Name: "ProductID"}, + {Name: "ProductName"}, + {Name: "CategoryName"}, + }, + Joins: []Join{ + { + Type: INNER, + MainTable: Table{ + Name: "Categories", + }, + Ons: []Conditional{ + { + Key: "Products.CategoryID", + Operator: "=", + Value: "Categories.CategoryID", + }, + }, + }, + }, + }, + }, } for _, sql := range testSqlStatements { @@ -171,6 +253,47 @@ func TestParseSelectStatement_StateMachine(t *testing.T) { } } + if len(answerAsSelect.Joins) != len(expected.Joins) { + t.Errorf("got %d number of joins for Select.Joinss, expected %d", len(answerAsSelect.Joins), len(expected.Joins)) + } else { + for i, expectedJoin := range expected.Joins { + if answerAsSelect.Joins[i].Type != expectedJoin.Type { + t.Errorf("got %d for Select.Joins[%d].Type, expected %d", answerAsSelect.Joins[i].Type, i, expectedJoin.Type) + } + if answerAsSelect.Joins[i].MainTable.Name != expectedJoin.MainTable.Name { + t.Errorf("got %s for Select.Joins[%d].MainTable.Name, expected %s", answerAsSelect.Joins[i].MainTable.Name, i, expectedJoin.MainTable.Name) + } + if answerAsSelect.Joins[i].MainTable.Alias != expectedJoin.MainTable.Alias { + t.Errorf("got %s for Select.Joins[%d].MainTable.Alias, expected %s", answerAsSelect.Joins[i].MainTable.Alias, i, expectedJoin.MainTable.Alias) + } + if answerAsSelect.Joins[i].JoiningTable.Name != expectedJoin.JoiningTable.Name { + t.Errorf("got %s for Select.Joins[%d].JoiningTable.Name, expected %s", answerAsSelect.Joins[i].JoiningTable.Name, i, expectedJoin.JoiningTable.Name) + } + if answerAsSelect.Joins[i].JoiningTable.Alias != expectedJoin.JoiningTable.Alias { + t.Errorf("got %s for Select.Joins[%d].JoiningTable.Alias, expected %s", answerAsSelect.Joins[i].JoiningTable.Alias, i, expectedJoin.JoiningTable.Alias) + } + + if len(answerAsSelect.Joins[i].Ons) != len(expectedJoin.Ons) { + t.Errorf("got %d number of ons for Select.Joins.Ons, expected %d", len(answerAsSelect.Joins[i].Ons), len(expectedJoin.Ons)) + } else { + for on_i, expectedCondition := range expected.Joins[i].Ons { + if expectedCondition.Key != answerAsSelect.Conditionals[on_i].Key { + t.Errorf("got %s for Select.Conditionals[%d].Key, expected %s", answerAsSelect.Conditionals[on_i].Key, on_i, expectedCondition.Key) + } + if expectedCondition.Operator != answerAsSelect.Conditionals[on_i].Operator { + t.Errorf("got %s for Select.Conditionals[%d].Operator, expected %s", answerAsSelect.Conditionals[on_i].Operator, on_i, expectedCondition.Operator) + } + if expectedCondition.Value != answerAsSelect.Conditionals[on_i].Value { + t.Errorf("got %s for Select.Conditionals[%d].Value, expected %s", answerAsSelect.Conditionals[on_i].Value, on_i, expectedCondition.Value) + } + if expectedCondition.Extension != answerAsSelect.Conditionals[on_i].Extension { + t.Errorf("got %s for Select.Conditionals[%d].Extension, expected %s", answerAsSelect.Conditionals[on_i].Extension, on_i, expectedCondition.Extension) + } + } + + } + } + } }) } }