diff --git a/q/parse.go b/q/parse.go index edba65a..ac4799e 100644 --- a/q/parse.go +++ b/q/parse.go @@ -83,7 +83,6 @@ func Parse(sql string) (Query, error) { Type: SELECT, } parseSelectStatement(&p) - fmt.Printf("IsWildcard %v", p.query.(*Select).IsWildcard) default: return p.query, fmt.Errorf("No process defined for determined queryType: %s", queryType) } @@ -103,7 +102,7 @@ func findQueryType(p *parser) (string, error) { return firstCommand.Value, nil } -func parseSelectStatement(p *parser) { +func parseSelectStatement(p *parser) error { selectQuery := p.query.(*Select) foundWildcard, _ := p.findToken(func(t Token) bool { @@ -113,9 +112,29 @@ func parseSelectStatement(p *parser) { selectQuery.IsWildcard = foundWildcard != nil if !selectQuery.IsWildcard { - parseSelectColumns(p) + err := parseSelectColumns(p) + if err != nil { + fmt.Println(err.Error()) + return err + } } + tableErr := parseSelectTable(p) + if tableErr != nil { + return tableErr + } + + conditionalsErr := parseSelectConditionals(p) + if conditionalsErr != nil { + return conditionalsErr + } + + ordersByErr := parseOrderBys(p) + if ordersByErr != nil { + return ordersByErr + } + + return nil } func parseSelectColumns(p *parser) error { @@ -133,39 +152,37 @@ func parseSelectColumns(p *parser) error { return fmt.Errorf("Could not find range between SELECT and FROM") } - tokensFromSelectToFrom := p.tokens[selectCommandIndex:fromKeywordIndex] - lookBehindBuffer := [10]Token{} var workingColumn Column columns := make([]Column, 0) - for _, token := range tokensFromSelectToFrom { + startRange := selectCommandIndex + 1 + endRange := fromKeywordIndex - 1 + + for i := startRange; i <= endRange; i++ { + token := p.tokens[i] + if token.Type == sqllexer.FUNCTION { unshiftBuffer(&lookBehindBuffer, token) workingColumn.AggregateFunction = AggregateFunctionTypeByName(token.Value) continue - } else if token.Type == sqllexer.PUNCTUATION { - if token.Value == "," { - columns = append(columns, workingColumn) - workingColumn = Column{} - } + } else if token.Type == sqllexer.PUNCTUATION && token.Value == "," { + columns = append(columns, workingColumn) + workingColumn = Column{} continue } else if token.Type == sqllexer.IDENT { unshiftBuffer(&lookBehindBuffer, token) - if lookBehindBuffer[0].Type == sqllexer.ALIAS_INDICATOR { + if lookBehindBuffer[1].Type == sqllexer.ALIAS_INDICATOR { workingColumn.Alias = token.Value } else { workingColumn.Name = token.Value } - continue } else if token.Type == sqllexer.ALIAS_INDICATOR { unshiftBuffer(&lookBehindBuffer, token) continue - } else if token.Type == sqllexer.SPACE { - continue - } else { + } else if i == endRange { if workingColumn.Name != "" { columns = append(columns, workingColumn) workingColumn = Column{} @@ -177,3 +194,126 @@ func parseSelectColumns(p *parser) error { return nil } + +func parseSelectTable(p *parser) error { + selectQuery := p.query.(*Select) + + _, fromKeywordIndex := p.findToken(func(t Token) bool { + return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "FROM" + }) + + if fromKeywordIndex < 0 { + return fmt.Errorf("Could not FROM keyword to look for table name") + } + + var foundTable Table + + for i := fromKeywordIndex + 1; i < len(p.tokens); i++ { + t := &p.tokens[i] + if foundTable.Name == "" && t.Type == sqllexer.IDENT { + foundTable.Name = p.tokens[i].Value + continue + } else if t.Type == sqllexer.IDENT { + foundTable.Alias = p.tokens[i].Value + break + } else if t.Type == sqllexer.SPACE || t.Type == sqllexer.ALIAS_INDICATOR { + continue + } else { + break + } + } + + if foundTable.Name == "" { + return fmt.Errorf("Could not find table name") + } + + selectQuery.Table = foundTable + return nil +} + +func parseSelectConditionals(p *parser) error { + selectQuery := p.query.(*Select) + + _, whereKeywordIndex := p.findToken(func(t Token) bool { + return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "WHERE" + }) + + if whereKeywordIndex < 0 { + return nil // fmt.Errorf("Could not find WHERE to look for conditionals") + } + + var workingConditional Conditional + for i := whereKeywordIndex + 1; i < len(p.tokens); i++ { + t := &p.tokens[i] + + if t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) != "AND" && strings.ToUpper(t.Value) != "OR" && strings.ToUpper(t.Value) != "NOT" { + break + } + + if t.Type == sqllexer.IDENT { + workingConditional.Key = t.Value + } else if t.Type == sqllexer.OPERATOR { + workingConditional.Operator = t.Value + } else if t.Type == sqllexer.BOOLEAN || t.Type == sqllexer.NULL || t.Type == sqllexer.STRING || t.Type == sqllexer.NUMBER { + workingConditional.Value = t.Value + } else if t.Type == sqllexer.KEYWORD { + if strings.ToUpper(t.Value) == "AND" || strings.ToUpper(t.Value) == "OR" { + workingConditional.Extension = strings.ToUpper(t.Value) + } + } + + if workingConditional.Key != "" && workingConditional.Operator != "" && workingConditional.Value != "" { + selectQuery.Conditionals = append(selectQuery.Conditionals, workingConditional) + workingConditional = Conditional{} + } + } + + return nil +} + +func parseOrderBys(p *parser) error { + selectQuery := p.query.(*Select) + + _, byKeywordIndex := p.findToken(func(t Token) bool { + return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "BY" + }) + + if byKeywordIndex < 0 { + return nil + } + + var orderBys []OrderBy + + var workingOrderBy OrderBy + for i := byKeywordIndex + 1; i < len(p.tokens); i++ { + t := &p.tokens[i] + if t.Type == sqllexer.SPACE { + continue + } else if t.Type == sqllexer.IDENT && workingOrderBy.Key == "" { + workingOrderBy.Key = t.Value + continue + } else if t.Type == sqllexer.IDENT && workingOrderBy.Key != "" { + orderBys = append(orderBys, workingOrderBy) + workingOrderBy.Key = t.Value + continue + } else if t.Type == sqllexer.KEYWORD { + if t.Value == "DESC" { + workingOrderBy.IsDescend = true + } else if t.Value != "ASC" { + break + } + } else if t.Type == sqllexer.PUNCTUATION { + orderBys = append(orderBys, workingOrderBy) + workingOrderBy = OrderBy{} + continue + } + } + + if workingOrderBy.Key != "" { + orderBys = append(orderBys, workingOrderBy) + } + + selectQuery.OrderBys = orderBys + + return nil +} diff --git a/q/parse_test.go b/q/parse_test.go index 4620322..951c061 100644 --- a/q/parse_test.go +++ b/q/parse_test.go @@ -27,6 +27,74 @@ func TestParseSelectStatement_StateMachine(t *testing.T) { }, }, }, + { + input: "SELECT CustomerName, City FROM Customers", + expected: &Select{ + Type: SELECT, + Table: Table{Name: "Customers"}, + IsWildcard: false, + Columns: []Column{ + { + Name: "CustomerName", + }, + { + Name: "City", + }, + }, + }, + }, + { + input: "SELECT CustomerName AS Customer, City AS town FROM Customers AS People", + expected: &Select{ + Type: SELECT, + Table: Table{Name: "Customers", Alias: "People"}, + IsWildcard: false, + Columns: []Column{ + { + Name: "CustomerName", + Alias: "Customer", + }, + { + Name: "City", + Alias: "town", + }, + }, + }, + }, + { + input: "SELECT * FROM Orders ORDER BY StreetNumber, CountryCode;", + expected: &Select{ + Type: SELECT, + Table: Table{Name: "Orders"}, + IsWildcard: true, + OrderBys: []OrderBy{ + { + Key: "StreetNumber", + }, + { + Key: "CountryCode", + }, + }, + }, + }, + { + input: "SELECT * FROM ZipCodes ORDER BY Code ASC, StateName DESC", + expected: &Select{ + Type: SELECT, + Table: Table{Name: "ZipCodes"}, + IsWildcard: true, + OrderBys: []OrderBy{ + { + Key: "Code", + IsDescend: false, + }, + { + Key: "StateName", + IsDescend: true, + }, + }, + }, + }, } for _, sql := range testSqlStatements { @@ -48,6 +116,60 @@ func TestParseSelectStatement_StateMachine(t *testing.T) { if answerAsSelect.IsWildcard != expected.IsWildcard { t.Errorf("got %#v for Select.IsWildcard, expected %#v", answerAsSelect.IsWildcard, expected.IsWildcard) } + if answerAsSelect.Table.Name != expected.Table.Name { + t.Errorf("got %s for Select.Table.Name, expected %s", answerAsSelect.Table.Name, expected.Table.Name) + } + if answerAsSelect.Table.Alias != expected.Table.Alias { + t.Errorf("got %s for Select.Table.Alias, expected %s", answerAsSelect.Table.Alias, expected.Table.Alias) + } + + if len(answerAsSelect.Columns) != len(expected.Columns) { + t.Errorf("got %d number of columns for Select.Columns, expected %d", len(answerAsSelect.Columns), len(expected.Columns)) + } else { + for i, expectedColumn := range expected.Columns { + if expectedColumn.Name != answerAsSelect.Columns[i].Name { + t.Errorf("got %s for Select.Column[%d].Name, expected %s", answerAsSelect.Columns[i].Name, i, expectedColumn.Name) + } + if expectedColumn.Alias != answerAsSelect.Columns[i].Alias { + t.Errorf("got %s for Select.Column[%d].Alias, expected %s", answerAsSelect.Columns[i].Alias, i, expectedColumn.Alias) + } + if expectedColumn.AggregateFunction != answerAsSelect.Columns[i].AggregateFunction { + t.Errorf("got %d for Select.Column[%d].AggregateFunction, expected %d", answerAsSelect.Columns[i].AggregateFunction, i, expectedColumn.AggregateFunction) + } + } + } + + if len(answerAsSelect.Conditionals) != len(expected.Conditionals) { + t.Errorf("got %d number of conditionals for Select.Conditionals, expected %d", len(answerAsSelect.Conditionals), len(expected.Conditionals)) + } else { + for i, expectedCondition := range expected.Conditionals { + if expectedCondition.Key != answerAsSelect.Conditionals[i].Key { + t.Errorf("got %s for Select.Conditionals[%d].Key, expected %s", answerAsSelect.Conditionals[i].Key, i, expectedCondition.Key) + } + if expectedCondition.Operator != answerAsSelect.Conditionals[i].Operator { + t.Errorf("got %s for Select.Conditionals[%d].Operator, expected %s", answerAsSelect.Conditionals[i].Operator, i, expectedCondition.Operator) + } + if expectedCondition.Value != answerAsSelect.Conditionals[i].Value { + t.Errorf("got %s for Select.Conditionals[%d].Value, expected %s", answerAsSelect.Conditionals[i].Value, i, expectedCondition.Value) + } + if expectedCondition.Extension != answerAsSelect.Conditionals[i].Extension { + t.Errorf("got %s for Select.Conditionals[%d].Extension, expected %s", answerAsSelect.Conditionals[i].Extension, i, expectedCondition.Extension) + } + } + } + + if len(answerAsSelect.OrderBys) != len(expected.OrderBys) { + t.Errorf("got %d number of orderBys for Select.OrderBys, expected %d", len(answerAsSelect.OrderBys), len(expected.OrderBys)) + } else { + for i, expectedOrderBy := range expected.OrderBys { + if expectedOrderBy.Key != answerAsSelect.OrderBys[i].Key { + t.Errorf("got %s for Select.OrderBys[%d].Key, expected %s", answerAsSelect.OrderBys[i].Key, i, expectedOrderBy.Key) + } + if expectedOrderBy.IsDescend != answerAsSelect.OrderBys[i].IsDescend { + t.Errorf("got %#v for Select.OrderBys[%d].IsDescend, expected %#v", answerAsSelect.OrderBys[i].IsDescend, i, expectedOrderBy.IsDescend) + } + } + } }) }