Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various updates to the decoding pipeline #3

Merged
merged 7 commits into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/DSharpPlus.VoiceLink/DiscordIpDiscoveryPacket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public static implicit operator byte[](DiscordIpDiscoveryPacket ipDiscovery)
BinaryPrimitives.WriteUInt16BigEndian(dataSpan[2..4], ipDiscovery.Length);
BinaryPrimitives.WriteUInt32BigEndian(dataSpan[4..8], ipDiscovery.Ssrc);
Encoding.UTF8.TryGetBytes(ipDiscovery.Address, dataSpan[8..72], out _);
dataSpan[71] = 0; // Need to null-terminate the IP string
OoLunar marked this conversation as resolved.
Show resolved Hide resolved
BinaryPrimitives.WriteUInt16BigEndian(dataSpan[72..74], ipDiscovery.Port);

return data;
Expand Down
63 changes: 29 additions & 34 deletions src/DSharpPlus.VoiceLink/Opus/OpusDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@

namespace DSharpPlus.VoiceLink.Opus
{
public struct OpusDecoder : IDisposable
public readonly struct OpusDecoder : IDisposable
{
private readonly IntPtr _state;

public OpusDecoder(IntPtr state)
{
_state = state;
}

/// <inheritdoc cref="OpusNativeMethods.DecoderGetSize(int)"/>
public static int GetSize(int channels) => OpusNativeMethods.DecoderGetSize(channels);

Expand All @@ -12,39 +19,36 @@ public struct OpusDecoder : IDisposable
/// <inheritdoc cref="OpusNativeMethods.DecoderCreate(OpusSampleRate, int, out OpusErrorCode*)"/>
public static unsafe OpusDecoder Create(OpusSampleRate sampleRate, int channels)
{
OpusDecoder* decoder = OpusNativeMethods.DecoderCreate(sampleRate, channels, out OpusErrorCode* errorCode);
return (errorCode != default && *errorCode != OpusErrorCode.Ok) ? throw new OpusException(*errorCode) : *decoder;
IntPtr state = OpusNativeMethods.DecoderCreate(sampleRate, channels, out OpusErrorCode* errorCode);
return (errorCode != default && *errorCode != OpusErrorCode.Ok)
? throw new OpusException(*errorCode)
: new OpusDecoder(state);
}

/// <inheritdoc cref="OpusNativeMethods.DecoderInit(OpusDecoder*, OpusSampleRate, int)"/>
public unsafe void Init(OpusSampleRate sampleRate, int channels)
/// <inheritdoc cref="OpusNativeMethods.DecoderInit(IntPtr, OpusSampleRate, int)"/>
public void Init(OpusSampleRate sampleRate, int channels)
{
OpusErrorCode errorCode;
fixed (OpusDecoder* pinned = &this)
{
errorCode = OpusNativeMethods.DecoderInit(pinned, sampleRate, channels);
}
OpusErrorCode errorCode = OpusNativeMethods.DecoderInit(_state, sampleRate, channels);

if (errorCode != OpusErrorCode.Ok)
{
throw new OpusException(errorCode);
}
}

/// <inheritdoc cref="OpusNativeMethods.Decode(OpusDecoder*, byte*, int, byte*, int, int)"/>
public unsafe int Decode(ReadOnlySpan<byte> data, Span<byte> pcm, bool decodeFec)
/// <inheritdoc cref="OpusNativeMethods.Decode(IntPtr, byte*, int, byte*, int, int)"/>
public unsafe int Decode(ReadOnlySpan<byte> data, Span<byte> pcm, int frameSize, bool decodeFec)
{
int decodedLength;
fixed (OpusDecoder* pinned = &this)
fixed (byte* dataPointer = data)
fixed (byte* pcmPointer = pcm)
{
decodedLength = OpusNativeMethods.Decode(
pinned,
_state,
dataPointer,
data.Length,
pcmPointer,
OpusNativeMethods.PacketGetNbFrames(dataPointer, data.Length),
frameSize,
decodeFec ? 1 : 0
);
}
Expand All @@ -59,15 +63,14 @@ public unsafe int Decode(ReadOnlySpan<byte> data, Span<byte> pcm, bool decodeFec
return decodedLength * sizeof(short) * 2;
}

/// <inheritdoc cref="OpusNativeMethods.DecodeFloat(OpusDecoder*, byte*, int, float*, int, int)"/>
/// <inheritdoc cref="OpusNativeMethods.DecodeFloat(IntPtr, byte*, int, byte*, int, int)"/>
public unsafe int DecodeFloat(ReadOnlySpan<byte> data, Span<byte> pcm, int frameSize, bool decodeFec)
{
int decodedLength;
fixed (OpusDecoder* pinned = &this)
fixed (byte* dataPointer = data)
fixed (byte* pcmPointer = pcm)
{
decodedLength = OpusNativeMethods.DecodeFloat(pinned, dataPointer, data.Length, pcmPointer, frameSize, decodeFec ? 1 : 0);
decodedLength = OpusNativeMethods.DecodeFloat(_state, dataPointer, data.Length, pcmPointer, frameSize, decodeFec ? 1 : 0);
}

// Less than zero means an error occurred
Expand All @@ -81,28 +84,21 @@ public unsafe int DecodeFloat(ReadOnlySpan<byte> data, Span<byte> pcm, int frame
return decodedLength;
}

/// <inheritdoc cref="OpusNativeMethods.DecoderControl(OpusDecoder*, OpusControlRequest, int)"/>
public unsafe void Control(OpusControlRequest control, out int value)
/// <inheritdoc cref="OpusNativeMethods.DecoderControl(IntPtr, OpusControlRequest, out int)"/>
public void Control(OpusControlRequest control, out int value)
{
OpusErrorCode errorCode;
fixed (OpusDecoder* pinned = &this)
{
errorCode = OpusNativeMethods.DecoderControl(pinned, control, out value);
}
OpusErrorCode errorCode = OpusNativeMethods.DecoderControl(_state, control, out value);

if (errorCode != OpusErrorCode.Ok)
{
throw new OpusException(errorCode);
}
}

/// <inheritdoc cref="OpusNativeMethods.DecoderDestroy(OpusDecoder*)"/>
public unsafe void Destroy()
/// <inheritdoc cref="OpusNativeMethods.DecoderDestroy(IntPtr)"/>
public void Destroy()
{
fixed (OpusDecoder* pinned = &this)
{
OpusNativeMethods.DecoderDestroy(pinned);
}
OpusNativeMethods.DecoderDestroy(_state);
}

/// <summary>
Expand All @@ -111,14 +107,13 @@ public unsafe void Destroy()
/// <exception cref="ArgumentException">Invalid argument passed to the decoder.</exception>
/// <exception cref="InvalidOperationException">The compressed data passed is corrupted or of an unsupported type or an unknown error occured.</exception>
/// <returns>The number of samples per channel of a packet.</returns>
/// <inheritdoc cref="OpusNativeMethods.DecoderGetNbSamples(OpusDecoder*, byte*, int)"/>
/// <inheritdoc cref="OpusNativeMethods.DecoderGetNbSamples(IntPtr, byte*, int)"/>
public unsafe int GetSampleCount(ReadOnlySpan<byte> data)
{
int sampleCount;
fixed (OpusDecoder* pinned = &this)
fixed (byte* dataPointer = data)
{
sampleCount = OpusNativeMethods.DecoderGetNbSamples(pinned, dataPointer, data.Length);
sampleCount = OpusNativeMethods.DecoderGetNbSamples(_state, dataPointer, data.Length);
}

// Less than zero means an error occurred
Expand Down
19 changes: 10 additions & 9 deletions src/DSharpPlus.VoiceLink/Opus/OpusNativeMethods.Decoder.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Runtime.InteropServices;

namespace DSharpPlus.VoiceLink.Opus
Expand All @@ -20,7 +21,7 @@ internal static partial class OpusNativeMethods
/// <param name="channels">Number of channels (1 or 2) to decode.</param>
/// <returns><see cref="OpusErrorCode.Ok"/> or other error codes.</returns>
[LibraryImport("opus", EntryPoint = "opus_decoder_create")]
public static unsafe partial OpusDecoder* DecoderCreate(OpusSampleRate sampleRate, int channels, out OpusErrorCode* error);
public static unsafe partial IntPtr DecoderCreate(OpusSampleRate sampleRate, int channels, out OpusErrorCode* error);

/// <summary>
/// Initializes a previously allocated decoder state. The state must be at least the size returned by <see cref="DecoderGetSize(int)"/>. This is intended for applications which use their own allocator instead of malloc.
Expand All @@ -30,7 +31,7 @@ internal static partial class OpusNativeMethods
/// <param name="channels">Number of channels (1 or 2) to decode.</param>
/// <returns><see cref="OpusErrorCode.Ok"/> or other error codes.</returns>
[LibraryImport("opus", EntryPoint = "opus_decoder_init")]
public static unsafe partial OpusErrorCode DecoderInit(OpusDecoder* decoder, OpusSampleRate sampleRate, int channels);
public static unsafe partial OpusErrorCode DecoderInit(IntPtr decoder, OpusSampleRate sampleRate, int channels);

/// <summary>
/// Decode an Opus packet.
Expand All @@ -43,11 +44,11 @@ internal static partial class OpusNativeMethods
/// <param name="decodeFec">Flag (0 or 1) to request that any in-band forward error correction data be decoded. If no such data is available, the frame is decoded as if it were lost.</param>
/// <returns>Number of decoded samples or an <see cref="OpusErrorCode"/></returns>
[LibraryImport("opus", EntryPoint = "opus_decode")]
public static unsafe partial int Decode(OpusDecoder* decoder, byte* data, int length, byte* pcm, int frameSize, int decodeFec);
public static unsafe partial int Decode(IntPtr decoder, byte* data, int length, byte* pcm, int frameSize, int decodeFec);

/// <inheritdoc cref="Decode(OpusDecoder*, byte*, int, byte*, int, int)"/>
/// <inheritdoc cref="Decode(IntPtr, byte*, int, byte*, int, int)"/>
[LibraryImport("opus", EntryPoint = "opus_decode_float")]
public static unsafe partial int DecodeFloat(OpusDecoder* decoder, byte* data, int length, byte* pcm, int frameSize, int decodeFec);
public static unsafe partial int DecodeFloat(IntPtr decoder, byte* data, int length, byte* pcm, int frameSize, int decodeFec);

/// <summary>
/// Perform a CTL function on an Opus decoder.
Expand All @@ -56,14 +57,14 @@ internal static partial class OpusNativeMethods
/// <param name="decoder">Decoder state.</param>
/// <param name="request">This and all remaining parameters should be replaced by one of the convenience macros in Generic CTLs or Decoder related CTLs.</param>
[LibraryImport("opus", EntryPoint = "opus_decoder_ctl")]
public static unsafe partial OpusErrorCode DecoderControl(OpusDecoder* decoder, OpusControlRequest request, out int value);
public static unsafe partial OpusErrorCode DecoderControl(IntPtr decoder, OpusControlRequest request, out int value);

/// <summary>
/// Frees an OpusDecoder allocated by <see cref="DecoderCreate(OpusSampleRate, int, out OpusErrorCode)"/>.
/// Frees an OpusDecoder allocated by <see cref="DecoderCreate(OpusSampleRate, int, out OpusErrorCode*)"/>.
/// </summary>
/// <param name="decoder">State to be freed.</param>
[LibraryImport("opus", EntryPoint = "opus_decoder_destroy")]
public static unsafe partial void DecoderDestroy(OpusDecoder* decoder);
public static unsafe partial void DecoderDestroy(IntPtr decoder);

/// <summary>
/// Gets the number of samples of an Opus packet.
Expand All @@ -73,6 +74,6 @@ internal static partial class OpusNativeMethods
/// <param name="length">Length of packet.</param>
/// <returns>Number of samples or <see cref="OpusErrorCode.BadArg"/> or <see cref="OpusErrorCode.InvalidPacket"/>.</returns>
[LibraryImport("opus", EntryPoint = "opus_decoder_get_nb_samples")]
public static unsafe partial int DecoderGetNbSamples(OpusDecoder* decoder, byte* data, int length);
public static unsafe partial int DecoderGetNbSamples(IntPtr decoder, byte* data, int length);
}
}
12 changes: 12 additions & 0 deletions src/DSharpPlus.VoiceLink/Rtp/RtpUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,17 @@ public static RtpHeader DecodeHeader(ReadOnlySpan<byte> source)
Ssrc = BinaryPrimitives.ReadUInt32BigEndian(source[8..12])
};
}

/// <summary>
/// Gets the length in bytes of an RTP header extension. The extension will prefix the RTP payload.
/// Use <see cref="RtpHeader.HasExtension"/> to determined whether an RTP packet includes an extension.
/// </summary>
/// <param name="rtpPayload">The RTP payload that is prefixed by a header extension.</param>
/// <returns>The byte length of the extension.</returns>
public static ushort GetHeaderExtensionLength(ReadOnlySpan<byte> rtpPayload)
{
// offset by two to ignore the profile marker
return BinaryPrimitives.ReadUInt16BigEndian(rtpPayload[2..]);
}
}
}
43 changes: 36 additions & 7 deletions src/DSharpPlus.VoiceLink/VoiceLinkConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ await _webSocket.SendAsync(new DiscordVoiceResumingCommand()
Token = _voiceToken!
}, _cancellationTokenSource.Token);
}
catch (Exception ex)
{
// We need to catch and log all errors here because this task method is not watched
_logger.LogError(ex, "Unexpected failure on voice websocket");
throw;
}
}
}

