360 lines
8.1 KiB
Go
360 lines
8.1 KiB
Go
package q
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/DataDog/go-sqllexer"
|
|
)
|
|
|
|
type AggregateFunctionType int
|
|
|
|
const (
|
|
MIN AggregateFunctionType = iota + 1
|
|
MAX
|
|
COUNT
|
|
SUM
|
|
AVG
|
|
)
|
|
|
|
type JoinType int
|
|
|
|
const (
|
|
INNER JoinType = iota
|
|
LEFT
|
|
RIGHT
|
|
FULL
|
|
SELF
|
|
)
|
|
|
|
type Table struct {
|
|
Name string
|
|
Alias string
|
|
}
|
|
|
|
type Column struct {
|
|
Name string
|
|
Alias string
|
|
AggregateFunction AggregateFunctionType
|
|
}
|
|
|
|
type Join struct {
|
|
Type JoinType
|
|
Table Table
|
|
Ons []Conditional
|
|
}
|
|
|
|
type Select struct {
|
|
Table string
|
|
Columns []Column
|
|
Conditionals []Conditional
|
|
OrderBys []OrderBy
|
|
Joins []Join
|
|
IsWildcard bool
|
|
IsDistinct bool
|
|
}
|
|
|
|
type OrderBy struct {
|
|
Key string
|
|
IsDescend bool // SQL queries with no ASC|DESC on their ORDER BY are ASC by default, hence why this bool for the opposite
|
|
}
|
|
|
|
func GetAggregateFunctionTypeByName(name string) AggregateFunctionType {
|
|
var functionType AggregateFunctionType
|
|
switch strings.ToUpper(name) {
|
|
case "MIN":
|
|
functionType = MIN
|
|
case "MAX":
|
|
functionType = MAX
|
|
case "COUNT":
|
|
functionType = COUNT
|
|
case "SUM":
|
|
functionType = SUM
|
|
case "AVG":
|
|
functionType = AVG
|
|
default:
|
|
functionType = 0
|
|
}
|
|
|
|
return functionType
|
|
}
|
|
|
|
func GetAggregateFunctionNameByType(functionType AggregateFunctionType) string {
|
|
var functionName string
|
|
switch functionType {
|
|
case MIN:
|
|
functionName = "MIN"
|
|
case MAX:
|
|
functionName = "MAX"
|
|
case COUNT:
|
|
functionName = "COUNT"
|
|
case SUM:
|
|
functionName = "SUM"
|
|
case AVG:
|
|
functionName = "AVG"
|
|
default:
|
|
functionName = ""
|
|
}
|
|
|
|
return functionName
|
|
|
|
}
|
|
|
|
func GetFullStringFromColumn(column Column) string {
|
|
var workingSlice string
|
|
|
|
if column.AggregateFunction > 0 {
|
|
workingSlice = fmt.Sprintf(
|
|
"%s(%s)",
|
|
GetAggregateFunctionNameByType(column.AggregateFunction),
|
|
column.Name,
|
|
)
|
|
} else {
|
|
workingSlice = column.Name
|
|
}
|
|
|
|
if column.Alias != "" {
|
|
workingSlice += fmt.Sprintf(" AS %s", column.Alias)
|
|
}
|
|
|
|
return workingSlice
|
|
|
|
}
|
|
|
|
func (q *Select) GetFullSql() string {
|
|
var workingSqlSlice []string
|
|
|
|
workingSqlSlice = append(workingSqlSlice, "SELECT")
|
|
|
|
if q.IsWildcard {
|
|
workingSqlSlice = append(workingSqlSlice, "*")
|
|
} else {
|
|
for i, column := range q.Columns {
|
|
if i < (len(q.Columns) - 1) {
|
|
workingSqlSlice = append(workingSqlSlice, GetFullStringFromColumn(column)+",")
|
|
} else {
|
|
workingSqlSlice = append(workingSqlSlice, GetFullStringFromColumn(column))
|
|
}
|
|
}
|
|
}
|
|
|
|
workingSqlSlice = append(workingSqlSlice, "FROM "+q.Table)
|
|
|
|
// TODO: need to account for `AND` and `OR`s and stuff
|
|
for _, condition := range q.Conditionals {
|
|
workingSqlSlice = append(workingSqlSlice, condition.Key)
|
|
workingSqlSlice = append(workingSqlSlice, condition.Operator)
|
|
workingSqlSlice = append(workingSqlSlice, condition.Value)
|
|
}
|
|
|
|
fullSql := strings.Join(workingSqlSlice, " ")
|
|
|
|
return fullSql
|
|
}
|
|
|
|
func mutateSelectFromKeyword(query *Select, keyword string) {
|
|
switch strings.ToUpper(keyword) {
|
|
case "DISTINCT":
|
|
query.IsDistinct = true
|
|
}
|
|
}
|
|
|
|
func unshiftBuffer(buf *[10]sqllexer.Token, value sqllexer.Token) {
|
|
for i := 9; i >= 1; i-- {
|
|
buf[i] = buf[i-1]
|
|
}
|
|
|
|
buf[0] = value
|
|
}
|
|
|
|
func ParseSelectStatement(sql string) Select {
|
|
query := Select{}
|
|
|
|
passedSELECT := false
|
|
passedColumns := false
|
|
passedFROM := false
|
|
passedTable := false
|
|
passedWHERE := false
|
|
passedConditionals := false
|
|
passedOrderByKeywords := false
|
|
passesOrderByColumns := false
|
|
|
|
lookBehindBuffer := [10]sqllexer.Token{}
|
|
var workingConditional = Conditional{}
|
|
var workingColumn Column
|
|
|
|
var columns []Column
|
|
var orderBys []OrderBy
|
|
lexer := sqllexer.New(sql)
|
|
for {
|
|
token := lexer.Scan()
|
|
|
|
if !passedSELECT && strings.ToUpper(token.Value) != "SELECT" {
|
|
break
|
|
} else if !passedSELECT {
|
|
passedSELECT = true
|
|
continue
|
|
}
|
|
|
|
// For any keywords that are before the columns or wildcard
|
|
if passedSELECT && len(columns) == 0 && !passedColumns {
|
|
if token.Type == sqllexer.KEYWORD {
|
|
switch strings.ToUpper(token.Value) {
|
|
case "DISTINCT":
|
|
query.IsDistinct = true
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
if !passedColumns {
|
|
if token.Type == sqllexer.WILDCARD {
|
|
passedColumns = true
|
|
columns = make([]Column, 0)
|
|
query.IsWildcard = true
|
|
continue
|
|
} else if token.Type == sqllexer.FUNCTION {
|
|
unshiftBuffer(&lookBehindBuffer, *token)
|
|
workingColumn.AggregateFunction = GetAggregateFunctionTypeByName(token.Value)
|
|
continue
|
|
} else if token.Type == sqllexer.PUNCTUATION {
|
|
if token.Value == "," {
|
|
columns = append(columns, workingColumn)
|
|
workingColumn = Column{}
|
|
}
|
|
continue
|
|
} else if token.Type == sqllexer.IDENT {
|
|
unshiftBuffer(&lookBehindBuffer, *token)
|
|
|
|
if lookBehindBuffer[0].Type == sqllexer.ALIAS_INDICATOR {
|
|
workingColumn.Alias = token.Value
|
|
} else {
|
|
workingColumn.Name = token.Value
|
|
}
|
|
|
|
continue
|
|
} else if token.Type == sqllexer.ALIAS_INDICATOR {
|
|
unshiftBuffer(&lookBehindBuffer, *token)
|
|
continue
|
|
} else if token.Type == sqllexer.SPACE {
|
|
continue
|
|
} else {
|
|
if workingColumn.Name != "" {
|
|
columns = append(columns, workingColumn)
|
|
workingColumn = Column{}
|
|
}
|
|
|
|
passedColumns = true
|
|
query.Columns = columns
|
|
}
|
|
}
|
|
|
|
// TODO: make sure to check for other keywords that are allowed
|
|
if !passedFROM && strings.ToUpper(token.Value) == "FROM" {
|
|
passedFROM = true
|
|
continue
|
|
}
|
|
|
|
if !passedTable && token.Type == sqllexer.IDENT {
|
|
passedTable = true
|
|
query.Table = token.Value
|
|
continue
|
|
} else if !passedTable {
|
|
continue
|
|
}
|
|
|
|
if !passedWHERE && token.Type == sqllexer.KEYWORD && strings.ToUpper(token.Value) == "WHERE" {
|
|
passedWHERE = true
|
|
continue
|
|
} else if !passedWHERE && token.Type == sqllexer.KEYWORD && strings.ToUpper(token.Value) != "WHERE" {
|
|
passedWHERE = true
|
|
}
|
|
|
|
if passedWHERE && !passedConditionals {
|
|
if token.Type == sqllexer.KEYWORD && strings.ToUpper(token.Value) != "AND" && strings.ToUpper(token.Value) != "OR" && strings.ToUpper(token.Value) != "NOT" {
|
|
passedConditionals = true
|
|
}
|
|
}
|
|
|
|
if passedWHERE && !passedConditionals {
|
|
if token.Type == sqllexer.IDENT {
|
|
workingConditional.Key = token.Value
|
|
} else if token.Type == sqllexer.OPERATOR {
|
|
workingConditional.Operator = token.Value
|
|
} else if token.Type == sqllexer.BOOLEAN || token.Type == sqllexer.NULL || token.Type == sqllexer.STRING || token.Type == sqllexer.NUMBER {
|
|
workingConditional.Value = token.Value
|
|
} else if token.Type == sqllexer.KEYWORD {
|
|
if strings.ToUpper(token.Value) == "AND" || strings.ToUpper(token.Value) == "OR" {
|
|
workingConditional.Extension = strings.ToUpper(token.Value)
|
|
}
|
|
}
|
|
|
|
if workingConditional.Key != "" && workingConditional.Operator != "" && workingConditional.Value != "" {
|
|
query.Conditionals = append(query.Conditionals, workingConditional)
|
|
workingConditional = Conditional{}
|
|
}
|
|
|
|
if IsTokenEndOfStatement(token) {
|
|
break
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
// Checking For ORDER BY
|
|
if passedConditionals && !passedOrderByKeywords && token.Type == sqllexer.KEYWORD {
|
|
unshiftBuffer(&lookBehindBuffer, *token)
|
|
|
|
if strings.ToUpper(lookBehindBuffer[1].Value) == "ORDER" && strings.ToUpper(lookBehindBuffer[0].Value) == "BY" {
|
|
passedOrderByKeywords = true
|
|
}
|
|
}
|
|
|
|
if passedOrderByKeywords && !passesOrderByColumns {
|
|
|
|
if token.Type == sqllexer.IDENT || token.Type == sqllexer.KEYWORD {
|
|
unshiftBuffer(&lookBehindBuffer, *token)
|
|
continue
|
|
}
|
|
|
|
if token.Type == sqllexer.PUNCTUATION || token.Type == sqllexer.EOF {
|
|
|
|
var orderByColumnName string
|
|
var directionKeyword string
|
|
var isDescend bool = false
|
|
|
|
if lookBehindBuffer[0].Type == sqllexer.KEYWORD {
|
|
orderByColumnName = lookBehindBuffer[1].Value
|
|
directionKeyword = lookBehindBuffer[0].Value
|
|
} else if lookBehindBuffer[0].Type == sqllexer.IDENT {
|
|
orderByColumnName = lookBehindBuffer[0].Value
|
|
}
|
|
|
|
if strings.ToUpper(directionKeyword) == "DESC" {
|
|
isDescend = true
|
|
}
|
|
|
|
orderBys = append(orderBys, OrderBy{
|
|
Key: orderByColumnName,
|
|
IsDescend: isDescend,
|
|
})
|
|
|
|
if IsTokenEndOfStatement(token) {
|
|
query.OrderBys = orderBys
|
|
break
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
query.OrderBys = orderBys
|
|
|
|
if IsTokenEndOfStatement(token) {
|
|
break
|
|
}
|
|
|
|
}
|
|
|
|
return query
|
|
}
|