diff --git a/src/wallet/test/fuzz/scriptpubkeyman.cpp b/src/wallet/test/fuzz/scriptpubkeyman.cpp index 091d42f6cf5287..dc9cc61f8066c2 100644 --- a/src/wallet/test/fuzz/scriptpubkeyman.cpp +++ b/src/wallet/test/fuzz/scriptpubkeyman.cpp @@ -49,6 +49,13 @@ void initialize_spkm() MOCKED_DESC_CONVERTER.Init(); } +void initialize_spkm_migration() +{ + static const auto testing_setup{MakeNoLogFileContext()}; + g_setup = testing_setup.get(); + SelectParams(ChainType::MAIN); +} + /** * Key derivation is expensive. Deriving deep derivation paths take a lot of compute and we'd rather spend time * elsewhere in this target, like on actually fuzzing the DescriptorScriptPubKeyMan. So rule out strings which could @@ -200,5 +207,104 @@ FUZZ_TARGET(scriptpubkeyman, .init = initialize_spkm) (void)spk_manager->GetKeyPoolSize(); } +FUZZ_TARGET(spkm_migration, .init = initialize_spkm_migration) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + const auto& node{g_setup->m_node}; + Chainstate& chainstate{node.chainman->ActiveChainstate()}; + std::unique_ptr wallet_ptr{std::make_unique(node.chain.get(), "", CreateMockableWalletDatabase())}; + CWallet& wallet{*wallet_ptr}; + wallet.m_keypool_size = 1; + { + LOCK(wallet.cs_wallet); + wallet.SetLastBlockProcessed(chainstate.m_chain.Height(), chainstate.m_chain.Tip()->GetBlockHash()); + } + + auto& legacy_data{*wallet.GetOrCreateLegacyDataSPKM()}; + + std::vector keys; + LIMITED_WHILE(fuzzed_data_provider.ConsumeBool(), 30) { + const auto key{ConsumePrivateKey(fuzzed_data_provider)}; + if (!key.IsValid()) return; + auto pub_key{key.GetPubKey()}; + if (!pub_key.IsFullyValid()) return; + if (legacy_data.LoadKey(key, pub_key) && std::find(keys.begin(), keys.end(), key) == keys.end()) keys.push_back(key); + } + + if (keys.empty()) return; + + auto hd_key{PickValue(fuzzed_data_provider, keys)}; + bool add_hd_chain{fuzzed_data_provider.ConsumeBool()}; + CHDChain hd_chain; + if (add_hd_chain) { + hd_chain.nVersion = fuzzed_data_provider.ConsumeBool() ? CHDChain::VERSION_HD_CHAIN_SPLIT : CHDChain::VERSION_HD_BASE; + hd_chain.seed_id = hd_key.GetPubKey().GetID(); + legacy_data.LoadHDChain(hd_chain); + } + + int script_count{0}; + bool good_data{true}; + LIMITED_WHILE(good_data && fuzzed_data_provider.ConsumeBool(), 30) { + CallOneOf( + fuzzed_data_provider, + [&] { + CKey private_key{ConsumePrivateKey(fuzzed_data_provider)}; + if (!private_key.IsValid()) return; + const auto& dest{GetDestinationForKey(private_key.GetPubKey(), OutputType::LEGACY)}; + (void)legacy_data.LoadWatchOnly(GetScriptForDestination(dest)); + }, + [&] { + CScript script; + auto key{PickValue(fuzzed_data_provider, keys)}; + auto pub_key{key.GetPubKey()}; + bool key_hash{false}; + if (fuzzed_data_provider.ConsumeBool()) { + script = GetScriptForDestination(CTxDestination{PKHash(pub_key)}); + } else { + key_hash = true; + script = GetScriptForDestination(WitnessV0KeyHash(pub_key)); + } + if (legacy_data.AddCScript(script) && !key_hash) script_count++; + }, + [&] { + auto key{PickValue(fuzzed_data_provider, keys)}; + const auto num_keys{fuzzed_data_provider.ConsumeIntegralInRange(1, MAX_PUBKEYS_PER_MULTISIG)}; + std::vector pubkeys; + size_t known_keys{0}; + for (size_t i = 0; i < num_keys; i++) { + if (fuzzed_data_provider.ConsumeBool()) { + known_keys++; + assert(!keys.empty()); + pubkeys.emplace_back(key.GetPubKey()); + } else { + CKey private_key{ConsumePrivateKey(fuzzed_data_provider)}; + if (!private_key.IsValid()) return; + pubkeys.emplace_back(private_key.GetPubKey()); + if (std::find(keys.begin(), keys.end(), private_key) != keys.end()) known_keys++; + } + } + if (pubkeys.size() < num_keys) return; + const auto multisig_script{GetScriptForMultisig(num_keys, pubkeys)}; + if (legacy_data.AddCScript(multisig_script) && known_keys == num_keys) script_count++; + } + ); + } + + std::optional res{legacy_data.MigrateToDescriptor()}; + assert(res->desc_spkms.size() == (keys.size() + size_t{add_hd_chain} + script_count)); + std::vector ids; + for (auto& spkm : res->desc_spkms) { + // Duplicate descriptors should not be created during migration. + assert(std::find(ids.begin(), ids.end(), spkm->GetID()) == ids.end()); + ids.push_back(spkm->GetID()); + } + + if (add_hd_chain) { + CExtKey master_key; + master_key.SetSeed(hd_key); + assert(res->master_key == master_key); + } +} + } // namespace } // namespace wallet