diff --git a/src/Wallet.sol b/src/Wallet.sol index 7fc7cca..5238665 100644 --- a/src/Wallet.sol +++ b/src/Wallet.sol @@ -11,6 +11,7 @@ contract Wallet { mapping(address => Transaction[]) public transactions; mapping(address => mapping(address => uint256)) token_balance; + mapping(address => bool) public supportedTokens; struct Transaction { uint256 amount; @@ -28,22 +29,43 @@ contract Wallet { _; } + modifier onlySupportedToken(address _token) { + require(supportedTokens[_token], "Token not supported"); + _; + } + constructor(address _owner, address _worldID, address _usdt) { require(msg.sender != address(0), "zero address found"); owner = _owner; worldID = IWorldID(_worldID); usdt = IERC20(_usdt); + supportedTokens[_usdt] = true; // Add USDT as a default supported token } function createWorldId() external onlyOwner {} - function transfer(address _recipient, uint256 _amount) external onlyVerified(msg.sender) { + function transfer(address _recipient, address _token, uint256 _amount) + external + onlyVerified(msg.sender) + onlySupportedToken(_token) + { require(_recipient != address(0), "Zero address detected"); - require(usdt.balanceOf(msg.sender) >= _amount, "Insufficient balance"); + IERC20 token = IERC20(_token); + require(token.balanceOf(msg.sender) >= _amount, "Insufficient balance"); require(_amount > 0, "Transfer amount must be greater than zero"); - require(usdt.transferFrom(msg.sender, _recipient, _amount), "Transfer failed"); + require(token.transferFrom(msg.sender, _recipient, _amount), "Transfer failed"); + + recordTransactionHistory(msg.sender, _amount, _token); + } + + function addSupportedToken(address _token) external onlyOwner { + require(_token != address(0), "Invalid token address"); + supportedTokens[_token] = true; + } - recordTransactionHistory(msg.sender, _amount, address(usdt)); + function removeSupportedToken(address _token) external onlyOwner { + require(supportedTokens[_token], "Token not supported"); + supportedTokens[_token] = false; } ////////////////////////////////////////////// diff --git a/test/Transfer.t.sol b/test/Transfer.t.sol index 2a19af8..a5f5083 100644 --- a/test/Transfer.t.sol +++ b/test/Transfer.t.sol @@ -121,7 +121,7 @@ contract USDTTransferTest is Test { // Perform the transfer vm.prank(user1); - transferContract.transfer(user2, transferAmount); + transferContract.transfer(user2, address(mockUSDT), transferAmount); // Assert balances after transfer assertEq(mockUSDT.balanceOf(user1), user1BalanceBefore - transferAmount); @@ -134,7 +134,7 @@ contract USDTTransferTest is Test { vm.prank(user1); vm.expectRevert("Insufficient balance"); - transferContract.transfer(user2, transferAmount); + transferContract.transfer(user2, address(mockUSDT), transferAmount); } // Test transfer function - Unverified user @@ -146,7 +146,7 @@ contract USDTTransferTest is Test { vm.prank(user1); vm.expectRevert("User not verified"); - transferContract.transfer(user2, transferAmount); + transferContract.transfer(user2, address(mockUSDT), transferAmount); } // Test transfer function - Transfer to unverified recipient @@ -157,7 +157,7 @@ contract USDTTransferTest is Test { vm.expectRevert("User not verified"); - transferContract.transfer(user2, transferAmount); + transferContract.transfer(user2, address(mockUSDT), transferAmount); } // Test transfer function - Multiple consecutive transfers @@ -171,7 +171,7 @@ contract USDTTransferTest is Test { // First transfer vm.prank(user1); - transferContract.transfer(user2, transferAmount1); + transferContract.transfer(user2, address(mockUSDT), transferAmount1); // Assert balances after first transfer assertEq(mockUSDT.balanceOf(user1), user1BalanceBefore - transferAmount1); @@ -182,7 +182,7 @@ contract USDTTransferTest is Test { user2BalanceBefore = mockUSDT.balanceOf(user2); vm.prank(user1); - transferContract.transfer(user2, transferAmount2); + transferContract.transfer(user2, address(mockUSDT), transferAmount2); assertEq(mockUSDT.balanceOf(user1), user1BalanceBefore - transferAmount2); assertEq(mockUSDT.balanceOf(user2), user2BalanceBefore + transferAmount2); @@ -195,6 +195,6 @@ contract USDTTransferTest is Test { vm.prank(user1); vm.expectRevert("Transfer amount must be greater than zero"); - transferContract.transfer(user2, transferAmount); + transferContract.transfer(user2, address(mockUSDT), transferAmount); } } diff --git a/test/Wallet.t.sol b/test/Wallet.t.sol index 7e6f61a..c9f0f09 100644 --- a/test/Wallet.t.sol +++ b/test/Wallet.t.sol @@ -16,7 +16,9 @@ contract WalletTest is Test { address public owner; address public user1; address public user2; + address public nonOwner; MockERC20 public usdt; + MockERC20 public anotherToken; MockWorldID public worldID; /// @notice Set up the test environment before each test @@ -24,7 +26,10 @@ contract WalletTest is Test { owner = address(this); user1 = address(0x1); user2 = address(0x2); + nonOwner = address(0x999); usdt = new MockERC20("USDT", "USDT"); + anotherToken = new MockERC20("Another Token", "ATKN"); + worldID = new MockWorldID(); factory = new WalletFactory(); @@ -48,7 +53,7 @@ contract WalletTest is Test { /// @notice Test recording a single transaction function testRecordSingleTransaction() public { vm.prank(user1); - wallet.transfer(user2, 100); + wallet.transfer(user2, address(usdt), 100); vm.prank(user1); Wallet.Transaction[] memory history = wallet.getTransactionHistory(user1); @@ -60,8 +65,8 @@ contract WalletTest is Test { /// @notice Test recording multiple transactions function testRecordMultipleTransactions() public { vm.startPrank(user1); - wallet.transfer(user2, 100); - wallet.transfer(user2, 200); + wallet.transfer(user2, address(usdt), 100); + wallet.transfer(user2, address(usdt), 200); Wallet.Transaction[] memory history = wallet.getTransactionHistory(user1); vm.stopPrank(); @@ -76,13 +81,13 @@ contract WalletTest is Test { /// @notice Test recording transactions for different users function testRecordTransactionsForDifferentUsers() public { vm.prank(user1); - wallet.transfer(user2, 100); + wallet.transfer(user2, address(usdt), 100); vm.prank(user1); - wallet.transfer(user2, 50); + wallet.transfer(user2, address(usdt), 50); vm.prank(user2); - wallet.transfer(user1, 50); + wallet.transfer(user1, address(usdt), 50); vm.prank(user1); Wallet.Transaction[] memory user1History = wallet.getTransactionHistory(user1); @@ -103,7 +108,7 @@ contract WalletTest is Test { vm.startPrank(user1); usdt.approve(address(wallet), largeAmount); - wallet.transfer(user2, largeAmount); + wallet.transfer(user2, address(usdt), largeAmount); Wallet.Transaction[] memory history = wallet.getTransactionHistory(user1); vm.stopPrank(); @@ -112,4 +117,43 @@ contract WalletTest is Test { assertEq(history[0].amount, largeAmount, "Recorded amount does not match"); assertEq(history[0].token, address(usdt), "Recorded token address does not match"); } + + function testAddSupportedTokenByOwner() public { + vm.prank(owner); // Set the caller as the owner + wallet.addSupportedToken(address(anotherToken)); + + assertTrue(wallet.supportedTokens(address(anotherToken))); + } + + function testAddSupportedTokenByNonOwnerReverts() public { + vm.prank(nonOwner); + vm.expectRevert("not owner"); + wallet.addSupportedToken(address(anotherToken)); + } + + function testRemoveSupportedTokenByOwner() public { + vm.startPrank(owner); // Start a transaction as the owner + wallet.addSupportedToken(address(anotherToken)); + assertTrue(wallet.supportedTokens(address(anotherToken))); + + wallet.removeSupportedToken(address(anotherToken)); + assertFalse(wallet.supportedTokens(address(anotherToken))); + vm.stopPrank(); + } + + function testRemoveSupportedTokenByNonOwnerReverts() public { + vm.startPrank(owner); + wallet.addSupportedToken(address(anotherToken)); + vm.stopPrank(); + + vm.prank(nonOwner); + vm.expectRevert("not owner"); + wallet.removeSupportedToken(address(anotherToken)); + } + + function testRemoveNonSupportedTokenReverts() public { + vm.prank(owner); + vm.expectRevert("Token not supported"); + wallet.removeSupportedToken(address(anotherToken)); + } }