Skip to content

Commit 902e93f

Browse files
committed
ISSUE #627: convert typed-nils to untyped-nils on error
1 parent 0e7685c commit 902e93f

File tree

3 files changed

+59
-4
lines changed

3 files changed

+59
-4
lines changed

client.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,19 @@ func (cl *Client) hasExtension(ext *sshfx.ExtensionPair) bool {
892892
return cl.exts[ext.Name] == ext.Data
893893
}
894894

895+
// StatVFS retrieves VFS statistics from a remote host.
896+
//
897+
// It implements the [email protected] SSH_FXP_EXTENDED feature from
898+
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL
899+
func (cl *Client) StatVFS(path string) (*openssh.StatVFSExtendedReplyPacket, error) {
900+
resp, err := getPacket[*openssh.StatVFSExtendedReplyPacket](context.Background(), nil, cl,
901+
&openssh.StatVFSExtendedPacket{
902+
Path: path,
903+
},
904+
)
905+
return valOrPathError("statvfs", path, resp, err)
906+
}
907+
895908
// Link creates newname as a hard link to oldname file.
896909
//
897910
// If the server did not announce support for the "[email protected]" extension,

localfs/localfs_integration_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,30 @@ func TestReadFrom(t *testing.T) {
751751
}
752752
}
753753

754+
func TestStatVFS(t *testing.T) {
755+
if !*testServerImpl {
756+
t.Skip("not testing against localfs server implementation")
757+
}
758+
759+
if _, ok := any(handler).(sftp.StatVFSServerHandler); !ok {
760+
t.Skip("handler does not implement statvfs")
761+
}
762+
763+
dir := t.TempDir()
764+
765+
targetNotExist := filepath.Join(dir, "statvfs-does-not-exist")
766+
767+
_, err := cl.StatVFS(toRemotePath(targetNotExist))
768+
if !errors.Is(err, fs.ErrNotExist) {
769+
t.Fatalf("unexpected error, got %v, should be fs.ErrNotFound", err)
770+
}
771+
772+
_, err = cl.StatVFS(toRemotePath(dir))
773+
if err != nil {
774+
t.Fatal(err)
775+
}
776+
}
777+
754778
var benchBuf []byte
755779

756780
func benchHelperWriteTo(b *testing.B, length int) {

server.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,12 @@ func Hijack[REQ sshfx.Packet](srv *Server, fn func(context.Context, REQ) error)
344344
// This is really only useful for supporting newer versions of the SFTP standard.
345345
func HijackWithResponse[REQ, RESP sshfx.Packet](srv *Server, fn func(context.Context, REQ) (RESP, error)) error {
346346
wrap := wrapHandler(func(ctx context.Context, req sshfx.Packet) (sshfx.Packet, error) {
347-
return fn(ctx, req.(REQ))
347+
resp, err := fn(ctx, req.(REQ))
348+
if err != nil {
349+
// We have to convert maybe typed-zero to untyped-nil.
350+
return nil, err
351+
}
352+
return resp, nil
348353
})
349354

350355
var pkt REQ
@@ -515,6 +520,7 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh
515520

516521
if len(srv.hijacks) > 0 {
517522
if fn := srv.hijacks[req.Type()]; fn != nil {
523+
// Hijack takes care of wrapping the getter into an untyped-nil on error.
518524
return get(srv, req, fn)
519525
}
520526
}
@@ -595,7 +601,13 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh
595601

596602
case *openssh.StatVFSExtendedPacket:
597603
if statvfser, ok := srv.Handler.(StatVFSServerHandler); ok {
598-
return get(srv, req, statvfser.StatVFS)
604+
resp, err := get(srv, req, statvfser.StatVFS)
605+
if err != nil {
606+
// We have to convert typed-nil to untyped-nil.
607+
return nil, err
608+
}
609+
610+
return resp, nil
599611
}
600612

601613
case interface{ GetHandle() string }:
@@ -618,7 +630,13 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh
618630
Path: file.Name(),
619631
}
620632

621-
return get(srv, req, statvfser.StatVFS)
633+
resp, err := get(srv, req, statvfser.StatVFS)
634+
if err != nil {
635+
// We have to convert typed-nil to untyped-nil.
636+
return nil, err
637+
}
638+
639+
return resp, nil
622640
}
623641
}
624642
}
@@ -701,7 +719,7 @@ func (srv *Server) handle(req sshfx.Packet, hint []byte, maxDataLen uint32) (ssh
701719
}
702720

703721
hint = slices.Grow(hint[:0], int(req.Length))[:req.Length]
704-
722+
705723
n, err := file.ReadAt(hint, int64(req.Offset))
706724
if err != nil {
707725
// We cannot return results AND a status like SSH_FX_EOF,

0 commit comments

Comments
 (0)