diff --git a/main.go b/main.go index cd0c312..d655fe3 100644 --- a/main.go +++ b/main.go @@ -2,31 +2,33 @@ package main import ( "fmt" - "query-inter/q" + //"query-inter/q" + + "github.com/DataDog/go-sqllexer" // "github.com/DataDog/go-sqllexer" ) func main() { - selectQuery := "SELECT id, name, createDate FROM users WHERE name=1;" + selectQuery := "SELECT MIN(Price) AS SmallestPrice, CategoryID FROM Products GROUP BY CategoryID;" - allStatements := q.ExtractSqlStatmentsFromString(selectQuery) - fmt.Println(allStatements) + //allStatements := q.ExtractSqlStatmentsFromString(selectQuery) + //fmt.Println(allStatements) - //lexer := sqllexer.New(selectQuery) - //for { - // token := lexer.Scan() - // fmt.Println(token.Value, token.Type) - // - // if token.Type == sqllexer.EOF { - // break - // } - //} + lexer := sqllexer.New(selectQuery) + for { + token := lexer.Scan() + fmt.Println(token.Value, token.Type) - for _, sql := range allStatements { - query := q.ParseSelectStatement(sql) - //fmt.Print(i) - //fmt.Println(query) - fmt.Println(query.GetFullSql()) + if token.Type == sqllexer.EOF { + break + } } + //for _, sql := range allStatements { + //query := q.ParseSelectStatement(sql) + //fmt.Print(i) + //fmt.Println(query) + //fmt.Println(query.GetFullSql()) + //} + } diff --git a/q/query.go b/q/query.go index 904d0db..ded018f 100644 --- a/q/query.go +++ b/q/query.go @@ -10,11 +10,12 @@ type Query interface { GetFullSql() string } +const NONE = 0 + type QueryType int const ( - NONE QueryType = iota - SELECT + SELECT QueryType = iota + 1 UPDATE INSERT DELETE diff --git a/q/select.go b/q/select.go index b630783..baf2085 100644 --- a/q/select.go +++ b/q/select.go @@ -1,16 +1,55 @@ package q import ( + "fmt" "strings" "github.com/DataDog/go-sqllexer" ) +type AggregateFunctionType int + +const ( + MIN AggregateFunctionType = iota + 1 + MAX + COUNT + SUM + AVG +) + +type JoinType int + +const ( + INNER JoinType = iota + LEFT + RIGHT + FULL + SELF +) + +type Table struct { + Name string + Alias string +} + +type Column struct { + Name string + Alias string + AggregateFunction AggregateFunctionType +} + +type Join struct { + Type JoinType + Table Table + Ons []Conditional +} + type Select struct { Table string - Columns []string + Columns []Column Conditionals []Conditional OrderBys []OrderBy + Joins []Join IsWildcard bool IsDistinct bool } @@ -20,6 +59,68 @@ type OrderBy struct { IsDescend bool // SQL queries with no ASC|DESC on their ORDER BY are ASC by default, hence why this bool for the opposite } +func GetAggregateFunctionTypeByName(name string) AggregateFunctionType { + var functionType AggregateFunctionType + switch strings.ToUpper(name) { + case "MIN": + functionType = MIN + case "MAX": + functionType = MAX + case "COUNT": + functionType = COUNT + case "SUM": + functionType = SUM + case "AVG": + functionType = AVG + default: + functionType = 0 + } + + return functionType +} + +func GetAggregateFunctionNameByType(functionType AggregateFunctionType) string { + var functionName string + switch functionType { + case MIN: + functionName = "MIN" + case MAX: + functionName = "MAX" + case COUNT: + functionName = "COUNT" + case SUM: + functionName = "SUM" + case AVG: + functionName = "AVG" + default: + functionName = "" + } + + return functionName + +} + +func GetFullStringFromColumn(column Column) string { + var workingSlice string + + if column.AggregateFunction > 0 { + workingSlice = fmt.Sprintf( + "%s(%s)", + GetAggregateFunctionNameByType(column.AggregateFunction), + column.Name, + ) + } else { + workingSlice = column.Name + } + + if column.Alias != "" { + workingSlice += fmt.Sprintf(" AS %s", column.Alias) + } + + return workingSlice + +} + func (q *Select) GetFullSql() string { var workingSqlSlice []string @@ -30,9 +131,9 @@ func (q *Select) GetFullSql() string { } else { for i, column := range q.Columns { if i < (len(q.Columns) - 1) { - workingSqlSlice = append(workingSqlSlice, column+",") + workingSqlSlice = append(workingSqlSlice, GetFullStringFromColumn(column)+",") } else { - workingSqlSlice = append(workingSqlSlice, column) + workingSqlSlice = append(workingSqlSlice, GetFullStringFromColumn(column)) } } } @@ -80,8 +181,9 @@ func ParseSelectStatement(sql string) Select { lookBehindBuffer := [10]sqllexer.Token{} var workingConditional = Conditional{} + var workingColumn Column - var columns []string + var columns []Column var orderBys []OrderBy lexer := sqllexer.New(sql) for { @@ -97,26 +199,53 @@ func ParseSelectStatement(sql string) Select { // For any keywords that are before the columns or wildcard if passedSELECT && len(columns) == 0 && !passedColumns { if token.Type == sqllexer.KEYWORD { - mutateSelectFromKeyword(&query, token.Value) - continue + switch strings.ToUpper(token.Value) { + case "DISTINCT": + query.IsDistinct = true + continue + } } } if !passedColumns { if token.Type == sqllexer.WILDCARD { passedColumns = true - columns = make([]string, 0) + columns = make([]Column, 0) query.IsWildcard = true continue - } else if token.Type == sqllexer.IDENT { - columns = append(columns, token.Value) + } else if token.Type == sqllexer.FUNCTION { + unshiftBuffer(&lookBehindBuffer, *token) + workingColumn.AggregateFunction = GetAggregateFunctionTypeByName(token.Value) continue - } else if token.Type == sqllexer.PUNCTUATION || token.Type == sqllexer.SPACE { + } else if token.Type == sqllexer.PUNCTUATION { + if token.Value == "," { + columns = append(columns, workingColumn) + workingColumn = Column{} + } + continue + } else if token.Type == sqllexer.IDENT { + unshiftBuffer(&lookBehindBuffer, *token) + + if lookBehindBuffer[0].Type == sqllexer.ALIAS_INDICATOR { + workingColumn.Alias = token.Value + } else { + workingColumn.Name = token.Value + } + + continue + } else if token.Type == sqllexer.ALIAS_INDICATOR { + unshiftBuffer(&lookBehindBuffer, *token) + continue + } else if token.Type == sqllexer.SPACE { continue } else { + if workingColumn.Name != "" { + columns = append(columns, workingColumn) + workingColumn = Column{} + } + passedColumns = true query.Columns = columns - continue } } diff --git a/q/select_test.go b/q/select_test.go index 1d30dc9..ee31297 100644 --- a/q/select_test.go +++ b/q/select_test.go @@ -31,9 +31,13 @@ func TestParseSelectStatement(t *testing.T) { expected: Select{ Table: "Customers", IsWildcard: false, - Columns: []string{ - "CustomerName", - "City", + Columns: []Column{ + { + Name: "CustomerName", + }, + { + Name: "City", + }, }, }, }, @@ -41,8 +45,10 @@ func TestParseSelectStatement(t *testing.T) { input: "SELECT DISTINCT Country FROM Nations;", expected: Select{ Table: "Nations", - Columns: []string{ - "Country", + Columns: []Column{ + { + Name: "Country", + }, }, }, }, @@ -79,11 +85,27 @@ func TestParseSelectStatement(t *testing.T) { }, }, { - input: "SELECT id, streetNumber, streetName, city, state FROM Addresses WHERE state = 'AL' AND zip > 9000 ORDER BY zip DESC, streetNumber", + input: "SELECT id, streetNumber, streetName, city, state FROM Addresses WHERE state = 'AL' AND zip > 9000 OR zip <= 12000 ORDER BY zip DESC, streetNumber", expected: Select{ Table: "Addresses", IsWildcard: false, - Columns: []string{"id", "streetNumber", "streetName", "city", "state"}, + Columns: []Column{ + { + Name: "id", + }, + { + Name: "streetNumber", + }, + { + Name: "streetName", + }, + { + Name: "city", + }, + { + Name: "state", + }, + }, Conditionals: []Conditional{ { Key: "state", @@ -96,6 +118,12 @@ func TestParseSelectStatement(t *testing.T) { Value: "9000", Extension: "AND", }, + { + Key: "zip", + Operator: "<=", + Value: "12000", + Extension: "OR", + }, }, OrderBys: []OrderBy{ { @@ -128,8 +156,14 @@ func TestParseSelectStatement(t *testing.T) { t.Errorf("got %d number of columns for Select.Columns, expected %d", len(answer.Columns), len(expected.Columns)) } else { for i, expectedColumn := range expected.Columns { - if expectedColumn != answer.Columns[i] { - t.Errorf("got %s for Select.Column[%d], expected %s", answer.Columns[i], i, expectedColumn) + if expectedColumn.Name != answer.Columns[i].Name { + t.Errorf("got %s for Select.Column[%d].Name, expected %s", answer.Columns[i].Name, i, expectedColumn.Name) + } + if expectedColumn.Alias != answer.Columns[i].Alias { + t.Errorf("got %s for Select.Column[%ORDER].Alias, expected %s", answer.Columns[i].Alias, i, expectedColumn.Alias) + } + if expectedColumn.AggregateFunction != answer.Columns[i].AggregateFunction { + t.Errorf("got %d for Select.Column[%d].AggregateFunction, expected %d", answer.Columns[i].AggregateFunction, i, expectedColumn.AggregateFunction) } } }