feat: aliases and aggregate functions

This commit is contained in:
Yehoshua Sandler 2025-04-24 21:01:39 -05:00
parent 28d041e244
commit 5380215a56
4 changed files with 206 additions and 40 deletions

38
main.go
View File

@ -2,31 +2,33 @@ package main
import ( import (
"fmt" "fmt"
"query-inter/q" //"query-inter/q"
"github.com/DataDog/go-sqllexer"
// "github.com/DataDog/go-sqllexer" // "github.com/DataDog/go-sqllexer"
) )
func main() { func main() {
selectQuery := "SELECT id, name, createDate FROM users WHERE name=1;" selectQuery := "SELECT MIN(Price) AS SmallestPrice, CategoryID FROM Products GROUP BY CategoryID;"
allStatements := q.ExtractSqlStatmentsFromString(selectQuery) //allStatements := q.ExtractSqlStatmentsFromString(selectQuery)
fmt.Println(allStatements) //fmt.Println(allStatements)
//lexer := sqllexer.New(selectQuery) lexer := sqllexer.New(selectQuery)
//for { for {
// token := lexer.Scan() token := lexer.Scan()
// fmt.Println(token.Value, token.Type) fmt.Println(token.Value, token.Type)
//
// if token.Type == sqllexer.EOF {
// break
// }
//}
for _, sql := range allStatements { if token.Type == sqllexer.EOF {
query := q.ParseSelectStatement(sql) break
//fmt.Print(i) }
//fmt.Println(query)
fmt.Println(query.GetFullSql())
} }
//for _, sql := range allStatements {
//query := q.ParseSelectStatement(sql)
//fmt.Print(i)
//fmt.Println(query)
//fmt.Println(query.GetFullSql())
//}
} }

View File

