package q import ( "fmt" "strings" "github.com/DataDog/go-sqllexer" ) // Find first token in array that matches condition func FindTokenInArray(tokens []Token, condition func(Token) bool) (*Token, int) { for i, element := range tokens { if condition(element) { return &element, i } } return nil, -1 } type Token = sqllexer.Token type parser struct { sql string tokens []Token index int query Query lookBehindBuffer []Token } func (p *parser) lookAhead(count int) Token { return p.tokens[p.index+count] } func (p *parser) lookBehind(count int) Token { return p.tokens[p.index-count] } // Returns pointer of the first found Token and its index in parser.Tokens // Returns -1 as the index if not found func (p *parser) findToken(condition func(Token) bool) (*Token, int) { for i, element := range p.tokens { if condition(element) { return &element, i } } return nil, -1 } type FoundToken struct { Token Token Index int } // Returns all tokens that match the condition, along with their indices func (p *parser) findAllTokens(condition func(Token) bool) []FoundToken { matches := make([]FoundToken, 0) for i, token := range p.tokens { if condition(token) { matches = append(matches, FoundToken{ Token: token, Index: i, }) } } return matches } func filter[T any](items []T, fn func(item T) bool) []T { filteredItems := []T{} for _, value := range items { if fn(value) { filteredItems = append(filteredItems, value) } } return filteredItems } func unshiftBuffer(buf *[10]sqllexer.Token, value sqllexer.Token) { for i := 9; i >= 1; i-- { buf[i] = buf[i-1] } buf[0] = value } func Parse(sql string) (Query, error) { lexer := sqllexer.New(sql) var tokens []Token for { t := lexer.Scan() tokens = append(tokens, *t) if IsTokenEndOfStatement(t) { break } } p := parser{ sql: sql, tokens: tokens, } queryType, err := findQueryType(&p) if err != nil { return p.query, err } switch strings.ToUpper(queryType) { case "SELECT": p.query = &Select{ Type: SELECT, } parseSelectStatement(&p) default: return p.query, fmt.Errorf("No process defined for determined queryType: %s", queryType) } return p.query, nil } func findQueryType(p *parser) (string, error) { firstCommand, _ := p.findToken(func(t Token) bool { return t.Type == sqllexer.COMMAND }) if firstCommand == nil { return "", fmt.Errorf("Could not find query type") } return firstCommand.Value, nil } func parseSelectStatement(p *parser) error { selectQuery := p.query.(*Select) distinctErr := parseDistinct(p) if distinctErr != nil { return distinctErr } foundWildcard, _ := p.findToken(func(t Token) bool { return t.Type == sqllexer.WILDCARD }) selectQuery.IsWildcard = foundWildcard != nil if !selectQuery.IsWildcard { err := parseSelectColumns(p) if err != nil { fmt.Println(err.Error()) return err } } tableErr := parseSelectTable(p) if tableErr != nil { return tableErr } conditionalsErr := parseSelectConditionals(p) if conditionalsErr != nil { return conditionalsErr } ordersByErr := parseOrderBys(p) if ordersByErr != nil { return ordersByErr } parseJoins(p) return nil } func parseDistinct(p *parser) error { selectQuery := p.query.(*Select) foundDistinctKeyword, _ := p.findToken(func(t Token) bool { return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "DISTINCT" }) if foundDistinctKeyword != nil { selectQuery.IsDistinct = true } return nil } func parseSelectColumns(p *parser) error { selectQuery := p.query.(*Select) _, selectCommandIndex := p.findToken(func(t Token) bool { return t.Type == sqllexer.COMMAND && strings.ToUpper(t.Value) == "SELECT" }) _, fromKeywordIndex := p.findToken(func(t Token) bool { return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "FROM" }) if selectCommandIndex < 0 || fromKeywordIndex < 0 { return fmt.Errorf("Could not find range between SELECT and FROM") } lookBehindBuffer := [10]Token{} var workingColumn Column columns := make([]Column, 0) startRange := selectCommandIndex + 1 endRange := fromKeywordIndex - 1 for i := startRange; i <= endRange; i++ { token := p.tokens[i] if token.Type == sqllexer.FUNCTION { unshiftBuffer(&lookBehindBuffer, token) workingColumn.AggregateFunction = AggregateFunctionTypeByName(token.Value) continue } else if token.Type == sqllexer.PUNCTUATION && token.Value == "," { columns = append(columns, workingColumn) workingColumn = Column{} continue } else if token.Type == sqllexer.IDENT { unshiftBuffer(&lookBehindBuffer, token) if lookBehindBuffer[1].Type == sqllexer.ALIAS_INDICATOR { workingColumn.Alias = token.Value } else { workingColumn.Name = token.Value } continue } else if token.Type == sqllexer.ALIAS_INDICATOR { unshiftBuffer(&lookBehindBuffer, token) continue } else if i == endRange { if workingColumn.Name != "" { columns = append(columns, workingColumn) workingColumn = Column{} } } } selectQuery.Columns = columns return nil } func parseSelectTable(p *parser) error { selectQuery := p.query.(*Select) _, fromKeywordIndex := p.findToken(func(t Token) bool { return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "FROM" }) if fromKeywordIndex < 0 { return fmt.Errorf("Could not FROM keyword to look for table name") } var foundTable Table for i := fromKeywordIndex + 1; i < len(p.tokens); i++ { t := &p.tokens[i] if foundTable.Name == "" && t.Type == sqllexer.IDENT { foundTable.Name = p.tokens[i].Value continue } else if t.Type == sqllexer.IDENT { foundTable.Alias = p.tokens[i].Value break } else if t.Type == sqllexer.SPACE || t.Type == sqllexer.ALIAS_INDICATOR { continue } else { break } } if foundTable.Name == "" { return fmt.Errorf("Could not find table name") } selectQuery.Table = foundTable return nil } func parseSelectConditionals(p *parser) error { selectQuery := p.query.(*Select) _, whereKeywordIndex := p.findToken(func(t Token) bool { return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "WHERE" }) if whereKeywordIndex < 0 { return nil // fmt.Errorf("Could not find WHERE to look for conditionals") } var workingConditional Conditional for i := whereKeywordIndex + 1; i < len(p.tokens); i++ { t := &p.tokens[i] if t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) != "AND" && strings.ToUpper(t.Value) != "OR" && strings.ToUpper(t.Value) != "NOT" { break } if t.Type == sqllexer.IDENT { workingConditional.Key = t.Value } else if t.Type == sqllexer.OPERATOR { workingConditional.Operator = t.Value } else if t.Type == sqllexer.BOOLEAN || t.Type == sqllexer.NULL || t.Type == sqllexer.STRING || t.Type == sqllexer.NUMBER { workingConditional.Value = t.Value } else if t.Type == sqllexer.KEYWORD { if strings.ToUpper(t.Value) == "AND" || strings.ToUpper(t.Value) == "OR" { workingConditional.Extension = strings.ToUpper(t.Value) } } if workingConditional.Key != "" && workingConditional.Operator != "" && workingConditional.Value != "" { selectQuery.Conditionals = append(selectQuery.Conditionals, workingConditional) workingConditional = Conditional{} } } return nil } func parseOrderBys(p *parser) error { selectQuery := p.query.(*Select) _, byKeywordIndex := p.findToken(func(t Token) bool { return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "BY" }) if byKeywordIndex < 0 { return nil } var orderBys []OrderBy var workingOrderBy OrderBy for i := byKeywordIndex + 1; i < len(p.tokens); i++ { t := &p.tokens[i] if t.Type == sqllexer.SPACE { continue } else if t.Type == sqllexer.IDENT && workingOrderBy.Key == "" { workingOrderBy.Key = t.Value continue } else if t.Type == sqllexer.IDENT && workingOrderBy.Key != "" { orderBys = append(orderBys, workingOrderBy) workingOrderBy.Key = t.Value continue } else if t.Type == sqllexer.KEYWORD { if t.Value == "DESC" { workingOrderBy.IsDescend = true } else if t.Value != "ASC" { break } } else if t.Type == sqllexer.PUNCTUATION { orderBys = append(orderBys, workingOrderBy) workingOrderBy = OrderBy{} continue } } if workingOrderBy.Key != "" { orderBys = append(orderBys, workingOrderBy) } selectQuery.OrderBys = orderBys return nil } func parseJoins(p *parser) error { selectQuery := p.query.(*Select) foundJoinKeywords := p.findAllTokens(func(t Token) bool { return t.Type == sqllexer.COMMAND && strings.ToUpper(t.Value) == "JOIN" }) if len(foundJoinKeywords) <= 0 { return nil } type FoundJoinSubslices struct { Tokens []Token StartingIndexInGreaterStatement int } var joinTokenRanges []FoundJoinSubslices for i, foundJoin := range foundJoinKeywords { startRangeIndex := foundJoin.Index var endRangeIndex int if i == (len(foundJoinKeywords) - 1) { endRangeIndex = len(p.tokens) - 1 } else { endRangeIndex = foundJoinKeywords[i+1].Index } joinTokenRanges = append(joinTokenRanges, FoundJoinSubslices{ Tokens: p.tokens[startRangeIndex:endRangeIndex], StartingIndexInGreaterStatement: startRangeIndex, }) } for _, joinRange := range joinTokenRanges { var workingJoin Join workingJoin.MainTable = selectQuery.Table // check for the join type by looking backwards in the greater statement joinTypeSearchIndex := joinRange.StartingIndexInGreaterStatement - 1 for ; joinTypeSearchIndex >= 0; joinTypeSearchIndex-- { if p.tokens[joinTypeSearchIndex].Type == sqllexer.KEYWORD { switch strings.ToUpper(p.tokens[joinTypeSearchIndex].Value) { case "LEFT": workingJoin.Type = LEFT break case "RIGHT": workingJoin.Type = RIGHT break case "FULL": workingJoin.Type = FULL break case "SELF": workingJoin.Type = SELF break case "INNER": workingJoin.Type = INNER default: workingJoin.Type = INNER } break // Stop after finding first keyword } } // Find joined table name for i := 1; i < len(joinRange.Tokens); i++ { if joinRange.Tokens[i].Type == sqllexer.IDENT { workingJoin.JoiningTable.Name = joinRange.Tokens[i].Value break // Stop after finding first IDENT // TODO: make sure you dont have to check for aliases } } //var ons []Conditional var workingOn Conditional _, foundOnTokenIndex := FindTokenInArray(joinRange.Tokens, func(t Token) bool { return t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) == "ON" }) if foundOnTokenIndex < 0 { selectQuery.Joins = append(selectQuery.Joins, workingJoin) continue } for i := foundOnTokenIndex + 1; i < len(joinRange.Tokens); i++ { t := &joinRange.Tokens[i] if t.Type == sqllexer.KEYWORD && strings.ToUpper(t.Value) != "AND" && strings.ToUpper(t.Value) != "OR" && strings.ToUpper(t.Value) != "NOT" { break } if t.Type == sqllexer.IDENT { if workingOn.Key == "" { workingOn.Key = t.Value } else { workingOn.Value = t.Value } } else if t.Type == sqllexer.OPERATOR { workingOn.Operator = t.Value } else if t.Type == sqllexer.BOOLEAN || t.Type == sqllexer.NULL || t.Type == sqllexer.STRING || t.Type == sqllexer.NUMBER { workingOn.Value = t.Value } else if t.Type == sqllexer.KEYWORD { if strings.ToUpper(t.Value) == "AND" || strings.ToUpper(t.Value) == "OR" { workingOn.Extension = strings.ToUpper(t.Value) } } if workingOn.Key != "" && workingOn.Operator != "" && workingOn.Value != "" { workingJoin.Ons = append(workingJoin.Ons, workingOn) workingOn = Conditional{} } } selectQuery.Joins = append(selectQuery.Joins, workingJoin) workingJoin = Join{} } return nil }