diff --git a/q/dto.go b/q/dto.go index 7dbc991..7e781e3 100644 --- a/q/dto.go +++ b/q/dto.go @@ -134,7 +134,13 @@ type Join struct { Ons []Conditional `json:"ons"` } +type OrderBy struct { + Key string `json:"key"` + IsDescend bool `json:"isDescend"` +} + type Select struct { + Type QueryType `json:"type"` Table Table `json:"table"` Columns []Column `json:"columns"` Conditionals []Conditional `json:"conditionals"` @@ -144,11 +150,6 @@ type Select struct { IsDistinct bool `json:"is_distinct"` } -type OrderBy struct { - Key string `json:"key"` - IsDescend bool `json:"isDescend"` -} - func MarshalSelect(selectStatement Select) ([]byte, error) { jsonData, err := json.MarshalIndent(selectStatement, "", " ") return jsonData, err diff --git a/q/parse.go b/q/parse.go new file mode 100644 index 0000000..edba65a --- /dev/null +++ b/q/parse.go @@ -0,0 +1,179 @@ +package q + +import ( + "fmt" + "strings" + + "github.com/DataDog/go-sqllexer" +) + +type Token = sqllexer.Token + +type parser struct { + sql string + tokens []Token + index int + query Query + lookBehindBuffer []Token +} + +func (p *parser) lookAhead(count int) Token { + return p.tokens[p.index+count] +} + +func (p *parser) lookBehind(count int) Token { + return p.tokens[p.index-count] +} + +// Returns pointer of the first found Token and its index in parser.Tokens +// Returns -1 as the index if not found +func (p *parser) findToken(condition func(Token) bool) (*Token, int) { + for i, element := range p.tokens { + if condition(element) { + return &element, i + } + } + return nil, -1 +} + +func filter[T any](items []T, fn func(item T) bool) []T { + filteredItems := []T{} + for _, value := range items { + if fn(value) { + filteredItems = append(filteredItems, value) + } + } + return filteredItems +} + +func unshiftBuffer(buf *[10]sqllexer.Token, value sqllexer.Token) { + for i := 9; i >= 1; i-- { + buf[i] = buf[i-1] + } + + buf[0] = value +} + +func Parse(sql string) (Query, error) { + lexer := sqllexer.New(sql) + + var tokens []Token + for { + t := lexer.Scan() + tokens = append(tokens, *t) + + if IsTokenEndOfStatement(t) { + break + } + } + + p := parser{ + sql: sql, + tokens: tokens, + } + + queryType, err := findQueryType(&p) + if err != nil { + return p.query, err + } + + switch strings.ToUpper(queryType) { + case "SELECT": + p.query = &Select{ + Type: SELECT, + } + parseSelectStatement(&p) + fmt.Printf("IsWildcard %v", p.query.(*Select).IsWildcard) + default: + return p.query, fmt.Errorf("No process defined for determined queryType: %s", queryType) + } + + return p.query, nil +} + +func findQueryType(p *parser) (string, error) { + firstCommand, _ := p.findToken(func(t Token) bool { + return t.Type == sqllexer.COMMAND + }) + + if firstCommand == nil { + return "", fmt.Errorf("Could not find query type") + } + + return firstCommand.Value, nil +} + +func parseSelectStatement(p *parser) { + selectQuery := p.query.(*Select) + + foundWildcard, _ := p.findToken(func(t Token) bool { + return t.Type == sqllexer.WILDCARD + }) + + selectQuery.IsWildcard = foundWildcard != nil + + if !selectQuery.IsWildcard { + parseSelectColumns(p) + } + +} + +func parseSelectColumns(p *parser) error { + selectQuery := p.query.(*Select) + + _, selectCommandIndex := p.findToken(func(t Token) bool { + return t.Type == sqllexer.COMMAND && strings.ToUpper(t.Value) == "SELECT" + }) + + _, fromKeywordIndex := p.findToken(func(t Token) bool { + return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "FROM" + }) + + if selectCommandIndex < 0 || fromKeywordIndex < 0 { + return fmt.Errorf("Could not find range between SELECT and FROM") + } + + tokensFromSelectToFrom := p.tokens[selectCommandIndex:fromKeywordIndex] + + lookBehindBuffer := [10]Token{} + var workingColumn Column + columns := make([]Column, 0) + + for _, token := range tokensFromSelectToFrom { + if token.Type == sqllexer.FUNCTION { + unshiftBuffer(&lookBehindBuffer, token) + workingColumn.AggregateFunction = AggregateFunctionTypeByName(token.Value) + continue + } else if token.Type == sqllexer.PUNCTUATION { + if token.Value == "," { + columns = append(columns, workingColumn) + workingColumn = Column{} + } + continue + } else if token.Type == sqllexer.IDENT { + unshiftBuffer(&lookBehindBuffer, token) + + if lookBehindBuffer[0].Type == sqllexer.ALIAS_INDICATOR { + workingColumn.Alias = token.Value + } else { + workingColumn.Name = token.Value + } + + continue + } else if token.Type == sqllexer.ALIAS_INDICATOR { + unshiftBuffer(&lookBehindBuffer, token) + continue + } else if token.Type == sqllexer.SPACE { + continue + } else { + if workingColumn.Name != "" { + columns = append(columns, workingColumn) + workingColumn = Column{} + } + } + } + + selectQuery.Columns = columns + + return nil +} diff --git a/q/parse_test.go b/q/parse_test.go new file mode 100644 index 0000000..4620322 --- /dev/null +++ b/q/parse_test.go @@ -0,0 +1,54 @@ +package q + +import ( + "fmt" + "testing" +) + +type ParsingTest struct { + input string + expected Query +} + +func TestParseSelectStatement_StateMachine(t *testing.T) { + var testSqlStatements = []ParsingTest{ + { + input: "SELECT * FROM users WHERE age >= 30", + expected: &Select{ + Type: SELECT, + Table: Table{Name: "users"}, + IsWildcard: true, + Conditionals: []Conditional{ + { + Key: "age", + Operator: ">=", + Value: "30", + }, + }, + }, + }, + } + + for _, sql := range testSqlStatements { + testName := fmt.Sprintf("%s", sql.input) + expected := sql.expected.(*Select) + + t.Run(testName, func(t *testing.T) { + answer, err := Parse(sql.input) + if err != nil { + t.Error(err) + return + } + + answerAsSelect := answer.(*Select) + + if answerAsSelect.Type != expected.Type { + t.Errorf("got %d for Select.Type, expected %d", answerAsSelect.Type, expected.Type) + } + if answerAsSelect.IsWildcard != expected.IsWildcard { + t.Errorf("got %#v for Select.IsWildcard, expected %#v", answerAsSelect.IsWildcard, expected.IsWildcard) + } + + }) + } +} diff --git a/q/select.go b/q/select.go index a530c98..2ff58e5 100644 --- a/q/select.go +++ b/q/select.go @@ -58,14 +58,6 @@ func (q *Select) GetFullSql() string { return fullSql } -func unshiftBuffer(buf *[10]sqllexer.Token, value sqllexer.Token) { - for i := 9; i >= 1; i-- { - buf[i] = buf[i-1] - } - - buf[0] = value -} - func ParseSelectStatement(sql string) Select { query := Select{} diff --git a/q/select_test.go b/q/select_test.go index da3f2fb..a685aa6 100644 --- a/q/select_test.go +++ b/q/select_test.go @@ -5,16 +5,11 @@ import ( "testing" ) -type ParsingTest struct { - input string - expected Select -} - func TestParseSelectStatement(t *testing.T) { var testSqlStatements = []ParsingTest{ { input: "SELECT * FROM users WHERE age >= 30", - expected: Select{ + expected: &Select{ Table: Table{Name: "users"}, IsWildcard: true, Conditionals: []Conditional{ @@ -28,7 +23,7 @@ func TestParseSelectStatement(t *testing.T) { }, { input: "SELECT CustomerName, City FROM Customers", - expected: Select{ + expected: &Select{ Table: Table{Name: "Customers"}, IsWildcard: false, Columns: []Column{ @@ -43,7 +38,7 @@ func TestParseSelectStatement(t *testing.T) { }, { input: "SELECT DISTINCT Country FROM Nations;", - expected: Select{ + expected: &Select{ Table: Table{Name: "Nations"}, Columns: []Column{ { @@ -54,7 +49,7 @@ func TestParseSelectStatement(t *testing.T) { }, { input: "SELECT * FROM Orders ORDER BY StreetNumber, CountryCode;", - expected: Select{ + expected: &Select{ Table: Table{Name: "Orders"}, IsWildcard: true, OrderBys: []OrderBy{ @@ -69,7 +64,7 @@ func TestParseSelectStatement(t *testing.T) { }, { input: "SELECT * FROM ZipCodes ORDER BY Code ASC, StateName DESC", - expected: Select{ + expected: &Select{ Table: Table{Name: "ZipCodes"}, IsWildcard: true, OrderBys: []OrderBy{ @@ -86,7 +81,7 @@ func TestParseSelectStatement(t *testing.T) { }, { input: "SELECT id, streetNumber, streetName, city, state FROM Addresses WHERE state = 'AL' AND zip > 9000 OR zip <= 12000 ORDER BY zip DESC, streetNumber", - expected: Select{ + expected: &Select{ Table: Table{Name: "Addresses"}, IsWildcard: false, Columns: []Column{ @@ -167,7 +162,7 @@ func TestParseSelectStatement(t *testing.T) { for _, sql := range testSqlStatements { testName := fmt.Sprintf("%s", sql.input) - expected := sql.expected + expected := sql.expected.(*Select) t.Run(testName, func(t *testing.T) { answer := ParseSelectStatement(sql.input)