Expand Down Expand Up @@ -277,7 +283,7 @@ private async Task ReceiveAudioLoopAsync()
continue;
}

if (rtpHeader.HasMarker || rtpHeader.HasExtension)
if (rtpHeader.HasMarker)
{
// All clients send a marker bit when they first connect. For now we're just going to ignore this.
continue;
Expand All @@ -294,20 +300,43 @@ private async Task ReceiveAudioLoopAsync()
}

// Decrypt the audio
byte[] decryptedAudio = ArrayPool<byte>.Shared.Rent(_voiceEncrypter.GetDecryptedSize(udpReceiveResult.Buffer.Length));
if (!_voiceEncrypter.TryDecryptOpusPacket(voiceLinkUser, udpReceiveResult.Buffer, _secretKey, decryptedAudio.AsSpan()))
int decryptedBufferSize = _voiceEncrypter.GetDecryptedSize(udpReceiveResult.Buffer.Length);
byte[] decryptedAudioArr = ArrayPool<byte>.Shared.Rent(decryptedBufferSize);
Memory<byte> decryptedAudio = decryptedAudioArr.AsMemory(0, decryptedBufferSize);

if (!_voiceEncrypter.TryDecryptOpusPacket(voiceLinkUser, udpReceiveResult.Buffer, _secretKey, decryptedAudio.Span))
{
_logger.LogWarning("Connection {GuildId}: Failed to decrypt audio from {Ssrc}, skipping.", Guild.Id, rtpHeader.Ssrc);
ArrayPool<byte>.Shared.Return(decryptedAudioArr);
continue;
}