@ -10,11 +10,12 @@ type Query interface {
GetFullSql() string GetFullSql() string
} }
const NONE = 0
type QueryType int type QueryType int
const ( const (
NONE QueryType = iota SELECT QueryType = iota + 1
SELECT
UPDATE UPDATE
INSERT INSERT
DELETE DELETE

View File

@ -1,16 +1,55 @@
package q package q
import ( import (
"fmt"
"strings" "strings"
"github.com/DataDog/go-sqllexer" "github.com/DataDog/go-sqllexer"
) )
type AggregateFunctionType int
const (
MIN AggregateFunctionType = iota + 1
MAX
COUNT
SUM
AVG
)
type JoinType int
const (
INNER JoinType = iota
LEFT
RIGHT
FULL
SELF
)
type Table struct {
Name string
Alias string
}
type Column struct {
Name string
Alias string
AggregateFunction AggregateFunctionType
}
type Join struct {
Type JoinType
Table Table
Ons []Conditional
}
type Select struct { type Select struct {
Table string Table string
Columns []string Columns []Column
Conditionals []Conditional Conditionals []Conditional
OrderBys []OrderBy OrderBys []OrderBy
Joins []Join
IsWildcard bool IsWildcard bool
IsDistinct bool IsDistinct bool
} }
@ -20,6 +59,68 @@ type OrderBy struct {
IsDescend bool // SQL queries with no ASC|DESC on their ORDER BY are ASC by default, hence why this bool for the opposite IsDescend bool // SQL queries with no ASC|DESC on their ORDER BY are ASC by default, hence why this bool for the opposite
} }
func GetAggregateFunctionTypeByName(name string) AggregateFunctionType {
var functionType AggregateFunctionType
switch strings.ToUpper(name) {
case "MIN":
functionType = MIN
case "MAX":
functionType = MAX
case "COUNT":
functionType = COUNT
case "SUM":
functionType = SUM
case "AVG":
functionType = AVG
default:
functionType = 0
}
return functionType
}
func GetAggregateFunctionNameByType(functionType AggregateFunctionType) string {
var functionName string
switch functionType {
case MIN:
functionName = "MIN"
case MAX:
functionName = "MAX"
case COUNT:
functionName = "COUNT"
case SUM:
functionName = "SUM"
case AVG:
functionName = "AVG"
default:
functionName = ""
}
return functionName
}
func GetFullStringFromColumn(column Column) string {
var workingSlice string
if column.AggregateFunction > 0 {
workingSlice = fmt.Sprintf(
"%s(%s)",
GetAggregateFunctionNameByType(column.AggregateFunction),
column.Name,
)
} else {
workingSlice = column.Name
}
if column.Alias != "" {
workingSlice += fmt.Sprintf(" AS %s", column.Alias)
}
return workingSlice
}
func (q *Select) GetFullSql() string { func (q *Select) GetFullSql() string {
var workingSqlSlice []string var workingSqlSlice []string
@ -30,9 +131,9 @@ func (q *Select) GetFullSql() string {
} else { } else {
for i, column := range q.Columns { for i, column := range q.Columns {
if i < (len(q.Columns) - 1) { if i < (len(q.Columns) - 1) {
workingSqlSlice = append(workingSqlSlice, column+",") workingSqlSlice = append(workingSqlSlice, GetFullStringFromColumn(column)+",")
} else { } else {
workingSqlSlice = append(workingSqlSlice, column) workingSqlSlice = append(workingSqlSlice, GetFullStringFromColumn(column))
} }
} }
} }
@ -80,8 +181,9 @@ func ParseSelectStatement(sql string) Select {
lookBehindBuffer := [10]sqllexer.Token{} lookBehindBuffer := [10]sqllexer.Token{}
var workingConditional = Conditional{} var workingConditional = Conditional{}
var workingColumn Column
var columns []string var columns []Column
var orderBys []OrderBy var orderBys []OrderBy
lexer := sqllexer.New(sql) lexer := sqllexer.New(sql)
for { for {
@ -97,26 +199,53 @@ func ParseSelectStatement(sql string) Select {
// For any keywords that are before the columns or wildcard // For any keywords that are before the columns or wildcard
if passedSELECT && len(columns) == 0 && !passedColumns { if passedSELECT && len(columns) == 0 && !passedColumns {
if token.Type == sqllexer.KEYWORD { if token.Type == sqllexer.KEYWORD {
mutateSelectFromKeyword(&query, token.Value) switch strings.ToUpper(token.Value) {
continue case "DISTINCT":
query.IsDistinct = true
continue
}
} }
} }
if !passedColumns { if !passedColumns {
if token.Type == sqllexer.WILDCARD { if token.Type == sqllexer.WILDCARD {
passedColumns = true passedColumns = true
columns = make([]string, 0) columns = make([]Column, 0)
query.IsWildcard = true query.IsWildcard = true
continue continue
} else if token.Type == sqllexer.IDENT { } else if token.Type == sqllexer.FUNCTION {
columns = append(columns, token.Value) unshiftBuffer(&lookBehindBuffer, *token)
workingColumn.AggregateFunction = GetAggregateFunctionTypeByName(token.Value)
continue continue
} else if token.Type == sqllexer.PUNCTUATION || token.Type == sqllexer.SPACE { } else if token.Type == sqllexer.PUNCTUATION {
if token.Value == "," {
columns = append(columns, workingColumn)
workingColumn = Column{}
}
continue
} else if token.Type == sqllexer.IDENT {
unshiftBuffer(&lookBehindBuffer, *token)
if lookBehindBuffer[0].Type == sqllexer.ALIAS_INDICATOR {
workingColumn.Alias = token.Value
} else {
workingColumn.Name = token.Value
}
continue
} else if token.Type == sqllexer.ALIAS_INDICATOR {
unshiftBuffer(&lookBehindBuffer, *token)
continue
} else if token.Type == sqllexer.SPACE {
continue continue
} else { } else {
if workingColumn.Name != "" {
columns = append(columns, workingColumn)
workingColumn = Column{}
}
passedColumns = true passedColumns = true
query.Columns = columns query.Columns = columns
continue
} }
} }

View File

@ -31,9 +31,13 @@ func TestParseSelectStatement(t *testing.T) {
expected: Select{ expected: Select{
Table: "Customers", Table: "Customers",
IsWildcard: false, IsWildcard: false,
Columns: []string{ Columns: []Column{
"CustomerName", {
"City", Name: "CustomerName",
},
{
Name: "City",
},
}, },
}, },
}, },
@ -41,8 +45,10 @@ func TestParseSelectStatement(t *testing.T) {
input: "SELECT DISTINCT Country FROM Nations;", input: "SELECT DISTINCT Country FROM Nations;",
expected: Select{ expected: Select{
Table: "Nations", Table: "Nations",
Columns: []string{ Columns: []Column{
"Country", {
Name: "Country",
},
}, },
}, },
}, },
@ -79,11 +85,27 @@ func TestParseSelectStatement(t *testing.T) {
}, },
}, },
{ {
input: "SELECT id, streetNumber, streetName, city, state FROM Addresses WHERE state = 'AL' AND zip > 9000 ORDER BY zip DESC, streetNumber", input: "SELECT id, streetNumber, streetName, city, state FROM Addresses WHERE state = 'AL' AND zip > 9000 OR zip <= 12000 ORDER BY zip DESC, streetNumber",
expected: Select{ expected: Select{
Table: "Addresses", Table: "Addresses",
IsWildcard: false, IsWildcard: false,
Columns: []string{"id", "streetNumber", "streetName", "city", "state"}, Columns: []Column{
{
Name: "id",
},
{
Name: "streetNumber",
},
{
Name: "streetName",
},
{
Name: "city",
},
{
Name: "state",
},
},
Conditionals: []Conditional{ Conditionals: []Conditional{
{ {
Key: "state", Key: "state",
@ -96,6 +118,12 @@ func TestParseSelectStatement(t *testing.T) {
Value: "9000", Value: "9000",
Extension: "AND", Extension: "AND",
}, },
{
Key: "zip",
Operator: "<=",
Value: "12000",
Extension: "OR",
},
}, },
OrderBys: []OrderBy{ OrderBys: []OrderBy{
{ {
@ -128,8 +156,14 @@ func TestParseSelectStatement(t *testing.T) {
t.Errorf("got %d number of columns for Select.Columns, expected %d", len(answer.Columns), len(expected.Columns)) t.Errorf("got %d number of columns for Select.Columns, expected %d", len(answer.Columns), len(expected.Columns))
} else { } else {
for i, expectedColumn := range expected.Columns { for i, expectedColumn := range expected.Columns {
if expectedColumn != answer.Columns[i] { if expectedColumn.Name != answer.Columns[i].Name {
t.Errorf("got %s for Select.Column[%d], expected %s", answer.Columns[i], i, expectedColumn) t.Errorf("got %s for Select.Column[%d].Name, expected %s", answer.Columns[i].Name, i, expectedColumn.Name)
}
if expectedColumn.Alias != answer.Columns[i].Alias {
t.Errorf("got %s for Select.Column[%ORDER].Alias, expected %s", answer.Columns[i].Alias, i, expectedColumn.Alias)
}
if expectedColumn.AggregateFunction != answer.Columns[i].AggregateFunction {
t.Errorf("got %d for Select.Column[%d].AggregateFunction, expected %d", answer.Columns[i].AggregateFunction, i, expectedColumn.AggregateFunction)
} }
} }
} }