Skip to content

Commit 806af86

Browse files
Jan Waśnineinchnick
authored andcommitted
Don't fail when decoding empty strings as floats in server responses
1 parent d64a6b6 commit 806af86

File tree

2 files changed

+80
-23
lines changed

2 files changed

+80
-23
lines changed

trino/trino.go

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -713,27 +713,27 @@ type stmtResponse struct {
713713
}
714714

715715
type stmtStats struct {
716-
State string `json:"state"`
717-
Scheduled bool `json:"scheduled"`
718-
Nodes int `json:"nodes"`
719-
TotalSplits int `json:"totalSplits"`
720-
QueuesSplits int `json:"queuedSplits"`
721-
RunningSplits int `json:"runningSplits"`
722-
CompletedSplits int `json:"completedSplits"`
723-
UserTimeMillis int `json:"userTimeMillis"`
724-
CPUTimeMillis int64 `json:"cpuTimeMillis"`
725-
WallTimeMillis int64 `json:"wallTimeMillis"`
726-
QueuedTimeMillis int64 `json:"queuedTimeMillis"`
727-
ElapsedTimeMillis int64 `json:"elapsedTimeMillis"`
728-
ProcessedRows int64 `json:"processedRows"`
729-
ProcessedBytes int64 `json:"processedBytes"`
730-
PhysicalInputBytes int64 `json:"physicalInputBytes"`
731-
PhysicalWrittenBytes int64 `json:"physicalWrittenBytes"`
732-
PeakMemoryBytes int64 `json:"peakMemoryBytes"`
733-
SpilledBytes int64 `json:"spilledBytes"`
734-
RootStage stmtStage `json:"rootStage"`
735-
ProgressPercentage float32 `json:"progressPercentage"`
736-
RunningPercentage float32 `json:"runningPercentage"`
716+
State string `json:"state"`
717+
Scheduled bool `json:"scheduled"`
718+
Nodes int `json:"nodes"`
719+
TotalSplits int `json:"totalSplits"`
720+
QueuesSplits int `json:"queuedSplits"`
721+
RunningSplits int `json:"runningSplits"`
722+
CompletedSplits int `json:"completedSplits"`
723+
UserTimeMillis int `json:"userTimeMillis"`
724+
CPUTimeMillis int64 `json:"cpuTimeMillis"`
725+
WallTimeMillis int64 `json:"wallTimeMillis"`
726+
QueuedTimeMillis int64 `json:"queuedTimeMillis"`
727+
ElapsedTimeMillis int64 `json:"elapsedTimeMillis"`
728+
ProcessedRows int64 `json:"processedRows"`
729+
ProcessedBytes int64 `json:"processedBytes"`
730+
PhysicalInputBytes int64 `json:"physicalInputBytes"`
731+
PhysicalWrittenBytes int64 `json:"physicalWrittenBytes"`
732+
PeakMemoryBytes int64 `json:"peakMemoryBytes"`
733+
SpilledBytes int64 `json:"spilledBytes"`
734+
RootStage stmtStage `json:"rootStage"`
735+
ProgressPercentage jsonFloat64 `json:"progressPercentage"`
736+
RunningPercentage jsonFloat64 `json:"runningPercentage"`
737737
}
738738

739739
type ErrTrino struct {
@@ -792,6 +792,28 @@ type stmtStage struct {
792792
SubStages []stmtStage `json:"subStages"`
793793
}
794794

795+
type jsonFloat64 float64
796+
797+
func (f *jsonFloat64) UnmarshalJSON(data []byte) error {
798+
var v float64
799+
err := json.Unmarshal(data, &v)
800+
if err != nil {
801+
var jsonErr *json.UnmarshalTypeError
802+
if errors.As(err, &jsonErr) {
803+
if f != nil {
804+
*f = 0
805+
}
806+
return nil
807+
}
808+
return err
809+
}
810+
p := (*float64)(f)
811+
*p = v
812+
return nil
813+
}
814+
815+
var _ json.Unmarshaler = new(jsonFloat64)
816+
795817
func (st *driverStmt) Query(args []driver.Value) (driver.Rows, error) {
796818
return nil, driver.ErrSkip
797819
}

trino/trino_test.go

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,34 @@ func TestRoundTripRetryQueryError(t *testing.T) {
250250
assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err)
251251
}
252252

253+
func TestRoundTripBogusData(t *testing.T) {
254+
count := 0
255+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
256+
if count == 0 {
257+
count++
258+
w.WriteHeader(http.StatusServiceUnavailable)
259+
return
260+
}
261+
w.WriteHeader(http.StatusOK)
262+
// some invalid JSON
263+
w.Write([]byte(`{"stats": {"progressPercentage": ""}}`))
264+
}))
265+
266+
t.Cleanup(ts.Close)
267+
268+
db, err := sql.Open("trino", ts.URL)
269+
require.NoError(t, err)
270+
271+
t.Cleanup(func() {
272+
assert.NoError(t, db.Close())
273+
})
274+
275+
rows, err := db.Query("SELECT 1")
276+
require.NoError(t, err)
277+
assert.False(t, rows.Next())
278+
require.NoError(t, rows.Err())
279+
}
280+
253281
func TestRoundTripCancellation(t *testing.T) {
254282
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
255283
w.WriteHeader(http.StatusServiceUnavailable)
@@ -336,10 +364,12 @@ func TestQueryForUsername(t *testing.T) {
336364
}
337365

338366
type TestQueryProgressCallback struct {
339-
statusMap map[time.Time]string
367+
progressMap map[time.Time]float64
368+
statusMap map[time.Time]string
340369
}
341370

342371
func (qpc *TestQueryProgressCallback) Update(qpi QueryProgressInfo) {
372+
qpc.progressMap[time.Now()] = float64(qpi.QueryStats.ProgressPercentage)
343373
qpc.statusMap[time.Now()] = qpi.QueryStats.State
344374
}
345375

@@ -387,9 +417,11 @@ func TestQueryProgressWithCallbackPeriod(t *testing.T) {
387417
assert.NoError(t, db.Close())
388418
})
389419

420+
progressMap := make(map[time.Time]float64)
390421
statusMap := make(map[time.Time]string)
391422
progressUpdater := &TestQueryProgressCallback{
392-
statusMap: statusMap,
423+
progressMap: progressMap,
424+
statusMap: statusMap,
393425
}
394426
progressUpdaterPeriod, err := time.ParseDuration("1ms")
395427
require.NoError(t, err)
@@ -416,6 +448,8 @@ func TestQueryProgressWithCallbackPeriod(t *testing.T) {
416448
}
417449

418450
// sort time in order to calculate interval
451+
assert.NotEmpty(t, progressMap)
452+
assert.NotEmpty(t, statusMap)
419453
var keys []time.Time
420454
for k := range statusMap {
421455
keys = append(keys, k)
@@ -428,6 +462,7 @@ func TestQueryProgressWithCallbackPeriod(t *testing.T) {
428462
if i > 0 {
429463
assert.GreaterOrEqual(t, k.Sub(keys[i-1]), progressUpdaterPeriod)
430464
}
465+
assert.GreaterOrEqual(t, progressMap[k], 0.0)
431466
}
432467
}
433468

0 commit comments

Comments
 (0)