Skip to content

Commit

Permalink
✅ test: Add unit test for src/server/routers/lambda/aiProvider.ts (lo…
Browse files Browse the repository at this point in the history
…behub#5686)

* Add unit tests for aiProviderRouter functionality in aiProvider.test.ts

* Add export for aiProviderProcedure in aiProvider.ts.

* Update aiProvider.ts
  • Loading branch information
gru-agent[bot] authored and AirboZH committed Feb 14, 2025
1 parent d844a97 commit 81c80e1
Showing 1 changed file with 200 additions and 0 deletions.
200 changes: 200 additions & 0 deletions src/server/routers/lambda/aiProvider.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import { beforeEach, describe, expect, it, vi } from 'vitest';

import { AiInfraRepos } from '@/database/repositories/aiInfra';
import { serverDB } from '@/database/server';
import { AiProviderModel } from '@/database/server/models/aiProvider';
import { UserModel } from '@/database/server/models/user';
import { getServerGlobalConfig } from '@/server/globalConfig';
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
import { AiProviderDetailItem, AiProviderRuntimeState } from '@/types/aiProvider';

import { aiProviderRouter } from './aiProvider';

vi.mock('@/server/globalConfig');
vi.mock('@/server/modules/KeyVaultsEncrypt');
vi.mock('@/database/repositories/aiInfra');
vi.mock('@/database/server/models/aiProvider');
vi.mock('@/database/server/models/user');

describe('aiProviderRouter', () => {
const mockUserId = 'test-user-id';
const mockProviderId = 'test-provider-id';
const mockEncrypt = vi.fn();
const mockDecrypt = vi.fn();

const mockGateKeeper = {
encrypt: mockEncrypt,
decrypt: mockDecrypt,
};

const mockProviderDetail: AiProviderDetailItem = {
id: mockProviderId,
name: 'Test Provider',
enabled: true,
description: 'Test Description',
source: 'custom',
settings: {},
};

const mockRuntimeState: AiProviderRuntimeState = {
enabledAiModels: [],
enabledAiProviders: [],
runtimeConfig: {},
};

beforeEach(() => {
vi.clearAllMocks();

vi.mocked(getServerGlobalConfig).mockReturnValue({
aiProvider: {},
} as any);

vi.mocked(KeyVaultsGateKeeper.initWithEnvKey).mockResolvedValue(mockGateKeeper as any);
});

const createMockContext = () => ({
userId: mockUserId,
});

describe('createAiProvider', () => {
it('should create a new AI provider', async () => {
const mockCreate = vi.fn().mockResolvedValue({ id: mockProviderId });
vi.mocked(AiProviderModel).prototype.create = mockCreate;

const caller = aiProviderRouter.createCaller(createMockContext());
const result = await caller.createAiProvider({
id: mockProviderId,
name: 'Test Provider',
source: 'custom',
});

expect(result).toBe(mockProviderId);
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
id: mockProviderId,
name: 'Test Provider',
}),
mockGateKeeper.encrypt,
);
});
});

describe('getAiProviderById', () => {
it('should get AI provider by id', async () => {
const mockGetDetail = vi.fn().mockResolvedValue(mockProviderDetail);
vi.mocked(AiInfraRepos).prototype.getAiProviderDetail = mockGetDetail;

const caller = aiProviderRouter.createCaller(createMockContext());
const result = await caller.getAiProviderById({ id: mockProviderId });

expect(result).toEqual(mockProviderDetail);
expect(mockGetDetail).toHaveBeenCalledWith(
mockProviderId,
KeyVaultsGateKeeper.getUserKeyVaults,
);
});
});

describe('getAiProviderList', () => {
it('should get AI provider list', async () => {
const mockList = [mockProviderDetail];
const mockGetList = vi.fn().mockResolvedValue(mockList);
vi.mocked(AiInfraRepos).prototype.getAiProviderList = mockGetList;

const caller = aiProviderRouter.createCaller(createMockContext());
const result = await caller.getAiProviderList();

expect(result).toEqual(mockList);
expect(mockGetList).toHaveBeenCalled();
});
});

describe('getAiProviderRuntimeState', () => {
it('should get AI provider runtime state', async () => {
const mockGetState = vi.fn().mockResolvedValue(mockRuntimeState);
vi.mocked(AiInfraRepos).prototype.getAiProviderRuntimeState = mockGetState;

const caller = aiProviderRouter.createCaller(createMockContext());
const result = await caller.getAiProviderRuntimeState({});

expect(result).toEqual(mockRuntimeState);
expect(mockGetState).toHaveBeenCalledWith(KeyVaultsGateKeeper.getUserKeyVaults);
});
});

describe('removeAiProvider', () => {
it('should remove AI provider', async () => {
const mockDelete = vi.fn();
vi.mocked(AiProviderModel).prototype.delete = mockDelete;

const caller = aiProviderRouter.createCaller(createMockContext());
await caller.removeAiProvider({ id: mockProviderId });

expect(mockDelete).toHaveBeenCalledWith(mockProviderId);
});
});

describe('toggleProviderEnabled', () => {
it('should toggle provider enabled state', async () => {
const mockToggle = vi.fn();
vi.mocked(AiProviderModel).prototype.toggleProviderEnabled = mockToggle;

const caller = aiProviderRouter.createCaller(createMockContext());
await caller.toggleProviderEnabled({
id: mockProviderId,
enabled: true,
});

expect(mockToggle).toHaveBeenCalledWith(mockProviderId, true);
});
});

describe('updateAiProvider', () => {
it('should update AI provider', async () => {
const mockUpdate = vi.fn();
vi.mocked(AiProviderModel).prototype.update = mockUpdate;

const caller = aiProviderRouter.createCaller(createMockContext());
await caller.updateAiProvider({
id: mockProviderId,
value: { name: 'Updated Provider' },
});

expect(mockUpdate).toHaveBeenCalledWith(mockProviderId, {
name: 'Updated Provider',
});
});
});

describe('updateAiProviderConfig', () => {
it('should update AI provider config', async () => {
const mockUpdateConfig = vi.fn();
vi.mocked(AiProviderModel).prototype.updateConfig = mockUpdateConfig;

const caller = aiProviderRouter.createCaller(createMockContext());
await caller.updateAiProviderConfig({
id: mockProviderId,
value: { checkModel: 'gpt-4' },
});

expect(mockUpdateConfig).toHaveBeenCalledWith(
mockProviderId,
{ checkModel: 'gpt-4' },
mockGateKeeper.encrypt,
);
});
});

describe('updateAiProviderOrder', () => {
it('should update AI provider order', async () => {
const mockUpdateOrder = vi.fn();
vi.mocked(AiProviderModel).prototype.updateOrder = mockUpdateOrder;

const sortMap = [{ id: mockProviderId, sort: 1 }];
const caller = aiProviderRouter.createCaller(createMockContext());
await caller.updateAiProviderOrder({ sortMap });

expect(mockUpdateOrder).toHaveBeenCalledWith(sortMap);
});
});
});

0 comments on commit 81c80e1

Please sign in to comment.