Skip to content

Commit 04a3f65

Browse files
committed
merge bitcoin#23721: Move restorewallet() logic to the wallet section
1 parent e47d5ac commit 04a3f65

File tree

8 files changed

+88
-30
lines changed

8 files changed

+88
-30
lines changed

src/interfaces/wallet.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ class WalletLoader : public ChainClient
346346
//! Return default wallet directory.
347347
virtual std::string getWalletDir() = 0;
348348

349+
//! Restore backup wallet
350+
virtual std::unique_ptr<Wallet> restoreWallet(const std::string& backup_file, const std::string& wallet_name, bilingual_str& error, std::vector<bilingual_str>& warnings) = 0;
351+
349352
//! Return available wallets in wallet directory.
350353
virtual std::vector<std::string> listWalletDir() = 0;
351354

src/rpc/protocol.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ enum RPCErrorCode
8181
RPC_WALLET_NOT_FOUND = -18, //!< Invalid wallet specified
8282
RPC_WALLET_NOT_SPECIFIED = -19, //!< No wallet specified (error when there are multiple wallets loaded)
8383
RPC_WALLET_ALREADY_LOADED = -35, //!< This same wallet is already loaded
84+
RPC_WALLET_ALREADY_EXISTS = -36, //!< There is already a wallet with the same name
8485

8586

8687
//! Backwards compatible aliases

src/wallet/db.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ enum class DatabaseStatus {
222222
FAILED_LOAD,
223223
FAILED_VERIFY,
224224
FAILED_ENCRYPT,
225+
FAILED_INVALID_BACKUP_FILE,
225226
};
226227

227228
/** Recursively list database paths in directory. */

src/wallet/interfaces.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,12 @@ class WalletLoaderImpl : public WalletLoader
607607
assert(m_context.m_coinjoin_loader);
608608
return MakeWallet(LoadWallet(*m_context.chain, *m_context.m_coinjoin_loader, name, true /* load_on_start */, options, status, error, warnings));
609609
}
610+
std::unique_ptr<Wallet> restoreWallet(const std::string& backup_file, const std::string& wallet_name, bilingual_str& error, std::vector<bilingual_str>& warnings) override
611+
{
612+
DatabaseStatus status;
613+
assert(m_context.m_coinjoin_loader);
614+
return MakeWallet(RestoreWallet(*m_context.chain, *m_context.m_coinjoin_loader, backup_file, wallet_name, /*load_on_start=*/true, status, error, warnings));
615+
}
610616
std::string getWalletDir() override
611617
{
612618
return GetWalletDir().string();

src/wallet/rpcwallet.cpp

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2716,16 +2716,8 @@ static RPCHelpMan listwallets()
27162716
};
27172717
}
27182718

2719-
static std::tuple<std::shared_ptr<CWallet>, std::vector<bilingual_str>> LoadWalletHelper(WalletContext& context, UniValue load_on_start_param, const std::string wallet_name)
2719+
void HandleWalletError(const std::shared_ptr<CWallet> wallet, DatabaseStatus& status, bilingual_str& error)
27202720
{
2721-
DatabaseOptions options;
2722-
DatabaseStatus status;
2723-
options.require_existing = true;
2724-
bilingual_str error;
2725-
std::vector<bilingual_str> warnings;
2726-
std::optional<bool> load_on_start = load_on_start_param.isNull() ? std::nullopt : std::optional<bool>(load_on_start_param.get_bool());
2727-
std::shared_ptr<CWallet> const wallet = LoadWallet(*context.chain, *context.m_coinjoin_loader, wallet_name, load_on_start, options, status, error, warnings);
2728-
27292721
if (!wallet) {
27302722
// Map bad format to not found, since bad format is returned when the
27312723
// wallet directory exists, but doesn't contain a data file.
@@ -2738,13 +2730,17 @@ static std::tuple<std::shared_ptr<CWallet>, std::vector<bilingual_str>> LoadWall
27382730
case DatabaseStatus::FAILED_ALREADY_LOADED:
27392731
code = RPC_WALLET_ALREADY_LOADED;
27402732
break;
2733+
case DatabaseStatus::FAILED_ALREADY_EXISTS:
2734+
code = RPC_WALLET_ALREADY_EXISTS;
2735+
break;
2736+
case DatabaseStatus::FAILED_INVALID_BACKUP_FILE:
2737+
code = RPC_INVALID_PARAMETER;
2738+
break;
27412739
default: // RPC_WALLET_ERROR is returned for all other cases.
27422740
break;
27432741
}
27442742
throw JSONRPCError(code, error.original);
27452743
}
2746-
2747-
return { wallet, warnings };
27482744
}
27492745

