Skip to content

Cleanup | SNI Native Wrapper #3156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ private SNIErrorDetails GetSniErrorDetails()
}
else
{
SniNativeWrapper.SNIGetLastError(out SniError sniError);
SniNativeWrapper.SniGetLastError(out SniError sniError);
details.sniErrorNumber = sniError.sniError;
details.errorMessage = sniError.errorMessage;
details.nativeError = sniError.nativeError;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ internal abstract void CreatePhysicalSNIHandle(

internal abstract void ReleasePacket(PacketHandle syncReadPacket);

protected abstract uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize);
protected abstract uint SniPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize);

internal abstract PacketHandle GetResetWritePacket(int dataSize);

Expand Down Expand Up @@ -401,7 +401,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error)

private uint GetSniPacket(PacketHandle packet, ref uint dataSize)
{
return SNIPacketGetData(packet, _inBuff, ref dataSize);
return SniPacketGetData(packet, _inBuff, ref dataSize);
}

private void ChangeNetworkPacketTimeout(int dueTime, int period)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ internal SNIMarsHandle CreateMarsSession(object callbackObject, bool async)
/// <param name="inBuff">Destination byte array where data packets are copied to</param>
/// <param name="dataSize">Length of data packets</param>
/// <returns>SNI error status</returns>
protected override uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize)
protected override uint SniPacketGetData(PacketHandle packet, byte[] inBuff, ref uint dataSize)
{
int dataSizeInt = 0;
packet.ManagedPacket.GetData(inBuff, ref dataSizeInt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ internal override void AssignPendingDNSInfo(string userProtocol, string DNSCache
result = SniNativeWrapper.SniGetConnectionPort(Handle, ref portFromSNI);
Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionPort");

result = SniNativeWrapper.SniGetConnectionIPString(Handle, ref IPStringFromSNI);
result = SniNativeWrapper.SniGetConnectionIpString(Handle, ref IPStringFromSNI);
Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionIPString");

pendingDNSInfo = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI.ToString());
Expand Down Expand Up @@ -181,10 +181,10 @@ internal override void CreatePhysicalSNIHandle(
spns = new[] { serverSPN.TrimEnd() };
}

protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
protected override uint SniPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
{
Debug.Assert(packet.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");
return SniNativeWrapper.SNIPacketGetData(packet.NativePointer, _inBuff, ref dataSize);
return SniNativeWrapper.SniPacketGetData(packet.NativePointer, _inBuff, ref dataSize);
}

protected override bool CheckPacket(PacketHandle packet, TaskCompletionSource<object> source)
Expand Down Expand Up @@ -264,7 +264,7 @@ internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint
throw ADP.ClosedConnectionError();
}
IntPtr readPacketPtr = IntPtr.Zero;
error = SniNativeWrapper.SNIReadSyncOverAsync(handle, ref readPacketPtr, GetTimeoutRemaining());
error = SniNativeWrapper.SniReadSyncOverAsync(handle, ref readPacketPtr, GetTimeoutRemaining());
return PacketHandle.FromNativePointer(readPacketPtr);
}

Expand All @@ -281,20 +281,20 @@ internal override bool IsPacketEmpty(PacketHandle readPacket)
internal override void ReleasePacket(PacketHandle syncReadPacket)
{
Debug.Assert(syncReadPacket.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");
SniNativeWrapper.SNIPacketRelease(syncReadPacket.NativePointer);
SniNativeWrapper.SniPacketRelease(syncReadPacket.NativePointer);
}

internal override uint CheckConnection()
{
SNIHandle handle = Handle;
return handle == null ? TdsEnums.SNI_SUCCESS : SniNativeWrapper.SNICheckConnection(handle);
return handle == null ? TdsEnums.SNI_SUCCESS : SniNativeWrapper.SniCheckConnection(handle);
}

internal override PacketHandle ReadAsync(SessionHandle handle, out uint error)
{
Debug.Assert(handle.Type == SessionHandle.NativeHandleType, "unexpected handle type when requiring NativePointer");
IntPtr readPacketPtr = IntPtr.Zero;
error = SniNativeWrapper.SNIReadAsync(handle.NativeHandle, ref readPacketPtr);
error = SniNativeWrapper.SniReadAsync(handle.NativeHandle, ref readPacketPtr);
return PacketHandle.FromNativePointer(readPacketPtr);
}

Expand All @@ -310,7 +310,7 @@ internal override PacketHandle CreateAndSetAttentionPacket()
internal override uint WritePacket(PacketHandle packet, bool sync)
{
Debug.Assert(packet.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket");
return SniNativeWrapper.SNIWritePacket(Handle, packet.NativePacket, sync);
return SniNativeWrapper.SniWritePacket(Handle, packet.NativePacket, sync);
}

internal override PacketHandle AddPacketToPendingList(PacketHandle packetToAdd)
Expand Down Expand Up @@ -343,7 +343,7 @@ internal override PacketHandle GetResetWritePacket(int dataSize)
{
if (_sniPacket != null)
{
SniNativeWrapper.SNIPacketReset(Handle, IoType.WRITE, _sniPacket, ConsumerNumber.SNI_Consumer_SNI);
SniNativeWrapper.SniPacketReset(Handle, IoType.WRITE, _sniPacket, ConsumerNumber.SNI_Consumer_SNI);
}
else
{
Expand Down Expand Up @@ -372,17 +372,17 @@ internal override void ClearAllWritePackets()
internal override void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed)
{
Debug.Assert(packet.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket");
SniNativeWrapper.SNIPacketSetData(packet.NativePacket, buffer, bytesUsed);
SniNativeWrapper.SniPacketSetData(packet.NativePacket, buffer, bytesUsed);
}

internal override uint SniGetConnectionId(ref Guid clientConnectionId)
=> SniNativeWrapper.SniGetConnectionId(Handle, ref clientConnectionId);

internal override uint DisableSsl()
=> SniNativeWrapper.SNIRemoveProvider(Handle, Provider.SSL_PROV);
=> SniNativeWrapper.SniRemoveProvider(Handle, Provider.SSL_PROV);

internal override uint EnableMars(ref uint info)
=> SniNativeWrapper.SNIAddProvider(Handle, Provider.SMUX_PROV, ref info);
=> SniNativeWrapper.SniAddProvider(Handle, Provider.SMUX_PROV, ref info);

internal override uint EnableSsl(ref uint info, bool tlsFirst, string serverCertificateFilename)
{
Expand All @@ -392,15 +392,15 @@ internal override uint EnableSsl(ref uint info, bool tlsFirst, string serverCert
authInfo.serverCertFileName = serverCertificateFilename;

// Add SSL (Encryption) SNI provider.
return SniNativeWrapper.SNIAddProvider(Handle, Provider.SSL_PROV, ref authInfo);
return SniNativeWrapper.SniAddProvider(Handle, Provider.SSL_PROV, ref authInfo);
}

internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize)
=> SniNativeWrapper.SNISetInfo(Handle, QueryType.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize);
=> SniNativeWrapper.SniSetInfo(Handle, QueryType.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize);

internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion)
{
uint returnValue = SniNativeWrapper.SNIWaitForSSLHandshakeToComplete(Handle, GetTimeoutRemaining(), out uint nativeProtocolVersion);
uint returnValue = SniNativeWrapper.SniWaitForSslHandshakeToComplete(Handle, GetTimeoutRemaining(), out uint nativeProtocolVersion);
var nativeProtocol = (NativeProtocols)nativeProtocolVersion;

#pragma warning disable CA5398 // Avoid hardcoded SslProtocols values
Expand Down Expand Up @@ -469,7 +469,7 @@ public SNIPacket Take(SNIHandle sniHandle)
{
// Success - reset the packet
packet = _packets.Pop();
SniNativeWrapper.SNIPacketReset(sniHandle, IoType.WRITE, packet, ConsumerNumber.SNI_Consumer_SNI);
SniNativeWrapper.SniPacketReset(sniHandle, IoType.WRITE, packet, ConsumerNumber.SNI_Consumer_SNI);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ internal void RemoveEncryption()
uint error = 0;

// Remove SSL (Encryption) SNI provider since we only wanted to encrypt login.
error = SniNativeWrapper.SNIRemoveProvider(_physicalStateObj.Handle, Provider.SSL_PROV);
error = SniNativeWrapper.SniRemoveProvider(_physicalStateObj.Handle, Provider.SSL_PROV);
if (error != TdsEnums.SNI_SUCCESS)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
Expand All @@ -726,7 +726,7 @@ internal void EnableMars()
uint info = 0;

// Add SMUX (MARS) SNI provider.
error = SniNativeWrapper.SNIAddProvider(_pMarsPhysicalConObj.Handle, Provider.SMUX_PROV, ref info);
error = SniNativeWrapper.SniAddProvider(_pMarsPhysicalConObj.Handle, Provider.SMUX_PROV, ref info);

if (error != TdsEnums.SNI_SUCCESS)
{
Expand All @@ -747,12 +747,12 @@ internal void EnableMars()
{
_pMarsPhysicalConObj.IncrementPendingCallbacks();

error = SniNativeWrapper.SNIReadAsync(_pMarsPhysicalConObj.Handle, ref temp);
error = SniNativeWrapper.SniReadAsync(_pMarsPhysicalConObj.Handle, ref temp);

if (temp != IntPtr.Zero)
{
// Be sure to release packet, otherwise it will be leaked by native.
SniNativeWrapper.SNIPacketRelease(temp);
SniNativeWrapper.SniPacketRelease(temp);
}
}
Debug.Assert(IntPtr.Zero == temp, "unexpected syncReadPacket without corresponding SNIPacketRelease");
Expand Down Expand Up @@ -1025,7 +1025,7 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ

Debug.Assert((_encryptionOption & EncryptionOptions.CLIENT_CERT) == 0, "Client certificate authentication support has been removed");

error = SniNativeWrapper.SNIAddProvider(_physicalStateObj.Handle, Provider.SSL_PROV, authInfo);
error = SniNativeWrapper.SniAddProvider(_physicalStateObj.Handle, Provider.SSL_PROV, authInfo);

if (error != TdsEnums.SNI_SUCCESS)
{
Expand All @@ -1037,7 +1037,7 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ
// wait for SSL handshake to complete, so that the SSL context is fully negotiated before we try to use its
// Channel Bindings as part of the Windows Authentication context build (SSL handshake must complete
// before calling SNISecGenClientContext).
error = SniNativeWrapper.SNIWaitForSSLHandshakeToComplete(_physicalStateObj.Handle, _physicalStateObj.GetTimeoutRemaining(), out uint protocolVersion);
error = SniNativeWrapper.SniWaitForSslHandshakeToComplete(_physicalStateObj.Handle, _physicalStateObj.GetTimeoutRemaining(), out uint protocolVersion);

if (error != TdsEnums.SNI_SUCCESS)
{
Expand Down Expand Up @@ -1591,7 +1591,7 @@ internal SqlError ProcessSNIError(TdsParserStateObject stateObj)
Debug.Assert(SniContext.Undefined != stateObj.DebugOnlyCopyOfSniContext || ((_fMARS) && ((_state == TdsParserState.Closed) || (_state == TdsParserState.Broken))), "SniContext must not be None");
#endif
SniError sniError = new SniError();
SniNativeWrapper.SNIGetLastError(out sniError);
SniNativeWrapper.SniGetLastError(out sniError);

if (sniError.sniError != 0)
{
Expand Down Expand Up @@ -2906,7 +2906,7 @@ private TdsOperationStatus TryProcessEnvChange(int tokenLength, TdsParserStateOb

// Update SNI ConsumerInfo value to be resulting packet size
uint unsignedPacketSize = (uint)packetSize;
uint bufferSizeResult = SniNativeWrapper.SNISetInfo(_physicalStateObj.Handle, QueryType.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize);
uint bufferSizeResult = SniNativeWrapper.SniSetInfo(_physicalStateObj.Handle, QueryType.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize);

Debug.Assert(bufferSizeResult == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SNISetInfo");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ internal void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey)
Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionPort");


result = SniNativeWrapper.SniGetConnectionIPString(_physicalStateObj.Handle, ref IPStringFromSNI);
result = SniNativeWrapper.SniGetConnectionIpString(_physicalStateObj.Handle, ref IPStringFromSNI);
Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionIPString");

_connHandler.pendingSQLDNSObject = new SQLDNSInfo(DNSCacheKey, null, null, portFromSNI.ToString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,20 +270,20 @@ internal PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint error)
{
SNIHandle handle = Handle ?? throw ADP.ClosedConnectionError();
PacketHandle readPacket = default;
error = SniNativeWrapper.SNIReadSyncOverAsync(handle, ref readPacket, timeoutRemaining);
error = SniNativeWrapper.SniReadSyncOverAsync(handle, ref readPacket, timeoutRemaining);
return readPacket;
}

internal PacketHandle ReadAsync(SessionHandle handle, out uint error)
{
PacketHandle readPacket = default;
error = SniNativeWrapper.SNIReadAsync(handle.NativeHandle, ref readPacket);
error = SniNativeWrapper.SniReadAsync(handle.NativeHandle, ref readPacket);
return readPacket;
}

internal uint CheckConnection() => SniNativeWrapper.SNICheckConnection(Handle);
internal uint CheckConnection() => SniNativeWrapper.SniCheckConnection(Handle);

internal void ReleasePacket(PacketHandle syncReadPacket) => SniNativeWrapper.SNIPacketRelease(syncReadPacket);
internal void ReleasePacket(PacketHandle syncReadPacket) => SniNativeWrapper.SniPacketRelease(syncReadPacket);

[ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)]
internal int DecrementPendingCallbacks(bool release)
Expand Down Expand Up @@ -401,7 +401,7 @@ internal bool ValidateSNIConnection()
SNIHandle handle = Handle;
if (handle != null)
{
error = SniNativeWrapper.SNICheckConnection(handle);
error = SniNativeWrapper.SniCheckConnection(handle);
}
}
finally
Expand Down Expand Up @@ -518,7 +518,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error)

private uint GetSniPacket(PacketHandle packet, ref uint dataSize)
{
return SniNativeWrapper.SNIPacketGetData(packet, _inBuff, ref dataSize);
return SniNativeWrapper.SniPacketGetData(packet, _inBuff, ref dataSize);
}

private void ChangeNetworkPacketTimeout(int dueTime, int period)
Expand Down Expand Up @@ -1007,7 +1007,7 @@ private Task SNIWritePacket(SNIHandle handle, SNIPacket packet, out uint sniErro
}
finally
{
sniError = SniNativeWrapper.SNIWritePacket(handle, packet, sync);
sniError = SniNativeWrapper.SniWritePacket(handle, packet, sync);
}

if (sniError == TdsEnums.SNI_SUCCESS_IO_PENDING)
Expand Down Expand Up @@ -1119,7 +1119,7 @@ internal void SendAttention(bool mustTakeWriteLock = false, bool asyncClose = fa
SNIPacket attnPacket = new SNIPacket(Handle);
_sniAsyncAttnPacket = attnPacket;

SniNativeWrapper.SNIPacketSetData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN, null, null);
SniNativeWrapper.SniPacketSetData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN, null, null);

RuntimeHelpers.PrepareConstrainedRegions();
try
Expand Down Expand Up @@ -1183,7 +1183,7 @@ private Task WriteSni(bool canAccumulate)
{
// Prepare packet, and write to packet.
SNIPacket packet = GetResetWritePacket();
SniNativeWrapper.SNIPacketSetData(packet, _outBuff, _outBytesUsed, _securePasswords, _securePasswordOffsetsInBuffer);
SniNativeWrapper.SniPacketSetData(packet, _outBuff, _outBytesUsed, _securePasswords, _securePasswordOffsetsInBuffer);

Debug.Assert(Parser.Connection._parserLock.ThreadMayHaveLock(), "Thread is writing without taking the connection lock");
Task task = SNIWritePacket(Handle, packet, out _, canAccumulate, callerHasConnectionLock: true);
Expand Down Expand Up @@ -1238,7 +1238,7 @@ internal SNIPacket GetResetWritePacket()
{
if (_sniPacket != null)
{
SniNativeWrapper.SNIPacketReset(Handle, IoType.WRITE, _sniPacket, ConsumerNumber.SNI_Consumer_SNI);
SniNativeWrapper.SniPacketReset(Handle, IoType.WRITE, _sniPacket, ConsumerNumber.SNI_Consumer_SNI);
}
else
{
Expand Down
Loading
Loading