query-interpreter/q/select.go

360 lines
8.1 KiB
Go

package q
import (
"fmt"
"strings"
"github.com/DataDog/go-sqllexer"
)
type AggregateFunctionType int
const (
MIN AggregateFunctionType = iota + 1
MAX
COUNT
SUM
AVG
)
type JoinType int
const (
INNER JoinType = iota
LEFT
RIGHT
FULL
SELF
)
type Table struct {
Name string
Alias string
}
type Column struct {
Name string
Alias string
AggregateFunction AggregateFunctionType
}
type Join struct {
Type JoinType
Table Table
Ons []Conditional
}
type Select struct {
Table string
Columns []Column
Conditionals []Conditional
OrderBys []OrderBy
Joins []Join
IsWildcard bool
IsDistinct bool
}
type OrderBy struct {
Key string
IsDescend bool // SQL queries with no ASC|DESC on their ORDER BY are ASC by default, hence why this bool for the opposite
}
func GetAggregateFunctionTypeByName(name string) AggregateFunctionType {
var functionType AggregateFunctionType
switch strings.ToUpper(name) {
case "MIN":
functionType = MIN
case "MAX":
functionType = MAX
case "COUNT":
functionType = COUNT
case "SUM":
functionType = SUM
case "AVG":
functionType = AVG
default:
functionType = 0
}
return functionType
}
func GetAggregateFunctionNameByType(functionType AggregateFunctionType) string {
var functionName string
switch functionType {
case MIN:
functionName = "MIN"
case MAX:
functionName = "MAX"
case COUNT:
functionName = "COUNT"
case SUM:
functionName = "SUM"
case AVG:
functionName = "AVG"
default:
functionName = ""
}
return functionName
}
func GetFullStringFromColumn(column Column) string {
var workingSlice string
if column.AggregateFunction > 0 {
workingSlice = fmt.Sprintf(
"%s(%s)",
GetAggregateFunctionNameByType(column.AggregateFunction),
column.Name,
)
} else {
workingSlice = column.Name
}
if column.Alias != "" {
workingSlice += fmt.Sprintf(" AS %s", column.Alias)
}
return workingSlice
}
func (q *Select) GetFullSql() string {
var workingSqlSlice []string
workingSqlSlice = append(workingSqlSlice, "SELECT")
if q.IsWildcard {
workingSqlSlice = append(workingSqlSlice, "*")
} else {
for i, column := range q.Columns {
if i < (len(q.Columns) - 1) {
workingSqlSlice = append(workingSqlSlice, GetFullStringFromColumn(column)+",")
} else {
workingSqlSlice = append(workingSqlSlice, GetFullStringFromColumn(column))
}
}
}
workingSqlSlice = append(workingSqlSlice, "FROM "+q.Table)
// TODO: need to account for `AND` and `OR`s and stuff
for _, condition := range q.Conditionals {
workingSqlSlice = append(workingSqlSlice, condition.Key)
workingSqlSlice = append(workingSqlSlice, condition.Operator)
workingSqlSlice = append(workingSqlSlice, condition.Value)
}
fullSql := strings.Join(workingSqlSlice, " ")
return fullSql
}
func mutateSelectFromKeyword(query *Select, keyword string) {
switch strings.ToUpper(keyword) {
case "DISTINCT":
query.IsDistinct = true
}
}
func unshiftBuffer(buf *[10]sqllexer.Token, value sqllexer.Token) {
for i := 9; i >= 1; i-- {
buf[i] = buf[i-1]
}
buf[0] = value
}
func ParseSelectStatement(sql string) Select {
query := Select{}
passedSELECT := false
passedColumns := false
passedFROM := false
passedTable := false
passedWHERE := false
passedConditionals := false
passedOrderByKeywords := false
passesOrderByColumns := false
lookBehindBuffer := [10]sqllexer.Token{}
var workingConditional = Conditional{}
var workingColumn Column
var columns []Column
var orderBys []OrderBy
lexer := sqllexer.New(sql)
for {
token := lexer.Scan()
if !passedSELECT && strings.ToUpper(token.Value) != "SELECT" {
break
} else if !passedSELECT {
passedSELECT = true
continue
}
// For any keywords that are before the columns or wildcard
if passedSELECT && len(columns) == 0 && !passedColumns {
if token.Type == sqllexer.KEYWORD {
switch strings.ToUpper(token.Value) {
case "DISTINCT":
query.IsDistinct = true
continue
}
}
}
if !passedColumns {
if token.Type == sqllexer.WILDCARD {
passedColumns = true
columns = make([]Column, 0)
query.IsWildcard = true
continue
} else if token.Type == sqllexer.FUNCTION {
unshiftBuffer(&lookBehindBuffer, *token)
workingColumn.AggregateFunction = GetAggregateFunctionTypeByName(token.Value)
continue
} else if token.Type == sqllexer.PUNCTUATION {
if token.Value == "," {
columns = append(columns, workingColumn)
workingColumn = Column{}
}
continue
} else if token.Type == sqllexer.IDENT {
unshiftBuffer(&lookBehindBuffer, *token)
if lookBehindBuffer[0].Type == sqllexer.ALIAS_INDICATOR {
workingColumn.Alias = token.Value
} else {
workingColumn.Name = token.Value
}
continue
} else if token.Type == sqllexer.ALIAS_INDICATOR {
unshiftBuffer(&lookBehindBuffer, *token)
continue
} else if token.Type == sqllexer.SPACE {
continue
} else {
if workingColumn.Name != "" {
columns = append(columns, workingColumn)
workingColumn = Column{}
}
passedColumns = true
query.Columns = columns
}
}
// TODO: make sure to check for other keywords that are allowed
if !passedFROM && strings.ToUpper(token.Value) == "FROM" {
passedFROM = true
continue
}
if !passedTable && token.Type == sqllexer.IDENT {
passedTable = true
query.Table = token.Value
continue
} else if !passedTable {
continue
}
if !passedWHERE && token.Type == sqllexer.KEYWORD && strings.ToUpper(token.Value) == "WHERE" {
passedWHERE = true
continue
} else if !passedWHERE && token.Type == sqllexer.KEYWORD && strings.ToUpper(token.Value) != "WHERE" {
passedWHERE = true
}
if passedWHERE && !passedConditionals {
if token.Type == sqllexer.KEYWORD && strings.ToUpper(token.Value) != "AND" && strings.ToUpper(token.Value) != "OR" && strings.ToUpper(token.Value) != "NOT" {
passedConditionals = true
}
}
if passedWHERE && !passedConditionals {
if token.Type == sqllexer.IDENT {
workingConditional.Key = token.Value
} else if token.Type == sqllexer.OPERATOR {
workingConditional.Operator = token.Value
} else if token.Type == sqllexer.BOOLEAN || token.Type == sqllexer.NULL || token.Type == sqllexer.STRING || token.Type == sqllexer.NUMBER {
workingConditional.Value = token.Value
} else if token.Type == sqllexer.KEYWORD {
if strings.ToUpper(token.Value) == "AND" || strings.ToUpper(token.Value) == "OR" {
workingConditional.Extension = strings.ToUpper(token.Value)
}
}
if workingConditional.Key != "" && workingConditional.Operator != "" && workingConditional.Value != "" {
query.Conditionals = append(query.Conditionals, workingConditional)
workingConditional = Conditional{}
}
if IsTokenEndOfStatement(token) {
break
}
continue
}
// Checking For ORDER BY
if passedConditionals && !passedOrderByKeywords && token.Type == sqllexer.KEYWORD {
unshiftBuffer(&lookBehindBuffer, *token)
if strings.ToUpper(lookBehindBuffer[1].Value) == "ORDER" && strings.ToUpper(lookBehindBuffer[0].Value) == "BY" {
passedOrderByKeywords = true
}
}
if passedOrderByKeywords && !passesOrderByColumns {
if token.Type == sqllexer.IDENT || token.Type == sqllexer.KEYWORD {
unshiftBuffer(&lookBehindBuffer, *token)
continue
}
if token.Type == sqllexer.PUNCTUATION || token.Type == sqllexer.EOF {
var orderByColumnName string
var directionKeyword string
var isDescend bool = false
if lookBehindBuffer[0].Type == sqllexer.KEYWORD {
orderByColumnName = lookBehindBuffer[1].Value
directionKeyword = lookBehindBuffer[0].Value
} else if lookBehindBuffer[0].Type == sqllexer.IDENT {
orderByColumnName = lookBehindBuffer[0].Value
}
if strings.ToUpper(directionKeyword) == "DESC" {
isDescend = true
}
orderBys = append(orderBys, OrderBy{
Key: orderByColumnName,
IsDescend: isDescend,
})
if IsTokenEndOfStatement(token) {
query.OrderBys = orderBys
break
}
}
}
query.OrderBys = orderBys
if IsTokenEndOfStatement(token) {
break
}
}
return query
}