Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swan fixes upgrade #19

Merged
merged 5 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 51 additions & 10 deletions src/SwanAgent.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol";
import {LLMOracleTaskParameters} from "@firstbatch/dria-oracle-contracts/LLMOracleTask.sol";
import {Swan, SwanAgentPurchaseOracleProtocol, SwanAgentStateOracleProtocol} from "./Swan.sol";
import {SwanMarketParameters} from "./SwanManager.sol";
import {SwanArtifact} from "./SwanArtifact.sol";

/// @notice Factory contract to deploy Agent contracts.
/// @dev This saves from contract space for Swan.
Expand Down Expand Up @@ -51,6 +52,9 @@ contract SwanAgent is Ownable {
EVENTS
//////////////////////////////////////////////////////////////*/

/// @notice Emitted when an artifact is skipped.
event ItemSkipped(address indexed agent, address indexed artifact);

/// @notice Emitted when a state update request is made.
event StateRequest(uint256 indexed taskId, uint256 indexed round);

Expand Down Expand Up @@ -168,7 +172,7 @@ contract SwanAgent is Ownable {
/// @notice The minimum amount of money that the agent must leave within the contract.
/// @dev minFundAmount should be `amountPerRound + oracleFee` to be able to make requests.
function minFundAmount() public view returns (uint256) {
return amountPerRound + swan.getOracleFee();
return amountPerRound + 2 * swan.getOracleFee();
}

/// @notice Reads the best performing result for a given task id, and parses it as an array of addresses.
Expand Down Expand Up @@ -255,24 +259,37 @@ contract SwanAgent is Ownable {

// read oracle result using the latest task id for this round
bytes memory output = oracleResult(taskId);
// TODO: add try-catch (When solidity supports) to handle more data when revert
address[] memory artifacts = abi.decode(output, (address[]));

// we purchase each artifact returned
for (uint256 i = 0; i < artifacts.length; i++) {
address artifact = artifacts[i];

// must not exceed the roundly buy-limit
uint256 price = swan.getListingPrice(artifact);
spendings[round] += price;
if (spendings[round] > amountPerRound) {
revert BuyLimitExceeded(spendings[round], amountPerRound);

// skip artifacts that exceed budget instead of reverting
if (spendings[round] + price > amountPerRound) {
emit ItemSkipped(address(this), artifact);
continue;
}

// add to inventory
inventory[round].push(artifact);
// check approval
SwanArtifact artifactContract = SwanArtifact(artifact);
address seller = swan.getListing(artifact).seller;

if (!artifactContract.isApprovedForAll(seller, address(swan))) {
emit ItemSkipped(address(this), artifact);
continue;
}

// make the actual purchase
swan.purchase(artifact);
// try purchase for other potential failures
try swan.purchase(artifact) {
spendings[round] += price;
inventory[round].push(artifact);
} catch {
emit ItemSkipped(address(this), artifact);
continue;
}
}

// update taskId as completed
Expand Down Expand Up @@ -420,4 +437,28 @@ contract SwanAgent is Ownable {

amountPerRound = _amountPerRound;
}

/// @notice Withdraws all available funds within allowable limits
/// @dev Withdraws maximum possible amount while respecting minFundAmount requirements
function withdrawAll() external onlyAuthorized {
(, Phase phase,) = getRoundPhase();
uint256 balance = treasury();

if (phase != Phase.Withdraw) {
// Must leave minFundAmount in non-withdraw phase
if (balance > minFundAmount()) {
uint256 withdrawable = balance - minFundAmount();
swan.token().transfer(owner(), withdrawable);
}
} else {
// Can withdraw everything in withdraw phase
swan.token().transfer(owner(), balance);
}
}

/// @notice Get the inventory for a specific round
/// @param round The queried round
function getInventory(uint256 round) public view returns (address[] memory) {
return inventory[round];
}
}
105 changes: 103 additions & 2 deletions test/SwanTest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ contract SwanTest is Helper {
}

/// @notice Agent cannot spend more than amountPerRound per round
function test_RevertWhen_PurchaseMoreThanAmountPerRound()
function test_PurchaseOnlyWithinAmountPerRound()
external
fund
createAgents
Expand Down Expand Up @@ -172,8 +172,109 @@ contract SwanTest is Helper {
safeValidate(validators[0], 1);

vm.prank(_agentOwnerToFail);
vm.expectRevert(abi.encodeWithSelector(SwanAgent.BuyLimitExceeded.selector, artifactPrice * 2, amountPerRound));
_agentToFail.purchase();

// Verify spending didn't exceed amountPerRound
assertLe(_agentToFail.spendings(currRound), _agentToFail.amountPerRound());

// Get purchased artifacts
address[] memory purchasedArtifacts = _agentToFail.getInventory(currRound);
assertLt(purchasedArtifacts.length, output.length, "Should not purchase all artifacts if over budget");
}

function test_PurchaseWithSkippedItems()
external
fund
createAgents
sellersApproveToSwan
addValidatorsToWhitelist
registerOracles
listArtifacts(sellers[0], marketParameters.maxArtifactCount, address(agents[0]))
{
address _agentOwnerToFail = agentOwners[0];
SwanAgent _agentToFail = agents[0];

// Get artifacts and encode output early
address[] memory output = swan.getListedArtifacts(address(_agentToFail), currRound);
bytes memory encodedOutput = abi.encode(output);
address artifact1Addr = output[0];
address artifact3Addr = output[2];

// Fund agent with WETH for purchases
deal(address(token), address(_agentToFail), _agentToFail.amountPerRound() * 3);

// Approve WETH transfers
vm.startPrank(address(_agentToFail));
token.approve(address(swan), type(uint256).max);
vm.stopPrank();

vm.recordLogs();

// Revoke approval for artifact 1
vm.prank(sellers[0]);
SwanArtifact(artifact1Addr).setApprovalForAll(address(swan), false);

// Move to next round
uint256 nextRoundTime = _agentToFail.createdAt() + marketParameters.listingInterval
+ marketParameters.buyInterval + marketParameters.withdrawInterval;
vm.warp(nextRoundTime);

// Relist with higher price
uint256 overPrice = _agentToFail.amountPerRound() - 1;
vm.prank(sellers[0]);
swan.relist(output[1], address(_agentToFail), overPrice);

// Move to buy phase
vm.warp(nextRoundTime + marketParameters.listingInterval + 1);

// Make purchase request
vm.prank(_agentOwnerToFail);
_agentToFail.oraclePurchaseRequest(input, models);

safeRespond(generators[0], encodedOutput, 1);
safeRespond(generators[1], encodedOutput, 1);
safeValidate(validators[0], 1);

vm.prank(_agentOwnerToFail);
_agentToFail.purchase();

address[] memory purchasedArtifacts = _agentToFail.getInventory(1); // Round 1
assertTrue(purchasedArtifacts.length == 1, "Should purchase exactly one artifact");
assertTrue(purchasedArtifacts[0] == output[1], "Should have purchased artifact 2");

// Record logs and execute purchase
Vm.Log[] memory logs = vm.getRecordedLogs();
bool foundArtifact1Skip = false;
bool foundArtifact3Skip = false;

bytes32 skipEventSig = 0x3c44a811ea05c98efb27db6d3cbc9d4e7b0eb204b81047d92adfa387d3b0e818;
bytes32 soldEventSig = 0x7b1dae0d1aa5992cbf93242e4c807f1f27f69b51255335200caa21c7a6e5ab61;
bool foundArtifact2Sold = false;

for (uint256 i = 0; i < logs.length; i++) {
if (logs[i].topics.length > 0 && logs[i].topics[0] == skipEventSig) {
address artifact = address(uint160(uint256(logs[i].topics[2])));

if (artifact == artifact1Addr) {
foundArtifact1Skip = true;
} else if (artifact == artifact3Addr) {
foundArtifact3Skip = true;
}
}

if (logs[i].topics.length > 0 && logs[i].topics[0] == soldEventSig) {
// ArtifactSold(address owner, address agent, address artifact, uint256 price)
address artifact = address(uint160(uint256(logs[i].topics[3])));
if (artifact == output[1]) {
// artifact2 address
foundArtifact2Sold = true;
}
}
}

assertTrue(foundArtifact1Skip, "artifact 1 should be skipped");
assertTrue(foundArtifact2Sold, "artifact 2 should be purchased");
assertTrue(foundArtifact3Skip, "artifact 3 should be skipped");
}

/// @notice Agent can purchase
Expand Down