feat: aliases and aggregate functions
This commit is contained in:
parent
28d041e244
commit
5380215a56
38
main.go
38
main.go
@ -2,31 +2,33 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"query-inter/q"
|
//"query-inter/q"
|
||||||
|
|
||||||
|
"github.com/DataDog/go-sqllexer"
|
||||||
// "github.com/DataDog/go-sqllexer"
|
// "github.com/DataDog/go-sqllexer"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
selectQuery := "SELECT id, name, createDate FROM users WHERE name=1;"
|
selectQuery := "SELECT MIN(Price) AS SmallestPrice, CategoryID FROM Products GROUP BY CategoryID;"
|
||||||
|
|
||||||
allStatements := q.ExtractSqlStatmentsFromString(selectQuery)
|
//allStatements := q.ExtractSqlStatmentsFromString(selectQuery)
|
||||||
fmt.Println(allStatements)
|
//fmt.Println(allStatements)
|
||||||
|
|
||||||
//lexer := sqllexer.New(selectQuery)
|
lexer := sqllexer.New(selectQuery)
|
||||||
//for {
|
for {
|
||||||
// token := lexer.Scan()
|
token := lexer.Scan()
|
||||||
// fmt.Println(token.Value, token.Type)
|
fmt.Println(token.Value, token.Type)
|
||||||
//
|
|
||||||
// if token.Type == sqllexer.EOF {
|
|
||||||
// break
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
|
|
||||||
for _, sql := range allStatements {
|
if token.Type == sqllexer.EOF {
|
||||||
query := q.ParseSelectStatement(sql)
|
break
|
||||||
//fmt.Print(i)
|
}
|
||||||
//fmt.Println(query)
|
|
||||||
fmt.Println(query.GetFullSql())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//for _, sql := range allStatements {
|
||||||
|
//query := q.ParseSelectStatement(sql)
|
||||||
|
//fmt.Print(i)
|
||||||
|
//fmt.Println(query)
|
||||||
|
//fmt.Println(query.GetFullSql())
|
||||||
|
//}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,11 +10,12 @@ type Query interface {
|
|||||||
GetFullSql() string
|
GetFullSql() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const NONE = 0
|
||||||
|
|
||||||
type QueryType int
|
type QueryType int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NONE QueryType = iota
|
SELECT QueryType = iota + 1
|
||||||
SELECT
|
|
||||||
UPDATE
|
UPDATE
|
||||||
INSERT
|
INSERT
|
||||||
DELETE
|
DELETE
|
||||||
|
|||||||
151
q/select.go
151
q/select.go
@ -1,16 +1,55 @@
|
|||||||
package q
|
package q
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/DataDog/go-sqllexer"
|
"github.com/DataDog/go-sqllexer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type AggregateFunctionType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
MIN AggregateFunctionType = iota + 1
|
||||||
|
MAX
|
||||||
|
COUNT
|
||||||
|
SUM
|
||||||
|
AVG
|
||||||
|
)
|
||||||
|
|
||||||
|
type JoinType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
INNER JoinType = iota
|
||||||
|
LEFT
|
||||||
|
RIGHT
|
||||||
|
FULL
|
||||||
|
SELF
|
||||||
|
)
|
||||||
|
|
||||||
|
type Table struct {
|
||||||
|
Name string
|
||||||
|
Alias string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Column struct {
|
||||||
|
Name string
|
||||||
|
Alias string
|
||||||
|
AggregateFunction AggregateFunctionType
|
||||||
|
}
|
||||||
|
|
||||||
|
type Join struct {
|
||||||
|
Type JoinType
|
||||||
|
Table Table
|
||||||
|
Ons []Conditional
|
||||||
|
}
|
||||||
|
|
||||||
type Select struct {
|
type Select struct {
|
||||||
Table string
|
Table string
|
||||||
Columns []string
|
Columns []Column
|
||||||
Conditionals []Conditional
|
Conditionals []Conditional
|
||||||
OrderBys []OrderBy
|
OrderBys []OrderBy
|
||||||
|
Joins []Join
|
||||||
IsWildcard bool
|
IsWildcard bool
|
||||||
IsDistinct bool
|
IsDistinct bool
|
||||||
}
|
}
|
||||||
@ -20,6 +59,68 @@ type OrderBy struct {
|
|||||||
IsDescend bool // SQL queries with no ASC|DESC on their ORDER BY are ASC by default, hence why this bool for the opposite
|
IsDescend bool // SQL queries with no ASC|DESC on their ORDER BY are ASC by default, hence why this bool for the opposite
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetAggregateFunctionTypeByName(name string) AggregateFunctionType {
|
||||||
|
var functionType AggregateFunctionType
|
||||||
|
switch strings.ToUpper(name) {
|
||||||
|
case "MIN":
|
||||||
|
functionType = MIN
|
||||||
|
case "MAX":
|
||||||
|
functionType = MAX
|
||||||
|
case "COUNT":
|
||||||
|
functionType = COUNT
|
||||||
|
case "SUM":
|
||||||
|
functionType = SUM
|
||||||
|
case "AVG":
|
||||||
|
functionType = AVG
|
||||||
|
default:
|
||||||
|
functionType = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return functionType
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetAggregateFunctionNameByType(functionType AggregateFunctionType) string {
|
||||||
|
var functionName string
|
||||||
|
switch functionType {
|
||||||
|
case MIN:
|
||||||
|
functionName = "MIN"
|
||||||
|
case MAX:
|
||||||
|
functionName = "MAX"
|
||||||
|
case COUNT:
|
||||||
|
functionName = "COUNT"
|
||||||
|
case SUM:
|
||||||
|
functionName = "SUM"
|
||||||
|
case AVG:
|
||||||
|
functionName = "AVG"
|
||||||
|
default:
|
||||||
|
functionName = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return functionName
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetFullStringFromColumn(column Column) string {
|
||||||
|
var workingSlice string
|
||||||
|
|
||||||
|
if column.AggregateFunction > 0 {
|
||||||
|
workingSlice = fmt.Sprintf(
|
||||||
|
"%s(%s)",
|
||||||
|
GetAggregateFunctionNameByType(column.AggregateFunction),
|
||||||
|
column.Name,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
workingSlice = column.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
if column.Alias != "" {
|
||||||
|
workingSlice += fmt.Sprintf(" AS %s", column.Alias)
|
||||||
|
}
|
||||||
|
|
||||||
|
return workingSlice
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (q *Select) GetFullSql() string {
|
func (q *Select) GetFullSql() string {
|
||||||
var workingSqlSlice []string
|
var workingSqlSlice []string
|
||||||
|
|
||||||
@ -30,9 +131,9 @@ func (q *Select) GetFullSql() string {
|
|||||||
} else {
|
} else {
|
||||||
for i, column := range q.Columns {
|
for i, column := range q.Columns {
|
||||||
if i < (len(q.Columns) - 1) {
|
if i < (len(q.Columns) - 1) {
|
||||||
workingSqlSlice = append(workingSqlSlice, column+",")
|
workingSqlSlice = append(workingSqlSlice, GetFullStringFromColumn(column)+",")
|
||||||
} else {
|
} else {
|
||||||
workingSqlSlice = append(workingSqlSlice, column)
|
workingSqlSlice = append(workingSqlSlice, GetFullStringFromColumn(column))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -80,8 +181,9 @@ func ParseSelectStatement(sql string) Select {
|
|||||||
|
|
||||||
lookBehindBuffer := [10]sqllexer.Token{}
|
lookBehindBuffer := [10]sqllexer.Token{}
|
||||||
var workingConditional = Conditional{}
|
var workingConditional = Conditional{}
|
||||||
|
var workingColumn Column
|
||||||
|
|
||||||
var columns []string
|
var columns []Column
|
||||||
var orderBys []OrderBy
|
var orderBys []OrderBy
|
||||||
lexer := sqllexer.New(sql)
|
lexer := sqllexer.New(sql)
|
||||||
for {
|
for {
|
||||||
@ -97,26 +199,53 @@ func ParseSelectStatement(sql string) Select {
|
|||||||
// For any keywords that are before the columns or wildcard
|
// For any keywords that are before the columns or wildcard
|
||||||
if passedSELECT && len(columns) == 0 && !passedColumns {
|
if passedSELECT && len(columns) == 0 && !passedColumns {
|
||||||
if token.Type == sqllexer.KEYWORD {
|
if token.Type == sqllexer.KEYWORD {
|
||||||
mutateSelectFromKeyword(&query, token.Value)
|
switch strings.ToUpper(token.Value) {
|
||||||
continue
|
case "DISTINCT":
|
||||||
|
query.IsDistinct = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !passedColumns {
|
if !passedColumns {
|
||||||
if token.Type == sqllexer.WILDCARD {
|
if token.Type == sqllexer.WILDCARD {
|
||||||
passedColumns = true
|
passedColumns = true
|
||||||
columns = make([]string, 0)
|
columns = make([]Column, 0)
|
||||||
query.IsWildcard = true
|
query.IsWildcard = true
|
||||||
continue
|
continue
|
||||||
} else if token.Type == sqllexer.IDENT {
|
} else if token.Type == sqllexer.FUNCTION {
|
||||||
columns = append(columns, token.Value)
|
unshiftBuffer(&lookBehindBuffer, *token)
|
||||||
|
workingColumn.AggregateFunction = GetAggregateFunctionTypeByName(token.Value)
|
||||||
continue
|
continue
|
||||||
} else if token.Type == sqllexer.PUNCTUATION || token.Type == sqllexer.SPACE {
|
} else if token.Type == sqllexer.PUNCTUATION {
|
||||||
|
if token.Value == "," {
|
||||||
|
columns = append(columns, workingColumn)
|
||||||
|
workingColumn = Column{}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
} else if token.Type == sqllexer.IDENT {
|
||||||
|
unshiftBuffer(&lookBehindBuffer, *token)
|
||||||
|
|
||||||
|
if lookBehindBuffer[0].Type == sqllexer.ALIAS_INDICATOR {
|
||||||
|
workingColumn.Alias = token.Value
|
||||||
|
} else {
|
||||||
|
workingColumn.Name = token.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
} else if token.Type == sqllexer.ALIAS_INDICATOR {
|
||||||
|
unshiftBuffer(&lookBehindBuffer, *token)
|
||||||
|
continue
|
||||||
|
} else if token.Type == sqllexer.SPACE {
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
|
if workingColumn.Name != "" {
|
||||||
|
columns = append(columns, workingColumn)
|
||||||
|
workingColumn = Column{}
|
||||||
|
}
|
||||||
|
|
||||||
passedColumns = true
|
passedColumns = true
|
||||||
query.Columns = columns
|
query.Columns = columns
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -31,9 +31,13 @@ func TestParseSelectStatement(t *testing.T) {
|
|||||||
expected: Select{
|
expected: Select{
|
||||||
Table: "Customers",
|
Table: "Customers",
|
||||||
IsWildcard: false,
|
IsWildcard: false,
|
||||||
Columns: []string{
|
Columns: []Column{
|
||||||
"CustomerName",
|
{
|
||||||
"City",
|
Name: "CustomerName",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "City",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -41,8 +45,10 @@ func TestParseSelectStatement(t *testing.T) {
|
|||||||
input: "SELECT DISTINCT Country FROM Nations;",
|
input: "SELECT DISTINCT Country FROM Nations;",
|
||||||
expected: Select{
|
expected: Select{
|
||||||
Table: "Nations",
|
Table: "Nations",
|
||||||
Columns: []string{
|
Columns: []Column{
|
||||||
"Country",
|
{
|
||||||
|
Name: "Country",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -79,11 +85,27 @@ func TestParseSelectStatement(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
input: "SELECT id, streetNumber, streetName, city, state FROM Addresses WHERE state = 'AL' AND zip > 9000 ORDER BY zip DESC, streetNumber",
|
input: "SELECT id, streetNumber, streetName, city, state FROM Addresses WHERE state = 'AL' AND zip > 9000 OR zip <= 12000 ORDER BY zip DESC, streetNumber",
|
||||||
expected: Select{
|
expected: Select{
|
||||||
Table: "Addresses",
|
Table: "Addresses",
|
||||||
IsWildcard: false,
|
IsWildcard: false,
|
||||||
Columns: []string{"id", "streetNumber", "streetName", "city", "state"},
|
Columns: []Column{
|
||||||
|
{
|
||||||
|
Name: "id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "streetNumber",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "streetName",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "city",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "state",
|
||||||
|
},
|
||||||
|
},
|
||||||
Conditionals: []Conditional{
|
Conditionals: []Conditional{
|
||||||
{
|
{
|
||||||
Key: "state",
|
Key: "state",
|
||||||
@ -96,6 +118,12 @@ func TestParseSelectStatement(t *testing.T) {
|
|||||||
Value: "9000",
|
Value: "9000",
|
||||||
Extension: "AND",
|
Extension: "AND",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Key: "zip",
|
||||||
|
Operator: "<=",
|
||||||
|
Value: "12000",
|
||||||
|
Extension: "OR",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
OrderBys: []OrderBy{
|
OrderBys: []OrderBy{
|
||||||
{
|
{
|
||||||
@ -128,8 +156,14 @@ func TestParseSelectStatement(t *testing.T) {
|
|||||||
t.Errorf("got %d number of columns for Select.Columns, expected %d", len(answer.Columns), len(expected.Columns))
|
t.Errorf("got %d number of columns for Select.Columns, expected %d", len(answer.Columns), len(expected.Columns))
|
||||||
} else {
|
} else {
|
||||||
for i, expectedColumn := range expected.Columns {
|
for i, expectedColumn := range expected.Columns {
|
||||||
if expectedColumn != answer.Columns[i] {
|
if expectedColumn.Name != answer.Columns[i].Name {
|
||||||
t.Errorf("got %s for Select.Column[%d], expected %s", answer.Columns[i], i, expectedColumn)
|
t.Errorf("got %s for Select.Column[%d].Name, expected %s", answer.Columns[i].Name, i, expectedColumn.Name)
|
||||||
|
}
|
||||||
|
if expectedColumn.Alias != answer.Columns[i].Alias {
|
||||||
|
t.Errorf("got %s for Select.Column[%ORDER].Alias, expected %s", answer.Columns[i].Alias, i, expectedColumn.Alias)
|
||||||
|
}
|
||||||
|
if expectedColumn.AggregateFunction != answer.Columns[i].AggregateFunction {
|
||||||
|
t.Errorf("got %d for Select.Column[%d].AggregateFunction, expected %d", answer.Columns[i].AggregateFunction, i, expectedColumn.AggregateFunction)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user