but am now tired and writ hurts and is late and want to work on this on my main machine tomorrow
372 lines
8.3 KiB
Go
372 lines
8.3 KiB
Go
package q
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/DataDog/go-sqllexer"
|
|
)
|
|
|
|
type Token = sqllexer.Token
|
|
|
|
type parser struct {
|
|
sql string
|
|
tokens []Token
|
|
index int
|
|
query Query
|
|
lookBehindBuffer []Token
|
|
}
|
|
|
|
func (p *parser) lookAhead(count int) Token {
|
|
return p.tokens[p.index+count]
|
|
}
|
|
|
|
func (p *parser) lookBehind(count int) Token {
|
|
return p.tokens[p.index-count]
|
|
}
|
|
|
|
// Returns pointer of the first found Token and its index in parser.Tokens
|
|
// Returns -1 as the index if not found
|
|
func (p *parser) findToken(condition func(Token) bool) (*Token, int) {
|
|
for i, element := range p.tokens {
|
|
if condition(element) {
|
|
return &element, i
|
|
}
|
|
}
|
|
return nil, -1
|
|
}
|
|
|
|
type FoundToken struct {
|
|
Token Token
|
|
Index int
|
|
}
|
|
|
|
// Returns all tokens that match the condition, along with their indices
|
|
func (p *parser) findAllTokens(condition func(Token) bool) []FoundToken {
|
|
matches := make([]FoundToken, 0)
|
|
|
|
for i, token := range p.tokens {
|
|
if condition(token) {
|
|
matches = append(matches, FoundToken{
|
|
Token: token,
|
|
Index: i,
|
|
})
|
|
}
|
|
}
|
|
|
|
return matches
|
|
}
|
|
|
|
func filter[T any](items []T, fn func(item T) bool) []T {
|
|
filteredItems := []T{}
|
|
for _, value := range items {
|
|
if fn(value) {
|
|
filteredItems = append(filteredItems, value)
|
|
}
|
|
}
|
|
return filteredItems
|
|
}
|
|
|
|
func unshiftBuffer(buf *[10]sqllexer.Token, value sqllexer.Token) {
|
|
for i := 9; i >= 1; i-- {
|
|
buf[i] = buf[i-1]
|
|
}
|
|
|
|
buf[0] = value
|
|
}
|
|
|
|
func Parse(sql string) (Query, error) {
|
|
lexer := sqllexer.New(sql)
|
|
|
|
var tokens []Token
|
|
for {
|
|
t := lexer.Scan()
|
|
tokens = append(tokens, *t)
|
|
|
|
if IsTokenEndOfStatement(t) {
|
|
break
|
|
}
|
|
}
|
|
|
|
p := parser{
|
|
sql: sql,
|
|
tokens: tokens,
|
|
}
|
|
|
|
queryType, err := findQueryType(&p)
|
|
if err != nil {
|
|
return p.query, err
|
|
}
|
|
|
|
switch strings.ToUpper(queryType) {
|
|
case "SELECT":
|
|
p.query = &Select{
|
|
Type: SELECT,
|
|
}
|
|
parseSelectStatement(&p)
|
|
default:
|
|
return p.query, fmt.Errorf("No process defined for determined queryType: %s", queryType)
|
|
}
|
|
|
|
return p.query, nil
|
|
}
|
|
|
|
func findQueryType(p *parser) (string, error) {
|
|
firstCommand, _ := p.findToken(func(t Token) bool {
|
|
return t.Type == sqllexer.COMMAND
|
|
})
|
|
|
|
if firstCommand == nil {
|
|
return "", fmt.Errorf("Could not find query type")
|
|
}
|
|
|
|
return firstCommand.Value, nil
|
|
}
|
|
|
|
func parseSelectStatement(p *parser) error {
|
|
selectQuery := p.query.(*Select)
|
|
|
|
foundWildcard, _ := p.findToken(func(t Token) bool {
|
|
return t.Type == sqllexer.WILDCARD
|
|
})
|
|
|
|
selectQuery.IsWildcard = foundWildcard != nil
|
|
|
|
if !selectQuery.IsWildcard {
|
|
err := parseSelectColumns(p)
|
|
if err != nil {
|
|
fmt.Println(err.Error())
|
|
return err
|
|
}
|
|
}
|
|
|
|
tableErr := parseSelectTable(p)
|
|
if tableErr != nil {
|
|
return tableErr
|
|
}
|
|
|
|
conditionalsErr := parseSelectConditionals(p)
|
|
if conditionalsErr != nil {
|
|
return conditionalsErr
|
|
}
|
|
|
|
ordersByErr := parseOrderBys(p)
|
|
if ordersByErr != nil {
|
|
return ordersByErr
|
|
}
|
|
|
|
parseJoins(p)
|
|
|
|
return nil
|
|
}
|
|
|
|
func parseSelectColumns(p *parser) error {
|
|
selectQuery := p.query.(*Select)
|
|
|
|
_, selectCommandIndex := p.findToken(func(t Token) bool {
|
|
return t.Type == sqllexer.COMMAND && strings.ToUpper(t.Value) == "SELECT"
|
|
})
|
|
|
|
_, fromKeywordIndex := p.findToken(func(t Token) bool {
|
|
return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "FROM"
|
|
})
|
|
|
|
if selectCommandIndex < 0 || fromKeywordIndex < 0 {
|
|
return fmt.Errorf("Could not find range between SELECT and FROM")
|
|
}
|
|
|
|
lookBehindBuffer := [10]Token{}
|
|
var workingColumn Column
|
|
columns := make([]Column, 0)
|
|
|
|
startRange := selectCommandIndex + 1
|
|
endRange := fromKeywordIndex - 1
|
|
|
|
for i := startRange; i <= endRange; i++ {
|
|
token := p.tokens[i]
|
|
|
|
if token.Type == sqllexer.FUNCTION {
|
|
unshiftBuffer(&lookBehindBuffer, token)
|
|
workingColumn.AggregateFunction = AggregateFunctionTypeByName(token.Value)
|
|
continue
|
|
} else if token.Type == sqllexer.PUNCTUATION && token.Value == "," {
|
|
columns = append(columns, workingColumn)
|
|
workingColumn = Column{}
|
|
continue
|
|
} else if token.Type == sqllexer.IDENT {
|
|
unshiftBuffer(&lookBehindBuffer, token)
|
|
|
|
if lookBehindBuffer[1].Type == sqllexer.ALIAS_INDICATOR {
|
|
workingColumn.Alias = token.Value
|
|
} else {
|
|
workingColumn.Name = token.Value
|
|
}
|
|
continue
|
|
} else if token.Type == sqllexer.ALIAS_INDICATOR {
|
|
unshiftBuffer(&lookBehindBuffer, token)
|
|
continue
|
|
} else if i == endRange {
|
|
if workingColumn.Name != "" {
|
|
columns = append(columns, workingColumn)
|
|
workingColumn = Column{}
|
|
}
|
|
}
|
|
}
|
|
|
|
selectQuery.Columns = columns
|
|
|
|
return nil
|
|
}
|
|
|
|
func parseSelectTable(p *parser) error {
|
|
selectQuery := p.query.(*Select)
|
|
|
|
_, fromKeywordIndex := p.findToken(func(t Token) bool {
|
|
return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "FROM"
|
|
})
|
|
|
|
if fromKeywordIndex < 0 {
|
|
return fmt.Errorf("Could not FROM keyword to look for table name")
|
|
}
|
|
|
|
var foundTable Table
|
|
|
|
for i := fromKeywordIndex + 1; i < len(p.tokens); i++ {
|
|
t := &p.tokens[i]
|
|
if foundTable.Name == "" && t.Type == sqllexer.IDENT {
|
|
foundTable.Name = p.tokens[i].Value
|
|
continue
|
|
} else if t.Type == sqllexer.IDENT {
|
|
foundTable.Alias = p.tokens[i].Value
|
|
break
|
|
} else if t.Type == sqllexer.SPACE || t.Type == sqllexer.ALIAS_INDICATOR {
|
|
continue
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
|
|
if foundTable.Name == "" {
|
|
return fmt.Errorf("Could not find table name")
|
|
}
|
|
|
|
selectQuery.Table = foundTable
|
|
return nil
|
|
}
|
|
|
|
func parseSelectConditionals(p *parser) error {
|
|
selectQuery := p.query.(*Select)
|
|
|
|
_, whereKeywordIndex := p.findToken(func(t Token) bool {
|
|
return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "WHERE"
|
|
})
|
|
|
|
if whereKeywordIndex < 0 {
|
|
return nil // fmt.Errorf("Could not find WHERE to look for conditionals")
|
|
}
|
|
|
|
var workingConditional Conditional
|
|
for i := whereKeywordIndex + 1; i < len(p.tokens); i++ {
|
|
t := &p.tokens[i]
|
|
|
|
if t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) != "AND" && strings.ToUpper(t.Value) != "OR" && strings.ToUpper(t.Value) != "NOT" {
|
|
break
|
|
}
|
|
|
|
if t.Type == sqllexer.IDENT {
|
|
workingConditional.Key = t.Value
|
|
} else if t.Type == sqllexer.OPERATOR {
|
|
workingConditional.Operator = t.Value
|
|
} else if t.Type == sqllexer.BOOLEAN || t.Type == sqllexer.NULL || t.Type == sqllexer.STRING || t.Type == sqllexer.NUMBER {
|
|
workingConditional.Value = t.Value
|
|
} else if t.Type == sqllexer.KEYWORD {
|
|
if strings.ToUpper(t.Value) == "AND" || strings.ToUpper(t.Value) == "OR" {
|
|
workingConditional.Extension = strings.ToUpper(t.Value)
|
|
}
|
|
}
|
|
|
|
if workingConditional.Key != "" && workingConditional.Operator != "" && workingConditional.Value != "" {
|
|
selectQuery.Conditionals = append(selectQuery.Conditionals, workingConditional)
|
|
workingConditional = Conditional{}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func parseOrderBys(p *parser) error {
|
|
selectQuery := p.query.(*Select)
|
|
|
|
_, byKeywordIndex := p.findToken(func(t Token) bool {
|
|
return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "BY"
|
|
})
|
|
|
|
if byKeywordIndex < 0 {
|
|
return nil
|
|
}
|
|
|
|
var orderBys []OrderBy
|
|
|
|
var workingOrderBy OrderBy
|
|
for i := byKeywordIndex + 1; i < len(p.tokens); i++ {
|
|
t := &p.tokens[i]
|
|
if t.Type == sqllexer.SPACE {
|
|
continue
|
|
} else if t.Type == sqllexer.IDENT && workingOrderBy.Key == "" {
|
|
workingOrderBy.Key = t.Value
|
|
continue
|
|
} else if t.Type == sqllexer.IDENT && workingOrderBy.Key != "" {
|
|
orderBys = append(orderBys, workingOrderBy)
|
|
workingOrderBy.Key = t.Value
|
|
continue
|
|
} else if t.Type == sqllexer.KEYWORD {
|
|
if t.Value == "DESC" {
|
|
workingOrderBy.IsDescend = true
|
|
} else if t.Value != "ASC" {
|
|
break
|
|
}
|
|
} else if t.Type == sqllexer.PUNCTUATION {
|
|
orderBys = append(orderBys, workingOrderBy)
|
|
workingOrderBy = OrderBy{}
|
|
continue
|
|
}
|
|
}
|
|
|
|
if workingOrderBy.Key != "" {
|
|
orderBys = append(orderBys, workingOrderBy)
|
|
}
|
|
|
|
selectQuery.OrderBys = orderBys
|
|
|
|
return nil
|
|
}
|
|
|
|
func parseJoins(p *parser) error {
|
|
selectQuery := p.query.(*Select)
|
|
|
|
foundJoinKeywords := p.findAllTokens(func(t Token) bool {
|
|
return t.Type == sqllexer.COMMAND && strings.ToUpper(t.Value) == "JOIN"
|
|
})
|
|
|
|
if len(foundJoinKeywords) <= 0 {
|
|
return nil
|
|
}
|
|
|
|
var joinTokenRanges []Token
|
|
|
|
for i := 0; i < len(foundJoinKeywords); i++ {
|
|
startRangeIndex := foundJoinKeywords[i].Index
|
|
var endRangeIndex int
|
|
|
|
if i == (len(foundJoinKeywords) - 1) {
|
|
endRangeIndex = len(p.tokens) - 1
|
|
} else {
|
|
endRangeIndex = foundJoinKeywords[i+1].Index
|
|
}
|
|
|
|
joinTokenRanges = append(joinTokenRanges, p.tokens[startRangeIndex:endRangeIndex]...)
|
|
}
|
|
|
|
return nil
|
|
}
|