package q import ( "fmt" "strings" "github.com/DataDog/go-sqllexer" ) type Token = sqllexer.Token type parser struct { sql string tokens []Token index int query Query lookBehindBuffer []Token } func (p *parser) lookAhead(count int) Token { return p.tokens[p.index+count] } func (p *parser) lookBehind(count int) Token { return p.tokens[p.index-count] } // Returns pointer of the first found Token and its index in parser.Tokens // Returns -1 as the index if not found func (p *parser) findToken(condition func(Token) bool) (*Token, int) { for i, element := range p.tokens { if condition(element) { return &element, i } } return nil, -1 } type FoundToken struct { Token Token Index int } // Returns all tokens that match the condition, along with their indices func (p *parser) findAllTokens(condition func(Token) bool) []FoundToken { matches := make([]FoundToken, 0) for i, token := range p.tokens { if condition(token) { matches = append(matches, FoundToken{ Token: token, Index: i, }) } } return matches } func filter[T any](items []T, fn func(item T) bool) []T { filteredItems := []T{} for _, value := range items { if fn(value) { filteredItems = append(filteredItems, value) } } return filteredItems } func unshiftBuffer(buf *[10]sqllexer.Token, value sqllexer.Token) { for i := 9; i >= 1; i-- { buf[i] = buf[i-1] } buf[0] = value } func Parse(sql string) (Query, error) { lexer := sqllexer.New(sql) var tokens []Token for { t := lexer.Scan() tokens = append(tokens, *t) if IsTokenEndOfStatement(t) { break } } p := parser{ sql: sql, tokens: tokens, } queryType, err := findQueryType(&p) if err != nil { return p.query, err } switch strings.ToUpper(queryType) { case "SELECT": p.query = &Select{ Type: SELECT, } parseSelectStatement(&p) default: return p.query, fmt.Errorf("No process defined for determined queryType: %s", queryType) } return p.query, nil } func findQueryType(p *parser) (string, error) { firstCommand, _ := p.findToken(func(t Token) bool { return t.Type == sqllexer.COMMAND }) if firstCommand == nil { return "", fmt.Errorf("Could not find query type") } return firstCommand.Value, nil } func parseSelectStatement(p *parser) error { selectQuery := p.query.(*Select) foundWildcard, _ := p.findToken(func(t Token) bool { return t.Type == sqllexer.WILDCARD }) selectQuery.IsWildcard = foundWildcard != nil if !selectQuery.IsWildcard { err := parseSelectColumns(p) if err != nil { fmt.Println(err.Error()) return err } } tableErr := parseSelectTable(p) if tableErr != nil { return tableErr } conditionalsErr := parseSelectConditionals(p) if conditionalsErr != nil { return conditionalsErr } ordersByErr := parseOrderBys(p) if ordersByErr != nil { return ordersByErr } parseJoins(p) return nil } func parseSelectColumns(p *parser) error { selectQuery := p.query.(*Select) _, selectCommandIndex := p.findToken(func(t Token) bool { return t.Type == sqllexer.COMMAND && strings.ToUpper(t.Value) == "SELECT" }) _, fromKeywordIndex := p.findToken(func(t Token) bool { return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "FROM" }) if selectCommandIndex < 0 || fromKeywordIndex < 0 { return fmt.Errorf("Could not find range between SELECT and FROM") } lookBehindBuffer := [10]Token{} var workingColumn Column columns := make([]Column, 0) startRange := selectCommandIndex + 1 endRange := fromKeywordIndex - 1 for i := startRange; i <= endRange; i++ { token := p.tokens[i] if token.Type == sqllexer.FUNCTION { unshiftBuffer(&lookBehindBuffer, token) workingColumn.AggregateFunction = AggregateFunctionTypeByName(token.Value) continue } else if token.Type == sqllexer.PUNCTUATION && token.Value == "," { columns = append(columns, workingColumn) workingColumn = Column{} continue } else if token.Type == sqllexer.IDENT { unshiftBuffer(&lookBehindBuffer, token) if lookBehindBuffer[1].Type == sqllexer.ALIAS_INDICATOR { workingColumn.Alias = token.Value } else { workingColumn.Name = token.Value } continue } else if token.Type == sqllexer.ALIAS_INDICATOR { unshiftBuffer(&lookBehindBuffer, token) continue } else if i == endRange { if workingColumn.Name != "" { columns = append(columns, workingColumn) workingColumn = Column{} } } } selectQuery.Columns = columns return nil } func parseSelectTable(p *parser) error { selectQuery := p.query.(*Select) _, fromKeywordIndex := p.findToken(func(t Token) bool { return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "FROM" }) if fromKeywordIndex < 0 { return fmt.Errorf("Could not FROM keyword to look for table name") } var foundTable Table for i := fromKeywordIndex + 1; i < len(p.tokens); i++ { t := &p.tokens[i] if foundTable.Name == "" && t.Type == sqllexer.IDENT { foundTable.Name = p.tokens[i].Value continue } else if t.Type == sqllexer.IDENT { foundTable.Alias = p.tokens[i].Value break } else if t.Type == sqllexer.SPACE || t.Type == sqllexer.ALIAS_INDICATOR { continue } else { break } } if foundTable.Name == "" { return fmt.Errorf("Could not find table name") } selectQuery.Table = foundTable return nil } func parseSelectConditionals(p *parser) error { selectQuery := p.query.(*Select) _, whereKeywordIndex := p.findToken(func(t Token) bool { return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "WHERE" }) if whereKeywordIndex < 0 { return nil // fmt.Errorf("Could not find WHERE to look for conditionals") } var workingConditional Conditional for i := whereKeywordIndex + 1; i < len(p.tokens); i++ { t := &p.tokens[i] if t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) != "AND" && strings.ToUpper(t.Value) != "OR" && strings.ToUpper(t.Value) != "NOT" { break } if t.Type == sqllexer.IDENT { workingConditional.Key = t.Value } else if t.Type == sqllexer.OPERATOR { workingConditional.Operator = t.Value } else if t.Type == sqllexer.BOOLEAN || t.Type == sqllexer.NULL || t.Type == sqllexer.STRING || t.Type == sqllexer.NUMBER { workingConditional.Value = t.Value } else if t.Type == sqllexer.KEYWORD { if strings.ToUpper(t.Value) == "AND" || strings.ToUpper(t.Value) == "OR" { workingConditional.Extension = strings.ToUpper(t.Value) } } if workingConditional.Key != "" && workingConditional.Operator != "" && workingConditional.Value != "" { selectQuery.Conditionals = append(selectQuery.Conditionals, workingConditional) workingConditional = Conditional{} } } return nil } func parseOrderBys(p *parser) error { selectQuery := p.query.(*Select) _, byKeywordIndex := p.findToken(func(t Token) bool { return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "BY" }) if byKeywordIndex < 0 { return nil } var orderBys []OrderBy var workingOrderBy OrderBy for i := byKeywordIndex + 1; i < len(p.tokens); i++ { t := &p.tokens[i] if t.Type == sqllexer.SPACE { continue } else if t.Type == sqllexer.IDENT && workingOrderBy.Key == "" { workingOrderBy.Key = t.Value continue } else if t.Type == sqllexer.IDENT && workingOrderBy.Key != "" { orderBys = append(orderBys, workingOrderBy) workingOrderBy.Key = t.Value continue } else if t.Type == sqllexer.KEYWORD { if t.Value == "DESC" { workingOrderBy.IsDescend = true } else if t.Value != "ASC" { break } } else if t.Type == sqllexer.PUNCTUATION { orderBys = append(orderBys, workingOrderBy) workingOrderBy = OrderBy{} continue } } if workingOrderBy.Key != "" { orderBys = append(orderBys, workingOrderBy) } selectQuery.OrderBys = orderBys return nil } func parseJoins(p *parser) error { selectQuery := p.query.(*Select) foundJoinKeywords := p.findAllTokens(func(t Token) bool { return t.Type == sqllexer.COMMAND && strings.ToUpper(t.Value) == "JOIN" }) if len(foundJoinKeywords) <= 0 { return nil } var joinTokenRanges []Token for i := 0; i < len(foundJoinKeywords); i++ { startRangeIndex := foundJoinKeywords[i].Index var endRangeIndex int if i == (len(foundJoinKeywords) - 1) { endRangeIndex = len(p.tokens) - 1 } else { endRangeIndex = foundJoinKeywords[i+1].Index } joinTokenRanges = append(joinTokenRanges, p.tokens[startRangeIndex:endRangeIndex]...) } return nil }