// Strip any RTP header extensions. See https://www.rfc-editor.org/rfc/rfc3550#section-5.3.1
// Discord currently uses a generic profile marker of [0xbe, 0xde], see
// https://www.rfc-editor.org/rfc/rfc8285#section-4.2
if (rtpHeader.HasExtension)
{
ushort extensionLength = RtpUtilities.GetHeaderExtensionLength(decryptedAudio.Span);
decryptedAudio = decryptedAudio[(4 + 4 * extensionLength)..];
}

// TODO: Handle FEC (Forward Error Correction) aka packet loss.
// * https://tools.ietf.org/html/rfc5109
bool hasDataLoss = voiceLinkUser.UpdateSequence(rtpHeader.Sequence);

// Decode the audio
DecodeOpusAudio(decryptedAudio, voiceLinkUser, hasDataLoss);
ArrayPool<byte>.Shared.Return(decryptedAudio);
try
{
DecodeOpusAudio(decryptedAudio.Span, voiceLinkUser, hasDataLoss);
}
catch (Exception ex)
{
// TODO: Should this be a reason to terminate the connection?
// definitely should if a few in a row fail, at the very least
_logger.LogError(ex, "Connection {GuildId}: Failed to decode opus audio from {Ssrc}, skipping", Guild.Id, rtpHeader.Ssrc);
}