27502746
static RPCHelpMan upgradetohd()
@@ -2872,7 +2868,15 @@ static RPCHelpMan loadwallet()
28722868
WalletContext& context = EnsureWalletContext(request.context);
28732869
const std::string name(request.params[0].get_str());
28742870

2875-
auto [wallet, warnings] = LoadWalletHelper(context, request.params[1], name);
2871+
DatabaseOptions options;
2872+
DatabaseStatus status;
2873+
options.require_existing = true;
2874+
bilingual_str error;
2875+
std::vector<bilingual_str> warnings;
2876+
std::optional<bool> load_on_start = request.params[1].isNull() ? std::nullopt : std::optional<bool>(request.params[1].get_bool());
2877+
std::shared_ptr<CWallet> const wallet = LoadWallet(*context.chain, *context.m_coinjoin_loader, name, load_on_start, options, status, error, warnings);
2878+
2879+
HandleWalletError(wallet, status, error);
28762880

28772881
UniValue obj(UniValue::VOBJ);
28782882
obj.pushKV("name", wallet->GetName());
@@ -3072,27 +3076,17 @@ static RPCHelpMan restorewallet()
30723076

30733077
std::string backup_file = request.params[1].get_str();
30743078

3075-
if (!fs::exists(backup_file)) {
3076-
throw JSONRPCError(RPC_INVALID_PARAMETER, "Backup file does not exist");
3077-
}
3078-
30793079
std::string wallet_name = request.params[0].get_str();
30803080

3081-
const fs::path wallet_path = fsbridge::AbsPathJoin(GetWalletDir(), wallet_name);
3082-
3083-
if (fs::exists(wallet_path)) {
3084-
throw JSONRPCError(RPC_INVALID_PARAMETER, "Wallet name already exists.");
3085-
}
3086-
3087-
if (!TryCreateDirectories(wallet_path)) {
3088-
throw JSONRPCError(RPC_WALLET_ERROR, strprintf("Failed to create database path '%s'. Database already exists.", wallet_path.string()));
3089-
}
3081+
std::optional<bool> load_on_start = request.params[2].isNull() ? std::nullopt : std::optional<bool>(request.params[2].get_bool());
30903082

3091-
auto wallet_file = wallet_path / "wallet.dat";
3083+
DatabaseStatus status;
3084+
bilingual_str error;
3085+
std::vector<bilingual_str> warnings;
30923086

3093-
fs::copy_file(backup_file, wallet_file, fs::copy_option::fail_if_exists);
3087+
const std::shared_ptr<CWallet> wallet = RestoreWallet(*context.chain, *context.m_coinjoin_loader, backup_file, wallet_name, load_on_start, status, error, warnings);
30943088

3095-
auto [wallet, warnings] = LoadWalletHelper(context, request.params[2], wallet_name);
3089+
HandleWalletError(wallet, status, error);
30963090

30973091
UniValue obj(UniValue::VOBJ);
30983092
obj.pushKV("name", wallet->GetName());

src/wallet/wallet.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,38 @@ std::shared_ptr<CWallet> CreateWallet(interfaces::Chain& chain, interfaces::Coin
365365
return wallet;
366366
}
367367

