Skip to content

Fix panic in merge join when using custom Indexes that don't allow range lookups. #1985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion enginetest/engine_only_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ func TestTableFunctions(t *testing.T) {
harness.Setup(setup.MydbData)

databaseProvider := harness.NewDatabaseProvider()
testDatabaseProvider := NewTestProvider(&databaseProvider, SimpleTableFunction{}, memory.IntSequenceTable{})
testDatabaseProvider := NewTestProvider(&databaseProvider, SimpleTableFunction{}, memory.IntSequenceTable{}, memory.PointLookupTable{})

engine := enginetest.NewEngineWithProvider(t, harness, testDatabaseProvider)
engine.EngineAnalyzer().ExecBuilder = rowexec.DefaultBuilder
Expand Down
24 changes: 24 additions & 0 deletions enginetest/queries/table_func_scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,32 @@ var TableFunctionScriptTests = []ScriptTest{
Expected: []sql.Row{{0}, {1}, {2}, {3}, {4}},
},
{
Name: "sequence_table allows point lookups",
Query: "select * from sequence_table('x', 5) where x = 2",
Expected: []sql.Row{{2}},
ExpectedIndexes: []string{"x"},
},
{
Name: "sequence_table allows range lookups",
Query: "select * from sequence_table('x', 5) where x >= 1 and x <= 3",
Expected: []sql.Row{{1}, {2}, {3}},
ExpectedIndexes: []string{"x"},
},
{
Name: "basic behavior of point_lookup_table",
Query: "select seq.x from point_lookup_table('x', 5) seq",
Expected: []sql.Row{{0}, {1}, {2}, {3}, {4}},
},
{
Name: "point_lookup_table allows point lookups",
Query: "select * from point_lookup_table('x', 5) where x = 2",
Expected: []sql.Row{{2}},
ExpectedIndexes: []string{"x"},
},
{
Name: "point_lookup_table disallows range lookups",
Query: "select * from point_lookup_table('x', 5) where x >= 1 and x <= 3",
Expected: []sql.Row{{1}, {2}, {3}},
ExpectedIndexes: []string{},
},
}
104 changes: 104 additions & 0 deletions memory/point_lookup_table.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package memory

import (
"fmt"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/types"
)

var _ sql.TableFunction = PointLookupTable{}
var _ sql.CollationCoercible = PointLookupTable{}
var _ sql.ExecSourceRel = PointLookupTable{}
var _ sql.IndexAddressable = PointLookupTable{}
var _ sql.IndexedTable = PointLookupTable{}
var _ sql.TableNode = PointLookupTable{}

// PointLookupTable is a table whose indexes only support point lookups but not range scans.
// It's used for testing optimizations on indexes.
type PointLookupTable struct {
IntSequenceTable
}

func (s PointLookupTable) UnderlyingTable() sql.Table {
return s
}

func (s PointLookupTable) NewInstance(ctx *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
node, err := s.IntSequenceTable.NewInstance(ctx, db, args)
return PointLookupTable{node.(IntSequenceTable)}, err
}

func (s PointLookupTable) String() string {
return fmt.Sprintf("pointLookup")
}

func (s PointLookupTable) DebugString() string {
return "pointLookup"
}

func (s PointLookupTable) Name() string {
return "point_lookup_table"
}

func (s PointLookupTable) Description() string {
return "point_lookup_table"
}

var _ sql.Partition = (*sequencePartition)(nil)

func (s PointLookupTable) GetIndexes(ctx *sql.Context) (indexes []sql.Index, err error) {
return []sql.Index{
pointLookupIndex{&Index{
DB: "",
DriverName: "",
Tbl: nil,
TableName: s.Name(),
Exprs: []sql.Expression{
expression.NewGetFieldWithTable(0, types.Int64, s.Name(), s.name, false),
},
Name: s.name,
Unique: true,
Spatial: false,
Fulltext: false,
CommentStr: "",
PrefixLens: nil,
fulltextInfo: fulltextInfo{},
}},
}, nil
}

type pointLookupIndex struct {
sql.Index
}

func (i pointLookupIndex) CanSupport(ranges ...sql.Range) bool {
for _, r := range ranges {
if len(r) != 1 {
return false
}
below, ok := r[0].LowerBound.(sql.Below)
if !ok {
return false
}
belowKey, _, err := types.Int64.Convert(below.Key)
if err != nil {
return false
}

above, ok := r[0].UpperBound.(sql.Above)
if !ok {
return false
}
aboveKey, _, err := types.Int64.Convert(above.Key)
if err != nil {
return false
}

if belowKey != aboveKey {
return false
}
}
return true
}
6 changes: 6 additions & 0 deletions sql/analyzer/indexed_joins.go
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,13 @@ func addMergeJoins(m *memo.Memo) error {
leftColId := getOnlyColumnId(matchedEqFilters[0].filter.Left)
rightColId := getOnlyColumnId(matchedEqFilters[0].filter.Right)
lIndexScan := makeIndexScan(lIndex, leftColId, lFilters)
if lIndexScan == nil {
continue
}
rIndexScan := makeIndexScan(rIndex, rightColId, rFilters)
if rIndexScan == nil {
continue
}
m.MemoizeMergeJoin(e.Group(), join.Left, join.Right, lIndexScan, rIndexScan, jb.Op.AsMerge(), newFilters, false)
}
}
Expand Down