feat: aliases and aggregate functions

This commit is contained in:
Yehoshua Sandler 2025-04-24 21:01:39 -05:00
parent 28d041e244
commit 5380215a56
4 changed files with 206 additions and 40 deletions

38
main.go
View File

@ -2,31 +2,33 @@ package main
import (
"fmt"
"query-inter/q"
//"query-inter/q"
"github.com/DataDog/go-sqllexer"
// "github.com/DataDog/go-sqllexer"
)
func main() {
selectQuery := "SELECT id, name, createDate FROM users WHERE name=1;"
selectQuery := "SELECT MIN(Price) AS SmallestPrice, CategoryID FROM Products GROUP BY CategoryID;"
allStatements := q.ExtractSqlStatmentsFromString(selectQuery)
fmt.Println(allStatements)
//allStatements := q.ExtractSqlStatmentsFromString(selectQuery)
//fmt.Println(allStatements)
//lexer := sqllexer.New(selectQuery)
//for {
// token := lexer.Scan()
// fmt.Println(token.Value, token.Type)
//
// if token.Type == sqllexer.EOF {
// break
// }
//}
lexer := sqllexer.New(selectQuery)
for {
token := lexer.Scan()
fmt.Println(token.Value, token.Type)
for _, sql := range allStatements {
query := q.ParseSelectStatement(sql)
//fmt.Print(i)
//fmt.Println(query)
fmt.Println(query.GetFullSql())
if token.Type == sqllexer.EOF {
break
}
}
//for _, sql := range allStatements {
//query := q.ParseSelectStatement(sql)
//fmt.Print(i)
//fmt.Println(query)
//fmt.Println(query.GetFullSql())
//}
}

View File

@ -10,11 +10,12 @@ type Query interface {
GetFullSql() string
}
const NONE = 0
type QueryType int
const (
NONE QueryType = iota
SELECT
SELECT QueryType = iota + 1
UPDATE
INSERT
DELETE

View File

@ -1,16 +1,55 @@
package q
import (
"fmt"
"strings"
"github.com/DataDog/go-sqllexer"
)
type AggregateFunctionType int
const (
MIN AggregateFunctionType = iota + 1
MAX
COUNT
SUM
AVG
)
type JoinType int
const (
INNER JoinType = iota
LEFT
RIGHT
FULL
SELF
)
type Table struct {
Name string
Alias string
}
type Column struct {
Name string
Alias string
AggregateFunction AggregateFunctionType
}
type Join struct {
Type JoinType
Table Table
Ons []Conditional
}
type Select struct {
Table string
Columns []string
Columns []Column
Conditionals []Conditional
OrderBys []OrderBy
Joins []Join
IsWildcard bool
IsDistinct bool
}
@ -20,6 +59,68 @@ type OrderBy struct {
IsDescend bool // SQL queries with no ASC|DESC on their ORDER BY are ASC by default, hence why this bool for the opposite
}
func GetAggregateFunctionTypeByName(name string) AggregateFunctionType {
var functionType AggregateFunctionType
switch strings.ToUpper(name) {
case "MIN":
functionType = MIN
case "MAX":
functionType = MAX
case "COUNT":
functionType = COUNT
case "SUM":
functionType = SUM
case "AVG":
functionType = AVG
default:
functionType = 0
}
return functionType
}
func GetAggregateFunctionNameByType(functionType AggregateFunctionType) string {
var functionName string
switch functionType {
case MIN:
functionName = "MIN"
case MAX:
functionName = "MAX"
case COUNT:
functionName = "COUNT"
case SUM:
functionName = "SUM"
case AVG:
functionName = "AVG"
default:
functionName = ""
}
return functionName
}
func GetFullStringFromColumn(column Column) string {
var workingSlice string
if column.AggregateFunction > 0 {
workingSlice = fmt.Sprintf(
"%s(%s)",
GetAggregateFunctionNameByType(column.AggregateFunction),
column.Name,
)
} else {
workingSlice = column.Name
}
if column.Alias != "" {
workingSlice += fmt.Sprintf(" AS %s", column.Alias)
}
return workingSlice
}
func (q *Select) GetFullSql() string {
var workingSqlSlice []string
@ -30,9 +131,9 @@ func (q *Select) GetFullSql() string {
} else {
for i, column := range q.Columns {
if i < (len(q.Columns) - 1) {
workingSqlSlice = append(workingSqlSlice, column+",")
workingSqlSlice = append(workingSqlSlice, GetFullStringFromColumn(column)+",")
} else {
workingSqlSlice = append(workingSqlSlice, column)
workingSqlSlice = append(workingSqlSlice, GetFullStringFromColumn(column))
}
}
}
@ -80,8 +181,9 @@ func ParseSelectStatement(sql string) Select {
lookBehindBuffer := [10]sqllexer.Token{}
var workingConditional = Conditional{}
var workingColumn Column
var columns []string
var columns []Column
var orderBys []OrderBy
lexer := sqllexer.New(sql)
for {
@ -97,26 +199,53 @@ func ParseSelectStatement(sql string) Select {
// For any keywords that are before the columns or wildcard
if passedSELECT && len(columns) == 0 && !passedColumns {
if token.Type == sqllexer.KEYWORD {
mutateSelectFromKeyword(&query, token.Value)
continue
switch strings.ToUpper(token.Value) {
case "DISTINCT":
query.IsDistinct = true
continue
}
}
}
if !passedColumns {
if token.Type == sqllexer.WILDCARD {
passedColumns = true
columns = make([]string, 0)
columns = make([]Column, 0)
query.IsWildcard = true
continue
} else if token.Type == sqllexer.IDENT {
columns = append(columns, token.Value)
} else if token.Type == sqllexer.FUNCTION {
unshiftBuffer(&lookBehindBuffer, *token)
workingColumn.AggregateFunction = GetAggregateFunctionTypeByName(token.Value)
continue
} else if token.Type == sqllexer.PUNCTUATION || token.Type == sqllexer.SPACE {
} else if token.Type == sqllexer.PUNCTUATION {
if token.Value == "," {
columns = append(columns, workingColumn)
workingColumn = Column{}
}
continue
} else if token.Type == sqllexer.IDENT {
unshiftBuffer(&lookBehindBuffer, *token)
if lookBehindBuffer[0].Type == sqllexer.ALIAS_INDICATOR {
workingColumn.Alias = token.Value
} else {
workingColumn.Name = token.Value
}
continue
} else if token.Type == sqllexer.ALIAS_INDICATOR {
unshiftBuffer(&lookBehindBuffer, *token)
continue
} else if token.Type == sqllexer.SPACE {
continue
} else {
if workingColumn.Name != "" {
columns = append(columns, workingColumn)
workingColumn = Column{}
}
passedColumns = true
query.Columns = columns
continue
}
}

View File

@ -31,9 +31,13 @@ func TestParseSelectStatement(t *testing.T) {
expected: Select{
Table: "Customers",
IsWildcard: false,
Columns: []string{
"CustomerName",
"City",
Columns: []Column{
{
Name: "CustomerName",
},
{
Name: "City",
},
},
},
},
@ -41,8 +45,10 @@ func TestParseSelectStatement(t *testing.T) {
input: "SELECT DISTINCT Country FROM Nations;",
expected: Select{
Table: "Nations",
Columns: []string{
"Country",
Columns: []Column{
{
Name: "Country",
},
},
},
},
@ -79,11 +85,27 @@ func TestParseSelectStatement(t *testing.T) {
},
},
{
input: "SELECT id, streetNumber, streetName, city, state FROM Addresses WHERE state = 'AL' AND zip > 9000 ORDER BY zip DESC, streetNumber",
input: "SELECT id, streetNumber, streetName, city, state FROM Addresses WHERE state = 'AL' AND zip > 9000 OR zip <= 12000 ORDER BY zip DESC, streetNumber",
expected: Select{
Table: "Addresses",
IsWildcard: false,
Columns: []string{"id", "streetNumber", "streetName", "city", "state"},
Columns: []Column{
{
Name: "id",
},
{
Name: "streetNumber",
},
{
Name: "streetName",
},
{
Name: "city",
},
{
Name: "state",
},
},
Conditionals: []Conditional{
{
Key: "state",
@ -96,6 +118,12 @@ func TestParseSelectStatement(t *testing.T) {
Value: "9000",
Extension: "AND",
},
{
Key: "zip",
Operator: "<=",
Value: "12000",
Extension: "OR",
},
},
OrderBys: []OrderBy{
{
@ -128,8 +156,14 @@ func TestParseSelectStatement(t *testing.T) {
t.Errorf("got %d number of columns for Select.Columns, expected %d", len(answer.Columns), len(expected.Columns))
} else {
for i, expectedColumn := range expected.Columns {
if expectedColumn != answer.Columns[i] {
t.Errorf("got %s for Select.Column[%d], expected %s", answer.Columns[i], i, expectedColumn)
if expectedColumn.Name != answer.Columns[i].Name {
t.Errorf("got %s for Select.Column[%d].Name, expected %s", answer.Columns[i].Name, i, expectedColumn.Name)
}
if expectedColumn.Alias != answer.Columns[i].Alias {
t.Errorf("got %s for Select.Column[%ORDER].Alias, expected %s", answer.Columns[i].Alias, i, expectedColumn.Alias)
}
if expectedColumn.AggregateFunction != answer.Columns[i].AggregateFunction {
t.Errorf("got %d for Select.Column[%d].AggregateFunction, expected %d", answer.Columns[i].AggregateFunction, i, expectedColumn.AggregateFunction)
}
}
}