From 703f5598c859d9fbcfd4651958b57e78ee59538e Mon Sep 17 00:00:00 2001 From: korrrba <88761781+gitcoindev@users.noreply.github.com> Date: Tue, 27 Feb 2024 13:57:03 +0100 Subject: [PATCH] feat: protect protocol upgradeability loss (#904) * feat: protect protocol upgradeability loss diamondCut function cannot be removed. Resolves: https://github.com/sherlock-audit/2023-12-ubiquity-judging/issues/21 * chore: explicit explanation of diamondCut function selector value As proposed during pull request review. --- .../src/dollar/libraries/LibDiamond.sol | 6 ++++++ .../contracts/test/diamond/DiamondTest.t.sol | 17 +++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/packages/contracts/src/dollar/libraries/LibDiamond.sol b/packages/contracts/src/dollar/libraries/LibDiamond.sol index 9a2546e15..1e093dcd8 100644 --- a/packages/contracts/src/dollar/libraries/LibDiamond.sol +++ b/packages/contracts/src/dollar/libraries/LibDiamond.sol @@ -321,6 +321,12 @@ library LibDiamond { _facetAddress != address(0), "LibDiamondCut: Can't remove function that doesn't exist" ); + // precomputed diamondCut function selector to save gas + // bytes4(keccak256(abi.encodeWithSignature("diamondCut((address,uint8,bytes4[])[],address,bytes)"))) == 0x1f931c1c + require( + _selector != bytes4(0x1f931c1c), + "LibDiamondCut: Can't remove diamondCut function" + ); // an immutable function is a function defined directly in a diamond require( _facetAddress != address(this), diff --git a/packages/contracts/test/diamond/DiamondTest.t.sol b/packages/contracts/test/diamond/DiamondTest.t.sol index 93b20f74c..1732c1552 100644 --- a/packages/contracts/test/diamond/DiamondTest.t.sol +++ b/packages/contracts/test/diamond/DiamondTest.t.sol @@ -287,6 +287,23 @@ contract TestDiamond is DiamondTestSetup { IMockFacet(address(diamondCutFacet)).functionB(); } + function testCutFacetShouldNotRemoveDiamondCutFunction() public { + FacetCut[] memory facetCut = new FacetCut[](1); + bytes4[] memory selectors = new bytes4[](1); + selectors[0] = diamondCutFacet.diamondCut.selector; + + facetCut[0] = FacetCut({ + facetAddress: address(0), + action: FacetCutAction.Remove, + functionSelectors: selectors + }); + + // try to remove diamondCut function + vm.prank(owner); + vm.expectRevert("LibDiamondCut: Can't remove diamondCut function"); + diamondCutFacet.diamondCut(facetCut, address(0x0), ""); + } + function testSelectors_ShouldBeAssociatedWithCorrectFacet() public { for (uint256 i; i < facetAddressList.length; i++) { if (compareStrings(facetNames[i], "DiamondCutFacet")) {