ArrayPool<byte>.Shared.Return(decryptedAudioArr);
await voiceLinkUser._audioPipe.Writer.FlushAsync(_cancellationTokenSource.Token);

static void DecodeOpusAudio(ReadOnlySpan<byte> opusPacket, VoiceLinkUser voiceLinkUser, bool hasPacketLoss = false)
Expand All @@ -316,13 +345,13 @@ static void DecodeOpusAudio(ReadOnlySpan<byte> opusPacket, VoiceLinkUser voiceLi
const int sampleRate = 48000; // 48 kHz
const double frameDuration = 0.020; // 20 milliseconds
const int frameSize = (int)(sampleRate * frameDuration); // 960 samples
const int bufferSize = frameSize * 2; // Stereo audio
const int bufferSize = frameSize * 2 * sizeof(short); // Stereo audio + opus PCM units are 16 bits

// Allocate the buffer for the PCM data
Span<byte> audioBuffer = voiceLinkUser._audioPipe.Writer.GetSpan(bufferSize);

// Decode the Opus packet
voiceLinkUser._opusDecoder.Decode(opusPacket, audioBuffer, hasPacketLoss);
voiceLinkUser._opusDecoder.Decode(opusPacket, audioBuffer, frameSize, hasPacketLoss);

// Write the audio to the pipe
voiceLinkUser._audioPipe.Writer.Advance(bufferSize);
Expand Down
Loading