query-interpreter/q/parse.go
2025-05-13 18:27:23 -05:00

495 lines
12 KiB
Go

package q
import (
"fmt"
"strings"
"github.com/DataDog/go-sqllexer"
)
// Find first token in array that matches condition
func FindTokenInArray(tokens []Token, condition func(Token) bool) (*Token, int) {
for i, element := range tokens {
if condition(element) {
return &element, i
}
}
return nil, -1
}
type Token = sqllexer.Token
type parser struct {
sql string
tokens []Token
index int
query Query
lookBehindBuffer []Token
}
func (p *parser) lookAhead(count int) Token {
return p.tokens[p.index+count]
}
func (p *parser) lookBehind(count int) Token {
return p.tokens[p.index-count]
}
// Returns pointer of the first found Token and its index in parser.Tokens
// Returns -1 as the index if not found
func (p *parser) findToken(condition func(Token) bool) (*Token, int) {
for i, element := range p.tokens {
if condition(element) {
return &element, i
}
}
return nil, -1
}
type FoundToken struct {
Token Token
Index int
}
// Returns all tokens that match the condition, along with their indices
func (p *parser) findAllTokens(condition func(Token) bool) []FoundToken {
matches := make([]FoundToken, 0)
for i, token := range p.tokens {
if condition(token) {
matches = append(matches, FoundToken{
Token: token,
Index: i,
})
}
}
return matches
}
func filter[T any](items []T, fn func(item T) bool) []T {
filteredItems := []T{}
for _, value := range items {
if fn(value) {
filteredItems = append(filteredItems, value)
}
}
return filteredItems
}
func unshiftBuffer(buf *[10]sqllexer.Token, value sqllexer.Token) {
for i := 9; i >= 1; i-- {
buf[i] = buf[i-1]
}
buf[0] = value
}
func Parse(sql string) (Query, error) {
lexer := sqllexer.New(sql)
var tokens []Token
for {
t := lexer.Scan()
tokens = append(tokens, *t)
if IsTokenEndOfStatement(t) {
break
}
}
p := parser{
sql: sql,
tokens: tokens,
}
queryType, err := findQueryType(&p)
if err != nil {
return p.query, err
}
switch strings.ToUpper(queryType) {
case "SELECT":
p.query = &Select{
Type: SELECT,
}
parseSelectStatement(&p)
default:
return p.query, fmt.Errorf("No process defined for determined queryType: %s", queryType)
}
return p.query, nil
}
func findQueryType(p *parser) (string, error) {
firstCommand, _ := p.findToken(func(t Token) bool {
return t.Type == sqllexer.COMMAND
})
if firstCommand == nil {
return "", fmt.Errorf("Could not find query type")
}
return firstCommand.Value, nil
}
func parseSelectStatement(p *parser) error {
selectQuery := p.query.(*Select)
distinctErr := parseDistinct(p)
if distinctErr != nil {
return distinctErr
}
foundWildcard, _ := p.findToken(func(t Token) bool {
return t.Type == sqllexer.WILDCARD
})
selectQuery.IsWildcard = foundWildcard != nil
if !selectQuery.IsWildcard {
err := parseSelectColumns(p)
if err != nil {
fmt.Println(err.Error())
return err
}
}
tableErr := parseSelectTable(p)
if tableErr != nil {
return tableErr
}
conditionalsErr := parseSelectConditionals(p)
if conditionalsErr != nil {
return conditionalsErr
}
ordersByErr := parseOrderBys(p)
if ordersByErr != nil {
return ordersByErr
}
parseJoins(p)
return nil
}
func parseDistinct(p *parser) error {
selectQuery := p.query.(*Select)
foundDistinctKeyword, _ := p.findToken(func(t Token) bool {
return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "DISTINCT"
})
if foundDistinctKeyword != nil {
selectQuery.IsDistinct = true
}
return nil
}
func parseSelectColumns(p *parser) error {
selectQuery := p.query.(*Select)
_, selectCommandIndex := p.findToken(func(t Token) bool {
return t.Type == sqllexer.COMMAND && strings.ToUpper(t.Value) == "SELECT"
})
_, fromKeywordIndex := p.findToken(func(t Token) bool {
return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "FROM"
})
if selectCommandIndex < 0 || fromKeywordIndex < 0 {
return fmt.Errorf("Could not find range between SELECT and FROM")
}
lookBehindBuffer := [10]Token{}
var workingColumn Column
columns := make([]Column, 0)
startRange := selectCommandIndex + 1
endRange := fromKeywordIndex - 1
for i := startRange; i <= endRange; i++ {
token := p.tokens[i]
if token.Type == sqllexer.FUNCTION {
unshiftBuffer(&lookBehindBuffer, token)
workingColumn.AggregateFunction = AggregateFunctionTypeByName(token.Value)
continue
} else if token.Type == sqllexer.PUNCTUATION && token.Value == "," {
columns = append(columns, workingColumn)
workingColumn = Column{}
continue
} else if token.Type == sqllexer.IDENT {
unshiftBuffer(&lookBehindBuffer, token)
if lookBehindBuffer[1].Type == sqllexer.ALIAS_INDICATOR {
workingColumn.Alias = token.Value
} else {
workingColumn.Name = token.Value
}
continue
} else if token.Type == sqllexer.ALIAS_INDICATOR {
unshiftBuffer(&lookBehindBuffer, token)
continue
} else if i == endRange {
if workingColumn.Name != "" {
columns = append(columns, workingColumn)
workingColumn = Column{}
}
}
}
selectQuery.Columns = columns
return nil
}
func parseSelectTable(p *parser) error {
selectQuery := p.query.(*Select)
_, fromKeywordIndex := p.findToken(func(t Token) bool {
return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "FROM"
})
if fromKeywordIndex < 0 {
return fmt.Errorf("Could not FROM keyword to look for table name")
}
var foundTable Table
for i := fromKeywordIndex + 1; i < len(p.tokens); i++ {
t := &p.tokens[i]
if foundTable.Name == "" && t.Type == sqllexer.IDENT {
foundTable.Name = p.tokens[i].Value
continue
} else if t.Type == sqllexer.IDENT {
foundTable.Alias = p.tokens[i].Value
break
} else if t.Type == sqllexer.SPACE || t.Type == sqllexer.ALIAS_INDICATOR {
continue
} else {
break
}
}
if foundTable.Name == "" {
return fmt.Errorf("Could not find table name")
}
selectQuery.Table = foundTable
return nil
}
func parseSelectConditionals(p *parser) error {
selectQuery := p.query.(*Select)
_, whereKeywordIndex := p.findToken(func(t Token) bool {
return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "WHERE"
})
if whereKeywordIndex < 0 {
return nil // fmt.Errorf("Could not find WHERE to look for conditionals")
}
var workingConditional Conditional
for i := whereKeywordIndex + 1; i < len(p.tokens); i++ {
t := &p.tokens[i]
if t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) != "AND" && strings.ToUpper(t.Value) != "OR" && strings.ToUpper(t.Value) != "NOT" {
break
}
if t.Type == sqllexer.IDENT {
workingConditional.Key = t.Value
} else if t.Type == sqllexer.OPERATOR {
workingConditional.Operator = t.Value
} else if t.Type == sqllexer.BOOLEAN || t.Type == sqllexer.NULL || t.Type == sqllexer.STRING || t.Type == sqllexer.NUMBER {
workingConditional.Value = t.Value
} else if t.Type == sqllexer.KEYWORD {
if strings.ToUpper(t.Value) == "AND" || strings.ToUpper(t.Value) == "OR" {
workingConditional.Extension = strings.ToUpper(t.Value)
}
}
if workingConditional.Key != "" && workingConditional.Operator != "" && workingConditional.Value != "" {
selectQuery.Conditionals = append(selectQuery.Conditionals, workingConditional)
workingConditional = Conditional{}
}
}
return nil
}
func parseOrderBys(p *parser) error {
selectQuery := p.query.(*Select)
_, byKeywordIndex := p.findToken(func(t Token) bool {
return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "BY"
})
if byKeywordIndex < 0 {
return nil
}
var orderBys []OrderBy
var workingOrderBy OrderBy
for i := byKeywordIndex + 1; i < len(p.tokens); i++ {
t := &p.tokens[i]
if t.Type == sqllexer.SPACE {
continue
} else if t.Type == sqllexer.IDENT && workingOrderBy.Key == "" {
workingOrderBy.Key = t.Value
continue
} else if t.Type == sqllexer.IDENT && workingOrderBy.Key != "" {
orderBys = append(orderBys, workingOrderBy)
workingOrderBy.Key = t.Value
continue
} else if t.Type == sqllexer.KEYWORD {
if t.Value == "DESC" {
workingOrderBy.IsDescend = true
} else if t.Value != "ASC" {
break
}
} else if t.Type == sqllexer.PUNCTUATION {
orderBys = append(orderBys, workingOrderBy)
workingOrderBy = OrderBy{}
continue
}
}
if workingOrderBy.Key != "" {
orderBys = append(orderBys, workingOrderBy)
}
selectQuery.OrderBys = orderBys
return nil
}
func parseJoins(p *parser) error {
selectQuery := p.query.(*Select)
foundJoinKeywords := p.findAllTokens(func(t Token) bool {
return t.Type == sqllexer.COMMAND && strings.ToUpper(t.Value) == "JOIN"
})
if len(foundJoinKeywords) <= 0 {
return nil
}
type FoundJoinSubslices struct {
Tokens []Token
StartingIndexInGreaterStatement int
}
var joinTokenRanges []FoundJoinSubslices
for i, foundJoin := range foundJoinKeywords {
startRangeIndex := foundJoin.Index
var endRangeIndex int
if i == (len(foundJoinKeywords) - 1) {
endRangeIndex = len(p.tokens) - 1
} else {
endRangeIndex = foundJoinKeywords[i+1].Index
}
joinTokenRanges = append(joinTokenRanges, FoundJoinSubslices{
Tokens: p.tokens[startRangeIndex:endRangeIndex],
StartingIndexInGreaterStatement: startRangeIndex,
})
}
for _, joinRange := range joinTokenRanges {
var workingJoin Join
workingJoin.MainTable = selectQuery.Table
// check for the join type by looking backwards in the greater statement
joinTypeSearchIndex := joinRange.StartingIndexInGreaterStatement - 1
for ; joinTypeSearchIndex >= 0; joinTypeSearchIndex-- {
if p.tokens[joinTypeSearchIndex].Type == sqllexer.KEYWORD {
switch strings.ToUpper(p.tokens[joinTypeSearchIndex].Value) {
case "LEFT":
workingJoin.Type = LEFT
break
case "RIGHT":
workingJoin.Type = RIGHT
break
case "FULL":
workingJoin.Type = FULL
break
case "SELF":
workingJoin.Type = SELF
break
case "INNER":
workingJoin.Type = INNER
default:
workingJoin.Type = INNER
}
break // Stop after finding first keyword
}
}
// Find joined table name
for i := 1; i < len(joinRange.Tokens); i++ {
if joinRange.Tokens[i].Type == sqllexer.IDENT {
workingJoin.JoiningTable.Name = joinRange.Tokens[i].Value
break // Stop after finding first IDENT
// TODO: make sure you dont have to check for aliases
}
}
//var ons []Conditional
var workingOn Conditional
_, foundOnTokenIndex := FindTokenInArray(joinRange.Tokens, func(t Token) bool {
return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "ON"
})
if foundOnTokenIndex < 0 {
selectQuery.Joins = append(selectQuery.Joins, workingJoin)
continue
}
for i := foundOnTokenIndex + 1; i < len(joinRange.Tokens); i++ {
t := &joinRange.Tokens[i]
if t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) != "AND" && strings.ToUpper(t.Value) != "OR" && strings.ToUpper(t.Value) != "NOT" {
break
}
if t.Type == sqllexer.IDENT {
if workingOn.Key == "" {
workingOn.Key = t.Value
} else {
workingOn.Value = t.Value
}
} else if t.Type == sqllexer.OPERATOR {
workingOn.Operator = t.Value
} else if t.Type == sqllexer.BOOLEAN || t.Type == sqllexer.NULL || t.Type == sqllexer.STRING || t.Type == sqllexer.NUMBER {
workingOn.Value = t.Value
} else if t.Type == sqllexer.KEYWORD {
if strings.ToUpper(t.Value) == "AND" || strings.ToUpper(t.Value) == "OR" {
workingOn.Extension = strings.ToUpper(t.Value)
}
}
if workingOn.Key != "" && workingOn.Operator != "" && workingOn.Value != "" {
workingJoin.Ons = append(workingJoin.Ons, workingOn)
workingOn = Conditional{}
}
}
selectQuery.Joins = append(selectQuery.Joins, workingJoin)
workingJoin = Join{}
}
return nil
}