Skip to content

Commit

Permalink
feat: fetch matxa voices from backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ccoreilly committed Dec 29, 2024
1 parent 95b01ec commit 5654b4c
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 26 deletions.
4 changes: 2 additions & 2 deletions src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ function App() {
setMediaUrl(TranscriptionAPIService.getMediaUrl(newUuid, revision ?? ''));
setMediaType(videoDataResponse.contentType);
setMediaFileName(videoDataResponse.filename);
setTracks(TranscriptionAPIService.parseTracksFromJSON(tracksDataResponse));
setTracks(await TranscriptionAPIService.parseTracksFromJSON(tracksDataResponse));
}
} catch (error) {
console.error("Error loading media or tracks from UUID:", error);
Expand Down Expand Up @@ -516,7 +516,7 @@ function App() {
try {
if (uuidParam) {
const rawTracks = await DubbingAPIService.loadTracksFromUUID(uuidParam);
const parsedTracks = DubbingAPIService.parseTracksFromJSON(rawTracks);
const parsedTracks = await DubbingAPIService.parseTracksFromJSON(rawTracks);
setTracks(parsedTracks);

setMediaUrl(DubbingAPIService.getSilentVideoUrl(uuidParam));
Expand Down
2 changes: 1 addition & 1 deletion src/services/APIServiceInterface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Track } from "../types/Track";
export interface APIServiceInterface {
getMediaUrl: (uuid: string, revision: string) => string;
loadTracksFromUUID: (uuid: string) => Promise<any>;
parseTracksFromJSON: (json: any) => Track[];
parseTracksFromJSON: (json: any) => Promise<Track[]>;
}

export interface DubbingAPIServiceInterface extends APIServiceInterface {
Expand Down
14 changes: 7 additions & 7 deletions src/services/DubbingAPIService.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ describe("DubbingAPIService", () => {
});

describe("parseTracksFromJSON", () => {
it("should parse valid dubbing JSON data", () => {
it("should parse valid dubbing JSON data", async () => {
const mockData: DubbingJSON[] = [
{
id: 1,
Expand All @@ -51,7 +51,7 @@ describe("DubbingAPIService", () => {
},
];

const result = DubbingAPIService.parseTracksFromJSON(mockData);
const result = await DubbingAPIService.parseTracksFromJSON(mockData);

expect(result).toHaveLength(1);
expect(result[0]).toEqual({
Expand Down Expand Up @@ -80,7 +80,7 @@ describe("DubbingAPIService", () => {
});
});

it("should handle missing properties", () => {
it("should handle missing properties", async () => {
const mockData = [
{
id: "1",
Expand All @@ -90,7 +90,7 @@ describe("DubbingAPIService", () => {
},
];

const result = DubbingAPIService.parseTracksFromJSON(mockData);
const result = await DubbingAPIService.parseTracksFromJSON(mockData);

expect(result).toHaveLength(1);
expect(result[0]).toEqual(
Expand All @@ -117,12 +117,12 @@ describe("DubbingAPIService", () => {
});
});

it("should throw an error for invalid input", () => {
it("should throw an error for invalid input", async () => {
const mockData = { invalid: "data" };

expect(() =>
await expect(
DubbingAPIService.parseTracksFromJSON(mockData as any)
).toThrow("utterances is not iterable");
).rejects.toThrow("utterances is not iterable");
});
});

Expand Down
8 changes: 5 additions & 3 deletions src/services/DubbingAPIService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,15 @@ export const loadTracksFromUUID = async (
return response.json();
};

export const parseTracksFromJSON = (utterances: DubbingJSON[]): Track[] => {
export const parseTracksFromJSON = async (
utterances: DubbingJSON[]
): Promise<Track[]> => {
const tracks: Track[] = [];
for (const utterance of utterances) {
speakerService.setSpeaker({
await speakerService.setSpeaker({
id: utterance.speaker_id,
name: `${getI18n().t("speaker")} ${utterance.speaker_id.slice(-2)}`,
voice: matxaSynthesisProvider.getVoice(utterance.assigned_voice),
voice: await matxaSynthesisProvider.getVoice(utterance.assigned_voice),
});

const text = utterance.text || "";
Expand Down
138 changes: 138 additions & 0 deletions src/services/MatxaSynthesisProvider.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import { MatxaSynthesisProvider, MatxaVoice } from "./MatxaSynthesisProvider";

describe("MatxaSynthesisProvider", () => {
let provider: MatxaSynthesisProvider;
const mockApiVoices = [
{
name: "test-central",
id: "99",
gender: "male",
language: "cat",
region: "central",
},
];

beforeEach(() => {
// Reset fetch mock before each test
jest.resetAllMocks();
// Create a new instance for each test
provider = new MatxaSynthesisProvider();
});

describe("voices()", () => {
it("should fetch and return voices from API", async () => {
global.fetch = jest.fn().mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockApiVoices),
});

const voices = await provider.voices();

expect(global.fetch).toHaveBeenCalledWith(
expect.stringContaining("/voices/")
);
expect(voices).toHaveLength(1);
expect(voices[0]).toEqual({
...mockApiVoices[0],
provider: "matxa",
label: "Home - Central (Test)",
});
});

it("should return default voices when API fails", async () => {
global.fetch = jest.fn().mockRejectedValueOnce(new Error("API Error"));

const voices = await provider.voices();

expect(voices).toHaveLength(8); // Default voice list length
expect(voices[0].name).toBe("quim-balear");
});
});

describe("speak()", () => {
const mockVoice: MatxaVoice = {
id: "1",
provider: "matxa",
language: "cat",
region: "central",
name: "test-central",
gender: "male",
label: "Home - Central (Test)",
};

const mockText = "Hello world";

it("should call API with correct parameters and return ArrayBuffer", async () => {
const mockArrayBuffer = new ArrayBuffer(8);
global.fetch = jest.fn().mockResolvedValueOnce({
ok: true,
arrayBuffer: () => Promise.resolve(mockArrayBuffer),
});

const result = await provider.speak(mockText, mockVoice);

expect(global.fetch).toHaveBeenCalledWith(
expect.stringContaining(
`/speak/?text=${encodeURIComponent(mockText)}&voice=${mockVoice.id}`
)
);
expect(result).toBe(mockArrayBuffer);
});

it("should throw error when API call fails", async () => {
global.fetch = jest.fn().mockResolvedValueOnce({
ok: false,
status: 500,
});

await expect(provider.speak(mockText, mockVoice)).rejects.toThrow(
"HTTP error! status: 500"
);
});

it("should throw error when network fails", async () => {
global.fetch = jest
.fn()
.mockRejectedValueOnce(new Error("Network error"));

await expect(provider.speak(mockText, mockVoice)).rejects.toThrow(
"Network error"
);
});
});

describe("getVoice()", () => {
const mockApiVoices = [
{
name: "test-central",
id: "99",
gender: "male",
language: "cat",
region: "central",
},
];

beforeEach(() => {
global.fetch = jest.fn().mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockApiVoices),
});
});

it("should return voice by id", async () => {
const voice = await provider.getVoice("1");
expect(voice.name).toBe("test-central");
});

it("should return first voice when id not found", async () => {
const voice = await provider.getVoice("non-existent");
expect(voice.name).toBe("test-central");
});
});

describe("getProviderName()", () => {
it("should return correct provider name", () => {
expect(provider.getProviderName()).toBe("matxa");
});
});
});
55 changes: 50 additions & 5 deletions src/services/MatxaSynthesisProvider.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
import { Voice } from "../types/Voice";
import { SynthesisProvider } from "./SynthesisService";

interface MatxaVoice extends Voice {
export interface MatxaVoice extends Voice {
language: string;
region: string;
name: string;
}

interface MatxaVoiceResponse {
gender: string;
id: string;
language: string;
name: string;
region: string;
}

const API_BASE_URL =
process.env.MATXA_API_BASE_URL ||
"https://api.softcatala.org/dubbing-service/v1";

class MatxaSynthesisProvider implements SynthesisProvider {
export class MatxaSynthesisProvider implements SynthesisProvider {
private providerName = "matxa";
private voiceList: MatxaVoice[] = [
private defaultVoiceList: MatxaVoice[] = [
{
name: "quim-balear",
id: "0",
Expand Down Expand Up @@ -87,8 +95,42 @@ class MatxaSynthesisProvider implements SynthesisProvider {
provider: this.providerName,
},
];
private voiceList: MatxaVoice[] = [];

private async fetchVoices(): Promise<void> {
try {
const response = await fetch(`${API_BASE_URL}/voices/`);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const data: MatxaVoiceResponse[] = await response.json();

async voices(): Promise<Voice[]> {
this.voiceList = data.map((voice) => ({
...voice,
provider: this.providerName,
label: this.createVoiceLabel(voice),
}));
} catch (error) {
console.error("Error fetching voices:", error);
this.voiceList = this.defaultVoiceList;
}
}

private createVoiceLabel(voice: MatxaVoiceResponse): string {
const gender = voice.gender === "male" ? "Home" : "Dona";
const region = this.capitalizeFirstLetter(voice.region);
const name = this.capitalizeFirstLetter(voice.name.split("-")[0]);
return `${gender} - ${region} (${name})`;
}

private capitalizeFirstLetter(str: string): string {
return str.charAt(0).toUpperCase() + str.slice(1);
}

async voices(): Promise<MatxaVoice[]> {
if (this.voiceList.length === 0) {
await this.fetchVoices();
}
return this.voiceList;
}

Expand All @@ -109,7 +151,10 @@ class MatxaSynthesisProvider implements SynthesisProvider {
}
}

getVoice(id: string): Voice {
async getVoice(id: string): Promise<MatxaVoice> {
if (this.voiceList.length === 0) {
await this.fetchVoices();
}
return this.voiceList.find((voice) => voice.id === id) || this.voiceList[0];
}

Expand Down
15 changes: 8 additions & 7 deletions src/services/SpeakerService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,26 @@ const getRandomColor = () => {
class SpeakerService {
private speakers: Speaker[] = [];

setSpeakers(speakersData: Partial<Speaker>[]): void {
this.speakers = speakersData
.map((speaker) => ({
async setSpeakers(speakersData: Partial<Speaker>[]): Promise<void> {
this.speakers = await Promise.all(
speakersData.map(async (speaker) => ({
id: speaker.id || uuidv4(),
name: speaker.name || "",
voice: speaker.voice || matxaSynthesisProvider.getVoice("0"),
voice: speaker.voice || (await matxaSynthesisProvider.getVoice("0")),
color: speaker.color || getRandomColor(),
}))
.sort((a, b) => a.name.localeCompare(b.name));
);
this.sortSpeakers();
}

setSpeaker(speaker: Partial<Speaker> & { id: string }): void {
async setSpeaker(speaker: Partial<Speaker> & { id: string }): Promise<void> {
if (this.speakers.find((s) => s.id === speaker.id)) {
this.updateSpeaker(speaker.id, speaker);
} else {
this.speakers.push({
id: speaker.id,
name: speaker.name || "",
voice: speaker.voice || matxaSynthesisProvider.getVoice("0"),
voice: speaker.voice || (await matxaSynthesisProvider.getVoice("0")),
color: speaker.color || getRandomColor(),
});
}
Expand Down
2 changes: 1 addition & 1 deletion src/services/TranscriptionAPIService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export const loadTracksFromUUID = async (uuid: string): Promise<any> => {
return response.json();
};

export const parseTracksFromJSON = (json: any): Track[] => {
export const parseTracksFromJSON = async (json: any): Promise<Track[]> => {
return json.segments.map((segment: any) => ({
id: segment.id,
start: segment.start,
Expand Down

0 comments on commit 5654b4c

Please sign in to comment.