368+
std::shared_ptr<CWallet> RestoreWallet(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& backup_file, const std::string& wallet_name, std::optional<bool> load_on_start, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings)
369+
{
370+
DatabaseOptions options;
371+
options.require_existing = true;
372+
373+
if (!fs::exists(backup_file)) {
374+
error = Untranslated("Backup file does not exist");
375+
status = DatabaseStatus::FAILED_INVALID_BACKUP_FILE;
376+
return nullptr;
377+
}
378+
379+
const fs::path wallet_path = fsbridge::AbsPathJoin(GetWalletDir(), wallet_name);
380+
381+
if (fs::exists(wallet_path) || !TryCreateDirectories(wallet_path)) {
382+
error = Untranslated(strprintf("Failed to create database path '%s'. Database already exists.", wallet_path.string()));
383+
status = DatabaseStatus::FAILED_ALREADY_EXISTS;
384+
return nullptr;
385+
}
386+
387+
auto wallet_file = wallet_path / "wallet.dat";
388+
fs::copy_file(backup_file, wallet_file, fs::copy_option::fail_if_exists);
389+
390+
auto wallet = LoadWallet(chain, coinjoin_loader, wallet_name, load_on_start, options, status, error, warnings);
391+
392+
if (!wallet) {
393+
fs::remove(wallet_file);
394+
fs::remove(wallet_path);
395+
}
396+
397+
return wallet;
398+
}
399+
368400
/** @defgroup mapWallet
369401
*
370402
* @{

src/wallet/wallet.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ std::vector<std::shared_ptr<CWallet>> GetWallets();
6262
std::shared_ptr<CWallet> GetWallet(const std::string& name);
6363
std::shared_ptr<CWallet> LoadWallet(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& name, std::optional<bool> load_on_start, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings);
6464
std::shared_ptr<CWallet> CreateWallet(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& name, std::optional<bool> load_on_start, DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings);
65+
std::shared_ptr<CWallet> RestoreWallet(interfaces::Chain& chain, interfaces::CoinJoin::Loader& coinjoin_loader, const std::string& backup_file, const std::string& wallet_name, std::optional<bool> load_on_start, DatabaseStatus& status, bilingual_str& error, std::vector<bilingual_str>& warnings);
6566
std::unique_ptr<interfaces::Handler> HandleLoadWallet(LoadWalletFn load_wallet);
6667
std::unique_ptr<WalletDatabase> MakeWalletDatabase(const std::string& name, const DatabaseOptions& options, DatabaseStatus& status, bilingual_str& error);
6768

test/functional/wallet_backup.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,17 +107,32 @@ def erase_three(self):
107107
os.remove(os.path.join(self.nodes[1].datadir, self.chain, 'wallets', self.default_wallet_name, self.wallet_data_filename))
108108
os.remove(os.path.join(self.nodes[2].datadir, self.chain, 'wallets', self.default_wallet_name, self.wallet_data_filename))
109109

110+
def restore_invalid_wallet(self):
111+
node = self.nodes[3]
112+
invalid_wallet_file = os.path.join(self.nodes[0].datadir, 'invalid_wallet_file.bak')
113+
open(invalid_wallet_file, 'a', encoding="utf8").write('invald wallet')
114+
wallet_name = "res0"
115+
not_created_wallet_file = os.path.join(node.datadir, self.chain, 'wallets', wallet_name)
116+
error_message = "Wallet file verification failed. Failed to load database path '{}'. Data is not in recognized format.".format(not_created_wallet_file)
117+
assert_raises_rpc_error(-18, error_message, node.restorewallet, wallet_name, invalid_wallet_file)
118+
assert not os.path.exists(not_created_wallet_file)
119+
110120
def restore_nonexistent_wallet(self):
111121
node = self.nodes[3]
112122
nonexistent_wallet_file = os.path.join(self.nodes[0].datadir, 'nonexistent_wallet.bak')
113123
wallet_name = "res0"
114124
assert_raises_rpc_error(-8, "Backup file does not exist", node.restorewallet, wallet_name, nonexistent_wallet_file)
125+
not_created_wallet_file = os.path.join(node.datadir, self.chain, 'wallets', wallet_name)
126+
assert not os.path.exists(not_created_wallet_file)
115127

116128
def restore_wallet_existent_name(self):
117129
node = self.nodes[3]
118-
wallet_file = os.path.join(self.nodes[0].datadir, 'wallet.bak')
130+
backup_file = os.path.join(self.nodes[0].datadir, 'wallet.bak')
119131
wallet_name = "res0"
120-
assert_raises_rpc_error(-8, "Wallet name already exists.", node.restorewallet, wallet_name, wallet_file)
132+
wallet_file = os.path.join(node.datadir, self.chain, 'wallets', wallet_name)
133+
error_message = "Failed to create database path '{}'. Database already exists.".format(wallet_file)
134+
assert_raises_rpc_error(-36, error_message, node.restorewallet, wallet_name, backup_file)
135+
assert os.path.exists(wallet_file)
121136

122137
def init_three(self):
123138
self.init_wallet(0)
@@ -179,6 +194,7 @@ def run_test(self):
179194
##
180195
self.log.info("Restoring wallets on node 3 using backup files")
181196

197+
self.restore_invalid_wallet()
182198
self.restore_nonexistent_wallet()
183199

184200
backup_file_0 = os.path.join(self.nodes[0].datadir, 'wallet.bak')
@@ -189,6 +205,10 @@ def run_test(self):
189205
self.nodes[3].restorewallet("res1", backup_file_1)
190206
self.nodes[3].restorewallet("res2", backup_file_2)
191207

208+
assert os.path.exists(os.path.join(self.nodes[3].datadir, self.chain, 'wallets', "res0"))
209+
assert os.path.exists(os.path.join(self.nodes[3].datadir, self.chain, 'wallets', "res1"))
210+
assert os.path.exists(os.path.join(self.nodes[3].datadir, self.chain, 'wallets', "res2"))
211+
192212
res0_rpc = self.nodes[3].get_wallet_rpc("res0")
193213
res1_rpc = self.nodes[3].get_wallet_rpc("res1")
194214
res2_rpc = self.nodes[3].get_wallet_rpc("res2")

0 commit comments

Comments
 (0)