From 2ad592e5d4f127a15641d1e8591f41db56dbc477 Mon Sep 17 00:00:00 2001 From: Carl Date: Sun, 14 Jan 2024 11:43:57 +1300 Subject: [PATCH 1/7] Hold decoder state pointer internally to resolve passing an invalid reference to the Opus native methods. Pass frameSize to OpusDecoder.cs#Decode --- src/DSharpPlus.VoiceLink/Opus/OpusDecoder.cs | 63 +++++++++---------- .../Opus/OpusNativeMethods.Decoder.cs | 19 +++--- .../VoiceLinkConnection.cs | 2 +- 3 files changed, 40 insertions(+), 44 deletions(-) diff --git a/src/DSharpPlus.VoiceLink/Opus/OpusDecoder.cs b/src/DSharpPlus.VoiceLink/Opus/OpusDecoder.cs index 8851b9d..af5c838 100644 --- a/src/DSharpPlus.VoiceLink/Opus/OpusDecoder.cs +++ b/src/DSharpPlus.VoiceLink/Opus/OpusDecoder.cs @@ -2,8 +2,15 @@ namespace DSharpPlus.VoiceLink.Opus { - public struct OpusDecoder : IDisposable + public readonly struct OpusDecoder : IDisposable { + private readonly IntPtr _state; + + public OpusDecoder(IntPtr state) + { + _state = state; + } + /// public static int GetSize(int channels) => OpusNativeMethods.DecoderGetSize(channels); @@ -12,18 +19,16 @@ public struct OpusDecoder : IDisposable /// public static unsafe OpusDecoder Create(OpusSampleRate sampleRate, int channels) { - OpusDecoder* decoder = OpusNativeMethods.DecoderCreate(sampleRate, channels, out OpusErrorCode* errorCode); - return (errorCode != default && *errorCode != OpusErrorCode.Ok) ? throw new OpusException(*errorCode) : *decoder; + IntPtr state = OpusNativeMethods.DecoderCreate(sampleRate, channels, out OpusErrorCode* errorCode); + return (errorCode != default && *errorCode != OpusErrorCode.Ok) + ? throw new OpusException(*errorCode) + : new OpusDecoder(state); } - /// - public unsafe void Init(OpusSampleRate sampleRate, int channels) + /// + public void Init(OpusSampleRate sampleRate, int channels) { - OpusErrorCode errorCode; - fixed (OpusDecoder* pinned = &this) - { - errorCode = OpusNativeMethods.DecoderInit(pinned, sampleRate, channels); - } + OpusErrorCode errorCode = OpusNativeMethods.DecoderInit(_state, sampleRate, channels); if (errorCode != OpusErrorCode.Ok) { @@ -31,20 +36,19 @@ public unsafe void Init(OpusSampleRate sampleRate, int channels) } } - /// - public unsafe int Decode(ReadOnlySpan data, Span pcm, bool decodeFec) + /// + public unsafe int Decode(ReadOnlySpan data, Span pcm, int frameSize, bool decodeFec) { int decodedLength; - fixed (OpusDecoder* pinned = &this) fixed (byte* dataPointer = data) fixed (byte* pcmPointer = pcm) { decodedLength = OpusNativeMethods.Decode( - pinned, + _state, dataPointer, data.Length, pcmPointer, - OpusNativeMethods.PacketGetNbFrames(dataPointer, data.Length), + frameSize, decodeFec ? 1 : 0 ); } @@ -59,15 +63,14 @@ public unsafe int Decode(ReadOnlySpan data, Span pcm, bool decodeFec return decodedLength * sizeof(short) * 2; } - /// + /// public unsafe int DecodeFloat(ReadOnlySpan data, Span pcm, int frameSize, bool decodeFec) { int decodedLength; - fixed (OpusDecoder* pinned = &this) fixed (byte* dataPointer = data) fixed (byte* pcmPointer = pcm) { - decodedLength = OpusNativeMethods.DecodeFloat(pinned, dataPointer, data.Length, pcmPointer, frameSize, decodeFec ? 1 : 0); + decodedLength = OpusNativeMethods.DecodeFloat(_state, dataPointer, data.Length, pcmPointer, frameSize, decodeFec ? 1 : 0); } // Less than zero means an error occurred @@ -81,14 +84,10 @@ public unsafe int DecodeFloat(ReadOnlySpan data, Span pcm, int frame return decodedLength; } - /// - public unsafe void Control(OpusControlRequest control, out int value) + /// + public void Control(OpusControlRequest control, out int value) { - OpusErrorCode errorCode; - fixed (OpusDecoder* pinned = &this) - { - errorCode = OpusNativeMethods.DecoderControl(pinned, control, out value); - } + OpusErrorCode errorCode = OpusNativeMethods.DecoderControl(_state, control, out value); if (errorCode != OpusErrorCode.Ok) { @@ -96,13 +95,10 @@ public unsafe void Control(OpusControlRequest control, out int value) } } - /// - public unsafe void Destroy() + /// + public void Destroy() { - fixed (OpusDecoder* pinned = &this) - { - OpusNativeMethods.DecoderDestroy(pinned); - } + OpusNativeMethods.DecoderDestroy(_state); } /// @@ -111,14 +107,13 @@ public unsafe void Destroy() /// Invalid argument passed to the decoder. /// The compressed data passed is corrupted or of an unsupported type or an unknown error occured. /// The number of samples per channel of a packet. - /// + /// public unsafe int GetSampleCount(ReadOnlySpan data) { int sampleCount; - fixed (OpusDecoder* pinned = &this) fixed (byte* dataPointer = data) { - sampleCount = OpusNativeMethods.DecoderGetNbSamples(pinned, dataPointer, data.Length); + sampleCount = OpusNativeMethods.DecoderGetNbSamples(_state, dataPointer, data.Length); } // Less than zero means an error occurred diff --git a/src/DSharpPlus.VoiceLink/Opus/OpusNativeMethods.Decoder.cs b/src/DSharpPlus.VoiceLink/Opus/OpusNativeMethods.Decoder.cs index 85b7bbe..18bede3 100644 --- a/src/DSharpPlus.VoiceLink/Opus/OpusNativeMethods.Decoder.cs +++ b/src/DSharpPlus.VoiceLink/Opus/OpusNativeMethods.Decoder.cs @@ -1,3 +1,4 @@ +using System; using System.Runtime.InteropServices; namespace DSharpPlus.VoiceLink.Opus @@ -20,7 +21,7 @@ internal static partial class OpusNativeMethods /// Number of channels (1 or 2) to decode. /// or other error codes. [LibraryImport("opus", EntryPoint = "opus_decoder_create")] - public static unsafe partial OpusDecoder* DecoderCreate(OpusSampleRate sampleRate, int channels, out OpusErrorCode* error); + public static unsafe partial IntPtr DecoderCreate(OpusSampleRate sampleRate, int channels, out OpusErrorCode* error); /// /// Initializes a previously allocated decoder state. The state must be at least the size returned by . This is intended for applications which use their own allocator instead of malloc. @@ -30,7 +31,7 @@ internal static partial class OpusNativeMethods /// Number of channels (1 or 2) to decode. /// or other error codes. [LibraryImport("opus", EntryPoint = "opus_decoder_init")] - public static unsafe partial OpusErrorCode DecoderInit(OpusDecoder* decoder, OpusSampleRate sampleRate, int channels); + public static unsafe partial OpusErrorCode DecoderInit(IntPtr decoder, OpusSampleRate sampleRate, int channels); /// /// Decode an Opus packet. @@ -43,11 +44,11 @@ internal static partial class OpusNativeMethods /// Flag (0 or 1) to request that any in-band forward error correction data be decoded. If no such data is available, the frame is decoded as if it were lost. /// Number of decoded samples or an [LibraryImport("opus", EntryPoint = "opus_decode")] - public static unsafe partial int Decode(OpusDecoder* decoder, byte* data, int length, byte* pcm, int frameSize, int decodeFec); + public static unsafe partial int Decode(IntPtr decoder, byte* data, int length, byte* pcm, int frameSize, int decodeFec); - /// + /// [LibraryImport("opus", EntryPoint = "opus_decode_float")] - public static unsafe partial int DecodeFloat(OpusDecoder* decoder, byte* data, int length, byte* pcm, int frameSize, int decodeFec); + public static unsafe partial int DecodeFloat(IntPtr decoder, byte* data, int length, byte* pcm, int frameSize, int decodeFec); /// /// Perform a CTL function on an Opus decoder. @@ -56,14 +57,14 @@ internal static partial class OpusNativeMethods /// Decoder state. /// This and all remaining parameters should be replaced by one of the convenience macros in Generic CTLs or Decoder related CTLs. [LibraryImport("opus", EntryPoint = "opus_decoder_ctl")] - public static unsafe partial OpusErrorCode DecoderControl(OpusDecoder* decoder, OpusControlRequest request, out int value); + public static unsafe partial OpusErrorCode DecoderControl(IntPtr decoder, OpusControlRequest request, out int value); /// - /// Frees an OpusDecoder allocated by . + /// Frees an OpusDecoder allocated by . /// /// State to be freed. [LibraryImport("opus", EntryPoint = "opus_decoder_destroy")] - public static unsafe partial void DecoderDestroy(OpusDecoder* decoder); + public static unsafe partial void DecoderDestroy(IntPtr decoder); /// /// Gets the number of samples of an Opus packet. @@ -73,6 +74,6 @@ internal static partial class OpusNativeMethods /// Length of packet. /// Number of samples or or . [LibraryImport("opus", EntryPoint = "opus_decoder_get_nb_samples")] - public static unsafe partial int DecoderGetNbSamples(OpusDecoder* decoder, byte* data, int length); + public static unsafe partial int DecoderGetNbSamples(IntPtr decoder, byte* data, int length); } } diff --git a/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs b/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs index 8db0baa..690c1d5 100644 --- a/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs +++ b/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs @@ -322,7 +322,7 @@ static void DecodeOpusAudio(ReadOnlySpan opusPacket, VoiceLinkUser voiceLi Span audioBuffer = voiceLinkUser._audioPipe.Writer.GetSpan(bufferSize); // Decode the Opus packet - voiceLinkUser._opusDecoder.Decode(opusPacket, audioBuffer, hasPacketLoss); + voiceLinkUser._opusDecoder.Decode(opusPacket, audioBuffer, frameSize, hasPacketLoss); // Write the audio to the pipe voiceLinkUser._audioPipe.Writer.Advance(bufferSize); From c3c38e93181ad36109aa5dc74976e4bd20d79bf9 Mon Sep 17 00:00:00 2001 From: Carl Date: Sun, 14 Jan 2024 11:44:20 +1300 Subject: [PATCH 2/7] Ensure address string is null terminated --- src/DSharpPlus.VoiceLink/DiscordIpDiscoveryPacket.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DSharpPlus.VoiceLink/DiscordIpDiscoveryPacket.cs b/src/DSharpPlus.VoiceLink/DiscordIpDiscoveryPacket.cs index 6c8af2f..a2e42a7 100644 --- a/src/DSharpPlus.VoiceLink/DiscordIpDiscoveryPacket.cs +++ b/src/DSharpPlus.VoiceLink/DiscordIpDiscoveryPacket.cs @@ -41,6 +41,7 @@ public static implicit operator byte[](DiscordIpDiscoveryPacket ipDiscovery) BinaryPrimitives.WriteUInt16BigEndian(dataSpan[2..4], ipDiscovery.Length); BinaryPrimitives.WriteUInt32BigEndian(dataSpan[4..8], ipDiscovery.Ssrc); Encoding.UTF8.TryGetBytes(ipDiscovery.Address, dataSpan[8..72], out _); + dataSpan[71] = 0; // Need to null-terminate the IP string BinaryPrimitives.WriteUInt16BigEndian(dataSpan[72..74], ipDiscovery.Port); return data; From aec6dcbc654ebd08885ce2ce7ba1057671a1d1c9 Mon Sep 17 00:00:00 2001 From: Carl Date: Sun, 14 Jan 2024 11:50:37 +1300 Subject: [PATCH 3/7] Handle RTP header extensions --- src/DSharpPlus.VoiceLink/Rtp/RtpUtilities.cs | 6 ++++++ src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs | 11 ++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/DSharpPlus.VoiceLink/Rtp/RtpUtilities.cs b/src/DSharpPlus.VoiceLink/Rtp/RtpUtilities.cs index 6bd696a..39d9806 100644 --- a/src/DSharpPlus.VoiceLink/Rtp/RtpUtilities.cs +++ b/src/DSharpPlus.VoiceLink/Rtp/RtpUtilities.cs @@ -75,5 +75,11 @@ public static RtpHeader DecodeHeader(ReadOnlySpan source) Ssrc = BinaryPrimitives.ReadUInt32BigEndian(source[8..12]) }; } + + public static ushort GetHeaderExtensionLength(ReadOnlySpan rtpPayload) + { + // offset by two to ignore the profile marker + return BinaryPrimitives.ReadUInt16BigEndian(rtpPayload[2..]); + } } } diff --git a/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs b/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs index 690c1d5..209938a 100644 --- a/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs +++ b/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs @@ -277,7 +277,7 @@ private async Task ReceiveAudioLoopAsync() continue; } - if (rtpHeader.HasMarker || rtpHeader.HasExtension) + if (rtpHeader.HasMarker) { // All clients send a marker bit when they first connect. For now we're just going to ignore this. continue; @@ -301,6 +301,15 @@ private async Task ReceiveAudioLoopAsync() continue; } + // Strip any RTP header extensions. See https://www.rfc-editor.org/rfc/rfc3550#section-5.3.1 + // Discord currently uses a generic profile marker of [0xbe, 0xde], see + // https://www.rfc-editor.org/rfc/rfc8285#section-4.2 + if (rtpHeader.HasExtension) + { + ushort extensionLength = RtpUtilities.GetHeaderExtensionLength(decryptedAudio.Span); + decryptedAudio = decryptedAudio[(4 + 4 * extensionLength)..]; + } + // TODO: Handle FEC (Forward Error Correction) aka packet loss. // * https://tools.ietf.org/html/rfc5109 bool hasDataLoss = voiceLinkUser.UpdateSequence(rtpHeader.Sequence); From 90db16fe54305f8be5990bd23a4e8219e9454bc9 Mon Sep 17 00:00:00 2001 From: Carl Date: Sun, 14 Jan 2024 11:54:46 +1300 Subject: [PATCH 4/7] Only operate on voice data, not the entirety of the rented buffer --- src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs b/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs index 209938a..316e5a4 100644 --- a/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs +++ b/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs @@ -294,10 +294,14 @@ private async Task ReceiveAudioLoopAsync() } // Decrypt the audio - byte[] decryptedAudio = ArrayPool.Shared.Rent(_voiceEncrypter.GetDecryptedSize(udpReceiveResult.Buffer.Length)); - if (!_voiceEncrypter.TryDecryptOpusPacket(voiceLinkUser, udpReceiveResult.Buffer, _secretKey, decryptedAudio.AsSpan())) + int decryptedBufferSize = _voiceEncrypter.GetDecryptedSize(udpReceiveResult.Buffer.Length); + byte[] decryptedAudioArr = ArrayPool.Shared.Rent(decryptedBufferSize); + Memory decryptedAudio = decryptedAudioArr.AsMemory(0, decryptedBufferSize); + + if (!_voiceEncrypter.TryDecryptOpusPacket(voiceLinkUser, udpReceiveResult.Buffer, _secretKey, decryptedAudio.Span)) { _logger.LogWarning("Connection {GuildId}: Failed to decrypt audio from {Ssrc}, skipping.", Guild.Id, rtpHeader.Ssrc); + ArrayPool.Shared.Return(decryptedAudioArr); continue; } From 19a3270ec91aec8af26d1936d1d48e6775563ed3 Mon Sep 17 00:00:00 2001 From: Carl Date: Sun, 14 Jan 2024 11:55:02 +1300 Subject: [PATCH 5/7] Account for PCM unit size when allocating PCM buffer --- src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs b/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs index 316e5a4..01ebf0e 100644 --- a/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs +++ b/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs @@ -329,7 +329,7 @@ static void DecodeOpusAudio(ReadOnlySpan opusPacket, VoiceLinkUser voiceLi const int sampleRate = 48000; // 48 kHz const double frameDuration = 0.020; // 20 milliseconds const int frameSize = (int)(sampleRate * frameDuration); // 960 samples - const int bufferSize = frameSize * 2; // Stereo audio + const int bufferSize = frameSize * 2 * sizeof(short); // Stereo audio + opus PCM units are 16 bits // Allocate the buffer for the PCM data Span audioBuffer = voiceLinkUser._audioPipe.Writer.GetSpan(bufferSize); From 3a585fc9a1dc071e3936c98ecd3b8ac0fdb7c3e3 Mon Sep 17 00:00:00 2001 From: Carl Date: Sun, 14 Jan 2024 11:57:09 +1300 Subject: [PATCH 6/7] Improved error handling --- .../VoiceLinkConnection.cs | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs b/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs index 01ebf0e..8c7c227 100644 --- a/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs +++ b/src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs @@ -204,6 +204,12 @@ await _webSocket.SendAsync(new DiscordVoiceResumingCommand() Token = _voiceToken! }, _cancellationTokenSource.Token); } + catch (Exception ex) + { + // We need to catch and log all errors here because this task method is not watched + _logger.LogError(ex, "Unexpected failure on voice websocket"); + throw; + } } } @@ -319,8 +325,18 @@ private async Task ReceiveAudioLoopAsync() bool hasDataLoss = voiceLinkUser.UpdateSequence(rtpHeader.Sequence); // Decode the audio - DecodeOpusAudio(decryptedAudio, voiceLinkUser, hasDataLoss); - ArrayPool.Shared.Return(decryptedAudio); + try + { + DecodeOpusAudio(decryptedAudio.Span, voiceLinkUser, hasDataLoss); + } + catch (Exception ex) + { + // TODO: Should this be a reason to terminate the connection? + // definitely should if a few in a row fail, at the very least + _logger.LogError(ex, "Connection {GuildId}: Failed to decode opus audio from {Ssrc}, skipping", Guild.Id, rtpHeader.Ssrc); + } + + ArrayPool.Shared.Return(decryptedAudioArr); await voiceLinkUser._audioPipe.Writer.FlushAsync(_cancellationTokenSource.Token); static void DecodeOpusAudio(ReadOnlySpan opusPacket, VoiceLinkUser voiceLinkUser, bool hasPacketLoss = false) From dade935e74ccb41df236f41c2376782596032fda Mon Sep 17 00:00:00 2001 From: Carl Date: Sun, 14 Jan 2024 12:08:38 +1300 Subject: [PATCH 7/7] XML doc for RtpUtilities.cs#GetHeaderExtensionLength --- src/DSharpPlus.VoiceLink/Rtp/RtpUtilities.cs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/DSharpPlus.VoiceLink/Rtp/RtpUtilities.cs b/src/DSharpPlus.VoiceLink/Rtp/RtpUtilities.cs index 39d9806..1bad820 100644 --- a/src/DSharpPlus.VoiceLink/Rtp/RtpUtilities.cs +++ b/src/DSharpPlus.VoiceLink/Rtp/RtpUtilities.cs @@ -76,6 +76,12 @@ public static RtpHeader DecodeHeader(ReadOnlySpan source) }; } + /// + /// Gets the length in bytes of an RTP header extension. The extension will prefix the RTP payload. + /// Use to determined whether an RTP packet includes an extension. + /// + /// The RTP payload that is prefixed by a header extension. + /// The byte length of the extension. public static ushort GetHeaderExtensionLength(ReadOnlySpan rtpPayload) { // offset by two to ignore the profile marker