diff --git a/.gemini/config.yaml b/.gemini/config.yaml
index 307f783561..0494774758 100644
--- a/.gemini/config.yaml
+++ b/.gemini/config.yaml
@@ -33,4 +33,4 @@ code_review:
# List of glob patterns to ignore (files and directories).
# Type: array of string, default: [].
-ignore_patterns: []
+ignore_patterns: ["deprecated.go"]
diff --git a/.gemini/styleguide.md b/.gemini/styleguide.md
index f0d52a1e1e..83bf86ec2a 100644
--- a/.gemini/styleguide.md
+++ b/.gemini/styleguide.md
@@ -1,4 +1,4 @@
-# LND Style Guide
+# Btcwallet Style Guide
## Code Documentation and Commenting
@@ -9,7 +9,7 @@
- Unit tests must always use the `require` library. Either table driven unit
tests or tests using the `rapid` library are preferred.
- The line length MUST NOT exceed 80 characters, this is very important.
- You must count the Golang indentation (tabulator character) as 8 spaces when
+ You must count the Golang indentation (tabulator character) as 4 spaces when
determining the line length. Use creative approaches or the wrapping rules
specified below to make sure the line length isn't exceeded.
- Every function must be commented with its purpose and assumptions.
@@ -151,7 +151,7 @@ if amt < 546 {
### 80 character line length
- Wrap columns at 80 characters.
-- Tabs are 8 spaces.
+- Tabs are 4 spaces.
**WRONG**
```go
diff --git a/.gitignore b/.gitignore
index 8dce5946b3..25a01ef540 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,3 +6,7 @@ coverage.txt
.vscode
.DS_Store
.aider*
+coverage.out
+*.prof
+*.test
+*cpu.out
diff --git a/.golangci.yml b/.golangci.yml
index ccd520f9b1..c457c439d3 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -10,6 +10,9 @@ run:
linters:
default: all
disable:
+ # TODO(yy): Re-enable this linter once the refactoring series is done.
+ - ireturn
+
# Global variables are used in many places throughout the code base.
- gochecknoglobals
@@ -91,6 +94,19 @@ linters:
multi-func: true
multi-if: true
+ wrapcheck:
+ ignore-sig-regexps:
+ # Allow returning .Err() from context.Context without wrapping it.
+ - context\.Context.*\.Err\(\)
+
+ gomoddirectives:
+ replace-local: true
+ replace-allow-list:
+ # This package will be downgrade to internal so we will import it
+ # directly here.
+ - github.com/btcsuite/btcwallet/wtxmgr
+
+
# Defines a set of rules to ignore issues.
# It does not skip the analysis, and so does not ignore "typecheck" errors.
exclusions:
@@ -112,6 +128,7 @@ linters:
paths:
- rpc/legacyrpc/
- wallet/deprecated.go
+ - wallet/deprecated_test.go
rules:
# Exclude gosec from running for tests so that tests with weak randomness
diff --git a/btcwallet.go b/btcwallet.go
index 898ab90e15..c709ac8d5e 100644
--- a/btcwallet.go
+++ b/btcwallet.go
@@ -237,10 +237,16 @@ func rpcClientConnectLoop(legacyRPCServer *legacyrpc.Server, loader *wallet.Load
loadedWallet.SetChainSynced(false)
// TODO: Rework the wallet so changing the RPC client
- // does not require stopping and restarting everything.
- loadedWallet.Stop()
+ //nolint:staticcheck // This should be fixed once
+ // the interface refactor is finished, and new wallet
+ // RPC is built.
+ loadedWallet.StopDeprecated()
loadedWallet.WaitForShutdown()
- loadedWallet.Start()
+
+ //nolint:staticcheck // This should be fixed once
+ // the interface refactor is finished, and new wallet
+ // RPC is built.
+ loadedWallet.StartDeprecated()
}
}
}
diff --git a/chain/bitcoind_client.go b/chain/bitcoind_client.go
index 0c397c8810..da874b6216 100644
--- a/chain/bitcoind_client.go
+++ b/chain/bitcoind_client.go
@@ -4,6 +4,7 @@ import (
"container/list"
"context"
"encoding/hex"
+ "encoding/json"
"errors"
"fmt"
"sync"
@@ -13,7 +14,10 @@ import (
"github.com/btcsuite/btcd/address/v2"
"github.com/btcsuite/btcd/btcjson"
"github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/gcs"
+ "github.com/btcsuite/btcd/btcutil/v2/gcs/builder"
"github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/rpcclient"
"github.com/btcsuite/btcd/txscript/v2"
"github.com/btcsuite/btcd/wire/v2"
"github.com/btcsuite/btcwallet/waddrmgr"
@@ -26,6 +30,10 @@ var (
// to receive a notification for a specific item and the bitcoind client
// is in the middle of shutting down.
ErrBitcoindClientShuttingDown = errors.New("client is shutting down")
+
+ // ErrOnlyBasicFilters is an error returned when a filter type other
+ // than basic is requested.
+ ErrOnlyBasicFilters = errors.New("only basic filters are supported")
)
// BitcoindClient represents a persistent client connection to a bitcoind server
@@ -50,6 +58,15 @@ type BitcoindClient struct {
// the RPC and ZMQ connections to a bitcoind node.
chainConn *BitcoindConn
+ // batchClient is a secondary RPC client dedicated for batch requests.
+ // This client is created specifically for batch operations because the
+ // rpcclient.Client in batch mode is stateful, accumulating requests
+ // until `Send()` is called. Using a dedicated instance avoids race
+ // conditions and ensures isolation from other concurrent RPC calls
+ // made by the main `chainConn.client` or other `BitcoindClient`
+ // instances.
+ batchClient *rpcclient.Client
+
// bestBlock keeps track of the tip of the current best chain.
bestBlockMtx sync.RWMutex
bestBlock waddrmgr.BlockStamp
@@ -114,6 +131,56 @@ func (c *BitcoindClient) BackEnd() string {
return "bitcoind"
}
+// GetCFilter returns a compact filter for the given block hash and filter
+// type.
+//
+// NOTE: This is part of the chain.Interface interface.
+func (c *BitcoindClient) GetCFilter(hash *chainhash.Hash,
+ filterType wire.FilterType) (*gcs.Filter, error) {
+
+ if filterType != wire.GCSFilterRegular {
+ return nil, ErrOnlyBasicFilters
+ }
+
+ // The getblockfilter RPC takes the block hash and the filter type.
+ // Filter type defaults to "basic" if omitted, but we specify it for
+ // clarity.
+ params := []json.RawMessage{
+ json.RawMessage(fmt.Sprintf("%q", hash.String())),
+ json.RawMessage(fmt.Sprintf("%q", "basic")),
+ }
+
+ resp, err := c.chainConn.client.RawRequest("getblockfilter", params)
+ if err != nil {
+ return nil, c.MapRPCErr(err)
+ }
+
+ var res struct {
+ Filter string `json:"filter"`
+ Header string `json:"header"`
+ }
+
+ err = json.Unmarshal(resp, &res)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal filter: %w", err)
+ }
+
+ filterBytes, err := hex.DecodeString(res.Filter)
+ if err != nil {
+ return nil, fmt.Errorf("failed to decode filter: %w", err)
+ }
+
+ filter, err := gcs.FromNBytes(
+ builder.DefaultP, builder.DefaultM, filterBytes,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("failed to create filter from bytes: %w",
+ err)
+ }
+
+ return filter, nil
+}
+
// GetBestBlock returns the highest block known to bitcoind.
func (c *BitcoindClient) GetBestBlock() (*chainhash.Hash, int32, error) {
bcinfo, err := c.chainConn.client.GetBlockChainInfo()
@@ -573,6 +640,9 @@ func (c *BitcoindClient) Stop() {
// prevent sending notifications to it after it's been stopped.
c.chainConn.RemoveClient(c.id)
+ c.batchClient.Shutdown()
+ c.batchClient.WaitForShutdown()
+
c.notificationQueue.Stop()
}
@@ -1185,8 +1255,13 @@ func (c *BitcoindClient) filterBlock(block *wire.MsgBlock, height int32,
// transaction.
blockDetails.Index = i
txDetails := btcutil.NewTx(tx)
+
+ // We disable individual transaction notifications here because
+ // the full set of relevant transactions will be dispatched
+ // atomically via FilteredBlockConnected at the end of block
+ // processing.
isRelevant, rec, err := c.filterTx(
- txDetails, blockDetails, notify,
+ txDetails, blockDetails, false,
)
if err != nil {
log.Warnf("Unable to filter transaction %v: %v",
@@ -1349,8 +1424,9 @@ func (c *BitcoindClient) filterTx(txDetails *btcutil.Tx,
c.mempool[*txDetails.Hash()] = struct{}{}
}
- c.onRelevantTx(rec, blockDetails)
-
+ if notify {
+ c.onRelevantTx(rec, blockDetails)
+ }
return true, rec, nil
}
@@ -1416,3 +1492,157 @@ func (c *BitcoindClient) updateWatchedFilters(update any) {
}
}
}
+
+// GetBlockHashes returns a slice of block hashes for the given height range.
+func (c *BitcoindClient) GetBlockHashes(startHeight,
+ endHeight int64) ([]chainhash.Hash, error) {
+
+ if startHeight > endHeight {
+ return nil, fmt.Errorf("%w: start height %d, end height %d",
+ ErrInvalidParam, startHeight, endHeight)
+ }
+
+ client := c.batchClient
+ count := endHeight - startHeight + 1
+ hashes := make([]chainhash.Hash, 0, count)
+ futures := make([]rpcclient.FutureGetBlockHashResult, 0, count)
+
+ for h := startHeight; h <= endHeight; h++ {
+ futures = append(futures, client.GetBlockHashAsync(h))
+ }
+
+ err := client.Send()
+ if err != nil {
+ return nil, fmt.Errorf("batch send: %w", err)
+ }
+
+ for _, f := range futures {
+ hash, err := f.Receive()
+ if err != nil {
+ return nil, fmt.Errorf("receive block hash: %w", err)
+ }
+
+ hashes = append(hashes, *hash)
+ }
+
+ return hashes, nil
+}
+
+// GetCFilters returns a slice of filters for the given block hashes.
+func (c *BitcoindClient) GetCFilters(hashes []chainhash.Hash,
+ filterType wire.FilterType) ([]*gcs.Filter, error) {
+
+ if filterType != wire.GCSFilterRegular {
+ return nil, ErrOnlyBasicFilters
+ }
+
+ client := c.batchClient
+ filters := make([]*gcs.Filter, 0, len(hashes))
+ futures := make([]rpcclient.FutureRawResult, 0, len(hashes))
+
+ for _, hash := range hashes {
+ params := []json.RawMessage{
+ json.RawMessage(fmt.Sprintf("%q", hash.String())),
+ json.RawMessage(fmt.Sprintf("%q", "basic")),
+ }
+ futures = append(futures, client.RawRequestAsync(
+ "getblockfilter", params,
+ ))
+ }
+
+ err := client.Send()
+ if err != nil {
+ return nil, fmt.Errorf("batch send: %w", err)
+ }
+
+ for _, f := range futures {
+ resp, err := f.Receive()
+ if err != nil {
+ return nil, fmt.Errorf("receive cfilter: %w", err)
+ }
+
+ var res struct {
+ Filter string `json:"filter"`
+ Header string `json:"header"`
+ }
+
+ err = json.Unmarshal(resp, &res)
+ if err != nil {
+ return nil, fmt.Errorf("unmarshal cfilter: %w", err)
+ }
+
+ filterBytes, err := hex.DecodeString(res.Filter)
+ if err != nil {
+ return nil, fmt.Errorf("decode cfilter: %w", err)
+ }
+
+ filter, err := gcs.FromNBytes(
+ builder.DefaultP, builder.DefaultM, filterBytes,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("parse cfilter: %w", err)
+ }
+
+ filters = append(filters, filter)
+ }
+
+ return filters, nil
+}
+
+// GetBlocks returns a slice of full blocks for the given block hashes.
+func (c *BitcoindClient) GetBlocks(hashes []chainhash.Hash) (
+ []*wire.MsgBlock, error) {
+
+ client := c.batchClient
+ blocks := make([]*wire.MsgBlock, 0, len(hashes))
+ futures := make([]rpcclient.FutureGetBlockResult, 0, len(hashes))
+
+ for _, hash := range hashes {
+ futures = append(futures, client.GetBlockAsync(&hash))
+ }
+
+ err := client.Send()
+ if err != nil {
+ return nil, fmt.Errorf("batch send: %w", err)
+ }
+
+ for _, f := range futures {
+ block, err := f.Receive()
+ if err != nil {
+ return nil, fmt.Errorf("receive block: %w", err)
+ }
+
+ blocks = append(blocks, block)
+ }
+
+ return blocks, nil
+}
+
+// GetBlockHeaders returns a slice of block headers for the given block hashes.
+func (c *BitcoindClient) GetBlockHeaders(hashes []chainhash.Hash) (
+ []*wire.BlockHeader, error) {
+
+ client := c.batchClient
+ headers := make([]*wire.BlockHeader, 0, len(hashes))
+ futures := make([]rpcclient.FutureGetBlockHeaderResult, 0, len(hashes))
+
+ for _, hash := range hashes {
+ futures = append(futures, client.GetBlockHeaderAsync(&hash))
+ }
+
+ err := client.Send()
+ if err != nil {
+ return nil, fmt.Errorf("batch send: %w", err)
+ }
+
+ for _, f := range futures {
+ header, err := f.Receive()
+ if err != nil {
+ return nil, fmt.Errorf("receive header: %w", err)
+ }
+
+ headers = append(headers, header)
+ }
+
+ return headers, nil
+}
diff --git a/chain/bitcoind_conn.go b/chain/bitcoind_conn.go
index 85b85b5e98..67d84a374e 100644
--- a/chain/bitcoind_conn.go
+++ b/chain/bitcoind_conn.go
@@ -401,13 +401,29 @@ func getCurrentNet(client *rpcclient.Client) (wire.BitcoinNet, error) {
// NewBitcoindClient returns a bitcoind client using the current bitcoind
// connection. This allows us to share the same connection using multiple
// clients.
-func (c *BitcoindConn) NewBitcoindClient() *BitcoindClient {
+func (c *BitcoindConn) NewBitcoindClient() (*BitcoindClient, error) {
+ clientCfg := &rpcclient.ConnConfig{
+ Host: c.cfg.Host,
+ User: c.cfg.User,
+ Pass: c.cfg.Pass,
+ DisableAutoReconnect: false,
+ DisableConnectOnNew: true,
+ DisableTLS: true,
+ HTTPPostMode: true,
+ }
+
+ batchClient, err := rpcclient.NewBatch(clientCfg)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create batch client: %w", err)
+ }
+
return &BitcoindClient{
quit: make(chan struct{}),
id: atomic.AddUint64(&c.rescanClientCounter, 1),
- chainConn: c,
+ chainConn: c,
+ batchClient: batchClient,
watchedAddresses: make(map[string]struct{}),
watchedOutPoints: make(map[wire.OutPoint]struct{}),
@@ -419,7 +435,7 @@ func (c *BitcoindConn) NewBitcoindClient() *BitcoindClient {
mempool: make(map[chainhash.Hash]struct{}),
expiredMempool: make(map[int32]map[chainhash.Hash]struct{}),
- }
+ }, nil
}
// AddClient adds a client to the set of active rescan clients of the current
diff --git a/chain/bitcoind_events_test.go b/chain/bitcoind_events_test.go
index d7aafcfdd8..fe174c212b 100644
--- a/chain/bitcoind_events_test.go
+++ b/chain/bitcoind_events_test.go
@@ -2,73 +2,110 @@ package chain
import (
"fmt"
- "math/rand"
"os/exec"
"testing"
"time"
"github.com/btcsuite/btcd/address/v2"
"github.com/btcsuite/btcd/btcec/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/gcs"
+ "github.com/btcsuite/btcd/btcutil/v2/gcs/builder"
"github.com/btcsuite/btcd/chaincfg/v2"
"github.com/btcsuite/btcd/chainhash/v2"
"github.com/btcsuite/btcd/integration/rpctest"
"github.com/btcsuite/btcd/rpcclient"
"github.com/btcsuite/btcd/txscript/v2"
"github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/chain/port"
"github.com/stretchr/testify/require"
)
-// TestBitcoindEvents ensures that the BitcoindClient correctly delivers tx and
-// block notifications for both the case where a ZMQ subscription is used and
-// for the case where RPC polling is used.
-func TestBitcoindEvents(t *testing.T) {
+const (
+ // defaultTestTimeout is the default timeout used for tests in this
+ // file. It is set to 30 seconds to allow for slow test environments.
+ defaultTestTimeout = 30 * time.Second
+)
+
+// TestBitcoindEventsZMQ runs all bitcoind event tests using ZMQ subscriptions.
+//
+// We cannot run these tests in parallel as it involves running multiple
+// bitcoind servers and btcd servers in the background. While running multiple
+// bitcoind servers is fine, the current integration test setup in `btcd`
+// doesn't allow it as the created RPC client will share the same ports.
+//
+//nolint:paralleltest
+func TestBitcoindEventsZMQ(t *testing.T) {
+ runBitcoindEventsTests(t, false)
+}
+
+// TestBitcoindEventsRPC runs all bitcoind event tests using RPC polling.
+//
+// We cannot run these tests in parallel as it involves running multiple
+// bitcoind servers and btcd servers in the background. While running multiple
+// bitcoind servers is fine, the current integration test setup in `btcd`
+// doesn't allow it as the created RPC client will share the same ports.
+//
+//nolint:paralleltest
+func TestBitcoindEventsRPC(t *testing.T) {
+ runBitcoindEventsTests(t, true)
+}
+
+// runBitcoindEventsTests runs the suite of bitcoind event tests with the
+// specified polling mode.
+func runBitcoindEventsTests(t *testing.T, rpcPolling bool) {
+ t.Helper()
+
tests := []struct {
- name string
- rpcPolling bool
+ name string
+ testFn func(*testing.T, *rpctest.Harness, *BitcoindClient)
}{
{
- name: "Events via ZMQ subscriptions",
- rpcPolling: false,
+ name: "Reorg",
+ testFn: testReorg,
+ },
+ {
+ name: "NotifyBlocks",
+ testFn: testNotifyBlocks,
+ },
+ {
+ name: "NotifyTx",
+ testFn: testNotifyTx,
+ },
+ {
+ name: "NotifySpentMempool",
+ testFn: testNotifySpentMempool,
},
{
- name: "Events via RPC Polling",
- rpcPolling: true,
+ name: "LookupInputMempoolSpend",
+ testFn: testLookupInputMempoolSpend,
+ },
+ {
+ name: "GetCFilter",
+ testFn: testBitcoindClientGetCFilter,
+ },
+ {
+ name: "Batch RPCs",
+ testFn: func(t *testing.T, h *rpctest.Harness,
+ bc *BitcoindClient) {
+
+ t.Helper()
+ testInterfaceBatchMethods(t, h, bc)
+ },
},
}
for _, test := range tests {
test := test
+ t.Run(test.name, func(t *testing.T) {
+ // Initialize a fresh miner for the test case.
+ miner1 := setupMiner(t)
+ addr := miner1.P2PAddress()
- // Set up 2 btcd miners.
- miner1, miner2 := setupMiners(t)
- addr := miner1.P2PAddress()
+ // Initialize a fresh bitcoind client for EVERY test
+ // case.
+ btcClient := setupBitcoind(t, addr, rpcPolling)
- t.Run(test.name, func(t *testing.T) {
- // Set up a bitcoind node and connect it to miner 1.
- btcClient := setupBitcoind(t, addr, test.rpcPolling)
-
- // Test that the correct block `Connect` and
- // `Disconnect` notifications are received during a
- // re-org.
- testReorg(t, miner1, miner2, btcClient)
-
- // Test that the expected block notifications are
- // received.
- btcClient = setupBitcoind(t, addr, test.rpcPolling)
- testNotifyBlocks(t, miner1, btcClient)
-
- // Test that the expected tx notifications are
- // received.
- btcClient = setupBitcoind(t, addr, test.rpcPolling)
- testNotifyTx(t, miner1, btcClient)
-
- // Test notifications for inputs already found in
- // mempool.
- btcClient = setupBitcoind(t, addr, test.rpcPolling)
- testNotifySpentMempool(t, miner1, btcClient)
-
- // Test looking up mempool for input spent.
- testLookupInputMempoolSpend(t, miner1, btcClient)
+ test.testFn(t, miner1, btcClient)
})
}
}
@@ -91,31 +128,21 @@ func testNotifyTx(t *testing.T, miner *rpctest.Harness, client *BitcoindClient)
err = client.NotifyTx([]chainhash.Hash{hash})
require.NoError(err)
- _, err = client.SendRawTransaction(tx, true)
- require.NoError(err)
+ // Send the transaction. This might fail if the bitcoind node hasn't
+ // synced the inputs yet, so we'll retry until it succeeds.
+ require.Eventually(func() bool {
+ _, err = client.SendRawTransaction(tx, true)
+ return err == nil
+ }, defaultTestTimeout, 100*time.Millisecond,
+ "SendRawTransaction failed")
ntfns := client.Notifications()
// We expect to get a ClientConnected notification.
- select {
- case ntfn := <-ntfns:
- _, ok := ntfn.(ClientConnected)
- require.Truef(ok, "Expected type ClientConnected, got %T", ntfn)
-
- case <-time.After(time.Second):
- require.Fail("timed out for ClientConnected notification")
- }
+ waitForClientConnected(t, ntfns)
// We expect to get a RelevantTx notification.
- select {
- case ntfn := <-ntfns:
- tx, ok := ntfn.(RelevantTx)
- require.Truef(ok, "Expected type RelevantTx, got %T", ntfn)
- require.True(tx.TxRecord.Hash.IsEqual(&hash))
-
- case <-time.After(time.Second):
- require.Fail("timed out waiting for RelevantTx notification")
- }
+ waitForRelevantTx(t, ntfns, &hash)
}
// testNotifyBlocks tests that the correct notifications are received for
@@ -134,14 +161,7 @@ func testNotifyBlocks(t *testing.T, miner *rpctest.Harness,
miner.Client.Generate(1)
// We expect to get a ClientConnected notification.
- select {
- case ntfn := <-ntfns:
- _, ok := ntfn.(ClientConnected)
- require.Truef(ok, "Expected type ClientConnected, got %T", ntfn)
-
- case <-time.After(time.Second):
- require.Fail("timed out for ClientConnected notification")
- }
+ waitForClientConnected(t, ntfns)
// We expect to get a FilteredBlockConnected notification.
select {
@@ -150,7 +170,7 @@ func testNotifyBlocks(t *testing.T, miner *rpctest.Harness,
require.Truef(ok, "Expected type FilteredBlockConnected, "+
"got %T", ntfn)
- case <-time.After(time.Second):
+ case <-time.After(defaultTestTimeout):
require.Fail("timed out for FilteredBlockConnected " +
"notification")
}
@@ -161,7 +181,7 @@ func testNotifyBlocks(t *testing.T, miner *rpctest.Harness,
_, ok := ntfn.(BlockConnected)
require.Truef(ok, "Expected type BlockConnected, got %T", ntfn)
- case <-time.After(time.Second):
+ case <-time.After(defaultTestTimeout):
require.Fail("timed out for BlockConnected notification")
}
}
@@ -195,25 +215,10 @@ func testNotifySpentMempool(t *testing.T, miner *rpctest.Harness,
ntfns := client.Notifications()
// We expect to get a ClientConnected notification.
- select {
- case ntfn := <-ntfns:
- _, ok := ntfn.(ClientConnected)
- require.Truef(ok, "Expected type ClientConnected, got %T", ntfn)
-
- case <-time.After(time.Second):
- require.Fail("timed out for ClientConnected notification")
- }
+ waitForClientConnected(t, ntfns)
// We expect to get a RelevantTx notification.
- select {
- case ntfn := <-ntfns:
- tx, ok := ntfn.(RelevantTx)
- require.Truef(ok, "Expected type RelevantTx, got %T", ntfn)
- require.True(tx.TxRecord.Hash.IsEqual(&txid))
-
- case <-time.After(time.Second):
- require.Fail("timed out waiting for RelevantTx notification")
- }
+ waitForRelevantTx(t, ntfns, &txid)
}
// testLookupInputMempoolSpend tests that LookupInputMempoolSpend returns the
@@ -250,7 +255,7 @@ func testLookupInputMempoolSpend(t *testing.T, miner *rpctest.Harness,
rt.Eventually(func() bool {
txid, found = client.LookupInputMempoolSpend(op)
return found
- }, 5*time.Second, 100*time.Millisecond)
+ }, defaultTestTimeout, 100*time.Millisecond)
// Check the expected txid is returned.
rt.Equal(tx.TxHash(), txid)
@@ -258,8 +263,10 @@ func testLookupInputMempoolSpend(t *testing.T, miner *rpctest.Harness,
// testReorg tests that the given BitcoindClient correctly responds to a chain
// re-org.
-func testReorg(t *testing.T, miner1, miner2 *rpctest.Harness,
- client *BitcoindClient) {
+func testReorg(t *testing.T, miner1 *rpctest.Harness, client *BitcoindClient) {
+ t.Helper()
+
+ miner2 := setupReorgMiner(t, miner1)
require := require.New(t)
@@ -281,7 +288,7 @@ func testReorg(t *testing.T, miner1, miner2 *rpctest.Harness,
_, ok := ntfn.(ClientConnected)
require.Truef(ok, "Expected type ClientConnected, got %T", ntfn)
- case <-time.After(time.Second):
+ case <-time.After(defaultTestTimeout):
require.Fail("timed out for ClientConnected notification")
}
@@ -370,7 +377,7 @@ func testReorg(t *testing.T, miner1, miner2 *rpctest.Harness,
func waitForBlockNtfn(t *testing.T, ntfns <-chan interface{},
expectedHeight int32, connected bool) chainhash.Hash {
- timer := time.NewTimer(2 * time.Second)
+ timer := time.NewTimer(defaultTestTimeout)
for {
select {
case nftn := <-ntfns:
@@ -414,21 +421,47 @@ func waitForBlockNtfn(t *testing.T, ntfns <-chan interface{},
}
}
-// setUpMiners sets up two miners that can be used for a re-org test.
-func setupMiners(t *testing.T) (*rpctest.Harness, *rpctest.Harness) {
- trickle := fmt.Sprintf("--trickleinterval=%v", 10*time.Millisecond)
- args := []string{trickle}
+// setUpMiner sets up a single miner.
+func setupMiner(t *testing.T) *rpctest.Harness {
+ t.Helper()
+
+ args := []string{
+ fmt.Sprintf("--trickleinterval=%v", 10*time.Millisecond),
+ // TODO(yy): We should uncomment the following to allow setting
+ // up ports here in the test. However, this cannot work without
+ // modifying the rpcclient in the `btcd` first, as the ports
+ // are overwritten there.
+ //
+ // fmt.Sprintf("--listen=%v", port.NextAvailablePort()),
+ // fmt.Sprintf("--rpclisten=%v", port.NextAvailablePort()),
+ }
- miner1, err := rpctest.New(
- &chaincfg.RegressionNetParams, nil, args, "",
- )
+ miner, err := rpctest.New(&chaincfg.RegressionNetParams, nil, args, "")
require.NoError(t, err)
t.Cleanup(func() {
- miner1.TearDown()
+ require.NoError(t, miner.TearDown())
})
- require.NoError(t, miner1.SetUp(true, 1))
+ require.NoError(t, miner.SetUp(true, 101))
+
+ return miner
+}
+
+// setupReorgMiner sets up a second miner that can be used for a re-org test.
+func setupReorgMiner(t *testing.T, miner1 *rpctest.Harness) *rpctest.Harness {
+ t.Helper()
+
+ args := []string{
+ fmt.Sprintf("--trickleinterval=%v", 10*time.Millisecond),
+ // TODO(yy): We should uncomment the following to allow setting
+ // up ports here in the test. However, this cannot work without
+ // modifying the rpcclient in the `btcd` first, as the ports
+ // are overwritten there.
+ //
+ // fmt.Sprintf("--listen=%v", port.NextAvailablePort()),
+ // fmt.Sprintf("--rpclisten=%v", port.NextAvailablePort()),
+ }
miner2, err := rpctest.New(
&chaincfg.RegressionNetParams, nil, args, "",
@@ -449,7 +482,7 @@ func setupMiners(t *testing.T) (*rpctest.Harness, *rpctest.Harness) {
)
require.NoError(t, err)
- return miner1, miner2
+ return miner2
}
// setupBitcoind starts up a bitcoind node with either a zmq connection or
@@ -460,10 +493,14 @@ func setupBitcoind(t *testing.T, minerAddr string,
// Start a bitcoind instance and connect it to miner1.
tempBitcoindDir := t.TempDir()
- zmqBlockHost := "ipc:///" + tempBitcoindDir + "/blocks.socket"
- zmqTxHost := "ipc:///" + tempBitcoindDir + "/tx.socket"
+ zmqBlockPort := port.NextAvailablePort()
+ zmqTxPort := port.NextAvailablePort()
+
+ zmqBlockHost := fmt.Sprintf("tcp://127.0.0.1:%d", zmqBlockPort)
+ zmqTxHost := fmt.Sprintf("tcp://127.0.0.1:%d", zmqTxPort)
- rpcPort := rand.Int()%(65536-1024) + 1024
+ rpcPort := port.NextAvailablePort()
+ p2pPort := port.NextAvailablePort()
bitcoind := exec.Command(
"bitcoind",
"-datadir="+tempBitcoindDir,
@@ -474,9 +511,11 @@ func setupBitcoind(t *testing.T, minerAddr string,
"d$507c670e800a95284294edb5773b05544b"+
"220110063096c221be9933c82d38e1",
fmt.Sprintf("-rpcport=%d", rpcPort),
+ fmt.Sprintf("-port=%d", p2pPort),
"-disablewallet",
"-zmqpubrawblock="+zmqBlockHost,
"-zmqpubrawtx="+zmqTxHost,
+ "-blockfilterindex=1",
)
require.NoError(t, bitcoind.Start())
@@ -523,13 +562,20 @@ func setupBitcoind(t *testing.T, minerAddr string,
})
// Create a bitcoind client.
- btcClient := chainConn.NewBitcoindClient()
+ btcClient, err := chainConn.NewBitcoindClient()
+ require.NoError(t, err)
require.NoError(t, btcClient.Start(t.Context()))
t.Cleanup(func() {
btcClient.Stop()
})
+ // Wait for bitcoind to sync with the miner.
+ require.Eventually(t, func() bool {
+ _, height, err := btcClient.GetBestBlock()
+ return err == nil && height >= 101
+ }, defaultTestTimeout, 100*time.Millisecond)
+
return btcClient
}
@@ -557,3 +603,106 @@ func randPubKeyHashScript() ([]byte, *btcec.PrivateKey, error) {
return pkScript, privKey, nil
}
+
+// testBitcoindClientGetCFilter verifies the BitcoindClient's GetCFilter
+// implementation by interacting with a live bitcoind node.
+func testBitcoindClientGetCFilter(t *testing.T, miner *rpctest.Harness,
+ client *BitcoindClient) {
+
+ t.Helper()
+
+ require := require.New(t)
+
+ // Generate a block to have something to query a filter for.
+ hashes, err := miner.Client.Generate(1)
+ require.NoError(err)
+
+ blockHash := hashes[0]
+
+ // Get the CFilter using the BitcoindClient. This might take a few
+ // attempts as the filter index might not be immediately available.
+ var gcsFilter *gcs.Filter
+ require.Eventually(func() bool {
+ gcsFilter, err = client.GetCFilter(
+ blockHash, wire.GCSFilterRegular,
+ )
+
+ return err == nil
+ }, defaultTestTimeout, 100*time.Millisecond,
+ "GetCFilter should succeed")
+ require.NotNil(gcsFilter, "GCS filter should not be nil")
+ require.IsType(&gcs.Filter{}, gcsFilter)
+
+ // Verify the filter matches the block data.
+ block, err := client.GetBlock(blockHash)
+ require.NoError(err)
+
+ // Use the first transaction's first output script.
+ script := block.Transactions[0].TxOut[0].PkScript
+
+ // Derive the filter key.
+ key := builder.DeriveKey(blockHash)
+
+ // Check match.
+ matched, err := gcsFilter.Match(key, script)
+ require.NoError(err)
+ require.True(matched, "Filter should match script from block")
+
+ // Test with an unsupported filter type.
+ _, err = client.GetCFilter(blockHash, wire.FilterType(99))
+ require.ErrorContains(err, "only basic filters are supported",
+ "Unsupported filter type should return an error")
+
+ // Test GetCFilter for a non-existent block.
+ dummyHash := &chainhash.Hash{0x01, 0x02, 0x03}
+ _, err = client.GetCFilter(dummyHash, wire.GCSFilterRegular)
+ require.ErrorContains(err, "Block not found",
+ "Non-existent block should return an error")
+}
+
+// waitForClientConnected waits for a ClientConnected notification on the passed
+// channel. Any other notifications received while waiting are ignored.
+func waitForClientConnected(t *testing.T, ntfns <-chan any) {
+ t.Helper()
+
+ timer := time.NewTimer(defaultTestTimeout)
+ defer timer.Stop()
+
+ for {
+ select {
+ case ntfn := <-ntfns:
+ if _, ok := ntfn.(ClientConnected); ok {
+ return
+ }
+
+ case <-timer.C:
+ require.FailNow(t, "timed out for ClientConnected "+
+ "notification")
+ }
+ }
+}
+
+// waitForRelevantTx waits for a RelevantTx notification for the passed tx
+// hash on the passed channel. Any other notifications received while waiting
+// are ignored.
+func waitForRelevantTx(t *testing.T, ntfns <-chan any, hash *chainhash.Hash) {
+ t.Helper()
+
+ timer := time.NewTimer(defaultTestTimeout)
+ defer timer.Stop()
+
+ for {
+ select {
+ case ntfn := <-ntfns:
+ if tx, ok := ntfn.(RelevantTx); ok {
+ if tx.TxRecord.Hash.IsEqual(hash) {
+ return
+ }
+ }
+
+ case <-timer.C:
+ require.FailNow(t, "timed out waiting for RelevantTx "+
+ "notification")
+ }
+ }
+}
diff --git a/chain/btcd.go b/chain/btcd.go
index f266dee29b..f940af149d 100644
--- a/chain/btcd.go
+++ b/chain/btcd.go
@@ -40,6 +40,8 @@ type RPCClient struct {
wg sync.WaitGroup
started bool
quitMtx sync.Mutex
+
+ batchClient *rpcclient.Client
}
// A compile-time check to ensure that RPCClient satisfies the chain.Interface
@@ -92,7 +94,25 @@ func NewRPCClient(chainParams *chaincfg.Params, connect, user, pass string, cert
if err != nil {
return nil, err
}
+
+ batchConfig := *client.connConfig
+
+ // The batch client is exclusively used for batch RPC calls, which
+ // require HTTP POST mode. Therefore, we explicitly set HTTPPostMode to
+ // true and clear the Endpoint field to ensure the batch client is
+ // correctly configured, regardless of the main client's WebSocket (ws)
+ // or HTTP POST configuration.
+ batchConfig.HTTPPostMode = true
+ batchConfig.Endpoint = ""
+
+ batchClient, err := rpcclient.NewBatch(&batchConfig)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create batch client: %w", err)
+ }
+
+ client.batchClient = batchClient
client.Client = rpcClient
+
return client, nil
}
@@ -195,7 +215,24 @@ func NewRPCClientWithConfig(cfg *RPCClientConfig) (*RPCClient, error) {
return nil, err
}
+ batchConfig := *cfg.Conn
+
+ // The batch client is exclusively used for batch RPC calls, which
+ // require HTTP POST mode. Therefore, we explicitly set HTTPPostMode to
+ // true and clear the Endpoint field to ensure the batch client is
+ // correctly configured, regardless of the main client's WebSocket (ws)
+ // or HTTP POST configuration.
+ batchConfig.HTTPPostMode = true
+ batchConfig.Endpoint = ""
+
+ batchClient, err := rpcclient.NewBatch(&batchConfig)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create batch client: %w", err)
+ }
+
+ client.batchClient = batchClient
client.Client = rpcClient
+
return client, nil
}
@@ -258,6 +295,8 @@ func (c *RPCClient) Stop() {
close(c.quit)
c.Client.Shutdown()
c.Client.WaitForShutdown()
+ c.batchClient.Shutdown()
+ c.batchClient.WaitForShutdown()
if !c.started {
close(c.dequeueNotification)
@@ -297,6 +336,28 @@ func (c *RPCClient) Rescan(startHash *chainhash.Hash, addrs []address.Address,
return c.Client.Rescan(startHash, addrs, flatOutpoints) // nolint:staticcheck
}
+// GetCFilter returns a compact filter for the given block hash and filter
+// type. It wraps the underlying rpcclient method and converts the result to a
+// *gcs.Filter.
+func (c *RPCClient) GetCFilter(hash *chainhash.Hash,
+ filterType wire.FilterType) (*gcs.Filter, error) {
+
+ rawFilter, err := c.Client.GetCFilter(hash, filterType)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get filter: %w", err)
+ }
+
+ filter, err := gcs.FromNBytes(
+ builder.DefaultP, builder.DefaultM, rawFilter.Data,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("failed to create filter from bytes: %w",
+ err)
+ }
+
+ return filter, nil
+}
+
// WaitForShutdown blocks until both the client has finished disconnecting
// and all handlers have exited.
func (c *RPCClient) WaitForShutdown() {
@@ -347,7 +408,9 @@ func (c *RPCClient) FilterBlocks(
// the filter returns a positive match, the full block is then requested
// and scanned for addresses using the block filterer.
for i, blk := range req.Blocks {
- rawFilter, err := c.GetCFilter(&blk.Hash, wire.GCSFilterRegular)
+ rawFilter, err := c.Client.GetCFilter(
+ &blk.Hash, wire.GCSFilterRegular,
+ )
if err != nil {
return nil, err
}
@@ -654,3 +717,133 @@ func (c *RPCClient) SendRawTransaction(tx *wire.MsgTx,
return txid, nil
}
+
+// GetBlockHashes returns a slice of block hashes for the given height range.
+func (c *RPCClient) GetBlockHashes(startHeight,
+ endHeight int64) ([]chainhash.Hash, error) {
+
+ if startHeight > endHeight {
+ return nil, fmt.Errorf("%w: start height %d, end height %d",
+ ErrInvalidParam, startHeight, endHeight)
+ }
+
+ count := endHeight - startHeight + 1
+ hashes := make([]chainhash.Hash, 0, count)
+ futures := make([]rpcclient.FutureGetBlockHashResult, 0, count)
+
+ for h := startHeight; h <= endHeight; h++ {
+ futures = append(futures, c.batchClient.GetBlockHashAsync(h))
+ }
+
+ err := c.batchClient.Send()
+ if err != nil {
+ return nil, fmt.Errorf("batch send: %w", err)
+ }
+
+ for _, f := range futures {
+ hash, err := f.Receive()
+ if err != nil {
+ return nil, fmt.Errorf("receive block hash: %w", err)
+ }
+
+ hashes = append(hashes, *hash)
+ }
+
+ return hashes, nil
+}
+
+// GetCFilters returns a slice of filters for the given block hashes.
+func (c *RPCClient) GetCFilters(hashes []chainhash.Hash,
+ filterType wire.FilterType) ([]*gcs.Filter, error) {
+
+ filters := make([]*gcs.Filter, 0, len(hashes))
+ futures := make([]rpcclient.FutureGetCFilterResult, 0, len(hashes))
+
+ for _, hash := range hashes {
+ futures = append(
+ futures,
+ c.batchClient.GetCFilterAsync(&hash, filterType),
+ )
+ }
+
+ err := c.batchClient.Send()
+ if err != nil {
+ return nil, fmt.Errorf("batch send: %w", err)
+ }
+
+ for _, f := range futures {
+ msgFilter, err := f.Receive()
+ if err != nil {
+ return nil, fmt.Errorf("receive cfilter: %w", err)
+ }
+
+ filter, err := gcs.FromNBytes(
+ builder.DefaultP, builder.DefaultM, msgFilter.Data,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("parse cfilter: %w", err)
+ }
+
+ filters = append(filters, filter)
+ }
+
+ return filters, nil
+}
+
+// GetBlocks returns a slice of full blocks for the given block hashes.
+func (c *RPCClient) GetBlocks(hashes []chainhash.Hash) (
+ []*wire.MsgBlock, error) {
+
+ blocks := make([]*wire.MsgBlock, 0, len(hashes))
+ futures := make([]rpcclient.FutureGetBlockResult, 0, len(hashes))
+
+ for _, hash := range hashes {
+ futures = append(futures, c.batchClient.GetBlockAsync(&hash))
+ }
+
+ err := c.batchClient.Send()
+ if err != nil {
+ return nil, fmt.Errorf("batch send: %w", err)
+ }
+
+ for _, f := range futures {
+ block, err := f.Receive()
+ if err != nil {
+ return nil, fmt.Errorf("receive block: %w", err)
+ }
+
+ blocks = append(blocks, block)
+ }
+
+ return blocks, nil
+}
+
+// GetBlockHeaders returns a slice of block headers for the given block hashes.
+func (c *RPCClient) GetBlockHeaders(hashes []chainhash.Hash) (
+ []*wire.BlockHeader, error) {
+
+ headers := make([]*wire.BlockHeader, 0, len(hashes))
+ futures := make([]rpcclient.FutureGetBlockHeaderResult, 0, len(hashes))
+
+ for _, hash := range hashes {
+ futures = append(
+ futures, c.batchClient.GetBlockHeaderAsync(&hash),
+ )
+ }
+
+ err := c.batchClient.Send()
+ if err != nil {
+ return nil, fmt.Errorf("batch send: %w", err)
+ }
+
+ for _, f := range futures {
+ header, err := f.Receive()
+ if err != nil {
+ return nil, fmt.Errorf("receive header: %w", err)
+ }
+
+ headers = append(headers, header)
+ }
+
+ return headers, nil
+}
diff --git a/chain/btcd_test.go b/chain/btcd_test.go
index 9e75daff0a..2cdf6c633d 100644
--- a/chain/btcd_test.go
+++ b/chain/btcd_test.go
@@ -1,13 +1,64 @@
package chain
import (
+ "fmt"
"testing"
+ "time"
"github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcd/integration/rpctest"
"github.com/btcsuite/btcd/rpcclient"
+ "github.com/btcsuite/btcd/wire/v2"
"github.com/stretchr/testify/require"
)
+// setupBtcd starts up a btcd node with cfilters enabled and returns a client
+// wrapper of this connection.
+func setupBtcd(t *testing.T) (*rpctest.Harness, *RPCClient) {
+ t.Helper()
+
+ trickle := fmt.Sprintf("--trickleinterval=%v", 10*time.Millisecond)
+ args := []string{trickle}
+
+ miner, err := rpctest.New(
+ &chaincfg.RegressionNetParams, nil, args, "",
+ )
+ require.NoError(t, err)
+
+ require.NoError(t, miner.SetUp(true, 1))
+
+ t.Cleanup(func() {
+ require.NoError(t, miner.TearDown())
+ })
+
+ rpcConf := miner.RPCConfig()
+ client, err := NewRPCClientWithConfig(&RPCClientConfig{
+ ReconnectAttempts: 1,
+ Chain: &chaincfg.RegressionNetParams,
+ Conn: &rpcclient.ConnConfig{
+ Host: rpcConf.Host,
+ User: rpcConf.User,
+ Pass: rpcConf.Pass,
+ Certificates: rpcConf.Certificates,
+ DisableTLS: false,
+ DisableAutoReconnect: false,
+ DisableConnectOnNew: true,
+ HTTPPostMode: false,
+ Endpoint: "ws",
+ },
+ })
+ require.NoError(t, err)
+
+ err = client.Start(t.Context())
+ require.NoError(t, err)
+
+ t.Cleanup(func() {
+ client.Stop()
+ })
+
+ return miner, client
+}
+
// TestValidateConfig checks the `validate` method on the RPCClientConfig
// behaves as expected.
func TestValidateConfig(t *testing.T) {
@@ -56,3 +107,90 @@ func TestValidateConfig(t *testing.T) {
_, err := NewRPCClientWithConfig(nil)
rt.ErrorContains(err, "missing rpc config")
}
+
+// testInterfaceBatchMethods verifies the batch fetching methods implementation
+// for a given chain.Interface client.
+func testInterfaceBatchMethods(t *testing.T, miner *rpctest.Harness,
+ client Interface) {
+
+ t.Helper()
+
+ require := require.New(t)
+
+ // Generate blocks to have a chain to query.
+ const numBlocks = 5
+
+ _, err := miner.Client.Generate(numBlocks)
+ require.NoError(err)
+
+ // Test GetBlockHashes.
+ // Query from height 1 to 3.
+ startHeight := int64(1)
+ endHeight := int64(3)
+ hashes, err := client.GetBlockHashes(startHeight, endHeight)
+ require.NoError(err, "GetBlockHashes failed")
+ require.Len(hashes, 3)
+
+ // Verify hashes match miner.
+ for i, hash := range hashes {
+ minerHash, err := miner.Client.GetBlockHash(int64(i) + 1)
+ require.NoError(err)
+ require.Equal(*minerHash, hash)
+ }
+
+ // Test GetBlocks.
+ blocks, err := client.GetBlocks(hashes)
+ require.NoError(err, "GetBlocks failed")
+ require.Len(blocks, 3)
+
+ for i, block := range blocks {
+ require.Equal(hashes[i], block.BlockHash())
+ }
+
+ // Test GetBlockHeaders.
+ headers, err := client.GetBlockHeaders(hashes)
+ require.NoError(err, "GetBlockHeaders failed")
+ require.Len(headers, 3)
+
+ for i, header := range headers {
+ require.Equal(hashes[i], header.BlockHash())
+ }
+
+ // Test GetCFilters.
+ // Note: bitcoind needs -blockfilterindex=1 for this to work, which is
+ // set in setupBitcoind.
+ // We use Eventually because filter indexing is asynchronous.
+ require.Eventually(func() bool {
+ filters, err := client.GetCFilters(
+ hashes, wire.GCSFilterRegular,
+ )
+ if err != nil {
+ return false
+ }
+
+ if len(filters) != 3 {
+ return false
+ }
+ // Verify filters are not empty/nil.
+ for _, f := range filters {
+ if f == nil || f.N() == 0 {
+ return false
+ }
+ }
+
+ return true
+ }, defaultTestTimeout, 100*time.Millisecond,
+ "GetCFilters failed or timed out")
+}
+
+// TestRPCClientBatchMethods verifies the RPCClient's batch fetching methods
+// implementation against a live btcd node.
+func TestRPCClientBatchMethods(t *testing.T) {
+ t.Parallel()
+
+ // Set up a miner (btcd node) and client.
+ miner, client := setupBtcd(t)
+
+ // Run batch method tests.
+ testInterfaceBatchMethods(t, miner, client)
+}
diff --git a/chain/interface.go b/chain/interface.go
index e0f29e78f9..88d95eeadd 100644
--- a/chain/interface.go
+++ b/chain/interface.go
@@ -7,6 +7,7 @@ import (
"github.com/btcsuite/btcd/address/v2"
"github.com/btcsuite/btcd/btcjson"
"github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/gcs"
"github.com/btcsuite/btcd/chainhash/v2"
"github.com/btcsuite/btcd/rpcclient"
"github.com/btcsuite/btcd/wire/v2"
@@ -44,6 +45,8 @@ type Interface interface {
GetBlockHash(int64) (*chainhash.Hash, error)
GetBlockHeader(*chainhash.Hash) (*wire.BlockHeader, error)
IsCurrent() bool
+ GetCFilter(hash *chainhash.Hash,
+ filterType wire.FilterType) (*gcs.Filter, error)
FilterBlocks(*FilterBlocksRequest) (*FilterBlocksResponse, error)
BlockStamp() (*waddrmgr.BlockStamp, error)
SendRawTransaction(*wire.MsgTx, bool) (*chainhash.Hash, error)
@@ -57,6 +60,32 @@ type Interface interface {
SubmitPackage(txns []*wire.MsgTx,
maxFeeRate *float64) (*btcjson.SubmitPackageResult, error)
MapRPCErr(err error) error
+
+ // Batching methods for optimized scanning.
+ //
+ // GetBlockHashes returns a slice of block hashes for the given height
+ // range (inclusive).
+ //
+ // NOTE: This is a batching method, designed for optimized scanning.
+ GetBlockHashes(startHeight, endHeight int64) ([]chainhash.Hash, error)
+
+ // GetCFilters returns a slice of compact filters for the given block
+ // hashes and filter type.
+ //
+ // NOTE: This is a batching method, designed for optimized scanning.
+ GetCFilters(hashes []chainhash.Hash,
+ filterType wire.FilterType) ([]*gcs.Filter, error)
+
+ // GetBlocks returns a slice of full blocks for the given block hashes.
+ //
+ // NOTE: This is a batching method, designed for optimized scanning.
+ GetBlocks(hashes []chainhash.Hash) ([]*wire.MsgBlock, error)
+
+ // GetBlockHeaders returns a slice of block headers for the given block
+ // hashes.
+ //
+ // NOTE: This is a batching method, designed for optimized scanning.
+ GetBlockHeaders(hashes []chainhash.Hash) ([]*wire.BlockHeader, error)
}
// Notification types. These are defined here and processed from from reading
diff --git a/chain/jitter_test.go b/chain/jitter_test.go
index 9d62eab13f..16d1d5804b 100644
--- a/chain/jitter_test.go
+++ b/chain/jitter_test.go
@@ -91,8 +91,8 @@ func TestJitterTicker(t *testing.T) {
// Tick duration should be between 80ms and 120ms.
require.True(t, diff >= 80*time.Millisecond, "diff: %v", diff)
- // We give 1ms more to account for the time it takes to run the
+ // We give 5ms more to account for the time it takes to run the
// code.
- require.True(t, diff < 121*time.Millisecond, "diff: %v", diff)
+ require.Less(t, diff, 125*time.Millisecond, "diff: %v", diff)
}
}
diff --git a/chain/mocks_test.go b/chain/mocks_test.go
index 12efcef811..b70fc42203 100644
--- a/chain/mocks_test.go
+++ b/chain/mocks_test.go
@@ -67,96 +67,115 @@ func (m *mockRescanner) WaitForShutdown() {
// mockChainService is a mock implementation of a chain service for use in
// tests. Only the Start, GetBlockHeader and BestBlock methods are implemented.
type mockChainService struct {
+ mock.Mock
}
func (m *mockChainService) Start(_ context.Context) error {
- return nil
+ args := m.Called()
+ return args.Error(0)
}
func (m *mockChainService) BestBlock() (*headerfs.BlockStamp, error) {
- return testBestBlock, nil
+ args := m.Called()
+ return args.Get(0).(*headerfs.BlockStamp), args.Error(1)
}
func (m *mockChainService) GetBlockHeader(
- *chainhash.Hash) (*wire.BlockHeader, error) {
+ hash *chainhash.Hash) (*wire.BlockHeader, error) {
- return &wire.BlockHeader{}, nil
+ args := m.Called(hash)
+ return args.Get(0).(*wire.BlockHeader), args.Error(1)
}
-func (m *mockChainService) GetBlock(chainhash.Hash,
- ...neutrino.QueryOption) (*btcutil.Block, error) {
+func (m *mockChainService) GetBlock(
+ hash chainhash.Hash,
+ options ...neutrino.QueryOption) (*btcutil.Block, error) {
- return nil, errNotImplemented
+ args := m.Called(hash, options)
+ return args.Get(0).(*btcutil.Block), args.Error(1)
}
-func (m *mockChainService) GetBlockHeight(*chainhash.Hash) (int32, error) {
- return 0, errNotImplemented
+func (m *mockChainService) GetBlockHeight(hash *chainhash.Hash) (int32, error) {
+ args := m.Called(hash)
+ return args.Get(0).(int32), args.Error(1)
}
-func (m *mockChainService) GetBlockHash(int64) (*chainhash.Hash, error) {
- return nil, errNotImplemented
+func (m *mockChainService) GetBlockHash(height int64) (*chainhash.Hash, error) {
+ args := m.Called(height)
+ return args.Get(0).(*chainhash.Hash), args.Error(1)
}
func (m *mockChainService) IsCurrent() bool {
- return false
+ args := m.Called()
+ return args.Bool(0)
}
-func (m *mockChainService) SendTransaction(*wire.MsgTx) error {
- return errNotImplemented
+func (m *mockChainService) SendTransaction(tx *wire.MsgTx) error {
+ args := m.Called(tx)
+ return args.Error(0)
}
-func (m *mockChainService) GetCFilter(chainhash.Hash,
- wire.FilterType, ...neutrino.QueryOption) (*gcs.Filter, error) {
+func (m *mockChainService) GetCFilter(
+ hash chainhash.Hash, filterType wire.FilterType,
+ options ...neutrino.QueryOption) (*gcs.Filter, error) {
- return nil, errNotImplemented
+ args := m.Called(hash, filterType, options)
+ return args.Get(0).(*gcs.Filter), args.Error(1)
}
func (m *mockChainService) GetUtxo(
- _ ...neutrino.RescanOption) (*neutrino.SpendReport, error) {
+ opts ...neutrino.RescanOption) (*neutrino.SpendReport, error) {
- return nil, errNotImplemented
+ args := m.Called(opts)
+ return args.Get(0).(*neutrino.SpendReport), args.Error(1)
}
-func (m *mockChainService) BanPeer(string, banman.Reason) error {
- return errNotImplemented
+func (m *mockChainService) BanPeer(addr string, reason banman.Reason) error {
+ args := m.Called(addr, reason)
+ return args.Error(0)
}
func (m *mockChainService) IsBanned(addr string) bool {
- panic(errNotImplemented)
+ args := m.Called(addr)
+ return args.Bool(0)
}
-func (m *mockChainService) AddPeer(*neutrino.ServerPeer) {
- panic(errNotImplemented)
+func (m *mockChainService) AddPeer(peer *neutrino.ServerPeer) {
+ m.Called(peer)
}
-func (m *mockChainService) AddBytesSent(uint64) {
- panic(errNotImplemented)
+func (m *mockChainService) AddBytesSent(bytes uint64) {
+ m.Called(bytes)
}
-func (m *mockChainService) AddBytesReceived(uint64) {
- panic(errNotImplemented)
+func (m *mockChainService) AddBytesReceived(bytes uint64) {
+ m.Called(bytes)
}
func (m *mockChainService) NetTotals() (uint64, uint64) {
- panic(errNotImplemented)
+ args := m.Called()
+ return args.Get(0).(uint64), args.Get(1).(uint64)
}
-func (m *mockChainService) UpdatePeerHeights(*chainhash.Hash,
- int32, *neutrino.ServerPeer,
-) {
- panic(errNotImplemented)
+func (m *mockChainService) UpdatePeerHeights(hash *chainhash.Hash,
+ height int32, peer *neutrino.ServerPeer) {
+
+ m.Called(hash, height, peer)
}
func (m *mockChainService) ChainParams() chaincfg.Params {
- panic(errNotImplemented)
+ args := m.Called()
+ return args.Get(0).(chaincfg.Params)
}
func (m *mockChainService) Stop() error {
- panic(errNotImplemented)
+ args := m.Called()
+ return args.Error(0)
}
-func (m *mockChainService) PeerByAddr(string) *neutrino.ServerPeer {
- panic(errNotImplemented)
+func (m *mockChainService) PeerByAddr(addr string) *neutrino.ServerPeer {
+ args := m.Called(addr)
+ return args.Get(0).(*neutrino.ServerPeer)
}
// mockRPCClient mocks the rpcClient interface.
diff --git a/chain/neutrino.go b/chain/neutrino.go
index 9c3530923e..226a543075 100644
--- a/chain/neutrino.go
+++ b/chain/neutrino.go
@@ -218,6 +218,22 @@ func (s *NeutrinoClient) IsCurrent() bool {
return s.CS.IsCurrent()
}
+// GetCFilter returns a compact filter for the given block hash and filter
+// type.
+//
+// NOTE: This is part of the chain.Interface interface.
+func (s *NeutrinoClient) GetCFilter(blockHash *chainhash.Hash,
+ filterType wire.FilterType) (*gcs.Filter, error) {
+
+ filter, err := s.CS.GetCFilter(*blockHash, filterType)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get filter from "+
+ "neutrino: %w", err)
+ }
+
+ return filter, nil
+}
+
// SendRawTransaction replicates the RPC client's SendRawTransaction command.
func (s *NeutrinoClient) SendRawTransaction(tx *wire.MsgTx, allowHighFees bool) (
*chainhash.Hash, error) {
@@ -588,6 +604,85 @@ func (s *NeutrinoClient) SetStartTime(startTime time.Time) {
s.startTime = startTime
}
+// GetBlockHashes returns a slice of block hashes for the given height range.
+func (s *NeutrinoClient) GetBlockHashes(startHeight, endHeight int64) (
+ []chainhash.Hash, error) {
+
+ if startHeight > endHeight {
+ return nil, fmt.Errorf("%w: start height %d, end height %d",
+ ErrInvalidParam, startHeight, endHeight)
+ }
+
+ count := endHeight - startHeight + 1
+
+ hashes := make([]chainhash.Hash, 0, count)
+
+ for h := startHeight; h <= endHeight; h++ {
+ hash, err := s.CS.GetBlockHash(h)
+ if err != nil {
+ return nil, fmt.Errorf("get block hash: %w", err)
+ }
+
+ hashes = append(hashes, *hash)
+ }
+
+ return hashes, nil
+}
+
+// GetCFilters returns a slice of filters for the given block hashes.
+func (s *NeutrinoClient) GetCFilters(hashes []chainhash.Hash,
+ filterType wire.FilterType) ([]*gcs.Filter, error) {
+
+ filters := make([]*gcs.Filter, 0, len(hashes))
+
+ for _, hash := range hashes {
+ filter, err := s.CS.GetCFilter(hash, filterType)
+ if err != nil {
+ return nil, fmt.Errorf("get cfilter: %w", err)
+ }
+
+ filters = append(filters, filter)
+ }
+
+ return filters, nil
+}
+
+// GetBlocks returns a slice of full blocks for the given block hashes.
+func (s *NeutrinoClient) GetBlocks(hashes []chainhash.Hash) (
+ []*wire.MsgBlock, error) {
+
+ blocks := make([]*wire.MsgBlock, 0, len(hashes))
+
+ for _, hash := range hashes {
+ block, err := s.CS.GetBlock(hash)
+ if err != nil {
+ return nil, fmt.Errorf("get block: %w", err)
+ }
+
+ blocks = append(blocks, block.MsgBlock())
+ }
+
+ return blocks, nil
+}
+
+// GetBlockHeaders returns a slice of block headers for the given block hashes.
+func (s *NeutrinoClient) GetBlockHeaders(hashes []chainhash.Hash) (
+ []*wire.BlockHeader, error) {
+
+ headers := make([]*wire.BlockHeader, 0, len(hashes))
+
+ for _, hash := range hashes {
+ header, err := s.CS.GetBlockHeader(&hash)
+ if err != nil {
+ return nil, fmt.Errorf("get block header: %w", err)
+ }
+
+ headers = append(headers, header)
+ }
+
+ return headers, nil
+}
+
// onFilteredBlockConnected sends appropriate notifications to the notification
// channel.
func (s *NeutrinoClient) onFilteredBlockConnected(height int32,
diff --git a/chain/neutrino_test.go b/chain/neutrino_test.go
index 78a1907e53..5d2cf98e8a 100644
--- a/chain/neutrino_test.go
+++ b/chain/neutrino_test.go
@@ -7,14 +7,86 @@ import (
"time"
"github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/gcs"
+ "github.com/btcsuite/btcd/chainhash/v2"
"github.com/btcsuite/btcd/wire/v2"
+ "github.com/lightninglabs/neutrino/headerfs"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// maxDur is the max duration a test has to execute successfully.
var maxDur = 5 * time.Second
+// TestNeutrinoClientBatchFetch verifies that the batch fetching methods
+// correctly loop over the range/list and call the underlying service.
+func TestNeutrinoClientBatchFetch(t *testing.T) {
+ t.Parallel()
+
+ nc := newMockNeutrinoClient()
+ mockCS, ok := nc.CS.(*mockChainService)
+ require.True(t, ok)
+
+ // Clear default expectations set in newMockNeutrinoClient so we can
+ // set strict expectations for this test.
+
+ // Test GetBlockHashes
+ startHeight := int64(100)
+ endHeight := int64(102)
+ hash1 := chainhash.Hash{1}
+ hash2 := chainhash.Hash{2}
+ hash3 := chainhash.Hash{3}
+
+ mockCS.On("GetBlockHash", int64(100)).Return(&hash1, nil).Once()
+ mockCS.On("GetBlockHash", int64(101)).Return(&hash2, nil).Once()
+ mockCS.On("GetBlockHash", int64(102)).Return(&hash3, nil).Once()
+
+ hashes, err := nc.GetBlockHashes(startHeight, endHeight)
+ require.NoError(t, err)
+ require.Len(t, hashes, 3)
+ require.Equal(t, hash1, hashes[0])
+ require.Equal(t, hash2, hashes[1])
+ require.Equal(t, hash3, hashes[2])
+
+ // Test GetCFilters
+ filterType := wire.GCSFilterRegular
+ filter1 := &gcs.Filter{} // Empty filter
+ mockCS.On("GetCFilter", hash1, filterType, mock.Anything).
+ Return(filter1, nil).Once()
+ mockCS.On("GetCFilter", hash2, filterType, mock.Anything).
+ Return(filter1, nil).Once()
+ mockCS.On("GetCFilter", hash3, filterType, mock.Anything).
+ Return(filter1, nil).Once()
+
+ filters, err := nc.GetCFilters(hashes, filterType)
+ require.NoError(t, err)
+ require.Len(t, filters, 3)
+
+ // Test GetBlocks
+ block1 := btcutil.NewBlock(&wire.MsgBlock{})
+ mockCS.On("GetBlock", hash1, mock.Anything).Return(block1, nil).Once()
+ mockCS.On("GetBlock", hash2, mock.Anything).Return(block1, nil).Once()
+ mockCS.On("GetBlock", hash3, mock.Anything).Return(block1, nil).Once()
+
+ blocks, err := nc.GetBlocks(hashes)
+ require.NoError(t, err)
+ require.Len(t, blocks, 3)
+
+ // Test GetBlockHeaders
+ header1 := &wire.BlockHeader{}
+ mockCS.On("GetBlockHeader", &hash1).Return(header1, nil).Once()
+ mockCS.On("GetBlockHeader", &hash2).Return(header1, nil).Once()
+ mockCS.On("GetBlockHeader", &hash3).Return(header1, nil).Once()
+
+ headers, err := nc.GetBlockHeaders(hashes)
+ require.NoError(t, err)
+ require.Len(t, headers, 3)
+
+ mockCS.AssertExpectations(t)
+}
+
// TestNeutrinoClientSequentialStartStop ensures that the client
// can sequentially Start and Stop without errors or races.
func TestNeutrinoClientSequentialStartStop(t *testing.T) {
@@ -23,6 +95,18 @@ func TestNeutrinoClientSequentialStartStop(t *testing.T) {
wantRestarts = 50
)
+ mockCS, ok := nc.CS.(*mockChainService)
+ require.True(t, ok)
+
+ testBestBlock := &headerfs.BlockStamp{
+ Hash: chainhash.Hash(make([]byte, 32)),
+ Height: 1,
+ }
+
+ mockCS.On("Start").Return(nil).Times(wantRestarts)
+ mockCS.On("Stop").Return(nil).Times(wantRestarts)
+ mockCS.On("BestBlock").Return(testBestBlock, nil).Maybe()
+
// callStartStop starts the neutrino client, requires no error on
// startup, immediately stops the client and waits for shutdown.
// The returned channel is closed once shutdown is complete.
@@ -118,13 +202,29 @@ func TestNeutrinoClientNotifyReceivedRescan(t *testing.T) {
gotMsgs = 0
msgCh = make(chan string, wantMsgs)
msgPrefix = "successfully called"
-
- // sendMsg writes a message to the buffered message channel.
- sendMsg = func(s string) {
- msgCh <- fmt.Sprintf("%s %s", msgPrefix, s)
- }
)
+ mockCS, ok := nc.CS.(*mockChainService)
+ require.True(t, ok)
+
+ testBestBlock := &headerfs.BlockStamp{
+ Hash: chainhash.Hash(make([]byte, 32)),
+ Height: 1,
+ }
+
+ testBlockHeader := &wire.BlockHeader{Timestamp: time.Unix(1, 0)}
+
+ mockCS.On("Start").Return(nil).Once()
+ mockCS.On("Stop").Return(nil).Once()
+ mockCS.On("BestBlock").Return(testBestBlock, nil).Maybe()
+ mockCS.On("GetBlockHeader", mock.Anything).
+ Return(testBlockHeader, nil).Maybe()
+
+ // sendMsg writes a message to the buffered message channel.
+ sendMsg := func(s string) {
+ msgCh <- fmt.Sprintf("%s %s", msgPrefix, s)
+ }
+
// Define closures to wrap desired neutrino client method calls.
// cleanup is the shared cleanup function for a closure executing
diff --git a/chain/port/port.go b/chain/port/port.go
new file mode 100644
index 0000000000..59eab24adf
--- /dev/null
+++ b/chain/port/port.go
@@ -0,0 +1,198 @@
+// Package port provides functionality for managing network ports, including
+// finding available ports and ensuring exclusive access using lock files.
+package port
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "os"
+ "path/filepath"
+ "strconv"
+ "sync"
+ "time"
+)
+
+const (
+ // defaultTimeout is the default timeout that is used for the wait
+ // package.
+ defaultTimeout = 30 * time.Second
+
+ // ListenerFormat is the format string that is used to generate local
+ // listener addresses.
+ ListenerFormat = "127.0.0.1:%d"
+
+ // defaultNodePort is the start of the range for listening ports of
+ // harness nodes. Ports are monotonically increasing starting from this
+ // number and are determined by the results of NextAvailablePort().
+ defaultNodePort int = 10000
+
+ // uniquePortFile is the name of the file that is used to store the
+ // last port that was used by a node. This is used to make sure that
+ // the same port is not used by multiple nodes at the same time. The
+ // file is located in the temp directory of a system.
+ uniquePortFile = "rpctest-port"
+
+ // filePerms is the file permission used for the lock file and port
+ // file.
+ filePerms = 0600
+
+ // retryInterval is the interval to wait before retrying to acquire the
+ // lock file.
+ retryInterval = 10 * time.Millisecond
+
+ // maxPort is the maximum valid port number.
+ maxPort = 65535
+)
+
+var (
+ // portFileMutex is a mutex that is used to make sure that the port file
+ // is not accessed by multiple goroutines of the same process at the
+ // same time. This is used in conjunction with the lock file to make
+ // sure that the port file is not accessed by multiple processes at the
+ // same time either. So the lock file is to guard between processes and
+ // the mutex is to guard between goroutines of the same process.
+ portFileMutex sync.Mutex
+)
+
+// NextAvailablePort returns the first port that is available for listening by a
+// new node, using a lock file to make sure concurrent access for parallel tasks
+// on the same system don't re-use the same port.
+func NextAvailablePort() int {
+ portFileMutex.Lock()
+ defer portFileMutex.Unlock()
+
+ lockFile := filepath.Join(os.TempDir(), uniquePortFile+".lock")
+ lockFile = filepath.Clean(lockFile)
+ lockFileHandle := acquireLockFile(lockFile)
+
+ // Release the lock file when we're done.
+ defer func() {
+ // Always close file first, Windows won't allow us to remove it
+ // otherwise.
+ _ = lockFileHandle.Close()
+
+ err := os.Remove(lockFile)
+ if err != nil {
+ panic(fmt.Errorf("couldn't remove lock file: %w", err))
+ }
+ }()
+
+ portFile := filepath.Join(os.TempDir(), uniquePortFile)
+ portFile = filepath.Clean(portFile)
+
+ port, err := os.ReadFile(portFile)
+ if err != nil {
+ if !os.IsNotExist(err) {
+ panic(fmt.Errorf("error reading port file: %w", err))
+ }
+
+ port = []byte(strconv.Itoa(defaultNodePort))
+ }
+
+ lastPort, err := strconv.Atoi(string(port))
+ if err != nil {
+ panic(fmt.Errorf("error parsing port: %w", err))
+ }
+
+ // lastPort has reached the max allowed port, we start with the default
+ // node port.
+ if lastPort >= maxPort {
+ lastPort = defaultNodePort
+ }
+
+ // Determine the first port to try.
+ nextPort := lastPort + 1
+
+ availablePort := findAvailablePort(nextPort)
+
+ err = os.WriteFile(
+ portFile, []byte(strconv.Itoa(availablePort)), filePerms,
+ )
+ if err != nil {
+ panic(fmt.Errorf("error updating port file: %w", err))
+ }
+
+ return availablePort
+}
+
+// findAvailablePort searches for an available port starting from the given
+// port. If it reaches the maximum port number, it wraps around to the default
+// node port and continues searching until it has checked the entire range.
+func findAvailablePort(startPort int) int {
+ currentPort := startPort
+ for {
+ // If there are no errors while attempting to listen on this
+ // port, close the socket and return it as available. While it
+ // could be the case that some other process picks up this port
+ // between the time the socket is closed, and it's reopened in
+ // the harness node, in practice in CI servers this seems much
+ // less likely than simply some other process already being
+ // bound at the start of the tests.
+ addr := fmt.Sprintf(ListenerFormat, currentPort)
+
+ lc := &net.ListenConfig{}
+
+ l, err := lc.Listen(context.Background(), "tcp4", addr)
+ if err == nil {
+ _ = l.Close()
+ return currentPort
+ }
+
+ currentPort++
+
+ // Start from the beginning if we reached the end of the port
+ // range. We need to do this because the lock file now is
+ // persistent across runs on the same machine during the same
+ // boot/uptime cycle. So in order to make this work on
+ // developer's machines, we need to reset the port to the
+ // default value when we reach the end of the range.
+ if currentPort > maxPort {
+ currentPort = defaultNodePort
+ }
+
+ // If we reached the start port again, it means no ports are
+ // available.
+ if currentPort == startPort {
+ break
+ }
+ }
+
+ // No ports available? Must be a mistake.
+ panic("no ports available for listening")
+}
+
+// acquireLockFile attempts to acquire the lock file. If it already exists, it
+// waits for a bit and retries until the timeout is reached. If the process is
+// killed before the lock file is removed, this function will timeout and panic.
+// In that case, the lock file must be manually removed.
+func acquireLockFile(lockFile string) *os.File {
+ timeout := time.After(defaultTimeout)
+
+ var (
+ lockFileHandle *os.File
+ err error
+ )
+ for {
+ // Attempt to acquire the lock file. If it already exists, wait
+ // for a bit and retry.
+ //
+ //nolint:gosec // lockFile is constructed from os.TempDir() and
+ // a constant, not from user input.
+ lockFileHandle, err = os.OpenFile(
+ lockFile, os.O_CREATE|os.O_EXCL, filePerms,
+ )
+ if err == nil {
+ // Lock acquired.
+ return lockFileHandle
+ }
+
+ // Wait for a bit and retry.
+ select {
+ case <-timeout:
+ str := "timeout waiting for lock file: " + lockFile
+ panic(str)
+ case <-time.After(retryInterval):
+ }
+ }
+}
diff --git a/chain/pruned_block_dispatcher_test.go b/chain/pruned_block_dispatcher_test.go
index 6519fb594b..d6782d3510 100644
--- a/chain/pruned_block_dispatcher_test.go
+++ b/chain/pruned_block_dispatcher_test.go
@@ -67,7 +67,7 @@ func newNetworkBlockTestHarness(t *testing.T, numBlocks,
localConns: make(map[string]net.Conn, numPeers),
remoteConns: make(map[string]net.Conn, numPeers),
dialedPeer: make(chan string),
- queriedPeer: make(chan struct{}),
+ queriedPeer: make(chan struct{}, numBlocks*numPeers),
blocksQueried: make(map[chainhash.Hash]int),
shouldReply: 0,
}
@@ -327,7 +327,7 @@ func (h *prunedBlockDispatcherHarness) disconnectPeer(addr string, fallback bool
h.dispatcher.peerMtx.Lock()
defer h.dispatcher.peerMtx.Unlock()
return len(h.dispatcher.currentPeers) == numPeers-1
- }, time.Second, 200*time.Millisecond)
+ }, defaultTestTimeout, 200*time.Millisecond)
// Reset the peer connection state to allow connections to them again.
h.resetPeer(addr, fallback)
@@ -339,7 +339,7 @@ func (h *prunedBlockDispatcherHarness) assertPeerDialed() {
select {
case <-h.dialedPeer:
- case <-time.After(5 * time.Second):
+ case <-time.After(defaultTestTimeout):
h.t.Fatalf("expected peer to be dialed")
}
}
@@ -351,7 +351,7 @@ func (h *prunedBlockDispatcherHarness) assertPeerDialedWithAddr(addr string) {
select {
case dialedAddr := <-h.dialedPeer:
require.Equal(h.t, addr, dialedAddr)
- case <-time.After(5 * time.Second):
+ case <-time.After(defaultTestTimeout):
h.t.Fatalf("expected peer to be dialed")
}
}
@@ -362,7 +362,7 @@ func (h *prunedBlockDispatcherHarness) assertPeerQueried() {
select {
case <-h.queriedPeer:
- case <-time.After(5 * time.Second):
+ case <-time.After(defaultTestTimeout):
h.t.Fatalf("expected a peer to be queried")
}
}
@@ -395,7 +395,7 @@ func (h *prunedBlockDispatcherHarness) assertPeerReplied(
// We need to check the errChan after a timeout because when a request
// was successful a nil error is signaled via the errChan and this
// might happen even before the block is received.
- case <-time.After(5 * time.Second):
+ case <-time.After(defaultTestTimeout):
select {
case err := <-errChan:
h.t.Fatalf("received unexpected error send: %v", err)
@@ -415,7 +415,7 @@ func (h *prunedBlockDispatcherHarness) assertPeerReplied(
select {
case err := <-errChan:
require.NoError(h.t, err)
- case <-time.After(5 * time.Second):
+ case <-time.After(defaultTestTimeout):
h.t.Fatal("expected nil err to signal completion")
}
}
@@ -446,7 +446,7 @@ func (h *prunedBlockDispatcherHarness) assertPeerFailed(
case err := <-cancelChan:
require.ErrorIs(h.t, err, expectedErr)
- case <-time.After(5 * time.Second):
+ case <-time.After(defaultTestTimeout):
h.t.Fatalf("expected the error for the block request: %v",
expectedErr)
}
@@ -461,7 +461,7 @@ func (h *prunedBlockDispatcherHarness) assertNoPeerDialed() {
select {
case peer := <-h.dialedPeer:
h.t.Fatalf("unexpected connection established with peer %v", peer)
- case <-time.After(2 * time.Second):
+ case <-time.After(1 * time.Second):
}
}
@@ -482,7 +482,7 @@ func (h *prunedBlockDispatcherHarness) assertNoReply(
h.t.Fatalf("received unexpected cancel request with error: %v",
err)
- case <-time.After(2 * time.Second):
+ case <-time.After(1 * time.Second):
}
}
diff --git a/docs/developer/README.md b/docs/developer/README.md
index 1c3dd1f863..38ed43f839 100644
--- a/docs/developer/README.md
+++ b/docs/developer/README.md
@@ -40,4 +40,12 @@ A deep dive into the core design philosophy, architectural patterns, and Go impl
Formal documentation of significant architectural decisions, their context, and consequences.
-**[➡️ View Architecture Decision Records](./adr/README.md)**
\ No newline at end of file
+**[➡️ View Architecture Decision Records](./adr/README.md)**
+
+---
+
+## 📜 PSBT Workflows Guide
+
+A detailed guide to creating Bitcoin transactions using the `PsbtManager` interface, covering various scenarios and best practices.
+
+**[➡️ Read the PSBT Workflows Guide](./psbt_workflows.md)**
\ No newline at end of file
diff --git a/docs/developer/adr/0002-controller-syncer-architecture.md b/docs/developer/adr/0002-controller-syncer-architecture.md
new file mode 100644
index 0000000000..46ab623e3a
--- /dev/null
+++ b/docs/developer/adr/0002-controller-syncer-architecture.md
@@ -0,0 +1,51 @@
+# ADR 0002: Controller-Syncer-State Architecture
+
+## 1. Context
+
+The legacy `btcwallet` architecture tightly coupled lifecycle management, synchronization logic, and state tracking within a single `Wallet` struct. This monolithic design led to several issues:
+* **Race Conditions:** Ambiguity between "Started" and "Syncing" states made it difficult to safely manage concurrent access.
+* **Blocking Operations:** Long-running sync operations would block control-plane requests (like `Stop` or `Info`).
+* **Testing Difficulty:** The tight coupling made it nearly impossible to unit test synchronization logic in isolation from the full wallet stack.
+
+We need a robust, testable, and concurrent architecture to support modern features like multi-wallet management and targeted rescans.
+
+## 2. Decision
+
+We will adopt a **Controller-Syncer-State** pattern with an **Orthogonal State Model**.
+
+### 2.1 The Components
+
+1. **Controller (`Controller` interface / `Wallet` struct):**
+ * **Role:** The public API surface and lifecycle manager.
+ * **Responsibility:** Validates requests, manages the `Start/Stop` lifecycle, and delegates long-running tasks. It never blocks on chain operations.
+
+2. **Syncer (`chainSyncer` interface / `syncer` struct):**
+ * **Role:** The background worker.
+ * **Responsibility:** Executes the chain loop, communicates with the backend, and manages the database state for synchronization. It is isolated and testable.
+
+3. **State (`walletState` struct):**
+ * **Role:** The source of truth for the wallet's status.
+ * **Responsibility:** Maintains state across three independent dimensions (Lifecycle, Sync, Auth) using atomic operations.
+
+### 2.2 Orthogonal State Model
+
+Instead of a single status enum, we track three separate dimensions:
+* **Lifecycle:** `Stopped` -> `Starting` -> `Started` -> `Stopping`
+* **Synchronization:** `BackendSyncing` -> `Syncing` -> `Synced` | `Rescanning`
+* **Authentication:** `Locked` | `Unlocked`
+
+## 3. Consequences
+
+### Pros
+* **Concurrency Safety:** State transitions are atomic and explicitly managed, eliminating race conditions.
+* **Responsiveness:** The Controller remains responsive to user requests even while the Syncer is performing heavy I/O.
+* **Testability:** The `Syncer` can be tested with a mock `Chain` and `Store` without instantiating a full `Wallet`. The `Controller` can be tested with a mock `Syncer`.
+* **Clarity:** The separation of concerns makes the codebase easier to navigate and reason about.
+
+### Cons
+* **Complexity:** Increases the number of distinct types and files.
+* **Indirection:** Calls to sync functionality now go through a channel-based request mechanism rather than direct method calls.
+
+## 4. Status
+
+Accepted and Implemented.
diff --git a/docs/developer/adr/0003-optimistic-cfilter-batching.md b/docs/developer/adr/0003-optimistic-cfilter-batching.md
new file mode 100644
index 0000000000..fb179b2e65
--- /dev/null
+++ b/docs/developer/adr/0003-optimistic-cfilter-batching.md
@@ -0,0 +1,54 @@
+# ADR 0003: Optimistic CFilter Batch Scanning
+
+## 1. Context
+
+Synchronizing a wallet using BIP 157/158 Compact Filters (CFilters) presents a performance challenge.
+* **Latency:** Fetching filters and blocks sequentially (Header -> Filter -> Block) incurs significant network round-trip time (RTT), especially for high-latency backends like Neutrino.
+* **The Horizon Problem:** BIP 32 wallets must expand their "lookahead window" (derive new addresses) when used addresses are discovered. If a block contains a transaction to the last address in the window, the wallet must immediately derive more addresses and re-scan subsequent blocks to ensure no funds are missed.
+
+We need a scanning algorithm that maximizes throughput (minimizing RTT) while guaranteeing correctness (respecting the gap limit).
+
+## 2. Decision
+
+We will implement an **Optimistic Batching strategy with In-Place Resume**.
+
+### 2.1 The Strategy
+
+1. **Optimistic Fetch:** The wallet fetches headers, CFilters, and (if matched) blocks for a large batch (e.g., 100 blocks) in parallel, assuming the current address lookahead window is sufficient.
+2. **Sequential Process:** The downloaded blocks are processed sequentially in memory.
+3. **In-Place Resume:** If processing Block `N` triggers a horizon expansion (new addresses derived):
+ * The processing loop pauses.
+ * The wallet updates its internal watchlist with the new addresses.
+ * The wallet **re-scans** the remaining blocks in the *current batch* (Blocks `N+1` to `End`) using the updated watchlist.
+ * If necessary, it fetches missing blocks that now match the new filters.
+
+### 2.2 Logic Flow
+
+```
+Batch Loop:
+ 1. Fetch Filters for Batch [Start, End]
+ 2. Match Filters against Current Watchlist
+ 3. Fetch Matched Blocks
+ 4. Block Loop (i from Start to End):
+ a. Process Block(i)
+ b. If Horizon Expanded:
+ i. Update Watchlist
+ ii. Re-Match Filters for [i+1, End]
+ iii. Fetch Newly Matched Blocks
+ iv. Continue Loop
+```
+
+## 3. Consequences
+
+### Pros
+* **High Throughput:** In the common case (no sequential expansion), the wallet fetches data in large, efficient batches, saturating the network connection.
+* **Correctness:** The "In-Place Resume" logic guarantees that even if a user receives a chain of payments to sequential addresses in a single batch, the wallet will discover all of them.
+* **Efficiency:** It avoids the naive "Stop-and-Go" approach of processing one block at a time, which is prohibitively slow.
+
+### Cons
+* **Complexity:** The resumption logic adds complexity to the scan loop implementation.
+* **Redundant Work (Edge Case):** In the worst-case scenario (sequential expansion in every block), the algorithm degrades to re-matching filters repeatedly. However, this is rare in practice.
+
+## 4. Status
+
+Accepted and Implemented.
diff --git a/docs/developer/adr/0004-targeted-rescan-vs-rewind.md b/docs/developer/adr/0004-targeted-rescan-vs-rewind.md
new file mode 100644
index 0000000000..bd4da5944b
--- /dev/null
+++ b/docs/developer/adr/0004-targeted-rescan-vs-rewind.md
@@ -0,0 +1,58 @@
+# ADR 0004: Targeted Rescan vs. Global Rewind
+
+## 1. Context
+
+In `btcwallet`, discovering missing transactions has historically required a "Rescan." The legacy implementation treated all rescans as a "Rewind":
+1. Set the wallet's global `SyncedTo` height back to the start block.
+2. Force the wallet into a `Syncing` state.
+3. Re-process all blocks from that height forward.
+
+This "Global Rewind" approach is problematic for modern use cases like importing a single private key or account.
+* **Disruption:** It forces the entire wallet to be "unsynced" for minutes or hours, blocking critical operations like creating transactions, even though the existing keys are perfectly up-to-date.
+* **Inefficiency:** It re-scans the chain for *all* wallet addresses, not just the imported ones.
+
+We need a mechanism to scan for specific keys without disrupting the global wallet state.
+
+## 2. Decision
+
+We will implement two distinct types of history recovery, managed by the `Syncer` but differentiated by their effect on the global state.
+
+### 2.1 Global Rewind (Manual Rescan)
+* **Trigger:** Explicit user request via `Resync(...)`.
+* **Behavior:**
+ * **Rewinds** the global `SyncedTo` watermark in the database.
+ * Sets state to `Syncing`.
+ * Re-scans for **all** known wallet addresses.
+* **Use Case:** Recovering from a corrupted database, a chain reorganization deep in history, or a user explicitly wanting to "reset" the wallet's view.
+
+### 2.2 Targeted Rescan (Import Scan)
+* **Trigger:** Importing keys/accounts (e.g., `ImportPrivateKey`, `ImportAccount`), or a user request with specific targets.
+* **Behavior:**
+ * **Does NOT** rewind the global `SyncedTo` watermark.
+ * Sets state to a new `Rescanning` sub-state.
+ * Constructs a **Partial Recovery State** containing *only* the specific targets (addresses/scripts).
+ * Scans the requested block range for these targets.
+ * Inserts found transactions into the database.
+* **Use Case:** Adding a new key to an existing, synced wallet.
+
+## 3. Concurrency and Safety
+
+To prevent race conditions during these operations, we enforce strict access control based on the Orthogonal State Model.
+
+* **`CreateTransaction` / `FundPsbt`**: Blocked if state is `Syncing` or `Rescanning`. The UTXO set is considered unstable during any scan.
+* **`Balance` / `ListUnspent`**: Allowed during `Rescanning`. They return the state of the *existing* (synced) keys, which is safe because the targeted rescan only *adds* new data; it doesn't invalidate existing confirmed history.
+
+## 4. Consequences
+
+### Pros
+* **User Experience:** Importing a key is a background task. The user can continue to use their existing funds immediately.
+* **Performance:** Scanning for 1 key is significantly faster than scanning for 10,000 keys (especially with CFilters).
+* **Safety:** Explicitly differentiating the states prevents the "accidental rewind" that scares users.
+
+### Cons
+* **Complexity:** The `Syncer` logic must handle two different "modes" of operation (Global Loop vs. Ad-hoc Job).
+* **Database Complexity:** We must ensure that inserting transactions during a targeted rescan doesn't conflict with the global sync loop if they happen to overlap (though the design serializes them in the `chainLoop`).
+
+## 5. Status
+
+Accepted and Implemented.
diff --git a/docs/developer/adr/0005-no-auto-rescan-on-import.md b/docs/developer/adr/0005-no-auto-rescan-on-import.md
new file mode 100644
index 0000000000..52039a3cc8
--- /dev/null
+++ b/docs/developer/adr/0005-no-auto-rescan-on-import.md
@@ -0,0 +1,45 @@
+# ADR 0005: Explicit Rescan on Import
+
+## 1. Context
+
+When importing new keys, addresses, or accounts into a wallet (e.g., via `ImportPrivateKey` or `ImportAccount`), the wallet needs to scan the blockchain history to discover any existing funds associated with these new credentials.
+
+A common pattern in some wallet implementations is to automatically trigger a rescan immediately upon import. However, this approach introduces several issues:
+* **Performance Storms:** If a user or application imports a batch of 100 keys sequentially, an automatic trigger would launch 100 overlapping, redundant rescan jobs.
+* **Blocking Behavior:** If the import method waits for the scan, a simple database insertion becomes a potentially hour-long operation.
+* **API Ambiguity:** It blurs the line between "State Management" (adding a key) and "Network Operation" (scanning the chain).
+
+## 2. Decision
+
+`btcwallet` will **not** automatically trigger a blockchain rescan when keys, addresses, or accounts are imported.
+
+* **Import Methods are Purely Database Operations:** Methods like `ImportPrivateKey`, `ImportAccount`, and `ImportScript` will only persist the data to the wallet database and return immediately.
+* **Rescans Must Be Explicit:** The caller is responsible for explicitly requesting a rescan (via `Rescan(...)`) after the import is complete.
+
+## 3. Rationale
+
+### 3.1 Batch Efficiency
+This design allows downstream applications (like `lnd` or custom scripts) to batch imports efficiently. An application can import 1,000 keys in a loop and then trigger a **single** targeted rescan for the aggregate birthday of those keys. This is orders of magnitude more efficient than 1,000 individual scans.
+
+### 3.2 API Clarity
+Separating the concerns of "Storage" and "Synchronization" makes the API predictable.
+* `ImportXXX`: "I want to save this key." (Fast, Atomic, Synchronous)
+* `Rescan`: "I want to look for money." (Slow, Asynchronous, Cancellable)
+
+### 3.3 User Control
+The user (or calling software) retains control over system resources. They may choose to import keys now but defer the heavy scanning operation until a maintenance window or when bandwidth is available.
+
+## 4. Consequences
+
+### Pros
+* **Performance:** Eliminates redundant scanning during bulk imports.
+* **Responsiveness:** Import RPCs remain consistently fast.
+* **Flexibility:** Allows advanced import workflows (e.g., offline imports).
+
+### Cons
+* **Usability Pitfall:** A naive user might import a key and be confused why their balance shows `0`. Documentation and RPC output must clearly indicate that a rescan is required to see funds.
+* **Client Burden:** Clients must implement the "Import -> Rescan" logic themselves.
+
+## 5. Status
+
+Accepted and Implemented.
diff --git a/docs/developer/psbt_workflows.md b/docs/developer/psbt_workflows.md
new file mode 100644
index 0000000000..b77d635017
--- /dev/null
+++ b/docs/developer/psbt_workflows.md
@@ -0,0 +1,328 @@
+# PSBT Workflows Guide
+
+This document provides a guide to creating Bitcoin transactions using the
+`PsbtManager` interface. We will explore several scenarios, from a simple
+single-person payment to a more complex, multi-party collaborative transaction,
+highlighting best practices for security and efficiency.
+
+Our actors:
+- **Alice**: A user of `btcwallet`.
+- **Bob**: Another user of `btcwallet`.
+- **Carol**: The recipient of the payments.
+
+---
+
+## Scenario 1: Simple Single-Signer Transaction (Alice Pays Carol)
+
+This is the most common use case: a single user creating a transaction from their
+own wallet. The workflow is linear and straightforward.
+
+**Goal:** Alice wants to pay 1 BTC to Carol.
+
+```mermaid
+flowchart LR
+ Start([Start]) --> Create["Create Bare PSBT"]
+ Create --> Fund["Fund PSBT
(Coin Selection)"]
+ Fund --> Sign["Sign PSBT"]
+ Sign --> Finalize["Finalize PSBT"]
+ Finalize --> Broadcast["Broadcast TX"]
+ Broadcast --> End([Done])
+```
+
+### Workflow Steps
+
+1. **Create a Bare PSBT:** Alice's application first creates a bare PSBT that
+ describes the intended output.
+
+ ```go
+ import "github.com/btcsuite/btcwallet/wallet"
+
+ carolOutput := &wire.TxOut{Value: 100_000_000, PkScript: carolPkScript}
+ packet, err := wallet.CreatePsbt(nil, []*wire.TxOut{carolOutput})
+ ```
+
+2. **Fund the PSBT:** Alice's wallet performs coin selection to add inputs and a
+ change output.
+
+ ```go
+ fundIntent := &wallet.FundIntent{
+ Packet: packet,
+ Policy: &wallet.InputsPolicy{
+ Source: &wallet.ScopedAccount{
+ AccountName: "default",
+ KeyScope: waddrmgr.KeyScopeBIP0086,
+ },
+ MinConfs: 1,
+ },
+ FeeRate: btcunit.NewSatPerKVByte(1000), // e.g., 1 sat/vb
+ }
+
+ fundedPacket, _, err := aliceWallet.FundPsbt(ctx, fundIntent)
+
+ ```
+
+ The `fundedPacket` now contains the necessary inputs (fully decorated) and a change output.
+
+3. **Sign the PSBT:** The wallet signs all inputs it has the keys for.
+
+ ```go
+ signParams := &wallet.SignPsbtParams{Packet: packet}
+ _, err = aliceWallet.SignPsbt(ctx, signParams)
+ ```
+
+4. **Finalize and Broadcast:** Alice finalizes the PSBT to produce a complete,
+ valid transaction and broadcasts it.
+
+ ```go
+ err = aliceWallet.FinalizePsbt(ctx, packet)
+ finalTx, err := psbt.Extract(packet)
+ err = aliceWallet.Broadcast(ctx, finalTx, "Payment to Carol")
+ ```
+
+### Analysis
+
+- **Round Trips:** 0 (all operations are local to Alice's wallet).
+- **Security:** High. Alice controls the entire process, so there is no risk
+ of external manipulation.
+
+---
+
+## Scenario 2: Collaborative Transaction (Alice and Bob Pay Carol)
+
+This is a more advanced workflow where multiple parties contribute inputs to a
+single transaction. This requires careful coordination to ensure security.
+
+**Goal:** Alice and Bob want to jointly pay Carol.
+
+We will explore two models for this: a naive (and insecure) model, and the
+recommended, secure Coordinator Model.
+
+### The Naive (and Insecure) Independent Funding Model - **DO NOT USE**
+
+In this model, participants create their contributions independently and a
+coordinator merges them.
+
+1. **Alice Funds:** Alice creates a PSBT that pays Carol her portion.
+ `aliceWallet.FundPsbt(...)` -> `packet_alice`
+2. **Bob Funds:** Bob does the same. `bobWallet.FundPsbt(...)` -> `packet_bob`
+3. **Coordinator Combines:** A coordinator merges these.
+ `combinedPacket, _ := wallet.CombinePsbt(ctx, packet_alice, packet_bob)`
+4. **Signing:** The `combinedPacket` is passed around for signatures.
+
+#### Security Concerns: Critical Flaw
+
+This model is **dangerously insecure** in a trustless environment.
+
+Imagine a malicious Bob. When creating `packet_bob`, he could add an extra,
+unexpected output that pays some of the transaction's value to himself.
+
+When the coordinator calls `CombinePsbt`, this malicious output is merged into
+the final transaction. If Alice's application logic does not manually parse and
+validate every single input and output from Bob's PSBT fragment, she will
+unknowingly sign a transaction that steals funds. **The API makes the insecure
+path easy.**
+
+### The Recommended (and Secure) Coordinator Model
+
+This model ensures security by having all participants agree on the final
+transaction structure *before* any signatures are created.
+
+**Principle:** Verify the whole transaction, then sign.
+
+```mermaid
+sequenceDiagram
+ participant A as Alice (Coordinator)
+ participant B as Bob
+ participant N as Bitcoin Network
+
+ Note over A,B: 1. Off-chain Agreement on Terms
+
+ A->>A: Create Bare PSBT (Template)
+ A->>B: Send Bare PSBT
+
+ par Parallel Signing
+ A->>A: Verify, Decorate, & Sign Input
+ B->>B: Verify, Decorate, & Sign Input
+ end
+
+ B->>A: Send Partially Signed PSBT
+ A->>A: Combine Signatures
+ A->>A: Finalize Transaction
+ A->>N: Broadcast Transaction
+```
+
+#### Workflow Steps
+
+**1. Agreement (Off-chain)**
+Alice and Bob first communicate and agree on the exact transaction:
+- Which UTXO Alice will contribute (`alice_utxo_1`).
+- Which UTXO Bob will contribute (`bob_utxo_1`).
+- The final, combined output for Carol.
+- The exact change output for Alice.
+- The exact change output for Bob.
+- The agreed-upon fee rate.
+
+**2. Coordinator Creates the Template (Alice)**
+Alice, acting as coordinator, creates a single, bare PSBT that represents the
+**entire, final transaction**. This is the "single source of truth".
+
+```go
+// Alice's code (as coordinator)
+allInputs := []*wire.OutPoint{&alice_utxo_1, &bob_utxo_1}
+allOutputs := []*wire.TxOut{carol_output, alice_change_output, bob_change_output}
+
+// Create a single PSBT template for the entire transaction.
+barePacket, err := wallet.CreatePsbt(allInputs, allOutputs)
+```
+
+**3. Participants Verify, Decorate, and Sign (Parallel)**
+The coordinator sends the `barePacket` to all participants (including herself).
+Each participant now performs the same set of actions independently.
+
+```go
+// Bob's code (Alice does the same with her wallet)
+
+// CRITICAL STEP: Verify the transaction structure.
+// Bob's application logic MUST inspect barePacket to ensure it exactly
+// matches the off-chain agreement. It checks that only the expected inputs
+// and outputs are present, with the correct values.
+if !isValid(barePacket) {
+ return errors.New("transaction proposal is invalid")
+}
+
+// If valid, Bob's wallet decorates its own input.
+err := bobWallet.DecorateInputs(ctx, barePacket, true)
+// The wallet finds bob_utxo_1 and adds its UTXO/derivation info.
+
+// Bob's wallet now signs its input.
+signParams := &wallet.SignPsbtParams{Packet: barePacket}
+_, err = bobWallet.SignPsbt(ctx, signParams)
+
+// Bob sends the partially signed PSBT back to the coordinator.
+```
+
+**4. Coordinator Combines Signatures (Alice)**
+Alice collects the signed PSBTs from all participants. Each PSBT is a copy of
+the original `barePacket` but now contains a different partial signature. She
+uses `CombinePsbt` to merge these signatures into a single, fully-signed PSBT.
+
+```go
+// Alice's code (as coordinator)
+// (Alice has already signed her own copy, `my_signed_packet`)
+fullySignedPacket, err := aliceWallet.CombinePsbt(
+ ctx, my_signed_packet, signed_packet_from_bob,
+)
+```
+
+**5. Finalize and Broadcast (Alice)**
+The coordinator now has a complete PSBT and can finalize it to produce the
+broadcastable transaction.
+
+```go
+err = aliceWallet.FinalizePsbt(ctx, fullySignedPacket)
+finalTx, err := psbt.Extract(fullySignedPacket)
+err = aliceWallet.Broadcast(ctx, finalTx, "Collaborative payment to Carol")
+```
+
+#### Analysis
+
+- **Round Trips:** 2.
+ 1. Coordinator distributes the `barePacket` to all participants.
+ 2. Participants return their signed PSBTs to the coordinator.
+- **Security:** High. The security comes from the "verify-then-sign" workflow.
+ Each participant validates the entire, final transaction structure *before*
+ creating a signature. A signature becomes a cryptographic commitment to the
+ complete, agreed-upon transaction, preventing any party from maliciously
+ altering it after the fact.
+
+---
+
+## Scenario 3: Taproot Script Path Multisig (Signing Multiple Times)
+
+Taproot introduces powerful new capabilities, such as Script Path spends where
+multiple parties can sign via different leaf scripts or a single script requiring
+multiple signatures (e.g., a 2-of-2 multisig script leaf).
+
+The `SignPsbt` method enforces a **strict single-derivation-path policy** per
+input call. This means if a wallet holds multiple keys involved in a multisig
+input, it must call `SignPsbt` multiple times—once for each key it intends to
+sign with.
+
+**Goal:** A 2-of-2 multisig input (Alice + Bob) is being spent via a Taproot
+Script Path. Alice holds both Key A1 and Key A2 (e.g., for redundancy or testing)
+and needs to provide two signatures for the same input.
+
+### Workflow Steps
+
+1. **Prepare the PSBT:** The coordinator constructs the PSBT with the Taproot
+ input. The input MUST include the `TaprootLeafScript` and `ControlBlock` to
+ identify the script path being spent.
+
+2. **First Signing Pass (Key A1):** Alice's wallet inspects the PSBT. To sign
+ with Key A1, the application must ensure the PSBT input contains the
+ `TaprootBip32Derivation` for **Key A1 only**.
+
+ ```go
+ // Populate derivation info for Key A1 ONLY.
+ packet.Inputs[0].TaprootBip32Derivation = []*psbt.TaprootBip32Derivation{
+ derivInfoForKeyA1,
+ }
+
+ // Sign. The wallet sees one derivation path and generates one signature.
+ // It appends this signature to the `TaprootScriptSpendSig` list.
+ signedResult, err := aliceWallet.SignPsbt(ctx, &wallet.SignPsbtParams{
+ Packet: packet,
+ })
+ ```
+
+3. **Second Signing Pass (Key A2):** Now Alice needs to sign with Key A2. The
+ application updates the PSBT input to show the derivation for **Key A2**.
+
+ ```go
+ // Replace derivation info with Key A2.
+ packet.Inputs[0].TaprootBip32Derivation = []*psbt.TaprootBip32Derivation{
+ derivInfoForKeyA2,
+ }
+
+ // Sign again. The wallet sees a new, single derivation path.
+ // It generates the second signature and appends it to the list.
+ // The previous signature for Key A1 is preserved.
+ signedResult2, err := aliceWallet.SignPsbt(ctx, &wallet.SignPsbtParams{
+ Packet: packet,
+ })
+ ```
+
+### Why this Restriction?
+
+Enforcing a single derivation path per call eliminates ambiguity.
+- If `SignPsbt` received multiple derivation paths for one input, it would be
+ unclear if the caller intended to sign *all* of them, *one* of them, or if
+ some were just metadata.
+- By requiring explicit, singular intent, the API ensures deterministic
+ behavior: "Here is the key I want you to use; sign with it."
+
+---
+
+## Advanced Topics & Best Practices
+
+### Why `SIGHASH_ALL` is Essential
+In collaborative transactions, all signatures should use `SIGHASH_ALL` (the
+default). This flag ensures that the signature commits to *all* inputs and *all*
+outputs in the transaction. If a participant were to use a different flag like
+`SIGHASH_SINGLE`, a malicious coordinator could modify the parts of the
+transaction not covered by the signature, leading to fund loss or unexpected
+behavior.
+
+### The Role of `DecorateInputs`
+`DecorateInputs` is the bridge between a transaction's structure and its
+signability. In the Coordinator Model, it's a crucial step that allows each
+participant's wallet to add the private information (UTXO value, script,
+derivation path) needed for its own hardware or software to produce a valid
+signature.
+
+### Coin Control
+A user can choose a specific UTXO to spend by creating a PSBT with that input
+already included before calling `FundPsbt`. The `FundPsbt` method will detect
+the existing input and enter a "completion" mode, where it simply calculates
+fees and adds a change output, rather than performing automatic coin selection.
+This is how Alice specified her input in the Coordinator Model example.
\ No newline at end of file
diff --git a/docs/developer/scanning_sync_architecture.md b/docs/developer/scanning_sync_architecture.md
new file mode 100644
index 0000000000..ca6a5111e5
--- /dev/null
+++ b/docs/developer/scanning_sync_architecture.md
@@ -0,0 +1,167 @@
+# Wallet Synchronization and Scanning Architecture
+
+This document details the architecture of the `btcwallet` synchronization subsystem. It explains how the wallet maintains consensus with the blockchain, discovers relevant transactions, and manages the recovery of funds.
+
+## 1. High-Level Architecture
+
+The synchronization system is designed around a **Controller-Worker-State** pattern, separating the public API from the background work and the core logic.
+
+```mermaid
+graph TD
+ User[User / RPC] -->|Calls Start/Rescan/Unlock| Controller
+
+ subgraph "Wallet Package"
+ Controller[Controller] -->|Manages| State[Wallet State]
+ Controller -->|Sends Req| Syncer[Syncer]
+
+ Syncer -->|Maintains| RecoveryState[Recovery State]
+ Syncer -->|Reads/Writes| DB[(Wallet DB)]
+ Syncer -->|Fetches Data| Chain[Chain Backend]
+ end
+```
+
+### 1.1 Key Components
+
+* **Controller (`wallet/controller.go`)**: The public face of the wallet. It manages the wallet's lifecycle (`Start`, `Stop`), handles authentication (`Lock`, `Unlock`), and acts as the gatekeeper for state transitions. It does *not* perform blocking chain operations directly.
+* **Syncer (`wallet/syncer.go`)**: A dedicated background worker responsible for the main synchronization loop. It communicates with the chain backend (e.g., `bitcoind`, `neutrino`), orchestrates batch scanning, and handles blockchain reorganizations (rollbacks).
+* **RecoveryState (`wallet/recovery.go`)**: A specialized state machine that encapsulates the logic for *what* to scan for. It manages BIP32 derivation horizons, address lookahead windows, and the set of watched outpoints. It is purely logic and memory-based, decoupled from the I/O mechanisms of the Syncer.
+
+---
+
+## 2. State Management: The Orthogonal Model
+
+To manage concurrency and API availability safely, the wallet employs an **Orthogonal State Model**. Instead of a single monolithic status (e.g., "Syncing"), we track three independent dimensions of state. This decoupling allows for precise representation of complex conditions (e.g., a wallet can be "Started" AND "Syncing" AND "Locked") without state explosion.
+
+### 2.1 Lifecycle (System State)
+Tracks the runtime status of the wallet's main event loop and background processes.
+* **Stopped**: The wallet is idle. No background routines are running.
+* **Starting**: The wallet is in the middle of its synchronous startup sequence (e.g., loading accounts, verifying birthday).
+* **Started**: The wallet is fully operational. `mainLoop` and `chainLoop` are running.
+* **Stopping**: A shutdown signal has been sent; the wallet is waiting for background routines to exit.
+
+### 2.2 Synchronization (Chain State)
+Tracks data freshness relative to the blockchain backend.
+* **BackendSyncing**: Waiting for the chain backend (e.g., bitcoind) to finish its own synchronization.
+* **Syncing**: The wallet is actively downloading blocks or filters to catch up to the chain tip.
+* **Synced**: The wallet is fully caught up with the current chain tip.
+* **Rescanning**: The wallet is performing a targeted historical scan for specific accounts or addresses. This is a sub-state that does not rewind the global sync watermark.
+
+### 2.3 Authentication (Security State)
+Tracks the accessibility of sensitive private key material.
+* **Locked**: Private keys are encrypted and inaccessible in memory.
+* **Unlocked**: Private keys have been decrypted and are available for signing.
+* **Security Note**: The system tracks the `unlocked` flag such that the zero-value (false) defaults to the secure **Locked** state. The wallet is forcefully locked during any `Stop` or `Stopping` transition.
+
+---
+
+## 3. Synchronization Modes
+
+The `Syncer` operates in two primary modes:
+
+### 3.1 Chain Synchronization (Global Sync)
+This is the default background process that ensures the wallet maintains consensus with the blockchain.
+
+* **Goal**: Advance the global `SyncedTo` pointer to the current chain tip.
+* **Mechanism**: Sequential forward scanning of block batches.
+* **Persistence**: Upon successful completion of a batch, the wallet updates its global "sync tip" in the database.
+
+### 3.2 Targeted Rescan (Import Scanning)
+Triggered by user actions like importing a new account, a private key, or an XPUB.
+
+* **Goal**: Discover historical transactions for the *newly added* keys without affecting the synchronization status of existing keys.
+* **Mechanism**: Ad-hoc scanning of a specific block range (typically from the birthday of the imported key to the current tip).
+* **Persistence**: Found transactions are inserted into the database, but the global `SyncedTo` watermark is **not** altered. This allows the wallet to remain "Synced" for the rest of its keys while processing the import in the background.
+
+---
+
+## 4. Data Preparation
+
+Before scanning can begin, the Syncer must prepare a `RecoveryState` object. This object acts as the "Checklist" of things to look for in the blocks. The source of this data depends on the sync mode.
+
+### 4.1 Loading for Global Sync
+When performing the standard chain sync, the wallet loads **all** active data from the database:
+1. **Accounts**: Iterates through all active BIP32 accounts in the `waddrmgr`.
+2. **Horizons**: For each account, retrieves the current external and internal branch horizons (the index of the last used address).
+3. **Historical Addresses**: Loads every address that has ever received funds.
+4. **UTXOs**: Loads all unspent transaction outputs to detect spends.
+
+The `RecoveryState` is initialized with this data and immediately derives `N` new lookahead addresses (based on the `RecoveryWindow`) for every account branch.
+
+### 4.2 Loading for Targeted Rescan
+When performing a targeted rescan (e.g., after `ImportAccount`), the caller provides specific targets. The wallet constructs a **Partial Recovery State**:
+1. **Targets**: Only the specific accounts or addresses requested by the caller are loaded.
+2. **Isolation**: Existing, fully-synced accounts are **excluded** from this state.
+3. **Efficiency**: This ensures the scanner only spends CPU cycles matching the new keys, ignoring the thousands of keys that are already up-to-date.
+
+---
+
+## 5. Full Block Scanning Algorithm
+
+This is the traditional scanning method, used when bandwidth is abundant or when privacy filters are not supported by the backend.
+
+### 5.1 The Algorithm
+1. **Fetch Batch**: The Syncer requests a batch of full blocks (e.g., 20 blocks) directly from the backend (RPC `getblock` or P2P `MSG_BLOCK`).
+2. **Process Sequentially**: It iterates through each block in memory.
+3. **Transaction Matching**:
+ * **Inputs**: Checked against the `watchedOutPoints` map to detect spends.
+ * **Outputs**: Checked against the `addrFilters` map to detect receives.
+4. **Horizon Expansion**: If a transaction pays to a lookahead address:
+ * Mark address as used.
+ * Derive new lookahead addresses.
+ * **No Restart Needed**: Since we have the full block data, we simply add the new addresses to the map and continue processing. Future blocks in the batch will be checked against the updated map.
+
+---
+
+## 6. CFilter Scanning Algorithm
+
+This method uses BIP 157/158 Compact Filters to minimize bandwidth usage. It is complex because filters are probabilistic and abstract; we don't have the transaction data until we fetch the block.
+
+### 6.1 Optimistic Batch Processing with In-Place Resume
+To overcome the latency of fetching headers -> filters -> blocks sequentially, we use an **Optimistic** strategy.
+
+1. **Parallel Fetch**:
+ * Assume the current lookahead window is sufficient.
+ * Fetch a large batch (e.g., 250 blocks) of **Headers** and **CFilters** in parallel.
+
+2. **Local Filtering**:
+ * Match the CFilters against the `RecoveryState`'s watchlist (Addresses + Outpoints).
+ * Queue only the *matching* blocks for download.
+
+3. **Sequential Process & Resume Loop**:
+ * Iterate through the batch.
+ * **Horizon Expansion Event**: If Block `N` contains a payment to a lookahead address, we must expand the window.
+ * **The Problem**: The filters for blocks `N+1` to `End` were checked against the *old* watchlist. They might contain payments to the *new* addresses we just derived.
+ * **The Fix (In-Place Resume)**:
+ * Pause processing.
+ * Update the watchlist with new addresses.
+ * **Re-Match** the filters for the remainder of the batch (`N+1`...`End`) against the new watchlist.
+ * Fetch any *newly* matched blocks.
+ * Resume processing from `N+1`.
+
+---
+
+## 7. Strategy Comparison & Selection
+
+The wallet automatically selects the best strategy based on the environment and state.
+
+| Feature | Full Block Scanning | CFilter Scanning |
+| :--- | :--- | :--- |
+| **Bandwidth** | High (All data) | Low (Headers + Filters + Matched Blocks) |
+| **CPU Usage** | Low (Hash map lookups) | High (Elliptic Curve ops + SIPHash matching) |
+| **Latency** | Low (Local) / High (Remote) | Low (Parallel Fetch) |
+| **Privacy** | High (Indistinguishable) | Medium (Leaks Block Interest) |
+| **Best For** | Local Bitcoind, Huge Wallets | Mobile, Light Clients, Bandwidth Cap |
+
+### 7.1 Selection Logic (`SyncMethodAuto`)
+The wallet uses a heuristic to choose:
+1. **Backend Capability**: If the backend doesn't support BIP 157 (CFilters), fall back to Full Blocks.
+2. **Watchlist Size**: If the wallet is watching > 100,000 items (addresses + UTXOs), CFilter matching becomes CPU-prohibitive. The wallet switches to **Full Block** scanning, as checking a map is O(1) regardless of size.
+3. **Default**: Use **CFilters** for efficiency and privacy.
+
+---
+
+## 8. Performance and Efficiency
+
+* **Write Batching**: Database writes are the single biggest bottleneck. The Syncer aggregates all findings (transactions, state updates) from a batch and commits them in a **single database transaction**. This reduces disk I/O by orders of magnitude compared to per-block commits.
+* **Lookahead Derivation**: Address derivation is cached. The `RecoveryState` ensures we don't re-derive keys we've already generated, even if the scan is restarted.
+* **Non-Blocking**: All scanning happens in a dedicated goroutine. The Wallet Controller remains responsive to `Info` and `Balance` requests even during a massive re-sync.
\ No newline at end of file
diff --git a/docs/user/README.md b/docs/user/README.md
new file mode 100644
index 0000000000..906a2cf6ee
--- /dev/null
+++ b/docs/user/README.md
@@ -0,0 +1,19 @@
+# User Documentation
+
+This section contains documentation for end-users and operators of `btcwallet`.
+
+---
+
+## 🔄 Synchronization Modes
+
+A detailed guide on choosing between Compact Filters (`SyncMethodCFilters`) and Full Blocks (`SyncMethodFullBlocks`) based on your transaction density. Includes performance benchmarks and strategy recommendations.
+
+**[➡️ Read the Synchronization Modes Guide](./synchronization_modes.md)**
+
+---
+
+## 🛠️ Rebuilding Transaction History
+
+Instructions on how to force a full wallet rescan using the `dropwtxmgr` tool. Useful for recovering from database corruption or missing history.
+
+**[➡️ Learn about Forced Rescans](./force_rescans.md)**
diff --git a/docs/user/synchronization_modes.md b/docs/user/synchronization_modes.md
new file mode 100644
index 0000000000..351779a53c
--- /dev/null
+++ b/docs/user/synchronization_modes.md
@@ -0,0 +1,65 @@
+# Synchronization Modes: Compact Filters vs. Full Blocks
+
+When connecting to a chain backend, `btcwallet` offers two distinct synchronization strategies, configured via the `SyncMethod` parameter:
+
+1. **`SyncMethodCFilters`**: Uses lightweight Compact Filters (Neutrino) to scan for relevant transactions.
+2. **`SyncMethodFullBlocks`**: Downloads full blocks (or batches of blocks) to scan locally.
+
+Choosing the right mode can significantly impact your wallet's startup time, bandwidth usage, and CPU load. This guide explains the differences and provides performance benchmarks to help you decide which mode fits your use case.
+
+## Summary Recommendation
+
+| Transaction Density | Recommended Mode | Why? |
+| :--- | :--- | :--- |
+| **Low Frequency** (< 1 hit per 100 blocks) | **`SyncMethodCFilters`** | **~4x Faster.** Optimized for sparse histories by avoiding block data entirely. |
+| **Moderate Frequency** (1 hit per 10-100 blocks) | **`SyncMethodCFilters`** | **~2x Faster.** Still faster than full blocks due to reduced data transfer. |
+| **High Frequency** (> 1 hit per 10 blocks) | **`SyncMethodFullBlocks`** | **High Throughput.** Most efficient when hits are dense, avoiding matching overhead. |
+
+---
+
+## 1. Compact Filters (CFilter)
+**Default & Recommended for 99% of Users.**
+
+In this mode (`SyncMethodCFilters`), the wallet downloads lightweight **Neutrino Compact Filters** (Golomb-Rice filters) for each block to check if the block contains any relevant transactions.
+* **Process**: Fetch Filter -> Match Filter -> (Only if Match) Fetch Block.
+* **Pros**:
+ * Extremely fast for "empty" blocks.
+ * Minimal bandwidth usage (filters are ~15KB vs blocks ~1-4MB).
+* **Cons**:
+ * Slower if every block is a "hit" (requires double round-trips: get filter, then get block).
+
+## 2. Full Blocks
+**Recommended for High-Density Wallets.**
+
+In this mode (`SyncMethodFullBlocks`), the wallet indiscriminately downloads every full block (or batches of blocks) and scans them locally.
+* **Process**: Fetch Batch of Blocks -> Scan All.
+* **Pros**:
+ * Linear scaling for high-traffic wallets.
+ * Eliminates the "Match Filter -> Fetch Block" latency penalty when hits are frequent.
+* **Cons**:
+ * High bandwidth (downloads the entire blockchain history during rescan).
+ * High CPU/Memory usage (parsing gigabytes of JSON/Hex block data).
+
+---
+
+## Performance Analysis & UTXO Density
+
+We benchmarked both modes against a standard `bitcoind` node over a range of 1,000 blocks. The performance depends heavily on **UTXO Density**: how often your wallet sends or receives a transaction relative to the number of blocks.
+
+### Benchmark Results (1,000 Blocks)
+
+| Wallet Activity (Density) | Real-World Equivalent | `SyncMethodFullBlocks` | `SyncMethodCFilters` | Winner |
+| :--- | :--- | :--- | :--- | :--- |
+| **0.001 (0.1%)** | **1 Tx / Week** | ~1.9x Faster | **~3.8x Faster** | **`SyncMethodCFilters`** |
+| **0.01 (1.0%)** | **1 Tx / 16 Hours** | ~1.5x Faster | **~2.1x Faster** | **`SyncMethodCFilters`** |
+| **0.1 (10%)** | **10 Txs / Hour** | **~1.8x Faster** | Slower | **`SyncMethodFullBlocks`** |
+
+* *Speedups are compared against the legacy synchronization API.*
+
+### Interpreting Density
+* **Low to Moderate Frequency (Density < 0.1)**: This represents the vast majority of wallet usage. If your wallet receives or sends funds a few times a day or less, your transaction history is "sparse" relative to the blockchain (~144 blocks/day). **Use `SyncMethodCFilters`** to save bandwidth and reduce sync time.
+* **High Frequency (Density > 0.1)**: This represents "dense" wallets that are active in **1 out of every 10 blocks** (or more). For these high-traffic scenarios, the overhead of double-fetching (filter match then block fetch) exceeds the cost of just fetching everything. **Use `SyncMethodFullBlocks`** for maximum throughput.
+
+## Conclusion
+
+The new synchronization architecture in `btcwallet` is designed to be adaptive. By default, **`SyncMethodCFilters`** provides a superior experience for typical users, offering massive speedups and bandwidth savings. However, for high-workload scenarios where the wallet effectively indexes a significant portion of the chain, **`SyncMethodFullBlocks`** mode provides a robust, high-throughput alternative that outperforms the legacy implementation.
diff --git a/go.mod b/go.mod
index 54a3beb083..5ed9404034 100644
--- a/go.mod
+++ b/go.mod
@@ -64,3 +64,7 @@ require (
// If you change this please run `make lint` to see where else it needs to be
// updated as well.
go 1.25.11
+
+replace github.com/btcsuite/btcwallet/wtxmgr => ./wtxmgr
+
+replace github.com/btcsuite/btcwallet/wallet/txsizes => ./wallet/txsizes
diff --git a/go.sum b/go.sum
index c4715582e1..c2d93e042b 100644
--- a/go.sum
+++ b/go.sum
@@ -26,12 +26,8 @@ github.com/btcsuite/btcwallet/wallet/txauthor v1.4.0 h1:oIkGj32YK1CvWaJGlVwZA1f+
github.com/btcsuite/btcwallet/wallet/txauthor v1.4.0/go.mod h1:sGrBjcqQ8UPexuRajFs72+o544CJn3Pavv/5H0VAWVk=
github.com/btcsuite/btcwallet/wallet/txrules v1.3.0 h1:D5aGMwWIxdqek3xEJs4eOdMoh6iga2EI2xSlaXCdnNo=
github.com/btcsuite/btcwallet/wallet/txrules v1.3.0/go.mod h1:ZzSdn2XrsUDPa193Q/su1sJY+716rlFK2H1mYwbY/18=
-github.com/btcsuite/btcwallet/wallet/txsizes v1.3.0 h1:2W9qt0edMoX8crx0Wm4Cv+eAj4B3jlbn0N/5fckLHSU=
-github.com/btcsuite/btcwallet/wallet/txsizes v1.3.0/go.mod h1:42aE6+LMZSSEisQAa15Xml25ncuJFfhCrkcpB5OmkZk=
github.com/btcsuite/btcwallet/walletdb v1.6.0 h1:Yund5XbdqFxNW7+R2Sxs02bMC5fMrmORj4GN8MV55no=
github.com/btcsuite/btcwallet/walletdb v1.6.0/go.mod h1:q9xif0Csp52GVb3l252BbHCuyiCnuEbrPWu/HAsvaYc=
-github.com/btcsuite/btcwallet/wtxmgr v1.6.0 h1:ivSSnYCD4Kb5yAMZVyBA1VMYABIFcopPEcmHCrRZXcE=
-github.com/btcsuite/btcwallet/wtxmgr v1.6.0/go.mod h1:Raor7IBIwHSIKE9Lr5o+R9rwX7sRMHU1zjxgEQgn9h8=
github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd h1:R/opQEbFEy9JGkIguV40SvRY1uliPX8ifOvi6ICsFCw=
github.com/btcsuite/go-socks v0.0.0-20170105172521-4720035b7bfd/go.mod h1:HHNXQzUsZCxOoE+CPiyCTO6x34Zs86zZUiwtpXoGdtg=
github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792 h1:R8vQdOQdZ9Y3SkEwmHoWBmX1DNXhXZqlTpq6s4tyJGc=
diff --git a/pkg/btcunit/README.md b/pkg/btcunit/README.md
new file mode 100644
index 0000000000..ed4abef300
--- /dev/null
+++ b/pkg/btcunit/README.md
@@ -0,0 +1,25 @@
+# btcwallet/btcunit
+
+This package provides a set of idiomatic, type-safe units for handling common
+Bitcoin quantities like transaction sizes, weights, and fee rates.
+
+## Purpose
+
+In complex Bitcoin applications, it is crucial to handle different units of
+measurement safely and consistently. Raw integer types can lead to ambiguity and
+errors (e.g., is a fee rate in sat/byte, sat/vbyte, or sat/kw?).
+
+This package establishes a canonical set of types to be used within `btcwallet`
+and by any application that consumes it. By using these types, developers can
+avoid conversion errors and make their code more readable and self-documenting.
+
+## Provided Units
+
+- **Transaction Size**: `WeightUnit` and `VByte` for handling transaction
+ weight and virtual size according to SegWit (BIP-141) standards.
+- **Fee Rates**: `SatPerVByte`, `SatPerKVByte`, and `SatPerKWeight` for
+ expressing fee rates in the most common industry formats. These types use
+ `math/big.Rat` internally to allow for fractional (sub-satoshi) values,
+ ensuring precision in fee calculations. These types use
+ `math/big.Rat` internally to allow for fractional (sub-satoshi) values,
+ ensuring precision in fee calculations.
diff --git a/pkg/btcunit/rates.go b/pkg/btcunit/rates.go
new file mode 100644
index 0000000000..8c28f6b640
--- /dev/null
+++ b/pkg/btcunit/rates.go
@@ -0,0 +1,482 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+// Package btcunit provides a set of types for dealing with bitcoin units.
+package btcunit
+
+import (
+ "log/slog"
+ "math"
+ "math/big"
+
+ "github.com/btcsuite/btcd/blockchain"
+ "github.com/btcsuite/btcd/btcutil/v2"
+)
+
+const (
+ // kilo is a generic multiplier for kilo units.
+ kilo = 1000
+
+ // floatStringPrecision is the number of decimal places to use when
+ // converting a fee rate to a string. We use 3 decimal places to ensure
+ // that low fee rates (e.g., 1 sat/kvb = 0.001 sat/vbyte) are displayed
+ // with sufficient precision and not rounded to zero.
+ floatStringPrecision = 3
+)
+
+var (
+ // ZeroSatPerVByte is a fee rate of 0 sat/vb.
+ ZeroSatPerVByte = NewSatPerVByte(0)
+
+ // ZeroSatPerKVByte is a fee rate of 0 sat/kvb.
+ ZeroSatPerKVByte = NewSatPerKVByte(0)
+
+ // ZeroSatPerKWeight is a fee rate of 0 sat/kw.
+ ZeroSatPerKWeight = NewSatPerKWeight(0)
+
+ // ZeroSatPerWeight is a fee rate of 0 sat/wu.
+ ZeroSatPerWeight = NewSatPerWeight(0)
+)
+
+// baseFeeRate stores the canonical representation of a fee rate, which is
+// satoshis per kilo-weight-unit (sat/kwu). All other fee rate units are
+// derived from this.
+type baseFeeRate struct {
+ // satsPerKWU is the fee rate in satoshis per kilo-weight-unit. This is
+ // the canonical representation for all fee rates within this package,
+ // chosen for its direct alignment with Bitcoin's weight unit for fee
+ // calculations and to minimize rounding errors.
+ satsPerKWU *big.Rat
+}
+
+// newBaseFeeRate creates a new baseFeeRate with the given numerator and
+// denominator. It panics if the denominator is zero.
+func newBaseFeeRate(numerator btcutil.Amount, denominator uint64) baseFeeRate {
+ if denominator == 0 {
+ panic("fee rate calculation: denominator cannot be zero")
+ }
+
+ return baseFeeRate{satsPerKWU: big.NewRat(
+ int64(numerator),
+ safeUint64ToInt64(denominator),
+ )}
+}
+
+// ToSatPerVByte converts the fee rate to sat/vb.
+func (f baseFeeRate) ToSatPerVByte() SatPerVByte {
+ return SatPerVByte{f}
+}
+
+// ToSatPerKVByte converts the fee rate to sat/kvb.
+func (f baseFeeRate) ToSatPerKVByte() SatPerKVByte {
+ return SatPerKVByte{f}
+}
+
+// ToSatPerKWeight converts the fee rate to sat/kw.
+func (f baseFeeRate) ToSatPerKWeight() SatPerKWeight {
+ return SatPerKWeight{f}
+}
+
+// ToSatPerWeight converts the fee rate to sat/wu.
+func (f baseFeeRate) ToSatPerWeight() SatPerWeight {
+ return SatPerWeight{f}
+}
+
+// FeeForWeight calculates the fee resulting from this fee rate and the given
+// weight in weight units (wu).
+func (f baseFeeRate) FeeForWeight(weightUnit WeightUnit) btcutil.Amount {
+ // The fee rate is stored as satoshis per kilo-weight-unit (sat/kwu).
+ // To calculate the fee for a given weight, we need to multiply the
+ // rate by the weight expressed in kilo-weight-units. We do this by
+ // creating a rational number of weightUnit.wu / kilo.
+ //
+ // The resulting fee is rounded down (truncated).
+ feeRateRational := big.NewRat(0, 1)
+ feeRateRational.Mul(
+ f.satsPerKWU,
+ big.NewRat(safeUint64ToInt64(weightUnit.wu), kilo),
+ )
+
+ // Extract the numerator and denominator for integer division.
+ numerator := feeRateRational.Num()
+ denominator := feeRateRational.Denom()
+
+ // Perform integer division to truncate the result (round down).
+ quotient := big.NewInt(0)
+ quotient.Div(numerator, denominator)
+
+ return btcutil.Amount(quotient.Int64())
+}
+
+// FeeForWeightRoundUp calculates the fee resulting from this fee rate and the
+// given weight in weight units (wu), rounding up to the nearest satoshi.
+func (f baseFeeRate) FeeForWeightRoundUp(weightUnit WeightUnit) btcutil.Amount {
+ // The rounding logic for ceiling division is based on the formula:
+ // (numerator + denominator - 1) / denominator
+ // This ensures that any fractional part of the fee is rounded up to
+ // the next whole satoshi.
+ //
+ // Calculate the fee rate as a rational number.
+ feeRateRational := big.NewRat(0, 1)
+ feeRateRational.Mul(
+ f.satsPerKWU, big.NewRat(
+ safeUint64ToInt64(weightUnit.wu), kilo,
+ ),
+ )
+
+ // Get the numerator and denominator of the calculated fee.
+ numerator := feeRateRational.Num()
+ denominator := feeRateRational.Denom()
+
+ // Initialize a new big.Int to store the result of the ceiling division.
+ result := big.NewInt(0)
+
+ // Apply the ceiling division formula:
+ // (numerator + denominator - 1) / denominator.
+ result.Add(numerator, denominator)
+ result.Sub(result, big.NewInt(1))
+ result.Div(result, denominator)
+
+ return btcutil.Amount(result.Int64())
+}
+
+// FeeForVByte calculates the fee resulting from this fee rate and the given
+// size in vbytes (vb).
+func (f baseFeeRate) FeeForVByte(vb VByte) btcutil.Amount {
+ return f.FeeForWeight(vb.ToWU())
+}
+
+// FeeForKVByte calculates the fee resulting from this fee rate and the given
+// vsize in kilo-vbytes.
+func (f baseFeeRate) FeeForKVByte(kvb KVByte) btcutil.Amount {
+ // Directly convert kilo-virtual-bytes to weight units for fee
+ // calculation to maintain precision and avoid intermediate rounding
+ // effects.
+ return f.FeeForWeight(kvb.ToWU())
+}
+
+// FeeForKWeight calculates the fee resulting from this fee rate and the given
+// weight in kilo-weight-units (kwu).
+func (f baseFeeRate) FeeForKWeight(kwu KWeightUnit) btcutil.Amount {
+ return f.FeeForWeight(kwu.ToWU())
+}
+
+// equal returns true if the fee rate is equal to the other fee rate.
+func (f baseFeeRate) equal(other baseFeeRate) bool {
+ return f.satsPerKWU.Cmp(other.satsPerKWU) == 0
+}
+
+// greaterThan returns true if the fee rate is greater than the other fee rate.
+func (f baseFeeRate) greaterThan(other baseFeeRate) bool {
+ return f.satsPerKWU.Cmp(other.satsPerKWU) > 0
+}
+
+// lessThan returns true if the fee rate is less than the other fee rate.
+func (f baseFeeRate) lessThan(other baseFeeRate) bool {
+ return f.satsPerKWU.Cmp(other.satsPerKWU) < 0
+}
+
+// greaterThanOrEqual returns true if the fee rate is greater than or equal to
+// the other fee rate.
+func (f baseFeeRate) greaterThanOrEqual(other baseFeeRate) bool {
+ return f.satsPerKWU.Cmp(other.satsPerKWU) >= 0
+}
+
+// lessThanOrEqual returns true if the fee rate is less than or equal to the
+// other fee rate.
+func (f baseFeeRate) lessThanOrEqual(other baseFeeRate) bool {
+ return f.satsPerKWU.Cmp(other.satsPerKWU) <= 0
+}
+
+// SatPerVByte represents a fee rate in sat/vbyte. Internally, all fee rates
+// are stored and operated on as satoshis per kilo-weight-unit (sat/kw).
+// Conversions to other units and fee calculations are performed using this
+// canonical internal representation. The `String()` method is the only one
+// that presents the fee rate in its specific sat/vbyte unit.
+type SatPerVByte struct {
+ baseFeeRate
+}
+
+// NewSatPerVByte creates a new fee rate in sat/vb.
+func NewSatPerVByte(rate btcutil.Amount) SatPerVByte {
+ return CalcSatPerVByte(rate, NewVByte(1))
+}
+
+// CalcSatPerVByte calculates the fee rate in sat/vb for a given fee and size.
+func CalcSatPerVByte(fee btcutil.Amount, vb VByte) SatPerVByte {
+ // To convert the rate to the canonical sat/kwu unit, we use the
+ // formula: (fee * 1000) / size_in_wu.
+ //
+ // vb.wu provides the size in weight units (wu), implicitly accounting
+ // for the WitnessScaleFactor.
+ numerator := fee * kilo
+ denominator := vb.wu
+
+ return SatPerVByte{newBaseFeeRate(numerator, denominator)}
+}
+
+// String returns a human-readable string of the fee rate.
+func (s SatPerVByte) String() string {
+ // Calculate the fee rate in sat/vb from the canonical sat/kwu.
+ // The WitnessScaleFactor (4) is used to convert weight units to vbytes.
+ // The `kilo` constant is used to scale kilo-weight-units.
+ kwToVbRate := big.NewRat(0, 1)
+ kwToVbRate.Mul(s.satsPerKWU,
+ big.NewRat(blockchain.WitnessScaleFactor, kilo),
+ )
+
+ // Format the rational number to a string with the specified precision.
+ return kwToVbRate.FloatString(floatStringPrecision) + " sat/vb"
+}
+
+// Equal returns true if the fee rate is equal to the other fee rate.
+func (s SatPerVByte) Equal(other SatPerVByte) bool {
+ return s.equal(other.baseFeeRate)
+}
+
+// GreaterThan returns true if the fee rate is greater than the other fee rate.
+func (s SatPerVByte) GreaterThan(other SatPerVByte) bool {
+ return s.greaterThan(other.baseFeeRate)
+}
+
+// LessThan returns true if the fee rate is less than the other fee rate.
+func (s SatPerVByte) LessThan(other SatPerVByte) bool {
+ return s.lessThan(other.baseFeeRate)
+}
+
+// GreaterThanOrEqual returns true if the fee rate is greater than or equal to
+// the other fee rate.
+func (s SatPerVByte) GreaterThanOrEqual(other SatPerVByte) bool {
+ return s.greaterThanOrEqual(other.baseFeeRate)
+}
+
+// LessThanOrEqual returns true if the fee rate is less than or equal to the
+// other fee rate.
+func (s SatPerVByte) LessThanOrEqual(other SatPerVByte) bool {
+ return s.lessThanOrEqual(other.baseFeeRate)
+}
+
+// SatPerKVByte represents a fee rate in sat/kvb. Internally, all fee rates
+// are stored and operated on as satoshis per kilo-weight-unit (sat/kw).
+// Conversions to other units and fee calculations are performed using this
+// canonical internal representation. The `String()` method is the only one
+// that presents the fee rate in its specific sat/kvb unit.
+type SatPerKVByte struct {
+ baseFeeRate
+}
+
+// NewSatPerKVByte creates a new fee rate in sat/kvb.
+func NewSatPerKVByte(rate btcutil.Amount) SatPerKVByte {
+ return CalcSatPerKVByte(rate, NewKVByte(1))
+}
+
+// CalcSatPerKVByte calculates the fee rate in sat/kvb for a given fee and size.
+func CalcSatPerKVByte(fee btcutil.Amount, kvb KVByte) SatPerKVByte {
+ // To convert the rate to the canonical sat/kwu unit, we use the
+ // formula: (fee * 1000) / size_in_wu.
+ //
+ // kvb.wu provides the size in weight units (wu), implicitly accounting
+ // for the WitnessScaleFactor and kilo scaling.
+ numerator := fee * kilo
+ denominator := kvb.wu
+
+ return SatPerKVByte{newBaseFeeRate(numerator, denominator)}
+}
+
+// Val returns the fee rate in sat/kvb.
+//
+// NOTE: This method is provided for backward compatibility with legacy APIs
+// that expect a raw integer fee rate. New code should use the btcunit types
+// directly.
+func (s SatPerKVByte) Val() btcutil.Amount {
+ return s.FeeForKVByte(NewKVByte(1))
+}
+
+// String returns a human-readable string of the fee rate.
+func (s SatPerKVByte) String() string {
+ // Calculate the fee rate in sat/kvb from the canonical sat/kwu.
+ // The WitnessScaleFactor (4) is used to convert weight units to vbytes.
+ // No `kilo` division here as we are converting to *kilo*-vbytes.
+ kwToKvbRate := big.NewRat(0, 1)
+ kwToKvbRate.Mul(s.satsPerKWU,
+ big.NewRat(blockchain.WitnessScaleFactor, 1),
+ )
+
+ // Format the rational number to a string with the specified precision.
+ return kwToKvbRate.FloatString(floatStringPrecision) +
+ " sat/kvb"
+}
+
+// Equal returns true if the fee rate is equal to the other fee rate.
+func (s SatPerKVByte) Equal(other SatPerKVByte) bool {
+ return s.equal(other.baseFeeRate)
+}
+
+// GreaterThan returns true if the fee rate is greater than the other fee rate.
+func (s SatPerKVByte) GreaterThan(other SatPerKVByte) bool {
+ return s.greaterThan(other.baseFeeRate)
+}
+
+// LessThan returns true if the fee rate is less than the other fee rate.
+func (s SatPerKVByte) LessThan(other SatPerKVByte) bool {
+ return s.lessThan(other.baseFeeRate)
+}
+
+// GreaterThanOrEqual returns true if the fee rate is greater than or equal to
+// the other fee rate.
+func (s SatPerKVByte) GreaterThanOrEqual(other SatPerKVByte) bool {
+ return s.greaterThanOrEqual(other.baseFeeRate)
+}
+
+// LessThanOrEqual returns true if the fee rate is less than or equal to the
+// other fee rate.
+func (s SatPerKVByte) LessThanOrEqual(other SatPerKVByte) bool {
+ return s.lessThanOrEqual(other.baseFeeRate)
+}
+
+// SatPerKWeight represents a fee rate in sat/kw. Internally, all fee rates
+// are stored and operated on as satoshis per kilo-weight-unit (sat/kw).
+// Conversions to other units and fee calculations are performed using this
+// canonical internal representation. The `String()` method is the only one
+// that presents the fee rate in its specific sat/kw unit.
+type SatPerKWeight struct {
+ baseFeeRate
+}
+
+// NewSatPerKWeight creates a new fee rate in sat/kw.
+func NewSatPerKWeight(rate btcutil.Amount) SatPerKWeight {
+ return CalcSatPerKWeight(rate, NewKWeightUnit(1))
+}
+
+// CalcSatPerKWeight calculates the fee rate in sat/kw for a given fee and size.
+func CalcSatPerKWeight(fee btcutil.Amount, kwu KWeightUnit) SatPerKWeight {
+ // To convert the rate to the canonical sat/kwu unit, we use the
+ // formula: (fee * 1000) / size_in_wu.
+ //
+ // kwu.wu provides the size in weight units (wu), implicitly accounting
+ // for the kilo scaling.
+ numerator := fee * kilo
+ denominator := kwu.wu
+
+ return SatPerKWeight{newBaseFeeRate(numerator, denominator)}
+}
+
+// Val returns the fee rate in sat/kw.
+//
+// NOTE: This method is provided for backward compatibility with legacy APIs
+// that expect a raw integer fee rate. New code should use the btcunit types
+// directly.
+func (s SatPerKWeight) Val() btcutil.Amount {
+ return s.FeeForKWeight(NewKWeightUnit(1))
+}
+
+// String returns a human-readable string of the fee rate.
+func (s SatPerKWeight) String() string {
+ return s.satsPerKWU.FloatString(floatStringPrecision) + " sat/kw"
+}
+
+// Equal returns true if the fee rate is equal to the other fee rate.
+func (s SatPerKWeight) Equal(other SatPerKWeight) bool {
+ return s.equal(other.baseFeeRate)
+}
+
+// GreaterThan returns true if the fee rate is greater than the other fee rate.
+func (s SatPerKWeight) GreaterThan(other SatPerKWeight) bool {
+ return s.greaterThan(other.baseFeeRate)
+}
+
+// LessThan returns true if the fee rate is less than the other fee rate.
+func (s SatPerKWeight) LessThan(other SatPerKWeight) bool {
+ return s.lessThan(other.baseFeeRate)
+}
+
+// GreaterThanOrEqual returns true if the fee rate is greater than or equal to
+// the other fee rate.
+func (s SatPerKWeight) GreaterThanOrEqual(other SatPerKWeight) bool {
+ return s.greaterThanOrEqual(other.baseFeeRate)
+}
+
+// LessThanOrEqual returns true if the fee rate is less than or equal to the
+// other fee rate.
+func (s SatPerKWeight) LessThanOrEqual(other SatPerKWeight) bool {
+ return s.lessThanOrEqual(other.baseFeeRate)
+}
+
+// SatPerWeight represents a fee rate in sat/wu. Internally, all fee rates
+// are stored and operated on as satoshis per kilo-weight-unit (sat/kw).
+// Conversions to other units and fee calculations are performed using this
+// canonical internal representation. The `String()` method is the only one
+// that presents the fee rate in its specific sat/wu unit.
+type SatPerWeight struct {
+ baseFeeRate
+}
+
+// NewSatPerWeight creates a new fee rate in sat/wu.
+func NewSatPerWeight(rate btcutil.Amount) SatPerWeight {
+ return CalcSatPerWeight(rate, NewWeightUnit(1))
+}
+
+// CalcSatPerWeight calculates the fee rate in sat/wu for a given fee and size.
+func CalcSatPerWeight(fee btcutil.Amount, wu WeightUnit) SatPerWeight {
+ // To convert the rate to the canonical sat/kwu unit, we use the
+ // formula: (fee * 1000) / size_in_wu.
+ //
+ // wu.wu provides the size in weight units (wu).
+ numerator := fee * kilo
+ denominator := wu.wu
+
+ return SatPerWeight{newBaseFeeRate(numerator, denominator)}
+}
+
+// String returns a human-readable string of the fee rate.
+func (s SatPerWeight) String() string {
+ // Calculate the fee rate in sat/wu from the canonical sat/kwu.
+ // 1 sat/wu = 1000 sat/kwu. So we need to divide by kilo.
+ wuRate := big.NewRat(0, 1)
+ wuRate.Mul(s.satsPerKWU, big.NewRat(1, kilo))
+
+ return wuRate.FloatString(floatStringPrecision) + " sat/wu"
+}
+
+// Equal returns true if the fee rate is equal to the other fee rate.
+func (s SatPerWeight) Equal(other SatPerWeight) bool {
+ return s.equal(other.baseFeeRate)
+}
+
+// GreaterThan returns true if the fee rate is greater than the other fee rate.
+func (s SatPerWeight) GreaterThan(other SatPerWeight) bool {
+ return s.greaterThan(other.baseFeeRate)
+}
+
+// LessThan returns true if the fee rate is less than the other fee rate.
+func (s SatPerWeight) LessThan(other SatPerWeight) bool {
+ return s.lessThan(other.baseFeeRate)
+}
+
+// GreaterThanOrEqual returns true if the fee rate is greater than or equal to
+// the other fee rate.
+func (s SatPerWeight) GreaterThanOrEqual(other SatPerWeight) bool {
+ return s.greaterThanOrEqual(other.baseFeeRate)
+}
+
+// LessThanOrEqual returns true if the fee rate is less than or equal to the
+// other fee rate.
+func (s SatPerWeight) LessThanOrEqual(other SatPerWeight) bool {
+ return s.lessThanOrEqual(other.baseFeeRate)
+}
+
+// safeUint64ToInt64 converts a uint64 to an int64, capping at math.MaxInt64.
+// This is used to silence gosec warnings about integer overflows. In practice,
+// the values being converted are transaction weights or sizes, which are
+// limited by consensus rules and are not expected to overflow an int64.
+func safeUint64ToInt64(u uint64) int64 {
+ if u > math.MaxInt64 {
+ slog.Warn("Capping uint64 value to math.MaxInt64",
+ slog.Uint64("old", u), slog.Int64("new", math.MaxInt64))
+
+ return math.MaxInt64
+ }
+
+ return int64(u)
+}
diff --git a/pkg/btcunit/rates_test.go b/pkg/btcunit/rates_test.go
new file mode 100644
index 0000000000..8c3568cdc7
--- /dev/null
+++ b/pkg/btcunit/rates_test.go
@@ -0,0 +1,737 @@
+package btcunit
+
+import (
+ "math"
+ "math/big"
+ "testing"
+
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/stretchr/testify/require"
+)
+
+// TestFeeRateConversions checks that the conversion between the different fee
+// rate units is correct.
+func TestFeeRateConversions(t *testing.T) {
+ t.Parallel()
+
+ testCases := []struct {
+ name string
+ rate any
+ expectedVB SatPerVByte
+ expectedKVB SatPerKVByte
+ expectedKW SatPerKWeight
+ expectedW SatPerWeight
+ expectedSats btcutil.Amount
+ }{
+ {
+ name: "1 sat/vb",
+ rate: NewSatPerVByte(1),
+ expectedVB: NewSatPerVByte(1),
+ expectedKVB: NewSatPerKVByte(1000),
+ expectedKW: NewSatPerKWeight(250),
+ expectedW: CalcSatPerWeight(1, NewWeightUnit(4)),
+ expectedSats: 250,
+ },
+ {
+ name: "1000 sat/kvb",
+ rate: NewSatPerKVByte(1000),
+ expectedVB: NewSatPerVByte(1),
+ expectedKVB: NewSatPerKVByte(1000),
+ expectedKW: NewSatPerKWeight(250),
+ expectedW: CalcSatPerWeight(1, NewWeightUnit(4)),
+ expectedSats: 250,
+ },
+ {
+ name: "250 sat/kw",
+ rate: NewSatPerKWeight(250),
+ expectedVB: NewSatPerVByte(1),
+ expectedKVB: NewSatPerKVByte(1000),
+ expectedKW: NewSatPerKWeight(250),
+ expectedW: CalcSatPerWeight(1, NewWeightUnit(4)),
+ expectedSats: 250,
+ },
+ {
+ name: "0.25 sat/wu",
+ rate: CalcSatPerWeight(1, NewWeightUnit(4)),
+ expectedVB: NewSatPerVByte(1),
+ expectedKVB: NewSatPerKVByte(1000),
+ expectedKW: NewSatPerKWeight(250),
+ expectedW: CalcSatPerWeight(1, NewWeightUnit(4)),
+ expectedSats: 250,
+ },
+ {
+ name: "0.11 sat/vb",
+ rate: CalcSatPerVByte(11, NewVByte(100)),
+ expectedVB: CalcSatPerVByte(11, NewVByte(100)),
+ expectedKVB: NewSatPerKVByte(110),
+ expectedKW: CalcSatPerKWeight(55, NewKWeightUnit(2)),
+ expectedW: CalcSatPerWeight(11, NewWeightUnit(400)),
+ expectedSats: 27,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ switch r := tc.rate.(type) {
+ case SatPerVByte:
+ require.True(t, tc.expectedVB.equal(
+ r.ToSatPerVByte().baseFeeRate,
+ ))
+ require.True(t, tc.expectedKVB.equal(
+ r.ToSatPerKVByte().baseFeeRate,
+ ))
+ require.True(t, tc.expectedKW.equal(
+ r.ToSatPerKWeight().baseFeeRate,
+ ))
+ require.True(t, tc.expectedW.equal(
+ r.ToSatPerWeight().baseFeeRate,
+ ))
+
+ // Calculate the floor of the fee rate.
+ floor := big.NewInt(0)
+ floor.Div(
+ r.satsPerKWU.Num(),
+ r.satsPerKWU.Denom(),
+ )
+ require.Equal(
+ t, tc.expectedSats,
+ btcutil.Amount(floor.Int64()),
+ )
+
+ case SatPerKVByte:
+ require.True(t, tc.expectedVB.equal(
+ r.ToSatPerVByte().baseFeeRate,
+ ))
+ require.True(
+ t, tc.expectedKVB.equal(r.baseFeeRate),
+ )
+ require.True(t, tc.expectedKW.equal(
+ r.ToSatPerKWeight().baseFeeRate,
+ ))
+ require.True(t, tc.expectedW.equal(
+ r.ToSatPerWeight().baseFeeRate,
+ ))
+
+ // Calculate the floor of the fee rate.
+ floor := big.NewInt(0)
+ floor.Div(
+ r.satsPerKWU.Num(),
+ r.satsPerKWU.Denom(),
+ )
+ require.Equal(
+ t, tc.expectedSats,
+ btcutil.Amount(floor.Int64()),
+ )
+
+ case SatPerKWeight:
+ require.True(t,
+ tc.expectedVB.equal(
+ r.ToSatPerVByte().baseFeeRate,
+ ),
+ )
+ require.True(t, tc.expectedKVB.equal(
+ r.ToSatPerKVByte().baseFeeRate,
+ ))
+ require.True(
+ t, tc.expectedKW.equal(r.baseFeeRate),
+ )
+ require.True(t, tc.expectedW.equal(
+ r.ToSatPerWeight().baseFeeRate,
+ ))
+
+ // Calculate the floor of the fee rate.
+ floor := big.NewInt(0)
+ floor.Div(
+ r.satsPerKWU.Num(),
+ r.satsPerKWU.Denom(),
+ )
+ require.Equal(
+ t, tc.expectedSats,
+ btcutil.Amount(floor.Int64()),
+ )
+
+ case SatPerWeight:
+ require.True(t, tc.expectedVB.equal(
+ r.ToSatPerVByte().baseFeeRate,
+ ))
+ require.True(t, tc.expectedKVB.equal(
+ r.ToSatPerKVByte().baseFeeRate,
+ ))
+ require.True(t, tc.expectedKW.equal(
+ r.ToSatPerKWeight().baseFeeRate,
+ ))
+ require.True(
+ t, tc.expectedW.equal(r.baseFeeRate),
+ )
+
+ // Calculate the floor of the fee rate.
+ floor := big.NewInt(0)
+ floor.Div(
+ r.satsPerKWU.Num(),
+ r.satsPerKWU.Denom(),
+ )
+ require.Equal(
+ t, tc.expectedSats,
+ btcutil.Amount(floor.Int64()),
+ )
+ }
+ })
+ }
+}
+
+// TestFeeRateComparisonsVB tests the comparison methods of the SatPerVByte
+// type.
+func TestFeeRateComparisonsVB(t *testing.T) {
+ t.Parallel()
+
+ // Create a set of fee rates to compare.
+ r1 := NewSatPerVByte(1)
+ r2 := NewSatPerVByte(2)
+ r3 := NewSatPerVByte(1)
+
+ // Test Equal.
+ require.True(t, r1.Equal(r3))
+ require.False(t, r1.Equal(r2))
+
+ // Test GreaterThan.
+ require.True(t, r2.GreaterThan(r1))
+ require.False(t, r1.GreaterThan(r2))
+ require.False(t, r1.GreaterThan(r3))
+
+ // Test LessThan.
+ require.True(t, r1.LessThan(r2))
+ require.False(t, r2.LessThan(r1))
+ require.False(t, r1.LessThan(r3))
+
+ // Test GreaterThanOrEqual.
+ require.True(t, r2.GreaterThanOrEqual(r1))
+ require.True(t, r1.GreaterThanOrEqual(r3))
+ require.False(t, r1.GreaterThanOrEqual(r2))
+
+ // Test LessThanOrEqual.
+ require.True(t, r1.LessThanOrEqual(r2))
+ require.True(t, r1.LessThanOrEqual(r3))
+ require.False(t, r2.LessThanOrEqual(r1))
+}
+
+// TestFeeForWeightRoundUp checks that the FeeForWeightRoundUp method correctly
+// rounds up the fee for a given weight.
+func TestFeeForWeightRoundUp(t *testing.T) {
+ t.Parallel()
+
+ feeRate := NewSatPerVByte(1).ToSatPerKWeight()
+ txWeight := NewWeightUnit(674) // 674 weight units is 168.5 vb.
+
+ require.EqualValues(t, 168, feeRate.FeeForWeight(txWeight))
+ require.EqualValues(t, 169, feeRate.FeeForWeightRoundUp(txWeight))
+}
+
+// TestNewFeeRateConstructors checks that the New* and Calc* fee rate
+// constructors work as expected.
+func TestNewFeeRateConstructors(t *testing.T) {
+ t.Parallel()
+
+ // Test CalcSatPerKWeight.
+ fee := btcutil.Amount(1000)
+ wu := NewWeightUnit(1000)
+ expectedRate := NewSatPerKWeight(1000)
+ require.Zero(
+ t, expectedRate.satsPerKWU.Cmp(
+ CalcSatPerKWeight(fee, wu.ToKWU()).satsPerKWU,
+ ),
+ )
+
+ // Test CalcSatPerWeight.
+ expectedRateW := NewSatPerWeight(1000)
+ require.Zero(
+ t, expectedRateW.satsPerKWU.Cmp(
+ CalcSatPerWeight(fee, NewWeightUnit(1)).satsPerKWU,
+ ),
+ )
+
+ // Test CalcSatPerVByte.
+ vb := NewVByte(250)
+ expectedRateVB := NewSatPerVByte(4)
+ require.Zero(
+ t, expectedRateVB.satsPerKWU.Cmp(
+ CalcSatPerVByte(fee, vb).satsPerKWU,
+ ),
+ )
+
+ // Test CalcSatPerKVByte.
+ kvb := NewKVByte(1)
+ expectedRateKVB := NewSatPerKVByte(1000)
+ require.Zero(
+ t, expectedRateKVB.satsPerKWU.Cmp(
+ CalcSatPerKVByte(fee, kvb).satsPerKWU,
+ ),
+ )
+}
+
+// TestStringer tests the stringer methods of the fee rate types.
+func TestStringer(t *testing.T) {
+ t.Parallel()
+
+ // Create a set of fee rates to test.
+ r1 := NewSatPerVByte(1)
+ r2 := NewSatPerKVByte(1000)
+ r3 := NewSatPerKWeight(250)
+ r4 := CalcSatPerWeight(1, NewWeightUnit(4)) // 0.25 sat/wu
+
+ // Test String.
+ require.Equal(t, "1.000 sat/vb", r1.String())
+ require.Equal(t, "1000.000 sat/kvb", r2.String())
+ require.Equal(t, "250.000 sat/kw", r3.String())
+ require.Equal(t, "0.250 sat/wu", r4.String())
+}
+
+// TestFeeRateComparisonsKVB tests the comparison methods of the SatPerKVByte
+// type.
+func TestFeeRateComparisonsKVB(t *testing.T) {
+ t.Parallel()
+
+ // Create a set of fee rates to compare.
+ r1 := NewSatPerKVByte(1)
+ r2 := NewSatPerKVByte(2)
+ r3 := NewSatPerKVByte(1)
+
+ // Test Equal.
+ require.True(t, r1.Equal(r3))
+ require.False(t, r1.Equal(r2))
+
+ // Test GreaterThan.
+ require.True(t, r2.GreaterThan(r1))
+ require.False(t, r1.GreaterThan(r2))
+ require.False(t, r1.GreaterThan(r3))
+
+ // Test LessThan.
+ require.True(t, r1.LessThan(r2))
+ require.False(t, r2.LessThan(r1))
+ require.False(t, r1.LessThan(r3))
+
+ // Test GreaterThanOrEqual.
+ require.True(t, r2.GreaterThanOrEqual(r1))
+ require.True(t, r1.GreaterThanOrEqual(r3))
+ require.False(t, r1.GreaterThanOrEqual(r2))
+
+ // Test LessThanOrEqual.
+ require.True(t, r1.LessThanOrEqual(r2))
+ require.True(t, r1.LessThanOrEqual(r3))
+ require.False(t, r2.LessThanOrEqual(r1))
+}
+
+// TestFeeRateComparisonsKW tests the comparison methods of the SatPerKWeight
+// type.
+func TestFeeRateComparisonsKW(t *testing.T) {
+ t.Parallel()
+
+ // Create a set of fee rates to compare.
+ r1 := NewSatPerKWeight(1)
+ r2 := NewSatPerKWeight(2)
+ r3 := NewSatPerKWeight(1)
+
+ // Test Equal.
+ require.True(t, r1.Equal(r3))
+ require.False(t, r1.Equal(r2))
+
+ // Test GreaterThan.
+ require.True(t, r2.GreaterThan(r1))
+ require.False(t, r1.GreaterThan(r2))
+ require.False(t, r1.GreaterThan(r3))
+
+ // Test LessThan.
+ require.True(t, r1.LessThan(r2))
+ require.False(t, r2.LessThan(r1))
+ require.False(t, r1.LessThan(r3))
+
+ // Test GreaterThanOrEqual.
+ require.True(t, r2.GreaterThanOrEqual(r1))
+ require.True(t, r1.GreaterThanOrEqual(r3))
+ require.False(t, r1.GreaterThanOrEqual(r2))
+
+ // Test LessThanOrEqual.
+ require.True(t, r1.LessThanOrEqual(r2))
+ require.True(t, r1.LessThanOrEqual(r3))
+ require.False(t, r2.LessThanOrEqual(r1))
+}
+
+// TestFeeRateComparisonsW tests the comparison methods of the SatPerWeight
+// type.
+func TestFeeRateComparisonsW(t *testing.T) {
+ t.Parallel()
+
+ // Create a set of fee rates to compare.
+ r1 := NewSatPerWeight(1)
+ r2 := NewSatPerWeight(2)
+ r3 := NewSatPerWeight(1)
+
+ // Test Equal.
+ require.True(t, r1.Equal(r3))
+ require.False(t, r1.Equal(r2))
+
+ // Test GreaterThan.
+ require.True(t, r2.GreaterThan(r1))
+ require.False(t, r1.GreaterThan(r2))
+ require.False(t, r1.GreaterThan(r3))
+
+ // Test LessThan.
+ require.True(t, r1.LessThan(r2))
+ require.False(t, r2.LessThan(r1))
+ require.False(t, r1.LessThan(r3))
+
+ // Test GreaterThanOrEqual.
+ require.True(t, r2.GreaterThanOrEqual(r1))
+ require.True(t, r1.GreaterThanOrEqual(r3))
+ require.False(t, r1.GreaterThanOrEqual(r2))
+
+ // Test LessThanOrEqual.
+ require.True(t, r1.LessThanOrEqual(r2))
+ require.True(t, r1.LessThanOrEqual(r3))
+ require.False(t, r2.LessThanOrEqual(r1))
+}
+
+// TestFeeForSize tests the FeeForVSize and FeeForVByte methods.
+func TestFeeForSize(t *testing.T) {
+ t.Parallel()
+
+ // Create a set of fee rates to test.
+ // r1: 1000 sat/kvb = 1000 sat / 1000 vbyte = 1 sat/vbyte.
+ r1 := NewSatPerKVByte(1000)
+
+ // r2: 250 sat/kwu. This matches r1.
+ r2 := NewSatPerKWeight(250)
+
+ // r3: 1 sat/vbyte.
+ r3 := NewSatPerVByte(1)
+
+ // r4: 0.25 sat/wu.
+ // 0.25 sat/wu * 1000 = 250 sat/kwu.
+ r4 := CalcSatPerWeight(1, NewWeightUnit(4))
+
+ // Test FeeForVByte with r1 (1000 sat/kvb).
+ // Size: 250 vbytes.
+ // Fee: 250 vbytes * 1 sat/vbyte = 250 sats.
+ require.Equal(t, btcutil.Amount(250), r1.FeeForVByte(NewVByte(250)))
+
+ // Test FeeForVByte with r2 (250 sat/kwu).
+ // Size: 250 vbytes = 1000 weight units.
+ // Rate: 250 sat/1000 wu = 0.25 sat/wu.
+ // Fee: 1000 wu * 0.25 sat/wu = 250 sats.
+ require.Equal(t, btcutil.Amount(250), r2.FeeForVByte(NewVByte(250)))
+
+ // Test FeeForVByte with SatPerVByte.
+ // Size: 1000 vbytes.
+ // Rate: 1 sat/vbyte.
+ // Fee: 1000 sats.
+ require.Equal(t, btcutil.Amount(1000), r3.FeeForVByte(NewVByte(1000)))
+
+ // Test FeeForKVByte with SatPerVByte.
+ // Size: 1 kvb = 1000 vbytes.
+ // Rate: 1 sat/vbyte.
+ // Fee: 1000 sats.
+ require.Equal(t, btcutil.Amount(1000), r3.FeeForKVByte(NewKVByte(1)))
+
+ // Test FeeForWeight with SatPerVByte.
+ // Size: 1000 weight units.
+ // Rate: 1 sat/vbyte = 0.25 sat/wu.
+ // Fee: 1000 * 0.25 = 250 sats.
+ require.Equal(t, btcutil.Amount(250),
+ r3.FeeForWeight(NewWeightUnit(1000)))
+
+ // Test ToSatPerVByte with SatPerKVByte.
+ // 1000 sat/kvb should equal 1 sat/vbyte.
+ require.True(t, r3.Equal(r1.ToSatPerVByte()))
+
+ // Test FeeForKVByte with SatPerKVByte.
+ // Size: 1 kvb.
+ // Rate: 1000 sat/kvb.
+ // Fee: 1000 sats.
+ require.Equal(t, btcutil.Amount(1000), r1.FeeForKVByte(NewKVByte(1)))
+
+ // Test FeeForWeight with SatPerKVByte.
+ // Size: 1000 weight units.
+ // Rate: 1000 sat/kvb = 0.25 sat/wu.
+ // Fee: 1000 * 0.25 = 250 sats.
+ require.Equal(t, btcutil.Amount(250),
+ r1.FeeForWeight(NewWeightUnit(1000)))
+
+ // Test FeeForKVByte with SatPerKWeight.
+ // Size: 1 kvb = 1000 vbytes = 4000 weight units.
+ // Rate: 250 sat/kwu = 0.25 sat/wu.
+ // Fee: 4000 * 0.25 = 1000 sats.
+ require.Equal(t, btcutil.Amount(1000), r2.FeeForKVByte(NewKVByte(1)))
+
+ // Test FeeForKWeight with SatPerKWeight.
+ // Size: 1 kwu = 1000 weight units.
+ // Rate: 250 sat/kwu = 0.25 sat/wu.
+ // Fee: 1000 * 0.25 = 250 sats.
+ require.Equal(t, btcutil.Amount(250),
+ r2.FeeForKWeight(NewKWeightUnit(1)))
+
+ // Test FeeForWeight with SatPerWeight.
+ // Size: 1000 weight units.
+ // Rate: 0.25 sat/wu.
+ // Fee: 1000 * 0.25 = 250 sats.
+ require.Equal(t, btcutil.Amount(250),
+ r4.FeeForWeight(NewWeightUnit(1000)))
+
+ // Test ToSatPerWeight with SatPerVByte.
+ // 1 sat/vbyte should equal 0.25 sat/wu.
+ require.True(t, r4.Equal(r3.ToSatPerWeight()))
+}
+
+// TestNewFeeRateConstructorsZero tests the New* fee rate constructors with
+// zero values.
+func TestNewFeeRateConstructorsZero(t *testing.T) {
+ t.Parallel()
+
+ // Test CalcSatPerKWeight with zero weight should panic.
+ fee := btcutil.Amount(1000)
+ require.Panics(t, func() {
+ kwu := NewKWeightUnit(0)
+ _ = CalcSatPerKWeight(fee, kwu)
+ })
+
+ // Test CalcSatPerVByte with zero vbytes should panic.
+ require.Panics(t, func() {
+ vb := NewVByte(0)
+ _ = CalcSatPerVByte(fee, vb)
+ })
+
+ // Test CalcSatPerKVByte with zero kvbytes should panic.
+ require.Panics(t, func() {
+ kvb := NewKVByte(0)
+ _ = CalcSatPerKVByte(fee, kvb)
+ })
+
+ // Test CalcSatPerWeight with zero weight units should panic.
+ require.Panics(t, func() {
+ wu := NewWeightUnit(0)
+ _ = CalcSatPerWeight(fee, wu)
+ })
+
+ // Test zero constants.
+ // NewSatPerVByte(0) -> Rate 0 sats / 1 vb. Valid.
+ require.True(t, ZeroSatPerVByte.Equal(NewSatPerVByte(0)))
+ require.True(t, ZeroSatPerKVByte.Equal(NewSatPerKVByte(0)))
+ require.True(t, ZeroSatPerKWeight.Equal(NewSatPerKWeight(0)))
+ require.True(t, ZeroSatPerWeight.Equal(NewSatPerWeight(0)))
+
+ require.Equal(t, "0.000 sat/vb", ZeroSatPerVByte.String())
+ require.Equal(t, "0.000 sat/kvb", ZeroSatPerKVByte.String())
+ require.Equal(t, "0.000 sat/kw", ZeroSatPerKWeight.String())
+ require.Equal(t, "0.000 sat/wu", ZeroSatPerWeight.String())
+}
+
+// TestSafeUint64ToInt64Overflow tests the overflow condition in
+// safeUint64ToInt64 through the New* constructors.
+func TestSafeUint64ToInt64Overflow(t *testing.T) {
+ t.Parallel()
+
+ fee := btcutil.Amount(1)
+
+ // Test CalcSatPerVByte with an overflowing vbyte value.
+ // The denominator should be capped at math.MaxInt64.
+ // We manually construct the VByte to ensure wu > MaxInt64 without
+ // overflowing the constructor's internal multiplication.
+ overflowVByte := VByte{baseUnit{wu: math.MaxInt64 + 1}}
+ expectedDenom := big.NewInt(math.MaxInt64)
+
+ rateVB := CalcSatPerVByte(fee, overflowVByte)
+ require.Zero(t, expectedDenom.Cmp(rateVB.satsPerKWU.Denom()))
+
+ // Test CalcSatPerKVByte with an overflowing kvb value.
+ // The denominator should be capped at math.MaxInt64.
+ overflowKVByte := KVByte{baseUnit{wu: math.MaxInt64 + 1}}
+ rateKVB := CalcSatPerKVByte(fee, overflowKVByte)
+ require.Zero(t, expectedDenom.Cmp(rateKVB.satsPerKWU.Denom()))
+
+ // Test CalcSatPerKWeight with an overflowing weight unit value.
+ overflowWU := KWeightUnit{baseUnit{wu: math.MaxInt64 + 1}}
+ rateKW := CalcSatPerKWeight(fee, overflowWU)
+ require.Zero(t, expectedDenom.Cmp(rateKW.satsPerKWU.Denom()))
+
+ // Test CalcSatPerWeight with an overflowing weight unit value.
+ overflowWeight := WeightUnit{baseUnit{wu: math.MaxInt64 + 1}}
+ rateW := CalcSatPerWeight(fee, overflowWeight)
+ require.Zero(t, expectedDenom.Cmp(rateW.satsPerKWU.Denom()))
+}
+
+// TestVal checks that the Val method returns the correct integer fee rate.
+func TestVal(t *testing.T) {
+ t.Parallel()
+
+ // Test SatPerKVByte.Val().
+ rateKVB := NewSatPerKVByte(1000)
+ require.Equal(t, btcutil.Amount(1000), rateKVB.Val())
+
+ // Test SatPerKWeight.Val().
+ rateKW := NewSatPerKWeight(250)
+ require.Equal(t, btcutil.Amount(250), rateKW.Val())
+}
+
+// TestRatePrecision checks that baseFeeRate preserves precision for
+// non-integer rates (e.g., repeating decimals) during conversions and fee
+// calculations for all rate units.
+func TestRatePrecision(t *testing.T) {
+ t.Parallel()
+
+ // We choose a test payload size of 12,000 weight units.
+ // This specific number is chosen because it is cleanly divisible by
+ // all unit factors, allowing us to pass exact integer amounts to all
+ // FeeFor... methods.
+ //
+ // 12,000 wu = 12 kwu
+ // 12,000 wu = 3,000 vb
+ // 12,000 wu = 3 kvb
+ const (
+ payloadWU = 12000
+ payloadKWU = 12
+ payloadVB = 3000
+ payloadKVB = 3
+ )
+
+ // expectedFee is always 1 satoshi because we define the rate in each
+ // test case as (1 sat / payload_size).
+ const expectedFee = btcutil.Amount(1)
+
+ // 1. Test SatPerWeight.
+ // Rate: 1 sat / 12,000 wu = 0.0000833... sat/wu.
+ t.Run("SatPerWeight", func(t *testing.T) {
+ t.Parallel()
+
+ rate := CalcSatPerWeight(1, NewWeightUnit(payloadWU))
+
+ // The rate 0.0000833... rounds to 0.000 when displayed with 3
+ // decimal places, but the internal precision is preserved.
+ require.Equal(t, "0.000 sat/wu", rate.String())
+ require.Equal(t, expectedFee,
+ rate.FeeForWeight(NewWeightUnit(payloadWU)))
+
+ // Convert to SatPerKWeight.
+ // Rate: 1 sat / 12 kwu = 0.0833... sat/kw.
+ kw := rate.ToSatPerKWeight()
+ require.Equal(t, "0.083 sat/kw", kw.String())
+ require.Equal(t, expectedFee,
+ kw.FeeForKWeight(NewKWeightUnit(payloadKWU)))
+
+ // Convert to SatPerVByte.
+ // Rate: 1 sat / 3,000 vb = 0.00033... sat/vb.
+ // This rounds to 0.000 at 3 decimals.
+ vb := rate.ToSatPerVByte()
+ require.Equal(t, "0.000 sat/vb", vb.String())
+ require.Equal(t, expectedFee,
+ vb.FeeForVByte(NewVByte(payloadVB)))
+
+ // Convert to SatPerKVByte.
+ // Rate: 1 sat / 3 kvb = 0.333... sat/kvb.
+ kvb := rate.ToSatPerKVByte()
+ require.Equal(t, "0.333 sat/kvb", kvb.String())
+ require.Equal(t, expectedFee,
+ kvb.FeeForKVByte(NewKVByte(payloadKVB)))
+ })
+
+ // 2. Test SatPerKWeight.
+ // Rate: 1 sat / 12 kwu = 0.0833... sat/kw.
+ t.Run("SatPerKWeight", func(t *testing.T) {
+ t.Parallel()
+
+ rate := CalcSatPerKWeight(1, NewKWeightUnit(payloadKWU))
+ require.Equal(t, "0.083 sat/kw", rate.String())
+ require.Equal(t, expectedFee,
+ rate.FeeForKWeight(NewKWeightUnit(payloadKWU)))
+
+ // Convert to SatPerWeight.
+ // Rate: 1 sat / 12,000 wu = 0.0000833... sat/wu.
+ // Rounds to 0.000.
+ w := rate.ToSatPerWeight()
+ require.Equal(t, "0.000 sat/wu", w.String())
+ require.Equal(t, expectedFee,
+ w.FeeForWeight(NewWeightUnit(payloadWU)))
+
+ // Convert to SatPerVByte.
+ // Rate: 1 sat / 3,000 vb = 0.00033... sat/vb.
+ // Rounds to 0.000.
+ vb := rate.ToSatPerVByte()
+ require.Equal(t, "0.000 sat/vb", vb.String())
+ require.Equal(t, expectedFee,
+ vb.FeeForVByte(NewVByte(payloadVB)))
+
+ // Convert to SatPerKVByte.
+ // Rate: 1 sat / 3 kvb = 0.333... sat/kvb.
+ kvb := rate.ToSatPerKVByte()
+ require.Equal(t, "0.333 sat/kvb", kvb.String())
+ require.Equal(t, expectedFee,
+ kvb.FeeForKVByte(NewKVByte(payloadKVB)))
+ })
+
+ // 3. Test SatPerVByte.
+ // Rate: 1 sat / 3,000 vb = 0.00033... sat/vb.
+ t.Run("SatPerVByte", func(t *testing.T) {
+ t.Parallel()
+
+ rate := CalcSatPerVByte(1, NewVByte(payloadVB))
+ // Rounds to 0.000 at 3 decimals.
+ require.Equal(t, "0.000 sat/vb", rate.String())
+ require.Equal(t, expectedFee,
+ rate.FeeForVByte(NewVByte(payloadVB)))
+
+ // Convert to SatPerKVByte.
+ // Rate: 1 sat / 3 kvb = 0.333... sat/kvb.
+ kvb := rate.ToSatPerKVByte()
+ require.Equal(t, "0.333 sat/kvb", kvb.String())
+ require.Equal(t, expectedFee,
+ kvb.FeeForKVByte(NewKVByte(payloadKVB)))
+
+ // Convert to SatPerKWeight.
+ // Rate: 1 sat / 12 kwu = 0.0833... sat/kw.
+ kw := rate.ToSatPerKWeight()
+ require.Equal(t, "0.083 sat/kw", kw.String())
+ require.Equal(t, expectedFee,
+ kw.FeeForKWeight(NewKWeightUnit(payloadKWU)))
+
+ // Convert to SatPerWeight.
+ // Rate: 1 sat / 12,000 wu = 0.0000833... sat/wu.
+ // Rounds to 0.000.
+ w := rate.ToSatPerWeight()
+ require.Equal(t, "0.000 sat/wu", w.String())
+ require.Equal(t, expectedFee,
+ w.FeeForWeight(NewWeightUnit(payloadWU)))
+ })
+
+ // 4. Test SatPerKVByte.
+ // Rate: 1 sat / 3 kvb = 0.333... sat/kvb.
+ t.Run("SatPerKVByte", func(t *testing.T) {
+ t.Parallel()
+
+ rate := CalcSatPerKVByte(1, NewKVByte(payloadKVB))
+ require.Equal(t, "0.333 sat/kvb", rate.String())
+ require.Equal(t, expectedFee,
+ rate.FeeForKVByte(NewKVByte(payloadKVB)))
+
+ // Convert to SatPerVByte.
+ // Rate: 1 sat / 3,000 vb = 0.00033... sat/vb.
+ // Rounds to 0.000.
+ vb := rate.ToSatPerVByte()
+ require.Equal(t, "0.000 sat/vb", vb.String())
+ require.Equal(t, expectedFee,
+ vb.FeeForVByte(NewVByte(payloadVB)))
+
+ // Convert to SatPerKWeight.
+ // Rate: 1 sat / 12 kwu = 0.0833... sat/kw.
+ kw := rate.ToSatPerKWeight()
+ require.Equal(t, "0.083 sat/kw", kw.String())
+ require.Equal(t, expectedFee,
+ kw.FeeForKWeight(NewKWeightUnit(payloadKWU)))
+
+ // Convert to SatPerWeight.
+ // Rate: 1 sat / 12,000 wu = 0.0000833... sat/wu.
+ // Rounds to 0.000.
+ w := rate.ToSatPerWeight()
+ require.Equal(t, "0.000 sat/wu", w.String())
+ require.Equal(t, expectedFee,
+ w.FeeForWeight(NewWeightUnit(payloadWU)))
+ })
+}
diff --git a/pkg/btcunit/txsize.go b/pkg/btcunit/txsize.go
new file mode 100644
index 0000000000..bc61cbe57f
--- /dev/null
+++ b/pkg/btcunit/txsize.go
@@ -0,0 +1,112 @@
+package btcunit
+
+import (
+ "fmt"
+
+ "github.com/btcsuite/btcd/blockchain"
+)
+
+// baseUnit stores the canonical representation of a transaction size, which is
+// weight units (wu). All other size units are derived from this.
+type baseUnit struct {
+ wu uint64
+}
+
+// ToWU converts the unit to a WeightUnit.
+func (b baseUnit) ToWU() WeightUnit {
+ return WeightUnit{b}
+}
+
+// ToVB converts the unit to a VByte.
+func (b baseUnit) ToVB() VByte {
+ return VByte{b}
+}
+
+// ToKVB converts the unit to a KVByte.
+func (b baseUnit) ToKVB() KVByte {
+ return KVByte{b}
+}
+
+// ToKWU converts the unit to a KWeightUnit.
+func (b baseUnit) ToKWU() KWeightUnit {
+ return KWeightUnit{b}
+}
+
+// WeightUnit defines a unit to express the transaction size. One weight unit
+// is 1/4_000_000 of the max block size. The tx weight is calculated using
+// `Base tx size * 3 + Total tx size`.
+// - Base tx size is size of the transaction serialized without the witness
+// data.
+// - Total tx size is the transaction size in bytes serialized according
+// #BIP144.
+type WeightUnit struct {
+ // The internal size is recorded in weight units.
+ baseUnit
+}
+
+// NewWeightUnit creates a new WeightUnit from a uint64 value.
+func NewWeightUnit(val uint64) WeightUnit {
+ return WeightUnit{baseUnit{wu: val}}
+}
+
+// String returns the string representation of the weight unit.
+func (w WeightUnit) String() string {
+ return fmt.Sprintf("%d wu", w.wu)
+}
+
+// VByte defines a unit to express the transaction size. One virtual byte is
+// 1/4th of a weight unit. The tx virtual bytes is calculated using `TxWeight /
+// 4`.
+type VByte struct {
+ // The internal size is recorded in weight units.
+ baseUnit
+}
+
+// NewVByte creates a new VByte from a uint64 value.
+func NewVByte(val uint64) VByte {
+ return VByte{baseUnit{wu: val * blockchain.WitnessScaleFactor}}
+}
+
+// String returns the string representation of the virtual byte.
+func (v VByte) String() string {
+ vbytes := (v.wu + blockchain.WitnessScaleFactor - 1) /
+ blockchain.WitnessScaleFactor
+
+ return fmt.Sprintf("%d vb", vbytes)
+}
+
+// KVByte defines a unit to express the transaction size in kilo-virtual-bytes.
+type KVByte struct {
+ // The internal size is recorded in weight units.
+ baseUnit
+}
+
+// NewKVByte creates a new KVByte from a uint64.
+func NewKVByte(val uint64) KVByte {
+ return KVByte{baseUnit{wu: val * kilo * blockchain.WitnessScaleFactor}}
+}
+
+// String returns the string representation of the kilo-virtual-byte.
+func (k KVByte) String() string {
+ vbytes := (k.wu + blockchain.WitnessScaleFactor - 1) /
+ blockchain.WitnessScaleFactor
+
+ return fmt.Sprintf("%d kvb", vbytes/kilo)
+}
+
+// KWeightUnit defines a unit to express the transaction size in
+// kilo-weight-units.
+type KWeightUnit struct {
+ // The internal size is recorded in weight units.
+ baseUnit
+}
+
+// NewKWeightUnit creates a new KWeightUnit from a uint64.
+func NewKWeightUnit(val uint64) KWeightUnit {
+ return KWeightUnit{baseUnit{wu: val * kilo}}
+}
+
+// String returns the string representation of the kilo-weight-unit.
+func (k KWeightUnit) String() string {
+ return fmt.Sprintf("%d kwu", k.wu/kilo)
+}
diff --git a/pkg/btcunit/txsize_test.go b/pkg/btcunit/txsize_test.go
new file mode 100644
index 0000000000..63835dcdcb
--- /dev/null
+++ b/pkg/btcunit/txsize_test.go
@@ -0,0 +1,129 @@
+package btcunit
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// TestBaseUnitConversions checks that the conversion methods of baseUnit are
+// correct.
+func TestBaseUnitConversions(t *testing.T) {
+ t.Parallel()
+
+ // Test data: 1000 weight units.
+ base := baseUnit{wu: 1000}
+
+ // Test ToWU: 1000 wu.
+ wu := base.ToWU()
+ require.Equal(t, uint64(1000), wu.wu)
+
+ // Test ToVByte: 1000 wu (250 vb).
+ vb := base.ToVB()
+ require.Equal(t, uint64(1000), vb.wu)
+
+ // Test ToKVByte: 1000 wu (0.25 kvb).
+ kvb := base.ToKVB()
+ require.Equal(t, uint64(1000), kvb.wu)
+
+ // Test ToKWeightUnit: 1000 wu (1 kwu).
+ kwu := base.ToKWU()
+ require.Equal(t, uint64(1000), kwu.wu)
+}
+
+// TestTxSizeConversion checks that the conversion between weight units and
+// virtual bytes is correct.
+func TestTxSizeConversion(t *testing.T) {
+ t.Parallel()
+
+ // We'll use 4000 weight units (wu) as our base for testing. This is
+ // equivalent to 1000 virtual bytes (vb), 1 kilo-virtual-byte (kvb),
+ // and 4 kilo-weight-units (kwu).
+ //
+ // Initialize the same size in different units.
+ wu := NewWeightUnit(4000)
+ vb := NewVByte(1000)
+ kvb := NewKVByte(1)
+ kwu := NewKWeightUnit(4)
+
+ // Check that the internal 'wu' values are consistent across different
+ // unit types representing the same size.
+ require.Equal(t, uint64(4000), wu.wu)
+ require.Equal(t, uint64(4000), vb.wu)
+ require.Equal(t, uint64(4000), kvb.wu)
+ require.Equal(t, uint64(4000), kwu.wu)
+
+ // Test conversions from WeightUnit. After conversion, the underlying
+ // weight units (wu) should remain 4000.
+ require.Equal(t, uint64(4000), wu.ToWU().wu)
+ require.Equal(t, uint64(4000), wu.ToVB().wu)
+ require.Equal(t, uint64(4000), wu.ToKVB().wu)
+ require.Equal(t, uint64(4000), wu.ToKWU().wu)
+ require.Equal(t, "4000 wu", wu.String())
+
+ // Test conversions from VByte. After conversion, the underlying weight
+ // units (wu) should remain 4000.
+ require.Equal(t, uint64(4000), vb.ToWU().wu)
+ require.Equal(t, uint64(4000), vb.ToVB().wu)
+ require.Equal(t, uint64(4000), vb.ToKVB().wu)
+ require.Equal(t, uint64(4000), vb.ToKWU().wu)
+ require.Equal(t, "1000 vb", vb.String())
+
+ // Test conversions from KVByte. After conversion, the underlying
+ // weight units (wu) should remain 4000.
+ require.Equal(t, uint64(4000), kvb.ToWU().wu)
+ require.Equal(t, uint64(4000), kvb.ToVB().wu)
+ require.Equal(t, uint64(4000), kvb.ToKVB().wu)
+ require.Equal(t, uint64(4000), kvb.ToKWU().wu)
+ require.Equal(t, "1 kvb", kvb.String())
+
+ // Test conversions from KWeightUnit. After conversion, the underlying
+ // weight units (wu) should remain 4000.
+ require.Equal(t, uint64(4000), kwu.ToWU().wu)
+ require.Equal(t, uint64(4000), kwu.ToVB().wu)
+ require.Equal(t, uint64(4000), kwu.ToKVB().wu)
+ require.Equal(t, uint64(4000), kwu.ToKWU().wu)
+ require.Equal(t, "4 kwu", kwu.String())
+}
+
+// TestTxSizePrecision checks that precision is preserved when converting
+// between units for values that are not perfectly divisible by the witness
+// scale factor.
+func TestTxSizePrecision(t *testing.T) {
+ t.Parallel()
+
+ // Use a weight unit value that is not divisible by 4
+ // (WitnessScaleFactor).
+ // 3999 % 4 = 3.
+ wu := NewWeightUnit(3999)
+
+ // Convert to VByte. This should wrap the same underlying wu value.
+ vb := wu.ToVB()
+ require.Equal(t, uint64(3999), vb.wu)
+
+ // Convert back to WeightUnit. Should still be 3999.
+ wu2 := vb.ToWU()
+ require.Equal(t, uint64(3999), wu2.wu)
+
+ // The string representation should still perform the rounding for
+ // display.
+ // ceil(3999 / 4) = 1000.
+ require.Equal(t, "1000 vb", vb.String())
+}
+
+// TestTxSizeStringer tests the stringer methods of the tx size types.
+func TestTxSizeStringer(t *testing.T) {
+ t.Parallel()
+
+ // Create a test weight of 1000 wu.
+ wu := NewWeightUnit(1000)
+ vb := NewVByte(250)
+ kvb := NewKVByte(1)
+ kwu := NewKWeightUnit(1)
+
+ // Test String.
+ require.Equal(t, "1000 wu", wu.String())
+ require.Equal(t, "250 vb", vb.String())
+ require.Equal(t, "1 kvb", kvb.String())
+ require.Equal(t, "1 kwu", kwu.String())
+}
diff --git a/rpc/legacyrpc/methods.go b/rpc/legacyrpc/methods.go
index 21827a2a41..17858f1406 100644
--- a/rpc/legacyrpc/methods.go
+++ b/rpc/legacyrpc/methods.go
@@ -451,7 +451,7 @@ func getBalance(icmd interface{}, w *wallet.Wallet) (interface{}, error) {
// getBestBlock handles a getbestblock request by returning a JSON object
// with the height and hash of the most recently processed block.
func getBestBlock(icmd interface{}, w *wallet.Wallet) (interface{}, error) {
- blk := w.Manager.SyncedTo()
+ blk := w.AddrManager().SyncedTo()
result := &btcjson.GetBestBlockResult{
Hash: blk.Hash.String(),
Height: blk.Height,
@@ -462,14 +462,14 @@ func getBestBlock(icmd interface{}, w *wallet.Wallet) (interface{}, error) {
// getBestBlockHash handles a getbestblockhash request by returning the hash
// of the most recently processed block.
func getBestBlockHash(icmd interface{}, w *wallet.Wallet) (interface{}, error) {
- blk := w.Manager.SyncedTo()
+ blk := w.AddrManager().SyncedTo()
return blk.Hash.String(), nil
}
// getBlockCount handles a getblockcount request by returning the chain height
// of the most recently processed block.
func getBlockCount(icmd interface{}, w *wallet.Wallet) (interface{}, error) {
- blk := w.Manager.SyncedTo()
+ blk := w.AddrManager().SyncedTo()
return blk.Height, nil
}
@@ -671,7 +671,8 @@ func renameAccount(icmd interface{}, w *wallet.Wallet) (interface{}, error) {
if err != nil {
return nil, err
}
- return nil, w.RenameAccount(waddrmgr.KeyScopeBIP0044, account, cmd.NewAccount)
+
+ return nil, w.RenameAccountDeprecated(waddrmgr.KeyScopeBIP0044, account, cmd.NewAccount)
}
// getNewAddress handles a getnewaddress request by returning a new
@@ -702,7 +703,7 @@ func getNewAddress(icmd interface{}, w *wallet.Wallet) (interface{}, error) {
if err != nil {
return nil, err
}
- addr, err := w.NewAddress(account, keyScope)
+ addr, err := w.NewAddressDeprecated(account, keyScope)
if err != nil {
return nil, err
}
@@ -812,7 +813,7 @@ func getTransaction(icmd interface{}, w *wallet.Wallet) (interface{}, error) {
return nil, &ErrNoTransactionInfo
}
- syncBlock := w.Manager.SyncedTo()
+ syncBlock := w.AddrManager().SyncedTo()
// TODO: The serialized transaction is already in the DB, so
// reserializing can be avoided here.
@@ -1135,7 +1136,7 @@ func listReceivedByAddress(icmd interface{}, w *wallet.Wallet) (interface{}, err
account string
}
- syncBlock := w.Manager.SyncedTo()
+ syncBlock := w.AddrManager().SyncedTo()
// Intermediate data for all addresses.
allAddrData := make(map[string]AddrData)
@@ -1214,7 +1215,7 @@ func listReceivedByAddress(icmd interface{}, w *wallet.Wallet) (interface{}, err
func listSinceBlock(icmd interface{}, w *wallet.Wallet, chainClient *chain.RPCClient) (interface{}, error) {
cmd := icmd.(*btcjson.ListSinceBlockCmd)
- syncBlock := w.Manager.SyncedTo()
+ syncBlock := w.AddrManager().SyncedTo()
targetConf := int64(*cmd.TargetConfirmations)
// For the result we need the block hash for the last block counted
@@ -1331,7 +1332,7 @@ func listUnspent(icmd interface{}, w *wallet.Wallet) (interface{}, error) {
}
}
- return w.ListUnspent(int32(*cmd.MinConf), int32(*cmd.MaxConf), "")
+ return w.ListUnspentDeprecated(int32(*cmd.MinConf), int32(*cmd.MaxConf), "") //nolint:gosec,staticcheck
}
// lockUnspent handles the lockunspent command.
@@ -1778,7 +1779,7 @@ func validateAddress(icmd interface{}, w *wallet.Wallet) (interface{}, error) {
result.Address = addr.EncodeAddress()
result.IsValid = true
- ainfo, err := w.AddressInfo(addr)
+ ainfo, err := w.AddressInfoDeprecated(addr)
if err != nil {
if waddrmgr.IsError(err, waddrmgr.ErrAddressNotFound) {
// No additional information available about the address.
@@ -1897,7 +1898,7 @@ func walletIsLocked(icmd interface{}, w *wallet.Wallet) (interface{}, error) {
// wallets, returning an error if any wallet is not encrypted (for example,
// a watching-only wallet).
func walletLock(icmd interface{}, w *wallet.Wallet) (interface{}, error) {
- w.Lock()
+ w.LockDeprecated()
return nil, nil
}
@@ -1912,7 +1913,7 @@ func walletPassphrase(icmd interface{}, w *wallet.Wallet) (interface{}, error) {
if timeout != 0 {
unlockAfter = time.After(timeout)
}
- err := w.Unlock([]byte(cmd.Passphrase), unlockAfter)
+ err := w.UnlockDeprecated([]byte(cmd.Passphrase), unlockAfter)
return nil, err
}
diff --git a/rpc/legacyrpc/server.go b/rpc/legacyrpc/server.go
index 63c87bcdc5..fe7f43a7e2 100644
--- a/rpc/legacyrpc/server.go
+++ b/rpc/legacyrpc/server.go
@@ -219,7 +219,7 @@ func (s *Server) Stop() {
chainClient := s.chainClient
s.handlerMu.Unlock()
if wallet != nil {
- wallet.Stop()
+ wallet.StopDeprecated()
}
if chainClient != nil {
chainClient.Stop()
diff --git a/rpc/rpcserver/server.go b/rpc/rpcserver/server.go
index c66be61b6b..9593aee26a 100644
--- a/rpc/rpcserver/server.go
+++ b/rpc/rpcserver/server.go
@@ -189,7 +189,10 @@ func (s *walletServer) Accounts(ctx context.Context, req *pb.AccountsRequest) (
func (s *walletServer) RenameAccount(ctx context.Context, req *pb.RenameAccountRequest) (
*pb.RenameAccountResponse, error) {
- err := s.wallet.RenameAccount(waddrmgr.KeyScopeBIP0044, req.AccountNumber, req.NewName)
+ err := s.wallet.RenameAccountDeprecated(
+ waddrmgr.KeyScopeBIP0044, req.GetAccountNumber(),
+ req.GetNewName(),
+ )
if err != nil {
return nil, translateError(err)
}
@@ -210,7 +213,11 @@ func (s *walletServer) NextAccount(ctx context.Context, req *pb.NextAccountReque
defer func() {
lock <- time.Time{} // send matters, not the value
}()
- err := s.wallet.Unlock(req.Passphrase, lock)
+
+ //nolint:staticcheck // This should be fixed once the interface
+ // refactor is finished, and new wallet RPC is built.
+ err := s.wallet.UnlockDeprecated(req.GetPassphrase(),
+ lock)
if err != nil {
return nil, translateError(err)
}
@@ -232,7 +239,9 @@ func (s *walletServer) NextAddress(ctx context.Context, req *pb.NextAddressReque
)
switch req.Kind {
case pb.NextAddressRequest_BIP0044_EXTERNAL:
- addr, err = s.wallet.NewAddress(req.Account, waddrmgr.KeyScopeBIP0044)
+ addr, err = s.wallet.NewAddressDeprecated(
+ req.GetAccount(), waddrmgr.KeyScopeBIP0044,
+ )
case pb.NextAddressRequest_BIP0044_INTERNAL:
addr, err = s.wallet.NewChangeAddress(req.Account, waddrmgr.KeyScopeBIP0044)
default:
@@ -260,7 +269,11 @@ func (s *walletServer) ImportPrivateKey(ctx context.Context, req *pb.ImportPriva
defer func() {
lock <- time.Time{} // send matters, not the value
}()
- err = s.wallet.Unlock(req.Passphrase, lock)
+
+ //nolint:staticcheck // This should be fixed once the interface
+ // refactor is finished, and new wallet RPC is built.
+ err = s.wallet.UnlockDeprecated(req.GetPassphrase(),
+ lock)
if err != nil {
return nil, translateError(err)
}
@@ -456,7 +469,11 @@ func (s *walletServer) SignTransaction(ctx context.Context, req *pb.SignTransact
defer func() {
lock <- time.Time{} // send matters, not the value
}()
- err = s.wallet.Unlock(req.Passphrase, lock)
+
+ //nolint:staticcheck // This should be fixed once the interface
+ // refactor is finished, and new wallet RPC is built.
+ err = s.wallet.UnlockDeprecated(req.GetPassphrase(),
+ lock)
if err != nil {
return nil, translateError(err)
}
diff --git a/waddrmgr/interface.go b/waddrmgr/interface.go
new file mode 100644
index 0000000000..0e8ad1d446
--- /dev/null
+++ b/waddrmgr/interface.go
@@ -0,0 +1,323 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package waddrmgr
+
+import (
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcec/v2"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcwallet/walletdb"
+)
+
+// TODO(yy) This file provides a set of interfaces that abstract the
+// functionality of the waddrmgr package. The interfaces are designed to be
+// composable, allowing for a clean separation of concerns and making it easier
+// to test and maintain the codebase.
+//
+// The AddrStore interface is the top-level interface that composes all the
+// other interfaces. It is responsible for managing its own database
+// transactions, which means that the walletdb.ReadWriteBucket and
+// walletdb.ReadBucket arguments are not present in the interface methods.
+//
+// The breakdown of the interfaces is as follows:
+//
+// ChainState: Manages the wallet's sync state with the blockchain.
+// KeyScopeManager: Manages key scopes.
+// AddressManager: Manages addresses.
+// AccountManager: Manages accounts.
+// CryptoManager: Manages the encrypted state of the wallet.
+// WatchOnlyManager: Manages watch-only functionality.
+//
+// The current AddrStore interface has several design flaws that should be
+// addressed in a future refactoring:
+//
+// 1. Leaky Abstraction & Lack of Encapsulation:
+// - Problem: Nearly every method in the interface requires the caller (the
+// `wallet` package) to pass in a `walletdb.ReadWriteBucket` or
+// `walletdb.ReadBucket`.
+// - Why it's an issue: This is a classic "leaky abstraction." The
+// `AddrStore` is supposed to abstract away the details of address
+// management, but it's forcing its consumer to know about and manage its
+// internal database structure and transactions. The `wallet` package
+// should not be responsible for starting a database transaction just to
+// call a method on the `AddrStore`. The `AddrStore` should manage its own
+// persistence internally.
+//
+// 2. Violation of the Interface Segregation Principle (ISP):
+// - Problem: The `AddrStore` is a "fat" interface. It includes dozens of
+// methods covering many distinct areas of responsibility: cryptographic
+// operations (`Lock`, `Unlock`), chain synchronization (`SyncedTo`), key
+// management (`NewScopedKeyManager`), and address lookups (`Address`).
+// - Why it's an issue: Consumers of the interface are forced to depend on
+// methods they don't use. For example, a component that only needs to
+// look up an address (`Address`) is also coupled to methods for changing
+// passphrases. This leads to unnecessary dependencies and makes the code
+// harder to test, as mocks become massive and unwieldy.
+//
+// 3. Violation of the Single Responsibility Principle (SRP):
+// - Problem: The interface combines multiple, distinct responsibilities
+// into one unit. It acts as a key manager, an address book, a crypto
+// manager, and a chain state tracker all at once.
+// - Why it's an issue: This makes the `AddrStore` difficult to reason about
+// and maintain. A change in how we manage chain state, for example, could
+// require modifying an interface that is also responsible for
+// cryptography. These concerns should be separate.
+
+// AddrStore is an interface that describes a wallet address store.
+//
+//nolint:interfacebloat
+type AddrStore interface {
+ // Birthday returns the birthday of the address store.
+ Birthday() time.Time
+
+ // SetSyncedTo marks the address manager to be in sync with the
+ // recently-seen block described by the blockstamp.
+ SetSyncedTo(ns walletdb.ReadWriteBucket, bs *BlockStamp) error
+
+ // SetBirthdayBlock sets the birthday block, or earliest time a key
+ // could have been used, for the manager.
+ SetBirthdayBlock(ns walletdb.ReadWriteBucket, block BlockStamp,
+ verified bool) error
+
+ // SyncedTo returns details about the block height and hash that the
+ // address manager is synced through at the very least.
+ SyncedTo() BlockStamp
+
+ // BlockHash returns the block hash at a particular block height.
+ BlockHash(ns walletdb.ReadBucket, height int32) (*chainhash.Hash, error)
+
+ // ActiveScopedKeyManagers returns a slice of all the active scoped key
+ // managers currently known by the root key manager.
+ ActiveScopedKeyManagers() []AccountStore
+
+ // FetchScopedKeyManager attempts to fetch an active scoped manager
+ // according to its registered scope.
+ FetchScopedKeyManager(scope KeyScope) (AccountStore, error)
+
+ // Address returns a managed address given the passed address if it is
+ // known to the address manager.
+ Address(ns walletdb.ReadBucket,
+ address address.Address) (ManagedAddress, error)
+
+ // AddrAccount returns the account to which the given address belongs.
+ AddrAccount(ns walletdb.ReadBucket,
+ address address.Address) (AccountStore, uint32, error)
+
+ // AddressDetails determines whether the wallet has access to the
+ // private keys required to sign for a given address, and returns other
+ // address details.
+ AddressDetails(ns walletdb.ReadBucket,
+ addr address.Address) (bool, string, AddressType)
+
+ // ForEachRelevantActiveAddress invokes the given closure on each active
+ // address relevant to the wallet.
+ ForEachRelevantActiveAddress(ns walletdb.ReadBucket,
+ fn func(addr address.Address) error) error
+
+ // Unlock derives the master private key from the specified passphrase.
+ Unlock(ns walletdb.ReadBucket, passphrase []byte) error
+
+ // Lock performs a best try effort to remove and zero all secret keys
+ // associated with the address manager.
+ Lock() error
+
+ // IsLocked returns whether or not the address managed is locked.
+ IsLocked() bool
+
+ // ChangePassphrase changes either the public or private passphrase to
+ // the provided value depending on the private flag.
+ ChangePassphrase(ns walletdb.ReadWriteBucket, oldPass, newPass []byte,
+ private bool, scryptOptions *ScryptOptions) error
+
+ // WatchOnly returns true if the root manager is in watch only mode, and
+ // false otherwise.
+ WatchOnly() bool
+
+ // MarkUsed updates the used flag for the provided address.
+ MarkUsed(ns walletdb.ReadWriteBucket, address address.Address) error
+
+ // BirthdayBlock returns the birthday block of the address store.
+ BirthdayBlock(ns walletdb.ReadBucket) (BlockStamp, bool, error)
+
+ // IsWatchOnlyAccount determines if the account with the given key scope
+ // is set up as watch-only.
+ IsWatchOnlyAccount(ns walletdb.ReadBucket, keyScope KeyScope,
+ account uint32) (bool, error)
+
+ // NewScopedKeyManager creates a new scoped key manager from the root
+ // manager.
+ NewScopedKeyManager(ns walletdb.ReadWriteBucket,
+ scope KeyScope,
+ addrSchema ScopeAddrSchema) (AccountStore, error)
+
+ // SetBirthday sets the birthday of the address store.
+ SetBirthday(ns walletdb.ReadWriteBucket, birthday time.Time) error
+
+ // ForEachAccountAddress calls the given function with each address of
+ // the given account stored in the manager, breaking early on error.
+ ForEachAccountAddress(ns walletdb.ReadBucket, account uint32,
+ fn func(maddr ManagedAddress) error) error
+
+ // LookupAccount returns the corresponding key scope and account number
+ // for the account with the given name.
+ LookupAccount(ns walletdb.ReadBucket,
+ name string) (KeyScope, uint32, error)
+
+ // ForEachActiveAddress calls the given function with each active
+ // address stored in the manager, breaking early on error.
+ ForEachActiveAddress(ns walletdb.ReadBucket,
+ fn func(addr address.Address) error) error
+
+ // ConvertToWatchingOnly converts the current address manager to a
+ // locked watching-only address manager.
+ ConvertToWatchingOnly(ns walletdb.ReadWriteBucket) error
+
+ // ChainParams returns the chain parameters for this address manager.
+ ChainParams() *chaincfg.Params
+
+ // Close cleanly shuts down the manager.
+ Close()
+}
+
+// AccountStore is an interface that describes a scoped key manager.
+//
+// TODO(yy): remove this interface and hide the details inside AddrStore.
+//
+//nolint:interfacebloat
+type AccountStore interface {
+ // Scope returns the key scope of the manager.
+ Scope() KeyScope
+
+ // AccountProperties returns the properties of an account, including
+ // address indexes and name.
+ AccountProperties(ns walletdb.ReadBucket,
+ account uint32) (*AccountProperties, error)
+
+ // LastExternalAddress returns the last external address for an account.
+ LastExternalAddress(ns walletdb.ReadBucket,
+ account uint32) (ManagedAddress, error)
+
+ // LastInternalAddress returns the last internal address for an account.
+ LastInternalAddress(ns walletdb.ReadBucket,
+ account uint32) (ManagedAddress, error)
+
+ // ForEachAccountAddress calls the given function with each address of
+ // the given account stored in the manager, breaking early on error.
+ ForEachAccountAddress(ns walletdb.ReadBucket, account uint32,
+ fn func(maddr ManagedAddress) error) error
+
+ // LookupAccount returns the account number for the given account name.
+ LookupAccount(ns walletdb.ReadBucket, name string) (uint32, error)
+
+ // AccountName returns the name of an account.
+ AccountName(ns walletdb.ReadBucket, account uint32) (string, error)
+
+ // ExtendExternalAddresses extends the external addresses for an
+ // account.
+ ExtendExternalAddresses(ns walletdb.ReadWriteBucket, account uint32,
+ count uint32) error
+
+ // ExtendInternalAddresses extends the internal addresses for an
+ // account.
+ ExtendInternalAddresses(ns walletdb.ReadWriteBucket, account uint32,
+ count uint32) error
+
+ // ExtendAddresses ensures that all valid keys through lastIndex are
+ // derived and stored in the wallet for the specified branch.
+ ExtendAddresses(ns walletdb.ReadWriteBucket, account uint32,
+ lastIndex uint32, branch uint32) error
+
+ // MarkUsed updates the used flag for the provided address.
+ MarkUsed(ns walletdb.ReadWriteBucket, address address.Address) error
+
+ // DeriveFromKeyPath derives a key from the given key path.
+ DeriveFromKeyPath(ns walletdb.ReadBucket,
+ path DerivationPath) (ManagedAddress, error)
+
+ // CanAddAccount returns an error if a new account cannot be created.
+ CanAddAccount() error
+
+ // NewAccount creates a new account.
+ NewAccount(ns walletdb.ReadWriteBucket, name string) (uint32, error)
+
+ // LastAccount returns the last account number.
+ LastAccount(ns walletdb.ReadBucket) (uint32, error)
+
+ // RenameAccount renames an account.
+ RenameAccount(ns walletdb.ReadWriteBucket, account uint32,
+ name string) error
+
+ // NextExternalAddresses returns the next external addresses for an
+ // account.
+ NextExternalAddresses(ns walletdb.ReadWriteBucket, account uint32,
+ count uint32) ([]ManagedAddress, error)
+
+ // NextInternalAddresses returns the next internal addresses for an
+ // account.
+ NextInternalAddresses(ns walletdb.ReadWriteBucket, account uint32,
+ count uint32) ([]ManagedAddress, error)
+
+ // NewAddress creates a new address for an account.
+ NewAddress(ns walletdb.ReadWriteBucket, account string,
+ internal bool) (address.Address, error)
+
+ // ImportPublicKey imports a public key.
+ ImportPublicKey(ns walletdb.ReadWriteBucket, pubKey *btcec.PublicKey,
+ bs *BlockStamp) (ManagedAddress, error)
+
+ // ImportTaprootScript imports a taproot script.
+ ImportTaprootScript(ns walletdb.ReadWriteBucket,
+ script *Tapscript, bs *BlockStamp, privKeyType byte,
+ isInternal bool) (ManagedTaprootScriptAddress, error)
+
+ // ForEachAccount calls the given function with each account stored in
+ // the manager, breaking early on error.
+ ForEachAccount(ns walletdb.ReadBucket,
+ fn func(account uint32) error) error
+
+ // IsWatchOnlyAccount determines if the account is watch-only.
+ IsWatchOnlyAccount(ns walletdb.ReadBucket, account uint32) (bool, error)
+
+ // NewAccountWatchingOnly creates a new watch-only account.
+ NewAccountWatchingOnly(ns walletdb.ReadWriteBucket, name string,
+ pubKey *hdkeychain.ExtendedKey, masterKeyFingerprint uint32,
+ addrSchema *ScopeAddrSchema) (uint32, error)
+
+ // InvalidateAccountCache invalidates the account cache.
+ InvalidateAccountCache(account uint32)
+
+ // ImportPrivateKey imports a private key.
+ ImportPrivateKey(ns walletdb.ReadWriteBucket, wif *btcutil.WIF,
+ bs *BlockStamp) (ManagedPubKeyAddress, error)
+
+ // AddrAccount returns the account for a given address.
+ AddrAccount(ns walletdb.ReadBucket,
+ address address.Address) (uint32, error)
+
+ // DeriveFromKeyPathCache derives a key from the given key path, using
+ // the cache.
+ DeriveFromKeyPathCache(kp DerivationPath) (*btcec.PrivateKey, error)
+
+ // NewRawAccount creates a new account with a raw account number.
+ NewRawAccount(ns walletdb.ReadWriteBucket, number uint32) error
+
+ // ImportScript imports a script.
+ ImportScript(ns walletdb.ReadWriteBucket, script []byte,
+ bs *BlockStamp) (ManagedScriptAddress, error)
+
+ // ActiveAccounts returns the account numbers of all accounts currently
+ // loaded in memory.
+ ActiveAccounts() []uint32
+
+ // DeriveAddr derives a single address for the given account, branch,
+ // and index.
+ DeriveAddr(account uint32, branch uint32, index uint32) (
+ address.Address, []byte, error)
+}
diff --git a/waddrmgr/manager.go b/waddrmgr/manager.go
index 9fc40086eb..56178c6dc6 100644
--- a/waddrmgr/manager.go
+++ b/waddrmgr/manager.go
@@ -46,14 +46,17 @@ const (
// DefaultAccountNum is the number of the default account.
DefaultAccountNum = 0
- // defaultAccountName is the initial name of the default account. Note
+ // DefaultAccountName is the initial name of the default account. Note
// that the default account may be renamed and is not a reserved name,
// so the default account might not be named "default" and non-default
// accounts may be named "default".
//
// Account numbers never change, so the DefaultAccountNum should be
// used to refer to (and only to) the default account.
- defaultAccountName = "default"
+ DefaultAccountName = "default"
+
+ // unknownAccountName is the string returned when an account is unknown.
+ unknownAccountName = "unknown"
// The hierarchy described by BIP0043 is:
// m/'/*
@@ -428,17 +431,16 @@ func (m *Manager) IsWatchOnlyAccount(ns walletdb.ReadBucket, keyScope KeyScope,
func (m *Manager) lock() {
for _, manager := range m.scopedManagers {
// Clear all of the account private keys.
- for _, acctInfo := range manager.acctInfo {
+ for _, acctInfo := range manager.accountInfo() {
if acctInfo.acctKeyPriv != nil {
acctInfo.acctKeyPriv.Zero()
}
acctInfo.acctKeyPriv = nil
}
- }
- // Remove clear text private keys and scripts from all address entries.
- for _, manager := range m.scopedManagers {
- for _, ma := range manager.addrs {
+ // Remove clear text private keys and scripts from all address
+ // entries.
+ for _, ma := range manager.addresses() {
switch addr := ma.(type) {
case *managedAddress:
addr.lock()
@@ -503,7 +505,7 @@ func (m *Manager) Close() {
// TODO(roasbeef): addrtype of raw key means it'll look in scripts to possibly
// mark as gucci?
func (m *Manager) NewScopedKeyManager(ns walletdb.ReadWriteBucket,
- scope KeyScope, addrSchema ScopeAddrSchema) (*ScopedKeyManager, error) {
+ scope KeyScope, addrSchema ScopeAddrSchema) (AccountStore, error) {
m.mtx.Lock()
defer m.mtx.Unlock()
@@ -626,7 +628,7 @@ func (m *Manager) NewScopedKeyManager(ns walletdb.ReadWriteBucket,
// its registered scope. If the manger is found, then a nil error is returned
// along with the active scoped manager. Otherwise, a nil manager and a non-nil
// error will be returned.
-func (m *Manager) FetchScopedKeyManager(scope KeyScope) (*ScopedKeyManager, error) {
+func (m *Manager) FetchScopedKeyManager(scope KeyScope) (AccountStore, error) {
m.mtx.RLock()
defer m.mtx.RUnlock()
@@ -641,11 +643,11 @@ func (m *Manager) FetchScopedKeyManager(scope KeyScope) (*ScopedKeyManager, erro
// ActiveScopedKeyManagers returns a slice of all the active scoped key
// managers currently known by the root key manager.
-func (m *Manager) ActiveScopedKeyManagers() []*ScopedKeyManager {
+func (m *Manager) ActiveScopedKeyManagers() []AccountStore {
m.mtx.RLock()
defer m.mtx.RUnlock()
- scopedManagers := make([]*ScopedKeyManager, 0, len(m.scopedManagers))
+ scopedManagers := make([]AccountStore, 0, len(m.scopedManagers))
for _, smgr := range m.scopedManagers {
scopedManagers = append(scopedManagers, smgr)
}
@@ -753,7 +755,7 @@ func (m *Manager) MarkUsed(ns walletdb.ReadWriteBucket,
// AddrAccount returns the account to which the given address belongs. We also
// return the scoped manager that owns the addr+account combo.
func (m *Manager) AddrAccount(ns walletdb.ReadBucket,
- address address.Address) (*ScopedKeyManager, uint32, error) {
+ address address.Address) (AccountStore, uint32, error) {
m.mtx.RLock()
defer m.mtx.RUnlock()
@@ -780,6 +782,53 @@ func (m *Manager) AddrAccount(ns walletdb.ReadBucket,
return nil, 0, managerError(ErrAddressNotFound, str, nil)
}
+// AddressDetails determines whether the wallet has access to the private keys
+// required to sign for a given address, and returns other address details.
+func (m *Manager) AddressDetails(ns walletdb.ReadBucket,
+ addr address.Address) (bool, string, AddressType) {
+
+ managedAddr, err := m.Address(ns, addr)
+ if err != nil {
+ // If we don't know the address, we can't spend it.
+ return false, unknownAccountName, 0
+ }
+
+ addrType := managedAddr.AddrType()
+
+ // A global watch-only wallet can't spend anything.
+ if m.WatchOnly() {
+ return false, unknownAccountName, addrType
+ }
+
+ // Imported addresses are considered unspendable by policy.
+ if managedAddr.Imported() {
+ return false, ImportedAddrAccountName, addrType
+ }
+
+ // Check if the specific account for this address is watch-only.
+ scopedMgr, account, err := m.AddrAccount(ns, addr)
+ if err != nil {
+ return false, unknownAccountName, addrType
+ }
+
+ accountName, err := scopedMgr.AccountName(ns, account)
+ if err != nil {
+ return false, unknownAccountName, addrType
+ }
+
+ isWatchOnlyAccount, err := scopedMgr.IsWatchOnlyAccount(ns, account)
+ if err != nil {
+ return false, accountName, addrType
+ }
+
+ if isWatchOnlyAccount {
+ return false, accountName, addrType
+ }
+
+ // If all checks pass, the address is spendable.
+ return true, accountName, addrType
+}
+
// ForEachActiveAccountAddress calls the given function with each active
// address of the given account stored in the manager, across all active
// scopes, breaking early on error.
@@ -1765,7 +1814,7 @@ func createManagerKeyScope(ns walletdb.ReadWriteBucket,
// Save the information for the default account to the database.
err = putDefaultAccountInfo(
ns, &scope, DefaultAccountNum, acctPubEnc, acctPrivEnc, 0, 0,
- defaultAccountName,
+ DefaultAccountName,
)
if err != nil {
return err
diff --git a/waddrmgr/manager_test.go b/waddrmgr/manager_test.go
index 92847e656f..bb6de160e3 100644
--- a/waddrmgr/manager_test.go
+++ b/waddrmgr/manager_test.go
@@ -1407,11 +1407,7 @@ func testChangePassphrase(tc *testContext) bool {
func testNewAccount(tc *testContext) bool {
if tc.watchingOnly {
// Creating new accounts in watching-only mode should return ErrWatchingOnly
- err := walletdb.Update(tc.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- _, err := tc.manager.NewAccount(ns, "test")
- return err
- })
+ err := tc.manager.CanAddAccount()
if !checkManagerError(
tc.t, "Create account in watching-only mode", err,
ErrWatchingOnly,
@@ -1422,11 +1418,7 @@ func testNewAccount(tc *testContext) bool {
return true
}
// Creating new accounts when wallet is locked should return ErrLocked
- err := walletdb.Update(tc.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- _, err := tc.manager.NewAccount(ns, "test")
- return err
- })
+ err := tc.manager.CanAddAccount()
if !checkManagerError(
tc.t, "Create account when wallet is locked", err, ErrLocked,
) {
@@ -1507,7 +1499,7 @@ func testNewAccount(tc *testContext) bool {
func testLookupAccount(tc *testContext) bool {
// Lookup accounts created earlier in testNewAccount
expectedAccounts := map[string]uint32{
- defaultAccountName: DefaultAccountNum,
+ DefaultAccountName: DefaultAccountNum,
ImportedAddrAccountName: ImportedAddrAccount,
}
@@ -1778,7 +1770,7 @@ func testManagerAPI(tc *testContext, caseCreatedWatchingOnly bool) {
testNewAccount(tc)
expectedAccounts := map[string]uint32{
- defaultAccountName: DefaultAccountNum,
+ DefaultAccountName: DefaultAccountNum,
}
testLookupExpectedAccount(tc, expectedAccounts, 0)
//testForEachAccount(tc)
@@ -1844,11 +1836,15 @@ func testConvertWatchingOnly(tc *testContext) bool {
// Run all of the manager API tests against the converted manager and
// close it. We'll also retrieve the default scope (BIP0044) from the
// manager in order to use.
- scopedMgr, err := mgr.FetchScopedKeyManager(KeyScopeBIP0044)
+ sMgr, err := mgr.FetchScopedKeyManager(KeyScopeBIP0044)
if err != nil {
tc.t.Errorf("unable to fetch bip 44 scope %v", err)
return false
}
+
+ scopedMgr, ok := sMgr.(*ScopedKeyManager)
+ require.True(tc.t, ok)
+
testManagerAPI(&testContext{
t: tc.t,
caseName: tc.caseName,
@@ -1874,12 +1870,15 @@ func testConvertWatchingOnly(tc *testContext) bool {
}
defer mgr.Close()
- scopedMgr, err = mgr.FetchScopedKeyManager(KeyScopeBIP0044)
+ sMgr, err = mgr.FetchScopedKeyManager(KeyScopeBIP0044)
if err != nil {
tc.t.Errorf("unable to fetch bip 44 scope %v", err)
return false
}
+ scopedMgr, ok = sMgr.(*ScopedKeyManager)
+ require.True(tc.t, ok)
+
testManagerAPI(&testContext{
t: tc.t,
caseName: tc.caseName,
@@ -2049,11 +2048,14 @@ func testManagerCase(t *testing.T, caseName string,
return
}
- scopedMgr, err := mgr.FetchScopedKeyManager(KeyScopeBIP0044)
+ sMgr, err := mgr.FetchScopedKeyManager(KeyScopeBIP0044)
if err != nil {
t.Fatalf("(%s) unable to fetch default scope: %v", caseName, err)
}
+ scopedMgr, ok := sMgr.(*ScopedKeyManager)
+ require.True(t, ok)
+
if caseCreatedWatchingOnly {
accountKey := deriveTestAccountKey(t)
if accountKey == nil {
@@ -2070,7 +2072,7 @@ func testManagerCase(t *testing.T, caseName string,
err = walletdb.Update(db, func(tx walletdb.ReadWriteTx) error {
ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
_, err = scopedMgr.NewAccountWatchingOnly(
- ns, defaultAccountName, acctKeyPub, 0, nil,
+ ns, DefaultAccountName, acctKeyPub, 0, nil,
)
return err
})
@@ -2108,10 +2110,13 @@ func testManagerCase(t *testing.T, caseName string,
}
defer mgr.Close()
- scopedMgr, err = mgr.FetchScopedKeyManager(KeyScopeBIP0044)
+ sMgr, err = mgr.FetchScopedKeyManager(KeyScopeBIP0044)
if err != nil {
t.Fatalf("(%s) unable to fetch default scope: %v", caseName, err)
}
+
+ scopedMgr, ok = sMgr.(*ScopedKeyManager)
+ require.True(t, ok)
tc := &testContext{
t: t,
caseName: caseName,
@@ -2374,7 +2379,10 @@ func TestScopedKeyManagerManagement(t *testing.T) {
t.Fatalf("unable to fetch scope %v: %v", scope, err)
}
- externalAddr, err := sMgr.NextExternalAddresses(
+ scopedMgr, ok := sMgr.(*ScopedKeyManager)
+ require.True(t, ok)
+
+ externalAddr, err := scopedMgr.NextExternalAddresses(
ns, DefaultAccountNum, 1,
)
if err != nil {
@@ -2389,7 +2397,7 @@ func TestScopedKeyManagerManagement(t *testing.T) {
ScopeAddrMap[scope].ExternalAddrType)
}
- internalAddr, err := sMgr.NextInternalAddresses(
+ internalAddr, err := scopedMgr.NextInternalAddresses(
ns, DefaultAccountNum, 1,
)
if err != nil {
@@ -2421,11 +2429,12 @@ func TestScopedKeyManagerManagement(t *testing.T) {
ExternalAddrType: NestedWitnessPubKey,
InternalAddrType: WitnessPubKey,
}
- var scopedMgr *ScopedKeyManager
+
+ var sMgr AccountStore
err = walletdb.Update(db, func(tx walletdb.ReadWriteTx) error {
ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- scopedMgr, err = mgr.NewScopedKeyManager(ns, testScope, addrSchema)
+ sMgr, err = mgr.NewScopedKeyManager(ns, testScope, addrSchema)
if err != nil {
return err
}
@@ -2436,6 +2445,9 @@ func TestScopedKeyManagerManagement(t *testing.T) {
t.Fatalf("unable to read db: %v", err)
}
+ scopedMgr, ok := sMgr.(*ScopedKeyManager)
+ require.True(t, ok)
+
// The manager was just created, we should be able to look it up within
// the root manager.
if _, err := mgr.FetchScopedKeyManager(testScope); err != nil {
@@ -2481,7 +2493,7 @@ func TestScopedKeyManagerManagement(t *testing.T) {
NestedWitnessPubKey, externalAddr[0].AddrType())
}
- _, ok := externalAddr[0].Address().(*address.AddressScriptHash)
+ _, ok = externalAddr[0].Address().(*address.AddressScriptHash)
if !ok {
t.Fatalf("wrong type: %T", externalAddr[0].Address())
}
@@ -2518,11 +2530,14 @@ func TestScopedKeyManagerManagement(t *testing.T) {
// We should be able to retrieve the new scoped manager that we just
// created.
- scopedMgr, err = mgr.FetchScopedKeyManager(testScope)
+ sMgr, err = mgr.FetchScopedKeyManager(testScope)
if err != nil {
t.Fatalf("attempt to read created mgr failed: %v", err)
}
+ scopedMgr, ok = sMgr.(*ScopedKeyManager)
+ require.True(t, ok)
+
// If we fetch the last generated external address, it should map
// exactly to the address that we just generated.
var lastAddr ManagedAddress
@@ -2707,11 +2722,14 @@ func TestNewRawAccount(t *testing.T) {
// Now that we have the manager created, we'll fetch one of the default
// scopes for usage within this test.
- scopedMgr, err := mgr.FetchScopedKeyManager(KeyScopeBIP0084)
+ sMgr, err := mgr.FetchScopedKeyManager(KeyScopeBIP0084)
if err != nil {
t.Fatalf("unable to fetch scope %v: %v", KeyScopeBIP0084, err)
}
+ scopedMgr, ok := sMgr.(*ScopedKeyManager)
+ require.True(t, ok)
+
// With the scoped manager retrieved, we'll attempt to create a new raw
// account by number.
const accountNum = 1000
@@ -2767,11 +2785,14 @@ func TestNewRawAccountWatchingOnly(t *testing.T) {
// Now that we have the manager created, we'll fetch one of the default
// scopes for usage within this test.
- scopedMgr, err := mgr.FetchScopedKeyManager(KeyScopeBIP0044)
+ sMgr, err := mgr.FetchScopedKeyManager(KeyScopeBIP0044)
if err != nil {
t.Fatalf("unable to fetch scope %v: %v", KeyScopeBIP0044, err)
}
+ scopedMgr, ok := sMgr.(*ScopedKeyManager)
+ require.True(t, ok)
+
accountKey := deriveTestAccountKey(t)
if accountKey == nil {
return
@@ -2828,11 +2849,14 @@ func TestNewRawAccountHybrid(t *testing.T) {
// Now that we have the manager created, we'll fetch one of the default
// scopes for usage within this test.
- scopedMgr, err := mgr.FetchScopedKeyManager(KeyScopeBIP0044)
+ sMgr, err := mgr.FetchScopedKeyManager(KeyScopeBIP0044)
if err != nil {
t.Fatalf("unable to fetch scope %v: %v", KeyScopeBIP0044, err)
}
+ scopedMgr, ok := sMgr.(*ScopedKeyManager)
+ require.True(t, ok)
+
accountKey := deriveTestAccountKey(t)
if accountKey == nil {
return
@@ -2951,11 +2975,14 @@ func TestDeriveFromKeyPathCache(t *testing.T) {
// Now that we have the manager created, we'll fetch one of the default
// scopes for usage within this test.
- scopedMgr, err := mgr.FetchScopedKeyManager(KeyScopeBIP0044)
+ sMgr, err := mgr.FetchScopedKeyManager(KeyScopeBIP0044)
require.NoError(
t, err, "unable to fetch scope %v: %v", KeyScopeBIP0044, err,
)
+ scopedMgr, ok := sMgr.(*ScopedKeyManager)
+ require.True(t, ok)
+
keyPath := DerivationPath{
InternalAccount: 0,
Account: hdkeychain.HardenedKeyStart,
@@ -3049,11 +3076,14 @@ func TestTaprootPubKeyDerivation(t *testing.T) {
// Now that we have the manager created, we'll fetch one of the default
// scopes for usage within this test.
- scopedMgr, err := mgr.FetchScopedKeyManager(KeyScopeBIP0086)
+ sMgr, err := mgr.FetchScopedKeyManager(KeyScopeBIP0086)
require.NoError(
t, err, "unable to fetch scope %v: %v", KeyScopeBIP0086, err,
)
+ scopedMgr, ok := sMgr.(*ScopedKeyManager)
+ require.True(t, ok)
+
externalPath := DerivationPath{
InternalAccount: 0,
Account: hdkeychain.HardenedKeyStart,
@@ -3256,12 +3286,15 @@ func TestManagedAddressValidation(t *testing.T) {
scope)
t.Run(testName, func(t *testing.T) {
- scopedMgr, err := mgr.FetchScopedKeyManager(scope)
+ sMgr, err := mgr.FetchScopedKeyManager(scope)
require.NoError(
t, err, "unable to fetch scope %v: %v",
KeyScopeBIP0086, err,
)
+ scopedMgr, ok := sMgr.(*ScopedKeyManager)
+ require.True(t, ok)
+
var addr ManagedAddress
// With the scoped managed we created above,
diff --git a/waddrmgr/scoped_manager.go b/waddrmgr/scoped_manager.go
index 4e4dad7c27..09d7ac2682 100644
--- a/waddrmgr/scoped_manager.go
+++ b/waddrmgr/scoped_manager.go
@@ -4,6 +4,7 @@ import (
"crypto/sha256"
"encoding/binary"
"fmt"
+ "maps"
"sync"
"github.com/btcsuite/btcd/address/v2"
@@ -131,6 +132,60 @@ func (k KeyScope) String() string {
return fmt.Sprintf("m/%v'/%v'", k.Purpose, k.Coin)
}
+// AccountScope uniquely identifies a specific account within a key scope.
+type AccountScope struct {
+ // Scope is the BIP44 account' used to derive the child key.
+ Scope KeyScope
+
+ // Account is the account number.
+ Account uint32
+}
+
+// String returns a human readable version describing the account scope.
+func (as AccountScope) String() string {
+ return fmt.Sprintf("%s/%d'", as.Scope, as.Account)
+}
+
+// BranchScope uniquely identifies a specific derivation branch within an
+// account.
+type BranchScope struct {
+ // Scope is the key scope of the branch.
+ Scope KeyScope
+
+ // Account is the account number of the branch.
+ Account uint32
+
+ // Branch is the branch number (e.g. waddrmgr.ExternalBranch or
+ // waddrmgr.InternalBranch).
+ Branch uint32
+}
+
+// String returns a human readable version describing the branch scope.
+func (bs BranchScope) String() string {
+ return fmt.Sprintf("%s/%d/%d'", bs.Scope, bs.Account, bs.Branch)
+}
+
+// IsChange returns true if the branch matches the internal (change) branch.
+func (bs BranchScope) IsChange() bool {
+ return bs.Branch == InternalBranch
+}
+
+// AddrScope uniquely identifies a specific address within a derivation branch.
+type AddrScope struct {
+ BranchScope BranchScope
+ Index uint32
+}
+
+// String returns a human readable version describing the address scope.
+func (as AddrScope) String() string {
+ return fmt.Sprintf("%s/%d", as.BranchScope, as.Index)
+}
+
+// IsChange returns true if the address belongs to an internal (change) branch.
+func (as AddrScope) IsChange() bool {
+ return as.BranchScope.IsChange()
+}
+
// Identity is a closure that returns the identifier of an address.
type Identity func() []byte
@@ -303,6 +358,24 @@ type ScopedKeyManager struct {
mtx sync.RWMutex
}
+// A compile-time assertion to ensure that ScopedKeyManager implements the
+// AccountStore interface.
+var _ AccountStore = (*ScopedKeyManager)(nil)
+
+// ActiveAccounts returns the account numbers of all accounts currently loaded
+// in memory.
+func (s *ScopedKeyManager) ActiveAccounts() []uint32 {
+ s.mtx.RLock()
+ defer s.mtx.RUnlock()
+
+ accounts := make([]uint32, 0, len(s.acctInfo))
+ for account := range s.acctInfo {
+ accounts = append(accounts, account)
+ }
+
+ return accounts
+}
+
// Scope returns the exact KeyScope of this scoped key manager.
func (s *ScopedKeyManager) Scope() KeyScope {
return s.scope
@@ -1002,6 +1075,41 @@ func (s *ScopedKeyManager) accountAddrType(acctInfo *accountInfo,
return addrSchema.ExternalAddrType
}
+// TODO(yy): This method is a "God method" that does too much, leading to
+// several issues. It should be refactored to improve separation of concerns,
+// reduce complexity, and increase robustness.
+//
+// Issues:
+// 1. **Excessive Responsibility:** The method handles account validation, key
+// derivation, address creation, database persistence, and in-memory state
+// management, violating the Single Responsibility Principle.
+// 2. **Fragile Concurrency Model:** The use of an `onCommit` closure to sync
+// the in-memory cache with the database state is prone to race
+// conditions and deadlocks, as it requires re-acquiring a mutex outside
+// of the original database transaction's scope.
+// 3. **Inefficiency and Redundancy:** The method performs a read-after-write
+// validation by loading addresses from the DB immediately after saving
+// them, which indicates a lack of confidence in the persistence logic.
+// 4. **Complex API:** The method returns a slice of addresses but is almost
+// always called with a request for a single address, making the API and
+// its usage unnecessarily complex.
+//
+// Refactoring Tasks:
+// - **Separate DB Logic:** Encapsulate all database read/write operations
+// within this package. The method should handle its own transaction
+// instead of accepting a `walletdb.ReadWriteBucket`.
+// - **Simplify State Management:** Refactor the in-memory cache (`s.addrs`)
+// and the next address index (`acctInfo.next...Index`) to be updated
+// atomically with the database write, removing the fragile `onCommit`
+// closure.
+// - **Decompose the Method:** Break this function into smaller, private
+// helpers for each distinct responsibility: deriving keys, creating
+// managed addresses, and persisting them.
+// - **Simplify the API:** Create a new, simpler method that returns a single
+// address, as this is the most common use case. The batch generation
+// logic can be deprecated or refactored into a separate method if still
+// needed.
+//
// nextAddresses returns the specified number of next chained address from the
// branch indicated by the internal flag.
//
@@ -1010,8 +1118,10 @@ func (s *ScopedKeyManager) nextAddresses(ns walletdb.ReadWriteBucket,
account uint32, numAddresses uint32, internal bool) ([]ManagedAddress,
error) {
- // The next address can only be generated for accounts that have
- // already been created.
+ // The next address can only be generated for accounts that have already
+ // been created. We load the account info to retrieve the decrypted keys and
+ // other cached metadata. This ensures we don't perform expensive crypto
+ // operations for every address generation.
acctInfo, err := s.loadAccountInfo(ns, account)
if err != nil {
return nil, err
@@ -1020,24 +1130,37 @@ func (s *ScopedKeyManager) nextAddresses(ns walletdb.ReadWriteBucket,
// Choose the account key to used based on whether the address manager
// is locked.
acctKey := acctInfo.acctKeyPub
- watchOnly := s.rootManager.WatchOnly() || len(acctInfo.acctKeyEncrypted) == 0
+ watchOnly := s.rootManager.WatchOnly() ||
+ len(acctInfo.acctKeyEncrypted) == 0
+
if !s.rootManager.IsLocked() && !watchOnly {
acctKey = acctInfo.acctKeyPriv
}
+ // Choose the appropriate type of address to derive since it's possible
+ // for a watch-only account to have a different schema from the
+ // manager's.
+ addrType := s.accountAddrType(acctInfo, internal)
+
+ // We also fetch the raw account row directly from the database to
+ // ensure we have the most up-to-date address indices, bypassing any
+ // potentially stale cached state. This is critical for preventing
+ // race conditions during concurrent address generation.
+ //
+ // NOTE: We must do this *after* loadAccountInfo to ensure the account
+ // exists and is cached, but we override the indices with the DB values.
+ nextIndex, err := s.getNextIndex(ns, account, internal)
+ if err != nil {
+ return nil, err
+ }
+
// Choose the branch key and index depending on whether or not this is
// an internal address.
- branchNum, nextIndex := ExternalBranch, acctInfo.nextExternalIndex
+ branchNum := ExternalBranch
if internal {
branchNum = InternalBranch
- nextIndex = acctInfo.nextInternalIndex
}
- // Choose the appropriate type of address to derive since it's possible
- // for a watch-only account to have a different schema from the
- // manager's.
- addrType := s.accountAddrType(acctInfo, internal)
-
// Ensure the requested number of addresses doesn't exceed the maximum
// allowed for this account.
if numAddresses > MaxAddressesPerAccount || nextIndex+numAddresses >
@@ -1209,13 +1332,22 @@ func (s *ScopedKeyManager) nextAddresses(ns walletdb.ReadWriteBucket,
}
// Set the last address and next address for tracking.
+ //
+ // NOTE: We only update the cache if the new index is strictly greater
+ // than the current cached index. This protects against a race condition
+ // where a slower transaction commits *after* a faster one, which could
+ // otherwise cause the cache to regress.
ma := addressInfo[len(addressInfo)-1].managedAddr
if internal {
- acctInfo.nextInternalIndex = nextIndex
- acctInfo.lastInternalAddr = ma
+ if nextIndex > acctInfo.nextInternalIndex {
+ acctInfo.nextInternalIndex = nextIndex
+ acctInfo.lastInternalAddr = ma
+ }
} else {
- acctInfo.nextExternalIndex = nextIndex
- acctInfo.lastExternalAddr = ma
+ if nextIndex > acctInfo.nextExternalIndex {
+ acctInfo.nextExternalIndex = nextIndex
+ acctInfo.lastExternalAddr = ma
+ }
}
}
ns.Tx().OnCommit(onCommit)
@@ -1248,19 +1380,30 @@ func (s *ScopedKeyManager) extendAddresses(ns walletdb.ReadWriteBucket,
acctKey = acctInfo.acctKeyPriv
}
+ // Choose the appropriate type of address to derive since it's possible
+ // for a watch-only account to have a different schema from the
+ // manager's.
+ addrType := s.accountAddrType(acctInfo, internal)
+
+ // We also fetch the raw account row directly from the database to
+ // ensure we have the most up-to-date address indices, bypassing any
+ // potentially stale cached state. This is critical for preventing
+ // race conditions during concurrent address generation.
+ //
+ // NOTE: We must do this *after* loadAccountInfo to ensure the account
+ // exists and is cached, but we override the indices with the DB values.
+ nextIndex, err := s.getNextIndex(ns, account, internal)
+ if err != nil {
+ return err
+ }
+
// Choose the branch key and index depending on whether or not this is
// an internal address.
- branchNum, nextIndex := ExternalBranch, acctInfo.nextExternalIndex
+ branchNum := ExternalBranch
if internal {
branchNum = InternalBranch
- nextIndex = acctInfo.nextInternalIndex
}
- // Choose the appropriate type of address to derive since it's possible
- // for a watch-only account to have a different schema from the
- // manager's.
- addrType := s.accountAddrType(acctInfo, internal)
-
// If the last index requested is already lower than the next index, we
// can return early.
if lastIndex < nextIndex {
@@ -1399,18 +1542,45 @@ func (s *ScopedKeyManager) extendAddresses(ns walletdb.ReadWriteBucket,
}
// Set the last address and next address for tracking.
+ //
+ // NOTE: We only update the cache if the new index is strictly greater
+ // than the current cached index. This protects against a race condition
+ // where a slower transaction commits *after* a faster one, which could
+ // otherwise cause the cache to regress.
ma := addressInfo[len(addressInfo)-1].managedAddr
if internal {
- acctInfo.nextInternalIndex = nextIndex
- acctInfo.lastInternalAddr = ma
+ if nextIndex > acctInfo.nextInternalIndex {
+ acctInfo.nextInternalIndex = nextIndex
+ acctInfo.lastInternalAddr = ma
+ }
} else {
- acctInfo.nextExternalIndex = nextIndex
- acctInfo.lastExternalAddr = ma
+ if nextIndex > acctInfo.nextExternalIndex {
+ acctInfo.nextExternalIndex = nextIndex
+ acctInfo.lastExternalAddr = ma
+ }
}
return nil
}
+// ExtendAddresses ensures that all valid keys through lastIndex are
+// derived and stored in the wallet for the specified branch.
+func (s *ScopedKeyManager) ExtendAddresses(ns walletdb.ReadWriteBucket,
+ account uint32, lastIndex uint32, branch uint32) error {
+
+ if account > MaxAccountNum {
+ err := managerError(ErrAccountNumTooHigh, errAcctTooHigh, nil)
+ return err
+ }
+
+ s.mtx.Lock()
+ defer s.mtx.Unlock()
+
+ return s.extendAddresses(
+ ns, account, lastIndex, branch == InternalBranch,
+ )
+}
+
// NextExternalAddresses returns the specified number of next chained addresses
// that are intended for external use from the address manager.
func (s *ScopedKeyManager) NextExternalAddresses(ns walletdb.ReadWriteBucket,
@@ -1549,22 +1719,31 @@ func (s *ScopedKeyManager) LastInternalAddress(ns walletdb.ReadBucket,
return nil, managerError(ErrAddressNotFound, "no previous internal address", nil)
}
-// NewRawAccount creates a new account for the scoped manager. This method
-// differs from the NewAccount method in that this method takes the account
-// number *directly*, rather than taking a string name for the account, then
-// mapping that to the next highest account number.
-func (s *ScopedKeyManager) NewRawAccount(ns walletdb.ReadWriteBucket, number uint32) error {
+// CanAddAccount returns an error if a new account cannot be created.
+// This is the case if the manager is watch-only or is locked. A descriptive
+// error is returned in these cases.
+func (s *ScopedKeyManager) CanAddAccount() error {
if s.rootManager.WatchOnly() {
return managerError(ErrWatchingOnly, errWatchingOnly, nil)
}
- s.mtx.Lock()
- defer s.mtx.Unlock()
-
if s.rootManager.IsLocked() {
return managerError(ErrLocked, errLocked, nil)
}
+ return nil
+}
+
+// NewRawAccount creates a new account for the scoped manager. This method
+// differs from the NewAccount method in that this method takes the account
+// number *directly*, rather than taking a string name for the account, then
+// mapping that to the next highest account number.
+func (s *ScopedKeyManager) NewRawAccount(
+ ns walletdb.ReadWriteBucket, number uint32) error {
+
+ s.mtx.Lock()
+ defer s.mtx.Unlock()
+
// As this is an ad hoc account that may not follow our normal linear
// derivation, we'll create a new name for this account based off of
// the account number.
@@ -1608,17 +1787,9 @@ func (s *ScopedKeyManager) NewRawAccountWatchingOnly(
// access to the cointype keys (from which extended account keys are derived),
// it requires the manager to be unlocked.
func (s *ScopedKeyManager) NewAccount(ns walletdb.ReadWriteBucket, name string) (uint32, error) {
- if s.rootManager.WatchOnly() {
- return 0, managerError(ErrWatchingOnly, errWatchingOnly, nil)
- }
-
s.mtx.Lock()
defer s.mtx.Unlock()
- if s.rootManager.IsLocked() {
- return 0, managerError(ErrLocked, errLocked, nil)
- }
-
// Fetch latest account, and create a new account in the same
// transaction Fetch the latest account number to generate the next
// account number
@@ -1830,11 +2001,6 @@ func (s *ScopedKeyManager) RenameAccount(ns walletdb.ReadWriteBucket,
return managerError(ErrDuplicateAccount, str, err)
}
- // Validate account name
- if err := ValidateAccountName(name); err != nil {
- return err
- }
-
rowInterface, err := fetchAccountInfo(ns, &s.scope, account)
if err != nil {
return err
@@ -1980,6 +2146,82 @@ func (s *ScopedKeyManager) ImportPublicKey(ns walletdb.ReadWriteBucket,
return s.toImportedPublicManagedAddress(pubKey, true)
}
+// DeriveAddr derives a single address and its corresponding pkScript for the
+// given account, branch, and index. This method relies on the in-memory
+// account state and extended public keys, avoiding database access.
+func (s *ScopedKeyManager) DeriveAddr(account uint32, branch uint32,
+ index uint32) (address.Address, []byte, error) {
+
+ s.mtx.RLock()
+ defer s.mtx.RUnlock()
+
+ acctInfo, ok := s.acctInfo[account]
+ if !ok {
+ return nil, nil, managerError(ErrAccountNotCached,
+ "account not cached", nil)
+ }
+
+ return s.deriveAddr(acctInfo, account, branch, index)
+}
+
+// DeriveAddrs derives a range of addresses and their corresponding pkScripts
+// for the given account and branch. It generates `count` addresses starting
+// from `startIndex`. This method relies on the in-memory account state and
+// extended public keys, avoiding database access.
+//
+// It returns:
+// - A slice of derived addresses.
+// - A slice of corresponding pkScripts.
+// - An error if the account is not cached or derivation fails.
+func (s *ScopedKeyManager) DeriveAddrs(account uint32, branch uint32,
+ startIndex uint32, count uint32) ([]address.Address, [][]byte, error) {
+
+ // Make sure the index is sane.
+ if startIndex+count < startIndex {
+ str := fmt.Sprintf("child index overflow: %d + %d",
+ startIndex, count)
+
+ return nil, nil, managerError(ErrTooManyAddresses, str, nil)
+ }
+
+ s.mtx.RLock()
+ defer s.mtx.RUnlock()
+
+ // Ensure the account information is cached. If not, we cannot proceed
+ // without a DB transaction, so we return an error. The caller is
+ // expected to ensure the account is loaded (e.g. via AccountProperties)
+ // before calling this method.
+ acctInfo, ok := s.acctInfo[account]
+ if !ok {
+ return nil, nil, managerError(ErrAccountNotCached,
+ "account not cached", nil)
+ }
+
+ addrs := make([]address.Address, 0, count)
+ scripts := make([][]byte, 0, count)
+
+ // Iterate through the requested range of child indexes (startIndex to
+ // startIndex+count). For each index, we derive the corresponding
+ // extended key and convert it into a payment address and script.
+ //
+ // TODO(yy): Optimize by deriving the branch key once outside the loop
+ // instead of re-deriving it for every index via s.deriveKey.
+ endIndex := startIndex + count
+ for index := startIndex; index < endIndex; index++ {
+ addr, script, err := s.deriveAddr(
+ acctInfo, account, branch, index,
+ )
+ if err != nil {
+ return nil, nil, err
+ }
+
+ addrs = append(addrs, addr)
+ scripts = append(scripts, script)
+ }
+
+ return addrs, scripts, nil
+}
+
// importPublicKey imports a public key into the address manager and updates the
// wallet's start block if necessary. An error is returned if the public key
// already exists.
@@ -2579,3 +2821,163 @@ func (s *ScopedKeyManager) InvalidateAccountCache(account uint32) {
defer s.mtx.Unlock()
delete(s.acctInfo, account)
}
+
+// NewAddress returns a new address for the given account. The `change`
+// parameter dictates whether a change address (internal) or a receiving
+// address (external) should be generated. The caller is responsible for
+// providing a database transaction. The method first looks up the account
+// number from the provided account name. It then uses the appropriate
+// method (`NextInternalAddresses` or `NextExternalAddresses`) to derive the
+// next chained address for that account.
+func (s *ScopedKeyManager) NewAddress(addrmgrNs walletdb.ReadWriteBucket,
+ account string, change bool) (address.Address, error) {
+
+ accountNum, err := s.LookupAccount(addrmgrNs, account)
+ if err != nil {
+ return nil, err
+ }
+
+ // TODO(yy): get rid of the list, we should always return one address
+ // here.
+ var addrs []ManagedAddress
+
+ if change {
+ // Get next chained change address from wallet for account.
+ addrs, err = s.NextInternalAddresses(addrmgrNs, accountNum, 1)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ // Get next address from wallet.
+ addrs, err = s.NextExternalAddresses(addrmgrNs, accountNum, 1)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if len(addrs) == 0 {
+ return nil, managerError(
+ ErrAddressNotFound, "no addresses were generated", nil,
+ )
+ }
+
+ addr := addrs[0].Address()
+
+ return addr, nil
+}
+
+// getNextIndex fetches the current next address index for the specified branch
+// (internal/external) of an account directly from the database. This bypasses
+// the in-memory cache to ensure the most up-to-date state is used, avoiding
+// potential race conditions.
+func (s *ScopedKeyManager) getNextIndex(ns walletdb.ReadBucket,
+ account uint32, internal bool) (uint32, error) {
+
+ rowInterface, err := fetchAccountInfo(ns, &s.scope, account)
+ if err != nil {
+ return 0, maybeConvertDbError(err)
+ }
+
+ var nextInternal, nextExternal uint32
+ switch row := rowInterface.(type) {
+ case *dbDefaultAccountRow:
+ nextInternal = row.nextInternalIndex
+ nextExternal = row.nextExternalIndex
+
+ case *dbWatchOnlyAccountRow:
+ nextInternal = row.nextInternalIndex
+ nextExternal = row.nextExternalIndex
+
+ default:
+ str := fmt.Sprintf("unsupported account type %T", row)
+ return 0, managerError(ErrDatabase, str, nil)
+ }
+
+ if internal {
+ return nextInternal, nil
+ }
+
+ return nextExternal, nil
+}
+
+// deriveAddr performs the actual derivation logic for a single address using
+// the provided account info. It assumes the manager lock is held.
+func (s *ScopedKeyManager) deriveAddr(acctInfo *accountInfo, account, branch,
+ index uint32) (address.Address, []byte, error) {
+
+ // Determine the address type (schema) for this account and branch.
+ // This tells us whether to generate P2PKH (BIP44), P2WPKH (BIP84),
+ // Nested P2WPKH (BIP49), or Taproot (BIP86) addresses.
+ // Internal branch usually implies change addresses.
+ internal := branch == InternalBranch
+ addrType := s.accountAddrType(acctInfo, internal)
+
+ // Derive the extended key for this index.
+ // We pass 'false' for the private flag because we only need the
+ // public key to derive the address. This allows operation even
+ // if the wallet is locked (encrypted private keys unavailable).
+ key, err := s.deriveKey(acctInfo, branch, index, false)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ pubKey, err := key.ECPubKey()
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to parse public key: %w",
+ err)
+ }
+
+ // Construct the derivation path metadata. This is required by
+ // newManagedAddressWithoutPrivKey to properly tag the address.
+ derivationPath := DerivationPath{
+ InternalAccount: account,
+ Account: acctInfo.acctKeyPub.ChildIndex(),
+ Branch: branch,
+ Index: index,
+ MasterKeyFingerprint: acctInfo.masterKeyFingerprint,
+ }
+
+ // Create a temporary managed address. We use this helper because
+ // it encapsulates the complex logic for converting a public key
+ // into the correct address format (e.g. P2SH wrapping for nested
+ // SegWit) based on the addrType.
+ ma, err := newManagedAddressWithoutPrivKey(
+ s, derivationPath, pubKey, true, addrType,
+ )
+ if err != nil {
+ return nil, nil, err
+ }
+
+ addr := ma.Address()
+
+ script, err := txscript.PayToAddrScript(addr)
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to create script: %w", err)
+ }
+
+ return addr, script, nil
+}
+
+// accountInfo returns a copy of the account info map.
+func (s *ScopedKeyManager) accountInfo() map[uint32]*accountInfo {
+ s.mtx.RLock()
+ defer s.mtx.RUnlock()
+
+ acctInfoCopy := make(map[uint32]*accountInfo, len(s.acctInfo))
+ maps.Copy(acctInfoCopy, s.acctInfo)
+
+ return acctInfoCopy
+}
+
+// addresses returns a slice of all managed addresses.
+func (s *ScopedKeyManager) addresses() []ManagedAddress {
+ s.mtx.RLock()
+ defer s.mtx.RUnlock()
+
+ addrs := make([]ManagedAddress, 0, len(s.addrs))
+ for _, ma := range s.addrs {
+ addrs = append(addrs, ma)
+ }
+
+ return addrs
+}
diff --git a/waddrmgr/scoped_manager_test.go b/waddrmgr/scoped_manager_test.go
new file mode 100644
index 0000000000..323d130e35
--- /dev/null
+++ b/waddrmgr/scoped_manager_test.go
@@ -0,0 +1,303 @@
+package waddrmgr
+
+import (
+ "math"
+ "testing"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/stretchr/testify/require"
+)
+
+// TestDeriveAddrs verifies that DeriveAddrs correctly derives addresses using
+// in-memory state, producing the same results as database-backed derivation.
+func TestDeriveAddrs(t *testing.T) {
+ t.Parallel()
+
+ // Initialize a new address manager with a clean database for testing.
+ teardown, db, mgr := setupManager(t)
+ t.Cleanup(teardown)
+
+ // Unlock the manager to allow full functionality, although DeriveAddrs
+ // works without unlocking (tested separately). We unlock here to
+ // ensure DeriveFromKeyPath (the baseline) has access to private keys
+ // if needed by its internal logic.
+ err := walletdb.View(db, func(tx walletdb.ReadTx) error {
+ ns := tx.ReadBucket(waddrmgrNamespaceKey)
+ return mgr.Unlock(ns, privPassphrase)
+ })
+ require.NoError(t, err)
+
+ // Fetch the default BIP0044 scoped manager.
+ scope := KeyScopeBIP0044
+ acctStore, err := mgr.FetchScopedKeyManager(scope)
+ require.NoError(t, err)
+
+ // Cast to the concrete type to access the method under test.
+ scopedMgr, ok := acctStore.(*ScopedKeyManager)
+ require.True(t, ok, "expected *ScopedKeyManager")
+
+ account := uint32(DefaultAccountNum)
+
+ // Pre-load account into cache (required for DeriveAddrs).
+ err = walletdb.View(db, func(tx walletdb.ReadTx) error {
+ ns := tx.ReadBucket(waddrmgrNamespaceKey)
+ _, err := scopedMgr.AccountProperties(ns, account)
+
+ return err
+ })
+ require.NoError(t, err)
+
+ // NOTE: We define it here instead of using anonymous struct as this
+ // struct is needed for the `assertDBCorrectness`.
+ type testCase struct {
+ name string
+ branch uint32
+ startIndex uint32
+ count uint32
+ }
+
+ // We define a set of test cases covering different branches
+ // (internal/external) and index ranges to ensure robust derivation. We
+ // also test large batches to verify performance and correctness at
+ // scale.
+ tests := []testCase{
+ {
+ name: "External Branch, Index 0-4",
+ branch: ExternalBranch,
+ startIndex: 0,
+ count: 5,
+ },
+ {
+ name: "Internal Branch, Index 10-14",
+ branch: InternalBranch,
+ startIndex: 10,
+ count: 5,
+ },
+ {
+ name: "Large Batch",
+ branch: ExternalBranch,
+ startIndex: 100,
+ count: 50,
+ },
+ {
+ name: "Single Address",
+ branch: ExternalBranch,
+ startIndex: 1000,
+ count: 1,
+ },
+ {
+ name: "Zero Addresses",
+ branch: ExternalBranch,
+ startIndex: 0,
+ count: 0,
+ },
+ }
+
+ accountNum := hdkeychain.HardenedKeyStart + account
+
+ // assertDBCorrectness is a helper closure that verifies the results
+ // returned by DeriveAddrs against the baseline DeriveFromKeyPath
+ // method. This ensures that the in-memory derivation logic produces
+ // identical addresses and scripts as the database-backed logic.
+ assertDBCorrectness := func(t *testing.T, tc testCase,
+ addrs []address.Address) {
+
+ t.Helper()
+
+ err := walletdb.View(db, func(tx walletdb.ReadTx) error {
+ ns := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ for i := range tc.count {
+ index := tc.startIndex + i
+
+ // Construct the derivation path for the
+ // baseline check.
+ path := DerivationPath{
+ InternalAccount: account,
+ Account: accountNum,
+ Branch: tc.branch,
+ Index: index,
+ }
+
+ // Derive using the standard DB-backed method.
+ managedAddr, err := scopedMgr.DeriveFromKeyPath(
+ ns, path,
+ )
+ require.NoError(t, err)
+
+ // Compare the resulting address string.
+ expectedAddr := managedAddr.Address()
+ require.Equal(t, expectedAddr.String(),
+ addrs[i].String(), "Address mismatch "+
+ "at index %d", index)
+
+ // Compare the resulting script.
+ require.Equal(t, expectedAddr.ScriptAddress(),
+ addrs[i].ScriptAddress())
+ }
+
+ return nil
+ })
+ require.NoError(t, err)
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Call the new in-memory derivation method. This
+ // should return the derived addresses and scripts
+ // without further DB access.
+ addrs, scripts, err := scopedMgr.DeriveAddrs(
+ account, tc.branch, tc.startIndex, tc.count,
+ )
+ require.NoError(t, err)
+ require.Len(t, addrs, int(tc.count))
+ require.Len(t, scripts, int(tc.count))
+
+ // Verify the results against the established,
+ // database-backed DeriveFromKeyPath method to ensure
+ // correctness.
+ assertDBCorrectness(t, tc, addrs)
+ })
+ }
+}
+
+// TestDeriveAddrsLocked verifies that DeriveAddrs works even when the wallet
+// is locked (using extended public keys).
+func TestDeriveAddrsLocked(t *testing.T) {
+ t.Parallel()
+
+ // Initialize the manager. By default, it is locked.
+ teardown, db, mgr := setupManager(t)
+ t.Cleanup(teardown)
+
+ // Confirm the manager is indeed locked.
+ require.True(t, mgr.IsLocked())
+
+ scope := KeyScopeBIP0044
+ acctStore, err := mgr.FetchScopedKeyManager(scope)
+ require.NoError(t, err)
+
+ scopedMgr, ok := acctStore.(*ScopedKeyManager)
+ require.True(t, ok)
+
+ // Pre-load the account into the cache using a read-only transaction.
+ // AccountProperties only needs public keys, so it works while locked.
+ err = walletdb.View(db, func(tx walletdb.ReadTx) error {
+ ns := tx.ReadBucket(waddrmgrNamespaceKey)
+ _, err := scopedMgr.AccountProperties(ns, DefaultAccountNum)
+
+ return err
+ })
+ require.NoError(t, err)
+
+ // Attempt to derive addresses while locked. This should succeed
+ // because it uses the cached extended public keys.
+ addrs, _, err := scopedMgr.DeriveAddrs(
+ DefaultAccountNum, ExternalBranch, 0, 5,
+ )
+
+ // Verify success and result count.
+ require.NoError(t, err, "DeriveAddrs should succeed when locked")
+ require.Len(t, addrs, 5)
+}
+
+// TestDeriveAddrsOverflow verifies that DeriveAddrs returns an error when the
+// requested range of child indexes overflows uint32.
+func TestDeriveAddrsOverflow(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup the environment with a zero-valued ScopedKeyManager,
+ // as the overflow check is performed before any other logic or state
+ // access.
+ var s ScopedKeyManager
+
+ // Act: Execute the function under test with a range that triggers a
+ // uint32 overflow (startIndex + count > math.MaxUint32).
+ startIndex := uint32(math.MaxUint32 - 5)
+ count := uint32(10)
+ _, _, err := s.DeriveAddrs(0, 0, startIndex, count)
+
+ // Assert: Verify that the expected overflow error is returned.
+ require.Error(t, err)
+ require.True(t, IsError(err, ErrTooManyAddresses))
+ require.Contains(t, err.Error(), "child index overflow")
+}
+
+// TestDeriveAddr verifies that DeriveAddr correctly derives a single address
+// using in-memory state.
+func TestDeriveAddr(t *testing.T) {
+ t.Parallel()
+
+ // Initialize manager.
+ teardown, db, mgr := setupManager(t)
+ t.Cleanup(teardown)
+
+ // Unlock manager to allow full functionality.
+ err := walletdb.View(db, func(tx walletdb.ReadTx) error {
+ ns := tx.ReadBucket(waddrmgrNamespaceKey)
+ return mgr.Unlock(ns, privPassphrase)
+ })
+ require.NoError(t, err)
+
+ // Fetch scoped manager.
+ scope := KeyScopeBIP0044
+ acctStore, err := mgr.FetchScopedKeyManager(scope)
+ require.NoError(t, err)
+
+ scopedMgr, ok := acctStore.(*ScopedKeyManager)
+ require.True(t, ok)
+
+ account := uint32(DefaultAccountNum)
+
+ // Pre-load account into cache.
+ err = walletdb.View(db, func(tx walletdb.ReadTx) error {
+ ns := tx.ReadBucket(waddrmgrNamespaceKey)
+ _, err := scopedMgr.AccountProperties(ns, account)
+
+ return err
+ })
+ require.NoError(t, err)
+
+ // Define test parameters.
+ branch := ExternalBranch
+ index := uint32(0)
+
+ // Call DeriveAddr (In-Memory).
+ addr, script, err := scopedMgr.DeriveAddr(account, branch, index)
+ require.NoError(t, err)
+
+ // Verify against Baseline (DeriveFromKeyPath via DB).
+ err = walletdb.View(db, func(tx walletdb.ReadTx) error {
+ ns := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ path := DerivationPath{
+ InternalAccount: account,
+ Account: hdkeychain.HardenedKeyStart + account,
+ Branch: branch,
+ Index: index,
+ }
+
+ managedAddr, err := scopedMgr.DeriveFromKeyPath(ns, path)
+ require.NoError(t, err)
+
+ // Compare address string and script hash.
+ require.Equal(t, managedAddr.Address().String(),
+ addr.String())
+ require.Equal(t, managedAddr.Address().ScriptAddress(),
+ addr.ScriptAddress())
+
+ // Verify returned script matches expected P2PKH script.
+ expectedScript, _ := txscript.PayToAddrScript(
+ managedAddr.Address(),
+ )
+ require.Equal(t, expectedScript, script)
+
+ return nil
+ })
+ require.NoError(t, err)
+}
diff --git a/wallet/account_manager.go b/wallet/account_manager.go
new file mode 100644
index 0000000000..7c32d88d8f
--- /dev/null
+++ b/wallet/account_manager.go
@@ -0,0 +1,1059 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+// Package wallet implements the account management for the wallet.
+//
+// TODO(yy): bring wrapcheck back when implementing the `Store` interface.
+//
+//nolint:wrapcheck
+package wallet
+
+import (
+ "context"
+ "encoding/binary"
+ "errors"
+ "fmt"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/netparams"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+)
+
+// AccountManager provides a high-level interface for managing wallet
+// accounts.
+//
+// # Account Derivation
+//
+// The wallet uses a hierarchical deterministic (HD) key generation scheme based
+// on BIP-44. Addresses are derived from a path with the following structure:
+//
+// m / purpose' / coin_type' / account' / change / address_index
+//
+// The AccountManager abstracts this complexity by mapping a human-readable
+// name to the cryptographic `account'` index within a given KeyScope.
+//
+// # Key Scopes
+//
+// The `purpose'` and `coin_type'` fields of the derivation path are defined by
+// a waddrmgr.KeyScope. This allows the wallet to manage different kinds of
+// accounts (and address types) simultaneously. The wallet initializes a set of
+// default scopes upon creation:
+// - KeyScopeBIP0044: For legacy P2PKH addresses.
+// - KeyScopeBIP0049Plus: For P2WPKH addresses nested in P2SH (NP2WKH).
+// - KeyScopeBIP0084: For native SegWit v0 P2WPKH addresses.
+// - KeyScopeBIP0086: For native Taproot v1 P2TR addresses.
+//
+// # Account Names and Reserved Accounts
+//
+// An account name is a human-readable identifier that is unique *within its
+// KeyScope*. The wallet initializes two special, reserved accounts:
+// - "default": The first user-created account (account number 0). This
+// account is created for each of the default key scopes and CAN be renamed.
+// - "imported": A special account that holds all individually imported keys.
+// This account is global and CANNOT be renamed.
+type AccountManager interface {
+ // NewAccount creates a new account for a given key scope and name. The
+ // provided name must be unique within that key scope.
+ NewAccount(ctx context.Context, scope waddrmgr.KeyScope, name string) (
+ *waddrmgr.AccountProperties, error)
+
+ // ListAccounts returns a list of all accounts managed by the wallet.
+ ListAccounts(ctx context.Context) (*AccountsResult, error)
+
+ // ListAccountsByScope returns a list of all accounts for a given key
+ // scope.
+ ListAccountsByScope(ctx context.Context, scope waddrmgr.KeyScope) (
+ *AccountsResult, error)
+
+ // ListAccountsByName searches for accounts with the given name across
+ // all key scopes. Because names are not globally unique, this may
+ // return multiple results.
+ ListAccountsByName(ctx context.Context, name string) (
+ *AccountsResult, error)
+
+ // GetAccount returns the properties for a specific account, looked up
+ // by its key scope and unique name within that scope.
+ GetAccount(ctx context.Context, scope waddrmgr.KeyScope, name string) (
+ *AccountResult, error)
+
+ // RenameAccount renames an existing account. To uniquely identify the
+ // account, the key scope must be provided. The new name must be unique
+ // within that same key scope. The reserved "imported" account cannot
+ // be renamed.
+ RenameAccount(ctx context.Context, scope waddrmgr.KeyScope,
+ oldName string, newName string) error
+
+ // Balance returns the balance for a specific account, identified by its
+ // scope and name, for a given number of required confirmations.
+ Balance(ctx context.Context, conf uint32, scope waddrmgr.KeyScope,
+ name string) (btcutil.Amount, error)
+
+ // ImportAccount imports an account from an extended public or private
+ // key. The key scope is derived from the version bytes of the
+ // extended key. The account name must be unique within the derived
+ // scope. If dryRun is true, the import is validated but not persisted.
+ ImportAccount(ctx context.Context, name string,
+ accountKey *hdkeychain.ExtendedKey,
+ masterKeyFingerprint uint32, addrType waddrmgr.AddressType,
+ dryRun bool) (*waddrmgr.AccountProperties, error)
+}
+
+// A compile time check to ensure that Wallet implements the interface.
+var _ AccountManager = (*Wallet)(nil)
+
+// NewAccount creates the next account and returns its account number. The name
+// must be unique under the kep scope. In order to support automatic seed
+// restoring, new accounts may not be created when all of the previous 100
+// accounts have no transaction history (this is a deviation from the BIP0044
+// spec, which allows no unused account gaps).
+func (w *Wallet) NewAccount(_ context.Context, scope waddrmgr.KeyScope,
+ name string) (*waddrmgr.AccountProperties, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ // Validate that the scope manager can add this new account.
+ err = manager.CanAddAccount()
+ if err != nil {
+ return nil, err
+ }
+
+ var props *waddrmgr.AccountProperties
+
+ err = walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ // Create a new account under the current key scope.
+ accNum, err := manager.NewAccount(addrmgrNs, name)
+ if err != nil {
+ return err
+ }
+
+ // Get the account's properties.
+ props, err = manager.AccountProperties(addrmgrNs, accNum)
+
+ return err
+ })
+
+ return props, err
+}
+
+// AccountResult is the result of a ListAccounts query.
+type AccountResult struct {
+ // AccountProperties is the account's properties.
+ waddrmgr.AccountProperties
+
+ // TotalBalance is the total balance of the account.
+ TotalBalance btcutil.Amount
+}
+
+// AccountsResult is the result of a ListAccounts query. It contains a list of
+// accounts and the current block height and hash.
+type AccountsResult struct {
+ // Accounts is a list of accounts.
+ Accounts []AccountResult
+
+ // CurrentBlockHash is the hash of the current block.
+ CurrentBlockHash chainhash.Hash
+
+ // CurrentBlockHeight is the height of the current block.
+ CurrentBlockHeight int32
+}
+
+// ListAccounts returns a list of all accounts for the wallet, including those
+// with a zero balance. The current chain tip is included in the result for
+// reference.
+//
+// The function calculates balances by first creating a comprehensive map of
+// balances for all accounts that currently own UTXOs. It then iterates through
+// all known accounts across all key scopes, retrieving their properties and
+// assigning the pre-calculated balance. Accounts with no UTXOs will correctly
+// be assigned a zero balance.
+//
+// The time complexity of this method is O(U*logA + A), where U is the number of
+// UTXOs and A is the number of accounts in the wallet. A potential future
+// improvement is to make the balance calculation optional.
+func (w *Wallet) ListAccounts(_ context.Context) (*AccountsResult, error) {
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ // Get all active key scope managers to iterate through all available
+ // scopes.
+ scopes := w.addrStore.ActiveScopedKeyManagers()
+
+ var accounts []AccountResult
+
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ // First, build a map of balances for all accounts that own at
+ // least one UTXO. This is done by iterating through the UTXO
+ // set and aggregating the values by account.
+ scopedBalances, err := w.fetchAccountBalances(tx)
+ if err != nil {
+ return err
+ }
+
+ // Now, iterate through all key scopes to assemble the final
+ // list of accounts with their properties and balances.
+ for _, scopeMgr := range scopes {
+ scope := scopeMgr.Scope()
+ accountBalances := scopedBalances[scope]
+
+ // For the current scope, retrieve the properties for
+ // each account and combine them with the
+ // pre-calculated balances.
+ scopedAccounts, err := listAccountsWithBalances(
+ scopeMgr, addrmgrNs, accountBalances,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Append the accounts from this scope to the final
+ // list.
+ accounts = append(accounts, scopedAccounts...)
+ }
+
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Include the wallet's current sync state in the result to provide a
+ // point-in-time reference for the balances.
+ syncBlock := w.addrStore.SyncedTo()
+
+ return &AccountsResult{
+ Accounts: accounts,
+ CurrentBlockHash: syncBlock.Hash,
+ CurrentBlockHeight: syncBlock.Height,
+ }, nil
+}
+
+// ListAccountsByScope returns a list of all accounts for a given key scope,
+// including those with a zero balance. The current chain tip is included for
+// reference.
+//
+// The function first fetches the balances for all accounts within the given
+// scope by iterating over the wallet's UTXO set. It then retrieves the
+// properties for each account in that scope and combines them with the
+// pre-calculated balances.
+//
+// The time complexity of this method is O(U*logA + A), where U is the number of
+// UTXOs and A is the number of accounts in the wallet.
+func (w *Wallet) ListAccountsByScope(_ context.Context,
+ scope waddrmgr.KeyScope) (*AccountsResult, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ // First, we'll fetch the scoped key manager for the given scope. This
+ // manager will be used to list the accounts.
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ var accounts []AccountResult
+
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ // Calculate the balances for all accounts, but only for the
+ // key scope we are interested in.
+ scopedBalances, err := w.fetchAccountBalances(
+ tx, withScope(scope),
+ )
+ if err != nil {
+ return err
+ }
+
+ // Now, retrieve the properties for each account in the scope
+ // and combine them with the balances calculated above.
+ accounts, err = listAccountsWithBalances(
+ manager, addrmgrNs, scopedBalances[scope],
+ )
+
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Include the wallet's current sync state in the result.
+ syncBlock := w.addrStore.SyncedTo()
+
+ return &AccountsResult{
+ Accounts: accounts,
+ CurrentBlockHash: syncBlock.Hash,
+ CurrentBlockHeight: syncBlock.Height,
+ }, nil
+}
+
+// ListAccountsByName returns a list of all accounts that have a given name.
+// Since account names are only unique within a key scope, this can return
+// multiple accounts. The current chain tip is included for reference.
+//
+// The function first calculates the balances for any accounts matching the
+// given name, and then iterates through all key scopes to find and retrieve
+// the properties of those accounts.
+//
+// The time complexity of this method is O(U*logA), where U is the number of
+// UTXOs and logA is the cost of an account lookup.
+func (w *Wallet) ListAccountsByName(_ context.Context,
+ name string) (*AccountsResult, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ scopes := w.addrStore.ActiveScopedKeyManagers()
+
+ var accounts []AccountResult
+
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ // First, calculate the balances for any accounts that match the
+ // given name. This is efficient as it iterates over the UTXO
+ // set, not accounts.
+ scopedBalances, err := w.fetchAccountBalances(tx)
+ if err != nil {
+ return err
+ }
+
+ // Now, find all accounts that match the given name by iterating
+ // through all active scopes.
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ for _, scopeMgr := range scopes {
+ // Look up the account number for the given name in the
+ // current scope.
+ accNum, err := scopeMgr.LookupAccount(addrmgrNs, name)
+ if err != nil {
+ // If the account is not found in this scope,
+ // we can safely continue to the next one.
+ if waddrmgr.IsError(
+ err, waddrmgr.ErrAccountNotFound) {
+
+ continue
+ }
+
+ return err
+ }
+
+ // Retrieve the account's properties.
+ props, err := scopeMgr.AccountProperties(
+ addrmgrNs, accNum,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Get the pre-calculated balance for this account. If
+ // the account has no balance, it will be zero.
+ var balance btcutil.Amount
+
+ balances, ok := scopedBalances[scopeMgr.Scope()]
+ if ok {
+ balance = balances[accNum]
+ }
+
+ accounts = append(accounts, AccountResult{
+ AccountProperties: *props,
+ TotalBalance: balance,
+ })
+ }
+
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ syncBlock := w.addrStore.SyncedTo()
+
+ return &AccountsResult{
+ Accounts: accounts,
+ CurrentBlockHash: syncBlock.Hash,
+ CurrentBlockHeight: syncBlock.Height,
+ }, nil
+}
+
+// GetAccount returns the account for a given account name and key scope.
+//
+// The function first looks up the account's properties and then calculates its
+// balance by iterating over the wallet's UTXO set.
+//
+// The time complexity of this method is O(U*logA), where U is the number of
+// UTXOs and logA is the cost of an account lookup.
+func (w *Wallet) GetAccount(_ context.Context, scope waddrmgr.KeyScope,
+ name string) (*AccountResult, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ var account *AccountResult
+
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ // Look up the account number for the given name and scope. This
+ // is a fast, indexed lookup.
+ accNum, err := manager.LookupAccount(addrmgrNs, name)
+ if err != nil {
+ return err
+ }
+
+ // Retrieve the static properties for the account.
+ props, err := manager.AccountProperties(addrmgrNs, accNum)
+ if err != nil {
+ return err
+ }
+
+ account = &AccountResult{
+ AccountProperties: *props,
+ }
+
+ // Calculate the balance for this specific account by fetching
+ // the UTXOs that belong to it.
+ scopedBalances, err := w.fetchAccountBalances(
+ tx, withScope(scope),
+ )
+ if err != nil {
+ return err
+ }
+
+ // Assign the balance to the account result. If the account has
+ // no UTXOs, the balance will be zero.
+ if balances, ok := scopedBalances[scope]; ok {
+ if balance, ok := balances[accNum]; ok {
+ account.TotalBalance = balance
+ }
+ }
+
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return account, nil
+}
+
+// RenameAccount renames an existing account. The new name must be unique within
+// the same key scope. The reserved "imported" account cannot be renamed.
+//
+// The time complexity of this method is dominated by the database lookup for
+// the old account name.
+func (w *Wallet) RenameAccount(_ context.Context, scope waddrmgr.KeyScope,
+ oldName, newName string) error {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return err
+ }
+
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return err
+ }
+
+ // Validate the new account name to ensure it meets the required
+ // criteria.
+ err = waddrmgr.ValidateAccountName(newName)
+ if err != nil {
+ return err
+ }
+
+ return walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ // Look up the account number for the given name. This is
+ // required to perform the rename operation.
+ accNum, err := manager.LookupAccount(addrmgrNs, oldName)
+ if err != nil {
+ return err
+ }
+
+ // Perform the rename operation in the address manager.
+ return manager.RenameAccount(addrmgrNs, accNum, newName)
+ })
+}
+
+// Balance returns the balance for a specific account, identified by its scope
+// and name, for a given number of required confirmations.
+//
+// The function first looks up the account number and then iterates through all
+// unspent transaction outputs (UTXOs), summing the values of those that belong
+// to the account and meet the required number of confirmations.
+//
+// The time complexity of this method is O(U*logA), where U is the number of
+// UTXOs and logA is the cost of an account lookup.
+func (w *Wallet) Balance(_ context.Context, conf uint32,
+ scope waddrmgr.KeyScope, name string) (btcutil.Amount, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return 0, err
+ }
+
+ var balance btcutil.Amount
+
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ // Look up the account number for the given name and scope.
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return err
+ }
+
+ accNum, err := manager.LookupAccount(addrmgrNs, name)
+ if err != nil {
+ return err
+ }
+
+ // Iterate through all unspent outputs and sum the balances for
+ // the addresses that belong to the target account.
+ syncBlock := w.addrStore.SyncedTo()
+
+ utxos, err := w.txStore.UnspentOutputs(txmgrNs)
+ if err != nil {
+ return err
+ }
+
+ for _, utxo := range utxos {
+ // Skip any UTXOs that have not yet reached the required
+ // number of confirmations.
+ if !hasMinConfs(conf, utxo.Height, syncBlock.Height) {
+ continue
+ }
+
+ balance += w.balanceForUTXO(
+ addrmgrNs, scope, accNum, utxo,
+ )
+ }
+
+ return nil
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ return balance, nil
+}
+
+// balanceForUTXO is a helper function for Balance that calculates the balance
+// of a single UTXO if it belongs to the target account.
+func (w *Wallet) balanceForUTXO(addrmgrNs walletdb.ReadBucket,
+ scope waddrmgr.KeyScope, accNum uint32,
+ utxo wtxmgr.Credit) btcutil.Amount {
+
+ // Extract the address from the UTXO's public key script.
+ addr := extractAddrFromPKScript(
+ utxo.PkScript, w.cfg.ChainParams,
+ )
+ if addr == nil {
+ return 0
+ }
+
+ // Look up the account that owns the address.
+ addrScope, addrAcc, err := w.addrStore.AddrAccount(addrmgrNs, addr)
+ if err != nil {
+ // Ignore addresses that are not found in the wallet.
+ return 0
+ }
+
+ // If the address belongs to the target account, add the UTXO's value
+ // to the total balance.
+ if addrScope.Scope() == scope && addrAcc == accNum {
+ return utxo.Amount
+ }
+
+ return 0
+}
+
+// ImportAccount imports an account from an extended public or private key. The
+// key scope is derived from the version bytes of the extended key. The account
+// name must be unique within the derived scope. If dryRun is true, the import
+// is validated but not persisted.
+//
+// The time complexity of this method is dominated by the database lookup to
+// ensure the account name is unique within the scope.
+func (w *Wallet) ImportAccount(ctx context.Context,
+ name string, accountKey *hdkeychain.ExtendedKey,
+ masterKeyFingerprint uint32, addrType waddrmgr.AddressType,
+ dryRun bool) (*waddrmgr.AccountProperties, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ return w.importAccountInternal(
+ ctx, name, accountKey, masterKeyFingerprint, addrType, dryRun,
+ )
+}
+
+// importAccountInternal is the internal implementation of ImportAccount,
+// allowing callers (like Manager.Create) to bypass the started check.
+//
+// TODO(yy): we will move the db operation to a dedicated method, so we can
+// ignore cyclop for now.
+//
+//nolint:cyclop
+func (w *Wallet) importAccountInternal(_ context.Context,
+ name string, accountKey *hdkeychain.ExtendedKey,
+ masterKeyFingerprint uint32, addrType waddrmgr.AddressType,
+ dryRun bool) (*waddrmgr.AccountProperties, error) {
+
+ // Ensure we have a valid account public key. We require an account-level
+ // key (depth 3) to properly manage the derivation path.
+ err := validateExtendedPubKey(accountKey, true, w.cfg.ChainParams)
+ if err != nil {
+ return nil, err
+ }
+
+ // Determine what key scope the account public key should belong to and
+ // whether it should use a custom address schema. This is inferred from
+ // the key's HD version bytes.
+ keyScope, addrSchema, err := keyScopeFromPubKey(accountKey, &addrType)
+ if err != nil {
+ return nil, err
+ }
+
+ var props *waddrmgr.AccountProperties
+
+ // We'll perform the import within a database update transaction to ensure
+ // atomicity. If dryRun is enabled, we'll return a special error at the end
+ // to trigger a rollback.
+ err = walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ // Check if a manager for this key scope already exists. If not, we'll
+ // create a new one using the inferred schema.
+ scopedMgr, err := w.addrStore.FetchScopedKeyManager(keyScope)
+ if err != nil {
+ scopedMgr, err = w.addrStore.NewScopedKeyManager(
+ ns, keyScope, *addrSchema,
+ )
+ if err != nil {
+ return err
+ }
+ }
+
+ // Create the new watching-only account using the provided key. Since we
+ // only have the public key, the wallet won't be able to sign for this
+ // account unless the private key is also provided later.
+ account, err := scopedMgr.NewAccountWatchingOnly(
+ ns, name, accountKey, masterKeyFingerprint, addrSchema,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Retrieve the properties for the newly created account.
+ props, err = scopedMgr.AccountProperties(ns, account)
+ if !dryRun {
+ return err
+ }
+
+ // If this is a dry-run, we'll generate a few addresses to simulate the
+ // import process and then roll back.
+ props, err = importAccountDryRun(ns, props, scopedMgr)
+ if err != nil {
+ return err
+ }
+
+ // Make sure we always roll back the dry-run transaction by returning an
+ // error here.
+ return walletdb.ErrDryRunRollBack
+ })
+
+ // If this was a dry-run, we ignore the rollback error.
+ if err != nil && !dryRun && !errors.Is(err, walletdb.ErrDryRunRollBack) {
+ return nil, err
+ }
+
+ return props, nil
+}
+
+// importAccountDryRun simulates an account import by generating a single
+// address for both the internal and external derivation branches. This ensures
+// that the provided account key is valid and can be used to derive addresses.
+// The changes made during this simulation are rolled back by the caller.
+func importAccountDryRun(ns walletdb.ReadWriteBucket,
+ props *waddrmgr.AccountProperties, scopedMgr waddrmgr.AccountStore) (
+ *waddrmgr.AccountProperties, error) {
+
+ // The importAccount method above will cache the imported account within the
+ // scoped manager. Since this is a dry-run attempt, we'll want to invalidate
+ // the cache for it.
+ defer scopedMgr.InvalidateAccountCache(props.AccountNumber)
+
+ _, err := scopedMgr.NextExternalAddresses(ns, props.AccountNumber, 1)
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = scopedMgr.NextInternalAddresses(ns, props.AccountNumber, 1)
+ if err != nil {
+ return nil, err
+ }
+
+ // Refresh the account's properties after generating the addresses.
+ props, err = scopedMgr.AccountProperties(ns, props.AccountNumber)
+ if err != nil {
+ return nil, err
+ }
+
+ return props, nil
+}
+
+// validateExtendedPubKey ensures a sane derived public key is provided.
+func validateExtendedPubKey(pubKey *hdkeychain.ExtendedKey,
+ isAccountKey bool, chainParams *chaincfg.Params) error {
+
+ // Private keys are not allowed.
+ if pubKey.IsPrivate() {
+ return fmt.Errorf("%w: private keys cannot be imported",
+ ErrInvalidAccountKey)
+ }
+
+ // The public key must have a version corresponding to the current
+ // chain.
+ if !isPubKeyForNet(pubKey, chainParams) {
+ return fmt.Errorf("%w: expected extended public key for current "+
+ "network %v", ErrInvalidAccountKey, chainParams.Name)
+ }
+
+ // Verify the extended public key's depth and child index based on
+ // whether it's an account key or not.
+ if isAccountKey {
+ if pubKey.Depth() != accountPubKeyDepth {
+ return fmt.Errorf("%w: must be of the form "+
+ "m/purpose'/coin_type'/account'", ErrInvalidAccountKey)
+ }
+
+ if pubKey.ChildIndex() < hdkeychain.HardenedKeyStart {
+ return fmt.Errorf("%w: must be hardened", ErrInvalidAccountKey)
+ }
+
+ return nil
+ }
+
+ if pubKey.Depth() != pubKeyDepth {
+ return fmt.Errorf("%w: must be of the form "+
+ "m/purpose'/coin_type'/account'/change/address_index",
+ ErrInvalidAccountKey)
+ }
+
+ if pubKey.ChildIndex() >= hdkeychain.HardenedKeyStart {
+ return fmt.Errorf("%w: must not be hardened", ErrInvalidAccountKey)
+ }
+
+ return nil
+}
+
+// isPubKeyForNet determines if the given public key is for the current network
+// the wallet is operating under.
+//
+// Ignore exhaustive linter as the `wire.SigNet` is covered by `SigNetWire`.
+//
+//nolint:exhaustive,cyclop
+func isPubKeyForNet(pubKey *hdkeychain.ExtendedKey,
+ chainParams *chaincfg.Params) bool {
+
+ version := waddrmgr.HDVersion(binary.BigEndian.Uint32(pubKey.Version()))
+ switch chainParams.Net {
+ case wire.MainNet:
+ return version == waddrmgr.HDVersionMainNetBIP0044 ||
+ version == waddrmgr.HDVersionMainNetBIP0049 ||
+ version == waddrmgr.HDVersionMainNetBIP0084
+
+ case wire.TestNet, wire.TestNet3, wire.TestNet4,
+ netparams.SigNetWire(chainParams):
+
+ return version == waddrmgr.HDVersionTestNetBIP0044 ||
+ version == waddrmgr.HDVersionTestNetBIP0049 ||
+ version == waddrmgr.HDVersionTestNetBIP0084
+
+ // For simnet, we'll also allow the mainnet versions since simnet
+ // doesn't have defined versions for some of our key scopes, and the
+ // mainnet versions are usually used as the default regardless of the
+ // network/key scope.
+ case wire.SimNet:
+ return version == waddrmgr.HDVersionSimNetBIP0044 ||
+ version == waddrmgr.HDVersionMainNetBIP0049 ||
+ version == waddrmgr.HDVersionMainNetBIP0084
+
+ default:
+ return false
+ }
+}
+
+// extractAddrFromPKScript extracts an address from a public key script. If the
+// script cannot be parsed or does not contain any addresses, it returns nil.
+//
+// The address.Address is an interface that abstracts over different address
+// types. Returning the interface is idiomatic in this context.
+//
+//nolint:ireturn
+func extractAddrFromPKScript(pkScript []byte,
+ chainParams *chaincfg.Params) address.Address {
+
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(
+ pkScript, chainParams,
+ )
+ if err != nil {
+ // We'll log the error and return nil to prevent a single
+ // un-parsable script from failing a larger operation.
+ log.Errorf("Unable to parse pkscript: %v", err)
+ return nil
+ }
+
+ // This can happen for scripts that don't resolve to a standard address,
+ // such as OP_RETURN outputs. We can safely ignore these.
+ if len(addrs) == 0 {
+ return nil
+ }
+
+ // TODO(yy): For bare multisig outputs, ExtractPkScriptAddrs can
+ // return more than one address. Currently, we are only considering
+ // the first address, which could lead to incorrect balance
+ // attribution. However, since bare multisig is rare and modern
+ // wallets almost exclusively use P2SH or P2WSH for multisig (which
+ // are correctly handled as a single address), this is a low-priority
+ // issue.
+ return addrs[0]
+}
+
+// accountFilter is an internal struct used to specify filters for account
+// balance queries.
+type accountFilter struct {
+ scope *waddrmgr.KeyScope
+}
+
+// filterOption is a functional option type for account filtering.
+type filterOption func(*accountFilter)
+
+// withScope is a filter option to limit account queries to a specific key
+// scope.
+func withScope(scope waddrmgr.KeyScope) filterOption {
+ return func(f *accountFilter) {
+ f.scope = &scope
+ }
+}
+
+// scopedBalances is a type alias for a map of key scopes to a map of account
+// numbers to their total balance.
+type scopedBalances map[waddrmgr.KeyScope]map[uint32]btcutil.Amount
+
+// fetchAccountBalances creates a nested map of account balances, keyed by scope
+// and account number.
+//
+// This function is a core component of the wallet's balance calculation
+// logic. It is designed to be efficient, especially for wallets with a large
+// number of addresses.
+//
+// Design Rationale:
+// The primary performance consideration is the trade-off between iterating
+// through all Unspent Transaction Outputs (UTXOs) versus iterating through all
+// derived addresses for all accounts. A mature wallet may have millions of used
+// addresses, but a relatively small set of UTXOs. Therefore, this function is
+// optimized for this common case.
+//
+// The algorithm works as follows:
+// 1. Make a single pass over all UTXOs in the wallet.
+// 2. For each UTXO, look up the address and its corresponding account.
+// 3. Aggregate the UTXO values into a map of balances per account.
+//
+// This approach avoids iterating through a potentially massive number of
+// addresses and performing a database lookup for each one to check for a
+// balance. Instead, it starts with the smaller, known set of UTXOs and works
+// backward to the accounts.
+//
+// Filters:
+// The function's behavior can be customized by passing one or more filterOption
+// functions. This allows the caller to restrict the balance calculation to:
+// - A specific key scope (withScope).
+//
+// If no filters are provided, balances for all accounts across all scopes will
+// be fetched.
+//
+// TODO(yy): With a future SQL backend, this entire function could be
+// replaced by a single, more efficient query. By adding `account_id` and
+// `key_scope` columns to the `outputs` table, we could perform a direct
+// aggregation in the database, like:
+// `SELECT key_scope, account_id, SUM(value) FROM outputs
+// WHERE is_spent = false GROUP BY key_scope, account_id;`.
+// This would be significantly faster as the database is optimized for
+// these types of operations.
+//
+// TODO(yy): The current UTXO-first approach is optimal for mature wallets where
+// the number of addresses greatly exceeds the number of UTXOs. For new wallets
+// or accounts, an address-first approach might be more efficient. A future
+// improvement could be to dynamically choose the strategy based on the relative
+// counts of addresses and UTXOs for the accounts in question.
+func (w *Wallet) fetchAccountBalances(tx walletdb.ReadTx,
+ opts ...filterOption) (scopedBalances, error) {
+
+ // Apply the filter options.
+ filter := &accountFilter{}
+ for _, opt := range opts {
+ opt(filter)
+ }
+
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ // First, fetch all unspent outputs.
+ utxos, err := w.txStore.UnspentOutputs(txmgrNs)
+ if err != nil {
+ return nil, err
+ }
+
+ // Now, create the nested map to hold the balances.
+ scopedBalances := make(scopedBalances)
+
+ // Iterate through all UTXOs, mapping them back to their owning account
+ // to aggregate the total balance for each.
+ for _, utxo := range utxos {
+ addr := extractAddrFromPKScript(
+ utxo.PkScript, w.cfg.ChainParams,
+ )
+ if addr == nil {
+ // This can happen for non-standard script types.
+ continue
+ }
+
+ // Now that we have the address, we'll look up which account it
+ // belongs to.
+ scope, accNum, err := w.addrStore.AddrAccount(addrmgrNs, addr)
+ if err != nil {
+ log.Errorf("Unable to query account using address %v: "+
+ "%v", addr, err)
+
+ continue
+ }
+
+ // If a scope filter was provided, apply it now.
+ if filter.scope != nil {
+ if scope.Scope() != *filter.scope {
+ continue
+ }
+ }
+
+ // We'll use a nested map to store balances. If this is the
+ // first time we've seen this key scope, we'll need to
+ // initialize the inner map.
+ keyScope := scope.Scope()
+ if _, ok := scopedBalances[keyScope]; !ok {
+ scopedBalances[keyScope] = make(
+ map[uint32]btcutil.Amount,
+ )
+ }
+
+ // Finally, we'll add the UTXO's value to the account's
+ // balance.
+ scopedBalances[keyScope][accNum] += utxo.Amount
+ }
+
+ return scopedBalances, nil
+}
+
+// listAccountsWithBalances is a helper function that iterates through all
+// accounts in a given scope, fetches their properties, and combines them with
+// the provided account balances.
+//
+// This function is designed to be called after the balances for all relevant
+// accounts have already been computed by a function like fetchAccountBalances.
+// It serves as the final step to assemble the complete AccountResult objects.
+//
+// The function operates as follows:
+// 1. It determines the last account number for the given scope.
+// 2. It iterates from account number 0 to the last account.
+// 3. For each account, it retrieves its properties from the database.
+// 4. It looks up the pre-calculated balance from the accountBalances map.
+// 5. It constructs an AccountResult object with both the properties and the
+// balance.
+//
+// This separation of concerns (first calculating all balances, then assembling
+// the results) is a key part of the overall optimization strategy. It ensures
+// that we can efficiently gather all necessary data in distinct phases, rather
+// than mixing database reads and balance calculations in a less efficient
+// manner.
+func listAccountsWithBalances(scopeMgr waddrmgr.AccountStore,
+ addrmgrNs walletdb.ReadBucket,
+ accountBalances map[uint32]btcutil.Amount) ([]AccountResult, error) {
+
+ var accounts []AccountResult
+
+ lastAccount, err := scopeMgr.LastAccount(addrmgrNs)
+ if err != nil {
+ // If the scope has no accounts, we can just return an empty
+ // slice. This is a normal condition and not an error.
+ if waddrmgr.IsError(err, waddrmgr.ErrAccountNotFound) {
+ return nil, nil
+ }
+
+ return nil, err
+ }
+
+ // Iterate through all accounts from 0 to the last known account
+ // number for this scope.
+ for accNum := uint32(0); accNum <= lastAccount; accNum++ {
+ // For each account number, we'll fetch its full set of
+ // properties from the database.
+ props, err := scopeMgr.AccountProperties(addrmgrNs, accNum)
+ if err != nil {
+ return nil, err
+ }
+
+ // We'll look up the pre-calculated balance for this account.
+ // If the account has no UTXOs, it won't be in the map, so
+ // we'll default to a balance of 0.
+ balance, ok := accountBalances[accNum]
+ if !ok {
+ balance = 0
+ }
+
+ // Finally, we'll construct the full account result and add it
+ // to our list.
+ accounts = append(accounts, AccountResult{
+ AccountProperties: *props,
+ TotalBalance: balance,
+ })
+ }
+
+ return accounts, nil
+}
diff --git a/wallet/account_manager_benchmark_test.go b/wallet/account_manager_benchmark_test.go
new file mode 100644
index 0000000000..b7c1d4dea8
--- /dev/null
+++ b/wallet/account_manager_benchmark_test.go
@@ -0,0 +1,804 @@
+package wallet
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/stretchr/testify/require"
+)
+
+// BenchmarkListAccountsByScopeAPI benchmarks ListAccountsByScope API and a
+// deprecated variant of it using same key scope and identical test data across
+// multiple dataset sizes. Test names start with dataset size to group API
+// comparisons for benchstat analysis.
+func BenchmarkListAccountsByScopeAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ utxoGrowthPadding = decimalWidth(
+ utxoGrowth[len(utxoGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0044}
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-UTXOs",
+ accountGrowthPadding, accountGrowth[i],
+ utxoGrowthPadding, utxoGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := w.Accounts(scopes[0])
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := w.ListAccountsByScope(
+ b.Context(), scopes[0],
+ )
+ require.NoError(b, err)
+ }
+ })
+ }
+}
+
+// BenchmarkListAccountsAPI benchmarks ListAccounts API and a deprecated variant
+// of it using same key scopes and identical test data across multiple dataset
+// sizes. Test names start with dataset size to group API comparisons for
+// benchstat analysis.
+func BenchmarkListAccountsAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ utxoGrowthPadding = decimalWidth(
+ utxoGrowth[len(utxoGrowth)-1],
+ )
+
+ scopes = waddrmgr.DefaultKeyScopes
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-UTXOs",
+ accountGrowthPadding, accountGrowth[i],
+ utxoGrowthPadding, utxoGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := listAccountsDeprecated(bw.Wallet)
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := w.ListAccounts(b.Context())
+ require.NoError(b, err)
+ }
+ })
+ }
+}
+
+// BenchmarkListAccountsByNameAPI benchmarks ListAccountsByName API and a
+// deprecated variant of it using same key scopes and identical test data across
+// multiple dataset sizes. Test names start with dataset size to group API
+// comparisons for benchstat analysis.
+func BenchmarkListAccountsByNameAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ utxoGrowthPadding = decimalWidth(
+ utxoGrowth[len(utxoGrowth)-1],
+ )
+
+ scopes = waddrmgr.DefaultKeyScopes
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ accountName, _ := generateAccountName(accountGrowth[i], scopes)
+
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-UTXOs",
+ accountGrowthPadding, accountGrowth[i],
+ utxoGrowthPadding, utxoGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := listAccountsByNameDeprecated(
+ bw.Wallet, accountName,
+ )
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := w.ListAccountsByName(
+ b.Context(), accountName,
+ )
+ require.NoError(b, err)
+ }
+ })
+ }
+}
+
+// BenchmarkNewAccountAPI benchmarks NewAccount API and NextAccount API using
+// identical account creation operations across multiple dataset sizes. Test
+// names start with dataset size to group API comparisons for benchstat
+// analysis.
+func BenchmarkNewAccountAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 10
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0044}
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("%0*d-Accounts", accountGrowthPadding,
+ accountGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ count := 0
+ for b.Loop() {
+ // Generate a unique account name for each
+ // iteration to ensure the idempotent nature of
+ // the benchmark.
+ accountName := fmt.Sprintf("new-account-%d",
+ count)
+
+ _, err := bw.NextAccount(
+ scopes[0], accountName,
+ )
+ require.NoError(b, err)
+
+ count++
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ count := 0
+ for b.Loop() {
+ // Generate a unique account name for each
+ // iteration to ensure the idempotent nature of
+ // the benchmark.
+ accountName := fmt.Sprintf("new-account-%d",
+ count)
+
+ _, err := w.NewAccount(
+ b.Context(), scopes[0], accountName,
+ )
+ require.NoError(b, err)
+
+ count++
+ }
+ })
+ }
+}
+
+// BenchmarkGetAccountAPI benchmarks GetAccount API and a deprecated wrapper API
+// using identical account lookups across multiple dataset sizes. Test names
+// start with dataset size to group API comparisons for benchstat analysis.
+func BenchmarkGetAccountAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ utxoGrowthPadding = decimalWidth(
+ utxoGrowth[len(utxoGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0044}
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ accountName, _ := generateAccountName(accountGrowth[i], scopes)
+
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-UTXOs",
+ accountGrowthPadding, accountGrowth[i],
+ utxoGrowthPadding, utxoGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := getAccountDeprecated(
+ bw.Wallet, scopes[0], accountName,
+ )
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := w.GetAccount(
+ b.Context(), scopes[0], accountName,
+ )
+ require.NoError(b, err)
+ }
+ })
+ }
+}
+
+// BenchmarkRenameAccountAPI benchmarks RenameAccount API and
+// RenameAccountDeprecated API using identical rename operations across multiple
+// dataset sizes. Test names start with dataset size to group API comparisons
+// for benchstat analysis.
+func BenchmarkRenameAccountAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 11
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0044}
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ accountName, accountNumber := generateAccountName(
+ accountGrowth[i], scopes,
+ )
+ newAccountName := accountName + "-renamed"
+
+ name := fmt.Sprintf("%0*d-Accounts", accountGrowthPadding,
+ accountGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ count := 0
+ for b.Loop() {
+ newAccountName2 := fmt.Sprintf("%s-%d",
+ newAccountName, count)
+
+ err := w.RenameAccountDeprecated(
+ scopes[0], accountNumber,
+ newAccountName2,
+ )
+ require.NoError(b, err)
+
+ // Rename back to original to keep the benchmark
+ // idempotent.
+ err = w.RenameAccountDeprecated(
+ scopes[0], accountNumber, accountName,
+ )
+ require.NoError(b, err)
+
+ count++
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ newAccountName := accountName + "-renamed"
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ count := 0
+ for b.Loop() {
+ newAccountName2 := fmt.Sprintf("%s-%d",
+ newAccountName, count)
+
+ err := w.RenameAccount(
+ b.Context(), scopes[0], accountName,
+ newAccountName2,
+ )
+ require.NoError(b, err)
+
+ // Rename back to original to keep the benchmark
+ // idempotent.
+ err = w.RenameAccount(
+ b.Context(), scopes[0], newAccountName2,
+ accountName,
+ )
+ require.NoError(b, err)
+
+ count++
+ }
+ })
+ }
+}
+
+// BenchmarkGetBalanceAPI benchmarks Balance API and a deprecated wrapper API
+// using identical balance lookups across multiple dataset sizes. Test names
+// start with dataset size to group API comparisons for benchstat analysis.
+func BenchmarkGetBalanceAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ utxoGrowthPadding = decimalWidth(
+ utxoGrowth[len(utxoGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0044}
+
+ confirmations = int32(0)
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ accountName, _ := generateAccountName(accountGrowth[i], scopes)
+
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-UTXOs",
+ accountGrowthPadding, accountGrowth[i],
+ utxoGrowthPadding, utxoGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := getBalanceDeprecated(
+ bw.Wallet, scopes[0], accountName,
+ confirmations,
+ )
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := bw.Balance(
+ b.Context(), uint32(confirmations),
+ scopes[0], accountName,
+ )
+ require.NoError(b, err)
+ }
+ })
+ }
+}
+
+// BenchmarkImportAccountAPI benchmarks ImportAccount API and
+// ImportAccountDeprecated API using identical account import operations
+// across multiple dataset sizes. Test names start with dataset size to group
+// API comparisons for benchstat analysis.
+func BenchmarkImportAccountAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 10
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+
+ dryRun = false
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ accountKey, masterFingerprint, addrT := generateTestExtendedKey(
+ b, accountGrowth[i],
+ )
+
+ name := fmt.Sprintf("%0*d-Accounts", accountGrowthPadding,
+ accountGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ count := 0
+ for b.Loop() {
+ // Generate a unique account name for each
+ // iteration to ensure the idempotent nature of
+ // the benchmark.
+ accountName := fmt.Sprintf("import-account-%d",
+ count)
+
+ _, err := w.ImportAccountDeprecated(
+ accountName, accountKey,
+ masterFingerprint, &addrT,
+ )
+ require.NoError(b, err)
+
+ count++
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ count := 0
+ for b.Loop() {
+ // Generate a unique account name for each
+ // iteration to ensure the idempotent nature of
+ // the benchmark.
+ accountName := fmt.Sprintf("import-account-%d",
+ count)
+
+ _, err := w.ImportAccount(
+ b.Context(), accountName, accountKey,
+ masterFingerprint, addrT, dryRun,
+ )
+ require.NoError(b, err)
+
+ count++
+ }
+ })
+ }
+}
diff --git a/wallet/account_manager_test.go b/wallet/account_manager_test.go
new file mode 100644
index 0000000000..685835b5a0
--- /dev/null
+++ b/wallet/account_manager_test.go
@@ -0,0 +1,1321 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wallet
+
+import (
+ "encoding/binary"
+ "strings"
+ "testing"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+func hardenedKey(key uint32) uint32 {
+ return key + hdkeychain.HardenedKeyStart
+}
+
+func deriveAcctPubKey(t *testing.T, root *hdkeychain.ExtendedKey,
+ scope waddrmgr.KeyScope, paths ...uint32) *hdkeychain.ExtendedKey {
+
+ t.Helper()
+
+ path := []uint32{hardenedKey(scope.Purpose), hardenedKey(scope.Coin)}
+ path = append(path, paths...)
+
+ var (
+ currentKey = root
+ err error
+ )
+ for _, pathPart := range path {
+ currentKey, err = currentKey.Derive(pathPart)
+ require.NoError(t, err)
+ }
+
+ // The Neuter() method checks the version and doesn't know any
+ // non-standard methods. We need to convert them to standard, neuter,
+ // then convert them back with the target extended public key version.
+ pubVersionBytes := make([]byte, 4)
+ copy(pubVersionBytes, chainParams.HDPublicKeyID[:])
+
+ switch {
+ case strings.HasPrefix(root.String(), "uprv"):
+ binary.BigEndian.PutUint32(pubVersionBytes, uint32(
+ waddrmgr.HDVersionTestNetBIP0049,
+ ))
+
+ case strings.HasPrefix(root.String(), "vprv"):
+ binary.BigEndian.PutUint32(pubVersionBytes, uint32(
+ waddrmgr.HDVersionTestNetBIP0084,
+ ))
+ }
+
+ currentKey, err = currentKey.CloneWithVersion(
+ chainParams.HDPrivateKeyID[:],
+ )
+ require.NoError(t, err)
+ currentKey, err = currentKey.Neuter()
+ require.NoError(t, err)
+ currentKey, err = currentKey.CloneWithVersion(pubVersionBytes)
+ require.NoError(t, err)
+
+ return currentKey
+}
+
+const (
+ // testAccountName is a constant for the account name used in the tests.
+ testAccountName = "test"
+)
+
+// TestNewAccount tests that the NewAccount method works as expected.
+func TestNewAccount(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // We'll start by creating a new account under the BIP0084 scope. We
+ // expect this to succeed.
+ scope := waddrmgr.KeyScopeBIP0084
+
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On("NewAccount", mock.Anything, testAccountName).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: testAccountName,
+ }, nil).Once()
+
+ account, err := w.NewAccount(t.Context(), scope, testAccountName)
+ require.NoError(t, err, "unable to create new account")
+
+ // The new account should be the first account created, so it should
+ // have an index of 1.
+ require.Equal(t, uint32(1), account.AccountNumber, "expected account 1")
+
+ // We should be able to retrieve the account by its name.
+ deps.txStore.On("UnspentOutputs", mock.Anything).
+ Return([]wtxmgr.Credit(nil), nil).Once()
+
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("LookupAccount", mock.Anything, testAccountName).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: testAccountName,
+ }, nil).Once()
+
+ account2, err := w.GetAccount(t.Context(), scope, testAccountName)
+ require.NoError(t, err, "unable to retrieve account")
+ require.Equal(t, uint32(1), account2.AccountNumber)
+ require.Equal(t, testAccountName, account2.AccountName)
+
+ // We should not be able to create a new account with the same name.
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On("NewAccount", mock.Anything, testAccountName).
+ Return(uint32(0), waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrDuplicateAccount,
+ }).Once()
+
+ _, err = w.NewAccount(t.Context(), scope, testAccountName)
+ require.Error(t, err, "expected error when creating duplicate account")
+ require.True(
+ t, waddrmgr.IsError(err, waddrmgr.ErrDuplicateAccount),
+ "expected ErrDuplicateAccount",
+ )
+
+ // We should not be able to create a new account when the wallet is
+ // locked.
+ deps.addrStore.On("Lock").Return(nil).Once()
+
+ err = w.Lock(t.Context())
+ require.NoError(t, err)
+
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("CanAddAccount").
+ Return(waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrLocked,
+ }).Once()
+
+ _, err = w.NewAccount(t.Context(), scope, "test2")
+ require.Error(
+ t, err, "expected error when creating account while wallet is "+
+ "locked",
+ )
+}
+
+// TestListAccounts tests that the ListAccounts method works as expected.
+func TestListAccounts(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // We'll start by creating a new account under the BIP0084 scope.
+ scope := waddrmgr.KeyScopeBIP0084
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On("NewAccount", mock.Anything, testAccountName).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: testAccountName,
+ }, nil).Once()
+
+ _, err := w.NewAccount(t.Context(), scope, testAccountName)
+ require.NoError(t, err, "unable to create new account")
+
+ // Setup expectations for ListAccounts.
+ deps.addrStore.On("ActiveScopedKeyManagers").
+ Return([]waddrmgr.AccountStore{deps.accountManager}).Once()
+
+ deps.accountManager.On("Scope").Return(scope).Once()
+ deps.accountManager.On("LastAccount", mock.Anything).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(0)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 0,
+ AccountName: "default",
+ }, nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: testAccountName,
+ }, nil).Once()
+
+ deps.txStore.On("UnspentOutputs", mock.Anything).
+ Return([]wtxmgr.Credit(nil), nil).Once()
+
+ // Now, we'll list all accounts and check that we have the default
+ // account and the new account.
+ accounts, err := w.ListAccounts(t.Context())
+ require.NoError(t, err, "unable to list accounts")
+
+ // We should have two accounts.
+ require.Len(t, accounts.Accounts, 2, "expected two accounts")
+
+ // The first account should be the default account.
+ require.Equal(
+ t, "default", accounts.Accounts[0].AccountName,
+ "expected default account",
+ )
+ require.Equal(
+ t, uint32(0), accounts.Accounts[0].AccountNumber,
+ "expected default account number",
+ )
+ require.Equal(
+ t, btcutil.Amount(0), accounts.Accounts[0].TotalBalance,
+ "expected zero balance for default account",
+ )
+
+ // The new account should also be present.
+ require.Equal(
+ t, testAccountName, accounts.Accounts[1].AccountName,
+ "expected new account",
+ )
+ require.Equal(
+ t, uint32(1), accounts.Accounts[1].AccountNumber,
+ "expected new account number",
+ )
+}
+
+// TestListAccountsByScope tests that the ListAccountsByScope method works as
+// expected.
+func TestListAccountsByScope(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // We'll create two new accounts, one under the BIP0084 scope and one
+ // under the BIP0049 scope.
+ scopeBIP84 := waddrmgr.KeyScopeBIP0084
+ accBIP84Name := "test bip84"
+
+ deps.addrStore.On("FetchScopedKeyManager", scopeBIP84).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On("NewAccount", mock.Anything, accBIP84Name).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: accBIP84Name,
+ }, nil).Once()
+
+ _, err := w.NewAccount(t.Context(), scopeBIP84, accBIP84Name)
+ require.NoError(t, err)
+
+ scopeBIP49 := waddrmgr.KeyScopeBIP0049Plus
+ accBIP49Name := "test bip49"
+
+ deps.addrStore.On("FetchScopedKeyManager", scopeBIP49).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On("NewAccount", mock.Anything, accBIP49Name).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: accBIP49Name,
+ }, nil).Once()
+
+ _, err = w.NewAccount(t.Context(), scopeBIP49, accBIP49Name)
+ require.NoError(t, err)
+
+ // Mock expectations for ListAccountsByScope (BIP84).
+ deps.addrStore.On("FetchScopedKeyManager", scopeBIP84).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("LastAccount", mock.Anything).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(0)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 0,
+ AccountName: "default",
+ }, nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: accBIP84Name,
+ }, nil).Once()
+
+ deps.txStore.On("UnspentOutputs", mock.Anything).
+ Return([]wtxmgr.Credit(nil), nil).Once()
+
+ // Now, we'll list the accounts for the BIP0084 scope and check that
+ // we only get the default account for that scope and the new account we
+ // created.
+ accountsBIP84, err := w.ListAccountsByScope(t.Context(), scopeBIP84)
+ require.NoError(t, err)
+
+ // We should have two accounts, the default account and the new account.
+ require.Len(t, accountsBIP84.Accounts, 2)
+
+ // The first account should be the default account.
+ require.Equal(t, "default", accountsBIP84.Accounts[0].AccountName)
+ require.Equal(t, uint32(0), accountsBIP84.Accounts[0].AccountNumber)
+
+ // The second account should be the new account.
+ require.Equal(t, accBIP84Name, accountsBIP84.Accounts[1].AccountName)
+ require.Equal(t, uint32(1), accountsBIP84.Accounts[1].AccountNumber)
+
+ // Mock expectations for ListAccountsByScope (BIP49).
+ deps.addrStore.On("FetchScopedKeyManager", scopeBIP49).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("LastAccount", mock.Anything).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(0)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 0,
+ AccountName: "default",
+ }, nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: accBIP49Name,
+ }, nil).Once()
+
+ deps.txStore.On("UnspentOutputs", mock.Anything).
+ Return([]wtxmgr.Credit(nil), nil).Once()
+
+ // Now, we'll do the same for the BIP0049 scope.
+ accountsBIP49, err := w.ListAccountsByScope(t.Context(), scopeBIP49)
+ require.NoError(t, err)
+
+ // We should have two accounts, the default account and the new account.
+ require.Len(t, accountsBIP49.Accounts, 2)
+
+ // The first account should be the default account.
+ require.Equal(t, "default", accountsBIP49.Accounts[0].AccountName)
+ require.Equal(t, uint32(0), accountsBIP49.Accounts[0].AccountNumber)
+
+ // The second account should be the new account.
+ require.Equal(t, accBIP49Name, accountsBIP49.Accounts[1].AccountName)
+ require.Equal(t, uint32(1), accountsBIP49.Accounts[1].AccountNumber)
+}
+
+// TestListAccountsByName tests that the ListAccountsByName method works as
+// expected.
+func TestListAccountsByName(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // We'll create two new accounts, one under the BIP0084 scope and one
+ // under the BIP0049 scope.
+ scopeBIP84 := waddrmgr.KeyScopeBIP0084
+ accBIP84Name := "test bip84"
+
+ deps.addrStore.On("FetchScopedKeyManager", scopeBIP84).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On("NewAccount", mock.Anything, accBIP84Name).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: accBIP84Name,
+ }, nil).Once()
+
+ _, err := w.NewAccount(t.Context(), scopeBIP84, accBIP84Name)
+ require.NoError(t, err)
+
+ scopeBIP49 := waddrmgr.KeyScopeBIP0049Plus
+ accBIP49Name := "test bip49"
+
+ deps.addrStore.On("FetchScopedKeyManager", scopeBIP49).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On("NewAccount", mock.Anything, accBIP49Name).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: accBIP49Name,
+ }, nil).Once()
+
+ _, err = w.NewAccount(t.Context(), scopeBIP49, accBIP49Name)
+ require.NoError(t, err)
+
+ // Mock expectations for ListAccountsByName (BIP84 name).
+ deps.addrStore.On("ActiveScopedKeyManagers").
+ Return([]waddrmgr.AccountStore{deps.accountManager}).Once()
+ deps.accountManager.On("Scope").Return(scopeBIP84).Maybe()
+ deps.accountManager.On("LookupAccount", mock.Anything, accBIP84Name).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: accBIP84Name,
+ }, nil).Once()
+
+ deps.txStore.On("UnspentOutputs", mock.Anything).
+ Return([]wtxmgr.Credit(nil), nil).Times(3)
+
+ // Now, we'll list the accounts for the BIP0084 scope and check that
+ // we only get the default account for that scope and the new account we
+ // created.
+ accountsBIP84, err := w.ListAccountsByName(t.Context(), accBIP84Name)
+ require.NoError(t, err)
+
+ // We should have one account.
+ require.Len(t, accountsBIP84.Accounts, 1)
+
+ // The first account should be the new account.
+ require.Equal(t, accBIP84Name, accountsBIP84.Accounts[0].AccountName)
+ require.Equal(t, uint32(1), accountsBIP84.Accounts[0].AccountNumber)
+
+ // Mock expectations for ListAccountsByName (BIP49 name).
+ deps.addrStore.On("ActiveScopedKeyManagers").
+ Return([]waddrmgr.AccountStore{deps.accountManager}).Once()
+ deps.accountManager.On("Scope").Return(scopeBIP49).Maybe()
+ deps.accountManager.On("LookupAccount", mock.Anything, accBIP49Name).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: accBIP49Name,
+ }, nil).Once()
+
+ // Now, we'll do the same for the BIP0049 scope.
+ accountsBIP49, err := w.ListAccountsByName(t.Context(), accBIP49Name)
+ require.NoError(t, err)
+
+ // We should have one account.
+ require.Len(t, accountsBIP49.Accounts, 1)
+
+ // The first account should be the new account.
+ require.Equal(t, accBIP49Name, accountsBIP49.Accounts[0].AccountName)
+ require.Equal(t, uint32(1), accountsBIP49.Accounts[0].AccountNumber)
+
+ // Mock expectations for non-existent account.
+ deps.addrStore.On("ActiveScopedKeyManagers").
+ Return([]waddrmgr.AccountStore{deps.accountManager}).Once()
+ deps.accountManager.On("Scope").Return(scopeBIP84).Maybe()
+ deps.accountManager.On("LookupAccount", mock.Anything, "non-existent").
+ Return(uint32(0), waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrAccountNotFound,
+ }).Once()
+
+ // We should get an empty result if we query for a non-existent
+ // account.
+ accounts, err := w.ListAccountsByName(t.Context(), "non-existent")
+ require.NoError(t, err)
+ require.Empty(t, accounts.Accounts)
+}
+
+// TestGetAccount tests that the GetAccount method works as expected.
+func TestGetAccount(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // We'll create a new account under the BIP0084 scope.
+ scope := waddrmgr.KeyScopeBIP0084
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On("NewAccount", mock.Anything, testAccountName).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: testAccountName,
+ }, nil).Once()
+
+ _, err := w.NewAccount(t.Context(), scope, testAccountName)
+ require.NoError(t, err)
+
+ // Mock expectations for GetAccount (success).
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("LookupAccount", mock.Anything, testAccountName).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: testAccountName,
+ }, nil).Once()
+
+ deps.txStore.On("UnspentOutputs", mock.Anything).
+ Return([]wtxmgr.Credit(nil), nil).Twice()
+
+ // We should be able to get the new account.
+ account, err := w.GetAccount(t.Context(), scope, testAccountName)
+ require.NoError(t, err)
+ require.Equal(t, testAccountName, account.AccountName)
+ require.Equal(t, uint32(1), account.AccountNumber)
+ require.Equal(t, btcutil.Amount(0), account.TotalBalance)
+
+ // Mock expectations for GetAccount (default account).
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("LookupAccount", mock.Anything, "default").
+ Return(uint32(0), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(0)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 0,
+ AccountName: "default",
+ }, nil).Once()
+
+ // We should also be able to get the default account.
+ account, err = w.GetAccount(t.Context(), scope, "default")
+ require.NoError(t, err)
+ require.Equal(t, "default", account.AccountName)
+ require.Equal(t, uint32(0), account.AccountNumber)
+ require.Equal(t, btcutil.Amount(0), account.TotalBalance)
+
+ // Mock expectations for GetAccount (error path).
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("LookupAccount", mock.Anything, "non-existent").
+ Return(uint32(0), waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrAccountNotFound,
+ }).Once()
+
+ // We should get an error when trying to get a non-existent account.
+ _, err = w.GetAccount(t.Context(), scope, "non-existent")
+ require.Error(t, err)
+ require.True(
+ t, waddrmgr.IsError(err, waddrmgr.ErrAccountNotFound),
+ "expected ErrAccountNotFound",
+ )
+}
+
+// TestRenameAccount tests that the RenameAccount method works as expected.
+func TestRenameAccount(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // We'll create a new account under the BIP0084 scope.
+ scope := waddrmgr.KeyScopeBIP0084
+ oldName := "old name"
+ newName := "new name"
+
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On("NewAccount", mock.Anything, oldName).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: oldName,
+ }, nil).Once()
+
+ _, err := w.NewAccount(t.Context(), scope, oldName)
+ require.NoError(t, err)
+
+ // Mock expectations for RenameAccount.
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("LookupAccount", mock.Anything, oldName).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("RenameAccount", mock.Anything, uint32(1), newName).
+ Return(nil).Once()
+
+ // We should be able to rename the account.
+ err = w.RenameAccount(t.Context(), scope, oldName, newName)
+ require.NoError(t, err)
+
+ // Mock expectations for GetAccount (new name).
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("LookupAccount", mock.Anything, newName).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: newName,
+ }, nil).Once()
+
+ deps.txStore.On("UnspentOutputs", mock.Anything).
+ Return([]wtxmgr.Credit(nil), nil).Once()
+
+ // We should be able to get the account by its new name.
+ account, err := w.GetAccount(t.Context(), scope, newName)
+ require.NoError(t, err)
+ require.Equal(t, newName, account.AccountName)
+
+ // Mock expectations for GetAccount (old name - fail).
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("LookupAccount", mock.Anything, oldName).
+ Return(uint32(0), waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrAccountNotFound,
+ }).Once()
+
+ // We should not be able to get the account by its old name.
+ _, err = w.GetAccount(t.Context(), scope, oldName)
+ require.Error(t, err)
+ require.True(
+ t, waddrmgr.IsError(err, waddrmgr.ErrAccountNotFound),
+ "expected ErrAccountNotFound",
+ )
+
+ // Mock expectations for RenameAccount (duplicate name).
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("LookupAccount", mock.Anything, newName).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On(
+ "RenameAccount", mock.Anything, uint32(1), "default",
+ ).Return(waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrDuplicateAccount,
+ }).Once()
+
+ // We should not be able to rename an account to an existing name.
+ err = w.RenameAccount(t.Context(), scope, newName, "default")
+ require.Error(t, err)
+ require.True(
+ t, waddrmgr.IsError(err, waddrmgr.ErrDuplicateAccount),
+ "expected ErrDuplicateAccount",
+ )
+
+ // Mock expectations for RenameAccount (non-existent account).
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("LookupAccount", mock.Anything, "non-existent").
+ Return(uint32(0), waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrAccountNotFound,
+ }).Once()
+
+ // We should not be able to rename a non-existent account.
+ err = w.RenameAccount(t.Context(), scope, "non-existent", "new name 2")
+ require.Error(t, err)
+ require.True(
+ t, waddrmgr.IsError(err, waddrmgr.ErrAccountNotFound),
+ "expected ErrAccountNotFound",
+ )
+}
+
+// TestBalance tests that the Balance method works as expected.
+func TestBalance(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // We'll create a new account under the BIP0084 scope.
+ scope := waddrmgr.KeyScopeBIP0084
+
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).
+ Once()
+
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On("NewAccount", mock.Anything, testAccountName).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: testAccountName,
+ }, nil).Once()
+
+ _, err := w.NewAccount(t.Context(), scope, testAccountName)
+ require.NoError(t, err)
+
+ // Mock expectations for initial balance (0).
+ deps.txStore.On("UnspentOutputs", mock.Anything).
+ Return([]wtxmgr.Credit(nil), nil).Once()
+
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("LookupAccount", mock.Anything, testAccountName).
+ Return(uint32(1), nil).Once()
+
+ // The balance should be zero initially.
+ balance, err := w.Balance(t.Context(), 1, scope, testAccountName)
+ require.NoError(t, err)
+ require.Equal(t, btcutil.Amount(0), balance)
+
+ // Now, we'll add a UTXO to the account.
+ mockAddr, _ := address.NewAddressWitnessPubKeyHash(
+ make([]byte, 20), w.cfg.ChainParams,
+ )
+ pkScript, err := txscript.PayToAddrScript(mockAddr)
+ require.NoError(t, err)
+
+ // Mock expectations for balance with UTXO.
+ deps.txStore.On("UnspentOutputs", mock.Anything).Return([]wtxmgr.Credit{
+ {
+ Amount: 100,
+ PkScript: pkScript,
+ BlockMeta: wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ Height: 1,
+ },
+ },
+ },
+ }, nil).Once()
+
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+ deps.addrStore.On("AddrAccount", mock.Anything, mockAddr).
+ Return(deps.accountManager, uint32(1), nil).Once()
+
+ deps.accountManager.On("LookupAccount", mock.Anything, testAccountName).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("Scope").Return(scope).Once()
+
+ // The balance should now be 100.
+ balance, err = w.Balance(t.Context(), 1, scope, testAccountName)
+ require.NoError(t, err)
+ require.Equal(t, btcutil.Amount(100), balance)
+
+ // Mock expectations for balance of non-existent account.
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("LookupAccount", mock.Anything, "non-existent").
+ Return(uint32(0), waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrAccountNotFound,
+ }).Once()
+
+ // We should get an error when trying to get the balance of a
+ // non-existent account.
+ _, err = w.Balance(t.Context(), 1, scope, "non-existent")
+ require.Error(t, err)
+ require.True(
+ t, waddrmgr.IsError(err, waddrmgr.ErrAccountNotFound),
+ "expected ErrAccountNotFound",
+ )
+}
+
+// TestImportAccount tests that the ImportAccount works as expected.
+func TestImportAccount(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // We'll start by creating a new account under the BIP0084 scope.
+ scope := waddrmgr.KeyScopeBIP0084
+ addrType := waddrmgr.WitnessPubKey
+ masterPriv := "tprv8ZgxMBicQKsPeWwrFuNjEGTTDSY4mRLwd2KDJAPGa1AY" +
+ "quw38bZqNMSuB3V1Va3hqJBo9Pt8Sx7kBQer5cNMrb8SYquoWPt9" +
+ "Y3BZdhdtUcw"
+ root, err := hdkeychain.NewKeyFromString(masterPriv)
+ require.NoError(t, err)
+ acctPubKey := deriveAcctPubKey(t, root, scope, hardenedKey(0))
+
+ // Mock expectations for ImportAccount.
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("NewAccountWatchingOnly", mock.Anything,
+ testAccountName, acctPubKey, root.ParentFingerprint(),
+ mock.Anything).Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: testAccountName,
+ }, nil).Once()
+
+ // We should be able to import the account.
+ props, err := w.ImportAccount(
+ t.Context(), testAccountName, acctPubKey,
+ root.ParentFingerprint(), addrType, false,
+ )
+ require.NoError(t, err)
+ require.Equal(t, testAccountName, props.AccountName)
+
+ // Mock expectations for GetAccount.
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("LookupAccount", mock.Anything, testAccountName).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: testAccountName,
+ }, nil).Once()
+
+ deps.txStore.On("UnspentOutputs", mock.Anything).
+ Return([]wtxmgr.Credit(nil), nil).Once()
+
+ // We should be able to get the account by its name.
+ _, err = w.GetAccount(t.Context(), scope, testAccountName)
+ require.NoError(t, err)
+
+ // Mock expectations for duplicate ImportAccount.
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("NewAccountWatchingOnly", mock.Anything,
+ testAccountName, acctPubKey, root.ParentFingerprint(),
+ mock.Anything).Return(uint32(0), waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrDuplicateAccount,
+ }).Once()
+
+ // We should not be able to import an account with the same name.
+ _, err = w.ImportAccount(
+ t.Context(), testAccountName, acctPubKey,
+ root.ParentFingerprint(), addrType, false,
+ )
+ require.Error(t, err)
+ require.True(
+ t, waddrmgr.IsError(err, waddrmgr.ErrDuplicateAccount),
+ "expected ErrDuplicateAccount",
+ )
+
+ // Mock expectations for dry-run.
+ dryRunName := "dry run"
+
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("NewAccountWatchingOnly", mock.Anything,
+ dryRunName, acctPubKey, root.ParentFingerprint(),
+ mock.Anything).Return(uint32(2), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(2)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 2,
+ AccountName: dryRunName,
+ }, nil).Twice()
+ deps.accountManager.On("InvalidateAccountCache", uint32(2)).Return().Once()
+ deps.accountManager.On("NextExternalAddresses", mock.Anything, uint32(2),
+ uint32(1)).Return([]waddrmgr.ManagedAddress(nil), nil).Once()
+ deps.accountManager.On("NextInternalAddresses", mock.Anything, uint32(2),
+ uint32(1)).Return([]waddrmgr.ManagedAddress(nil), nil).Once()
+
+ // We should be able to do a dry run of the import.
+ _, err = w.ImportAccount(
+ t.Context(), dryRunName, acctPubKey,
+ root.ParentFingerprint(), addrType, true,
+ )
+ require.NoError(t, err)
+
+ // Mock expectations for GetAccount (fail).
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("LookupAccount", mock.Anything, dryRunName).
+ Return(uint32(0), waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrAccountNotFound,
+ }).Once()
+
+ // The account should not have been imported.
+ _, err = w.GetAccount(t.Context(), scope, dryRunName)
+ require.Error(t, err)
+ require.True(
+ t, waddrmgr.IsError(err, waddrmgr.ErrAccountNotFound),
+ "expected ErrAccountNotFound",
+ )
+}
+
+// TestExtractAddrFromPKScript tests that the extractAddrFromPKScript
+// helper function works as expected.
+func TestExtractAddrFromPKScript(t *testing.T) {
+ t.Parallel()
+
+ w := testWallet(t)
+
+ w.cfg.ChainParams = &chaincfg.MainNetParams
+
+ p2pkhAddr, err := address.DecodeAddress(
+ "17VZNX1SN5NtKa8UQFxwQbFeFc3iqRYhem", w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ p2shAddr, err := address.DecodeAddress(
+ "347N1Thc213QqfYCz3PZkjoJpNv5b14kBd", w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ p2wpkhAddr, err := address.DecodeAddress(
+ "bc1qw508d6qejxtdg4y5r3zarvary0c5xw7kv8f3t4", w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ testCases := []struct {
+ name string
+ script func() []byte
+ addr string
+ }{
+ {
+ name: "p2pkh",
+ script: func() []byte {
+ pkScript, err := txscript.PayToAddrScript(
+ p2pkhAddr,
+ )
+ require.NoError(t, err)
+
+ return pkScript
+ },
+ addr: p2pkhAddr.String(),
+ },
+ {
+ name: "p2sh",
+ script: func() []byte {
+ pkScript, err := txscript.PayToAddrScript(
+ p2shAddr,
+ )
+ require.NoError(t, err)
+
+ return pkScript
+ },
+ addr: p2shAddr.String(),
+ },
+ {
+ name: "p2wpkh",
+ script: func() []byte {
+ pkScript, err := txscript.PayToAddrScript(
+ p2wpkhAddr,
+ )
+ require.NoError(t, err)
+
+ return pkScript
+ },
+ addr: p2wpkhAddr.String(),
+ },
+ {
+ name: "op_return",
+ script: func() []byte {
+ pkScript, err := txscript.NewScriptBuilder().
+ AddOp(txscript.OP_RETURN).
+ AddData([]byte("test")).
+ Script()
+ require.NoError(t, err)
+
+ return pkScript
+ },
+ addr: "",
+ },
+ {
+ name: "invalid script",
+ script: func() []byte { return []byte("invalid") },
+ addr: "",
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.name, func(t *testing.T) {
+ t.Parallel()
+
+ addr := extractAddrFromPKScript(
+ testCase.script(), w.cfg.ChainParams,
+ )
+ if addr == nil {
+ require.Empty(t, testCase.addr)
+ } else {
+ require.Equal(t, testCase.addr, addr.String())
+ }
+ })
+ }
+}
+
+// TestFetchAccountBalances tests that the fetchAccountBalances helper function
+// works as expected.
+func TestFetchAccountBalances(t *testing.T) {
+ t.Parallel()
+
+ // setupTestCase is a helper closure to set up a test case.
+ setupTestCase := func(t *testing.T) (*Wallet, *mockWalletDeps) {
+ t.Helper()
+
+ w, deps := createStartedWalletWithMocks(t)
+
+ // Create accounts.
+ deps.addrStore.On("FetchScopedKeyManager", waddrmgr.KeyScopeBIP0084).
+ Return(deps.accountManager, nil).
+ Once()
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On("NewAccount", mock.Anything, "acc1-bip84").
+ Return(uint32(1), nil).
+ Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: "acc1-bip84",
+ }, nil).
+ Once()
+
+ _, err := w.NewAccount(
+ t.Context(), waddrmgr.KeyScopeBIP0084, "acc1-bip84",
+ )
+ require.NoError(t, err)
+
+ deps.addrStore.On(
+ "FetchScopedKeyManager", waddrmgr.KeyScopeBIP0049Plus,
+ ).Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On("NewAccount", mock.Anything, "acc1-bip49").
+ Return(uint32(1), nil).
+ Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: "acc1-bip49",
+ }, nil).
+ Once()
+
+ _, err = w.NewAccount(
+ t.Context(), waddrmgr.KeyScopeBIP0049Plus, "acc1-bip49",
+ )
+ require.NoError(t, err)
+
+ // Create mock addresses for balance mapping.
+ addr84def, _ := address.NewAddressWitnessPubKeyHash(
+ []byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ w.cfg.ChainParams,
+ )
+ addr84acc1, _ := address.NewAddressWitnessPubKeyHash(
+ []byte{2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ w.cfg.ChainParams,
+ )
+ addr49acc1, _ := address.NewAddressWitnessPubKeyHash(
+ []byte{3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ w.cfg.ChainParams,
+ )
+
+ // Setup persistent mocks for balance calculation.
+ deps.txStore.On("UnspentOutputs", mock.Anything).Return([]wtxmgr.Credit{
+ {
+ Amount: 100,
+ PkScript: mustPayToAddr(addr84def),
+ BlockMeta: wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ Height: 1,
+ },
+ },
+ },
+ {
+ Amount: 200,
+ PkScript: mustPayToAddr(addr84acc1),
+ BlockMeta: wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ Height: 1,
+ },
+ },
+ },
+ {
+ Amount: 300,
+ PkScript: mustPayToAddr(addr49acc1),
+ BlockMeta: wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ Height: 1,
+ },
+ },
+ },
+ }, nil).Once()
+
+ // addr84def -> Default Account (0)
+ deps.addrStore.On("AddrAccount", mock.Anything, addr84def).
+ Return(deps.accountManager, uint32(0), nil).
+ Once()
+
+ // addr84acc1 -> Account 1 (BIP84)
+ deps.addrStore.On("AddrAccount", mock.Anything, addr84acc1).
+ Return(deps.accountManager, uint32(1), nil).
+ Once()
+
+ // addr49acc1 -> Account 1 (BIP49)
+ // We use a different mock account manager to simulate the
+ // different scope.
+ mockAccountStore49 := &mockAccountStore{}
+ mockAccountStore49.On("Scope").
+ Return(waddrmgr.KeyScopeBIP0049Plus).
+ Maybe() // Called varying times depending on filter
+
+ deps.addrStore.On("AddrAccount", mock.Anything, addr49acc1).
+ Return(mockAccountStore49, uint32(1), nil).
+ Once()
+
+ return w, deps
+ }
+
+ testCases := []struct {
+ name string
+ setup func(t *testing.T, w *Wallet, deps *mockWalletDeps)
+ filters []filterOption
+ expectedBalances scopedBalances
+ }{
+ {
+ name: "no filters",
+ setup: func(t *testing.T, w *Wallet, deps *mockWalletDeps) {
+ t.Helper()
+ // Called twice: once for default, once for acc1.
+ deps.accountManager.On("Scope").
+ Return(waddrmgr.KeyScopeBIP0084).
+ Times(2)
+ },
+ filters: nil,
+ expectedBalances: scopedBalances{
+ waddrmgr.KeyScopeBIP0084: {0: 100, 1: 200},
+ waddrmgr.KeyScopeBIP0049Plus: {1: 300},
+ },
+ },
+ {
+ name: "filter by scope",
+ setup: func(t *testing.T, w *Wallet, deps *mockWalletDeps) {
+ t.Helper()
+ // Called 4 times:
+ // 1. Filter check (def) -> Match
+ // 2. Map key (def)
+ // 3. Filter check (acc1) -> Match
+ // 4. Map key (acc1)
+ deps.accountManager.On("Scope").
+ Return(waddrmgr.KeyScopeBIP0084).
+ Times(4)
+ },
+ filters: []filterOption{
+ withScope(waddrmgr.KeyScopeBIP0084),
+ },
+ expectedBalances: scopedBalances{
+ waddrmgr.KeyScopeBIP0084: {0: 100, 1: 200},
+ },
+ },
+ {
+ name: "account with no balance",
+ setup: func(t *testing.T, w *Wallet, deps *mockWalletDeps) {
+ t.Helper()
+
+ // Expect 2 Scope calls for the existing accounts with UTXOs.
+ deps.accountManager.On("Scope").
+ Return(waddrmgr.KeyScopeBIP0084).
+ Times(2)
+
+ deps.addrStore.On(
+ "FetchScopedKeyManager", waddrmgr.KeyScopeBIP0084,
+ ).Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On(
+ "NewAccount", mock.Anything, "no-balance",
+ ).Return(uint32(2), nil).Once()
+ deps.accountManager.On(
+ "AccountProperties", mock.Anything, uint32(2),
+ ).Return(&waddrmgr.AccountProperties{
+ AccountNumber: 2,
+ AccountName: "no-balance",
+ }, nil).Once()
+
+ _, err := w.NewAccount(
+ t.Context(),
+ waddrmgr.KeyScopeBIP0084, "no-balance",
+ )
+ require.NoError(t, err)
+ },
+ filters: nil,
+ expectedBalances: scopedBalances{
+ waddrmgr.KeyScopeBIP0084: {0: 100, 1: 200},
+ waddrmgr.KeyScopeBIP0049Plus: {1: 300},
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ w, deps := setupTestCase(t)
+
+ if tc.setup != nil {
+ tc.setup(t, w, deps)
+ }
+
+ var balances scopedBalances
+
+ err := walletdb.View(
+ w.cfg.DB, func(tx walletdb.ReadTx) error {
+ var err error
+
+ balances, err = w.fetchAccountBalances(
+ tx, tc.filters...,
+ )
+
+ return err
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedBalances, balances)
+ })
+ }
+}
+
+func mustPayToAddr(addr address.Address) []byte {
+ script, _ := txscript.PayToAddrScript(addr)
+ return script
+}
+
+// TestListAccountsWithBalances tests that the listAccountsWithBalances helper
+// function works as expected.
+func TestListAccountsWithBalances(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // We'll create two new accounts under the BIP0084 scope to have a
+ // predictable state.
+ scope := waddrmgr.KeyScopeBIP0084
+ acc1Name := "test account"
+
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On("NewAccount", mock.Anything, acc1Name).
+ Return(uint32(1), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: acc1Name,
+ }, nil).Once()
+
+ _, err := w.NewAccount(t.Context(), scope, acc1Name)
+ require.NoError(t, err)
+
+ acc2Name := "no balance account"
+
+ deps.addrStore.On("FetchScopedKeyManager", scope).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("CanAddAccount").Return(nil).Once()
+ deps.accountManager.On("NewAccount", mock.Anything, acc2Name).
+ Return(uint32(2), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(2)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 2,
+ AccountName: acc2Name,
+ }, nil).Once()
+
+ _, err = w.NewAccount(t.Context(), scope, acc2Name)
+ require.NoError(t, err)
+
+ // We'll now create a balance map for some of the accounts. We
+ // intentionally leave out the second new account to test the zero
+ // balance case.
+ balances := map[uint32]btcutil.Amount{
+ 0: 100, // Default account
+ 1: 200, // "test account"
+ }
+
+ // Now, we'll call listAccountsWithBalances within a read transaction
+ // and verify the results.
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ // Setup mock expectations for listAccountsWithBalances.
+ deps.accountManager.On("LastAccount", mock.Anything).
+ Return(uint32(2), nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(0)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 0,
+ AccountName: "default",
+ }, nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(1)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: acc1Name,
+ }, nil).Once()
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(2)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 2,
+ AccountName: acc2Name,
+ }, nil).Once()
+
+ // Call the function under test.
+ results, err := listAccountsWithBalances(
+ deps.accountManager, addrmgrNs, balances,
+ )
+ require.NoError(t, err)
+
+ // The BIP0084 scope should have three accounts: the default
+ // one and the two we just created.
+ require.Len(t, results, 3, "expected three accounts for scope")
+
+ // Check the default account's result.
+ require.Equal(t, "default", results[0].AccountName)
+ require.Equal(t, uint32(0), results[0].AccountNumber)
+ require.Equal(t, btcutil.Amount(100), results[0].TotalBalance)
+
+ // Check the first new account's result.
+ require.Equal(t, acc1Name, results[1].AccountName)
+ require.Equal(t, uint32(1), results[1].AccountNumber)
+ require.Equal(t, btcutil.Amount(200), results[1].TotalBalance)
+
+ // Check the second new account's result (zero balance).
+ require.Equal(t, acc2Name, results[2].AccountName)
+ require.Equal(t, uint32(2), results[2].AccountNumber)
+ require.Equal(t, btcutil.Amount(0), results[2].TotalBalance)
+
+ return nil
+ })
+ require.NoError(t, err)
+}
diff --git a/wallet/address_manager.go b/wallet/address_manager.go
new file mode 100644
index 0000000000..e3df863355
--- /dev/null
+++ b/wallet/address_manager.go
@@ -0,0 +1,893 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+// Package wallet provides the AddressManager interface for generating and
+// inspecting wallet addresses and scripts.
+//
+// TODO(yy): bring wrapcheck back when implementing the `Store` interface.
+//
+//nolint:wrapcheck
+package wallet
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcec/v2"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcd/psbt/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+)
+
+var (
+ // ErrDerivationPathNotFound is returned when the derivation path for a
+ // given script cannot be found. This may be because the script does
+ // not belong to the wallet, is imported, or is not a pubkey-based
+ // script.
+ ErrDerivationPathNotFound = errors.New("derivation path not found")
+
+ // ErrUnknownAddrType is an error returned when a wallet function is
+ // called with an unknown address type.
+ ErrUnknownAddrType = errors.New("unknown address type")
+
+ // ErrImportedAccountNoAddrGen is an error returned when a new address
+ // is requested for the default imported account within the wallet.
+ ErrImportedAccountNoAddrGen = errors.New("addresses cannot be " +
+ "generated for the default imported account")
+
+ // ErrNotPubKeyAddress is an error returned when a function requires a
+ // public key address, but a different type of address is provided.
+ ErrNotPubKeyAddress = errors.New(
+ "address is not a p2wkh or np2wkh address",
+ )
+
+ // ErrUnableToExtractAddress is returned when an address cannot be
+ // extracted from a pkscript.
+ ErrUnableToExtractAddress = errors.New("unable to extract address")
+
+ // errStopIteration is a special error used to stop the iteration in
+ // ForEachAccountAddress.
+ errStopIteration = errors.New("stop iteration")
+)
+
+// AddressProperty represents an address and its balance.
+type AddressProperty struct {
+ // Address is the address.
+ Address address.Address
+
+ // Balance is the total unspent balance of the address, including both
+ // confirmed and unconfirmed funds.
+ Balance btcutil.Amount
+}
+
+// Script represents the script information required to spend a UTXO.
+type Script struct {
+ // Addr is the managed address of the UTXO.
+ Addr waddrmgr.ManagedAddress
+
+ // WitnessProgram is the witness program of the UTXO.
+ WitnessProgram []byte
+
+ // RedeemScript is the redeem script of the UTXO.
+ RedeemScript []byte
+}
+
+// AddressManager provides an interface for generating and inspecting wallet
+// addresses and scripts.
+type AddressManager interface {
+ // NewAddress returns a new address for the given account and address
+ // type.
+ //
+ // NOTE: This method should be used with caution. Unlike
+ // GetUnusedAddress, it does not scan for previously derived but unused
+ // addresses. Using this method repeatedly can create gaps in the
+ // address chain, which may negatively impact wallet recovery under
+ // BIP44. It is primarily intended for advanced use cases such as bulk
+ // address generation.
+ NewAddress(ctx context.Context, accountName string,
+ addrType waddrmgr.AddressType,
+ change bool) (address.Address, error)
+
+ // GetUnusedAddress returns the first, oldest, unused address by
+ // scanning forward from the start of the derivation path. This method
+ // is the recommended default for obtaining a new receiving address, as
+ // it prevents address reuse and avoids creating gaps in the address
+ // chain that could impact wallet recovery.
+ GetUnusedAddress(ctx context.Context, accountName string,
+ addrType waddrmgr.AddressType, change bool) (
+ address.Address, error)
+
+ // AddressInfo returns detailed information about a managed address. If
+ // the address is not known to the wallet, an error is returned.
+ AddressInfo(ctx context.Context,
+ a address.Address) (waddrmgr.ManagedAddress, error)
+
+ // ListAddresses lists all addresses for a given account, including
+ // their balances.
+ ListAddresses(ctx context.Context, accountName string,
+ addrType waddrmgr.AddressType) ([]AddressProperty, error)
+
+ // ImportPublicKey imports a single public key as a watch-only address.
+ ImportPublicKey(ctx context.Context, pubKey *btcec.PublicKey,
+ addrType waddrmgr.AddressType) error
+
+ // ImportTaprootScript imports a taproot script for tracking and
+ // spending.
+ ImportTaprootScript(ctx context.Context,
+ tapscript waddrmgr.Tapscript) (waddrmgr.ManagedAddress, error)
+
+ // ScriptForOutput returns the address, witness program, and redeem
+ // script for a given UTXO.
+ ScriptForOutput(ctx context.Context, output wire.TxOut) (Script, error)
+
+ // GetDerivationInfo returns the BIP-32 derivation path for a given
+ // address.
+ GetDerivationInfo(ctx context.Context,
+ addr address.Address) (*psbt.Bip32Derivation, error)
+}
+
+// A compile time check to ensure that Wallet implements the interface.
+var _ AddressManager = (*Wallet)(nil)
+
+// NewAddress returns a new address for the given account and address type.
+// This method is a low-level primitive that will always derive a new, unused
+// address from the end of the address chain.
+//
+// It returns the next external or internal address for the wallet dictated by
+// the value of the `change` parameter. If change is true, then an internal
+// address will be returned, otherwise an external address should be returned.
+// The account parameter is the name of the account from which the address
+// should be generated. The addrType parameter specifies the type of address to
+// be generated.
+//
+// NOTE: This method should be used with caution. Unlike GetUnusedAddress, it
+// does not scan for previously derived but unused addresses. Using this method
+// repeatedly can create gaps in the address chain. If a gap of 20 consecutive
+// unused addresses is created, wallet recovery from seed may fail under BIP44.
+// It is primarily intended for advanced use cases such as bulk address
+// generation. For most applications, GetUnusedAddress is the recommended
+// method for obtaining a receiving address.
+//
+// TODO(yy): The current implementation of NewAddress has several architectural
+// issues that should be addressed:
+//
+// 1. **Lack of Separation of Concerns:** The method tightly couples the
+// database logic with the address generation and chain backend
+// notification logic. The `waddrmgr` package currently handles both
+// derivation and persistence within a single database transaction, which
+// makes the transaction larger and longer than necessary.
+//
+// 2. **Incorrect Ordering of Operations:** The current flow is:
+// 1. Create DB transaction.
+// 2. Derive address.
+// 3. Save address to DB.
+// 4. Commit DB transaction.
+// 5. Notify the chain backend to watch the new address.
+// This creates a potential race condition. If the program crashes after
+// committing the address to the database but before successfully
+// notifying the chain backend, the wallet will own an address that the
+// backend is not aware of. This could lead to a permanent loss of funds
+// if coins are sent to that address.
+//
+// Refactoring Plan:
+// - **Decouple `waddrmgr`:** The `waddrmgr` package should be refactored to
+// separate its concerns. It should provide:
+// - A pure, stateless function to derive an address from account info.
+// - A simple method to persist a newly derived address to the database.
+// - **Improve Operation Ordering in `wallet`:** The `NewAddress` method in
+// the `wallet` package should be updated to follow a more robust
+// sequence:
+// 1. Start a DB transaction to read the required account information.
+// 2. Use the pure derivation function from `waddrmgr` to generate the
+// new address *outside* of any DB transaction.
+// 3. Notify the chain backend to watch the new address.
+// 4. If the notification is successful, start a *second*, short-lived DB
+// transaction to persist the new address.
+// This ensures that we only save an address after we are confident that
+// it is being watched by the backend, preventing fund loss.
+func (w *Wallet) NewAddress(_ context.Context, accountName string,
+ addrType waddrmgr.AddressType, change bool) (address.Address, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ // Addresses cannot be derived from the catch-all imported accounts.
+ if accountName == waddrmgr.ImportedAddrAccountName {
+ return nil, ErrImportedAccountNoAddrGen
+ }
+
+ keyScope, err := w.keyScopeFromAddrType(addrType)
+ if err != nil {
+ return nil, err
+ }
+
+ manager, err := w.addrStore.FetchScopedKeyManager(keyScope)
+ if err != nil {
+ return nil, err
+ }
+
+ addr, err := w.newAddress(manager, accountName, change)
+ if err != nil {
+ return nil, err
+ }
+
+ // Notify the rpc server about the newly created address.
+ err = w.cfg.Chain.NotifyReceived([]address.Address{addr})
+ if err != nil {
+ return nil, err
+ }
+
+ return addr, nil
+}
+
+// keyScopeFromAddrType determines the appropriate key scope for a given
+// address type.
+//
+// NOTE: While it may seem intuitive to iterate over the waddrmgr.ScopeAddrMap
+// to act as a single source of truth, doing so is unsafe. The map contains
+// ambiguities where a single address type, such as waddrmgr.WitnessPubKey, can
+// map to multiple key scopes (e.g., KeyScopeBIP0084 and
+// KeyScopeBIP0049Plus). Because map iteration in Go is non-deterministic, this
+// would lead to unpredictable behavior. The switch statement is used here
+// intentionally to enforce a clear, deterministic policy, ensuring that
+// ambiguous types always resolve to their preferred, modern key scope.
+func (w *Wallet) keyScopeFromAddrType(
+ addrType waddrmgr.AddressType) (waddrmgr.KeyScope, error) {
+
+ // Map the requested address type to its key scope.
+ var addrKeyScope waddrmgr.KeyScope
+ switch addrType {
+ case waddrmgr.PubKeyHash:
+ addrKeyScope = waddrmgr.KeyScopeBIP0044
+
+ case waddrmgr.WitnessPubKey:
+ addrKeyScope = waddrmgr.KeyScopeBIP0084
+
+ case waddrmgr.NestedWitnessPubKey:
+ addrKeyScope = waddrmgr.KeyScopeBIP0049Plus
+
+ case waddrmgr.TaprootPubKey:
+ addrKeyScope = waddrmgr.KeyScopeBIP0086
+
+ // The following address types are not supported by this function as
+ // they are not derived from a single public key using a key scope.
+ // They are typically imported or involve more complex script-based
+ // constructions.
+ case waddrmgr.Script, waddrmgr.RawPubKey,
+ waddrmgr.WitnessScript, waddrmgr.TaprootScript:
+ return waddrmgr.KeyScope{}, fmt.Errorf("%w: %v",
+ ErrUnknownAddrType, addrType)
+ default:
+ return waddrmgr.KeyScope{}, fmt.Errorf("%w: %v",
+ ErrUnknownAddrType, addrType)
+ }
+
+ return addrKeyScope, nil
+}
+
+// newAddress returns the next external chained address for a wallet. It
+// wraps the database transaction and the call to the scoped key manager's
+// NewAddress method. The underlying address manager handles its own
+// synchronization to ensure that in-memory state remains consistent with the
+// database, preventing race conditions during address creation.
+func (w *Wallet) newAddress(manager waddrmgr.AccountStore,
+ accountName string, change bool) (address.Address, error) {
+
+ var (
+ addr address.Address
+ err error
+ )
+
+ err = walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ addr, err = manager.NewAddress(addrmgrNs, accountName, change)
+ if err != nil {
+ return err
+ }
+
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return addr, nil
+}
+
+// GetUnusedAddress returns the first, oldest, unused address by scanning
+// forward from the start of the derivation path. The address is considered
+// "unused" if it has never appeared in a transaction. This method is the
+// recommended default for obtaining a new receiving address. It prevents
+// address reuse and avoids creating gaps in the address chain, which is
+// critical for reliable wallet recovery under standards like BIP44 that
+// enforce a gap limit of 20 unused addresses. If all previously derived
+// addresses have been used, this method will delegate to NewAddress to
+// generate a new one.
+//
+// TODO(yy): The current implementation of GetUnusedAddress is inefficient for
+// wallets with a large number of used addresses. It iterates from the first
+// address (index 0) forward until it finds an unused one, resulting in an O(n)
+// complexity where n is the number of used addresses.
+//
+// A potential optimization of scanning backwards from the last derived address
+// is UNSAFE. While faster in the common case, it can create gaps in the
+// address chain. For example, if addresses [0, 1, 3] are used but [2] is not,
+// a backward scan would return a new address after 3, leaving 2 as a gap.
+// This violates the BIP44 gap limit (typically 20) and can lead to fund loss
+// upon wallet recovery from seed, as the recovery process would stop scanning
+// at the gap.
+//
+// The correct optimization is to persist a "first unused address pointer"
+// (e.g., `firstUnusedExternalIndex`) for each account in the database.
+//
+// This would change the logic to:
+// 1. `GetUnusedAddress`: Becomes an O(1) lookup. It reads the index from the
+// database and derives the address at that index.
+// 2. `MarkUsed`: When an address is marked as used, if its index matches the
+// stored pointer, a one-time forward scan is performed to find the next
+// unused address, and the pointer is updated in the database.
+//
+// This moves the expensive scan from the frequent "read" operation to the less
+// frequent "write" operation, providing both performance and safety.
+func (w *Wallet) GetUnusedAddress(ctx context.Context, accountName string,
+ addrType waddrmgr.AddressType, change bool) (address.Address, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ if accountName == waddrmgr.ImportedAddrAccountName {
+ return nil, ErrImportedAccountNoAddrGen
+ }
+
+ keyScope, err := w.keyScopeFromAddrType(addrType)
+ if err != nil {
+ return nil, err
+ }
+
+ manager, err := w.addrStore.FetchScopedKeyManager(keyScope)
+ if err != nil {
+ return nil, err
+ }
+
+ unusedAddr, err := w.findUnusedAddress(manager, accountName, change)
+ // We'll ignore the special error that we use to stop the iteration.
+ if err != nil && !errors.Is(err, errStopIteration) {
+ return nil, err
+ }
+
+ // If we found an unused address, we can return it now.
+ if unusedAddr != nil {
+ return unusedAddr, nil
+ }
+
+ // Otherwise, we'll generate a new one.
+ return w.NewAddress(ctx, accountName, addrType, change)
+}
+
+// findUnusedAddress scans for an unused address for the given account.
+func (w *Wallet) findUnusedAddress(manager waddrmgr.AccountStore,
+ accountName string, change bool) (address.Address, error) {
+
+ var unusedAddr address.Address
+
+ err := walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ // First, look up the account number for the passed account
+ // name.
+ acctNum, err := manager.LookupAccount(addrmgrNs, accountName)
+ if err != nil {
+ return err
+ }
+
+ // Now, iterate through all addresses for the account and
+ // return the first one that is unused.
+ return manager.ForEachAccountAddress(
+ addrmgrNs, acctNum,
+ func(maddr waddrmgr.ManagedAddress) error {
+ // We only want to consider addresses that match
+ // the change parameter.
+ if maddr.Internal() != change {
+ return nil
+ }
+
+ if !maddr.Used(addrmgrNs) {
+ unusedAddr = maddr.Address()
+
+ // Return a special error to signal
+ // that the iteration should be
+ // stopped. This is the idiomatic way
+ // to halt a ForEach* loop in this
+ // codebase.
+ return errStopIteration
+ }
+
+ return nil
+ },
+ )
+ })
+
+ return unusedAddr, err
+}
+
+// AddressInfo returns detailed information regarding a wallet address.
+//
+// This method provides metadata about a managed address, such as its type,
+// derivation path, and whether it's internal or compressed.
+//
+// How it works:
+// The method performs a direct lookup in the address manager to find the
+// requested address.
+//
+// Logical Steps:
+// 1. Initiate a read-only database transaction.
+// 2. Call the underlying address manager's `Address` method to look up the
+// address.
+// 3. Return the managed address information.
+//
+// Database Actions:
+// - This method performs a single read-only database transaction
+// (`walletdb.View`).
+// - It reads from the `waddrmgr` namespace to find the address.
+//
+// Time Complexity:
+// - The operation is a direct database lookup, making its complexity roughly
+// O(1) or O(log N) depending on the database backend's indexing strategy
+// for addresses. It is a very fast operation.
+func (w *Wallet) AddressInfo(_ context.Context,
+ a address.Address) (waddrmgr.ManagedAddress, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ var managedAddress waddrmgr.ManagedAddress
+
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ managedAddress, err = w.addrStore.Address(addrmgrNs, a)
+
+ return err
+ })
+
+ return managedAddress, err
+}
+
+// ListAddresses lists all addresses for a given account, including their
+// balances.
+//
+// This method provides a comprehensive view of all addresses within a
+// specific account, along with their current confirmed balances.
+//
+// How it works:
+// The method first calculates the balances of all UTXOs in the wallet and
+// stores them in a map. It then iterates through all addresses of the
+// specified account and looks up their balance in the map.
+//
+// Logical Steps:
+// 1. Initiate a read-only database transaction.
+// 2. Create a map to store address balances.
+// 3. Iterate through all unspent transaction outputs (UTXOs) in the
+// wallet's `wtxmgr` namespace.
+// 4. For each UTXO, extract the address and add the output's value to the
+// address's balance in the map.
+// 5. Fetch the scoped key manager for the given address type.
+// 6. Look up the account number for the given account name.
+// 7. Iterate through all addresses in that account.
+// 8. For each address, create an `AddressProperty` with the address and its
+// balance from the map.
+// 9. Return the list of `AddressProperty` objects.
+//
+// Database Actions:
+// - This method performs a single read-only database transaction
+// (`walletdb.View`).
+// - It reads from both the `wtxmgr` and `waddrmgr` namespaces.
+//
+// Time Complexity:
+// - The complexity is O(U + A), where U is the number of unspent
+// transaction outputs in the wallet and A is the number of addresses in
+// the specified account. This is because it iterates through all UTXOs to
+// build the balance map and then iterates through all account addresses.
+func (w *Wallet) ListAddresses(_ context.Context, accountName string,
+ addrType waddrmgr.AddressType) ([]AddressProperty, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ var properties []AddressProperty
+
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ // First, we'll create a map of address to balance by iterating
+ // through all the unspent outputs.
+ addrToBalance := make(map[string]btcutil.Amount)
+
+ utxos, err := w.txStore.UnspentOutputs(txmgrNs)
+ if err != nil {
+ return err
+ }
+
+ for _, utxo := range utxos {
+ addr := extractAddrFromPKScript(
+ utxo.PkScript, w.cfg.ChainParams,
+ )
+ if addr == nil {
+ continue
+ }
+
+ addrToBalance[addr.String()] += utxo.Amount
+ }
+
+ keyScope, err := w.keyScopeFromAddrType(addrType)
+ if err != nil {
+ return err
+ }
+
+ manager, err := w.addrStore.FetchScopedKeyManager(keyScope)
+ if err != nil {
+ return err
+ }
+
+ acctNum, err := manager.LookupAccount(addrmgrNs, accountName)
+ if err != nil {
+ return err
+ }
+
+ return manager.ForEachAccountAddress(addrmgrNs, acctNum,
+ func(maddr waddrmgr.ManagedAddress) error {
+ addr := maddr.Address()
+ properties = append(properties, AddressProperty{
+ Address: addr,
+ Balance: addrToBalance[addr.String()],
+ })
+
+ return nil
+ })
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return properties, nil
+}
+
+// ImportPublicKey imports a single public key as a watch-only address.
+//
+// This method allows the wallet to track transactions related to a specific
+// public key without having access to the corresponding private key. This is
+// useful for monitoring addresses without compromising their security.
+//
+// How it works:
+// The method determines the appropriate key scope based on the provided
+// address type and then uses the corresponding scoped key manager to import
+// the public key.
+//
+// Logical Steps:
+// 1. Determine the key scope from the address type (e.g., P2WKH, NP2WKH).
+// 2. Fetch the scoped key manager for that scope.
+// 3. Initiate a database transaction.
+// 4. Within the transaction, call the underlying address manager's
+// ImportPublicKey method to store the key.
+// 5. Commit the transaction.
+//
+// Database Actions:
+// - This method performs a single database write transaction
+// (`walletdb.Update`).
+// - It stores the public key and its associated address information within
+// the `waddrmgr` namespace.
+//
+// Time Complexity:
+// - The operation is dominated by the database write, making its complexity
+// roughly O(1) or O(log N) depending on the database backend's indexing
+// strategy for keys. It is generally a fast operation.
+func (w *Wallet) ImportPublicKey(_ context.Context, pubKey *btcec.PublicKey,
+ addrType waddrmgr.AddressType) error {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return err
+ }
+
+ keyScope, err := w.keyScopeFromAddrType(addrType)
+ if err != nil {
+ return err
+ }
+
+ manager, err := w.addrStore.FetchScopedKeyManager(keyScope)
+ if err != nil {
+ return err
+ }
+
+ var addr address.Address
+
+ err = walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ ma, err := manager.ImportPublicKey(addrmgrNs, pubKey, nil)
+ if err != nil {
+ return err
+ }
+
+ addr = ma.Address()
+
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+
+ return w.cfg.Chain.NotifyReceived([]address.Address{addr})
+}
+
+// ImportTaprootScript imports a taproot script for tracking and spending.
+//
+// This method allows the wallet to import a taproot script, which is
+// necessary for spending from or tracking a taproot address.
+//
+// How it works:
+// The method uses the BIP-0086 key scope to fetch the taproot-specific
+// scoped key manager. It then calls the underlying manager's
+// ImportTaprootScript method to store the script information.
+//
+// Logical Steps:
+// 1. Fetch the scoped key manager for the taproot key scope (BIP-0086).
+// 2. Initiate a database transaction.
+// 3. Within the transaction, get the wallet's current sync state to use as
+// the "birthday" for the new script.
+// 4. Call the underlying address manager's ImportTaprootScript method.
+// 5. Commit the transaction.
+//
+// Database Actions:
+// - This method performs a single database write transaction
+// (`walletdb.Update`).
+// - It stores the taproot script and its derived address information within
+// the `waddrmgr` namespace.
+//
+// Time Complexity:
+// - Similar to ImportPublicKey, this operation is dominated by a database
+// write, making it a fast operation with a complexity of roughly O(1).
+func (w *Wallet) ImportTaprootScript(_ context.Context,
+ tapscript waddrmgr.Tapscript) (waddrmgr.ManagedAddress, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ manager, err := w.addrStore.FetchScopedKeyManager(
+ waddrmgr.KeyScopeBIP0086,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ var addr waddrmgr.ManagedAddress
+
+ err = walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ syncedTo := w.addrStore.SyncedTo()
+ addr, err = manager.ImportTaprootScript(
+ ns, &tapscript, &syncedTo, 1, false,
+ )
+
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ err = w.cfg.Chain.NotifyReceived([]address.Address{addr.Address()})
+ if err != nil {
+ return nil, err
+ }
+
+ return addr, nil
+}
+
+// ScriptForOutput returns the address, witness program, and redeem script
+// for a given UTXO.
+//
+// This method is essential for constructing the necessary scripts to spend a
+// transaction output. It provides the components required to build the
+// scriptSig and witness fields of a transaction input.
+//
+// How it works:
+// The method first identifies which of the wallet's addresses corresponds to
+// the output's script. It then determines the correct script format (redeem
+// script, witness program) based on the address type.
+//
+// Logical Steps:
+// 1. Look up the output's pkScript in the database to find the
+// corresponding managed address.
+// 2. Verify that the address is a public key address that the wallet can
+// sign for (e.g., P2WKH, NP2WKH, P2TR).
+// 3. Based on the address type, construct the appropriate scripts:
+// - For nested P2WKH (NP2WKH), it creates a redeem script
+// (`sigScript`) that contains the P2WKH witness program.
+// - For native SegWit outputs (P2WKH, P2TR), the `witnessProgram` is
+// the output's `pkScript`, and the `sigScript` is nil.
+//
+// Database Actions:
+// - This method performs a read-only database access to fetch address
+// details from the `waddrmgr` namespace.
+//
+// Time Complexity:
+// - The operation is dominated by the database lookup for the address, which
+// is typically fast (O(log N) or O(1) with indexing). The script
+// generation is a constant-time operation.
+func (w *Wallet) ScriptForOutput(ctx context.Context, output wire.TxOut) (
+ Script, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return Script{}, err
+ }
+
+ // First, we'll extract the address from the output's pkScript.
+ addr := extractAddrFromPKScript(output.PkScript, w.cfg.ChainParams)
+ if addr == nil {
+ return Script{}, fmt.Errorf("%w: from pkscript %x",
+ ErrUnableToExtractAddress, output.PkScript)
+ }
+
+ // We'll then use the address to look up the managed address from the
+ // database.
+ managedAddr, err := w.AddressInfo(ctx, addr)
+ if err != nil {
+ return Script{}, fmt.Errorf("unable to get address info "+
+ "for %s: %w", addr.String(), err)
+ }
+
+ pubKeyAddr, ok := managedAddr.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ return Script{}, fmt.Errorf("%w: addr %s",
+ ErrNotPubKeyAddress, managedAddr.Address())
+ }
+
+ witnessProgram, redeemScript, err := buildScriptsForManagedAddress(
+ pubKeyAddr, output.PkScript, w.cfg.ChainParams,
+ )
+ if err != nil {
+ return Script{}, err
+ }
+
+ return Script{
+ Addr: managedAddr,
+ WitnessProgram: witnessProgram,
+ RedeemScript: redeemScript,
+ }, nil
+}
+
+// buildScriptsForManagedAddress constructs the witness and redeem scripts for a
+// given managed public key address and its corresponding pkScript.
+func buildScriptsForManagedAddress(pubKeyAddr waddrmgr.ManagedPubKeyAddress,
+ pkScript []byte, chainParams *chaincfg.Params) ([]byte, []byte, error) {
+
+ var (
+ witnessProgram []byte
+ redeemScript []byte
+ )
+
+ switch {
+ // If we're spending p2wkh output nested within a p2sh output, then
+ // we'll need to attach a sigScript in addition to witness data.
+ case pubKeyAddr.AddrType() == waddrmgr.NestedWitnessPubKey:
+ pubKey := pubKeyAddr.PubKey()
+ pubKeyHash := address.Hash160(pubKey.SerializeCompressed())
+
+ // Next, we'll generate a valid sigScript that will allow us to
+ // spend the p2sh output. The sigScript will contain only a
+ // single push of the p2wkh witness program corresponding to
+ // the matching public key of this address.
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ pubKeyHash, chainParams,
+ )
+ if err != nil {
+ return nil, nil, err
+ }
+
+ witnessProgram, err = txscript.PayToAddrScript(p2wkhAddr)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ bldr := txscript.NewScriptBuilder()
+ bldr.AddData(witnessProgram)
+
+ redeemScript, err = bldr.Script()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Otherwise, this is a regular p2wkh or p2tr output, so we include the
+ // witness program itself as the subscript to generate the proper
+ // sighash digest. As part of the new sighash digest algorithm, the
+ // p2wkh witness program will be expanded into a regular p2kh
+ // script.
+ default:
+ witnessProgram = pkScript
+ }
+
+ return witnessProgram, redeemScript, nil
+}
+
+// GetDerivationInfo returns the BIP-32 derivation path for a given address.
+func (w *Wallet) GetDerivationInfo(ctx context.Context,
+ addr address.Address) (*psbt.Bip32Derivation, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ // We'll use the address to look up the derivation path.
+ managedAddr, err := w.AddressInfo(ctx, addr)
+ if err != nil {
+ return nil, err
+ }
+
+ // We only care about pubkey addresses, as they are the only
+ // ones with derivation paths.
+ pubKeyAddr, ok := managedAddr.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ return nil, fmt.Errorf("%w: addr=%v not found",
+ ErrDerivationPathNotFound, addr)
+ }
+
+ return derivationForManagedAddress(pubKeyAddr)
+}
+
+// derivationForManagedAddress constructs a PSBT Bip32Derivation struct from a
+// managed public key address.
+func derivationForManagedAddress(pubKeyAddr waddrmgr.ManagedPubKeyAddress) (
+ *psbt.Bip32Derivation, error) {
+
+ // Imported addresses don't have derivation paths.
+ if pubKeyAddr.Imported() {
+ return nil, fmt.Errorf("%w: addr=%v is imported",
+ ErrDerivationPathNotFound, pubKeyAddr.Address())
+ }
+
+ // Get the derivation info.
+ keyScope, derivPath, ok := pubKeyAddr.DerivationInfo()
+ if !ok {
+ return nil, fmt.Errorf("%w: derivation info not found for %v",
+ ErrDerivationPathNotFound, pubKeyAddr.Address())
+ }
+
+ // Get the public key.
+ pubKey := pubKeyAddr.PubKey()
+
+ derivationInfo := &psbt.Bip32Derivation{
+ PubKey: pubKey.SerializeCompressed(),
+ MasterKeyFingerprint: derivPath.MasterKeyFingerprint,
+ Bip32Path: []uint32{
+ keyScope.Purpose + hdkeychain.HardenedKeyStart,
+ keyScope.Coin + hdkeychain.HardenedKeyStart,
+ derivPath.Account,
+ derivPath.Branch,
+ derivPath.Index,
+ },
+ }
+
+ return derivationInfo, nil
+}
diff --git a/wallet/address_manager_benchmark_test.go b/wallet/address_manager_benchmark_test.go
new file mode 100644
index 0000000000..8032ea38af
--- /dev/null
+++ b/wallet/address_manager_benchmark_test.go
@@ -0,0 +1,770 @@
+package wallet
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/stretchr/testify/require"
+)
+
+// BenchmarkListAddressesAPI benchmarks ListAddresses API and a deprecated
+// variant of it using same key scope and identical test data across multiple
+// dataset sizes. Test names start with dataset size to group API comparisons
+// for benchstat analysis.
+func BenchmarkListAddressesAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ utxoGrowthPadding = decimalWidth(
+ utxoGrowth[len(utxoGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0044}
+
+ addrType = waddrmgr.PubKeyHash
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ accountName, accountNumber := generateAccountName(
+ accountGrowth[i], scopes,
+ )
+
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-Addresses-%0*d-UTXOs",
+ accountGrowthPadding, accountGrowth[i],
+ addressGrowthPadding, addressGrowth[i],
+ utxoGrowthPadding, utxoGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := listAddressesDeprecated(
+ bw.Wallet, accountNumber,
+ )
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := w.ListAddresses(
+ b.Context(), accountName, addrType,
+ )
+ require.NoError(b, err)
+ }
+ })
+ }
+}
+
+// BenchmarkAddressInfoAPI benchmarks AddressInfo API and its deprecated
+// variant using same key scope and identical test data across multiple
+// dataset sizes. Test names start with dataset size to group API comparisons
+// for benchstat analysis.
+func BenchmarkAddressInfoAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-Addresses",
+ accountGrowthPadding, accountGrowth[i],
+ addressGrowthPadding, addressGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ testAddr := getTestAddress(
+ b, bw.Wallet, accountGrowth[i],
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := bw.AddressInfoDeprecated(testAddr)
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ testAddr := getTestAddress(
+ b, bw.Wallet, accountGrowth[i],
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := bw.AddressInfo(b.Context(), testAddr)
+ require.NoError(b, err)
+ }
+ })
+ }
+}
+
+// BenchmarkGetUnusedAddressAPI benchmarks GetUnusedAddress API and its
+// deprecated variant NewAddressDeprecated using same key scope and identical
+// address datasets across multiple dataset sizes. Test names start with dataset
+// size to group API comparisons for benchstat analysis. The benchmark
+// demonstrates the trade-off between performance (O(1) vs O(n)) and safety
+// (preventing address reuse and BIP44 gap limit violations).
+func BenchmarkGetUnusedAddressAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0044}
+
+ addrType = waddrmgr.PubKeyHash
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ accountName, accountNumber := generateAccountName(
+ accountGrowth[i], scopes,
+ )
+
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-Addresses",
+ accountGrowthPadding, accountGrowth[i],
+ addressGrowthPadding, addressGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ addr, err := bw.NewAddressDeprecated(
+ accountNumber, scopes[0],
+ )
+ require.NoError(b, err)
+
+ // Mark the address as used to make the
+ // benchmark iteration idempotent.
+ markAddressAsUsed(b, bw.Wallet, addr)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ addr, err := bw.GetUnusedAddress(
+ b.Context(), accountName, addrType,
+ false,
+ )
+ require.NoError(b, err)
+
+ // Mark the address as used to make the
+ // benchmark iteration idempotent.
+ markAddressAsUsed(b, bw.Wallet, addr)
+ }
+ })
+ }
+}
+
+// BenchmarkNewAddressAPI benchmarks NewAddress API and its deprecated variant
+// NewAddressDeprecated using same key scope and identical address datasets
+// across multiple dataset sizes. Test names start with dataset size to group
+// API comparisons for benchstat analysis. The benchmark demonstrates that the
+// new API maintains performance parity with the deprecated API.
+func BenchmarkNewAddressAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0044}
+
+ addrType = waddrmgr.PubKeyHash
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ accountName, accountNumber := generateAccountName(
+ accountGrowth[i], scopes,
+ )
+
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-Addresses",
+ accountGrowthPadding, accountGrowth[i],
+ addressGrowthPadding, addressGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := bw.NewAddressDeprecated(
+ accountNumber, scopes[0],
+ )
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := w.NewAddress(
+ b.Context(), accountName, addrType,
+ false,
+ )
+ require.NoError(b, err)
+ }
+ })
+ }
+}
+
+// BenchmarkImportPublicKeyAPI benchmarks ImportPublicKey API and its deprecated
+// variant ImportPublicKeyDeprecated using identical public key datasets across
+// multiple dataset sizes. Test names start with dataset size to group API
+// comparisons for benchstat analysis. The benchmark demonstrates that the new
+// API maintains performance parity with the deprecated API.
+func BenchmarkImportPublicKeyAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+
+ addrType = waddrmgr.WitnessPubKey
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-Addresses",
+ accountGrowthPadding, accountGrowth[i],
+ addressGrowthPadding, addressGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ iterCount := 0
+ for b.Loop() {
+ // Generate a unique key for each iteration to
+ // avoid in-memory cache collision and for an
+ // idempotent benchmark iteration test.
+ seedIndex := accountGrowth[i] + iterCount
+ key, _, _ := generateTestExtendedKey(
+ b, seedIndex,
+ )
+ pubKey, err := key.ECPubKey()
+ require.NoError(b, err)
+
+ err = w.ImportPublicKeyDeprecated(
+ pubKey, addrType,
+ )
+ require.NoError(b, err)
+
+ iterCount++
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ iterCount := 0
+ for b.Loop() {
+ // Generate a unique key for each iteration to
+ // avoid in-memory cache collision and for an
+ // idempotent benchmark iteration test.
+ seedIndex := accountGrowth[i] + iterCount
+ key, _, _ := generateTestExtendedKey(
+ b, seedIndex,
+ )
+ pubKey, err := key.ECPubKey()
+ require.NoError(b, err)
+
+ err = w.ImportPublicKey(
+ b.Context(), pubKey, addrType,
+ )
+ require.NoError(b, err)
+
+ iterCount++
+ }
+ })
+ }
+}
+
+// BenchmarkImportTaprootScriptAPI benchmarks ImportTaprootScript API and its
+// deprecated variant ImportTaprootScriptDeprecated using identical tapscript
+// datasets across multiple dataset sizes. Test names start with dataset size
+// to group API comparisons for benchstat analysis. The benchmark demonstrates
+// that the new API maintains performance parity with the deprecated API.
+func BenchmarkImportTaprootScriptAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 10
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0086}
+
+ witnessVersion = 1
+
+ isSecretScript = false
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-Addresses",
+ accountGrowthPadding, accountGrowth[i],
+ addressGrowthPadding, addressGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ iterCount := 0
+ for b.Loop() {
+ // Generate a unique tapscript for each
+ // iteration to avoid in-memory cache collision
+ // and for an idempotent benchmark iteration
+ // test.
+ seedIndex := accountGrowth[i] + iterCount
+ key, _, _ := generateTestExtendedKey(
+ b, seedIndex,
+ )
+ pubKey, err := key.ECPubKey()
+ require.NoError(b, err)
+
+ tapscript := generateTestTapscript(b, pubKey)
+
+ syncedTo := w.addrStore.SyncedTo()
+ _, err = w.ImportTaprootScriptDeprecated(
+ scopes[0], &tapscript, &syncedTo,
+ byte(witnessVersion), isSecretScript,
+ )
+ require.NoError(b, err)
+
+ iterCount++
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ iterCount := 0
+ for b.Loop() {
+ // Generate a unique tapscript for each
+ // iteration to avoid in-memory cache collision
+ // and for an idempotent benchmark iteration
+ // test.
+ seedIndex := accountGrowth[i] + iterCount
+ key, _, _ := generateTestExtendedKey(
+ b, seedIndex,
+ )
+ pubKey, err := key.ECPubKey()
+ require.NoError(b, err)
+
+ tapscript := generateTestTapscript(b, pubKey)
+
+ _, err = w.ImportTaprootScript(
+ b.Context(), tapscript,
+ )
+ require.NoError(b, err)
+
+ iterCount++
+ }
+ })
+ }
+}
+
+// BenchmarkScriptForOutputAPI benchmarks ScriptForOutput API and its deprecated
+// variant ScriptForOutputDeprecated using identical TxOut datasets across
+// multiple dataset sizes. Test names start with dataset size to group API
+// comparisons for benchstat analysis. The benchmark demonstrates that the new
+// API maintains performance parity with the deprecated API.
+func BenchmarkScriptForOutputAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 10
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ utxoGrowthPadding = decimalWidth(
+ utxoGrowth[len(utxoGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-Addresses-%0*d-UTXOs",
+ accountGrowthPadding, accountGrowth[i],
+ addressGrowthPadding, addressGrowth[i],
+ utxoGrowthPadding, utxoGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ testAddr := getTestAddress(
+ b, bw.Wallet, accountGrowth[i],
+ )
+ testTxOut := generateTestTxOut(b, testAddr)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, _, _, err := bw.ScriptForOutputDeprecated(
+ &testTxOut,
+ )
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ testAddr := getTestAddress(
+ b, bw.Wallet, accountGrowth[i],
+ )
+ testTxOut := generateTestTxOut(b, testAddr)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := bw.ScriptForOutput(
+ b.Context(), testTxOut,
+ )
+ require.NoError(b, err)
+ }
+ })
+ }
+}
diff --git a/wallet/address_manager_test.go b/wallet/address_manager_test.go
new file mode 100644
index 0000000000..ab857342bf
--- /dev/null
+++ b/wallet/address_manager_test.go
@@ -0,0 +1,910 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wallet
+
+import (
+ "testing"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcec/v2"
+ "github.com/btcsuite/btcd/btcec/v2/schnorr"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+// TestKeyScopeFromAddrType tests the keyScopeFromAddrType method to ensure
+// it correctly maps address types to their corresponding key scopes.
+func TestKeyScopeFromAddrType(t *testing.T) {
+ t.Parallel()
+
+ testCases := []struct {
+ name string
+ addrType waddrmgr.AddressType
+ expectedScope waddrmgr.KeyScope
+ expectedErr error
+ }{
+ {
+ name: "pubkey hash",
+ addrType: waddrmgr.PubKeyHash,
+ expectedScope: waddrmgr.KeyScopeBIP0044,
+ expectedErr: nil,
+ },
+ {
+ name: "witness pubkey",
+ addrType: waddrmgr.WitnessPubKey,
+ expectedScope: waddrmgr.KeyScopeBIP0084,
+ expectedErr: nil,
+ },
+ {
+ name: "nested witness pubkey",
+ addrType: waddrmgr.NestedWitnessPubKey,
+ expectedScope: waddrmgr.KeyScopeBIP0049Plus,
+ expectedErr: nil,
+ },
+ {
+ name: "taproot pubkey",
+ addrType: waddrmgr.TaprootPubKey,
+ expectedScope: waddrmgr.KeyScopeBIP0086,
+ expectedErr: nil,
+ },
+ {
+ name: "unknown address type",
+ addrType: waddrmgr.WitnessScript,
+ expectedErr: ErrUnknownAddrType,
+ },
+ }
+
+ w := &Wallet{
+ cfg: Config{},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ scope, err := w.keyScopeFromAddrType(tc.addrType)
+ require.ErrorIs(t, err, tc.expectedErr)
+ require.Equal(t, tc.expectedScope, scope)
+ })
+ }
+}
+
+// TestNewAddress tests the NewAddress method, ensuring it can generate
+// various address types for different accounts and correctly handles both
+// internal and external address generation.
+func TestNewAddress(t *testing.T) {
+ t.Parallel()
+
+ // Define a set of test cases to cover different address types and
+ // scenarios.
+ testCases := []struct {
+ name string
+ accountName string
+ addrType waddrmgr.AddressType
+ change bool
+ expectErr bool
+ expectedAddrType address.Address
+ }{
+ {
+ name: "default account p2wkh",
+ accountName: "default",
+ addrType: waddrmgr.WitnessPubKey,
+ change: false,
+ expectedAddrType: &address.AddressWitnessPubKeyHash{},
+ },
+ {
+ name: "p2wkh change address",
+ accountName: "default",
+ addrType: waddrmgr.WitnessPubKey,
+ change: true,
+ expectedAddrType: &address.AddressWitnessPubKeyHash{},
+ },
+ {
+ name: "default account np2wkh",
+ accountName: "default",
+ addrType: waddrmgr.NestedWitnessPubKey,
+ change: false,
+ expectedAddrType: &address.AddressScriptHash{},
+ },
+ {
+ name: "default account p2tr",
+ accountName: "default",
+ addrType: waddrmgr.TaprootPubKey,
+ change: false,
+ expectedAddrType: &address.AddressTaproot{},
+ },
+ {
+ name: "unknown address type",
+ accountName: "default",
+ addrType: waddrmgr.WitnessScript,
+ expectErr: true,
+ },
+ {
+ name: "imported account",
+ accountName: waddrmgr.ImportedAddrAccountName,
+ addrType: waddrmgr.WitnessPubKey,
+ expectErr: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet for each test case.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // Setup mock expectations.
+ if tc.expectErr {
+ // Attempt to generate a new address with the specified
+ // parameters.
+ _, err := w.NewAddress(
+ t.Context(), tc.accountName,
+ tc.addrType, tc.change,
+ )
+ require.Error(t, err)
+
+ return
+ }
+
+ var addr address.Address
+ switch tc.addrType {
+ case waddrmgr.WitnessPubKey:
+ addr, _ = address.NewAddressWitnessPubKeyHash(
+ make([]byte, 20), w.cfg.ChainParams,
+ )
+ case waddrmgr.NestedWitnessPubKey:
+ addr, _ = address.NewAddressScriptHash(
+ make([]byte, 20), w.cfg.ChainParams,
+ )
+ case waddrmgr.TaprootPubKey:
+ addr, _ = address.NewAddressTaproot(
+ make([]byte, 32), w.cfg.ChainParams,
+ )
+ case waddrmgr.PubKeyHash, waddrmgr.Script,
+ waddrmgr.RawPubKey, waddrmgr.WitnessScript,
+ waddrmgr.TaprootScript:
+
+ require.FailNow(t, "unhandled address type", tc.addrType)
+
+ default:
+ require.FailNow(t, "unknown address type", tc.addrType)
+ }
+
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).
+ Once()
+ deps.addrStore.On("Address", mock.Anything, addr).
+ Return(deps.addr, nil).
+ Once()
+
+ deps.accountManager.On(
+ "NewAddress", mock.Anything, tc.accountName, tc.change,
+ ).Return(addr, nil).Once()
+
+ deps.chain.On("NotifyReceived", []address.Address{addr}).
+ Return(nil).
+ Once()
+
+ deps.addr.On("Internal").Return(tc.change).Once()
+
+ // Attempt to generate a new address with the specified
+ // parameters.
+ addr, err := w.NewAddress(
+ t.Context(), tc.accountName,
+ tc.addrType, tc.change,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, addr)
+
+ // Verify that the address is of the correct type.
+ require.IsType(t, tc.expectedAddrType, addr)
+
+ // Verify that the address is correctly marked as
+ // internal or external.
+ addrInfo, err := w.AddressInfo(t.Context(), addr)
+ require.NoError(t, err)
+ require.Equal(t, tc.change, addrInfo.Internal())
+ })
+ }
+}
+
+// TestGetUnusedAddress tests the GetUnusedAddress method to ensure it
+// correctly returns the earliest unused address.
+func TestGetUnusedAddress(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // Get a new address to start with.
+ mockAddr, _ := address.NewAddressWitnessPubKeyHash(
+ make([]byte, 20), w.cfg.ChainParams,
+ )
+
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("NewAddress", mock.Anything, "default", false).
+ Return(mockAddr, nil).Once()
+ deps.chain.On("NotifyReceived", []address.Address{mockAddr}).
+ Return(nil).Once()
+
+ addr, err := w.NewAddress(
+ t.Context(), "default", waddrmgr.WitnessPubKey, false,
+ )
+ require.NoError(t, err)
+
+ // The first unused address should be the one we just created.
+ // GetUnusedAddress calls:
+ // - w.keyScopeFromAddrType
+ // - w.addrStore.FetchScopedKeyManager
+ // - w.findUnusedAddress (calls manager.LookupAccount and
+ // ForEachAccountAddress)
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("LookupAccount", mock.Anything, "default").
+ Return(uint32(0), nil).Once()
+
+ deps.accountManager.On("ForEachAccountAddress", mock.Anything, uint32(0),
+ mock.Anything).Run(func(args mock.Arguments) {
+ f, ok := args.Get(2).(func(waddrmgr.ManagedAddress) error)
+ require.True(t, ok)
+ mockAddr1 := &mockManagedAddress{}
+ mockAddr1.On("Internal").Return(false).Once()
+ mockAddr1.On("Used", mock.Anything).Return(false).Once()
+ mockAddr1.On("Address").Return(addr).Once()
+ _ = f(mockAddr1)
+ }).Return(errStopIteration).Once()
+
+ unusedAddr, err := w.GetUnusedAddress(
+ t.Context(), "default", waddrmgr.WitnessPubKey, false,
+ )
+ require.NoError(t, err)
+ require.Equal(t, addr.String(), unusedAddr.String())
+
+ // "Use" the address by creating a fake UTXO for it.
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ // Setup expectations for using the address.
+ deps.txStore.On("InsertTx", mock.Anything, mock.Anything,
+ mock.Anything).Return(nil).Once()
+ deps.txStore.On("AddCredit", mock.Anything, mock.Anything,
+ mock.Anything, uint32(0), false).Return(nil).Once()
+ deps.addrStore.On("MarkUsed", mock.Anything, addr).Return(nil).Once()
+
+ // We need to create a realistic transaction that has at least one
+ // input.
+ err = walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ txmgrNs := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+
+ // Create a new transaction and set the output to the address
+ // we want to mark as used.
+ msgTx := TstTx.MsgTx()
+ msgTx.TxOut = []*wire.TxOut{{
+ PkScript: pkScript,
+ Value: 1000,
+ }}
+
+ rec, err := wtxmgr.NewTxRecordFromMsgTx(msgTx, time.Now())
+ if err != nil {
+ return err
+ }
+
+ err = w.txStore.InsertTx(txmgrNs, rec, nil)
+ if err != nil {
+ return err
+ }
+
+ err = w.txStore.AddCredit(txmgrNs, rec, nil, 0, false)
+ if err != nil {
+ return err
+ }
+
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ return w.addrStore.MarkUsed(addrmgrNs, addr)
+ })
+ require.NoError(t, err)
+
+ // Get the next unused address.
+ // This time findUnusedAddress will find the first address as used, and
+ // then we mock it returning nil for any more existing addresses,
+ // triggering a NewAddress call.
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).Twice()
+
+ deps.accountManager.On("LookupAccount", mock.Anything, "default").
+ Return(uint32(0), nil).Once()
+
+ deps.accountManager.On("ForEachAccountAddress", mock.Anything, uint32(0),
+ mock.Anything).Run(func(args mock.Arguments) {
+ f, ok := args.Get(2).(func(waddrmgr.ManagedAddress) error)
+ require.True(t, ok)
+
+ // First addr is used.
+ mockAddr1 := &mockManagedAddress{}
+ mockAddr1.On("Internal").Return(false).Once()
+ mockAddr1.On("Used", mock.Anything).Return(true).Once()
+ _ = f(mockAddr1)
+ }).Return(nil).Once()
+
+ nextAddrVal, _ := address.NewAddressWitnessPubKeyHash(
+ []byte{
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
+ 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
+ }, w.cfg.ChainParams,
+ )
+ deps.accountManager.On("NewAddress", mock.Anything, "default", false).
+ Return(nextAddrVal, nil).Once()
+ deps.chain.On("NotifyReceived", []address.Address{nextAddrVal}).
+ Return(nil).Once()
+
+ nextAddr, err := w.GetUnusedAddress(
+ t.Context(), "default", waddrmgr.WitnessPubKey, false,
+ )
+ require.NoError(t, err)
+
+ // The next unused address should not be the same as the first one.
+ require.NotEqual(t, addr.String(), nextAddr.String())
+
+ // Now, let's test the change address.
+ changeAddrVal, _ := address.NewAddressWitnessPubKeyHash(
+ []byte{
+ 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
+ 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
+ }, w.cfg.ChainParams,
+ )
+
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("NewAddress", mock.Anything, "default", true).
+ Return(changeAddrVal, nil).Once()
+ deps.chain.On("NotifyReceived", []address.Address{changeAddrVal}).
+ Return(nil).Once()
+
+ changeAddr, err := w.NewAddress(
+ t.Context(), "default", waddrmgr.WitnessPubKey, true,
+ )
+ require.NoError(t, err)
+
+ // The first unused change address should be the one we just created.
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("LookupAccount", mock.Anything, "default").
+ Return(uint32(0), nil).Once()
+
+ deps.accountManager.On("ForEachAccountAddress", mock.Anything, uint32(0),
+ mock.Anything).Run(func(args mock.Arguments) {
+ f, ok := args.Get(2).(func(waddrmgr.ManagedAddress) error)
+ require.True(t, ok)
+
+ // First external addr (used).
+ deps.addr.On("Internal").Return(false).Once()
+ _ = f(deps.addr)
+
+ // First internal addr (unused).
+ mockAddr2 := &mockManagedAddress{}
+ mockAddr2.On("Internal").Return(true).Once()
+ mockAddr2.On("Used", mock.Anything).Return(false).Once()
+ mockAddr2.On("Address").Return(changeAddrVal).Once()
+ _ = f(mockAddr2)
+ }).Return(errStopIteration).Once()
+
+ unusedChangeAddr, err := w.GetUnusedAddress(
+ t.Context(), "default", waddrmgr.WitnessPubKey, true,
+ )
+ require.NoError(t, err)
+ require.Equal(t, changeAddr.String(), unusedChangeAddr.String())
+}
+
+// TestAddressInfo tests the AddressInfo method to ensure it returns correct
+// information for both internal and external addresses.
+func TestAddressInfo(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // Get a new external address to test with.
+ var addr address.Address
+
+ addr, _ = address.NewAddressWitnessPubKeyHash(
+ make([]byte, 20), w.cfg.ChainParams,
+ )
+
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("NewAddress", mock.Anything, "default", false).
+ Return(addr, nil).Once()
+ deps.chain.On("NotifyReceived", []address.Address{addr}).
+ Return(nil).Once()
+
+ extAddr, err := w.NewAddress(
+ t.Context(), "default", waddrmgr.WitnessPubKey, false,
+ )
+ require.NoError(t, err)
+
+ // Get the address info for the external address.
+ deps.addrStore.On("Address", mock.Anything, extAddr).
+ Return(deps.addr, nil).Once()
+ deps.addr.On("Address").Return(extAddr).Once()
+ deps.addr.On("Internal").Return(false).Once()
+ deps.addr.On("Compressed").Return(true).Once()
+ deps.addr.On("Imported").Return(false).Once()
+ deps.addr.On("AddrType").Return(waddrmgr.WitnessPubKey).Once()
+
+ extInfo, err := w.AddressInfo(t.Context(), extAddr)
+ require.NoError(t, err)
+
+ // Check the external address info.
+ require.Equal(t, extAddr.String(), extInfo.Address().String())
+ require.False(t, extInfo.Internal())
+ require.True(t, extInfo.Compressed())
+ require.False(t, extInfo.Imported())
+ require.Equal(t, waddrmgr.WitnessPubKey, extInfo.AddrType())
+
+ // Get a new internal address to test with.
+ addr, _ = address.NewAddressWitnessPubKeyHash(
+ make([]byte, 20), w.cfg.ChainParams,
+ )
+
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("NewAddress", mock.Anything, "default", true).
+ Return(addr, nil).Once()
+ deps.chain.On("NotifyReceived", []address.Address{addr}).
+ Return(nil).Once()
+
+ intAddr, err := w.NewAddress(
+ t.Context(), "default", waddrmgr.WitnessPubKey, true,
+ )
+ require.NoError(t, err)
+
+ // Get the address info for the internal address.
+ deps.addrStore.On("Address", mock.Anything, intAddr).
+ Return(deps.addr, nil).Once()
+ deps.addr.On("Address").Return(intAddr).Once()
+ deps.addr.On("Internal").Return(true).Once()
+ deps.addr.On("Compressed").Return(true).Once()
+ deps.addr.On("Imported").Return(false).Once()
+ deps.addr.On("AddrType").Return(waddrmgr.WitnessPubKey).Once()
+
+ intInfo, err := w.AddressInfo(t.Context(), intAddr)
+ require.NoError(t, err)
+
+ // Check the internal address info.
+ require.Equal(t, intAddr.String(), intInfo.Address().String())
+ require.True(t, intInfo.Internal())
+ require.True(t, intInfo.Compressed())
+ require.False(t, intInfo.Imported())
+ require.Equal(t, waddrmgr.WitnessPubKey, intInfo.AddrType())
+}
+
+// TestGetDerivationInfoExternalAddressSuccess tests that we can successfully
+// get the derivation info for an external address.
+func TestGetDerivationInfoExternalAddressSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a new test wallet and a new p2wkh address to test
+ // with.
+ w, deps := createStartedWalletWithMocks(t)
+ mockAddr, _ := address.NewAddressWitnessPubKeyHash(
+ make([]byte, 20), w.cfg.ChainParams,
+ )
+
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("NewAddress", mock.Anything, "default", false).
+ Return(mockAddr, nil).Once()
+ deps.chain.On("NotifyReceived", []address.Address{mockAddr}).
+ Return(nil).Once()
+
+ addr, err := w.NewAddress(
+ t.Context(), "default", waddrmgr.WitnessPubKey, false,
+ )
+ require.NoError(t, err)
+
+ // Act: Get the derivation info for the address.
+ deps.addrStore.On("Address", mock.Anything, addr).
+ Return(deps.pubKeyAddr, nil).Once()
+ deps.pubKeyAddr.On("Imported").Return(false).Once()
+
+ privKey, _ := btcec.NewPrivateKey()
+ pubKey := privKey.PubKey()
+ deps.pubKeyAddr.On("PubKey").Return(pubKey).Once()
+
+ scope := waddrmgr.KeyScopeBIP0084
+ path := waddrmgr.DerivationPath{
+ Account: 0,
+ Branch: 0,
+ Index: 0,
+ MasterKeyFingerprint: 123,
+ }
+ deps.pubKeyAddr.On("DerivationInfo").Return(scope, path, true).Once()
+
+ derivationInfo, err := w.GetDerivationInfo(t.Context(), addr)
+
+ // Assert: Check that the correct derivation info is returned.
+ require.NoError(t, err)
+ require.NotNil(t, derivationInfo)
+
+ expectedPath := []uint32{
+ scope.Purpose + hdkeychain.HardenedKeyStart,
+ scope.Coin + hdkeychain.HardenedKeyStart,
+ path.Account,
+ path.Branch,
+ path.Index,
+ }
+
+ require.Equal(t, pubKey.SerializeCompressed(), derivationInfo.PubKey)
+ require.Equal(t, path.MasterKeyFingerprint,
+ derivationInfo.MasterKeyFingerprint)
+ require.Equal(t, expectedPath, derivationInfo.Bip32Path)
+}
+
+// TestGetDerivationInfoInternalAddressSuccess tests that we can successfully
+// get the derivation info for an internal address.
+func TestGetDerivationInfoInternalAddressSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a new test wallet and a new p2wkh change address to
+ // test with.
+ w, deps := createStartedWalletWithMocks(t)
+ mockAddr, _ := address.NewAddressWitnessPubKeyHash(
+ make([]byte, 20), w.cfg.ChainParams,
+ )
+
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("NewAddress", mock.Anything, "default", true).
+ Return(mockAddr, nil).Once()
+ deps.chain.On("NotifyReceived", []address.Address{mockAddr}).
+ Return(nil).Once()
+
+ addr, err := w.NewAddress(
+ t.Context(), "default", waddrmgr.WitnessPubKey, true,
+ )
+ require.NoError(t, err)
+
+ // Act: Get the derivation info for the address.
+ deps.addrStore.On("Address", mock.Anything, addr).
+ Return(deps.pubKeyAddr, nil).Once()
+ deps.pubKeyAddr.On("Imported").Return(false).Once()
+
+ privKey, _ := btcec.NewPrivateKey()
+ pubKey := privKey.PubKey()
+ deps.pubKeyAddr.On("PubKey").Return(pubKey).Once()
+
+ scope := waddrmgr.KeyScopeBIP0084
+ path := waddrmgr.DerivationPath{
+ Account: 0,
+ Branch: 1,
+ Index: 0,
+ MasterKeyFingerprint: 123,
+ }
+ deps.pubKeyAddr.On("DerivationInfo").Return(scope, path, true).Once()
+
+ derivationInfo, err := w.GetDerivationInfo(t.Context(), addr)
+
+ // Assert: Check that the correct derivation info is returned.
+ require.NoError(t, err)
+ require.NotNil(t, derivationInfo)
+
+ expectedPath := []uint32{
+ scope.Purpose + hdkeychain.HardenedKeyStart,
+ scope.Coin + hdkeychain.HardenedKeyStart,
+ path.Account,
+ path.Branch,
+ path.Index,
+ }
+ require.Equal(t, expectedPath, derivationInfo.Bip32Path)
+ require.Equal(t, uint32(1), path.Branch)
+}
+
+// TestGetDerivationInfoNoDerivationInfo tests that we get an error when trying
+// to get the derivation info for an address that is not in the wallet or is
+// imported.
+func TestGetDerivationInfoNoDerivationInfo(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a new test wallet and a key and address that is not
+ // in the wallet.
+ w, deps := createStartedWalletWithMocks(t)
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+ addr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()),
+ w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ // Act & Assert: Check that we get an error for an address not in the
+ // wallet.
+ deps.addrStore.On("Address", mock.Anything, addr).Return(
+ nil, errDBMock).Once()
+
+ _, err = w.GetDerivationInfo(t.Context(), addr)
+ require.Error(t, err)
+
+ // Arrange: Import the key as a watch-only address.
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("ImportPublicKey", mock.Anything, pubKey,
+ mock.Anything).Return(deps.pubKeyAddr, nil).Once()
+ deps.pubKeyAddr.On("Address").Return(addr).Maybe()
+ deps.chain.On("NotifyReceived", []address.Address{addr}).
+ Return(nil).Once()
+
+ err = w.ImportPublicKey(t.Context(), pubKey, waddrmgr.WitnessPubKey)
+ require.NoError(t, err)
+
+ // Act & Assert: Check that we still get an error because it's an
+ // imported key.
+ deps.addrStore.On("Address", mock.Anything, addr).
+ Return(deps.pubKeyAddr, nil).Once()
+ deps.pubKeyAddr.On("Imported").Return(true).Once()
+
+ _, err = w.GetDerivationInfo(t.Context(), addr)
+ require.ErrorIs(t, err, ErrDerivationPathNotFound)
+}
+
+// TestListAddresses tests the ListAddresses method to ensure it returns the
+// correct addresses and balances for a given account.
+func TestListAddresses(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // Get a new address and give it a balance.
+ mockAddr, _ := address.NewAddressWitnessPubKeyHash(
+ make([]byte, 20), w.cfg.ChainParams,
+ )
+
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("NewAddress", mock.Anything, "default", false).
+ Return(mockAddr, nil).Once()
+ deps.chain.On("NotifyReceived", []address.Address{mockAddr}).
+ Return(nil).Once()
+
+ addr, err := w.NewAddress(
+ t.Context(), "default", waddrmgr.WitnessPubKey, false,
+ )
+ require.NoError(t, err)
+
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ // We need to create a realistic transaction that has at least one
+ // input.
+ deps.txStore.On("InsertTx", mock.Anything, mock.Anything,
+ mock.Anything).Return(nil).Once()
+ deps.txStore.On("AddCredit", mock.Anything, mock.Anything,
+ mock.Anything, uint32(0), false).Return(nil).Once()
+ deps.addrStore.On("MarkUsed", mock.Anything, addr).Return(nil).Once()
+
+ err = walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ txmgrNs := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+
+ // Create a new transaction and set the output to the address
+ // we want to mark as used.
+ msgTx := TstTx.MsgTx()
+ msgTx.TxOut = []*wire.TxOut{{
+ PkScript: pkScript,
+ Value: 1000,
+ }}
+
+ rec, err := wtxmgr.NewTxRecordFromMsgTx(msgTx, time.Now())
+ if err != nil {
+ return err
+ }
+
+ err = w.txStore.InsertTx(txmgrNs, rec, nil)
+ if err != nil {
+ return err
+ }
+
+ err = w.txStore.AddCredit(txmgrNs, rec, nil, 0, false)
+ if err != nil {
+ return err
+ }
+
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ return w.addrStore.MarkUsed(addrmgrNs, addr)
+ })
+ require.NoError(t, err)
+
+ // List the addresses for the default account.
+ deps.txStore.On("UnspentOutputs", mock.Anything).Return([]wtxmgr.Credit{
+ {
+ Amount: 1000,
+ PkScript: pkScript,
+ },
+ }, nil).Once()
+
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).Once()
+
+ deps.accountManager.On("LookupAccount", mock.Anything, "default").
+ Return(uint32(0), nil).Once()
+
+ deps.accountManager.On("ForEachAccountAddress", mock.Anything, uint32(0),
+ mock.Anything).Run(func(args mock.Arguments) {
+ f, ok := args.Get(2).(func(waddrmgr.ManagedAddress) error)
+ require.True(t, ok)
+ deps.addr.On("Address").Return(addr).Once()
+ _ = f(deps.addr)
+ }).Return(nil).Once()
+
+ addrs, err := w.ListAddresses(
+ t.Context(), "default", waddrmgr.WitnessPubKey,
+ )
+ require.NoError(t, err)
+
+ // We should have one address with a balance of 1000.
+ require.Len(t, addrs, 1)
+ require.Equal(t, addr.String(), addrs[0].Address.String())
+ require.Equal(t, btcutil.Amount(1000), addrs[0].Balance)
+}
+
+// TestImportPublicKey tests the ImportPublicKey method to ensure it can
+// import a public key as a watch-only address.
+func TestImportPublicKey(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // Create a new public key to import.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ // Import the public key.
+ addr, _ := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()),
+ w.cfg.ChainParams,
+ )
+
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("ImportPublicKey", mock.Anything, pubKey,
+ mock.Anything).Return(deps.pubKeyAddr, nil).Once()
+ deps.pubKeyAddr.On("Address").Return(addr).Once()
+ deps.chain.On("NotifyReceived", []address.Address{addr}).
+ Return(nil).Once()
+
+ err = w.ImportPublicKey(t.Context(), pubKey, waddrmgr.WitnessPubKey)
+ require.NoError(t, err)
+
+ // Check that the address is now managed by the wallet.
+ deps.addrStore.On("Address", mock.Anything, addr).
+ Return(deps.pubKeyAddr, nil).Once()
+
+ info, err := w.AddressInfo(t.Context(), addr)
+ require.NoError(t, err)
+ require.NotNil(t, info)
+}
+
+// TestImportTaprootScript tests the ImportTaprootScript method to ensure it can
+// import a taproot script as a watch-only address.
+func TestImportTaprootScript(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // Create a new tapscript to import.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+ script, err := txscript.NewScriptBuilder().
+ AddData(pubKey.SerializeCompressed()).
+ AddOp(txscript.OP_CHECKSIG).
+ Script()
+ require.NoError(t, err)
+
+ leaf := txscript.NewTapLeaf(txscript.BaseLeafVersion, script)
+ tree := txscript.AssembleTaprootScriptTree(leaf)
+ rootHash := tree.RootNode.TapHash()
+ tapscript := waddrmgr.Tapscript{
+ Type: waddrmgr.TapscriptTypeFullTree,
+ ControlBlock: &txscript.ControlBlock{
+ InternalKey: pubKey,
+ },
+ Leaves: []txscript.TapLeaf{leaf},
+ }
+
+ // Import the tapscript.
+ addr, _ := address.NewAddressTaproot(
+ schnorr.SerializePubKey(txscript.ComputeTaprootOutputKey(
+ pubKey, rootHash[:],
+ )), w.cfg.ChainParams,
+ )
+
+ deps.addrStore.On("FetchScopedKeyManager", waddrmgr.KeyScopeBIP0086).
+ Return(deps.accountManager, nil).Once()
+
+ // SyncedTo is mocked in createStartedWalletWithMocks (height 1).
+ deps.accountManager.On("ImportTaprootScript", mock.Anything,
+ mock.Anything, mock.Anything, uint8(1), false).
+ Return(deps.taprootAddr, nil).Once()
+ deps.taprootAddr.On("Address").Return(addr).Once()
+ deps.chain.On("NotifyReceived", []address.Address{addr}).
+ Return(nil).Once()
+
+ _, err = w.ImportTaprootScript(t.Context(), tapscript)
+ require.NoError(t, err)
+
+ // Check that the address is now managed by the wallet.
+ deps.addrStore.On("Address", mock.Anything, addr).
+ Return(deps.taprootAddr, nil).Once()
+
+ info, err := w.AddressInfo(t.Context(), addr)
+ require.NoError(t, err)
+ require.NotNil(t, info)
+}
+
+// TestScriptForOutput tests the ScriptForOutput method to ensure it returns the
+// correct script for a given output.
+func TestScriptForOutput(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, deps := createStartedWalletWithMocks(t)
+
+ // Create a new p2wkh address and output.
+ mockAddr, _ := address.NewAddressWitnessPubKeyHash(
+ make([]byte, 20), w.cfg.ChainParams,
+ )
+
+ deps.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(deps.accountManager, nil).Once()
+ deps.accountManager.On("NewAddress", mock.Anything, "default", false).
+ Return(mockAddr, nil).Once()
+ deps.chain.On("NotifyReceived", []address.Address{mockAddr}).
+ Return(nil).Once()
+
+ addr, err := w.NewAddress(
+ t.Context(), "default", waddrmgr.WitnessPubKey, false,
+ )
+ require.NoError(t, err)
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ output := wire.TxOut{
+ Value: 1000,
+ PkScript: pkScript,
+ }
+
+ // Get the script for the output.
+ deps.addrStore.On("Address", mock.Anything, addr).
+ Return(deps.pubKeyAddr, nil).Once()
+ deps.pubKeyAddr.On("AddrType").Return(waddrmgr.WitnessPubKey).Once()
+
+ script, err := w.ScriptForOutput(t.Context(), output)
+ require.NoError(t, err)
+
+ // Check that the script is correct.
+ require.Equal(t, pkScript, script.WitnessProgram)
+ require.Nil(t, script.RedeemScript)
+}
diff --git a/wallet/benchmark_helpers_test.go b/wallet/benchmark_helpers_test.go
new file mode 100644
index 0000000000..2869484c35
--- /dev/null
+++ b/wallet/benchmark_helpers_test.go
@@ -0,0 +1,1129 @@
+package wallet
+
+import (
+ "errors"
+ "fmt"
+ "math/rand"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcec/v2"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/stretchr/testify/require"
+)
+
+var errAccountNotFound = errors.New("account not found")
+
+// growthFunc defines how a benchmark parameter should scale with iteration
+// index. It takes an iteration index i (0-based) and returns the parameter
+// value for that iteration. This allows flexible configuration of benchmark
+// data sizes with different growth patterns (linear, exponential, logarithmic,
+// etc.).
+type growthFunc func(i int) int
+
+// constantGrowth returns a constant value regardless of iteration.
+//
+// Use when: The parameter is a control variable not under test and should
+// remain fixed across all iterations.
+//
+// Example: accountGrowth when testing transaction complexity (not account
+// scaling).
+//
+// Note: Ideal for CI as it produces predictable, stable results for regression
+// detection.
+//
+// Result: 5, 5, 5, 5, 5...
+func constantGrowth(i int) int {
+ return 5
+}
+
+// linearGrowth scales the parameter value linearly with arithmetic progression.
+//
+// Use when: Testing gradual scaling behavior, O(n) or O(n²) algorithms, or
+// when detailed granularity is needed across a moderate range.
+//
+// Example: Transaction I/O counts, address counts, or database record counts
+// where you want to see how performance degrades proportionally.
+//
+// Note: Safe for CI when used with limited range (e.g., i = 0..9 yields 5..50)
+// for regression detection. Avoid for O(log n) algorithms as x grows linearly
+// while y grows logarithmically, making regressions harder to detect.
+//
+// Result: 5, 10, 15, 20, 25, 30, 35...
+func linearGrowth(i int) int {
+ return 5 + (i * 5)
+}
+
+// exponentialGrowth scales the parameter value exponentially (powers of 2).
+//
+// Use when: Stress testing scalability limits, testing concurrency levels, or
+// quickly covering a wide range from small to large values. Works well for
+// algorithms with O(log n) complexity as it creates a linear relationship when
+// plotted (e.g., y = log₂(x) when x grows exponentially, y grows linearly).
+//
+// Example: Concurrent worker counts, cache sizes, or finding performance
+// breaking points.
+//
+// Note: Avoid running in CI due to large values and long execution times. Use
+// for local performance analysis only.
+//
+// Result: 1, 2, 4, 8, 16, 32, 64, 128, 256...
+func exponentialGrowth(i int) int {
+ return 1 << i
+}
+
+// InterleavePattern defines how to interleave two slices of any type.
+type InterleavePattern int
+
+const (
+
+ // Alternating pattern interleaves elements one-by-one from each slice.
+ //
+ // Use when: Testing iteration logic that needs to handle mixed element
+ // types or simulating varied element ordering.
+ //
+ // Pattern: A B A B A B ...
+ //
+ // Example: Testing APIs that filter or group elements by type during
+ // iteration.
+ Alternating InterleavePattern = iota
+
+ // Sequential pattern concatenates all elements from first slice
+ // followed by all elements from second slice.
+ //
+ // Use when: Testing best/worst case scenarios where elements are
+ // perfectly sorted by type, or when element order doesn't affect the
+ // algorithm being tested.
+ //
+ // Pattern: A ... B ...
+ Sequential
+
+ // Grouped pattern interleaves elements in batches of specified size.
+ //
+ // Use when: Testing batch processing scenarios or simulating elements
+ // that arrive in groups.
+ //
+ // Pattern (groupSize=2): A B A B ...
+ //
+ // Example: Testing batch transaction processing or cache locality
+ // effects.
+ Grouped
+
+ // Random pattern shuffles elements from both slices randomly.
+ //
+ // Use when: Testing performance with unpredictable ordering or
+ // verifying that algorithms handle arbitrary element arrangements.
+ //
+ // Pattern: B A B A ...
+ //
+ // Note: Uses simple pseudo-random shuffle, not cryptographically
+ // secure.
+ Random
+)
+
+// Interleave combines two slices according to the specified pattern.
+func Interleave[T any](pattern InterleavePattern, groupSize int, a, b []T) []T {
+ switch pattern {
+ case Sequential:
+ return sequentialInterleave(a, b)
+ case Alternating:
+ return alternatingInterleave(a, b)
+ case Grouped:
+ return groupedInterleave(groupSize, a, b)
+ case Random:
+ return randomInterleave(a, b)
+ default:
+ return alternatingInterleave(a, b)
+ }
+}
+
+// sequentialInterleave concatenates all elements from a followed by all
+// elements from b.
+func sequentialInterleave[T any](a, b []T) []T {
+ result := make([]T, 0, len(a)+len(b))
+ result = append(result, a...)
+ result = append(result, b...)
+
+ return result
+}
+
+// alternatingInterleave interleaves elements one-by-one from a and b.
+func alternatingInterleave[T any](a, b []T) []T {
+ result := make([]T, 0, len(a)+len(b))
+ aIdx, bIdx := 0, 0
+
+ for aIdx < len(a) || bIdx < len(b) {
+ if aIdx < len(a) {
+ result = append(result, a[aIdx])
+ aIdx++
+ }
+
+ if bIdx < len(b) {
+ result = append(result, b[bIdx])
+ bIdx++
+ }
+ }
+
+ return result
+}
+
+// groupedInterleave interleaves elements in batches of groupSize from a and b.
+func groupedInterleave[T any](groupSize int, a, b []T) []T {
+ if groupSize <= 0 {
+ groupSize = 1
+ }
+
+ result := make([]T, 0, len(a)+len(b))
+ aIdx, bIdx := 0, 0
+
+ for aIdx < len(a) || bIdx < len(b) {
+ // Add batch from a.
+ for i := 0; i < groupSize && aIdx < len(a); i++ {
+ result = append(result, a[aIdx])
+ aIdx++
+ }
+
+ // Add batch from b.
+ for i := 0; i < groupSize && bIdx < len(b); i++ {
+ result = append(result, b[bIdx])
+ bIdx++
+ }
+ }
+
+ return result
+}
+
+// randomInterleave combines elements from a and b in pseudo-random order.
+func randomInterleave[T any](a, b []T) []T {
+ result := make([]T, 0, len(a)+len(b))
+ result = append(result, a...)
+ result = append(result, b...)
+
+ rand.Shuffle(len(result), func(i, j int) {
+ result[i], result[j] = result[j], result[i]
+ })
+
+ return result
+}
+
+// mapRange maps fn over indices [start..end] (inclusive) and returns the
+// results. This provides functional-style array generation for benchmarks.
+//
+//nolint:unparam // Different benchmarks may intentionally use different values
+func mapRange(start, end int, fn growthFunc) []int {
+ result := make([]int, end-start+1)
+ for i := range result {
+ result[i] = fn(start + i)
+ }
+
+ return result
+}
+
+// decimalWidth returns the number of characters in the decimal representation
+// of given value.
+func decimalWidth(value int) int {
+ return len(strconv.Itoa(value))
+}
+
+// benchmarkWalletConfig holds configuration for benchmark wallet setup.
+type benchmarkWalletConfig struct {
+ // scopes is the key scopes to create accounts in.
+ scopes []waddrmgr.KeyScope
+
+ // numAccounts is the number of accounts to create.
+ numAccounts int
+
+ // numWalletTxs is the number of wallet transactions to create.
+ numWalletTxs int
+
+ // numAddresses is the number of addresses to create.
+ numAddresses int
+
+ // numTxInputs is the number of inputs per transaction. If 0, defaults
+ // to 1 input per transaction.
+ numTxInputs int
+
+ // numTxOutputs is the number of outputs per transaction. If 0,
+ // defaults to 1 output per transaction.
+ numTxOutputs int
+
+ // txInterleavePattern specifies the interleaving pattern for organizing
+ // transactions with different states (confirmed vs unconfirmed).
+ txInterleavePattern InterleavePattern
+}
+
+// benchmarkWallet holds a wallet and its created wallet transactions.
+type benchmarkWallet struct {
+ *Wallet
+
+ // confirmedTxs contains confirmed wallet transactions created during
+ // benchmark setup. These are spending transactions with both debits
+ // (inputs) and credits (outputs) that have been mined in blocks.
+ confirmedTxs []*wire.MsgTx
+
+ // unconfirmedTxs contains unconfirmed wallet transactions created
+ // during benchmark setup. These are spending transactions with both
+ // debits (inputs) and credits (outputs) that are in the mempool.
+ unconfirmedTxs []*wire.MsgTx
+
+ // allTxs contains all wallet transactions (both confirmed and
+ // unconfirmed) combined.
+ allTxs []*wire.MsgTx
+}
+
+// setupBenchmarkWallet creates a wallet with test data based on the provided
+// configuration. It distributes accounts evenly across the specified scopes
+// and returns the wallet along with the outpoints of all created UTXOs. If
+// config.miner is provided, the wallet is connected to the btcd node via RPC.
+func setupBenchmarkWallet(tb testing.TB,
+ cfg benchmarkWalletConfig) *benchmarkWallet {
+
+ tb.Helper()
+
+ // Since testWallet requires a *testing.T, we can't pass the benchmark's
+ // *testing.B. Instead, we create a setup *testing.T and manually fail
+ // the benchmark if the setup fails.
+ setupT := &testing.T{}
+ w := testWallet(setupT)
+ require.False(tb, setupT.Failed(), "testWallet setup failed")
+
+ // Backfill Config and State for benchmarks comparing new APIs.
+ // Legacy testWallet does not populate these.
+ if w.cfg.DB == nil {
+ w.cfg = Config{
+ DB: w.db,
+ ChainParams: w.chainParams,
+ Chain: w.chainClient,
+ }
+ }
+
+ if w.sync == nil {
+ w.sync = newSyncer(w.cfg, w.addrStore, w.txStore, w)
+ }
+
+ // Initialize controller channels and timer.
+ if w.requestChan == nil {
+ w.requestChan = make(chan any)
+ }
+
+ if w.lockTimer == nil {
+ w.lockTimer = time.NewTimer(0)
+ if !w.lockTimer.Stop() {
+ <-w.lockTimer.C
+ }
+ }
+
+ // Force state to Started to satisfy validateStarted() in new APIs.
+ err := w.state.toStarting()
+ if err == nil {
+ err = w.state.toStarted()
+ require.NoError(tb, err)
+ }
+ // Transition to Unlocked so signing operations are permitted.
+ w.state.toUnlocked()
+
+ addresses := createTestAccounts(
+ tb, w, cfg.scopes, cfg.numAccounts, cfg.numAddresses,
+ )
+
+ var txsResult *testWalletTxsResult
+ if cfg.numWalletTxs > 0 {
+ txsResult = createTestWalletTxs(
+ tb, w, addresses, cfg.numWalletTxs, cfg.numTxInputs,
+ cfg.numTxOutputs,
+ )
+ } else {
+ // Return empty result if no transactions requested.
+ txsResult = &testWalletTxsResult{
+ confirmed: []*wire.MsgTx{},
+ unconfirmed: []*wire.MsgTx{},
+ }
+ }
+
+ // Combine confirmed and unconfirmed transactions using the specified
+ // pattern. This ensures diverse transaction states in allTxs even with
+ // small dataset sizes (e.g., when using constantGrowth which defaults
+ // to 5 elements), allowing benchmarks to test iteration logic that
+ // handles mixed confirmation states.
+ allTxs := Interleave(
+ cfg.txInterleavePattern, 0, txsResult.confirmed,
+ txsResult.unconfirmed,
+ )
+
+ return &benchmarkWallet{
+ Wallet: w,
+ confirmedTxs: txsResult.confirmed,
+ unconfirmedTxs: txsResult.unconfirmed,
+ allTxs: allTxs,
+ }
+}
+
+// setSyncedToHeight updates the wallet's synced block height. This is useful
+// for benchmark tests to ensure confirmation calculations work correctly.
+func setSyncedToHeight(tb testing.TB, w *Wallet, height int32,
+ hash chainhash.Hash) {
+
+ tb.Helper()
+
+ err := walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ return w.addrStore.SetSyncedTo(addrmgrNs, &waddrmgr.BlockStamp{
+ Height: height,
+ Hash: hash,
+ })
+ })
+ require.NoError(tb, err, "failed to set synced height to %d", height)
+}
+
+// createTestAccounts creates test accounts across the specified key scopes
+// and returns all generated addresses.
+func createTestAccounts(tb testing.TB, w *Wallet, scopes []waddrmgr.KeyScope,
+ numAccounts, numAddresses int) []waddrmgr.ManagedAddress {
+
+ tb.Helper()
+
+ var allAddresses []waddrmgr.ManagedAddress
+
+ err := walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ // Distribute accounts across the specified key scopes.
+ accountsPerScope := numAccounts / len(scopes)
+ remainder := numAccounts % len(scopes)
+
+ for i, scope := range scopes {
+ scopeAccounts := accountsPerScope
+ if i < remainder {
+ // Distribute remainder accounts.
+ scopeAccounts++
+ }
+
+ err := createAccountsInScope(
+ w, tx, scope, scopeAccounts, numAddresses,
+ i*accountsPerScope, &allAddresses,
+ )
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+
+ require.NoError(tb, err, "failed to create test accounts: %v", err)
+
+ return allAddresses
+}
+
+// createAccountsInScope creates accounts within a specific scope with unique
+// naming across scopes.
+func createAccountsInScope(w *Wallet, tx walletdb.ReadWriteTx,
+ scope waddrmgr.KeyScope, numAccounts, numAddresses, offset int,
+ allAddresses *[]waddrmgr.ManagedAddress) error {
+
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return err
+ }
+
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ for i := range numAccounts {
+ name := fmt.Sprintf("bench-scope-%d-%d-account-%d",
+ scope.Purpose, scope.Coin, offset+i)
+
+ account, err := manager.NewAccount(addrmgrNs, name)
+ if err != nil {
+ return err
+ }
+
+ addrs, err := manager.NextExternalAddresses(
+ addrmgrNs, account, uint32(numAddresses),
+ )
+ if err != nil {
+ return err
+ }
+
+ *allAddresses = append(*allAddresses, addrs...)
+ }
+
+ return nil
+}
+
+// testWalletTxsResult holds the result of creating test wallet transactions.
+type testWalletTxsResult struct {
+ // confirmed contains confirmed spending transactions.
+ confirmed []*wire.MsgTx
+
+ // unconfirmed contains unconfirmed spending transactions.
+ unconfirmed []*wire.MsgTx
+
+ // highestBlockMeta is the metadata for the highest block containing
+ // confirmed transactions.
+ highestBlockMeta wtxmgr.BlockMeta
+}
+
+// createTestWalletTxs creates diverse test wallet transactions with both
+// confirmed and unconfirmed transaction history. The goal is diversity for more
+// comprehensive benchmark testing. The function creates four passes of
+// transactions:
+// 1. Initial confirmed UTXOs for confirmed spending txs (credits only)
+// 2. Confirmed spending transactions (debits + credits, mined in blocks)
+// 3. Initial confirmed UTXOs for unconfirmed spending txs (credits only)
+// 4. Unconfirmed spending transactions (debits + credits, unmined/mempool)
+//
+// Each set of spending transactions uses separate UTXOs to avoid double-spend
+// conflicts. numInputs and numOutputs control transaction complexity. Returns
+// both confirmed and unconfirmed spending transactions.
+func createTestWalletTxs(tb testing.TB, w *Wallet,
+ addresses []waddrmgr.ManagedAddress, numTxs,
+ numInputs, numOutputs int) *testWalletTxsResult {
+
+ tb.Helper()
+
+ var (
+ txsConfirmed []*wire.MsgTx
+ txsUnconfirmed []*wire.MsgTx
+ highestBlockMeta wtxmgr.BlockMeta
+ )
+
+ err := walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ txmgrNs := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ msgTx := TstTx.MsgTx()
+
+ var (
+ initialCreditsConfirmed []*wire.MsgTx
+ prevOutpointsConfirmed []wire.OutPoint
+ )
+
+ var (
+ baseBlockHeight int32 = 1
+ mined = true
+ )
+
+ // First pass: Create initial UTXOs (credits only, no debits).
+ initialCreditsConfirmed, highestBlockMeta = createTxBatch(
+ tb, w, txmgrNs, addrmgrNs, addresses, numTxs,
+ msgTx.Version, baseBlockHeight, 200000, nil, 0,
+ numOutputs, mined,
+ )
+
+ prevOutpointsConfirmed = txsToOutpoints(
+ initialCreditsConfirmed,
+ )
+
+ // Second pass: Create confirmed spending transactions
+ baseBlockHeight = highestBlockMeta.Height + 1
+ txsConfirmed, highestBlockMeta = createTxBatch(
+ tb, w, txmgrNs, addrmgrNs, addresses, numTxs,
+ msgTx.Version, baseBlockHeight, 100000,
+ prevOutpointsConfirmed, numInputs, numOutputs, mined,
+ )
+
+ // Third pass: Create initial UTXOs for unconfirmed spending
+ // txs.
+ baseBlockHeight = highestBlockMeta.Height + 1
+ initialCreditsConfirmed, highestBlockMeta = createTxBatch(
+ tb, w, txmgrNs, addrmgrNs, addresses, numTxs,
+ msgTx.Version, baseBlockHeight, 200000, nil, 0,
+ numOutputs, mined,
+ )
+
+ prevOutpointsConfirmed = txsToOutpoints(initialCreditsConfirmed)
+
+ // Fourth pass: Create unconfirmed spending transactions.
+ baseBlockHeight = -1
+ mined = false
+ txsUnconfirmed, _ = createTxBatch(
+ tb, w, txmgrNs, addrmgrNs, addresses, numTxs,
+ msgTx.Version, baseBlockHeight, 110000,
+ prevOutpointsConfirmed, numInputs, numOutputs, mined,
+ )
+
+ return nil
+ })
+
+ require.NoError(tb, err, "failed to create test wallet txs: %v", err)
+
+ // Sync wallet to the highest block containing confirmed transactions.
+ setSyncedToHeight(
+ tb, w, highestBlockMeta.Height,
+ highestBlockMeta.Hash,
+ )
+
+ return &testWalletTxsResult{
+ confirmed: txsConfirmed,
+ unconfirmed: txsUnconfirmed,
+ highestBlockMeta: highestBlockMeta,
+ }
+}
+
+// txsToOutpoints converts all transaction outputs to outpoints. For X txs with
+// Y outputs per tx outputs, returns X*Y outpoints.
+func txsToOutpoints(txs []*wire.MsgTx) []wire.OutPoint {
+ var outpoints []wire.OutPoint
+ for _, tx := range txs {
+ txHash := tx.TxHash()
+ for j := range tx.TxOut {
+ outpoints = append(
+ outpoints, wire.OutPoint{
+ Hash: txHash,
+ Index: uint32(j),
+ },
+ )
+ }
+ }
+
+ return outpoints
+}
+
+// createTxBatch is a helper that creates a batch of transactions.
+// If prevOutpoints is nil, creates receiving transactions (credits only).
+// If prevOutpoints is provided, creates spending transactions
+// (debits + credits). If mined is true, each transaction is placed in its own
+// block (blockHeight + i). If mined is false, transactions are unmined
+// (unconfirmed). numInputs and numOutputs control transaction complexity; if 0,
+// defaults to 1 input and 1 output per transaction. Returns the created
+// transactions and the block metadata for the highest block (only meaningful if
+// mined is true).
+func createTxBatch(tb testing.TB, w *Wallet, txmgrNs,
+ addrmgrNs walletdb.ReadWriteBucket, addresses []waddrmgr.ManagedAddress,
+ count int, txVersion int32, startBlockHeight int32, baseAmount int64,
+ prevOutpoints []wire.OutPoint, numInputs, numOutputs int,
+ mined bool) ([]*wire.MsgTx, wtxmgr.BlockMeta) {
+
+ tb.Helper()
+
+ // Default to 1 input and 1 output if not specified.
+ if numInputs == 0 {
+ numInputs = 1
+ }
+
+ if numOutputs == 0 {
+ numOutputs = 1
+ }
+
+ var (
+ transactions []*wire.MsgTx
+ lastBlockMeta wtxmgr.BlockMeta
+ )
+
+ for i := 0; i < count && i < len(addresses); i++ {
+ var blockMeta *wtxmgr.BlockMeta
+ if mined {
+ // Each transaction goes in its own block with unique
+ // hash.
+ blockHash := chainhash.Hash{}
+ blockHash[0] = byte(startBlockHeight + int32(i))
+ blockHash[1] = byte((startBlockHeight + int32(i)) >> 8)
+
+ blockMeta = &wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ Hash: blockHash,
+ Height: startBlockHeight + int32(i),
+ },
+ Time: time.Now(),
+ }
+ lastBlockMeta = *blockMeta
+ }
+
+ tx := buildTxForBatch(
+ tb, addresses, txVersion, i, baseAmount,
+ prevOutpoints, numInputs, numOutputs,
+ )
+
+ rec, err := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
+ require.NoError(tb, err)
+
+ err = w.txStore.InsertTx(txmgrNs, rec, blockMeta)
+ require.NoError(tb, err)
+
+ // Add credits for all outputs belonging to our wallet.
+ for j := range numOutputs {
+ err = w.txStore.AddCredit(
+ txmgrNs, rec, blockMeta, uint32(j), false,
+ )
+ require.NoError(tb, err)
+ }
+
+ // Mark all addresses as used.
+ for j := range numOutputs {
+ addr := addresses[(i+j)%len(addresses)]
+ err = w.addrStore.MarkUsed(addrmgrNs, addr.Address())
+ require.NoError(tb, err)
+ }
+
+ transactions = append(transactions, tx)
+ }
+
+ return transactions, lastBlockMeta
+}
+
+// buildTxForBatch creates a single transaction with the specified inputs and
+// outputs.
+func buildTxForBatch(tb testing.TB, addresses []waddrmgr.ManagedAddress,
+ txVersion int32, i int, baseAmount int64, prevOutpoints []wire.OutPoint,
+ numInputs, numOutputs int) *wire.MsgTx {
+
+ tb.Helper()
+
+ tx := wire.NewMsgTx(txVersion)
+
+ // Add multiple outputs to our wallet (creates credits).
+ for j := range numOutputs {
+ addr := addresses[(i+j)%len(addresses)]
+ pkScript, err := txscript.PayToAddrScript(addr.Address())
+ require.NoError(tb, err)
+
+ // Add random jitter based on timestamp to ensure unique
+ // transaction hashes across benchmark runs, preventing
+ // duplicate transaction errors when the same test data is
+ // created multiple times. This is necessary for
+ // representative benchmarking.
+ randomJitter := time.Now().UnixNano() % 1000
+ amount := btcutil.Amount(
+ baseAmount + int64(i*1000+j*100) + randomJitter,
+ )
+ txOut := wire.NewTxOut(int64(amount), pkScript)
+ tx.AddTxOut(txOut)
+ }
+
+ // Add multiple inputs - either external or from our wallet.
+ for j := range numInputs {
+ outpointIdx := i*numInputs + j
+ if prevOutpoints != nil && outpointIdx < len(prevOutpoints) {
+ // Spend from our previous UTXO (creates debit).
+ txIn := wire.NewTxIn(
+ &prevOutpoints[outpointIdx], nil, nil,
+ )
+ tx.AddTxIn(txIn)
+ } else {
+ // External input (no debit). Needed for tx to be
+ // syntactically valid.
+ prevHash := chainhash.Hash{}
+ prevHash[0] = byte(i)
+ prevHash[1] = byte(j)
+ txIn := wire.NewTxIn(
+ wire.NewOutPoint(&prevHash, uint32(j)), nil,
+ nil,
+ )
+ tx.AddTxIn(txIn)
+ }
+ }
+
+ return tx
+}
+
+// generateAccountName generates a consistent account name and number for
+// benchmarking based on the given number of accounts and scopes. It returns
+// the first account name and number in the last scope, which provides a good
+// heuristic case for evaluating search performance.
+func generateAccountName(numAccounts int,
+ scopes []waddrmgr.KeyScope) (string, uint32) {
+
+ accountsPerScope := numAccounts / len(scopes)
+
+ lastScopeIndex := len(scopes) - 1
+ lastScope := scopes[lastScopeIndex]
+ lastScopeOffset := lastScopeIndex * accountsPerScope
+
+ accountName := fmt.Sprintf("bench-scope-%d-%d-account-%d",
+ lastScope.Purpose, lastScope.Coin, lastScopeOffset)
+
+ // Account numbers start from 1, not 0. Account 0 is reserved for
+ // "default".
+ accountNumber := uint32(lastScopeOffset + 1)
+
+ return accountName, accountNumber
+}
+
+// generateTestExtendedKey generates a test extended public key for benchmarking
+// ImportAccount operations. It uses a deterministic seed based on the
+// seed index to ensure consistent and unique results across benchmark runs.
+func generateTestExtendedKey(tb testing.TB,
+ seedIndex int) (*hdkeychain.ExtendedKey, uint32, waddrmgr.AddressType) {
+
+ tb.Helper()
+
+ // Use a simple deterministic seed based on seed index.
+ seed := make([]byte, 32)
+ for j := range seed {
+ seed[j] = byte(seedIndex + j)
+ }
+
+ // Create master key from seed.
+ masterKey, err := hdkeychain.NewMaster(seed, &chaincfg.TestNet3Params)
+ require.NoError(tb, err)
+
+ // Derive account key for BIP0084 (m/84'/1'/seedIndex').
+ purpose, err := masterKey.Derive(hdkeychain.HardenedKeyStart + 84)
+ require.NoError(tb, err)
+
+ coin, err := purpose.Derive(hdkeychain.HardenedKeyStart + 1)
+ require.NoError(tb, err)
+
+ account, err := coin.Derive(
+ hdkeychain.HardenedKeyStart + uint32(seedIndex),
+ )
+ require.NoError(tb, err)
+
+ accountPubKey, err := account.Neuter()
+ require.NoError(tb, err)
+
+ return accountPubKey, uint32(seedIndex), waddrmgr.WitnessPubKey
+}
+
+// getMedianTestAddress returns a median address from a median account for
+// benchmarking purposes.
+func getTestAddress(tb testing.TB, w *Wallet, numAccounts int) address.Address {
+ tb.Helper()
+
+ medianAccount := uint32(numAccounts / 2)
+ addresses, err := w.AccountAddresses(medianAccount)
+ require.NoError(tb, err)
+
+ return addresses[len(addresses)/2]
+}
+
+// markAddressAsUsed marks an address as used in the wallet database. This is
+// useful for making benchmark iterations idempotent.
+func markAddressAsUsed(b *testing.B, w *Wallet, addr address.Address) {
+ b.Helper()
+
+ err := walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ manager, err := w.addrStore.FetchScopedKeyManager(
+ waddrmgr.KeyScopeBIP0044,
+ )
+ if err != nil {
+ return err
+ }
+
+ return manager.MarkUsed(addrmgrNs, addr)
+ })
+ require.NoError(b, err)
+}
+
+// getTestUtxoOutpoint returns a median UTXO outpoint from the provided list
+// for benchmarking purposes. It returns the outpoint from the middle of the
+// list to provide a representative test case.
+func getTestUtxoOutpoint(outpoints []wire.OutPoint) wire.OutPoint {
+ medianIndex := len(outpoints) / 2
+ return outpoints[medianIndex]
+}
+
+// generateTestTapscript generates a test tapscript for benchmarking purposes.
+// It creates a simple script that checks a signature against the provided
+// public key, wraps it in a tap leaf, and returns a complete Tapscript
+// structure ready for import.
+func generateTestTapscript(tb testing.TB,
+ pubKey *btcec.PublicKey) waddrmgr.Tapscript {
+
+ tb.Helper()
+
+ script, err := txscript.NewScriptBuilder().
+ AddData(pubKey.SerializeCompressed()).
+ AddOp(txscript.OP_CHECKSIG).
+ Script()
+ require.NoError(tb, err)
+
+ leaf := txscript.NewTapLeaf(txscript.BaseLeafVersion, script)
+
+ return waddrmgr.Tapscript{
+ Type: waddrmgr.TapscriptTypeFullTree,
+ ControlBlock: &txscript.ControlBlock{
+ InternalKey: pubKey,
+ },
+ Leaves: []txscript.TapLeaf{leaf},
+ }
+}
+
+// generateTestTxOut generates a test TxOut for benchmarking purposes.
+// It creates a TxOut with the provided address as the PkScript.
+func generateTestTxOut(tb testing.TB, addr address.Address) wire.TxOut {
+ tb.Helper()
+
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(tb, err)
+
+ return wire.TxOut{
+ Value: 1e8,
+ PkScript: pkScript,
+ }
+}
+
+// leaseAllOutputs leases all outputs in the wallet with unique lock IDs. This
+// is used to set up benchmarks for ListLeasedOutputs where we want to maximize
+// the N+1 query impact when comparing the new vs deprecated ListLeasedOutputs
+// APIs.
+func leaseAllOutputs(tb testing.TB, w *Wallet, outpoints []wire.OutPoint,
+ duration time.Duration) {
+
+ tb.Helper()
+
+ for i, outpoint := range outpoints {
+ lockID := wtxmgr.LockID{byte(i)}
+ _, err := w.LeaseOutput(
+ tb.Context(), lockID, outpoint, duration,
+ )
+ require.NoError(tb, err, "failed to lease output %v", outpoint)
+ }
+}
+
+// signMultipleInputs is a helper function that signs multiple transaction
+// inputs using ComputeUnlockingScript. This is useful for benchmarks that test
+// multi-input transaction signing performance.
+//
+// IMPORTANT: sigHashes should be pre-computed outside the benchmark loop to
+// avoid measuring setup time. The caller should create a single sigHashes
+// instance using all prevOuts before the benchmark loop, then pass it here.
+func signMultipleInputs(tb testing.TB, w *Wallet, tx *wire.MsgTx,
+ prevOuts []*wire.TxOut, sigHashes *txscript.TxSigHashes,
+ hashType txscript.SigHashType) {
+
+ tb.Helper()
+
+ signMultipleInputsWithTweaker(
+ tb, w, tx, prevOuts, sigHashes, hashType, nil,
+ )
+}
+
+// signMultipleInputsWithTweaker is a helper function that signs multiple
+// transaction inputs using ComputeUnlockingScript with an optional tweaker
+// function. This is useful for benchmarks that test multi-input transaction
+// signing performance with custom key tweaking.
+//
+// IMPORTANT: sigHashes should be pre-computed outside the benchmark loop to
+// avoid measuring setup time. The caller should create a single sigHashes
+// instance using all prevOuts before the benchmark loop, then pass it here.
+func signMultipleInputsWithTweaker(tb testing.TB, w *Wallet, tx *wire.MsgTx,
+ prevOuts []*wire.TxOut, sigHashes *txscript.TxSigHashes,
+ hashType txscript.SigHashType, tweaker PrivKeyTweaker) {
+
+ tb.Helper()
+
+ for j := range prevOuts {
+ params := &UnlockingScriptParams{
+ Tx: tx,
+ InputIndex: j,
+ Output: prevOuts[j],
+ SigHashes: sigHashes,
+ HashType: hashType,
+ Tweaker: tweaker,
+ }
+
+ unlockingScript, err := w.ComputeUnlockingScript(
+ tb.Context(), params,
+ )
+ require.NoError(tb, err)
+ require.NotNil(tb, unlockingScript)
+ }
+}
+
+// listAccountsDeprecated wraps the deprecated Accounts API to satisfy the same
+// contract as ListAccounts by calling Accounts API across all active key scopes
+// and aggregating the results.
+func listAccountsDeprecated(w *Wallet) (*AccountsResult, error) {
+ var (
+ allAccounts []AccountResult
+ finalBlockHash chainhash.Hash
+ finalBlockHeight int32
+ scopeManagers = w.addrStore.ActiveScopedKeyManagers()
+ )
+
+ for _, scopeMgr := range scopeManagers {
+ scope := scopeMgr.Scope()
+
+ result, err := w.Accounts(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ allAccounts = append(allAccounts, result.Accounts...)
+
+ finalBlockHash = result.CurrentBlockHash
+ finalBlockHeight = result.CurrentBlockHeight
+ }
+
+ return &AccountsResult{
+ Accounts: allAccounts,
+ CurrentBlockHash: finalBlockHash,
+ CurrentBlockHeight: finalBlockHeight,
+ }, nil
+}
+
+// listAccountsByNameDeprecated wraps the deprecated Accounts API to satisfy the
+// same contract as ListAccountsByName by calling Accounts API across all active
+// key scopes, filtering by account name, and aggregating the results.
+func listAccountsByNameDeprecated(w *Wallet,
+ name string) (*AccountsResult, error) {
+
+ var (
+ matchingAccounts []AccountResult
+ finalBlockHash chainhash.Hash
+ finalBlockHeight int32
+ scopeManagers = w.addrStore.ActiveScopedKeyManagers()
+ )
+
+ for _, scopeMgr := range scopeManagers {
+ scope := scopeMgr.Scope()
+
+ result, err := w.Accounts(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ // Filter accounts by name from this scope's results.
+ for _, account := range result.Accounts {
+ if account.AccountName == name {
+ matchingAccounts = append(
+ matchingAccounts, account,
+ )
+ }
+ }
+
+ finalBlockHash = result.CurrentBlockHash
+ finalBlockHeight = result.CurrentBlockHeight
+ }
+
+ return &AccountsResult{
+ Accounts: matchingAccounts,
+ CurrentBlockHash: finalBlockHash,
+ CurrentBlockHeight: finalBlockHeight,
+ }, nil
+}
+
+// getAccountDeprecated wraps the deprecated Accounts API to satisfy the same
+// contract as GetAccount by calling Accounts API across all active key scopes
+// and filtering by account name.
+func getAccountDeprecated(w *Wallet, scope waddrmgr.KeyScope,
+ accountName string) (*AccountResult, error) {
+
+ result, err := w.Accounts(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, account := range result.Accounts {
+ if account.AccountName == accountName {
+ return &account, nil
+ }
+ }
+
+ return nil, fmt.Errorf("%w: %s", errAccountNotFound, accountName)
+}
+
+// getBalanceDeprecated wraps the deprecated Accounts API to satisfy the same
+// contract as GetBalance by calling Accounts API across all active key scopes
+// and filtering by account name.
+func getBalanceDeprecated(w *Wallet, scope waddrmgr.KeyScope,
+ accountName string, _ int32) (btcutil.Amount, error) {
+
+ result, err := w.Accounts(scope)
+ if err != nil {
+ return 0, err
+ }
+
+ for _, account := range result.Accounts {
+ if account.AccountName == accountName {
+ // The deprecated Accounts API doesn't support
+ // confirmation filtering. It always returns total
+ // balance.
+ return account.TotalBalance, nil
+ }
+ }
+
+ return 0, fmt.Errorf("%w: %s", errAccountNotFound, accountName)
+}
+
+// listAddressesDeprecated wraps the deprecated AccountAddresses and
+// TotalReceivedForAddr APIs to satisfy the same contract as ListAddresses by
+// calling the old APIs and aggregating the results with balances.
+func listAddressesDeprecated(w *Wallet,
+ accountID uint32) ([]AddressProperty, error) {
+
+ addresses, err := w.AccountAddresses(accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ allProperties := make([]AddressProperty, 0, len(addresses))
+
+ for _, addr := range addresses {
+ balance, err := w.TotalReceivedForAddr(addr, 0)
+ if err != nil {
+ return nil, err
+ }
+
+ allProperties = append(allProperties, AddressProperty{
+ Address: addr,
+ Balance: balance,
+ })
+ }
+
+ return allProperties, nil
+}
+
+// getUtxoDeprecated wraps the deprecated FetchOutpointInfo API to satisfy the
+// same contract as GetUtxo by calling FetchOutpointInfo and performing
+// additional lookups to construct a complete Utxo struct. This demonstrates
+// the inefficiency of the old API which returns raw data requiring the caller
+// to perform multiple additional lookups.
+func getUtxoDeprecated(w *Wallet, prevOut wire.OutPoint) (*Utxo, error) {
+ _, txOut, confs, err := w.FetchOutpointInfo(&prevOut)
+ if err != nil {
+ return nil, err
+ }
+
+ // Additional lookup 1: Extract address from pkScript.
+ addr := extractAddrFromPKScript(txOut.PkScript, w.cfg.ChainParams)
+ if addr == nil {
+ return nil, ErrNotMine
+ }
+
+ // Additional lookup 2: Get address details (spendability, account,
+ // address type) from the address manager.
+ var (
+ spendable bool
+ account string
+ addrType waddrmgr.AddressType
+ )
+
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ spendable, account, addrType = w.addrStore.AddressDetails(
+ addrmgrNs, addr,
+ )
+
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Additional lookup 3: Check if the output is locked.
+ locked := w.LockedOutpoint(prevOut)
+
+ return &Utxo{
+ OutPoint: prevOut,
+ Amount: btcutil.Amount(txOut.Value),
+ PkScript: txOut.PkScript,
+ Confirmations: int32(confs),
+ Spendable: spendable,
+ Address: addr,
+ Account: account,
+ AddressType: addrType,
+ Locked: locked,
+ }, nil
+}
diff --git a/wallet/chain_mock_test.go b/wallet/chain_mock_test.go
new file mode 100644
index 0000000000..6eaa8d147c
--- /dev/null
+++ b/wallet/chain_mock_test.go
@@ -0,0 +1,254 @@
+package wallet
+
+import (
+ "context"
+ "errors"
+ "maps"
+ "sync"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcjson"
+ "github.com/btcsuite/btcd/btcutil/v2/gcs"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/chain"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+)
+
+var (
+ // errTxnAlreadyInMempool is returned when a transaction already exists
+ // in the mempool.
+ errTxnAlreadyInMempool = "txn-already-in-mempool"
+
+ // ErrNotImplemented is returned when a mock method is not implemented.
+ ErrNotImplemented = errors.New("not implemented")
+)
+
+type mockChainClient struct {
+ getBestBlockHeight int32
+ getBlockHashFunc func() (*chainhash.Hash, error)
+ getBlockHeader *wire.BlockHeader
+
+ // mempool tracks transactions that have been broadcast to simulate
+ // mempool behavior for benchmarks.
+ mempool map[chainhash.Hash]*wire.MsgTx
+
+ // mu protects concurrent reads and writes to mempool.
+ mu sync.RWMutex
+}
+
+var _ chain.Interface = (*mockChainClient)(nil)
+
+func (m *mockChainClient) Start(_ context.Context) error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.mempool == nil {
+ m.mempool = make(map[chainhash.Hash]*wire.MsgTx)
+ }
+
+ return nil
+}
+
+func (m *mockChainClient) Stop() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ m.mempool = nil
+}
+
+func (m *mockChainClient) WaitForShutdown() {}
+
+// ResetMempool clears all transactions from the mock mempool.
+func (m *mockChainClient) ResetMempool() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ m.mempool = make(map[chainhash.Hash]*wire.MsgTx)
+}
+
+func (m *mockChainClient) GetBestBlock() (*chainhash.Hash, int32, error) {
+ return nil, m.getBestBlockHeight, nil
+}
+
+func (m *mockChainClient) GetBlock(*chainhash.Hash) (*wire.MsgBlock, error) {
+ return nil, ErrNotImplemented
+}
+
+func (m *mockChainClient) GetBlockHash(int64) (*chainhash.Hash, error) {
+ if m.getBlockHashFunc != nil {
+ return m.getBlockHashFunc()
+ }
+
+ return nil, ErrNotImplemented
+}
+
+func (m *mockChainClient) GetBlockHeader(*chainhash.Hash) (*wire.BlockHeader,
+ error) {
+
+ return m.getBlockHeader, nil
+}
+
+func (m *mockChainClient) GetBlockHashes(int64,
+ int64) ([]chainhash.Hash, error) {
+
+ return nil, ErrNotImplemented
+}
+
+func (m *mockChainClient) GetBlockHeaders(
+ []chainhash.Hash) ([]*wire.BlockHeader, error) {
+
+ return nil, ErrNotImplemented
+}
+
+func (m *mockChainClient) GetCFilters([]chainhash.Hash, wire.FilterType) (
+ []*gcs.Filter, error) {
+
+ return nil, ErrNotImplemented
+}
+
+func (m *mockChainClient) GetBlocks(
+ []chainhash.Hash) ([]*wire.MsgBlock, error) {
+
+ return nil, ErrNotImplemented
+}
+
+func (m *mockChainClient) GetMempool() (map[chainhash.Hash]*wire.MsgTx, error) {
+ // Acquire read lock non-exclusively. It allows concurrent readers and
+ // blocks writers.
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+
+ // Return a shallow copy of the map to avoid TOCTOU
+ // (time-of-check-to-time-of-use) races. Returning m.mempool directly
+ // would share the map reference - after RUnlock(), concurrent writes
+ // could modify the map structure during caller's iteration causing:
+ // "fatal error: concurrent map iteration and map write".
+ // Note: This is a shallow copy - the *wire.MsgTx pointers are shared.
+ // We assume transactions are not mutated after creation.
+ result := make(map[chainhash.Hash]*wire.MsgTx, len(m.mempool))
+ maps.Copy(result, m.mempool)
+
+ return result, nil
+}
+
+func (m *mockChainClient) IsCurrent() bool {
+ return false
+}
+
+func (m *mockChainClient) GetCFilter(hash *chainhash.Hash,
+ filterType wire.FilterType) (*gcs.Filter, error) {
+
+ return nil, ErrNotImplemented
+}
+
+func (m *mockChainClient) FilterBlocks(*chain.FilterBlocksRequest) (
+ *chain.FilterBlocksResponse, error) {
+
+ return nil, ErrNotImplemented
+}
+
+func (m *mockChainClient) BlockStamp() (*waddrmgr.BlockStamp, error) {
+ return &waddrmgr.BlockStamp{
+ Height: 500000,
+ Hash: chainhash.Hash{},
+ Timestamp: time.Unix(1234, 0),
+ }, nil
+}
+
+func (m *mockChainClient) SendRawTransaction(tx *wire.MsgTx,
+ allowHighFees bool) (*chainhash.Hash, error) {
+
+ // Acquire write lock exclusively. It blocks all readers and writers.
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ txHash := tx.TxHash()
+
+ // Reject duplicate transactions to isolate the external behavior of
+ // real chain backends. This is important for reliable testing and
+ // benchmarking handling in broadcast APIs.
+ if _, exists := m.mempool[txHash]; exists {
+ return nil, chain.ErrTxAlreadyInMempool
+ }
+
+ m.mempool[txHash] = tx
+
+ return &txHash, nil
+}
+
+func (m *mockChainClient) Rescan(*chainhash.Hash, []address.Address,
+ map[wire.OutPoint]address.Address) error {
+
+ return nil
+}
+
+func (m *mockChainClient) NotifyReceived([]address.Address) error {
+ return nil
+}
+
+func (m *mockChainClient) NotifyBlocks() error {
+ return nil
+}
+
+func (m *mockChainClient) Notifications() <-chan interface{} {
+ return nil
+}
+
+func (m *mockChainClient) BackEnd() string {
+ return "mock"
+}
+
+// TestMempoolAcceptCmd returns result of mempool acceptance tests indicating
+// if raw transaction(s) would be accepted by mempool.
+//
+// NOTE: This is part of the chain.Interface interface.
+func (m *mockChainClient) TestMempoolAccept(txns []*wire.MsgTx,
+ maxFeeRate float64) ([]*btcjson.TestMempoolAcceptResult, error) {
+
+ // Acquire read lock non-exclusively. It allows concurrent readers and
+ // blocks writers.
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+
+ // Return acceptance result for each transaction.
+ results := make([]*btcjson.TestMempoolAcceptResult, len(txns))
+ for i := range txns {
+ txHash := txns[i].TxHash()
+ result := &btcjson.TestMempoolAcceptResult{
+ Txid: txHash.String(),
+ }
+
+ // Check if transaction already exists in mempool.
+ if _, exists := m.mempool[txHash]; exists {
+ result.Allowed = false
+ result.RejectReason = errTxnAlreadyInMempool
+ } else {
+ result.Allowed = true
+ }
+
+ results[i] = result
+ }
+
+ return results, nil
+}
+
+// SubmitPackage is part of the chain.Interface interface.
+func (m *mockChainClient) SubmitPackage(txns []*wire.MsgTx,
+ maxFeeRate *float64) (*btcjson.SubmitPackageResult, error) {
+
+ return &btcjson.SubmitPackageResult{}, nil
+}
+
+func (m *mockChainClient) MapRPCErr(err error) error {
+ if err == nil {
+ return nil
+ }
+
+ if err.Error() == errTxnAlreadyInMempool {
+ return chain.ErrTxAlreadyInMempool
+ }
+
+ return err
+}
diff --git a/wallet/chainntfns.go b/wallet/chainntfns.go
deleted file mode 100644
index bdd99c60f6..0000000000
--- a/wallet/chainntfns.go
+++ /dev/null
@@ -1,540 +0,0 @@
-// Copyright (c) 2013-2015 The btcsuite developers
-// Use of this source code is governed by an ISC
-// license that can be found in the LICENSE file.
-
-package wallet
-
-import (
- "bytes"
- "time"
-
- "github.com/btcsuite/btcd/chainhash/v2"
- "github.com/btcsuite/btcd/txscript/v2"
- "github.com/btcsuite/btcd/wire/v2"
- "github.com/btcsuite/btcwallet/chain"
- "github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/btcsuite/btcwallet/walletdb"
- "github.com/btcsuite/btcwallet/wtxmgr"
-)
-
-const (
- // birthdayBlockDelta is the maximum time delta allowed between our
- // birthday timestamp and our birthday block's timestamp when searching
- // for a better birthday block candidate (if possible).
- birthdayBlockDelta = 2 * time.Hour
-)
-
-func (w *Wallet) handleChainNotifications() {
- defer w.wg.Done()
-
- chainClient, err := w.requireChainClient()
- if err != nil {
- log.Errorf("handleChainNotifications called without RPC client")
- return
- }
-
- catchUpHashes := func(w *Wallet, client chain.Interface,
- height int32) error {
- // TODO(aakselrod): There's a race condition here, which
- // happens when a reorg occurs between the
- // rescanProgress notification and the last GetBlockHash
- // call. The solution when using btcd is to make btcd
- // send blockconnected notifications with each block
- // the way Neutrino does, and get rid of the loop. The
- // other alternative is to check the final hash and,
- // if it doesn't match the original hash returned by
- // the notification, to roll back and restart the
- // rescan.
- log.Infof("Catching up block hashes to height %d, this"+
- " might take a while", height)
- err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
-
- startBlock := w.Manager.SyncedTo()
-
- for i := startBlock.Height + 1; i <= height; i++ {
- hash, err := client.GetBlockHash(int64(i))
- if err != nil {
- return err
- }
- header, err := chainClient.GetBlockHeader(hash)
- if err != nil {
- return err
- }
-
- bs := waddrmgr.BlockStamp{
- Height: i,
- Hash: *hash,
- Timestamp: header.Timestamp,
- }
- err = w.Manager.SetSyncedTo(ns, &bs)
- if err != nil {
- return err
- }
- }
- return nil
- })
- if err != nil {
- log.Errorf("Failed to update address manager "+
- "sync state for height %d: %v", height, err)
- }
-
- log.Info("Done catching up block hashes")
- return err
- }
-
- waitForSync := func(birthdayBlock *waddrmgr.BlockStamp) error {
- // We start with a retry delay of 0 to execute the first attempt
- // immediately.
- var retryDelay time.Duration
- for {
- select {
- case <-time.After(retryDelay):
- // Set the delay to the configured value in case
- // we actually need to re-try.
- retryDelay = w.syncRetryInterval
-
- // Sync may be interrupted by actions such as
- // locking the wallet. Try again after waiting a
- // bit.
- err = w.syncWithChain(birthdayBlock)
- if err != nil {
- if w.ShuttingDown() {
- return ErrWalletShuttingDown
- }
-
- log.Errorf("Unable to synchronize "+
- "wallet to chain, trying "+
- "again in %s: %v",
- w.syncRetryInterval, err)
-
- continue
- }
-
- return nil
-
- case <-w.quitChan():
- return ErrWalletShuttingDown
- }
- }
- }
-
- for {
- select {
- case n, ok := <-chainClient.Notifications():
- if !ok {
- return
- }
-
- var notificationName string
- var err error
- switch n := n.(type) {
- case chain.ClientConnected:
- // Before attempting to sync with our backend,
- // we'll make sure that our birthday block has
- // been set correctly to potentially prevent
- // missing relevant events.
- birthdayStore := &walletBirthdayStore{
- db: w.db,
- manager: w.Manager,
- }
- birthdayBlock, err := birthdaySanityCheck(
- chainClient, birthdayStore,
- )
- if err != nil && !waddrmgr.IsError(
- err, waddrmgr.ErrBirthdayBlockNotSet,
- ) {
-
- log.Errorf("Unable to sanity check "+
- "wallet birthday block: %v",
- err)
- }
-
- err = waitForSync(birthdayBlock)
- if err != nil {
- log.Infof("Stopped waiting for wallet "+
- "sync due to error: %v", err)
-
- return
- }
-
- case chain.BlockConnected:
- err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- return w.connectBlock(tx, wtxmgr.BlockMeta(n))
- })
- notificationName = "block connected"
- case chain.BlockDisconnected:
- err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- return w.disconnectBlock(tx, wtxmgr.BlockMeta(n))
- })
- notificationName = "block disconnected"
- case chain.RelevantTx:
- err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- return w.addRelevantTx(tx, n.TxRecord, n.Block)
- })
- notificationName = "relevant transaction"
- case chain.FilteredBlockConnected:
- // Atomically update for the whole block.
- if len(n.RelevantTxs) > 0 {
- err = walletdb.Update(w.db, func(
- tx walletdb.ReadWriteTx) error {
- var err error
- for _, rec := range n.RelevantTxs {
- err = w.addRelevantTx(tx, rec,
- n.Block)
- if err != nil {
- return err
- }
- }
- return nil
- })
- }
- notificationName = "filtered block connected"
-
- // The following require some database maintenance, but also
- // need to be reported to the wallet's rescan goroutine.
- case *chain.RescanProgress:
- err = catchUpHashes(w, chainClient, n.Height)
- notificationName = "rescan progress"
- select {
- case w.rescanNotifications <- n:
- case <-w.quitChan():
- return
- }
- case *chain.RescanFinished:
- err = catchUpHashes(w, chainClient, n.Height)
- notificationName = "rescan finished"
- w.SetChainSynced(true)
- select {
- case w.rescanNotifications <- n:
- case <-w.quitChan():
- return
- }
- }
- if err != nil {
- // If we received a block connected notification
- // while rescanning, then we can ignore logging
- // the error as we'll properly catch up once we
- // process the RescanFinished notification.
- if notificationName == "block connected" &&
- waddrmgr.IsError(err, waddrmgr.ErrBlockNotFound) &&
- !w.ChainSynced() {
-
- log.Debugf("Received block connected "+
- "notification for height %v "+
- "while rescanning",
- n.(chain.BlockConnected).Height)
- continue
- }
-
- log.Errorf("Unable to process chain backend "+
- "%v notification: %v", notificationName,
- err)
- }
- case <-w.quit:
- return
- }
- }
-}
-
-// connectBlock handles a chain server notification by marking a wallet
-// that's currently in-sync with the chain server as being synced up to
-// the passed block.
-func (w *Wallet) connectBlock(dbtx walletdb.ReadWriteTx, b wtxmgr.BlockMeta) error {
- addrmgrNs := dbtx.ReadWriteBucket(waddrmgrNamespaceKey)
-
- bs := waddrmgr.BlockStamp{
- Height: b.Height,
- Hash: b.Hash,
- Timestamp: b.Time,
- }
- err := w.Manager.SetSyncedTo(addrmgrNs, &bs)
- if err != nil {
- return err
- }
-
- // Notify interested clients of the connected block.
- //
- // TODO: move all notifications outside of the database transaction.
- w.NtfnServer.notifyAttachedBlock(dbtx, &b)
- return nil
-}
-
-// disconnectBlock handles a chain server reorganize by rolling back all
-// block history from the reorged block for a wallet in-sync with the chain
-// server.
-func (w *Wallet) disconnectBlock(dbtx walletdb.ReadWriteTx, b wtxmgr.BlockMeta) error {
- addrmgrNs := dbtx.ReadWriteBucket(waddrmgrNamespaceKey)
- txmgrNs := dbtx.ReadWriteBucket(wtxmgrNamespaceKey)
-
- if !w.ChainSynced() {
- return nil
- }
-
- // Disconnect the removed block and all blocks after it if we know about
- // the disconnected block. Otherwise, the block is in the future.
- if b.Height <= w.Manager.SyncedTo().Height {
- hash, err := w.Manager.BlockHash(addrmgrNs, b.Height)
- if err != nil {
- return err
- }
- if bytes.Equal(hash[:], b.Hash[:]) {
- bs := waddrmgr.BlockStamp{
- Height: b.Height - 1,
- }
- hash, err = w.Manager.BlockHash(addrmgrNs, bs.Height)
- if err != nil {
- return err
- }
- b.Hash = *hash
-
- client := w.ChainClient()
- header, err := client.GetBlockHeader(hash)
- if err != nil {
- return err
- }
-
- bs.Timestamp = header.Timestamp
- err = w.Manager.SetSyncedTo(addrmgrNs, &bs)
- if err != nil {
- return err
- }
-
- err = w.TxStore.Rollback(txmgrNs, b.Height)
- if err != nil {
- return err
- }
- }
- }
-
- // Notify interested clients of the disconnected block.
- w.NtfnServer.notifyDetachedBlock(&b.Hash)
-
- return nil
-}
-
-func (w *Wallet) addRelevantTx(dbtx walletdb.ReadWriteTx, rec *wtxmgr.TxRecord,
- block *wtxmgr.BlockMeta) error {
-
- addrmgrNs := dbtx.ReadWriteBucket(waddrmgrNamespaceKey)
- txmgrNs := dbtx.ReadWriteBucket(wtxmgrNamespaceKey)
-
- // At the moment all notified transactions are assumed to actually be
- // relevant. This assumption will not hold true when SPV support is
- // added, but until then, simply insert the transaction because there
- // should either be one or more relevant inputs or outputs.
- exists, err := w.TxStore.InsertTxCheckIfExists(txmgrNs, rec, block)
- if err != nil {
- return err
- }
-
- // If the transaction has already been recorded, we can return early.
- // Note: Returning here is safe as we're within the context of an atomic
- // database transaction, so we don't need to worry about the MarkUsed
- // calls below.
- if exists {
- return nil
- }
-
- // Check every output to determine whether it is controlled by a wallet
- // key. If so, mark the output as a credit.
- for i, output := range rec.MsgTx.TxOut {
- _, addrs, _, err := txscript.ExtractPkScriptAddrs(output.PkScript,
- w.chainParams)
- if err != nil {
- // Non-standard outputs are skipped.
- log.Warnf("Cannot extract non-std pkScript=%x",
- output.PkScript)
-
- continue
- }
-
- for _, addr := range addrs {
- ma, err := w.Manager.Address(addrmgrNs, addr)
-
- switch {
- // Missing addresses are skipped.
- case waddrmgr.IsError(err, waddrmgr.ErrAddressNotFound):
- continue
-
- // Other errors should be propagated.
- case err != nil:
- return err
- }
-
- // Prevent addresses from non-default scopes to be
- // detected here. We don't watch funds sent to
- // non-default scopes in other places either, so
- // detecting them here would mean we'd also not properly
- // detect them as spent later.
- scopedManager, _, err := w.Manager.AddrAccount(
- addrmgrNs, addr,
- )
- if err != nil {
- return err
- }
- if !waddrmgr.IsDefaultScope(scopedManager.Scope()) {
- log.Debugf("Skipping non-default scope "+
- "address %v", addr)
-
- continue
- }
-
- // TODO: Credits should be added with the
- // account they belong to, so wtxmgr is able to
- // track per-account balances.
- err = w.TxStore.AddCredit(
- txmgrNs, rec, block, uint32(i), ma.Internal(),
- )
- if err != nil {
- return err
- }
- err = w.Manager.MarkUsed(addrmgrNs, addr)
- if err != nil {
- return err
- }
- log.Debugf("Marked address %v used", addr)
- }
- }
-
- // Send notification of mined or unmined transaction to any interested
- // clients.
- //
- // TODO: Avoid the extra db hits.
- if block == nil {
- w.NtfnServer.notifyUnminedTransaction(dbtx, txmgrNs, rec.Hash)
- } else {
- w.NtfnServer.notifyMinedTransaction(
- dbtx, txmgrNs, rec.Hash, block,
- )
- }
-
- return nil
-}
-
-// chainConn is an interface that abstracts the chain connection logic required
-// to perform a wallet's birthday block sanity check.
-type chainConn interface {
- // GetBestBlock returns the hash and height of the best block known to
- // the backend.
- GetBestBlock() (*chainhash.Hash, int32, error)
-
- // GetBlockHash returns the hash of the block with the given height.
- GetBlockHash(int64) (*chainhash.Hash, error)
-
- // GetBlockHeader returns the header for the block with the given hash.
- GetBlockHeader(*chainhash.Hash) (*wire.BlockHeader, error)
-}
-
-// birthdayStore is an interface that abstracts the wallet's sync-related
-// information required to perform a birthday block sanity check.
-type birthdayStore interface {
- // Birthday returns the birthday timestamp of the wallet.
- Birthday() time.Time
-
- // BirthdayBlock returns the birthday block of the wallet. The boolean
- // returned should signal whether the wallet has already verified the
- // correctness of its birthday block.
- BirthdayBlock() (waddrmgr.BlockStamp, bool, error)
-
- // SetBirthdayBlock updates the birthday block of the wallet to the
- // given block. The boolean can be used to signal whether this block
- // should be sanity checked the next time the wallet starts.
- //
- // NOTE: This should also set the wallet's synced tip to reflect the new
- // birthday block. This will allow the wallet to rescan from this point
- // to detect any potentially missed events.
- SetBirthdayBlock(waddrmgr.BlockStamp) error
-}
-
-// walletBirthdayStore is a wrapper around the wallet's database and address
-// manager that satisfies the birthdayStore interface.
-type walletBirthdayStore struct {
- db walletdb.DB
- manager *waddrmgr.Manager
-}
-
-var _ birthdayStore = (*walletBirthdayStore)(nil)
-
-// Birthday returns the birthday timestamp of the wallet.
-func (s *walletBirthdayStore) Birthday() time.Time {
- return s.manager.Birthday()
-}
-
-// BirthdayBlock returns the birthday block of the wallet.
-func (s *walletBirthdayStore) BirthdayBlock() (waddrmgr.BlockStamp, bool, error) {
- var (
- birthdayBlock waddrmgr.BlockStamp
- birthdayBlockVerified bool
- )
-
- err := walletdb.View(s.db, func(tx walletdb.ReadTx) error {
- var err error
- ns := tx.ReadBucket(waddrmgrNamespaceKey)
- birthdayBlock, birthdayBlockVerified, err = s.manager.BirthdayBlock(ns)
- return err
- })
-
- return birthdayBlock, birthdayBlockVerified, err
-}
-
-// SetBirthdayBlock updates the birthday block of the wallet to the
-// given block. The boolean can be used to signal whether this block
-// should be sanity checked the next time the wallet starts.
-//
-// NOTE: This should also set the wallet's synced tip to reflect the new
-// birthday block. This will allow the wallet to rescan from this point
-// to detect any potentially missed events.
-func (s *walletBirthdayStore) SetBirthdayBlock(block waddrmgr.BlockStamp) error {
- return walletdb.Update(s.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- err := s.manager.SetBirthdayBlock(ns, block, true)
- if err != nil {
- return err
- }
- return s.manager.SetSyncedTo(ns, &block)
- })
-}
-
-// birthdaySanityCheck is a helper function that ensures a birthday block
-// correctly reflects the birthday timestamp within a reasonable timestamp
-// delta. It's intended to be run after the wallet establishes its connection
-// with the backend, but before it begins syncing. This is done as the second
-// part to the wallet's address manager migration where we populate the birthday
-// block to ensure we do not miss any relevant events throughout rescans.
-// waddrmgr.ErrBirthdayBlockNotSet is returned if the birthday block has not
-// been set yet.
-func birthdaySanityCheck(chainConn chainConn,
- birthdayStore birthdayStore) (*waddrmgr.BlockStamp, error) {
-
- // We'll start by fetching our wallet's birthday timestamp and block.
- birthdayTimestamp := birthdayStore.Birthday()
- birthdayBlock, birthdayBlockVerified, err := birthdayStore.BirthdayBlock()
- if err != nil {
- return nil, err
- }
-
- // If the birthday block has already been verified to be correct, we can
- // exit our sanity check to prevent potentially fetching a better
- // candidate.
- if birthdayBlockVerified {
- log.Debugf("Birthday block has already been verified: "+
- "height=%d, hash=%v", birthdayBlock.Height,
- birthdayBlock.Hash)
-
- return &birthdayBlock, nil
- }
-
- // Otherwise, we'll attempt to locate a better one now that we have
- // access to the chain.
- newBirthdayBlock, err := locateBirthdayBlock(chainConn, birthdayTimestamp)
- if err != nil {
- return nil, err
- }
-
- if err := birthdayStore.SetBirthdayBlock(*newBirthdayBlock); err != nil {
- return nil, err
- }
-
- return newBirthdayBlock, nil
-}
diff --git a/wallet/chainntfns_test.go b/wallet/chainntfns_test.go
deleted file mode 100644
index 1d44ad28a1..0000000000
--- a/wallet/chainntfns_test.go
+++ /dev/null
@@ -1,310 +0,0 @@
-package wallet
-
-import (
- "fmt"
- "reflect"
- "testing"
- "time"
-
- "github.com/btcsuite/btcd/chaincfg/v2"
- "github.com/btcsuite/btcd/chainhash/v2"
- "github.com/btcsuite/btcd/wire/v2"
- "github.com/btcsuite/btcwallet/waddrmgr"
- _ "github.com/btcsuite/btcwallet/walletdb/bdb"
-)
-
-const (
- // defaultBlockInterval is the default time interval between any two
- // blocks in a mocked chain.
- defaultBlockInterval = 10 * time.Minute
-)
-
-var (
- // chainParams are the chain parameters used throughout the wallet
- // tests.
- chainParams = chaincfg.MainNetParams
-)
-
-// mockChainConn is a mock in-memory implementation of the chainConn interface
-// that will be used for the birthday block sanity check tests. The struct is
-// capable of being backed by a chain in order to reproduce real-world
-// scenarios.
-type mockChainConn struct {
- chainTip uint32
- blockHashes map[uint32]chainhash.Hash
- blocks map[chainhash.Hash]*wire.MsgBlock
-}
-
-var _ chainConn = (*mockChainConn)(nil)
-
-// createMockChainConn creates a new mock chain connection backed by a chain
-// with N blocks. Each block has a timestamp that is exactly blockInterval after
-// the previous block's timestamp.
-func createMockChainConn(genesis *wire.MsgBlock, n uint32,
- blockInterval time.Duration) *mockChainConn {
-
- c := &mockChainConn{
- chainTip: n,
- blockHashes: make(map[uint32]chainhash.Hash),
- blocks: make(map[chainhash.Hash]*wire.MsgBlock),
- }
-
- genesisHash := genesis.BlockHash()
- c.blockHashes[0] = genesisHash
- c.blocks[genesisHash] = genesis
-
- for i := uint32(1); i <= n; i++ {
- prevTimestamp := c.blocks[c.blockHashes[i-1]].Header.Timestamp
- block := &wire.MsgBlock{
- Header: wire.BlockHeader{
- Timestamp: prevTimestamp.Add(blockInterval),
- },
- }
-
- blockHash := block.BlockHash()
- c.blockHashes[i] = blockHash
- c.blocks[blockHash] = block
- }
-
- return c
-}
-
-// GetBestBlock returns the hash and height of the best block known to the
-// backend.
-func (c *mockChainConn) GetBestBlock() (*chainhash.Hash, int32, error) {
- bestHash, ok := c.blockHashes[c.chainTip]
- if !ok {
- return nil, 0, fmt.Errorf("block with height %d not found",
- c.chainTip)
- }
-
- return &bestHash, int32(c.chainTip), nil
-}
-
-// GetBlockHash returns the hash of the block with the given height.
-func (c *mockChainConn) GetBlockHash(height int64) (*chainhash.Hash, error) {
- hash, ok := c.blockHashes[uint32(height)]
- if !ok {
- return nil, fmt.Errorf("block with height %d not found", height)
- }
-
- return &hash, nil
-}
-
-// GetBlockHeader returns the header for the block with the given hash.
-func (c *mockChainConn) GetBlockHeader(hash *chainhash.Hash) (*wire.BlockHeader, error) {
- block, ok := c.blocks[*hash]
- if !ok {
- return nil, fmt.Errorf("header for block %v not found", hash)
- }
-
- return &block.Header, nil
-}
-
-// mockBirthdayStore is a mock in-memory implementation of the birthdayStore interface
-// that will be used for the birthday block sanity check tests.
-type mockBirthdayStore struct {
- birthday time.Time
- birthdayBlock *waddrmgr.BlockStamp
- birthdayBlockVerified bool
- syncedTo waddrmgr.BlockStamp
-}
-
-var _ birthdayStore = (*mockBirthdayStore)(nil)
-
-// Birthday returns the birthday timestamp of the wallet.
-func (s *mockBirthdayStore) Birthday() time.Time {
- return s.birthday
-}
-
-// BirthdayBlock returns the birthday block of the wallet.
-func (s *mockBirthdayStore) BirthdayBlock() (waddrmgr.BlockStamp, bool, error) {
- if s.birthdayBlock == nil {
- err := waddrmgr.ManagerError{
- ErrorCode: waddrmgr.ErrBirthdayBlockNotSet,
- }
- return waddrmgr.BlockStamp{}, false, err
- }
-
- return *s.birthdayBlock, s.birthdayBlockVerified, nil
-}
-
-// SetBirthdayBlock updates the birthday block of the wallet to the given block.
-// The boolean can be used to signal whether this block should be sanity checked
-// the next time the wallet starts.
-func (s *mockBirthdayStore) SetBirthdayBlock(block waddrmgr.BlockStamp) error {
- s.birthdayBlock = &block
- s.birthdayBlockVerified = true
- s.syncedTo = block
- return nil
-}
-
-// TestBirthdaySanityCheckEmptyBirthdayBlock ensures that a sanity check is not
-// done if the birthday block does not exist in the first place.
-func TestBirthdaySanityCheckEmptyBirthdayBlock(t *testing.T) {
- t.Parallel()
-
- chainConn := &mockChainConn{}
-
- // Our birthday store will reflect that we don't have a birthday block
- // set, so we should not attempt a sanity check.
- birthdayStore := &mockBirthdayStore{}
-
- birthdayBlock, err := birthdaySanityCheck(chainConn, birthdayStore)
- if !waddrmgr.IsError(err, waddrmgr.ErrBirthdayBlockNotSet) {
- t.Fatalf("expected ErrBirthdayBlockNotSet, got %v", err)
- }
-
- if birthdayBlock != nil {
- t.Fatalf("expected birthday block to be nil due to not being "+
- "set, got %v", *birthdayBlock)
- }
-}
-
-// TestBirthdaySanityCheckVerifiedBirthdayBlock ensures that a sanity check is
-// not performed if the birthday block has already been verified.
-func TestBirthdaySanityCheckVerifiedBirthdayBlock(t *testing.T) {
- t.Parallel()
-
- const chainTip = 5000
- chainConn := createMockChainConn(
- chainParams.GenesisBlock, chainTip, defaultBlockInterval,
- )
- expectedBirthdayBlock := waddrmgr.BlockStamp{Height: 1337}
-
- // Our birthday store reflects that our birthday block has already been
- // verified and should not require a sanity check.
- birthdayStore := &mockBirthdayStore{
- birthdayBlock: &expectedBirthdayBlock,
- birthdayBlockVerified: true,
- syncedTo: waddrmgr.BlockStamp{
- Height: chainTip,
- },
- }
-
- // Now, we'll run the sanity check. We should see that the birthday
- // block hasn't changed.
- birthdayBlock, err := birthdaySanityCheck(chainConn, birthdayStore)
- if err != nil {
- t.Fatalf("unable to sanity check birthday block: %v", err)
- }
- if !reflect.DeepEqual(*birthdayBlock, expectedBirthdayBlock) {
- t.Fatalf("expected birthday block %v, got %v",
- expectedBirthdayBlock, birthdayBlock)
- }
-
- // To ensure the sanity check didn't proceed, we'll check our synced to
- // height, as this value should have been modified if a new candidate
- // was found.
- if birthdayStore.syncedTo.Height != chainTip {
- t.Fatalf("expected synced height remain the same (%d), got %d",
- chainTip, birthdayStore.syncedTo.Height)
- }
-}
-
-// TestBirthdaySanityCheckLowerEstimate ensures that we can properly locate a
-// better birthday block candidate if our estimate happens to be too far back in
-// the chain.
-func TestBirthdaySanityCheckLowerEstimate(t *testing.T) {
- t.Parallel()
-
- // We'll start by defining our birthday timestamp to be around the
- // timestamp of the 1337th block.
- genesisTimestamp := chainParams.GenesisBlock.Header.Timestamp
- birthday := genesisTimestamp.Add(1337 * defaultBlockInterval)
-
- // We'll establish a connection to a mock chain of 5000 blocks.
- chainConn := createMockChainConn(
- chainParams.GenesisBlock, 5000, defaultBlockInterval,
- )
-
- // Our birthday store will reflect that our birthday block is currently
- // set as the genesis block. This value is too low and should be
- // adjusted by the sanity check.
- birthdayStore := &mockBirthdayStore{
- birthday: birthday,
- birthdayBlock: &waddrmgr.BlockStamp{
- Hash: *chainParams.GenesisHash,
- Height: 0,
- Timestamp: genesisTimestamp,
- },
- birthdayBlockVerified: false,
- syncedTo: waddrmgr.BlockStamp{
- Height: 5000,
- },
- }
-
- // We'll perform the sanity check and determine whether we were able to
- // find a better birthday block candidate.
- birthdayBlock, err := birthdaySanityCheck(chainConn, birthdayStore)
- if err != nil {
- t.Fatalf("unable to sanity check birthday block: %v", err)
- }
- if birthday.Sub(birthdayBlock.Timestamp) >= birthdayBlockDelta {
- t.Fatalf("expected birthday block timestamp=%v to be within "+
- "%v of birthday timestamp=%v", birthdayBlock.Timestamp,
- birthdayBlockDelta, birthday)
- }
-
- // Finally, our synced to height should now reflect our new birthday
- // block to ensure the wallet doesn't miss any events from this point
- // forward.
- if !reflect.DeepEqual(birthdayStore.syncedTo, *birthdayBlock) {
- t.Fatalf("expected syncedTo and birthday block to match: "+
- "%v vs %v", birthdayStore.syncedTo, birthdayBlock)
- }
-}
-
-// TestBirthdaySanityCheckHigherEstimate ensures that we can properly locate a
-// better birthday block candidate if our estimate happens to be too far in the
-// chain.
-func TestBirthdaySanityCheckHigherEstimate(t *testing.T) {
- t.Parallel()
-
- // We'll start by defining our birthday timestamp to be around the
- // timestamp of the 1337th block.
- genesisTimestamp := chainParams.GenesisBlock.Header.Timestamp
- birthday := genesisTimestamp.Add(1337 * defaultBlockInterval)
-
- // We'll establish a connection to a mock chain of 5000 blocks.
- chainConn := createMockChainConn(
- chainParams.GenesisBlock, 5000, defaultBlockInterval,
- )
-
- // Our birthday store will reflect that our birthday block is currently
- // set as the chain tip. This value is too high and should be adjusted
- // by the sanity check.
- bestBlock := chainConn.blocks[chainConn.blockHashes[5000]]
- birthdayStore := &mockBirthdayStore{
- birthday: birthday,
- birthdayBlock: &waddrmgr.BlockStamp{
- Hash: bestBlock.BlockHash(),
- Height: 5000,
- Timestamp: bestBlock.Header.Timestamp,
- },
- birthdayBlockVerified: false,
- syncedTo: waddrmgr.BlockStamp{
- Height: 5000,
- },
- }
-
- // We'll perform the sanity check and determine whether we were able to
- // find a better birthday block candidate.
- birthdayBlock, err := birthdaySanityCheck(chainConn, birthdayStore)
- if err != nil {
- t.Fatalf("unable to sanity check birthday block: %v", err)
- }
- if birthday.Sub(birthdayBlock.Timestamp) >= birthdayBlockDelta {
- t.Fatalf("expected birthday block timestamp=%v to be within "+
- "%v of birthday timestamp=%v", birthdayBlock.Timestamp,
- birthdayBlockDelta, birthday)
- }
-
- // Finally, our synced to height should now reflect our new birthday
- // block to ensure the wallet doesn't miss any events from this point
- // forward.
- if !reflect.DeepEqual(birthdayStore.syncedTo, *birthdayBlock) {
- t.Fatalf("expected syncedTo and birthday block to match: "+
- "%v vs %v", birthdayStore.syncedTo, birthdayBlock)
- }
-}
diff --git a/wallet/common_test.go b/wallet/common_test.go
new file mode 100644
index 0000000000..07d528ad46
--- /dev/null
+++ b/wallet/common_test.go
@@ -0,0 +1,246 @@
+package wallet
+
+import (
+ "context"
+ "errors"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+ _ "github.com/btcsuite/btcwallet/walletdb/bdb"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ errDBMock = errors.New("db error")
+ errMock = errors.New("mock error")
+ errChainMock = errors.New("chain error")
+ errPutMock = errors.New("put error")
+ errLockMock = errors.New("lock fail")
+ errDBFail = errors.New("db fail")
+ errDeriveFail = errors.New("derive fail")
+ errLoadStateFail = errors.New("load state fail")
+ errRollbackFail = errors.New("rollback fail")
+ errFetchFail = errors.New("fetch fail")
+ errCFilterFail = errors.New("cfilter fail")
+ errActiveMgrsFail = errors.New("active managers fail")
+
+ errSetFail = errors.New("set fail")
+ errOther = errors.New("other error")
+ errBroadcast = errors.New("broadcast fail")
+ errScan = errors.New("scan fail")
+ errBlocks = errors.New("blocks fail")
+ errDBInsert = errors.New("db insert fail")
+ errBestBlock = errors.New("best block fail")
+ errAddr = errors.New("addr fail")
+ errInsert = errors.New("insert fail")
+ errManager = errors.New("manager fail")
+ errUtxo = errors.New("utxo fail")
+ errGetBlocks = errors.New("get blocks fail")
+ errBlockHash = errors.New("block hash fail")
+ errSetSync = errors.New("set sync fail")
+ errRemote = errors.New("remote fail")
+ errNotify = errors.New("notify fail")
+ errHashes = errors.New("hashes fail")
+ errHeaders = errors.New("headers fail")
+ errHeader = errors.New("header fail")
+)
+
+var (
+ // chainParams are the chain parameters used throughout the wallet
+ // tests.
+ chainParams = chaincfg.RegressionNetParams
+)
+
+// setupTestDB creates a temporary database for testing.
+func setupTestDB(t *testing.T) (walletdb.DB, func()) {
+ t.Helper()
+
+ f, err := os.CreateTemp(t.TempDir(), "wallet-test-*.db")
+ require.NoError(t, err)
+
+ dbPath := f.Name()
+ require.NoError(t, f.Close())
+ require.NoError(t, os.Remove(dbPath))
+
+ db, err := walletdb.Create("bdb", dbPath, true, time.Second*10, false)
+ require.NoError(t, err)
+
+ cleanup := func() {
+ _ = db.Close()
+ _ = os.Remove(dbPath)
+ }
+
+ // Create buckets.
+ err = walletdb.Update(db, func(tx walletdb.ReadWriteTx) error {
+ _, err := tx.CreateTopLevelBucket(waddrmgrNamespaceKey)
+ if err != nil {
+ return err
+ }
+
+ _, err = tx.CreateTopLevelBucket(wtxmgrNamespaceKey)
+
+ return err
+ })
+ require.NoError(t, err)
+
+ return db, cleanup
+}
+
+// mockWalletDeps holds the mocked dependencies for the Wallet.
+type mockWalletDeps struct {
+ addrStore *mockAddrStore
+ txStore *mockTxStore
+ syncer *mockChainSyncer
+ chain *mockChain
+ addr *mockManagedAddress
+ accountManager *mockAccountStore
+ pubKeyAddr *mockManagedPubKeyAddr
+ taprootAddr *mockManagedTaprootScriptAddress
+}
+
+// createTestWalletWithMocks creates a Wallet instance with mocked
+// dependencies. It returns the wallet and the struct holding the mocks for
+// assertion.
+func createTestWalletWithMocks(t *testing.T) (*Wallet, *mockWalletDeps) {
+ t.Helper()
+
+ db, cleanup := setupTestDB(t)
+ t.Cleanup(cleanup)
+
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ mockSyncer := &mockChainSyncer{}
+ mockChain := &mockChain{}
+ mockAddr := &mockManagedAddress{}
+ mockAccountManager := &mockAccountStore{}
+ mockPubKeyAddr := &mockManagedPubKeyAddr{}
+ mockTaprootAddr := &mockManagedTaprootScriptAddress{}
+
+ ctx, cancel := context.WithCancel(t.Context())
+
+ w := &Wallet{
+ addrStore: mockAddrStore,
+ txStore: mockTxStore,
+ sync: mockSyncer,
+ state: newWalletState(mockSyncer),
+ lifetimeCtx: ctx,
+ cancel: cancel,
+ requestChan: make(chan any, 1),
+ lockTimer: time.NewTimer(time.Hour),
+ birthdayBlock: waddrmgr.BlockStamp{
+ Height: 100,
+ },
+ cfg: Config{
+ DB: db,
+ Chain: mockChain,
+ ChainParams: &chainParams,
+ },
+ }
+
+ // Stop the timer immediately to avoid leaks.
+ w.lockTimer.Stop()
+
+ deps := &mockWalletDeps{
+ addrStore: mockAddrStore,
+ txStore: mockTxStore,
+ syncer: mockSyncer,
+ chain: mockChain,
+ addr: mockAddr,
+ accountManager: mockAccountManager,
+ pubKeyAddr: mockPubKeyAddr,
+ taprootAddr: mockTaprootAddr,
+ }
+
+ t.Cleanup(func() {
+ mockAddrStore.AssertExpectations(t)
+ mockTxStore.AssertExpectations(t)
+ mockSyncer.AssertExpectations(t)
+ mockChain.AssertExpectations(t)
+ mockAddr.AssertExpectations(t)
+ mockAccountManager.AssertExpectations(t)
+ mockPubKeyAddr.AssertExpectations(t)
+ mockTaprootAddr.AssertExpectations(t)
+ })
+
+ return w, deps
+}
+
+// createStartedWalletWithMocks creates a fully started and unlocked Wallet
+// instance with mocked dependencies.
+func createStartedWalletWithMocks(t *testing.T) (*Wallet, *mockWalletDeps) {
+ t.Helper()
+
+ w, deps := createTestWalletWithMocks(t)
+
+ // Mock the birthday block to be present.
+ deps.addrStore.On("BirthdayBlock", mock.Anything).
+ Return(waddrmgr.BlockStamp{}, true, nil).
+ Once()
+
+ // Allow SyncedTo to be called any number of times (background sync).
+ deps.addrStore.On("SyncedTo").
+ Return(waddrmgr.BlockStamp{Height: 1}).
+ Maybe()
+
+ // Mock account loading.
+ deps.addrStore.On("ActiveScopedKeyManagers").
+ Return([]waddrmgr.AccountStore{deps.accountManager}).
+ Once()
+
+ deps.accountManager.On("LastAccount", mock.Anything).
+ Return(uint32(0), nil).
+ Once()
+
+ deps.accountManager.On("AccountProperties", mock.Anything, uint32(0)).
+ Return(&waddrmgr.AccountProperties{
+ AccountNumber: 0,
+ AccountName: "default",
+ }, nil).
+ Once()
+
+ // Mock expired lock deletion.
+ deps.txStore.On("DeleteExpiredLockedOutputs", mock.Anything).
+ Return(nil).
+ Once()
+
+ // Mock the syncer run.
+ deps.syncer.On("run", mock.Anything).Return(nil).Once()
+
+ // Start the wallet.
+ require.NoError(t, w.Start(t.Context()))
+
+ t.Cleanup(func() {
+ ctx, cancel := context.WithTimeout(
+ context.Background(), 5*time.Second,
+ )
+ defer cancel()
+
+ require.NoError(t, w.Stop(ctx))
+ })
+
+ return w, deps
+}
+
+// createUnlockedWalletWithMocks creates a fully started and unlocked Wallet
+// instance with mocked dependencies.
+func createUnlockedWalletWithMocks(t *testing.T) (*Wallet, *mockWalletDeps) {
+ t.Helper()
+
+ w, deps := createStartedWalletWithMocks(t)
+
+ // Transition to Unlocked.
+ w.state.toUnlocked()
+
+ return w, deps
+}
+
+func init() {
+ // Use fast scrypt options for tests to avoid CPU exhaustion and
+ // timeouts, especially when running with -race.
+ waddrmgr.DefaultScryptOptions = waddrmgr.FastScryptOptions
+}
diff --git a/wallet/controller.go b/wallet/controller.go
new file mode 100644
index 0000000000..441ed229e6
--- /dev/null
+++ b/wallet/controller.go
@@ -0,0 +1,797 @@
+package wallet
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "math"
+ "time"
+
+ "github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+)
+
+const (
+ // initialBackoff is the initial delay between synchronization retry
+ // attempts.
+ initialBackoff = 1 * time.Second
+
+ // maxBackoff is the maximum delay allowed between synchronization retry
+ // attempts.
+ maxBackoff = 5 * time.Minute
+
+ // stableRunTime is the minimum amount of time the syncer must run
+ // without error to be considered "stable", at which point the retry
+ // backoff is reset to initialBackoff.
+ stableRunTime = 10 * time.Minute
+)
+
+var (
+ // ErrWalletNotStopped is returned when an attempt is made to start the
+ // wallet when it is not in the stopped state.
+ ErrWalletNotStopped = errors.New("wallet not in stopped state")
+
+ // ErrWalletAlreadyStarted is returned when an attempt is made to start
+ // the wallet when it is already started.
+ ErrWalletAlreadyStarted = errors.New("wallet already started")
+
+ // ErrStateChanged is returned when the wallet state changes
+ // unexpectedly during an operation, such as a rescan setup.
+ ErrStateChanged = errors.New("wallet state changed unexpectedly")
+)
+
+// UnlockRequest contains the parameters for unlocking the wallet.
+type UnlockRequest struct {
+ // Passphrase is the private passphrase to unlock the wallet.
+ Passphrase []byte
+
+ // Timeout defines the duration after which the wallet should
+ // automatically lock. If zero, it defaults to the wallet's configured
+ // AutoLockDuration. If negative, the wallet remains unlocked until
+ // explicitly locked or stopped.
+ Timeout time.Duration
+}
+
+// Info provides a comprehensive snapshot of the wallet's static configuration
+// and dynamic synchronization state.
+type Info struct {
+ // BirthdayBlock is the block from which the wallet started scanning.
+ BirthdayBlock waddrmgr.BlockStamp
+
+ // Backend is the name of the chain backend (e.g. "neutrino",
+ // "bitcoind").
+ Backend string
+
+ // ChainParams are the parameters of the chain the wallet is connected
+ // to.
+ ChainParams *chaincfg.Params
+
+ // Locked indicates if the wallet is currently locked.
+ Locked bool
+
+ // Synced indicates if the wallet is synced to the chain tip.
+ Synced bool
+
+ // SyncedTo is the block to which the wallet is currently synced.
+ SyncedTo waddrmgr.BlockStamp
+
+ // IsRecoveryMode indicates if the wallet is currently in recovery
+ // mode.
+ IsRecoveryMode bool
+
+ // RecoveryProgress is the progress of the recovery (0.0 - 1.0).
+ RecoveryProgress float64
+}
+
+// ChangePassphraseRequest contains the parameters for changing wallet
+// passphrases. It supports changing the public passphrase, the private
+// passphrase, or both simultaneously.
+type ChangePassphraseRequest struct {
+ // ChangePublic indicates whether the public passphrase should be
+ // changed.
+ ChangePublic bool
+ PublicOld []byte
+ PublicNew []byte
+
+ // ChangePrivate indicates whether the private passphrase should be
+ // changed.
+ ChangePrivate bool
+ PrivateOld []byte
+ PrivateNew []byte
+}
+
+// Controller provides an interface for managing the wallet's lifecycle and
+// state.
+type Controller interface {
+ // Unlock unlocks the wallet with a passphrase. The wallet will remain
+ // unlocked until explicitly locked or the provided lock duration
+ // expires.
+ Unlock(ctx context.Context, req UnlockRequest) error
+
+ // Lock locks the wallet, clearing any cached private key material.
+ Lock(ctx context.Context) error
+
+ // ChangePassphrase changes the wallet's passphrases according to the
+ // request.
+ ChangePassphrase(ctx context.Context, req ChangePassphraseRequest) error
+
+ // Info returns a comprehensive snapshot of the wallet's static
+ // configuration and dynamic synchronization state.
+ Info(ctx context.Context) (*Info, error)
+
+ // Start starts the background processes necessary to manage the wallet.
+ // It returns an error if the wallet is already started.
+ Start(ctx context.Context) error
+
+ // Stop signals all wallet background processes to shutdown and blocks
+ // until they have all exited. It returns an error if the context is
+ // canceled before the shutdown is complete.
+ Stop(ctx context.Context) error
+
+ // Resync rewinds the wallet's synchronization state to a specific
+ // block height.
+ Resync(ctx context.Context, startHeight uint32) error
+
+ // Rescan initiates a targeted rescan for specific accounts or addresses
+ // starting from the given block height. This operation scans for
+ // relevant transactions without rewinding the wallet's global
+ // synchronization state.
+ Rescan(ctx context.Context, startHeight uint32,
+ targets []waddrmgr.AccountScope) error
+}
+
+// Start starts the background processes necessary to manage the wallet.
+//
+// This is part of the Controller interface.
+func (w *Wallet) Start(startCtx context.Context) error {
+ // 1. Attempt to transition from Stopped to Starting.
+ err := w.state.toStarting()
+ if err != nil {
+ return err
+ }
+
+ // 2. Setup background resources.
+ //
+ // w.lifetimeCtx governs the lifecycle of all background goroutines.
+ // It is canceled when stop() is called.
+ w.lifetimeCtx, w.cancel = context.WithCancel(context.Background())
+
+ // 3. Perform runtime setup.
+ //
+ // We use startCtx here because these operations must complete
+ // synchronously before the wallet is considered "started". If
+ // startCtx is canceled, the startup sequence aborts.
+ err = w.performRuntimeSetup(startCtx)
+ if err != nil {
+ // Cleanup resources.
+ w.cancel()
+
+ // Revert state if setup fails.
+ stopErr := w.state.toStopped()
+ if stopErr != nil {
+ log.Warnf("Failed to revert state to stopped: %v",
+ stopErr)
+ }
+
+ return err
+ }
+
+ // 4. Start background goroutines.
+ w.wg.Add(1)
+
+ go w.mainLoop()
+
+ w.wg.Add(1)
+
+ go func() {
+ defer w.wg.Done()
+
+ w.runSyncLoop()
+ }()
+
+ // 5. Mark the wallet as fully started.
+ err = w.state.toStarted()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// runSyncLoop executes the main chain synchronization loop with automatic
+// retries and exponential backoff. It ensures the wallet attempts to stay
+// synced even if the backend connection is flaky.
+func (w *Wallet) runSyncLoop() {
+ backoff := initialBackoff
+
+ for {
+ startTime := time.Now()
+
+ // Block until the syncer exits.
+ err := w.sync.run(w.lifetimeCtx)
+
+ // If the wallet is shutting down, we can exit immediately.
+ if w.lifetimeCtx.Err() != nil {
+ log.Info("Chain sync loop exiting due to wallet shutdown")
+
+ return
+ }
+
+ // If the syncer exited cleanly (nil error), it generally means it was
+ // requested to stop, so we shouldn't restart.
+ if err == nil {
+ log.Info("Chain sync loop exited normally")
+ return
+ }
+
+ log.Errorf("Chain sync loop exited with error: %v", err)
+
+ var shouldContinue bool
+
+ backoff, shouldContinue = w.waitForBackoff(
+ startTime, backoff, time.After,
+ )
+ if !shouldContinue {
+ return
+ }
+ }
+}
+
+// waitForBackoff handles the delay between synchronization retry attempts. It
+// resets the backoff if the previous run was stable, waits for the calculated
+// delay, and then returns the updated backoff duration for the next attempt.
+// It returns false if the wallet is shutting down.
+func (w *Wallet) waitForBackoff(startTime time.Time, backoff time.Duration,
+ timerFn func(time.Duration) <-chan time.Time) (time.Duration, bool) {
+
+ // If the syncer ran for a significant amount of time, we consider it a
+ // "stable" run and reset the backoff.
+ if time.Since(startTime) > stableRunTime {
+ backoff = initialBackoff
+ }
+
+ log.Infof("Restarting sync loop in %v...", backoff)
+
+ // Wait for the backoff period or a shutdown signal.
+ select {
+ case <-timerFn(backoff):
+ // Increase backoff for the next attempt, capping it.
+ backoff *= 2
+ if backoff > maxBackoff {
+ backoff = maxBackoff
+ }
+
+ return backoff, true
+
+ case <-w.lifetimeCtx.Done():
+ log.Debug("Backoff interrupted by wallet shutdown")
+
+ return 0, false
+ }
+}
+
+// performRuntimeSetup executes the synchronous initialization tasks required
+// before the wallet's main loops can start. This includes sanity checking the
+// birthday block, loading accounts into memory, and cleaning up expired locks.
+func (w *Wallet) performRuntimeSetup(startCtx context.Context) error {
+ // Perform the birthday sanity check synchronously to ensure we are
+ // connected and our status is valid before starting the main loop.
+ //
+ // This also initializes the birthday block cache used by the Info
+ // method.
+ err := w.verifyBirthday(startCtx)
+ if err != nil {
+ return err
+ }
+
+ // Ensure all accounts are loaded into memory so we can efficiently
+ // access them during the scan loop without database lookups.
+ err = w.DBGetAllAccounts(startCtx)
+ if err != nil {
+ return err
+ }
+
+ // Cleanup any expired output locks.
+ return w.DBDeleteExpiredLockedOutputs(startCtx)
+}
+
+// Stop signals all wallet background processes to shutdown and blocks until
+// they have all exited. It returns an error if the context is canceled before
+// the shutdown is complete.
+//
+// This is part of the Controller interface.
+func (w *Wallet) Stop(stopCtx context.Context) error {
+ // Attempt to transition from Started to Stopping.
+ err := w.state.toStopping()
+ if err != nil {
+ // If the wallet is not started, we can consider it stopped.
+ log.Warnf("Wallet already stopped: %v", err)
+ return nil
+ }
+
+ // Signal all background processes to stop.
+ //
+ // It is safe to call w.cancel() here because the successful transition
+ // to Stopping guarantees that we were previously in the Started state,
+ // which in turn guarantees that start() has completed initialization
+ // of w.lifetimeCtx and w.cancel.
+ //
+ // Additionally, w.cancel() is idempotent, so it is safe to call even
+ // if it has effectively already been called (though the state machine
+ // guarantees we only reach this point once).
+ w.cancel()
+
+ // Wait for all goroutines to finish.
+ done := make(chan struct{})
+ go func() {
+ w.wg.Wait()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-stopCtx.Done():
+ return fmt.Errorf("stop request cancelled: %w", stopCtx.Err())
+ }
+
+ // Mark the wallet as stopped.
+ err = w.state.toStopped()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// Unlock unlocks the wallet with a passphrase.
+//
+// This is part of the Controller interface.
+func (w *Wallet) Unlock(ctx context.Context, req UnlockRequest) error {
+ // Ensure the wallet is in a state that allows unlocking.
+ err := w.state.canUnlock()
+ if err != nil {
+ return err
+ }
+
+ // Apply default timeout if none specified.
+ if req.Timeout == 0 {
+ req.Timeout = w.cfg.AutoLockDuration
+ log.Infof("Using default auto-lock timeout of %v", req.Timeout)
+ }
+
+ r := newUnlockReq(req)
+
+ // Submit the request.
+ err = w.sendReq(ctx, r)
+ if err != nil {
+ return err
+ }
+
+ // Wait for the result from the mainLoop.
+ return w.waitForResp(ctx, r.resp)
+}
+
+// Lock locks the wallet.
+//
+// This is part of the Controller interface.
+func (w *Wallet) Lock(ctx context.Context) error {
+ // Ensure the wallet is in a state that allows locking.
+ err := w.state.canLock()
+ if err != nil {
+ return err
+ }
+
+ r := newLockReq()
+
+ err = w.sendReq(ctx, r)
+ if err != nil {
+ return err
+ }
+
+ // Wait for the result.
+ return w.waitForResp(ctx, r.resp)
+}
+
+// ChangePassphrase changes the wallet's passphrases according to the request.
+//
+// This is part of the Controller interface.
+func (w *Wallet) ChangePassphrase(ctx context.Context,
+ req ChangePassphraseRequest) error {
+
+ // Ensure the wallet is in a state that allows changing the passphrase.
+ err := w.state.canChangePassphrase()
+ if err != nil {
+ return err
+ }
+
+ r := newChangePassphraseReq(req)
+
+ err = w.sendReq(ctx, r)
+ if err != nil {
+ return err
+ }
+
+ // Wait for the result.
+ return w.waitForResp(ctx, r.resp)
+}
+
+// Info returns a comprehensive snapshot of the wallet's static configuration
+// and dynamic synchronization state.
+//
+// This is part of the Controller interface.
+func (w *Wallet) Info(_ context.Context) (*Info, error) {
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ info := &Info{
+ BirthdayBlock: w.birthdayBlock,
+ Backend: w.cfg.Chain.BackEnd(),
+ ChainParams: w.cfg.ChainParams,
+ Locked: !w.state.isUnlocked(),
+ Synced: w.state.isSynced(),
+ SyncedTo: w.SyncedTo(),
+ IsRecoveryMode: w.state.isRecoveryMode(),
+ RecoveryProgress: 0,
+ }
+
+ return info, nil
+}
+
+// Resync rewinds the wallet's synchronization state to a specific block
+// height.
+//
+// This is part of the Controller interface.
+func (w *Wallet) Resync(ctx context.Context, startHeight uint32) error {
+ return w.submitRescanRequest(
+ ctx, scanTypeRewind, startHeight, nil,
+ )
+}
+
+// Rescan initiates a targeted rescan for specific accounts or addresses
+// starting from the given block height. This operation scans for
+// relevant transactions without rewinding the wallet's global
+// synchronization state.
+func (w *Wallet) Rescan(ctx context.Context, startHeight uint32,
+ targets []waddrmgr.AccountScope) error {
+
+ if len(targets) == 0 {
+ return ErrNoScanTargets
+ }
+
+ return w.submitRescanRequest(
+ ctx, scanTypeTargeted, startHeight, targets,
+ )
+}
+
+// submitRescanRequest validates the rescan request and submits it to the
+// syncer.
+func (w *Wallet) submitRescanRequest(ctx context.Context, typ scanType,
+ startHeight uint32, targets []waddrmgr.AccountScope) error {
+
+ // Ensure the wallet is running and synced.
+ err := w.state.validateSynced()
+ if err != nil {
+ return err
+ }
+
+ // BlockStamp.Height is int32, so we need to ensure the requested
+ // startHeight does not exceed math.MaxInt32.
+ if startHeight > math.MaxInt32 {
+ return fmt.Errorf("%w: %d", ErrStartHeightTooLarge, startHeight)
+ }
+
+ startHeightInt32 := int32(startHeight)
+
+ // Fetch the current best block to ensure we don't resync past the tip.
+ _, bestHeightInt32, err := w.cfg.Chain.GetBestBlock()
+ if err != nil {
+ return fmt.Errorf("unable to get chain tip: %w", err)
+ }
+
+ if startHeightInt32 > bestHeightInt32 {
+ return fmt.Errorf("%w: start height %d is greater than "+
+ "current chain tip %d", ErrStartHeightTooHigh,
+ startHeight, bestHeightInt32)
+ }
+
+ // Submit the rescan request to the syncer.
+ req := &scanReq{
+ typ: typ,
+ startBlock: waddrmgr.BlockStamp{
+ Height: startHeightInt32,
+ },
+ targets: targets,
+ }
+
+ return w.sync.requestScan(ctx, req)
+}
+
+// mainLoop is the central event loop for the wallet, responsible for
+// coordinating and serializing all lifecycle and authentication requests. It
+// manages the transition between locked and unlocked states and handles the
+// automatic locking of the wallet after a specified duration.
+func (w *Wallet) mainLoop() {
+ defer w.wg.Done()
+
+ for {
+ select {
+ case req := <-w.requestChan:
+ // Process incoming serialized requests.
+ switch r := req.(type) {
+ // Perform the unlock.
+ case unlockReq:
+ w.handleUnlockReq(r)
+
+ // Perform an explicit lock and stop the timer.
+ case lockReq:
+ w.handleLockReq(r)
+
+ // Rotate wallet passphrases.
+ case changePassphraseReq:
+ w.handleChangePassphraseReq(r)
+
+ default:
+ log.Errorf("Wallet received unknown request "+
+ "type: %T", req)
+ }
+
+ // The auto-lock timer has expired. We trigger a lock with a
+ // dummy response channel to avoid nil checks in the handler.
+ case <-w.lockTimer.C:
+ log.Infof("Auto-lock timeout fired, locking wallet")
+ w.handleLockReq(newLockReq())
+
+ // The wallet is shutting down. We exit the main loop.
+ case <-w.lifetimeCtx.Done():
+ w.lockTimer.Stop()
+
+ return
+ }
+ }
+}
+
+// verifyBirthday performs a sanity check on the wallet's birthday block to
+// ensure it is set and valid.
+//
+// Logical Steps:
+// 1. Fetch the current birthday block from the database.
+// 2. If the block is already verified, initialize the memory cache and
+// return.
+// 3. If the block is missing or unverified, fetch the wallet's birthday
+// timestamp.
+// 4. Use the chain backend to locate a suitable block matching the
+// birthday timestamp.
+// 5. Persist the new birthday block, mark it as verified, and update the
+// wallet's sync tip to this point to ensure a clean rescan range.
+// 6. Update the memory cache.
+func (w *Wallet) verifyBirthday(ctx context.Context) error {
+ // We'll start by fetching our wallet's birthday block.
+ birthdayBlock, verified, err := w.DBGetBirthdayBlock(ctx)
+ if err != nil {
+ var mgrErr waddrmgr.ManagerError
+ if !errors.As(err, &mgrErr) ||
+ mgrErr.ErrorCode != waddrmgr.ErrBirthdayBlockNotSet {
+
+ log.Errorf("Unable to sanity check wallet birthday "+
+ "block: %v", err)
+
+ return err
+ }
+ // If not set, we proceed to locate it.
+ }
+
+ // If the birthday block has already been verified, we initialize the
+ // cache and exit our sanity check to avoid redundant lookups.
+ if verified {
+ log.Infof("Birthday block verified: height=%d, hash=%v",
+ birthdayBlock.Height, birthdayBlock.Hash)
+ w.birthdayBlock = birthdayBlock
+
+ return nil
+ }
+ // Otherwise, we'll attempt to locate a better one now that we have
+ // access to the chain.
+ timestamp := w.addrStore.Birthday()
+
+ newBirthdayBlock, err := locateBirthdayBlock(w.cfg.Chain, timestamp)
+ if err != nil {
+ log.Errorf("Unable to sanity check wallet birthday "+
+ "block: %v", err)
+
+ return err
+ }
+
+ err = w.DBPutBirthdayBlock(ctx, *newBirthdayBlock)
+ if err != nil {
+ log.Errorf("Unable to sanity check wallet birthday "+
+ "block: %v", err)
+
+ return err
+ }
+
+ w.birthdayBlock = *newBirthdayBlock
+
+ return nil
+}
+
+// resultChan is a generic channel for returning errors to callers.
+type resultChan chan error
+
+// unlockReq requests the wallet to be unlocked.
+type unlockReq struct {
+ req UnlockRequest
+ resp resultChan
+}
+
+// lockReq requests the wallet to be locked.
+type lockReq struct {
+ resp resultChan
+}
+
+// changePassphraseReq requests a change of the wallet's passphrases.
+type changePassphraseReq struct {
+ req ChangePassphraseRequest
+ resp resultChan
+}
+
+// newUnlockReq creates a new unlock request with a buffered response channel.
+// We use this constructor to ensure that the response channel is always
+// correctly initialized and buffered, preventing the main loop from blocking
+// when reporting the result.
+func newUnlockReq(req UnlockRequest) unlockReq {
+ return unlockReq{
+ req: req,
+ resp: make(resultChan, 1),
+ }
+}
+
+// newLockReq creates a new lock request with a buffered response channel.
+func newLockReq() lockReq {
+ return lockReq{
+ resp: make(resultChan, 1),
+ }
+}
+
+// newChangePassphraseReq creates a new change passphrase request with a
+// buffered response channel.
+func newChangePassphraseReq(req ChangePassphraseRequest) changePassphraseReq {
+ return changePassphraseReq{
+ req: req,
+ resp: make(resultChan, 1),
+ }
+}
+
+// handleUnlockReq processes an incoming request to unlock the wallet. It
+// authenticates the provided passphrase against the database and, on success,
+// transitions the wallet to the unlocked state.
+func (w *Wallet) handleUnlockReq(req unlockReq) {
+ // First, validate that the wallet is in a state that allows unlocking.
+ err := w.state.canUnlock()
+ if err != nil {
+ req.resp <- err
+ return
+ }
+
+ // Attempt to unlock the underlying address manager.
+ err = w.DBUnlock(w.lifetimeCtx, req.req.Passphrase)
+ if err != nil {
+ req.resp <- err
+ return
+ }
+
+ // On success, update the atomic wallet state to reflect that we are
+ // now unlocked.
+ w.state.toUnlocked()
+
+ // Handle auto-lock timer. If a timeout is specified, we reset the
+ // timer to fire in the future. Otherwise, we stop the timer to disable
+ // auto-locking.
+ duration := req.req.Timeout
+ if duration > 0 {
+ w.lockTimer.Reset(duration)
+ } else if !w.lockTimer.Stop() {
+ // If the timer has already fired, we drain its channel to
+ // prevent a stale signal from being processed by the main
+ // loop, which would cause an immediate, unexpected lock.
+ select {
+ case <-w.lockTimer.C:
+ default:
+ }
+ }
+
+ // Always report the result back to the caller.
+ req.resp <- nil
+}
+
+// handleLockReq processes an incoming request to lock the wallet. It clears
+// any cached private key material from memory and transitions the wallet to
+// the locked state.
+func (w *Wallet) handleLockReq(req lockReq) {
+ // First, validate that the wallet is in a state that allows locking.
+ err := w.state.canLock()
+ if err != nil {
+ req.resp <- err
+ return
+ }
+
+ // Stop the auto-lock timer since the wallet is now explicitly locked.
+ if !w.lockTimer.Stop() {
+ // Drain the channel if the timer has already fired to ensure
+ // we don't process a stale lock signal in the next iteration.
+ select {
+ case <-w.lockTimer.C:
+ default:
+ }
+ }
+
+ // Signal the address manager to lock, clearing sensitive data.
+ err = w.addrStore.Lock()
+ if err != nil {
+ log.Errorf("Could not lock wallet: %v", err)
+
+ // If the wallet is already locked, we consider this a success
+ // (idempotency) and proceed to ensure our state is consistent.
+ if !waddrmgr.IsError(err, waddrmgr.ErrLocked) {
+ req.resp <- err
+
+ return
+ }
+ }
+
+ // Even if an error occurred (e.g. already locked), we ensure the
+ // wallet's high-level state is synchronized to 'locked'.
+ w.state.toLocked()
+
+ // Report the result back to the caller.
+ req.resp <- nil
+}
+
+// handleChangePassphraseReq processes a request to rotate the wallet's
+// passphrases. It can change either the public passphrase, the private
+// passphrase, or both in a single atomic database update.
+func (w *Wallet) handleChangePassphraseReq(req changePassphraseReq) {
+ // First, validate that the wallet is in a state that allows changing
+ // the passphrase.
+ err := w.state.canChangePassphrase()
+ if err != nil {
+ req.resp <- err
+ return
+ }
+
+ // Delegate the cryptographic rotation to the database layer.
+ err = w.DBPutPassphrase(w.lifetimeCtx, req.req)
+
+ // Report the result back to the caller.
+ req.resp <- err
+}
+
+// sendReq sends an operation request to the main loop or handles cancellation.
+func (w *Wallet) sendReq(ctx context.Context, req any) error {
+ select {
+ case w.requestChan <- req:
+ return nil
+
+ case <-w.lifetimeCtx.Done():
+ return ErrWalletShuttingDown
+
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
+// waitForResp waits for the response from an operation request or handles
+// cancellation.
+func (w *Wallet) waitForResp(ctx context.Context, resp <-chan error) error {
+ select {
+ case err := <-resp:
+ return err
+
+ case <-w.lifetimeCtx.Done():
+ return ErrWalletShuttingDown
+
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
diff --git a/wallet/controller_benchmark_test.go b/wallet/controller_benchmark_test.go
new file mode 100644
index 0000000000..df857a654d
--- /dev/null
+++ b/wallet/controller_benchmark_test.go
@@ -0,0 +1,839 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wallet
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime/pprof"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcd/integration/rpctest"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/chain"
+ "github.com/btcsuite/btcwallet/chain/port"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+ _ "github.com/btcsuite/btcwallet/walletdb/bdb"
+ "github.com/stretchr/testify/require"
+)
+
+const (
+ // benchBlocksToSync is the number of blocks to generate and sync during
+ // the benchmark.
+ benchBlocksToSync = 1000
+
+ // benchTxsPerBlock is the total number of transactions per block
+ // (hits + noise).
+ benchTxsPerBlock = 100
+
+ // testRecoveryWindow is the address lookahead window used during
+ // benchmarks. A larger window ensures that the wallet can discover
+ // addresses even if there are gaps in the address chain, which is essential
+ // for realistic synchronization benchmarks.
+ testRecoveryWindow = 100
+)
+
+// BenchmarkSyncEmpty benchmarks the wallet synchronization performance against
+// empty blocks and an empty wallet by comparing the legacy SynchronizeRPC with
+// the new Controller.Start API across different block depths.
+func BenchmarkSyncEmpty(b *testing.B) {
+ scenarios := []struct {
+ blocks int
+ }{
+ {10},
+ {100},
+ {1000},
+ }
+
+ for _, s := range scenarios {
+ name := fmt.Sprintf("Blocks-%d", s.blocks)
+ b.Run(name, func(b *testing.B) {
+ // Initialize a common miner for all sub-benchmarks in this
+ // scenario.
+ miner := setupChain(b, s.blocks)
+
+ b.Run("Legacy", func(b *testing.B) {
+ runLegacySync(b, miner)
+ })
+
+ b.Run("NewWithFullBlock", func(b *testing.B) {
+ runNewSync(b, miner, SyncMethodFullBlocks)
+ })
+
+ b.Run("NewWithCFilter", func(b *testing.B) {
+ runNewSync(b, miner, SyncMethodCFilters)
+ })
+ })
+ }
+}
+
+// BenchmarkSyncData benchmarks the wallet synchronization performance against
+// blocks with data and a populated wallet. It compares Legacy vs New APIs
+// across different wallet sizes (number of accounts/addresses).
+func BenchmarkSyncData(b *testing.B) {
+ scenarios := []struct {
+ addrs int
+ numUTXOs int
+ }{
+ // Case 1: Sparse Discovery Stress Test.
+ // 100 total addresses, but only 1 UTXO sent to the 100th address
+ // (index 99). The hit is delivered at block 500.
+ // Density: 1 hit per 1000 blocks (0.1%).
+ // Intent: Tests how the wallet handles a nearly empty history where
+ // it must scan far into the chain and its lookahead window to find
+ // a single isolated transaction.
+ {addrs: 100, numUTXOs: 1},
+
+ // Case 2: Periodic Discovery Stress Test.
+ // 100 total addresses, with 10 UTXOs sent to every 10th address
+ // (index 9, 19, ... 99). Hits are delivered every 100 blocks.
+ // Density: 1 hit per 100 blocks (1%).
+ // Intent: Tests incremental discovery and sequential rescan logic
+ // as the wallet regularly finds hits and must decide whether to
+ // expand its search window.
+ {addrs: 100, numUTXOs: 10},
+
+ // Case 3: Dense Sync Throughput Test.
+ // 100 total addresses, with 100 UTXOs sent to every address.
+ // Hits are delivered every 10 blocks.
+ // Density: 1 hit per 10 blocks (10%).
+ // Intent: Tests the raw throughput of the transaction indexing and
+ // block processing logic when relevant data appears frequently.
+ {addrs: 100, numUTXOs: 100},
+ }
+
+ for _, s := range scenarios {
+ density := float64(s.numUTXOs) / float64(benchBlocksToSync)
+ name := fmt.Sprintf("UTXODensity-%.3f", density)
+
+ b.Run(name, func(b *testing.B) {
+ seed, err := hdkeychain.GenerateSeed(hdkeychain.MinSeedBytes)
+ require.NoError(b, err)
+
+ // Setup common miner and populate it with wallet-destined data.
+ // Always use 1 account.
+ miner := setupChainWithWalletData(
+ b, seed, s.addrs, s.numUTXOs,
+ )
+
+ b.ResetTimer()
+
+ b.Run("Legacy", func(b *testing.B) {
+ runLegacySyncData(b, miner, seed, s.numUTXOs)
+ })
+
+ b.Run("NewWithFullBlock", func(b *testing.B) {
+ runNewSyncData(
+ b, miner, seed, SyncMethodFullBlocks, s.numUTXOs,
+ )
+ })
+
+ b.Run("NewWithCFilter", func(b *testing.B) {
+ runNewSyncData(
+ b, miner, seed, SyncMethodCFilters, s.numUTXOs,
+ )
+ })
+ })
+ }
+}
+
+// startProfiling begins CPU profiling if the profileName is not empty. It
+// returns a cleanup function that must be called to stop profiling.
+func startProfiling(tb testing.TB) func() {
+ tb.Helper()
+
+ // We use the test name to generate a unique profile filename for each
+ // benchmark case. Slashes are replaced with underscores to ensure a
+ // valid filename.
+ name := strings.ReplaceAll(tb.Name(), "/", "_") + ".prof"
+
+ f, err := os.Create(name)
+ require.NoError(tb, err)
+
+ err = pprof.StartCPUProfile(f)
+ require.NoError(tb, err)
+
+ return func() {
+ pprof.StopCPUProfile()
+ require.NoError(tb, f.Close())
+ }
+}
+
+// runLegacySync executes the legacy synchronization benchmark loop.
+func runLegacySync(b *testing.B, miner *rpctest.Harness) {
+ b.Helper()
+
+ for b.Loop() {
+ b.StopTimer()
+
+ // Setup a fresh legacy wallet for each iteration.
+ seed, err := hdkeychain.GenerateSeed(hdkeychain.MinSeedBytes)
+ require.NoError(b, err)
+ w := setupLegacyWallet(b, seed)
+
+ // Connect a fresh chain client.
+ chainClient := setupChainClient(b, miner)
+
+ stopProfile := startProfiling(b)
+
+ b.StartTimer()
+
+ // Start legacy sync process.
+ w.StartDeprecated()
+ w.SynchronizeRPC(chainClient)
+
+ // Poll until the wallet reports it is synced.
+ for !w.ChainSynced() {
+ time.Sleep(5 * time.Millisecond)
+ }
+
+ stopProfile()
+ }
+}
+
+// runLegacySyncData executes the legacy synchronization benchmark loop
+// with data.
+func runLegacySyncData(b *testing.B, miner *rpctest.Harness, seed []byte,
+ expectedUTXOs int) {
+
+ b.Helper()
+
+ for b.Loop() {
+ // Stop the timer to exclude expensive setup operations (like
+ // creating the wallet database and accounts) from the measured
+ // sync time.
+ b.StopTimer()
+
+ // Setup a fresh legacy wallet.
+ w := setupLegacyWallet(b, seed)
+
+ // Connect a fresh chain client (Bitcoind).
+ chainClient := setupChainClient(b, miner)
+
+ stopProfile := startProfiling(b)
+
+ // Start the timer for the actual synchronization phase.
+ b.StartTimer()
+
+ // Start legacy sync process.
+ w.StartDeprecated()
+
+ w.SynchronizeRPC(chainClient)
+
+ // Poll until the wallet reports it is synced.
+ for !w.ChainSynced() {
+ time.Sleep(100 * time.Millisecond)
+ }
+
+ // Stop the timer to exclude verification.
+ b.StopTimer()
+
+ stopProfile()
+
+ // Verify UTXO count using high-level method.
+ assertUTXOCountDeprecated(b, w, expectedUTXOs)
+
+ // Restart timer for loop accounting.
+ b.StartTimer()
+ }
+}
+
+// runNewSync executes the modern Controller synchronization benchmark loop.
+func runNewSync(b *testing.B, miner *rpctest.Harness, method SyncMethod) {
+ b.Helper()
+
+ for b.Loop() {
+ b.StopTimer()
+
+ // Connect a fresh chain client.
+ chainClient := setupChainClient(b, miner)
+
+ // Configure for the specified sync mode.
+ cfg := defaultWalletConfig(b)
+ cfg.Chain = chainClient
+ cfg.SyncMethod = method
+
+ // Setup a fresh modern wallet.
+ seed, err := hdkeychain.GenerateSeed(hdkeychain.MinSeedBytes)
+ require.NoError(b, err)
+ w := setupNewWallet(b, seed, cfg)
+
+ stopProfile := startProfiling(b)
+
+ b.StartTimer()
+
+ // Start modern controller and syncing.
+ err = w.Start(b.Context())
+ require.NoError(b, err)
+
+ // Poll until the controller reports it is synced.
+ //
+ // NOTE: We use w.Info() here to poll for the synced status. This is a
+ // heavier operation than the legacy mutex-protected boolean read,
+ // making the observed performance gains even more significant as they
+ // include this additional status-check overhead.
+ for {
+ info, err := w.Info(b.Context())
+ require.NoError(b, err)
+
+ if info.Synced {
+ break
+ }
+
+ time.Sleep(5 * time.Millisecond)
+ }
+
+ stopProfile()
+ }
+}
+
+// runNewSyncData executes the modern Controller synchronization benchmark loop
+// with data.
+func runNewSyncData(b *testing.B, miner *rpctest.Harness, seed []byte,
+ method SyncMethod, expectedUTXOs int) {
+
+ b.Helper()
+
+ for b.Loop() {
+ // Stop the timer to exclude expensive setup operations (like
+ // creating the wallet database and accounts) from the measured
+ // sync time.
+ b.StopTimer()
+
+ chainClient := setupChainClient(b, miner)
+ cfg := defaultWalletConfig(b)
+ cfg.Chain = chainClient
+ cfg.SyncMethod = method
+
+ w := setupNewWallet(b, seed, cfg)
+
+ stopProfile := startProfiling(b)
+
+ // Start the timer for the actual synchronization phase.
+ b.StartTimer()
+
+ // Start modern controller and syncing.
+ err := w.Start(b.Context())
+ require.NoError(b, err)
+
+ // Poll until the controller reports it is synced.
+ for {
+ info, err := w.Info(b.Context())
+ require.NoError(b, err)
+
+ if info.Synced {
+ break
+ }
+
+ time.Sleep(100 * time.Millisecond)
+ }
+
+ // Stop the timer to exclude verification.
+ b.StopTimer()
+
+ stopProfile()
+
+ // Verify UTXO count using high-level method.
+ assertUTXOCount(b, w, expectedUTXOs)
+
+ // Restart timer for loop accounting.
+ b.StartTimer()
+ }
+}
+
+// setupLegacyWallet initializes a legacy wallet for benchmarking. It
+// automatically registers resource cleanup.
+func setupLegacyWallet(tb testing.TB, seed []byte) *Wallet {
+ tb.Helper()
+
+ // Initialize temporary database directory and standard test credentials.
+ dir := tb.TempDir()
+ pubPass := []byte("public")
+ privPass := []byte("private")
+
+ // Create the wallet using the legacy Loader.
+ loader := NewLoader(
+ &chaincfg.RegressionNetParams, dir, true, 10*time.Second,
+ testRecoveryWindow,
+ WithWalletSyncRetryInterval(10*time.Millisecond),
+ )
+
+ // Use an old birthday to ensure the wallet rescans past blocks.
+ birthday := time.Now().Add(-48 * time.Hour)
+ w, err := loader.CreateNewWallet(pubPass, privPass, seed, birthday)
+ require.NoError(tb, err)
+
+ // Register cleanup function to ensure all legacy background processes are
+ // stopped and the database is correctly closed after the benchmark subtest.
+ tb.Cleanup(func() {
+ w.StopDeprecated()
+ w.WaitForShutdown()
+
+ if val := w.Database(); val != nil {
+ require.NoError(tb, val.Close())
+ }
+ })
+
+ return w
+}
+
+// setupNewWallet initializes a modern wallet using the Manager API. It accepts
+// a Config which should at least have the Chain client populated. It
+// automatically registers resource cleanup.
+func setupNewWallet(tb testing.TB, seed []byte, cfg Config) *Wallet {
+ tb.Helper()
+
+ privPass := []byte("private")
+ params := CreateWalletParams{
+ Mode: ModeImportSeed,
+ Seed: seed,
+ PrivatePassphrase: privPass,
+ PubPassphrase: cfg.PubPassphrase,
+ Birthday: time.Now().Add(-48 * time.Hour),
+ }
+
+ // Create the wallet using the new Manager API. This returns a loaded
+ // but unstarted wallet instance.
+ manager := NewManager()
+ w, err := manager.Create(cfg, params)
+ require.NoError(tb, err)
+
+ // Register cleanup function to handle the Controller shutdown and close
+ // the database handle after the benchmark subtest.
+ tb.Cleanup(func() {
+ _ = w.Stop(tb.Context())
+ require.NoError(tb, w.cfg.DB.Close())
+ })
+
+ return w
+}
+
+// setupChain prepares a btcd node and generates the required blocks.
+func setupChain(tb testing.TB, blocks int) *rpctest.Harness {
+ tb.Helper()
+
+ args := []string{
+ "--txindex",
+ "--minrelaytxfee=0.00000001", // 1 sat/kb
+ }
+ miner, err := rpctest.New(&chaincfg.RegressionNetParams, nil, args, "")
+ require.NoError(tb, err)
+ require.NoError(tb, miner.SetUp(true, uint32(blocks)))
+
+ // Generate the requested number of empty blocks.
+ if blocks > 0 {
+ _, err := miner.Client.Generate(uint32(blocks))
+ require.NoError(tb, err)
+ }
+
+ tb.Cleanup(func() {
+ require.NoError(tb, miner.TearDown())
+ })
+
+ return miner
+}
+
+// setupChainWithWalletData prepares a miner and populates the blockchain with
+// transactions destined for a wallet derived from the provided seed. This
+// establishes a realistic environment for benchmarking synchronization.
+func setupChainWithWalletData(tb testing.TB, seed []byte,
+ addrsPerAccount, numUTXOs int) *rpctest.Harness {
+
+ tb.Helper()
+
+ // Initialize common miner.
+ miner := setupChain(tb, 0)
+
+ // 1. Setup a template wallet to extract addresses for the chain.
+ cfg := defaultWalletConfig(tb)
+ cfg.Chain = setupChainClient(tb, miner)
+
+ templateW := setupNewWallet(tb, seed, cfg)
+
+ err := templateW.Start(tb.Context())
+ require.NoError(tb, err)
+
+ // Unlock template wallet to derive addresses.
+ err = templateW.Unlock(tb.Context(), UnlockRequest{
+ Passphrase: []byte("private"),
+ })
+ require.NoError(tb, err)
+
+ // Manually derive addresses for the template wallet.
+ var targetAddrs []address.Address
+
+ // Calculate chunk for selecting target addresses.
+ // Total addresses = addrsPerAccount (since 1 account).
+ // We want to pick 'numUTXOs' targets.
+ // Stride = Total / numUTXOs.
+ chunk := addrsPerAccount / numUTXOs
+
+ accountName := waddrmgr.DefaultAccountName
+ for i := range addrsPerAccount {
+ addr, err := templateW.NewAddress(tb.Context(),
+ accountName, waddrmgr.WitnessPubKey, false)
+ require.NoError(tb, err)
+
+ // Select target addresses based on the calculated chunk.
+ // For example, if we have 100 addresses and need 10 UTXOs, we pick
+ // every 10th address (index 9, 19, ... 99). This ensures the wallet
+ // must scan through gaps to find all hits, stressing the recovery
+ // and discovery logic.
+ targetIdx := (len(targetAddrs)+1)*chunk - 1
+ if i == targetIdx {
+ targetAddrs = append(targetAddrs, addr)
+ }
+ }
+
+ // Close the template wallet now that we are done with it. This releases
+ // the database lock and resources.
+ _ = templateW.Stop(tb.Context())
+ require.NoError(tb, templateW.cfg.DB.Close())
+
+ // Ensure we selected the correct number of targets.
+ require.Len(tb, targetAddrs, numUTXOs,
+ "failed to select target addresses")
+
+ // Pre-mine 200 blocks to ensure the miner has a sufficient balance of
+ // mature coinbase outputs. In Bitcoin, coinbase outputs cannot be spent
+ // until they have reached a depth of 100 blocks (maturity). Pre-mining 200
+ // blocks ensures that the miner can immediately begin sending transactions
+ // to the wallet during the setup phase.
+ _, err = miner.Client.Generate(200)
+ require.NoError(tb, err)
+
+ // 2. Setup the chain based on the scenario (numUTXOs).
+ switch numUTXOs {
+ case 1:
+ setupChainCase1(tb, miner, targetAddrs)
+
+ case 10:
+ setupChainCase2(tb, miner, targetAddrs)
+
+ case 100:
+ setupChainCase3(tb, miner, targetAddrs)
+
+ default:
+ tb.Fatalf("Unsupported numUTXOs: %d", numUTXOs)
+ }
+
+ return miner
+}
+
+// setupChainCase1 mines 1000 blocks and sends exactly 1 UTXO to the wallet at
+// the 500th block. This scenario tests how the wallet handles a sparse history
+// and whether it can correctly recover from a birthday that predates a single,
+// isolated transaction.
+func setupChainCase1(tb testing.TB, miner *rpctest.Harness,
+ targetAddrs []address.Address) {
+
+ tb.Helper()
+ require.Len(tb, targetAddrs, 1)
+
+ var err error
+
+ // Iterate through 1000 blocks. We want the wallet to find its single UTXO
+ // midway through the scan.
+ for i := range benchBlocksToSync {
+ // Send the single UTXO to the wallet at block 500.
+ if i == 500 {
+ var pkScript []byte
+
+ pkScript, err = txscript.PayToAddrScript(targetAddrs[0])
+ require.NoError(tb, err)
+
+ _, err = miner.SendOutputs([]*wire.TxOut{
+ {Value: 1000, PkScript: pkScript},
+ }, 1)
+ require.NoError(tb, err)
+ }
+
+ // Fill the block with 100 noise transactions to simulate a realistic
+ // mainnet-like environment where the wallet must filter through
+ // many irrelevant transactions.
+ noiseCount := benchTxsPerBlock
+ if i == 500 {
+ noiseCount--
+ }
+
+ generateNonWalletTxns(tb, miner, noiseCount)
+
+ // Mine the block.
+ _, err = miner.Client.Generate(1)
+ require.NoError(tb, err)
+ }
+}
+
+// setupChainCase2 mines 1000 blocks and sends 1 UTXO to the wallet every 100
+// blocks, for a total of 10 UTXOs. This scenario tests the wallet's ability
+// to handle incremental discovery and sequential rescans as it finds hits
+// distributed regularly across the chain.
+func setupChainCase2(tb testing.TB, miner *rpctest.Harness,
+ targetAddrs []address.Address) {
+
+ tb.Helper()
+ require.Len(tb, targetAddrs, 10)
+
+ var err error
+
+ targetIdx := 0
+ for i := range benchBlocksToSync {
+ // Determine if this block is a hit (every 100th block). This creates
+ // a periodic matching pattern that triggers incremental discovery
+ // as the wallet reaches the end of its lookahead window.
+ isHit := (i+1)%100 == 0
+
+ if isHit {
+ // Pop the next target address and generate a script that pays
+ // to it.
+ var pkScript []byte
+
+ pkScript, err = txscript.PayToAddrScript(targetAddrs[targetIdx])
+ require.NoError(tb, err)
+
+ targetIdx++
+
+ // Send the wallet payment.
+ _, err = miner.SendOutputs([]*wire.TxOut{
+ {Value: 1000, PkScript: pkScript},
+ }, 1)
+ require.NoError(tb, err)
+ }
+
+ // Add noise transactions to fill the block to 100 total outputs. This
+ // ensures the wallet must filter through a realistic amount of
+ // irrelevant data in every block.
+ noiseCount := benchTxsPerBlock
+ if isHit {
+ noiseCount--
+ }
+
+ generateNonWalletTxns(tb, miner, noiseCount)
+
+ // Mine the block.
+ _, err = miner.Client.Generate(1)
+ require.NoError(tb, err)
+ }
+}
+
+// setupChainCase3 mines 1000 blocks and sends 1 UTXO to the wallet every 10
+// blocks, for a total of 100 UTXOs. This is a "dense" scenario that
+// stresses the wallet's block processing and transaction indexing performance
+// when relevant data appears frequently.
+func setupChainCase3(tb testing.TB, miner *rpctest.Harness,
+ targetAddrs []address.Address) {
+
+ tb.Helper()
+ require.Len(tb, targetAddrs, 100)
+
+ var err error
+
+ targetIdx := 0
+ for i := range benchBlocksToSync {
+ // Determine if this block is a hit (every 10th block).
+ isHit := (i+1)%10 == 0
+
+ if isHit {
+ // Pop the next target address and generate a script that pays
+ // to it.
+ var pkScript []byte
+
+ pkScript, err = txscript.PayToAddrScript(targetAddrs[targetIdx])
+ require.NoError(tb, err)
+
+ targetIdx++
+
+ // Send the wallet payment.
+ _, err = miner.SendOutputs([]*wire.TxOut{
+ {Value: 1000, PkScript: pkScript},
+ }, 1)
+ require.NoError(tb, err)
+ }
+
+ // Add noise transactions to reach the 100 txs/block target. This
+ // maintains a constant transaction density across all scenarios.
+ noiseCount := benchTxsPerBlock
+ if isHit {
+ noiseCount--
+ }
+
+ generateNonWalletTxns(tb, miner, noiseCount)
+
+ // Mine the block.
+ _, err = miner.Client.Generate(1)
+ require.NoError(tb, err)
+ }
+}
+
+// generateNonWalletTxns creates 'count' random transactions.
+func generateNonWalletTxns(tb testing.TB, miner *rpctest.Harness, count int) {
+ tb.Helper()
+
+ outputs := make([]*wire.TxOut, 0, count)
+ for range count {
+ mAddr, err := miner.NewAddress()
+ require.NoError(tb, err)
+ pkScript, err := txscript.PayToAddrScript(mAddr)
+ require.NoError(tb, err)
+
+ outputs = append(outputs, &wire.TxOut{Value: 1000, PkScript: pkScript})
+ }
+
+ _, err := miner.SendOutputs(outputs, 1)
+ require.NoError(tb, err)
+}
+
+// setupChainClient initializes and starts a new bitcoind client connection to
+// the provided chain backend. It automatically registers resource cleanup.
+func setupChainClient(tb testing.TB, miner *rpctest.Harness) chain.Interface {
+ tb.Helper()
+
+ // Start a bitcoind instance and connect it to miner.
+ tempBitcoindDir := tb.TempDir()
+
+ zmqBlockPort := port.NextAvailablePort()
+ zmqTxPort := port.NextAvailablePort()
+
+ zmqBlockHost := fmt.Sprintf("tcp://127.0.0.1:%d", zmqBlockPort)
+ zmqTxHost := fmt.Sprintf("tcp://127.0.0.1:%d", zmqTxPort)
+
+ rpcPort := port.NextAvailablePort()
+ p2pPort := port.NextAvailablePort()
+ minerAddr := miner.P2PAddress()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ tb.Cleanup(cancel)
+
+ bitcoind := exec.CommandContext(
+ ctx,
+ "bitcoind",
+ "-datadir="+tempBitcoindDir,
+ "-regtest",
+ "-connect="+minerAddr,
+ "-txindex",
+ "-rpcauth=weks:469e9bb14ab2360f8e226efed5ca6f"+
+ "d$507c670e800a95284294edb5773b05544b"+
+ "220110063096c221be9933c82d38e1",
+ fmt.Sprintf("-rpcport=%d", rpcPort),
+ fmt.Sprintf("-port=%d", p2pPort),
+ "-disablewallet",
+ "-zmqpubrawblock="+zmqBlockHost,
+ "-zmqpubrawtx="+zmqTxHost,
+ "-blockfilterindex=1",
+ )
+ require.NoError(tb, bitcoind.Start())
+
+ tb.Cleanup(func() {
+ _ = bitcoind.Process.Kill()
+ _ = bitcoind.Wait()
+ })
+
+ // Wait for the bitcoind instance to start up.
+ time.Sleep(time.Second)
+
+ host := fmt.Sprintf("127.0.0.1:%d", rpcPort)
+ cfg := &chain.BitcoindConfig{
+ ChainParams: &chaincfg.RegressionNetParams,
+ Host: host,
+ User: "weks",
+ Pass: "weks",
+ ZMQConfig: &chain.ZMQConfig{
+ ZMQBlockHost: zmqBlockHost,
+ ZMQTxHost: zmqTxHost,
+ ZMQReadDeadline: 5 * time.Second,
+ MempoolPollingInterval: time.Millisecond * 100,
+ },
+ }
+
+ chainConn, err := chain.NewBitcoindConn(cfg)
+ require.NoError(tb, err)
+ require.NoError(tb, chainConn.Start())
+
+ tb.Cleanup(func() {
+ chainConn.Stop()
+ })
+
+ // Create a bitcoind client.
+ btcClient, err := chainConn.NewBitcoindClient()
+ require.NoError(tb, err)
+ require.NoError(tb, btcClient.Start(tb.Context()))
+
+ tb.Cleanup(func() {
+ btcClient.Stop()
+ })
+
+ // Wait for bitcoind to sync with the miner.
+ // We want to ensure it has synced at least to the miner's tip.
+ require.Eventually(tb, func() bool {
+ _, height, err := btcClient.GetBestBlock()
+ if err != nil {
+ return false
+ }
+
+ _, minerHeight, _ := miner.Client.GetBestBlock()
+
+ return height >= minerHeight
+ }, 30*time.Second, 100*time.Millisecond)
+
+ return btcClient
+}
+
+// defaultWalletConfig returns a Config with standard benchmark settings.
+func defaultWalletConfig(tb testing.TB) Config {
+ tb.Helper()
+
+ dir := tb.TempDir()
+ dbPath := filepath.Join(dir, "wallet.db")
+ db, err := walletdb.Create("bdb", dbPath, true, 10*time.Second, false)
+ require.NoError(tb, err)
+
+ return Config{
+ DB: db,
+ ChainParams: &chaincfg.RegressionNetParams,
+ Name: "bench-wallet",
+ PubPassphrase: []byte("public"),
+ WalletSyncRetryInterval: 10 * time.Millisecond,
+ RecoveryWindow: testRecoveryWindow,
+ }
+}
+
+// assertUTXOCount verifies the number of unspent outputs in a modern wallet.
+func assertUTXOCount(b *testing.B, w *Wallet, expected int) {
+ b.Helper()
+
+ require.Eventually(b, func() bool {
+ utxos, err := w.ListUnspent(b.Context(), UtxoQuery{
+ MinConfs: 0,
+ MaxConfs: 99999,
+ })
+ require.NoError(b, err)
+
+ return len(utxos) == expected
+ }, 20*time.Second, 100*time.Millisecond, "new wallet utxo count mismatch")
+}
+
+// assertUTXOCountDeprecated verifies the number of unspent outputs in a legacy
+// wallet.
+func assertUTXOCountDeprecated(b *testing.B, w *Wallet, expected int) {
+ b.Helper()
+
+ require.Eventually(b, func() bool {
+ utxos, err := w.ListUnspentDeprecated(0, 999999, "")
+ require.NoError(b, err)
+
+ return len(utxos) == expected
+ }, 20*time.Second, 100*time.Millisecond,
+ "legacy wallet utxo count mismatch")
+}
diff --git a/wallet/controller_test.go b/wallet/controller_test.go
new file mode 100644
index 0000000000..937ade30a1
--- /dev/null
+++ b/wallet/controller_test.go
@@ -0,0 +1,1775 @@
+package wallet
+
+import (
+ "context"
+ "math"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+// TestHandleUnlockReq verifies that the handleUnlockReq method correctly
+// processes an unlock request by invoking the address manager's Unlock method
+// and updating the wallet state.
+func TestHandleUnlockReq(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet and mock its dependencies.
+ w, deps := createTestWalletWithMocks(t)
+
+ // Simulate the wallet being in the 'Started' state, which is a
+ // prerequisite for unlocking.
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ pass := []byte("password")
+ req := newUnlockReq(UnlockRequest{Passphrase: pass})
+
+ // Setup the expected call to the address manager's Unlock method.
+ deps.addrStore.On("Unlock", mock.Anything, pass).Return(nil).Once()
+
+ // Act: Dispatch the unlock request to the handler.
+ w.handleUnlockReq(req)
+
+ // Assert: Verify that the response indicates success and the wallet
+ // state has transitioned to 'Unlocked'.
+ resp := <-req.resp
+ require.NoError(t, resp)
+ require.True(t, w.state.isUnlocked())
+}
+
+// TestHandleUnlockReq_Errors verifies that handleUnlockReq correctly handles
+// error conditions, such as attempting to unlock a stopped wallet or a failure
+// from the underlying storage.
+func TestHandleUnlockReq_Errors(t *testing.T) {
+ t.Parallel()
+
+ t.Run("ErrStateForbidden", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet. By default, it is in the 'Stopped'
+ // state.
+ w, _ := createTestWalletWithMocks(t)
+
+ pass := []byte("password")
+ req := newUnlockReq(UnlockRequest{Passphrase: pass})
+
+ // Act: Attempt to unlock the wallet while it is stopped.
+ w.handleUnlockReq(req)
+
+ // Assert: Verify that the request fails with ErrStateForbidden.
+ err := <-req.resp
+ require.ErrorIs(t, err, ErrStateForbidden)
+ })
+
+ t.Run("DBUnlock_Failure", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet and transition to 'Started'.
+ w, deps := createTestWalletWithMocks(t)
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ pass := []byte("password")
+ req := newUnlockReq(UnlockRequest{Passphrase: pass})
+ deps.addrStore.On("Unlock", mock.Anything, pass).Return(
+ errDBMock,
+ ).Once()
+
+ // Act: Attempt to unlock the wallet.
+ w.handleUnlockReq(req)
+
+ // Assert: Verify that the database error is propagated.
+ err := <-req.resp
+ require.ErrorContains(t, err, "db error")
+ })
+}
+
+// TestHandleLockReq verifies that the handleLockReq method correctly processes
+// a lock request by invoking the address manager's Lock method and updating
+// the wallet state.
+func TestHandleLockReq(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet and transition it to 'Started' and
+ // then 'Unlocked'.
+ w, deps := createTestWalletWithMocks(t)
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+ w.state.toUnlocked()
+
+ req := newLockReq()
+
+ // Setup the expected call to the address manager's Lock method.
+ deps.addrStore.On("Lock").Return(nil).Once()
+
+ // Act: Dispatch the lock request to the handler.
+ w.handleLockReq(req)
+
+ // Assert: Verify that the response indicates success and the wallet
+ // state is no longer 'Unlocked'.
+ resp := <-req.resp
+ require.NoError(t, resp)
+ require.False(t, w.state.isUnlocked())
+}
+
+// TestHandleLockReq_Idempotency verifies that if the wallet is already locked
+// (indicated by waddrmgr.ErrLocked), the lock request treats it as a success
+// and ensures the state is consistent.
+func TestHandleLockReq_Idempotency(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet and transition it to 'Started'.
+ w, deps := createTestWalletWithMocks(t)
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ // Transition the wallet to the 'Unlocked' state for testing.
+ w.state.toUnlocked()
+
+ req := newLockReq()
+
+ // Setup the expected call to the address manager's Lock method
+ // returning ErrLocked.
+ errLocked := waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrLocked,
+ Description: "address manager is locked",
+ }
+ deps.addrStore.On("Lock").Return(errLocked).Once()
+
+ // Act: Dispatch the lock request to the handler.
+ w.handleLockReq(req)
+
+ // Assert: Verify that the response indicates success and the wallet
+ // state is 'Locked'.
+ resp := <-req.resp
+ require.NoError(t, resp)
+ require.False(t, w.state.isUnlocked())
+}
+
+// TestHandleLockReq_Errors verifies that handleLockReq correctly handles error
+// conditions, such as attempting to lock a stopped wallet.
+func TestHandleLockReq_Errors(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet in the default 'Stopped' state.
+ w, _ := createTestWalletWithMocks(t)
+
+ req := newLockReq()
+
+ // Act: Attempt to lock the wallet.
+ w.handleLockReq(req)
+
+ // Assert: Verify that the request fails with ErrStateForbidden.
+ err := <-req.resp
+ require.ErrorIs(t, err, ErrStateForbidden)
+}
+
+// TestMainLoop verifies that the wallet's main event loop can start and stop
+// correctly in response to context cancellation.
+func TestMainLoop(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet and setup a cancellable context to
+ // control the main loop's lifecycle.
+ w, _ := createTestWalletWithMocks(t)
+ ctx, cancel := context.WithCancel(t.Context())
+ w.lifetimeCtx = ctx
+ w.cancel = cancel
+
+ var testWg sync.WaitGroup
+ testWg.Add(1)
+ w.wg.Add(1)
+
+ // Act: Start the main loop in a background goroutine.
+ go func() {
+ defer testWg.Done()
+
+ w.mainLoop()
+ }()
+
+ // Act: Cancel the context to signal the main loop to exit.
+ cancel()
+
+ // Assert: Wait for the main loop to exit, ensuring it respects the
+ // context cancellation.
+ testWg.Wait()
+}
+
+// TestHandleChangePassphraseReq verifies the change passphrase request handler.
+func TestHandleChangePassphraseReq(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet and a dummy change passphrase request.
+ w, deps := createTestWalletWithMocks(t)
+
+ // Transition the wallet to 'Started' so the state check passes.
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ reqStruct := ChangePassphraseRequest{
+ ChangePrivate: true,
+ PrivateOld: []byte("old"),
+ PrivateNew: []byte("new"),
+ }
+ req := newChangePassphraseReq(reqStruct)
+
+ // Setup the expected call to the address manager's ChangePassphrase
+ // method.
+ deps.addrStore.On(
+ "ChangePassphrase", mock.Anything, []byte("old"),
+ []byte("new"), true, mock.Anything,
+ ).Return(nil).Once()
+
+ // Act: Call the handler.
+ w.handleChangePassphraseReq(req)
+
+ // Assert: Verify that the response indicates success.
+ resp := <-req.resp
+ require.NoError(t, resp)
+}
+
+// TestControllerStart verifies that the Start method correctly initializes the
+// wallet, verifying the birthday block, loading accounts, cleaning up locks,
+// and starting the syncer.
+func TestControllerStart(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet and mock all dependencies required for
+ // startup.
+ w, deps := createTestWalletWithMocks(t)
+
+ // 1. Mock verifyBirthday: Expect a call to retrieve the birthday
+ // block.
+ bs := waddrmgr.BlockStamp{Height: 100}
+ deps.addrStore.On(
+ "BirthdayBlock", mock.Anything,
+ ).Return(bs, true, nil).Once()
+
+ // 2. Mock DBGetAllAccounts: Expect a call to load active account
+ // managers.
+ deps.addrStore.On(
+ "ActiveScopedKeyManagers",
+ ).Return([]waddrmgr.AccountStore(nil)).Once()
+
+ // 3. Mock deleteExpiredLockedOutputs: Expect a call to cleanup expired
+ // locks in the transaction store.
+ deps.txStore.On(
+ "DeleteExpiredLockedOutputs", mock.Anything,
+ ).Return(nil).Once()
+
+ // 4. Mock syncer.run: Expect the syncer to be started.
+ deps.syncer.On(
+ "run", mock.Anything,
+ ).Return(nil).Once()
+
+ // Act: Start the wallet.
+ err := w.Start(t.Context())
+
+ // Assert: Verify that Start returned no error and the wallet state is
+ // 'Started'.
+ require.NoError(t, err)
+ require.True(t, w.state.isStarted())
+
+ // Cleanup: Stop the wallet to release resources.
+ err = w.Stop(t.Context())
+ require.NoError(t, err)
+ w.wg.Wait()
+}
+
+// TestControllerUnlock_Interrupted_SendCancelled verifies Unlock when the
+// request send is interrupted by context cancellation.
+func TestControllerUnlock_Interrupted_SendCancelled(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a wallet and a cancelled context during Unlock.
+ w1, _ := createTestWalletWithMocks(t)
+ require.NoError(t, w1.state.toStarting())
+ require.NoError(t, w1.state.toStarted())
+
+ w1.requestChan = make(chan any) // Unbuffered to block send.
+ ctx1, cancel1 := context.WithCancel(t.Context())
+
+ errChan1 := make(chan error, 1)
+ go func() {
+ errChan1 <- w1.Unlock(ctx1,
+ UnlockRequest{Passphrase: []byte("pw")})
+ }()
+
+ // Act: Cancel context to interrupt send.
+ cancel1()
+
+ // Assert: Verify cancellation error.
+ select {
+ case err := <-errChan1:
+ require.ErrorIs(t, err, context.Canceled)
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for response")
+ }
+}
+
+// TestControllerUnlock_Interrupted_SendShutdown verifies Unlock when the
+// request send is interrupted by wallet shutdown.
+func TestControllerUnlock_Interrupted_SendShutdown(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a wallet and trigger shutdown during Unlock.
+ w2, _ := createTestWalletWithMocks(t)
+ require.NoError(t, w2.state.toStarting())
+ require.NoError(t, w2.state.toStarted())
+
+ w2.requestChan = make(chan any)
+
+ errChan2 := make(chan error, 1)
+ go func() {
+ errChan2 <- w2.Unlock(t.Context(),
+ UnlockRequest{Passphrase: []byte("pw")})
+ }()
+
+ // Act: Stop wallet.
+ w2.cancel()
+
+ // Assert: Verify shutdown error.
+ select {
+ case err := <-errChan2:
+ require.ErrorIs(t, err, ErrWalletShuttingDown)
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for response")
+ }
+}
+
+// TestControllerUnlock_Interrupted_WaitCancelled verifies Unlock when the
+// response wait is interrupted by context cancellation.
+func TestControllerUnlock_Interrupted_WaitCancelled(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a wallet with a buffered channel to allow send but
+ // block on response.
+ w3, _ := createTestWalletWithMocks(t)
+ require.NoError(t, w3.state.toStarting())
+ require.NoError(t, w3.state.toStarted())
+
+ ctx3, cancel3 := context.WithCancel(t.Context())
+
+ errChan3 := make(chan error, 1)
+ go func() {
+ errChan3 <- w3.Unlock(ctx3,
+ UnlockRequest{Passphrase: []byte("pw")})
+ }()
+
+ // Wait for request to be sent.
+ select {
+ case <-w3.requestChan:
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for request")
+ }
+
+ // Act: Cancel context during response wait.
+ cancel3()
+
+ // Assert: Verify cancellation error.
+ select {
+ case err := <-errChan3:
+ require.ErrorIs(t, err, context.Canceled)
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for response")
+ }
+}
+
+// TestControllerUnlock_Interrupted_WaitShutdown verifies Unlock when the
+// response wait is interrupted by wallet shutdown.
+func TestControllerUnlock_Interrupted_WaitShutdown(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a wallet and trigger shutdown during response wait.
+ w4, _ := createTestWalletWithMocks(t)
+ require.NoError(t, w4.state.toStarting())
+ require.NoError(t, w4.state.toStarted())
+
+ errChan4 := make(chan error, 1)
+ go func() {
+ errChan4 <- w4.Unlock(t.Context(),
+ UnlockRequest{Passphrase: []byte("pw")})
+ }()
+
+ select {
+ case <-w4.requestChan:
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for request")
+ }
+
+ // Act: Stop wallet.
+ w4.cancel()
+
+ // Assert: Verify shutdown error.
+ select {
+ case err := <-errChan4:
+ require.ErrorIs(t, err, ErrWalletShuttingDown)
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for response")
+ }
+}
+
+// TestControllerVerifyBirthday_Verified verifies that verifyBirthday
+// returns early if the birthday block is already verified.
+func TestControllerVerifyBirthday_Verified(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a wallet where the birthday block is already verified.
+ w, deps := createTestWalletWithMocks(t)
+ bs := waddrmgr.BlockStamp{Height: 123, Hash: chainhash.Hash{0x01}}
+ deps.addrStore.On("BirthdayBlock", mock.Anything).Return(
+ bs, true, nil).Once()
+
+ // Act: Verify birthday.
+ err := w.verifyBirthday(t.Context())
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+ require.Equal(t, bs, w.birthdayBlock)
+}
+
+// TestControllerVerifyBirthday_LocateFail verifies verifyBirthday failure
+// when locateBirthdayBlock fails.
+func TestControllerVerifyBirthday_LocateFail(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where the birthday block is not set
+ // and chain lookup fails.
+ w, deps := createTestWalletWithMocks(t)
+
+ deps.addrStore.On("BirthdayBlock", mock.Anything).Return(
+ waddrmgr.BlockStamp{}, false, waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrBirthdayBlockNotSet,
+ }).Once()
+ deps.addrStore.On("Birthday").Return(time.Now()).Once()
+ deps.chain.On("GetBestBlock").Return(nil, int32(0), errChainMock).Once()
+
+ // Act: Attempt to verify birthday.
+ err := w.verifyBirthday(t.Context())
+
+ // Assert: Verify failure.
+ require.ErrorContains(t, err, "chain error")
+}
+
+// TestControllerVerifyBirthday_PutFail verifies verifyBirthday failure
+// when DBPutBirthdayBlock fails.
+func TestControllerVerifyBirthday_PutFail(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where block location succeeds but
+ // persisting the birthday block fails.
+ w, deps := createTestWalletWithMocks(t)
+
+ deps.addrStore.On("BirthdayBlock", mock.Anything).Return(
+ waddrmgr.BlockStamp{}, false, waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrBirthdayBlockNotSet,
+ }).Once()
+ deps.addrStore.On("Birthday").Return(time.Now()).Once()
+ deps.chain.On("GetBestBlock").Return(
+ &chainhash.Hash{}, int32(100), nil).Once()
+ deps.chain.On("GetBlockHash", mock.Anything).Return(
+ &chainhash.Hash{}, nil).Maybe()
+ deps.chain.On("GetBlockHeader", mock.Anything).Return(
+ &wire.BlockHeader{}, nil).Maybe()
+ deps.addrStore.On("SetBirthdayBlock", mock.Anything, mock.Anything,
+ true).Return(errPutMock).Once()
+
+ // Act: Attempt to verify birthday.
+ err := w.verifyBirthday(t.Context())
+
+ // Assert: Verify failure.
+ require.ErrorContains(t, err, "put error")
+}
+
+// TestSubmitRescanRequest_Errors verifies submitRescanRequest error paths.
+func TestSubmitRescanRequest_Errors(t *testing.T) {
+ t.Parallel()
+
+ t.Run("ErrStateForbidden", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a stopped wallet.
+ w, _ := createTestWalletWithMocks(t)
+
+ // Act: Attempt to submit rescan.
+ err := w.submitRescanRequest(t.Context(), scanTypeRewind, 0, nil)
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, ErrStateForbidden)
+ })
+
+ t.Run("GetBestBlock_Failure", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a started wallet where best block lookup fails.
+ w, deps := createTestWalletWithMocks(t)
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ deps.syncer.On("syncState").Return(syncStateSynced)
+ deps.chain.On("GetBestBlock").Return(
+ nil, int32(0), errBestBlock).Once()
+
+ // Act: Attempt to submit rescan.
+ err := w.submitRescanRequest(t.Context(), scanTypeRewind,
+ 0, nil)
+
+ // Assert: Verify failure.
+ require.ErrorContains(t, err, "best block fail")
+ })
+}
+
+// TestControllerStop verifies that the Stop method correctly shuts down the
+// wallet, waiting for the syncer and other background processes to exit.
+func TestControllerStop(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create and start a test wallet.
+ w, deps := createTestWalletWithMocks(t)
+
+ // Setup mocks for the startup sequence.
+ deps.addrStore.On(
+ "BirthdayBlock", mock.Anything,
+ ).Return(waddrmgr.BlockStamp{}, true, nil).Once()
+ deps.addrStore.On(
+ "ActiveScopedKeyManagers",
+ ).Return([]waddrmgr.AccountStore(nil)).Once()
+ deps.txStore.On(
+ "DeleteExpiredLockedOutputs", mock.Anything,
+ ).Return(nil).Once()
+
+ // Mock syncer.run to simulate a long-running process that exits when
+ // the context is cancelled.
+ deps.syncer.On("run", mock.Anything).Run(func(args mock.Arguments) {
+ ctx, ok := args.Get(0).(context.Context)
+ if !ok {
+ return
+ }
+ <-ctx.Done()
+ }).Return(nil).Once()
+
+ require.NoError(t, w.Start(t.Context()))
+ require.True(t, w.state.isStarted())
+
+ // Act: Stop the wallet.
+ err := w.Stop(t.Context())
+
+ // Assert: Verify that Stop returned no error and the wallet state is
+ // no longer 'Started'.
+ require.NoError(t, err)
+ require.False(t, w.state.isStarted())
+
+ // Act: Call Stop again to verify idempotency.
+ err = w.Stop(t.Context())
+
+ // Assert: Verify that subsequent Stop calls are safe and return no
+ // error.
+ require.NoError(t, err)
+}
+
+// TestControllerLock verifies the Lock method. It ensures that the wallet
+// can only be locked when it is started and currently unlocked.
+func TestControllerLock(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create and start a test wallet.
+ w, deps := createTestWalletWithMocks(t)
+
+ // Setup mocks for startup.
+ deps.addrStore.On(
+ "BirthdayBlock", mock.Anything,
+ ).Return(waddrmgr.BlockStamp{}, true, nil).Once()
+ deps.addrStore.On(
+ "ActiveScopedKeyManagers",
+ ).Return([]waddrmgr.AccountStore(nil)).Once()
+ deps.txStore.On(
+ "DeleteExpiredLockedOutputs", mock.Anything,
+ ).Return(nil).Once()
+ deps.syncer.On("run", mock.Anything).Return(nil).Once()
+
+ require.NoError(t, w.Start(t.Context()))
+
+ // Transition the wallet to the 'Unlocked' state for testing.
+ w.state.toUnlocked()
+ require.True(t, w.state.isUnlocked())
+
+ // Expect a call to the address manager's Lock method.
+ deps.addrStore.On("Lock").Return(nil).Once()
+
+ // Act: Call the Lock method.
+ err := w.Lock(t.Context())
+
+ // Assert: Verify success and that the wallet state is locked.
+ require.NoError(t, err)
+ require.False(t, w.state.isUnlocked())
+
+ // Cleanup: Stop the wallet to release resources.
+ err = w.Stop(t.Context())
+ require.NoError(t, err)
+ w.wg.Wait()
+}
+
+// TestControllerUnlock verifies the Unlock method. It ensures that the wallet
+// can be unlocked by providing the correct passphrase.
+func TestControllerUnlock(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create and start a test wallet.
+ w, deps := createTestWalletWithMocks(t)
+
+ // Setup mocks for startup.
+ deps.addrStore.On(
+ "BirthdayBlock", mock.Anything,
+ ).Return(waddrmgr.BlockStamp{}, true, nil).Once()
+ deps.addrStore.On(
+ "ActiveScopedKeyManagers",
+ ).Return([]waddrmgr.AccountStore(nil)).Once()
+ deps.txStore.On(
+ "DeleteExpiredLockedOutputs", mock.Anything,
+ ).Return(nil).Once()
+ deps.syncer.On("run", mock.Anything).Return(nil).Once()
+
+ require.NoError(t, w.Start(t.Context()))
+ require.False(t, w.state.isUnlocked())
+
+ pass := []byte("password")
+
+ // Expect a call to the address manager's Unlock method.
+ deps.addrStore.On("Unlock", mock.Anything, pass).Return(nil).Once()
+
+ // Act: Call the Unlock method.
+ err := w.Unlock(t.Context(), UnlockRequest{Passphrase: pass})
+
+ // Assert: Verify success and that the wallet state is unlocked.
+ require.NoError(t, err)
+ require.True(t, w.state.isUnlocked())
+
+ // Cleanup: Stop the wallet to release resources.
+ err = w.Stop(t.Context())
+ require.NoError(t, err)
+ w.wg.Wait()
+}
+
+// TestControllerChangePassphrase verifies the ChangePassphrase method. It
+// ensures that the wallet forwards the request to the address manager to
+// update the passphrases.
+func TestControllerChangePassphrase(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create and start a test wallet.
+ w, deps := createTestWalletWithMocks(t)
+
+ // Setup mocks for startup.
+ deps.addrStore.On(
+ "BirthdayBlock", mock.Anything,
+ ).Return(waddrmgr.BlockStamp{}, true, nil).Once()
+ deps.addrStore.On(
+ "ActiveScopedKeyManagers",
+ ).Return([]waddrmgr.AccountStore(nil)).Once()
+ deps.txStore.On(
+ "DeleteExpiredLockedOutputs", mock.Anything,
+ ).Return(nil).Once()
+ deps.syncer.On("run", mock.Anything).Return(nil).Once()
+
+ require.NoError(t, w.Start(t.Context()))
+
+ req := ChangePassphraseRequest{
+ ChangePrivate: true,
+ PrivateOld: []byte("old"),
+ PrivateNew: []byte("new"),
+ }
+
+ // Expect a call to ChangePassphrase in the address store.
+ deps.addrStore.On(
+ "ChangePassphrase", mock.Anything, []byte("old"), []byte("new"),
+ true, mock.Anything,
+ ).Return(nil).Once()
+
+ // Act: Call ChangePassphrase.
+ err := w.ChangePassphrase(t.Context(), req)
+
+ // Assert: Verify that the operation completed without error.
+ require.NoError(t, err)
+
+ // Cleanup: Stop the wallet to release resources.
+ err = w.Stop(t.Context())
+ require.NoError(t, err)
+ w.wg.Wait()
+}
+
+// TestControllerLock_Errors verifies Lock failures.
+func TestControllerLock_Errors(t *testing.T) {
+ t.Parallel()
+
+ t.Run("ContextCanceled", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a started wallet with an unbuffered request
+ // channel and a cancelled context.
+ w, _ := createTestWalletWithMocks(t)
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ w.requestChan = make(chan any) // Unbuffered to block send.
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel()
+
+ // Act: Attempt to lock.
+ err := w.Lock(ctx)
+
+ // Assert: Verify cancellation error.
+ require.ErrorIs(t, err, context.Canceled)
+ })
+
+ t.Run("WalletStopped", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a stopped wallet.
+ w, _ := createTestWalletWithMocks(t)
+ ctx, cancel := context.WithCancel(t.Context())
+ w.lifetimeCtx = ctx
+ w.cancel = cancel
+ w.cancel() // Stop wallet.
+
+ // Act: Attempt to lock.
+ err := w.Lock(t.Context())
+
+ // Assert: Verify forbidden error.
+ require.ErrorIs(t, err, ErrStateForbidden)
+ })
+}
+
+// TestControllerChangePassphrase_Errors verifies ChangePassphrase failures.
+func TestControllerChangePassphrase_Errors(t *testing.T) {
+ t.Parallel()
+
+ t.Run("ContextCanceled", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a started wallet and a cancelled context.
+ w, _ := createTestWalletWithMocks(t)
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ w.requestChan = make(chan any)
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel()
+
+ // Act: Attempt to change passphrase.
+ err := w.ChangePassphrase(ctx, ChangePassphraseRequest{})
+
+ // Assert: Verify cancellation error.
+ require.ErrorIs(t, err, context.Canceled)
+ })
+
+ t.Run("WalletStopped", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a stopped wallet.
+ w, _ := createTestWalletWithMocks(t)
+ ctx, cancel := context.WithCancel(t.Context())
+ w.lifetimeCtx = ctx
+ w.cancel = cancel
+ w.cancel()
+
+ // Act: Attempt to change passphrase.
+ err := w.ChangePassphrase(
+ t.Context(), ChangePassphraseRequest{},
+ )
+
+ // Assert: Verify forbidden error.
+ require.ErrorIs(t, err, ErrStateForbidden)
+ })
+}
+
+// TestControllerStart_WithAccounts verifies Start with existing accounts.
+func TestControllerStart_WithAccounts(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a wallet with existing accounts in the address store.
+ w, deps := createTestWalletWithMocks(t)
+
+ bs := waddrmgr.BlockStamp{Height: 100}
+ deps.addrStore.On(
+ "BirthdayBlock", mock.Anything,
+ ).Return(bs, true, nil).Once()
+
+ scopedMgr := &mockAccountStore{}
+ deps.addrStore.On(
+ "ActiveScopedKeyManagers",
+ ).Return([]waddrmgr.AccountStore{scopedMgr}).Once()
+
+ scopedMgr.On("LastAccount", mock.Anything).Return(uint32(1), nil).Once()
+ scopedMgr.On("Scope").Return(waddrmgr.KeyScopeBIP0084).Maybe()
+ scopedMgr.On(
+ "AccountProperties", mock.Anything, uint32(0),
+ ).Return(&waddrmgr.AccountProperties{AccountNumber: 0}, nil).Once()
+ scopedMgr.On(
+ "AccountProperties", mock.Anything, uint32(1),
+ ).Return(&waddrmgr.AccountProperties{AccountNumber: 1}, nil).Once()
+
+ deps.txStore.On(
+ "DeleteExpiredLockedOutputs", mock.Anything,
+ ).Return(nil).Once()
+ deps.syncer.On("run", mock.Anything).Return(nil).Once()
+
+ // Act: Start the wallet.
+ err := w.Start(t.Context())
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+ require.True(t, w.state.isStarted())
+
+ // Cleanup.
+ require.NoError(t, w.Stop(t.Context()))
+ w.wg.Wait()
+}
+
+// TestMainLoop_AutoLock verifies that the main loop handles auto-lock
+// timeouts.
+func TestMainLoop_AutoLock(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup an unlocked wallet with a short lock timer.
+ w, deps := createTestWalletWithMocks(t)
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ w.state.toUnlocked()
+
+ ctx, cancel := context.WithCancel(t.Context())
+ w.lifetimeCtx = ctx
+ w.cancel = cancel
+ w.lockTimer = time.NewTimer(time.Millisecond * 10)
+
+ lockCalled := make(chan struct{})
+ deps.addrStore.On("Lock").Run(func(args mock.Arguments) {
+ close(lockCalled)
+ }).Return(nil).Once()
+
+ // Act: Start main loop.
+ w.wg.Add(1)
+
+ go w.mainLoop()
+
+ // Assert: Verify that the auto-lock was triggered.
+ select {
+ case <-lockCalled:
+ case <-time.After(time.Second):
+ t.Fatal("Auto-lock not triggered")
+ }
+
+ // Clean up.
+ cancel()
+ w.wg.Wait()
+}
+
+// TestMainLoop_UnknownRequest verifies main loop handles unknown requests
+// gracefully.
+func TestMainLoop_UnknownRequest(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a wallet and start the main loop.
+ w, _ := createTestWalletWithMocks(t)
+
+ ctx, cancel := context.WithCancel(t.Context())
+ w.lifetimeCtx = ctx
+ w.cancel = cancel
+
+ w.wg.Add(1)
+
+ go w.mainLoop()
+
+ // Act: Send an unknown request type.
+ w.requestChan <- "unknown"
+
+ // Assert: Ensure it doesn't crash and can be stopped cleanly.
+ cancel()
+ w.wg.Wait()
+}
+
+// TestControllerLock_Interrupted_SendShutdown verifies Lock when request
+// send is interrupted by wallet shutdown.
+func TestControllerLock_Interrupted_SendShutdown(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Trigger shutdown during Lock.
+ w, _ := createTestWalletWithMocks(t)
+
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ w.requestChan = make(chan any)
+ w.cancel() // Stop wallet.
+
+ // Act: Attempt to lock.
+ err := w.Lock(t.Context())
+
+ // Assert: Verify error.
+ require.ErrorIs(t, err, ErrWalletShuttingDown)
+}
+
+// TestControllerLock_Interrupted_WaitCancelled verifies Lock when response
+// wait is interrupted by context cancellation.
+func TestControllerLock_Interrupted_WaitCancelled(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Block during response wait and cancel context.
+ w, _ := createTestWalletWithMocks(t)
+
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ ctx, cancel := context.WithCancel(t.Context())
+
+ errChan := make(chan error, 1)
+ go func() {
+ errChan <- w.Lock(ctx)
+ }()
+
+ // Wait for request.
+ select {
+ case <-w.requestChan:
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for request")
+ }
+
+ // Act: Cancel context.
+ cancel()
+
+ // Assert: Verify error.
+ select {
+ case err := <-errChan:
+ require.ErrorIs(t, err, context.Canceled)
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for response")
+ }
+}
+
+// TestControllerLock_Interrupted_WaitShutdown verifies Lock when response
+// wait is interrupted by wallet shutdown.
+func TestControllerLock_Interrupted_WaitShutdown(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Block during response wait and trigger shutdown.
+ w, _ := createTestWalletWithMocks(t)
+
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ errChan := make(chan error, 1)
+ go func() {
+ errChan <- w.Lock(t.Context())
+ }()
+
+ select {
+ case <-w.requestChan:
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for request")
+ }
+
+ // Act: Stop wallet.
+ w.cancel()
+
+ // Assert: Verify error.
+ select {
+ case err := <-errChan:
+ require.ErrorIs(t, err, ErrWalletShuttingDown)
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for response")
+ }
+}
+
+// TestControllerLock_Interrupted_WaitTimeout verifies Lock when response
+// wait times out.
+func TestControllerLock_Interrupted_WaitTimeout(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Block during response wait and allow timeout to occur.
+ w, _ := createTestWalletWithMocks(t)
+
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ ctx, cancel := context.WithTimeout(
+ t.Context(), 10*time.Millisecond,
+ )
+ defer cancel()
+
+ errChan := make(chan error, 1)
+ go func() {
+ errChan <- w.Lock(ctx)
+ }()
+
+ select {
+ case <-w.requestChan:
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
+
+ // Assert: Verify timeout error.
+ select {
+ case err := <-errChan:
+ require.ErrorContains(t, err,
+ "context deadline exceeded")
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for response")
+ }
+}
+
+// TestControllerChangePassphrase_Interrupted_SendShutdown verifies
+// ChangePassphrase when request send is interrupted by wallet shutdown.
+func TestControllerChangePassphrase_Interrupted_SendShutdown(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Trigger shutdown during send.
+ w, _ := createTestWalletWithMocks(t)
+
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ w.requestChan = make(chan any)
+ w.cancel() // Stop wallet.
+
+ // Act: Attempt change.
+ err := w.ChangePassphrase(t.Context(),
+ ChangePassphraseRequest{})
+
+ // Assert: Verify error.
+ require.ErrorIs(t, err, ErrWalletShuttingDown)
+}
+
+// TestControllerChangePassphrase_Interrupted_WaitCancelled verifies
+// ChangePassphrase when response wait is interrupted by context cancellation.
+func TestControllerChangePassphrase_Interrupted_WaitCancelled(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Block during response wait and cancel context.
+ w, _ := createTestWalletWithMocks(t)
+
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ ctx, cancel := context.WithCancel(t.Context())
+
+ errChan := make(chan error, 1)
+ go func() {
+ errChan <- w.ChangePassphrase(ctx,
+ ChangePassphraseRequest{})
+ }()
+
+ select {
+ case <-w.requestChan:
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for request")
+ }
+
+ // Act: Cancel context.
+ cancel()
+
+ // Assert: Verify error.
+ select {
+ case err := <-errChan:
+ require.ErrorIs(t, err, context.Canceled)
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for response")
+ }
+}
+
+// TestControllerChangePassphrase_Interrupted_WaitShutdown verifies
+// ChangePassphrase when response wait is interrupted by wallet shutdown.
+func TestControllerChangePassphrase_Interrupted_WaitShutdown(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Block during response wait and trigger shutdown.
+ w, _ := createTestWalletWithMocks(t)
+
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ errChan := make(chan error, 1)
+ go func() {
+ errChan <- w.ChangePassphrase(t.Context(),
+ ChangePassphraseRequest{})
+ }()
+
+ select {
+ case <-w.requestChan:
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for request")
+ }
+
+ // Act: Stop wallet.
+ w.cancel()
+
+ // Assert: Verify error.
+ select {
+ case err := <-errChan:
+ require.ErrorIs(t, err, ErrWalletShuttingDown)
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for response")
+ }
+}
+
+// TestControllerChangePassphrase_Interrupted_WaitTimeout verifies
+// ChangePassphrase when response wait times out.
+func TestControllerChangePassphrase_Interrupted_WaitTimeout(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Block during response wait and allow timeout.
+ w, _ := createTestWalletWithMocks(t)
+
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ ctx, cancel := context.WithTimeout(t.Context(),
+ 10*time.Millisecond)
+ defer cancel()
+
+ errChan := make(chan error, 1)
+ go func() {
+ errChan <- w.ChangePassphrase(ctx,
+ ChangePassphraseRequest{})
+ }()
+
+ select {
+ case <-w.requestChan:
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
+
+ // Assert: Verify timeout.
+ select {
+ case err := <-errChan:
+ require.ErrorContains(t, err,
+ "context deadline exceeded")
+ case <-time.After(time.Second):
+ t.Fatal("timeout waiting for response")
+ }
+}
+
+// TestHandleChangePassphraseReq_Errors verifies error handling for the
+// internal change passphrase request handler.
+func TestHandleChangePassphraseReq_Errors(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet in the default 'Stopped' state.
+ w, _ := createTestWalletWithMocks(t)
+
+ req := changePassphraseReq{
+ req: ChangePassphraseRequest{},
+ resp: make(chan error, 1),
+ }
+
+ // Act: Call the internal handler while the wallet is stopped.
+ w.handleChangePassphraseReq(req)
+
+ // Assert: Verify that the request fails with ErrStateForbidden.
+ err := <-req.resp
+ require.ErrorIs(t, err, ErrStateForbidden)
+}
+
+// TestControllerInfo verifies the Info method. It checks that the wallet
+// correctly aggregates information from its subsystems (chain backend,
+// address manager, and syncer).
+func TestControllerInfo(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create and start a test wallet with mocked subsystems.
+ w, deps := createTestWalletWithMocks(t)
+
+ bs := waddrmgr.BlockStamp{Height: 100}
+ deps.addrStore.On(
+ "BirthdayBlock", mock.Anything,
+ ).Return(bs, true, nil).Once()
+ deps.addrStore.On(
+ "ActiveScopedKeyManagers",
+ ).Return([]waddrmgr.AccountStore(nil)).Once()
+ deps.txStore.On(
+ "DeleteExpiredLockedOutputs", mock.Anything,
+ ).Return(nil).Once()
+ deps.syncer.On("run", mock.Anything).Return(nil).Once()
+
+ // Mock the chain backend to return a specific name.
+ deps.chain.On("BackEnd").Return("mock")
+
+ // Mock SyncedTo to return a known block stamp.
+ deps.addrStore.On("SyncedTo").Return(bs)
+
+ // Mock syncState to indicate the wallet is fully synced.
+ deps.syncer.On("syncState").Return(syncStateSynced)
+
+ require.NoError(t, w.Start(t.Context()))
+
+ // Act: Call the Info method.
+ info, err := w.Info(t.Context())
+
+ // Assert: Verify that the returned information matches the mocked
+ // values and current wallet state.
+ require.NoError(t, err)
+ require.Equal(t, "mock", info.Backend)
+ require.Equal(t, int32(100), info.BirthdayBlock.Height)
+ require.True(t, info.Synced)
+ require.True(t, info.Locked)
+
+ // Cleanup: Stop the wallet to release resources.
+ err = w.Stop(t.Context())
+ require.NoError(t, err)
+ w.wg.Wait()
+}
+
+// TestControllerResync verifies the Resync method.
+func TestControllerResync(t *testing.T) {
+ t.Parallel()
+
+ t.Run("StartHeightTooHigh", func(t *testing.T) {
+ t.Parallel()
+
+ w, deps := createTestWalletWithMocks(t)
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ deps.syncer.On("syncState").Return(syncStateSynced)
+ deps.chain.On("GetBestBlock").Return(
+ &chainhash.Hash{}, int32(100), nil,
+ ).Once()
+
+ err := w.Resync(t.Context(), 101)
+ require.ErrorIs(t, err, ErrStartHeightTooHigh)
+ })
+
+ t.Run("Success", func(t *testing.T) {
+ t.Parallel()
+
+ w, deps := createTestWalletWithMocks(t)
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ deps.syncer.On("syncState").Return(syncStateSynced)
+ deps.chain.On("GetBestBlock").Return(
+ &chainhash.Hash{}, int32(100), nil,
+ ).Once()
+ deps.syncer.On("requestScan", mock.Anything, mock.MatchedBy(
+ func(req *scanReq) bool {
+ return req.typ == scanTypeRewind &&
+ req.startBlock.Height == 50
+ },
+ )).Return(nil).Once()
+
+ err := w.Resync(t.Context(), 50)
+ require.NoError(t, err)
+ })
+}
+
+// TestControllerRescan verifies the Rescan method.
+func TestControllerRescan(t *testing.T) {
+ t.Parallel()
+
+ t.Run("NoTargets", func(t *testing.T) {
+ t.Parallel()
+
+ w, _ := createTestWalletWithMocks(t)
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ err := w.Rescan(t.Context(), 50, nil)
+ require.ErrorIs(t, err, ErrNoScanTargets)
+ })
+
+ t.Run("Success", func(t *testing.T) {
+ t.Parallel()
+
+ w, deps := createTestWalletWithMocks(t)
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ deps.syncer.On("syncState").Return(syncStateSynced)
+
+ targets := []waddrmgr.AccountScope{{Account: 1}}
+
+ deps.chain.On("GetBestBlock").Return(
+ &chainhash.Hash{}, int32(100), nil,
+ ).Once()
+ deps.syncer.On("requestScan", mock.Anything, mock.MatchedBy(
+ func(req *scanReq) bool {
+ return req.typ == scanTypeTargeted &&
+ req.startBlock.Height == 50 &&
+ len(req.targets) == 1
+ },
+ )).Return(nil).Once()
+
+ err := w.Rescan(t.Context(), 50, targets)
+ require.NoError(t, err)
+ })
+}
+
+// TestControllerStart_VerifyBirthdayFail verifies Start fails when
+// verifyBirthday fails.
+func TestControllerStart_VerifyBirthdayFail(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where birthday block lookup fails.
+ w, deps := createTestWalletWithMocks(t)
+
+ deps.addrStore.On(
+ "BirthdayBlock", mock.Anything,
+ ).Return(waddrmgr.BlockStamp{}, false, errDBMock).Once()
+
+ // Act: Attempt to start the wallet.
+ err := w.Start(t.Context())
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errDBMock)
+ require.False(t, w.state.isStarted())
+}
+
+// TestControllerStart_DBGetAllAccountsFail verifies Start fails when
+// DBGetAllAccounts fails.
+func TestControllerStart_DBGetAllAccountsFail(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where account lookup fails during
+ // startup.
+ w, deps := createTestWalletWithMocks(t)
+
+ bs := waddrmgr.BlockStamp{Height: 100}
+ deps.addrStore.On(
+ "BirthdayBlock", mock.Anything,
+ ).Return(bs, true, nil).Once()
+
+ mockScopedMgr := &mockAccountStore{}
+ deps.addrStore.On(
+ "ActiveScopedKeyManagers",
+ ).Return([]waddrmgr.AccountStore{mockScopedMgr}).Once()
+
+ mockScopedMgr.On(
+ "LastAccount", mock.Anything,
+ ).Return(uint32(0), errDBMock).Once()
+
+ // Act: Attempt to start the wallet.
+ err := w.Start(t.Context())
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errDBMock)
+ require.False(t, w.state.isStarted())
+}
+
+// TestControllerStart_BirthdayNotSet verifies the flow when birthday block is
+// not set in DB.
+func TestControllerStart_BirthdayNotSet(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where the birthday block is not set
+ // and must be located from the chain.
+ w, deps := createTestWalletWithMocks(t)
+
+ deps.addrStore.On(
+ "BirthdayBlock", mock.Anything,
+ ).Return(waddrmgr.BlockStamp{}, false, waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrBirthdayBlockNotSet,
+ }).Once()
+
+ birthday := time.Now()
+ deps.addrStore.On("Birthday").Return(birthday).Once()
+
+ deps.chain.On(
+ "GetBestBlock",
+ ).Return(&chainhash.Hash{}, int32(100), nil).Once()
+ deps.chain.On(
+ "GetBlockHash", int64(50),
+ ).Return(&chainhash.Hash{}, nil).Once()
+
+ header := &wire.BlockHeader{Timestamp: birthday}
+ deps.chain.On(
+ "GetBlockHeader", mock.Anything,
+ ).Return(header, nil).Once()
+
+ deps.addrStore.On(
+ "SetBirthdayBlock", mock.Anything,
+ mock.MatchedBy(func(bs waddrmgr.BlockStamp) bool {
+ return bs.Height == 50
+ }), true,
+ ).Return(nil).Once()
+ deps.addrStore.On(
+ "SetSyncedTo", mock.Anything, mock.Anything,
+ ).Return(nil).Once()
+
+ deps.addrStore.On(
+ "ActiveScopedKeyManagers",
+ ).Return([]waddrmgr.AccountStore(nil)).Once()
+ deps.txStore.On(
+ "DeleteExpiredLockedOutputs", mock.Anything,
+ ).Return(nil).Once()
+ deps.syncer.On("run", mock.Anything).Return(nil).Once()
+
+ // Act: Start the wallet.
+ err := w.Start(t.Context())
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+ require.True(t, w.state.isStarted())
+
+ // Clean up.
+ require.NoError(t, w.Stop(t.Context()))
+ w.wg.Wait()
+}
+
+// TestControllerUnlock_DefaultTimeout verifies default timeout usage.
+func TestControllerUnlock_DefaultTimeout(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a wallet with an auto-lock duration and start the
+ // main loop.
+ w, deps := createTestWalletWithMocks(t)
+
+ w.cfg.AutoLockDuration = time.Minute
+
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ w.wg.Add(1)
+
+ go w.mainLoop()
+
+ pass := []byte("pass")
+ req := UnlockRequest{Passphrase: pass}
+ deps.addrStore.On("Unlock", mock.Anything, pass).Return(nil).Once()
+ // Auto-lock might trigger if the test runs slowly, but it's not
+ // guaranteed.
+ deps.addrStore.On("Lock").Return(nil).Maybe()
+
+ // Act: Perform Unlock with default timeout.
+ err := w.Unlock(t.Context(), req)
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+
+ // Clean up.
+ w.cancel()
+ w.wg.Wait()
+}
+
+// TestControllerStart_DeleteExpiredFail verifies Start fails when
+// deleteExpiredLockedOutputs fails.
+func TestControllerStart_DeleteExpiredFail(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where cleanup of expired locks
+ // fails.
+ w, deps := createTestWalletWithMocks(t)
+
+ bs := waddrmgr.BlockStamp{Height: 100}
+ deps.addrStore.On("BirthdayBlock", mock.Anything).Return(
+ bs, true, nil).Once()
+ deps.addrStore.On("ActiveScopedKeyManagers").Return(
+ []waddrmgr.AccountStore(nil)).Once()
+
+ deps.txStore.On("DeleteExpiredLockedOutputs", mock.Anything).Return(
+ errDBMock).Once()
+
+ // Act: Attempt to start.
+ err := w.Start(t.Context())
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errDBMock)
+ require.False(t, w.state.isStarted())
+}
+
+// TestControllerUnlock_NegativeTimeout verifies Unlock with negative
+// timeout.
+func TestControllerUnlock_NegativeTimeout(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a wallet and start the main loop.
+ w, deps := createTestWalletWithMocks(t)
+
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ w.wg.Add(1)
+
+ go w.mainLoop()
+
+ pass := []byte("pass")
+ req := UnlockRequest{Passphrase: pass, Timeout: -1}
+ deps.addrStore.On("Unlock", mock.Anything, pass).Return(nil).Once()
+
+ // Act: Perform Unlock with negative timeout (no auto-lock).
+ err := w.Unlock(t.Context(), req)
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+
+ // Clean up.
+ w.cancel()
+ w.wg.Wait()
+}
+
+// TestControllerUnlock_DBUnlockFail verifies Unlock failure when
+// DBUnlock fails.
+func TestControllerUnlock_DBUnlockFail(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a wallet and mock an unlock failure.
+ w, deps := createTestWalletWithMocks(t)
+
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ w.wg.Add(1)
+
+ go w.mainLoop()
+
+ pass := []byte("pass")
+ deps.addrStore.On("Unlock", mock.Anything, pass).Return(
+ errDBMock).Once()
+
+ // Act: Attempt Unlock.
+ err := w.Unlock(t.Context(), UnlockRequest{Passphrase: pass})
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errDBMock)
+
+ // Clean up.
+ w.cancel()
+ w.wg.Wait()
+}
+
+// TestHandleLockReq_LockError verifies error handling when Lock fails.
+func TestHandleLockReq_LockError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where internal lock fails.
+ w, deps := createTestWalletWithMocks(t)
+
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ req := lockReq{resp: make(chan error, 1)}
+
+ deps.addrStore.On("Lock").Return(errLockMock).Once()
+
+ // Act: Handle lock request.
+ w.handleLockReq(req)
+ err := <-req.resp
+
+ // Assert: Verify error.
+ require.ErrorContains(t, err, "lock fail")
+}
+
+// TestSubmitRescanRequest_HeightOverflow verifies large start height rejection.
+func TestSubmitRescanRequest_HeightOverflow(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a wallet and attempt a rescan with an invalid height.
+ w, deps := createTestWalletWithMocks(t)
+
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ deps.syncer.On("syncState").Return(syncStateSynced).Maybe()
+
+ height := uint32(math.MaxInt32 + 1)
+
+ // Act: Attempt to submit rescan request.
+ err := w.submitRescanRequest(t.Context(), scanTypeTargeted,
+ height, nil)
+
+ // Assert: Verify error.
+ require.ErrorIs(t, err, ErrStartHeightTooLarge)
+}
+
+// TestChangePassphrase_StateError verifies early failure when state forbids
+// change.
+func TestChangePassphrase_StateError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a stopped wallet.
+ w, _ := createTestWalletWithMocks(t)
+
+ // Act: Attempt change.
+ err := w.ChangePassphrase(t.Context(),
+ ChangePassphraseRequest{})
+
+ // Assert: Verify forbidden error.
+ require.ErrorIs(t, err, ErrStateForbidden)
+}
+
+// TestControllerStart_AlreadyStarted verifies Start fails if already started.
+func TestControllerStart_AlreadyStarted(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a started wallet.
+ w, _ := createTestWalletWithMocks(t)
+
+ require.NoError(t, w.state.toStarting())
+ require.NoError(t, w.state.toStarted())
+
+ // Act: Attempt to start again.
+ err := w.Start(t.Context())
+
+ // Assert: Verify error.
+ require.ErrorIs(t, err, ErrWalletAlreadyStarted)
+}
+
+// TestControllerUnlock_StateError verifies Unlock fails if not started.
+func TestControllerUnlock_StateError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a stopped wallet.
+ w, _ := createTestWalletWithMocks(t)
+
+ // Act: Attempt Unlock.
+ err := w.Unlock(t.Context(), UnlockRequest{})
+
+ // Assert: Verify error.
+ require.ErrorIs(t, err, ErrStateForbidden)
+}
+
+// TestControllerLock_StateError verifies Lock fails if not started.
+func TestControllerLock_StateError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a stopped wallet.
+ w, _ := createTestWalletWithMocks(t)
+
+ // Act: Attempt Lock.
+ err := w.Lock(t.Context())
+
+ // Assert: Verify error.
+ require.ErrorIs(t, err, ErrStateForbidden)
+}
+
+// TestWaitForBackoff_StableRun verifies that the backoff is reset to the
+// initial value if the syncer has been running for a stable amount of time.
+func TestWaitForBackoff_StableRun(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a wallet with a canceled context to avoid waiting.
+ w := &Wallet{
+ lifetimeCtx: context.Background(),
+ }
+
+ // Mock a start time that exceeds the stable run time.
+ startTime := time.Now().Add(-stableRunTime - time.Minute)
+ currentBackoff := maxBackoff
+
+ // Mock the timer function to fire immediately.
+ timerFn := func(d time.Duration) <-chan time.Time {
+ // Verify that the backoff was reset to initial before waiting.
+ require.Equal(t, initialBackoff, d)
+
+ c := make(chan time.Time, 1)
+ c <- time.Now()
+
+ return c
+ }
+
+ // Act: Wait for backoff.
+ nextBackoff, ok := w.waitForBackoff(startTime, currentBackoff, timerFn)
+
+ // Assert: Verify that the operation continued and backoff doubled.
+ require.True(t, ok)
+ require.Equal(t, initialBackoff*2, nextBackoff)
+}
+
+// TestWaitForBackoff_UnstableRun verifies that the backoff duration doubles
+// when the syncer fails quickly (unstable run).
+func TestWaitForBackoff_UnstableRun(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a wallet.
+ w := &Wallet{
+ lifetimeCtx: context.Background(),
+ }
+
+ // Mock a start time that is recent (unstable).
+ startTime := time.Now()
+ currentBackoff := time.Second
+
+ // Mock the timer function.
+ timerFn := func(d time.Duration) <-chan time.Time {
+ // Verify that the backoff was NOT reset.
+ require.Equal(t, currentBackoff, d)
+
+ c := make(chan time.Time, 1)
+ c <- time.Now()
+
+ return c
+ }
+
+ // Act: Wait for backoff.
+ nextBackoff, ok := w.waitForBackoff(startTime, currentBackoff, timerFn)
+
+ // Assert: Verify that the operation continued and backoff doubled.
+ require.True(t, ok)
+ require.Equal(t, currentBackoff*2, nextBackoff)
+}
+
+// TestWaitForBackoff_MaxBackoffCap verifies that the backoff duration is
+// capped at maxBackoff.
+func TestWaitForBackoff_MaxBackoffCap(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a wallet.
+ w := &Wallet{
+ lifetimeCtx: context.Background(),
+ }
+
+ startTime := time.Now()
+ // Current backoff is already high enough that doubling it would exceed
+ // maxBackoff.
+ currentBackoff := maxBackoff
+
+ timerFn := func(d time.Duration) <-chan time.Time {
+ require.Equal(t, currentBackoff, d)
+
+ c := make(chan time.Time, 1)
+ c <- time.Now()
+
+ return c
+ }
+
+ // Act: Wait for backoff.
+ nextBackoff, ok := w.waitForBackoff(startTime, currentBackoff, timerFn)
+
+ // Assert: Verify that the backoff is capped.
+ require.True(t, ok)
+ require.Equal(t, maxBackoff, nextBackoff)
+}
+
+// TestWaitForBackoff_Shutdown verifies that waitForBackoff returns early if
+// the wallet is shutting down.
+func TestWaitForBackoff_Shutdown(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a wallet with a canceled context.
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ w := &Wallet{
+ lifetimeCtx: ctx,
+ }
+
+ startTime := time.Now()
+ currentBackoff := time.Second
+
+ // Mock a timer that never fires, ensuring we select on the context.
+ timerFn := func(d time.Duration) <-chan time.Time {
+ return make(chan time.Time)
+ }
+
+ // Act: Wait for backoff.
+ nextBackoff, ok := w.waitForBackoff(startTime, currentBackoff, timerFn)
+
+ // Assert: Verify that the operation was aborted.
+ require.False(t, ok)
+ require.Equal(t, time.Duration(0), nextBackoff)
+}
diff --git a/wallet/createtx.go b/wallet/createtx.go
deleted file mode 100644
index 8d9d340eaf..0000000000
--- a/wallet/createtx.go
+++ /dev/null
@@ -1,592 +0,0 @@
-// Copyright (c) 2013-2017 The btcsuite developers
-// Copyright (c) 2015-2016 The btcsuite developers
-// Use of this source code is governed by an ISC
-// license that can be found in the LICENSE file.
-
-package wallet
-
-import (
- "errors"
- "fmt"
- "math/rand"
- "sort"
-
- "github.com/btcsuite/btcd/address/v2"
- "github.com/btcsuite/btcd/btcec/v2"
- "github.com/btcsuite/btcd/btcutil/v2"
- "github.com/btcsuite/btcd/txscript/v2"
- "github.com/btcsuite/btcd/wire/v2"
- "github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/btcsuite/btcwallet/wallet/txauthor"
- "github.com/btcsuite/btcwallet/wallet/txsizes"
- "github.com/btcsuite/btcwallet/walletdb"
- "github.com/btcsuite/btcwallet/wtxmgr"
- "github.com/lightningnetwork/lnd/fn/v2"
-)
-
-func makeInputSource(eligible []Coin) txauthor.InputSource {
- // Current inputs and their total value. These are closed over by the
- // returned input source and reused across multiple calls.
- currentTotal := btcutil.Amount(0)
- currentInputs := make([]*wire.TxIn, 0, len(eligible))
- currentScripts := make([][]byte, 0, len(eligible))
- currentInputValues := make([]btcutil.Amount, 0, len(eligible))
-
- return func(target btcutil.Amount) (btcutil.Amount, []*wire.TxIn,
- []btcutil.Amount, [][]byte, error) {
-
- for currentTotal < target && len(eligible) != 0 {
- nextCredit := eligible[0]
- prevOut := nextCredit.TxOut
- outpoint := nextCredit.OutPoint
- eligible = eligible[1:]
-
- nextInput := wire.NewTxIn(&outpoint, nil, nil)
- currentTotal += btcutil.Amount(prevOut.Value)
- currentInputs = append(currentInputs, nextInput)
- currentScripts = append(
- currentScripts, prevOut.PkScript,
- )
- currentInputValues = append(
- currentInputValues,
- btcutil.Amount(prevOut.Value),
- )
- }
-
- return currentTotal, currentInputs, currentInputValues,
- currentScripts, nil
- }
-}
-
-// constantInputSource creates an input source function that always returns the
-// static set of user-selected UTXOs.
-func constantInputSource(eligible []wtxmgr.Credit) txauthor.InputSource {
- // Current inputs and their total value. These won't change over
- // different invocations as we want our inputs to remain static since
- // they're selected by the user.
- currentTotal := btcutil.Amount(0)
- currentInputs := make([]*wire.TxIn, 0, len(eligible))
- currentScripts := make([][]byte, 0, len(eligible))
- currentInputValues := make([]btcutil.Amount, 0, len(eligible))
-
- for _, credit := range eligible {
- nextInput := wire.NewTxIn(&credit.OutPoint, nil, nil)
- currentTotal += credit.Amount
- currentInputs = append(currentInputs, nextInput)
- currentScripts = append(currentScripts, credit.PkScript)
- currentInputValues = append(currentInputValues, credit.Amount)
- }
-
- return func(target btcutil.Amount) (btcutil.Amount, []*wire.TxIn,
- []btcutil.Amount, [][]byte, error) {
-
- return currentTotal, currentInputs, currentInputValues,
- currentScripts, nil
- }
-}
-
-// secretSource is an implementation of txauthor.SecretSource for the wallet's
-// address manager.
-type secretSource struct {
- *waddrmgr.Manager
- addrmgrNs walletdb.ReadBucket
-}
-
-func (s secretSource) GetKey(
- addr address.Address) (*btcec.PrivateKey, bool, error) {
-
- ma, err := s.Address(s.addrmgrNs, addr)
- if err != nil {
- return nil, false, err
- }
-
- mpka, ok := ma.(waddrmgr.ManagedPubKeyAddress)
- if !ok {
- e := fmt.Errorf("managed address type for %v is `%T` but "+
- "want waddrmgr.ManagedPubKeyAddress", addr, ma)
- return nil, false, e
- }
- privKey, err := mpka.PrivKey()
- if err != nil {
- return nil, false, err
- }
- return privKey, ma.Compressed(), nil
-}
-
-func (s secretSource) GetScript(addr address.Address) ([]byte, error) {
- ma, err := s.Address(s.addrmgrNs, addr)
- if err != nil {
- return nil, err
- }
-
- msa, ok := ma.(waddrmgr.ManagedScriptAddress)
- if !ok {
- e := fmt.Errorf("managed address type for %v is `%T` but "+
- "want waddrmgr.ManagedScriptAddress", addr, ma)
- return nil, e
- }
- return msa.Script()
-}
-
-// txToOutputs creates a signed transaction which includes each output from
-// outputs. Previous outputs to redeem are chosen from the passed account's
-// UTXO set and minconf policy. An additional output may be added to return
-// change to the wallet. This output will have an address generated from the
-// given key scope and account. If a key scope is not specified, the address
-// will always be generated from the P2WKH key scope. An appropriate fee is
-// included based on the wallet's current relay fee. The wallet must be
-// unlocked to create the transaction.
-//
-// NOTE: The dryRun argument can be set true to create a tx that doesn't alter
-// the database. A tx created with this set to true will intentionally have no
-// input scripts added and SHOULD NOT be broadcasted.
-func (w *Wallet) txToOutputs(outputs []*wire.TxOut,
- coinSelectKeyScope, changeKeyScope *waddrmgr.KeyScope,
- account uint32, minconf int32, feeSatPerKb btcutil.Amount,
- strategy CoinSelectionStrategy, dryRun bool,
- selectedUtxos []wire.OutPoint,
- allowUtxo func(utxo wtxmgr.Credit) bool) (
- *txauthor.AuthoredTx, error) {
-
- chainClient, err := w.requireChainClient()
- if err != nil {
- return nil, err
- }
-
- // Get current block's height and hash.
- bs, err := chainClient.BlockStamp()
- if err != nil {
- return nil, err
- }
-
- // Fall back to default coin selection strategy if none is supplied.
- if strategy == nil {
- strategy = CoinSelectionLargest
- }
-
- // The addrMgrWithChangeSource function of the wallet creates a
- // new change address. The address manager uses OnCommit on the
- // walletdb tx to update the in-memory state of the account
- // state. But because the commit happens _after_ the account
- // manager internal lock has been released, there is a chance
- // for the address index to be accessed concurrently, even
- // though the closure in OnCommit re-acquires the lock. To avoid
- // this issue, we surround the whole address creation process
- // with a lock.
- w.newAddrMtx.Lock()
- defer w.newAddrMtx.Unlock()
-
- var tx *txauthor.AuthoredTx
- err = walletdb.Update(w.db, func(dbtx walletdb.ReadWriteTx) error {
- addrmgrNs, changeSource, err := w.addrMgrWithChangeSource(
- dbtx, changeKeyScope, account,
- )
- if err != nil {
- return err
- }
-
- eligible, err := w.findEligibleOutputs(
- dbtx, coinSelectKeyScope, account, minconf,
- bs, allowUtxo,
- )
- if err != nil {
- return err
- }
-
- var inputSource txauthor.InputSource
- if len(selectedUtxos) > 0 {
- dedupUtxos := fn.NewSet(selectedUtxos...)
- if len(dedupUtxos) != len(selectedUtxos) {
- return errors.New("selected UTXOs contain " +
- "duplicate values")
- }
-
- eligibleByOutpoint := make(
- map[wire.OutPoint]wtxmgr.Credit,
- )
-
- for _, e := range eligible {
- eligibleByOutpoint[e.OutPoint] = e
- }
-
- var eligibleSelectedUtxo []wtxmgr.Credit
- for _, outpoint := range selectedUtxos {
- e, ok := eligibleByOutpoint[outpoint]
-
- if !ok {
- return fmt.Errorf("selected outpoint "+
- "not eligible for "+
- "spending: %v", outpoint)
- }
- eligibleSelectedUtxo = append(
- eligibleSelectedUtxo, e,
- )
- }
-
- inputSource = constantInputSource(eligibleSelectedUtxo)
-
- } else {
- // Wrap our coins in a type that implements the
- // SelectableCoin interface, so we can arrange them
- // according to the selected coin selection strategy.
- wrappedEligible := make([]Coin, len(eligible))
- for i := range eligible {
- wrappedEligible[i] = Coin{
- TxOut: wire.TxOut{
- Value: int64(
- eligible[i].Amount,
- ),
- PkScript: eligible[i].PkScript,
- },
- OutPoint: eligible[i].OutPoint,
- }
- }
-
- arrangedCoins, err := strategy.ArrangeCoins(
- wrappedEligible, feeSatPerKb,
- )
- if err != nil {
- return err
- }
- inputSource = makeInputSource(arrangedCoins)
- }
-
- tx, err = txauthor.NewUnsignedTransaction(
- outputs, feeSatPerKb, inputSource, changeSource,
- )
- if err != nil {
- return err
- }
-
- // Randomize change position, if change exists, before signing.
- // This doesn't affect the serialize size, so the change amount
- // will still be valid.
- if tx.ChangeIndex >= 0 {
- tx.RandomizeChangePosition()
- }
-
- // If a dry run was requested, we return now before adding the
- // input scripts, and don't commit the database transaction.
- // By returning an error, we make sure the walletdb.Update call
- // rolls back the transaction. But we'll react to this specific
- // error outside of the DB transaction so we can still return
- // the produced chain TX.
- if dryRun {
- return walletdb.ErrDryRunRollBack
- }
-
- // Before committing the transaction, we'll sign our inputs. If
- // the inputs are part of a watch-only account, there's no
- // private key information stored, so we'll skip signing such.
- var watchOnly bool
- if coinSelectKeyScope == nil {
- // If a key scope wasn't specified, then coin selection
- // was performed from the default wallet accounts
- // (NP2WKH, P2WKH, P2TR), so any key scope provided
- // doesn't impact the result of this call.
- watchOnly, err = w.Manager.IsWatchOnlyAccount(
- addrmgrNs, waddrmgr.KeyScopeBIP0086, account,
- )
- } else {
- watchOnly, err = w.Manager.IsWatchOnlyAccount(
- addrmgrNs, *coinSelectKeyScope, account,
- )
- }
- if err != nil {
- return err
- }
- if !watchOnly {
- err = tx.AddAllInputScripts(
- secretSource{w.Manager, addrmgrNs},
- )
- if err != nil {
- return err
- }
-
- err = validateMsgTx(
- tx.Tx, tx.PrevScripts, tx.PrevInputValues,
- )
- if err != nil {
- return err
- }
- }
-
- if tx.ChangeIndex >= 0 && account == waddrmgr.ImportedAddrAccount {
- changeAmount := btcutil.Amount(
- tx.Tx.TxOut[tx.ChangeIndex].Value,
- )
- log.Warnf("Spend from imported account produced "+
- "change: moving %v from imported account into "+
- "default account.", changeAmount)
- }
-
- // Finally, we'll request the backend to notify us of the
- // transaction that pays to the change address, if there is one,
- // when it confirms.
- if tx.ChangeIndex >= 0 {
- changePkScript := tx.Tx.TxOut[tx.ChangeIndex].PkScript
- _, addrs, _, err := txscript.ExtractPkScriptAddrs(
- changePkScript, w.chainParams,
- )
- if err != nil {
- return err
- }
- if err := chainClient.NotifyReceived(addrs); err != nil {
- return err
- }
- }
-
- return nil
- })
- if err != nil && !errors.Is(err, walletdb.ErrDryRunRollBack) {
- return nil, err
- }
-
- return tx, nil
-}
-
-func (w *Wallet) findEligibleOutputs(dbtx walletdb.ReadTx,
- keyScope *waddrmgr.KeyScope, account uint32, minconf int32,
- bs *waddrmgr.BlockStamp,
- allowUtxo func(utxo wtxmgr.Credit) bool) ([]wtxmgr.Credit, error) {
-
- addrmgrNs := dbtx.ReadBucket(waddrmgrNamespaceKey)
- txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
-
- unspent, err := w.TxStore.UnspentOutputs(txmgrNs)
- if err != nil {
- return nil, err
- }
-
- // TODO: Eventually all of these filters (except perhaps output locking)
- // should be handled by the call to UnspentOutputs (or similar).
- // Because one of these filters requires matching the output script to
- // the desired account, this change depends on making wtxmgr a waddrmgr
- // dependency and requesting unspent outputs for a single account.
- eligible := make([]wtxmgr.Credit, 0, len(unspent))
- for i := range unspent {
- output := &unspent[i]
-
- // Restrict the selected utxos if a filter function is provided.
- if allowUtxo != nil &&
- !allowUtxo(*output) {
-
- continue
- }
-
- // Only include this output if it meets the required number of
- // confirmations. Coinbase transactions must have reached
- // maturity before their outputs may be spent.
- if !hasMinConfs(minconf, output.Height, bs.Height) {
- continue
- }
- if output.FromCoinBase {
- target := int32(w.chainParams.CoinbaseMaturity)
- if !hasMinConfs(target, output.Height, bs.Height) {
- continue
- }
- }
-
- // Locked unspent outputs are skipped.
- if w.LockedOutpoint(output.OutPoint) {
- continue
- }
-
- // Only include the output if it is associated with the passed
- // account.
- //
- // TODO: Handle multisig outputs by determining if enough of the
- // addresses are controlled.
- _, addrs, _, err := txscript.ExtractPkScriptAddrs(
- output.PkScript, w.chainParams)
- if err != nil || len(addrs) != 1 {
- continue
- }
- scopedMgr, addrAcct, err := w.Manager.AddrAccount(addrmgrNs, addrs[0])
- if err != nil {
- continue
- }
- if keyScope != nil && scopedMgr.Scope() != *keyScope {
- continue
- }
- if addrAcct != account {
- continue
- }
- eligible = append(eligible, *output)
- }
- return eligible, nil
-}
-
-// inputYieldsPositively returns a boolean indicating whether this input yields
-// positively if added to a transaction. This determination is based on the
-// best-case added virtual size. For edge cases this function can return true
-// while the input is yielding slightly negative as part of the final
-// transaction.
-func inputYieldsPositively(credit *wire.TxOut,
- feeRatePerKb btcutil.Amount) bool {
-
- inputSize := txsizes.GetMinInputVirtualSize(credit.PkScript)
- inputFee := feeRatePerKb * btcutil.Amount(inputSize) / 1000
-
- return inputFee < btcutil.Amount(credit.Value)
-}
-
-// addrMgrWithChangeSource returns the address manager bucket and a change
-// source that returns change addresses from said address manager. The change
-// addresses will come from the specified key scope and account, unless a key
-// scope is not specified. In that case, change addresses will always come from
-// the P2WKH key scope.
-func (w *Wallet) addrMgrWithChangeSource(dbtx walletdb.ReadWriteTx,
- changeKeyScope *waddrmgr.KeyScope, account uint32) (
- walletdb.ReadWriteBucket, *txauthor.ChangeSource, error) {
-
- // Determine the address type for change addresses of the given
- // account.
- if changeKeyScope == nil {
- changeKeyScope = &waddrmgr.KeyScopeBIP0086
- }
- addrType := waddrmgr.ScopeAddrMap[*changeKeyScope].InternalAddrType
-
- // It's possible for the account to have an address schema override, so
- // prefer that if it exists.
- addrmgrNs := dbtx.ReadWriteBucket(waddrmgrNamespaceKey)
- scopeMgr, err := w.Manager.FetchScopedKeyManager(*changeKeyScope)
- if err != nil {
- return nil, nil, err
- }
- accountInfo, err := scopeMgr.AccountProperties(addrmgrNs, account)
- if err != nil {
- return nil, nil, err
- }
- if accountInfo.AddrSchema != nil {
- addrType = accountInfo.AddrSchema.InternalAddrType
- }
-
- // Compute the expected size of the script for the change address type.
- var scriptSize int
- switch addrType {
- case waddrmgr.PubKeyHash:
- scriptSize = txsizes.P2PKHPkScriptSize
- case waddrmgr.NestedWitnessPubKey:
- scriptSize = txsizes.NestedP2WPKHPkScriptSize
- case waddrmgr.WitnessPubKey:
- scriptSize = txsizes.P2WPKHPkScriptSize
- case waddrmgr.TaprootPubKey:
- scriptSize = txsizes.P2TRPkScriptSize
- default:
- return nil, nil, fmt.Errorf("unsupported address type: %v",
- addrType)
- }
-
- newChangeScript := func() ([]byte, error) {
- // Derive the change output script. As a hack to allow spending
- // from the imported account, change addresses are created from
- // account 0.
- var (
- changeAddr address.Address
- err error
- )
- if account == waddrmgr.ImportedAddrAccount {
- changeAddr, err = w.newChangeAddress(
- addrmgrNs, 0, *changeKeyScope,
- )
- } else {
- changeAddr, err = w.newChangeAddress(
- addrmgrNs, account, *changeKeyScope,
- )
- }
- if err != nil {
- return nil, err
- }
- return txscript.PayToAddrScript(changeAddr)
- }
-
- return addrmgrNs, &txauthor.ChangeSource{
- ScriptSize: scriptSize,
- NewScript: newChangeScript,
- }, nil
-}
-
-// validateMsgTx verifies transaction input scripts for tx. All previous output
-// scripts from outputs redeemed by the transaction, in the same order they are
-// spent, must be passed in the prevScripts slice.
-func validateMsgTx(tx *wire.MsgTx, prevScripts [][]byte,
- inputValues []btcutil.Amount) error {
-
- inputFetcher, err := txauthor.TXPrevOutFetcher(
- tx, prevScripts, inputValues,
- )
- if err != nil {
- return err
- }
-
- hashCache := txscript.NewTxSigHashes(tx, inputFetcher)
- for i, prevScript := range prevScripts {
- vm, err := txscript.NewEngine(
- prevScript, tx, i, txscript.StandardVerifyFlags, nil,
- hashCache, int64(inputValues[i]), inputFetcher,
- )
- if err != nil {
- return fmt.Errorf("cannot create script engine: %w", err)
- }
- err = vm.Execute()
- if err != nil {
- return fmt.Errorf("cannot validate transaction: %w", err)
- }
- }
- return nil
-}
-
-// sortByAmount is a generic sortable type for sorting coins by their amount.
-type sortByAmount []Coin
-
-func (s sortByAmount) Len() int { return len(s) }
-func (s sortByAmount) Less(i, j int) bool {
- return s[i].Value < s[j].Value
-}
-func (s sortByAmount) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
-
-// LargestFirstCoinSelector is an implementation of the CoinSelectionStrategy
-// that always selects the largest coins first.
-type LargestFirstCoinSelector struct{}
-
-// ArrangeCoins takes a list of coins and arranges them according to the
-// specified coin selection strategy and fee rate.
-func (*LargestFirstCoinSelector) ArrangeCoins(eligible []Coin,
- _ btcutil.Amount) ([]Coin, error) {
-
- sort.Sort(sort.Reverse(sortByAmount(eligible)))
-
- return eligible, nil
-}
-
-// RandomCoinSelector is an implementation of the CoinSelectionStrategy that
-// selects coins at random. This prevents the creation of ever smaller UTXOs
-// over time that may never become economical to spend.
-type RandomCoinSelector struct{}
-
-// ArrangeCoins takes a list of coins and arranges them according to the
-// specified coin selection strategy and fee rate.
-func (*RandomCoinSelector) ArrangeCoins(eligible []Coin,
- feeSatPerKb btcutil.Amount) ([]Coin, error) {
-
- // Skip inputs that do not raise the total transaction output
- // value at the requested fee rate.
- positivelyYielding := make([]Coin, 0, len(eligible))
- for _, output := range eligible {
- output := output
-
- if !inputYieldsPositively(&output.TxOut, feeSatPerKb) {
- continue
- }
-
- positivelyYielding = append(positivelyYielding, output)
- }
-
- rand.Shuffle(len(positivelyYielding), func(i, j int) {
- positivelyYielding[i], positivelyYielding[j] =
- positivelyYielding[j], positivelyYielding[i]
- })
-
- return positivelyYielding, nil
-}
diff --git a/wallet/createtx_test.go b/wallet/createtx_test.go
deleted file mode 100644
index 5bb3dd6a1a..0000000000
--- a/wallet/createtx_test.go
+++ /dev/null
@@ -1,599 +0,0 @@
-// Copyright (c) 2018 The btcsuite developers
-// Use of this source code is governed by an ISC
-// license that can be found in the LICENSE file.
-
-package wallet
-
-import (
- "bytes"
- "testing"
- "time"
-
- "github.com/btcsuite/btcd/address/v2"
- "github.com/btcsuite/btcd/btcutil/v2"
- "github.com/btcsuite/btcd/chaincfg/v2"
- "github.com/btcsuite/btcd/chainhash/v2"
- "github.com/btcsuite/btcd/txscript/v2"
- "github.com/btcsuite/btcd/wire/v2"
- "github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/btcsuite/btcwallet/wallet/txauthor"
- "github.com/btcsuite/btcwallet/walletdb"
- _ "github.com/btcsuite/btcwallet/walletdb/bdb"
- "github.com/btcsuite/btcwallet/wtxmgr"
- "github.com/stretchr/testify/require"
-)
-
-var (
- testBlockHash, _ = chainhash.NewHashFromStr(
- "00000000000000017188b968a371bab95aa43522665353b646e41865abae" +
- "02a4",
- )
- testBlockHeight int32 = 276425
-
- alwaysAllowUtxo = func(utxo wtxmgr.Credit) bool { return true }
-)
-
-// TestTxToOutput checks that no new address is added to he database if we
-// request a dry run of the txToOutputs call. It also makes sure a subsequent
-// non-dry run call produces a similar transaction to the dry-run.
-func TestTxToOutputsDryRun(t *testing.T) {
- t.Parallel()
-
- w, cleanup := testWallet(t)
- defer cleanup()
-
- // Create an address we can use to send some coins to.
- keyScope := waddrmgr.KeyScopeBIP0049Plus
- addr, err := w.CurrentAddress(0, keyScope)
- if err != nil {
- t.Fatalf("unable to get current address: %v", addr)
- }
- p2shAddr, err := txscript.PayToAddrScript(addr)
- if err != nil {
- t.Fatalf("unable to convert wallet address to p2sh: %v", err)
- }
-
- // Add an output paying to the wallet's address to the database.
- txOut := wire.NewTxOut(100000, p2shAddr)
- incomingTx := &wire.MsgTx{
- TxIn: []*wire.TxIn{
- {},
- },
- TxOut: []*wire.TxOut{
- txOut,
- },
- }
- addUtxo(t, w, incomingTx)
-
- // Now tell the wallet to create a transaction paying to the specified
- // outputs.
- txOuts := []*wire.TxOut{
- {
- PkScript: p2shAddr,
- Value: 10000,
- },
- {
- PkScript: p2shAddr,
- Value: 20000,
- },
- }
-
- // First do a few dry-runs, making sure the number of addresses in the
- // database us not inflated.
- dryRunTx, err := w.txToOutputs(
- txOuts, nil, nil, 0, 1, 1000, CoinSelectionLargest, true,
- nil, alwaysAllowUtxo,
- )
- if err != nil {
- t.Fatalf("unable to author tx: %v", err)
- }
- change := dryRunTx.Tx.TxOut[dryRunTx.ChangeIndex]
-
- addresses, err := w.AccountAddresses(0)
- if err != nil {
- t.Fatalf("unable to get addresses: %v", err)
- }
-
- if len(addresses) != 1 {
- t.Fatalf("expected 1 address, found %v", len(addresses))
- }
-
- dryRunTx2, err := w.txToOutputs(
- txOuts, nil, nil, 0, 1, 1000, CoinSelectionLargest, true,
- nil, alwaysAllowUtxo,
- )
- if err != nil {
- t.Fatalf("unable to author tx: %v", err)
- }
- change2 := dryRunTx2.Tx.TxOut[dryRunTx2.ChangeIndex]
-
- addresses, err = w.AccountAddresses(0)
- if err != nil {
- t.Fatalf("unable to get addresses: %v", err)
- }
-
- if len(addresses) != 1 {
- t.Fatalf("expected 1 address, found %v", len(addresses))
- }
-
- // The two dry-run TXs should be invalid, since they don't have
- // signatures.
- err = validateMsgTx(
- dryRunTx.Tx, dryRunTx.PrevScripts, dryRunTx.PrevInputValues,
- )
- if err == nil {
- t.Fatalf("Expected tx to be invalid")
- }
-
- err = validateMsgTx(
- dryRunTx2.Tx, dryRunTx2.PrevScripts, dryRunTx2.PrevInputValues,
- )
- if err == nil {
- t.Fatalf("Expected tx to be invalid")
- }
-
- // Now we do a proper, non-dry run. This should add a change address
- // to the database.
- tx, err := w.txToOutputs(
- txOuts, nil, nil, 0, 1, 1000, CoinSelectionLargest, false,
- nil, alwaysAllowUtxo,
- )
- if err != nil {
- t.Fatalf("unable to author tx: %v", err)
- }
- change3 := tx.Tx.TxOut[tx.ChangeIndex]
-
- addresses, err = w.AccountAddresses(0)
- if err != nil {
- t.Fatalf("unable to get addresses: %v", err)
- }
-
- if len(addresses) != 2 {
- t.Fatalf("expected 2 addresses, found %v", len(addresses))
- }
-
- err = validateMsgTx(tx.Tx, tx.PrevScripts, tx.PrevInputValues)
- if err != nil {
- t.Fatalf("Expected tx to be valid: %v", err)
- }
-
- // Finally, we check that all the transaction were using the same
- // change address.
- if !bytes.Equal(change.PkScript, change2.PkScript) {
- t.Fatalf("first dry-run using different change address " +
- "than second")
- }
- if !bytes.Equal(change2.PkScript, change3.PkScript) {
- t.Fatalf("dry-run using different change address " +
- "than wet run")
- }
-}
-
-// addUtxo add the given transaction to the wallet's database marked as a
-// confirmed UTXO .
-func addUtxo(t *testing.T, w *Wallet, incomingTx *wire.MsgTx) {
- var b bytes.Buffer
- if err := incomingTx.Serialize(&b); err != nil {
- t.Fatalf("unable to serialize tx: %v", err)
- }
- txBytes := b.Bytes()
-
- rec, err := wtxmgr.NewTxRecord(txBytes, time.Now())
- if err != nil {
- t.Fatalf("unable to create tx record: %v", err)
- }
-
- // The block meta will be inserted to tell the wallet this is a
- // confirmed transaction.
- block := &wtxmgr.BlockMeta{
- Block: wtxmgr.Block{
- Hash: *testBlockHash,
- Height: testBlockHeight,
- },
- Time: time.Unix(1387737310, 0),
- }
-
- if err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(wtxmgrNamespaceKey)
- err = w.TxStore.InsertTx(ns, rec, block)
- if err != nil {
- return err
- }
- // Add all tx outputs as credits.
- for i := 0; i < len(incomingTx.TxOut); i++ {
- err = w.TxStore.AddCredit(
- ns, rec, block, uint32(i), false,
- )
- if err != nil {
- return err
- }
- }
- return nil
- }); err != nil {
- t.Fatalf("failed inserting tx: %v", err)
- }
-}
-
-// addTxAndCredit adds the given transaction to the wallet's database marked as
-// a confirmed UTXO specified by the creditIndex.
-func addTxAndCredit(t *testing.T, w *Wallet, tx *wire.MsgTx,
- creditIndex uint32) {
-
- var b bytes.Buffer
- require.NoError(t, tx.Serialize(&b), "unable to serialize tx")
-
- txBytes := b.Bytes()
-
- rec, err := wtxmgr.NewTxRecord(txBytes, time.Now())
- require.NoError(t, err)
-
- // The block meta will be inserted to tell the wallet this is a
- // confirmed transaction.
- block := &wtxmgr.BlockMeta{
- Block: wtxmgr.Block{
- Hash: *testBlockHash,
- Height: testBlockHeight,
- },
- Time: time.Unix(1387737310, 0),
- }
-
- err = walletdb.Update(w.db, func(dbTx walletdb.ReadWriteTx) error {
- ns := dbTx.ReadWriteBucket(wtxmgrNamespaceKey)
- err = w.TxStore.InsertTx(ns, rec, block)
- if err != nil {
- return err
- }
-
- // Add the specified output as credit.
- err = w.TxStore.AddCredit(ns, rec, block, creditIndex, false)
- if err != nil {
- return err
- }
-
- return nil
- })
- require.NoError(t, err, "failed inserting tx")
-}
-
-// TestInputYield verifies the functioning of the inputYieldsPositively.
-func TestInputYield(t *testing.T) {
- t.Parallel()
-
- addr, _ := address.DecodeAddress(
- "bc1qw508d6qejxtdg4y5r3zarvary0c5xw7kv8f3t4", &chaincfg.MainNetParams,
- )
- pkScript, err := txscript.PayToAddrScript(addr)
- require.NoError(t, err)
-
- credit := &wire.TxOut{
- Value: 1000,
- PkScript: pkScript,
- }
-
- // At 10 sat/b this input is yielding positively.
- require.True(t, inputYieldsPositively(credit, 10000))
-
- // At 20 sat/b this input is yielding negatively.
- require.False(t, inputYieldsPositively(credit, 20000))
-}
-
-// TestTxToOutputsRandom tests random coin selection.
-func TestTxToOutputsRandom(t *testing.T) {
- t.Parallel()
-
- w, cleanup := testWallet(t)
- defer cleanup()
-
- // Create an address we can use to send some coins to.
- keyScope := waddrmgr.KeyScopeBIP0049Plus
- addr, err := w.CurrentAddress(0, keyScope)
- if err != nil {
- t.Fatalf("unable to get current address: %v", addr)
- }
- p2shAddr, err := txscript.PayToAddrScript(addr)
- if err != nil {
- t.Fatalf("unable to convert wallet address to p2sh: %v", err)
- }
-
- // Add a set of utxos to the wallet.
- incomingTx := &wire.MsgTx{
- TxIn: []*wire.TxIn{
- {},
- },
- TxOut: []*wire.TxOut{},
- }
- for amt := int64(5000); amt <= 125000; amt += 10000 {
- incomingTx.AddTxOut(wire.NewTxOut(amt, p2shAddr))
- }
-
- addUtxo(t, w, incomingTx)
-
- // Now tell the wallet to create a transaction paying to the specified
- // outputs.
- txOuts := []*wire.TxOut{
- {
- PkScript: p2shAddr,
- Value: 50000,
- },
- {
- PkScript: p2shAddr,
- Value: 100000,
- },
- }
-
- const (
- feeSatPerKb = 100000
- maxIterations = 100
- )
-
- createTx := func() *txauthor.AuthoredTx {
- tx, err := w.txToOutputs(
- txOuts, nil, nil, 0, 1, feeSatPerKb,
- CoinSelectionRandom, true, nil, alwaysAllowUtxo,
- )
- require.NoError(t, err)
- return tx
- }
-
- firstTx := createTx()
- var isRandom bool
- for iteration := 0; iteration < maxIterations; iteration++ {
- tx := createTx()
-
- // Check to see if we are getting a total input value.
- // We consider this proof that the randomization works.
- if tx.TotalInput != firstTx.TotalInput {
- isRandom = true
- }
-
- // At the used fee rate of 100 sat/b, the 5000 sat input is
- // negatively yielding. We don't expect it to ever be selected.
- for _, inputValue := range tx.PrevInputValues {
- require.NotEqual(t, inputValue, btcutil.Amount(5000))
- }
- }
-
- require.True(t, isRandom)
-}
-
-// TestCreateSimpleCustomChange tests that it's possible to let the
-// CreateSimpleTx use all coins for coin selection, but specify a custom scope
-// that isn't the current default scope.
-func TestCreateSimpleCustomChange(t *testing.T) {
- t.Parallel()
-
- w, cleanup := testWallet(t)
- defer cleanup()
-
- // First, we'll make a P2TR and a P2WKH address to send some coins to
- // (two different coin scopes).
- p2wkhAddr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0084)
- require.NoError(t, err)
-
- p2trAddr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0086)
- require.NoError(t, err)
-
- // We'll now make a transaction that'll send coins to both outputs,
- // then "credit" the wallet for that send.
- p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
- require.NoError(t, err)
- p2trScript, err := txscript.PayToAddrScript(p2trAddr)
- require.NoError(t, err)
-
- const testAmt = 1_000_000
-
- incomingTx := &wire.MsgTx{
- TxIn: []*wire.TxIn{
- {},
- },
- TxOut: []*wire.TxOut{
- wire.NewTxOut(testAmt, p2wkhScript),
- wire.NewTxOut(testAmt, p2trScript),
- },
- }
- addUtxo(t, w, incomingTx)
-
- // With the amounts credited to the wallet, we'll now do a dry run coin
- // selection w/o any default args.
- targetTxOut := &wire.TxOut{
- Value: 1_500_000,
- PkScript: p2trScript,
- }
- tx1, err := w.txToOutputs(
- []*wire.TxOut{targetTxOut}, nil, nil, 0, 1, 1000,
- CoinSelectionLargest, true, nil, alwaysAllowUtxo,
- )
- require.NoError(t, err)
-
- // We expect that all inputs were used and also the change output is a
- // taproot output (the current default).
- require.Len(t, tx1.Tx.TxIn, 2)
- require.Len(t, tx1.Tx.TxOut, 2)
- for _, txOut := range tx1.Tx.TxOut {
- scriptType, _, _, err := txscript.ExtractPkScriptAddrs(
- txOut.PkScript, w.chainParams,
- )
- require.NoError(t, err)
-
- require.Equal(t, scriptType, txscript.WitnessV1TaprootTy)
- }
-
- // Next, we'll do another dry run, but this time, specify a custom
- // change key scope. We'll also require that only inputs of P2TR are used.
- targetTxOut = &wire.TxOut{
- Value: 500_000,
- PkScript: p2trScript,
- }
- tx2, err := w.txToOutputs(
- []*wire.TxOut{targetTxOut}, &waddrmgr.KeyScopeBIP0086,
- &waddrmgr.KeyScopeBIP0084, 0, 1, 1000, CoinSelectionLargest,
- true, nil, alwaysAllowUtxo,
- )
- require.NoError(t, err)
-
- // The resulting transaction should spend a single input, and use P2WKH
- // as the output script.
- require.Len(t, tx2.Tx.TxIn, 1)
- require.Len(t, tx2.Tx.TxOut, 2)
- for i, txOut := range tx2.Tx.TxOut {
- if i != tx2.ChangeIndex {
- continue
- }
-
- scriptType, _, _, err := txscript.ExtractPkScriptAddrs(
- txOut.PkScript, w.chainParams,
- )
- require.NoError(t, err)
-
- require.Equal(t, scriptType, txscript.WitnessV0PubKeyHashTy)
- }
-}
-
-// TestSelectUtxosTxoToOutpoint tests that it is possible to use passed
-// selected utxos to craft a transaction in `txToOutpoint`.
-func TestSelectUtxosTxoToOutpoint(t *testing.T) {
- t.Parallel()
-
- w, cleanup := testWallet(t)
- defer cleanup()
-
- // First, we'll make a P2TR and a P2WKH address to send some coins to.
- p2wkhAddr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0084)
- require.NoError(t, err)
-
- p2trAddr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0086)
- require.NoError(t, err)
-
- // We'll now make a transaction that'll send coins to both outputs,
- // then "credit" the wallet for that send.
- p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
- require.NoError(t, err)
-
- p2trScript, err := txscript.PayToAddrScript(p2trAddr)
- require.NoError(t, err)
-
- incomingTx := &wire.MsgTx{
- TxIn: []*wire.TxIn{
- {},
- },
- TxOut: []*wire.TxOut{
- wire.NewTxOut(1_000_000, p2wkhScript),
- wire.NewTxOut(2_000_000, p2trScript),
- wire.NewTxOut(3_000_000, p2trScript),
- wire.NewTxOut(7_000_000, p2trScript),
- },
- }
- addUtxo(t, w, incomingTx)
-
- // We expect 4 unspent UTXOs.
- unspent, err := w.ListUnspent(0, 80, "")
- require.NoError(t, err)
- require.Len(t, unspent, 4, "expected 4 unspent UTXOs")
-
- tCases := []struct {
- name string
- selectUTXOs []wire.OutPoint
- errString string
- }{
- {
- name: "Duplicate utxo values",
- selectUTXOs: []wire.OutPoint{
- {
- Hash: incomingTx.TxHash(),
- Index: 1,
- },
- {
- Hash: incomingTx.TxHash(),
- Index: 1,
- },
- },
- errString: "selected UTXOs contain duplicate values",
- },
- {
- name: "all selected UTXOs not eligible for spending",
- selectUTXOs: []wire.OutPoint{
- {
- Hash: chainhash.Hash([32]byte{1}),
- Index: 1,
- },
- {
- Hash: chainhash.Hash([32]byte{3}),
- Index: 1,
- },
- },
- errString: "selected outpoint not eligible for " +
- "spending",
- },
- {
- name: "some select UTXOs not eligible for spending",
- selectUTXOs: []wire.OutPoint{
- {
- Hash: chainhash.Hash([32]byte{1}),
- Index: 1,
- },
- {
- Hash: incomingTx.TxHash(),
- Index: 1,
- },
- },
- errString: "selected outpoint not eligible for " +
- "spending",
- },
- {
- name: "select utxo, no duplicates and all eligible " +
- "for spending",
- selectUTXOs: []wire.OutPoint{
- {
- Hash: incomingTx.TxHash(),
- Index: 1,
- },
- {
- Hash: incomingTx.TxHash(),
- Index: 2,
- },
- },
- },
- }
-
- for _, tc := range tCases {
- t.Run(tc.name, func(t *testing.T) {
- // Test by sending 200_000.
- targetTxOut := &wire.TxOut{
- Value: 200_000,
- PkScript: p2trScript,
- }
- tx1, err := w.txToOutputs(
- []*wire.TxOut{targetTxOut}, nil, nil, 0, 1,
- 1000, CoinSelectionLargest, true,
- tc.selectUTXOs, alwaysAllowUtxo,
- )
- if tc.errString != "" {
- require.ErrorContains(t, err, tc.errString)
- require.Nil(t, tx1)
-
- return
- }
-
- require.NoError(t, err)
- require.NotNil(t, tx1)
-
- // We expect all and only our select UTXOs to be input
- // in this transaction.
- require.Len(t, tx1.Tx.TxIn, len(tc.selectUTXOs))
-
- lookupSelectUtxos := make(map[wire.OutPoint]struct{})
- for _, utxo := range tc.selectUTXOs {
- lookupSelectUtxos[utxo] = struct{}{}
- }
-
- for _, tx := range tx1.Tx.TxIn {
- _, ok := lookupSelectUtxos[tx.PreviousOutPoint]
- require.True(t, ok)
- }
-
- // Expect two outputs, change and the actual payment to
- // the address.
- require.Len(t, tx1.Tx.TxOut, 2)
- })
- }
-}
diff --git a/wallet/db_ops.go b/wallet/db_ops.go
new file mode 100644
index 0000000000..eb5acc026e
--- /dev/null
+++ b/wallet/db_ops.go
@@ -0,0 +1,844 @@
+// Package wallet provides the implementation of a Bitcoin wallet.
+//
+// TODO(yy): This file will be removed once the Store implementation is
+// finished.
+package wallet
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/btcsuite/btcwallet/walletdb/migration"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+)
+
+var (
+ // ErrMissingAddressManager is returned when the address manager namespace
+ // is missing from the database.
+ ErrMissingAddressManager = errors.New("missing address manager namespace")
+
+ // ErrMissingTxManager is returned when the transaction manager namespace is
+ // missing from the database.
+ ErrMissingTxManager = errors.New("missing transaction manager namespace")
+)
+
+// DBCreateWallet initializes the database structure for a new wallet.
+func DBCreateWallet(cfg Config, params CreateWalletParams,
+ rootKey *hdkeychain.ExtendedKey) error {
+
+ err := walletdb.Update(cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ // Create the top-level bucket for the address manager.
+ addrMgrNs, err := tx.CreateTopLevelBucket(waddrmgrNamespaceKey)
+ if err != nil {
+ return fmt.Errorf("create addr mgr bucket: %w", err)
+ }
+
+ // Create the top-level bucket for the transaction manager.
+ txMgrNs, err := tx.CreateTopLevelBucket(wtxmgrNamespaceKey)
+ if err != nil {
+ return fmt.Errorf("create tx mgr bucket: %w", err)
+ }
+
+ // Initialize the address manager in the database. This sets up
+ // the master keys and the initial account structure.
+ err = waddrmgr.Create(
+ addrMgrNs, rootKey, params.PubPassphrase, params.PrivatePassphrase,
+ cfg.ChainParams, nil, params.Birthday,
+ )
+ if err != nil {
+ return fmt.Errorf("create addr mgr: %w", err)
+ }
+
+ // Initialize the transaction manager in the database.
+ err = wtxmgr.Create(txMgrNs)
+ if err != nil {
+ return fmt.Errorf("create tx mgr: %w", err)
+ }
+
+ return nil
+ })
+ if err != nil {
+ return fmt.Errorf("update: %w", err)
+ }
+
+ return nil
+}
+
+// DBLoadWallet initializes the database and returns the address and transaction
+// managers.
+func DBLoadWallet(cfg Config) (*waddrmgr.Manager, *wtxmgr.Store, error) {
+ var (
+ addrMgr *waddrmgr.Manager
+ txMgr *wtxmgr.Store
+ )
+
+ // Before attempting to open the wallet, we'll check if there are any
+ // database upgrades for us to proceed. We'll also create our references
+ // to the address and transaction managers, as they are backed by the
+ // database.
+ err := walletdb.Update(cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ addrMgrBucket := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ if addrMgrBucket == nil {
+ return ErrMissingAddressManager
+ }
+
+ txMgrBucket := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+ if txMgrBucket == nil {
+ return ErrMissingTxManager
+ }
+
+ addrMgrUpgrader := waddrmgr.NewMigrationManager(addrMgrBucket)
+ txMgrUpgrader := wtxmgr.NewMigrationManager(txMgrBucket)
+
+ err := migration.Upgrade(txMgrUpgrader, addrMgrUpgrader)
+ if err != nil {
+ return fmt.Errorf("failed to upgrade database: %w", err)
+ }
+
+ addrMgr, err = waddrmgr.Open(
+ addrMgrBucket, cfg.PubPassphrase, cfg.ChainParams,
+ )
+ if err != nil {
+ return fmt.Errorf("failed to open address manager: %w", err)
+ }
+
+ txMgr, err = wtxmgr.Open(txMgrBucket, cfg.ChainParams)
+ if err != nil {
+ return fmt.Errorf("failed to open transaction manager: %w", err)
+ }
+
+ return nil
+ })
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to load wallet: %w", err)
+ }
+
+ return addrMgr, txMgr, nil
+}
+
+// DBGetBirthdayBlock retrieves the current birthday block from the database.
+//
+// TODO(yy): Refactor this in the `Store` implementation - we can call
+// `GetWallet` to get the birthday info.
+func (w *Wallet) DBGetBirthdayBlock(_ context.Context) (waddrmgr.BlockStamp,
+ bool, error) {
+
+ var (
+ birthdayBlock waddrmgr.BlockStamp
+ verified bool
+ )
+
+ err := walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ var err error
+
+ ns := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ birthdayBlock, verified, err = w.addrStore.BirthdayBlock(ns)
+ if err != nil {
+ return fmt.Errorf("get birthday block: %w", err)
+ }
+
+ return nil
+ })
+ if err != nil {
+ return waddrmgr.BlockStamp{}, false, fmt.Errorf("view: %w", err)
+ }
+
+ return birthdayBlock, verified, nil
+}
+
+// DBPutBirthdayBlock updates the wallet's birthday block in the database
+// and marks it as verified.
+//
+// TODO(yy): Refactor this in the `Store` implementation - we can call
+// `UpdateWallet` to set the birthday info.
+func (w *Wallet) DBPutBirthdayBlock(_ context.Context,
+ block waddrmgr.BlockStamp) error {
+
+ err := walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ err := w.addrStore.SetBirthdayBlock(ns, block, true)
+ if err != nil {
+ return fmt.Errorf("set birthday block: %w", err)
+ }
+
+ return w.addrStore.SetSyncedTo(ns, &block)
+ })
+ if err != nil {
+ return fmt.Errorf("update: %w", err)
+ }
+
+ return nil
+}
+
+// DBDeleteExpiredLockedOutputs removes any expired output locks from the
+// transaction store.
+//
+// TODO(yy): Refactor this in the `Store` implementation - we can call
+// `UpdateUTXOs` instead.
+func (w *Wallet) DBDeleteExpiredLockedOutputs(_ context.Context) error {
+ err := walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ txmgrNs := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+ return w.txStore.DeleteExpiredLockedOutputs(txmgrNs)
+ })
+ if err != nil {
+ return fmt.Errorf("cleanup expired locks: %w", err)
+ }
+
+ return nil
+}
+
+// DBUnlock attempts to unlock the wallet's address manager with the provided
+// passphrase.
+//
+// TODO(yy): Refactor this in the `Store` implementation - the only db
+// operation needed is to load the account info and derive the private keys.
+func (w *Wallet) DBUnlock(_ context.Context, passphrase []byte) error {
+ err := walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ return w.addrStore.Unlock(addrmgrNs, passphrase)
+ })
+ if err != nil {
+ return fmt.Errorf("view: %w", err)
+ }
+
+ return nil
+}
+
+// DBPutPassphrase updates the wallet's public or private passphrases.
+//
+// TODO(yy): Refactor this in the `Store` implementation - we can call
+// `UpdateWallet` instead.
+func (w *Wallet) DBPutPassphrase(_ context.Context,
+ req ChangePassphraseRequest) error {
+
+ err := walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ if req.ChangePublic {
+ err := w.addrStore.ChangePassphrase(
+ addrmgrNs, req.PublicOld, req.PublicNew,
+ false, &waddrmgr.DefaultScryptOptions,
+ )
+ if err != nil {
+ return fmt.Errorf("change public passphrase: "+
+ "%w", err)
+ }
+ }
+
+ if req.ChangePrivate {
+ err := w.addrStore.ChangePassphrase(
+ addrmgrNs, req.PrivateOld,
+ req.PrivateNew, true,
+ &waddrmgr.DefaultScryptOptions,
+ )
+ if err != nil {
+ return fmt.Errorf("change private passphrase: "+
+ "%w", err)
+ }
+ }
+
+ return nil
+ })
+ if err != nil {
+ return fmt.Errorf("update: %w", err)
+ }
+
+ return nil
+}
+
+// DBGetAllAccounts ensures all account properties are loaded into the address
+// manager's cache.
+//
+// TODO(yy): Refactor this in the `Store` implementation - we can call
+// `ListAccounts` instead, without the balance info.
+func (w *Wallet) DBGetAllAccounts(_ context.Context) error {
+ scopes := w.addrStore.ActiveScopedKeyManagers()
+
+ err := walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ for _, scopedMgr := range scopes {
+ lastAccount, err := scopedMgr.LastAccount(addrmgrNs)
+ if err != nil {
+ if waddrmgr.IsError(
+ err, waddrmgr.ErrAccountNotFound,
+ ) {
+
+ continue
+ }
+
+ return fmt.Errorf("last account: %w", err)
+ }
+
+ for i := uint32(0); i <= lastAccount; i++ {
+ _, err := scopedMgr.AccountProperties(
+ addrmgrNs, i,
+ )
+ if err != nil {
+ return fmt.Errorf("account: %w", err)
+ }
+ }
+ }
+
+ return nil
+ })
+ if err != nil {
+ return fmt.Errorf("load all accounts: %w", err)
+ }
+
+ return nil
+}
+
+// DBGetUnminedTxns retrieves all transactions currently held in the
+// wallet's unmined (mempool) store.
+//
+// TODO(yy): Refactor this in the `Store` implementation - we can call
+// `ListTxns` instead.
+func (s *syncer) DBGetUnminedTxns(_ context.Context) ([]*wire.MsgTx, error) {
+ var txs []*wire.MsgTx
+
+ err := walletdb.View(
+ s.cfg.DB, func(tx walletdb.ReadTx) error {
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ var err error
+
+ txs, err = s.txStore.UnminedTxs(txmgrNs)
+ if err != nil {
+ return fmt.Errorf("unmined txs: %w",
+ err)
+ }
+
+ return nil
+ },
+ )
+ if err != nil {
+ return nil, fmt.Errorf("view: %w", err)
+ }
+
+ return txs, nil
+}
+
+// DBPutBlocks atomically processes a filtered block connected notification
+// by inserting relevant transactions and updating the sync tip.
+//
+// NOTE: This method is used for notifications (not scans). It performs an
+// extra step to resolve address scopes (via putRelevantTxns) before
+// committing, as notification data does not include scope information.
+//
+// TODO(yy): Refactor this in the `Store` implementation - we can call
+// `UpdateWallet` instead.
+func (s *syncer) DBPutBlocks(ctx context.Context,
+ matches TxEntries, block *wtxmgr.BlockMeta) error {
+
+ err := walletdb.Update(s.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ if len(matches) > 0 {
+ err := s.putRelevantTxns(
+ ctx, tx, matches, block,
+ )
+ if err != nil {
+ return err
+ }
+ }
+
+ return s.putSyncTip(ctx, tx, *block)
+ })
+ if err != nil {
+ return fmt.Errorf("process filtered block: %w", err)
+ }
+
+ return nil
+}
+
+// DBPutTxns parses a batch of relevant transactions, identifies their
+// relevant outputs, and commits them to the database.
+//
+// NOTE: This method is used for notifications (not scans). It performs an
+// extra step to resolve address scopes (via putRelevantTxns) before
+// committing, as notification data does not include scope information.
+//
+// TODO(yy): Refactor this in the `Store` implementation - we can call
+// `UpdateUTXOs` instead.
+func (s *syncer) DBPutTxns(ctx context.Context, matches TxEntries,
+ block *wtxmgr.BlockMeta) error {
+
+ err := walletdb.Update(s.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ return s.putRelevantTxns(ctx, tx, matches, block)
+ })
+ if err != nil {
+ return fmt.Errorf("process txns: %w", err)
+ }
+
+ return nil
+}
+
+// DBGetScanData retrieves all necessary data from the database to initialize
+// the recovery state. This includes account horizons, active addresses, and
+// unspent outputs to watch.
+//
+// TODO(yy): Refactor this in the `Store` implementation - we can call
+// `ListUTXOx+ListAddress` instead, or build a dedicated sql query.
+func (s *syncer) DBGetScanData(_ context.Context,
+ targets []waddrmgr.AccountScope) ([]*waddrmgr.AccountProperties,
+ []address.Address, []wtxmgr.Credit, error) {
+
+ var (
+ horizonData []*waddrmgr.AccountProperties
+ initialAddrs []address.Address
+ initialUnspent []wtxmgr.Credit
+ )
+
+ // Perform all database reads in a single read-only transaction.
+ //
+ // TODO(yy): Refactor to build a single SQL query for these data
+ // fetches instead of multiple smaller operations within the
+ // transaction.
+ //
+ // NOTE: RecoveryState initialization and mutation are intentionally
+ // kept outside this transaction to strictly separate database I/O from
+ // in-memory state management.
+ err := walletdb.View(s.cfg.DB, func(dbtx walletdb.ReadTx) error {
+ addrmgrNs := dbtx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
+
+ // 1. Collect Horizons.
+ for _, target := range targets {
+ scopedMgr, err := s.addrStore.FetchScopedKeyManager(
+ target.Scope,
+ )
+ if err != nil {
+ return fmt.Errorf("fetch scoped manager: %w",
+ err)
+ }
+
+ props, err := scopedMgr.AccountProperties(
+ addrmgrNs, target.Account,
+ )
+ if err != nil {
+ return fmt.Errorf("account properties: %w", err)
+ }
+
+ horizonData = append(horizonData, props)
+ }
+
+ // 2. Load Active Addresses.
+ err := s.addrStore.ForEachRelevantActiveAddress(
+ addrmgrNs, func(addr address.Address) error {
+ initialAddrs = append(initialAddrs, addr)
+ return nil
+ },
+ )
+ if err != nil {
+ return fmt.Errorf("for each relevant address: %w", err)
+ }
+
+ // 3. Load UTXOs.
+ initialUnspent, err = s.txStore.OutputsToWatch(txmgrNs)
+ if err != nil {
+ return fmt.Errorf("outputs to watch: %w", err)
+ }
+
+ return nil
+ })
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("load recovery state: %w", err)
+ }
+
+ return horizonData, initialAddrs, initialUnspent, nil
+}
+
+// DBGetSyncedBlocks retrieves a batch of block hashes from the wallet's
+// database for the range [startHeight, endHeight].
+//
+// TODO(yy): Refactor this in the `Store` implementation - we can call
+// `ListSyncedBlocks` instead on `WalletStore`?
+func (s *syncer) DBGetSyncedBlocks(_ context.Context, startHeight,
+ endHeight int32) ([]*chainhash.Hash, error) {
+
+ var localHashes []*chainhash.Hash
+
+ err := walletdb.View(s.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ count := endHeight - startHeight + 1
+ localHashes = make([]*chainhash.Hash, 0, count)
+
+ // We fetch from startHeight to endHeight to match the order
+ // we'll get from the chain backend (ascending).
+ for h := startHeight; h <= endHeight; h++ {
+ hash, err := s.addrStore.BlockHash(addrmgrNs, h)
+ if err != nil {
+ return fmt.Errorf("get block hash %d: %w",
+ h, err)
+ }
+
+ localHashes = append(localHashes, hash)
+ }
+
+ return nil
+ })
+ if err != nil {
+ return nil, fmt.Errorf("fetch synced block hashes: %w", err)
+ }
+
+ return localHashes, nil
+}
+
+// DBPutRewind rewinds the wallet state to the specified fork point.
+//
+// TODO(yy): Refactor this in the `Store` implementation - we need to define a
+// new method and build customized query for this.
+func (s *syncer) DBPutRewind(_ context.Context,
+ bs waddrmgr.BlockStamp) error {
+
+ err := walletdb.Update(s.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ txmgrNs := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+
+ err := s.addrStore.SetSyncedTo(addrmgrNs, &bs)
+ if err != nil {
+ return fmt.Errorf("set synced to: %w", err)
+ }
+
+ return s.txStore.Rollback(txmgrNs, bs.Height+1)
+ })
+ if err != nil {
+ return fmt.Errorf("rollback wallet: %w", err)
+ }
+
+ return nil
+}
+
+// DBPutSyncBatch updates the database with the results of a batch scan. It
+// handles persisting address horizons, transactions, and connecting blocks.
+//
+// TODO(yy): Refactor this in the `Store` implementation - we need a dedicated
+// query for this on `WalletStore`?
+func (s *syncer) DBPutSyncBatch(ctx context.Context,
+ results []scanResult) error {
+
+ // TODO(yy): build a single SQL query for this.
+ err := walletdb.Update(s.cfg.DB, func(dbtx walletdb.ReadWriteTx) error {
+ addrmgrNs := dbtx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ // 1. Update Address State (Horizons).
+ err := s.putAddrHorizons(ctx, addrmgrNs, results)
+ if err != nil {
+ return err
+ }
+
+ // 2. Update UTXO State (Transactions).
+ err = s.putScanTxns(ctx, dbtx, results)
+ if err != nil {
+ return err
+ }
+
+ // 3. Connect Blocks.
+ // We must process blocks in order and connect each one to
+ // ensure the address manager's block index remains contiguous.
+ //
+ // TODO(yy): This is inefficient as it performs a DB
+ // write/check for each block. Implement a batch write method
+ // in waddrmgr (or wait for SQL migration) to validate and
+ // insert the entire chain segment at once.
+ for _, res := range results {
+ err = s.putSyncTip(ctx, dbtx, *res.meta)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+ if err != nil {
+ return fmt.Errorf("process scan batch: %w", err)
+ }
+
+ return nil
+}
+
+// DBPutTargetedBatch updates the database with the results of a targeted
+// rescan. It persists address horizons and transactions but does NOT connect
+// blocks or update the wallet's synced tip.
+//
+// TODO(yy): Refactor this in the `Store` implementation - we need a dedicated
+// query for this on `WalletStore`?
+func (s *syncer) DBPutTargetedBatch(ctx context.Context,
+ results []scanResult) error {
+
+ err := walletdb.Update(s.cfg.DB, func(dbtx walletdb.ReadWriteTx) error {
+ addrmgrNs := dbtx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ // 1. Update Address State (Horizons).
+ err := s.putAddrHorizons(ctx, addrmgrNs, results)
+ if err != nil {
+ return err
+ }
+
+ // 2. Update UTXO State (Transactions).
+ err = s.putScanTxns(ctx, dbtx, results)
+ if err != nil {
+ return err
+ }
+
+ return nil
+ })
+ if err != nil {
+ return fmt.Errorf("process rescan batch: %w", err)
+ }
+
+ return nil
+}
+
+// DBPutSyncTip handles a chain server notification by marking a wallet
+// that's currently in-sync with the chain server as being synced up to the
+// passed block.
+//
+// TODO(yy): Refactor this in the `Store` implementation - we can call
+// `UpdateWallet` instead.
+func (s *syncer) DBPutSyncTip(ctx context.Context,
+ b wtxmgr.BlockMeta) error {
+
+ err := walletdb.Update(s.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ return s.putSyncTip(ctx, tx, b)
+ })
+ if err != nil {
+ return fmt.Errorf("commit sync tip: %w", err)
+ }
+
+ return nil
+}
+
+// putRelevantTxns identifies the branch scopes for a batch of relevant
+// transactions received from notifications (not scans), resolves them, and
+// commits them to the database.
+func (s *syncer) putRelevantTxns(ctx context.Context,
+ dbtx walletdb.ReadWriteTx, matches TxEntries,
+ block *wtxmgr.BlockMeta) error {
+
+ // 1. Resolution: Resolve scopes and finalize entries.
+ err := s.resolveTxMatches(ctx, dbtx, matches)
+ if err != nil {
+ return err
+ }
+
+ // 2. Commit: Insert each transaction with its resolved credits.
+ return s.putTxns(ctx, dbtx, matches, block)
+}
+
+// resolveTxMatches identifies the branch scopes for a batch of pre-extracted
+// transactions and address entries, filtering out invalid ones.
+func (s *syncer) resolveTxMatches(ctx context.Context,
+ dbtx walletdb.ReadTx, matches TxEntries) error {
+
+ // 1. Resolution: Resolve scopes for all unique addresses.
+ scopeMap, err := s.filterBranchScopes(ctx, dbtx, matches)
+ if err != nil {
+ return err
+ }
+
+ // 2. Construction: Finalize entries by applying resolved scopes.
+ for i := range matches {
+ match := &matches[i]
+
+ valid := make([]AddrEntry, 0, len(match.Entries))
+ for _, entry := range match.Entries {
+ scope, ok := scopeMap[entry.Address.String()]
+ if !ok {
+ continue
+ }
+
+ entry.Credit.Change = scope.Branch ==
+ waddrmgr.InternalBranch
+
+ valid = append(valid, entry)
+ }
+
+ match.Entries = valid
+ }
+
+ return nil
+}
+
+// putSyncTip handles a chain server notification by marking a wallet that's
+// currently in-sync with the chain server as being synced up to the passed
+// block.
+func (s *syncer) putSyncTip(_ context.Context,
+ dbtx walletdb.ReadWriteTx, b wtxmgr.BlockMeta) error {
+
+ addrmgrNs := dbtx.ReadWriteBucket(waddrmgrNamespaceKey)
+ bs := waddrmgr.BlockStamp{
+ Height: b.Height,
+ Hash: b.Hash,
+ Timestamp: b.Time,
+ }
+
+ err := s.addrStore.SetSyncedTo(addrmgrNs, &bs)
+ if err != nil {
+ return fmt.Errorf("failed to set synced to: %w", err)
+ }
+
+ return nil
+}
+
+// filterBranchScopes retrieves the branch scope for a given set of address
+// entries. It returns a map where the key is the address string and the value
+// is the corresponding branch scope.
+func (s *syncer) filterBranchScopes(_ context.Context, dbtx walletdb.ReadTx,
+ matches TxEntries) (map[string]waddrmgr.BranchScope, error) {
+
+ ns := dbtx.ReadBucket(waddrmgrNamespaceKey)
+
+ // Deduplicate addresses from the input entries to minimize expensive
+ // database lookups for transactions with multiple outputs to the same
+ // address.
+ uniqueAddrs := make(map[string]address.Address)
+ for _, match := range matches {
+ for _, entry := range match.Entries {
+ uniqueAddrs[entry.Address.String()] = entry.Address
+ }
+ }
+
+ // Resolve the branch scope (Scope, Account, Branch) for each unique
+ // address. Addresses not found in the manager are skipped.
+ scopes := make(map[string]waddrmgr.BranchScope, len(uniqueAddrs))
+ for addrStr, addr := range uniqueAddrs {
+ ma, err := s.addrStore.Address(ns, addr)
+ if err != nil {
+ if waddrmgr.IsError(err, waddrmgr.ErrAddressNotFound) {
+ continue
+ }
+
+ return nil, fmt.Errorf("get address info: %w", err)
+ }
+
+ scopedManager, account, err := s.addrStore.AddrAccount(ns, addr)
+ if err != nil {
+ return nil, fmt.Errorf("get addr account: %w", err)
+ }
+
+ branch := waddrmgr.ExternalBranch
+ if ma.Internal() {
+ branch = waddrmgr.InternalBranch
+ }
+
+ scopes[addrStr] = waddrmgr.BranchScope{
+ Scope: scopedManager.Scope(),
+ Account: account,
+ Branch: branch,
+ }
+ }
+
+ return scopes, nil
+}
+
+// putAddrHorizons aggregates found address horizons from the scan
+// results and updates the address manager state (extends horizons) in the
+// database.
+func (s *syncer) putAddrHorizons(_ context.Context,
+ ns walletdb.ReadWriteBucket, results []scanResult) error {
+
+ // Aggregate Horizon Expansion.
+ batchHorizons := make(map[waddrmgr.BranchScope]uint32)
+ for _, res := range results {
+ for bs, idx := range res.FoundHorizons {
+ if current, ok := batchHorizons[bs]; !ok ||
+ idx > current {
+
+ batchHorizons[bs] = idx
+ }
+ }
+ }
+
+ if len(batchHorizons) == 0 {
+ return nil
+ }
+ // Update the database.
+ for bs, maxFoundIndex := range batchHorizons {
+ scopedMgr, err := s.addrStore.FetchScopedKeyManager(bs.Scope)
+ if err != nil {
+ return fmt.Errorf("fetch scoped manager: %w", err)
+ }
+
+ err = scopedMgr.ExtendAddresses(
+ ns, bs.Account, maxFoundIndex, bs.Branch,
+ )
+ if err != nil {
+ return fmt.Errorf("extend addresses: %w", err)
+ }
+ }
+
+ return nil
+}
+
+// putScanTxns processes relevant transactions found during the scan
+// and inserts them into the transaction store (and address manager for usage).
+func (s *syncer) putScanTxns(ctx context.Context,
+ dbtx walletdb.ReadWriteTx, results []scanResult) error {
+
+ for _, result := range results {
+ matches := result.RelevantOutputs
+
+ // The RelevantTxs in scanResult are *btcutil.Tx. We need to
+ // ensure the TxEntries have the correct *wtxmgr.TxRecord.
+ for i := range matches {
+ matches[i].Rec.Received = result.meta.Time
+ }
+
+ err := s.putTxns(ctx, dbtx, matches, result.meta)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// putTxns inserts relevant transactions and their credits into the wallet
+// using pre-matched output data.
+func (s *syncer) putTxns(_ context.Context, dbtx walletdb.ReadWriteTx,
+ matches TxEntries, block *wtxmgr.BlockMeta) error {
+
+ txmgrNs := dbtx.ReadWriteBucket(wtxmgrNamespaceKey)
+ addrmgrNs := dbtx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ for _, match := range matches {
+ rec := match.Rec
+ entries := match.Entries
+
+ credits := make([]wtxmgr.CreditEntry, 0, len(entries))
+ for _, entry := range entries {
+ credits = append(credits, entry.Credit)
+
+ err := s.addrStore.MarkUsed(addrmgrNs, entry.Address)
+ if err != nil {
+ return fmt.Errorf("mark used: %w", err)
+ }
+ }
+
+ var err error
+ if block != nil {
+ err = s.txStore.InsertConfirmedTx(
+ txmgrNs, rec, block, credits,
+ )
+ } else {
+ err = s.txStore.InsertUnconfirmedTx(
+ txmgrNs, rec, credits,
+ )
+ }
+
+ if err != nil {
+ return fmt.Errorf("insert tx: %w", err)
+ }
+ }
+
+ return nil
+}
diff --git a/wallet/db_ops_test.go b/wallet/db_ops_test.go
new file mode 100644
index 0000000000..ed5f84a3be
--- /dev/null
+++ b/wallet/db_ops_test.go
@@ -0,0 +1,1123 @@
+package wallet
+
+import (
+ "testing"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+ _ "github.com/btcsuite/btcwallet/walletdb/bdb"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+// TestDBCreateWallet verifies that the wallet database is correctly
+// initialized with the address and transaction manager buckets.
+func TestDBCreateWallet(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet with a fresh database.
+ // Note: createTestWalletWithMocks creates the top-level buckets, but
+ // they are empty. DBCreateWallet will populate them.
+ w, _ := createTestWalletWithMocks(t)
+
+ params := CreateWalletParams{
+ PubPassphrase: []byte("public"),
+ PrivatePassphrase: []byte("private"),
+ Birthday: time.Now(),
+ }
+
+ // Act: Initialize the wallet database.
+ err := DBCreateWallet(w.cfg, params, nil)
+
+ // Assert: Verify initialization success.
+ require.NoError(t, err)
+
+ // Verify that the address manager and transaction manager can be
+ // opened, indicating successful initialization.
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ require.NotNil(t, addrmgrNs)
+
+ _, err := waddrmgr.Open(
+ addrmgrNs, params.PubPassphrase, w.cfg.ChainParams,
+ )
+ if err != nil {
+ return err
+ }
+
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+ require.NotNil(t, txmgrNs)
+
+ _, err = wtxmgr.Open(txmgrNs, w.cfg.ChainParams)
+
+ return err
+ })
+ require.NoError(t, err)
+}
+
+// TestDBLoadWallet verifies that the wallet database can be successfully loaded
+// and the address and transaction managers retrieved.
+func TestDBLoadWallet(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet and initialize it.
+ w, _ := createTestWalletWithMocks(t)
+
+ pubPass := []byte("public")
+ w.cfg.PubPassphrase = pubPass
+
+ params := CreateWalletParams{
+ PubPassphrase: pubPass,
+ PrivatePassphrase: []byte("private"),
+ Birthday: time.Now(),
+ }
+
+ err := DBCreateWallet(w.cfg, params, nil)
+ require.NoError(t, err)
+
+ // Act: Load the wallet database.
+ addrMgr, txMgr, err := DBLoadWallet(w.cfg)
+
+ // Assert: Verify that both managers were loaded successfully.
+ require.NoError(t, err)
+ require.NotNil(t, addrMgr)
+ require.NotNil(t, txMgr)
+}
+
+// TestDBBirthdayBlock verifies that the wallet can successfully persist and
+// retrieve its birthday block information.
+func TestDBBirthdayBlock(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet with mocked underlying stores.
+ w, mocks := createTestWalletWithMocks(t)
+
+ block := waddrmgr.BlockStamp{
+ Height: 100,
+ Hash: chainhash.Hash{0x01},
+ Timestamp: time.Unix(1000, 0),
+ }
+
+ // 1. Test DBPutBirthdayBlock.
+ //
+ // Arrange: Setup the expected mock calls for updating the birthday
+ // block and the sync tip in the address manager.
+ mocks.addrStore.On(
+ "SetBirthdayBlock", mock.Anything, block, true,
+ ).Return(nil).Once()
+ mocks.addrStore.On(
+ "SetSyncedTo", mock.Anything, &block,
+ ).Return(nil).Once()
+
+ // Act: Persist the birthday block to the database.
+ err := w.DBPutBirthdayBlock(t.Context(), block)
+
+ // Assert: Ensure the update completed without error.
+ require.NoError(t, err)
+
+ // 2. Test DBGetBirthdayBlock.
+ //
+ // Arrange: Setup the expected mock call for retrieving the birthday
+ // block from the address manager.
+ mocks.addrStore.On(
+ "BirthdayBlock", mock.Anything,
+ ).Return(block, true, nil).Once()
+
+ // Act: Retrieve the persisted birthday block from the database.
+ retBlock, verified, err := w.DBGetBirthdayBlock(t.Context())
+
+ // Assert: Verify the retrieved block data and verification status
+ // matches what was persisted.
+ require.NoError(t, err)
+ require.True(t, verified)
+ require.Equal(t, block, retBlock)
+}
+
+// TestDBUnlock verifies that the wallet can successfully unlock its address
+// manager using the provided passphrase.
+func TestDBUnlock(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet and setup the expected mock call for
+ // unlocking the address manager.
+ w, mocks := createTestWalletWithMocks(t)
+ pass := []byte("password")
+
+ mocks.addrStore.On("Unlock", mock.Anything, pass).Return(nil).Once()
+
+ // Act: Attempt to unlock the wallet with the passphrase.
+ err := w.DBUnlock(t.Context(), pass)
+
+ // Assert: Verify that the unlock operation succeeded.
+ require.NoError(t, err)
+}
+
+// TestDBDeleteExpiredLockedOutputs verifies that the wallet successfully
+// invokes the transaction store to remove any expired output locks.
+func TestDBDeleteExpiredLockedOutputs(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet and setup the expected mock call for
+ // deleting expired locked outputs.
+ w, mocks := createTestWalletWithMocks(t)
+
+ mocks.txStore.On(
+ "DeleteExpiredLockedOutputs", mock.Anything,
+ ).Return(nil).Once()
+
+ // Act: Trigger the cleanup of expired locked outputs in the database.
+ err := w.DBDeleteExpiredLockedOutputs(t.Context())
+
+ // Assert: Verify that the cleanup operation finished without error.
+ require.NoError(t, err)
+}
+
+// TestDBPutPassphrase verifies that the wallet can successfully update both
+// its public and private passphrases in the address manager.
+func TestDBPutPassphrase(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet and a request to change both
+ // passphrases.
+ w, mocks := createTestWalletWithMocks(t)
+
+ req := ChangePassphraseRequest{
+ ChangePublic: true,
+ PublicOld: []byte("old"),
+ PublicNew: []byte("new"),
+ ChangePrivate: true,
+ PrivateOld: []byte("old_priv"),
+ PrivateNew: []byte("new_priv"),
+ }
+
+ // Setup mock calls for both passphrase changes.
+ mocks.addrStore.On(
+ "ChangePassphrase", mock.Anything, []byte("old"), []byte("new"),
+ false, mock.Anything,
+ ).Return(nil).Once()
+
+ mocks.addrStore.On(
+ "ChangePassphrase", mock.Anything, req.PrivateOld,
+ req.PrivateNew, true,
+ mock.MatchedBy(func(opts *waddrmgr.ScryptOptions) bool {
+ return opts.N == 16 && opts.R == 8 && opts.P == 1
+ }),
+ ).Return(nil).Once()
+
+ // Act: Commit the passphrase changes to the database.
+ err := w.DBPutPassphrase(t.Context(), req)
+
+ // Assert: Verify that both passphrases were updated successfully.
+ require.NoError(t, err)
+}
+
+// TestDBPutPassphrase_Error verifies that DBPutPassphrase correctly handles
+// and returns errors encountered during the passphrase update process.
+func TestDBPutPassphrase_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet and setup a mock call that simulates
+ // a database error during a private passphrase change.
+ w, mocks := createTestWalletWithMocks(t)
+
+ req := ChangePassphraseRequest{
+ ChangePrivate: true,
+ }
+
+ mocks.addrStore.On(
+ "ChangePassphrase", mock.Anything, mock.Anything,
+ mock.Anything, true, mock.Anything,
+ ).Return(errDBMock).Once()
+
+ // Act: Attempt to change the passphrase, expecting a failure.
+ err := w.DBPutPassphrase(t.Context(), req)
+
+ // Assert: Verify that the expected database error is returned.
+ require.ErrorContains(t, err, "db error")
+}
+
+// TestDBPutBlocks_Error verifies that DBPutBlocks correctly handles errors
+// that occur during the transaction matching resolution phase.
+func TestDBPutBlocks_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer with mocked stores and setup a scenario
+ // where address lookup fails during transaction resolution.
+ w, mocks := createTestWalletWithMocks(t)
+ s := newSyncer(w.cfg, w.addrStore, w.txStore, nil)
+
+ addr, _ := address.NewAddressPubKeyHash(
+ make([]byte, 20), &chaincfg.MainNetParams,
+ )
+ matches := TxEntries{{Entries: []AddrEntry{{Address: addr}}}}
+
+ mocks.addrStore.On("Address", mock.Anything, addr).Return(
+ nil, errDBMock,
+ ).Once()
+
+ // Act: Attempt to process a block with relevant transactions.
+ err := s.DBPutBlocks(t.Context(), matches, nil)
+
+ // Assert: Verify that the address lookup error is correctly
+ // propagated.
+ require.ErrorContains(t, err, "db error")
+}
+
+// TestDBPutSyncBatch_Error verifies that DBPutSyncBatch correctly propagates
+// errors encountered when fetching scoped key managers for horizon updates.
+func TestDBPutSyncBatch_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer and a scan result that requires updating an
+ // address manager's horizon.
+ w, mocks := createTestWalletWithMocks(t)
+ s := newSyncer(w.cfg, w.addrStore, w.txStore, nil)
+
+ res := scanResult{
+ meta: &wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{Height: 100},
+ },
+ BlockProcessResult: &BlockProcessResult{
+ FoundHorizons: map[waddrmgr.BranchScope]uint32{
+ {
+ Scope: waddrmgr.KeyScopeBIP0084,
+ }: 5,
+ },
+ },
+ }
+
+ // Simulate a failure when fetching the scoped key manager.
+ mocks.addrStore.On(
+ "FetchScopedKeyManager", waddrmgr.KeyScopeBIP0084,
+ ).Return(nil, errMock).Once()
+
+ // Act: Attempt to commit a batch of scan results to the database.
+ err := s.DBPutSyncBatch(t.Context(), []scanResult{res})
+
+ // Assert: Verify that the expected error is returned.
+ require.ErrorIs(t, err, errMock)
+}
+
+// TestDBPutBlocks verifies the full lifecycle of DBPutBlocks, including
+// resolving transaction scopes, marking addresses as used, inserting confirmed
+// transactions, and updating the sync tip.
+func TestDBPutBlocks(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer and setup test data for a confirmed
+ // transaction.
+ w, mocks := createTestWalletWithMocks(t)
+ s := newSyncer(w.cfg, w.addrStore, w.txStore, nil)
+
+ tx := wire.NewMsgTx(1)
+ rec, _ := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
+ addr, _ := address.NewAddressPubKeyHash(
+ make([]byte, 20), &chaincfg.MainNetParams,
+ )
+
+ matches := TxEntries{{
+ Rec: rec,
+ Entries: []AddrEntry{{
+ Address: addr,
+ }},
+ }}
+
+ block := &wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{Height: 100},
+ }
+
+ // 1. Transaction Resolution.
+ //
+ // Setup mocks to resolve the address scope for the transaction output.
+ mockAddr := &mockManagedPubKeyAddr{}
+ mockAddr.On("Internal").Return(false).Once()
+ mocks.addrStore.On("Address", mock.Anything, addr).Return(
+ mockAddr, nil,
+ ).Once()
+
+ scopedMgr := &mockAccountStore{}
+ scopedMgr.On("Scope").Return(waddrmgr.KeyScopeBIP0084).Once()
+ mocks.addrStore.On("AddrAccount", mock.Anything, addr).Return(
+ scopedMgr, uint32(0), nil,
+ ).Once()
+
+ // 2. Transaction Insertion.
+ //
+ // Expect the address to be marked as used and the transaction to be
+ // inserted as confirmed.
+ mocks.addrStore.On("MarkUsed", mock.Anything, addr).Return(nil).Once()
+ mocks.txStore.On("InsertConfirmedTx", mock.Anything, rec, block,
+ mock.Anything,
+ ).Return(nil).Once()
+
+ // 3. Sync Tip Update.
+ //
+ // Expect the wallet's sync tip to be updated to the new block.
+ mocks.addrStore.On("SetSyncedTo", mock.Anything, mock.Anything).Return(
+ nil,
+ ).Once()
+
+ // Act: Process the block and its relevant transactions.
+ err := s.DBPutBlocks(t.Context(), matches, block)
+
+ // Assert: Verify that all operations completed successfully.
+ require.NoError(t, err)
+}
+
+// TestDBPutTxns verifies that DBPutTxns can successfully resolve and persist
+// unconfirmed transactions in the wallet's transaction store.
+func TestDBPutTxns(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer and setup test data for an unconfirmed
+ // transaction.
+ w, mocks := createTestWalletWithMocks(t)
+ s := newSyncer(w.cfg, w.addrStore, w.txStore, nil)
+
+ tx := wire.NewMsgTx(1)
+ rec, _ := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
+ addr, _ := address.NewAddressPubKeyHash(
+ make([]byte, 20), &chaincfg.MainNetParams,
+ )
+
+ matches := TxEntries{{
+ Rec: rec,
+ Entries: []AddrEntry{{
+ Address: addr,
+ }},
+ }}
+
+ // Setup mock calls to resolve the address scope and persist the
+ // unconfirmed transaction.
+ mockAddr := &mockManagedPubKeyAddr{}
+ mockAddr.On("Internal").Return(false).Once()
+ mocks.addrStore.On("Address", mock.Anything, addr).Return(
+ mockAddr, nil,
+ ).Once()
+
+ scopedMgr := &mockAccountStore{}
+ scopedMgr.On("Scope").Return(waddrmgr.KeyScopeBIP0084).Once()
+ mocks.addrStore.On("AddrAccount", mock.Anything, addr).Return(
+ scopedMgr, uint32(0), nil,
+ ).Once()
+
+ mocks.addrStore.On("MarkUsed", mock.Anything, addr).Return(nil).Once()
+ mocks.txStore.On("InsertUnconfirmedTx", mock.Anything, rec,
+ mock.Anything,
+ ).Return(nil).Once()
+
+ // Act: Attempt to persist the unconfirmed transaction.
+ err := s.DBPutTxns(t.Context(), matches, nil)
+
+ // Assert: Verify that the transaction was persisted successfully.
+ require.NoError(t, err)
+}
+
+// TestPutAddrHorizons verifies that address horizons are correctly extended
+// in the database based on scan results.
+func TestPutAddrHorizons(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer and setup a scan result that indicates a
+ // horizon expansion is needed for a specific BIP84 account.
+ w, mocks := createTestWalletWithMocks(t)
+ s := newSyncer(w.cfg, w.addrStore, w.txStore, nil)
+
+ bs := waddrmgr.BranchScope{
+ Scope: waddrmgr.KeyScopeBIP0084,
+ Account: 0,
+ Branch: waddrmgr.ExternalBranch,
+ }
+
+ res := []scanResult{{
+ BlockProcessResult: &BlockProcessResult{
+ FoundHorizons: map[waddrmgr.BranchScope]uint32{
+ bs: 10,
+ },
+ },
+ }}
+
+ // Setup mock calls for fetching the manager and extending the
+ // addresses.
+ scopedMgr := &mockAccountStore{}
+ mocks.addrStore.On("FetchScopedKeyManager", bs.Scope).Return(
+ scopedMgr, nil,
+ ).Once()
+
+ scopedMgr.On("ExtendAddresses", mock.Anything, bs.Account, uint32(10),
+ bs.Branch,
+ ).Return(nil).Once()
+
+ // Act: Trigger the horizon expansion within a database transaction.
+ err := walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ return s.putAddrHorizons(t.Context(), ns, res)
+ })
+
+ // Assert: Verify that the horizons were extended without error.
+ require.NoError(t, err)
+}
+
+// TestDBGetScanData verifies that the wallet can successfully retrieve all
+// necessary state (horizons, active addresses, and UTXOs) to initialize a
+// chain rescan.
+func TestDBGetScanData(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer and setup mock expectations for all data
+ // required during rescan initialization.
+ w, mocks := createTestWalletWithMocks(t)
+ s := newSyncer(w.cfg, w.addrStore, w.txStore, nil)
+
+ targets := []waddrmgr.AccountScope{{
+ Scope: waddrmgr.KeyScopeBIP0084,
+ Account: 0,
+ }}
+
+ // 1. Horizons lookup.
+ scopedMgr := &mockAccountStore{}
+ mocks.addrStore.On("FetchScopedKeyManager",
+ waddrmgr.KeyScopeBIP0084,
+ ).Return(scopedMgr, nil).Once()
+
+ props := &waddrmgr.AccountProperties{AccountNumber: 0}
+ scopedMgr.On("AccountProperties", mock.Anything, uint32(0)).Return(
+ props, nil,
+ ).Once()
+
+ // 2. Active addresses lookup.
+ mocks.addrStore.On("ForEachRelevantActiveAddress", mock.Anything,
+ mock.Anything,
+ ).Return(nil).Once()
+
+ // 3. UTXO lookup.
+ mocks.txStore.On("OutputsToWatch", mock.Anything).Return(
+ []wtxmgr.Credit(nil), nil,
+ ).Once()
+
+ // Act: Retrieve the initial scan data from the database.
+ horizonData, initialAddrs, initialUnspent, err := s.DBGetScanData(
+ t.Context(), targets,
+ )
+
+ // Assert: Verify that the retrieved data matches our expectations and
+ // that no error occurred.
+ require.NoError(t, err)
+ require.Len(t, horizonData, 1)
+ require.Equal(t, props, horizonData[0])
+ require.Empty(t, initialAddrs)
+ require.Empty(t, initialUnspent)
+}
+
+// TestDBGetSyncedBlocks verifies that the wallet can successfully retrieve a
+// range of block hashes from its internal index.
+func TestDBGetSyncedBlocks(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer and setup a mock expectation for fetching a
+ // block hash from the address manager.
+ w, mocks := createTestWalletWithMocks(t)
+ s := newSyncer(w.cfg, w.addrStore, w.txStore, nil)
+
+ hash := chainhash.Hash{0x01}
+ mocks.addrStore.On("BlockHash", mock.Anything, int32(100)).Return(
+ &hash, nil,
+ ).Once()
+
+ // Act: Fetch the block hashes for the requested range.
+ hashes, err := s.DBGetSyncedBlocks(t.Context(), 100, 100)
+
+ // Assert: Verify that the retrieved hash is correct.
+ require.NoError(t, err)
+ require.Len(t, hashes, 1)
+ require.Equal(t, &hash, hashes[0])
+}
+
+// TestDBPutRewind verifies that the wallet can successfully rewind its
+// synchronized state and transaction history to a specific point.
+func TestDBPutRewind(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer and setup mock expectations for updating
+ // the sync tip and rolling back the transaction store.
+ w, mocks := createTestWalletWithMocks(t)
+ s := newSyncer(w.cfg, w.addrStore, w.txStore, nil)
+
+ bs := waddrmgr.BlockStamp{Height: 100, Hash: chainhash.Hash{0x01}}
+
+ mocks.addrStore.On("SetSyncedTo", mock.Anything, &bs).Return(nil).Once()
+ mocks.txStore.On("Rollback",
+ mock.Anything, int32(101),
+ ).Return(nil).Once()
+
+ // Act: Rewind the wallet state to the specified block height.
+ err := s.DBPutRewind(t.Context(), bs)
+
+ // Assert: Verify that the rewind operation succeeded.
+ require.NoError(t, err)
+}
+
+// TestDBPutBirthdayBlock_Error verifies that DBPutBirthdayBlock correctly
+// handles and returns database errors during persistence.
+func TestDBPutBirthdayBlock_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet and setup a mock call that simulates a
+ // failure when setting the birthday block.
+ w, mocks := createTestWalletWithMocks(t)
+
+ bs := waddrmgr.BlockStamp{Height: 100}
+
+ mocks.addrStore.On("SetBirthdayBlock", mock.Anything, bs, true).Return(
+ errDBMock,
+ ).Once()
+
+ // Act: Attempt to persist the birthday block, expecting a failure.
+ err := w.DBPutBirthdayBlock(t.Context(), bs)
+
+ // Assert: Verify that the database error is correctly propagated.
+ require.ErrorContains(t, err, "db error")
+}
+
+// TestDBGetAllAccounts_Error verifies that DBGetAllAccounts correctly
+// handles and returns database errors encountered while iterating over
+// accounts.
+func TestDBGetAllAccounts_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet and setup a mock call that simulates a
+ // failure while querying for the last account index.
+ w, mocks := createTestWalletWithMocks(t)
+ scopedMgr := &mockAccountStore{}
+
+ mocks.addrStore.On("ActiveScopedKeyManagers").Return(
+ []waddrmgr.AccountStore{scopedMgr},
+ ).Once()
+
+ scopedMgr.On("LastAccount", mock.Anything).Return(
+ uint32(0), errDBMock,
+ ).Once()
+
+ // Act: Attempt to load all account properties.
+ err := w.DBGetAllAccounts(t.Context())
+
+ // Assert: Verify that the expected error is returned.
+ require.ErrorContains(t, err, "db error")
+}
+
+// TestDBGetScanData_MultipleTargets verifies that DBGetScanData correctly
+// aggregates horizon data when multiple account scopes are requested.
+func TestDBGetScanData_MultipleTargets(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer and setup test data for multiple accounts
+ // across different scopes.
+ w, mocks := createTestWalletWithMocks(t)
+ s := newSyncer(w.cfg, w.addrStore, w.txStore, nil)
+
+ targets := []waddrmgr.AccountScope{
+ {Scope: waddrmgr.KeyScopeBIP0084, Account: 0},
+ {Scope: waddrmgr.KeyScopeBIP0049Plus, Account: 1},
+ }
+
+ // Setup mock calls to handle property retrieval for both targets.
+ scopedMgr := &mockAccountStore{}
+ mocks.addrStore.On("FetchScopedKeyManager", mock.Anything).Return(
+ scopedMgr, nil,
+ ).Twice()
+
+ scopedMgr.On("AccountProperties", mock.Anything, mock.Anything).Return(
+ &waddrmgr.AccountProperties{}, nil,
+ ).Twice()
+
+ mocks.addrStore.On("ForEachRelevantActiveAddress", mock.Anything,
+ mock.Anything,
+ ).Return(nil).Once()
+
+ mocks.txStore.On("OutputsToWatch", mock.Anything).Return(
+ []wtxmgr.Credit(nil), nil,
+ ).Once()
+
+ // Act: Retrieve initial scan data for all requested targets.
+ horizons, _, _, err := s.DBGetScanData(t.Context(), targets)
+
+ // Assert: Verify that data for both targets was successfully
+ // collected.
+ require.NoError(t, err)
+ require.Len(t, horizons, 2)
+}
+
+// TestDBGetScanData_Error verifies that DBGetScanData correctly handles
+// and returns database errors during horizon lookup, ensuring no stale
+// data is returned.
+func TestDBGetScanData_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer and setup a mock expectation for a failure
+ // while fetching a scoped key manager.
+ w, mocks := createTestWalletWithMocks(t)
+ s := newSyncer(w.cfg, w.addrStore, w.txStore, nil)
+
+ targets := []waddrmgr.AccountScope{{
+ Scope: waddrmgr.KeyScopeBIP0084,
+ Account: 0,
+ }}
+
+ mocks.addrStore.On("FetchScopedKeyManager",
+ waddrmgr.KeyScopeBIP0084,
+ ).Return(nil, errDBMock).Once()
+
+ // Act: Attempt to retrieve scan data, which is expected to fail.
+ horizons, addrs, unspent, err := s.DBGetScanData(t.Context(), targets)
+
+ // Assert: Verify that the database error is returned and all returned
+ // data slices are nil.
+ require.ErrorContains(t, err, "db error")
+ require.Nil(t, horizons)
+ require.Nil(t, addrs)
+ require.Nil(t, unspent)
+}
+
+// TestDBPutTargetedBatch_WithTxns verifies that DBPutTargetedBatch processes
+// relevant outputs.
+func TestDBPutTargetedBatch_WithTxns(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer and mock dependencies for processing a
+ // targeted batch.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+
+ s := newSyncer(Config{DB: db}, mockAddrStore, mockTxStore, nil)
+
+ rec, err := wtxmgr.NewTxRecordFromMsgTx(wire.NewMsgTx(1), time.Now())
+ require.NoError(t, err)
+
+ results := []scanResult{
+ {
+ meta: &wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{Height: 100},
+ Time: time.Now(),
+ },
+ BlockProcessResult: &BlockProcessResult{
+ RelevantOutputs: TxEntries{
+ {Rec: rec, Entries: []AddrEntry{}},
+ },
+ },
+ },
+ }
+
+ mockTxStore.On("InsertConfirmedTx", mock.Anything, mock.Anything,
+ mock.Anything, mock.Anything).Return(nil).Once()
+
+ // Act: Execute the targeted batch update.
+ err = s.DBPutTargetedBatch(t.Context(), results)
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+}
+
+// TestDBPutSyncTip_Error verifies error propagation in DBPutSyncTip.
+func TestDBPutSyncTip_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where SetSyncedTo fails during
+ // DBPutSyncTip.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, nil, nil)
+
+ mockAddrStore.On("SetSyncedTo", mock.Anything,
+ mock.Anything).Return(errSetFail).Once()
+
+ // Act: Attempt to update the sync tip.
+ err := s.DBPutSyncTip(t.Context(), wtxmgr.BlockMeta{})
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errSetFail)
+}
+
+// TestDBPutTargetedBatch_Errors verifies error paths.
+func TestDBPutTargetedBatch_Errors(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where transaction insertion fails
+ // during a targeted batch update.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, mockTxStore, nil)
+
+ rec, err := wtxmgr.NewTxRecordFromMsgTx(wire.NewMsgTx(1), time.Now())
+ require.NoError(t, err)
+
+ results := []scanResult{
+ {
+ meta: &wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{Height: 100},
+ Time: time.Now(),
+ },
+ BlockProcessResult: &BlockProcessResult{
+ RelevantOutputs: TxEntries{
+ {Rec: rec, Entries: []AddrEntry{}},
+ },
+ },
+ },
+ }
+
+ mockTxStore.On("InsertConfirmedTx", mock.Anything, mock.Anything,
+ mock.Anything, mock.Anything).Return(errDBInsert).Once()
+
+ // Act: Execute the targeted batch update.
+ err = s.DBPutTargetedBatch(t.Context(), results)
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errDBInsert)
+}
+
+// TestDBPutTxns_Error verifies error propagation in DBPutTxns.
+func TestDBPutTxns_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where address lookup fails during
+ // transaction persistence.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, nil, nil)
+
+ addr, err := address.NewAddressPubKeyHash(
+ make([]byte, 20), &chainParams,
+ )
+ require.NoError(t, err)
+
+ matches := TxEntries{
+ {
+ Rec: &wtxmgr.TxRecord{},
+ Entries: []AddrEntry{{Address: addr}},
+ },
+ }
+
+ mockAddrStore.On("Address",
+ mock.Anything, mock.Anything).Return(nil, errAddr).Once()
+
+ // Act: Attempt to persist transactions.
+ err = s.DBPutTxns(t.Context(), matches, nil)
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errAddr)
+}
+
+// TestDBPutTxns_UnconfirmedError verifies error propagation for unconfirmed tx.
+func TestDBPutTxns_UnconfirmedError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where unconfirmed transaction
+ // insertion fails.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+
+ s := newSyncer(Config{DB: db}, mockAddrStore, mockTxStore, nil)
+
+ addr, err := address.NewAddressPubKeyHash(
+ make([]byte, 20), &chainParams,
+ )
+ require.NoError(t, err)
+
+ matches := TxEntries{
+ {
+ Rec: &wtxmgr.TxRecord{},
+ Entries: []AddrEntry{{Address: addr}},
+ },
+ }
+
+ maddr := &mockManagedAddress{}
+ maddr.On("Internal").Return(false).Maybe()
+ mockAddrStore.On("Address", mock.Anything, mock.Anything).Return(maddr,
+ nil).Once()
+
+ mgr := &mockAccountStore{}
+ mgr.On("Scope").Return(waddrmgr.KeyScopeBIP0084).Once()
+ mockAddrStore.On("AddrAccount", mock.Anything, mock.Anything).Return(
+ mgr,
+ uint32(0), nil).Once()
+ mockAddrStore.On("MarkUsed", mock.Anything,
+ mock.Anything).Return(nil).Once()
+ mockTxStore.On("InsertUnconfirmedTx", mock.Anything, mock.Anything,
+ mock.Anything).Return(errInsert).Once()
+
+ // Act: Attempt to persist unconfirmed transactions.
+ err = s.DBPutTxns(t.Context(), matches, nil)
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errInsert)
+}
+
+// TestPutSyncTip_Error verifies error propagation.
+func TestPutSyncTip_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where sync tip update fails within
+ // a database transaction.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, nil, nil)
+
+ // Act: Execute sync tip update within a database transaction.
+ err := walletdb.Update(db, func(tx walletdb.ReadWriteTx) error {
+ mockAddrStore.On("SetSyncedTo", mock.Anything,
+ mock.Anything).Return(errSetFail).Once()
+
+ return s.putSyncTip(t.Context(), tx, wtxmgr.BlockMeta{})
+ })
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errSetFail)
+}
+
+// TestDBGetScanData_ManagerError verifies account not found is handled.
+func TestDBGetScanData_ManagerError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where key manager lookup fails
+ // during scan data retrieval.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, nil, nil)
+
+ targets := []waddrmgr.AccountScope{
+ {Scope: waddrmgr.KeyScopeBIP0084, Account: 0},
+ }
+
+ mockAddrStore.On("FetchScopedKeyManager",
+ mock.Anything).Return(nil, errManager).Once()
+
+ // Act: Attempt to retrieve scan data.
+ horizons, addrs, unspent, err := s.DBGetScanData(t.Context(), targets)
+
+ // Assert: Verify failure.
+ require.Nil(t, horizons)
+ require.Nil(t, addrs)
+ require.Nil(t, unspent)
+ require.ErrorIs(t, err, errManager)
+}
+
+// TestDBGetScanData_UTXOError verifies UTXO loading failure.
+func TestDBGetScanData_UTXOError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where UTXO lookup fails during scan
+ // data retrieval.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, mockTxStore, nil)
+
+ mockAddrStore.On("ForEachRelevantActiveAddress", mock.Anything,
+ mock.AnythingOfType("func(address.Address) error"),
+ ).Return(nil).Once()
+ mockTxStore.On("OutputsToWatch",
+ mock.Anything).Return(([]wtxmgr.Credit)(nil), errUtxo).Once()
+
+ // Act: Attempt to retrieve scan data.
+ horizons, addrs, unspent, err := s.DBGetScanData(t.Context(), nil)
+
+ // Assert: Verify failure.
+ require.Nil(t, horizons)
+ require.Nil(t, addrs)
+ require.Nil(t, unspent)
+ require.ErrorIs(t, err, errUtxo)
+}
+
+// TestPutAddrHorizons_Error verifies error propagation.
+func TestPutAddrHorizons_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where key manager lookup fails
+ // during horizon persistence.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, nil, nil)
+
+ results := []scanResult{
+ {
+ BlockProcessResult: &BlockProcessResult{
+ FoundHorizons: map[waddrmgr.BranchScope]uint32{
+ {}: 1,
+ },
+ },
+ },
+ }
+
+ mockAddrStore.On("FetchScopedKeyManager",
+ mock.Anything).Return(nil, errManager).Once()
+
+ // Act: Attempt to persist address horizons.
+ err := s.putAddrHorizons(t.Context(), nil, results)
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errManager)
+}
+
+// TestDBGetScanData_AddressError verifies active address loading failure.
+func TestDBGetScanData_AddressError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where address iteration fails
+ // during scan data retrieval.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, nil, nil)
+
+ mockAddrStore.On("ForEachRelevantActiveAddress", mock.Anything,
+ mock.Anything).Return(errAddr).Once()
+
+ // Act: Attempt to retrieve scan data.
+ horizons, addrs, unspent, err := s.DBGetScanData(t.Context(), nil)
+
+ // Assert: Verify failure.
+ require.Nil(t, horizons)
+ require.Nil(t, addrs)
+ require.Nil(t, unspent)
+ require.ErrorIs(t, err, errAddr)
+}
+
+// TestDBPutTxns_InternalAddressAsChange verifies internal branch handling.
+func TestDBPutTxns_InternalAddressAsChange(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations for a transaction match where the
+ // address is internal, requiring it to be marked as change.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, mockTxStore, nil)
+
+ addr, err := address.NewAddressPubKeyHash(
+ make([]byte, 20), &chainParams,
+ )
+ require.NoError(t, err)
+
+ matches := TxEntries{
+ {
+ Rec: &wtxmgr.TxRecord{},
+ Entries: []AddrEntry{{Address: addr}},
+ },
+ }
+
+ maddr := &mockManagedAddress{}
+ maddr.On("Internal").Return(true).Once()
+ mockAddrStore.On("Address",
+ mock.Anything, mock.Anything).Return(maddr, nil).Once()
+
+ mgr := &mockAccountStore{}
+ mgr.On("Scope").Return(waddrmgr.KeyScopeBIP0084).Once()
+ mockAddrStore.On("AddrAccount",
+ mock.Anything, mock.Anything).Return(mgr, uint32(0), nil).Once()
+
+ mockAddrStore.On("MarkUsed", mock.Anything,
+ mock.Anything).Return(nil).Once()
+
+ mockTxStore.On("InsertUnconfirmedTx", mock.Anything, mock.Anything,
+ mock.Anything).Return(nil).Once()
+
+ // Act: Persist transactions and filter branch scopes.
+ err = s.DBPutTxns(t.Context(), matches, nil)
+
+ // Assert: Verify that the output was correctly identified as change.
+ require.NoError(t, err)
+ require.True(t, matches[0].Entries[0].Credit.Change)
+}
+
+// TestDBPutTxns_AddressNotFound verifies ignoring not-found addresses.
+func TestDBPutTxns_AddressNotFound(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where an address lookup returns a
+ // "not found" error, which should lead to the entry being filtered out.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, mockTxStore, nil)
+
+ addr, err := address.NewAddressPubKeyHash(
+ make([]byte, 20), &chainParams,
+ )
+ require.NoError(t, err)
+
+ matches := TxEntries{
+ {
+ Rec: &wtxmgr.TxRecord{},
+ Entries: []AddrEntry{{Address: addr}},
+ },
+ }
+
+ mockAddrStore.On("Address",
+ mock.Anything, mock.Anything,
+ ).Return(nil, waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrAddressNotFound}).Once()
+
+ mockTxStore.On("InsertUnconfirmedTx", mock.Anything, mock.Anything,
+ mock.Anything).Return(nil).Once()
+
+ // Act: Persist transactions.
+ err = s.DBPutTxns(t.Context(), matches, nil)
+
+ // Assert: Verify that the unknown address entry was filtered.
+ require.NoError(t, err)
+ require.Empty(t, matches[0].Entries)
+}
+
+// TestDBPutRewind_Error verifies error propagation.
+func TestDBPutRewind_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where SetSyncedTo fails during
+ // DBPutRewind.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, nil, nil)
+
+ mockAddrStore.On("SetSyncedTo",
+ mock.Anything, mock.Anything).Return(errSetSync).Once()
+
+ // Act: Perform DBPutRewind.
+ err := s.DBPutRewind(t.Context(), waddrmgr.BlockStamp{})
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errSetSync)
+}
diff --git a/wallet/deprecated.go b/wallet/deprecated.go
new file mode 100644
index 0000000000..99e13c1aea
--- /dev/null
+++ b/wallet/deprecated.go
@@ -0,0 +1,7349 @@
+//nolint:lll
+package wallet
+
+import (
+ "bytes"
+ "encoding/binary"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "sort"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/blockchain"
+ "github.com/btcsuite/btcd/btcec/v2"
+ "github.com/btcsuite/btcd/btcjson"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/psbt/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/chain"
+ "github.com/btcsuite/btcwallet/internal/prompt"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/wallet/txauthor"
+ "github.com/btcsuite/btcwallet/wallet/txrules"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/btcsuite/btcwallet/walletdb/migration"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/davecgh/go-spew/spew"
+ "github.com/lightningnetwork/lnd/fn/v2"
+)
+
+// NextAccount creates the next account and returns its account number. The
+// name must be unique to the account. In order to support automatic seed
+// restoring, new accounts may not be created when all of the previous 100
+// accounts have no transaction history (this is a deviation from the BIP0044
+// spec, which allows no unused account gaps).
+func (w *Wallet) NextAccount(scope waddrmgr.KeyScope, name string) (uint32, error) {
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return 0, err
+ }
+
+ // Validate that the scope manager can add this new account.
+ err = manager.CanAddAccount()
+ if err != nil {
+ return 0, err
+ }
+
+ var (
+ account uint32
+ props *waddrmgr.AccountProperties
+ )
+ err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ var err error
+ account, err = manager.NewAccount(addrmgrNs, name)
+ if err != nil {
+ return err
+ }
+ props, err = manager.AccountProperties(addrmgrNs, account)
+
+ return err
+ })
+ if err != nil {
+ log.Errorf("Cannot fetch new account properties for notification "+
+ "after account creation: %v", err)
+ } else {
+ w.NtfnServer.notifyAccountProperties(props)
+ }
+
+ return account, err
+}
+
+// Accounts returns the current names, numbers, and total balances of all
+// accounts in the wallet restricted to a particular key scope. The current
+// chain tip is included in the result for atomicity reasons.
+//
+// TODO(jrick): Is the chain tip really needed, since only the total balances
+// are included?
+func (w *Wallet) Accounts(scope waddrmgr.KeyScope) (*AccountsResult, error) {
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ var (
+ accounts []AccountResult
+ syncBlockHash *chainhash.Hash
+ syncBlockHeight int32
+ )
+ err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ syncBlock := w.addrStore.SyncedTo()
+ syncBlockHash = &syncBlock.Hash
+ syncBlockHeight = syncBlock.Height
+ unspent, err := w.txStore.UnspentOutputs(txmgrNs)
+ if err != nil {
+ return err
+ }
+ err = manager.ForEachAccount(addrmgrNs, func(acct uint32) error {
+ props, err := manager.AccountProperties(addrmgrNs, acct)
+ if err != nil {
+ return err
+ }
+ accounts = append(accounts, AccountResult{
+ AccountProperties: *props,
+ // TotalBalance set below
+ })
+
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+ m := make(map[uint32]*btcutil.Amount)
+ for i := range accounts {
+ a := &accounts[i]
+ m[a.AccountNumber] = &a.TotalBalance
+ }
+ for i := range unspent {
+ output := unspent[i]
+ var outputAcct uint32
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(output.PkScript, w.chainParams)
+ if err == nil && len(addrs) > 0 {
+ _, outputAcct, err = w.addrStore.AddrAccount(addrmgrNs, addrs[0])
+ }
+ if err == nil {
+ amt, ok := m[outputAcct]
+ if ok {
+ *amt += output.Amount
+ }
+ }
+ }
+
+ return nil
+ })
+
+ return &AccountsResult{
+ Accounts: accounts,
+ CurrentBlockHash: *syncBlockHash,
+ CurrentBlockHeight: syncBlockHeight,
+ }, err
+}
+
+// RenameAccountDeprecated sets the name for an account number to newName.
+func (w *Wallet) RenameAccountDeprecated(scope waddrmgr.KeyScope,
+ account uint32, newName string) error {
+
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return err
+ }
+
+ var props *waddrmgr.AccountProperties
+ err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ err := manager.RenameAccount(addrmgrNs, account, newName)
+ if err != nil {
+ return err
+ }
+ props, err = manager.AccountProperties(addrmgrNs, account)
+
+ return err
+ })
+ if err == nil {
+ w.NtfnServer.notifyAccountProperties(props)
+ }
+
+ return err
+}
+
+// ScriptForOutputDeprecated returns the address, witness program and redeem
+// script for a given UTXO. An error is returned if the UTXO does not
+// belong to our wallet or it is not a managed pubKey address.
+//
+// Deprecated: Use AddressManager.ScriptForOutput instead.
+func (w *Wallet) ScriptForOutputDeprecated(output *wire.TxOut) (
+ waddrmgr.ManagedPubKeyAddress, []byte, []byte, error) {
+
+ // First make sure we can sign for the input by making sure the script
+ // in the UTXO belongs to our wallet and we have the private key for it.
+ walletAddr, err := w.fetchOutputAddr(output.PkScript)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ pubKeyAddr, ok := walletAddr.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ return nil, nil, nil, fmt.Errorf("address %s is not a "+
+ "p2wkh or np2wkh address", walletAddr.Address())
+ }
+
+ var (
+ witnessProgram []byte
+ sigScript []byte
+ )
+
+ switch {
+ // If we're spending p2wkh output nested within a p2sh output, then
+ // we'll need to attach a sigScript in addition to witness data.
+ case walletAddr.AddrType() == waddrmgr.NestedWitnessPubKey:
+ pubKey := pubKeyAddr.PubKey()
+ pubKeyHash := address.Hash160(pubKey.SerializeCompressed())
+
+ // Next, we'll generate a valid sigScript that will allow us to
+ // spend the p2sh output. The sigScript will contain only a
+ // single push of the p2wkh witness program corresponding to
+ // the matching public key of this address.
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ pubKeyHash, w.chainParams,
+ )
+ if err != nil {
+ return nil, nil, nil, err
+ }
+ witnessProgram, err = txscript.PayToAddrScript(p2wkhAddr)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ bldr := txscript.NewScriptBuilder()
+ bldr.AddData(witnessProgram)
+ sigScript, err = bldr.Script()
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ // Otherwise, this is a regular p2wkh or p2tr output, so we include the
+ // witness program itself as the subscript to generate the proper
+ // sighash digest. As part of the new sighash digest algorithm, the
+ // p2wkh witness program will be expanded into a regular p2kh
+ // script.
+ default:
+ witnessProgram = output.PkScript
+ }
+
+ return pubKeyAddr, witnessProgram, sigScript, nil
+}
+
+// ComputeInputScript generates a complete InputScript for the passed
+// transaction with the signature as defined within the passed
+// SignDescriptor. This method is capable of generating the proper input
+// script for both regular p2wkh output and p2wkh outputs nested within a
+// regular p2sh output.
+func (w *Wallet) ComputeInputScript(tx *wire.MsgTx, output *wire.TxOut,
+ inputIndex int, sigHashes *txscript.TxSigHashes,
+ hashType txscript.SigHashType, tweaker PrivKeyTweaker) (wire.TxWitness,
+ []byte, error) {
+
+ walletAddr, witnessProgram, sigScript, err :=
+ w.ScriptForOutputDeprecated(
+ output,
+ )
+ if err != nil {
+ return nil, nil, err
+ }
+
+ privKey, err := walletAddr.PrivKey()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // If we need to maybe tweak our private key, do it now.
+ if tweaker != nil {
+ privKey, err = tweaker(privKey)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+
+ // We need to produce a Schnorr signature for p2tr key spend addresses.
+ if txscript.IsPayToTaproot(output.PkScript) {
+ // We can now generate a valid witness which will allow us to
+ // spend this output.
+ witnessScript, err := txscript.TaprootWitnessSignature(
+ tx, sigHashes, inputIndex, output.Value,
+ output.PkScript, hashType, privKey,
+ )
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return witnessScript, nil, nil
+ }
+
+ // Generate a valid witness stack for the input.
+ witnessScript, err := txscript.WitnessSignature(
+ tx, sigHashes, inputIndex, output.Value, witnessProgram,
+ hashType, privKey, true,
+ )
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return witnessScript, sigScript, nil
+}
+
+var (
+ // ErrNotMine is an error denoting that a Wallet instance is unable to
+ // spend a specified output.
+ ErrNotMine = errors.New("the passed output does not belong to the " +
+ "wallet")
+)
+
+// OutputSelectionPolicy describes the rules for selecting an output from the
+// wallet.
+type OutputSelectionPolicy struct {
+ Account uint32
+ RequiredConfirmations int32
+}
+
+func (p *OutputSelectionPolicy) meetsRequiredConfs(txHeight,
+ curHeight int32) bool {
+
+ return hasMinConfs(
+ //nolint:gosec
+ uint32(p.RequiredConfirmations), txHeight, curHeight,
+ )
+}
+
+// UnspentOutputs fetches all unspent outputs from the wallet that match rules
+// described in the passed policy.
+func (w *Wallet) UnspentOutputs(policy OutputSelectionPolicy) ([]*TransactionOutput, error) {
+ var outputResults []*TransactionOutput
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ syncBlock := w.addrStore.SyncedTo()
+
+ // TODO: actually stream outputs from the db instead of fetching
+ // all of them at once.
+ outputs, err := w.txStore.UnspentOutputs(txmgrNs)
+ if err != nil {
+ return err
+ }
+
+ for _, output := range outputs {
+ // Ignore outputs that haven't reached the required
+ // number of confirmations.
+ if !policy.meetsRequiredConfs(output.Height, syncBlock.Height) {
+ continue
+ }
+
+ // Ignore outputs that are not controlled by the account.
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(output.PkScript,
+ w.chainParams)
+ if err != nil || len(addrs) == 0 {
+ // Cannot determine which account this belongs
+ // to without a valid address. TODO: Fix this
+ // by saving outputs per account, or accounts
+ // per output.
+ continue
+ }
+
+ _, outputAcct, err := w.addrStore.AddrAccount(
+ addrmgrNs, addrs[0],
+ )
+ if err != nil {
+ return err
+ }
+ if outputAcct != policy.Account {
+ continue
+ }
+
+ // Stakebase isn't exposed by wtxmgr so those will be
+ // OutputKindNormal for now.
+ outputSource := OutputKindNormal
+ if output.FromCoinBase {
+ outputSource = OutputKindCoinbase
+ }
+
+ result := &TransactionOutput{
+ OutPoint: output.OutPoint,
+ Output: wire.TxOut{
+ Value: int64(output.Amount),
+ PkScript: output.PkScript,
+ },
+ OutputKind: outputSource,
+ ContainingBlock: BlockIdentity(output.Block),
+ ReceiveTime: output.Received,
+ }
+ outputResults = append(outputResults, result)
+ }
+
+ return nil
+ })
+ return outputResults, err
+}
+
+// FetchInputInfo queries for the wallet's knowledge of the passed outpoint. If
+// the wallet determines this output is under its control, then the original
+// full transaction, the target txout, the derivation info and the number of
+// confirmations are returned. Otherwise, a non-nil error value of ErrNotMine
+// is returned instead.
+//
+// NOTE: This method is kept for compatibility.
+func (w *Wallet) FetchInputInfo(prevOut *wire.OutPoint) (*wire.MsgTx,
+ *wire.TxOut, *psbt.Bip32Derivation, int64, error) {
+
+ tx, txOut, confs, err := w.FetchOutpointInfo(prevOut)
+ if err != nil {
+ return nil, nil, nil, 0, err
+ }
+
+ derivation, err := w.FetchDerivationInfo(txOut.PkScript)
+ if err != nil {
+ return nil, nil, nil, 0, err
+ }
+
+ return tx, txOut, derivation, confs, nil
+}
+
+// fetchOutputAddr attempts to fetch the managed address corresponding to the
+// passed output script. This function is used to look up the proper key which
+// should be used to sign a specified input.
+func (w *Wallet) fetchOutputAddr(script []byte) (waddrmgr.ManagedAddress, error) {
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(script, w.chainParams)
+ if err != nil {
+ return nil, err
+ }
+
+ // If the case of a multi-sig output, several address may be extracted.
+ // Therefore, we simply select the key for the first address we know
+ // of.
+ for _, addr := range addrs {
+ addr, err := w.AddressInfoDeprecated(addr)
+ if err == nil {
+ return addr, nil
+ }
+ }
+
+ return nil, ErrNotMine
+}
+
+// FetchOutpointInfo queries for the wallet's knowledge of the passed outpoint.
+// If the wallet determines this output is under its control, the original full
+// transaction, the target txout and the number of confirmations are returned.
+// Otherwise, a non-nil error value of ErrNotMine is returned instead.
+func (w *Wallet) FetchOutpointInfo(prevOut *wire.OutPoint) (*wire.MsgTx,
+ *wire.TxOut, int64, error) {
+
+ // We manually look up the output within the tx store.
+ txid := &prevOut.Hash
+ txDetail, err := UnstableAPI(w).TxDetails(txid)
+ if err != nil {
+ return nil, nil, 0, err
+ } else if txDetail == nil {
+ return nil, nil, 0, ErrNotMine
+ }
+
+ // With the output retrieved, we'll make an additional check to ensure
+ // we actually have control of this output. We do this because the
+ // check above only guarantees that the transaction is somehow relevant
+ // to us, like in the event of us being the sender of the transaction.
+ numOutputs := uint32(len(txDetail.TxRecord.MsgTx.TxOut))
+ if prevOut.Index >= numOutputs {
+ return nil, nil, 0, fmt.Errorf("invalid output index %v for "+
+ "transaction with %v outputs", prevOut.Index,
+ numOutputs)
+ }
+
+ // Exit early if the output doesn't belong to our wallet. We know it's
+ // our UTXO iff the `TxDetails` has a credit record on this output.
+ if !hasOutput(txDetail, prevOut.Index) {
+ return nil, nil, 0, ErrNotMine
+ }
+
+ pkScript := txDetail.TxRecord.MsgTx.TxOut[prevOut.Index].PkScript
+
+ // Determine the number of confirmations the output currently has.
+ _, currentHeight, err := w.chainClient.GetBestBlock()
+ if err != nil {
+ return nil, nil, 0, fmt.Errorf("unable to retrieve current "+
+ "height: %w", err)
+ }
+
+ confs := int64(0)
+ if txDetail.Block.Height != -1 {
+ confs = int64(currentHeight - txDetail.Block.Height)
+ }
+
+ return &txDetail.TxRecord.MsgTx, &wire.TxOut{
+ Value: txDetail.TxRecord.MsgTx.TxOut[prevOut.Index].Value,
+ PkScript: pkScript,
+ }, confs, nil
+}
+
+// FetchDerivationInfo queries for the wallet's knowledge of the passed
+// pkScript and constructs the derivation info and returns it.
+func (w *Wallet) FetchDerivationInfo(pkScript []byte) (*psbt.Bip32Derivation,
+ error) {
+
+ addr, err := w.fetchOutputAddr(pkScript)
+ if err != nil {
+ return nil, err
+ }
+
+ pubKeyAddr, ok := addr.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ return nil, ErrNotMine
+ }
+ keyScope, derivationPath, _ := pubKeyAddr.DerivationInfo()
+
+ derivation := &psbt.Bip32Derivation{
+ PubKey: pubKeyAddr.PubKey().SerializeCompressed(),
+ MasterKeyFingerprint: derivationPath.MasterKeyFingerprint,
+ Bip32Path: []uint32{
+ keyScope.Purpose + hdkeychain.HardenedKeyStart,
+ keyScope.Coin + hdkeychain.HardenedKeyStart,
+ derivationPath.Account,
+ derivationPath.Branch,
+ derivationPath.Index,
+ },
+ }
+
+ return derivation, nil
+}
+
+// hasOutpoint takes an output identified by its output index and determines
+// whether the TxDetails contains this output. If the TxDetails doesn't have
+// this output, it means this output doesn't belong to our wallet.
+//
+// TODO(yy): implement this method on `TxDetails` and update the package
+// `wtxmgr` instead.
+func hasOutput(t *wtxmgr.TxDetails, outputIndex uint32) bool {
+ for _, cred := range t.Credits {
+ if outputIndex == cred.Index {
+ return true
+ }
+ }
+
+ return false
+}
+
+// CreateSimpleTx creates a new signed transaction spending unspent outputs with
+// at least minconf confirmations spending to any number of address/amount
+// pairs. Only unspent outputs belonging to the given key scope and account will
+// be selected, unless a key scope is not specified. In that case, inputs from all
+// accounts may be selected, no matter what key scope they belong to. This is
+// done to handle the default account case, where a user wants to fund a PSBT
+// with inputs regardless of their type (NP2WKH, P2WKH, etc.). Change and an
+// appropriate transaction fee are automatically included, if necessary. All
+// transaction creation through this function is serialized to prevent the
+// creation of many transactions which spend the same outputs.
+//
+// A set of functional options can be passed in to apply modifications to the
+// tx creation process such as using a custom change scope, which otherwise
+// defaults to the same as the specified coin selection scope.
+//
+// NOTE: The dryRun argument can be set true to create a tx that doesn't alter
+// the database. A tx created with this set to true SHOULD NOT be broadcast.
+func (w *Wallet) CreateSimpleTx(coinSelectKeyScope *waddrmgr.KeyScope,
+ account uint32, outputs []*wire.TxOut, minconf int32,
+ satPerKb btcutil.Amount, coinSelectionStrategy CoinSelectionStrategy,
+ dryRun bool, optFuncs ...TxCreateOption) (*txauthor.AuthoredTx, error) {
+
+ opts := defaultTxCreateOptions()
+ for _, optFunc := range optFuncs {
+ optFunc(opts)
+ }
+
+ // If the change scope isn't set, then it should be the same as the
+ // coin selection scope in order to match existing behavior.
+ if opts.changeKeyScope == nil {
+ opts.changeKeyScope = coinSelectKeyScope
+ }
+
+ req := createTxRequest{
+ coinSelectKeyScope: coinSelectKeyScope,
+ changeKeyScope: opts.changeKeyScope,
+ account: account,
+ outputs: outputs,
+ minconf: minconf,
+ feeSatPerKB: satPerKb,
+ coinSelectionStrategy: coinSelectionStrategy,
+ dryRun: dryRun,
+ resp: make(chan createTxResponse),
+ selectUtxos: opts.selectUtxos,
+ allowUtxo: opts.allowUtxo,
+ }
+ w.createTxRequests <- req
+ resp := <-req.resp
+ return resp.tx, resp.err
+}
+
+// FundPsbtDeprecated creates a fully populated PSBT packet that contains
+// enough inputs to fund the outputs specified in the passed in packet with the
+// specified fee rate. If there is change left, a change output from the wallet
+// is added and the index of the change output is returned. If no custom change
+// scope is specified, we will use the coin selection scope (if not nil) or the
+// BIP0086 scope by default. Otherwise, no additional output is created and the
+// index -1 is returned.
+//
+// NOTE: If the packet doesn't contain any inputs, coin selection is performed
+// automatically, only selecting inputs from the account based on the given key
+// scope and account number. If a key scope is not specified, then inputs from
+// accounts matching the account number provided across all key scopes may be
+// selected. This is done to handle the default account case, where a user wants
+// to fund a PSBT with inputs regardless of their type (NP2WKH, P2WKH, etc.). If
+// the packet does contain any inputs, it is assumed that full coin selection
+// happened externally and no additional inputs are added. If the specified
+// inputs aren't enough to fund the outputs with the given fee rate, an error is
+// returned.
+//
+// NOTE: A caller of the method should hold the global coin selection lock of
+// the wallet. However, no UTXO specific lock lease is acquired for any of the
+// selected/validated inputs by this method. It is in the caller's
+// responsibility to lock the inputs before handing the partial transaction out.
+func (w *Wallet) FundPsbtDeprecated(packet *psbt.Packet, keyScope *waddrmgr.KeyScope,
+ minConfs int32, account uint32, feeSatPerKB btcutil.Amount,
+ coinSelectionStrategy CoinSelectionStrategy,
+ optFuncs ...TxCreateOption) (int32, error) {
+
+ // Make sure the packet is well formed. We only require there to be at
+ // least one input or output.
+ err := psbt.VerifyInputOutputLen(packet, false, false)
+ if err != nil {
+ return 0, err
+ }
+
+ if len(packet.UnsignedTx.TxIn) == 0 && len(packet.UnsignedTx.TxOut) == 0 {
+ return 0, fmt.Errorf("PSBT packet must contain at least one " +
+ "input or output")
+ }
+
+ txOut := packet.UnsignedTx.TxOut
+ txIn := packet.UnsignedTx.TxIn
+
+ // Make sure none of the outputs are dust.
+ for _, output := range txOut {
+ // When checking an output for things like dusty-ness, we'll
+ // use the default mempool relay fee rather than the target
+ // effective fee rate to ensure accuracy. Otherwise, we may
+ // mistakenly mark small-ish, but not quite dust output as
+ // dust.
+ err := txrules.CheckOutput(output, txrules.DefaultRelayFeePerKb)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ // Let's find out the amount to fund first.
+ amt := int64(0)
+ for _, output := range txOut {
+ amt += output.Value
+ }
+
+ var tx *txauthor.AuthoredTx
+ switch {
+ // We need to do coin selection.
+ case len(txIn) == 0:
+ // We ask the underlying wallet to fund a TX for us. This
+ // includes everything we need, specifically fee estimation and
+ // change address creation.
+ tx, err = w.CreateSimpleTx(
+ keyScope, account, packet.UnsignedTx.TxOut, minConfs,
+ feeSatPerKB, coinSelectionStrategy, false,
+ optFuncs...,
+ )
+ if err != nil {
+ return 0, fmt.Errorf("error creating funding TX: %w",
+ err)
+ }
+
+ // Copy over the inputs now then collect all UTXO information
+ // that we can and attach them to the PSBT as well. We don't
+ // include the witness as the resulting PSBT isn't expected not
+ // should be signed yet.
+ packet.UnsignedTx.TxIn = tx.Tx.TxIn
+ packet.Inputs = make([]psbt.PInput, len(packet.UnsignedTx.TxIn))
+
+ for idx := range packet.UnsignedTx.TxIn {
+ // We don't want to include the witness or any script
+ // on the unsigned TX just yet.
+ packet.UnsignedTx.TxIn[idx].Witness = wire.TxWitness{}
+ packet.UnsignedTx.TxIn[idx].SignatureScript = nil
+ }
+
+ err := w.DecorateInputsDeprecated(packet, true)
+ if err != nil {
+ return 0, err
+ }
+
+ // If there are inputs, we need to check if they're sufficient and add
+ // a change output if necessary.
+ default:
+ // Make sure all inputs provided are actually ours.
+ packet.Inputs = make([]psbt.PInput, len(packet.UnsignedTx.TxIn))
+
+ for idx := range packet.UnsignedTx.TxIn {
+ // We don't want to include the witness or any script
+ // on the unsigned TX just yet.
+ packet.UnsignedTx.TxIn[idx].Witness = wire.TxWitness{}
+ packet.UnsignedTx.TxIn[idx].SignatureScript = nil
+ }
+
+ err := w.DecorateInputsDeprecated(packet, true)
+ if err != nil {
+ return 0, err
+ }
+
+ // We can leverage the fee calculation of the txauthor package
+ // if we provide the selected UTXOs as a coin source. We just
+ // need to make sure we always return the full list of user-
+ // selected UTXOs rather than a subset, otherwise our change
+ // amount will be off (in case the user selected multiple UTXOs
+ // that are large enough on their own). That's why we use our
+ // own static input source creator instead of the more generic
+ // makeInputSource() that selects a subset that is "large
+ // enough".
+ credits := make([]wtxmgr.Credit, len(txIn))
+ for idx, in := range txIn {
+ utxo := packet.Inputs[idx].WitnessUtxo
+ credits[idx] = wtxmgr.Credit{
+ OutPoint: in.PreviousOutPoint,
+ Amount: btcutil.Amount(utxo.Value),
+ PkScript: utxo.PkScript,
+ }
+ }
+ inputSource := constantInputSource(credits)
+
+ // Build the TxCreateOption to retrieve the change scope.
+ opts := defaultTxCreateOptions()
+ for _, optFunc := range optFuncs {
+ optFunc(opts)
+ }
+
+ if opts.changeKeyScope == nil {
+ opts.changeKeyScope = keyScope
+ }
+
+ // The addrMgrWithChangeSource function of the wallet creates a
+ // new change address. The address manager uses OnCommit on the
+ // walletdb tx to update the in-memory state of the account
+ // state. But because the commit happens _after_ the account
+ // manager internal lock has been released, there is a chance
+ // for the address index to be accessed concurrently, even
+ // though the closure in OnCommit re-acquires the lock. To avoid
+ // this issue, we surround the whole address creation process
+ // with a lock.
+ w.newAddrMtx.Lock()
+
+ // We also need a change source which needs to be able to insert
+ // a new change address into the database.
+ err = walletdb.Update(w.db, func(dbtx walletdb.ReadWriteTx) error {
+ _, changeSource, err := w.addrMgrWithChangeSource(
+ dbtx, opts.changeKeyScope, account,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Ask the txauthor to create a transaction with our
+ // selected coins. This will perform fee estimation and
+ // add a change output if necessary.
+ tx, err = txauthor.NewUnsignedTransaction(
+ txOut, feeSatPerKB, inputSource, changeSource,
+ )
+ if err != nil {
+ return fmt.Errorf("fee estimation not "+
+ "successful: %w", err)
+ }
+
+ return nil
+ })
+ w.newAddrMtx.Unlock()
+
+ if err != nil {
+ return 0, fmt.Errorf("could not add change address to "+
+ "database: %w", err)
+ }
+ }
+
+ // If there is a change output, we need to copy it over to the PSBT now.
+ var changeTxOut *wire.TxOut
+ if tx.ChangeIndex >= 0 {
+ changeTxOut = tx.Tx.TxOut[tx.ChangeIndex]
+ packet.UnsignedTx.TxOut = append(
+ packet.UnsignedTx.TxOut, changeTxOut,
+ )
+
+ addr, _, _, err := w.ScriptForOutputDeprecated(changeTxOut)
+ if err != nil {
+ return 0, fmt.Errorf("error querying wallet for "+
+ "change addr: %w", err)
+ }
+
+ changeOutputInfo, err := createOutputInfo(changeTxOut, addr)
+ if err != nil {
+ return 0, fmt.Errorf("error adding output info to "+
+ "change output: %w", err)
+ }
+
+ packet.Outputs = append(packet.Outputs, *changeOutputInfo)
+ }
+
+ // Now that we have the final PSBT ready, we can sort it according to
+ // BIP 69. This will sort the wire inputs and outputs and move the
+ // partial inputs and outputs accordingly.
+ err = psbt.InPlaceSort(packet)
+ if err != nil {
+ return 0, fmt.Errorf("could not sort PSBT: %w", err)
+ }
+
+ // The change output index might have changed after the sorting. We need
+ // to find our index again.
+ changeIndex := int32(-1)
+ if changeTxOut != nil {
+ for idx, txOut := range packet.UnsignedTx.TxOut {
+ if psbt.TxOutsEqual(changeTxOut, txOut) {
+ changeIndex = int32(idx)
+ break
+ }
+ }
+ }
+
+ return changeIndex, nil
+}
+
+// DecorateInputsDeprecated fetches the UTXO information of all inputs it can identify and
+// adds the required information to the package's inputs. The failOnUnknown
+// boolean controls whether the method should return an error if it cannot
+// identify an input or if it should just skip it.
+func (w *Wallet) DecorateInputsDeprecated(packet *psbt.Packet, failOnUnknown bool) error {
+ for idx := range packet.Inputs {
+ txIn := packet.UnsignedTx.TxIn[idx]
+
+ tx, utxo, derivationPath, _, err := w.FetchInputInfo(
+ &txIn.PreviousOutPoint,
+ )
+
+ switch {
+ // If the error just means it's not an input our wallet controls
+ // and the user doesn't care about that, then we can just skip
+ // this input and continue.
+ case errors.Is(err, ErrNotMine) && !failOnUnknown:
+ continue
+
+ case err != nil:
+ return fmt.Errorf("error fetching UTXO: %w", err)
+ }
+
+ addr, witnessProgram, _, err := w.ScriptForOutputDeprecated(
+ utxo,
+ )
+ if err != nil {
+ return fmt.Errorf("error fetching UTXO script: %w", err)
+ }
+
+ switch {
+ case txscript.IsPayToTaproot(utxo.PkScript):
+ addInputInfoSegWitV1(
+ &packet.Inputs[idx], utxo, derivationPath,
+ )
+
+ default:
+ addInputInfoSegWitV0(
+ &packet.Inputs[idx], tx, utxo, derivationPath,
+ addr, witnessProgram,
+ )
+ }
+ }
+
+ return nil
+}
+
+// FinalizePsbtDeprecated expects a partial transaction with all inputs and outputs fully
+// declared and tries to sign all inputs that belong to the wallet. Our wallet
+// must be the last signer of the transaction. That means, if there are any
+// unsigned non-witness inputs or inputs without UTXO information attached or
+// inputs without witness data that do not belong to the wallet, this method
+// will fail. If no error is returned, the PSBT is ready to be extracted and the
+// final TX within to be broadcast.
+//
+// NOTE: This method does NOT publish the transaction after it's been finalized
+// successfully.
+func (w *Wallet) FinalizePsbtDeprecated(keyScope *waddrmgr.KeyScope, account uint32,
+ packet *psbt.Packet) error {
+
+ // Let's check that this is actually something we can and want to sign.
+ // We need at least one input and one output. In addition each
+ // input needs nonWitness Utxo or witness Utxo data specified.
+ err := psbt.InputsReadyToSign(packet)
+ if err != nil {
+ return err
+ }
+
+ // Go through each input that doesn't have final witness data attached
+ // to it already and try to sign it. We do expect that we're the last
+ // ones to sign. If there is any input without witness data that we
+ // cannot sign because it's not our UTXO, this will be a hard failure.
+ tx := packet.UnsignedTx
+ fetcher, err := PsbtPrevOutputFetcher(packet)
+ if err != nil {
+ return err
+ }
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+ for idx, txIn := range tx.TxIn {
+ in := packet.Inputs[idx]
+
+ // We can only sign if we have UTXO information available. We
+ // can just continue here as a later step will fail with a more
+ // precise error message.
+ if in.WitnessUtxo == nil && in.NonWitnessUtxo == nil {
+ continue
+ }
+
+ // Skip this input if it's got final witness data attached.
+ if len(in.FinalScriptWitness) > 0 {
+ continue
+ }
+
+ // We can only sign this input if it's ours, so we try to map it
+ // to a coin we own. If we can't, then we'll continue as it
+ // isn't our input.
+ fullTx, txOut, _, _, err := w.FetchInputInfo(
+ &txIn.PreviousOutPoint,
+ )
+ if err != nil {
+ continue
+ }
+
+ // Find out what UTXO we are signing. Wallets _should_ always
+ // provide the full non-witness UTXO for segwit v0.
+ var signOutput *wire.TxOut
+ if in.NonWitnessUtxo != nil {
+ prevIndex := txIn.PreviousOutPoint.Index
+ signOutput = in.NonWitnessUtxo.TxOut[prevIndex]
+
+ if !psbt.TxOutsEqual(txOut, signOutput) {
+ return fmt.Errorf("found UTXO %#v but it "+
+ "doesn't match PSBT's input %v", txOut,
+ signOutput)
+ }
+
+ if fullTx.TxHash() != txIn.PreviousOutPoint.Hash {
+ return fmt.Errorf("found UTXO tx %v but it "+
+ "doesn't match PSBT's input %v",
+ fullTx.TxHash(),
+ txIn.PreviousOutPoint.Hash)
+ }
+ }
+
+ // Fall back to witness UTXO only for older wallets.
+ if in.WitnessUtxo != nil {
+ signOutput = in.WitnessUtxo
+
+ if !psbt.TxOutsEqual(txOut, signOutput) {
+ return fmt.Errorf("found UTXO %#v but it "+
+ "doesn't match PSBT's input %v", txOut,
+ signOutput)
+ }
+ }
+
+ // Finally, if the input doesn't belong to a watch-only account,
+ // then we'll sign it as is, and populate the input with the
+ // witness and sigScript (if needed).
+ watchOnly := false
+ err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ ns := tx.ReadBucket(waddrmgrNamespaceKey)
+ var err error
+ if keyScope == nil {
+ // If a key scope wasn't specified, then coin
+ // selection was performed from the default
+ // wallet accounts (NP2WKH, P2WKH, P2TR), so any
+ // key scope provided doesn't impact the result
+ // of this call.
+ watchOnly, err = w.addrStore.IsWatchOnlyAccount(
+ ns, waddrmgr.KeyScopeBIP0084, account,
+ )
+ } else {
+ watchOnly, err = w.addrStore.IsWatchOnlyAccount(
+ ns, *keyScope, account,
+ )
+ }
+ return err
+ })
+ if err != nil {
+ return fmt.Errorf("unable to determine if account is "+
+ "watch-only: %w", err)
+ }
+ if watchOnly {
+ continue
+ }
+
+ witness, sigScript, err := w.ComputeInputScript(
+ tx, signOutput, idx, sigHashes, in.SighashType, nil,
+ )
+ if err != nil {
+ return fmt.Errorf("error computing input script for "+
+ "input %d: %w", idx, err)
+ }
+
+ // Serialize the witness format from the stack representation to
+ // the wire representation.
+ var witnessBytes bytes.Buffer
+ err = psbt.WriteTxWitness(&witnessBytes, witness)
+ if err != nil {
+ return fmt.Errorf("error serializing witness: %w", err)
+ }
+ packet.Inputs[idx].FinalScriptWitness = witnessBytes.Bytes()
+ packet.Inputs[idx].FinalScriptSig = sigScript
+ }
+
+ // Make sure the PSBT itself thinks it's finalized and ready to be
+ // broadcast.
+ err = psbt.MaybeFinalizeAll(packet)
+ if err != nil {
+ return fmt.Errorf("error finalizing PSBT: %w", err)
+ }
+
+ return nil
+}
+
+// StartDeprecated starts the goroutines necessary to manage a wallet.
+//
+// Deprecated: Use WalletController.Start instead.
+func (w *Wallet) StartDeprecated() {
+ w.quitMu.Lock()
+ select {
+ case <-w.quit:
+ // Restart the wallet goroutines after shutdown finishes.
+ w.WaitForShutdown()
+ w.quit = make(chan struct{})
+ default:
+ // Ignore when the wallet is still running.
+ if w.started {
+ w.quitMu.Unlock()
+ return
+ }
+ w.started = true
+ }
+ w.quitMu.Unlock()
+
+ w.wg.Add(2)
+ go w.txCreator()
+ go w.walletLocker()
+}
+
+// StopDeprecated signals all wallet goroutines to shutdown.
+//
+// Deprecated: Use WalletController.Stop instead.
+func (w *Wallet) StopDeprecated() {
+ <-w.endRecovery()
+
+ w.quitMu.Lock()
+ quit := w.quit
+ w.quitMu.Unlock()
+
+ select {
+ case <-quit:
+ default:
+ close(quit)
+ w.chainClientLock.Lock()
+ if w.chainClient != nil {
+ w.chainClient.Stop()
+ w.chainClient = nil
+ }
+ w.chainClientLock.Unlock()
+ }
+}
+
+// UnlockDeprecated unlocks the wallet's address manager and relocks it after timeout has
+// expired. If the wallet is already unlocked and the new passphrase is
+// correct, the current timeout is replaced with the new one. The wallet will
+// be locked if the passphrase is incorrect or any other error occurs during the
+// unlock.
+//
+// Deprecated: Use WalletController.Unlock instead.
+func (w *Wallet) UnlockDeprecated(passphrase []byte, lock <-chan time.Time) error {
+ err := make(chan error, 1)
+ w.unlockRequests <- unlockRequest{
+ passphrase: passphrase,
+ lockAfter: lock,
+ err: err,
+ }
+ return <-err
+}
+
+// LockDeprecated locks the wallet's address manager.
+//
+// Deprecated: Use WalletController.Lock instead.
+func (w *Wallet) LockDeprecated() {
+ w.lockRequests <- struct{}{}
+}
+
+// AddressInfoDeprecated returns detailed information regarding a wallet
+// address.
+func (w *Wallet) AddressInfoDeprecated(a address.Address) (
+ waddrmgr.ManagedAddress, error) {
+
+ var managedAddress waddrmgr.ManagedAddress
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ var err error
+
+ managedAddress, err = w.addrStore.Address(addrmgrNs, a)
+ return err
+ })
+ return managedAddress, err
+}
+
+// SynchronizeRPC associates the wallet with the consensus RPC client,
+// synchronizes the wallet with the latest changes to the blockchain, and
+// continuously updates the wallet through RPC notifications.
+//
+// This method is unstable and will be removed when all syncing logic is moved
+// outside of the wallet package.
+func (w *Wallet) SynchronizeRPC(chainClient chain.Interface) {
+ w.quitMu.Lock()
+ select {
+ case <-w.quit:
+ w.quitMu.Unlock()
+ return
+ default:
+ }
+ w.quitMu.Unlock()
+
+ // TODO: Ignoring the new client when one is already set breaks callers
+ // who are replacing the client, perhaps after a disconnect.
+ w.chainClientLock.Lock()
+ if w.chainClient != nil {
+ w.chainClientLock.Unlock()
+ return
+ }
+ w.chainClient = chainClient
+
+ // If the chain client is a NeutrinoClient instance, set a birthday so
+ // we don't download all the filters as we go.
+ switch cc := chainClient.(type) {
+ case *chain.NeutrinoClient:
+ cc.SetStartTime(w.addrStore.Birthday())
+ case *chain.BitcoindClient:
+ cc.SetBirthday(w.addrStore.Birthday())
+ }
+ w.chainClientLock.Unlock()
+
+ // TODO: It would be preferable to either run these goroutines
+ // separately from the wallet (use wallet mutator functions to
+ // make changes from the RPC client) and not have to stop and
+ // restart them each time the client disconnects and reconnets.
+ w.wg.Add(4)
+ go w.handleChainNotifications()
+ go w.rescanBatchHandler()
+ go w.rescanProgressHandler()
+ go w.rescanRPCHandler()
+}
+
+// requireChainClient marks that a wallet method can only be completed when the
+// consensus RPC server is set. This function and all functions that call it
+// are unstable and will need to be moved when the syncing code is moved out of
+// the wallet.
+func (w *Wallet) requireChainClient() (chain.Interface, error) {
+ w.chainClientLock.Lock()
+ chainClient := w.chainClient
+ w.chainClientLock.Unlock()
+ if chainClient == nil {
+ return nil, errors.New("blockchain RPC is inactive")
+ }
+ return chainClient, nil
+}
+
+// ChainClient returns the optional consensus RPC client associated with the
+// wallet.
+//
+// This function is unstable and will be removed once sync logic is moved out of
+// the wallet.
+func (w *Wallet) ChainClient() chain.Interface {
+ w.chainClientLock.Lock()
+ chainClient := w.chainClient
+ w.chainClientLock.Unlock()
+ return chainClient
+}
+
+// quitChan atomically reads the quit channel.
+func (w *Wallet) quitChan() <-chan struct{} {
+ w.quitMu.Lock()
+ c := w.quit
+ w.quitMu.Unlock()
+ return c
+}
+
+// ShuttingDown returns whether the wallet is currently in the process of
+// shutting down or not.
+func (w *Wallet) ShuttingDown() bool {
+ select {
+ case <-w.quitChan():
+ return true
+ default:
+ return false
+ }
+}
+
+// WaitForShutdown blocks until all wallet goroutines have finished executing.
+func (w *Wallet) WaitForShutdown() {
+ w.chainClientLock.Lock()
+ if w.chainClient != nil {
+ w.chainClient.WaitForShutdown()
+ }
+ w.chainClientLock.Unlock()
+ w.wg.Wait()
+}
+
+// SynchronizingToNetwork returns whether the wallet is currently synchronizing
+// with the Bitcoin network.
+func (w *Wallet) SynchronizingToNetwork() bool {
+ // At the moment, RPC is the only synchronization method. In the
+ // future, when SPV is added, a separate check will also be needed, or
+ // SPV could always be enabled if RPC was not explicitly specified when
+ // creating the wallet.
+ w.chainClientSyncMtx.Lock()
+ syncing := w.chainClient != nil
+ w.chainClientSyncMtx.Unlock()
+ return syncing
+}
+
+// ChainSynced returns whether the wallet has been attached to a chain server
+// and synced up to the best block on the main chain.
+func (w *Wallet) ChainSynced() bool {
+ w.chainClientSyncMtx.Lock()
+ synced := w.chainClientSynced
+ w.chainClientSyncMtx.Unlock()
+ return synced
+}
+
+// SetChainSynced marks whether the wallet is connected to and currently in sync
+// with the latest block notified by the chain server.
+//
+// NOTE: Due to an API limitation with rpcclient, this may return true after
+// the client disconnected (and is attempting a reconnect). This will be unknown
+// until the reconnect notification is received, at which point the wallet can be
+// marked out of sync again until after the next rescan completes.
+func (w *Wallet) SetChainSynced(synced bool) {
+ w.chainClientSyncMtx.Lock()
+ w.chainClientSynced = synced
+ w.chainClientSyncMtx.Unlock()
+}
+
+// activeData returns the currently-active receiving addresses and all unspent
+// outputs. This is primarely intended to provide the parameters for a
+// rescan request.
+func (w *Wallet) activeData(dbtx walletdb.ReadWriteTx) ([]address.Address, []wtxmgr.Credit, error) {
+ addrmgrNs := dbtx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := dbtx.ReadWriteBucket(wtxmgrNamespaceKey)
+
+ var addrs []address.Address
+
+ err := w.addrStore.ForEachRelevantActiveAddress(
+ addrmgrNs, func(addr address.Address) error {
+ addrs = append(addrs, addr)
+ return nil
+ },
+ )
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Before requesting the list of spendable UTXOs, we'll delete any
+ // expired output locks.
+ err = w.txStore.DeleteExpiredLockedOutputs(
+ dbtx.ReadWriteBucket(wtxmgrNamespaceKey),
+ )
+ if err != nil {
+ return nil, nil, err
+ }
+
+ unspent, err := w.txStore.OutputsToWatch(txmgrNs)
+ return addrs, unspent, err
+}
+
+// syncWithChain brings the wallet up to date with the current chain server
+// connection. It creates a rescan request and blocks until the rescan has
+// finished. The birthday block can be passed in, if set, to ensure we can
+// properly detect if it gets rolled back.
+func (w *Wallet) syncWithChain(birthdayStamp *waddrmgr.BlockStamp) error {
+ chainClient, err := w.requireChainClient()
+ if err != nil {
+ return err
+ }
+
+ // Neutrino relies on the information given to it by the cfheader server
+ // so it knows exactly whether it's synced up to the server's state or
+ // not, even on dev chains. To recover a Neutrino wallet, we need to
+ // make sure it's synced before we start scanning for addresses,
+ // otherwise we might miss some if we only scan up to its current sync
+ // point.
+ neutrinoRecovery := chainClient.BackEnd() == "neutrino" &&
+ w.recoveryWindow > 0
+
+ // We'll wait until the backend is synced to ensure we get the latest
+ // MaxReorgDepth blocks to store. We don't do this for development
+ // environments as we can't guarantee a lively chain, except for
+ // Neutrino, where the cfheader server tells us what it believes the
+ // chain tip is.
+ if !w.isDevEnv() || neutrinoRecovery {
+ log.Debug("Waiting for chain backend to sync to tip")
+ if err := w.waitUntilBackendSynced(chainClient); err != nil {
+ return err
+ }
+ log.Debug("Chain backend synced to tip!")
+ }
+
+ // If we've yet to find our birthday block, we'll do so now.
+ if birthdayStamp == nil {
+ var err error
+ birthdayStamp, err = locateBirthdayBlock(
+ chainClient, w.addrStore.Birthday(),
+ )
+ if err != nil {
+ return fmt.Errorf("unable to locate birthday block: %w",
+ err)
+ }
+
+ // We'll also determine our initial sync starting height. This
+ // is needed as the wallet can now begin storing blocks from an
+ // arbitrary height, rather than all the blocks from genesis, so
+ // we persist this height to ensure we don't store any blocks
+ // before it.
+ startHeight := birthdayStamp.Height
+
+ // With the starting height obtained, get the remaining block
+ // details required by the wallet.
+ startHash, err := chainClient.GetBlockHash(int64(startHeight))
+ if err != nil {
+ return err
+ }
+ startHeader, err := chainClient.GetBlockHeader(startHash)
+ if err != nil {
+ return err
+ }
+
+ err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ err := w.addrStore.SetSyncedTo(ns, &waddrmgr.BlockStamp{
+ Hash: *startHash,
+ Height: startHeight,
+ Timestamp: startHeader.Timestamp,
+ })
+ if err != nil {
+ return err
+ }
+
+ return w.addrStore.SetBirthdayBlock(
+ ns, *birthdayStamp, true,
+ )
+ })
+ if err != nil {
+ return fmt.Errorf("unable to persist initial sync "+
+ "data: %w", err)
+ }
+ }
+
+ // If the wallet requested an on-chain recovery of its funds, we'll do
+ // so now.
+ if w.recoveryWindow > 0 {
+ if err := w.recovery(chainClient, birthdayStamp); err != nil {
+ return fmt.Errorf("unable to perform wallet recovery: "+
+ "%w", err)
+ }
+ }
+
+ // Compare previously-seen blocks against the current chain. If any of
+ // these blocks no longer exist, rollback all of the missing blocks
+ // before catching up with the rescan.
+ rollback := false
+ rollbackStamp := w.addrStore.SyncedTo()
+ err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ txmgrNs := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+
+ for height := rollbackStamp.Height; true; height-- {
+ hash, err := w.addrStore.BlockHash(addrmgrNs, height)
+ if err != nil {
+ return err
+ }
+ chainHash, err := chainClient.GetBlockHash(int64(height))
+ if err != nil {
+ return err
+ }
+ header, err := chainClient.GetBlockHeader(chainHash)
+ if err != nil {
+ return err
+ }
+
+ rollbackStamp.Hash = *chainHash
+ rollbackStamp.Height = height
+ rollbackStamp.Timestamp = header.Timestamp
+
+ if bytes.Equal(hash[:], chainHash[:]) {
+ break
+ }
+ rollback = true
+ }
+
+ // If a rollback did not happen, we can proceed safely.
+ if !rollback {
+ return nil
+ }
+
+ // Otherwise, we'll mark this as our new synced height.
+ err := w.addrStore.SetSyncedTo(addrmgrNs, &rollbackStamp)
+ if err != nil {
+ return err
+ }
+
+ // If the rollback happened to go beyond our birthday stamp,
+ // we'll need to find a new one by syncing with the chain again
+ // until finding one.
+ if rollbackStamp.Height <= birthdayStamp.Height &&
+ rollbackStamp.Hash != birthdayStamp.Hash {
+
+ err := w.addrStore.SetBirthdayBlock(
+ addrmgrNs, rollbackStamp, true,
+ )
+ if err != nil {
+ return err
+ }
+ }
+
+ // Finally, we'll roll back our transaction store to reflect the
+ // stale state. `Rollback` unconfirms transactions at and beyond
+ // the passed height, so add one to the new synced-to height to
+ // prevent unconfirming transactions in the synced-to block.
+ return w.txStore.Rollback(txmgrNs, rollbackStamp.Height+1)
+ })
+ if err != nil {
+ return err
+ }
+
+ // Request notifications for connected and disconnected blocks.
+ //
+ // TODO(jrick): Either request this notification only once, or when
+ // rpcclient is modified to allow some notification request to not
+ // automatically resent on reconnect, include the notifyblocks request
+ // as well. I am leaning towards allowing off all rpcclient
+ // notification re-registrations, in which case the code here should be
+ // left as is.
+ if err := chainClient.NotifyBlocks(); err != nil {
+ return err
+ }
+
+ // Finally, we'll trigger a wallet rescan and request notifications for
+ // transactions sending to all wallet addresses and spending all wallet
+ // UTXOs.
+ var (
+ addrs []address.Address
+ unspent []wtxmgr.Credit
+ )
+ err = walletdb.Update(w.db, func(dbtx walletdb.ReadWriteTx) error {
+ addrs, unspent, err = w.activeData(dbtx)
+ return err
+ })
+ if err != nil {
+ return err
+ }
+
+ return w.rescanWithTarget(addrs, unspent, nil)
+}
+
+// isDevEnv determines whether the wallet is currently under a local developer
+// environment, e.g. simnet or regtest.
+func (w *Wallet) isDevEnv() bool {
+ switch uint32(w.ChainParams().Net) {
+ case uint32(chaincfg.RegressionNetParams.Net):
+ case uint32(chaincfg.SimNetParams.Net):
+ default:
+ return false
+ }
+ return true
+}
+
+// waitUntilBackendSynced blocks until the chain backend considers itself
+// "current".
+func (w *Wallet) waitUntilBackendSynced(chainClient chain.Interface) error {
+ // We'll poll every second to determine if our chain considers itself
+ // "current".
+ t := time.NewTicker(time.Second)
+ defer t.Stop()
+
+ for {
+ select {
+ case <-t.C:
+ if chainClient.IsCurrent() {
+ return nil
+ }
+ case <-w.quitChan():
+ return ErrWalletShuttingDown
+ }
+ }
+}
+
+// recoverySyncer is used to synchronize wallet and address manager locking
+// with the end of recovery. (*Wallet).recovery will store a recoverySyncer
+// when invoked, and will close the done chan upon exit. Setting the quit flag
+// will cause recovery to end after the current batch of blocks.
+type recoverySyncer struct {
+ done chan struct{}
+ quit uint32 // atomic
+}
+
+// recovery attempts to recover any unspent outputs that pay to any of our
+// addresses starting from our birthday, or the wallet's tip (if higher), which
+// would indicate resuming a recovery after a restart.
+func (w *Wallet) recovery(chainClient chain.Interface,
+ birthdayBlock *waddrmgr.BlockStamp) error {
+
+ log.Infof("RECOVERY MODE ENABLED -- rescanning for used addresses "+
+ "with recovery_window=%d", w.recoveryWindow)
+
+ // Wallet locking must synchronize with the end of recovery, since use of
+ // keys in recovery is racy with manager IsLocked checks, which could
+ // result in enrypting data with a zeroed key.
+ syncer := &recoverySyncer{done: make(chan struct{})}
+ w.recovering.Store(syncer)
+ defer close(syncer.done)
+
+ // We'll initialize the recovery manager with a default batch size of
+ // 2000.
+ recoveryMgr := NewRecoveryManager(
+ w.recoveryWindow, recoveryBatchSize, w.chainParams,
+ )
+
+ // In the event that this recovery is being resumed, we will need to
+ // repopulate all found addresses from the database. Ideally, for basic
+ // recovery, we would only do so for the default scopes, but due to a
+ // bug in which the wallet would create change addresses outside of the
+ // default scopes, it's necessary to attempt all registered key scopes.
+ scopedMgrs := make(map[waddrmgr.KeyScope]waddrmgr.AccountStore)
+ for _, scopedMgr := range w.addrStore.ActiveScopedKeyManagers() {
+ scopedMgrs[scopedMgr.Scope()] = scopedMgr
+ }
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ txMgrNS := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ credits, err := w.txStore.UnspentOutputs(txMgrNS)
+ if err != nil {
+ return err
+ }
+ addrMgrNS := tx.ReadBucket(waddrmgrNamespaceKey)
+ return recoveryMgr.Resurrect(addrMgrNS, scopedMgrs, credits)
+ })
+ if err != nil {
+ return err
+ }
+
+ // Fetch the best height from the backend to determine when we should
+ // stop.
+ _, bestHeight, err := chainClient.GetBestBlock()
+ if err != nil {
+ return err
+ }
+
+ // Now we can begin scanning the chain from the wallet's current tip to
+ // ensure we properly handle restarts. Since the recovery process itself
+ // acts as rescan, we'll also update our wallet's synced state along the
+ // way to reflect the blocks we process and prevent rescanning them
+ // later on.
+ //
+ // NOTE: We purposefully don't update our best height since we assume
+ // that a wallet rescan will be performed from the wallet's tip, which
+ // will be of bestHeight after completing the recovery process.
+ var blocks []*waddrmgr.BlockStamp
+
+ startHeight := w.addrStore.SyncedTo().Height + 1
+ for height := startHeight; height <= bestHeight; height++ {
+ if atomic.LoadUint32(&syncer.quit) == 1 {
+ return errors.New("recovery: forced shutdown")
+ }
+
+ hash, err := chainClient.GetBlockHash(int64(height))
+ if err != nil {
+ return err
+ }
+ header, err := chainClient.GetBlockHeader(hash)
+ if err != nil {
+ return err
+ }
+ blocks = append(blocks, &waddrmgr.BlockStamp{
+ Hash: *hash,
+ Height: height,
+ Timestamp: header.Timestamp,
+ })
+
+ // It's possible for us to run into blocks before our birthday
+ // if our birthday is after our reorg safe height, so we'll make
+ // sure to not add those to the batch.
+ if height >= birthdayBlock.Height {
+ recoveryMgr.AddToBlockBatch(
+ hash, height, header.Timestamp,
+ )
+ }
+
+ // We'll perform our recovery in batches of 2000 blocks. It's
+ // possible for us to reach our best height without exceeding
+ // the recovery batch size, so we can proceed to commit our
+ // state to disk.
+ recoveryBatch := recoveryMgr.BlockBatch()
+ if len(recoveryBatch) == recoveryBatchSize || height == bestHeight {
+ err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ if err := w.recoverScopedAddresses(
+ chainClient, tx, ns, recoveryBatch,
+ recoveryMgr.State(), scopedMgrs,
+ ); err != nil {
+ return err
+ }
+
+ // TODO: Any error here will roll back this
+ // entire tx. This may cause the in memory sync
+ // point to become desyncronized. Refactor so
+ // that this cannot happen.
+ for _, block := range blocks {
+ err := w.addrStore.SetSyncedTo(
+ ns, block,
+ )
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+
+ if len(recoveryBatch) > 0 {
+ log.Infof("Recovered addresses from blocks "+
+ "%d-%d", recoveryBatch[0].Height,
+ recoveryBatch[len(recoveryBatch)-1].Height)
+ }
+
+ // Clear the batch of all processed blocks to reuse the
+ // same memory for future batches.
+ blocks = blocks[:0]
+ recoveryMgr.ResetBlockBatch()
+ }
+ }
+
+ return nil
+}
+
+// recoverScopedAddresses scans a range of blocks in attempts to recover any
+// previously used addresses for a particular account derivation path. At a high
+// level, the algorithm works as follows:
+//
+// 1. Ensure internal and external branch horizons are fully expanded.
+// 2. Filter the entire range of blocks, stopping if a non-zero number of
+// address are contained in a particular block.
+// 3. Record all internal and external addresses found in the block.
+// 4. Record any outpoints found in the block that should be watched for spends
+// 5. Trim the range of blocks up to and including the one reporting the addrs.
+// 6. Repeat from (1) if there are still more blocks in the range.
+//
+// TODO(conner): parallelize/pipeline/cache intermediate network requests
+func (w *Wallet) recoverScopedAddresses(
+ chainClient chain.Interface,
+ tx walletdb.ReadWriteTx,
+ ns walletdb.ReadWriteBucket,
+ batch []wtxmgr.BlockMeta,
+ recoveryState *RecoveryState,
+ scopedMgrs map[waddrmgr.KeyScope]waddrmgr.AccountStore) error {
+
+ // If there are no blocks in the batch, we are done.
+ if len(batch) == 0 {
+ return nil
+ }
+
+ log.Infof("Scanning %d blocks for recoverable addresses", len(batch))
+
+expandHorizons:
+ for scope, scopedMgr := range scopedMgrs {
+ scopeState := recoveryState.StateForScope(scope)
+ err := expandScopeHorizons(ns, scopedMgr, scopeState)
+ if err != nil {
+ return err
+ }
+ }
+
+ // With the internal and external horizons properly expanded, we now
+ // construct the filter blocks request. The request includes the range
+ // of blocks we intend to scan, in addition to the scope-index -> addr
+ // map for all internal and external branches.
+ filterReq := newFilterBlocksRequest(batch, scopedMgrs, recoveryState)
+
+ // Initiate the filter blocks request using our chain backend. If an
+ // error occurs, we are unable to proceed with the recovery.
+ filterResp, err := chainClient.FilterBlocks(filterReq)
+ if err != nil {
+ return err
+ }
+
+ // If the filter response is empty, this signals that the rest of the
+ // batch was completed, and no other addresses were discovered. As a
+ // result, no further modifications to our recovery state are required
+ // and we can proceed to the next batch.
+ if filterResp == nil {
+ return nil
+ }
+
+ // Otherwise, retrieve the block info for the block that detected a
+ // non-zero number of address matches.
+ block := batch[filterResp.BatchIndex]
+
+ // Log any non-trivial findings of addresses or outpoints.
+ logFilterBlocksResp(block, filterResp)
+
+ // Report any external or internal addresses found as a result of the
+ // appropriate branch recovery state. Adding indexes above the
+ // last-found index of either will result in the horizons being expanded
+ // upon the next iteration. Any found addresses are also marked used
+ // using the scoped key manager.
+ err = extendFoundAddresses(ns, filterResp, scopedMgrs, recoveryState)
+ if err != nil {
+ return err
+ }
+
+ // Update the global set of watched outpoints with any that were found
+ // in the block.
+ for outPoint, addr := range filterResp.FoundOutPoints {
+ outPoint := outPoint
+ recoveryState.AddWatchedOutPoint(&outPoint, addr)
+ }
+
+ // Finally, record all of the relevant transactions that were returned
+ // in the filter blocks response. This ensures that these transactions
+ // and their outputs are tracked when the final rescan is performed.
+ for _, txn := range filterResp.RelevantTxns {
+ txRecord, err := wtxmgr.NewTxRecordFromMsgTx(
+ txn, filterResp.BlockMeta.Time,
+ )
+ if err != nil {
+ return err
+ }
+
+ err = w.addRelevantTx(tx, txRecord, &filterResp.BlockMeta)
+ if err != nil {
+ return err
+ }
+ }
+
+ // Update the batch to indicate that we've processed all block through
+ // the one that returned found addresses.
+ batch = batch[filterResp.BatchIndex+1:]
+
+ // If this was not the last block in the batch, we will repeat the
+ // filtering process again after expanding our horizons.
+ if len(batch) > 0 {
+ goto expandHorizons
+ }
+
+ return nil
+}
+
+// expandScopeHorizons ensures that the ScopeRecoveryState has an adequately
+// sized look ahead for both its internal and external branches. The keys
+// derived here are added to the scope's recovery state, but do not affect the
+// persistent state of the wallet. If any invalid child keys are detected, the
+// horizon will be properly extended such that our lookahead always includes the
+// proper number of valid child keys.
+func expandScopeHorizons(ns walletdb.ReadWriteBucket,
+ scopedMgr waddrmgr.AccountStore,
+ scopeState *ScopeRecoveryState) error {
+
+ // Compute the current external horizon and the number of addresses we
+ // must derive to ensure we maintain a sufficient recovery window for
+ // the external branch.
+ exHorizon, exWindow := scopeState.ExternalBranch.ExtendHorizon()
+ count, childIndex := uint32(0), exHorizon
+ for count < exWindow {
+ keyPath := externalKeyPath(childIndex)
+ addr, err := scopedMgr.DeriveFromKeyPath(ns, keyPath)
+ switch {
+ case err == hdkeychain.ErrInvalidChild:
+ // Record the existence of an invalid child with the
+ // external branch's recovery state. This also
+ // increments the branch's horizon so that it accounts
+ // for this skipped child index.
+ scopeState.ExternalBranch.MarkInvalidChild(childIndex)
+ childIndex++
+ continue
+
+ case err != nil:
+ return err
+ }
+
+ // Register the newly generated external address and child index
+ // with the external branch recovery state.
+ scopeState.ExternalBranch.AddAddr(childIndex, addr.Address())
+
+ childIndex++
+ count++
+ }
+
+ // Compute the current internal horizon and the number of addresses we
+ // must derive to ensure we maintain a sufficient recovery window for
+ // the internal branch.
+ inHorizon, inWindow := scopeState.InternalBranch.ExtendHorizon()
+ count, childIndex = 0, inHorizon
+ for count < inWindow {
+ keyPath := internalKeyPath(childIndex)
+ addr, err := scopedMgr.DeriveFromKeyPath(ns, keyPath)
+ switch {
+ case err == hdkeychain.ErrInvalidChild:
+ // Record the existence of an invalid child with the
+ // internal branch's recovery state. This also
+ // increments the branch's horizon so that it accounts
+ // for this skipped child index.
+ scopeState.InternalBranch.MarkInvalidChild(childIndex)
+ childIndex++
+ continue
+
+ case err != nil:
+ return err
+ }
+
+ // Register the newly generated internal address and child index
+ // with the internal branch recovery state.
+ scopeState.InternalBranch.AddAddr(childIndex, addr.Address())
+
+ childIndex++
+ count++
+ }
+
+ return nil
+}
+
+// externalKeyPath returns the relative external derivation path /0/0/index.
+func externalKeyPath(index uint32) waddrmgr.DerivationPath {
+ return waddrmgr.DerivationPath{
+ InternalAccount: waddrmgr.DefaultAccountNum,
+ Account: waddrmgr.DefaultAccountNum,
+ Branch: waddrmgr.ExternalBranch,
+ Index: index,
+ }
+}
+
+// internalKeyPath returns the relative internal derivation path /0/1/index.
+func internalKeyPath(index uint32) waddrmgr.DerivationPath {
+ return waddrmgr.DerivationPath{
+ InternalAccount: waddrmgr.DefaultAccountNum,
+ Account: waddrmgr.DefaultAccountNum,
+ Branch: waddrmgr.InternalBranch,
+ Index: index,
+ }
+}
+
+// newFilterBlocksRequest constructs FilterBlocksRequests using our current
+// block range, scoped managers, and recovery state.
+func newFilterBlocksRequest(batch []wtxmgr.BlockMeta,
+ scopedMgrs map[waddrmgr.KeyScope]waddrmgr.AccountStore,
+ recoveryState *RecoveryState) *chain.FilterBlocksRequest {
+
+ filterReq := &chain.FilterBlocksRequest{
+ Blocks: batch,
+ ExternalAddrs: make(map[waddrmgr.ScopedIndex]address.Address),
+ InternalAddrs: make(map[waddrmgr.ScopedIndex]address.Address),
+ WatchedOutPoints: recoveryState.WatchedOutPoints(),
+ }
+
+ // Populate the external and internal addresses by merging the addresses
+ // sets belong to all currently tracked scopes.
+ for scope := range scopedMgrs {
+ scopeState := recoveryState.StateForScope(scope)
+ for index, addr := range scopeState.ExternalBranch.Addrs() {
+ scopedIndex := waddrmgr.ScopedIndex{
+ Scope: scope,
+ Index: index,
+ }
+ filterReq.ExternalAddrs[scopedIndex] = addr
+ }
+ for index, addr := range scopeState.InternalBranch.Addrs() {
+ scopedIndex := waddrmgr.ScopedIndex{
+ Scope: scope,
+ Index: index,
+ }
+ filterReq.InternalAddrs[scopedIndex] = addr
+ }
+ }
+
+ return filterReq
+}
+
+// extendFoundAddresses accepts a filter blocks response that contains addresses
+// found on chain, and advances the state of all relevant derivation paths to
+// match the highest found child index for each branch.
+func extendFoundAddresses(ns walletdb.ReadWriteBucket,
+ filterResp *chain.FilterBlocksResponse,
+ scopedMgrs map[waddrmgr.KeyScope]waddrmgr.AccountStore,
+ recoveryState *RecoveryState) error {
+
+ // Mark all recovered external addresses as used. This will be done only
+ // for scopes that reported a non-zero number of external addresses in
+ // this block.
+ for scope, indexes := range filterResp.FoundExternalAddrs {
+ // First, report all external child indexes found for this
+ // scope. This ensures that the external last-found index will
+ // be updated to include the maximum child index seen thus far.
+ scopeState := recoveryState.StateForScope(scope)
+ for index := range indexes {
+ scopeState.ExternalBranch.ReportFound(index)
+ }
+
+ scopedMgr := scopedMgrs[scope]
+
+ // Now, with all found addresses reported, derive and extend all
+ // external addresses up to and including the current last found
+ // index for this scope.
+ exNextUnfound := scopeState.ExternalBranch.NextUnfound()
+
+ exLastFound := exNextUnfound
+ if exLastFound > 0 {
+ exLastFound--
+ }
+
+ err := scopedMgr.ExtendExternalAddresses(
+ ns, waddrmgr.DefaultAccountNum, exLastFound,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Finally, with the scope's addresses extended, we mark used
+ // the external addresses that were found in the block and
+ // belong to this scope.
+ for index := range indexes {
+ addr := scopeState.ExternalBranch.GetAddr(index)
+ err := scopedMgr.MarkUsed(ns, addr)
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ // Mark all recovered internal addresses as used. This will be done only
+ // for scopes that reported a non-zero number of internal addresses in
+ // this block.
+ for scope, indexes := range filterResp.FoundInternalAddrs {
+ // First, report all internal child indexes found for this
+ // scope. This ensures that the internal last-found index will
+ // be updated to include the maximum child index seen thus far.
+ scopeState := recoveryState.StateForScope(scope)
+ for index := range indexes {
+ scopeState.InternalBranch.ReportFound(index)
+ }
+
+ scopedMgr := scopedMgrs[scope]
+
+ // Now, with all found addresses reported, derive and extend all
+ // internal addresses up to and including the current last found
+ // index for this scope.
+ inNextUnfound := scopeState.InternalBranch.NextUnfound()
+
+ inLastFound := inNextUnfound
+ if inLastFound > 0 {
+ inLastFound--
+ }
+ err := scopedMgr.ExtendInternalAddresses(
+ ns, waddrmgr.DefaultAccountNum, inLastFound,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Finally, with the scope's addresses extended, we mark used
+ // the internal addresses that were found in the blockand belong
+ // to this scope.
+ for index := range indexes {
+ addr := scopeState.InternalBranch.GetAddr(index)
+ err := scopedMgr.MarkUsed(ns, addr)
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// logFilterBlocksResp provides useful logging information when filtering
+// succeeded in finding relevant transactions.
+func logFilterBlocksResp(block wtxmgr.BlockMeta,
+ resp *chain.FilterBlocksResponse) {
+
+ // Log the number of external addresses found in this block.
+ var nFoundExternal int
+ for _, indexes := range resp.FoundExternalAddrs {
+ nFoundExternal += len(indexes)
+ }
+ if nFoundExternal > 0 {
+ log.Infof("Recovered %d external addrs at height=%d hash=%v",
+ nFoundExternal, block.Height, block.Hash)
+ }
+
+ // Log the number of internal addresses found in this block.
+ var nFoundInternal int
+ for _, indexes := range resp.FoundInternalAddrs {
+ nFoundInternal += len(indexes)
+ }
+ if nFoundInternal > 0 {
+ log.Infof("Recovered %d internal addrs at height=%d hash=%v",
+ nFoundInternal, block.Height, block.Hash)
+ }
+
+ // Log the number of outpoints found in this block.
+ nFoundOutPoints := len(resp.FoundOutPoints)
+ if nFoundOutPoints > 0 {
+ log.Infof("Found %d spends from watched outpoints at "+
+ "height=%d hash=%v",
+ nFoundOutPoints, block.Height, block.Hash)
+ }
+}
+
+type (
+ createTxRequest struct {
+ coinSelectKeyScope *waddrmgr.KeyScope
+ changeKeyScope *waddrmgr.KeyScope
+ account uint32
+ outputs []*wire.TxOut
+ minconf int32
+ feeSatPerKB btcutil.Amount
+ coinSelectionStrategy CoinSelectionStrategy
+ dryRun bool
+ resp chan createTxResponse
+ selectUtxos []wire.OutPoint
+ allowUtxo func(wtxmgr.Credit) bool
+ }
+ createTxResponse struct {
+ tx *txauthor.AuthoredTx
+ err error
+ }
+)
+
+// txCreator is responsible for the input selection and creation of
+// transactions. These functions are the responsibility of this method
+// (designed to be run as its own goroutine) since input selection must be
+// serialized, or else it is possible to create double spends by choosing the
+// same inputs for multiple transactions. Along with input selection, this
+// method is also responsible for the signing of transactions, since we don't
+// want to end up in a situation where we run out of inputs as multiple
+// transactions are being created. In this situation, it would then be possible
+// for both requests, rather than just one, to fail due to not enough available
+// inputs.
+func (w *Wallet) txCreator() {
+ quit := w.quitChan()
+out:
+ for {
+ select {
+ case txr := <-w.createTxRequests:
+ // If the wallet can be locked because it contains
+ // private key material, we need to prevent it from
+ // doing so while we are assembling the transaction.
+ release := func() {}
+ if !w.addrStore.WatchOnly() {
+ heldUnlock, err := w.holdUnlock()
+ if err != nil {
+ txr.resp <- createTxResponse{nil, err}
+ continue
+ }
+
+ release = heldUnlock.release
+ }
+
+ tx, err := w.txToOutputs(
+ txr.outputs, txr.coinSelectKeyScope,
+ txr.changeKeyScope, txr.account, txr.minconf,
+ txr.feeSatPerKB, txr.coinSelectionStrategy,
+ txr.dryRun, txr.selectUtxos, txr.allowUtxo,
+ )
+
+ release()
+ txr.resp <- createTxResponse{tx, err}
+ case <-quit:
+ break out
+ }
+ }
+ w.wg.Done()
+}
+
+// txCreateOptions is a set of optional arguments to modify the tx creation
+// process. This can be used to do things like use a custom coin selection
+// scope, which otherwise will default to the specified coin selection scope.
+type txCreateOptions struct {
+ changeKeyScope *waddrmgr.KeyScope
+ selectUtxos []wire.OutPoint
+ allowUtxo func(wtxmgr.Credit) bool
+}
+
+// TxCreateOption is a set of optional arguments to modify the tx creation
+// process. This can be used to do things like use a custom coin selection
+// scope, which otherwise will default to the specified coin selection scope.
+type TxCreateOption func(*txCreateOptions)
+
+// defaultTxCreateOptions is the default set of options.
+func defaultTxCreateOptions() *txCreateOptions {
+ return &txCreateOptions{}
+}
+
+// WithCustomChangeScope can be used to specify a change scope for the change
+// address. If unspecified, then the same scope will be used for both inputs
+// and the change addr. Not specifying any scope at all (nil) will use all
+// available coins and the default change scope (P2TR).
+func WithCustomChangeScope(changeScope *waddrmgr.KeyScope) TxCreateOption {
+ return func(opts *txCreateOptions) {
+ opts.changeKeyScope = changeScope
+ }
+}
+
+// WithCustomSelectUtxos is used to specify the inputs to be used while
+// creating txns.
+func WithCustomSelectUtxos(utxos []wire.OutPoint) TxCreateOption {
+ return func(opts *txCreateOptions) {
+ opts.selectUtxos = utxos
+ }
+}
+
+// WithUtxoFilter is used to restrict the selection of the internal wallet
+// inputs by further external conditions. Utxos which pass the filter are
+// considered when creating the transaction.
+func WithUtxoFilter(allowUtxo func(utxo wtxmgr.Credit) bool) TxCreateOption {
+ return func(opts *txCreateOptions) {
+ opts.allowUtxo = allowUtxo
+ }
+}
+
+type (
+ unlockRequest struct {
+ passphrase []byte
+ lockAfter <-chan time.Time // nil prevents the timeout.
+ err chan error
+ }
+
+ changePassphraseRequest struct {
+ old, new []byte
+ private bool
+ err chan error
+ }
+
+ changePassphrasesRequest struct {
+ publicOld, publicNew []byte
+ privateOld, privateNew []byte
+ err chan error
+ }
+
+ // heldUnlock is a tool to prevent the wallet from automatically
+ // locking after some timeout before an operation which needed
+ // the unlocked wallet has finished. Any acquired heldUnlock
+ // *must* be released (preferably with a defer) or the wallet
+ // will forever remain unlocked.
+ heldUnlock chan struct{}
+)
+
+// endRecovery tells (*Wallet).recovery to stop, if running, and returns a
+// channel that will be closed when the recovery routine exits.
+func (w *Wallet) endRecovery() <-chan struct{} {
+ if recoverySyncI := w.recovering.Load(); recoverySyncI != nil {
+ recoverySync := recoverySyncI.(*recoverySyncer)
+
+ // If recovery is still running, it will end early with an error
+ // once we set the quit flag.
+ atomic.StoreUint32(&recoverySync.quit, 1)
+
+ return recoverySync.done
+ }
+ c := make(chan struct{})
+ close(c)
+ return c
+}
+
+// walletLocker manages the locked/unlocked state of a wallet.
+func (w *Wallet) walletLocker() {
+ var timeout <-chan time.Time
+ holdChan := make(heldUnlock)
+ quit := w.quitChan()
+out:
+ for {
+ select {
+ case req := <-w.unlockRequests:
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ return w.addrStore.Unlock(
+ addrmgrNs, req.passphrase,
+ )
+ })
+ if err != nil {
+ req.err <- err
+ continue
+ }
+ timeout = req.lockAfter
+ if timeout == nil {
+ log.Info("The wallet has been unlocked without a time limit")
+ } else {
+ log.Info("The wallet has been temporarily unlocked")
+ }
+ req.err <- nil
+ continue
+
+ case req := <-w.changePassphrase:
+ err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ return w.addrStore.ChangePassphrase(
+ addrmgrNs, req.old, req.new, req.private,
+ &waddrmgr.DefaultScryptOptions,
+ )
+ })
+ req.err <- err
+ continue
+
+ case req := <-w.changePassphrases:
+ err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ err := w.addrStore.ChangePassphrase(
+ addrmgrNs, req.publicOld, req.publicNew,
+ false, &waddrmgr.DefaultScryptOptions,
+ )
+ if err != nil {
+ return err
+ }
+
+ return w.addrStore.ChangePassphrase(
+ addrmgrNs, req.privateOld, req.privateNew,
+ true, &waddrmgr.DefaultScryptOptions,
+ )
+ })
+ req.err <- err
+ continue
+
+ case req := <-w.holdUnlockRequests:
+ if w.addrStore.IsLocked() {
+ close(req)
+ continue
+ }
+
+ req <- holdChan
+ <-holdChan // Block until the lock is released.
+
+ // If, after holding onto the unlocked wallet for some
+ // time, the timeout has expired, lock it now instead
+ // of hoping it gets unlocked next time the top level
+ // select runs.
+ select {
+ case <-timeout:
+ // Let the top level select fallthrough so the
+ // wallet is locked.
+ default:
+ continue
+ }
+
+ case w.lockState <- w.addrStore.IsLocked():
+ continue
+
+ case <-quit:
+ break out
+
+ case <-w.lockRequests:
+ case <-timeout:
+ }
+
+ // Select statement fell through by an explicit lock or the
+ // timer expiring. Lock the manager here.
+
+ // We can't lock the manager if recovery is active because we use
+ // cryptoKeyPriv and cryptoKeyScript in recovery.
+ <-w.endRecovery()
+
+ timeout = nil
+
+ err := w.addrStore.Lock()
+ if err != nil && !waddrmgr.IsError(err, waddrmgr.ErrLocked) {
+ log.Errorf("Could not lock wallet: %v", err)
+ } else {
+ log.Info("The wallet has been locked")
+ }
+ }
+ w.wg.Done()
+}
+
+// Locked returns whether the account manager for a wallet is locked.
+func (w *Wallet) Locked() bool {
+ return <-w.lockState
+}
+
+// holdUnlock prevents the wallet from being locked. The heldUnlock object
+// *must* be released, or the wallet will forever remain unlocked.
+//
+// TODO: To prevent the above scenario, perhaps closures should be passed
+// to the walletLocker goroutine and disallow callers from explicitly
+
+// handling the locking mechanism.
+func (w *Wallet) holdUnlock() (heldUnlock, error) {
+ req := make(chan heldUnlock)
+ w.holdUnlockRequests <- req
+ hl, ok := <-req
+ if !ok {
+ // TODO(davec): This should be defined and exported from
+ // waddrmgr.
+ return nil, waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrLocked,
+ Description: "address manager is locked",
+ }
+ }
+ return hl, nil
+}
+
+// release releases the hold on the unlocked-state of the wallet and allows the
+// wallet to be locked again. If a lock timeout has already expired, the
+// wallet is locked again as soon as release is called.
+func (c heldUnlock) release() {
+ c <- struct{}{}
+}
+
+// ChangePrivatePassphrase attempts to change the passphrase for a wallet from
+// old to new. Changing the passphrase is synchronized with all other address
+// manager locking and unlocking. The lock state will be the same as it was
+// before the password change.
+func (w *Wallet) ChangePrivatePassphrase(old, new []byte) error {
+ err := make(chan error, 1)
+ w.changePassphrase <- changePassphraseRequest{
+ old: old,
+ new: new,
+ private: true,
+ err: err,
+ }
+ return <-err
+}
+
+// ChangePublicPassphrase modifies the public passphrase of the wallet.
+func (w *Wallet) ChangePublicPassphrase(old, new []byte) error {
+ err := make(chan error, 1)
+ w.changePassphrase <- changePassphraseRequest{
+ old: old,
+ new: new,
+ private: false,
+ err: err,
+ }
+ return <-err
+}
+
+// ChangePassphrases modifies the public and private passphrase of the wallet
+// atomically.
+func (w *Wallet) ChangePassphrases(publicOld, publicNew, privateOld,
+ privateNew []byte) error {
+
+ err := make(chan error, 1)
+ w.changePassphrases <- changePassphrasesRequest{
+ publicOld: publicOld,
+ publicNew: publicNew,
+ privateOld: privateOld,
+ privateNew: privateNew,
+ err: err,
+ }
+ return <-err
+}
+
+// AccountAddresses returns the addresses for every created address for an
+// account.
+func (w *Wallet) AccountAddresses(account uint32) (addrs []address.Address, err error) {
+ err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ return w.addrStore.ForEachAccountAddress(
+ addrmgrNs, account,
+ func(maddr waddrmgr.ManagedAddress) error {
+ addrs = append(addrs, maddr.Address())
+ return nil
+ })
+ })
+ return
+}
+
+// AccountManagedAddresses returns the managed addresses for every created
+// address for an account.
+func (w *Wallet) AccountManagedAddresses(scope waddrmgr.KeyScope,
+ accountNum uint32) ([]waddrmgr.ManagedAddress, error) {
+
+ scopedMgr, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ addrs := make([]waddrmgr.ManagedAddress, 0)
+
+ err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ return scopedMgr.ForEachAccountAddress(
+ addrmgrNs, accountNum,
+ func(a waddrmgr.ManagedAddress) error {
+ addrs = append(addrs, a)
+
+ return nil
+ },
+ )
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ return addrs, nil
+}
+
+// CalculateBalance sums the amounts of all unspent transaction
+// outputs to addresses of a wallet and returns the balance.
+//
+// If confirmations is 0, all UTXOs, even those not present in a
+// block (height -1), will be used to get the balance. Otherwise,
+// a UTXO must be in a block. If confirmations is 1 or greater,
+// the balance will be calculated based on how many how many blocks
+// include a UTXO.
+func (w *Wallet) CalculateBalance(confirms int32) (btcutil.Amount, error) {
+ var balance btcutil.Amount
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+ var err error
+
+ blk := w.addrStore.SyncedTo()
+ balance, err = w.txStore.Balance(txmgrNs, confirms, blk.Height)
+
+ return err
+ })
+ return balance, err
+}
+
+// Balances records total, spendable (by policy), and immature coinbase
+// reward balance amounts.
+type Balances struct {
+ Total btcutil.Amount
+ Spendable btcutil.Amount
+ ImmatureReward btcutil.Amount
+}
+
+// CalculateAccountBalances sums the amounts of all unspent transaction
+// outputs to the given account of a wallet and returns the balance.
+//
+// This function is much slower than it needs to be since transactions outputs
+// are not indexed by the accounts they credit to, and all unspent transaction
+// outputs must be iterated.
+func (w *Wallet) CalculateAccountBalances(account uint32, confirms int32) (Balances, error) {
+ var bals Balances
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ // Get current block. The block height used for calculating
+ // the number of tx confirmations.
+ syncBlock := w.addrStore.SyncedTo()
+
+ unspent, err := w.txStore.UnspentOutputs(txmgrNs)
+ if err != nil {
+ return err
+ }
+ for i := range unspent {
+ output := &unspent[i]
+
+ var outputAcct uint32
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(
+ output.PkScript, w.chainParams)
+ if err == nil && len(addrs) > 0 {
+ _, outputAcct, err = w.addrStore.AddrAccount(
+ addrmgrNs, addrs[0],
+ )
+ }
+ if err != nil || outputAcct != account {
+ continue
+ }
+
+ bals.Total += output.Amount
+ if output.FromCoinBase && !hasMinConfs(
+ uint32(w.chainParams.CoinbaseMaturity),
+ output.Height, syncBlock.Height,
+ ) {
+
+ bals.ImmatureReward += output.Amount
+ } else if hasMinConfs(
+ //nolint:gosec
+ uint32(confirms), output.Height,
+ syncBlock.Height,
+ ) {
+
+ bals.Spendable += output.Amount
+ }
+ }
+ return nil
+ })
+ return bals, err
+}
+
+// CurrentAddress gets the most recently requested Bitcoin payment address
+// from a wallet for a particular key-chain scope. If the address has already
+// been used (there is at least one transaction spending to it in the
+// blockchain or btcd mempool), the next chained address is returned.
+func (w *Wallet) CurrentAddress(account uint32, scope waddrmgr.KeyScope) (address.Address, error) {
+ chainClient, err := w.requireChainClient()
+ if err != nil {
+ return nil, err
+ }
+
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ // The address manager uses OnCommit on the walletdb tx to update the
+ // in-memory state of the account state. But because the commit happens
+ // _after_ the account manager internal lock has been released, there
+ // is a chance for the address index to be accessed concurrently, even
+ // though the closure in OnCommit re-acquires the lock. To avoid this
+ // issue, we surround the whole address creation process with a lock.
+ w.newAddrMtx.Lock()
+ defer w.newAddrMtx.Unlock()
+
+ var (
+ addr address.Address
+ props *waddrmgr.AccountProperties
+ )
+ err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ maddr, err := manager.LastExternalAddress(addrmgrNs, account)
+ if err != nil {
+ // If no address exists yet, create the first external
+ // address.
+ if waddrmgr.IsError(err, waddrmgr.ErrAddressNotFound) {
+ addr, props, err = w.newAddressDeprecated(
+ addrmgrNs, account, scope,
+ )
+ }
+ return err
+ }
+
+ // Get next chained address if the last one has already been
+ // used.
+ if maddr.Used(addrmgrNs) {
+ addr, props, err = w.newAddressDeprecated(
+ addrmgrNs, account, scope,
+ )
+ return err
+ }
+
+ addr = maddr.Address()
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // If the props have been initially, then we had to create a new address
+ // to satisfy the query. Notify the rpc server about the new address.
+ if props != nil {
+ err = chainClient.NotifyReceived([]address.Address{addr})
+ if err != nil {
+ return nil, err
+ }
+
+ w.NtfnServer.notifyAccountProperties(props)
+ }
+
+ return addr, nil
+}
+
+// PubKeyForAddress looks up the associated public key for a P2PKH address.
+func (w *Wallet) PubKeyForAddress(a address.Address) (*btcec.PublicKey, error) {
+ var pubKey *btcec.PublicKey
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ managedAddr, err := w.addrStore.Address(addrmgrNs, a)
+ if err != nil {
+ return err
+ }
+ managedPubKeyAddr, ok := managedAddr.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ return errors.New("address does not have an associated public key")
+ }
+ pubKey = managedPubKeyAddr.PubKey()
+ return nil
+ })
+ return pubKey, err
+}
+
+// LabelTransaction adds a label to the transaction with the hash provided. The
+// call will fail if the label is too long, or if the transaction already has
+// a label and the overwrite boolean is not set.
+func (w *Wallet) LabelTransaction(hash chainhash.Hash, label string,
+ overwrite bool) error {
+
+ // Check that the transaction is known to the wallet, and fail if it is
+ // unknown. If the transaction is known, check whether it already has
+ // a label.
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ dbTx, err := w.txStore.TxDetails(txmgrNs, &hash)
+ if err != nil {
+ return err
+ }
+
+ // If the transaction looked up is nil, it was not found. We
+ // do not allow labelling of unknown transactions so we fail.
+ if dbTx == nil {
+ return ErrUnknownTransaction
+ }
+
+ _, err = w.txStore.FetchTxLabel(txmgrNs, hash)
+ return err
+ })
+
+ switch err {
+ // If no labels have been written yet, we can silence the error.
+ // Likewise if there is no label, we do not need to do any overwrite
+ // checks.
+ case wtxmgr.ErrNoLabelBucket:
+ case wtxmgr.ErrTxLabelNotFound:
+
+ // If we successfully looked up a label, fail if the overwrite param
+ // is not set.
+ case nil:
+ if !overwrite {
+ return ErrTxLabelExists
+ }
+
+ // In another unrelated error occurred, return it.
+ default:
+ return err
+ }
+
+ return walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ txmgrNs := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+ return w.txStore.PutTxLabel(txmgrNs, hash, label)
+ })
+}
+
+// HaveAddress returns whether the wallet is the owner of the address a.
+func (w *Wallet) HaveAddress(a address.Address) (bool, error) {
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ _, err := w.addrStore.Address(addrmgrNs, a)
+ return err
+ })
+ if err == nil {
+ return true, nil
+ }
+ if waddrmgr.IsError(err, waddrmgr.ErrAddressNotFound) {
+ return false, nil
+ }
+ return false, err
+}
+
+// AccountOfAddress finds the account that an address is associated with.
+func (w *Wallet) AccountOfAddress(a address.Address) (uint32, error) {
+ var account uint32
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ var err error
+
+ _, account, err = w.addrStore.AddrAccount(addrmgrNs, a)
+ return err
+ })
+ return account, err
+}
+
+// DumpPrivKeys returns the WIF-encoded private keys for all addresses with
+// private keys in a wallet.
+func (w *Wallet) DumpPrivKeys() ([]string, error) {
+ var privkeys []string
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ // Iterate over each active address, appending the private key to
+ return w.addrStore.ForEachActiveAddress(
+ addrmgrNs, func(addr address.Address) error {
+ ma, err := w.addrStore.Address(addrmgrNs, addr)
+ if err != nil {
+ return err
+ }
+
+ // Only those addresses with keys needed.
+ pka, ok := ma.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ return nil
+ }
+
+ wif, err := pka.ExportPrivKey()
+ if err != nil {
+ // It would be nice to zero out the
+ // array here. However, since strings
+ // in go are immutable, and we have no
+ // control over the caller I don't
+ // think we can. :(
+ return err
+ }
+
+ privkeys = append(privkeys, wif.String())
+
+ return nil
+ })
+ })
+ return privkeys, err
+}
+
+// DumpWIFPrivateKey returns the WIF encoded private key for a
+// single wallet address.
+func (w *Wallet) DumpWIFPrivateKey(addr address.Address) (string, error) {
+ var maddr waddrmgr.ManagedAddress
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ waddrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ // Get private key from wallet if it exists.
+ var err error
+
+ maddr, err = w.addrStore.Address(waddrmgrNs, addr)
+ return err
+ })
+ if err != nil {
+ return "", err
+ }
+
+ pka, ok := maddr.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ return "", fmt.Errorf("address %s is not a key type", addr)
+ }
+
+ wif, err := pka.ExportPrivKey()
+ if err != nil {
+ return "", err
+ }
+ return wif.String(), nil
+}
+
+// LockOutpoint marks an outpoint as locked, that is, it should not be used as
+// an input for newly created transactions.
+func (w *Wallet) LockOutpoint(op wire.OutPoint) {
+ w.lockedOutpointsMtx.Lock()
+ defer w.lockedOutpointsMtx.Unlock()
+
+ w.lockedOutpoints[op] = struct{}{}
+}
+
+// UnlockOutpoint marks an outpoint as unlocked, that is, it may be used as an
+// input for newly created transactions.
+func (w *Wallet) UnlockOutpoint(op wire.OutPoint) {
+ w.lockedOutpointsMtx.Lock()
+ defer w.lockedOutpointsMtx.Unlock()
+
+ delete(w.lockedOutpoints, op)
+}
+
+// LockedOutpoint returns whether an outpoint has been marked as locked and
+// should not be used as an input for created transactions.
+func (w *Wallet) LockedOutpoint(op wire.OutPoint) bool {
+ w.lockedOutpointsMtx.Lock()
+ defer w.lockedOutpointsMtx.Unlock()
+
+ _, locked := w.lockedOutpoints[op]
+
+ return locked
+}
+
+// ResetLockedOutpoints resets the set of locked outpoints so all may be used
+// as inputs for new transactions.
+func (w *Wallet) ResetLockedOutpoints() {
+ w.lockedOutpointsMtx.Lock()
+ defer w.lockedOutpointsMtx.Unlock()
+
+ w.lockedOutpoints = map[wire.OutPoint]struct{}{}
+}
+
+// LockedOutpoints returns a slice of currently locked outpoints. This is
+// intended to be used by marshaling the result as a JSON array for
+// listlockunspent RPC results.
+func (w *Wallet) LockedOutpoints() []btcjson.TransactionInput {
+ w.lockedOutpointsMtx.Lock()
+ defer w.lockedOutpointsMtx.Unlock()
+
+ locked := make([]btcjson.TransactionInput, len(w.lockedOutpoints))
+ i := 0
+ for op := range w.lockedOutpoints {
+ locked[i] = btcjson.TransactionInput{
+ Txid: op.Hash.String(),
+ Vout: op.Index,
+ }
+ i++
+ }
+ return locked
+}
+
+// LeaseOutputDeprecated locks an output to the given ID, preventing it from
+// being available for coin selection. The absolute time of the lock's
+// expiration is
+// returned. The expiration of the lock can be extended by successive
+// invocations of this call.
+//
+// Outputs can be unlocked before their expiration through `UnlockOutput`.
+// Otherwise, they are unlocked lazily through calls which iterate through all
+// known outputs, e.g., `CalculateBalance`, `ListUnspent`.
+//
+// If the output is not known, ErrUnknownOutput is returned. If the output has
+// already been locked to a different ID, then ErrOutputAlreadyLocked is
+// returned.
+//
+// NOTE: This differs from LockOutpoint in that outputs are locked for a limited
+// amount of time and their locks are persisted to disk.
+//
+// Deprecated: Use UtxoManager.LeaseOutput instead.
+func (w *Wallet) LeaseOutputDeprecated(id wtxmgr.LockID, op wire.OutPoint,
+ duration time.Duration) (time.Time, error) {
+
+ var expiry time.Time
+ err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+ var err error
+
+ expiry, err = w.txStore.LockOutput(ns, id, op, duration)
+ return err
+ })
+ return expiry, err
+}
+
+// ReleaseOutputDeprecated unlocks an output, allowing it to be available for
+// coin selection if it remains unspent. The ID should match the one used to
+// originally lock the output.
+//
+// Deprecated: Use UtxoManager.ReleaseOutput instead.
+func (w *Wallet) ReleaseOutputDeprecated(
+ id wtxmgr.LockID, op wire.OutPoint) error {
+
+ return walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+ return w.txStore.UnlockOutput(ns, id, op)
+ })
+}
+
+// resendUnminedTxs iterates through all transactions that spend from wallet
+// credits that are not known to have been mined into a block, and attempts
+// to send each to the chain server for relay.
+func (w *Wallet) resendUnminedTxs() {
+ var txs []*wire.MsgTx
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+ var err error
+
+ txs, err = w.txStore.UnminedTxs(txmgrNs)
+ return err
+ })
+ if err != nil {
+ log.Errorf("Unable to retrieve unconfirmed transactions to "+
+ "resend: %v", err)
+ return
+ }
+
+ for _, tx := range txs {
+ txHash, err := w.publishTransaction(tx)
+ if err != nil {
+ log.Debugf("Unable to rebroadcast transaction %v: %v",
+ tx.TxHash(), err)
+ continue
+ }
+
+ log.Debugf("Successfully rebroadcast unconfirmed transaction %v",
+ txHash)
+ }
+}
+
+// SortedActivePaymentAddresses returns a slice of all active payment
+// addresses in a wallet.
+func (w *Wallet) SortedActivePaymentAddresses() ([]string, error) {
+ var addrStrs []string
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ return w.addrStore.ForEachActiveAddress(
+ addrmgrNs, func(addr address.Address) error {
+ addrStrs = append(
+ addrStrs, addr.EncodeAddress(),
+ )
+
+ return nil
+ })
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ sort.Strings(addrStrs)
+ return addrStrs, nil
+}
+
+// NewAddressDeprecated returns the next external chained address for a wallet.
+func (w *Wallet) NewAddressDeprecated(account uint32,
+ scope waddrmgr.KeyScope) (address.Address, error) {
+
+ chainClient, err := w.requireChainClient()
+ if err != nil {
+ return nil, err
+ }
+
+ // The address manager uses OnCommit on the walletdb tx to update the
+ // in-memory state of the account state. But because the commit happens
+ // _after_ the account manager internal lock has been released, there
+ // is a chance for the address index to be accessed concurrently, even
+ // though the closure in OnCommit re-acquires the lock. To avoid this
+ // issue, we surround the whole address creation process with a lock.
+ w.newAddrMtx.Lock()
+ defer w.newAddrMtx.Unlock()
+
+ var (
+ addr address.Address
+ props *waddrmgr.AccountProperties
+ )
+ err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ var err error
+
+ addr, props, err = w.newAddressDeprecated(
+ addrmgrNs, account, scope,
+ )
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Notify the rpc server about the newly created address.
+ err = chainClient.NotifyReceived([]address.Address{addr})
+ if err != nil {
+ return nil, err
+ }
+
+ w.NtfnServer.notifyAccountProperties(props)
+
+ return addr, nil
+}
+
+// NewChangeAddress returns a new change address for a wallet.
+func (w *Wallet) NewChangeAddress(account uint32,
+ scope waddrmgr.KeyScope) (address.Address, error) {
+
+ chainClient, err := w.requireChainClient()
+ if err != nil {
+ return nil, err
+ }
+
+ // The address manager uses OnCommit on the walletdb tx to update the
+ // in-memory state of the account state. But because the commit happens
+ // _after_ the account manager internal lock has been released, there
+ // is a chance for the address index to be accessed concurrently, even
+ // though the closure in OnCommit re-acquires the lock. To avoid this
+ // issue, we surround the whole address creation process with a lock.
+ w.newAddrMtx.Lock()
+ defer w.newAddrMtx.Unlock()
+
+ var addr address.Address
+ err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ var err error
+ addr, err = w.newChangeAddress(addrmgrNs, account, scope)
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Notify the rpc server about the newly created address.
+ err = chainClient.NotifyReceived([]address.Address{addr})
+ if err != nil {
+ return nil, err
+ }
+
+ return addr, nil
+}
+
+// newChangeAddress returns a new change address for the wallet.
+//
+// NOTE: This method requires the caller to use the backend's NotifyReceived
+// method in order to detect when an on-chain transaction pays to the address
+// being created.
+func (w *Wallet) newChangeAddress(addrmgrNs walletdb.ReadWriteBucket,
+ account uint32, scope waddrmgr.KeyScope) (address.Address, error) {
+
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ // Get next chained change address from wallet for account.
+ addrs, err := manager.NextInternalAddresses(addrmgrNs, account, 1)
+ if err != nil {
+ return nil, err
+ }
+
+ return addrs[0].Address(), nil
+}
+
+// AccountTotalReceivedResult is a single result for the
+// Wallet.TotalReceivedForAccounts method.
+type AccountTotalReceivedResult struct {
+ AccountNumber uint32
+ AccountName string
+ TotalReceived btcutil.Amount
+ LastConfirmation int32
+}
+
+// TotalReceivedForAccounts iterates through a wallet's transaction history,
+// returning the total amount of Bitcoin received for all accounts.
+func (w *Wallet) TotalReceivedForAccounts(scope waddrmgr.KeyScope,
+ minConf int32) ([]AccountTotalReceivedResult, error) {
+
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ var results []AccountTotalReceivedResult
+ err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ syncBlock := w.addrStore.SyncedTo()
+
+ err := manager.ForEachAccount(addrmgrNs, func(account uint32) error {
+ accountName, err := manager.AccountName(addrmgrNs, account)
+ if err != nil {
+ return err
+ }
+ results = append(results, AccountTotalReceivedResult{
+ AccountNumber: account,
+ AccountName: accountName,
+ })
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+
+ var stopHeight int32
+
+ if minConf > 0 {
+ stopHeight = syncBlock.Height - minConf + 1
+ } else {
+ stopHeight = -1
+ }
+
+ //nolint:lll
+ rangeFn := func(details []wtxmgr.TxDetails) (bool, error) {
+ for i := range details {
+ detail := &details[i]
+ for _, cred := range detail.Credits {
+ pkScript := detail.MsgTx.TxOut[cred.Index].PkScript
+ var outputAcct uint32
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(pkScript, w.chainParams)
+ if err == nil && len(addrs) > 0 {
+ _, outputAcct, err = w.addrStore.AddrAccount(addrmgrNs, addrs[0])
+ }
+ if err == nil {
+ acctIndex := int(outputAcct)
+ if outputAcct == waddrmgr.ImportedAddrAccount {
+ acctIndex = len(results) - 1
+ }
+ res := &results[acctIndex]
+ res.TotalReceived += cred.Amount
+
+ confs := calcConf(
+ detail.Block.Height,
+ syncBlock.Height,
+ )
+ res.LastConfirmation = confs
+ }
+ }
+ }
+ return false, nil
+ }
+
+ return w.txStore.RangeTransactions(
+ txmgrNs, 0, stopHeight, rangeFn,
+ )
+ })
+ return results, err
+}
+
+// TotalReceivedForAddr iterates through a wallet's transaction history,
+// returning the total amount of bitcoins received for a single wallet
+// address.
+func (w *Wallet) TotalReceivedForAddr(addr address.Address, minConf int32) (btcutil.Amount, error) {
+ var amount btcutil.Amount
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ syncBlock := w.addrStore.SyncedTo()
+
+ var (
+ addrStr = addr.EncodeAddress()
+ stopHeight int32
+ )
+
+ if minConf > 0 {
+ stopHeight = syncBlock.Height - minConf + 1
+ } else {
+ stopHeight = -1
+ }
+ rangeFn := func(details []wtxmgr.TxDetails) (bool, error) {
+ for i := range details {
+ detail := &details[i]
+ for _, cred := range detail.Credits {
+ pkScript := detail.MsgTx.TxOut[cred.Index].PkScript
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(pkScript,
+ w.chainParams)
+ // An error creating addresses from the output script only
+ // indicates a non-standard script, so ignore this credit.
+ if err != nil {
+ continue
+ }
+ for _, a := range addrs {
+ if addrStr == a.EncodeAddress() {
+ amount += cred.Amount
+ break
+ }
+ }
+ }
+ }
+ return false, nil
+ }
+
+ return w.txStore.RangeTransactions(
+ txmgrNs, 0, stopHeight, rangeFn,
+ )
+ })
+ return amount, err
+}
+
+// SendOutputs creates and sends payment transactions. Coin selection is
+// performed by the wallet, choosing inputs that belong to the given key scope
+// and account, unless a key scope is not specified. In that case, inputs from
+// accounts matching the account number provided across all key scopes may be
+// selected. This is done to handle the default account case, where a user wants
+// to fund a PSBT with inputs regardless of their type (NP2WKH, P2WKH, etc.). It
+// returns the transaction upon success.
+func (w *Wallet) SendOutputs(outputs []*wire.TxOut, keyScope *waddrmgr.KeyScope,
+ account uint32, minconf int32, satPerKb btcutil.Amount,
+ coinSelectionStrategy CoinSelectionStrategy, label string) (*wire.MsgTx,
+ error) {
+
+ return w.sendOutputs(
+ outputs, keyScope, account, minconf, satPerKb,
+ coinSelectionStrategy, label,
+ )
+}
+
+// SendOutputsWithInput creates and sends payment transactions using the
+// provided selected utxos. It returns the transaction upon success.
+func (w *Wallet) SendOutputsWithInput(outputs []*wire.TxOut,
+ keyScope *waddrmgr.KeyScope,
+ account uint32, minconf int32, satPerKb btcutil.Amount,
+ coinSelectionStrategy CoinSelectionStrategy, label string,
+ selectedUtxos []wire.OutPoint) (*wire.MsgTx, error) {
+
+ return w.sendOutputs(outputs, keyScope, account, minconf, satPerKb,
+ coinSelectionStrategy, label, selectedUtxos...)
+}
+
+// sendOutputs creates and sends payment transactions. It returns the
+// transaction upon success.
+func (w *Wallet) sendOutputs(outputs []*wire.TxOut, keyScope *waddrmgr.KeyScope,
+ account uint32, minconf int32, satPerKb btcutil.Amount,
+ coinSelectionStrategy CoinSelectionStrategy, label string,
+ selectedUtxos ...wire.OutPoint) (*wire.MsgTx, error) {
+
+ // If the key scope wasn't specified, then we'll default to the BIP0084
+ // key scope for this account.
+ if keyScope == nil {
+ keyScope = &waddrmgr.KeyScopeBIP0084
+ }
+
+ // Create a transaction which spends from the wallet.
+ var (
+ tx *txauthor.AuthoredTx
+ err error
+ )
+ // We'll specify the WithCustomSelectUtxos functional option if we were
+ // passed a set of utxos to spend.
+ var opts []TxCreateOption
+ if len(selectedUtxos) != 0 {
+ opts = append(opts, WithCustomSelectUtxos(selectedUtxos))
+ }
+
+ tx, err = w.CreateSimpleTx(
+ keyScope, account, outputs, minconf, satPerKb,
+ coinSelectionStrategy, false, opts...,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // If there is a label we should write, get the namespace key
+ // and record it in the tx store.
+ //
+ // TODO(yy): We should remove this `label` parameter from the function
+ // signature and instead let the caller use `LabelTransaction` to label
+ // the transaction after it's been published.
+ if len(label) != 0 {
+ err := walletdb.Update(w.db, func(txmgr walletdb.ReadWriteTx) error {
+ ns := txmgr.ReadWriteBucket(wtxmgrNamespaceKey)
+ return w.txStore.PutTxLabel(ns, tx.Tx.TxHash(), label)
+ })
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // And publish it.
+ return nil, w.PublishTransaction(tx.Tx, label)
+}
+
+// SignatureError records the underlying error when validating a transaction
+// input signature.
+type SignatureError struct {
+ InputIndex uint32
+ Error error
+}
+
+// SignTransaction uses secrets of the wallet, as well as additional secrets
+// passed in by the caller, to create and add input signatures to a transaction.
+//
+// Transaction input script validation is used to confirm that all signatures
+// are valid. For any invalid input, a SignatureError is added to the returns.
+// The final error return is reserved for unexpected or fatal errors, such as
+// being unable to determine a previous output script to redeem.
+//
+// The transaction pointed to by tx is modified by this function.
+func (w *Wallet) SignTransaction(tx *wire.MsgTx, hashType txscript.SigHashType,
+ additionalPrevScripts map[wire.OutPoint][]byte,
+ additionalKeysByAddress map[string]*btcutil.WIF,
+ p2shRedeemScriptsByAddress map[string][]byte) ([]SignatureError, error) {
+
+ var signErrors []SignatureError
+ err := walletdb.View(w.db, func(dbtx walletdb.ReadTx) error {
+ addrmgrNs := dbtx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
+
+ inputFetcher := txscript.NewMultiPrevOutFetcher(nil)
+ for i, txIn := range tx.TxIn {
+ prevOutScript, ok := additionalPrevScripts[txIn.PreviousOutPoint]
+ if !ok {
+ prevHash := &txIn.PreviousOutPoint.Hash
+ prevIndex := txIn.PreviousOutPoint.Index
+
+ txDetails, err := w.txStore.TxDetails(
+ txmgrNs, prevHash,
+ )
+ if err != nil {
+ return fmt.Errorf("cannot query previous transaction "+
+ "details for %v: %w", txIn.PreviousOutPoint, err)
+ }
+ if txDetails == nil {
+ return fmt.Errorf("%v not found",
+ txIn.PreviousOutPoint)
+ }
+ prevOutScript = txDetails.MsgTx.TxOut[prevIndex].PkScript
+ }
+ inputFetcher.AddPrevOut(txIn.PreviousOutPoint, &wire.TxOut{
+ PkScript: prevOutScript,
+ })
+
+ // Set up our callbacks that we pass to txscript so it can
+ // look up the appropriate keys and scripts by address.
+ //
+ //nolint:lll
+ getKey := txscript.KeyClosure(func(addr address.Address) (*btcec.PrivateKey, bool, error) {
+ if len(additionalKeysByAddress) != 0 {
+ addrStr := addr.EncodeAddress()
+ wif, ok := additionalKeysByAddress[addrStr]
+ if !ok {
+ return nil, false,
+ errors.New("no key for address")
+ }
+ return wif.PrivKey, wif.CompressPubKey, nil
+ }
+
+ address, err := w.addrStore.Address(addrmgrNs, addr)
+ if err != nil {
+ return nil, false, err
+ }
+
+ pka, ok := address.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ return nil, false, fmt.Errorf("address %v is not "+
+ "a pubkey address", address.Address().EncodeAddress())
+ }
+
+ key, err := pka.PrivKey()
+ if err != nil {
+ return nil, false, err
+ }
+
+ return key, pka.Compressed(), nil
+ })
+ //nolint:lll
+ getScript := txscript.ScriptClosure(func(addr address.Address) ([]byte, error) {
+ // If keys were provided then we can only use the
+ // redeem scripts provided with our inputs, too.
+ if len(additionalKeysByAddress) != 0 {
+ addrStr := addr.EncodeAddress()
+ script, ok := p2shRedeemScriptsByAddress[addrStr]
+ if !ok {
+ return nil, errors.New("no script for address")
+ }
+ return script, nil
+ }
+
+ address, err := w.addrStore.Address(addrmgrNs, addr)
+ if err != nil {
+ return nil, err
+ }
+
+ // If we found the address, we check to see if it's
+ // a p2sh address, if so, then we'll verify that it
+ // is one that we know the redeem script for.
+ shAddr, ok := address.(waddrmgr.ManagedScriptAddress)
+ if !ok {
+ return nil, errors.New("address is not a " +
+ "p2sh address")
+ }
+
+ return shAddr.Script()
+ })
+
+ // SigHashSingle inputs can only be signed if there's a
+ // corresponding output. However this could be already signed,
+ // so we always verify the output.
+ if (hashType&txscript.SigHashSingle) !=
+ txscript.SigHashSingle || i < len(tx.TxOut) {
+
+ script, err := txscript.SignTxOutput(w.ChainParams(),
+ tx, i, prevOutScript, hashType, getKey,
+ getScript, txIn.SignatureScript)
+ // Failure to sign isn't an error, it just means that
+ // the tx isn't complete.
+ if err != nil {
+ signErrors = append(signErrors, SignatureError{
+ InputIndex: uint32(i),
+ Error: err,
+ })
+ continue
+ }
+ txIn.SignatureScript = script
+ }
+
+ // Either it was already signed or we just signed it.
+ // Find out if it is completely satisfied or still needs more.
+ vm, err := txscript.NewEngine(
+ prevOutScript, tx, i,
+ txscript.StandardVerifyFlags, nil, nil, 0,
+ inputFetcher,
+ )
+ if err == nil {
+ err = vm.Execute()
+ }
+ if err != nil {
+ signErrors = append(signErrors, SignatureError{
+ InputIndex: uint32(i),
+ Error: err,
+ })
+ }
+ }
+ return nil
+ })
+ return signErrors, err
+}
+
+// ErrDoubleSpend is an error returned from PublishTransaction in case the
+// published transaction failed to propagate since it was double spending a
+// confirmed transaction or a transaction in the mempool.
+type ErrDoubleSpend struct {
+ backendError error
+}
+
+// Error returns the string representation of ErrDoubleSpend.
+//
+// NOTE: Satisfies the error interface.
+func (e *ErrDoubleSpend) Error() string {
+ return fmt.Sprintf("double spend: %v", e.backendError)
+}
+
+// Unwrap returns the underlying error returned from the backend.
+func (e *ErrDoubleSpend) Unwrap() error {
+ return e.backendError
+}
+
+// ErrMempoolFee is an error returned from PublishTransaction in case the
+// published transaction failed to propagate since it did not match the
+// current mempool fee requirement.
+type ErrMempoolFee struct {
+ backendError error
+}
+
+// Error returns the string representation of ErrMempoolFee.
+//
+// NOTE: Satisfies the error interface.
+func (e *ErrMempoolFee) Error() string {
+ return fmt.Sprintf("mempool fee not met: %v", e.backendError)
+}
+
+// Unwrap returns the underlying error returned from the backend.
+func (e *ErrMempoolFee) Unwrap() error {
+ return e.backendError
+}
+
+// ErrAlreadyConfirmed is an error returned from PublishTransaction in case
+// a transaction is already confirmed in the blockchain.
+type ErrAlreadyConfirmed struct {
+ backendError error
+}
+
+// Error returns the string representation of ErrAlreadyConfirmed.
+//
+// NOTE: Satisfies the error interface.
+func (e *ErrAlreadyConfirmed) Error() string {
+ return fmt.Sprintf("tx already confirmed: %v", e.backendError)
+}
+
+// Unwrap returns the underlying error returned from the backend.
+func (e *ErrAlreadyConfirmed) Unwrap() error {
+ return e.backendError
+}
+
+// ErrInMempool is an error returned from PublishTransaction in case a
+// transaction is already in the mempool.
+type ErrInMempool struct {
+ backendError error
+}
+
+// Error returns the string representation of ErrInMempool.
+//
+// NOTE: Satisfies the error interface.
+func (e *ErrInMempool) Error() string {
+ return fmt.Sprintf("tx already in mempool: %v", e.backendError)
+}
+
+// Unwrap returns the underlying error returned from the backend.
+func (e *ErrInMempool) Unwrap() error {
+ return e.backendError
+}
+
+// PublishTransaction sends the transaction to the consensus RPC server so it
+// can be propagated to other nodes and eventually mined.
+//
+// This function is unstable and will be removed once syncing code is moved out
+// of the wallet.
+func (w *Wallet) PublishTransaction(tx *wire.MsgTx, label string) error {
+ _, err := w.reliablyPublishTransaction(tx, label)
+ return err
+}
+
+// reliablyPublishTransaction is a superset of publishTransaction which contains
+// the primary logic required for publishing a transaction, updating the
+// relevant database state, and finally possible removing the transaction from
+// the database (along with cleaning up all inputs used, and outputs created) if
+// the transaction is rejected by the backend.
+func (w *Wallet) reliablyPublishTransaction(tx *wire.MsgTx,
+ label string) (*chainhash.Hash, error) {
+
+ chainClient, err := w.requireChainClient()
+ if err != nil {
+ return nil, err
+ }
+
+ // As we aim for this to be general reliable transaction broadcast API,
+ // we'll write this tx to disk as an unconfirmed transaction. This way,
+ // upon restarts, we'll always rebroadcast it, and also add it to our
+ // set of records.
+ txRec, err := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
+ if err != nil {
+ return nil, err
+ }
+
+ // Along the way, we'll extract our relevant destination addresses from
+ // the transaction.
+ var ourAddrs []address.Address
+ err = walletdb.Update(w.db, func(dbTx walletdb.ReadWriteTx) error {
+ addrmgrNs := dbTx.ReadWriteBucket(waddrmgrNamespaceKey)
+ for _, txOut := range tx.TxOut {
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(
+ txOut.PkScript, w.chainParams,
+ )
+ if err != nil {
+ // Non-standard outputs can safely be skipped
+ // because they're not supported by the wallet.
+ log.Warnf("Non-standard pkScript=%x in tx=%v",
+ txOut.PkScript, tx.TxHash())
+
+ continue
+ }
+ for _, addr := range addrs {
+ // Skip any addresses which are not relevant to
+ // us.
+ _, err := w.addrStore.Address(addrmgrNs, addr)
+ if waddrmgr.IsError(err, waddrmgr.ErrAddressNotFound) {
+ continue
+ }
+ if err != nil {
+ return err
+ }
+ ourAddrs = append(ourAddrs, addr)
+ }
+ }
+
+ // If there is a label we should write, get the namespace key
+ // and record it in the tx store.
+ if len(label) != 0 {
+ txmgrNs := dbTx.ReadWriteBucket(wtxmgrNamespaceKey)
+
+ err = w.txStore.PutTxLabel(txmgrNs, tx.TxHash(), label)
+ if err != nil {
+ return err
+ }
+ }
+
+ return w.addRelevantTx(dbTx, txRec, nil)
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // We'll also ask to be notified of the transaction once it confirms
+ // on-chain. This is done outside of the database transaction to prevent
+ // backend interaction within it.
+ if err := chainClient.NotifyReceived(ourAddrs); err != nil {
+ return nil, err
+ }
+
+ return w.publishTransaction(tx)
+}
+
+// publishTransaction attempts to send an unconfirmed transaction to the
+// wallet's current backend. In the event that sending the transaction fails for
+// whatever reason, it will be removed from the wallet's unconfirmed transaction
+// store.
+func (w *Wallet) publishTransaction(tx *wire.MsgTx) (*chainhash.Hash, error) {
+ chainClient, err := w.requireChainClient()
+ if err != nil {
+ return nil, err
+ }
+
+ txid := tx.TxHash()
+ _, rpcErr := chainClient.SendRawTransaction(tx, false)
+ if rpcErr == nil {
+ return &txid, nil
+ }
+
+ switch {
+ case errors.Is(rpcErr, chain.ErrTxAlreadyInMempool):
+ log.Infof("%v: tx already in mempool", txid)
+ return &txid, nil
+
+ case errors.Is(rpcErr, chain.ErrTxAlreadyKnown),
+ errors.Is(rpcErr, chain.ErrTxAlreadyConfirmed):
+
+ dbErr := walletdb.Update(w.db, func(dbTx walletdb.ReadWriteTx) error {
+ txmgrNs := dbTx.ReadWriteBucket(wtxmgrNamespaceKey)
+ txRec, err := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
+ if err != nil {
+ return err
+ }
+
+ return w.txStore.RemoveUnminedTx(txmgrNs, txRec)
+ })
+ if dbErr != nil {
+ log.Warnf("Unable to remove confirmed transaction %v "+
+ "from unconfirmed store: %v", tx.TxHash(), dbErr)
+ }
+
+ log.Infof("%v: tx already confirmed", txid)
+
+ return &txid, nil
+
+ }
+
+ // Log the causing error, even if we know how to handle it.
+ log.Infof("%v: broadcast failed because of: %v", txid, rpcErr)
+
+ // If the transaction was rejected for whatever other reason, then
+ // we'll remove it from the transaction store, as otherwise, we'll
+ // attempt to continually re-broadcast it, and the UTXO state of the
+ // wallet won't be accurate.
+ dbErr := walletdb.Update(w.db, func(dbTx walletdb.ReadWriteTx) error {
+ txmgrNs := dbTx.ReadWriteBucket(wtxmgrNamespaceKey)
+ txRec, err := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
+ if err != nil {
+ return err
+ }
+
+ return w.txStore.RemoveUnminedTx(txmgrNs, txRec)
+ })
+ if dbErr != nil {
+ log.Warnf("Unable to remove invalid transaction %v: %v",
+ tx.TxHash(), dbErr)
+ } else {
+ log.Infof("Removed invalid transaction: %v", tx.TxHash())
+
+ // The serialized transaction is for logging only, don't fail
+ // on the error.
+ var txRaw bytes.Buffer
+ _ = tx.Serialize(&txRaw)
+
+ // Optionally log the tx in debug when the size is manageable.
+ if txRaw.Len() < 1_000_000 {
+ log.Debugf("Removed invalid transaction: %v \n hex=%x",
+ newLogClosure(func() string {
+ return spew.Sdump(tx)
+ }), txRaw.Bytes())
+ } else {
+ log.Debug("Removed invalid transaction due to size " +
+ "too large")
+ }
+ }
+
+ return nil, rpcErr
+}
+
+// CreateDeprecated creates an new wallet, writing it to an empty database.
+// If the passed root key is non-nil, it is used. Otherwise, a secure
+// random seed of the recommended length is generated.
+//
+// Deprecated: Use wallet.Create instead.
+func CreateDeprecated(db walletdb.DB, pubPass, privPass []byte,
+ rootKey *hdkeychain.ExtendedKey, params *chaincfg.Params,
+ birthday time.Time) error {
+
+ return create(
+ db, pubPass, privPass, rootKey, params, birthday, false, nil,
+ )
+}
+
+// AccountNumber returns the account number for an account name under a
+// particular key scope.
+func (w *Wallet) AccountNumber(scope waddrmgr.KeyScope, accountName string) (uint32, error) {
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return 0, err
+ }
+
+ var account uint32
+ err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ var err error
+ account, err = manager.LookupAccount(addrmgrNs, accountName)
+ return err
+ })
+ return account, err
+}
+
+// AccountName returns the name of an account.
+func (w *Wallet) AccountName(scope waddrmgr.KeyScope, accountNumber uint32) (string, error) {
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return "", err
+ }
+
+ var accountName string
+ err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ var err error
+ accountName, err = manager.AccountName(addrmgrNs, accountNumber)
+ return err
+ })
+ return accountName, err
+}
+
+// AccountProperties returns the properties of an account, including address
+// indexes and name. It first fetches the desynced information from the address
+// manager, then updates the indexes based on the address pools.
+func (w *Wallet) AccountProperties(scope waddrmgr.KeyScope, acct uint32) (*waddrmgr.AccountProperties, error) {
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ var props *waddrmgr.AccountProperties
+ err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ waddrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ var err error
+ props, err = manager.AccountProperties(waddrmgrNs, acct)
+ return err
+ })
+ return props, err
+}
+
+// AccountPropertiesByName returns the properties of an account by its name. It
+// first fetches the desynced information from the address manager, then updates
+// the indexes based on the address pools.
+func (w *Wallet) AccountPropertiesByName(scope waddrmgr.KeyScope,
+ name string) (*waddrmgr.AccountProperties, error) {
+
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ var props *waddrmgr.AccountProperties
+ err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ waddrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ acct, err := manager.LookupAccount(waddrmgrNs, name)
+ if err != nil {
+ return err
+ }
+ props, err = manager.AccountProperties(waddrmgrNs, acct)
+ return err
+ })
+ return props, err
+}
+
+// LookupAccount returns the corresponding key scope and account number for the
+// account with the given name.
+func (w *Wallet) LookupAccount(name string) (waddrmgr.KeyScope, uint32, error) {
+ var (
+ keyScope waddrmgr.KeyScope
+ account uint32
+ )
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ ns := tx.ReadBucket(waddrmgrNamespaceKey)
+ var err error
+
+ keyScope, account, err = w.addrStore.LookupAccount(ns, name)
+ return err
+ })
+
+ return keyScope, account, err
+}
+
+// CreditCategory describes the type of wallet transaction output. The category
+// of "sent transactions" (debits) is always "send", and is not expressed by
+// this type.
+//
+// TODO: This is a requirement of the RPC server and should be moved.
+type CreditCategory byte
+
+// These constants define the possible credit categories.
+const (
+ CreditReceive CreditCategory = iota
+ CreditGenerate
+ CreditImmature
+)
+
+// String returns the category as a string. This string may be used as the
+// JSON string for categories as part of listtransactions and gettransaction
+// RPC responses.
+func (c CreditCategory) String() string {
+ switch c {
+ case CreditReceive:
+ return "receive"
+ case CreditGenerate:
+ return "generate"
+ case CreditImmature:
+ return "immature"
+ default:
+ return "unknown"
+ }
+}
+
+// RecvCategory returns the category of received credit outputs from a
+// transaction record. The passed block chain height is used to distinguish
+// immature from mature coinbase outputs.
+//
+// TODO: This is intended for use by the RPC server and should be moved out of
+// this package at a later time.
+func RecvCategory(details *wtxmgr.TxDetails, syncHeight int32, net *chaincfg.Params) CreditCategory {
+ if blockchain.IsCoinBaseTx(&details.MsgTx) {
+ if hasMinConfs(uint32(net.CoinbaseMaturity),
+ details.Block.Height, syncHeight) {
+
+ return CreditGenerate
+ }
+ return CreditImmature
+ }
+ return CreditReceive
+}
+
+// listTransactions creates a object that may be marshalled to a response result
+// for a listtransactions RPC.
+//
+// TODO: This should be moved to the legacyrpc package.
+func listTransactions(tx walletdb.ReadTx, details *wtxmgr.TxDetails,
+ addrMgr waddrmgr.AddrStore, syncHeight int32,
+ net *chaincfg.Params) []btcjson.ListTransactionsResult {
+
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ var (
+ blockHashStr string
+ blockTime int64
+ confirmations int64
+ )
+ if details.Block.Height != -1 {
+ blockHashStr = details.Block.Hash.String()
+ blockTime = details.Block.Time.Unix()
+ confirmations = int64(
+ calcConf(details.Block.Height, syncHeight),
+ )
+ }
+
+ results := []btcjson.ListTransactionsResult{}
+ txHashStr := details.Hash.String()
+ received := details.Received.Unix()
+ generated := blockchain.IsCoinBaseTx(&details.MsgTx)
+ recvCat := RecvCategory(details, syncHeight, net).String()
+
+ send := len(details.Debits) != 0
+
+ // Fee can only be determined if every input is a debit.
+ var feeF64 float64
+ if len(details.Debits) == len(details.MsgTx.TxIn) {
+ var debitTotal btcutil.Amount
+ for _, deb := range details.Debits {
+ debitTotal += deb.Amount
+ }
+ var outputTotal btcutil.Amount
+ for _, output := range details.MsgTx.TxOut {
+ outputTotal += btcutil.Amount(output.Value)
+ }
+ // Note: The actual fee is debitTotal - outputTotal. However,
+ // this RPC reports negative numbers for fees, so the inverse
+ // is calculated.
+ feeF64 = (outputTotal - debitTotal).ToBTC()
+ }
+
+outputs:
+ for i, output := range details.MsgTx.TxOut {
+ // Determine if this output is a credit, and if so, determine
+ // its spentness.
+ var isCredit bool
+ var spentCredit bool
+ for _, cred := range details.Credits {
+ if cred.Index == uint32(i) {
+ // Change outputs are ignored.
+ if cred.Change {
+ continue outputs
+ }
+
+ isCredit = true
+ spentCredit = cred.Spent
+ break
+ }
+ }
+
+ var address string
+ var accountName string
+ _, addrs, _, _ := txscript.ExtractPkScriptAddrs(output.PkScript, net)
+ if len(addrs) == 1 {
+ addr := addrs[0]
+ address = addr.EncodeAddress()
+ mgr, account, err := addrMgr.AddrAccount(addrmgrNs, addrs[0])
+ if err == nil {
+ accountName, err = mgr.AccountName(addrmgrNs, account)
+ if err != nil {
+ accountName = ""
+ }
+ }
+ }
+
+ amountF64 := btcutil.Amount(output.Value).ToBTC()
+ result := btcjson.ListTransactionsResult{
+ // Fields left zeroed:
+ // InvolvesWatchOnly
+ // BlockIndex
+ //
+ // Fields set below:
+ // Account (only for non-"send" categories)
+ // Category
+ // Amount
+ // Fee
+ Address: address,
+ Vout: uint32(i),
+ Confirmations: confirmations,
+ Generated: generated,
+ BlockHash: blockHashStr,
+ BlockTime: blockTime,
+ TxID: txHashStr,
+ WalletConflicts: []string{},
+ Time: received,
+ TimeReceived: received,
+ }
+
+ // Add a received/generated/immature result if this is a credit.
+ // If the output was spent, create a second result under the
+ // send category with the inverse of the output amount. It is
+ // therefore possible that a single output may be included in
+ // the results set zero, one, or two times.
+ //
+ // Since credits are not saved for outputs that are not
+ // controlled by this wallet, all non-credits from transactions
+ // with debits are grouped under the send category.
+
+ if send || spentCredit {
+ result.Category = "send"
+ result.Amount = -amountF64
+ result.Fee = &feeF64
+ results = append(results, result)
+ }
+ if isCredit {
+ result.Account = accountName
+ result.Category = recvCat
+ result.Amount = amountF64
+ result.Fee = nil
+ results = append(results, result)
+ }
+ }
+ return results
+}
+
+// ListSinceBlock returns a slice of objects with details about transactions
+// since the given block. If the block is -1 then all transactions are included.
+// This is intended to be used for listsinceblock RPC replies.
+func (w *Wallet) ListSinceBlock(start, end, syncHeight int32) ([]btcjson.ListTransactionsResult, error) {
+ txList := []btcjson.ListTransactionsResult{}
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ rangeFn := func(details []wtxmgr.TxDetails) (bool, error) {
+ for _, detail := range details {
+ detail := detail
+
+ jsonResults := listTransactions(
+ tx, &detail, w.addrStore, syncHeight,
+ w.chainParams,
+ )
+ txList = append(txList, jsonResults...)
+ }
+ return false, nil
+ }
+
+ return w.txStore.RangeTransactions(txmgrNs, start, end, rangeFn)
+ })
+ return txList, err
+}
+
+// ListTransactions returns a slice of objects with details about a recorded
+// transaction. This is intended to be used for listtransactions RPC
+// replies.
+func (w *Wallet) ListTransactions(from, count int) ([]btcjson.ListTransactionsResult, error) {
+ txList := []btcjson.ListTransactionsResult{}
+
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ // Get current block. The block height used for calculating
+ // the number of tx confirmations.
+ syncBlock := w.addrStore.SyncedTo()
+
+ // Need to skip the first from transactions, and after those, only
+ // include the next count transactions.
+ skipped := 0
+ n := 0
+
+ rangeFn := func(details []wtxmgr.TxDetails) (bool, error) {
+ // Iterate over transactions at this height in reverse order.
+ // This does nothing for unmined transactions, which are
+ // unsorted, but it will process mined transactions in the
+ // reverse order they were marked mined.
+ for i := len(details) - 1; i >= 0; i-- {
+ if from > skipped {
+ skipped++
+ continue
+ }
+
+ n++
+ if n > count {
+ return true, nil
+ }
+
+ jsonResults := listTransactions(
+ tx, &details[i], w.addrStore,
+ syncBlock.Height, w.chainParams,
+ )
+ txList = append(txList, jsonResults...)
+
+ if len(jsonResults) > 0 {
+ n++
+ }
+ }
+
+ return false, nil
+ }
+
+ // Return newer results first by starting at mempool height and working
+ // down to the genesis block.
+ return w.txStore.RangeTransactions(txmgrNs, -1, 0, rangeFn)
+ })
+ return txList, err
+}
+
+// ListAddressTransactions returns a slice of objects with details about
+// recorded transactions to or from any address belonging to a set. This is
+// intended to be used for listaddresstransactions RPC replies.
+func (w *Wallet) ListAddressTransactions(pkHashes map[string]struct{}) ([]btcjson.ListTransactionsResult, error) {
+ txList := []btcjson.ListTransactionsResult{}
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ // Get current block. The block height used for calculating
+ // the number of tx confirmations.
+ syncBlock := w.addrStore.SyncedTo()
+ rangeFn := func(details []wtxmgr.TxDetails) (bool, error) {
+ loopDetails:
+ for i := range details {
+ detail := &details[i]
+
+ for _, cred := range detail.Credits {
+ pkScript := detail.MsgTx.TxOut[cred.Index].PkScript
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(
+ pkScript, w.chainParams)
+ if err != nil || len(addrs) != 1 {
+ continue
+ }
+ apkh, ok := addrs[0].(*address.AddressPubKeyHash)
+ if !ok {
+ continue
+ }
+ _, ok = pkHashes[string(apkh.ScriptAddress())]
+ if !ok {
+ continue
+ }
+
+ jsonResults := listTransactions(
+ tx, detail, w.addrStore,
+ syncBlock.Height, w.chainParams,
+ )
+ txList = append(txList, jsonResults...)
+ continue loopDetails
+ }
+ }
+ return false, nil
+ }
+
+ return w.txStore.RangeTransactions(txmgrNs, 0, -1, rangeFn)
+ })
+ return txList, err
+}
+
+// ListAllTransactions returns a slice of objects with details about a recorded
+// transaction. This is intended to be used for listalltransactions RPC
+// replies.
+func (w *Wallet) ListAllTransactions() ([]btcjson.ListTransactionsResult, error) {
+ txList := []btcjson.ListTransactionsResult{}
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ // Get current block. The block height used for calculating
+ // the number of tx confirmations.
+ syncBlock := w.addrStore.SyncedTo()
+
+ rangeFn := func(details []wtxmgr.TxDetails) (bool, error) {
+ // Iterate over transactions at this height in reverse order.
+ // This does nothing for unmined transactions, which are
+ // unsorted, but it will process mined transactions in the
+ // reverse order they were marked mined.
+ for i := len(details) - 1; i >= 0; i-- {
+ jsonResults := listTransactions(
+ tx, &details[i], w.addrStore,
+ syncBlock.Height, w.chainParams,
+ )
+ txList = append(txList, jsonResults...)
+ }
+ return false, nil
+ }
+
+ // Return newer results first by starting at mempool height and
+ // working down to the genesis block.
+ return w.txStore.RangeTransactions(txmgrNs, -1, 0, rangeFn)
+ })
+ return txList, err
+}
+
+// GetTransactions returns transaction results between a starting and ending
+// block. Blocks in the block range may be specified by either a height or a
+// hash.
+//
+// Because this is a possibly lenghtly operation, a cancel channel is provided
+// to cancel the task. If this channel unblocks, the results created thus far
+// will be returned.
+//
+// Transaction results are organized by blocks in ascending order and unmined
+// transactions in an unspecified order. Mined transactions are saved in a
+// Block structure which records properties about the block.
+func (w *Wallet) GetTransactions(startBlock, endBlock *BlockIdentifier,
+ accountName string, cancel <-chan struct{}) (*GetTransactionsResult, error) {
+
+ var start, end int32 = 0, -1
+
+ w.chainClientLock.Lock()
+ chainClient := w.chainClient
+ w.chainClientLock.Unlock()
+
+ // TODO: Fetching block heights by their hashes is inherently racy
+ // because not all block headers are saved but when they are for SPV the
+ // db can be queried directly without this.
+ if startBlock != nil {
+ if startBlock.hash == nil {
+ start = startBlock.height
+ } else {
+ if chainClient == nil {
+ return nil, errors.New("no chain server client")
+ }
+ switch client := chainClient.(type) {
+ case *chain.RPCClient:
+ startHeader, err := client.GetBlockHeaderVerbose(
+ startBlock.hash,
+ )
+ if err != nil {
+ return nil, err
+ }
+ start = startHeader.Height
+ case *chain.BitcoindClient:
+ var err error
+ start, err = client.GetBlockHeight(startBlock.hash)
+ if err != nil {
+ return nil, err
+ }
+ case *chain.NeutrinoClient:
+ var err error
+ start, err = client.GetBlockHeight(startBlock.hash)
+ if err != nil {
+ return nil, err
+ }
+ }
+ }
+ }
+ if endBlock != nil {
+ if endBlock.hash == nil {
+ end = endBlock.height
+ } else {
+ if chainClient == nil {
+ return nil, errors.New("no chain server client")
+ }
+ switch client := chainClient.(type) {
+ case *chain.RPCClient:
+ endHeader, err := client.GetBlockHeaderVerbose(
+ endBlock.hash,
+ )
+ if err != nil {
+ return nil, err
+ }
+ end = endHeader.Height
+ case *chain.BitcoindClient:
+ var err error
+ start, err = client.GetBlockHeight(endBlock.hash)
+ if err != nil {
+ return nil, err
+ }
+ case *chain.NeutrinoClient:
+ var err error
+ end, err = client.GetBlockHeight(endBlock.hash)
+ if err != nil {
+ return nil, err
+ }
+ }
+ }
+ }
+
+ var res GetTransactionsResult
+ err := walletdb.View(w.db, func(dbtx walletdb.ReadTx) error {
+ txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
+
+ rangeFn := func(details []wtxmgr.TxDetails) (bool, error) {
+ // TODO: probably should make RangeTransactions not reuse the
+ // details backing array memory.
+ dets := make([]wtxmgr.TxDetails, len(details))
+ copy(dets, details)
+ details = dets
+
+ txs := make([]TransactionSummary, 0, len(details))
+ for i := range details {
+ txs = append(txs, makeTxSummary(dbtx, w, &details[i]))
+ }
+
+ if details[0].Block.Height != -1 {
+ blockHash := details[0].Block.Hash
+ res.MinedTransactions = append(res.MinedTransactions, Block{
+ Hash: &blockHash,
+ Height: details[0].Block.Height,
+ Timestamp: details[0].Block.Time.Unix(),
+ Transactions: txs,
+ })
+ } else {
+ res.UnminedTransactions = txs
+ }
+
+ select {
+ case <-cancel:
+ return true, nil
+ default:
+ return false, nil
+ }
+ }
+
+ return w.txStore.RangeTransactions(txmgrNs, start, end, rangeFn)
+ })
+ return &res, err
+}
+
+// GetTransaction returns detailed data of a transaction given its id. In
+// addition it returns properties about its block.
+func (w *Wallet) GetTransaction(txHash chainhash.Hash) (*GetTransactionResult,
+ error) {
+
+ var res GetTransactionResult
+ err := walletdb.View(w.db, func(dbtx walletdb.ReadTx) error {
+ txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
+
+ txDetail, err := w.txStore.TxDetails(txmgrNs, &txHash)
+ if err != nil {
+ return err
+ }
+
+ // If the transaction was not found we return an error.
+ if txDetail == nil {
+ return fmt.Errorf("%w: txid %v", ErrNoTx, txHash)
+ }
+
+ res = GetTransactionResult{
+ Summary: makeTxSummary(dbtx, w, txDetail),
+ BlockHash: nil,
+ Height: -1,
+ Confirmations: 0,
+ Timestamp: 0,
+ }
+
+ // If it is a confirmed transaction we set the corresponding
+ // block height, timestamp, hash, and confirmations.
+ if txDetail.Block.Height != -1 {
+ res.Height = txDetail.Block.Height
+ res.Timestamp = txDetail.Block.Time.Unix()
+ res.BlockHash = &txDetail.Block.Hash
+
+ bestBlock := w.SyncedTo()
+ blockHeight := txDetail.Block.Height
+ res.Confirmations = calcConf(
+ blockHeight, bestBlock.Height,
+ )
+ }
+
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+ return &res, nil
+}
+
+// AccountBalances returns all accounts in the wallet and their balances.
+// Balances are determined by excluding transactions that have not met
+// requiredConfs confirmations.
+func (w *Wallet) AccountBalances(scope waddrmgr.KeyScope,
+ requiredConfs int32) ([]AccountBalanceResult, error) {
+
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ var results []AccountBalanceResult
+ err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ syncBlock := w.addrStore.SyncedTo()
+
+ // Fill out all account info except for the balances.
+ lastAcct, err := manager.LastAccount(addrmgrNs)
+ if err != nil {
+ return err
+ }
+ results = make([]AccountBalanceResult, lastAcct+2)
+ for i := range results[:len(results)-1] {
+ accountName, err := manager.AccountName(addrmgrNs, uint32(i))
+ if err != nil {
+ return err
+ }
+ results[i].AccountNumber = uint32(i)
+ results[i].AccountName = accountName
+ }
+ results[len(results)-1].AccountNumber = waddrmgr.ImportedAddrAccount
+ results[len(results)-1].AccountName = waddrmgr.ImportedAddrAccountName
+
+ // Fetch all unspent outputs, and iterate over them tallying each
+ // account's balance where the output script pays to an account address
+ // and the required number of confirmations is met.
+ unspentOutputs, err := w.txStore.UnspentOutputs(txmgrNs)
+ if err != nil {
+ return err
+ }
+ for i := range unspentOutputs {
+ output := &unspentOutputs[i]
+ if !hasMinConfs(
+ //nolint:gosec
+ uint32(requiredConfs), output.Height,
+ syncBlock.Height,
+ ) {
+
+ continue
+ }
+
+ if output.FromCoinBase && !hasMinConfs(
+ uint32(w.ChainParams().CoinbaseMaturity),
+ output.Height, syncBlock.Height,
+ ) {
+
+ continue
+ }
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(output.PkScript, w.chainParams)
+ if err != nil || len(addrs) == 0 {
+ continue
+ }
+ outputAcct, err := manager.AddrAccount(addrmgrNs, addrs[0])
+ if err != nil {
+ continue
+ }
+ switch {
+ case outputAcct == waddrmgr.ImportedAddrAccount:
+ results[len(results)-1].AccountBalance += output.Amount
+ case outputAcct > lastAcct:
+ return errors.New("waddrmgr.Manager.AddrAccount returned account " +
+ "beyond recorded last account")
+ default:
+ results[outputAcct].AccountBalance += output.Amount
+ }
+ }
+ return nil
+ })
+ return results, err
+}
+
+// ListUnspentDeprecated returns a slice of objects representing the
+// unspent wallet transactions fitting the given criteria. The confirmations
+// will be more than
+// minconf, less than maxconf and if addresses is populated only the addresses
+// contained within it will be considered. If we know nothing about a
+// transaction an empty array will be returned.
+//
+// Deprecated: Use UtxoManager.ListUnspent instead.
+//
+//nolint:funlen
+func (w *Wallet) ListUnspentDeprecated(minconf, maxconf int32,
+ accountName string) ([]*btcjson.ListUnspentResult, error) {
+
+ var results []*btcjson.ListUnspentResult
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ syncBlock := w.addrStore.SyncedTo()
+
+ filter := accountName != ""
+
+ unspent, err := w.txStore.UnspentOutputs(txmgrNs)
+ if err != nil {
+ return err
+ }
+ sort.Sort(sort.Reverse(creditSlice(unspent)))
+
+ defaultAccountName := "default"
+
+ results = make([]*btcjson.ListUnspentResult, 0, len(unspent))
+ for i := range unspent {
+ output := unspent[i]
+
+ // Outputs with fewer confirmations than the minimum or
+ // more confs than the maximum are excluded.
+ confs := calcConf(output.Height, syncBlock.Height)
+ if confs < minconf || confs > maxconf {
+ continue
+ }
+
+ // Only mature coinbase outputs are included.
+ if output.FromCoinBase {
+ target := uint32(
+ w.ChainParams().CoinbaseMaturity,
+ )
+ if !hasMinConfs(
+ target, output.Height, syncBlock.Height,
+ ) {
+
+ continue
+ }
+ }
+
+ // Exclude locked outputs from the result set.
+ if w.LockedOutpoint(output.OutPoint) {
+ continue
+ }
+
+ // Lookup the associated account for the output. Use the
+ // default account name in case there is no associated account
+ // for some reason, although this should never happen.
+ //
+ // This will be unnecessary once transactions and outputs are
+ // grouped under the associated account in the db.
+ outputAcctName := defaultAccountName
+ sc, addrs, _, err := txscript.ExtractPkScriptAddrs(
+ output.PkScript, w.chainParams)
+ if err != nil {
+ continue
+ }
+ if len(addrs) > 0 {
+ smgr, acct, err := w.addrStore.AddrAccount(
+ addrmgrNs, addrs[0],
+ )
+ if err == nil {
+ s, err := smgr.AccountName(addrmgrNs, acct)
+ if err == nil {
+ outputAcctName = s
+ }
+ }
+ }
+
+ if filter && outputAcctName != accountName {
+ continue
+ }
+
+ // At the moment watch-only addresses are not supported, so all
+ // recorded outputs that are not multisig are "spendable".
+ // Multisig outputs are only "spendable" if all keys are
+ // controlled by this wallet.
+ //
+ // TODO: Each case will need updates when watch-only addrs
+ // is added. For P2PK, P2PKH, and P2SH, the address must be
+ // looked up and not be watching-only. For multisig, all
+ // pubkeys must belong to the manager with the associated
+ // private key (currently it only checks whether the pubkey
+ // exists, since the private key is required at the moment).
+ var spendable bool
+ scSwitch:
+ switch sc {
+ case txscript.PubKeyHashTy:
+ spendable = true
+ case txscript.PubKeyTy:
+ spendable = true
+ case txscript.WitnessV0ScriptHashTy:
+ spendable = true
+ case txscript.WitnessV0PubKeyHashTy:
+ spendable = true
+ case txscript.MultiSigTy:
+ for _, a := range addrs {
+ _, err := w.addrStore.Address(
+ addrmgrNs, a,
+ )
+ if err == nil {
+ continue
+ }
+ if waddrmgr.IsError(err, waddrmgr.ErrAddressNotFound) {
+ break scSwitch
+ }
+ return err
+ }
+ spendable = true
+ }
+
+ result := &btcjson.ListUnspentResult{
+ TxID: output.OutPoint.Hash.String(),
+ Vout: output.OutPoint.Index,
+ Account: outputAcctName,
+ ScriptPubKey: hex.EncodeToString(output.PkScript),
+ Amount: output.Amount.ToBTC(),
+ Confirmations: int64(confs),
+ Spendable: spendable,
+ }
+
+ // BUG: this should be a JSON array so that all
+ // addresses can be included, or removed (and the
+ // caller extracts addresses from the pkScript).
+ if len(addrs) > 0 {
+ result.Address = addrs[0].EncodeAddress()
+ }
+
+ results = append(results, result)
+ }
+ return nil
+ })
+ return results, err
+}
+
+// ListLeasedOutputsDeprecated returns a list of objects representing the
+// currently locked utxos.
+//
+// Deprecated: Use UtxoManager.ListLeasedOutputs instead.
+func (w *Wallet) ListLeasedOutputsDeprecated() (
+ []*ListLeasedOutputResult, error) {
+
+ var results []*ListLeasedOutputResult
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ ns := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ outputs, err := w.txStore.ListLockedOutputs(ns)
+ if err != nil {
+ return err
+ }
+
+ for _, output := range outputs {
+ details, err := w.txStore.TxDetails(
+ ns, &output.Outpoint.Hash,
+ )
+ if err != nil {
+ return err
+ }
+
+ if details == nil {
+ log.Infof("unable to find tx details for "+
+ "%v:%v", output.Outpoint.Hash,
+ output.Outpoint.Index)
+ continue
+ }
+
+ txOut := details.MsgTx.TxOut[output.Outpoint.Index]
+
+ result := &ListLeasedOutputResult{
+ LockedOutput: output,
+ Value: txOut.Value,
+ PkScript: txOut.PkScript,
+ }
+
+ results = append(results, result)
+ }
+
+ return nil
+ })
+ return results, err
+}
+
+// BlockIdentifier identifies a block by either a height or a hash.
+type BlockIdentifier struct {
+ height int32
+ hash *chainhash.Hash
+}
+
+// NewBlockIdentifierFromHeight constructs a BlockIdentifier for a block height.
+func NewBlockIdentifierFromHeight(height int32) *BlockIdentifier {
+ return &BlockIdentifier{height: height}
+}
+
+// NewBlockIdentifierFromHash constructs a BlockIdentifier for a block hash.
+func NewBlockIdentifierFromHash(hash *chainhash.Hash) *BlockIdentifier {
+ return &BlockIdentifier{hash: hash}
+}
+
+// GetTransactionsResult is the result of the wallet's GetTransactions method.
+// See GetTransactions for more details.
+type GetTransactionsResult struct {
+ MinedTransactions []Block
+ UnminedTransactions []TransactionSummary
+}
+
+// GetTransactionResult returns a summary of the transaction along with
+// other block properties.
+type GetTransactionResult struct {
+ Summary TransactionSummary
+ Height int32
+ BlockHash *chainhash.Hash
+ Confirmations int32
+ Timestamp int64
+}
+
+// AccountBalanceResult is a single result for the Wallet.AccountBalances method.
+type AccountBalanceResult struct {
+ AccountNumber uint32
+ AccountName string
+ AccountBalance btcutil.Amount
+}
+
+// creditSlice satisifies the sort.Interface interface to provide sorting
+// transaction credits from oldest to newest. Credits with the same receive
+// time and mined in the same block are not guaranteed to be sorted by the order
+// they appear in the block. Credits from the same transaction are sorted by
+// output index.
+type creditSlice []wtxmgr.Credit
+
+func (s creditSlice) Len() int {
+ return len(s)
+}
+
+func (s creditSlice) Less(i, j int) bool {
+ switch {
+ // If both credits are from the same tx, sort by output index.
+ case s[i].OutPoint.Hash == s[j].OutPoint.Hash:
+ return s[i].OutPoint.Index < s[j].OutPoint.Index
+
+ // If both transactions are unmined, sort by their received date.
+ case s[i].Height == -1 && s[j].Height == -1:
+ return s[i].Received.Before(s[j].Received)
+
+ // Unmined (newer) txs always come last.
+ case s[i].Height == -1:
+ return false
+ case s[j].Height == -1:
+ return true
+
+ // If both txs are mined in different blocks, sort by block height.
+ default:
+ return s[i].Height < s[j].Height
+ }
+}
+
+func (s creditSlice) Swap(i, j int) {
+ s[i], s[j] = s[j], s[i]
+}
+
+// ListLeasedOutputResult is a single result for the Wallet.ListLeasedOutputs method.
+// See that method for more details.
+type ListLeasedOutputResult struct {
+ *wtxmgr.LockedOutput
+ Value int64
+ PkScript []byte
+}
+
+func (w *Wallet) newAddressDeprecated(addrmgrNs walletdb.ReadWriteBucket,
+ account uint32, scope waddrmgr.KeyScope) (address.Address,
+ *waddrmgr.AccountProperties, error) {
+
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Get next address from wallet.
+ addrs, err := manager.NextExternalAddresses(addrmgrNs, account, 1)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ props, err := manager.AccountProperties(addrmgrNs, account)
+ if err != nil {
+ log.Errorf("Cannot fetch account properties for notification "+
+ "after deriving next external address: %v", err)
+
+ return nil, nil, err
+ }
+
+ return addrs[0].Address(), props, nil
+}
+
+// AddScopeManager creates a new scoped key manager from the root manager.
+func (w *Wallet) AddScopeManager(scope waddrmgr.KeyScope,
+ addrSchema waddrmgr.ScopeAddrSchema) (
+ waddrmgr.AccountStore, error) {
+
+ var scopedManager waddrmgr.AccountStore
+
+ err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ manager, err := w.addrStore.NewScopedKeyManager(
+ addrmgrNs, scope, addrSchema,
+ )
+ scopedManager = manager
+
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return scopedManager, nil
+}
+
+// InitAccounts creates a number of accounts specified by `num`, with account
+// number ranges from 1 to `num`.
+func (w *Wallet) InitAccounts(scope *waddrmgr.ScopedKeyManager,
+ watchOnly bool, num uint32) error {
+
+ return walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ // Generate all accounts that we could ever need. This includes
+ // all key families.
+ for account := uint32(1); account <= num; account++ {
+ // Otherwise, we'll check if the account already exists,
+ // if so, we can once again bail early.
+ _, err := scope.AccountName(addrmgrNs, account)
+ if err == nil {
+ continue
+ }
+
+ // If we reach this point, then the account hasn't yet
+ // been created, so we'll need to create it before we
+ // can proceed.
+ err = scope.NewRawAccount(addrmgrNs, account)
+ if err != nil {
+ return err
+ }
+ }
+
+ // If this is the first startup with remote signing and wallet
+ // migration turned on and the wallet wasn't previously
+ // migrated, we can do that now that we made sure all accounts
+ // that we need were derived correctly.
+ if watchOnly {
+ log.Infof("Migrating wallet to watch-only mode, " +
+ "purging all private key material")
+
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ return w.addrStore.ConvertToWatchingOnly(ns)
+ }
+
+ return nil
+ })
+}
+
+// DeriveFromKeyPath derives a private key using the given derivation path.
+func (w *Wallet) DeriveFromKeyPath(scope waddrmgr.KeyScope,
+ path waddrmgr.DerivationPath) (*btcec.PrivateKey, error) {
+
+ scopedMgr, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, fmt.Errorf("error fetching manager for scope %v: "+
+ "%w", scope, err)
+ }
+
+ // Let's see if we can hit the private key cache.
+ privKey, err := scopedMgr.DeriveFromKeyPathCache(path)
+ if err == nil {
+ return privKey, nil
+ }
+
+ // The key wasn't in the cache, let's fully derive it now.
+ err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ addr, err := scopedMgr.DeriveFromKeyPath(addrmgrNs, path)
+ if err != nil {
+ return fmt.Errorf("error deriving private key: %w", err)
+ }
+
+ mpka, ok := addr.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ err := fmt.Errorf("managed address type for %v is "+
+ "`%T` but want waddrmgr.ManagedPubKeyAddress",
+ addr, addr)
+
+ return err
+ }
+
+ privKey, err = mpka.PrivKey()
+
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return privKey, nil
+}
+
+// DeriveFromKeyPathAddAccount derives a private key using the given derivation
+// path. The account will be created if it doesn't exist.
+func (w *Wallet) DeriveFromKeyPathAddAccount(scope waddrmgr.KeyScope,
+ path waddrmgr.DerivationPath) (*btcec.PrivateKey, error) {
+
+ scopedMgr, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, fmt.Errorf("error fetching manager for scope %v: "+
+ "%w", scope, err)
+ }
+
+ // Let's see if we can hit the private key cache.
+ privKey, err := scopedMgr.DeriveFromKeyPathCache(path)
+ if err == nil {
+ return privKey, nil
+ }
+
+ derivePrivKey := func(addrmgrNs walletdb.ReadWriteBucket) error {
+ addr, err := scopedMgr.DeriveFromKeyPath(addrmgrNs, path)
+
+ // Exit early if there's no error.
+ if err == nil {
+ key, ok := addr.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ return nil
+ }
+
+ // Overwrite the returned private key variable.
+ privKey, err = key.PrivKey()
+
+ return err
+ }
+
+ return err
+ }
+
+ // The key wasn't in the cache, let's fully derive it now.
+ err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ err := derivePrivKey(addrmgrNs)
+
+ // Exit early if there's no error.
+ if err == nil {
+ return nil
+ }
+
+ // Exit with the error if it's not account not found.
+ if !waddrmgr.IsError(err, waddrmgr.ErrAccountNotFound) {
+ return fmt.Errorf("error deriving private key: %w", err)
+ }
+
+ // If we've reached this point, then the account doesn't yet
+ // exist, so we'll create it now to ensure we can sign.
+ err = scopedMgr.NewRawAccount(addrmgrNs, path.Account)
+ if err != nil {
+ return err
+ }
+
+ // Now that we know the account exists, we'll attempt to
+ // re-derive the private key.
+ return derivePrivKey(addrmgrNs)
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return privKey, nil
+}
+
+func (w *Wallet) handleChainNotifications() {
+ defer w.wg.Done()
+
+ chainClient, err := w.requireChainClient()
+ if err != nil {
+ log.Errorf("handleChainNotifications called without RPC client")
+ return
+ }
+
+ catchUpHashes := func(w *Wallet, client chain.Interface,
+ height int32) error {
+ // TODO(aakselrod): There's a race condition here, which
+ // happens when a reorg occurs between the
+ // rescanProgress notification and the last GetBlockHash
+ // call. The solution when using btcd is to make btcd
+ // send blockconnected notifications with each block
+ // the way Neutrino does, and get rid of the loop. The
+ // other alternative is to check the final hash and,
+ // if it doesn't match the original hash returned by
+ // the notification, to roll back and restart the
+ // rescan.
+ log.Infof("Catching up block hashes to height %d, this"+
+ " might take a while", height)
+ err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ startBlock := w.addrStore.SyncedTo()
+
+ for i := startBlock.Height + 1; i <= height; i++ {
+ hash, err := client.GetBlockHash(int64(i))
+ if err != nil {
+ return err
+ }
+ header, err := chainClient.GetBlockHeader(hash)
+ if err != nil {
+ return err
+ }
+
+ bs := waddrmgr.BlockStamp{
+ Height: i,
+ Hash: *hash,
+ Timestamp: header.Timestamp,
+ }
+
+ err = w.addrStore.SetSyncedTo(ns, &bs)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ if err != nil {
+ log.Errorf("Failed to update address manager "+
+ "sync state for height %d: %v", height, err)
+ }
+
+ log.Info("Done catching up block hashes")
+ return err
+ }
+
+ waitForSync := func(birthdayBlock *waddrmgr.BlockStamp) error {
+ // We start with a retry delay of 0 to execute the first attempt
+ // immediately.
+ var retryDelay time.Duration
+ for {
+ select {
+ case <-time.After(retryDelay):
+ // Set the delay to the configured value in case
+ // we actually need to re-try.
+ retryDelay = w.syncRetryInterval
+
+ // Sync may be interrupted by actions such as
+ // locking the wallet. Try again after waiting a
+ // bit.
+ err = w.syncWithChain(birthdayBlock)
+ if err != nil {
+ if w.ShuttingDown() {
+ return ErrWalletShuttingDown
+ }
+
+ log.Errorf("Unable to synchronize "+
+ "wallet to chain, trying "+
+ "again in %s: %v",
+ w.syncRetryInterval, err)
+
+ continue
+ }
+
+ return nil
+
+ case <-w.quitChan():
+ return ErrWalletShuttingDown
+ }
+ }
+ }
+
+ for {
+ select {
+ case n, ok := <-chainClient.Notifications():
+ if !ok {
+ return
+ }
+
+ var notificationName string
+ var err error
+ switch n := n.(type) {
+ case chain.ClientConnected:
+ // Before attempting to sync with our backend,
+ // we'll make sure that our birthday block has
+ // been set correctly to potentially prevent
+ // missing relevant events.
+ birthdayStore := &walletBirthdayStore{
+ db: w.db,
+ manager: w.addrStore,
+ }
+ birthdayBlock, err := birthdaySanityCheck(
+ chainClient, birthdayStore,
+ )
+ if err != nil && !waddrmgr.IsError(
+ err, waddrmgr.ErrBirthdayBlockNotSet,
+ ) {
+
+ log.Errorf("Unable to sanity check "+
+ "wallet birthday block: %v",
+ err)
+ }
+
+ err = waitForSync(birthdayBlock)
+ if err != nil {
+ log.Infof("Stopped waiting for wallet "+
+ "sync due to error: %v", err)
+
+ return
+ }
+
+ case chain.BlockConnected:
+ err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ return w.connectBlock(tx, wtxmgr.BlockMeta(n))
+ })
+ notificationName = "block connected"
+ case chain.BlockDisconnected:
+ err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ return w.disconnectBlock(tx, wtxmgr.BlockMeta(n))
+ })
+ notificationName = "block disconnected"
+ case chain.RelevantTx:
+ err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ return w.addRelevantTx(tx, n.TxRecord, n.Block)
+ })
+ notificationName = "relevant transaction"
+ case chain.FilteredBlockConnected:
+ // Atomically update for the whole block.
+ if len(n.RelevantTxs) > 0 {
+ err = walletdb.Update(w.db, func(
+ tx walletdb.ReadWriteTx) error {
+ var err error
+ for _, rec := range n.RelevantTxs {
+ err = w.addRelevantTx(tx, rec,
+ n.Block)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ }
+ notificationName = "filtered block connected"
+
+ // The following require some database maintenance, but also
+ // need to be reported to the wallet's rescan goroutine.
+ case *chain.RescanProgress:
+ err = catchUpHashes(w, chainClient, n.Height)
+ notificationName = "rescan progress"
+ select {
+ case w.rescanNotifications <- n:
+ case <-w.quitChan():
+ return
+ }
+ case *chain.RescanFinished:
+ err = catchUpHashes(w, chainClient, n.Height)
+ notificationName = "rescan finished"
+ w.SetChainSynced(true)
+ select {
+ case w.rescanNotifications <- n:
+ case <-w.quitChan():
+ return
+ }
+ }
+ if err != nil {
+ // If we received a block connected notification
+ // while rescanning, then we can ignore logging
+ // the error as we'll properly catch up once we
+ // process the RescanFinished notification.
+ if notificationName == "block connected" &&
+ waddrmgr.IsError(err, waddrmgr.ErrBlockNotFound) &&
+ !w.ChainSynced() {
+
+ log.Debugf("Received block connected "+
+ "notification for height %v "+
+ "while rescanning",
+ n.(chain.BlockConnected).Height)
+ continue
+ }
+
+ log.Errorf("Unable to process chain backend "+
+ "%v notification: %v", notificationName,
+ err)
+ }
+ case <-w.quit:
+ return
+ }
+ }
+}
+
+// connectBlock handles a chain server notification by marking a wallet
+// that's currently in-sync with the chain server as being synced up to
+// the passed block.
+func (w *Wallet) connectBlock(dbtx walletdb.ReadWriteTx, b wtxmgr.BlockMeta) error {
+ addrmgrNs := dbtx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ bs := waddrmgr.BlockStamp{
+ Height: b.Height,
+ Hash: b.Hash,
+ Timestamp: b.Time,
+ }
+
+ err := w.addrStore.SetSyncedTo(addrmgrNs, &bs)
+ if err != nil {
+ return err
+ }
+
+ // Notify interested clients of the connected block.
+ //
+ // TODO: move all notifications outside of the database transaction.
+ w.NtfnServer.notifyAttachedBlock(dbtx, &b)
+ return nil
+}
+
+// disconnectBlock handles a chain server reorganize by rolling back all
+// block history from the reorged block for a wallet in-sync with the chain
+// server.
+func (w *Wallet) disconnectBlock(dbtx walletdb.ReadWriteTx, b wtxmgr.BlockMeta) error {
+ addrmgrNs := dbtx.ReadWriteBucket(waddrmgrNamespaceKey)
+ txmgrNs := dbtx.ReadWriteBucket(wtxmgrNamespaceKey)
+
+ if !w.ChainSynced() {
+ return nil
+ }
+
+ // Disconnect the removed block and all blocks after it if we know about
+ // the disconnected block. Otherwise, the block is in the future.
+ //nolint:nestif
+ if b.Height <= w.addrStore.SyncedTo().Height {
+ hash, err := w.addrStore.BlockHash(addrmgrNs, b.Height)
+ if err != nil {
+ return err
+ }
+ if bytes.Equal(hash[:], b.Hash[:]) {
+ bs := waddrmgr.BlockStamp{
+ Height: b.Height - 1,
+ }
+
+ hash, err = w.addrStore.BlockHash(addrmgrNs, bs.Height)
+ if err != nil {
+ return err
+ }
+ b.Hash = *hash
+
+ client := w.ChainClient()
+ header, err := client.GetBlockHeader(hash)
+ if err != nil {
+ return err
+ }
+
+ bs.Timestamp = header.Timestamp
+
+ err = w.addrStore.SetSyncedTo(addrmgrNs, &bs)
+ if err != nil {
+ return err
+ }
+
+ err = w.txStore.Rollback(txmgrNs, b.Height)
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ // Notify interested clients of the disconnected block.
+ w.NtfnServer.notifyDetachedBlock(&b.Hash)
+
+ return nil
+}
+
+func (w *Wallet) addRelevantTx(dbtx walletdb.ReadWriteTx, rec *wtxmgr.TxRecord,
+ block *wtxmgr.BlockMeta) error {
+
+ addrmgrNs := dbtx.ReadWriteBucket(waddrmgrNamespaceKey)
+ txmgrNs := dbtx.ReadWriteBucket(wtxmgrNamespaceKey)
+
+ // At the moment all notified transactions are assumed to actually be
+ // relevant. This assumption will not hold true when SPV support is
+ // added, but until then, simply insert the transaction because there
+ // should either be one or more relevant inputs or outputs.
+ exists, err := w.txStore.InsertTxCheckIfExists(txmgrNs, rec, block)
+ if err != nil {
+ return err
+ }
+
+ // If the transaction has already been recorded, we can return early.
+ // Note: Returning here is safe as we're within the context of an atomic
+ // database transaction, so we don't need to worry about the MarkUsed
+ // calls below.
+ if exists {
+ return nil
+ }
+
+ // Check every output to determine whether it is controlled by a wallet
+ // key. If so, mark the output as a credit.
+ for i, output := range rec.MsgTx.TxOut {
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(output.PkScript,
+ w.chainParams)
+ if err != nil {
+ // Non-standard outputs are skipped.
+ log.Warnf("Cannot extract non-std pkScript=%x",
+ output.PkScript)
+
+ continue
+ }
+
+ for _, addr := range addrs {
+ ma, err := w.addrStore.Address(addrmgrNs, addr)
+
+ switch {
+ // Missing addresses are skipped.
+ case waddrmgr.IsError(err, waddrmgr.ErrAddressNotFound):
+ continue
+
+ // Other errors should be propagated.
+ case err != nil:
+ return err
+ }
+
+ // Prevent addresses from non-default scopes to be
+ // detected here. We don't watch funds sent to
+ // non-default scopes in other places either, so
+ // detecting them here would mean we'd also not properly
+ // detect them as spent later.
+ scopedManager, _, err := w.addrStore.AddrAccount(
+ addrmgrNs, addr,
+ )
+ if err != nil {
+ return err
+ }
+ if !waddrmgr.IsDefaultScope(scopedManager.Scope()) {
+ log.Debugf("Skipping non-default scope "+
+ "address %v", addr)
+
+ continue
+ }
+
+ // TODO: Credits should be added with the
+ // account they belong to, so wtxmgr is able to
+ // track per-account balances.
+ err = w.txStore.AddCredit(
+ txmgrNs, rec, block, uint32(i), ma.Internal(),
+ )
+ if err != nil {
+ return err
+ }
+
+ err = w.addrStore.MarkUsed(addrmgrNs, addr)
+ if err != nil {
+ return err
+ }
+ log.Debugf("Marked address %v used", addr)
+ }
+ }
+
+ // Send notification of mined or unmined transaction to any interested
+ // clients.
+ //
+ // TODO: Avoid the extra db hits.
+ if block == nil {
+ w.NtfnServer.notifyUnminedTransaction(dbtx, txmgrNs, rec.Hash)
+ } else {
+ w.NtfnServer.notifyMinedTransaction(
+ dbtx, txmgrNs, rec.Hash, block,
+ )
+ }
+
+ return nil
+}
+
+// chainConn is an interface that abstracts the chain connection logic required
+// to perform a wallet's birthday block sanity check.
+type chainConn interface {
+ // GetBestBlock returns the hash and height of the best block known to
+ // the backend.
+ GetBestBlock() (*chainhash.Hash, int32, error)
+
+ // GetBlockHash returns the hash of the block with the given height.
+ GetBlockHash(int64) (*chainhash.Hash, error)
+
+ // GetBlockHeader returns the header for the block with the given hash.
+ GetBlockHeader(*chainhash.Hash) (*wire.BlockHeader, error)
+}
+
+// birthdayStore is an interface that abstracts the wallet's sync-related
+// information required to perform a birthday block sanity check.
+type birthdayStore interface {
+ // Birthday returns the birthday timestamp of the wallet.
+ Birthday() time.Time
+
+ // BirthdayBlock returns the birthday block of the wallet. The boolean
+ // returned should signal whether the wallet has already verified the
+ // correctness of its birthday block.
+ BirthdayBlock() (waddrmgr.BlockStamp, bool, error)
+
+ // SetBirthdayBlock updates the birthday block of the wallet to the
+ // given block. The boolean can be used to signal whether this block
+ // should be sanity checked the next time the wallet starts.
+ //
+ // NOTE: This should also set the wallet's synced tip to reflect the new
+ // birthday block. This will allow the wallet to rescan from this point
+ // to detect any potentially missed events.
+ SetBirthdayBlock(waddrmgr.BlockStamp) error
+}
+
+// walletBirthdayStore is a wrapper around the wallet's database and address
+// manager that satisfies the birthdayStore interface.
+type walletBirthdayStore struct {
+ db walletdb.DB
+ manager waddrmgr.AddrStore
+}
+
+var _ birthdayStore = (*walletBirthdayStore)(nil)
+
+// Birthday returns the birthday timestamp of the wallet.
+func (s *walletBirthdayStore) Birthday() time.Time {
+ return s.manager.Birthday()
+}
+
+// BirthdayBlock returns the birthday block of the wallet.
+func (s *walletBirthdayStore) BirthdayBlock() (waddrmgr.BlockStamp, bool, error) {
+ var (
+ birthdayBlock waddrmgr.BlockStamp
+ birthdayBlockVerified bool
+ )
+
+ err := walletdb.View(s.db, func(tx walletdb.ReadTx) error {
+ var err error
+ ns := tx.ReadBucket(waddrmgrNamespaceKey)
+ birthdayBlock, birthdayBlockVerified, err = s.manager.BirthdayBlock(ns)
+ return err
+ })
+
+ return birthdayBlock, birthdayBlockVerified, err
+}
+
+// SetBirthdayBlock updates the birthday block of the wallet to the
+// given block. The boolean can be used to signal whether this block
+// should be sanity checked the next time the wallet starts.
+//
+// NOTE: This should also set the wallet's synced tip to reflect the new
+// birthday block. This will allow the wallet to rescan from this point
+// to detect any potentially missed events.
+func (s *walletBirthdayStore) SetBirthdayBlock(block waddrmgr.BlockStamp) error {
+ return walletdb.Update(s.db, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ err := s.manager.SetBirthdayBlock(ns, block, true)
+ if err != nil {
+ return err
+ }
+ return s.manager.SetSyncedTo(ns, &block)
+ })
+}
+
+// birthdaySanityCheck is a helper function that ensures a birthday block
+// correctly reflects the birthday timestamp within a reasonable timestamp
+// delta. It's intended to be run after the wallet establishes its connection
+// with the backend, but before it begins syncing. This is done as the second
+// part to the wallet's address manager migration where we populate the birthday
+// block to ensure we do not miss any relevant events throughout rescans.
+// waddrmgr.ErrBirthdayBlockNotSet is returned if the birthday block has not
+// been set yet.
+func birthdaySanityCheck(chainConn chainConn,
+ birthdayStore birthdayStore) (*waddrmgr.BlockStamp, error) {
+
+ // We'll start by fetching our wallet's birthday timestamp and block.
+ birthdayTimestamp := birthdayStore.Birthday()
+ birthdayBlock, birthdayBlockVerified, err := birthdayStore.BirthdayBlock()
+ if err != nil {
+ return nil, err
+ }
+
+ // If the birthday block has already been verified to be correct, we can
+ // exit our sanity check to prevent potentially fetching a better
+ // candidate.
+ if birthdayBlockVerified {
+ log.Debugf("Birthday block has already been verified: "+
+ "height=%d, hash=%v", birthdayBlock.Height,
+ birthdayBlock.Hash)
+
+ return &birthdayBlock, nil
+ }
+
+ // Otherwise, we'll attempt to locate a better one now that we have
+ // access to the chain.
+ newBirthdayBlock, err := locateBirthdayBlock(chainConn, birthdayTimestamp)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := birthdayStore.SetBirthdayBlock(*newBirthdayBlock); err != nil {
+ return nil, err
+ }
+
+ return newBirthdayBlock, nil
+}
+
+// secretSource is an implementation of txauthor.SecretSource for the wallet's
+// address manager.
+type secretSource struct {
+ waddrmgr.AddrStore
+
+ addrmgrNs walletdb.ReadBucket
+}
+
+func (s secretSource) GetKey(addr address.Address) (*btcec.PrivateKey, bool, error) {
+ ma, err := s.Address(s.addrmgrNs, addr)
+ if err != nil {
+ return nil, false, err
+ }
+
+ mpka, ok := ma.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ e := fmt.Errorf("managed address type for %v is `%T` but "+
+ "want waddrmgr.ManagedPubKeyAddress", addr, ma)
+ return nil, false, e
+ }
+ privKey, err := mpka.PrivKey()
+ if err != nil {
+ return nil, false, err
+ }
+ return privKey, ma.Compressed(), nil
+}
+
+func (s secretSource) GetScript(addr address.Address) ([]byte, error) {
+ ma, err := s.Address(s.addrmgrNs, addr)
+ if err != nil {
+ return nil, err
+ }
+
+ msa, ok := ma.(waddrmgr.ManagedScriptAddress)
+ if !ok {
+ e := fmt.Errorf("managed address type for %v is `%T` but "+
+ "want waddrmgr.ManagedScriptAddress", addr, ma)
+ return nil, e
+ }
+ return msa.Script()
+}
+
+// txToOutputs creates a signed transaction which includes each output from
+// outputs. Previous outputs to redeem are chosen from the passed account's
+// UTXO set and minconf policy. An additional output may be added to return
+// change to the wallet. This output will have an address generated from the
+// given key scope and account. If a key scope is not specified, the address
+// will always be generated from the P2WKH key scope. An appropriate fee is
+// included based on the wallet's current relay fee. The wallet must be
+// unlocked to create the transaction.
+//
+// NOTE: The dryRun argument can be set true to create a tx that doesn't alter
+// the database. A tx created with this set to true will intentionally have no
+// input scripts added and SHOULD NOT be broadcasted.
+func (w *Wallet) txToOutputs(outputs []*wire.TxOut,
+ coinSelectKeyScope, changeKeyScope *waddrmgr.KeyScope,
+ account uint32, minconf int32, feeSatPerKb btcutil.Amount,
+ strategy CoinSelectionStrategy, dryRun bool,
+ selectedUtxos []wire.OutPoint,
+ allowUtxo func(utxo wtxmgr.Credit) bool) (
+ *txauthor.AuthoredTx, error) {
+
+ chainClient, err := w.requireChainClient()
+ if err != nil {
+ return nil, err
+ }
+
+ // Get current block's height and hash.
+ bs, err := chainClient.BlockStamp()
+ if err != nil {
+ return nil, err
+ }
+
+ // Fall back to default coin selection strategy if none is supplied.
+ if strategy == nil {
+ strategy = CoinSelectionLargest
+ }
+
+ // The addrMgrWithChangeSource function of the wallet creates a
+ // new change address. The address manager uses OnCommit on the
+ // walletdb tx to update the in-memory state of the account
+ // state. But because the commit happens _after_ the account
+ // manager internal lock has been released, there is a chance
+ // for the address index to be accessed concurrently, even
+ // though the closure in OnCommit re-acquires the lock. To avoid
+ // this issue, we surround the whole address creation process
+ // with a lock.
+ w.newAddrMtx.Lock()
+ defer w.newAddrMtx.Unlock()
+
+ var tx *txauthor.AuthoredTx
+ err = walletdb.Update(w.db, func(dbtx walletdb.ReadWriteTx) error {
+ addrmgrNs, changeSource, err := w.addrMgrWithChangeSource(
+ dbtx, changeKeyScope, account,
+ )
+ if err != nil {
+ return err
+ }
+
+ eligible, err := w.findEligibleOutputs(
+ dbtx, coinSelectKeyScope, account,
+ //nolint:gosec
+ uint32(minconf),
+ bs, allowUtxo,
+ )
+ if err != nil {
+ return err
+ }
+
+ var inputSource txauthor.InputSource
+ if len(selectedUtxos) > 0 {
+ dedupUtxos := fn.NewSet(selectedUtxos...)
+ if len(dedupUtxos) != len(selectedUtxos) {
+ return errors.New("selected UTXOs contain " +
+ "duplicate values")
+ }
+
+ eligibleByOutpoint := make(
+ map[wire.OutPoint]wtxmgr.Credit,
+ )
+
+ for _, e := range eligible {
+ eligibleByOutpoint[e.OutPoint] = e
+ }
+
+ var eligibleSelectedUtxo []wtxmgr.Credit
+ for _, outpoint := range selectedUtxos {
+ e, ok := eligibleByOutpoint[outpoint]
+
+ if !ok {
+ return fmt.Errorf("selected outpoint "+
+ "not eligible for "+
+ "spending: %v", outpoint)
+ }
+ eligibleSelectedUtxo = append(
+ eligibleSelectedUtxo, e,
+ )
+ }
+
+ inputSource = constantInputSource(eligibleSelectedUtxo)
+
+ } else {
+ // Wrap our coins in a type that implements the
+ // SelectableCoin interface, so we can arrange them
+ // according to the selected coin selection strategy.
+ wrappedEligible := make([]Coin, len(eligible))
+ for i := range eligible {
+ wrappedEligible[i] = Coin{
+ TxOut: wire.TxOut{
+ Value: int64(
+ eligible[i].Amount,
+ ),
+ PkScript: eligible[i].PkScript,
+ },
+ OutPoint: eligible[i].OutPoint,
+ }
+ }
+
+ arrangedCoins, err := strategy.ArrangeCoins(
+ wrappedEligible, feeSatPerKb,
+ )
+ if err != nil {
+ return err
+ }
+ inputSource = makeInputSource(arrangedCoins)
+ }
+
+ tx, err = txauthor.NewUnsignedTransaction(
+ outputs, feeSatPerKb, inputSource, changeSource,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Randomize change position, if change exists, before signing.
+ // This doesn't affect the serialize size, so the change amount
+ // will still be valid.
+ if tx.ChangeIndex >= 0 {
+ tx.RandomizeChangePosition()
+ }
+
+ // If a dry run was requested, we return now before adding the
+ // input scripts, and don't commit the database transaction.
+ // By returning an error, we make sure the walletdb.Update call
+ // rolls back the transaction. But we'll react to this specific
+ // error outside of the DB transaction so we can still return
+ // the produced chain TX.
+ if dryRun {
+ return walletdb.ErrDryRunRollBack
+ }
+
+ // Before committing the transaction, we'll sign our inputs. If
+ // the inputs are part of a watch-only account, there's no
+ // private key information stored, so we'll skip signing such.
+ var watchOnly bool
+ if coinSelectKeyScope == nil {
+ // If a key scope wasn't specified, then coin selection
+ // was performed from the default wallet accounts
+ // (NP2WKH, P2WKH, P2TR), so any key scope provided
+ // doesn't impact the result of this call.
+ watchOnly, err = w.addrStore.IsWatchOnlyAccount(
+ addrmgrNs, waddrmgr.KeyScopeBIP0086, account,
+ )
+ } else {
+ watchOnly, err = w.addrStore.IsWatchOnlyAccount(
+ addrmgrNs, *coinSelectKeyScope, account,
+ )
+ }
+ if err != nil {
+ return err
+ }
+ if !watchOnly {
+ err = tx.AddAllInputScripts(
+ secretSource{w.addrStore, addrmgrNs},
+ )
+ if err != nil {
+ return err
+ }
+
+ err = validateMsgTx(
+ tx.Tx, tx.PrevScripts, tx.PrevInputValues,
+ )
+ if err != nil {
+ return err
+ }
+ }
+
+ if tx.ChangeIndex >= 0 && account == waddrmgr.ImportedAddrAccount {
+ changeAmount := btcutil.Amount(
+ tx.Tx.TxOut[tx.ChangeIndex].Value,
+ )
+ log.Warnf("Spend from imported account produced "+
+ "change: moving %v from imported account into "+
+ "default account.", changeAmount)
+ }
+
+ // Finally, we'll request the backend to notify us of the
+ // transaction that pays to the change address, if there is one,
+ // when it confirms.
+ if tx.ChangeIndex >= 0 {
+ changePkScript := tx.Tx.TxOut[tx.ChangeIndex].PkScript
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(
+ changePkScript, w.chainParams,
+ )
+ if err != nil {
+ return err
+ }
+ if err := chainClient.NotifyReceived(addrs); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+ if err != nil && !errors.Is(err, walletdb.ErrDryRunRollBack) {
+ return nil, err
+ }
+
+ return tx, nil
+}
+
+// validateMsgTx verifies transaction input scripts for tx. All previous output
+// scripts from outputs redeemed by the transaction, in the same order they are
+// spent, must be passed in the prevScripts slice.
+func validateMsgTx(tx *wire.MsgTx, prevScripts [][]byte,
+ inputValues []btcutil.Amount) error {
+
+ inputFetcher, err := txauthor.TXPrevOutFetcher(
+ tx, prevScripts, inputValues,
+ )
+ if err != nil {
+ return err
+ }
+
+ hashCache := txscript.NewTxSigHashes(tx, inputFetcher)
+ for i, prevScript := range prevScripts {
+ vm, err := txscript.NewEngine(
+ prevScript, tx, i, txscript.StandardVerifyFlags, nil,
+ hashCache, int64(inputValues[i]), inputFetcher,
+ )
+ if err != nil {
+ return fmt.Errorf("cannot create script engine: %w", err)
+ }
+ err = vm.Execute()
+ if err != nil {
+ return fmt.Errorf("cannot validate transaction: %w", err)
+ }
+ }
+ return nil
+}
+
+const (
+ // accountPubKeyDepth is the maximum depth of an extended key for an
+ // account public key.
+ accountPubKeyDepth = 3
+
+ // pubKeyDepth is the depth of an extended key for a derived public key.
+ pubKeyDepth = 5
+)
+
+// keyScopeFromPubKey returns the corresponding wallet key scope for the given
+// extended public key. The address type can usually be inferred from the key's
+// version, but may be required for certain keys to map them into the proper
+// scope.
+func keyScopeFromPubKey(pubKey *hdkeychain.ExtendedKey,
+ addrType *waddrmgr.AddressType) (waddrmgr.KeyScope,
+ *waddrmgr.ScopeAddrSchema, error) {
+
+ switch waddrmgr.HDVersion(binary.BigEndian.Uint32(pubKey.Version())) {
+ // For BIP-0044 keys, an address type must be specified as we intend to
+ // not support importing BIP-0044 keys into the wallet using the legacy
+ // pay-to-pubkey-hash (P2PKH) scheme. A nested witness address type will
+ // force the standard BIP-0049 derivation scheme (nested witness pubkeys
+ // everywhere), while a witness address type will force the standard
+ // BIP-0084 derivation scheme.
+ case waddrmgr.HDVersionMainNetBIP0044, waddrmgr.HDVersionTestNetBIP0044,
+ waddrmgr.HDVersionSimNetBIP0044:
+
+ if addrType == nil {
+ return waddrmgr.KeyScope{}, nil, errors.New("address " +
+ "type must be specified for account public " +
+ "key with legacy version")
+ }
+
+ switch *addrType {
+ case waddrmgr.NestedWitnessPubKey:
+ return waddrmgr.KeyScopeBIP0049Plus,
+ &waddrmgr.KeyScopeBIP0049AddrSchema, nil
+
+ case waddrmgr.WitnessPubKey:
+ return waddrmgr.KeyScopeBIP0084, nil, nil
+
+ case waddrmgr.TaprootPubKey:
+ return waddrmgr.KeyScopeBIP0086, nil, nil
+
+ default:
+ return waddrmgr.KeyScope{}, nil,
+ fmt.Errorf("unsupported address type %v",
+ *addrType)
+ }
+
+ // For BIP-0049 keys, we'll need to make a distinction between the
+ // traditional BIP-0049 address schema (nested witness pubkeys
+ // everywhere) and our own BIP-0049Plus address schema (nested
+ // externally, witness internally).
+ case waddrmgr.HDVersionMainNetBIP0049, waddrmgr.HDVersionTestNetBIP0049:
+ if addrType == nil {
+ return waddrmgr.KeyScope{}, nil, errors.New("address " +
+ "type must be specified for account public " +
+ "key with BIP-0049 version")
+ }
+
+ switch *addrType {
+ case waddrmgr.NestedWitnessPubKey:
+ return waddrmgr.KeyScopeBIP0049Plus,
+ &waddrmgr.KeyScopeBIP0049AddrSchema, nil
+
+ case waddrmgr.WitnessPubKey:
+ return waddrmgr.KeyScopeBIP0049Plus, nil, nil
+
+ default:
+ return waddrmgr.KeyScope{}, nil,
+ fmt.Errorf("unsupported address type %v",
+ *addrType)
+ }
+
+ // BIP-0086 does not have its own SLIP-0132 HD version byte set (yet?).
+ // So we either expect a user to import it with a BIP-0084 or BIP-0044
+ // encoding.
+ case waddrmgr.HDVersionMainNetBIP0084, waddrmgr.HDVersionTestNetBIP0084:
+ if addrType == nil {
+ return waddrmgr.KeyScope{}, nil, errors.New("address " +
+ "type must be specified for account public " +
+ "key with BIP-0084 version")
+ }
+
+ switch *addrType {
+ case waddrmgr.WitnessPubKey:
+ return waddrmgr.KeyScopeBIP0084, nil, nil
+
+ case waddrmgr.TaprootPubKey:
+ return waddrmgr.KeyScopeBIP0086, nil, nil
+
+ default:
+ return waddrmgr.KeyScope{}, nil,
+ errors.New("address type mismatch")
+ }
+
+ default:
+ return waddrmgr.KeyScope{}, nil, fmt.Errorf("unknown version %x",
+ pubKey.Version())
+ }
+}
+
+// ImportAccountDeprecated imports an account backed by an account extended
+// public key.
+// The master key fingerprint denotes the fingerprint of the root key
+// corresponding to the account public key (also known as the key with
+// derivation path m/). This may be required by some hardware wallets for proper
+// identification and signing.
+//
+// The address type can usually be inferred from the key's version, but may be
+// required for certain keys to map them into the proper scope.
+//
+// For BIP-0044 keys, an address type must be specified as we intend to not
+// support importing BIP-0044 keys into the wallet using the legacy
+// pay-to-pubkey-hash (P2PKH) scheme. A nested witness address type will force
+// the standard BIP-0049 derivation scheme, while a witness address type will
+// force the standard BIP-0084 derivation scheme.
+//
+// For BIP-0049 keys, an address type must also be specified to make a
+// distinction between the traditional BIP-0049 address schema (nested witness
+// pubkeys everywhere) and our own BIP-0049Plus address schema (nested
+// externally, witness internally).
+func (w *Wallet) ImportAccountDeprecated(
+ name string, accountPubKey *hdkeychain.ExtendedKey,
+ masterKeyFingerprint uint32, addrType *waddrmgr.AddressType) (
+ *waddrmgr.AccountProperties, error) {
+
+ var accountProps *waddrmgr.AccountProperties
+ err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ var err error
+ accountProps, err = w.importAccount(
+ ns, name, accountPubKey, masterKeyFingerprint, addrType,
+ )
+ return err
+ })
+ return accountProps, err
+}
+
+// ImportAccountWithScope imports an account backed by an account extended
+// public key for a specific key scope which is known in advance.
+// The master key fingerprint denotes the fingerprint of the root key
+// corresponding to the account public key (also known as the key with
+// derivation path m/). This may be required by some hardware wallets for proper
+// identification and signing.
+func (w *Wallet) ImportAccountWithScope(name string,
+ accountPubKey *hdkeychain.ExtendedKey, masterKeyFingerprint uint32,
+ keyScope waddrmgr.KeyScope, addrSchema waddrmgr.ScopeAddrSchema) (
+ *waddrmgr.AccountProperties, error) {
+
+ var accountProps *waddrmgr.AccountProperties
+ err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ var err error
+ accountProps, err = w.importAccountScope(
+ ns, name, accountPubKey, masterKeyFingerprint, keyScope,
+ &addrSchema,
+ )
+ return err
+ })
+ return accountProps, err
+}
+
+// importAccount is the internal implementation of ImportAccount -- one should
+// reference its documentation for this method.
+func (w *Wallet) importAccount(ns walletdb.ReadWriteBucket, name string,
+ accountPubKey *hdkeychain.ExtendedKey, masterKeyFingerprint uint32,
+ addrType *waddrmgr.AddressType) (*waddrmgr.AccountProperties, error) {
+
+ // Ensure we have a valid account public key.
+ if err := validateExtendedPubKey(accountPubKey, true, w.chainParams); err != nil {
+ return nil, err
+ }
+
+ // Determine what key scope the account public key should belong to and
+ // whether it should use a custom address schema.
+ keyScope, addrSchema, err := keyScopeFromPubKey(accountPubKey, addrType)
+ if err != nil {
+ return nil, err
+ }
+
+ return w.importAccountScope(
+ ns, name, accountPubKey, masterKeyFingerprint, keyScope,
+ addrSchema,
+ )
+}
+
+// importAccountScope imports a watch-only account for a given scope.
+func (w *Wallet) importAccountScope(ns walletdb.ReadWriteBucket, name string,
+ accountPubKey *hdkeychain.ExtendedKey, masterKeyFingerprint uint32,
+ keyScope waddrmgr.KeyScope, addrSchema *waddrmgr.ScopeAddrSchema) (
+ *waddrmgr.AccountProperties, error) {
+
+ scopedMgr, err := w.addrStore.FetchScopedKeyManager(keyScope)
+ if err != nil {
+ scopedMgr, err = w.addrStore.NewScopedKeyManager(
+ ns, keyScope, *addrSchema,
+ )
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ account, err := scopedMgr.NewAccountWatchingOnly(
+ ns, name, accountPubKey, masterKeyFingerprint, addrSchema,
+ )
+ if err != nil {
+ return nil, err
+ }
+ return scopedMgr.AccountProperties(ns, account)
+}
+
+// ImportAccountDryRun serves as a dry run implementation of ImportAccount. This
+// method also returns the first N external and internal addresses, which can be
+// presented to users to confirm whether the account has been imported
+// correctly.
+func (w *Wallet) ImportAccountDryRun(name string,
+ accountPubKey *hdkeychain.ExtendedKey, masterKeyFingerprint uint32,
+ addrType *waddrmgr.AddressType, numAddrs uint32) (
+ *waddrmgr.AccountProperties, []waddrmgr.ManagedAddress,
+ []waddrmgr.ManagedAddress, error) {
+
+ // The address manager uses OnCommit on the walletdb tx to update the
+ // in-memory state of the account state. But because the commit happens
+ // _after_ the account manager internal lock has been released, there
+ // is a chance for the address index to be accessed concurrently, even
+ // though the closure in OnCommit re-acquires the lock. To avoid this
+ // issue, we surround the whole address creation process with a lock.
+ w.newAddrMtx.Lock()
+ defer w.newAddrMtx.Unlock()
+
+ var (
+ accountProps *waddrmgr.AccountProperties
+ externalAddrs []waddrmgr.ManagedAddress
+ internalAddrs []waddrmgr.ManagedAddress
+ )
+
+ // Start a database transaction that we'll never commit and always
+ // rollback because we'll return a specific error in the end.
+ err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ // Import the account as usual.
+ var err error
+ accountProps, err = w.importAccount(
+ ns, name, accountPubKey, masterKeyFingerprint, addrType,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Derive the external and internal addresses. Note that we
+ // could do this based on the provided accountPubKey alone, but
+ // we go through the ScopedKeyManager instead to ensure
+ // addresses will be derived as expected from the wallet's
+ // point-of-view.
+ manager, err := w.addrStore.FetchScopedKeyManager(
+ accountProps.KeyScope,
+ )
+ if err != nil {
+ return err
+ }
+
+ // The importAccount method above will cache the imported
+ // account within the scoped manager. Since this is a dry-run
+ // attempt, we'll want to invalidate the cache for it.
+ defer manager.InvalidateAccountCache(accountProps.AccountNumber)
+
+ externalAddrs, err = manager.NextExternalAddresses(
+ ns, accountProps.AccountNumber, numAddrs,
+ )
+ if err != nil {
+ return err
+ }
+ internalAddrs, err = manager.NextInternalAddresses(
+ ns, accountProps.AccountNumber, numAddrs,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Refresh the account's properties after generating the
+ // addresses.
+ accountProps, err = manager.AccountProperties(
+ ns, accountProps.AccountNumber,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Make sure we always roll back the dry-run transaction by
+ // returning an error here.
+ return walletdb.ErrDryRunRollBack
+ })
+ if err != nil && err != walletdb.ErrDryRunRollBack {
+ return nil, nil, nil, err
+ }
+
+ return accountProps, externalAddrs, internalAddrs, nil
+}
+
+// ImportPublicKey imports a single derived public key into the address manager.
+// The address type can usually be inferred from the key's version, but in the
+// case of legacy versions (xpub, tpub), an address type must be specified as we
+// intend to not support importing BIP-44 keys into the wallet using the legacy
+// pay-to-pubkey-hash (P2PKH) scheme.
+func (w *Wallet) ImportPublicKeyDeprecated(pubKey *btcec.PublicKey,
+ addrType waddrmgr.AddressType) error {
+
+ // Determine what key scope the public key should belong to and import
+ // it into the key scope's default imported account.
+ var keyScope waddrmgr.KeyScope
+ switch addrType {
+ case waddrmgr.NestedWitnessPubKey:
+ keyScope = waddrmgr.KeyScopeBIP0049Plus
+
+ case waddrmgr.WitnessPubKey:
+ keyScope = waddrmgr.KeyScopeBIP0084
+
+ case waddrmgr.TaprootPubKey:
+ keyScope = waddrmgr.KeyScopeBIP0086
+
+ default:
+ return fmt.Errorf("address type %v is not supported", addrType)
+ }
+
+ scopedKeyManager, err := w.addrStore.FetchScopedKeyManager(keyScope)
+ if err != nil {
+ return err
+ }
+
+ // TODO: Perform rescan if requested.
+ var addr waddrmgr.ManagedAddress
+ err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ addr, err = scopedKeyManager.ImportPublicKey(ns, pubKey, nil)
+ return err
+ })
+ if err != nil {
+ return err
+ }
+
+ log.Infof("Imported address %v", addr.Address())
+
+ err = w.chainClient.NotifyReceived([]address.Address{addr.Address()})
+ if err != nil {
+ return fmt.Errorf("unable to subscribe for address "+
+ "notifications: %w", err)
+ }
+
+ return nil
+}
+
+// ImportTaprootScriptDeprecated imports a user-provided taproot script into the
+// address manager. The imported script will act as a pay-to-taproot address.
+//
+// Deprecated: Use AddressManager.ImportTaprootScript instead.
+func (w *Wallet) ImportTaprootScriptDeprecated(scope waddrmgr.KeyScope,
+ tapscript *waddrmgr.Tapscript, bs *waddrmgr.BlockStamp,
+ witnessVersion byte, isSecretScript bool) (waddrmgr.ManagedAddress,
+ error) {
+
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return nil, err
+ }
+
+ // The starting block for the key is the genesis block unless otherwise
+ // specified.
+ if bs == nil {
+ bs = &waddrmgr.BlockStamp{
+ Hash: *w.chainParams.GenesisHash,
+ Height: 0,
+ Timestamp: w.chainParams.GenesisBlock.Header.Timestamp,
+ }
+ } else if bs.Timestamp.IsZero() {
+ // Only update the new birthday time from default value if we
+ // actually have timestamp info in the header.
+ header, err := w.chainClient.GetBlockHeader(&bs.Hash)
+ if err == nil {
+ bs.Timestamp = header.Timestamp
+ }
+ }
+
+ // TODO: Perform rescan if requested.
+ var addr waddrmgr.ManagedAddress
+ err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ addr, err = manager.ImportTaprootScript(
+ ns, tapscript, bs, witnessVersion, isSecretScript,
+ )
+ return err
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ log.Infof("Imported address %v", addr.Address())
+
+ err = w.chainClient.NotifyReceived([]address.Address{addr.Address()})
+ if err != nil {
+ return nil, fmt.Errorf("unable to subscribe for address "+
+ "notifications: %w", err)
+ }
+
+ return addr, nil
+}
+
+// ImportPrivateKey imports a private key to the wallet and writes the new
+// wallet to disk.
+//
+// NOTE: If a block stamp is not provided, then the wallet's birthday will be
+// set to the genesis block of the corresponding chain.
+func (w *Wallet) ImportPrivateKey(scope waddrmgr.KeyScope, wif *btcutil.WIF,
+ bs *waddrmgr.BlockStamp, rescan bool) (string, error) {
+
+ manager, err := w.addrStore.FetchScopedKeyManager(scope)
+ if err != nil {
+ return "", err
+ }
+
+ // The starting block for the key is the genesis block unless otherwise
+ // specified.
+ if bs == nil {
+ bs = &waddrmgr.BlockStamp{
+ Hash: *w.chainParams.GenesisHash,
+ Height: 0,
+ Timestamp: w.chainParams.GenesisBlock.Header.Timestamp,
+ }
+ } else if bs.Timestamp.IsZero() {
+ // Only update the new birthday time from default value if we
+ // actually have timestamp info in the header.
+ header, err := w.chainClient.GetBlockHeader(&bs.Hash)
+ if err == nil {
+ bs.Timestamp = header.Timestamp
+ }
+ }
+
+ // Attempt to import private key into wallet.
+ var addr address.Address
+ var props *waddrmgr.AccountProperties
+ err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ maddr, err := manager.ImportPrivateKey(addrmgrNs, wif, bs)
+ if err != nil {
+ return err
+ }
+ addr = maddr.Address()
+ props, err = manager.AccountProperties(
+ addrmgrNs, waddrmgr.ImportedAddrAccount,
+ )
+ if err != nil {
+ return err
+ }
+
+ // We'll only update our birthday with the new one if it is
+ // before our current one. Otherwise, if we do, we can
+ // potentially miss detecting relevant chain events that
+ // occurred between them while rescanning.
+ birthdayBlock, _, err := w.addrStore.BirthdayBlock(addrmgrNs)
+ if err != nil {
+ return err
+ }
+ if bs.Height >= birthdayBlock.Height {
+ return nil
+ }
+
+ err = w.addrStore.SetBirthday(addrmgrNs, bs.Timestamp)
+ if err != nil {
+ return err
+ }
+
+ // To ensure this birthday block is correct, we'll mark it as
+ // unverified to prompt a sanity check at the next restart to
+ // ensure it is correct as it was provided by the caller.
+ return w.addrStore.SetBirthdayBlock(addrmgrNs, *bs, false)
+ })
+ if err != nil {
+ return "", err
+ }
+
+ // Rescan blockchain for transactions with txout scripts paying to the
+ // imported address.
+ if rescan {
+ job := &RescanJob{
+ Addrs: []address.Address{addr},
+ OutPoints: nil,
+ BlockStamp: *bs,
+ }
+
+ // Submit rescan job and log when the import has completed.
+ // Do not block on finishing the rescan. The rescan success
+ // or failure is logged elsewhere, and the channel is not
+ // required to be read, so discard the return value.
+ _ = w.SubmitRescan(job)
+ } else {
+ err := w.chainClient.NotifyReceived([]address.Address{addr})
+ if err != nil {
+ return "", fmt.Errorf("failed to subscribe for address ntfns for "+
+ "address %s: %w", addr.EncodeAddress(), err)
+ }
+ }
+
+ addrStr := addr.EncodeAddress()
+ log.Infof("Imported payment address %s", addrStr)
+
+ w.NtfnServer.notifyAccountProperties(props)
+
+ // Return the payment address string of the imported private key.
+ return addrStr, nil
+}
+
+// walletDeprecated encapsulates the legacy state and communication channels
+// that are being phased out in favor of the modern Controller and Syncer
+// architecture.
+//
+// Embedding this struct in the Wallet allows old logic to continue functioning
+// while clearly marking the fields as legacy. Access to these fields should
+// ideally be restricted to methods moved to this file.
+type walletDeprecated struct {
+ // Deprecated fields.
+ //
+ // NOTE: Listing below are deprecated fields and will be removed once
+ // the sqlization series is finished.
+ started bool
+ quit chan struct{}
+ quitMu sync.Mutex
+
+ // publicPassphrase is the passphrase used to encrypt and decrypt public
+ // data in the address manager.
+ publicPassphrase []byte
+
+ // db is the underlying key-value database where all wallet data is
+ // persisted.
+ db walletdb.DB
+
+ // recoveryWindow specifies the number of additional keys to derive
+ // beyond the last used one to look for previously used addresses
+ // during a rescan or recovery.
+ recoveryWindow uint32
+
+ chainClient chain.Interface
+ chainClientLock sync.Mutex
+ chainClientSynced bool
+ chainClientSyncMtx sync.Mutex
+
+ newAddrMtx sync.Mutex
+
+ lockedOutpoints map[wire.OutPoint]struct{}
+ lockedOutpointsMtx sync.Mutex
+
+ chainParams *chaincfg.Params
+
+ recovering atomic.Value
+
+ // Channels for rescan processing. Requests are added and merged with
+ // any waiting requests, before being sent to another goroutine to
+ // call the rescan RPC.
+ rescanAddJob chan *RescanJob
+ rescanBatch chan *rescanBatch
+ rescanNotifications chan any // From chain server
+ rescanProgress chan *RescanProgressMsg
+ rescanFinished chan *RescanFinishedMsg
+
+ // Channels for the manager locker.
+ unlockRequests chan unlockRequest
+ lockRequests chan struct{}
+ holdUnlockRequests chan chan heldUnlock
+ lockState chan bool
+ changePassphrase chan changePassphraseRequest
+ changePassphrases chan changePassphrasesRequest
+
+ // Channel for transaction creation requests.
+ createTxRequests chan createTxRequest
+
+ // rescanFinishedChan is a channel used to signal the completion of a
+ // rescan operation from the main loop to the rescan loop.
+ rescanFinishedChan chan *chain.RescanFinished
+
+ // syncRetryInterval is the amount of time to wait between re-tries on
+ // errors during initial sync.
+ syncRetryInterval time.Duration
+}
+
+// findEligibleOutputs finds eligible outputs for the given key scope and
+// account.
+func (w *Wallet) findEligibleOutputs(dbtx walletdb.ReadTx,
+ keyScope *waddrmgr.KeyScope, account uint32, minconf uint32,
+ bs *waddrmgr.BlockStamp,
+ allowUtxo func(utxo wtxmgr.Credit) bool) ([]wtxmgr.Credit, error) {
+
+ addrmgrNs := dbtx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
+
+ unspent, err := w.txStore.UnspentOutputs(txmgrNs)
+ if err != nil {
+ return nil, err
+ }
+
+ // TODO: Eventually all of these filters (except perhaps output locking)
+ // should be handled by the call to UnspentOutputs (or similar).
+ // Because one of these filters requires matching the output script to
+ // the desired account, this change depends on making wtxmgr a waddrmgr
+ // dependency and requesting unspent outputs for a single account.
+ eligible := make([]wtxmgr.Credit, 0, len(unspent))
+ for i := range unspent {
+ output := &unspent[i]
+
+ // Restrict the selected utxos if a filter function is provided.
+ if allowUtxo != nil && !allowUtxo(*output) {
+ continue
+ }
+
+ // Only include this output if it meets the required number of
+ // confirmations. Coinbase transactions must have reached
+ // maturity before their outputs may be spent.
+ if !hasMinConfs(minconf, output.Height, bs.Height) {
+ continue
+ }
+
+ if output.FromCoinBase {
+ target := w.chainParams.CoinbaseMaturity
+ if !hasMinConfs(
+ uint32(target), output.Height, bs.Height,
+ ) {
+
+ continue
+ }
+ }
+
+ // Locked unspent outputs are skipped.
+ if w.LockedOutpoint(output.OutPoint) {
+ continue
+ }
+
+ // Only include the output if it is associated with the passed
+ // account.
+ //
+ // TODO: Handle multisig outputs by determining if enough of the
+ // addresses are controlled.
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(
+ output.PkScript, w.chainParams)
+ if err != nil || len(addrs) != 1 {
+ continue
+ }
+
+ scopedMgr, addrAcct, err := w.addrStore.AddrAccount(
+ addrmgrNs, addrs[0],
+ )
+ if err != nil {
+ continue
+ }
+
+ if keyScope != nil && scopedMgr.Scope() != *keyScope {
+ continue
+ }
+
+ if addrAcct != account {
+ continue
+ }
+
+ eligible = append(eligible, *output)
+ }
+
+ return eligible, nil
+}
+
+// RescanDeprecated begins a rescan for all active addresses and unspent outputs
+// of a wallet. This is intended to be used to sync a wallet back up to the
+// current best block in the main chain, and is considered an initial sync
+// rescan.
+func (w *Wallet) RescanDeprecated(addrs []address.Address,
+ unspent []wtxmgr.Credit) error {
+
+ return w.rescanWithTarget(addrs, unspent, nil)
+}
+
+// RescanProgressMsg reports the current progress made by a rescan for a
+// set of wallet addresses.
+type RescanProgressMsg struct {
+ Addresses []address.Address
+ Notification chain.RescanProgress
+}
+
+// RescanFinishedMsg reports the addresses that were rescanned when a
+// rescanfinished message was received rescanning a batch of addresses.
+type RescanFinishedMsg struct {
+ Addresses []address.Address
+ Notification *chain.RescanFinished
+}
+
+// RescanJob is a job to be processed by the RescanManager. The job includes
+// a set of wallet addresses, a starting height to begin the rescan, and
+// outpoints spendable by the addresses thought to be unspent. After the
+// rescan completes, the error result of the rescan RPC is sent on the Err
+// channel.
+type RescanJob struct {
+ InitialSync bool
+ Addrs []address.Address
+ OutPoints map[wire.OutPoint]address.Address
+ BlockStamp waddrmgr.BlockStamp
+ err chan error
+}
+
+// rescanBatch is a collection of one or more RescanJobs that were merged
+// together before a rescan is performed.
+type rescanBatch struct {
+ initialSync bool
+ addrs []address.Address
+ outpoints map[wire.OutPoint]address.Address
+ bs waddrmgr.BlockStamp
+ errChans []chan error
+}
+
+// SubmitRescan submits a RescanJob to the RescanManager. A channel is
+// returned with the final error of the rescan. The channel is buffered
+// and does not need to be read to prevent a deadlock.
+func (w *Wallet) SubmitRescan(job *RescanJob) <-chan error {
+ errChan := make(chan error, 1)
+ job.err = errChan
+ select {
+ case w.rescanAddJob <- job:
+ case <-w.quitChan():
+ errChan <- ErrWalletShuttingDown
+ }
+ return errChan
+}
+
+// batch creates the rescanBatch for a single rescan job.
+func (job *RescanJob) batch() *rescanBatch {
+ return &rescanBatch{
+ initialSync: job.InitialSync,
+ addrs: job.Addrs,
+ outpoints: job.OutPoints,
+ bs: job.BlockStamp,
+ errChans: []chan error{job.err},
+ }
+}
+
+// merge merges the work from k into j, setting the starting height to
+// the minimum of the two jobs. This method does not check for
+// duplicate addresses or outpoints.
+func (b *rescanBatch) merge(job *RescanJob) {
+ if job.InitialSync {
+ b.initialSync = true
+ }
+ b.addrs = append(b.addrs, job.Addrs...)
+
+ for op, addr := range job.OutPoints {
+ b.outpoints[op] = addr
+ }
+
+ if job.BlockStamp.Height < b.bs.Height {
+ b.bs = job.BlockStamp
+ }
+ b.errChans = append(b.errChans, job.err)
+}
+
+// done iterates through all error channels, duplicating sending the error
+// to inform callers that the rescan finished (or could not complete due
+// to an error).
+func (b *rescanBatch) done(err error) {
+ for _, c := range b.errChans {
+ c <- err
+ }
+}
+
+// rescanBatchHandler handles incoming rescan request, serializing rescan
+// submissions, and possibly batching many waiting requests together so they
+// can be handled by a single rescan after the current one completes.
+func (w *Wallet) rescanBatchHandler() {
+ defer w.wg.Done()
+
+ var curBatch, nextBatch *rescanBatch
+ quit := w.quitChan()
+
+ for {
+ select {
+ case job := <-w.rescanAddJob:
+ if curBatch == nil {
+ // Set current batch as this job and send
+ // request.
+ curBatch = job.batch()
+ select {
+ case w.rescanBatch <- curBatch:
+ case <-quit:
+ job.err <- ErrWalletShuttingDown
+ return
+ }
+ } else {
+ // Create next batch if it doesn't exist, or
+ // merge the job.
+ if nextBatch == nil {
+ nextBatch = job.batch()
+ } else {
+ nextBatch.merge(job)
+ }
+ }
+
+ case n := <-w.rescanNotifications:
+ switch n := n.(type) {
+ case *chain.RescanProgress:
+ if curBatch == nil {
+ log.Warnf("Received rescan progress " +
+ "notification but no rescan " +
+ "currently running")
+ continue
+ }
+ select {
+ case w.rescanProgress <- &RescanProgressMsg{
+ Addresses: curBatch.addrs,
+ Notification: *n,
+ }:
+ case <-quit:
+ for _, errChan := range curBatch.errChans {
+ errChan <- ErrWalletShuttingDown
+ }
+ return
+ }
+
+ case *chain.RescanFinished:
+ if curBatch == nil {
+ log.Warnf("Received rescan finished " +
+ "notification but no rescan " +
+ "currently running")
+ continue
+ }
+ select {
+ case w.rescanFinished <- &RescanFinishedMsg{
+ Addresses: curBatch.addrs,
+ Notification: n,
+ }:
+ case <-quit:
+ for _, errChan := range curBatch.errChans {
+ errChan <- ErrWalletShuttingDown
+ }
+ return
+ }
+
+ curBatch, nextBatch = nextBatch, nil
+
+ if curBatch != nil {
+ select {
+ case w.rescanBatch <- curBatch:
+ case <-quit:
+ for _, errChan := range curBatch.errChans {
+ errChan <- ErrWalletShuttingDown
+ }
+ return
+ }
+ }
+
+ default:
+ // Unexpected message
+ panic(n)
+ }
+
+ case <-quit:
+ return
+ }
+ }
+}
+
+// rescanProgressHandler handles notifications for partially and fully completed
+// rescans by marking each rescanned address as partially or fully synced.
+func (w *Wallet) rescanProgressHandler() {
+ quit := w.quitChan()
+out:
+ for {
+ // These can't be processed out of order since both chans are
+ // unbuffured and are sent from same context (the batch
+ // handler).
+ select {
+ case msg := <-w.rescanProgress:
+ n := msg.Notification
+ log.Infof("Rescanned through block %v (height %d)",
+ n.Hash, n.Height)
+
+ case msg := <-w.rescanFinished:
+ n := msg.Notification
+ addrs := msg.Addresses
+ noun := pickNoun(len(addrs), "address", "addresses")
+ log.Infof("Finished rescan for %d %s (synced to block "+
+ "%s, height %d)", len(addrs), noun, n.Hash,
+ n.Height)
+
+ go w.resendUnminedTxs()
+
+ case <-quit:
+ break out
+ }
+ }
+ w.wg.Done()
+}
+
+// rescanRPCHandler reads batch jobs sent by rescanBatchHandler and sends the
+// RPC requests to perform a rescan. New jobs are not read until a rescan
+// finishes.
+func (w *Wallet) rescanRPCHandler() {
+ chainClient, err := w.requireChainClient()
+ if err != nil {
+ log.Errorf("rescanRPCHandler called without an RPC client")
+ w.wg.Done()
+ return
+ }
+
+ quit := w.quitChan()
+
+out:
+ for {
+ select {
+ case batch := <-w.rescanBatch:
+ // Log the newly-started rescan.
+ numAddrs := len(batch.addrs)
+ numOps := len(batch.outpoints)
+
+ log.Infof("Started rescan from block %v (height %d) "+
+ "for %d addrs, %d outpoints", batch.bs.Hash,
+ batch.bs.Height, numAddrs, numOps)
+
+ err := chainClient.Rescan(
+ &batch.bs.Hash, batch.addrs, batch.outpoints,
+ )
+ if err != nil {
+ log.Errorf("Rescan for %d addrs, %d outpoints "+
+ "failed: %v", numAddrs, numOps, err)
+ }
+ batch.done(err)
+ case <-quit:
+ break out
+ }
+ }
+
+ w.wg.Done()
+}
+
+// rescanWithTarget performs a rescan starting at the optional startStamp. If
+// none is provided, the rescan will begin from the manager's sync tip.
+func (w *Wallet) rescanWithTarget(addrs []address.Address,
+ unspent []wtxmgr.Credit, startStamp *waddrmgr.BlockStamp) error {
+
+ outpoints := make(map[wire.OutPoint]address.Address, len(unspent))
+ for _, output := range unspent {
+ _, outputAddrs, _, err := txscript.ExtractPkScriptAddrs(
+ output.PkScript, w.chainParams,
+ )
+ if err != nil {
+ return err
+ }
+
+ outpoints[output.OutPoint] = outputAddrs[0]
+ }
+
+ // If a start block stamp was provided, we will use that as the initial
+ // starting point for the rescan.
+ if startStamp == nil {
+ startStamp = &waddrmgr.BlockStamp{}
+ *startStamp = w.addrStore.SyncedTo()
+ }
+
+ job := &RescanJob{
+ InitialSync: true,
+ Addrs: addrs,
+ OutPoints: outpoints,
+ BlockStamp: *startStamp,
+ }
+
+ // Submit merged job and block until rescan completes.
+ select {
+ case err := <-w.SubmitRescan(job):
+ return err
+ case <-w.quitChan():
+ return ErrWalletShuttingDown
+ }
+}
+
+// ChainParams returns the network parameters for the blockchain the wallet
+// belongs to.
+func (w *Wallet) ChainParams() *chaincfg.Params {
+ return w.chainParams
+}
+
+// Database returns the underlying walletdb database. This method is provided
+// in order to allow applications wrapping btcwallet to store app-specific data
+// with the wallet's database.
+func (w *Wallet) Database() walletdb.DB {
+ return w.db
+}
+
+// Open loads an already-created wallet from the passed database and namespaces.
+func Open(db walletdb.DB, pubPass []byte, cbs *waddrmgr.OpenCallbacks,
+ params *chaincfg.Params, recoveryWindow uint32) (*Wallet, error) {
+
+ return OpenWithRetry(
+ db, pubPass, cbs, params, recoveryWindow,
+ defaultSyncRetryInterval,
+ )
+}
+
+// OpenWithRetry loads an already-created wallet from the passed database and
+// namespaces and re-tries on errors during initial sync.
+func OpenWithRetry(db walletdb.DB, pubPass []byte, cbs *waddrmgr.OpenCallbacks,
+ params *chaincfg.Params, recoveryWindow uint32,
+ syncRetryInterval time.Duration) (*Wallet, error) {
+
+ var (
+ addrMgr *waddrmgr.Manager
+ txMgr *wtxmgr.Store
+ )
+
+ // Before attempting to open the wallet, we'll check if there are any
+ // database upgrades for us to proceed. We'll also create our references
+ // to the address and transaction managers, as they are backed by the
+ // database.
+ err := walletdb.Update(db, func(tx walletdb.ReadWriteTx) error {
+ addrMgrBucket := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ if addrMgrBucket == nil {
+ return errors.New("missing address manager namespace")
+ }
+ txMgrBucket := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+ if txMgrBucket == nil {
+ return errors.New("missing transaction manager namespace")
+ }
+
+ addrMgrUpgrader := waddrmgr.NewMigrationManager(addrMgrBucket)
+ txMgrUpgrader := wtxmgr.NewMigrationManager(txMgrBucket)
+ err := migration.Upgrade(txMgrUpgrader, addrMgrUpgrader)
+ if err != nil {
+ return err
+ }
+
+ addrMgr, err = waddrmgr.Open(addrMgrBucket, pubPass, params)
+ if err != nil {
+ return err
+ }
+ txMgr, err = wtxmgr.Open(txMgrBucket, params)
+ if err != nil {
+ return err
+ }
+
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ log.Infof("Opened wallet") // TODO: log balance? last sync height?
+
+ deprecated := &walletDeprecated{
+ lockedOutpoints: map[wire.OutPoint]struct{}{},
+ publicPassphrase: pubPass,
+ db: db,
+ recoveryWindow: recoveryWindow,
+ rescanAddJob: make(chan *RescanJob),
+ rescanBatch: make(chan *rescanBatch),
+ rescanNotifications: make(chan interface{}),
+ rescanProgress: make(chan *RescanProgressMsg),
+ rescanFinished: make(chan *RescanFinishedMsg),
+ createTxRequests: make(chan createTxRequest),
+ unlockRequests: make(chan unlockRequest),
+ lockRequests: make(chan struct{}),
+ holdUnlockRequests: make(chan chan heldUnlock),
+ lockState: make(chan bool),
+ changePassphrase: make(chan changePassphraseRequest),
+ changePassphrases: make(chan changePassphrasesRequest),
+ chainParams: params,
+ quit: make(chan struct{}),
+ syncRetryInterval: syncRetryInterval,
+ }
+
+ w := &Wallet{
+ addrStore: addrMgr,
+ txStore: txMgr,
+ walletDeprecated: deprecated,
+ }
+
+ w.NtfnServer = newNotificationServer(w)
+ txMgr.NotifyUnspent = func(hash *chainhash.Hash, index uint32) {
+ w.NtfnServer.notifyUnspentOutput(0, hash, index)
+ }
+
+ return w, nil
+}
+
+// RecoveryManager maintains the state required to recover previously used
+// addresses, and coordinates batched processing of the blocks to search.
+//
+// TODO(yy): Deprecated, remove.
+type RecoveryManager struct {
+ // recoveryWindow defines the key-derivation lookahead used when
+ // attempting to recover the set of used addresses.
+ recoveryWindow uint32
+
+ // started is true after the first block has been added to the batch.
+ started bool
+
+ // blockBatch contains a list of blocks that have not yet been searched
+ // for recovered addresses.
+ blockBatch []wtxmgr.BlockMeta
+
+ // state encapsulates and allocates the necessary recovery state for all
+ // key scopes and subsidiary derivation paths.
+ state *RecoveryState
+
+ // chainParams are the parameters that describe the chain we're trying
+ // to recover funds on.
+ chainParams *chaincfg.Params
+}
+
+// NewRecoveryManager initializes a new RecoveryManager with a derivation
+// look-ahead of `recoveryWindow` child indexes, and pre-allocates a backing
+// array for `batchSize` blocks to scan at once.
+//
+// TODO(yy): Deprecated, remove.
+func NewRecoveryManager(recoveryWindow, batchSize uint32,
+ chainParams *chaincfg.Params) *RecoveryManager {
+
+ return &RecoveryManager{
+ recoveryWindow: recoveryWindow,
+ blockBatch: make([]wtxmgr.BlockMeta, 0, batchSize),
+ chainParams: chainParams,
+ state: NewRecoveryState(
+ recoveryWindow, chainParams, nil,
+ ),
+ }
+}
+
+// Resurrect restores all known addresses for the provided scopes that can be
+// found in the walletdb namespace, in addition to restoring all outpoints that
+// have been previously found. This method ensures that the recovery state's
+// horizons properly start from the last found address of a prior recovery
+// attempt.
+//
+// TODO(yy): Deprecated, remove.
+func (rm *RecoveryManager) Resurrect(ns walletdb.ReadBucket,
+ scopedMgrs map[waddrmgr.KeyScope]waddrmgr.AccountStore,
+ credits []wtxmgr.Credit) error {
+
+ // First, for each scope that we are recovering, rederive all of the
+ // addresses up to the last found address known to each branch.
+ for keyScope, scopedMgr := range scopedMgrs {
+ // Load the current account properties for this scope, using the
+ // the default account number.
+ // TODO(conner): rescan for all created accounts if we allow
+ // users to use non-default address
+ scopeState := rm.state.StateForScope(keyScope)
+ acctProperties, err := scopedMgr.AccountProperties(
+ ns, waddrmgr.DefaultAccountNum,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Fetch the external key count, which bounds the indexes we
+ // will need to rederive.
+ externalCount := acctProperties.ExternalKeyCount
+
+ // Walk through all indexes through the last external key,
+ // deriving each address and adding it to the external branch
+ // recovery state's set of addresses to look for.
+ for i := uint32(0); i < externalCount; i++ {
+ keyPath := externalKeyPath(i)
+ addr, err := scopedMgr.DeriveFromKeyPath(ns, keyPath)
+ if err != nil && err != hdkeychain.ErrInvalidChild {
+ return err
+ } else if err == hdkeychain.ErrInvalidChild {
+ scopeState.ExternalBranch.MarkInvalidChild(i)
+ continue
+ }
+
+ scopeState.ExternalBranch.AddAddr(i, addr.Address())
+ }
+
+ // Fetch the internal key count, which bounds the indexes we
+ // will need to rederive.
+ internalCount := acctProperties.InternalKeyCount
+
+ // Walk through all indexes through the last internal key,
+ // deriving each address and adding it to the internal branch
+ // recovery state's set of addresses to look for.
+ for i := uint32(0); i < internalCount; i++ {
+ keyPath := internalKeyPath(i)
+ addr, err := scopedMgr.DeriveFromKeyPath(ns, keyPath)
+ if err != nil && err != hdkeychain.ErrInvalidChild {
+ return err
+ } else if err == hdkeychain.ErrInvalidChild {
+ scopeState.InternalBranch.MarkInvalidChild(i)
+ continue
+ }
+
+ scopeState.InternalBranch.AddAddr(i, addr.Address())
+ }
+
+ // The key counts will point to the next key that can be
+ // derived, so we subtract one to point to last known key. If
+ // the key count is zero, then no addresses have been found.
+ if externalCount > 0 {
+ scopeState.ExternalBranch.ReportFound(externalCount - 1)
+ }
+ if internalCount > 0 {
+ scopeState.InternalBranch.ReportFound(internalCount - 1)
+ }
+ }
+
+ // In addition, we will re-add any outpoints that are known the wallet
+ // to our global set of watched outpoints, so that we can watch them for
+ // spends.
+ for _, credit := range credits {
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(
+ credit.PkScript, rm.chainParams,
+ )
+ if err != nil {
+ return err
+ }
+
+ rm.state.AddWatchedOutPoint(&credit.OutPoint, addrs[0])
+ }
+
+ return nil
+}
+
+// AddToBlockBatch appends the block information, consisting of hash and height,
+// to the batch of blocks to be searched.
+//
+// TODO(yy): Deprecated, remove.
+func (rm *RecoveryManager) AddToBlockBatch(hash *chainhash.Hash, height int32,
+ timestamp time.Time) {
+
+ if !rm.started {
+ log.Infof("Seed birthday surpassed, starting recovery "+
+ "of wallet from height=%d hash=%v with "+
+ "recovery-window=%d", height, *hash, rm.recoveryWindow)
+ rm.started = true
+ }
+
+ block := wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ Hash: *hash,
+ Height: height,
+ },
+ Time: timestamp,
+ }
+ rm.blockBatch = append(rm.blockBatch, block)
+}
+
+// BlockBatch returns a buffer of blocks that have not yet been searched.
+//
+// TODO(yy): Deprecated, remove.
+func (rm *RecoveryManager) BlockBatch() []wtxmgr.BlockMeta {
+ return rm.blockBatch
+}
+
+// ResetBlockBatch resets the internal block buffer to conserve memory.
+//
+// TODO(yy): Deprecated, remove.
+func (rm *RecoveryManager) ResetBlockBatch() {
+ rm.blockBatch = rm.blockBatch[:0]
+}
+
+// State returns the current RecoveryState.
+//
+// TODO(yy): Deprecated, remove.
+func (rm *RecoveryManager) State() *RecoveryState {
+ return rm.state
+}
+
+// ScopeRecoveryState is used to manage the recovery of addresses generated
+// under a particular BIP32 account. Each account tracks both an external and
+// internal branch recovery state, both of which use the same recovery window.
+//
+// TODO(yy): Deprecated, remove.
+type ScopeRecoveryState struct {
+ // ExternalBranch is the recovery state of addresses generated for
+ // external use, i.e. receiving addresses.
+ ExternalBranch *BranchRecoveryState
+
+ // InternalBranch is the recovery state of addresses generated for
+ // internal use, i.e. change addresses.
+ InternalBranch *BranchRecoveryState
+}
+
+// NewScopeRecoveryState initializes an ScopeRecoveryState with the chosen
+// recovery window.
+//
+// TODO(yy): Deprecated, remove.
+func NewScopeRecoveryState(recoveryWindow uint32) *ScopeRecoveryState {
+ return &ScopeRecoveryState{
+ ExternalBranch: NewBranchRecoveryState(recoveryWindow, nil),
+ InternalBranch: NewBranchRecoveryState(recoveryWindow, nil),
+ }
+}
+
+const (
+ // WalletDBName specified the database filename for the wallet.
+ WalletDBName = "wallet.db"
+
+ // DefaultDBTimeout is the default timeout value when opening the wallet
+ // database.
+ DefaultDBTimeout = 60 * time.Second
+)
+
+var (
+ // ErrLoaded describes the error condition of attempting to load or
+ // create a wallet when the loader has already done so.
+ ErrLoaded = errors.New("wallet already loaded")
+
+ // ErrNotLoaded describes the error condition of attempting to close a
+ // loaded wallet when a wallet has not been loaded.
+ ErrNotLoaded = errors.New("wallet is not loaded")
+
+ // ErrExists describes the error condition of attempting to create a new
+ // wallet when one exists already.
+ ErrExists = errors.New("wallet already exists")
+)
+
+// loaderConfig contains the configuration options for the loader.
+type loaderConfig struct {
+ walletSyncRetryInterval time.Duration
+}
+
+// defaultLoaderConfig returns the default configuration options for the loader.
+func defaultLoaderConfig() *loaderConfig {
+ return &loaderConfig{
+ walletSyncRetryInterval: defaultSyncRetryInterval,
+ }
+}
+
+// LoaderOption is a configuration option for the loader.
+type LoaderOption func(*loaderConfig)
+
+// WithWalletSyncRetryInterval specifies the interval at which the wallet
+// should retry syncing to the chain if it encounters an error.
+func WithWalletSyncRetryInterval(interval time.Duration) LoaderOption {
+ return func(c *loaderConfig) {
+ c.walletSyncRetryInterval = interval
+ }
+}
+
+// Loader implements the creating of new and opening of existing wallets, while
+// providing a callback system for other subsystems to handle the loading of a
+// wallet. This is primarily intended for use by the RPC servers, to enable
+// methods and services which require the wallet when the wallet is loaded by
+// another subsystem.
+//
+// Loader is safe for concurrent access.
+type Loader struct {
+ cfg *loaderConfig
+ callbacks []func(*Wallet)
+ chainParams *chaincfg.Params
+ dbDirPath string
+ noFreelistSync bool
+ timeout time.Duration
+ recoveryWindow uint32
+ wallet *Wallet
+ localDB bool
+ walletExists func() (bool, error)
+ walletCreated func(db walletdb.ReadWriteTx) error
+ db walletdb.DB
+ mu sync.Mutex
+}
+
+// NewLoader constructs a Loader with an optional recovery window. If the
+// recovery window is non-zero, the wallet will attempt to recovery addresses
+// starting from the last SyncedTo height.
+func NewLoader(chainParams *chaincfg.Params, dbDirPath string,
+ noFreelistSync bool, timeout time.Duration, recoveryWindow uint32,
+ opts ...LoaderOption) *Loader {
+
+ cfg := defaultLoaderConfig()
+ for _, opt := range opts {
+ opt(cfg)
+ }
+
+ return &Loader{
+ cfg: cfg,
+ chainParams: chainParams,
+ dbDirPath: dbDirPath,
+ noFreelistSync: noFreelistSync,
+ timeout: timeout,
+ recoveryWindow: recoveryWindow,
+ localDB: true,
+ }
+}
+
+// NewLoaderWithDB constructs a Loader with an externally provided DB. This way
+// users are free to use their own walletdb implementation (eg. leveldb, etcd)
+// to store the wallet. Given that the external DB may be shared an additional
+// function is also passed which will override Loader.WalletExists().
+func NewLoaderWithDB(chainParams *chaincfg.Params, recoveryWindow uint32,
+ db walletdb.DB, walletExists func() (bool, error),
+ opts ...LoaderOption) (*Loader, error) {
+
+ if db == nil {
+ return nil, fmt.Errorf("no DB provided")
+ }
+
+ if walletExists == nil {
+ return nil, fmt.Errorf("unable to check if wallet exists")
+ }
+
+ cfg := defaultLoaderConfig()
+ for _, opt := range opts {
+ opt(cfg)
+ }
+
+ return &Loader{
+ cfg: cfg,
+ chainParams: chainParams,
+ recoveryWindow: recoveryWindow,
+ localDB: false,
+ walletExists: walletExists,
+ db: db,
+ }, nil
+}
+
+// onLoaded executes each added callback and prevents loader from loading any
+// additional wallets. Requires mutex to be locked.
+func (l *Loader) onLoaded(w *Wallet) {
+ for _, fn := range l.callbacks {
+ fn(w)
+ }
+
+ l.wallet = w
+ l.callbacks = nil // not needed anymore
+}
+
+// RunAfterLoad adds a function to be executed when the loader creates or opens
+// a wallet. Functions are executed in a single goroutine in the order they are
+// added.
+func (l *Loader) RunAfterLoad(fn func(*Wallet)) {
+ l.mu.Lock()
+ if l.wallet != nil {
+ w := l.wallet
+ l.mu.Unlock()
+ fn(w)
+ } else {
+ l.callbacks = append(l.callbacks, fn)
+ l.mu.Unlock()
+ }
+}
+
+// OnWalletCreated adds a function that will be executed the wallet structure
+// is initialized in the wallet database. This is useful if users want to add
+// extra fields in the same transaction (eg. to flag wallet existence).
+func (l *Loader) OnWalletCreated(fn func(walletdb.ReadWriteTx) error) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ l.walletCreated = fn
+}
+
+// CreateNewWallet creates a new wallet using the provided public and private
+// passphrases. The seed is optional. If non-nil, addresses are derived from
+// this seed. If nil, a secure random seed is generated.
+func (l *Loader) CreateNewWallet(pubPassphrase, privPassphrase, seed []byte,
+ bday time.Time) (*Wallet, error) {
+
+ var (
+ rootKey *hdkeychain.ExtendedKey
+ err error
+ )
+
+ // If a seed was specified, we check its length now. If no seed is
+ // passed, the wallet will create a new random one.
+ if seed != nil {
+ if len(seed) < hdkeychain.MinSeedBytes ||
+ len(seed) > hdkeychain.MaxSeedBytes {
+
+ return nil, hdkeychain.ErrInvalidSeedLen
+ }
+
+ // Derive the master extended key from the seed.
+ rootKey, err = hdkeychain.NewMaster(seed, l.chainParams)
+ if err != nil {
+ return nil, fmt.Errorf("failed to derive master " +
+ "extended key")
+ }
+ }
+
+ return l.createNewWallet(
+ pubPassphrase, privPassphrase, rootKey, bday, false,
+ )
+}
+
+// CreateNewWalletExtendedKey creates a new wallet from an extended master root
+// key using the provided public and private passphrases. The root key is
+// optional. If non-nil, addresses are derived from this root key. If nil, a
+// secure random seed is generated and the root key is derived from that.
+func (l *Loader) CreateNewWalletExtendedKey(pubPassphrase, privPassphrase []byte,
+ rootKey *hdkeychain.ExtendedKey, bday time.Time) (*Wallet, error) {
+
+ return l.createNewWallet(
+ pubPassphrase, privPassphrase, rootKey, bday, false,
+ )
+}
+
+// CreateNewWatchingOnlyWallet creates a new wallet using the provided
+// public passphrase. No seed or private passphrase may be provided
+// since the wallet is watching-only.
+func (l *Loader) CreateNewWatchingOnlyWallet(pubPassphrase []byte,
+ bday time.Time) (*Wallet, error) {
+
+ return l.createNewWallet(
+ pubPassphrase, nil, nil, bday, true,
+ )
+}
+
+func (l *Loader) createNewWallet(pubPassphrase, privPassphrase []byte,
+ rootKey *hdkeychain.ExtendedKey, bday time.Time,
+ isWatchingOnly bool) (*Wallet, error) {
+
+ defer l.mu.Unlock()
+ l.mu.Lock()
+
+ if l.wallet != nil {
+ return nil, ErrLoaded
+ }
+
+ exists, err := l.WalletExists()
+ if err != nil {
+ return nil, err
+ }
+ if exists {
+ return nil, ErrExists
+ }
+
+ if l.localDB {
+ dbPath := filepath.Join(l.dbDirPath, WalletDBName)
+
+ // Create the wallet database backed by bolt db.
+ err = os.MkdirAll(l.dbDirPath, 0700)
+ if err != nil {
+ return nil, err
+ }
+ l.db, err = walletdb.Create(
+ "bdb", dbPath, l.noFreelistSync, l.timeout, false,
+ )
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Initialize the newly created database for the wallet before opening.
+ if isWatchingOnly {
+ err := CreateWatchingOnlyWithCallback(
+ l.db, pubPassphrase, l.chainParams, bday,
+ l.walletCreated,
+ )
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ err := CreateWithCallback(
+ l.db, pubPassphrase, privPassphrase, rootKey,
+ l.chainParams, bday, l.walletCreated,
+ )
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Open the newly-created wallet.
+ w, err := OpenWithRetry(
+ l.db, pubPassphrase, nil, l.chainParams, l.recoveryWindow,
+ l.cfg.walletSyncRetryInterval,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ l.onLoaded(w)
+ return w, nil
+}
+
+var errNoConsole = errors.New("db upgrade requires console access for additional input")
+
+func noConsole() ([]byte, error) {
+ return nil, errNoConsole
+}
+
+// OpenExistingWallet opens the wallet from the loader's wallet database path
+// and the public passphrase. If the loader is being called by a context where
+// standard input prompts may be used during wallet upgrades, setting
+// canConsolePrompt will enables these prompts.
+func (l *Loader) OpenExistingWallet(pubPassphrase []byte,
+ canConsolePrompt bool) (*Wallet, error) {
+
+ defer l.mu.Unlock()
+ l.mu.Lock()
+
+ if l.wallet != nil {
+ return nil, ErrLoaded
+ }
+
+ if l.localDB {
+ var err error
+ // Ensure that the network directory exists.
+ if err = checkCreateDir(l.dbDirPath); err != nil {
+ return nil, err
+ }
+
+ // Open the database using the boltdb backend.
+ dbPath := filepath.Join(l.dbDirPath, WalletDBName)
+ l.db, err = walletdb.Open(
+ "bdb", dbPath, l.noFreelistSync, l.timeout, false,
+ )
+ if err != nil {
+ log.Errorf("Failed to open database: %v", err)
+ return nil, err
+ }
+ }
+
+ var cbs *waddrmgr.OpenCallbacks
+ if canConsolePrompt {
+ cbs = &waddrmgr.OpenCallbacks{
+ ObtainSeed: prompt.ProvideSeed,
+ ObtainPrivatePass: prompt.ProvidePrivPassphrase,
+ }
+ } else {
+ cbs = &waddrmgr.OpenCallbacks{
+ ObtainSeed: noConsole,
+ ObtainPrivatePass: noConsole,
+ }
+ }
+ w, err := OpenWithRetry(
+ l.db, pubPassphrase, cbs, l.chainParams, l.recoveryWindow,
+ l.cfg.walletSyncRetryInterval,
+ )
+ if err != nil {
+ // If opening the wallet fails (e.g. because of wrong
+ // passphrase), we must close the backing database to
+ // allow future calls to walletdb.Open().
+ if l.localDB {
+ e := l.db.Close()
+ if e != nil {
+ log.Warnf("Error closing database: %v", e)
+ }
+ }
+
+ return nil, err
+ }
+
+ w.StartDeprecated()
+
+ l.onLoaded(w)
+ return w, nil
+}
+
+// WalletExists returns whether a file exists at the loader's database path.
+// This may return an error for unexpected I/O failures.
+func (l *Loader) WalletExists() (bool, error) {
+ if l.localDB {
+ dbPath := filepath.Join(l.dbDirPath, WalletDBName)
+ return fileExists(dbPath)
+ }
+
+ return l.walletExists()
+}
+
+// LoadedWallet returns the loaded wallet, if any, and a bool for whether the
+// wallet has been loaded or not. If true, the wallet pointer should be safe to
+// dereference.
+func (l *Loader) LoadedWallet() (*Wallet, bool) {
+ l.mu.Lock()
+ w := l.wallet
+ l.mu.Unlock()
+ return w, w != nil
+}
+
+// UnloadWallet stops the loaded wallet, if any, and closes the wallet database.
+// This returns ErrNotLoaded if the wallet has not been loaded with
+// CreateNewWallet or LoadExistingWallet. The Loader may be reused if this
+// function returns without error.
+func (l *Loader) UnloadWallet() error {
+ defer l.mu.Unlock()
+ l.mu.Lock()
+
+ if l.wallet == nil {
+ return ErrNotLoaded
+ }
+
+ l.wallet.StopDeprecated()
+ l.wallet.WaitForShutdown()
+ if l.localDB {
+ err := l.db.Close()
+ if err != nil {
+ return err
+ }
+ }
+
+ l.wallet = nil
+ l.db = nil
+ return nil
+}
+
+func fileExists(filePath string) (bool, error) {
+ _, err := os.Stat(filePath)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return false, nil
+ }
+ return false, err
+ }
+ return true, nil
+}
diff --git a/wallet/deprecated_test.go b/wallet/deprecated_test.go
new file mode 100644
index 0000000000..89a4dba38c
--- /dev/null
+++ b/wallet/deprecated_test.go
@@ -0,0 +1,2292 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wallet
+
+import (
+ "bytes"
+ "encoding/hex"
+ "fmt"
+ "math"
+ "reflect"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/psbt/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/wallet/txauthor"
+ "github.com/btcsuite/btcwallet/wallet/txrules"
+ "github.com/btcsuite/btcwallet/wallet/txsizes"
+ "github.com/btcsuite/btcwallet/walletdb"
+ _ "github.com/btcsuite/btcwallet/walletdb/bdb"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/sync/errgroup"
+)
+
+// TestLabelTransaction tests labelling of transactions with invalid labels,
+// and failure to label a transaction when it already has a label.
+func TestLabelTransaction(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+
+ // Whether the transaction should be known to the wallet.
+ txKnown bool
+
+ // Whether the test should write an existing label to disk.
+ existingLabel bool
+
+ // The overwrite parameter to call label transaction with.
+ overwrite bool
+
+ // The error we expect to be returned.
+ expectedErr error
+ }{
+ {
+ name: "existing label, not overwrite",
+ txKnown: true,
+ existingLabel: true,
+ overwrite: false,
+ expectedErr: ErrTxLabelExists,
+ },
+ {
+ name: "existing label, overwritten",
+ txKnown: true,
+ existingLabel: true,
+ overwrite: true,
+ expectedErr: nil,
+ },
+ {
+ name: "no prexisting label, ok",
+ txKnown: true,
+ existingLabel: false,
+ overwrite: false,
+ expectedErr: nil,
+ },
+ {
+ name: "transaction unknown",
+ txKnown: false,
+ existingLabel: false,
+ overwrite: false,
+ expectedErr: ErrUnknownTransaction,
+ },
+ }
+
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ w := testWallet(t)
+
+ // If the transaction should be known to the store, we
+ // write txdetail to disk.
+ if test.txKnown {
+ rec, err := wtxmgr.NewTxRecord(
+ TstSerializedTx, time.Now(),
+ )
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = walletdb.Update(w.db,
+ func(tx walletdb.ReadWriteTx) error {
+
+ ns := tx.ReadWriteBucket(
+ wtxmgrNamespaceKey,
+ )
+
+ return w.txStore.InsertTx(
+ ns, rec, nil,
+ )
+ })
+ if err != nil {
+ t.Fatalf("could not insert tx: %v", err)
+ }
+ }
+
+ // If we want to setup an existing label for the purpose
+ // of the test, write one to disk.
+ if test.existingLabel {
+ err := w.LabelTransaction(
+ *TstTxHash, "existing label", false,
+ )
+ if err != nil {
+ t.Fatalf("could not write label: %v",
+ err)
+ }
+ }
+
+ newLabel := "new label"
+ err := w.LabelTransaction(
+ *TstTxHash, newLabel, test.overwrite,
+ )
+ if err != test.expectedErr {
+ t.Fatalf("expected: %v, got: %v",
+ test.expectedErr, err)
+ }
+ })
+ }
+}
+
+// TestGetTransaction tests if we can fetch a mined, an existing
+// and a non-existing transaction from the wallet like we expect.
+func TestGetTransaction(t *testing.T) {
+ t.Parallel()
+ rec, err := wtxmgr.NewTxRecord(TstSerializedTx, time.Now())
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+
+ // Transaction id.
+ txid chainhash.Hash
+
+ // Expected height.
+ expectedHeight int32
+
+ // Store function.
+ f func(wtxmgr.TxStore,
+ walletdb.ReadWriteBucket) (wtxmgr.TxStore, error)
+
+ // The error we expect to be returned.
+ expectedErr error
+ }{
+ {
+ name: "existing unmined transaction",
+ txid: *TstTxHash,
+ expectedHeight: -1,
+ // We write txdetail for the tx to disk.
+ f: func(s wtxmgr.TxStore, ns walletdb.ReadWriteBucket) (
+ wtxmgr.TxStore, error) {
+
+ err = s.InsertTx(ns, rec, nil)
+ return s, err
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "existing mined transaction",
+ txid: *TstTxHash,
+ // We write txdetail for the tx to disk.
+ f: func(s wtxmgr.TxStore, ns walletdb.ReadWriteBucket) (
+ wtxmgr.TxStore, error) {
+
+ err = s.InsertTx(ns, rec, TstMinedSignedTxBlockDetails)
+ return s, err
+ },
+ expectedHeight: TstMinedTxBlockHeight,
+ expectedErr: nil,
+ },
+ {
+ name: "non-existing transaction",
+ txid: *TstTxHash,
+ // Write no txdetail to disk.
+ f: func(s wtxmgr.TxStore, _ walletdb.ReadWriteBucket) (
+ wtxmgr.TxStore, error) {
+
+ return s, nil
+ },
+ expectedErr: ErrNoTx,
+ },
+ }
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ w := testWallet(t)
+
+ err := walletdb.Update(w.db, func(rw walletdb.ReadWriteTx) error {
+ ns := rw.ReadWriteBucket(wtxmgrNamespaceKey)
+ _, err := test.f(w.txStore, ns)
+ return err
+ })
+ require.NoError(t, err)
+ tx, err := w.GetTransaction(test.txid)
+ require.ErrorIs(t, err, test.expectedErr)
+
+ // Discontinue if no transaction were found.
+ if err != nil {
+ return
+ }
+
+ // Check if we get the expected hash.
+ require.Equal(t, &test.txid, tx.Summary.Hash)
+
+ // Check the block height.
+ require.Equal(t, test.expectedHeight, tx.Height)
+ })
+ }
+}
+
+// TestGetTransactionConfirmations tests that GetTransaction correctly
+// calculates confirmations for both confirmed and unconfirmed transactions.
+// This is a regression test for a bug where confirmations were set to the
+// block height instead of being calculated as currentHeight - blockHeight + 1.
+//
+// The bug had several negative impacts:
+// - Unconfirmed transactions showed -1 confirmations instead of 0, breaking
+// zero-conf (accepting transactions before block inclusion)
+// - Confirmed transactions showed block height instead of actual confirmation
+// count
+// - LND and other consumers would make incorrect decisions based on wrong
+// counts
+func TestGetTransactionConfirmations(t *testing.T) {
+ t.Parallel()
+
+ rec, err := wtxmgr.NewTxRecord(TstSerializedTx, time.Now())
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+
+ // Block height where transaction is mined (-1 for unmined).
+ txBlockHeight int32
+
+ // Current wallet sync height.
+ currentHeight int32
+
+ // Expected confirmations.
+ expectedConfirmations int32
+
+ // Expected height in result.
+ expectedHeight int32
+
+ // Whether to check for non-zero timestamp.
+ expectTimestamp bool
+ }{
+ {
+ name: "unconfirmed tx",
+ txBlockHeight: -1,
+ currentHeight: 100,
+ expectedConfirmations: 0,
+ expectedHeight: -1,
+ expectTimestamp: false,
+ },
+ {
+ name: "tx with 1 confirmation",
+ txBlockHeight: 100,
+ currentHeight: 100,
+ expectedConfirmations: 1,
+ expectedHeight: 100,
+ expectTimestamp: true,
+ },
+ {
+ name: "tx with 3 confirmations",
+ txBlockHeight: 8,
+ currentHeight: 10,
+ expectedConfirmations: 3,
+ expectedHeight: 8,
+ expectTimestamp: true,
+ },
+ {
+ name: "old tx with many confirmations",
+ txBlockHeight: 1,
+ currentHeight: 1000,
+ expectedConfirmations: 1000,
+ expectedHeight: 1,
+ expectTimestamp: true,
+ },
+ {
+ name: "tx in future block",
+ txBlockHeight: 105,
+ currentHeight: 100,
+ expectedConfirmations: 0,
+ expectedHeight: 105,
+ expectTimestamp: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ w := testWallet(t)
+
+ // Set the wallet's synced height.
+ err := walletdb.Update(
+ w.db, func(tx walletdb.ReadWriteTx) error {
+ addrmgrNs := tx.ReadWriteBucket(
+ waddrmgrNamespaceKey,
+ )
+ bs := &waddrmgr.BlockStamp{
+ Height: tt.currentHeight,
+ Hash: chainhash.Hash{},
+ }
+
+ return w.addrStore.SetSyncedTo(
+ addrmgrNs, bs,
+ )
+ },
+ )
+ require.NoError(t, err)
+
+ // Insert transaction into wallet.
+ err = walletdb.Update(
+ w.db, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(
+ wtxmgrNamespaceKey,
+ )
+
+ // Create block metadata if transaction
+ // is mined.
+ var blockMeta *wtxmgr.BlockMeta
+ if tt.txBlockHeight != -1 {
+ hash := chainhash.Hash{}
+ height := tt.txBlockHeight
+ block := wtxmgr.Block{
+ Hash: hash,
+ Height: height,
+ }
+ blockMeta = &wtxmgr.BlockMeta{
+ Block: block,
+ Time: time.Now(),
+ }
+ }
+
+ return w.txStore.InsertTx(
+ ns, rec, blockMeta,
+ )
+ },
+ )
+ require.NoError(t, err)
+
+ result, err := w.GetTransaction(*TstTxHash)
+ require.NoError(t, err)
+
+ require.Equal(
+ t, tt.expectedConfirmations,
+ result.Confirmations,
+ )
+
+ require.Equal(t, tt.expectedHeight, result.Height)
+
+ if tt.expectTimestamp {
+ require.NotZero(t, result.Timestamp)
+ } else {
+ require.Zero(t, result.Timestamp)
+ }
+
+ // Additional checks for unconfirmed transactions.
+ if tt.txBlockHeight == -1 {
+ require.Nil(t, result.BlockHash)
+ require.Equal(t, int32(0), result.Confirmations)
+ } else {
+ require.NotNil(t, result.BlockHash)
+ // Only expect positive confirmations when tx is
+ // not in a future block.
+ if tt.txBlockHeight <= tt.currentHeight {
+ require.Positive(
+ t, result.Confirmations,
+ )
+ } else {
+ // Confirmed txns in future blocks for
+ // example due to reorg should be
+ // treated as unconfirmed and have 0
+ // confirmations.
+ require.Equal(
+ t, int32(0),
+ result.Confirmations,
+ )
+ }
+ }
+ })
+ }
+}
+
+// TestDuplicateAddressDerivation tests that duplicate addresses are not
+// derived when multiple goroutines are concurrently requesting new addresses.
+func TestDuplicateAddressDerivation(t *testing.T) {
+ w := testWallet(t)
+ var (
+ m sync.Mutex
+ globalAddrs = make(map[string]address.Address)
+ )
+
+ for o := 0; o < 10; o++ {
+ var eg errgroup.Group
+
+ for n := 0; n < 10; n++ {
+ eg.Go(func() error {
+ addrs := make([]address.Address, 10)
+ for i := 0; i < 10; i++ {
+ addr, err := w.NewAddressDeprecated(
+ 0, waddrmgr.KeyScopeBIP0084,
+ )
+ if err != nil {
+ return err
+ }
+
+ addrs[i] = addr
+ }
+
+ m.Lock()
+ defer m.Unlock()
+
+ for idx := range addrs {
+ addrStr := addrs[idx].String()
+ if a, ok := globalAddrs[addrStr]; ok {
+ return fmt.Errorf("duplicate "+
+ "address! already "+
+ "have %v, want to "+
+ "add %v", a, addrs[idx])
+ }
+
+ globalAddrs[addrStr] = addrs[idx]
+ }
+
+ return nil
+ })
+ }
+
+ require.NoError(t, eg.Wait())
+ }
+}
+
+func TestEndRecovery(t *testing.T) {
+ // This is an unconventional unit test, but I'm trying to keep things as
+ // succint as possible so that this test is readable without having to mock
+ // up literally everything.
+ // The unmonitored goroutine we're looking at is pretty deep:
+ // SynchronizeRPC -> handleChainNotifications -> syncWithChain -> recovery
+ // The "deadlock" we're addressing isn't actually a deadlock, but the wallet
+ // will hang on Stop() -> WaitForShutdown() until (*Wallet).recovery gets
+ // every single block, which could be hours depending on hardware and
+ // network factors. The WaitGroup is incremented in SynchronizeRPC, and
+ // WaitForShutdown will not return until handleChainNotifications returns,
+ // which is blocked by a running (*Wallet).recovery loop.
+ // It is noted that the conditions for long recovery are difficult to hit
+ // when using btcwallet with a fresh seed, because it requires an early
+ // birthday to be set or established.
+
+ w := testWallet(t)
+
+ blockHashCalled := make(chan struct{})
+
+ chainClient := &mockChainClient{
+ // Force the loop to iterate about forever.
+ getBestBlockHeight: math.MaxInt32,
+ // Get control of when the loop iterates.
+ getBlockHashFunc: func() (*chainhash.Hash, error) {
+ blockHashCalled <- struct{}{}
+ return &chainhash.Hash{}, nil
+ },
+ // Avoid a panic.
+ getBlockHeader: &wire.BlockHeader{},
+ }
+
+ recoveryDone := make(chan struct{})
+ go func() {
+ defer close(recoveryDone)
+ w.recovery(chainClient, &waddrmgr.BlockStamp{})
+ }()
+
+ getBlockHashCalls := func(expCalls int) {
+ var i int
+ for {
+ select {
+ case <-blockHashCalled:
+ i++
+ case <-time.After(time.Second):
+ t.Fatal("expected BlockHash to be called")
+ }
+ if i == expCalls {
+ break
+ }
+ }
+ }
+
+ // Recovery is running.
+ getBlockHashCalls(3)
+
+ // Closing the quit channel, e.g. Stop() without endRecovery, alone will not
+ // end the recovery loop.
+ w.quitMu.Lock()
+ close(w.quit)
+ w.quitMu.Unlock()
+ // Continues scanning.
+ getBlockHashCalls(3)
+
+ // We're done with this one
+ atomic.StoreUint32(&w.recovering.Load().(*recoverySyncer).quit, 1)
+ select {
+ case <-blockHashCalled:
+ case <-recoveryDone:
+ }
+
+ // Try again.
+ w = testWallet(t)
+
+ // We'll catch the error to make sure we're hitting our desired path. The
+ // WaitGroup isn't required for the test, but does show how it completes
+ // shutdown at a higher level.
+ var err error
+ w.wg.Add(1)
+ recoveryDone = make(chan struct{})
+ go func() {
+ defer w.wg.Done()
+ defer close(recoveryDone)
+ err = w.recovery(chainClient, &waddrmgr.BlockStamp{})
+ }()
+
+ waitedForShutdown := make(chan struct{})
+ go func() {
+ w.WaitForShutdown()
+ close(waitedForShutdown)
+ }()
+
+ // Recovery is running.
+ getBlockHashCalls(3)
+
+ // endRecovery is required to exit the unmonitored goroutine.
+ end := w.endRecovery()
+ select {
+ case <-blockHashCalled:
+ case <-recoveryDone:
+ }
+ <-end
+
+ // testWallet starts a couple of other unrelated goroutines that need to be
+ // killed, so we still need to close the quit channel.
+ w.quitMu.Lock()
+ close(w.quit)
+ w.quitMu.Unlock()
+
+ select {
+ case <-waitedForShutdown:
+ case <-time.After(time.Second):
+ t.Fatal("WaitForShutdown never returned")
+ }
+
+ if !strings.EqualFold(err.Error(), "recovery: forced shutdown") {
+ t.Fatal("wrong error")
+ }
+}
+
+// mockBirthdayStore is a mock in-memory implementation of the birthdayStore interface
+// that will be used for the birthday block sanity check tests.
+type mockBirthdayStore struct {
+ birthday time.Time
+ birthdayBlock *waddrmgr.BlockStamp
+ birthdayBlockVerified bool
+ syncedTo waddrmgr.BlockStamp
+}
+
+var _ birthdayStore = (*mockBirthdayStore)(nil)
+
+// Birthday returns the birthday timestamp of the wallet.
+func (s *mockBirthdayStore) Birthday() time.Time {
+ return s.birthday
+}
+
+// BirthdayBlock returns the birthday block of the wallet.
+func (s *mockBirthdayStore) BirthdayBlock() (waddrmgr.BlockStamp, bool, error) {
+ if s.birthdayBlock == nil {
+ err := waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrBirthdayBlockNotSet,
+ }
+ return waddrmgr.BlockStamp{}, false, err
+ }
+
+ return *s.birthdayBlock, s.birthdayBlockVerified, nil
+}
+
+// SetBirthdayBlock updates the birthday block of the wallet to the given block.
+// The boolean can be used to signal whether this block should be sanity checked
+// the next time the wallet starts.
+func (s *mockBirthdayStore) SetBirthdayBlock(block waddrmgr.BlockStamp) error {
+ s.birthdayBlock = &block
+ s.birthdayBlockVerified = true
+ s.syncedTo = block
+ return nil
+}
+
+// TestBirthdaySanityCheckEmptyBirthdayBlock ensures that a sanity check is not
+// done if the birthday block does not exist in the first place.
+func TestBirthdaySanityCheckEmptyBirthdayBlock(t *testing.T) {
+ t.Parallel()
+
+ chainConn := &mockChainConn{}
+
+ // Our birthday store will reflect that we don't have a birthday block
+ // set, so we should not attempt a sanity check.
+ birthdayStore := &mockBirthdayStore{}
+
+ birthdayBlock, err := birthdaySanityCheck(chainConn, birthdayStore)
+ if !waddrmgr.IsError(err, waddrmgr.ErrBirthdayBlockNotSet) {
+ t.Fatalf("expected ErrBirthdayBlockNotSet, got %v", err)
+ }
+
+ if birthdayBlock != nil {
+ t.Fatalf("expected birthday block to be nil due to not being "+
+ "set, got %v", *birthdayBlock)
+ }
+}
+
+// TestBirthdaySanityCheckVerifiedBirthdayBlock ensures that a sanity check is
+// not performed if the birthday block has already been verified.
+func TestBirthdaySanityCheckVerifiedBirthdayBlock(t *testing.T) {
+ t.Parallel()
+
+ const chainTip = 5000
+ const defaultBlockInterval = 10 * time.Minute
+ chainConn := createMockChainConn(
+ chainParams.GenesisBlock, chainTip, defaultBlockInterval,
+ )
+ expectedBirthdayBlock := waddrmgr.BlockStamp{Height: 1337}
+
+ // Our birthday store reflects that our birthday block has already been
+ // verified and should not require a sanity check.
+ birthdayStore := &mockBirthdayStore{
+ birthdayBlock: &expectedBirthdayBlock,
+ birthdayBlockVerified: true,
+ syncedTo: waddrmgr.BlockStamp{
+ Height: chainTip,
+ },
+ }
+
+ // Now, we'll run the sanity check. We should see that the birthday
+ // block hasn't changed.
+ birthdayBlock, err := birthdaySanityCheck(chainConn, birthdayStore)
+ if err != nil {
+ t.Fatalf("unable to sanity check birthday block: %v", err)
+ }
+ if !reflect.DeepEqual(*birthdayBlock, expectedBirthdayBlock) {
+ t.Fatalf("expected birthday block %v, got %v",
+ expectedBirthdayBlock, birthdayBlock)
+ }
+
+ // To ensure the sanity check didn't proceed, we'll check our synced to
+ // height, as this value should have been modified if a new candidate
+ // was found.
+ if birthdayStore.syncedTo.Height != chainTip {
+ t.Fatalf("expected synced height remain the same (%d), got %d",
+ chainTip, birthdayStore.syncedTo.Height)
+ }
+}
+
+// TestBirthdaySanityCheckLowerEstimate ensures that we can properly locate a
+// better birthday block candidate if our estimate happens to be too far back in
+// the chain.
+func TestBirthdaySanityCheckLowerEstimate(t *testing.T) {
+ t.Parallel()
+
+ const defaultBlockInterval = 10 * time.Minute
+
+ // We'll start by defining our birthday timestamp to be around the
+ // timestamp of the 1337th block.
+ genesisTimestamp := chainParams.GenesisBlock.Header.Timestamp
+ birthday := genesisTimestamp.Add(1337 * defaultBlockInterval)
+
+ // We'll establish a connection to a mock chain of 5000 blocks.
+ chainConn := createMockChainConn(
+ chainParams.GenesisBlock, 5000, defaultBlockInterval,
+ )
+
+ // Our birthday store will reflect that our birthday block is currently
+ // set as the genesis block. This value is too low and should be
+ // adjusted by the sanity check.
+ birthdayStore := &mockBirthdayStore{
+ birthday: birthday,
+ birthdayBlock: &waddrmgr.BlockStamp{
+ Hash: *chainParams.GenesisHash,
+ Height: 0,
+ Timestamp: genesisTimestamp,
+ },
+ birthdayBlockVerified: false,
+ syncedTo: waddrmgr.BlockStamp{
+ Height: 5000,
+ },
+ }
+
+ // We'll perform the sanity check and determine whether we were able to
+ // find a better birthday block candidate.
+ birthdayBlock, err := birthdaySanityCheck(chainConn, birthdayStore)
+ if err != nil {
+ t.Fatalf("unable to sanity check birthday block: %v", err)
+ }
+ if birthday.Sub(birthdayBlock.Timestamp) >= birthdayBlockDelta {
+ t.Fatalf("expected birthday block timestamp=%v to be within "+
+ "%v of birthday timestamp=%v", birthdayBlock.Timestamp,
+ birthdayBlockDelta, birthday)
+ }
+
+ // Finally, our synced to height should now reflect our new birthday
+ // block to ensure the wallet doesn't miss any events from this point
+ // forward.
+ if !reflect.DeepEqual(birthdayStore.syncedTo, *birthdayBlock) {
+ t.Fatalf("expected syncedTo and birthday block to match: "+
+ "%v vs %v", birthdayStore.syncedTo, birthdayBlock)
+ }
+}
+
+// TestBirthdaySanityCheckHigherEstimate ensures that we can properly locate a
+// better birthday block candidate if our estimate happens to be too far in the
+// chain.
+func TestBirthdaySanityCheckHigherEstimate(t *testing.T) {
+ t.Parallel()
+
+ const defaultBlockInterval = 10 * time.Minute
+
+ // We'll start by defining our birthday timestamp to be around the
+ // timestamp of the 1337th block.
+ genesisTimestamp := chainParams.GenesisBlock.Header.Timestamp
+ birthday := genesisTimestamp.Add(1337 * defaultBlockInterval)
+
+ // We'll establish a connection to a mock chain of 5000 blocks.
+ chainConn := createMockChainConn(
+ chainParams.GenesisBlock, 5000, defaultBlockInterval,
+ )
+
+ // Our birthday store will reflect that our birthday block is currently
+ // set as the chain tip. This value is too high and should be adjusted
+ // by the sanity check.
+ bestBlock := chainConn.blocks[chainConn.blockHashes[5000]]
+ birthdayStore := &mockBirthdayStore{
+ birthday: birthday,
+ birthdayBlock: &waddrmgr.BlockStamp{
+ Hash: bestBlock.BlockHash(),
+ Height: 5000,
+ Timestamp: bestBlock.Header.Timestamp,
+ },
+ birthdayBlockVerified: false,
+ syncedTo: waddrmgr.BlockStamp{
+ Height: 5000,
+ },
+ }
+
+ // We'll perform the sanity check and determine whether we were able to
+ // find a better birthday block candidate.
+ birthdayBlock, err := birthdaySanityCheck(chainConn, birthdayStore)
+ if err != nil {
+ t.Fatalf("unable to sanity check birthday block: %v", err)
+ }
+ if birthday.Sub(birthdayBlock.Timestamp) >= birthdayBlockDelta {
+ t.Fatalf("expected birthday block timestamp=%v to be within "+
+ "%v of birthday timestamp=%v", birthdayBlock.Timestamp,
+ birthdayBlockDelta, birthday)
+ }
+
+ // Finally, our synced to height should now reflect our new birthday
+ // block to ensure the wallet doesn't miss any events from this point
+ // forward.
+ if !reflect.DeepEqual(birthdayStore.syncedTo, *birthdayBlock) {
+ t.Fatalf("expected syncedTo and birthday block to match: "+
+ "%v vs %v", birthdayStore.syncedTo, birthdayBlock)
+ }
+}
+
+type testCase struct {
+ name string
+ masterPriv string
+ accountIndex uint32
+ addrType waddrmgr.AddressType
+ expectedScope waddrmgr.KeyScope
+ expectedAddr string
+ expectedChangeAddr string
+}
+
+var (
+ //nolint:lll
+ testCases = []*testCase{{
+ name: "bip44 with nested witness address type",
+ masterPriv: "tprv8ZgxMBicQKsPeWwrFuNjEGTTDSY4mRLwd2KDJAPGa1AY" +
+ "quw38bZqNMSuB3V1Va3hqJBo9Pt8Sx7kBQer5cNMrb8SYquoWPt9" +
+ "Y3BZdhdtUcw",
+ accountIndex: 0,
+ addrType: waddrmgr.NestedWitnessPubKey,
+ expectedScope: waddrmgr.KeyScopeBIP0049Plus,
+ expectedAddr: "2N5YTxG9XtGXx1YyhZb7N2pwEjoZLLMHGKj",
+ expectedChangeAddr: "2N7wpz5Gy2zEJTvq2MAuU6BCTEBLXNQ8dUw",
+ }, {
+ name: "bip44 with witness address type",
+ masterPriv: "tprv8ZgxMBicQKsPeWwrFuNjEGTTDSY4mRLwd2KDJAPGa1AY" +
+ "quw38bZqNMSuB3V1Va3hqJBo9Pt8Sx7kBQer5cNMrb8SYquoWPt9" +
+ "Y3BZdhdtUcw",
+ accountIndex: 777,
+ addrType: waddrmgr.WitnessPubKey,
+ expectedScope: waddrmgr.KeyScopeBIP0084,
+ expectedAddr: "bcrt1qllxcutkzsukf8u8c8stkp464j0esu9xquft3s0",
+ expectedChangeAddr: "bcrt1qu6jmqglrthscptjqj3egx54wy8xqvzn54ex9eh",
+ }, {
+ name: "traditional bip49",
+ masterPriv: "uprv8tXDerPXZ1QsVp8y6GAMSMYxPQgWi3LSY8qS5ZH9x1YRu" +
+ "1kGPFjPzR73CFSbVUhdEwJbtsUgucUJ4hGQoJnNepp3RBcE6Jhdom" +
+ "FD2KeY6G9",
+ accountIndex: 9,
+ addrType: waddrmgr.NestedWitnessPubKey,
+ expectedScope: waddrmgr.KeyScopeBIP0049Plus,
+ expectedAddr: "2NBCJ9WzGXZqpLpXGq3Hacybj3c4eHRcqgh",
+ expectedChangeAddr: "2N3bankFu6F3ZNU41iVJQqyS9MXqp9dvn1M",
+ }, {
+ name: "bip49+",
+ masterPriv: "uprv8tXDerPXZ1QsVp8y6GAMSMYxPQgWi3LSY8qS5ZH9x1YRu" +
+ "1kGPFjPzR73CFSbVUhdEwJbtsUgucUJ4hGQoJnNepp3RBcE6Jhdom" +
+ "FD2KeY6G9",
+ accountIndex: 9,
+ addrType: waddrmgr.WitnessPubKey,
+ expectedScope: waddrmgr.KeyScopeBIP0049Plus,
+ expectedAddr: "2NBCJ9WzGXZqpLpXGq3Hacybj3c4eHRcqgh",
+ expectedChangeAddr: "bcrt1qeqn05w2hfq6axpdprhs4y7x65gxkkvfvx0emz4",
+ }, {
+ name: "bip84",
+ masterPriv: "vprv9DMUxX4ShgxMM7L5vcwyeSeTZNpxefKwTFMerxB3L1vJ" +
+ "x7ZVdutxcUmBDTQBVPMYeaRQeM5FNGpqwysyX1CPT4VeHXJegDX8" +
+ "5VJrQvaFaz3",
+ accountIndex: 1,
+ addrType: waddrmgr.WitnessPubKey,
+ expectedScope: waddrmgr.KeyScopeBIP0084,
+ expectedAddr: "bcrt1q5vepvcl0z8xj7kps4rsux722r4dvfwlh5ntexr",
+ expectedChangeAddr: "bcrt1qlwe2kgxcsa8x4huu79yff4rze0l5mwaf2apn3y",
+ }}
+)
+
+// TestImportAccountDeprecated tests that extended public keys can successfully
+// be imported into both watch only and normal wallets.
+func TestImportAccountDeprecated(t *testing.T) {
+ t.Parallel()
+
+ for _, tc := range testCases {
+ tc := tc
+
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ w := testWallet(t)
+
+ testImportAccount(t, w, tc, false, tc.name)
+ })
+
+ name := tc.name + " watch-only"
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+
+ w := testWalletWatchingOnly(t)
+
+ testImportAccount(t, w, tc, true, name)
+ })
+ }
+}
+
+func testImportAccount(t *testing.T, w *Wallet, tc *testCase, watchOnly bool,
+ name string) {
+
+ // First derive the master public key of the account we want to import.
+ root, err := hdkeychain.NewKeyFromString(tc.masterPriv)
+ require.NoError(t, err)
+
+ // Derive the extended private and public key for our target account.
+ acct1Pub := deriveAcctPubKey(
+ t, root, tc.expectedScope, hardenedKey(tc.accountIndex),
+ )
+
+ // We want to make sure we can import and handle multiple accounts, so
+ // we create another one.
+ acct2Pub := deriveAcctPubKey(
+ t, root, tc.expectedScope, hardenedKey(tc.accountIndex+1),
+ )
+
+ // And we also want to be able to import loose extended public keys
+ // without needing to specify an explicit scope.
+ acct3ExternalExtPub := deriveAcctPubKey(
+ t, root, tc.expectedScope, hardenedKey(tc.accountIndex+2), 0, 0,
+ )
+ acct3ExternalPub, err := acct3ExternalExtPub.ECPubKey()
+ require.NoError(t, err)
+
+ // Do a dry run import first and check that it results in the expected
+ // addresses being derived.
+ _, extAddrs, intAddrs, err := w.ImportAccountDryRun(
+ name+"1", acct1Pub, root.ParentFingerprint(), &tc.addrType, 1,
+ )
+ require.NoError(t, err)
+ require.Len(t, extAddrs, 1)
+ require.Equal(t, tc.expectedAddr, extAddrs[0].Address().String())
+ require.Len(t, intAddrs, 1)
+ require.Equal(t, tc.expectedChangeAddr, intAddrs[0].Address().String())
+
+ // Import the extended public keys into new accounts.
+ acct1, err := w.ImportAccountDeprecated(
+ name+"1", acct1Pub, root.ParentFingerprint(), &tc.addrType,
+ )
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedScope, acct1.KeyScope)
+
+ acct2, err := w.ImportAccountDeprecated(
+ name+"2", acct2Pub, root.ParentFingerprint(), &tc.addrType,
+ )
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedScope, acct2.KeyScope)
+
+ err = w.ImportPublicKeyDeprecated(acct3ExternalPub, tc.addrType)
+ require.NoError(t, err)
+
+ // If the wallet is watch only, there is no default account and our
+ // imported account will be index 0.
+ firstAccountIndex := uint32(1)
+ numAccounts := 2
+ if watchOnly {
+ firstAccountIndex = 0
+ numAccounts = 1
+ }
+
+ // We should have 2 additional accounts now.
+ acctResult, err := w.Accounts(tc.expectedScope)
+ require.NoError(t, err)
+ require.Len(t, acctResult.Accounts, numAccounts+2)
+
+ // Validate the state of the accounts.
+ require.Equal(t, firstAccountIndex, acct1.AccountNumber)
+ require.Equal(t, name+"1", acct1.AccountName)
+ require.Equal(t, true, acct1.IsWatchOnly)
+ require.Equal(t, root.ParentFingerprint(), acct1.MasterKeyFingerprint)
+ require.NotNil(t, acct1.AccountPubKey)
+ require.Equal(t, acct1Pub.String(), acct1.AccountPubKey.String())
+ require.Equal(t, uint32(0), acct1.InternalKeyCount)
+ require.Equal(t, uint32(0), acct1.ExternalKeyCount)
+ require.Equal(t, uint32(0), acct1.ImportedKeyCount)
+
+ require.Equal(t, firstAccountIndex+1, acct2.AccountNumber)
+ require.Equal(t, name+"2", acct2.AccountName)
+ require.Equal(t, true, acct2.IsWatchOnly)
+ require.Equal(t, root.ParentFingerprint(), acct2.MasterKeyFingerprint)
+ require.NotNil(t, acct2.AccountPubKey)
+ require.Equal(t, acct2Pub.String(), acct2.AccountPubKey.String())
+ require.Equal(t, uint32(0), acct2.InternalKeyCount)
+ require.Equal(t, uint32(0), acct2.ExternalKeyCount)
+ require.Equal(t, uint32(0), acct2.ImportedKeyCount)
+
+ // Test address derivation.
+ extAddr, err := w.NewAddressDeprecated(
+ acct1.AccountNumber, tc.expectedScope,
+ )
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedAddr, extAddr.String())
+ intAddr, err := w.NewChangeAddress(acct1.AccountNumber, tc.expectedScope)
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedChangeAddr, intAddr.String())
+
+ // Make sure the key count was increased.
+ acct1, err = w.AccountProperties(tc.expectedScope, acct1.AccountNumber)
+ require.NoError(t, err)
+ require.Equal(t, uint32(1), acct1.InternalKeyCount)
+ require.Equal(t, uint32(1), acct1.ExternalKeyCount)
+ require.Equal(t, uint32(0), acct1.ImportedKeyCount)
+
+ // Make sure we can't get private keys for the imported
+ // accounts.
+ _, err = w.DumpWIFPrivateKey(intAddr)
+ require.True(t, waddrmgr.IsError(err, waddrmgr.ErrWatchingOnly))
+
+ // Get the address info for the single key we imported.
+ switch tc.addrType {
+ case waddrmgr.NestedWitnessPubKey:
+ witnessAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(acct3ExternalPub.SerializeCompressed()),
+ w.chainParams,
+ )
+ require.NoError(t, err)
+
+ witnessProg, err := txscript.PayToAddrScript(witnessAddr)
+ require.NoError(t, err)
+
+ intAddr, err = address.NewAddressScriptHash(
+ witnessProg, w.chainParams,
+ )
+ require.NoError(t, err)
+
+ case waddrmgr.WitnessPubKey:
+ intAddr, err = address.NewAddressWitnessPubKeyHash(
+ address.Hash160(acct3ExternalPub.SerializeCompressed()),
+ w.chainParams,
+ )
+ require.NoError(t, err)
+
+ default:
+ t.Fatalf("unhandled address type %v", tc.addrType)
+ }
+
+ addrManaged, err := w.AddressInfoDeprecated(intAddr)
+ require.NoError(t, err)
+ require.Equal(t, true, addrManaged.Imported())
+}
+
+// TestCreateWatchingOnly checks that we can construct a watching-only
+// wallet.
+func TestCreateWatchingOnly(t *testing.T) {
+ // Set up a wallet.
+ dir := t.TempDir()
+
+ pubPass := []byte("hello")
+
+ loader := NewLoader(
+ &chaincfg.TestNet3Params, dir, true, defaultDBTimeout, 250,
+ WithWalletSyncRetryInterval(10*time.Millisecond),
+ )
+ _, err := loader.CreateNewWatchingOnlyWallet(pubPass, time.Now())
+ if err != nil {
+ t.Fatalf("unable to create wallet: %v", err)
+ }
+}
+
+// defaultDBTimeout specifies the timeout value when opening the wallet
+// database.
+var defaultDBTimeout = 10 * time.Second
+
+// testWallet creates a test wallet and unlocks it.
+func testWallet(t *testing.T) *Wallet {
+ t.Helper()
+ // Set up a wallet.
+ dir := t.TempDir()
+
+ seed, err := hdkeychain.GenerateSeed(hdkeychain.MinSeedBytes)
+ if err != nil {
+ t.Fatalf("unable to create seed: %v", err)
+ }
+
+ pubPass := []byte("hello")
+ privPass := []byte("world")
+
+ loader := NewLoader(
+ &chainParams, dir, true, defaultDBTimeout, 250,
+ WithWalletSyncRetryInterval(10*time.Millisecond),
+ )
+ w, err := loader.CreateNewWallet(pubPass, privPass, seed, time.Now())
+ if err != nil {
+ t.Fatalf("unable to create wallet: %v", err)
+ }
+
+ chainClient := &mockChainClient{}
+ w.chainClient = chainClient
+
+ // Start the wallet.
+ w.StartDeprecated()
+
+ // Add the shutdown to the test's cleanup process.
+ t.Cleanup(func() {
+ w.StopDeprecated()
+ w.WaitForShutdown()
+ })
+
+ err = w.UnlockDeprecated(privPass, time.After(10*time.Minute))
+ if err != nil {
+ t.Fatalf("unable to unlock wallet: %v", err)
+ }
+
+ return w
+}
+
+// testWalletWatchingOnly creates a test watch only wallet and unlocks it.
+func testWalletWatchingOnly(t *testing.T) *Wallet {
+ t.Helper()
+ // Set up a wallet.
+ dir := t.TempDir()
+
+ pubPass := []byte("hello")
+ loader := NewLoader(
+ &chainParams, dir, true, defaultDBTimeout, 250,
+ WithWalletSyncRetryInterval(10*time.Millisecond),
+ )
+ w, err := loader.CreateNewWatchingOnlyWallet(pubPass, time.Now())
+ if err != nil {
+ t.Fatalf("unable to create wallet: %v", err)
+ }
+ chainClient := &mockChainClient{}
+ w.chainClient = chainClient
+
+ err = walletdb.Update(w.Database(), func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
+ for scope, schema := range waddrmgr.ScopeAddrMap {
+ _, err := w.addrStore.NewScopedKeyManager(
+ ns, scope, schema,
+ )
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+ if err != nil {
+ t.Fatalf("unable to create default scopes: %v", err)
+ }
+
+ w.StartDeprecated()
+ t.Cleanup(func() {
+ w.StopDeprecated()
+ w.WaitForShutdown()
+ })
+
+ return w
+}
+
+var (
+ testScriptP2WSH, _ = hex.DecodeString(
+ "0020d554616badeb46ccd4ce4b115e1c8d098e942d1387212d0af9ff93a1" +
+ "9c8f100e",
+ )
+ testScriptP2WKH, _ = hex.DecodeString(
+ "0014e7a43aa41ef6d72dc6baeeaad8362cedf63b79a3",
+ )
+)
+
+// TestFundPsbt tests that a given PSBT packet is funded correctly.
+func TestFundPsbt(t *testing.T) {
+ t.Parallel()
+
+ w := testWallet(t)
+
+ // Create a P2WKH address we can use to send some coins to.
+ addr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0084)
+ require.NoError(t, err)
+ p2wkhAddr, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ // Also create a nested P2WKH address we can use to send some coins to.
+ addr, err = w.CurrentAddress(0, waddrmgr.KeyScopeBIP0049Plus)
+ require.NoError(t, err)
+ np2wkhAddr, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ // Register two big UTXO that will be used when funding the PSBT.
+ const utxo1Amount = 1000000
+ incomingTx1 := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{}},
+ TxOut: []*wire.TxOut{wire.NewTxOut(utxo1Amount, p2wkhAddr)},
+ }
+ addUtxo(t, w, incomingTx1)
+ utxo1 := wire.OutPoint{
+ Hash: incomingTx1.TxHash(),
+ Index: 0,
+ }
+
+ const utxo2Amount = 900000
+ incomingTx2 := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{}},
+ TxOut: []*wire.TxOut{wire.NewTxOut(utxo2Amount, np2wkhAddr)},
+ }
+ addUtxo(t, w, incomingTx2)
+ utxo2 := wire.OutPoint{
+ Hash: incomingTx2.TxHash(),
+ Index: 0,
+ }
+
+ testCases := []struct {
+ name string
+ packet *psbt.Packet
+ feeRateSatPerKB btcutil.Amount
+ changeKeyScope *waddrmgr.KeyScope
+ expectedErr string
+ validatePackage bool
+ expectedChangeBeforeFee int64
+ expectedInputs []wire.OutPoint
+ additionalChecks func(*testing.T, *psbt.Packet, int32)
+ }{{
+ name: "no outputs provided",
+ packet: &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{},
+ },
+ feeRateSatPerKB: 0,
+ expectedErr: "PSBT packet must contain at least one " +
+ "input or output",
+ }, {
+ name: "single input, no outputs",
+ packet: &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{
+ TxIn: []*wire.TxIn{{
+ PreviousOutPoint: utxo1,
+ }},
+ },
+ Inputs: []psbt.PInput{{}},
+ },
+ feeRateSatPerKB: 20000,
+ validatePackage: true,
+ expectedInputs: []wire.OutPoint{utxo1},
+ expectedChangeBeforeFee: utxo1Amount,
+ }, {
+ name: "no dust outputs",
+ packet: &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{
+ TxOut: []*wire.TxOut{{
+ PkScript: []byte("foo"),
+ Value: 100,
+ }},
+ },
+ Outputs: []psbt.POutput{{}},
+ },
+ feeRateSatPerKB: 0,
+ expectedErr: "transaction output is dust",
+ }, {
+ name: "two outputs, no inputs",
+ packet: &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{
+ TxOut: []*wire.TxOut{{
+ PkScript: testScriptP2WSH,
+ Value: 100000,
+ }, {
+ PkScript: testScriptP2WKH,
+ Value: 50000,
+ }},
+ },
+ Outputs: []psbt.POutput{{}, {}},
+ },
+ feeRateSatPerKB: 2000, // 2 sat/byte
+ expectedErr: "",
+ validatePackage: true,
+ expectedChangeBeforeFee: utxo1Amount - 150000,
+ expectedInputs: []wire.OutPoint{utxo1},
+ }, {
+ name: "large output, no inputs",
+ packet: &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{
+ TxOut: []*wire.TxOut{{
+ PkScript: testScriptP2WSH,
+ Value: 1500000,
+ }},
+ },
+ Outputs: []psbt.POutput{{}},
+ },
+ feeRateSatPerKB: 4000, // 4 sat/byte
+ expectedErr: "",
+ validatePackage: true,
+ expectedChangeBeforeFee: (utxo1Amount + utxo2Amount) - 1500000,
+ expectedInputs: []wire.OutPoint{utxo1, utxo2},
+ }, {
+ name: "two outputs, two inputs",
+ packet: &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{
+ TxIn: []*wire.TxIn{{
+ PreviousOutPoint: utxo1,
+ }, {
+ PreviousOutPoint: utxo2,
+ }},
+ TxOut: []*wire.TxOut{{
+ PkScript: testScriptP2WSH,
+ Value: 100000,
+ }, {
+ PkScript: testScriptP2WKH,
+ Value: 50000,
+ }},
+ },
+ Inputs: []psbt.PInput{{}, {}},
+ Outputs: []psbt.POutput{{}, {}},
+ },
+ feeRateSatPerKB: 2000, // 2 sat/byte
+ expectedErr: "",
+ validatePackage: true,
+ expectedChangeBeforeFee: (utxo1Amount + utxo2Amount) - 150000,
+ expectedInputs: []wire.OutPoint{utxo1, utxo2},
+ additionalChecks: func(t *testing.T, packet *psbt.Packet,
+ changeIndex int32) {
+
+ // Check outputs, find index for each of the 3 expected.
+ txOuts := packet.UnsignedTx.TxOut
+ require.Len(t, txOuts, 3, "tx outputs")
+
+ p2wkhIndex := -1
+ p2wshIndex := -1
+ totalOut := int64(0)
+ for idx, txOut := range txOuts {
+ script := txOut.PkScript
+ totalOut += txOut.Value
+
+ switch {
+ case bytes.Equal(script, testScriptP2WKH):
+ p2wkhIndex = idx
+
+ case bytes.Equal(script, testScriptP2WSH):
+ p2wshIndex = idx
+
+ }
+ }
+ totalIn := int64(0)
+ for _, txIn := range packet.Inputs {
+ totalIn += txIn.WitnessUtxo.Value
+ }
+
+ // All outputs must be found.
+ require.Greater(t, p2wkhIndex, -1)
+ require.Greater(t, p2wshIndex, -1)
+ require.Greater(t, changeIndex, int32(-1))
+
+ // After BIP 69 sorting, the P2WKH output should be
+ // before the P2WSH output because the PK script is
+ // lexicographically smaller.
+ require.Less(
+ t, p2wkhIndex, p2wshIndex,
+ "index after sorting",
+ )
+ },
+ }, {
+ name: "one input and a custom change scope: BIP0084",
+ packet: &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{
+ TxIn: []*wire.TxIn{{
+ PreviousOutPoint: utxo1,
+ }},
+ },
+ Inputs: []psbt.PInput{{}},
+ },
+ feeRateSatPerKB: 20000,
+ validatePackage: true,
+ changeKeyScope: &waddrmgr.KeyScopeBIP0084,
+ expectedInputs: []wire.OutPoint{utxo1},
+ expectedChangeBeforeFee: utxo1Amount,
+ }, {
+ name: "no inputs and a custom change scope: BIP0084",
+ packet: &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{
+ TxOut: []*wire.TxOut{{
+ PkScript: testScriptP2WSH,
+ Value: 100000,
+ }, {
+ PkScript: testScriptP2WKH,
+ Value: 50000,
+ }},
+ },
+ Outputs: []psbt.POutput{{}, {}},
+ },
+ feeRateSatPerKB: 2000, // 2 sat/byte
+ expectedErr: "",
+ validatePackage: true,
+ changeKeyScope: &waddrmgr.KeyScopeBIP0084,
+ expectedChangeBeforeFee: utxo1Amount - 150000,
+ expectedInputs: []wire.OutPoint{utxo1},
+ }}
+
+ calcFee := func(feeRateSatPerKB btcutil.Amount,
+ packet *psbt.Packet) btcutil.Amount {
+
+ var numP2WKHInputs, numNP2WKHInputs int
+ for _, txin := range packet.UnsignedTx.TxIn {
+ if txin.PreviousOutPoint == utxo1 {
+ numP2WKHInputs++
+ }
+ if txin.PreviousOutPoint == utxo2 {
+ numNP2WKHInputs++
+ }
+ }
+ estimatedSize := txsizes.EstimateVirtualSize(
+ 0, 0, numP2WKHInputs, numNP2WKHInputs,
+ packet.UnsignedTx.TxOut, 0,
+ )
+ return txrules.FeeForSerializeSize(
+ feeRateSatPerKB, estimatedSize,
+ )
+ }
+
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ changeIndex, err := w.FundPsbtDeprecated(
+ tc.packet, nil, 1, 0,
+ tc.feeRateSatPerKB, CoinSelectionLargest,
+ WithCustomChangeScope(tc.changeKeyScope),
+ )
+
+ // In any case, unlock the UTXO before continuing, we
+ // don't want to pollute other test iterations.
+ for _, in := range tc.packet.UnsignedTx.TxIn {
+ w.UnlockOutpoint(in.PreviousOutPoint)
+ }
+
+ // Make sure the error is what we expected.
+ if tc.expectedErr != "" {
+ require.ErrorContains(t, err, tc.expectedErr)
+ return
+ }
+
+ require.NoError(t, err)
+
+ if !tc.validatePackage {
+ return
+ }
+
+ // Check wire inputs.
+ packet := tc.packet
+ assertTxInputs(t, packet, tc.expectedInputs)
+
+ // Run any additional tests if available.
+ if tc.additionalChecks != nil {
+ tc.additionalChecks(t, packet, changeIndex)
+ }
+
+ // Finally, check the change output size and fee.
+ txOuts := packet.UnsignedTx.TxOut
+ totalOut := int64(0)
+ for _, txOut := range txOuts {
+ totalOut += txOut.Value
+ }
+ totalIn := int64(0)
+ for _, txIn := range packet.Inputs {
+ totalIn += txIn.WitnessUtxo.Value
+ }
+ fee := totalIn - totalOut
+
+ expectedFee := calcFee(tc.feeRateSatPerKB, packet)
+ require.EqualValues(t, expectedFee, fee, "fee")
+ require.EqualValues(
+ t, tc.expectedChangeBeforeFee,
+ txOuts[changeIndex].Value+int64(expectedFee),
+ )
+
+ changeTxOut := txOuts[changeIndex]
+ changeOutput := packet.Outputs[changeIndex]
+
+ require.NotEmpty(t, changeOutput.Bip32Derivation)
+ b32d := changeOutput.Bip32Derivation[0]
+ require.Len(t, b32d.Bip32Path, 5, "derivation path len")
+ require.Len(t, b32d.PubKey, 33, "pubkey len")
+
+ // The third item should be the branch and should belong
+ // to a change output.
+ require.EqualValues(t, 1, b32d.Bip32Path[3])
+
+ assertChangeOutputScope(
+ t, changeTxOut.PkScript, tc.changeKeyScope,
+ )
+
+ if txscript.IsPayToTaproot(changeTxOut.PkScript) {
+ require.NotEmpty(
+ t, changeOutput.TaprootInternalKey,
+ )
+ require.Len(
+ t, changeOutput.TaprootInternalKey, 32,
+ "internal key len",
+ )
+ require.NotEmpty(
+ t, changeOutput.TaprootBip32Derivation,
+ )
+
+ trb32d := changeOutput.TaprootBip32Derivation[0]
+ require.Equal(
+ t, b32d.Bip32Path, trb32d.Bip32Path,
+ )
+ require.Len(
+ t, trb32d.XOnlyPubKey, 32,
+ "schnorr pubkey len",
+ )
+ require.Equal(
+ t, changeOutput.TaprootInternalKey,
+ trb32d.XOnlyPubKey,
+ )
+ }
+ })
+ }
+}
+
+func assertTxInputs(t *testing.T, packet *psbt.Packet,
+ expected []wire.OutPoint) {
+
+ require.Len(t, packet.UnsignedTx.TxIn, len(expected))
+
+ // The order of the UTXOs is random, we need to loop through each of
+ // them to make sure they're found. We also check that no signature data
+ // was added yet.
+ for _, txIn := range packet.UnsignedTx.TxIn {
+ if !containsUtxo(expected, txIn.PreviousOutPoint) {
+ t.Fatalf("outpoint %v not found in list of expected "+
+ "UTXOs", txIn.PreviousOutPoint)
+ }
+
+ require.Empty(t, txIn.SignatureScript)
+ require.Empty(t, txIn.Witness)
+ }
+}
+
+// assertChangeOutputScope checks if the pkScript has the right type.
+func assertChangeOutputScope(t *testing.T, pkScript []byte,
+ changeScope *waddrmgr.KeyScope) {
+
+ // By default (changeScope == nil), the script should
+ // be a pay-to-taproot one.
+ switch changeScope {
+ case nil, &waddrmgr.KeyScopeBIP0086:
+ require.True(t, txscript.IsPayToTaproot(pkScript))
+
+ case &waddrmgr.KeyScopeBIP0049Plus, &waddrmgr.KeyScopeBIP0084:
+ require.True(t, txscript.IsPayToWitnessPubKeyHash(pkScript))
+
+ case &waddrmgr.KeyScopeBIP0044:
+ require.True(t, txscript.IsPayToPubKeyHash(pkScript))
+
+ default:
+ require.Fail(t, "assertChangeOutputScope error",
+ "change scope: %s", changeScope.String())
+ }
+}
+
+func containsUtxo(list []wire.OutPoint, candidate wire.OutPoint) bool {
+ for _, utxo := range list {
+ if utxo == candidate {
+ return true
+ }
+ }
+
+ return false
+}
+
+// TestFinalizePsbt tests that a given PSBT packet can be finalized.
+func TestFinalizePsbt(t *testing.T) {
+ t.Parallel()
+
+ w := testWallet(t)
+
+ // Create a P2WKH address we can use to send some coins to.
+ addr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0084)
+ if err != nil {
+ t.Fatalf("unable to get current address: %v", addr)
+ }
+ p2wkhAddr, err := txscript.PayToAddrScript(addr)
+ if err != nil {
+ t.Fatalf("unable to convert wallet address to p2wkh: %v", err)
+ }
+
+ // Also create a nested P2WKH address we can send coins to.
+ addr, err = w.CurrentAddress(0, waddrmgr.KeyScopeBIP0049Plus)
+ if err != nil {
+ t.Fatalf("unable to get current address: %v", addr)
+ }
+ np2wkhAddr, err := txscript.PayToAddrScript(addr)
+ if err != nil {
+ t.Fatalf("unable to convert wallet address to np2wkh: %v", err)
+ }
+
+ // Register two big UTXO that will be used when funding the PSBT.
+ utxOutP2WKH := wire.NewTxOut(1000000, p2wkhAddr)
+ utxOutNP2WKH := wire.NewTxOut(1000000, np2wkhAddr)
+ incomingTx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{}},
+ TxOut: []*wire.TxOut{utxOutP2WKH, utxOutNP2WKH},
+ }
+ addUtxo(t, w, incomingTx)
+
+ // Create the packet that we want to sign.
+ packet := &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{
+ TxIn: []*wire.TxIn{{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: incomingTx.TxHash(),
+ Index: 0,
+ },
+ }, {
+ PreviousOutPoint: wire.OutPoint{
+ Hash: incomingTx.TxHash(),
+ Index: 1,
+ },
+ }},
+ TxOut: []*wire.TxOut{{
+ PkScript: testScriptP2WKH,
+ Value: 50000,
+ }, {
+ PkScript: testScriptP2WSH,
+ Value: 100000,
+ }, {
+ PkScript: testScriptP2WKH,
+ Value: 849632,
+ }},
+ },
+ Inputs: []psbt.PInput{{
+ WitnessUtxo: utxOutP2WKH,
+ SighashType: txscript.SigHashAll,
+ }, {
+ NonWitnessUtxo: incomingTx,
+ SighashType: txscript.SigHashAll,
+ }},
+ Outputs: []psbt.POutput{{}, {}, {}},
+ }
+
+ // Finalize it to add all witness data then extract the final TX.
+ err = w.FinalizePsbtDeprecated(nil, 0, packet)
+ if err != nil {
+ t.Fatalf("error finalizing PSBT packet: %v", err)
+ }
+ finalTx, err := psbt.Extract(packet)
+ if err != nil {
+ t.Fatalf("error extracting final TX from PSBT: %v", err)
+ }
+
+ // Finally verify that the created witness is valid.
+ err = validateMsgTx(
+ finalTx, [][]byte{utxOutP2WKH.PkScript, utxOutNP2WKH.PkScript},
+ []btcutil.Amount{1000000, 1000000},
+ )
+ if err != nil {
+ t.Fatalf("error validating tx: %v", err)
+ }
+}
+
+var (
+ testBlockHash, _ = chainhash.NewHashFromStr(
+ "00000000000000017188b968a371bab95aa43522665353b646e41865abae" +
+ "02a4",
+ )
+ testBlockHeight int32 = 276425
+
+ alwaysAllowUtxo = func(utxo wtxmgr.Credit) bool { return true }
+)
+
+// TestTxToOutput checks that no new address is added to he database if we
+// request a dry run of the txToOutputs call. It also makes sure a subsequent
+// non-dry run call produces a similar transaction to the dry-run.
+func TestTxToOutputsDryRun(t *testing.T) {
+ t.Parallel()
+
+ w := testWallet(t)
+
+ // Create an address we can use to send some coins to.
+ keyScope := waddrmgr.KeyScopeBIP0049Plus
+ addr, err := w.CurrentAddress(0, keyScope)
+ if err != nil {
+ t.Fatalf("unable to get current address: %v", addr)
+ }
+ p2shAddr, err := txscript.PayToAddrScript(addr)
+ if err != nil {
+ t.Fatalf("unable to convert wallet address to p2sh: %v", err)
+ }
+
+ // Add an output paying to the wallet's address to the database.
+ txOut := wire.NewTxOut(100000, p2shAddr)
+ incomingTx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{
+ {},
+ },
+ TxOut: []*wire.TxOut{
+ txOut,
+ },
+ }
+ addUtxo(t, w, incomingTx)
+
+ // Now tell the wallet to create a transaction paying to the specified
+ // outputs.
+ txOuts := []*wire.TxOut{
+ {
+ PkScript: p2shAddr,
+ Value: 10000,
+ },
+ {
+ PkScript: p2shAddr,
+ Value: 20000,
+ },
+ }
+
+ // First do a few dry-runs, making sure the number of addresses in the
+ // database us not inflated.
+ dryRunTx, err := w.txToOutputs(
+ txOuts, nil, nil, 0, 1, 1000, CoinSelectionLargest, true,
+ nil, alwaysAllowUtxo,
+ )
+ if err != nil {
+ t.Fatalf("unable to author tx: %v", err)
+ }
+ change := dryRunTx.Tx.TxOut[dryRunTx.ChangeIndex]
+
+ addresses, err := w.AccountAddresses(0)
+ if err != nil {
+ t.Fatalf("unable to get addresses: %v", err)
+ }
+
+ if len(addresses) != 1 {
+ t.Fatalf("expected 1 address, found %v", len(addresses))
+ }
+
+ dryRunTx2, err := w.txToOutputs(
+ txOuts, nil, nil, 0, 1, 1000, CoinSelectionLargest, true,
+ nil, alwaysAllowUtxo,
+ )
+ if err != nil {
+ t.Fatalf("unable to author tx: %v", err)
+ }
+ change2 := dryRunTx2.Tx.TxOut[dryRunTx2.ChangeIndex]
+
+ addresses, err = w.AccountAddresses(0)
+ if err != nil {
+ t.Fatalf("unable to get addresses: %v", err)
+ }
+
+ if len(addresses) != 1 {
+ t.Fatalf("expected 1 address, found %v", len(addresses))
+ }
+
+ // The two dry-run TXs should be invalid, since they don't have
+ // signatures.
+ err = validateMsgTx(
+ dryRunTx.Tx, dryRunTx.PrevScripts, dryRunTx.PrevInputValues,
+ )
+ if err == nil {
+ t.Fatalf("Expected tx to be invalid")
+ }
+
+ err = validateMsgTx(
+ dryRunTx2.Tx, dryRunTx2.PrevScripts, dryRunTx2.PrevInputValues,
+ )
+ if err == nil {
+ t.Fatalf("Expected tx to be invalid")
+ }
+
+ // Now we do a proper, non-dry run. This should add a change address
+ // to the database.
+ tx, err := w.txToOutputs(
+ txOuts, nil, nil, 0, 1, 1000, CoinSelectionLargest, false,
+ nil, alwaysAllowUtxo,
+ )
+ if err != nil {
+ t.Fatalf("unable to author tx: %v", err)
+ }
+ change3 := tx.Tx.TxOut[tx.ChangeIndex]
+
+ addresses, err = w.AccountAddresses(0)
+ if err != nil {
+ t.Fatalf("unable to get addresses: %v", err)
+ }
+
+ if len(addresses) != 2 {
+ t.Fatalf("expected 2 addresses, found %v", len(addresses))
+ }
+
+ err = validateMsgTx(tx.Tx, tx.PrevScripts, tx.PrevInputValues)
+ if err != nil {
+ t.Fatalf("Expected tx to be valid: %v", err)
+ }
+
+ // Finally, we check that all the transaction were using the same
+ // change address.
+ if !bytes.Equal(change.PkScript, change2.PkScript) {
+ t.Fatalf("first dry-run using different change address " +
+ "than second")
+ }
+ if !bytes.Equal(change2.PkScript, change3.PkScript) {
+ t.Fatalf("dry-run using different change address " +
+ "than wet run")
+ }
+}
+
+// addUtxo add the given transaction to the wallet's database marked as a
+// confirmed UTXO .
+func addUtxo(t *testing.T, w *Wallet, incomingTx *wire.MsgTx) {
+ var b bytes.Buffer
+ if err := incomingTx.Serialize(&b); err != nil {
+ t.Fatalf("unable to serialize tx: %v", err)
+ }
+ txBytes := b.Bytes()
+
+ rec, err := wtxmgr.NewTxRecord(txBytes, time.Now())
+ if err != nil {
+ t.Fatalf("unable to create tx record: %v", err)
+ }
+
+ // The block meta will be inserted to tell the wallet this is a
+ // confirmed transaction.
+ block := &wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ Hash: *testBlockHash,
+ Height: testBlockHeight,
+ },
+ Time: time.Unix(1387737310, 0),
+ }
+
+ if err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ ns := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+ err = w.txStore.InsertTx(ns, rec, block)
+ if err != nil {
+ return err
+ }
+ // Add all tx outputs as credits.
+ for i := 0; i < len(incomingTx.TxOut); i++ {
+ err = w.txStore.AddCredit(
+ ns, rec, block, uint32(i), false,
+ )
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+ }); err != nil {
+ t.Fatalf("failed inserting tx: %v", err)
+ }
+}
+
+// addTxAndCredit adds the given transaction to the wallet's database marked as
+// a confirmed UTXO specified by the creditIndex.
+func addTxAndCredit(t *testing.T, w *Wallet, tx *wire.MsgTx,
+ creditIndex uint32) {
+
+ var b bytes.Buffer
+ require.NoError(t, tx.Serialize(&b), "unable to serialize tx")
+
+ txBytes := b.Bytes()
+
+ rec, err := wtxmgr.NewTxRecord(txBytes, time.Now())
+ require.NoError(t, err)
+
+ // The block meta will be inserted to tell the wallet this is a
+ // confirmed transaction.
+ block := &wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ Hash: *testBlockHash,
+ Height: testBlockHeight,
+ },
+ Time: time.Unix(1387737310, 0),
+ }
+
+ err = walletdb.Update(w.db, func(dbTx walletdb.ReadWriteTx) error {
+ ns := dbTx.ReadWriteBucket(wtxmgrNamespaceKey)
+
+ err = w.txStore.InsertTx(ns, rec, block)
+ if err != nil {
+ return err
+ }
+
+ // Add the specified output as credit.
+ err = w.txStore.AddCredit(ns, rec, block, creditIndex, false)
+ if err != nil {
+ return err
+ }
+
+ return nil
+ })
+ require.NoError(t, err, "failed inserting tx")
+}
+
+// TestInputYield verifies the functioning of the inputYieldsPositively.
+func TestInputYield(t *testing.T) {
+ t.Parallel()
+
+ addr, _ := address.DecodeAddress("bc1qw508d6qejxtdg4y5r3zarvary0c5xw7kv8f3t4", &chaincfg.MainNetParams)
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ credit := &wire.TxOut{
+ Value: 1000,
+ PkScript: pkScript,
+ }
+
+ // At 10 sat/b this input is yielding positively.
+ require.True(t, inputYieldsPositively(credit, 10000))
+
+ // At 20 sat/b this input is yielding negatively.
+ require.False(t, inputYieldsPositively(credit, 20000))
+}
+
+// TestTxToOutputsRandom tests random coin selection.
+func TestTxToOutputsRandom(t *testing.T) {
+ t.Parallel()
+
+ w := testWallet(t)
+
+ // Create an address we can use to send some coins to.
+ keyScope := waddrmgr.KeyScopeBIP0049Plus
+ addr, err := w.CurrentAddress(0, keyScope)
+ if err != nil {
+ t.Fatalf("unable to get current address: %v", addr)
+ }
+ p2shAddr, err := txscript.PayToAddrScript(addr)
+ if err != nil {
+ t.Fatalf("unable to convert wallet address to p2sh: %v", err)
+ }
+
+ // Add a set of utxos to the wallet.
+ incomingTx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{
+ {},
+ },
+ TxOut: []*wire.TxOut{},
+ }
+ for amt := int64(5000); amt <= 125000; amt += 10000 {
+ incomingTx.AddTxOut(wire.NewTxOut(amt, p2shAddr))
+ }
+
+ addUtxo(t, w, incomingTx)
+
+ // Now tell the wallet to create a transaction paying to the specified
+ // outputs.
+ txOuts := []*wire.TxOut{
+ {
+ PkScript: p2shAddr,
+ Value: 50000,
+ },
+ {
+ PkScript: p2shAddr,
+ Value: 100000,
+ },
+ }
+
+ const (
+ feeSatPerKb = 100000
+ maxIterations = 100
+ )
+
+ createTx := func() *txauthor.AuthoredTx {
+ tx, err := w.txToOutputs(
+ txOuts, nil, nil, 0, 1, feeSatPerKb,
+ CoinSelectionRandom, true, nil, alwaysAllowUtxo,
+ )
+ require.NoError(t, err)
+ return tx
+ }
+
+ firstTx := createTx()
+ var isRandom bool
+ for iteration := 0; iteration < maxIterations; iteration++ {
+ tx := createTx()
+
+ // Check to see if we are getting a total input value.
+ // We consider this proof that the randomization works.
+ if tx.TotalInput != firstTx.TotalInput {
+ isRandom = true
+ }
+
+ // At the used fee rate of 100 sat/b, the 5000 sat input is
+ // negatively yielding. We don't expect it to ever be selected.
+ for _, inputValue := range tx.PrevInputValues {
+ require.NotEqual(t, inputValue, btcutil.Amount(5000))
+ }
+ }
+
+ require.True(t, isRandom)
+}
+
+// TestCreateSimpleCustomChange tests that it's possible to let the
+// CreateSimpleTx use all coins for coin selection, but specify a custom scope
+// that isn't the current default scope.
+func TestCreateSimpleCustomChange(t *testing.T) {
+ t.Parallel()
+
+ w := testWallet(t)
+
+ // First, we'll make a P2TR and a P2WKH address to send some coins to
+ // (two different coin scopes).
+ p2wkhAddr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0084)
+ require.NoError(t, err)
+
+ p2trAddr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0086)
+ require.NoError(t, err)
+
+ // We'll now make a transaction that'll send coins to both outputs,
+ // then "credit" the wallet for that send.
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+ p2trScript, err := txscript.PayToAddrScript(p2trAddr)
+ require.NoError(t, err)
+
+ const testAmt = 1_000_000
+
+ incomingTx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{
+ {},
+ },
+ TxOut: []*wire.TxOut{
+ wire.NewTxOut(testAmt, p2wkhScript),
+ wire.NewTxOut(testAmt, p2trScript),
+ },
+ }
+ addUtxo(t, w, incomingTx)
+
+ // With the amounts credited to the wallet, we'll now do a dry run coin
+ // selection w/o any default args.
+ targetTxOut := &wire.TxOut{
+ Value: 1_500_000,
+ PkScript: p2trScript,
+ }
+ tx1, err := w.txToOutputs(
+ []*wire.TxOut{targetTxOut}, nil, nil, 0, 1, 1000,
+ CoinSelectionLargest, true, nil, alwaysAllowUtxo,
+ )
+ require.NoError(t, err)
+
+ // We expect that all inputs were used and also the change output is a
+ // taproot output (the current default).
+ require.Len(t, tx1.Tx.TxIn, 2)
+ require.Len(t, tx1.Tx.TxOut, 2)
+ for _, txOut := range tx1.Tx.TxOut {
+ scriptType, _, _, err := txscript.ExtractPkScriptAddrs(
+ txOut.PkScript, w.chainParams,
+ )
+ require.NoError(t, err)
+
+ require.Equal(t, scriptType, txscript.WitnessV1TaprootTy)
+ }
+
+ // Next, we'll do another dry run, but this time, specify a custom
+ // change key scope. We'll also require that only inputs of P2TR are used.
+ targetTxOut = &wire.TxOut{
+ Value: 500_000,
+ PkScript: p2trScript,
+ }
+ tx2, err := w.txToOutputs(
+ []*wire.TxOut{targetTxOut}, &waddrmgr.KeyScopeBIP0086,
+ &waddrmgr.KeyScopeBIP0084, 0, 1, 1000, CoinSelectionLargest,
+ true, nil, alwaysAllowUtxo,
+ )
+ require.NoError(t, err)
+
+ // The resulting transaction should spend a single input, and use P2WKH
+ // as the output script.
+ require.Len(t, tx2.Tx.TxIn, 1)
+ require.Len(t, tx2.Tx.TxOut, 2)
+ for i, txOut := range tx2.Tx.TxOut {
+ if i != tx2.ChangeIndex {
+ continue
+ }
+
+ scriptType, _, _, err := txscript.ExtractPkScriptAddrs(
+ txOut.PkScript, w.chainParams,
+ )
+ require.NoError(t, err)
+
+ require.Equal(t, scriptType, txscript.WitnessV0PubKeyHashTy)
+ }
+}
+
+// TestSelectUtxosTxoToOutpoint tests that it is possible to use passed
+// selected utxos to craft a transaction in `txToOutpoint`.
+func TestSelectUtxosTxoToOutpoint(t *testing.T) {
+ t.Parallel()
+
+ w := testWallet(t)
+
+ // First, we'll make a P2TR and a P2WKH address to send some coins to.
+ p2wkhAddr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0084)
+ require.NoError(t, err)
+
+ p2trAddr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0086)
+ require.NoError(t, err)
+
+ // We'll now make a transaction that'll send coins to both outputs,
+ // then "credit" the wallet for that send.
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ p2trScript, err := txscript.PayToAddrScript(p2trAddr)
+ require.NoError(t, err)
+
+ incomingTx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{
+ {},
+ },
+ TxOut: []*wire.TxOut{
+ wire.NewTxOut(1_000_000, p2wkhScript),
+ wire.NewTxOut(2_000_000, p2trScript),
+ wire.NewTxOut(3_000_000, p2trScript),
+ wire.NewTxOut(7_000_000, p2trScript),
+ },
+ }
+ addUtxo(t, w, incomingTx)
+
+ // We expect 4 unspent UTXOs.
+ unspent, err := w.ListUnspentDeprecated(0, 80, "")
+ require.NoError(t, err)
+ require.Len(t, unspent, 4, "expected 4 unspent UTXOs")
+
+ tCases := []struct {
+ name string
+ selectUTXOs []wire.OutPoint
+ errString string
+ }{
+ {
+ name: "Duplicate utxo values",
+ selectUTXOs: []wire.OutPoint{
+ {
+ Hash: incomingTx.TxHash(),
+ Index: 1,
+ },
+ {
+ Hash: incomingTx.TxHash(),
+ Index: 1,
+ },
+ },
+ errString: "selected UTXOs contain duplicate values",
+ },
+ {
+ name: "all selected UTXOs not eligible for spending",
+ selectUTXOs: []wire.OutPoint{
+ {
+ Hash: chainhash.Hash([32]byte{1}),
+ Index: 1,
+ },
+ {
+ Hash: chainhash.Hash([32]byte{3}),
+ Index: 1,
+ },
+ },
+ errString: "selected outpoint not eligible for " +
+ "spending",
+ },
+ {
+ name: "some select UTXOs not eligible for spending",
+ selectUTXOs: []wire.OutPoint{
+ {
+ Hash: chainhash.Hash([32]byte{1}),
+ Index: 1,
+ },
+ {
+ Hash: incomingTx.TxHash(),
+ Index: 1,
+ },
+ },
+ errString: "selected outpoint not eligible for " +
+ "spending",
+ },
+ {
+ name: "select utxo, no duplicates and all eligible " +
+ "for spending",
+ selectUTXOs: []wire.OutPoint{
+ {
+ Hash: incomingTx.TxHash(),
+ Index: 1,
+ },
+ {
+ Hash: incomingTx.TxHash(),
+ Index: 2,
+ },
+ },
+ },
+ }
+
+ for _, tc := range tCases {
+ t.Run(tc.name, func(t *testing.T) {
+ // Test by sending 200_000.
+ targetTxOut := &wire.TxOut{
+ Value: 200_000,
+ PkScript: p2trScript,
+ }
+ tx1, err := w.txToOutputs(
+ []*wire.TxOut{targetTxOut}, nil, nil, 0, 1,
+ 1000, CoinSelectionLargest, true,
+ tc.selectUTXOs, alwaysAllowUtxo,
+ )
+ if tc.errString != "" {
+ require.ErrorContains(t, err, tc.errString)
+ require.Nil(t, tx1)
+
+ return
+ }
+
+ require.NoError(t, err)
+ require.NotNil(t, tx1)
+
+ // We expect all and only our select UTXOs to be input
+ // in this transaction.
+ require.Len(t, tx1.Tx.TxIn, len(tc.selectUTXOs))
+
+ lookupSelectUtxos := make(map[wire.OutPoint]struct{})
+ for _, utxo := range tc.selectUTXOs {
+ lookupSelectUtxos[utxo] = struct{}{}
+ }
+
+ for _, tx := range tx1.Tx.TxIn {
+ _, ok := lookupSelectUtxos[tx.PreviousOutPoint]
+ require.True(t, ok)
+ }
+
+ // Expect two outputs, change and the actual payment to
+ // the address.
+ require.Len(t, tx1.Tx.TxOut, 2)
+ })
+ }
+}
+
+// TestComputeInputScript checks that the wallet can create the full
+// witness script for a witness output.
+func TestComputeInputScript(t *testing.T) {
+ t.Parallel()
+
+ w := testWallet(t)
+
+ testCases := []struct {
+ name string
+ scope waddrmgr.KeyScope
+ expectedScriptLen int
+ }{{
+ name: "BIP084 P2WKH",
+ scope: waddrmgr.KeyScopeBIP0084,
+ expectedScriptLen: 0,
+ }, {
+ name: "BIP049 nested P2WKH",
+ scope: waddrmgr.KeyScopeBIP0049Plus,
+ expectedScriptLen: 23,
+ }}
+
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ runTestCase(t, w, tc.scope, tc.expectedScriptLen)
+ })
+ }
+}
+
+func runTestCase(t *testing.T, w *Wallet, scope waddrmgr.KeyScope,
+ scriptLen int) {
+
+ // Create an address we can use to send some coins to.
+ addr, err := w.CurrentAddress(0, scope)
+ if err != nil {
+ t.Fatalf("unable to get current address: %v", addr)
+ }
+ p2shAddr, err := txscript.PayToAddrScript(addr)
+ if err != nil {
+ t.Fatalf("unable to convert wallet address to p2sh: %v", err)
+ }
+
+ // Add an output paying to the wallet's address to the database.
+ utxOut := wire.NewTxOut(100000, p2shAddr)
+ incomingTx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{}},
+ TxOut: []*wire.TxOut{utxOut},
+ }
+ addUtxo(t, w, incomingTx)
+
+ // Create a transaction that spends the UTXO created above and spends to
+ // the same address again.
+ prevOut := wire.OutPoint{
+ Hash: incomingTx.TxHash(),
+ Index: 0,
+ }
+ outgoingTx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{
+ PreviousOutPoint: prevOut,
+ }},
+ TxOut: []*wire.TxOut{utxOut},
+ }
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ utxOut.PkScript, utxOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(outgoingTx, fetcher)
+
+ // Compute the input script to spend the UTXO now.
+ witness, script, err := w.ComputeInputScript(
+ outgoingTx, utxOut, 0, sigHashes, txscript.SigHashAll, nil,
+ )
+ if err != nil {
+ t.Fatalf("error computing input script: %v", err)
+ }
+ if len(script) != scriptLen {
+ t.Fatalf("unexpected script length, got %d wanted %d",
+ len(script), scriptLen)
+ }
+ if len(witness) != 2 {
+ t.Fatalf("unexpected witness stack length, got %d, wanted %d",
+ len(witness), 2)
+ }
+
+ // Finally verify that the created witness is valid.
+ outgoingTx.TxIn[0].Witness = witness
+ outgoingTx.TxIn[0].SignatureScript = script
+ err = validateMsgTx(
+ outgoingTx, [][]byte{utxOut.PkScript}, []btcutil.Amount{100000},
+ )
+ if err != nil {
+ t.Fatalf("error validating tx: %v", err)
+ }
+}
diff --git a/wallet/example_test.go b/wallet/example_test.go
deleted file mode 100644
index 698239971f..0000000000
--- a/wallet/example_test.go
+++ /dev/null
@@ -1,82 +0,0 @@
-package wallet
-
-import (
- "testing"
- "time"
-
- "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
- "github.com/btcsuite/btcd/chaincfg/v2"
- "github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/btcsuite/btcwallet/walletdb"
-)
-
-// defaultDBTimeout specifies the timeout value when opening the wallet
-// database.
-var defaultDBTimeout = 10 * time.Second
-
-// testWallet creates a test wallet and unlocks it.
-func testWallet(t *testing.T) (*Wallet, func()) {
- // Set up a wallet.
- dir := t.TempDir()
-
- seed, err := hdkeychain.GenerateSeed(hdkeychain.MinSeedBytes)
- if err != nil {
- t.Fatalf("unable to create seed: %v", err)
- }
-
- pubPass := []byte("hello")
- privPass := []byte("world")
-
- loader := NewLoader(
- &chaincfg.TestNet3Params, dir, true, defaultDBTimeout, 250,
- WithWalletSyncRetryInterval(10*time.Millisecond),
- )
- w, err := loader.CreateNewWallet(pubPass, privPass, seed, time.Now())
- if err != nil {
- t.Fatalf("unable to create wallet: %v", err)
- }
- chainClient := &mockChainClient{}
- w.chainClient = chainClient
- if err := w.Unlock(privPass, time.After(10*time.Minute)); err != nil {
- t.Fatalf("unable to unlock wallet: %v", err)
- }
-
- return w, func() {}
-}
-
-// testWalletWatchingOnly creates a test watch only wallet and unlocks it.
-func testWalletWatchingOnly(t *testing.T) (*Wallet, func()) {
- // Set up a wallet.
- dir := t.TempDir()
-
- pubPass := []byte("hello")
- loader := NewLoader(
- &chaincfg.TestNet3Params, dir, true, defaultDBTimeout, 250,
- WithWalletSyncRetryInterval(10*time.Millisecond),
- )
- w, err := loader.CreateNewWatchingOnlyWallet(pubPass, time.Now())
- if err != nil {
- t.Fatalf("unable to create wallet: %v", err)
- }
- chainClient := &mockChainClient{}
- w.chainClient = chainClient
-
- err = walletdb.Update(w.Database(), func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- for scope, schema := range waddrmgr.ScopeAddrMap {
- _, err := w.Manager.NewScopedKeyManager(
- ns, scope, schema,
- )
- if err != nil {
- return err
- }
- }
-
- return nil
- })
- if err != nil {
- t.Fatalf("unable to create default scopes: %v", err)
- }
-
- return w, func() {}
-}
diff --git a/wallet/import.go b/wallet/import.go
deleted file mode 100644
index fe7af70a39..0000000000
--- a/wallet/import.go
+++ /dev/null
@@ -1,597 +0,0 @@
-package wallet
-
-import (
- "encoding/binary"
- "errors"
- "fmt"
-
- "github.com/btcsuite/btcd/address/v2"
- "github.com/btcsuite/btcd/btcec/v2"
- "github.com/btcsuite/btcd/btcutil/v2"
- "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
- "github.com/btcsuite/btcd/wire/v2"
- "github.com/btcsuite/btcwallet/netparams"
- "github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/btcsuite/btcwallet/walletdb"
-)
-
-const (
- // accountPubKeyDepth is the maximum depth of an extended key for an
- // account public key.
- accountPubKeyDepth = 3
-
- // pubKeyDepth is the depth of an extended key for a derived public key.
- pubKeyDepth = 5
-)
-
-// keyScopeFromPubKey returns the corresponding wallet key scope for the given
-// extended public key. The address type can usually be inferred from the key's
-// version, but may be required for certain keys to map them into the proper
-// scope.
-func keyScopeFromPubKey(pubKey *hdkeychain.ExtendedKey,
- addrType *waddrmgr.AddressType) (waddrmgr.KeyScope,
- *waddrmgr.ScopeAddrSchema, error) {
-
- switch waddrmgr.HDVersion(binary.BigEndian.Uint32(pubKey.Version())) {
- // For BIP-0044 keys, an address type must be specified as we intend to
- // not support importing BIP-0044 keys into the wallet using the legacy
- // pay-to-pubkey-hash (P2PKH) scheme. A nested witness address type will
- // force the standard BIP-0049 derivation scheme (nested witness pubkeys
- // everywhere), while a witness address type will force the standard
- // BIP-0084 derivation scheme.
- case waddrmgr.HDVersionMainNetBIP0044, waddrmgr.HDVersionTestNetBIP0044,
- waddrmgr.HDVersionSimNetBIP0044:
-
- if addrType == nil {
- return waddrmgr.KeyScope{}, nil, errors.New("address " +
- "type must be specified for account public " +
- "key with legacy version")
- }
-
- switch *addrType {
- case waddrmgr.NestedWitnessPubKey:
- return waddrmgr.KeyScopeBIP0049Plus,
- &waddrmgr.KeyScopeBIP0049AddrSchema, nil
-
- case waddrmgr.WitnessPubKey:
- return waddrmgr.KeyScopeBIP0084, nil, nil
-
- case waddrmgr.TaprootPubKey:
- return waddrmgr.KeyScopeBIP0086, nil, nil
-
- default:
- return waddrmgr.KeyScope{}, nil,
- fmt.Errorf("unsupported address type %v",
- *addrType)
- }
-
- // For BIP-0049 keys, we'll need to make a distinction between the
- // traditional BIP-0049 address schema (nested witness pubkeys
- // everywhere) and our own BIP-0049Plus address schema (nested
- // externally, witness internally).
- case waddrmgr.HDVersionMainNetBIP0049, waddrmgr.HDVersionTestNetBIP0049:
- if addrType == nil {
- return waddrmgr.KeyScope{}, nil, errors.New("address " +
- "type must be specified for account public " +
- "key with BIP-0049 version")
- }
-
- switch *addrType {
- case waddrmgr.NestedWitnessPubKey:
- return waddrmgr.KeyScopeBIP0049Plus,
- &waddrmgr.KeyScopeBIP0049AddrSchema, nil
-
- case waddrmgr.WitnessPubKey:
- return waddrmgr.KeyScopeBIP0049Plus, nil, nil
-
- default:
- return waddrmgr.KeyScope{}, nil,
- fmt.Errorf("unsupported address type %v",
- *addrType)
- }
-
- // BIP-0086 does not have its own SLIP-0132 HD version byte set (yet?).
- // So we either expect a user to import it with a BIP-0084 or BIP-0044
- // encoding.
- case waddrmgr.HDVersionMainNetBIP0084, waddrmgr.HDVersionTestNetBIP0084:
- if addrType == nil {
- return waddrmgr.KeyScope{}, nil, errors.New("address " +
- "type must be specified for account public " +
- "key with BIP-0084 version")
- }
-
- switch *addrType {
- case waddrmgr.WitnessPubKey:
- return waddrmgr.KeyScopeBIP0084, nil, nil
-
- case waddrmgr.TaprootPubKey:
- return waddrmgr.KeyScopeBIP0086, nil, nil
-
- default:
- return waddrmgr.KeyScope{}, nil,
- errors.New("address type mismatch")
- }
-
- default:
- return waddrmgr.KeyScope{}, nil, fmt.Errorf("unknown version %x",
- pubKey.Version())
- }
-}
-
-// isPubKeyForNet determines if the given public key is for the current network
-// the wallet is operating under.
-func (w *Wallet) isPubKeyForNet(pubKey *hdkeychain.ExtendedKey) bool {
- version := waddrmgr.HDVersion(binary.BigEndian.Uint32(pubKey.Version()))
- switch w.chainParams.Net {
- case wire.MainNet:
- return version == waddrmgr.HDVersionMainNetBIP0044 ||
- version == waddrmgr.HDVersionMainNetBIP0049 ||
- version == waddrmgr.HDVersionMainNetBIP0084
-
- case wire.TestNet, wire.TestNet3, wire.TestNet4,
- netparams.SigNetWire(w.chainParams):
-
- return version == waddrmgr.HDVersionTestNetBIP0044 ||
- version == waddrmgr.HDVersionTestNetBIP0049 ||
- version == waddrmgr.HDVersionTestNetBIP0084
-
- // For simnet, we'll also allow the mainnet versions since simnet
- // doesn't have defined versions for some of our key scopes, and the
- // mainnet versions are usually used as the default regardless of the
- // network/key scope.
- case wire.SimNet:
- return version == waddrmgr.HDVersionSimNetBIP0044 ||
- version == waddrmgr.HDVersionMainNetBIP0049 ||
- version == waddrmgr.HDVersionMainNetBIP0084
-
- default:
- return false
- }
-}
-
-// validateExtendedPubKey ensures a sane derived public key is provided.
-func (w *Wallet) validateExtendedPubKey(pubKey *hdkeychain.ExtendedKey,
- isAccountKey bool) error {
-
- // Private keys are not allowed.
- if pubKey.IsPrivate() {
- return errors.New("private keys cannot be imported")
- }
-
- // The public key must have a version corresponding to the current
- // chain.
- if !w.isPubKeyForNet(pubKey) {
- return fmt.Errorf("expected extended public key for current "+
- "network %v", w.chainParams.Name)
- }
-
- // Verify the extended public key's depth and child index based on
- // whether it's an account key or not.
- if isAccountKey {
- if pubKey.Depth() != accountPubKeyDepth {
- return errors.New("invalid account key, must be of the " +
- "form m/purpose'/coin_type'/account'")
- }
- if pubKey.ChildIndex() < hdkeychain.HardenedKeyStart {
- return errors.New("invalid account key, must be hardened")
- }
- } else {
- if pubKey.Depth() != pubKeyDepth {
- return errors.New("invalid account key, must be of the " +
- "form m/purpose'/coin_type'/account'/change/" +
- "address_index")
- }
- if pubKey.ChildIndex() >= hdkeychain.HardenedKeyStart {
- return errors.New("invalid pulic key, must not be " +
- "hardened")
- }
- }
-
- return nil
-}
-
-// ImportAccount imports an account backed by an account extended public key.
-// The master key fingerprint denotes the fingerprint of the root key
-// corresponding to the account public key (also known as the key with
-// derivation path m/). This may be required by some hardware wallets for proper
-// identification and signing.
-//
-// The address type can usually be inferred from the key's version, but may be
-// required for certain keys to map them into the proper scope.
-//
-// For BIP-0044 keys, an address type must be specified as we intend to not
-// support importing BIP-0044 keys into the wallet using the legacy
-// pay-to-pubkey-hash (P2PKH) scheme. A nested witness address type will force
-// the standard BIP-0049 derivation scheme, while a witness address type will
-// force the standard BIP-0084 derivation scheme.
-//
-// For BIP-0049 keys, an address type must also be specified to make a
-// distinction between the traditional BIP-0049 address schema (nested witness
-// pubkeys everywhere) and our own BIP-0049Plus address schema (nested
-// externally, witness internally).
-func (w *Wallet) ImportAccount(name string, accountPubKey *hdkeychain.ExtendedKey,
- masterKeyFingerprint uint32, addrType *waddrmgr.AddressType) (
- *waddrmgr.AccountProperties, error) {
-
- var accountProps *waddrmgr.AccountProperties
- err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- var err error
- accountProps, err = w.importAccount(
- ns, name, accountPubKey, masterKeyFingerprint, addrType,
- )
- return err
- })
- return accountProps, err
-}
-
-// ImportAccountWithScope imports an account backed by an account extended
-// public key for a specific key scope which is known in advance.
-// The master key fingerprint denotes the fingerprint of the root key
-// corresponding to the account public key (also known as the key with
-// derivation path m/). This may be required by some hardware wallets for proper
-// identification and signing.
-func (w *Wallet) ImportAccountWithScope(name string,
- accountPubKey *hdkeychain.ExtendedKey, masterKeyFingerprint uint32,
- keyScope waddrmgr.KeyScope, addrSchema waddrmgr.ScopeAddrSchema) (
- *waddrmgr.AccountProperties, error) {
-
- var accountProps *waddrmgr.AccountProperties
- err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- var err error
- accountProps, err = w.importAccountScope(
- ns, name, accountPubKey, masterKeyFingerprint, keyScope,
- &addrSchema,
- )
- return err
- })
- return accountProps, err
-}
-
-// importAccount is the internal implementation of ImportAccount -- one should
-// reference its documentation for this method.
-func (w *Wallet) importAccount(ns walletdb.ReadWriteBucket, name string,
- accountPubKey *hdkeychain.ExtendedKey, masterKeyFingerprint uint32,
- addrType *waddrmgr.AddressType) (*waddrmgr.AccountProperties, error) {
-
- // Ensure we have a valid account public key.
- if err := w.validateExtendedPubKey(accountPubKey, true); err != nil {
- return nil, err
- }
-
- // Determine what key scope the account public key should belong to and
- // whether it should use a custom address schema.
- keyScope, addrSchema, err := keyScopeFromPubKey(accountPubKey, addrType)
- if err != nil {
- return nil, err
- }
-
- return w.importAccountScope(
- ns, name, accountPubKey, masterKeyFingerprint, keyScope,
- addrSchema,
- )
-}
-
-// importAccountScope imports a watch-only account for a given scope.
-func (w *Wallet) importAccountScope(ns walletdb.ReadWriteBucket, name string,
- accountPubKey *hdkeychain.ExtendedKey, masterKeyFingerprint uint32,
- keyScope waddrmgr.KeyScope, addrSchema *waddrmgr.ScopeAddrSchema) (
- *waddrmgr.AccountProperties, error) {
-
- scopedMgr, err := w.Manager.FetchScopedKeyManager(keyScope)
- if err != nil {
- scopedMgr, err = w.Manager.NewScopedKeyManager(
- ns, keyScope, *addrSchema,
- )
- if err != nil {
- return nil, err
- }
- }
-
- account, err := scopedMgr.NewAccountWatchingOnly(
- ns, name, accountPubKey, masterKeyFingerprint, addrSchema,
- )
- if err != nil {
- return nil, err
- }
- return scopedMgr.AccountProperties(ns, account)
-}
-
-// ImportAccountDryRun serves as a dry run implementation of ImportAccount. This
-// method also returns the first N external and internal addresses, which can be
-// presented to users to confirm whether the account has been imported
-// correctly.
-func (w *Wallet) ImportAccountDryRun(name string,
- accountPubKey *hdkeychain.ExtendedKey, masterKeyFingerprint uint32,
- addrType *waddrmgr.AddressType, numAddrs uint32) (
- *waddrmgr.AccountProperties, []waddrmgr.ManagedAddress,
- []waddrmgr.ManagedAddress, error) {
-
- // The address manager uses OnCommit on the walletdb tx to update the
- // in-memory state of the account state. But because the commit happens
- // _after_ the account manager internal lock has been released, there
- // is a chance for the address index to be accessed concurrently, even
- // though the closure in OnCommit re-acquires the lock. To avoid this
- // issue, we surround the whole address creation process with a lock.
- w.newAddrMtx.Lock()
- defer w.newAddrMtx.Unlock()
-
- var (
- accountProps *waddrmgr.AccountProperties
- externalAddrs []waddrmgr.ManagedAddress
- internalAddrs []waddrmgr.ManagedAddress
- )
-
- // Start a database transaction that we'll never commit and always
- // rollback because we'll return a specific error in the end.
- err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
-
- // Import the account as usual.
- var err error
- accountProps, err = w.importAccount(
- ns, name, accountPubKey, masterKeyFingerprint, addrType,
- )
- if err != nil {
- return err
- }
-
- // Derive the external and internal addresses. Note that we
- // could do this based on the provided accountPubKey alone, but
- // we go through the ScopedKeyManager instead to ensure
- // addresses will be derived as expected from the wallet's
- // point-of-view.
- manager, err := w.Manager.FetchScopedKeyManager(
- accountProps.KeyScope,
- )
- if err != nil {
- return err
- }
-
- // The importAccount method above will cache the imported
- // account within the scoped manager. Since this is a dry-run
- // attempt, we'll want to invalidate the cache for it.
- defer manager.InvalidateAccountCache(accountProps.AccountNumber)
-
- externalAddrs, err = manager.NextExternalAddresses(
- ns, accountProps.AccountNumber, numAddrs,
- )
- if err != nil {
- return err
- }
- internalAddrs, err = manager.NextInternalAddresses(
- ns, accountProps.AccountNumber, numAddrs,
- )
- if err != nil {
- return err
- }
-
- // Refresh the account's properties after generating the
- // addresses.
- accountProps, err = manager.AccountProperties(
- ns, accountProps.AccountNumber,
- )
- if err != nil {
- return err
- }
-
- // Make sure we always roll back the dry-run transaction by
- // returning an error here.
- return walletdb.ErrDryRunRollBack
- })
- if err != nil && err != walletdb.ErrDryRunRollBack {
- return nil, nil, nil, err
- }
-
- return accountProps, externalAddrs, internalAddrs, nil
-}
-
-// ImportPublicKey imports a single derived public key into the address manager.
-// The address type can usually be inferred from the key's version, but in the
-// case of legacy versions (xpub, tpub), an address type must be specified as we
-// intend to not support importing BIP-44 keys into the wallet using the legacy
-// pay-to-pubkey-hash (P2PKH) scheme.
-func (w *Wallet) ImportPublicKey(pubKey *btcec.PublicKey,
- addrType waddrmgr.AddressType) error {
-
- // Determine what key scope the public key should belong to and import
- // it into the key scope's default imported account.
- var keyScope waddrmgr.KeyScope
- switch addrType {
- case waddrmgr.NestedWitnessPubKey:
- keyScope = waddrmgr.KeyScopeBIP0049Plus
-
- case waddrmgr.WitnessPubKey:
- keyScope = waddrmgr.KeyScopeBIP0084
-
- case waddrmgr.TaprootPubKey:
- keyScope = waddrmgr.KeyScopeBIP0086
-
- default:
- return fmt.Errorf("address type %v is not supported", addrType)
- }
-
- scopedKeyManager, err := w.Manager.FetchScopedKeyManager(keyScope)
- if err != nil {
- return err
- }
-
- // TODO: Perform rescan if requested.
- var addr waddrmgr.ManagedAddress
- err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- addr, err = scopedKeyManager.ImportPublicKey(ns, pubKey, nil)
- return err
- })
- if err != nil {
- return err
- }
-
- log.Infof("Imported address %v", addr.Address())
-
- err = w.chainClient.NotifyReceived([]address.Address{addr.Address()})
- if err != nil {
- return fmt.Errorf("unable to subscribe for address "+
- "notifications: %w", err)
- }
-
- return nil
-}
-
-// ImportTaprootScript imports a user-provided taproot script into the address
-// manager. The imported script will act as a pay-to-taproot address.
-func (w *Wallet) ImportTaprootScript(scope waddrmgr.KeyScope,
- tapscript *waddrmgr.Tapscript, bs *waddrmgr.BlockStamp,
- witnessVersion byte, isSecretScript bool) (waddrmgr.ManagedAddress,
- error) {
-
- manager, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return nil, err
- }
-
- // The starting block for the key is the genesis block unless otherwise
- // specified.
- if bs == nil {
- bs = &waddrmgr.BlockStamp{
- Hash: *w.chainParams.GenesisHash,
- Height: 0,
- Timestamp: w.chainParams.GenesisBlock.Header.Timestamp,
- }
- } else if bs.Timestamp.IsZero() {
- // Only update the new birthday time from default value if we
- // actually have timestamp info in the header.
- header, err := w.chainClient.GetBlockHeader(&bs.Hash)
- if err == nil {
- bs.Timestamp = header.Timestamp
- }
- }
-
- // TODO: Perform rescan if requested.
- var addr waddrmgr.ManagedAddress
- err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- addr, err = manager.ImportTaprootScript(
- ns, tapscript, bs, witnessVersion, isSecretScript,
- )
- return err
- })
- if err != nil {
- return nil, err
- }
-
- log.Infof("Imported address %v", addr.Address())
-
- err = w.chainClient.NotifyReceived([]address.Address{addr.Address()})
- if err != nil {
- return nil, fmt.Errorf("unable to subscribe for address "+
- "notifications: %w", err)
- }
-
- return addr, nil
-}
-
-// ImportPrivateKey imports a private key to the wallet and writes the new
-// wallet to disk.
-//
-// NOTE: If a block stamp is not provided, then the wallet's birthday will be
-// set to the genesis block of the corresponding chain.
-func (w *Wallet) ImportPrivateKey(scope waddrmgr.KeyScope, wif *btcutil.WIF,
- bs *waddrmgr.BlockStamp, rescan bool) (string, error) {
-
- manager, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return "", err
- }
-
- // The starting block for the key is the genesis block unless otherwise
- // specified.
- if bs == nil {
- bs = &waddrmgr.BlockStamp{
- Hash: *w.chainParams.GenesisHash,
- Height: 0,
- Timestamp: w.chainParams.GenesisBlock.Header.Timestamp,
- }
- } else if bs.Timestamp.IsZero() {
- // Only update the new birthday time from default value if we
- // actually have timestamp info in the header.
- header, err := w.chainClient.GetBlockHeader(&bs.Hash)
- if err == nil {
- bs.Timestamp = header.Timestamp
- }
- }
-
- // Attempt to import private key into wallet.
- var addr address.Address
- var props *waddrmgr.AccountProperties
- err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- maddr, err := manager.ImportPrivateKey(addrmgrNs, wif, bs)
- if err != nil {
- return err
- }
- addr = maddr.Address()
- props, err = manager.AccountProperties(
- addrmgrNs, waddrmgr.ImportedAddrAccount,
- )
- if err != nil {
- return err
- }
-
- // We'll only update our birthday with the new one if it is
- // before our current one. Otherwise, if we do, we can
- // potentially miss detecting relevant chain events that
- // occurred between them while rescanning.
- birthdayBlock, _, err := w.Manager.BirthdayBlock(addrmgrNs)
- if err != nil {
- return err
- }
- if bs.Height >= birthdayBlock.Height {
- return nil
- }
-
- err = w.Manager.SetBirthday(addrmgrNs, bs.Timestamp)
- if err != nil {
- return err
- }
-
- // To ensure this birthday block is correct, we'll mark it as
- // unverified to prompt a sanity check at the next restart to
- // ensure it is correct as it was provided by the caller.
- return w.Manager.SetBirthdayBlock(addrmgrNs, *bs, false)
- })
- if err != nil {
- return "", err
- }
-
- // Rescan blockchain for transactions with txout scripts paying to the
- // imported address.
- if rescan {
- job := &RescanJob{
- Addrs: []address.Address{addr},
- OutPoints: nil,
- BlockStamp: *bs,
- }
-
- // Submit rescan job and log when the import has completed.
- // Do not block on finishing the rescan. The rescan success
- // or failure is logged elsewhere, and the channel is not
- // required to be read, so discard the return value.
- _ = w.SubmitRescan(job)
- } else {
- err := w.chainClient.NotifyReceived([]address.Address{addr})
- if err != nil {
- return "", fmt.Errorf("failed to subscribe for address ntfns for "+
- "address %s: %w", addr.EncodeAddress(), err)
- }
- }
-
- addrStr := addr.EncodeAddress()
- log.Infof("Imported payment address %s", addrStr)
-
- w.NtfnServer.notifyAccountProperties(props)
-
- // Return the payment address string of the imported private key.
- return addrStr, nil
-}
diff --git a/wallet/import_test.go b/wallet/import_test.go
deleted file mode 100644
index 910a849f66..0000000000
--- a/wallet/import_test.go
+++ /dev/null
@@ -1,295 +0,0 @@
-package wallet
-
-import (
- "encoding/binary"
- "strings"
- "testing"
-
- "github.com/btcsuite/btcd/address/v2"
- "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
- "github.com/btcsuite/btcd/chaincfg/v2"
- "github.com/btcsuite/btcd/txscript/v2"
- "github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/stretchr/testify/require"
-)
-
-func hardenedKey(key uint32) uint32 {
- return key + hdkeychain.HardenedKeyStart
-}
-
-func deriveAcctPubKey(t *testing.T, root *hdkeychain.ExtendedKey,
- scope waddrmgr.KeyScope, paths ...uint32) *hdkeychain.ExtendedKey {
-
- path := []uint32{hardenedKey(scope.Purpose), hardenedKey(scope.Coin)}
- path = append(path, paths...)
-
- var (
- currentKey = root
- err error
- )
- for _, pathPart := range path {
- currentKey, err = currentKey.Derive(pathPart)
- require.NoError(t, err)
- }
-
- // The Neuter() method checks the version and doesn't know any
- // non-standard methods. We need to convert them to standard, neuter,
- // then convert them back with the target extended public key version.
- pubVersionBytes := make([]byte, 4)
- copy(pubVersionBytes, chaincfg.TestNet3Params.HDPublicKeyID[:])
- switch {
- case strings.HasPrefix(root.String(), "uprv"):
- binary.BigEndian.PutUint32(pubVersionBytes, uint32(
- waddrmgr.HDVersionTestNetBIP0049,
- ))
-
- case strings.HasPrefix(root.String(), "vprv"):
- binary.BigEndian.PutUint32(pubVersionBytes, uint32(
- waddrmgr.HDVersionTestNetBIP0084,
- ))
- }
-
- currentKey, err = currentKey.CloneWithVersion(
- chaincfg.TestNet3Params.HDPrivateKeyID[:],
- )
- require.NoError(t, err)
- currentKey, err = currentKey.Neuter()
- require.NoError(t, err)
- currentKey, err = currentKey.CloneWithVersion(pubVersionBytes)
- require.NoError(t, err)
-
- return currentKey
-}
-
-type testCase struct {
- name string
- masterPriv string
- accountIndex uint32
- addrType waddrmgr.AddressType
- expectedScope waddrmgr.KeyScope
- expectedAddr string
- expectedChangeAddr string
-}
-
-var (
- testCases = []*testCase{{
- name: "bip44 with nested witness address type",
- masterPriv: "tprv8ZgxMBicQKsPeWwrFuNjEGTTDSY4mRLwd2KDJAPGa1AY" +
- "quw38bZqNMSuB3V1Va3hqJBo9Pt8Sx7kBQer5cNMrb8SYquoWPt9" +
- "Y3BZdhdtUcw",
- accountIndex: 0,
- addrType: waddrmgr.NestedWitnessPubKey,
- expectedScope: waddrmgr.KeyScopeBIP0049Plus,
- expectedAddr: "2N5YTxG9XtGXx1YyhZb7N2pwEjoZLLMHGKj",
- expectedChangeAddr: "2N7wpz5Gy2zEJTvq2MAuU6BCTEBLXNQ8dUw",
- }, {
- name: "bip44 with witness address type",
- masterPriv: "tprv8ZgxMBicQKsPeWwrFuNjEGTTDSY4mRLwd2KDJAPGa1AY" +
- "quw38bZqNMSuB3V1Va3hqJBo9Pt8Sx7kBQer5cNMrb8SYquoWPt9" +
- "Y3BZdhdtUcw",
- accountIndex: 777,
- addrType: waddrmgr.WitnessPubKey,
- expectedScope: waddrmgr.KeyScopeBIP0084,
- expectedAddr: "tb1qllxcutkzsukf8u8c8stkp464j0esu9xq7qju8x",
- expectedChangeAddr: "tb1qu6jmqglrthscptjqj3egx54wy8xqvzn5hslgw7",
- }, {
- name: "traditional bip49",
- masterPriv: "uprv8tXDerPXZ1QsVp8y6GAMSMYxPQgWi3LSY8qS5ZH9x1YRu" +
- "1kGPFjPzR73CFSbVUhdEwJbtsUgucUJ4hGQoJnNepp3RBcE6Jhdom" +
- "FD2KeY6G9",
- accountIndex: 9,
- addrType: waddrmgr.NestedWitnessPubKey,
- expectedScope: waddrmgr.KeyScopeBIP0049Plus,
- expectedAddr: "2NBCJ9WzGXZqpLpXGq3Hacybj3c4eHRcqgh",
- expectedChangeAddr: "2N3bankFu6F3ZNU41iVJQqyS9MXqp9dvn1M",
- }, {
- name: "bip49+",
- masterPriv: "uprv8tXDerPXZ1QsVp8y6GAMSMYxPQgWi3LSY8qS5ZH9x1YRu" +
- "1kGPFjPzR73CFSbVUhdEwJbtsUgucUJ4hGQoJnNepp3RBcE6Jhdom" +
- "FD2KeY6G9",
- accountIndex: 9,
- addrType: waddrmgr.WitnessPubKey,
- expectedScope: waddrmgr.KeyScopeBIP0049Plus,
- expectedAddr: "2NBCJ9WzGXZqpLpXGq3Hacybj3c4eHRcqgh",
- expectedChangeAddr: "tb1qeqn05w2hfq6axpdprhs4y7x65gxkkvfvyxqk4u",
- }, {
- name: "bip84",
- masterPriv: "vprv9DMUxX4ShgxMM7L5vcwyeSeTZNpxefKwTFMerxB3L1vJ" +
- "x7ZVdutxcUmBDTQBVPMYeaRQeM5FNGpqwysyX1CPT4VeHXJegDX8" +
- "5VJrQvaFaz3",
- accountIndex: 1,
- addrType: waddrmgr.WitnessPubKey,
- expectedScope: waddrmgr.KeyScopeBIP0084,
- expectedAddr: "tb1q5vepvcl0z8xj7kps4rsux722r4dvfwlhk6j532",
- expectedChangeAddr: "tb1qlwe2kgxcsa8x4huu79yff4rze0l5mwafg5c7xd",
- }}
-)
-
-// TestImportAccount tests that extended public keys can successfully be
-// imported into both watch only and normal wallets.
-func TestImportAccount(t *testing.T) {
- t.Parallel()
-
- for _, tc := range testCases {
- tc := tc
-
- t.Run(tc.name, func(t *testing.T) {
- t.Parallel()
-
- w, cleanup := testWallet(t)
- defer cleanup()
-
- testImportAccount(t, w, tc, false, tc.name)
- })
-
- name := tc.name + " watch-only"
- t.Run(name, func(t *testing.T) {
- t.Parallel()
-
- w, cleanup := testWalletWatchingOnly(t)
- defer cleanup()
-
- testImportAccount(t, w, tc, true, name)
- })
- }
-}
-
-func testImportAccount(t *testing.T, w *Wallet, tc *testCase, watchOnly bool,
- name string) {
-
- // First derive the master public key of the account we want to import.
- root, err := hdkeychain.NewKeyFromString(tc.masterPriv)
- require.NoError(t, err)
-
- // Derive the extended private and public key for our target account.
- acct1Pub := deriveAcctPubKey(
- t, root, tc.expectedScope, hardenedKey(tc.accountIndex),
- )
-
- // We want to make sure we can import and handle multiple accounts, so
- // we create another one.
- acct2Pub := deriveAcctPubKey(
- t, root, tc.expectedScope, hardenedKey(tc.accountIndex+1),
- )
-
- // And we also want to be able to import loose extended public keys
- // without needing to specify an explicit scope.
- acct3ExternalExtPub := deriveAcctPubKey(
- t, root, tc.expectedScope, hardenedKey(tc.accountIndex+2), 0, 0,
- )
- acct3ExternalPub, err := acct3ExternalExtPub.ECPubKey()
- require.NoError(t, err)
-
- // Do a dry run import first and check that it results in the expected
- // addresses being derived.
- _, extAddrs, intAddrs, err := w.ImportAccountDryRun(
- name+"1", acct1Pub, root.ParentFingerprint(), &tc.addrType, 1,
- )
- require.NoError(t, err)
- require.Len(t, extAddrs, 1)
- require.Equal(t, tc.expectedAddr, extAddrs[0].Address().String())
- require.Len(t, intAddrs, 1)
- require.Equal(t, tc.expectedChangeAddr, intAddrs[0].Address().String())
-
- // Import the extended public keys into new accounts.
- acct1, err := w.ImportAccount(
- name+"1", acct1Pub, root.ParentFingerprint(), &tc.addrType,
- )
- require.NoError(t, err)
- require.Equal(t, tc.expectedScope, acct1.KeyScope)
-
- acct2, err := w.ImportAccount(
- name+"2", acct2Pub, root.ParentFingerprint(), &tc.addrType,
- )
- require.NoError(t, err)
- require.Equal(t, tc.expectedScope, acct2.KeyScope)
-
- err = w.ImportPublicKey(acct3ExternalPub, tc.addrType)
- require.NoError(t, err)
-
- // If the wallet is watch only, there is no default account and our
- // imported account will be index 0.
- firstAccountIndex := uint32(1)
- numAccounts := 2
- if watchOnly {
- firstAccountIndex = 0
- numAccounts = 1
- }
-
- // We should have 2 additional accounts now.
- acctResult, err := w.Accounts(tc.expectedScope)
- require.NoError(t, err)
- require.Len(t, acctResult.Accounts, numAccounts+2)
-
- // Validate the state of the accounts.
- require.Equal(t, firstAccountIndex, acct1.AccountNumber)
- require.Equal(t, name+"1", acct1.AccountName)
- require.Equal(t, true, acct1.IsWatchOnly)
- require.Equal(t, root.ParentFingerprint(), acct1.MasterKeyFingerprint)
- require.NotNil(t, acct1.AccountPubKey)
- require.Equal(t, acct1Pub.String(), acct1.AccountPubKey.String())
- require.Equal(t, uint32(0), acct1.InternalKeyCount)
- require.Equal(t, uint32(0), acct1.ExternalKeyCount)
- require.Equal(t, uint32(0), acct1.ImportedKeyCount)
-
- require.Equal(t, firstAccountIndex+1, acct2.AccountNumber)
- require.Equal(t, name+"2", acct2.AccountName)
- require.Equal(t, true, acct2.IsWatchOnly)
- require.Equal(t, root.ParentFingerprint(), acct2.MasterKeyFingerprint)
- require.NotNil(t, acct2.AccountPubKey)
- require.Equal(t, acct2Pub.String(), acct2.AccountPubKey.String())
- require.Equal(t, uint32(0), acct2.InternalKeyCount)
- require.Equal(t, uint32(0), acct2.ExternalKeyCount)
- require.Equal(t, uint32(0), acct2.ImportedKeyCount)
-
- // Test address derivation.
- extAddr, err := w.NewAddress(acct1.AccountNumber, tc.expectedScope)
- require.NoError(t, err)
- require.Equal(t, tc.expectedAddr, extAddr.String())
- intAddr, err := w.NewChangeAddress(acct1.AccountNumber, tc.expectedScope)
- require.NoError(t, err)
- require.Equal(t, tc.expectedChangeAddr, intAddr.String())
-
- // Make sure the key count was increased.
- acct1, err = w.AccountProperties(tc.expectedScope, acct1.AccountNumber)
- require.NoError(t, err)
- require.Equal(t, uint32(1), acct1.InternalKeyCount)
- require.Equal(t, uint32(1), acct1.ExternalKeyCount)
- require.Equal(t, uint32(0), acct1.ImportedKeyCount)
-
- // Make sure we can't get private keys for the imported accounts.
- _, err = w.DumpWIFPrivateKey(intAddr)
- require.True(t, waddrmgr.IsError(err, waddrmgr.ErrWatchingOnly))
-
- // Get the address info for the single key we imported.
- switch tc.addrType {
- case waddrmgr.NestedWitnessPubKey:
- witnessAddr, err := address.NewAddressWitnessPubKeyHash(
- address.Hash160(acct3ExternalPub.SerializeCompressed()),
- &chaincfg.TestNet3Params,
- )
- require.NoError(t, err)
-
- witnessProg, err := txscript.PayToAddrScript(witnessAddr)
- require.NoError(t, err)
-
- intAddr, err = address.NewAddressScriptHash(
- witnessProg, &chaincfg.TestNet3Params,
- )
- require.NoError(t, err)
-
- case waddrmgr.WitnessPubKey:
- intAddr, err = address.NewAddressWitnessPubKeyHash(
- address.Hash160(acct3ExternalPub.SerializeCompressed()),
- &chaincfg.TestNet3Params,
- )
- require.NoError(t, err)
-
- default:
- t.Fatalf("unhandled address type %v", tc.addrType)
- }
-
- addrManaged, err := w.AddressInfo(intAddr)
- require.NoError(t, err)
- require.Equal(t, true, addrManaged.Imported())
-}
diff --git a/wallet/interface.go b/wallet/interface.go
index e5ad53e529..3e1229590c 100644
--- a/wallet/interface.go
+++ b/wallet/interface.go
@@ -30,11 +30,15 @@ import (
//
//nolint:interfacebloat
type Interface interface {
- // Start starts the goroutines necessary to manage a wallet.
- Start()
+ // StartDeprecated starts the goroutines necessary to manage a wallet.
+ //
+ // Deprecated: Use WalletController.Start instead.
+ StartDeprecated()
- // Stop signals all wallet goroutines to shutdown.
- Stop()
+ // StopDeprecated signals all wallet goroutines to shutdown.
+ //
+ // Deprecated: Use WalletController.Stop instead.
+ StopDeprecated()
// WaitForShutdown blocks until all wallet goroutines have finished.
WaitForShutdown()
@@ -47,14 +51,14 @@ type Interface interface {
// will fail.
Locked() bool
- // Unlock unlocks the wallet with a passphrase. The wallet will
- // automatically re-lock after the timeout has expired. If the timeout
- // channel is nil, the wallet remains unlocked indefinitely.
- Unlock(passphrase []byte, lock <-chan time.Time) error
+ // UnlockDeprecated unlocks the wallet with a passphrase. The wallet
+ // will remain unlocked until the returned lock channel is closed or
+ // the timeout expires.
+ UnlockDeprecated(passphrase []byte, lock <-chan time.Time) error
- // Lock locks the wallet. Any operations that require private keys will
- // fail until the wallet is unlocked again.
- Lock()
+ // LockDeprecated locks the wallet. Any operations that require private
+ // keys will fail if the wallet is locked.
+ LockDeprecated()
// ChainSynced returns whether the wallet is synchronized with the
// blockchain. Certain operations may fail if the wallet is not synced.
@@ -82,7 +86,7 @@ type Interface interface {
NotificationServer() *NotificationServer
// AddrManager returns the internal address manager.
- AddrManager() *waddrmgr.Manager
+ AddrManager() waddrmgr.AddrStore
// Accounts returns all accounts for a particular scope.
Accounts(scope waddrmgr.KeyScope) (*AccountsResult, error)
@@ -113,14 +117,22 @@ type Interface interface {
AccountManagedAddresses(scope waddrmgr.KeyScope,
accountNum uint32) ([]waddrmgr.ManagedAddress, error)
- // RenameAccount renames an existing account. It is an error to rename
- // a reserved account or to choose a name that is already in use.
- RenameAccount(scope waddrmgr.KeyScope, account uint32,
+ // RenameAccountDeprecated renames an existing account. It is an error
+ // to rename a reserved account or to choose a name that is already in
+ // use.
+ //
+ // Deprecated: Use AccountManager.RenameAccount instead.
+ RenameAccountDeprecated(scope waddrmgr.KeyScope, account uint32,
newName string) error
- // ImportAccount imports an account backed by an extended public key.
+ // ImportAccountDeprecated imports an account backed by an extended
+ // public key.
+ //
// This creates a watch-only account.
- ImportAccount(name string, accountPubKey *hdkeychain.ExtendedKey,
+ //
+ // Deprecated: Use AccountManager.ImportAccount instead.
+ ImportAccountDeprecated(name string,
+ accountPubKey *hdkeychain.ExtendedKey,
masterKeyFingerprint uint32, addrType *waddrmgr.AddressType,
) (*waddrmgr.AccountProperties, error)
@@ -139,7 +151,7 @@ type Interface interface {
// AddScopeManager adds a new scope manager to the wallet.
AddScopeManager(scope waddrmgr.KeyScope,
addrSchema waddrmgr.ScopeAddrSchema) (
- *waddrmgr.ScopedKeyManager, error)
+ waddrmgr.AccountStore, error)
// CurrentAddress returns the current, most recently generated address
// for a given account and scope. If the current address has been used,
@@ -147,8 +159,12 @@ type Interface interface {
CurrentAddress(account uint32, scope waddrmgr.KeyScope) (
address.Address, error)
- // NewAddress returns a new address for a given account and scope.
- NewAddress(account uint32, scope waddrmgr.KeyScope) (
+ // NewAddressDeprecated returns a new address for a given account and
+ // scope.
+ //
+ // Deprecated: This method will be removed in a future release. Use the
+ // AddressManager interface instead.
+ NewAddressDeprecated(account uint32, scope waddrmgr.KeyScope) (
address.Address, error)
// NewChangeAddress returns a new change address for a given account
@@ -156,19 +172,31 @@ type Interface interface {
NewChangeAddress(account uint32, scope waddrmgr.KeyScope) (
address.Address, error)
- // AddressInfo returns detailed information about a managed address,
- // including its derivation path and whether it's compressed.
- AddressInfo(a address.Address) (waddrmgr.ManagedAddress, error)
+ // AddressInfoDeprecated returns detailed information about a managed
+ // address, including its derivation path and whether it's compressed.
+ //
+ // Deprecated: This method leaks internal waddrmgr types. Callers
+ // should use specific methods such as AccountOfAddress,
+ // IsInternalAddress, etc. instead.
+ AddressInfoDeprecated(a address.Address) (
+ waddrmgr.ManagedAddress, error,
+ )
// HaveAddress returns whether the wallet is the owner of the address.
HaveAddress(a address.Address) (bool, error)
- // ImportPublicKey imports a public key as a watch-only address.
- ImportPublicKey(pubKey *btcec.PublicKey,
+ // ImportPublicKeyDeprecated imports a public key as a watch-only
+ // address.
+ //
+ // Deprecated: Use AddressManager.ImportPublicKey instead.
+ ImportPublicKeyDeprecated(pubKey *btcec.PublicKey,
addrType waddrmgr.AddressType) error
- // ImportTaprootScript imports a taproot script into the wallet.
- ImportTaprootScript(scope waddrmgr.KeyScope,
+ // ImportTaprootScriptDeprecated imports a taproot script into the
+ // wallet.
+ //
+ // Deprecated: Use AddressManager.ImportTaprootScript instead.
+ ImportTaprootScriptDeprecated(scope waddrmgr.KeyScope,
tapscript *waddrmgr.Tapscript, bs *waddrmgr.BlockStamp,
witnessVersion byte, isSecretScript bool) (
waddrmgr.ManagedAddress, error)
@@ -185,9 +213,11 @@ type Interface interface {
CalculateAccountBalances(account uint32, requiredConfirmations int32) (
Balances, error)
- // ListUnspent returns all unspent transaction outputs for a given
- // account and confirmation requirement.
- ListUnspent(minconf, maxconf int32, accountName string) (
+ // ListUnspentDeprecated returns all unspent transaction outputs for a
+ // given account and confirmation requirement.
+ //
+ // Deprecated: Use UtxoManager.ListUnspent instead.
+ ListUnspentDeprecated(minconf, maxconf int32, accountName string) (
[]*btcjson.ListUnspentResult, error)
// FetchOutpointInfo returns the output information for a given
@@ -208,10 +238,10 @@ type Interface interface {
// and should not be used as an input for created transactions.
LockedOutpoint(op wire.OutPoint) bool
- // LeaseOutput locks an output to the given ID, preventing it from
- // being available for coin selection. The absolute time of the lock's
- // expiration is returned. The expiration of the lock can be extended by
- // successive invocations of this call.
+ // LeaseOutputDeprecated locks an output to the given ID, preventing it
+ // from being available for coin selection. The absolute time of the
+ // lock's expiration is returned. The expiration of the lock can be
+ // extended by successive invocations of this call.
//
// Outputs can be unlocked before their expiration through
// `UnlockOutput`. Otherwise, they are unlocked lazily through calls
@@ -224,16 +254,23 @@ type Interface interface {
//
// NOTE: This differs from LockOutpoint in that outputs are locked for
// a limited amount of time and their locks are persisted to disk.
- LeaseOutput(id wtxmgr.LockID, op wire.OutPoint,
+ //
+ // Deprecated: Use UtxoManager.LeaseOutput instead.
+ LeaseOutputDeprecated(id wtxmgr.LockID, op wire.OutPoint,
duration time.Duration) (time.Time, error)
- // ReleaseOutput unlocks an output, allowing it to be available for
- // coin selection if it remains unspent. The ID should match the one
- // used to originally lock the output.
- ReleaseOutput(id wtxmgr.LockID, op wire.OutPoint) error
+ // ReleaseOutputDeprecated unlocks an output, allowing it to be
+ // available for coin selection if it remains unspent. The ID should
+ // match the one used to originally lock the output.
+ //
+ // Deprecated: Use UtxoManager.ReleaseOutput instead.
+ ReleaseOutputDeprecated(id wtxmgr.LockID, op wire.OutPoint) error
- // ListLeasedOutputs returns a list of all currently leased outputs.
- ListLeasedOutputs() ([]*ListLeasedOutputResult, error)
+ // ListLeasedOutputsDeprecated returns a list of all currently leased
+ // outputs.
+ //
+ // Deprecated: Use UtxoManager.ListLeasedOutputs instead.
+ ListLeasedOutputsDeprecated() ([]*ListLeasedOutputResult, error)
// CreateSimpleTx creates a new transaction to the specified outputs,
// automatically performing coin selection and creating a change output
@@ -261,21 +298,21 @@ type Interface interface {
// PublishTransaction broadcasts a transaction to the network.
PublishTransaction(tx *wire.MsgTx, label string) error
- // FundPsbt creates a PSBT with enough inputs to fund the specified
- // outputs, adding a change output if necessary.
- FundPsbt(packet *psbt.Packet, keyScope *waddrmgr.KeyScope,
+ // FundPsbtDeprecated creates a PSBT with enough inputs to fund the
+ // specified outputs, adding a change output if necessary.
+ FundPsbtDeprecated(packet *psbt.Packet, keyScope *waddrmgr.KeyScope,
minConfs int32, account uint32, feeSatPerKB btcutil.Amount,
strategy CoinSelectionStrategy,
optFuncs ...TxCreateOption) (int32, error)
- // FinalizePsbt signs and finalizes a PSBT, making it ready for
- // broadcast. The wallet must be the last signer.
- FinalizePsbt(keyScope *waddrmgr.KeyScope, account uint32,
+ // FinalizePsbtDeprecated signs and finalizes a PSBT, making it ready
+ // for broadcast. The wallet must be the last signer.
+ FinalizePsbtDeprecated(keyScope *waddrmgr.KeyScope, account uint32,
packet *psbt.Packet) error
- // DecorateInputs decorates the inputs of a PSBT with the necessary
- // information to sign it.
- DecorateInputs(packet *psbt.Packet, failOnUnknown bool) error
+ // DecorateInputsDeprecated decorates the inputs of a PSBT with the
+ // necessary information to sign it.
+ DecorateInputsDeprecated(packet *psbt.Packet, failOnUnknown bool) error
// GetTransaction returns the details for a transaction given its hash.
GetTransaction(txHash chainhash.Hash) (*GetTransactionResult, error)
@@ -306,18 +343,20 @@ type Interface interface {
DeriveFromKeyPathAddAccount(scope waddrmgr.KeyScope,
path waddrmgr.DerivationPath) (*btcec.PrivateKey, error)
- // ComputeInputScript generates a complete InputScript for the passed
- // transaction with the signature as defined within the passed
- // SignDescriptor.
+ // ComputeInputScript generates a complete InputScript for the
+ // passed transaction with the signature as defined within the
+ // passed SignDescriptor.
ComputeInputScript(tx *wire.MsgTx, output *wire.TxOut,
inputIndex int, sigHashes *txscript.TxSigHashes,
hashType txscript.SigHashType,
tweaker PrivKeyTweaker) (wire.TxWitness, []byte, error)
- // ScriptForOutput returns the address, witness program and redeem
- // script for a given UTXO.
- ScriptForOutput(output *wire.TxOut) (waddrmgr.ManagedPubKeyAddress,
- []byte, []byte, error)
+ // ScriptForOutputDeprecated returns the address, witness program and
+ // redeem script for a given UTXO.
+ //
+ // Deprecated: Use AddressManager.ScriptForOutput instead.
+ ScriptForOutputDeprecated(output *wire.TxOut) (
+ waddrmgr.ManagedPubKeyAddress, []byte, []byte, error)
}
// A compile time check to ensure that Wallet implements the interface.
diff --git a/wallet/loader.go b/wallet/loader.go
deleted file mode 100644
index 15dbcd3c96..0000000000
--- a/wallet/loader.go
+++ /dev/null
@@ -1,433 +0,0 @@
-// Copyright (c) 2015-2016 The btcsuite developers
-// Use of this source code is governed by an ISC
-// license that can be found in the LICENSE file.
-
-package wallet
-
-import (
- "errors"
- "fmt"
- "os"
- "path/filepath"
- "sync"
- "time"
-
- "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
- "github.com/btcsuite/btcd/chaincfg/v2"
- "github.com/btcsuite/btcwallet/internal/prompt"
- "github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/btcsuite/btcwallet/walletdb"
-)
-
-const (
- // WalletDBName specified the database filename for the wallet.
- WalletDBName = "wallet.db"
-
- // DefaultDBTimeout is the default timeout value when opening the wallet
- // database.
- DefaultDBTimeout = 60 * time.Second
-)
-
-var (
- // ErrLoaded describes the error condition of attempting to load or
- // create a wallet when the loader has already done so.
- ErrLoaded = errors.New("wallet already loaded")
-
- // ErrNotLoaded describes the error condition of attempting to close a
- // loaded wallet when a wallet has not been loaded.
- ErrNotLoaded = errors.New("wallet is not loaded")
-
- // ErrExists describes the error condition of attempting to create a new
- // wallet when one exists already.
- ErrExists = errors.New("wallet already exists")
-)
-
-// loaderConfig contains the configuration options for the loader.
-type loaderConfig struct {
- walletSyncRetryInterval time.Duration
-}
-
-// defaultLoaderConfig returns the default configuration options for the loader.
-func defaultLoaderConfig() *loaderConfig {
- return &loaderConfig{
- walletSyncRetryInterval: defaultSyncRetryInterval,
- }
-}
-
-// LoaderOption is a configuration option for the loader.
-type LoaderOption func(*loaderConfig)
-
-// WithWalletSyncRetryInterval specifies the interval at which the wallet
-// should retry syncing to the chain if it encounters an error.
-func WithWalletSyncRetryInterval(interval time.Duration) LoaderOption {
- return func(c *loaderConfig) {
- c.walletSyncRetryInterval = interval
- }
-}
-
-// Loader implements the creating of new and opening of existing wallets, while
-// providing a callback system for other subsystems to handle the loading of a
-// wallet. This is primarily intended for use by the RPC servers, to enable
-// methods and services which require the wallet when the wallet is loaded by
-// another subsystem.
-//
-// Loader is safe for concurrent access.
-type Loader struct {
- cfg *loaderConfig
- callbacks []func(*Wallet)
- chainParams *chaincfg.Params
- dbDirPath string
- noFreelistSync bool
- timeout time.Duration
- recoveryWindow uint32
- wallet *Wallet
- localDB bool
- walletExists func() (bool, error)
- walletCreated func(db walletdb.ReadWriteTx) error
- db walletdb.DB
- mu sync.Mutex
-}
-
-// NewLoader constructs a Loader with an optional recovery window. If the
-// recovery window is non-zero, the wallet will attempt to recovery addresses
-// starting from the last SyncedTo height.
-func NewLoader(chainParams *chaincfg.Params, dbDirPath string,
- noFreelistSync bool, timeout time.Duration, recoveryWindow uint32,
- opts ...LoaderOption) *Loader {
-
- cfg := defaultLoaderConfig()
- for _, opt := range opts {
- opt(cfg)
- }
-
- return &Loader{
- cfg: cfg,
- chainParams: chainParams,
- dbDirPath: dbDirPath,
- noFreelistSync: noFreelistSync,
- timeout: timeout,
- recoveryWindow: recoveryWindow,
- localDB: true,
- }
-}
-
-// NewLoaderWithDB constructs a Loader with an externally provided DB. This way
-// users are free to use their own walletdb implementation (eg. leveldb, etcd)
-// to store the wallet. Given that the external DB may be shared an additional
-// function is also passed which will override Loader.WalletExists().
-func NewLoaderWithDB(chainParams *chaincfg.Params, recoveryWindow uint32,
- db walletdb.DB, walletExists func() (bool, error),
- opts ...LoaderOption) (*Loader, error) {
-
- if db == nil {
- return nil, fmt.Errorf("no DB provided")
- }
-
- if walletExists == nil {
- return nil, fmt.Errorf("unable to check if wallet exists")
- }
-
- cfg := defaultLoaderConfig()
- for _, opt := range opts {
- opt(cfg)
- }
-
- return &Loader{
- cfg: cfg,
- chainParams: chainParams,
- recoveryWindow: recoveryWindow,
- localDB: false,
- walletExists: walletExists,
- db: db,
- }, nil
-}
-
-// onLoaded executes each added callback and prevents loader from loading any
-// additional wallets. Requires mutex to be locked.
-func (l *Loader) onLoaded(w *Wallet) {
- for _, fn := range l.callbacks {
- fn(w)
- }
-
- l.wallet = w
- l.callbacks = nil // not needed anymore
-}
-
-// RunAfterLoad adds a function to be executed when the loader creates or opens
-// a wallet. Functions are executed in a single goroutine in the order they are
-// added.
-func (l *Loader) RunAfterLoad(fn func(*Wallet)) {
- l.mu.Lock()
- if l.wallet != nil {
- w := l.wallet
- l.mu.Unlock()
- fn(w)
- } else {
- l.callbacks = append(l.callbacks, fn)
- l.mu.Unlock()
- }
-}
-
-// OnWalletCreated adds a function that will be executed the wallet structure
-// is initialized in the wallet database. This is useful if users want to add
-// extra fields in the same transaction (eg. to flag wallet existence).
-func (l *Loader) OnWalletCreated(fn func(walletdb.ReadWriteTx) error) {
- l.mu.Lock()
- defer l.mu.Unlock()
- l.walletCreated = fn
-}
-
-// CreateNewWallet creates a new wallet using the provided public and private
-// passphrases. The seed is optional. If non-nil, addresses are derived from
-// this seed. If nil, a secure random seed is generated.
-func (l *Loader) CreateNewWallet(pubPassphrase, privPassphrase, seed []byte,
- bday time.Time) (*Wallet, error) {
-
- var (
- rootKey *hdkeychain.ExtendedKey
- err error
- )
-
- // If a seed was specified, we check its length now. If no seed is
- // passed, the wallet will create a new random one.
- if seed != nil {
- if len(seed) < hdkeychain.MinSeedBytes ||
- len(seed) > hdkeychain.MaxSeedBytes {
-
- return nil, hdkeychain.ErrInvalidSeedLen
- }
-
- // Derive the master extended key from the seed.
- rootKey, err = hdkeychain.NewMaster(seed, l.chainParams)
- if err != nil {
- return nil, fmt.Errorf("failed to derive master " +
- "extended key")
- }
- }
-
- return l.createNewWallet(
- pubPassphrase, privPassphrase, rootKey, bday, false,
- )
-}
-
-// CreateNewWalletExtendedKey creates a new wallet from an extended master root
-// key using the provided public and private passphrases. The root key is
-// optional. If non-nil, addresses are derived from this root key. If nil, a
-// secure random seed is generated and the root key is derived from that.
-func (l *Loader) CreateNewWalletExtendedKey(pubPassphrase, privPassphrase []byte,
- rootKey *hdkeychain.ExtendedKey, bday time.Time) (*Wallet, error) {
-
- return l.createNewWallet(
- pubPassphrase, privPassphrase, rootKey, bday, false,
- )
-}
-
-// CreateNewWatchingOnlyWallet creates a new wallet using the provided
-// public passphrase. No seed or private passphrase may be provided
-// since the wallet is watching-only.
-func (l *Loader) CreateNewWatchingOnlyWallet(pubPassphrase []byte,
- bday time.Time) (*Wallet, error) {
-
- return l.createNewWallet(
- pubPassphrase, nil, nil, bday, true,
- )
-}
-
-func (l *Loader) createNewWallet(pubPassphrase, privPassphrase []byte,
- rootKey *hdkeychain.ExtendedKey, bday time.Time,
- isWatchingOnly bool) (*Wallet, error) {
-
- defer l.mu.Unlock()
- l.mu.Lock()
-
- if l.wallet != nil {
- return nil, ErrLoaded
- }
-
- exists, err := l.WalletExists()
- if err != nil {
- return nil, err
- }
- if exists {
- return nil, ErrExists
- }
-
- if l.localDB {
- dbPath := filepath.Join(l.dbDirPath, WalletDBName)
-
- // Create the wallet database backed by bolt db.
- err = os.MkdirAll(l.dbDirPath, 0700)
- if err != nil {
- return nil, err
- }
- l.db, err = walletdb.Create(
- "bdb", dbPath, l.noFreelistSync, l.timeout, false,
- )
- if err != nil {
- return nil, err
- }
- }
-
- // Initialize the newly created database for the wallet before opening.
- if isWatchingOnly {
- err := CreateWatchingOnlyWithCallback(
- l.db, pubPassphrase, l.chainParams, bday,
- l.walletCreated,
- )
- if err != nil {
- return nil, err
- }
- } else {
- err := CreateWithCallback(
- l.db, pubPassphrase, privPassphrase, rootKey,
- l.chainParams, bday, l.walletCreated,
- )
- if err != nil {
- return nil, err
- }
- }
-
- // Open the newly-created wallet.
- w, err := OpenWithRetry(
- l.db, pubPassphrase, nil, l.chainParams, l.recoveryWindow,
- l.cfg.walletSyncRetryInterval,
- )
- if err != nil {
- return nil, err
- }
- w.Start()
-
- l.onLoaded(w)
- return w, nil
-}
-
-var errNoConsole = errors.New("db upgrade requires console access for additional input")
-
-func noConsole() ([]byte, error) {
- return nil, errNoConsole
-}
-
-// OpenExistingWallet opens the wallet from the loader's wallet database path
-// and the public passphrase. If the loader is being called by a context where
-// standard input prompts may be used during wallet upgrades, setting
-// canConsolePrompt will enables these prompts.
-func (l *Loader) OpenExistingWallet(pubPassphrase []byte,
- canConsolePrompt bool) (*Wallet, error) {
-
- defer l.mu.Unlock()
- l.mu.Lock()
-
- if l.wallet != nil {
- return nil, ErrLoaded
- }
-
- if l.localDB {
- var err error
- // Ensure that the network directory exists.
- if err = checkCreateDir(l.dbDirPath); err != nil {
- return nil, err
- }
-
- // Open the database using the boltdb backend.
- dbPath := filepath.Join(l.dbDirPath, WalletDBName)
- l.db, err = walletdb.Open(
- "bdb", dbPath, l.noFreelistSync, l.timeout, false,
- )
- if err != nil {
- log.Errorf("Failed to open database: %v", err)
- return nil, err
- }
- }
-
- var cbs *waddrmgr.OpenCallbacks
- if canConsolePrompt {
- cbs = &waddrmgr.OpenCallbacks{
- ObtainSeed: prompt.ProvideSeed,
- ObtainPrivatePass: prompt.ProvidePrivPassphrase,
- }
- } else {
- cbs = &waddrmgr.OpenCallbacks{
- ObtainSeed: noConsole,
- ObtainPrivatePass: noConsole,
- }
- }
- w, err := OpenWithRetry(
- l.db, pubPassphrase, cbs, l.chainParams, l.recoveryWindow,
- l.cfg.walletSyncRetryInterval,
- )
- if err != nil {
- // If opening the wallet fails (e.g. because of wrong
- // passphrase), we must close the backing database to
- // allow future calls to walletdb.Open().
- if l.localDB {
- e := l.db.Close()
- if e != nil {
- log.Warnf("Error closing database: %v", e)
- }
- }
-
- return nil, err
- }
- w.Start()
-
- l.onLoaded(w)
- return w, nil
-}
-
-// WalletExists returns whether a file exists at the loader's database path.
-// This may return an error for unexpected I/O failures.
-func (l *Loader) WalletExists() (bool, error) {
- if l.localDB {
- dbPath := filepath.Join(l.dbDirPath, WalletDBName)
- return fileExists(dbPath)
- }
-
- return l.walletExists()
-}
-
-// LoadedWallet returns the loaded wallet, if any, and a bool for whether the
-// wallet has been loaded or not. If true, the wallet pointer should be safe to
-// dereference.
-func (l *Loader) LoadedWallet() (*Wallet, bool) {
- l.mu.Lock()
- w := l.wallet
- l.mu.Unlock()
- return w, w != nil
-}
-
-// UnloadWallet stops the loaded wallet, if any, and closes the wallet database.
-// This returns ErrNotLoaded if the wallet has not been loaded with
-// CreateNewWallet or LoadExistingWallet. The Loader may be reused if this
-// function returns without error.
-func (l *Loader) UnloadWallet() error {
- defer l.mu.Unlock()
- l.mu.Lock()
-
- if l.wallet == nil {
- return ErrNotLoaded
- }
-
- l.wallet.Stop()
- l.wallet.WaitForShutdown()
- if l.localDB {
- err := l.db.Close()
- if err != nil {
- return err
- }
- }
-
- l.wallet = nil
- l.db = nil
- return nil
-}
-
-func fileExists(filePath string) (bool, error) {
- _, err := os.Stat(filePath)
- if err != nil {
- if os.IsNotExist(err) {
- return false, nil
- }
- return false, err
- }
- return true, nil
-}
diff --git a/wallet/manager.go b/wallet/manager.go
new file mode 100644
index 0000000000..7264e4ddfd
--- /dev/null
+++ b/wallet/manager.go
@@ -0,0 +1,409 @@
+package wallet
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sort"
+ "sync"
+ "time"
+
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+)
+
+var (
+ // ErrWalletParams is returned when the creation parameters are invalid.
+ ErrWalletParams = errors.New("invalid wallet params")
+)
+
+// CreateMode determines how a new wallet is initialized.
+type CreateMode uint8
+
+const (
+ // ModeUnknown indicates no specific creation mode.
+ ModeUnknown CreateMode = iota
+
+ // ModeGenSeed indicates creating a new wallet by generating a fresh random
+ // seed.
+ ModeGenSeed
+
+ // ModeImportSeed indicates restoring a wallet from a provided seed
+ // (CreateWalletParams.Seed).
+ ModeImportSeed
+
+ // ModeImportExtKey indicates creating a wallet from an extended key
+ // (CreateWalletParams.RootKey).
+ ModeImportExtKey
+
+ // ModeShell indicates creating an empty wallet shell (no root key).
+ // Intended for importing specific Account XPubs.
+ ModeShell
+)
+
+// WatchOnlyAccount contains the information needed to import a watch-only
+// account.
+type WatchOnlyAccount struct {
+ // Scope is the key scope of the account.
+ Scope waddrmgr.KeyScope
+
+ // XPub is the extended public key for the account.
+ XPub *hdkeychain.ExtendedKey
+
+ // MasterKeyFingerprint is the fingerprint of the master key.
+ MasterKeyFingerprint uint32
+
+ // Name is the name of the account.
+ Name string
+
+ // AddrType is the address type of the account.
+ AddrType waddrmgr.AddressType
+}
+
+// CreateWalletParams holds the parameters required to initialize a new wallet.
+// These are one-time inputs used during the creation process.
+type CreateWalletParams struct {
+ // Mode determines which fields below are required.
+ Mode CreateMode
+
+ // Seed is required for ModeImportSeed. Ignored for others.
+ Seed []byte
+
+ // RootKey is required for ModeImportExtKey. Ignored for others. Can be XPrv
+ // or XPub.
+ RootKey *hdkeychain.ExtendedKey
+
+ // InitialAccounts is optional for ModeShell. Reserved for future use and
+ // currently has no effect during wallet creation.
+ InitialAccounts []WatchOnlyAccount
+
+ // WatchOnly controls whether the resulting wallet is watch-only.
+ // - If true with Seed/XPrv input: Derives Master XPub, then discards
+ // the private material.
+ // - If true with XPub/Shell input: No-op (already watch-only).
+ WatchOnly bool
+
+ // Birthday is the wallet's birthday.
+ Birthday time.Time
+
+ // PubPassphrase is the public passphrase for the wallet.
+ PubPassphrase []byte
+
+ // PrivatePassphrase is the private passphrase for the wallet.
+ PrivatePassphrase []byte
+}
+
+// validate ensures that the parameters are consistent with the chosen creation
+// mode.
+//
+// We skip cyclop because this method performs exhaustive validation of
+// mutually exclusive fields across all creation modes.
+//
+//nolint:cyclop
+func (p *CreateWalletParams) validate() error {
+ if p.Mode == ModeUnknown {
+ return fmt.Errorf("%w: unknown mode", ErrWalletParams)
+ }
+
+ // InitialAccounts should only be set for ModeShell.
+ if p.Mode != ModeShell && len(p.InitialAccounts) > 0 {
+ return fmt.Errorf("%w: initial accounts should only be set "+
+ "for ModeShell", ErrWalletParams)
+ }
+
+ if p.Mode == ModeGenSeed {
+ if len(p.Seed) != 0 {
+ return fmt.Errorf("%w: seed should not be set for "+
+ "ModeGenSeed", ErrWalletParams)
+ }
+
+ if p.RootKey != nil {
+ return fmt.Errorf("%w: root key should not be set for "+
+ "ModeGenSeed", ErrWalletParams)
+ }
+ }
+
+ if p.Mode == ModeImportSeed {
+ if len(p.Seed) == 0 {
+ return fmt.Errorf("%w: seed is required for "+
+ "ModeImportSeed", ErrWalletParams)
+ }
+
+ if p.RootKey != nil {
+ return fmt.Errorf("%w: root key should not be set for "+
+ "ModeImportSeed", ErrWalletParams)
+ }
+ }
+
+ if p.Mode == ModeImportExtKey {
+ if p.RootKey == nil {
+ return fmt.Errorf("%w: root key is required for "+
+ "ModeImportExtKey", ErrWalletParams)
+ }
+
+ if len(p.Seed) != 0 {
+ return fmt.Errorf("%w: seed should not be set for "+
+ "ModeImportExtKey", ErrWalletParams)
+ }
+ }
+
+ if p.Mode == ModeShell {
+ if len(p.Seed) != 0 {
+ return fmt.Errorf("%w: seed should not be set for "+
+ "ModeShell", ErrWalletParams)
+ }
+
+ if p.RootKey != nil {
+ return fmt.Errorf("%w: root key should not be set for "+
+ "ModeShell", ErrWalletParams)
+ }
+ }
+
+ return nil
+}
+
+// Manager is a high-level manager that handles the lifecycle of multiple
+// wallets. It acts as a factory for creating and loading wallets, and can
+// optionally track the active wallets.
+//
+// The Manager enables a one-to-many relationship, allowing a single application
+// to manage multiple distinct wallets (e.g., for different coins or different
+// accounts) simultaneously.
+type Manager struct {
+ sync.RWMutex
+
+ // wallets holds the active wallets keyed by their unique name.
+ wallets map[string]*Wallet
+}
+
+// NewManager creates a new Wallet Manager.
+func NewManager() *Manager {
+ return &Manager{
+ wallets: make(map[string]*Wallet),
+ }
+}
+
+// String returns a summary of the active wallets managed by the Manager.
+func (m *Manager) String() string {
+ m.RLock()
+ defer m.RUnlock()
+
+ names := make([]string, 0, len(m.wallets))
+ for name := range m.wallets {
+ names = append(names, name)
+ }
+
+ sort.Strings(names)
+
+ return fmt.Sprintf("active_wallets=%v", names)
+}
+
+// Create creates a new wallet based on the provided configuration and
+// initialization parameters. It initializes the database structure and then
+// loads the wallet.
+func (m *Manager) Create(cfg Config,
+ params CreateWalletParams) (*Wallet, error) {
+
+ rootKey, err := m.prepareWalletCreation(cfg, params)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create the underlying database structure.
+ err = DBCreateWallet(cfg, params, rootKey)
+ if err != nil {
+ return nil, err
+ }
+
+ // Load the newly created wallet.
+ w, err := m.Load(cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ // If we are in shell mode and have initial accounts, we import them now.
+ if params.Mode == ModeShell && len(params.InitialAccounts) > 0 {
+ err = w.importInitialAccounts(
+ context.Background(), params.InitialAccounts,
+ )
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return w, nil
+}
+
+// importInitialAccounts imports a list of watch-only accounts into the wallet.
+// This is typically used during wallet initialization in shell mode.
+func (w *Wallet) importInitialAccounts(ctx context.Context,
+ accounts []WatchOnlyAccount) error {
+
+ for _, account := range accounts {
+ _, err := w.importAccountInternal(
+ ctx, account.Name, account.XPub, account.MasterKeyFingerprint,
+ account.AddrType, false,
+ )
+ if err != nil {
+ return fmt.Errorf("failed to import account %s: %w", account.Name,
+ err)
+ }
+ }
+
+ return nil
+}
+
+// Load loads an existing wallet from the provided configuration. It opens the
+// database, initializes the wallet structure, and registers it with the manager
+// for tracking.
+func (m *Manager) Load(cfg Config) (*Wallet, error) {
+ err := cfg.validate()
+ if err != nil {
+ return nil, err
+ }
+
+ // Check if the wallet is already loaded.
+ m.RLock()
+ existingW, ok := m.wallets[cfg.Name]
+ m.RUnlock()
+
+ if ok {
+ return existingW, nil
+ }
+
+ addrMgr, txMgr, err := DBLoadWallet(cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ // Apply the safe default for auto-lock duration if not specified.
+ if cfg.AutoLockDuration == 0 {
+ cfg.AutoLockDuration = defaultLockDuration
+ }
+
+ // Initialize the auto-lock timer in a stopped state. We perform a
+ // non-blocking drain on the channel to ensure it's empty and won't fire
+ // immediately.
+ lockTimer := time.NewTimer(0)
+ if !lockTimer.Stop() {
+ <-lockTimer.C
+ }
+
+ lifetimeCtx, cancel := context.WithCancel(context.Background())
+
+ w := &Wallet{
+ cfg: cfg,
+ addrStore: addrMgr,
+ txStore: txMgr,
+ requestChan: make(chan any),
+ lifetimeCtx: lifetimeCtx,
+ cancel: cancel,
+ lockTimer: lockTimer,
+ }
+
+ w.sync = newSyncer(cfg, w.addrStore, w.txStore, w)
+ w.state = newWalletState(w.sync)
+
+ // Register the wallet.
+ m.Lock()
+ m.wallets[cfg.Name] = w
+ m.Unlock()
+
+ return w, nil
+}
+
+// prepareWalletCreation validates the configuration and parameters, and derives
+// the root key for wallet creation.
+func (m *Manager) prepareWalletCreation(cfg Config,
+ params CreateWalletParams) (*hdkeychain.ExtendedKey, error) {
+
+ err := cfg.validate()
+ if err != nil {
+ return nil, err
+ }
+
+ err = params.validate()
+ if err != nil {
+ return nil, err
+ }
+
+ rootKey, err := m.deriveRootKey(cfg, params)
+ if err != nil {
+ return nil, err
+ }
+
+ // If the wallet is NOT watch-only, we require a private root key to be able
+ // to sign transactions and derive child private keys.
+ if !params.WatchOnly && rootKey != nil && !rootKey.IsPrivate() {
+ return nil, fmt.Errorf("%w: private key required for "+
+ "non-watch-only wallet", ErrWalletParams)
+ }
+
+ return rootKey, nil
+}
+
+// deriveRootKey resolves the master extended key based on the creation mode.
+func (m *Manager) deriveRootKey(cfg Config,
+ params CreateWalletParams) (*hdkeychain.ExtendedKey, error) {
+
+ switch params.Mode {
+ case ModeGenSeed:
+ return m.genRootKey(cfg)
+
+ case ModeImportSeed:
+ return m.deriveFromSeed(cfg, params.Seed)
+
+ case ModeImportExtKey:
+ // Ensure an extended key was provided.
+ if params.RootKey == nil {
+ return nil, fmt.Errorf("%w: root key is required",
+ ErrWalletParams)
+ }
+
+ // Use the provided extended key (can be XPrv or XPub).
+ return params.RootKey, nil
+
+ case ModeShell:
+ // In shell mode, no root key is persisted. Accounts will be
+ // imported individually.
+ return nil, nil //nolint:nilnil
+
+ case ModeUnknown:
+ fallthrough
+
+ default:
+ return nil, fmt.Errorf("%w: unknown mode %v", ErrWalletParams,
+ params.Mode)
+ }
+}
+
+// genRootKey generates a fresh random seed and derives the master extended
+// private key from it.
+func (m *Manager) genRootKey(cfg Config) (*hdkeychain.ExtendedKey, error) {
+ // Generate a fresh random seed using the recommended length.
+ seed, err := hdkeychain.GenerateSeed(hdkeychain.RecommendedSeedLen)
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate seed: %w", err)
+ }
+
+ return m.deriveFromSeed(cfg, seed)
+}
+
+// deriveFromSeed derives the master extended private key from the provided
+// seed.
+func (m *Manager) deriveFromSeed(cfg Config, seed []byte) (
+ *hdkeychain.ExtendedKey, error) {
+
+ // Ensure a seed was provided for restoration.
+ if len(seed) == 0 {
+ return nil, fmt.Errorf("%w: seed is required", ErrWalletParams)
+ }
+
+ // Derive the master extended private key from the provided seed.
+ key, err := hdkeychain.NewMaster(seed, cfg.ChainParams)
+ if err != nil {
+ return nil, fmt.Errorf("failed to derive master key: %w", err)
+ }
+
+ return key, nil
+}
diff --git a/wallet/manager_test.go b/wallet/manager_test.go
new file mode 100644
index 0000000000..3ae2131039
--- /dev/null
+++ b/wallet/manager_test.go
@@ -0,0 +1,600 @@
+package wallet
+
+import (
+ "testing"
+ "time"
+
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/stretchr/testify/require"
+)
+
+// TestManagerCreateSuccess verifies that a wallet can be successfully created
+// in various modes. It checks that the Manager correctly initializes the
+// wallet structure and registers it for tracking.
+func TestManagerCreateSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Pre-calculate common setup values to be used in multiple test cases.
+ // This ensures we have valid cryptographic material ready for import
+ // scenarios.
+ seed, err := hdkeychain.GenerateSeed(hdkeychain.RecommendedSeedLen)
+ require.NoError(t, err)
+
+ rootKey, err := hdkeychain.NewMaster(seed, &chainParams)
+ require.NoError(t, err)
+
+ // Create an account XPub for ModeShell testing.
+ // Derive account key: m/44'/0'/0'
+ acctKey, err := rootKey.Derive(hdkeychain.HardenedKeyStart + 44)
+ require.NoError(t, err)
+ acctKey, err = acctKey.Derive(hdkeychain.HardenedKeyStart + 0)
+ require.NoError(t, err)
+ acctKey, err = acctKey.Derive(hdkeychain.HardenedKeyStart + 0)
+ require.NoError(t, err)
+ acctXPub, err := acctKey.Neuter()
+ require.NoError(t, err)
+
+ // Arrange: Define test cases for different creation modes.
+ tests := []struct {
+ name string
+ params CreateWalletParams
+ }{
+
+ {
+ name: "ModeGenSeed",
+ params: CreateWalletParams{
+ Mode: ModeGenSeed,
+ PubPassphrase: []byte("public"),
+ PrivatePassphrase: []byte("private"),
+ Birthday: time.Now(),
+ },
+ },
+ {
+ name: "ModeImportSeed",
+ params: CreateWalletParams{
+ Mode: ModeImportSeed,
+ Seed: seed,
+ PubPassphrase: []byte("public"),
+ PrivatePassphrase: []byte("private"),
+ Birthday: time.Now(),
+ },
+ },
+ {
+ name: "ModeImportExtKey",
+ params: CreateWalletParams{
+ Mode: ModeImportExtKey,
+ RootKey: rootKey,
+ PubPassphrase: []byte("public"),
+ PrivatePassphrase: []byte("private"),
+ Birthday: time.Now(),
+ },
+ },
+ {
+ name: "ModeShell",
+ params: CreateWalletParams{
+ Mode: ModeShell,
+ InitialAccounts: []WatchOnlyAccount{{
+ Scope: waddrmgr.KeyScopeBIP0049Plus,
+ XPub: acctXPub,
+ MasterKeyFingerprint: 0,
+ Name: "test-shell-account",
+ AddrType: waddrmgr.NestedWitnessPubKey,
+ }},
+ WatchOnly: true,
+ PubPassphrase: []byte("public"),
+ Birthday: time.Now(),
+ },
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Create a fresh test database for this run. We use setupTestDB
+ // which ensures we have a clean slate (empty buckets) to verify
+ // that Create correctly initializes the schema.
+ db, cleanup := setupTestDB(t)
+ t.Cleanup(cleanup)
+
+ m := NewManager()
+ cfg := Config{
+ DB: db,
+ Chain: &mockChain{},
+ ChainParams: &chainParams,
+ Name: "test-wallet",
+ PubPassphrase: []byte("public"),
+ RecoveryWindow: MinRecoveryWindow,
+ }
+
+ // Attempt to create the wallet with the specified parameters.
+ w, err := m.Create(cfg, tc.params)
+
+ // Verify that the wallet was created successfully and returned
+ // without error.
+ require.NoError(t, err)
+ require.NotNil(t, w)
+
+ // Verify internal state: Ensure the manager is tracking the
+ // newly created wallet in its internal map, keyed by the
+ // configuration name.
+ m.RLock()
+ loadedW, ok := m.wallets["test-wallet"]
+ m.RUnlock()
+ require.True(t, ok)
+ require.Same(t, w, loadedW)
+
+ // If ModeShell, verify account was imported.
+ if tc.params.Mode == ModeShell {
+ // We can't use w.GetAccount here because the wallet is not
+ // started. We'll verify directly against the address manager.
+ err := walletdb.View(db, func(tx walletdb.ReadTx) error {
+ ns := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ scopeMgr, err := w.addrStore.FetchScopedKeyManager(
+ tc.params.InitialAccounts[0].Scope,
+ )
+ if err != nil {
+ return err
+ }
+
+ _, err = scopeMgr.LookupAccount(
+ ns, tc.params.InitialAccounts[0].Name,
+ )
+
+ return err
+ })
+ require.NoError(t, err)
+ }
+ })
+ }
+}
+
+// TestManagerCreateError verifies that wallet creation fails when invalid
+// parameters are provided. This ensures that the Manager correctly validates
+// inputs before attempting to modify the database.
+func TestManagerCreateError(t *testing.T) {
+ t.Parallel()
+
+ // Pre-calculate cryptographic material to construct specific test
+ // scenarios.
+ seed, err := hdkeychain.GenerateSeed(hdkeychain.RecommendedSeedLen)
+ require.NoError(t, err)
+
+ rootKey, err := hdkeychain.NewMaster(seed, &chainParams)
+ require.NoError(t, err)
+
+ pubKey, err := rootKey.Neuter()
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ params CreateWalletParams
+ expectedErr string
+ }{
+ {
+ name: "ModeImportSeed missing seed",
+ params: CreateWalletParams{
+ Mode: ModeImportSeed,
+ Seed: nil,
+ },
+ expectedErr: "seed is required",
+ },
+ {
+ name: "ModeImportExtKey missing key",
+ params: CreateWalletParams{
+ Mode: ModeImportExtKey,
+ RootKey: nil,
+ },
+ expectedErr: "root key is required",
+ },
+ {
+ name: "Public Key for Non-WatchOnly",
+ params: CreateWalletParams{
+ Mode: ModeImportExtKey,
+ RootKey: pubKey,
+ WatchOnly: false,
+ },
+ expectedErr: "private key required",
+ },
+ {
+ name: "Unknown Mode",
+ params: CreateWalletParams{
+ Mode: ModeUnknown,
+ },
+ expectedErr: "unknown mode",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ db, cleanup := setupTestDB(t)
+ t.Cleanup(cleanup)
+
+ m := NewManager()
+ cfg := Config{
+ DB: db,
+ Chain: &mockChain{},
+ ChainParams: &chainParams,
+ Name: "test-wallet",
+ RecoveryWindow: MinRecoveryWindow,
+ }
+
+ // Attempt to create the wallet. We expect this to fail due to
+ // the invalid parameters configured in the test case.
+ _, err := m.Create(cfg, tc.params)
+
+ // Verify that the error matches our expectation.
+ require.Error(t, err)
+ require.ErrorContains(t, err, tc.expectedErr)
+ })
+ }
+}
+
+// TestCreateWalletParams_Validate verifies that the validate method enforces
+// the correct constraints for each creation mode.
+func TestCreateWalletParams_Validate(t *testing.T) {
+ t.Parallel()
+
+ // Pre-calculate cryptographic material to construct specific test
+ // scenarios.
+ seed, err := hdkeychain.GenerateSeed(hdkeychain.RecommendedSeedLen)
+ require.NoError(t, err)
+
+ rootKey, err := hdkeychain.NewMaster(seed, &chainParams)
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ params CreateWalletParams
+ expectedErr string
+ }{
+ {
+ name: "ModeGenSeed with Seed",
+ params: CreateWalletParams{
+ Mode: ModeGenSeed,
+ Seed: seed,
+ },
+ expectedErr: "seed should not be set for ModeGenSeed",
+ },
+ {
+ name: "ModeGenSeed with RootKey",
+ params: CreateWalletParams{
+ Mode: ModeGenSeed,
+ RootKey: rootKey,
+ },
+ expectedErr: "root key should not be set for ModeGenSeed",
+ },
+ {
+ name: "ModeImportSeed with RootKey",
+ params: CreateWalletParams{
+ Mode: ModeImportSeed,
+ Seed: seed,
+ RootKey: rootKey,
+ },
+ expectedErr: "root key should not be set for ModeImportSeed",
+ },
+ {
+ name: "ModeImportExtKey with Seed",
+ params: CreateWalletParams{
+ Mode: ModeImportExtKey,
+ RootKey: rootKey,
+ Seed: seed,
+ },
+ expectedErr: "seed should not be set for ModeImportExtKey",
+ },
+ {
+ name: "ModeShell with Seed",
+ params: CreateWalletParams{
+ Mode: ModeShell,
+ Seed: seed,
+ },
+ expectedErr: "seed should not be set for ModeShell",
+ },
+ {
+ name: "Unknown Mode",
+ params: CreateWalletParams{
+ Mode: ModeUnknown,
+ },
+ expectedErr: "unknown mode",
+ },
+ {
+ name: "InitialAccounts with ModeGenSeed",
+ params: CreateWalletParams{
+ Mode: ModeGenSeed,
+ InitialAccounts: []WatchOnlyAccount{{
+ Name: "test",
+ }},
+ },
+ expectedErr: "initial accounts should only " +
+ "be set for ModeShell",
+ }}
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ err := tc.params.validate()
+ require.Error(t, err)
+ require.ErrorContains(t, err, tc.expectedErr)
+ })
+ }
+}
+
+// TestManagerCreate_InvalidConfig verifies that the Create method performs
+// configuration validation before proceeding with any operations.
+func TestManagerCreate_InvalidConfig(t *testing.T) {
+ t.Parallel()
+
+ m := NewManager()
+
+ // Call Create with an empty Config struct. This should fail because
+ // required fields like DB and ChainParams are missing.
+ w, err := m.Create(Config{}, CreateWalletParams{})
+
+ require.ErrorIs(t, err, ErrMissingParam)
+ require.ErrorContains(t, err, "DB")
+ require.Nil(t, w)
+}
+
+// TestManagerLoadSuccess verifies that an existing wallet can be successfully
+// loaded from the database. This tests the persistence and restoration flow.
+func TestManagerLoadSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Initialize a database and create a wallet to serve as our existing
+ // state.
+ db, cleanup := setupTestDB(t)
+ t.Cleanup(cleanup)
+
+ m := NewManager()
+ cfg := Config{
+ DB: db,
+ Chain: &mockChain{},
+ ChainParams: &chainParams,
+ Name: "test-wallet",
+ PubPassphrase: []byte("public"),
+ RecoveryWindow: MinRecoveryWindow,
+ }
+ params := CreateWalletParams{
+ Mode: ModeGenSeed,
+ PubPassphrase: []byte("public"),
+ PrivatePassphrase: []byte("private"),
+ Birthday: time.Now(),
+ }
+
+ wCreated, err := m.Create(cfg, params)
+ require.NoError(t, err)
+ require.NotNil(t, wCreated)
+
+ // Create a new Manager instance to simulate a fresh start (e.g., daemon
+ // restart) and attempt to load the wallet from the existing database.
+ m2 := NewManager()
+ w, err := m2.Load(cfg)
+
+ // Verify that the load operation succeeded and returned a valid wallet.
+ require.NoError(t, err)
+ require.NotNil(t, w)
+
+ // Ensure the loaded wallet is correctly registered in the new manager.
+ m2.RLock()
+ loadedW, ok := m2.wallets["test-wallet"]
+ m2.RUnlock()
+ require.True(t, ok)
+ require.Same(t, w, loadedW)
+}
+
+// TestManagerLoad_ExistingWallet verifies that if Load is called for a wallet
+// that is already managed in memory, the Manager detects this.
+func TestManagerLoad_ExistingWallet(t *testing.T) {
+ t.Parallel()
+
+ db, cleanup := setupTestDB(t)
+ t.Cleanup(cleanup)
+
+ m := NewManager()
+ cfg := Config{
+ DB: db,
+ Chain: &mockChain{},
+ ChainParams: &chainParams,
+ Name: "test-wallet",
+ PubPassphrase: []byte("public"),
+ RecoveryWindow: MinRecoveryWindow,
+ }
+ params := CreateWalletParams{
+ Mode: ModeGenSeed,
+ PubPassphrase: []byte("public"),
+ PrivatePassphrase: []byte("private"),
+ Birthday: time.Now(),
+ }
+
+ wCreated, err := m.Create(cfg, params)
+ require.NoError(t, err)
+
+ // Attempt to load the same wallet again using the same manager instance.
+ // Since it's already loaded in memory, the manager should return the
+ // existing instance rather than reloading from disk.
+ wLoaded, err := m.Load(cfg)
+
+ // Verify that we got the same wallet instance back.
+ require.NoError(t, err)
+ require.Same(t, wCreated, wLoaded)
+}
+
+// TestManagerLoadError verifies that Load properly handles invalid
+// configurations and corrupted or uninitialized databases.
+func TestManagerLoadError(t *testing.T) {
+ t.Parallel()
+
+ t.Run("Invalid Config", func(t *testing.T) {
+ t.Parallel()
+
+ m := NewManager()
+
+ // Attempt to load with an empty config. This should fail validation.
+ w, err := m.Load(Config{})
+ require.ErrorContains(t, err, "missing config parameter")
+ require.Nil(t, w)
+ })
+
+ t.Run("Uninitialized DB", func(t *testing.T) {
+ t.Parallel()
+
+ db, cleanup := setupTestDB(t)
+ t.Cleanup(cleanup)
+
+ m := NewManager()
+ cfg := Config{
+ DB: db,
+ Chain: &mockChain{},
+ ChainParams: &chainParams,
+ Name: "test",
+ }
+
+ // Attempt to load from a database that has valid buckets but no
+ // wallet data (waddrmgr is not initialized). This should fail
+ // at the database loading step.
+ w, err := m.Load(cfg)
+
+ // We expect an error from waddrmgr.Open indicating the address
+ // manager namespace is missing or invalid.
+ require.Error(t, err)
+ require.Nil(t, w)
+ })
+}
+
+// TestManagerString verifies that the String representation of the Manager
+// correctly lists the tracked wallets in alphabetical order.
+func TestManagerString(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ setup func(*Manager)
+ expected string
+ }{
+ {
+ name: "empty",
+ setup: func(m *Manager) {},
+ expected: "active_wallets=[]",
+ },
+ {
+ name: "multiple sorted",
+ setup: func(m *Manager) {
+ m.wallets["wallet-b"] = &Wallet{}
+ m.wallets["wallet-a"] = &Wallet{}
+ },
+ expected: "active_wallets=[wallet-a wallet-b]",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ m := NewManager()
+ tc.setup(m)
+ require.Equal(t, tc.expected, m.String())
+ })
+ }
+}
+
+// TestManager_deriveFromSeed verifies the internal helper method
+// deriveFromSeed, checking that it correctly derives a master private key
+// from a seed and validates inputs.
+func TestManager_deriveFromSeed(t *testing.T) {
+ t.Parallel()
+
+ m := NewManager()
+ cfg := Config{ChainParams: &chainParams}
+
+ t.Run("Success", func(t *testing.T) {
+ t.Parallel()
+
+ seed, err := hdkeychain.GenerateSeed(hdkeychain.RecommendedSeedLen)
+ require.NoError(t, err)
+
+ key, err := m.deriveFromSeed(cfg, seed)
+
+ // Verify we got a valid private extended key.
+ require.NoError(t, err)
+ require.NotNil(t, key)
+ require.True(t, key.IsPrivate())
+ })
+
+ t.Run("Empty Seed", func(t *testing.T) {
+ t.Parallel()
+
+ key, err := m.deriveFromSeed(cfg, nil)
+ require.ErrorIs(t, err, ErrWalletParams)
+ require.ErrorContains(t, err, "seed is required")
+ require.Nil(t, key)
+ })
+
+ t.Run("Invalid Seed Length", func(t *testing.T) {
+ t.Parallel()
+
+ // Providing a seed that is too short for hdkeychain.NewMaster.
+ key, err := m.deriveFromSeed(cfg, []byte{0x01})
+ require.ErrorContains(t, err, "failed to derive master key")
+ require.Nil(t, key)
+ })
+}
+
+// TestManager_genRootKey verifies the internal helper method genRootKey,
+// ensuring it generates a random seed and derives a valid master key.
+func TestManager_genRootKey(t *testing.T) {
+ t.Parallel()
+
+ m := NewManager()
+ cfg := Config{ChainParams: &chainParams}
+
+ key, err := m.genRootKey(cfg)
+
+ // Verify we got a valid private extended key.
+ require.NoError(t, err)
+ require.NotNil(t, key)
+ require.True(t, key.IsPrivate())
+}
+
+// TestManager_deriveRootKey verifies the high-level key derivation logic,
+// checking that it correctly dispatches to the appropriate helper based on
+// the creation mode.
+func TestManager_deriveRootKey(t *testing.T) {
+ t.Parallel()
+
+ m := NewManager()
+ cfg := Config{ChainParams: &chainParams}
+
+ // 1. ModeShell: Should return nil/nil (no root key for shell).
+ t.Run("ModeShell", func(t *testing.T) {
+ t.Parallel()
+
+ key, err := m.deriveRootKey(cfg, CreateWalletParams{Mode: ModeShell})
+ require.NoError(t, err)
+ require.Nil(t, key)
+ })
+
+ t.Run("ModeUnknown", func(t *testing.T) {
+ t.Parallel()
+
+ key, err := m.deriveRootKey(cfg, CreateWalletParams{Mode: ModeUnknown})
+ require.ErrorIs(t, err, ErrWalletParams)
+ require.ErrorContains(t, err, "unknown mode")
+ require.Nil(t, key)
+ })
+
+ // 3. ModeGenSeed: Should return a newly generated private key.
+ t.Run("ModeGenSeed", func(t *testing.T) {
+ t.Parallel()
+
+ key, err := m.deriveRootKey(cfg, CreateWalletParams{Mode: ModeGenSeed})
+ require.NoError(t, err)
+ require.NotNil(t, key)
+ require.True(t, key.IsPrivate())
+ })
+}
diff --git a/wallet/mock.go b/wallet/mock.go
deleted file mode 100644
index ac7d1eb760..0000000000
--- a/wallet/mock.go
+++ /dev/null
@@ -1,115 +0,0 @@
-package wallet
-
-import (
- "context"
- "time"
-
- "github.com/btcsuite/btcd/address/v2"
- "github.com/btcsuite/btcd/btcjson"
- "github.com/btcsuite/btcd/chainhash/v2"
- "github.com/btcsuite/btcd/wire/v2"
- "github.com/btcsuite/btcwallet/chain"
- "github.com/btcsuite/btcwallet/waddrmgr"
-)
-
-type mockChainClient struct {
- getBestBlockHeight int32
- getBlockHashFunc func() (*chainhash.Hash, error)
- getBlockHeader *wire.BlockHeader
-}
-
-var _ chain.Interface = (*mockChainClient)(nil)
-
-func (m *mockChainClient) Start(_ context.Context) error {
- return nil
-}
-
-func (m *mockChainClient) Stop() {
-}
-
-func (m *mockChainClient) WaitForShutdown() {}
-
-func (m *mockChainClient) GetBestBlock() (*chainhash.Hash, int32, error) {
- return nil, m.getBestBlockHeight, nil
-}
-
-func (m *mockChainClient) GetBlock(*chainhash.Hash) (*wire.MsgBlock, error) {
- return nil, nil
-}
-
-func (m *mockChainClient) GetBlockHash(int64) (*chainhash.Hash, error) {
- if m.getBlockHashFunc != nil {
- return m.getBlockHashFunc()
- }
- return nil, nil
-}
-
-func (m *mockChainClient) GetBlockHeader(*chainhash.Hash) (*wire.BlockHeader,
- error) {
- return m.getBlockHeader, nil
-}
-
-func (m *mockChainClient) IsCurrent() bool {
- return false
-}
-
-func (m *mockChainClient) FilterBlocks(*chain.FilterBlocksRequest) (
- *chain.FilterBlocksResponse, error) {
- return nil, nil
-}
-
-func (m *mockChainClient) BlockStamp() (*waddrmgr.BlockStamp, error) {
- return &waddrmgr.BlockStamp{
- Height: 500000,
- Hash: chainhash.Hash{},
- Timestamp: time.Unix(1234, 0),
- }, nil
-}
-
-func (m *mockChainClient) SendRawTransaction(*wire.MsgTx, bool) (
- *chainhash.Hash, error) {
- return nil, nil
-}
-
-func (m *mockChainClient) Rescan(*chainhash.Hash, []address.Address,
- map[wire.OutPoint]address.Address) error {
-
- return nil
-}
-
-func (m *mockChainClient) NotifyReceived([]address.Address) error {
- return nil
-}
-
-func (m *mockChainClient) NotifyBlocks() error {
- return nil
-}
-
-func (m *mockChainClient) Notifications() <-chan interface{} {
- return nil
-}
-
-func (m *mockChainClient) BackEnd() string {
- return "mock"
-}
-
-// TestMempoolAcceptCmd returns result of mempool acceptance tests indicating
-// if raw transaction(s) would be accepted by mempool.
-//
-// NOTE: This is part of the chain.Interface interface.
-func (m *mockChainClient) TestMempoolAccept(txns []*wire.MsgTx,
- maxFeeRate float64) ([]*btcjson.TestMempoolAcceptResult, error) {
-
- return nil, nil
-}
-
-// SubmitPackage is part of the chain.Interface interface.
-func (m *mockChainClient) SubmitPackage(txns []*wire.MsgTx,
- maxFeeRate *float64) (*btcjson.SubmitPackageResult, error) {
-
- return &btcjson.SubmitPackageResult{}, nil
-}
-
-func (m *mockChainClient) MapRPCErr(err error) error {
- return nil
-}
diff --git a/wallet/mock_test.go b/wallet/mock_test.go
new file mode 100644
index 0000000000..7643c1e32d
--- /dev/null
+++ b/wallet/mock_test.go
@@ -0,0 +1,1527 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+// This file contains a mock implementation of the wtxmgr.TxStore interface.
+// It is used in various tests to isolate wallet logic from the underlying
+// database.
+
+package wallet
+
+import (
+ "context"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcec/v2"
+ "github.com/btcsuite/btcd/btcjson"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/gcs"
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/chain"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/lightninglabs/neutrino"
+ "github.com/lightninglabs/neutrino/banman"
+ "github.com/lightninglabs/neutrino/headerfs"
+ "github.com/stretchr/testify/mock"
+)
+
+// mockTxStore is a mock implementation of the wtxmgr.TxStore interface.
+type mockTxStore struct {
+ mock.Mock
+}
+
+// A compile-time assertion to ensure that mockTxStore implements the TxStore
+// interface.
+var _ wtxmgr.TxStore = (*mockTxStore)(nil)
+
+// Balance implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) Balance(ns walletdb.ReadBucket, minConf int32,
+ syncHeight int32) (btcutil.Amount, error) {
+
+ args := m.Called(ns, minConf, syncHeight)
+ if args.Get(0) == nil {
+ return btcutil.Amount(0), args.Error(1)
+ }
+
+ return args.Get(0).(btcutil.Amount), args.Error(1)
+}
+
+// DeleteExpiredLockedOutputs implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) DeleteExpiredLockedOutputs(
+ ns walletdb.ReadWriteBucket) error {
+
+ args := m.Called(ns)
+ return args.Error(0)
+}
+
+// InsertTx implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) InsertTx(ns walletdb.ReadWriteBucket,
+ rec *wtxmgr.TxRecord, block *wtxmgr.BlockMeta) error {
+
+ args := m.Called(ns, rec, block)
+ return args.Error(0)
+}
+
+// InsertTxCheckIfExists implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) InsertTxCheckIfExists(ns walletdb.ReadWriteBucket,
+ rec *wtxmgr.TxRecord, block *wtxmgr.BlockMeta) (bool, error) {
+
+ args := m.Called(ns, rec, block)
+ return args.Bool(0), args.Error(1)
+}
+
+// InsertConfirmedTx implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) InsertConfirmedTx(ns walletdb.ReadWriteBucket,
+ rec *wtxmgr.TxRecord, block *wtxmgr.BlockMeta,
+ credits []wtxmgr.CreditEntry) error {
+
+ args := m.Called(ns, rec, block, credits)
+ return args.Error(0)
+}
+
+// InsertUnconfirmedTx implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) InsertUnconfirmedTx(ns walletdb.ReadWriteBucket,
+ rec *wtxmgr.TxRecord, credits []wtxmgr.CreditEntry) error {
+
+ args := m.Called(ns, rec, credits)
+ return args.Error(0)
+}
+
+// AddCredit implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) AddCredit(ns walletdb.ReadWriteBucket,
+ rec *wtxmgr.TxRecord, block *wtxmgr.BlockMeta, index uint32,
+ change bool) error {
+
+ args := m.Called(ns, rec, block, index, change)
+ return args.Error(0)
+}
+
+// ListLockedOutputs implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) ListLockedOutputs(
+ ns walletdb.ReadBucket) ([]*wtxmgr.LockedOutput, error) {
+
+ args := m.Called(ns)
+ return args.Get(0).([]*wtxmgr.LockedOutput), args.Error(1)
+}
+
+// LockOutput implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) LockOutput(ns walletdb.ReadWriteBucket, id wtxmgr.LockID,
+ op wire.OutPoint, duration time.Duration) (time.Time, error) {
+
+ args := m.Called(ns, id, op, duration)
+ if args.Get(0) == nil {
+ return time.Time{}, args.Error(1)
+ }
+
+ return args.Get(0).(time.Time), args.Error(1)
+}
+
+// OutputsToWatch implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) OutputsToWatch(
+ ns walletdb.ReadBucket) ([]wtxmgr.Credit, error) {
+
+ args := m.Called(ns)
+ return args.Get(0).([]wtxmgr.Credit), args.Error(1)
+}
+
+// PutTxLabel implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) PutTxLabel(ns walletdb.ReadWriteBucket,
+ txid chainhash.Hash, label string) error {
+
+ args := m.Called(ns, txid, label)
+ return args.Error(0)
+}
+
+// RangeTransactions implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) RangeTransactions(ns walletdb.ReadBucket, begin,
+ end int32, f func([]wtxmgr.TxDetails) (bool, error)) error {
+
+ args := m.Called(ns, begin, end, f)
+ return args.Error(0)
+}
+
+// Rollback implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) Rollback(
+ ns walletdb.ReadWriteBucket, height int32) error {
+
+ args := m.Called(ns, height)
+ return args.Error(0)
+}
+
+// TxDetails implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) TxDetails(ns walletdb.ReadBucket,
+ txHash *chainhash.Hash) (*wtxmgr.TxDetails, error) {
+
+ args := m.Called(ns, txHash)
+ details, _ := args.Get(0).(*wtxmgr.TxDetails)
+
+ return details, args.Error(1)
+}
+
+// UniqueTxDetails implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) UniqueTxDetails(ns walletdb.ReadBucket,
+ txHash *chainhash.Hash,
+ block *wtxmgr.Block) (*wtxmgr.TxDetails, error) {
+
+ args := m.Called(ns, txHash, block)
+ details, _ := args.Get(0).(*wtxmgr.TxDetails)
+
+ return details, args.Error(1)
+}
+
+// UnlockOutput implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) UnlockOutput(ns walletdb.ReadWriteBucket,
+ id wtxmgr.LockID, op wire.OutPoint) error {
+
+ args := m.Called(ns, id, op)
+ return args.Error(0)
+}
+
+// UnspentOutputs implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) UnspentOutputs(
+ ns walletdb.ReadBucket) ([]wtxmgr.Credit, error) {
+
+ args := m.Called(ns)
+
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+
+ return args.Get(0).([]wtxmgr.Credit), args.Error(1)
+}
+
+// GetUtxo implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) GetUtxo(ns walletdb.ReadBucket,
+ outpoint wire.OutPoint) (*wtxmgr.Credit, error) {
+
+ args := m.Called(ns, outpoint)
+ credit, _ := args.Get(0).(*wtxmgr.Credit)
+
+ return credit, args.Error(1)
+}
+
+// FetchTxLabel implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) FetchTxLabel(ns walletdb.ReadBucket,
+ txid chainhash.Hash) (string, error) {
+
+ args := m.Called(ns, txid)
+ return args.String(0), args.Error(1)
+}
+
+// UnminedTxs implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) UnminedTxs(
+ ns walletdb.ReadBucket) ([]*wire.MsgTx, error) {
+
+ args := m.Called(ns)
+ return args.Get(0).([]*wire.MsgTx), args.Error(1)
+}
+
+// UnminedTxHashes implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) UnminedTxHashes(
+ ns walletdb.ReadBucket) ([]*chainhash.Hash, error) {
+
+ args := m.Called(ns)
+ return args.Get(0).([]*chainhash.Hash), args.Error(1)
+}
+
+// RemoveUnminedTx implements the wtxmgr.TxStore interface.
+func (m *mockTxStore) RemoveUnminedTx(ns walletdb.ReadWriteBucket,
+ rec *wtxmgr.TxRecord) error {
+
+ args := m.Called(ns, rec)
+ return args.Error(0)
+}
+
+// mockAddrStore is a mock implementation of the waddrmgr.AddrStore interface.
+type mockAddrStore struct {
+ mock.Mock
+}
+
+// Birthday returns the birthday of the address store.
+func (m *mockAddrStore) Birthday() time.Time {
+ args := m.Called()
+ return args.Get(0).(time.Time)
+}
+
+// SetSyncedTo marks the address manager to be in sync with the
+// recently-seen block described by the blockstamp.
+func (m *mockAddrStore) SetSyncedTo(ns walletdb.ReadWriteBucket,
+ bs *waddrmgr.BlockStamp) error {
+
+ args := m.Called(ns, bs)
+ return args.Error(0)
+}
+
+// SetBirthdayBlock sets the birthday block, or earliest time a key could
+// have been used, for the manager.
+func (m *mockAddrStore) SetBirthdayBlock(ns walletdb.ReadWriteBucket,
+ block waddrmgr.BlockStamp, verified bool) error {
+
+ args := m.Called(ns, block, verified)
+ return args.Error(0)
+}
+
+// SyncedTo returns details about the block height and hash that the
+// address manager is synced through at the very least.
+func (m *mockAddrStore) SyncedTo() waddrmgr.BlockStamp {
+ args := m.Called()
+ return args.Get(0).(waddrmgr.BlockStamp)
+}
+
+// BlockHash returns the block hash at a particular block height.
+func (m *mockAddrStore) BlockHash(ns walletdb.ReadBucket,
+ height int32) (*chainhash.Hash, error) {
+
+ args := m.Called(ns, height)
+ return args.Get(0).(*chainhash.Hash), args.Error(1)
+}
+
+// ActiveScopedKeyManagers returns a slice of all the active scoped key
+// managers currently known by the root key manager.
+func (m *mockAddrStore) ActiveScopedKeyManagers() []waddrmgr.AccountStore {
+ args := m.Called()
+ return args.Get(0).([]waddrmgr.AccountStore)
+}
+
+// FetchScopedKeyManager attempts to fetch an active scoped manager
+// according to its registered scope.
+func (m *mockAddrStore) FetchScopedKeyManager(
+ scope waddrmgr.KeyScope) (waddrmgr.AccountStore, error) {
+
+ args := m.Called(scope)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+
+ return args.Get(0).(waddrmgr.AccountStore), args.Error(1)
+}
+
+// Address returns a managed address given the passed address if it is
+// known to the address manager.
+func (m *mockAddrStore) Address(ns walletdb.ReadBucket,
+ address address.Address) (waddrmgr.ManagedAddress, error) {
+
+ args := m.Called(ns, address)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+
+ return args.Get(0).(waddrmgr.ManagedAddress), args.Error(1)
+}
+
+// AddrAccount returns the account to which the given address belongs.
+func (m *mockAddrStore) AddrAccount(ns walletdb.ReadBucket,
+ address address.Address) (waddrmgr.AccountStore, uint32, error) {
+
+ args := m.Called(ns, address)
+
+ return args.Get(0).(waddrmgr.AccountStore),
+ args.Get(1).(uint32), args.Error(2)
+}
+
+// AddressDetails determines whether the wallet has access to the private
+// keys required to sign for a given address, and returns other address
+// details.
+func (m *mockAddrStore) AddressDetails(ns walletdb.ReadBucket,
+ addr address.Address) (bool, string, waddrmgr.AddressType) {
+
+ args := m.Called(ns, addr)
+ return args.Bool(0), args.String(1), args.Get(2).(waddrmgr.AddressType)
+}
+
+// ForEachRelevantActiveAddress invokes the given closure on each active
+// address relevant to the wallet.
+func (m *mockAddrStore) ForEachRelevantActiveAddress(ns walletdb.ReadBucket,
+ fn func(addr address.Address) error) error {
+
+ args := m.Called(ns, fn)
+ return args.Error(0)
+}
+
+// Unlock derives the master private key from the specified passphrase.
+func (m *mockAddrStore) Unlock(ns walletdb.ReadBucket,
+ passphrase []byte) error {
+
+ args := m.Called(ns, passphrase)
+ return args.Error(0)
+}
+
+// Lock performs a best try effort to remove and zero all secret keys
+// associated with the address manager.
+func (m *mockAddrStore) Lock() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+// IsLocked returns whether or not the address managed is locked.
+func (m *mockAddrStore) IsLocked() bool {
+ args := m.Called()
+ return args.Bool(0)
+}
+
+// ChangePassphrase changes either the public or private passphrase to
+// the provided value depending on the private flag.
+func (m *mockAddrStore) ChangePassphrase(ns walletdb.ReadWriteBucket,
+ oldPass, newPass []byte, private bool,
+ scryptOptions *waddrmgr.ScryptOptions) error {
+
+ args := m.Called(ns, oldPass, newPass, private, scryptOptions)
+ return args.Error(0)
+}
+
+// WatchOnly returns true if the root manager is in watch only mode, and
+// false otherwise.
+func (m *mockAddrStore) WatchOnly() bool {
+ args := m.Called()
+ return args.Bool(0)
+}
+
+// MarkUsed updates the used flag for the provided address.
+func (m *mockAddrStore) MarkUsed(ns walletdb.ReadWriteBucket,
+ address address.Address) error {
+
+ args := m.Called(ns, address)
+ return args.Error(0)
+}
+
+// BirthdayBlock returns the birthday block of the address store.
+func (m *mockAddrStore) BirthdayBlock(
+ ns walletdb.ReadBucket) (waddrmgr.BlockStamp, bool, error) {
+
+ args := m.Called(ns)
+ return args.Get(0).(waddrmgr.BlockStamp), args.Bool(1), args.Error(2)
+}
+
+// IsWatchOnlyAccount determines if the account with the given key scope
+// is set up as watch-only.
+func (m *mockAddrStore) IsWatchOnlyAccount(ns walletdb.ReadBucket,
+ keyScope waddrmgr.KeyScope, account uint32) (bool, error) {
+
+ args := m.Called(ns, keyScope, account)
+ return args.Bool(0), args.Error(1)
+}
+
+// NewScopedKeyManager creates a new scoped key manager from the root
+// manager.
+func (m *mockAddrStore) NewScopedKeyManager(ns walletdb.ReadWriteBucket,
+ scope waddrmgr.KeyScope,
+ addrSchema waddrmgr.ScopeAddrSchema) (waddrmgr.AccountStore, error) {
+
+ args := m.Called(ns, scope, addrSchema)
+ return args.Get(0).(waddrmgr.AccountStore), args.Error(1)
+}
+
+// SetBirthday sets the birthday of the address store.
+func (m *mockAddrStore) SetBirthday(ns walletdb.ReadWriteBucket,
+ birthday time.Time) error {
+
+ args := m.Called(ns, birthday)
+ return args.Error(0)
+}
+
+// ForEachAccountAddress calls the given function with each address of
+// the given account stored in the manager, breaking early on error.
+func (m *mockAddrStore) ForEachAccountAddress(ns walletdb.ReadBucket,
+ account uint32, fn func(maddr waddrmgr.ManagedAddress) error) error {
+
+ args := m.Called(ns, account, fn)
+ return args.Error(0)
+}
+
+// LookupAccount returns the corresponding key scope and account number
+// for the account with the given name.
+func (m *mockAddrStore) LookupAccount(ns walletdb.ReadBucket,
+ name string) (waddrmgr.KeyScope, uint32, error) {
+
+ args := m.Called(ns, name)
+
+ return args.Get(0).(waddrmgr.KeyScope),
+ args.Get(1).(uint32), args.Error(2)
+}
+
+// ForEachActiveAddress calls the given function with each active address
+// stored in the manager, breaking early on error.
+func (m *mockAddrStore) ForEachActiveAddress(ns walletdb.ReadBucket,
+ fn func(addr address.Address) error) error {
+
+ args := m.Called(ns, fn)
+ return args.Error(0)
+}
+
+// ConvertToWatchingOnly converts the current address manager to a locked
+// watching-only address manager.
+func (m *mockAddrStore) ConvertToWatchingOnly(
+ ns walletdb.ReadWriteBucket) error {
+
+ args := m.Called(ns)
+ return args.Error(0)
+}
+
+// ChainParams returns the chain parameters for this address manager.
+func (m *mockAddrStore) ChainParams() *chaincfg.Params {
+ args := m.Called()
+ return args.Get(0).(*chaincfg.Params)
+}
+
+// Close cleanly shuts down the manager.
+func (m *mockAddrStore) Close() {
+ m.Called()
+}
+
+// mockAccountStore is a mock implementation of the waddrmgr.AccountStore
+// interface.
+type mockAccountStore struct {
+ mock.Mock
+}
+
+// A compile-time assertion to ensure that mockAccountStore implements the
+// AccountStore interface.
+var _ waddrmgr.AccountStore = (*mockAccountStore)(nil)
+
+// Scope implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) Scope() waddrmgr.KeyScope {
+ args := m.Called()
+ return args.Get(0).(waddrmgr.KeyScope)
+}
+
+// AccountProperties implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) AccountProperties(ns walletdb.ReadBucket,
+ account uint32) (*waddrmgr.AccountProperties, error) {
+
+ args := m.Called(ns, account)
+ return args.Get(0).(*waddrmgr.AccountProperties), args.Error(1)
+}
+
+// LastExternalAddress implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) LastExternalAddress(ns walletdb.ReadBucket,
+ account uint32) (waddrmgr.ManagedAddress, error) {
+
+ args := m.Called(ns, account)
+ return args.Get(0).(waddrmgr.ManagedAddress), args.Error(1)
+}
+
+// LastInternalAddress implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) LastInternalAddress(ns walletdb.ReadBucket,
+ account uint32) (waddrmgr.ManagedAddress, error) {
+
+ args := m.Called(ns, account)
+ return args.Get(0).(waddrmgr.ManagedAddress), args.Error(1)
+}
+
+// ForEachAccountAddress implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) ForEachAccountAddress(ns walletdb.ReadBucket,
+ account uint32, fn func(maddr waddrmgr.ManagedAddress) error) error {
+
+ args := m.Called(ns, account, fn)
+ return args.Error(0)
+}
+
+// LookupAccount implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) LookupAccount(ns walletdb.ReadBucket,
+ name string) (uint32, error) {
+
+ args := m.Called(ns, name)
+ return args.Get(0).(uint32), args.Error(1)
+}
+
+// AccountName implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) AccountName(ns walletdb.ReadBucket,
+ account uint32) (string, error) {
+
+ args := m.Called(ns, account)
+ return args.String(0), args.Error(1)
+}
+
+// ExtendExternalAddresses implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) ExtendExternalAddresses(ns walletdb.ReadWriteBucket,
+ account uint32, count uint32) error {
+
+ args := m.Called(ns, account, count)
+ return args.Error(0)
+}
+
+// ExtendInternalAddresses implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) ExtendInternalAddresses(ns walletdb.ReadWriteBucket,
+ account uint32, count uint32) error {
+
+ args := m.Called(ns, account, count)
+ return args.Error(0)
+}
+
+// MarkUsed implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) MarkUsed(ns walletdb.ReadWriteBucket,
+ address address.Address) error {
+
+ args := m.Called(ns, address)
+ return args.Error(0)
+}
+
+// DeriveFromKeyPath implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) DeriveFromKeyPath(ns walletdb.ReadBucket,
+ path waddrmgr.DerivationPath) (waddrmgr.ManagedAddress, error) {
+
+ args := m.Called(ns, path)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+
+ return args.Get(0).(waddrmgr.ManagedAddress), args.Error(1)
+}
+
+// CanAddAccount implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) CanAddAccount() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+// NewAccount implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) NewAccount(ns walletdb.ReadWriteBucket,
+ name string) (uint32, error) {
+
+ args := m.Called(ns, name)
+ return args.Get(0).(uint32), args.Error(1)
+}
+
+// LastAccount implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) LastAccount(ns walletdb.ReadBucket) (uint32, error) {
+ args := m.Called(ns)
+ return args.Get(0).(uint32), args.Error(1)
+}
+
+// RenameAccount implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) RenameAccount(ns walletdb.ReadWriteBucket,
+ account uint32, name string) error {
+
+ args := m.Called(ns, account, name)
+ return args.Error(0)
+}
+
+// NextExternalAddresses implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) NextExternalAddresses(ns walletdb.ReadWriteBucket,
+ account uint32, count uint32) ([]waddrmgr.ManagedAddress, error) {
+
+ args := m.Called(ns, account, count)
+ return args.Get(0).([]waddrmgr.ManagedAddress), args.Error(1)
+}
+
+// NextInternalAddresses implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) NextInternalAddresses(ns walletdb.ReadWriteBucket,
+ account uint32, count uint32) ([]waddrmgr.ManagedAddress, error) {
+
+ args := m.Called(ns, account, count)
+ return args.Get(0).([]waddrmgr.ManagedAddress), args.Error(1)
+}
+
+// NewAddress implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) NewAddress(ns walletdb.ReadWriteBucket,
+ account string, internal bool) (address.Address, error) {
+
+ args := m.Called(ns, account, internal)
+ return args.Get(0).(address.Address), args.Error(1)
+}
+
+// ImportPublicKey implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) ImportPublicKey(ns walletdb.ReadWriteBucket,
+ pubKey *btcec.PublicKey,
+ bs *waddrmgr.BlockStamp) (waddrmgr.ManagedAddress, error) {
+
+ args := m.Called(ns, pubKey, bs)
+ return args.Get(0).(waddrmgr.ManagedAddress), args.Error(1)
+}
+
+// ImportTaprootScript implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) ImportTaprootScript(ns walletdb.ReadWriteBucket,
+ script *waddrmgr.Tapscript, bs *waddrmgr.BlockStamp, privKeyType byte,
+ isInternal bool) (waddrmgr.ManagedTaprootScriptAddress, error) {
+
+ args := m.Called(ns, script, bs, privKeyType, isInternal)
+ return args.Get(0).(waddrmgr.ManagedTaprootScriptAddress), args.Error(1)
+}
+
+// ForEachAccount implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) ForEachAccount(ns walletdb.ReadBucket,
+ fn func(account uint32) error) error {
+
+ args := m.Called(ns, fn)
+ return args.Error(0)
+}
+
+// IsWatchOnlyAccount implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) IsWatchOnlyAccount(ns walletdb.ReadBucket,
+ account uint32) (bool, error) {
+
+ args := m.Called(ns, account)
+ return args.Bool(0), args.Error(1)
+}
+
+// NewAccountWatchingOnly implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) NewAccountWatchingOnly(ns walletdb.ReadWriteBucket,
+ name string, pubKey *hdkeychain.ExtendedKey,
+ masterKeyFingerprint uint32,
+ addrSchema *waddrmgr.ScopeAddrSchema) (uint32, error) {
+
+ args := m.Called(ns, name, pubKey, masterKeyFingerprint, addrSchema)
+ return args.Get(0).(uint32), args.Error(1)
+}
+
+// InvalidateAccountCache implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) InvalidateAccountCache(account uint32) {
+ m.Called(account)
+}
+
+// ImportPrivateKey implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) ImportPrivateKey(ns walletdb.ReadWriteBucket,
+ wif *btcutil.WIF,
+ bs *waddrmgr.BlockStamp) (waddrmgr.ManagedPubKeyAddress, error) {
+
+ args := m.Called(ns, wif, bs)
+ return args.Get(0).(waddrmgr.ManagedPubKeyAddress), args.Error(1)
+}
+
+// ActiveAccounts implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) ActiveAccounts() []uint32 {
+ args := m.Called()
+ return args.Get(0).([]uint32)
+}
+
+// ExtendAddresses implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) ExtendAddresses(ns walletdb.ReadWriteBucket,
+ account uint32, lastIndex uint32, branch uint32) error {
+
+ args := m.Called(ns, account, lastIndex, branch)
+ return args.Error(0)
+}
+
+// DeriveAddr implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) DeriveAddr(account, branch, index uint32) (
+ address.Address, []byte, error) {
+
+ args := m.Called(account, branch, index)
+
+ var addr address.Address
+ if args.Get(0) != nil {
+ addr = args.Get(0).(address.Address)
+ }
+
+ var script []byte
+ if args.Get(1) != nil {
+ script = args.Get(1).([]byte)
+ }
+
+ return addr, script, args.Error(2)
+}
+
+// AddrAccount implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) AddrAccount(ns walletdb.ReadBucket,
+ address address.Address) (uint32, error) {
+
+ args := m.Called(ns, address)
+ return args.Get(0).(uint32), args.Error(1)
+}
+
+// DeriveFromKeyPathCache implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) DeriveFromKeyPathCache(
+ kp waddrmgr.DerivationPath) (*btcec.PrivateKey, error) {
+
+ args := m.Called(kp)
+ return args.Get(0).(*btcec.PrivateKey), args.Error(1)
+}
+
+// NewRawAccount implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) NewRawAccount(ns walletdb.ReadWriteBucket,
+ number uint32) error {
+
+ args := m.Called(ns, number)
+ return args.Error(0)
+}
+
+// NewRawAccountWatchingOnly implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) NewRawAccountWatchingOnly(
+ ns walletdb.ReadWriteBucket,
+ number uint32, pubKey *hdkeychain.ExtendedKey,
+ masterKeyFingerprint uint32,
+ addrSchema *waddrmgr.ScopeAddrSchema) error {
+
+ args := m.Called(ns, number, pubKey, masterKeyFingerprint, addrSchema)
+ return args.Error(0)
+}
+
+// ImportScript implements the waddrmgr.AccountStore interface.
+func (m *mockAccountStore) ImportScript(
+ ns walletdb.ReadWriteBucket, script []byte,
+ bs *waddrmgr.BlockStamp) (waddrmgr.ManagedScriptAddress, error) {
+
+ args := m.Called(ns, script, bs)
+ return args.Get(0).(waddrmgr.ManagedScriptAddress), args.Error(1)
+}
+
+// mockManagedAddress is a mock implementation of the waddrmgr.ManagedAddress
+// interface.
+type mockManagedAddress struct {
+ mock.Mock
+}
+
+// A compile-time assertion to ensure that mockManagedAddress implements the
+// ManagedAddress interface.
+var _ waddrmgr.ManagedAddress = (*mockManagedAddress)(nil)
+
+// Address implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedAddress) Address() address.Address {
+ args := m.Called()
+ return args.Get(0).(address.Address)
+}
+
+// AddrHash implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedAddress) AddrHash() []byte {
+ args := m.Called()
+ return args.Get(0).([]byte)
+}
+
+// Imported implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedAddress) Imported() bool {
+ args := m.Called()
+ return args.Bool(0)
+}
+
+// Internal implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedAddress) Internal() bool {
+ args := m.Called()
+ return args.Bool(0)
+}
+
+// Compressed implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedAddress) Compressed() bool {
+ args := m.Called()
+ return args.Bool(0)
+}
+
+// Used implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedAddress) Used(ns walletdb.ReadBucket) bool {
+ args := m.Called(ns)
+ return args.Bool(0)
+}
+
+// AddrType implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedAddress) AddrType() waddrmgr.AddressType {
+ args := m.Called()
+ return args.Get(0).(waddrmgr.AddressType)
+}
+
+// InternalAccount implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedAddress) InternalAccount() uint32 {
+ args := m.Called()
+ return args.Get(0).(uint32)
+}
+
+// DerivationInfo implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedAddress) DerivationInfo() (
+ waddrmgr.KeyScope, waddrmgr.DerivationPath, bool) {
+
+ args := m.Called()
+
+ return args.Get(0).(waddrmgr.KeyScope),
+ args.Get(1).(waddrmgr.DerivationPath), args.Bool(2)
+}
+
+// mockCoinSelectionStrategy is a mock implementation of the
+// CoinSelectionStrategy interface used for testing purposes.
+type mockCoinSelectionStrategy struct {
+ mock.Mock
+}
+
+// ArrangeCoins implements the CoinSelectionStrategy interface.
+func (m *mockCoinSelectionStrategy) ArrangeCoins(coins []Coin,
+ feePerKb btcutil.Amount) ([]Coin, error) {
+
+ args := m.Called(coins, feePerKb)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+
+ return args.Get(0).([]Coin), args.Error(1)
+}
+
+// mockChain is a mock implementation of the chain.Interface.
+type mockChain struct {
+ mock.Mock
+}
+
+// A compile-time assertion to ensure that mockChain implements the
+// chain.Interface.
+var _ chain.Interface = (*mockChain)(nil)
+
+// Start implements the chain.Interface interface.
+func (m *mockChain) Start(_ context.Context) error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+// Stop implements the chain.Interface interface.
+func (m *mockChain) Stop() {
+ m.Called()
+}
+
+// WaitForShutdown implements the chain.Interface interface.
+func (m *mockChain) WaitForShutdown() {
+ m.Called()
+}
+
+// GetBestBlock implements the chain.Interface interface.
+func (m *mockChain) GetBestBlock() (*chainhash.Hash, int32, error) {
+ args := m.Called()
+ hash, _ := args.Get(0).(*chainhash.Hash)
+
+ return hash, args.Get(1).(int32), args.Error(2)
+}
+
+// GetBlock implements the chain.Interface interface.
+func (m *mockChain) GetBlock(hash *chainhash.Hash) (*wire.MsgBlock, error) {
+ args := m.Called(hash)
+ block, _ := args.Get(0).(*wire.MsgBlock)
+
+ return block, args.Error(1)
+}
+
+// GetBlockHash implements the chain.Interface interface.
+func (m *mockChain) GetBlockHash(height int64) (*chainhash.Hash, error) {
+ args := m.Called(height)
+ hash, _ := args.Get(0).(*chainhash.Hash)
+
+ return hash, args.Error(1)
+}
+
+// GetBlockHeader implements the chain.Interface interface.
+func (m *mockChain) GetBlockHeader(
+ hash *chainhash.Hash) (*wire.BlockHeader, error) {
+
+ args := m.Called(hash)
+ header, _ := args.Get(0).(*wire.BlockHeader)
+
+ return header, args.Error(1)
+}
+
+func (m *mockChain) GetBlockHashes(start, end int64) ([]chainhash.Hash, error) {
+ args := m.Called(start, end)
+ return args.Get(0).([]chainhash.Hash), args.Error(1)
+}
+
+func (m *mockChain) GetBlockHeaders(
+ hashes []chainhash.Hash) ([]*wire.BlockHeader, error) {
+
+ args := m.Called(hashes)
+ return args.Get(0).([]*wire.BlockHeader), args.Error(1)
+}
+
+func (m *mockChain) GetCFilters(hashes []chainhash.Hash,
+ filterType wire.FilterType) ([]*gcs.Filter, error) {
+
+ args := m.Called(hashes, filterType)
+ return args.Get(0).([]*gcs.Filter), args.Error(1)
+}
+
+func (m *mockChain) GetBlocks(
+ hashes []chainhash.Hash) ([]*wire.MsgBlock, error) {
+
+ args := m.Called(hashes)
+ return args.Get(0).([]*wire.MsgBlock), args.Error(1)
+}
+
+// IsCurrent implements the chain.Interface interface.
+func (m *mockChain) IsCurrent() bool {
+ args := m.Called()
+ return args.Bool(0)
+}
+
+// GetCFilter implements the chain.Interface interface.
+func (m *mockChain) GetCFilter(hash *chainhash.Hash,
+ filterType wire.FilterType) (*gcs.Filter, error) {
+
+ args := m.Called(hash, filterType)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+
+ return args.Get(0).(*gcs.Filter), args.Error(1)
+}
+
+// FilterBlocks implements the chain.Interface interface.
+func (m *mockChain) FilterBlocks(req *chain.FilterBlocksRequest) (
+ *chain.FilterBlocksResponse, error) {
+
+ args := m.Called(req)
+ resp, _ := args.Get(0).(*chain.FilterBlocksResponse)
+
+ return resp, args.Error(1)
+}
+
+// BlockStamp implements the chain.Interface interface.
+func (m *mockChain) BlockStamp() (*waddrmgr.BlockStamp, error) {
+ args := m.Called()
+ stamp, _ := args.Get(0).(*waddrmgr.BlockStamp)
+
+ return stamp, args.Error(1)
+}
+
+// SendRawTransaction implements the chain.Interface interface.
+func (m *mockChain) SendRawTransaction(tx *wire.MsgTx,
+ allowHighFees bool) (*chainhash.Hash, error) {
+
+ args := m.Called(tx, allowHighFees)
+ hash, _ := args.Get(0).(*chainhash.Hash)
+
+ return hash, args.Error(1)
+}
+
+// Rescan implements the chain.Interface interface.
+func (m *mockChain) Rescan(hash *chainhash.Hash, addrs []address.Address,
+ outpoints map[wire.OutPoint]address.Address) error {
+
+ args := m.Called(hash, addrs, outpoints)
+ return args.Error(0)
+}
+
+// NotifyReceived implements the chain.Interface interface.
+func (m *mockChain) NotifyReceived(addrs []address.Address) error {
+ args := m.Called(addrs)
+ return args.Error(0)
+}
+
+// NotifyBlocks implements the chain.Interface interface.
+func (m *mockChain) NotifyBlocks() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+// Notifications implements the chain.Interface interface.
+func (m *mockChain) Notifications() <-chan any {
+ args := m.Called()
+ ch, _ := args.Get(0).(<-chan any)
+
+ return ch
+}
+
+// BackEnd implements the chain.Interface interface.
+func (m *mockChain) BackEnd() string {
+ args := m.Called()
+ return args.String(0)
+}
+
+// TestMempoolAccept implements the chain.Interface interface.
+func (m *mockChain) TestMempoolAccept(txns []*wire.MsgTx,
+ maxFeeRate float64) ([]*btcjson.TestMempoolAcceptResult, error) {
+
+ args := m.Called(txns, maxFeeRate)
+ res, _ := args.Get(0).([]*btcjson.TestMempoolAcceptResult)
+
+ return res, args.Error(1)
+}
+
+// SubmitPackage implements the chain.Interface interface.
+func (m *mockChain) SubmitPackage(txns []*wire.MsgTx,
+ maxFeeRate *float64) (*btcjson.SubmitPackageResult, error) {
+
+ args := m.Called(txns, maxFeeRate)
+ res, _ := args.Get(0).(*btcjson.SubmitPackageResult)
+
+ return res, args.Error(1)
+}
+
+// MapRPCErr implements the chain.Interface interface.
+func (m *mockChain) MapRPCErr(err error) error {
+ args := m.Called(err)
+ return args.Error(0)
+}
+
+// mockNeutrinoChain is a mock implementation of the chain.NeutrinoChainService
+// interface.
+type mockNeutrinoChain struct {
+ mockChain
+}
+
+// A compile-time assertion to ensure that mockNeutrinoChain implements the
+// chain.NeutrinoChainService.
+var _ chain.NeutrinoChainService = (*mockNeutrinoChain)(nil)
+
+// Stop implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) Stop() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+// GetBlock implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) GetBlock(hash chainhash.Hash,
+ opts ...neutrino.QueryOption) (*btcutil.Block, error) {
+
+ args := m.Called(hash, opts)
+ if args.Get(0) != nil {
+ if val, ok := args.Get(0).(*btcutil.Block); ok {
+ return val, args.Error(1)
+ }
+ }
+
+ return nil, args.Error(1)
+}
+
+// GetCFilter implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) GetCFilter(hash chainhash.Hash,
+ filterType wire.FilterType,
+ opts ...neutrino.QueryOption) (*gcs.Filter, error) {
+
+ args := m.Called(hash, filterType, opts)
+ if args.Get(0) != nil {
+ if val, ok := args.Get(0).(*gcs.Filter); ok {
+ return val, args.Error(1)
+ }
+ }
+
+ return nil, args.Error(1)
+}
+
+// GetBlockHeight implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) GetBlockHeight(
+ hash *chainhash.Hash) (int32, error) {
+
+ args := m.Called(hash)
+ return args.Get(0).(int32), args.Error(1)
+}
+
+// BestBlock implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) BestBlock() (*headerfs.BlockStamp, error) {
+ args := m.Called()
+ if args.Get(0) != nil {
+ if val, ok := args.Get(0).(*headerfs.BlockStamp); ok {
+ return val, args.Error(1)
+ }
+ }
+
+ return nil, args.Error(1)
+}
+
+// SendTransaction implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) SendTransaction(tx *wire.MsgTx) error {
+ args := m.Called(tx)
+ return args.Error(0)
+}
+
+// GetUtxo implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) GetUtxo(
+ opts ...neutrino.RescanOption) (*neutrino.SpendReport, error) {
+
+ args := m.Called(opts)
+ if args.Get(0) != nil {
+ if val, ok := args.Get(0).(*neutrino.SpendReport); ok {
+ return val, args.Error(1)
+ }
+ }
+
+ return nil, args.Error(1)
+}
+
+// BanPeer implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) BanPeer(addr string,
+ reason banman.Reason) error {
+
+ args := m.Called(addr, reason)
+ return args.Error(0)
+}
+
+// IsBanned implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) IsBanned(addr string) bool {
+ args := m.Called(addr)
+ return args.Bool(0)
+}
+
+// AddPeer implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) AddPeer(peer *neutrino.ServerPeer) {
+ m.Called(peer)
+}
+
+// AddBytesSent implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) AddBytesSent(bytes uint64) {
+ m.Called(bytes)
+}
+
+// AddBytesReceived implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) AddBytesReceived(bytes uint64) {
+ m.Called(bytes)
+}
+
+// NetTotals implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) NetTotals() (uint64, uint64) {
+ args := m.Called()
+
+ var a, b uint64
+ if args.Get(0) != nil {
+ if val, ok := args.Get(0).(uint64); ok {
+ a = val
+ }
+ }
+
+ if args.Get(1) != nil {
+ if val, ok := args.Get(1).(uint64); ok {
+ b = val
+ }
+ }
+
+ return a, b
+}
+
+// UpdatePeerHeights implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) UpdatePeerHeights(hash *chainhash.Hash,
+ height int32, peer *neutrino.ServerPeer) {
+
+ m.Called(hash, height, peer)
+}
+
+// ChainParams implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) ChainParams() chaincfg.Params {
+ args := m.Called()
+ if args.Get(0) != nil {
+ if val, ok := args.Get(0).(chaincfg.Params); ok {
+ return val
+ }
+ }
+
+ return chaincfg.Params{}
+}
+
+// PeerByAddr implements the chain.NeutrinoChainService interface.
+func (m *mockNeutrinoChain) PeerByAddr(
+ addr string) *neutrino.ServerPeer {
+
+ args := m.Called(addr)
+ if args.Get(0) != nil {
+ if val, ok := args.Get(0).(*neutrino.ServerPeer); ok {
+ return val
+ }
+ }
+
+ return nil
+}
+
+// mockManagedPubKeyAddr is a mock implementation of the
+// waddrmgr.ManagedPubKeyAddress interface, used for testing.
+type mockManagedPubKeyAddr struct {
+ mock.Mock
+}
+
+// A compile-time check to ensure that mockManagedPubKeyAddr implements the
+// ManagedPubKeyAddress interface.
+var _ waddrmgr.ManagedPubKeyAddress = (*mockManagedPubKeyAddr)(nil)
+
+// PubKey implements the waddrmgr.ManagedPubKeyAddress interface.
+func (m *mockManagedPubKeyAddr) PubKey() *btcec.PublicKey {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+
+ return args.Get(0).(*btcec.PublicKey)
+}
+
+// ExportPrivKey implements the waddrmgr.ManagedPubKeyAddress interface.
+func (m *mockManagedPubKeyAddr) ExportPrivKey() (*btcutil.WIF, error) {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+
+ return args.Get(0).(*btcutil.WIF), args.Error(1)
+}
+
+// ExportPubKey implements the waddrmgr.ManagedPubKeyAddress interface.
+func (m *mockManagedPubKeyAddr) ExportPubKey() string {
+ args := m.Called()
+ return args.String(0)
+}
+
+// PrivKey implements the waddrmgr.ManagedPubKeyAddress interface.
+func (m *mockManagedPubKeyAddr) PrivKey() (*btcec.PrivateKey, error) {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+
+ return args.Get(0).(*btcec.PrivateKey), args.Error(1)
+}
+
+// Address implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedPubKeyAddr) Address() address.Address {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+
+ return args.Get(0).(address.Address)
+}
+
+// AddrHash implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedPubKeyAddr) AddrHash() []byte {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+
+ return args.Get(0).([]byte)
+}
+
+// Imported implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedPubKeyAddr) Imported() bool {
+ args := m.Called()
+ return args.Bool(0)
+}
+
+// Internal implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedPubKeyAddr) Internal() bool {
+ args := m.Called()
+ return args.Bool(0)
+}
+
+// Compressed implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedPubKeyAddr) Compressed() bool {
+ args := m.Called()
+ return args.Bool(0)
+}
+
+// Used implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedPubKeyAddr) Used(ns walletdb.ReadBucket) bool {
+ args := m.Called(ns)
+ return args.Bool(0)
+}
+
+// AddrType implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedPubKeyAddr) AddrType() waddrmgr.AddressType {
+ args := m.Called()
+ return args.Get(0).(waddrmgr.AddressType)
+}
+
+// InternalAccount implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedPubKeyAddr) InternalAccount() uint32 {
+ args := m.Called()
+ return args.Get(0).(uint32)
+}
+
+// DerivationInfo implements the waddrmgr.ManagedAddress interface.
+func (m *mockManagedPubKeyAddr) DerivationInfo() (waddrmgr.KeyScope,
+ waddrmgr.DerivationPath, bool) {
+
+ args := m.Called()
+
+ return args.Get(0).(waddrmgr.KeyScope),
+ args.Get(1).(waddrmgr.DerivationPath), args.Bool(2)
+}
+
+// mockSpendDetails is a mock implementation of the SpendDetails interface.
+type mockSpendDetails struct {
+ mock.Mock
+}
+
+// A compile-time assertion to ensure that mockSpendDetails implements the
+// SpendDetails interface.
+var _ SpendDetails = (*mockSpendDetails)(nil)
+
+// Sign implements the SpendDetails interface.
+func (m *mockSpendDetails) Sign(params *RawSigParams,
+ privKey *btcec.PrivateKey) (RawSignature, error) {
+
+ args := m.Called(params, privKey)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+
+ return args.Get(0).(RawSignature), args.Error(1)
+}
+
+// isSpendDetails implements the SpendDetails interface.
+func (m *mockSpendDetails) isSpendDetails() {}
+
+// mockController is a mock implementation of the Controller interface.
+type mockController struct {
+ mock.Mock
+}
+
+// Compile-time check to ensure mockController implements Controller.
+var _ Controller = (*mockController)(nil)
+
+// Unlock implements the Controller interface.
+func (m *mockController) Unlock(ctx context.Context, req UnlockRequest) error {
+ args := m.Called(ctx, req)
+ return args.Error(0)
+}
+
+// Lock implements the Controller interface.
+func (m *mockController) Lock(ctx context.Context) error {
+ args := m.Called(ctx)
+ return args.Error(0)
+}
+
+// ChangePassphrase implements the Controller interface.
+func (m *mockController) ChangePassphrase(ctx context.Context,
+ req ChangePassphraseRequest) error {
+
+ args := m.Called(ctx, req)
+ return args.Error(0)
+}
+
+// Info implements the Controller interface.
+func (m *mockController) Info(ctx context.Context) (*Info, error) {
+ args := m.Called(ctx)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+
+ return args.Get(0).(*Info), args.Error(1)
+}
+
+// Start implements the Controller interface.
+func (m *mockController) Start(ctx context.Context) error {
+ args := m.Called(ctx)
+ return args.Error(0)
+}
+
+// Stop implements the Controller interface.
+func (m *mockController) Stop(ctx context.Context) error {
+ args := m.Called(ctx)
+ return args.Error(0)
+}
+
+// Resync implements the Controller interface.
+func (m *mockController) Resync(ctx context.Context, startHeight uint32) error {
+ args := m.Called(ctx, startHeight)
+ return args.Error(0)
+}
+
+// Rescan implements the Controller interface.
+func (m *mockController) Rescan(ctx context.Context, startHeight uint32,
+ targets []waddrmgr.AccountScope) error {
+
+ args := m.Called(ctx, startHeight, targets)
+ return args.Error(0)
+}
+
+// mockChainSyncer is a mock implementation of the chainSyncer interface.
+type mockChainSyncer struct {
+ mock.Mock
+}
+
+// A compile-time assertion to ensure that mockChainSyncer implements the
+// chainSyncer interface.
+var _ chainSyncer = (*mockChainSyncer)(nil)
+
+// run implements the chainSyncer interface.
+func (m *mockChainSyncer) run(ctx context.Context) error {
+ args := m.Called(ctx)
+ return args.Error(0)
+}
+
+// requestScan implements the chainSyncer interface.
+func (m *mockChainSyncer) requestScan(ctx context.Context, req *scanReq) error {
+ args := m.Called(ctx, req)
+ return args.Error(0)
+}
+
+// syncState implements the chainSyncer interface.
+func (m *mockChainSyncer) syncState() syncState {
+ args := m.Called()
+ return args.Get(0).(syncState)
+}
+
+// mockTxPublisher is a mock implementation of the TxPublisher interface.
+type mockTxPublisher struct {
+ mock.Mock
+}
+
+// A compile-time check to ensure that mockTxPublisher implements the
+// TxPublisher interface.
+var _ TxPublisher = (*mockTxPublisher)(nil)
+
+// CheckMempoolAcceptance implements the TxPublisher interface.
+func (m *mockTxPublisher) CheckMempoolAcceptance(ctx context.Context,
+ tx *wire.MsgTx) error {
+
+ args := m.Called(ctx, tx)
+ return args.Error(0)
+}
+
+// Broadcast implements the TxPublisher interface.
+func (m *mockTxPublisher) Broadcast(ctx context.Context, tx *wire.MsgTx,
+ label string) error {
+
+ args := m.Called(ctx, tx, label)
+ return args.Error(0)
+}
+
+// mockAddress is a mock implementation of the address.Address interface.
+// It embeds mock.Mock to allow for flexible stubbing of its methods,
+// enabling granular control over address behavior in tests.
+type mockAddress struct {
+ mock.Mock
+}
+
+// EncodeAddress mocks the EncodeAddress method.
+// It returns a predefined string based on mock expectations.
+func (m *mockAddress) EncodeAddress() string {
+ args := m.Called()
+ return args.String(0)
+}
+
+// ScriptAddress mocks the ScriptAddress method.
+// It returns a predefined byte slice based on mock expectations.
+func (m *mockAddress) ScriptAddress() []byte {
+ args := m.Called()
+ return args.Get(0).([]byte)
+}
+
+// IsForNet mocks the IsForNet method.
+// It returns a predefined boolean based on mock expectations.
+func (m *mockAddress) IsForNet(params *chaincfg.Params) bool {
+ args := m.Called(params)
+ return args.Bool(0)
+}
+
+// String mocks the String method.
+// It returns a predefined string based on mock expectations.
+func (m *mockAddress) String() string {
+ args := m.Called()
+ return args.String(0)
+}
+
+// mockManagedTaprootScriptAddress is a mock implementation of the
+// waddrmgr.ManagedTaprootScriptAddress interface.
+type mockManagedTaprootScriptAddress struct {
+ mockManagedAddress
+}
+
+// A compile-time assertion to ensure that mockManagedTaprootScriptAddress
+// implements the waddrmgr.ManagedTaprootScriptAddress interface.
+var _ waddrmgr.ManagedTaprootScriptAddress = (*mockManagedTaprootScriptAddress)(
+ nil,
+)
+
+// Script implements the waddrmgr.ManagedScriptAddress interface.
+func (m *mockManagedTaprootScriptAddress) Script() ([]byte, error) {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+
+ return args.Get(0).([]byte), args.Error(1)
+}
+
+// TaprootScript implements the waddrmgr.ManagedTaprootScriptAddress interface.
+func (m *mockManagedTaprootScriptAddress) TaprootScript() (
+ *waddrmgr.Tapscript, error) {
+
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+
+ return args.Get(0).(*waddrmgr.Tapscript), args.Error(1)
+}
diff --git a/wallet/multisig.go b/wallet/multisig.go
index 46e9b9e1b5..a891664cc6 100644
--- a/wallet/multisig.go
+++ b/wallet/multisig.go
@@ -49,13 +49,15 @@ func (w *Wallet) MakeMultiSigScript(addrs []address.Address,
case *address.AddressPubKeyHash:
if dbtx == nil {
var err error
- dbtx, err = w.db.BeginReadTx()
+
+ dbtx, err = w.cfg.DB.BeginReadTx()
if err != nil {
return nil, err
}
addrmgrNs = dbtx.ReadBucket(waddrmgrNamespaceKey)
}
- addrInfo, err := w.Manager.Address(addrmgrNs, addr)
+
+ addrInfo, err := w.addrStore.Address(addrmgrNs, addr)
if err != nil {
return nil, err
}
@@ -63,7 +65,7 @@ func (w *Wallet) MakeMultiSigScript(addrs []address.Address,
PubKey().SerializeCompressed()
pubKeyAddr, err := address.NewAddressPubKey(
- serializedPubKey, w.chainParams)
+ serializedPubKey, w.cfg.ChainParams)
if err != nil {
return nil, err
}
@@ -79,7 +81,8 @@ func (w *Wallet) ImportP2SHRedeemScript(
script []byte) (*address.AddressScriptHash, error) {
var p2shAddr *address.AddressScriptHash
- err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+
+ err := walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
// TODO(oga) blockstamp current block?
@@ -90,7 +93,7 @@ func (w *Wallet) ImportP2SHRedeemScript(
// As this is a regular P2SH script, we'll import this into the
// BIP0044 scope.
- bip44Mgr, err := w.Manager.FetchScopedKeyManager(
+ bip44Mgr, err := w.addrStore.FetchScopedKeyManager(
waddrmgr.KeyScopeBIP0084,
)
if err != nil {
@@ -106,7 +109,7 @@ func (w *Wallet) ImportP2SHRedeemScript(
// This function will never error as it always
// hashes the script to the correct length.
p2shAddr, _ = address.NewAddressScriptHash(script,
- w.chainParams)
+ w.cfg.ChainParams)
return nil
}
return err
diff --git a/wallet/notifications.go b/wallet/notifications.go
index 4e867ed7db..f0ec498ded 100644
--- a/wallet/notifications.go
+++ b/wallet/notifications.go
@@ -54,7 +54,8 @@ func lookupInputAccount(dbtx walletdb.ReadTx, w *Wallet, details *wtxmgr.TxDetai
// TODO: Debits should record which account(s?) they
// debit from so this doesn't need to be looked up.
prevOP := &details.MsgTx.TxIn[deb.Index].PreviousOutPoint
- prev, err := w.TxStore.TxDetails(txmgrNs, &prevOP.Hash)
+
+ prev, err := w.txStore.TxDetails(txmgrNs, &prevOP.Hash)
if err != nil {
log.Errorf("Cannot query previous transaction details for %v: %v", prevOP.Hash, err)
return 0
@@ -67,7 +68,7 @@ func lookupInputAccount(dbtx walletdb.ReadTx, w *Wallet, details *wtxmgr.TxDetai
_, addrs, _, err := txscript.ExtractPkScriptAddrs(prevOut.PkScript, w.chainParams)
var inputAcct uint32
if err == nil && len(addrs) > 0 {
- _, inputAcct, err = w.Manager.AddrAccount(addrmgrNs, addrs[0])
+ _, inputAcct, err = w.addrStore.AddrAccount(addrmgrNs, addrs[0])
}
if err != nil {
log.Errorf("Cannot fetch account for previous output %v: %v", prevOP, err)
@@ -85,7 +86,7 @@ func lookupOutputChain(dbtx walletdb.ReadTx, w *Wallet, details *wtxmgr.TxDetail
_, addrs, _, err := txscript.ExtractPkScriptAddrs(output.PkScript, w.chainParams)
var ma waddrmgr.ManagedAddress
if err == nil && len(addrs) > 0 {
- ma, err = w.Manager.Address(addrmgrNs, addrs[0])
+ ma, err = w.addrStore.Address(addrmgrNs, addrs[0])
}
if err != nil {
log.Errorf("Cannot fetch account for wallet output: %v", err)
@@ -155,7 +156,10 @@ func makeTxSummary(dbtx walletdb.ReadTx, w *Wallet, details *wtxmgr.TxDetails) T
func totalBalances(dbtx walletdb.ReadTx, w *Wallet, m map[uint32]btcutil.Amount) error {
addrmgrNs := dbtx.ReadBucket(waddrmgrNamespaceKey)
- unspent, err := w.TxStore.UnspentOutputs(dbtx.ReadBucket(wtxmgrNamespaceKey))
+
+ unspent, err := w.txStore.UnspentOutputs(
+ dbtx.ReadBucket(wtxmgrNamespaceKey),
+ )
if err != nil {
return err
}
@@ -165,7 +169,9 @@ func totalBalances(dbtx walletdb.ReadTx, w *Wallet, m map[uint32]btcutil.Amount)
_, addrs, _, err := txscript.ExtractPkScriptAddrs(
output.PkScript, w.chainParams)
if err == nil && len(addrs) > 0 {
- _, outputAcct, err = w.Manager.AddrAccount(addrmgrNs, addrs[0])
+ _, outputAcct, err = w.addrStore.AddrAccount(
+ addrmgrNs, addrs[0],
+ )
}
if err == nil {
_, ok := m[outputAcct]
@@ -215,7 +221,7 @@ func (s *NotificationServer) notifyUnminedTransaction(dbtx walletdb.ReadTx,
// TODO(wilmer): ideally we should find the culprit to why we're
// receiving an additional unconfirmed chain.RelevantTx notification
// from the chain backend.
- details, err := s.wallet.TxStore.UniqueTxDetails(ns, &txHash, nil)
+ details, err := s.wallet.txStore.UniqueTxDetails(ns, &txHash, nil)
if err != nil {
log.Errorf("Cannot query transaction details for "+
"notification: %v", err)
@@ -231,7 +237,10 @@ func (s *NotificationServer) notifyUnminedTransaction(dbtx walletdb.ReadTx,
}
unminedTxs := []TransactionSummary{makeTxSummary(dbtx, s.wallet, details)}
- unminedHashes, err := s.wallet.TxStore.UnminedTxHashes(dbtx.ReadBucket(wtxmgrNamespaceKey))
+
+ unminedHashes, err := s.wallet.txStore.UnminedTxHashes(
+ dbtx.ReadBucket(wtxmgrNamespaceKey),
+ )
if err != nil {
log.Errorf("Cannot fetch unmined transaction hashes: %v", err)
return
@@ -284,7 +293,7 @@ func (s *NotificationServer) notifyMinedTransaction(dbtx walletdb.ReadTx,
// We'll only notify the transaction if it was found within the
// wallet's set of confirmed transactions.
- details, err := s.wallet.TxStore.UniqueTxDetails(
+ details, err := s.wallet.txStore.UniqueTxDetails(
ns, &txHash, &block.Block,
)
if err != nil {
@@ -353,7 +362,8 @@ func (s *NotificationServer) notifyAttachedBlock(dbtx walletdb.ReadTx, block *wt
// a new, previously unseen transaction appearing in unconfirmed.
txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
- unminedHashes, err := s.wallet.TxStore.UnminedTxHashes(txmgrNs)
+
+ unminedHashes, err := s.wallet.txStore.UnminedTxHashes(txmgrNs)
if err != nil {
log.Errorf("Cannot fetch unmined transaction hashes: %v", err)
return
diff --git a/wallet/psbt.go b/wallet/psbt.go
deleted file mode 100644
index 56ec4dd5d6..0000000000
--- a/wallet/psbt.go
+++ /dev/null
@@ -1,582 +0,0 @@
-// Copyright (c) 2020 The btcsuite developers
-// Use of this source code is governed by an ISC
-// license that can be found in the LICENSE file.
-
-package wallet
-
-import (
- "bytes"
- "errors"
- "fmt"
-
- "github.com/btcsuite/btcd/btcutil/v2"
- "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
- "github.com/btcsuite/btcd/psbt/v2"
- "github.com/btcsuite/btcd/txscript/v2"
- "github.com/btcsuite/btcd/wire/v2"
- "github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/btcsuite/btcwallet/wallet/txauthor"
- "github.com/btcsuite/btcwallet/wallet/txrules"
- "github.com/btcsuite/btcwallet/walletdb"
- "github.com/btcsuite/btcwallet/wtxmgr"
-)
-
-// FundPsbt creates a fully populated PSBT packet that contains enough inputs to
-// fund the outputs specified in the passed in packet with the specified fee
-// rate. If there is change left, a change output from the wallet is added and
-// the index of the change output is returned. If no custom change scope is
-// specified, we will use the coin selection scope (if not nil) or the BIP0086
-// scope by default. Otherwise, no additional output is created and the
-// index -1 is returned.
-//
-// NOTE: If the packet doesn't contain any inputs, coin selection is performed
-// automatically, only selecting inputs from the account based on the given key
-// scope and account number. If a key scope is not specified, then inputs from
-// accounts matching the account number provided across all key scopes may be
-// selected. This is done to handle the default account case, where a user wants
-// to fund a PSBT with inputs regardless of their type (NP2WKH, P2WKH, etc.). If
-// the packet does contain any inputs, it is assumed that full coin selection
-// happened externally and no additional inputs are added. If the specified
-// inputs aren't enough to fund the outputs with the given fee rate, an error is
-// returned.
-//
-// NOTE: A caller of the method should hold the global coin selection lock of
-// the wallet. However, no UTXO specific lock lease is acquired for any of the
-// selected/validated inputs by this method. It is in the caller's
-// responsibility to lock the inputs before handing the partial transaction out.
-func (w *Wallet) FundPsbt(packet *psbt.Packet, keyScope *waddrmgr.KeyScope,
- minConfs int32, account uint32, feeSatPerKB btcutil.Amount,
- coinSelectionStrategy CoinSelectionStrategy,
- optFuncs ...TxCreateOption) (int32, error) {
-
- // Make sure the packet is well formed. We only require there to be at
- // least one input or output.
- err := psbt.VerifyInputOutputLen(packet, false, false)
- if err != nil {
- return 0, err
- }
-
- if len(packet.UnsignedTx.TxIn) == 0 && len(packet.UnsignedTx.TxOut) == 0 {
- return 0, fmt.Errorf("PSBT packet must contain at least one " +
- "input or output")
- }
-
- txOut := packet.UnsignedTx.TxOut
- txIn := packet.UnsignedTx.TxIn
-
- // Make sure none of the outputs are dust.
- for _, output := range txOut {
- // When checking an output for things like dusty-ness, we'll
- // use the default mempool relay fee rather than the target
- // effective fee rate to ensure accuracy. Otherwise, we may
- // mistakenly mark small-ish, but not quite dust output as
- // dust.
- err := txrules.CheckOutput(output, txrules.DefaultRelayFeePerKb)
- if err != nil {
- return 0, err
- }
- }
-
- // Let's find out the amount to fund first.
- amt := int64(0)
- for _, output := range txOut {
- amt += output.Value
- }
-
- var tx *txauthor.AuthoredTx
- switch {
- // We need to do coin selection.
- case len(txIn) == 0:
- // We ask the underlying wallet to fund a TX for us. This
- // includes everything we need, specifically fee estimation and
- // change address creation.
- tx, err = w.CreateSimpleTx(
- keyScope, account, packet.UnsignedTx.TxOut, minConfs,
- feeSatPerKB, coinSelectionStrategy, false,
- optFuncs...,
- )
- if err != nil {
- return 0, fmt.Errorf("error creating funding TX: %w",
- err)
- }
-
- // Copy over the inputs now then collect all UTXO information
- // that we can and attach them to the PSBT as well. We don't
- // include the witness as the resulting PSBT isn't expected not
- // should be signed yet.
- packet.UnsignedTx.TxIn = tx.Tx.TxIn
- packet.Inputs = make([]psbt.PInput, len(packet.UnsignedTx.TxIn))
-
- for idx := range packet.UnsignedTx.TxIn {
- // We don't want to include the witness or any script
- // on the unsigned TX just yet.
- packet.UnsignedTx.TxIn[idx].Witness = wire.TxWitness{}
- packet.UnsignedTx.TxIn[idx].SignatureScript = nil
- }
-
- err := w.DecorateInputs(packet, true)
- if err != nil {
- return 0, err
- }
-
- // If there are inputs, we need to check if they're sufficient and add
- // a change output if necessary.
- default:
- // Make sure all inputs provided are actually ours.
- packet.Inputs = make([]psbt.PInput, len(packet.UnsignedTx.TxIn))
-
- for idx := range packet.UnsignedTx.TxIn {
- // We don't want to include the witness or any script
- // on the unsigned TX just yet.
- packet.UnsignedTx.TxIn[idx].Witness = wire.TxWitness{}
- packet.UnsignedTx.TxIn[idx].SignatureScript = nil
- }
-
- err := w.DecorateInputs(packet, true)
- if err != nil {
- return 0, err
- }
-
- // We can leverage the fee calculation of the txauthor package
- // if we provide the selected UTXOs as a coin source. We just
- // need to make sure we always return the full list of user-
- // selected UTXOs rather than a subset, otherwise our change
- // amount will be off (in case the user selected multiple UTXOs
- // that are large enough on their own). That's why we use our
- // own static input source creator instead of the more generic
- // makeInputSource() that selects a subset that is "large
- // enough".
- credits := make([]wtxmgr.Credit, len(txIn))
- for idx, in := range txIn {
- utxo := packet.Inputs[idx].WitnessUtxo
- credits[idx] = wtxmgr.Credit{
- OutPoint: in.PreviousOutPoint,
- Amount: btcutil.Amount(utxo.Value),
- PkScript: utxo.PkScript,
- }
- }
- inputSource := constantInputSource(credits)
-
- // Build the TxCreateOption to retrieve the change scope.
- opts := defaultTxCreateOptions()
- for _, optFunc := range optFuncs {
- optFunc(opts)
- }
-
- if opts.changeKeyScope == nil {
- opts.changeKeyScope = keyScope
- }
-
- // The addrMgrWithChangeSource function of the wallet creates a
- // new change address. The address manager uses OnCommit on the
- // walletdb tx to update the in-memory state of the account
- // state. But because the commit happens _after_ the account
- // manager internal lock has been released, there is a chance
- // for the address index to be accessed concurrently, even
- // though the closure in OnCommit re-acquires the lock. To avoid
- // this issue, we surround the whole address creation process
- // with a lock.
- w.newAddrMtx.Lock()
-
- // We also need a change source which needs to be able to insert
- // a new change address into the database.
- err = walletdb.Update(w.db, func(dbtx walletdb.ReadWriteTx) error {
- _, changeSource, err := w.addrMgrWithChangeSource(
- dbtx, opts.changeKeyScope, account,
- )
- if err != nil {
- return err
- }
-
- // Ask the txauthor to create a transaction with our
- // selected coins. This will perform fee estimation and
- // add a change output if necessary.
- tx, err = txauthor.NewUnsignedTransaction(
- txOut, feeSatPerKB, inputSource, changeSource,
- )
- if err != nil {
- return fmt.Errorf("fee estimation not "+
- "successful: %w", err)
- }
-
- return nil
- })
- w.newAddrMtx.Unlock()
-
- if err != nil {
- return 0, fmt.Errorf("could not add change address to "+
- "database: %w", err)
- }
- }
-
- // If there is a change output, we need to copy it over to the PSBT now.
- var changeTxOut *wire.TxOut
- if tx.ChangeIndex >= 0 {
- changeTxOut = tx.Tx.TxOut[tx.ChangeIndex]
- packet.UnsignedTx.TxOut = append(
- packet.UnsignedTx.TxOut, changeTxOut,
- )
-
- addr, _, _, err := w.ScriptForOutput(changeTxOut)
- if err != nil {
- return 0, fmt.Errorf("error querying wallet for "+
- "change addr: %w", err)
- }
-
- changeOutputInfo, err := createOutputInfo(changeTxOut, addr)
- if err != nil {
- return 0, fmt.Errorf("error adding output info to "+
- "change output: %w", err)
- }
-
- packet.Outputs = append(packet.Outputs, *changeOutputInfo)
- }
-
- // Now that we have the final PSBT ready, we can sort it according to
- // BIP 69. This will sort the wire inputs and outputs and move the
- // partial inputs and outputs accordingly.
- err = psbt.InPlaceSort(packet)
- if err != nil {
- return 0, fmt.Errorf("could not sort PSBT: %w", err)
- }
-
- // The change output index might have changed after the sorting. We need
- // to find our index again.
- changeIndex := int32(-1)
- if changeTxOut != nil {
- for idx, txOut := range packet.UnsignedTx.TxOut {
- if psbt.TxOutsEqual(changeTxOut, txOut) {
- changeIndex = int32(idx)
- break
- }
- }
- }
-
- return changeIndex, nil
-}
-
-// DecorateInputs fetches the UTXO information of all inputs it can identify and
-// adds the required information to the package's inputs. The failOnUnknown
-// boolean controls whether the method should return an error if it cannot
-// identify an input or if it should just skip it.
-func (w *Wallet) DecorateInputs(packet *psbt.Packet, failOnUnknown bool) error {
- for idx := range packet.Inputs {
- txIn := packet.UnsignedTx.TxIn[idx]
-
- tx, utxo, derivationPath, _, err := w.FetchInputInfo(
- &txIn.PreviousOutPoint,
- )
-
- switch {
- // If the error just means it's not an input our wallet controls
- // and the user doesn't care about that, then we can just skip
- // this input and continue.
- case errors.Is(err, ErrNotMine) && !failOnUnknown:
- continue
-
- case err != nil:
- return fmt.Errorf("error fetching UTXO: %w", err)
- }
-
- addr, witnessProgram, _, err := w.ScriptForOutput(utxo)
- if err != nil {
- return fmt.Errorf("error fetching UTXO script: %w", err)
- }
-
- switch {
- case txscript.IsPayToTaproot(utxo.PkScript):
- addInputInfoSegWitV1(
- &packet.Inputs[idx], utxo, derivationPath,
- )
-
- default:
- addInputInfoSegWitV0(
- &packet.Inputs[idx], tx, utxo, derivationPath,
- addr, witnessProgram,
- )
- }
- }
-
- return nil
-}
-
-// addInputInfoSegWitV0 adds the UTXO and BIP32 derivation info for a SegWit v0
-// PSBT input (p2wkh, np2wkh) from the given wallet information.
-func addInputInfoSegWitV0(in *psbt.PInput, prevTx *wire.MsgTx, utxo *wire.TxOut,
- derivationInfo *psbt.Bip32Derivation, addr waddrmgr.ManagedAddress,
- witnessProgram []byte) {
-
- // As a fix for CVE-2020-14199 we have to always include the full
- // non-witness UTXO in the PSBT for segwit v0.
- in.NonWitnessUtxo = prevTx
-
- // To make it more obvious that this is actually a witness output being
- // spent, we also add the same information as the witness UTXO.
- in.WitnessUtxo = &wire.TxOut{
- Value: utxo.Value,
- PkScript: utxo.PkScript,
- }
- in.SighashType = txscript.SigHashAll
-
- // Include the derivation path for each input.
- in.Bip32Derivation = []*psbt.Bip32Derivation{
- derivationInfo,
- }
-
- // For nested P2WKH we need to add the redeem script to the input,
- // otherwise an offline wallet won't be able to sign for it. For normal
- // P2WKH this will be nil.
- if addr.AddrType() == waddrmgr.NestedWitnessPubKey {
- in.RedeemScript = witnessProgram
- }
-}
-
-// addInputInfoSegWitV0 adds the UTXO and BIP32 derivation info for a SegWit v1
-// PSBT input (p2tr) from the given wallet information.
-func addInputInfoSegWitV1(in *psbt.PInput, utxo *wire.TxOut,
- derivationInfo *psbt.Bip32Derivation) {
-
- // For SegWit v1 we only need the witness UTXO information.
- in.WitnessUtxo = &wire.TxOut{
- Value: utxo.Value,
- PkScript: utxo.PkScript,
- }
- in.SighashType = txscript.SigHashDefault
-
- // Include the derivation path for each input in addition to the
- // taproot specific info we have below.
- in.Bip32Derivation = []*psbt.Bip32Derivation{
- derivationInfo,
- }
-
- // Include the derivation path for each input.
- in.TaprootBip32Derivation = []*psbt.TaprootBip32Derivation{{
- XOnlyPubKey: derivationInfo.PubKey[1:],
- MasterKeyFingerprint: derivationInfo.MasterKeyFingerprint,
- Bip32Path: derivationInfo.Bip32Path,
- }}
-}
-
-// createOutputInfo creates the BIP32 derivation info for an output from our
-// internal wallet.
-func createOutputInfo(txOut *wire.TxOut,
- addr waddrmgr.ManagedPubKeyAddress) (*psbt.POutput, error) {
-
- // We don't know the derivation path for imported keys. Those shouldn't
- // be selected as change outputs in the first place, but just to make
- // sure we don't run into an issue, we return early for imported keys.
- keyScope, derivationPath, isKnown := addr.DerivationInfo()
- if !isKnown {
- return nil, fmt.Errorf("error adding output info to PSBT, " +
- "change addr is an imported addr with unknown " +
- "derivation path")
- }
-
- // Include the derivation path for this output.
- derivation := &psbt.Bip32Derivation{
- PubKey: addr.PubKey().SerializeCompressed(),
- MasterKeyFingerprint: derivationPath.MasterKeyFingerprint,
- Bip32Path: []uint32{
- keyScope.Purpose + hdkeychain.HardenedKeyStart,
- keyScope.Coin + hdkeychain.HardenedKeyStart,
- derivationPath.Account,
- derivationPath.Branch,
- derivationPath.Index,
- },
- }
- out := &psbt.POutput{
- Bip32Derivation: []*psbt.Bip32Derivation{
- derivation,
- },
- }
-
- // Include the Taproot derivation path as well if this is a P2TR output.
- if txscript.IsPayToTaproot(txOut.PkScript) {
- schnorrPubKey := derivation.PubKey[1:]
- out.TaprootBip32Derivation = []*psbt.TaprootBip32Derivation{{
- XOnlyPubKey: schnorrPubKey,
- MasterKeyFingerprint: derivation.MasterKeyFingerprint,
- Bip32Path: derivation.Bip32Path,
- }}
- out.TaprootInternalKey = schnorrPubKey
- }
-
- return out, nil
-}
-
-// FinalizePsbt expects a partial transaction with all inputs and outputs fully
-// declared and tries to sign all inputs that belong to the wallet. Our wallet
-// must be the last signer of the transaction. That means, if there are any
-// unsigned non-witness inputs or inputs without UTXO information attached or
-// inputs without witness data that do not belong to the wallet, this method
-// will fail. If no error is returned, the PSBT is ready to be extracted and the
-// final TX within to be broadcast.
-//
-// NOTE: This method does NOT publish the transaction after it's been finalized
-// successfully.
-func (w *Wallet) FinalizePsbt(keyScope *waddrmgr.KeyScope, account uint32,
- packet *psbt.Packet) error {
-
- // Let's check that this is actually something we can and want to sign.
- // We need at least one input and one output. In addition each
- // input needs nonWitness Utxo or witness Utxo data specified.
- err := psbt.InputsReadyToSign(packet)
- if err != nil {
- return err
- }
-
- // Go through each input that doesn't have final witness data attached
- // to it already and try to sign it. We do expect that we're the last
- // ones to sign. If there is any input without witness data that we
- // cannot sign because it's not our UTXO, this will be a hard failure.
- tx := packet.UnsignedTx
- sigHashes := txscript.NewTxSigHashes(tx, PsbtPrevOutputFetcher(packet))
- for idx, txIn := range tx.TxIn {
- in := packet.Inputs[idx]
-
- // We can only sign if we have UTXO information available. We
- // can just continue here as a later step will fail with a more
- // precise error message.
- if in.WitnessUtxo == nil && in.NonWitnessUtxo == nil {
- continue
- }
-
- // Skip this input if it's got final witness data attached.
- if len(in.FinalScriptWitness) > 0 {
- continue
- }
-
- // We can only sign this input if it's ours, so we try to map it
- // to a coin we own. If we can't, then we'll continue as it
- // isn't our input.
- fullTx, txOut, _, _, err := w.FetchInputInfo(
- &txIn.PreviousOutPoint,
- )
- if err != nil {
- continue
- }
-
- // Find out what UTXO we are signing. Wallets _should_ always
- // provide the full non-witness UTXO for segwit v0.
- var signOutput *wire.TxOut
- if in.NonWitnessUtxo != nil {
- prevIndex := txIn.PreviousOutPoint.Index
- signOutput = in.NonWitnessUtxo.TxOut[prevIndex]
-
- if !psbt.TxOutsEqual(txOut, signOutput) {
- return fmt.Errorf("found UTXO %#v but it "+
- "doesn't match PSBT's input %v", txOut,
- signOutput)
- }
-
- if fullTx.TxHash() != txIn.PreviousOutPoint.Hash {
- return fmt.Errorf("found UTXO tx %v but it "+
- "doesn't match PSBT's input %v",
- fullTx.TxHash(),
- txIn.PreviousOutPoint.Hash)
- }
- }
-
- // Fall back to witness UTXO only for older wallets.
- if in.WitnessUtxo != nil {
- signOutput = in.WitnessUtxo
-
- if !psbt.TxOutsEqual(txOut, signOutput) {
- return fmt.Errorf("found UTXO %#v but it "+
- "doesn't match PSBT's input %v", txOut,
- signOutput)
- }
- }
-
- // Finally, if the input doesn't belong to a watch-only account,
- // then we'll sign it as is, and populate the input with the
- // witness and sigScript (if needed).
- watchOnly := false
- err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- ns := tx.ReadBucket(waddrmgrNamespaceKey)
- var err error
- if keyScope == nil {
- // If a key scope wasn't specified, then coin
- // selection was performed from the default
- // wallet accounts (NP2WKH, P2WKH, P2TR), so any
- // key scope provided doesn't impact the result
- // of this call.
- watchOnly, err = w.Manager.IsWatchOnlyAccount(
- ns, waddrmgr.KeyScopeBIP0084, account,
- )
- } else {
- watchOnly, err = w.Manager.IsWatchOnlyAccount(
- ns, *keyScope, account,
- )
- }
- return err
- })
- if err != nil {
- return fmt.Errorf("unable to determine if account is "+
- "watch-only: %w", err)
- }
- if watchOnly {
- continue
- }
-
- witness, sigScript, err := w.ComputeInputScript(
- tx, signOutput, idx, sigHashes, in.SighashType, nil,
- )
- if err != nil {
- return fmt.Errorf("error computing input script for "+
- "input %d: %w", idx, err)
- }
-
- // Serialize the witness format from the stack representation to
- // the wire representation.
- var witnessBytes bytes.Buffer
- err = psbt.WriteTxWitness(&witnessBytes, witness)
- if err != nil {
- return fmt.Errorf("error serializing witness: %w", err)
- }
- packet.Inputs[idx].FinalScriptWitness = witnessBytes.Bytes()
- packet.Inputs[idx].FinalScriptSig = sigScript
- }
-
- // Make sure the PSBT itself thinks it's finalized and ready to be
- // broadcast.
- err = psbt.MaybeFinalizeAll(packet)
- if err != nil {
- return fmt.Errorf("error finalizing PSBT: %w", err)
- }
-
- return nil
-}
-
-// PsbtPrevOutputFetcher returns a txscript.PrevOutFetcher built from the UTXO
-// information in a PSBT packet.
-func PsbtPrevOutputFetcher(packet *psbt.Packet) *txscript.MultiPrevOutFetcher {
- fetcher := txscript.NewMultiPrevOutFetcher(nil)
- for idx, txIn := range packet.UnsignedTx.TxIn {
- in := packet.Inputs[idx]
-
- // Skip any input that has no UTXO.
- if in.WitnessUtxo == nil && in.NonWitnessUtxo == nil {
- continue
- }
-
- if in.NonWitnessUtxo != nil {
- prevIndex := txIn.PreviousOutPoint.Index
- fetcher.AddPrevOut(
- txIn.PreviousOutPoint,
- in.NonWitnessUtxo.TxOut[prevIndex],
- )
-
- continue
- }
-
- // Fall back to witness UTXO only for older wallets.
- if in.WitnessUtxo != nil {
- fetcher.AddPrevOut(
- txIn.PreviousOutPoint, in.WitnessUtxo,
- )
- }
- }
-
- return fetcher
-}
diff --git a/wallet/psbt_manager.go b/wallet/psbt_manager.go
new file mode 100644
index 0000000000..64a9b94abb
--- /dev/null
+++ b/wallet/psbt_manager.go
@@ -0,0 +1,2359 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wallet
+
+import (
+ "bytes"
+ "context"
+ "crypto/sha256"
+ "errors"
+ "fmt"
+ "math"
+ "slices"
+
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/psbt/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/pkg/btcunit"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/wallet/txauthor"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+)
+
+var (
+ // ErrNilArguments is returned when a required argument is nil.
+ ErrNilArguments = errors.New("nil arguments")
+
+ // ErrUtxoLocked is returned when a UTXO is locked.
+ ErrUtxoLocked = errors.New("utxo is locked")
+
+ // ErrChangeAddressNotManagedPubKey is returned when a change address is
+ // not a managed public key address.
+ ErrChangeAddressNotManagedPubKey = errors.New(
+ "change address is not a managed pubkey address",
+ )
+
+ // ErrChangeIndexOutOfRange is returned when the change index is out of
+ // range.
+ ErrChangeIndexOutOfRange = errors.New("change index out of range")
+
+ // ErrPacketOutputsMissing is returned when a PSBT is provided for
+ // funding with no outputs.
+ ErrPacketOutputsMissing = errors.New("psbt packet has no outputs")
+
+ // ErrInputsAndPolicy is returned when a PSBT is provided with inputs,
+ // but a coin selection policy is also specified.
+ ErrInputsAndPolicy = errors.New(
+ "cannot specify both psbt inputs and a coin selection policy",
+ )
+
+ // ErrNoPsbtsToCombine is returned when no PSBTs are provided to
+ // combine.
+ ErrNoPsbtsToCombine = errors.New("no psbts to combine")
+
+ // ErrDifferentTransactions is returned when PSBTs do not refer to the
+ // same transaction.
+ ErrDifferentTransactions = errors.New(
+ "psbts do not refer to the same transaction",
+ )
+
+ // ErrInputCountMismatch is returned when PSBTs have different input
+ // counts.
+ ErrInputCountMismatch = errors.New("input count mismatch")
+
+ // ErrOutputCountMismatch is returned when PSBTs have different output
+ // counts.
+ ErrOutputCountMismatch = errors.New("output count mismatch")
+
+ // ErrUnknownAddressType is returned when an unknown address type is
+ // encountered.
+ ErrUnknownAddressType = errors.New("unknown address type")
+
+ // ErrUnknownBip32Purpose is returned when a BIP32 path has a purpose
+ // that is not supported by the wallet.
+ ErrUnknownBip32Purpose = errors.New("unknown BIP32 purpose")
+
+ // ErrInvalidBip32Path is returned when a BIP32 derivation path is
+ // invalid (e.g. wrong length, missing hardening, wrong coin type).
+ ErrInvalidBip32Path = errors.New("invalid BIP32 path")
+
+ // ErrUnsupportedTaprootLeafCount is returned when a Taproot derivation
+ // info contains an unsupported number of leaf hashes (e.g. > 1).
+ ErrUnsupportedTaprootLeafCount = errors.New("unsupported number of " +
+ "leaf hashes in Taproot derivation")
+
+ // ErrMissingTaprootLeafScript is returned when a Taproot derivation
+ // specifies a leaf hash but the corresponding Taproot leaf script is
+ // missing from the PSBT.
+ ErrMissingTaprootLeafScript = errors.New("specified leaf hash in " +
+ "taproot BIP0032 derivation but missing taproot leaf script")
+
+ // ErrTaprootLeafHashMismatch is returned when the calculated hash of
+ // the provided Taproot leaf script does not match the leaf hash
+ // specified in the derivation info.
+ ErrTaprootLeafHashMismatch = errors.New("specified leaf hash in " +
+ "taproot BIP0032 derivation but corresponding taproot leaf " +
+ "script was not found")
+
+ // ErrUnsupportedMultipleTaprootDerivation is returned when a Taproot
+ // input has multiple derivation paths, which is not supported.
+ ErrUnsupportedMultipleTaprootDerivation = errors.New(
+ "unsupported multiple taproot BIP0032 derivation info found",
+ )
+
+ // ErrUnsupportedMultipleBip32Derivation is returned when a BIP32
+ // input has multiple derivation paths, which is not supported.
+ ErrUnsupportedMultipleBip32Derivation = errors.New(
+ "unsupported multiple BIP0032 derivation info found",
+ )
+
+ // ErrAmbiguousDerivation is returned when an input has both Taproot and
+ // BIP32 derivation information, which is an ambiguous state.
+ ErrAmbiguousDerivation = errors.New(
+ "both Taproot and BIP32 derivation info found",
+ )
+
+ // ErrInvalidTaprootMerkleRootLength is returned when the Taproot
+ // Merkle Root has an invalid length.
+ ErrInvalidTaprootMerkleRootLength = errors.New(
+ "invalid taproot merkle root length",
+ )
+
+ // ErrPsbtMergeConflict is returned when merging PSBTs with conflicting
+ // fields (e.g. different sighash types, scripts, or signatures).
+ ErrPsbtMergeConflict = errors.New("psbt merge conflict")
+
+ // ErrImportedAddrNoDerivation is returned when trying to add output
+ // info for an imported address that has no derivation path.
+ ErrImportedAddrNoDerivation = errors.New("change addr is an " +
+ "imported addr with unknown derivation path")
+
+ // ErrIndexOutOfBounds is returned when an index is out of bounds.
+ ErrIndexOutOfBounds = errors.New("index out of bounds")
+
+ // ErrInputMissingUtxoInfo is returned when an input lacks both
+ // WitnessUtxo and NonWitnessUtxo.
+ ErrInputMissingUtxoInfo = errors.New("input missing both " +
+ "WitnessUtxo and NonWitnessUtxo")
+
+ // errAlreadySigned is returned when an input is already signed.
+ //
+ // NOTE: This error is private because it is used for internal control
+ // flow within the signing loop (to skip inputs) and should not be
+ // returned to the caller.
+ errAlreadySigned = errors.New("input already signed")
+
+ // errComputeRawSig is returned when the wallet cannot produce a
+ // signature for the input (e.g. key not found, signing error).
+ //
+ // NOTE: This error is private because it is used for internal control
+ // flow (skipping inputs that don't belong to this wallet) and should
+ // not be exposed to the caller.
+ errComputeRawSig = errors.New("cannot compute raw signature")
+)
+
+const (
+ // BIP32PathLength is the expected length of a BIP32 derivation path. A
+ // full path follows the structure:
+ // m / purpose' / coin_type' / account' / branch / index.
+ BIP32PathLength = 5
+)
+
+// FundIntent represents the user's intent for funding a PSBT. It serves as a
+// blueprint for the FundPsbt method, bundling all the parameters required to
+// construct a funded transaction into a single, coherent structure.
+type FundIntent struct {
+ // Packet is the PSBT to be funded. It must contain the outputs to be
+ // funded. If inputs are also specified, the wallet will detect this and
+ // enter a "completion" mode, where it only adds a change output if
+ // necessary, rather than performing full coin selection.
+ Packet *psbt.Packet
+
+ // Policy specifies the coin selection policy to use when funding the
+ // PSBT. This field is only used when the `Packet` has no inputs,
+ // indicating that automatic coin selection should be performed. If this
+ // policy is used, the `Source` (`*ScopedAccount`) must be fully
+ // specified with both `AccountName` and `KeyScope`, as the wallet will
+ // not perform any searches or guesses. If the `Packet` already
+ // contains inputs, this field is ignored.
+ Policy *InputsPolicy
+
+ // FeeRate specifies the desired fee rate for the transaction, expressed
+ // in satoshis per kilo-virtual-byte (sat/kvb). This field is always
+ // required, regardless of whether coin selection is performed.
+ FeeRate btcunit.SatPerKVByte
+
+ // ChangeSource specifies the account and key scope to use for the
+ // change output. If this field is nil, the wallet will use a default
+ // change source based on the account and scope of the inputs.
+ ChangeSource *ScopedAccount
+
+ // Label is an optional, human-readable label for the transaction. This
+ // can be used to associate a memo with the transaction for later
+ // reference.
+ Label string
+}
+
+// SignPsbtParams encapsulates the arguments for signing a PSBT.
+type SignPsbtParams struct {
+ // Packet is the PSBT to be signed.
+ Packet *psbt.Packet
+
+ // InputTweakers is a map of input indices to a private key tweaker.
+ // This allows the caller to define a specific tweaker for each input
+ // index.
+ //
+ // NOTE: The ideal implementation would be to add a new field to
+ // psbt.PInput that holds the tweaker, but that would require a change
+ // to the core psbt package in btcd. To keep btcwallet generic and avoid
+ // that dependency, we allow the caller (e.g. lnd) to inspect the PSBT
+ // beforehand, determine the necessary tweaks (e.g. based on custom
+ // fields like PsbtKeyTypeInputSignatureTweakSingle), and pass them in
+ // via this map.
+ InputTweakers map[int]PrivKeyTweaker
+}
+
+// SignPsbtResult encapsulates the result of a PSBT signing operation.
+type SignPsbtResult struct {
+ // SignedInputs contains the indices of the inputs that were
+ // successfully signed.
+ SignedInputs []uint32
+
+ // Packet is the modified PSBT packet. This is the same pointer as
+ // passed in the params, returned for convenience.
+ Packet *psbt.Packet
+}
+
+// PsbtManager provides a cohesive, high-level interface for creating and
+// managing Partially Signed Bitcoin Transactions (PSBTs). It encapsulates the
+// entire workflow, from funding and decorating to signing and finalization,
+// allowing users to construct complex transactions in a safe and predictable
+// manner.
+//
+// The typical workflow for a single-signer transaction is as follows:
+//
+// 1. Create a bare PSBT:
+// A stateless helper function, CreatePsbt, is used to construct a PSBT
+// packet from a list of desired inputs and outputs.
+//
+// // The user specifies their desired outputs.
+// outputs := []*wire.TxOut{{Value: 100000, PkScript: carolPkScript}}
+//
+// // A bare PSBT is created, representing the transaction template.
+// barePacket, err := wallet.CreatePsbt(nil, outputs)
+//
+// 2. Fund the PSBT:
+// The FundPsbt method is called to perform coin selection. The wallet selects
+// UTXOs to cover the output value and fee, adds them as inputs, and adds a
+// change output if necessary.
+//
+// fundIntent := &wallet.FundIntent{
+// Packet: barePacket,
+// Policy: &wallet.InputsPolicy{
+// Source: &wallet.ScopedAccount{
+// AccountName: "default",
+// KeyScope: waddrmgr.KeyScopeBIP0086,
+// },
+// MinConfs: 1,
+// },
+// FeeRate: btcunit.NewSatPerKVByte(250),
+// }
+// fundedPacket, changeIndex, err := psbtManager.FundPsbt(
+// ctx, fundIntent,
+// )
+//
+// The `fundedPacket` now contains the necessary inputs (fully decorated)
+// and a change output. The `changeIndex` indicates the index of the
+// change output in the `fundedPacket.UnsignedTx.TxOut` slice, or -1 if
+// no change output was added.
+//
+// 3. Sign the PSBT:
+// The wallet signs all inputs it has the keys for.
+//
+// signParams := &wallet.SignPsbtParams{Packet: barePacket}
+// result, err := psbtManager.SignPsbt(ctx, signParams)
+//
+// 4. Finalize the PSBT:
+// The final scriptSig and/or witness for each input is constructed.
+//
+// err = psbtManager.FinalizePsbt(ctx, barePacket)
+//
+// 5. Extract and Broadcast:
+// The final, network-ready transaction is extracted and broadcast.
+//
+// finalTx, err := psbt.Extract(barePacket)
+// err = broadcaster.Broadcast(ctx, finalTx, "payment")
+//
+// For more detailed examples, including multi-party collaborative workflows,
+// see the documentation in the `wallet/docs/psbt_workflows.md` file.
+type PsbtManager interface {
+ // DecorateInputs enriches a PSBT's inputs with UTXO and derivation
+ // information known to the wallet.
+ //
+ // This is useful when importing a PSBT created externally (e.g., by a
+ // coordinator or another wallet) that only contains references to
+ // inputs (txids/indices) but lacks the necessary witness data and key
+ // derivation paths required for signing.
+ //
+ // If `skipUnknown` is true, the wallet will skip inputs it does not
+ // recognize; otherwise, it will return an error if any input is not
+ // found in the wallet's transaction store.
+ DecorateInputs(ctx context.Context, packet *psbt.Packet,
+ skipUnknown bool) (*psbt.Packet, error)
+
+ // FundPsbt performs coin selection and adds the selected inputs (and a
+ // change output, if necessary) to the PSBT.
+ //
+ // It inspects the provided `FundIntent` to determine whether to
+ // perform automatic coin selection (if no inputs are present) or to
+ // validate and fund a specific set of manual inputs.
+ //
+ // The returned PSBT is a fully funded transaction template, ready for
+ // signing. The change output index is also returned (-1 if no change
+ // was added).
+ FundPsbt(ctx context.Context, intent *FundIntent) (*psbt.Packet,
+ int32, error)
+
+ // SignPsbt adds partial signatures to the PSBT for all inputs
+ // controlled by the wallet.
+ //
+ // It iterates through the inputs, identifying those for which the
+ // wallet possesses the private key (based on derivation information),
+ // and appends a valid signature to the partial signature field.
+ //
+ // Note: This method is non-destructive; it adds signatures without
+ // finalizing the inputs, allowing for further signing in multi-party
+ // scenarios. It enforces a strict policy of one signature per input
+ // per call to avoid ambiguity in complex derivation paths.
+ SignPsbt(ctx context.Context, params *SignPsbtParams) (
+ *SignPsbtResult, error)
+
+ // FinalizePsbt attempts to finalize the PSBT, transitioning it from a
+ // partially signed state to a complete, network-ready transaction.
+ //
+ // It validates that all inputs have sufficient signatures to satisfy
+ // their spending scripts. If valid, it constructs the final
+ // `scriptSig` and `witness` fields and removes the partial signature
+ // data.
+ //
+ // Note: This implementation is "smart": if it detects an input owned
+ // by the wallet that is not yet signed, it will attempt to sign it
+ // internally before finalization.
+ FinalizePsbt(ctx context.Context, packet *psbt.Packet) error
+
+ // CombinePsbt acts as the "Combiner" role in BIP 174, merging multiple
+ // Partially Signed Bitcoin Transactions (PSBTs) into a single packet.
+ //
+ // This is distinct from FinalizePsbt: CombinePsbt aggregates partial
+ // signatures and metadata from different signers (who signed copies of
+ // the same transaction in parallel), whereas FinalizePsbt uses those
+ // aggregated signatures to construct the final valid network
+ // transaction.
+ //
+ // This method is essential for collaborative workflows (e.g. Multisig,
+ // CoinJoin) where no single party holds all necessary keys.
+ CombinePsbt(ctx context.Context, psbts ...*psbt.Packet) (
+ *psbt.Packet, error)
+}
+
+// DecorateInputs enriches a PSBT's inputs with UTXO and derivation information.
+//
+// It iterates through all inputs in the PSBT and:
+// 1. Validates ownership: Calls `fetchAndValidateUtxo` to check if the input
+// references a UTXO owned by the wallet.
+// 2. Enriches: If owned, calls `decorateInput` to add the full previous
+// transaction (`NonWitnessUtxo`) or output (`WitnessUtxo`), along with
+// BIP32 derivation paths (`Bip32Derivation` or `TaprootBip32Derivation`)
+// and script information.
+func (w *Wallet) DecorateInputs(ctx context.Context, packet *psbt.Packet,
+ skipUnknown bool) (*psbt.Packet, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ // We'll iterate through all the inputs of the PSBT and decorate them
+ // if they are owned by the wallet. The `skipUnknown` parameter
+ // determines whether an error is returned if an input is not owned
+ // by the wallet.
+ for i, txIn := range packet.UnsignedTx.TxIn {
+ // Attempt to fetch the transaction details for the current
+ // input from our transaction store and validate that we own
+ // the UTXO. The `fetchAndValidateUtxo` function will return an
+ // `ErrNotMine` error if the UTXO is not found or not owned by
+ // the wallet.
+ tx, utxo, err := w.fetchAndValidateUtxo(txIn)
+ if err != nil {
+ // If the error is `ErrNotMine` and `skipUnknown` is
+ // true, we'll simply continue to the next input, as we
+ // don't own it and are not required to fail.
+ if errors.Is(err, ErrNotMine) && skipUnknown {
+ continue
+ }
+
+ // Otherwise, we'll return the error. This includes the
+ // case where the UTXO is locked.
+ return nil, err
+ }
+
+ // If we own the UTXO, we'll proceed to decorate the
+ // corresponding PSBT input with detailed information from the
+ // wallet.
+ err = w.decorateInput(ctx, &packet.Inputs[i], tx, utxo)
+ if err != nil {
+ return nil, fmt.Errorf("error decorating input %d: %w",
+ i, err)
+ }
+ }
+
+ return packet, nil
+}
+
+// decorateInput is a helper function that decorates a single PSBT input with
+// UTXO information from the wallet.
+//
+// NOTE: The `pInput` parameter is modified in-place by this function.
+func (w *Wallet) decorateInput(ctx context.Context, pInput *psbt.PInput,
+ tx *wire.MsgTx, utxo *wire.TxOut) error {
+
+ // We'll start by extracting the address from the UTXO's pkScript.
+ // This will be used to look up the managed address from the
+ // database.
+ addr := extractAddrFromPKScript(utxo.PkScript, w.cfg.ChainParams)
+ if addr == nil {
+ return fmt.Errorf("%w: from pkscript %x",
+ ErrUnableToExtractAddress, utxo.PkScript)
+ }
+
+ // We'll then use the address to look up the managed address from the
+ // database. This will give us access to the derivation information.
+ managedAddr, err := w.AddressInfo(ctx, addr)
+ if err != nil {
+ return fmt.Errorf("unable to get address info for %s: %w",
+ addr.String(), err)
+ }
+
+ // We'll ensure that the managed address is a public key address, as
+ // we can only decorate inputs for which we have the private key.
+ pubKeyAddr, ok := managedAddr.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ return fmt.Errorf("%w: addr %s", ErrNotPubKeyAddress,
+ managedAddr.Address())
+ }
+
+ // With the managed address, we can now get the derivation information
+ // for the address.
+ derivation, err := derivationForManagedAddress(pubKeyAddr)
+ if err != nil {
+ return err
+ }
+
+ // With all the information gathered, we'll now populate the PSBT
+ // input based on its address type by calling the existing, non-
+ // deprecated helper functions.
+ switch {
+ // For SegWit v1 (Taproot) inputs, we'll use the SegWit v1 helper.
+ case txscript.IsPayToTaproot(utxo.PkScript):
+ addInputInfoSegWitV1(pInput, utxo, derivation)
+
+ // For SegWit v0 inputs, we'll use the SegWit v0 helper.
+ default:
+ // We'll need to build the redeem script for the input.
+ _, redeemScript, err := buildScriptsForManagedAddress(
+ pubKeyAddr, utxo.PkScript, w.cfg.ChainParams,
+ )
+ if err != nil {
+ return err
+ }
+
+ // With the redeem script, we can now populate the PSBT
+ // input.
+ addInputInfoSegWitV0(
+ pInput, tx, utxo, derivation, managedAddr, redeemScript,
+ )
+ }
+
+ return nil
+}
+
+// fetchAndValidateUtxo fetches the transaction details for a given input,
+// validates that the wallet owns the UTXO, and ensures it is not locked.
+//
+// This function serves as a crucial pre-check before decorating a PSBT input.
+// It performs three key validation steps:
+// 1. Transaction Lookup: It first attempts to fetch the full transaction
+// details from the wallet's transaction store using the input's previous
+// outpoint. If the transaction is not found, it returns an `ErrNotMine`
+// error.
+// 2. Ownership Verification: If the transaction is found, it verifies that the
+// specific output index is a credit to the wallet. This ensures that the
+// wallet actually owns the UTXO. If this check fails, it also returns
+// `ErrNotMine`.
+// 3. Lock Status Check: After confirming ownership, it checks if the UTXO has
+// been locked. If the UTXO is locked, it returns an `ErrUtxoLocked`
+// error.
+//
+// Only if all these checks pass, the function returns the full parent
+// transaction (`*wire.MsgTx`) and the specific unspent transaction output
+// (`*wire.TxOut`).
+func (w *Wallet) fetchAndValidateUtxo(txIn *wire.TxIn) (
+ *wire.MsgTx, *wire.TxOut, error) {
+
+ // First, we'll attempt to fetch the transaction details from our
+ // transaction store.
+ txDetail, err := w.fetchTxDetails(&txIn.PreviousOutPoint.Hash)
+ if errors.Is(err, ErrTxNotFound) {
+ return nil, nil, fmt.Errorf("%w: %v", ErrNotMine,
+ txIn.PreviousOutPoint)
+ }
+
+ if err != nil {
+ return nil, nil, fmt.Errorf("failed to fetch tx details: %w",
+ err)
+ }
+
+ // With the transaction details retrieved, we'll make an additional
+ // check to ensure we actually have control of this output.
+ cred := findCredit(txDetail, txIn.PreviousOutPoint.Index)
+ if cred == nil {
+ return nil, nil, fmt.Errorf("%w: %v", ErrNotMine,
+ txIn.PreviousOutPoint)
+ }
+
+ // Now that we've confirmed we know about the UTXO, we'll check if it
+ // is locked.
+ if cred.Locked {
+ return nil, nil, fmt.Errorf("%w: %v", ErrUtxoLocked,
+ txIn.PreviousOutPoint)
+ }
+
+ // Now that we've confirmed we know about the UTXO, we'll proceed to
+ // gather the rest of the information required to decorate the PSBT
+ // input.
+ tx := &txDetail.MsgTx
+ utxo := tx.TxOut[txIn.PreviousOutPoint.Index]
+
+ return tx, utxo, nil
+}
+
+// findCredit determines whether a transaction's details contain a credit for a
+// specific output index.
+func findCredit(txDetail *wtxmgr.TxDetails,
+ outputIndex uint32) *wtxmgr.CreditRecord {
+
+ for i := range txDetail.Credits {
+ if txDetail.Credits[i].Index == outputIndex {
+ return &txDetail.Credits[i]
+ }
+ }
+
+ return nil
+}
+
+// FundPsbt performs coin selection and funds the PSBT.
+//
+// It executes the funding logic by:
+// 1. Validation: Checking the `FundIntent` for consistency.
+// 2. Creation: Converting the intent into a `TxIntent` and delegating to the
+// `CreateTransaction` method (which handles the underlying coin selection
+// and change calculation algorithms).
+// 3. Population: Calling `populatePsbtPacket` to apply the selected inputs and
+// change output to the PSBT structure and sort it according to BIP 69.
+func (w *Wallet) FundPsbt(ctx context.Context, intent *FundIntent) (
+ *psbt.Packet, int32, error) {
+
+ err := w.state.validateSynced()
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // Validate the funding intent before proceeding.
+ err = w.validateFundIntent(intent)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // Create a TxIntent from the FundIntent.
+ txIntent := w.createTxIntent(intent)
+
+ // Create the transaction.
+ authoredTx, err := w.CreateTransaction(ctx, txIntent)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // Populate the PSBT packet with the new transaction details.
+ packet, changeIndex, err := w.populatePsbtPacket(
+ ctx, intent.Packet, authoredTx,
+ )
+ if err != nil {
+ return nil, 0, err
+ }
+
+ return packet, changeIndex, nil
+}
+
+// populatePsbtPacket updates the PSBT packet with the new transaction details,
+// decorates the inputs, and handles the change output. It returns the modified
+// packet and the index of the change output, or -1 if no change output was
+// added.
+func (w *Wallet) populatePsbtPacket(ctx context.Context, packet *psbt.Packet,
+ authoredTx *txauthor.AuthoredTx) (*psbt.Packet, int32, error) {
+
+ // The authored transaction contains the selected inputs and the change
+ // output (if any). We'll update the PSBT packet with this new
+ // unsigned transaction.
+ packet.UnsignedTx = authoredTx.Tx
+
+ // We'll also re-initialize the input and output slices to match the
+ // dimensions of the new transaction. This is crucial because the
+ // `authoredTx` may have a different output order than the original PSBT
+ // (e.g., due to change output randomization in txauthor.AuthoredTx),
+ // which would otherwise cause a misalignment between the wire outputs
+ // and the PSBT's output metadata. By resetting, we ensure consistency.
+ packet.Inputs = make([]psbt.PInput, len(authoredTx.Tx.TxIn))
+ packet.Outputs = make([]psbt.POutput, len(authoredTx.Tx.TxOut))
+
+ // With the new inputs in place, we'll decorate them with UTXO and
+ // derivation information from the wallet. We set `skipUnknown` to
+ // false because all inputs in the `authoredTx` must be known to the
+ // wallet.
+ _, err := w.DecorateInputs(ctx, packet, false)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // If a change output was created, we need to add its derivation
+ // information to the corresponding PSBT output.
+ var changeOutput *wire.TxOut
+ if authoredTx.ChangeIndex >= 0 {
+ err := w.addChangeOutputInfo(ctx, packet, authoredTx)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ changeOutput = authoredTx.Tx.TxOut[authoredTx.ChangeIndex]
+ }
+
+ // The PSBT specification recommends that inputs and outputs are
+ // sorted. This is done for privacy and standardization. We'll sort
+ // the packet in place.
+ err = psbt.InPlaceSort(packet)
+ if err != nil {
+ return nil, 0, fmt.Errorf("cannot sort psbt: %w", err)
+ }
+
+ // After sorting, the original change index from `authoredTx` is no
+ // longer valid. We need to find the new index of the change output in
+ // the sorted list.
+ changeIndex, err := findChangeIndex(changeOutput, packet)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ return packet, changeIndex, nil
+}
+
+// addChangeOutputInfo is a helper function that adds the derivation information
+// for a change output to a PSBT packet.
+func (w *Wallet) addChangeOutputInfo(ctx context.Context, packet *psbt.Packet,
+ authoredTx *txauthor.AuthoredTx) error {
+
+ // TODO(yy): The calls to `w.ScriptForOutput` and `w.AddressInfo` both
+ // involve database lookups. This could be optimized to a single
+ // database call to fetch all necessary address information. However,
+ // for now, this approach favors readability over micro-optimization,
+ // as this path is not performance-critical.
+ //
+ // First, we'll get the script information for the change output.
+ changeScriptInfo, err := w.ScriptForOutput(
+ ctx, *authoredTx.Tx.TxOut[authoredTx.ChangeIndex],
+ )
+ if err != nil {
+ return err
+ }
+
+ // Then, we'll get the managed address for the change output.
+ changeAddr, err := w.AddressInfo(ctx, changeScriptInfo.Addr.Address())
+ if err != nil {
+ return err
+ }
+
+ // We'll ensure that the change address is a public key address.
+ managedPubKeyAddr, ok := changeAddr.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ return ErrChangeAddressNotManagedPubKey
+ }
+
+ // With the managed address, we can now create the PSBT output
+ // information.
+ changeOutputInfo, err := createOutputInfo(
+ authoredTx.Tx.TxOut[authoredTx.ChangeIndex],
+ managedPubKeyAddr,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Finally, we'll add the change output information to the PSBT packet.
+ packet.Outputs[authoredTx.ChangeIndex] = *changeOutputInfo
+
+ return nil
+}
+
+// validateFundIntent performs a series of checks on a FundIntent to ensure it
+// is well-formed and unambiguous. This function is called before any funding
+// logic to ensure that the caller has provided a valid intent.
+//
+// The following checks are performed:
+// 1. The intent must not be nil.
+// 2. The PSBT packet must not be nil.
+// 3. If the PSBT has no inputs (automatic coin selection mode), it must have
+// at least one output.
+// 4. If the PSBT has inputs, a coin selection policy must not be specified
+// (mutual exclusivity).
+func (w *Wallet) validateFundIntent(intent *FundIntent) error {
+ // The intent must not be nil.
+ if intent == nil {
+ return ErrNilArguments
+ }
+
+ // The PSBT packet must not be nil.
+ if intent.Packet == nil {
+ return fmt.Errorf(
+ "%w: psbt packet cannot be nil", ErrNilTxIntent,
+ )
+ }
+
+ // If the PSBT has no inputs (automatic coin selection mode), it must
+ // have at least one output.
+ if len(intent.Packet.UnsignedTx.TxIn) == 0 &&
+ len(intent.Packet.UnsignedTx.TxOut) == 0 {
+
+ return ErrPacketOutputsMissing
+ }
+
+ // If the PSBT has inputs, a coin selection policy must not be
+ // specified (mutual exclusivity).
+ if len(intent.Packet.UnsignedTx.TxIn) > 0 && intent.Policy != nil {
+ return ErrInputsAndPolicy
+ }
+
+ return nil
+}
+
+// findChangeIndex finds the new index of the change output after the PSBT has
+// been sorted.
+func findChangeIndex(changeOutput *wire.TxOut,
+ packet *psbt.Packet) (int32, error) {
+
+ if changeOutput == nil {
+ return -1, nil
+ }
+
+ for i, txOut := range packet.UnsignedTx.TxOut {
+ if i > math.MaxInt32 {
+ return 0, ErrChangeIndexOutOfRange
+ }
+
+ if psbt.TxOutsEqual(changeOutput, txOut) {
+ // The above check ensures that the conversion to int32
+ // is safe.
+ //
+ //nolint:gosec
+ return int32(i), nil
+ }
+ }
+
+ return -1, nil
+}
+
+// createTxIntent creates a TxIntent from a FundIntent. This helper function
+// acts as a pure adapter, translating the high-level funding request into a
+// concrete transaction creation plan for the wallet's underlying `TxCreator`.
+//
+// It does not perform any database lookups or validation. Instead, it relies
+// on the API contract that the caller must provide a fully specified
+// `InputsPolicy` (with both `AccountName` and `KeyScope`) if automatic coin
+// selection is desired. The underlying `TxCreator` is responsible for
+// validating the existence of the specified account.
+//
+// The function is responsible for two main pieces of logic:
+// 1. Input Source Determination: It inspects the incoming PSBT. If it has no
+// inputs, it uses the `InputsPolicy` from the intent. If it has inputs,
+// it creates an `InputsManual` source.
+// 2. Change Source Mapping: It directly maps the `FundIntent.ChangeSource`
+// to `TxIntent.ChangeSource`. Any default change source determination
+// (e.g., when `FundIntent.ChangeSource` is `nil`) is delegated to the
+// underlying `TxCreator`'s `determineChangeSource` method.
+func (w *Wallet) createTxIntent(intent *FundIntent) *TxIntent {
+ // First, we'll copy the outputs from the PSBT packet to the TxIntent.
+ outputs := make([]wire.TxOut, len(intent.Packet.UnsignedTx.TxOut))
+ for i, txOut := range intent.Packet.UnsignedTx.TxOut {
+ outputs[i] = *txOut
+ }
+
+ // The fee rate and label are passed through directly.
+ txIntent := &TxIntent{
+ Outputs: outputs,
+ FeeRate: intent.FeeRate,
+ Label: intent.Label,
+ }
+
+ // Now, we'll determine the input source based on whether the PSBT
+ // packet already contains inputs.
+ if len(intent.Packet.UnsignedTx.TxIn) == 0 {
+ // If the packet has no inputs, we'll use the policy-based input
+ // source from the intent. This will trigger automatic coin
+ // selection by the wallet. The caller is responsible for
+ // providing a complete `ScopedAccount` with both `AccountName`
+ // and `KeyScope`.
+ txIntent.Inputs = intent.Policy
+ } else {
+ // If the packet already has inputs, we'll use a manual input
+ // source. This bypasses coin selection and tells the wallet to
+ // use the exact inputs provided in the PSBT.
+ utxos := make(
+ []wire.OutPoint, len(intent.Packet.UnsignedTx.TxIn),
+ )
+ for i, txIn := range intent.Packet.UnsignedTx.TxIn {
+ utxos[i] = txIn.PreviousOutPoint
+ }
+
+ txIntent.Inputs = &InputsManual{
+ UTXOs: utxos,
+ }
+ }
+
+ // The change source is directly mapped from the FundIntent. If it is
+ // nil, the underlying `TxCreator` will determine a default.
+ txIntent.ChangeSource = intent.ChangeSource
+
+ return txIntent
+}
+
+// SignPsbt adds partial signatures to the PSBT.
+//
+// It achieves this by:
+// 1. Pre-computation: Creating a `PsbtPrevOutputFetcher` and calculating the
+// transaction sighashes once for efficiency.
+// 2. Iteration: Processing each input to determine if it is owned by the
+// wallet and ready for signing.
+// 3. Derivation Validation: Enforcing strict rules on derivation paths (one
+// path per input) to ensure deterministic key selection.
+// 4. Signing: dispatching to `signTaprootPsbtInput` or `signBip32PsbtInput` to
+// generate the raw ECDSA or Schnorr signature using the underlying
+// `Signer`.
+func (w *Wallet) SignPsbt(ctx context.Context, params *SignPsbtParams) (
+ *SignPsbtResult, error) {
+
+ err := w.state.canSign()
+ if err != nil {
+ return nil, err
+ }
+
+ if params == nil {
+ return nil, ErrNilArguments
+ }
+
+ packet := params.Packet
+
+ // signedInputs will track the indices of all inputs that we
+ // successfully sign during this operation. This is useful for callers
+ // (e.g., LND) to know which inputs were partially signed by this
+ // wallet.
+ var signedInputs = make([]uint32, 0, len(packet.Inputs))
+
+ // Before proceeding, we ensure that the PSBT inputs are in a state
+ // that allows them to be signed. This check verifies that each input
+ // has at least a WitnessUtxo or NonWitnessUtxo, which is crucial for
+ // signature generation. If this check fails, it indicates a malformed
+ // or incomplete PSBT that cannot be signed.
+ err = psbt.InputsReadyToSign(packet)
+ if err != nil {
+ return nil, fmt.Errorf("psbt inputs not ready: %w", err)
+ }
+
+ // We create a `PrevOutputFetcher` to allow `txscript` to retrieve the
+ // previous transaction outputs needed for sighash generation. This is
+ // a critical component as the value and script of the UTXO being spent
+ // are part of the data signed. Following this, we compute the
+ // transaction's sighashes, which are integral to producing valid
+ // signatures for each input.
+ prevOutFetcher, err := PsbtPrevOutputFetcher(packet)
+ if err != nil {
+ return nil, fmt.Errorf("error creating prevOutFetcher: %w", err)
+ }
+
+ sigHashes := txscript.NewTxSigHashes(
+ packet.UnsignedTx, prevOutFetcher,
+ )
+
+ // Iterate through each input in the PSBT. For each input, we attempt
+ // to sign it if the wallet can provide the necessary key material and
+ // if the input itself is in a signable state. This loop handles both
+ // Taproot (SegWit v1) and legacy/SegWit v0 inputs, adapting the
+ // signing process accordingly.
+ for i := range packet.Inputs {
+ signed, err := w.signPsbtInput(
+ ctx, packet, i, sigHashes, params.InputTweakers,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("input %d: %w", i, err)
+ }
+
+ if signed {
+ // If signing was successful, we record the index of
+ // this input as one that the wallet has contributed a
+ // signature to.
+ //
+ // We convert the index i (int) to uint32. This is safe
+ // because a Bitcoin transaction is strictly bounded by
+ // the block size limit.
+ //nolint:gosec
+ signedInputs = append(signedInputs, uint32(i))
+ }
+ }
+
+ // Finally, return the result, which includes the list of inputs that
+ // were successfully signed and the modified (partially) signed PSBT
+ // packet.
+ return &SignPsbtResult{
+ SignedInputs: signedInputs,
+ Packet: packet,
+ }, nil
+}
+
+// signPsbtInput attempts to sign a single input of the PSBT. It returns true
+// if the input was successfully signed, false if it was skipped (e.g. already
+// signed or not owned), and an error if a fatal signing error occurred.
+func (w *Wallet) signPsbtInput(ctx context.Context, packet *psbt.Packet,
+ i int, sigHashes *txscript.TxSigHashes,
+ tweakers map[int]PrivKeyTweaker) (bool, error) {
+
+ pInput := &packet.Inputs[i]
+
+ // First, we check if the current input should be skipped. This
+ // helper function identifies inputs that are already finalized
+ // or lack any derivation information (meaning we don't own the
+ // key or it's not intended for us to sign). Skipping these
+ // allows the wallet to focus on relevant inputs and gracefully
+ // handle multi-signer PSBTs.
+ if shouldSkipInput(pInput, i) {
+ return false, nil
+ }
+
+ // Validate the derivation information to ensure we have an
+ // unambiguous signing path. Our policy enforces that an input
+ // should not contain conflicting Taproot and BIP32 derivation
+ // paths, nor multiple paths of the same type. This prevents
+ // misinterpretations and ensures deterministic signing.
+ isTaproot, err := validateDerivation(pInput, i)
+ if err != nil {
+ return false, err
+ }
+
+ // Based on the validated derivation information, we dispatch
+ // the signing task to the appropriate helper function. If the
+ // input is identified as Taproot, we use
+ // `signTaprootPsbtInput`; otherwise, we assume it's a legacy
+ // or SegWit v0 input and use `signBip32PsbtInput`.
+ if isTaproot {
+ err = w.signTaprootPsbtInput(
+ ctx, packet, i, sigHashes, tweakers[i],
+ )
+ } else {
+ err = w.signBip32PsbtInput(
+ ctx, packet, i, sigHashes, tweakers[i],
+ )
+ }
+
+ // If an error occurred during signing, we first check if it's
+ // an error that permits us to skip the current input (e.g., if
+ // the key is not found, implying it's another signer's input
+ // in a collaborative PSBT). If the error is *not* skippable,
+ // it indicates a critical issue, and we return it immediately.
+ // Otherwise, we continue to the next input.
+ if err != nil {
+ if shouldSkipSigningError(err, i) {
+ return false, nil
+ }
+
+ return false, err
+ }
+
+ return true, nil
+}
+
+// parseBip32Path parses a raw derivation path (sequence of uint32s) and
+// verifies that it conforms to the BIP44-like hierarchy structure
+// (m / purpose' / coin_type' / account' / branch / index) used by this wallet.
+//
+// It enforces the following wallet-specific constraints (based on BIP44/49/84
+// conventions):
+// 1. Path length must be exactly 5.
+// 2. First 3 elements must be hardened.
+// 3. Coin type must match the wallet's chain parameters.
+//
+// NOTE: While the underlying cryptographic derivation is defined by BIP32, the
+// specific requirement for a 5-level path with hardened prefixes is strictly a
+// convention of the BIP44/49/84/86 standards, not a constraint of BIP32
+// itself.
+//
+// Returns `ErrInvalidBip32Path` if the path is invalid.
+func (w *Wallet) parseBip32Path(path []uint32) (BIP32Path, error) {
+ // The BIP32 path must have exactly 5 elements:
+ // m / purpose' / coin_type' / account' / branch / index
+ if len(path) != BIP32PathLength {
+ return BIP32Path{}, fmt.Errorf("%w: length %d",
+ ErrInvalidBip32Path, len(path))
+ }
+
+ // The first 3 elements (Purpose, CoinType, Account) must be hardened.
+ // We check this by verifying they are >= HardenedKeyStart.
+ for i := range 3 {
+ if path[i] < hdkeychain.HardenedKeyStart {
+ return BIP32Path{}, fmt.Errorf("%w: element %d not "+
+ "hardened", ErrInvalidBip32Path, i)
+ }
+ }
+
+ // Helper to extract values (remove hardened flag).
+ purpose := path[0] - hdkeychain.HardenedKeyStart
+ coinType := path[1] - hdkeychain.HardenedKeyStart
+ account := path[2] - hdkeychain.HardenedKeyStart
+ branch := path[3]
+ index := path[4]
+
+ // Verify that the coin type matches the wallet's chain parameters.
+ if coinType != w.cfg.ChainParams.HDCoinType {
+ return BIP32Path{}, fmt.Errorf("%w: expected coin type %d, "+
+ "got %d", ErrInvalidBip32Path,
+ w.cfg.ChainParams.HDCoinType, coinType)
+ }
+
+ scope := waddrmgr.KeyScope{
+ Purpose: purpose,
+ Coin: coinType,
+ }
+
+ bip32Path := BIP32Path{
+ KeyScope: scope,
+ DerivationPath: waddrmgr.DerivationPath{
+ Account: account,
+ Branch: branch,
+ Index: index,
+ },
+ }
+
+ return bip32Path, nil
+}
+
+// addressTypeFromPurpose maps a BIP purpose to a wallet address type.
+func addressTypeFromPurpose(purpose uint32) (waddrmgr.AddressType, error) {
+ // TODO(yy): Currently, we hardcode the supported BIP purposes.
+ // A more robust solution would dynamically query the `waddrmgr` to
+ // determine supported key scopes configured in the database, allowing
+ // for custom purposes (e.g., LND's 1017 purpose key) to be seamlessly
+ // supported without code changes here.
+ switch purpose {
+ case waddrmgr.KeyScopeBIP0044.Purpose:
+ return waddrmgr.PubKeyHash, nil
+
+ case waddrmgr.KeyScopeBIP0049Plus.Purpose:
+ return waddrmgr.NestedWitnessPubKey, nil
+
+ case waddrmgr.KeyScopeBIP0084.Purpose:
+ return waddrmgr.WitnessPubKey, nil
+
+ case waddrmgr.KeyScopeBIP0086.Purpose:
+ return waddrmgr.TaprootPubKey, nil
+
+ default:
+ return 0, fmt.Errorf("%w: %d", ErrUnknownBip32Purpose, purpose)
+ }
+}
+
+// shouldSkipInput determines whether the input at the given index should be
+// skipped during the signing process.
+//
+// It checks for two conditions:
+// 1. If the input already has a final script witness, it is considered
+// complete and is skipped.
+// 2. If the input lacks any derivation information (both Taproot and BIP32),
+// it implies that the wallet does not have the key to sign it, so it is
+// skipped.
+func shouldSkipInput(pInput *psbt.PInput, idx int) bool {
+ // Skip if already finalized.
+ if len(pInput.FinalScriptWitness) > 0 {
+ log.Debugf("Skipping input %d: already has final "+
+ "script witness", idx)
+
+ return true
+ }
+
+ // Check if we have any derivation info.
+ tapCount := len(pInput.TaprootBip32Derivation)
+ bip32Count := len(pInput.Bip32Derivation)
+
+ if tapCount == 0 && bip32Count == 0 {
+ // No derivation info, so we can't sign this input. We skip it
+ // silently, assuming it's not ours or not meant to be signed
+ // by us.
+ log.Debugf("Skipping input %d: no derivation info", idx)
+
+ return true
+ }
+
+ return false
+}
+
+// shouldSkipSigningError determines whether a signing error should be skipped
+// (logged and ignored) or returned as a fatal error.
+//
+// It handles cases typical in collaborative workflows where an input might
+// belong to another signer, already be signed, or use an unknown derivation
+// scheme.
+func shouldSkipSigningError(err error, idx int) bool {
+ // If the input is already signed, we can just skip it.
+ if errors.Is(err, errAlreadySigned) {
+ log.Debugf("Skipping input %d: already signed", idx)
+ return true
+ }
+
+ // In a collaborative PSBT workflow, the transaction may contain inputs
+ // that belong to other parties. Even if a derivation path is present
+ // and valid (e.g. BIP-84), it might correspond to a different signer's
+ // key (same path, different seed).
+ //
+ // If we encounter `errComputeRawSig`, it means we failed to produce a
+ // signature. This usually happens because we don't have the private
+ // key for the derived address (it's someone else's input). In this
+ // case, we skip the input and log a debug message, allowing us to
+ // proceed and sign the inputs that we DO own.
+ if errors.Is(err, errComputeRawSig) {
+ log.Debugf("Skipping input %d: %v", idx, err)
+ return true
+ }
+
+ // If the derivation path has an unknown purpose, it likely belongs to
+ // another signer or a scheme we don't support. We skip these as well.
+ if errors.Is(err, ErrUnknownBip32Purpose) {
+ log.Debugf("Skipping input %d: unknown BIP32 purpose", idx)
+ return true
+ }
+
+ return false
+}
+
+// validateDerivation inspects the derivation information for the input and
+// ensures it conforms to the supported signing modes.
+//
+// It enforces the following rules:
+// 1. Only one derivation path per type is supported.
+// 2. Taproot and BIP32 derivation information cannot be present
+// simultaneously. This avoids ambiguity about which signing path to take.
+//
+// It returns a boolean indicating whether the input is a Taproot input (true)
+// or a legacy/SegWit input (false), and an error if the validation fails.
+func validateDerivation(pInput *psbt.PInput, idx int) (bool, error) {
+ tapCount := len(pInput.TaprootBip32Derivation)
+ bip32Count := len(pInput.Bip32Derivation)
+
+ if tapCount > 1 {
+ return false, fmt.Errorf("input %d: %w", idx,
+ ErrUnsupportedMultipleTaprootDerivation)
+ }
+
+ if bip32Count > 1 {
+ return false, fmt.Errorf("input %d: %w", idx,
+ ErrUnsupportedMultipleBip32Derivation)
+ }
+
+ if tapCount == 1 && bip32Count == 1 {
+ // This is ambiguous/invalid state in the PSBT.
+ return false, fmt.Errorf("input %d: %w", idx,
+ ErrAmbiguousDerivation)
+ }
+
+ // If we have Taproot info, it's a Taproot input.
+ return tapCount == 1, nil
+}
+
+// fetchPsbtUtxo extracts the UTXO for the given input index from the PSBT
+// packet. It prioritizes the WitnessUtxo if present, otherwise falls back to
+// the NonWitnessUtxo.
+//
+// NOTE: While psbt.InputsReadyToSign guarantees that at least one of these
+// fields is set, this function performs additional checks and returns an error
+// if the UTXO information is missing or the index is out of bounds, preventing
+// panics on malformed packets.
+func fetchPsbtUtxo(packet *psbt.Packet, idx int) (*wire.TxOut, error) {
+ if idx >= len(packet.Inputs) {
+ return nil, fmt.Errorf("%w: psbt input index %d",
+ ErrIndexOutOfBounds, idx)
+ }
+
+ pInput := &packet.Inputs[idx]
+
+ if pInput.WitnessUtxo != nil {
+ return pInput.WitnessUtxo, nil
+ }
+
+ if pInput.NonWitnessUtxo == nil {
+ return nil, fmt.Errorf("%w: %d",
+ ErrInputMissingUtxoInfo, idx)
+ }
+
+ if idx >= len(packet.UnsignedTx.TxIn) {
+ return nil, fmt.Errorf("%w: psbt input index %d for "+
+ "UnsignedTx inputs", ErrIndexOutOfBounds, idx)
+ }
+
+ prevIdx := packet.UnsignedTx.TxIn[idx].PreviousOutPoint.Index
+
+ if int(prevIdx) >= len(pInput.NonWitnessUtxo.TxOut) {
+ return nil, fmt.Errorf("%w: input %d prevOut index %d",
+ ErrIndexOutOfBounds, idx, prevIdx)
+ }
+
+ return pInput.NonWitnessUtxo.TxOut[prevIdx], nil
+}
+
+// checkTaprootScriptSpendSig checks if a Taproot script-path signature already
+// exists for the given input and derivation details. It returns
+// errAlreadySigned, if a matching signature is found, otherwise nil.
+func checkTaprootScriptSpendSig(pInput *psbt.PInput,
+ tapDerivation *psbt.TaprootBip32Derivation) error {
+
+ for _, sig := range pInput.TaprootScriptSpendSig {
+ if bytes.Equal(
+ sig.XOnlyPubKey, tapDerivation.XOnlyPubKey,
+ ) && bytes.Equal(
+ sig.LeafHash, tapDerivation.LeafHashes[0],
+ ) {
+
+ return errAlreadySigned
+ }
+ }
+
+ return nil
+}
+
+// addTaprootSigToPInput adds the generated signature to the PSBT input.
+//
+// NOTE: This method modifies the `pInput` in-place.
+func addTaprootSigToPInput(pInput *psbt.PInput, sig []byte,
+ sighashType txscript.SigHashType, details TaprootSpendDetails,
+ tapDerivation *psbt.TaprootBip32Derivation) {
+
+ if details.SpendPath == KeyPathSpend {
+ if sighashType != txscript.SigHashDefault {
+ sig = append(sig, byte(sighashType))
+ }
+
+ pInput.TaprootKeySpendSig = sig
+ } else {
+ tsSig := &psbt.TaprootScriptSpendSig{
+ XOnlyPubKey: tapDerivation.XOnlyPubKey,
+ LeafHash: tapDerivation.LeafHashes[0],
+ Signature: sig,
+ SigHash: pInput.SighashType,
+ }
+ pInput.TaprootScriptSpendSig = append(
+ pInput.TaprootScriptSpendSig, tsSig,
+ )
+ }
+}
+
+// addBip32SigToPInput adds the generated signature to the PSBT input for
+// non-Taproot (Legacy/SegWit) inputs.
+//
+// NOTE: This method modifies the `pInput` in-place.
+func addBip32SigToPInput(pInput *psbt.PInput, sig []byte,
+ sighashType txscript.SigHashType, derivation *psbt.Bip32Derivation,
+ addrType waddrmgr.AddressType) {
+
+ // Append sighash type if needed (SegWit v0).
+ if addrType == waddrmgr.NestedWitnessPubKey ||
+ addrType == waddrmgr.WitnessPubKey {
+
+ sig = append(sig, byte(sighashType))
+ }
+
+ pInput.PartialSigs = append(pInput.PartialSigs,
+ &psbt.PartialSig{
+ PubKey: derivation.PubKey,
+ Signature: sig,
+ },
+ )
+}
+
+// createTaprootSpendDetails determines the signing method (Key Path vs Script
+// Path) and constructs the necessary details for generating a Taproot
+// signature.
+//
+// It inspects the derivation info and the PSBT input to decide:
+// 1. Key Path Spend: If `LeafHashes` is empty, it assumes a key path spend.
+// It validates that the input hasn't already been signed with a key path
+// signature.
+// 2. Script Path Spend: If `LeafHashes` has exactly one entry, it assumes a
+// script path spend. It validates the presence and correctness of the
+// corresponding `TaprootLeafScript` and checks if a signature for this
+// specific leaf and key already exists.
+//
+// Returns `ErrUnsupportedTaprootLeafCount` if `LeafHashes` has more than 1
+// entry.
+// Returns `ErrMissingTaprootLeafScript` or `ErrTaprootLeafHashMismatch` for
+// invalid script path state.
+// Returns `errAlreadySigned` if a valid signature already exists for the
+// target path.
+func createTaprootSpendDetails(pInput *psbt.PInput,
+ tapDerivation *psbt.TaprootBip32Derivation) (
+ TaprootSpendDetails, error) {
+
+ var details TaprootSpendDetails
+
+ nLeafHashes := len(tapDerivation.LeafHashes)
+ switch nLeafHashes {
+ // Case 1: Key Path Spend.
+ // A non-empty merkle root means we committed to a taproot hash
+ // that we need to use in the tap tweak. If LeafHashes is empty, it
+ // means we are signing for the internal key (Key Path).
+ case 0:
+ // If a Merkle Root is provided, it must be exactly 32 bytes.
+ if len(pInput.TaprootMerkleRoot) > 0 &&
+ len(pInput.TaprootMerkleRoot) != sha256.Size {
+
+ return details, fmt.Errorf("%w: expected %d, got %d",
+ ErrInvalidTaprootMerkleRootLength,
+ sha256.Size, len(pInput.TaprootMerkleRoot))
+ }
+
+ details = TaprootSpendDetails{
+ SpendPath: KeyPathSpend,
+ Tweak: pInput.TaprootMerkleRoot,
+ }
+
+ // Check if we have already signed this input.
+ if len(pInput.TaprootKeySpendSig) > 0 {
+ return details, errAlreadySigned
+ }
+
+ // Case 2: Script Path Spend (Single Leaf).
+ // Currently, we only support signing for one leaf at a time.
+ case 1:
+ // If we're supposed to be signing for a leaf hash, we also
+ // expect the leaf script that hashes to that hash in the
+ // appropriate field.
+ if len(pInput.TaprootLeafScript) != 1 {
+ return details, fmt.Errorf("%w: expected 1, got %d",
+ ErrMissingTaprootLeafScript,
+ len(pInput.TaprootLeafScript))
+ }
+
+ leafScript := pInput.TaprootLeafScript[0]
+ leaf := txscript.TapLeaf{
+ LeafVersion: leafScript.LeafVersion,
+ Script: leafScript.Script,
+ }
+ h := leaf.TapHash()
+
+ // Verify that the calculated hash of the provided script
+ // matches the leaf hash specified in the derivation info.
+ if !bytes.Equal(h[:], tapDerivation.LeafHashes[0]) {
+ return details, ErrTaprootLeafHashMismatch
+ }
+
+ details = TaprootSpendDetails{
+ SpendPath: ScriptPathSpend,
+ WitnessScript: leafScript.Script,
+ }
+
+ // Check if we have already signed this input.
+ err := checkTaprootScriptSpendSig(pInput, tapDerivation)
+ if err != nil {
+ return details, err
+ }
+
+ default:
+ return details, fmt.Errorf("%w: %d",
+ ErrUnsupportedTaprootLeafCount, nLeafHashes)
+ }
+
+ return details, nil
+}
+
+// createBip32SpendDetails constructs the spending details (e.g. redeem scripts,
+// witness scripts) required for signing a BIP32 input.
+//
+// It inspects the input's address type and existing script information in the
+// PSBT to determine the correct spending path (Legacy, SegWit v0, or Nested
+// SegWit).
+//
+// Returns `ErrUnknownAddressType` if the address type is not supported.
+// Returns `errAlreadySigned` if a valid signature for the derived key already
+// exists.
+func createBip32SpendDetails(pInput *psbt.PInput, utxo *wire.TxOut,
+ addrType waddrmgr.AddressType,
+ derivation *psbt.Bip32Derivation) (SpendDetails, error) {
+
+ // Determine the script to use for signing (subScript).
+ var subScript []byte
+ switch {
+ case len(pInput.RedeemScript) > 0:
+ subScript = pInput.RedeemScript
+
+ case len(pInput.WitnessScript) > 0:
+ subScript = pInput.WitnessScript
+
+ default:
+ subScript = utxo.PkScript
+ }
+
+ var details SpendDetails
+ switch addrType {
+ case waddrmgr.WitnessPubKey, waddrmgr.NestedWitnessPubKey:
+ details = SegwitV0SpendDetails{WitnessScript: subScript}
+
+ case waddrmgr.PubKeyHash:
+ details = LegacySpendDetails{RedeemScript: subScript}
+
+ case waddrmgr.Script, waddrmgr.RawPubKey,
+ waddrmgr.WitnessScript, waddrmgr.TaprootPubKey,
+ waddrmgr.TaprootScript:
+ return nil, fmt.Errorf("%w: %v", ErrUnknownAddressType,
+ addrType)
+ default:
+ return nil, fmt.Errorf("%w: %v", ErrUnknownAddressType,
+ addrType)
+ }
+
+ // Check if we have already signed this input.
+ for _, sig := range pInput.PartialSigs {
+ if bytes.Equal(sig.PubKey, derivation.PubKey) {
+ return nil, errAlreadySigned
+ }
+ }
+
+ return details, nil
+}
+
+// signTaprootPsbtInput attempts to sign a single Taproot input of a PSBT.
+//
+// It performs the following steps:
+// 1. Parses the BIP32 derivation path to ensure it is valid.
+// 2. Determines the specific spending path (Key Path vs Script Path) using
+// createTaprootSpendDetails.
+// 3. Computes the raw Schnorr signature using the wallet's signer.
+// 4. Adds the generated signature to the PSBT input (either as the key spend
+// signature or as a script spend signature).
+//
+// Returns an error if the input is invalid, the key is not found, or the
+// signing operation fails.
+func (w *Wallet) signTaprootPsbtInput(ctx context.Context, packet *psbt.Packet,
+ idx int, sigHashes *txscript.TxSigHashes,
+ tweaker PrivKeyTweaker) error {
+
+ // It is safe to access packet.Inputs[idx] directly here because
+ // SignPsbt calls psbt.InputsReadyToSign before this method, which
+ // ensures that the Inputs slice corresponds to the UnsignedTx inputs.
+ pInput := &packet.Inputs[idx]
+
+ // Fetch the UTXO (Witness or NonWitness) needed for signing.
+ utxo, err := fetchPsbtUtxo(packet, idx)
+ if err != nil {
+ return err
+ }
+
+ tapDerivation := pInput.TaprootBip32Derivation[0]
+
+ // Parse and validate the BIP32 derivation path.
+ path, err := w.parseBip32Path(tapDerivation.Bip32Path)
+ if err != nil {
+ // If the derivation path is invalid, we can't sign.
+ return fmt.Errorf("invalid derivation path: %w", err)
+ }
+
+ // Determine the SpendDetails (Key Path or Script Path).
+ details, err := createTaprootSpendDetails(pInput, tapDerivation)
+ if err != nil {
+ return err
+ }
+
+ params := &RawSigParams{
+ Tx: packet.UnsignedTx,
+ InputIndex: idx,
+ Output: utxo,
+ SigHashes: sigHashes,
+ HashType: pInput.SighashType,
+ Path: path,
+ Tweaker: tweaker,
+ Details: details,
+ }
+
+ // Compute the raw signature.
+ sig, err := w.ComputeRawSig(ctx, params)
+ if err != nil {
+ return fmt.Errorf("%w: %w", errComputeRawSig, err)
+ }
+
+ // Apply the signature to the PSBT input.
+ addTaprootSigToPInput(
+ pInput, sig, params.HashType, details, tapDerivation,
+ )
+
+ return nil
+}
+
+// signBip32PsbtInput attempts to sign a single non-Taproot (Legacy/SegWit)
+// input of a PSBT.
+//
+// It performs the following steps:
+// 1. Parses the BIP32 derivation path to determine the address type.
+// 2. Constructs the spending details (redeem scripts, etc.) using
+// createBip32SpendDetails.
+// 3. Computes the raw ECDSA signature using the wallet's signer.
+// 4. Adds the generated signature to the PSBT input's PartialSigs list.
+//
+// Returns an error if the input is invalid, the key is not found, or the
+// signing operation fails.
+func (w *Wallet) signBip32PsbtInput(ctx context.Context, packet *psbt.Packet,
+ idx int, sigHashes *txscript.TxSigHashes,
+ tweaker PrivKeyTweaker) error {
+
+ // It is safe to access packet.Inputs[idx] directly here because
+ // SignPsbt calls psbt.InputsReadyToSign before this method, which
+ // ensures that the Inputs slice corresponds to the UnsignedTx inputs.
+ pInput := &packet.Inputs[idx]
+
+ // Fetch the UTXO (Witness or NonWitness) needed for signing.
+ utxo, err := fetchPsbtUtxo(packet, idx)
+ if err != nil {
+ return err
+ }
+
+ derivation := pInput.Bip32Derivation[0]
+
+ // Parse and validate the BIP32 derivation path.
+ path, err := w.parseBip32Path(derivation.Bip32Path)
+ if err != nil {
+ return fmt.Errorf("invalid derivation path: %w", err)
+ }
+
+ addrType, err := addressTypeFromPurpose(path.KeyScope.Purpose)
+ if err != nil {
+ return err
+ }
+
+ // Construct SpendDetails for Legacy/SegWit input.
+ details, err := createBip32SpendDetails(
+ pInput, utxo, addrType, derivation,
+ )
+ if err != nil {
+ return err
+ }
+
+ params := &RawSigParams{
+ Tx: packet.UnsignedTx,
+ InputIndex: idx,
+ Output: utxo,
+ SigHashes: sigHashes,
+ HashType: pInput.SighashType,
+ Path: path,
+ Tweaker: tweaker,
+ Details: details,
+ }
+
+ // Compute the raw signature.
+ sig, err := w.ComputeRawSig(ctx, params)
+ if err != nil {
+ return fmt.Errorf("%w: %w", errComputeRawSig, err)
+ }
+
+ // Apply the signature to the PSBT input.
+ addBip32SigToPInput(pInput, sig, params.HashType, derivation, addrType)
+
+ return nil
+}
+
+// FinalizePsbt finalizes the PSBT.
+//
+// It performs the finalization by:
+// 1. Auto-Signing: Iterating through all inputs and calling `finalizeInput`.
+// This helper attempts to generate a signature and script witness for any
+// inputs owned by the wallet that are missing them.
+// 2. Completion: Calling `psbt.MaybeFinalizeAll`, which checks if every input
+// in the packet has the necessary data to pass script validation. If so, it
+// constructs the final witnesses and strips the PSBT metadata, leaving a
+// ready-to-broadcast transaction.
+func (w *Wallet) FinalizePsbt(ctx context.Context, packet *psbt.Packet) error {
+ err := w.state.canSign()
+ if err != nil {
+ return err
+ }
+
+ // Check that the PSBT is structurally ready to be signed/finalized.
+ err = psbt.InputsReadyToSign(packet)
+ if err != nil {
+ return fmt.Errorf("psbt inputs not ready: %w", err)
+ }
+
+ tx := packet.UnsignedTx
+
+ // We create a `PrevOutputFetcher` to allow `txscript` to retrieve the
+ // previous transaction outputs needed for sighash generation. This is
+ // required for generating valid signatures, as the value and script of
+ // the UTXO being spent are part of the signed digest.
+ prevOutFetcher, err := PsbtPrevOutputFetcher(packet)
+ if err != nil {
+ return fmt.Errorf("error creating prevOutFetcher: %w", err)
+ }
+
+ // Compute the transaction's sighashes. This is an optimization to
+ // calculate the sighashes once and reuse them for all inputs, rather
+ // than recalculating them for each signature. This is particularly
+ // beneficial for transactions with many inputs.
+ sigHashes := txscript.NewTxSigHashes(tx, prevOutFetcher)
+
+ // Iterate through each input in the PSBT. For each input, we will
+ // check if we can sign and finalize it (i.e., if we own the UTXO and
+ // have the private key).
+ for i := range packet.Inputs {
+ err := w.finalizeInput(ctx, packet, i, sigHashes)
+ if err != nil {
+ return err
+ }
+ }
+
+ // Finally, attempt to finalize the entire PSBT. This will check if all
+ // inputs have final scripts (either added by us above or constructed
+ // from PartialSigs by the psbt library) and strip the partial data.
+ err = psbt.MaybeFinalizeAll(packet)
+ if err != nil {
+ return fmt.Errorf("error finalizing PSBT: %w", err)
+ }
+
+ return nil
+}
+
+// finalizeInput attempts to finalize a single input of the PSBT.
+func (w *Wallet) finalizeInput(ctx context.Context, packet *psbt.Packet,
+ idx int, sigHashes *txscript.TxSigHashes) error {
+
+ pInput := &packet.Inputs[idx]
+
+ // If the input is already finalized, we can skip it.
+ if len(pInput.FinalScriptWitness) > 0 ||
+ len(pInput.FinalScriptSig) > 0 {
+
+ log.Debugf("Skipping input %d: already finalized", idx)
+ return nil
+ }
+
+ // Fetch the UTXO for this input.
+ utxo, err := fetchPsbtUtxo(packet, idx)
+ if err != nil {
+ // This should not happen if InputsReadyToSign passed (which is
+ // called at the start of the function), as it guarantees the
+ // presence of WitnessUtxo or NonWitnessUtxo. However, for
+ // defensive programming, we log an error and continue to avoid
+ // aborting the process in case of unexpected data
+ // inconsistency.
+ log.Errorf("Input %d has no UTXO info: %v", idx, err)
+ return nil
+ }
+
+ // Attempt to compute the unlocking script (witness and/or
+ // sigScript) for this input.
+ params := &UnlockingScriptParams{
+ Tx: packet.UnsignedTx,
+ InputIndex: idx,
+ Output: utxo,
+ SigHashes: sigHashes,
+ HashType: pInput.SighashType,
+ }
+
+ unlockingScript, err := w.ComputeUnlockingScript(ctx, params)
+ if err != nil {
+ // If we can't generate the script (e.g. we don't own the key,
+ // or it's a type we don't support yet, or the account is
+ // watch-only), we just skip this input and let the finalizer
+ // try to use any existing partial signatures.
+ log.Debugf("Could not compute unlocking script for "+
+ "input %d: %v", idx, err)
+
+ return nil
+ }
+
+ err = addScriptToPInput(pInput, unlockingScript)
+ if err != nil {
+ return fmt.Errorf("failed to patch input %d: %w",
+ idx, err)
+ }
+
+ return nil
+}
+
+// addScriptToPInput applies the generated witness and/or sigScript to the PSBT
+// input.
+func addScriptToPInput(pInput *psbt.PInput,
+ unlockingScript *UnlockingScript) error {
+
+ // If we successfully generated a witness, serialize and attach
+ // it.
+ if len(unlockingScript.Witness) > 0 {
+ var witnessBuf bytes.Buffer
+
+ err := psbt.WriteTxWitness(&witnessBuf, unlockingScript.Witness)
+ if err != nil {
+ return fmt.Errorf("failed to serialize witness: %w",
+ err)
+ }
+
+ pInput.FinalScriptWitness = witnessBuf.Bytes()
+ }
+
+ // If we generated a sigScript (for legacy/nested P2SH), attach
+ // it.
+ if len(unlockingScript.SigScript) > 0 {
+ pInput.FinalScriptSig = unlockingScript.SigScript
+ }
+
+ return nil
+}
+
+// CombinePsbt merges multiple PSBTs into one.
+//
+// It implements the "Combiner" role by performing two passes:
+// 1. Validation Pass: It iterates through all packets to ensure they refer to
+// the exact same global transaction (TXID) and have matching input/output
+// counts.
+// 2. Construction Pass: It creates a new, combined PSBT packet (to avoid
+// mutating inputs). It then iterates through every provided packet
+// (including the first) and merges its data into the combined result. This
+// includes deduplicating signatures and aggregating scripts/derivations.
+func (w *Wallet) CombinePsbt(_ context.Context, psbts ...*psbt.Packet) (
+ *psbt.Packet, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ // 1. Validation Pass: Ensure compatibility of all packets and prepare
+ // a fresh result packet.
+ combined, err := validatePsbtMerge(psbts)
+ if err != nil {
+ return nil, err
+ }
+
+ // 2. Construction Pass: Merge data into the prepared packet.
+ //
+ // Iterate through ALL packets (including the first) and merge their
+ // contents into the combined packet.
+ for _, p := range psbts {
+ // Merge Global Unknowns.
+ combined.Unknowns = deduplicateUnknowns(
+ combined.Unknowns, p.Unknowns,
+ )
+
+ // Merge Inputs.
+ for j := range combined.Inputs {
+ err := mergePsbtInputs(
+ &combined.Inputs[j], &p.Inputs[j],
+ )
+ if err != nil {
+ return nil, fmt.Errorf("input %d merge "+
+ "failed: %w", j, err)
+ }
+ }
+
+ // Merge Outputs.
+ for j := range combined.Outputs {
+ err := mergePsbtOutputs(
+ &combined.Outputs[j], &p.Outputs[j],
+ )
+ if err != nil {
+ return nil, fmt.Errorf("output %d merge "+
+ "failed: %w", j, err)
+ }
+ }
+ }
+
+ // Post-merge Validation: Ensure the resulting packet is structurally
+ // sound (e.g. has necessary UTXO info). This acts as a final sanity
+ // check.
+ err = psbt.InputsReadyToSign(combined)
+ if err != nil {
+ return nil, fmt.Errorf("combined psbt validation failed: %w",
+ err)
+ }
+
+ return combined, nil
+}
+
+// validatePsbtMerge checks that a set of PSBT packets are compatible for
+// merging and returns a new, empty packet initialized with the transaction
+// structure, ready to be populated.
+func validatePsbtMerge(psbts []*psbt.Packet) (*psbt.Packet, error) {
+ if len(psbts) == 0 {
+ return nil, ErrNoPsbtsToCombine
+ }
+
+ base := psbts[0]
+ baseTxHash := base.UnsignedTx.TxHash()
+ nInputs := len(base.Inputs)
+ nOutputs := len(base.Outputs)
+
+ for i, p := range psbts[1:] {
+ if p.UnsignedTx.TxHash() != baseTxHash {
+ return nil, fmt.Errorf("%w: packet index %d",
+ ErrDifferentTransactions, i+1)
+ }
+
+ if len(p.Inputs) != nInputs {
+ return nil, fmt.Errorf("%w: packet index %d",
+ ErrInputCountMismatch, i+1)
+ }
+
+ if len(p.Outputs) != nOutputs {
+ return nil, fmt.Errorf("%w: packet index %d",
+ ErrOutputCountMismatch, i+1)
+ }
+ }
+
+ // Initialize a fresh packet using a deep copy of the unsigned
+ // transaction to ensure we don't mutate any of the input packets.
+ combined := &psbt.Packet{
+ UnsignedTx: base.UnsignedTx.Copy(),
+ Inputs: make([]psbt.PInput, nInputs),
+ Outputs: make([]psbt.POutput, nOutputs),
+ Unknowns: make([]*psbt.Unknown, 0),
+ }
+
+ return combined, nil
+}
+
+// mergePsbtInputs merges the source input into the destination input.
+//
+// It returns an error if any immutable fields (Scripts, UTXOs, SighashType)
+// conflict between the two inputs.
+func mergePsbtInputs(dest, src *psbt.PInput) error {
+ // Merge PartialSigs (deduplicating by pubkey).
+ dest.PartialSigs = deduplicatePartialSigs(
+ dest.PartialSigs, src.PartialSigs,
+ )
+
+ var err error
+
+ err = mergeSighashType(dest, src)
+ if err != nil {
+ return err
+ }
+
+ err = mergeInputScripts(dest, src)
+ if err != nil {
+ return err
+ }
+
+ // Merge BIP32 Derivations (deduplicating by pubkey).
+ dest.Bip32Derivation = deduplicateBip32Derivations(
+ dest.Bip32Derivation, src.Bip32Derivation,
+ )
+
+ // Merge Taproot Derivations (deduplicating by x-only pubkey).
+ dest.TaprootBip32Derivation = deduplicateTaprootBip32Derivations(
+ dest.TaprootBip32Derivation, src.TaprootBip32Derivation,
+ )
+
+ err = mergeTaprootKeySpendSig(dest, src)
+ if err != nil {
+ return err
+ }
+
+ dest.TaprootScriptSpendSig = deduplicateTaprootScriptSpendSigs(
+ dest.TaprootScriptSpendSig, src.TaprootScriptSpendSig,
+ )
+
+ err = mergeWitnessUtxo(dest, src)
+ if err != nil {
+ return err
+ }
+
+ err = mergeNonWitnessUtxo(dest, src)
+ if err != nil {
+ return err
+ }
+
+ // Merge Unknowns.
+ dest.Unknowns = deduplicateUnknowns(dest.Unknowns, src.Unknowns)
+
+ return nil
+}
+
+// mergePsbtOutputs merges the source output into the destination output.
+//
+// It returns an error if any immutable fields (Taproot Internal Key, Scripts)
+// conflict.
+func mergePsbtOutputs(dest, src *psbt.POutput) error {
+ // Merge BIP32 Derivations for outputs.
+ dest.Bip32Derivation = deduplicateBip32Derivations(
+ dest.Bip32Derivation, src.Bip32Derivation,
+ )
+
+ var err error
+
+ err = mergeTaprootInternalKey(dest, src)
+ if err != nil {
+ return err
+ }
+
+ // Merge Taproot BIP32 Derivations for outputs.
+ dest.TaprootBip32Derivation = deduplicateTaprootBip32Derivations(
+ dest.TaprootBip32Derivation, src.TaprootBip32Derivation,
+ )
+
+ err = mergeOutputScripts(dest, src)
+ if err != nil {
+ return err
+ }
+
+ // Merge Unknowns.
+ dest.Unknowns = deduplicateUnknowns(dest.Unknowns, src.Unknowns)
+
+ return nil
+}
+
+// deduplicatePartialSigs adds new partial signatures from src to dest,
+// avoiding duplicates based on pubkey.
+func deduplicatePartialSigs(dest, src []*psbt.PartialSig) []*psbt.PartialSig {
+ for _, sig := range src {
+ if !slices.ContainsFunc(dest, func(dSig *psbt.PartialSig) bool {
+ return bytes.Equal(dSig.PubKey, sig.PubKey)
+ }) {
+
+ dest = append(dest, sig)
+ }
+ }
+
+ return dest
+}
+
+// deduplicateBip32Derivations adds new BIP32 derivations from src to dest,
+// avoiding duplicates based on pubkey.
+func deduplicateBip32Derivations(
+ dest, src []*psbt.Bip32Derivation) []*psbt.Bip32Derivation {
+
+ for _, der := range src {
+ if !slices.ContainsFunc(
+ dest, func(dDer *psbt.Bip32Derivation) bool {
+ return bytes.Equal(dDer.PubKey, der.PubKey)
+ },
+ ) {
+
+ dest = append(dest, der)
+ }
+ }
+
+ return dest
+}
+
+// deduplicateTaprootBip32Derivations adds new Taproot BIP32 derivations
+// from src to dest, avoiding duplicates based on x-only pubkey.
+func deduplicateTaprootBip32Derivations(dest,
+ src []*psbt.TaprootBip32Derivation) []*psbt.TaprootBip32Derivation {
+
+ for _, der := range src {
+ if !slices.ContainsFunc(
+ dest, func(dDer *psbt.TaprootBip32Derivation) bool {
+ return bytes.Equal(
+ dDer.XOnlyPubKey, der.XOnlyPubKey,
+ )
+ },
+ ) {
+
+ dest = append(dest, der)
+ }
+ }
+
+ return dest
+}
+
+// deduplicateTaprootScriptSpendSigs adds new Taproot Script Spend Signatures
+// from src to dest, avoiding duplicates based on XOnlyPubKey and LeafHash.
+func deduplicateTaprootScriptSpendSigs(dest,
+ src []*psbt.TaprootScriptSpendSig) []*psbt.TaprootScriptSpendSig {
+
+ for _, srcSig := range src {
+ if !slices.ContainsFunc(
+ dest, func(destSig *psbt.TaprootScriptSpendSig) bool {
+ return bytes.Equal(
+ destSig.XOnlyPubKey, srcSig.XOnlyPubKey,
+ ) && bytes.Equal(
+ destSig.LeafHash, srcSig.LeafHash,
+ )
+ },
+ ) {
+
+ dest = append(dest, srcSig)
+ }
+ }
+
+ return dest
+}
+
+// mergeSighashType merges the SighashType field. Returns error on conflict.
+func mergeSighashType(dest, src *psbt.PInput) error {
+ if dest.SighashType != 0 && src.SighashType != 0 &&
+ dest.SighashType != src.SighashType {
+
+ return fmt.Errorf("%w: sighash type mismatch %v vs %v",
+ ErrPsbtMergeConflict, dest.SighashType, src.SighashType)
+ }
+
+ if dest.SighashType == 0 {
+ dest.SighashType = src.SighashType
+ }
+
+ return nil
+}
+
+// mergeInputScripts merges RedeemScript, WitnessScript, FinalScriptSig, and
+// FinalScriptWitness for inputs. Returns error on conflict.
+func mergeInputScripts(dest, src *psbt.PInput) error {
+ err := mergeRedeemScript(dest, src)
+ if err != nil {
+ return err
+ }
+
+ err = mergeWitnessScript(dest, src)
+ if err != nil {
+ return err
+ }
+
+ err = mergeFinalScriptSig(dest, src)
+ if err != nil {
+ return err
+ }
+
+ return mergeFinalScriptWitness(dest, src)
+}
+
+// mergeRedeemScript merges the RedeemScript field.
+func mergeRedeemScript(dest, src *psbt.PInput) error {
+ if len(dest.RedeemScript) > 0 && len(src.RedeemScript) > 0 &&
+ !bytes.Equal(dest.RedeemScript, src.RedeemScript) {
+
+ return fmt.Errorf("%w: redeem script mismatch",
+ ErrPsbtMergeConflict)
+ }
+
+ if len(dest.RedeemScript) == 0 {
+ dest.RedeemScript = src.RedeemScript
+ }
+
+ return nil
+}
+
+// mergeWitnessScript merges the WitnessScript field.
+func mergeWitnessScript(dest, src *psbt.PInput) error {
+ if len(dest.WitnessScript) > 0 && len(src.WitnessScript) > 0 &&
+ !bytes.Equal(dest.WitnessScript, src.WitnessScript) {
+
+ return fmt.Errorf("%w: witness script mismatch",
+ ErrPsbtMergeConflict)
+ }
+
+ if len(dest.WitnessScript) == 0 {
+ dest.WitnessScript = src.WitnessScript
+ }
+
+ return nil
+}
+
+// mergeFinalScriptSig merges the FinalScriptSig field.
+func mergeFinalScriptSig(dest, src *psbt.PInput) error {
+ if len(dest.FinalScriptSig) > 0 && len(src.FinalScriptSig) > 0 &&
+ !bytes.Equal(dest.FinalScriptSig, src.FinalScriptSig) {
+
+ return fmt.Errorf("%w: final script sig mismatch",
+ ErrPsbtMergeConflict)
+ }
+
+ if len(dest.FinalScriptSig) == 0 {
+ dest.FinalScriptSig = src.FinalScriptSig
+ }
+
+ return nil
+}
+
+// mergeFinalScriptWitness merges the FinalScriptWitness field.
+func mergeFinalScriptWitness(dest, src *psbt.PInput) error {
+ if len(dest.FinalScriptWitness) > 0 &&
+ len(src.FinalScriptWitness) > 0 &&
+ !bytes.Equal(dest.FinalScriptWitness, src.FinalScriptWitness) {
+
+ return fmt.Errorf("%w: final script witness mismatch",
+ ErrPsbtMergeConflict)
+ }
+
+ if len(dest.FinalScriptWitness) == 0 {
+ dest.FinalScriptWitness = src.FinalScriptWitness
+ }
+
+ return nil
+}
+
+// mergeTaprootKeySpendSig merges the Taproot Key Spend Signature.
+// Returns error on conflict.
+func mergeTaprootKeySpendSig(dest, src *psbt.PInput) error {
+ if len(dest.TaprootKeySpendSig) > 0 &&
+ len(src.TaprootKeySpendSig) > 0 &&
+ !bytes.Equal(dest.TaprootKeySpendSig, src.TaprootKeySpendSig) {
+
+ return fmt.Errorf("%w: taproot key spend sig mismatch",
+ ErrPsbtMergeConflict)
+ }
+
+ if len(dest.TaprootKeySpendSig) == 0 {
+ dest.TaprootKeySpendSig = src.TaprootKeySpendSig
+ }
+
+ return nil
+}
+
+// mergeWitnessUtxo merges the Witness UTXO field. Returns error on conflict.
+func mergeWitnessUtxo(dest, src *psbt.PInput) error {
+ if dest.WitnessUtxo != nil && src.WitnessUtxo != nil {
+ if dest.WitnessUtxo.Value != src.WitnessUtxo.Value ||
+ !bytes.Equal(dest.WitnessUtxo.PkScript,
+ src.WitnessUtxo.PkScript) {
+
+ return fmt.Errorf("%w: witness utxo mismatch",
+ ErrPsbtMergeConflict)
+ }
+ }
+
+ if dest.WitnessUtxo == nil {
+ dest.WitnessUtxo = src.WitnessUtxo
+ }
+
+ return nil
+}
+
+// mergeNonWitnessUtxo merges the Non-Witness UTXO field. Returns error on
+// conflict (by TXID).
+func mergeNonWitnessUtxo(dest, src *psbt.PInput) error {
+ if dest.NonWitnessUtxo != nil && src.NonWitnessUtxo != nil {
+ if dest.NonWitnessUtxo.TxHash() != src.NonWitnessUtxo.TxHash() {
+ return fmt.Errorf("%w: non-witness utxo mismatch",
+ ErrPsbtMergeConflict)
+ }
+ }
+
+ if dest.NonWitnessUtxo == nil {
+ dest.NonWitnessUtxo = src.NonWitnessUtxo
+ }
+
+ return nil
+}
+
+// mergeTaprootInternalKey merges the Taproot Internal Key for outputs.
+// Returns error on conflict.
+func mergeTaprootInternalKey(dest, src *psbt.POutput) error {
+ if len(dest.TaprootInternalKey) > 0 &&
+ len(src.TaprootInternalKey) > 0 &&
+ !bytes.Equal(dest.TaprootInternalKey, src.TaprootInternalKey) {
+
+ return fmt.Errorf("%w: taproot internal key mismatch",
+ ErrPsbtMergeConflict)
+ }
+
+ if len(dest.TaprootInternalKey) == 0 {
+ dest.TaprootInternalKey = src.TaprootInternalKey
+ }
+
+ return nil
+}
+
+// mergeOutputScripts merges RedeemScript and WitnessScript for outputs.
+// Returns error on conflict.
+func mergeOutputScripts(dest, src *psbt.POutput) error {
+ if len(dest.RedeemScript) > 0 && len(src.RedeemScript) > 0 &&
+ !bytes.Equal(dest.RedeemScript, src.RedeemScript) {
+
+ return fmt.Errorf("%w: redeem script mismatch",
+ ErrPsbtMergeConflict)
+ }
+
+ if len(dest.RedeemScript) == 0 {
+ dest.RedeemScript = src.RedeemScript
+ }
+
+ if len(dest.WitnessScript) > 0 && len(src.WitnessScript) > 0 &&
+ !bytes.Equal(dest.WitnessScript, src.WitnessScript) {
+
+ return fmt.Errorf("%w: witness script mismatch",
+ ErrPsbtMergeConflict)
+ }
+
+ if len(dest.WitnessScript) == 0 {
+ dest.WitnessScript = src.WitnessScript
+ }
+
+ return nil
+}
+
+// addInputInfoSegWitV0 adds the UTXO and BIP32 derivation info for a
+// SegWit v0 PSBT input (p2wkh, np2wkh) from the given wallet
+// information.
+func addInputInfoSegWitV0(in *psbt.PInput, prevTx *wire.MsgTx, utxo *wire.TxOut,
+ derivationInfo *psbt.Bip32Derivation, addr waddrmgr.ManagedAddress,
+ witnessProgram []byte) {
+
+ // As a fix for CVE-2020-14199 we have to always include the full
+ // non-witness UTXO in the PSBT for segwit v0.
+ in.NonWitnessUtxo = prevTx
+
+ // To make it more obvious that this is actually a witness output being
+ // spent, we also add the same information as the witness UTXO.
+ in.WitnessUtxo = &wire.TxOut{
+ Value: utxo.Value,
+ PkScript: utxo.PkScript,
+ }
+ in.SighashType = txscript.SigHashAll
+
+ // Include the derivation path for each input.
+ in.Bip32Derivation = []*psbt.Bip32Derivation{
+ derivationInfo,
+ }
+
+ // For nested P2WKH we need to add the redeem script to the input,
+ // otherwise an offline wallet won't be able to sign for it. For normal
+ // P2WKH this will be nil.
+ if addr.AddrType() == waddrmgr.NestedWitnessPubKey {
+ in.RedeemScript = witnessProgram
+ }
+}
+
+// addInputInfoSegWitV1 adds the UTXO and BIP32 derivation info for a SegWit v1
+// PSBT input (p2tr) from the given wallet information.
+func addInputInfoSegWitV1(in *psbt.PInput, utxo *wire.TxOut,
+ derivationInfo *psbt.Bip32Derivation) {
+
+ // For SegWit v1 we only need the witness UTXO information.
+ in.WitnessUtxo = &wire.TxOut{
+ Value: utxo.Value,
+ PkScript: utxo.PkScript,
+ }
+ in.SighashType = txscript.SigHashDefault
+
+ // Include the derivation path for each input in addition to the
+ // taproot specific info we have below.
+ in.Bip32Derivation = []*psbt.Bip32Derivation{
+ derivationInfo,
+ }
+
+ // Include the derivation path for each input.
+ in.TaprootBip32Derivation = []*psbt.TaprootBip32Derivation{{
+ XOnlyPubKey: derivationInfo.PubKey[1:],
+ MasterKeyFingerprint: derivationInfo.MasterKeyFingerprint,
+ Bip32Path: derivationInfo.Bip32Path,
+ }}
+}
+
+// createOutputInfo creates the BIP32 derivation info for an output from our
+// internal wallet.
+func createOutputInfo(txOut *wire.TxOut,
+ addr waddrmgr.ManagedPubKeyAddress) (*psbt.POutput, error) {
+
+ // We don't know the derivation path for imported keys. Those shouldn't
+ // be selected as change outputs in the first place, but just to make
+ // sure we don't run into an issue, we return early for imported keys.
+ keyScope, derivationPath, isKnown := addr.DerivationInfo()
+ if !isKnown {
+ return nil, fmt.Errorf("error adding output info to PSBT: %w",
+ ErrImportedAddrNoDerivation)
+ }
+
+ // Include the derivation path for this output.
+ derivation := &psbt.Bip32Derivation{
+ PubKey: addr.PubKey().SerializeCompressed(),
+ MasterKeyFingerprint: derivationPath.MasterKeyFingerprint,
+ Bip32Path: []uint32{
+ keyScope.Purpose + hdkeychain.HardenedKeyStart,
+ keyScope.Coin + hdkeychain.HardenedKeyStart,
+ derivationPath.Account,
+ derivationPath.Branch,
+ derivationPath.Index,
+ },
+ }
+ out := &psbt.POutput{
+ Bip32Derivation: []*psbt.Bip32Derivation{
+ derivation,
+ },
+ }
+
+ // Include the Taproot derivation path as well if this is a P2TR output.
+ if txscript.IsPayToTaproot(txOut.PkScript) {
+ schnorrPubKey := derivation.PubKey[1:]
+ out.TaprootBip32Derivation = []*psbt.TaprootBip32Derivation{{
+ XOnlyPubKey: schnorrPubKey,
+ MasterKeyFingerprint: derivation.MasterKeyFingerprint,
+ Bip32Path: derivation.Bip32Path,
+ }}
+ out.TaprootInternalKey = schnorrPubKey
+ }
+
+ return out, nil
+}
+
+// PsbtPrevOutputFetcher returns a txscript.PrevOutputFetcher that is
+// backed by the UTXO information in a PSBT packet.
+func PsbtPrevOutputFetcher(packet *psbt.Packet) (
+ *txscript.MultiPrevOutFetcher, error) {
+
+ fetcher := txscript.NewMultiPrevOutFetcher(nil)
+ for idx, txIn := range packet.UnsignedTx.TxIn {
+ // Use the robust fetchPsbtUtxo helper.
+ utxo, err := fetchPsbtUtxo(packet, idx)
+ if err != nil {
+ // If the input is missing UTXO info entirely, we skip
+ // it (matching previous behavior).
+ if errors.Is(err, ErrInputMissingUtxoInfo) {
+ continue
+ }
+
+ // Other errors (e.g. index out of bounds) are fatal
+ // as they indicate a malformed PSBT.
+ return nil, err
+ }
+
+ fetcher.AddPrevOut(txIn.PreviousOutPoint, utxo)
+ }
+
+ return fetcher, nil
+}
+
+// deduplicateUnknowns adds new Unknowns from src to dest, avoiding duplicates
+// based on Key.
+//
+// TODO(yy): A more efficient approach would be to use a map to track the keys
+// of unknowns already in the dest slice, reducing the complexity to O(N+M).
+func deduplicateUnknowns(dest, src []*psbt.Unknown) []*psbt.Unknown {
+ for _, unknown := range src {
+ if !slices.ContainsFunc(dest, func(dU *psbt.Unknown) bool {
+ return bytes.Equal(dU.Key, unknown.Key)
+ }) {
+
+ dest = append(dest, unknown)
+ }
+ }
+
+ return dest
+}
diff --git a/wallet/psbt_manager_test.go b/wallet/psbt_manager_test.go
new file mode 100644
index 0000000000..4bc19ae3fb
--- /dev/null
+++ b/wallet/psbt_manager_test.go
@@ -0,0 +1,4719 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wallet
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "slices"
+ "testing"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcec/v2"
+ "github.com/btcsuite/btcd/btcec/v2/schnorr"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/psbt/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/pkg/btcunit"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/wallet/txauthor"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ errDb = errors.New("db error")
+ errKeyNotFound = errors.New("key not found")
+ errAddrNotFound = errors.New("addr not found")
+)
+
+// TestFindCredit tests that the findCredit helper returns true if a credit
+// exists at the specified index, and false otherwise.
+func TestFindCredit(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create TxDetails with credits at indices 0 and 2.
+ txDetails := &wtxmgr.TxDetails{
+ Credits: []wtxmgr.CreditRecord{
+ {Index: 0},
+ {Index: 2},
+ },
+ }
+
+ // Arrange: Define test cases to check for credits at various indices.
+ testCases := []struct {
+ name string
+ index uint32
+ expectedFound bool
+ }{
+ {
+ name: "credit exists at index 0",
+ index: 0,
+ expectedFound: true,
+ },
+ {
+ name: "credit exists at index 2",
+ index: 2,
+ expectedFound: true,
+ },
+ {
+ name: "credit does not exist at index 1",
+ index: 1,
+ expectedFound: false,
+ },
+ {
+ name: "credit does not exist at index 3",
+ index: 3,
+ expectedFound: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Act: Call findCredit with the configured TxDetails
+ // and index.
+ cred := findCredit(txDetails, tc.index)
+
+ // Assert: Verify that the returned credit record
+ // matches the expected outcome.
+ if tc.expectedFound {
+ require.NotNil(t, cred)
+ require.Equal(t, tc.index, cred.Index)
+ } else {
+ require.Nil(t, cred)
+ }
+ })
+ }
+}
+
+// TestFetchAndValidateUtxoSuccess tests that fetchAndValidateUtxo correctly
+// retrieves transaction details and validates ownership.
+func TestFetchAndValidateUtxoSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a transaction input (txHash:0) and mock the wallet's
+ // transaction store to return a corresponding credit at index 0.
+ txHash := chainhash.Hash{1}
+ txIn := &wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{Hash: txHash, Index: 0},
+ }
+
+ txDetails := &wtxmgr.TxDetails{
+ TxRecord: wtxmgr.TxRecord{
+ MsgTx: wire.MsgTx{
+ TxOut: []*wire.TxOut{
+ {Value: 1000},
+ },
+ },
+ },
+ Credits: []wtxmgr.CreditRecord{
+ {Index: 0},
+ },
+ }
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Mock the transaction store to return the details for our txHash.
+ mocks.txStore.On(
+ "TxDetails", mock.Anything,
+ mock.MatchedBy(func(h *chainhash.Hash) bool {
+ return h.IsEqual(&txHash)
+ }),
+ ).Return(txDetails, nil)
+
+ // Act: Call fetchAndValidateUtxo with the valid input.
+ tx, utxo, err := w.fetchAndValidateUtxo(txIn)
+
+ // Assert: Verify that no error occurred and that the returned
+ // transaction and UTXO match the expected values from the store.
+ require.NoError(t, err)
+ require.NotNil(t, tx)
+ require.NotNil(t, utxo)
+ require.Equal(t, txDetails.MsgTx.TxOut[0], utxo)
+}
+
+// TestFetchAndValidateUtxoError tests that fetchAndValidateUtxo returns the
+// expected errors for various failure conditions.
+func TestFetchAndValidateUtxoError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Prepare common data structures for the test cases.
+ txHash := chainhash.Hash{1}
+
+ // txIn pointing to an unlocked outpoint (Index 0).
+ txIn := &wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{Hash: txHash, Index: 0},
+ }
+
+ // txInLocked pointing to a locked outpoint (Index 1).
+ txInLocked := &wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{Hash: txHash, Index: 1},
+ }
+
+ // txDetails contains credits for both Index 0 and Index 1.
+ // Index 0 is used for unlocked tests.
+ // Index 1 is used for the locked test.
+ txDetails := &wtxmgr.TxDetails{
+ TxRecord: wtxmgr.TxRecord{
+ MsgTx: wire.MsgTx{
+ TxOut: []*wire.TxOut{
+ {Value: 1000},
+ {Value: 1000},
+ },
+ },
+ },
+ Credits: []wtxmgr.CreditRecord{
+ {Index: 0},
+ {Index: 1},
+ },
+ }
+
+ noCreditDetails := &wtxmgr.TxDetails{
+ TxRecord: txDetails.TxRecord,
+ Credits: []wtxmgr.CreditRecord{},
+ }
+
+ lockedDetails := &wtxmgr.TxDetails{
+ TxRecord: txDetails.TxRecord,
+ Credits: []wtxmgr.CreditRecord{
+ {Index: 0},
+ {Index: 1, Locked: true},
+ },
+ }
+
+ testCases := []struct {
+ name string
+ txIn *wire.TxIn
+ mockTxDetails *wtxmgr.TxDetails
+ mockErr error
+ expectedErr error
+ }{
+ {
+ name: "tx not found",
+ txIn: txIn,
+ mockTxDetails: nil,
+ mockErr: ErrTxNotFound,
+ expectedErr: ErrNotMine,
+ },
+ {
+ name: "store error",
+ txIn: txIn,
+ mockTxDetails: nil,
+ mockErr: errDb,
+ expectedErr: errDb,
+ },
+ {
+ name: "not credit",
+ txIn: txIn,
+ mockTxDetails: noCreditDetails,
+ mockErr: nil,
+ expectedErr: ErrNotMine,
+ },
+ {
+ name: "utxo locked",
+ txIn: txInLocked,
+ mockTxDetails: lockedDetails,
+ mockErr: nil,
+ expectedErr: ErrUtxoLocked,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock the transaction store to return the
+ // configured details or error for the specific test
+ // case.
+ mocks.txStore.On(
+ "TxDetails", mock.Anything,
+ mock.MatchedBy(func(h *chainhash.Hash) bool {
+ return h.IsEqual(&txHash)
+ }),
+ ).Return(tc.mockTxDetails, tc.mockErr)
+
+ // Act: Call fetchAndValidateUtxo with the configured
+ // input.
+ tx, utxo, err := w.fetchAndValidateUtxo(tc.txIn)
+
+ // Assert: Verify that the returned error matches the
+ // expected error and that no transaction or UTXO is
+ // returned.
+ require.ErrorIs(t, err, tc.expectedErr)
+ require.Nil(t, tx)
+ require.Nil(t, utxo)
+ })
+ }
+}
+
+// TestDecorateInputSegWitV0 tests that decorateInput correctly populates
+// PSBT input fields for a SegWit v0 (P2WKH) input.
+func TestDecorateInputSegWitV0(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup private and public keys for a P2WKH address.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ // Arrange: Create a P2WKH address and its corresponding script.
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ // Arrange: Define key scope and derivation path for address manager
+ // mocks.
+ keyScope := waddrmgr.KeyScopeBIP0084
+ derivationPath := waddrmgr.DerivationPath{
+ Account: 0,
+ Branch: 0,
+ Index: 0,
+ }
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock the address manager to return our P2WKH address as a
+ // ManagedPubKeyAddress when `Address` is called with the P2WKH
+ // address.
+ mocks.addrStore.On(
+ "Address", mock.Anything,
+ mock.MatchedBy(func(addr address.Address) bool {
+ return addr.String() == p2wkhAddr.String()
+ }),
+ ).Return(mocks.pubKeyAddr, nil)
+
+ // Arrange: Mock the ManagedPubKeyAddress methods to return relevant
+ // derivation and public key information.
+ mocks.pubKeyAddr.On("Imported").Return(false)
+ mocks.pubKeyAddr.On("DerivationInfo").Return(
+ keyScope, derivationPath, true,
+ )
+ mocks.pubKeyAddr.On("PubKey").Return(pubKey)
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.WitnessPubKey)
+
+ // Arrange: Create a UTXO with the P2WKH script and an empty PSBT input.
+ utxo := &wire.TxOut{
+ Value: 1000,
+ PkScript: p2wkhScript,
+ }
+ tx := &wire.MsgTx{}
+ pInput := &psbt.PInput{}
+
+ // Act: Call decorateInput to populate the PSBT input.
+ err = w.decorateInput(t.Context(), pInput, tx, utxo)
+
+ // Assert: Verify no error occurred and that the PSBT input is correctly
+ // populated with WitnessUtxo, NonWitnessUtxo, SighashType, and BIP32
+ // derivation info.
+ require.NoError(t, err)
+ require.Equal(t, utxo, pInput.WitnessUtxo)
+ require.Equal(t, tx, pInput.NonWitnessUtxo)
+ require.Equal(t, txscript.SigHashAll, pInput.SighashType)
+ require.Len(t, pInput.Bip32Derivation, 1)
+ require.Equal(
+ t, pubKey.SerializeCompressed(),
+ pInput.Bip32Derivation[0].PubKey,
+ )
+}
+
+// TestDecorateInputTaproot tests that decorateInput correctly populates
+// PSBT input fields for a Taproot (SegWit v1) input.
+func TestDecorateInputTaproot(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup private and public keys for a Taproot address.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ // Arrange: Create a Taproot address and its corresponding script.
+ taprootAddr, err := address.NewAddressTaproot(
+ schnorr.SerializePubKey(pubKey), &chainParams,
+ )
+ require.NoError(t, err)
+
+ taprootScript, err := txscript.PayToAddrScript(taprootAddr)
+ require.NoError(t, err)
+
+ // Arrange: Define key scope and derivation path for address manager
+ // mocks.
+ keyScope := waddrmgr.KeyScopeBIP0084
+ derivationPath := waddrmgr.DerivationPath{
+ Account: 0,
+ Branch: 0,
+ Index: 0,
+ }
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock the address manager to return our Taproot address as a
+ // ManagedPubKeyAddress when `Address` is called with the Taproot
+ // address.
+ mocks.addrStore.On(
+ "Address", mock.Anything,
+ mock.MatchedBy(func(addr address.Address) bool {
+ return addr.String() == taprootAddr.String()
+ }),
+ ).Return(mocks.pubKeyAddr, nil)
+
+ // Arrange: Mock the ManagedPubKeyAddress methods to return relevant
+ // derivation and public key information. AddrType is not strictly
+ // checked for Taproot inputs in decorateInput, so no mock is needed
+ // for it.
+ mocks.pubKeyAddr.On("Imported").Return(false)
+ mocks.pubKeyAddr.On("DerivationInfo").Return(
+ keyScope, derivationPath, true,
+ )
+ mocks.pubKeyAddr.On("PubKey").Return(pubKey)
+
+ // Arrange: Create a UTXO with the Taproot script and an empty PSBT
+ // input.
+ utxo := &wire.TxOut{
+ Value: 1000,
+ PkScript: taprootScript,
+ }
+ tx := &wire.MsgTx{}
+ pInput := &psbt.PInput{}
+
+ // Act: Call decorateInput to populate the PSBT input.
+ err = w.decorateInput(t.Context(), pInput, tx, utxo)
+
+ // Assert: Verify no error occurred and that the PSBT input is
+ // correctly populated with WitnessUtxo, SighashType, and Taproot BIP32
+ // derivation info, including the x-only public key.
+ require.NoError(t, err)
+ require.Equal(t, utxo, pInput.WitnessUtxo)
+ require.Equal(t, txscript.SigHashDefault, pInput.SighashType)
+ require.Len(t, pInput.TaprootBip32Derivation, 1)
+ require.Equal(
+ t, schnorr.SerializePubKey(pubKey),
+ pInput.TaprootBip32Derivation[0].XOnlyPubKey,
+ )
+}
+
+// TestDecorateInputErrExtractAddr tests that decorateInput returns
+// ErrUnableToExtractAddress when the pkScript does not contain a valid
+// address.
+func TestDecorateInputErrExtractAddr(t *testing.T) {
+ t.Parallel()
+
+ w, _ := createStartedWalletWithMocks(t)
+
+ // Arrange: Create a UTXO with an OP_RETURN script, which cannot be
+ // parsed into a valid address.
+ utxo := &wire.TxOut{
+ Value: 1000,
+ PkScript: []byte{0x6a}, // OP_RETURN
+ }
+ tx := &wire.MsgTx{}
+ pInput := &psbt.PInput{}
+
+ // Act: Call decorateInput.
+ err := w.decorateInput(t.Context(), pInput, tx, utxo)
+
+ // Assert: Verify the error.
+ require.ErrorIs(t, err, ErrUnableToExtractAddress)
+}
+
+// TestDecorateInputErrAddrInfo tests that decorateInput returns an error when
+// the address lookup fails.
+func TestDecorateInputErrAddrInfo(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys and address.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock AddressInfo to return an error.
+ mocks.addrStore.On(
+ "Address", mock.Anything, mock.Anything,
+ ).Return(nil, errDb)
+
+ utxo := &wire.TxOut{
+ Value: 1000,
+ PkScript: p2wkhScript,
+ }
+ tx := &wire.MsgTx{}
+ pInput := &psbt.PInput{}
+
+ // Act: Call decorateInput.
+ err = w.decorateInput(t.Context(), pInput, tx, utxo)
+
+ // Assert: Verify the error.
+ require.ErrorIs(t, err, errDb)
+}
+
+// TestDecorateInputErrNotPubKey tests that decorateInput returns
+// ErrNotPubKeyAddress when the address is not a ManagedPubKeyAddress.
+func TestDecorateInputErrNotPubKey(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys and address.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock AddressInfo to return a generic ManagedAddress
+ // (mocks.addr) instead of a ManagedPubKeyAddress (mocks.pubKeyAddr).
+ mocks.addrStore.On(
+ "Address", mock.Anything, mock.Anything,
+ ).Return(mocks.addr, nil)
+
+ mocks.addr.On("Address").Return(p2wkhAddr)
+
+ utxo := &wire.TxOut{
+ Value: 1000,
+ PkScript: p2wkhScript,
+ }
+ tx := &wire.MsgTx{}
+ pInput := &psbt.PInput{}
+
+ // Act: Call decorateInput.
+ err = w.decorateInput(t.Context(), pInput, tx, utxo)
+
+ // Assert: Verify the error.
+ require.ErrorIs(t, err, ErrNotPubKeyAddress)
+}
+
+// TestDecorateInputErrImported tests that decorateInput returns
+// ErrDerivationPathNotFound when the address is imported.
+func TestDecorateInputErrImported(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys and address.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock AddressInfo to return a ManagedPubKeyAddress that is
+ // marked as imported.
+ mocks.addrStore.On(
+ "Address", mock.Anything, mock.Anything,
+ ).Return(mocks.pubKeyAddr, nil)
+
+ mocks.pubKeyAddr.On("Imported").Return(true)
+ mocks.pubKeyAddr.On("Address").Return(p2wkhAddr)
+
+ utxo := &wire.TxOut{
+ Value: 1000,
+ PkScript: p2wkhScript,
+ }
+ tx := &wire.MsgTx{}
+ pInput := &psbt.PInput{}
+
+ // Act: Call decorateInput.
+ err = w.decorateInput(t.Context(), pInput, tx, utxo)
+
+ // Assert: Verify the error.
+ require.ErrorIs(t, err, ErrDerivationPathNotFound)
+}
+
+// TestDecorateInputErrDerivationMissing tests that decorateInput returns
+// ErrDerivationPathNotFound when derivation info is missing.
+func TestDecorateInputErrDerivationMissing(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys and address.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock AddressInfo to return a ManagedPubKeyAddress that has
+ // no derivation info.
+ mocks.addrStore.On(
+ "Address", mock.Anything, mock.Anything,
+ ).Return(mocks.pubKeyAddr, nil)
+
+ mocks.pubKeyAddr.On("Imported").Return(false)
+ mocks.pubKeyAddr.On("DerivationInfo").Return(
+ waddrmgr.KeyScope{}, waddrmgr.DerivationPath{}, false,
+ )
+ mocks.pubKeyAddr.On("Address").Return(p2wkhAddr)
+
+ utxo := &wire.TxOut{
+ Value: 1000,
+ PkScript: p2wkhScript,
+ }
+ tx := &wire.MsgTx{}
+ pInput := &psbt.PInput{}
+
+ // Act: Call decorateInput.
+ err = w.decorateInput(t.Context(), pInput, tx, utxo)
+
+ // Assert: Verify the error.
+ require.ErrorIs(t, err, ErrDerivationPathNotFound)
+}
+
+// TestDecorateInputsSuccess tests that DecorateInputs correctly decorates
+// known inputs and skips unknown inputs when skipUnknown is true.
+func TestDecorateInputsSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys and address.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ // Arrange: Define 3 inputs.
+ // Input 0: Known (TxHash0)
+ // Input 1: Unknown (TxHash1)
+ // Input 2: Known (TxHash2)
+ txHash0 := chainhash.Hash{0}
+ txHash1 := chainhash.Hash{1}
+ txHash2 := chainhash.Hash{2}
+
+ unsignedTx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{
+ {PreviousOutPoint: wire.OutPoint{
+ Hash: txHash0, Index: 0,
+ }},
+ {PreviousOutPoint: wire.OutPoint{
+ Hash: txHash1, Index: 0,
+ }},
+ {PreviousOutPoint: wire.OutPoint{
+ Hash: txHash2, Index: 0,
+ }},
+ },
+ }
+
+ packet, err := psbt.NewFromUnsignedTx(unsignedTx)
+ require.NoError(t, err)
+
+ // Arrange: Setup TxDetails for known inputs.
+ txDetails0 := &wtxmgr.TxDetails{
+ TxRecord: wtxmgr.TxRecord{
+ MsgTx: wire.MsgTx{
+ TxOut: []*wire.TxOut{{
+ Value: 1000, PkScript: p2wkhScript,
+ }},
+ },
+ },
+ Credits: []wtxmgr.CreditRecord{{Index: 0}},
+ }
+ txDetails2 := &wtxmgr.TxDetails{
+ TxRecord: wtxmgr.TxRecord{
+ MsgTx: wire.MsgTx{
+ TxOut: []*wire.TxOut{{
+ Value: 2000, PkScript: p2wkhScript,
+ }},
+ },
+ },
+ Credits: []wtxmgr.CreditRecord{{Index: 0}},
+ }
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock TxDetails lookups.
+ // Input 0 -> Found
+ mocks.txStore.On(
+ "TxDetails", mock.Anything,
+ mock.MatchedBy(func(h *chainhash.Hash) bool {
+ return h.IsEqual(&txHash0)
+ }),
+ ).Return(txDetails0, nil)
+
+ // Input 1 -> Not Found
+ mocks.txStore.On(
+ "TxDetails", mock.Anything,
+ mock.MatchedBy(func(h *chainhash.Hash) bool {
+ return h.IsEqual(&txHash1)
+ }),
+ ).Return(nil, ErrTxNotFound)
+
+ // Input 2 -> Found
+ mocks.txStore.On(
+ "TxDetails", mock.Anything,
+ mock.MatchedBy(func(h *chainhash.Hash) bool {
+ return h.IsEqual(&txHash2)
+ }),
+ ).Return(txDetails2, nil)
+
+ // Arrange: Mock Address lookup (common for both known inputs).
+ mocks.addrStore.On(
+ "Address", mock.Anything,
+ mock.MatchedBy(func(addr address.Address) bool {
+ return addr.String() == p2wkhAddr.String()
+ }),
+ ).Return(mocks.pubKeyAddr, nil)
+
+ // Arrange: Mock ManagedPubKeyAddress methods.
+ mocks.pubKeyAddr.On("Imported").Return(false)
+ mocks.pubKeyAddr.On("DerivationInfo").Return(
+ waddrmgr.KeyScopeBIP0084, waddrmgr.DerivationPath{}, true,
+ )
+ mocks.pubKeyAddr.On("PubKey").Return(pubKey)
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.WitnessPubKey)
+
+ // Act: Call DecorateInputs with skipUnknown=true.
+ _, err = w.DecorateInputs(t.Context(), packet, true)
+ require.NoError(t, err)
+
+ // Assert: Input 0 is decorated.
+ require.NotNil(t, packet.Inputs[0].WitnessUtxo)
+ require.Equal(t, int64(1000), packet.Inputs[0].WitnessUtxo.Value)
+ require.Len(t, packet.Inputs[0].Bip32Derivation, 1)
+
+ // Assert: Input 1 is NOT decorated.
+ require.Nil(t, packet.Inputs[1].WitnessUtxo)
+ require.Nil(t, packet.Inputs[1].NonWitnessUtxo)
+ require.Empty(t, packet.Inputs[1].Bip32Derivation)
+
+ // Assert: Input 2 is decorated.
+ require.NotNil(t, packet.Inputs[2].WitnessUtxo)
+ require.Equal(t, int64(2000), packet.Inputs[2].WitnessUtxo.Value)
+ require.Len(t, packet.Inputs[2].Bip32Derivation, 1)
+}
+
+// TestDecorateInputsErrUnknownRequired tests that DecorateInputs returns
+// ErrNotMine when an input is unknown and skipUnknown is false.
+func TestDecorateInputsErrUnknownRequired(t *testing.T) {
+ t.Parallel()
+
+ txHash := chainhash.Hash{1}
+ unsignedTx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: txHash,
+ Index: 0,
+ },
+ }},
+ }
+ packet, err := psbt.NewFromUnsignedTx(unsignedTx)
+ require.NoError(t, err)
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock TxDetails to return ErrTxNotFound.
+ mocks.txStore.On(
+ "TxDetails", mock.Anything, mock.Anything,
+ ).Return(nil, ErrTxNotFound)
+
+ // Act: Call DecorateInputs with skipUnknown=false.
+ _, err = w.DecorateInputs(t.Context(), packet, false)
+
+ // Assert: Error is ErrNotMine.
+ require.ErrorIs(t, err, ErrNotMine)
+}
+
+// TestDecorateInputsErrFetchFailed tests that DecorateInputs returns an error
+// when fetching/validating a UTXO fails with a non-ErrNotMine error.
+func TestDecorateInputsErrFetchFailed(t *testing.T) {
+ t.Parallel()
+
+ txHash := chainhash.Hash{1}
+ unsignedTx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: txHash,
+ Index: 0,
+ },
+ }},
+ }
+ packet, err := psbt.NewFromUnsignedTx(unsignedTx)
+ require.NoError(t, err)
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock TxDetails to return a database error.
+ mocks.txStore.On(
+ "TxDetails", mock.Anything, mock.Anything,
+ ).Return(nil, errDb)
+
+ // Act: Call DecorateInputs (skipUnknown irrelevant for other errors).
+ _, err = w.DecorateInputs(t.Context(), packet, true)
+
+ // Assert: Error is errDb.
+ require.ErrorIs(t, err, errDb)
+}
+
+// TestDecorateInputsErrDecorationFailed tests that DecorateInputs returns an
+// error when the internal decorateInput call fails.
+func TestDecorateInputsErrDecorationFailed(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup valid key/address/script for a known input.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ txHash := chainhash.Hash{1}
+ unsignedTx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: txHash,
+ Index: 0,
+ },
+ }},
+ }
+ packet, err := psbt.NewFromUnsignedTx(unsignedTx)
+ require.NoError(t, err)
+
+ txDetails := &wtxmgr.TxDetails{
+ TxRecord: wtxmgr.TxRecord{
+ MsgTx: wire.MsgTx{
+ TxOut: []*wire.TxOut{{
+ Value: 1000, PkScript: p2wkhScript,
+ }},
+ },
+ },
+ Credits: []wtxmgr.CreditRecord{{Index: 0}},
+ }
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock TxDetails success.
+ mocks.txStore.On(
+ "TxDetails", mock.Anything, mock.Anything,
+ ).Return(txDetails, nil)
+
+ // Arrange: Mock AddressInfo to fail (causing decorateInput to fail).
+ mocks.addrStore.On(
+ "Address", mock.Anything, mock.Anything,
+ ).Return(nil, errDb)
+
+ // Act: Call DecorateInputs.
+ _, err = w.DecorateInputs(t.Context(), packet, true)
+
+ // Assert: Error is errDb.
+ require.ErrorIs(t, err, errDb)
+}
+
+// TestValidateFundIntentSuccess tests that validateFundIntent returns no error
+// for valid funding intents.
+func TestValidateFundIntentSuccess(t *testing.T) {
+ t.Parallel()
+
+ w, _ := createStartedWalletWithMocks(t)
+
+ // Arrange: Create a valid PSBT packet with one output (for auto
+ // selection).
+ tx := wire.NewMsgTx(2)
+ tx.AddTxOut(&wire.TxOut{})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ // Arrange: Create a FundIntent for automatic coin selection (no inputs
+ // in packet).
+ intentAuto := &FundIntent{
+ Packet: packet,
+ }
+
+ // Arrange: Create a valid PSBT packet with one input and one output
+ // (for manual selection).
+ txWithInputs := wire.NewMsgTx(2)
+ txWithInputs.AddTxIn(&wire.TxIn{})
+ txWithInputs.AddTxOut(&wire.TxOut{})
+ packetWithInputs, err := psbt.NewFromUnsignedTx(txWithInputs)
+ require.NoError(t, err)
+
+ // Arrange: Create a FundIntent for manual coin selection (inputs
+ // present in packet).
+ intentManual := &FundIntent{
+ Packet: packetWithInputs,
+ }
+
+ // Act & Assert: Validate the auto selection intent. Expect no error.
+ err = w.validateFundIntent(intentAuto)
+ require.NoError(t, err)
+
+ // Act & Assert: Validate the manual selection intent. Expect no error.
+ err = w.validateFundIntent(intentManual)
+ require.NoError(t, err)
+}
+
+// TestValidateFundIntentError tests that validateFundIntent returns expected
+// errors for invalid funding intents.
+func TestValidateFundIntentError(t *testing.T) {
+ t.Parallel()
+
+ w, _ := createStartedWalletWithMocks(t)
+
+ // Arrange: Helper function to create a PSBT packet with specified
+ // inputs and outputs.
+ createPacket := func(numInputs, numOutputs int) *psbt.Packet {
+ tx := wire.NewMsgTx(2)
+ for range numInputs {
+ tx.AddTxIn(&wire.TxIn{})
+ }
+
+ for range numOutputs {
+ tx.AddTxOut(&wire.TxOut{})
+ }
+
+ p, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ return p
+ }
+
+ // Arrange: Define test cases for various error scenarios.
+ testCases := []struct {
+ name string
+ intent *FundIntent
+ expectedErr error
+ }{
+ {
+ name: "nil intent",
+ intent: nil,
+ expectedErr: ErrNilArguments,
+ }, {
+ name: "nil packet",
+ intent: &FundIntent{Packet: nil},
+ expectedErr: ErrNilTxIntent,
+ },
+ {
+ name: "no inputs and no outputs",
+ intent: &FundIntent{Packet: createPacket(0, 0)},
+ expectedErr: ErrPacketOutputsMissing,
+ },
+ {
+ name: "inputs and policy conflict",
+ intent: &FundIntent{
+ Packet: createPacket(1, 1),
+ Policy: &InputsPolicy{},
+ },
+ expectedErr: ErrInputsAndPolicy,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Act: Call validateFundIntent with the configured
+ // invalid intent.
+ err := w.validateFundIntent(tc.intent)
+
+ // Assert: Verify that the returned error matches the
+ // expected error.
+ require.ErrorIs(t, err, tc.expectedErr)
+ })
+ }
+}
+
+// TestCreateTxIntentAuto tests that createTxIntent correctly converts
+// FundIntent to TxIntent for automatic coin selection.
+func TestCreateTxIntentAuto(t *testing.T) {
+ t.Parallel()
+
+ w, _ := createStartedWalletWithMocks(t)
+
+ // Arrange: Create a PSBT packet with two outputs and no inputs,
+ // which signals automatic coin selection.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxOut(&wire.TxOut{Value: 1})
+ tx.AddTxOut(&wire.TxOut{Value: 2})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ // Arrange: Define the fee rate, coin selection policy, and change
+ // source.
+ feeRate := btcunit.NewSatPerKVByte(1000)
+ policy := &InputsPolicy{
+ MinConfs: 1,
+ }
+ changeSource := &ScopedAccount{}
+
+ // Arrange: Create the FundIntent with the above parameters.
+ intent := &FundIntent{
+ Packet: packet,
+ Policy: policy,
+ FeeRate: feeRate,
+ Label: "test",
+ ChangeSource: changeSource,
+ }
+
+ // Act: Call createTxIntent to convert the FundIntent.
+ txIntent := w.createTxIntent(intent)
+
+ // Assert: Verify that the basic fields of the resulting TxIntent
+ // match the input FundIntent.
+ expectedOutputs := []wire.TxOut{{Value: 1}, {Value: 2}}
+ require.Equal(t, expectedOutputs, txIntent.Outputs)
+ require.Equal(t, feeRate, txIntent.FeeRate)
+ require.Equal(t, "test", txIntent.Label)
+ require.Equal(t, changeSource, txIntent.ChangeSource)
+
+ // Assert: Verify that the Inputs field of TxIntent is of type
+ // *InputsPolicy and matches the expected policy for auto selection.
+ inputsPolicy, ok := txIntent.Inputs.(*InputsPolicy)
+ require.True(t, ok)
+ require.Equal(t, policy, inputsPolicy)
+}
+
+// TestCreateTxIntentManual tests that createTxIntent correctly converts
+// FundIntent to TxIntent for manual coin selection.
+func TestCreateTxIntentManual(t *testing.T) {
+ t.Parallel()
+
+ w, _ := createStartedWalletWithMocks(t)
+
+ // Arrange: Create a PSBT packet with two inputs and one output,
+ // which signals manual coin selection.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{Index: 0},
+ })
+ tx.AddTxIn(&wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{Index: 1},
+ })
+ tx.AddTxOut(&wire.TxOut{Value: 1})
+
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ // Arrange: Define the fee rate and change source. Policy is not needed
+ // for manual selection.
+ feeRate := btcunit.NewSatPerKVByte(1000)
+ changeSource := &ScopedAccount{}
+
+ // Arrange: Create the FundIntent with the above parameters.
+ intent := &FundIntent{
+ Packet: packet,
+ FeeRate: feeRate,
+ Label: "manual",
+ ChangeSource: changeSource,
+ }
+
+ // Act: Call createTxIntent to convert the FundIntent.
+ txIntent := w.createTxIntent(intent)
+
+ // Assert: Verify that the basic fields of the resulting TxIntent
+ // match the input FundIntent.
+ expectedOutputs := []wire.TxOut{{Value: 1}}
+ require.Equal(t, expectedOutputs, txIntent.Outputs)
+ require.Equal(t, feeRate, txIntent.FeeRate)
+ require.Equal(t, "manual", txIntent.Label)
+ require.Equal(t, changeSource, txIntent.ChangeSource)
+
+ // Assert: Verify that the Inputs field of TxIntent is of type
+ // *InputsManual and contains the expected UTXOs from the packet inputs.
+ inputsManual, ok := txIntent.Inputs.(*InputsManual)
+ require.True(t, ok)
+
+ expectedUTXOs := []wire.OutPoint{{Index: 0}, {Index: 1}}
+ require.Equal(t, expectedUTXOs, inputsManual.UTXOs)
+}
+
+// TestFindChangeIndex tests that findChangeIndex correctly locates the change
+// output in the sorted PSBT packet.
+func TestFindChangeIndex(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create three distinct transaction outputs.
+ out1 := &wire.TxOut{Value: 1000, PkScript: []byte{1}}
+ out2 := &wire.TxOut{Value: 2000, PkScript: []byte{2}}
+
+ // Identified as the change output.
+ changeOut := &wire.TxOut{Value: 500, PkScript: []byte{3}}
+
+ // Arrange: Setup a PSBT Packet where the outputs are sorted
+ // differently, with the change output now at index 0: [changeOut,
+ // out1, out2].
+ packet := &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{
+ TxOut: []*wire.TxOut{changeOut, out1, out2},
+ },
+ }
+
+ // Act: Call findChangeIndex to locate the change output within the
+ // sorted packet.
+ idx, err := findChangeIndex(changeOut, packet)
+
+ // Assert: Verify that no error occurred and the change index found in
+ // the packet is 0, matching its new sorted position.
+ require.NoError(t, err)
+ require.Equal(t, int32(0), idx)
+
+ // Act: Call findChangeIndex for the case with no change output (nil).
+ idx, err = findChangeIndex(nil, packet)
+
+ // Assert: Verify that no error occurred and the returned index is -1,
+ // correctly indicating the absence of a change output.
+ require.NoError(t, err)
+ require.Equal(t, int32(-1), idx)
+
+ // Act: Call findChangeIndex for a change output not present in the
+ // packet.
+ unknownOut := &wire.TxOut{Value: 9999, PkScript: []byte{4}}
+ idx, err = findChangeIndex(unknownOut, packet)
+
+ // Assert: Verify that no error occurred and the returned index is -1.
+ require.NoError(t, err)
+ require.Equal(t, int32(-1), idx)
+}
+
+// TestAddChangeOutputInfoSuccess tests that addChangeOutputInfo correctly adds
+// derivation information to the change output.
+func TestAddChangeOutputInfoSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys and address.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ // Arrange: Create an AuthoredTx with a change output at index 0.
+ changeOut := &wire.TxOut{
+ Value: 500,
+ PkScript: p2wkhScript,
+ }
+ authoredTx := &txauthor.AuthoredTx{
+ Tx: &wire.MsgTx{
+ TxOut: []*wire.TxOut{changeOut},
+ },
+ ChangeIndex: 0,
+ }
+
+ // Arrange: Create a PSBT packet with a corresponding output.
+ packet, err := psbt.NewFromUnsignedTx(authoredTx.Tx)
+ require.NoError(t, err)
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock Address lookup.
+ mocks.addrStore.On(
+ "Address", mock.Anything,
+ mock.MatchedBy(func(addr address.Address) bool {
+ return addr.String() == p2wkhAddr.String()
+ }),
+ ).Return(mocks.pubKeyAddr, nil)
+
+ // Arrange: Mock ManagedPubKeyAddress methods.
+ mocks.pubKeyAddr.On("Address").Return(p2wkhAddr)
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.WitnessPubKey)
+ mocks.pubKeyAddr.On("PubKey").Return(pubKey)
+ // Removed Imported() as addChangeOutputInfo does not call it.
+ mocks.pubKeyAddr.On("DerivationInfo").Return(
+ waddrmgr.KeyScopeBIP0084, waddrmgr.DerivationPath{}, true,
+ )
+
+ // Act: Call addChangeOutputInfo.
+ err = w.addChangeOutputInfo(t.Context(), packet, authoredTx)
+
+ // Assert: Verify success and that derivation info is added.
+ require.NoError(t, err)
+ require.Len(t, packet.Outputs[0].Bip32Derivation, 1)
+ require.Equal(
+ t, pubKey.SerializeCompressed(),
+ packet.Outputs[0].Bip32Derivation[0].PubKey,
+ )
+}
+
+// TestAddChangeOutputInfoErrScriptFail tests that addChangeOutputInfo returns
+// an error if the script cannot be resolved (e.g. address lookup fails).
+func TestAddChangeOutputInfoErrScriptFail(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys/address.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ // Arrange: Create authoredTx with change output.
+ authoredTx := &txauthor.AuthoredTx{
+ Tx: &wire.MsgTx{
+ TxOut: []*wire.TxOut{{
+ Value: 500, PkScript: p2wkhScript,
+ }},
+ },
+ ChangeIndex: 0,
+ }
+ packet, err := psbt.NewFromUnsignedTx(authoredTx.Tx)
+ require.NoError(t, err)
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock Address lookup to fail.
+ mocks.addrStore.On(
+ "Address", mock.Anything, mock.Anything,
+ ).Return(nil, errDb)
+
+ // Act: Call addChangeOutputInfo.
+ err = w.addChangeOutputInfo(t.Context(), packet, authoredTx)
+
+ // Assert: Verify error (from ScriptForOutput).
+ require.ErrorIs(t, err, errDb)
+}
+
+// TestAddChangeOutputInfoErrNotPubKey tests that addChangeOutputInfo returns
+// ErrChangeAddressNotManagedPubKey if the change address is not a pubkey addr.
+func TestAddChangeOutputInfoErrNotPubKey(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys/address.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ authoredTx := &txauthor.AuthoredTx{
+ Tx: &wire.MsgTx{
+ TxOut: []*wire.TxOut{{
+ Value: 500, PkScript: p2wkhScript,
+ }},
+ },
+ ChangeIndex: 0,
+ }
+ packet, err := psbt.NewFromUnsignedTx(authoredTx.Tx)
+ require.NoError(t, err)
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock Address lookup to return a generic address.
+ mocks.addrStore.On(
+ "Address", mock.Anything, mock.Anything,
+ ).Return(mocks.addr, nil)
+ mocks.addr.On("Address").Return(p2wkhAddr)
+
+ // Act: Call addChangeOutputInfo.
+ err = w.addChangeOutputInfo(t.Context(), packet, authoredTx)
+
+ // Assert: Verify error (ErrNotPubKeyAddress from ScriptForOutput
+ // check).
+ require.ErrorIs(t, err, ErrNotPubKeyAddress)
+}
+
+// TestAddChangeOutputInfoErrDerivationUnknown tests that addChangeOutputInfo
+// returns an error if the change address has no derivation info.
+func TestAddChangeOutputInfoErrDerivationUnknown(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys/address.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ authoredTx := &txauthor.AuthoredTx{
+ Tx: &wire.MsgTx{
+ TxOut: []*wire.TxOut{{
+ Value: 500, PkScript: p2wkhScript,
+ }},
+ },
+ ChangeIndex: 0,
+ }
+ packet, err := psbt.NewFromUnsignedTx(authoredTx.Tx)
+ require.NoError(t, err)
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock Address lookup.
+ mocks.addrStore.On(
+ "Address", mock.Anything, mock.Anything,
+ ).Return(mocks.pubKeyAddr, nil)
+
+ // Arrange: Mock ManagedPubKeyAddress methods.
+ mocks.pubKeyAddr.On("Address").Return(p2wkhAddr)
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.WitnessPubKey)
+ // PubKey is not called because DerivationInfo returns false.
+ // DerivationInfo returns false (unknown/imported).
+ mocks.pubKeyAddr.On("DerivationInfo").Return(
+ waddrmgr.KeyScope{}, waddrmgr.DerivationPath{}, false,
+ )
+
+ // Act: Call addChangeOutputInfo.
+ err = w.addChangeOutputInfo(t.Context(), packet, authoredTx)
+
+ // Assert: Verify error.
+ require.ErrorContains(t, err, "change addr is an imported addr")
+}
+
+// TestPopulatePsbtPacketErrors tests error paths in populatePsbtPacket.
+func TestPopulatePsbtPacketErrors(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ // Input Address (Valid)
+ addrIn, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+ scriptIn, err := txscript.PayToAddrScript(addrIn)
+ require.NoError(t, err)
+
+ // Output Address (Valid struct, but will mock failure)
+ addrOut, err := address.NewAddressWitnessPubKeyHash(
+ make([]byte, 20), &chainParams,
+ )
+ require.NoError(t, err)
+ scriptOut, err := txscript.PayToAddrScript(addrOut)
+ require.NoError(t, err)
+
+ txHash := chainhash.Hash{1}
+ authoredTx := &txauthor.AuthoredTx{
+ Tx: &wire.MsgTx{
+ TxIn: []*wire.TxIn{{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: txHash,
+ Index: 0,
+ },
+ }},
+ TxOut: []*wire.TxOut{{
+ Value: 500,
+ PkScript: scriptOut,
+ }},
+ },
+ ChangeIndex: 0, // Output 0 is change
+ }
+
+ t.Run("DecorateInputs fails", func(t *testing.T) {
+ t.Parallel()
+ w, mocks := createStartedWalletWithMocks(t)
+ packet := &psbt.Packet{}
+
+ // Mock TxDetails failure (DecorateInputs ->
+ // fetchAndValidateUtxo)
+ mocks.txStore.On("TxDetails", mock.Anything, mock.Anything).
+ Return(nil, errDb)
+
+ _, _, err := w.populatePsbtPacket(
+ t.Context(), packet, authoredTx,
+ )
+ require.ErrorIs(t, err, errDb)
+ })
+
+ t.Run("addChangeOutputInfo fails", func(t *testing.T) {
+ t.Parallel()
+ w, mocks := createStartedWalletWithMocks(t)
+ packet := &psbt.Packet{}
+
+ // Mock TxDetails success (DecorateInputs)
+ txDetails := &wtxmgr.TxDetails{
+ TxRecord: wtxmgr.TxRecord{
+ MsgTx: wire.MsgTx{
+ TxOut: []*wire.TxOut{{
+ Value: 1000,
+ PkScript: scriptIn,
+ }},
+ },
+ },
+ Credits: []wtxmgr.CreditRecord{{Index: 0}},
+ }
+ mocks.txStore.On("TxDetails", mock.Anything, mock.Anything).
+ Return(txDetails, nil)
+
+ // Mock Address lookup for Input (Success)
+ mocks.addrStore.On(
+ "Address", mock.Anything,
+ mock.MatchedBy(func(a address.Address) bool {
+ return a.String() == addrIn.String()
+ }),
+ ).Return(mocks.pubKeyAddr, nil)
+
+ mocks.pubKeyAddr.On("Imported").Return(false)
+ mocks.pubKeyAddr.On("DerivationInfo").Return(
+ waddrmgr.KeyScopeBIP0084, waddrmgr.DerivationPath{},
+ true,
+ )
+ mocks.pubKeyAddr.On("PubKey").Return(pubKey)
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.WitnessPubKey)
+
+ // Mock Address lookup for Output (Fail)
+ mocks.addrStore.On(
+ "Address", mock.Anything,
+ mock.MatchedBy(func(a address.Address) bool {
+ return a.String() == addrOut.String()
+ }),
+ ).Return(nil, errDb)
+
+ _, _, err := w.populatePsbtPacket(
+ t.Context(), packet, authoredTx,
+ )
+ require.ErrorIs(t, err, errDb)
+ })
+}
+
+// TestPopulatePsbtPacketSuccess tests that populatePsbtPacket correctly
+// updates the packet with the transaction, decorates inputs, adds change info,
+// and sorts the packet.
+func TestPopulatePsbtPacketSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys/address.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ // Arrange: Create AuthoredTx with 1 input and 2 outputs.
+ // Output 0: Change (Value 1001)
+ // Output 1: Payment (Value 1000)
+ txHash := chainhash.Hash{}
+ changeOut := &wire.TxOut{
+ Value: 1001,
+ PkScript: p2wkhScript,
+ }
+ paymentOut := &wire.TxOut{
+ Value: 1000,
+ PkScript: []byte{0x00}, // Simple script
+ }
+
+ authoredTx := &txauthor.AuthoredTx{
+ Tx: &wire.MsgTx{
+ TxIn: []*wire.TxIn{{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: txHash,
+ Index: 0,
+ },
+ }},
+ TxOut: []*wire.TxOut{changeOut, paymentOut},
+ },
+ ChangeIndex: 0,
+ }
+
+ // Arrange: Create empty packet (will be overwritten).
+ packet := &psbt.Packet{}
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock TxDetails for input decoration.
+ txDetails := &wtxmgr.TxDetails{
+ TxRecord: wtxmgr.TxRecord{
+ MsgTx: wire.MsgTx{
+ TxOut: []*wire.TxOut{{
+ Value: 1000, PkScript: p2wkhScript,
+ }},
+ },
+ },
+ Credits: []wtxmgr.CreditRecord{{Index: 0}},
+ }
+ mocks.txStore.On("TxDetails", mock.Anything, mock.Anything).
+ Return(txDetails, nil)
+
+ // Arrange: Mock Address lookup (used for both input decoration and
+ // change output info).
+ mocks.addrStore.On("Address", mock.Anything, mock.Anything).
+ Return(mocks.pubKeyAddr, nil)
+
+ // Arrange: Mock ManagedPubKeyAddress methods.
+ mocks.pubKeyAddr.On("Imported").Return(false)
+ mocks.pubKeyAddr.On("DerivationInfo").Return(
+ waddrmgr.KeyScopeBIP0084, waddrmgr.DerivationPath{}, true,
+ )
+ mocks.pubKeyAddr.On("PubKey").Return(pubKey)
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.WitnessPubKey)
+ mocks.pubKeyAddr.On("Address").Return(p2wkhAddr)
+
+ // Act: Call populatePsbtPacket.
+ updatedPacket, changeIdx, err := w.populatePsbtPacket(
+ t.Context(), packet, authoredTx,
+ )
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+ require.NotNil(t, updatedPacket)
+
+ // Assert: Verify that the returned changeIdx points to the change
+ // output. We know the change output has Value 1001.
+ require.GreaterOrEqual(t, changeIdx, int32(0))
+ require.Less(t, changeIdx, int32(len(updatedPacket.UnsignedTx.TxOut)))
+ require.Equal(
+ t, int64(1001), updatedPacket.UnsignedTx.TxOut[changeIdx].Value,
+ )
+
+ // Assert: Verify that the decorated output is indeed the change
+ // output. The test setup ensures only the change address (p2wkhAddr)
+ // returns derivation info in the mock. The payment output (simple
+ // script) won't trigger address lookup that leads to derivation info
+ // in this specific mock setup.
+ require.Len(t, updatedPacket.Outputs[changeIdx].Bip32Derivation, 1)
+ require.Equal(
+ t, pubKey.SerializeCompressed(),
+ updatedPacket.Outputs[changeIdx].Bip32Derivation[0].PubKey,
+ )
+
+ // Assert: Input decorated.
+ require.Len(t, updatedPacket.Inputs, 1)
+ require.NotNil(t, updatedPacket.Inputs[0].WitnessUtxo)
+}
+
+// TestFundPsbtWorkflow tests the high-level FundPsbt workflow with manual
+// inputs.
+func TestFundPsbtWorkflow(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup private and public keys for a P2WKH address.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ // Arrange: Create a PSBT with one input and one output to simulate a
+ // transaction that needs funding and decoration.
+ // Input: 1.0 BTC (100,000,000 sat)
+ // Output: 0.5 BTC (50,000,000 sat)
+ // Fee: ~1000 sat (simplified)
+ // Expected Change: ~0.5 BTC (after fees)
+ txHash := chainhash.Hash{1}
+ outPoint := wire.OutPoint{Hash: txHash, Index: 0}
+ inputAmount := btcutil.Amount(100000000)
+ outputAmount := btcutil.Amount(50000000)
+
+ unsignedTx := wire.NewMsgTx(2)
+ unsignedTx.AddTxIn(&wire.TxIn{PreviousOutPoint: outPoint})
+ unsignedTx.AddTxOut(&wire.TxOut{
+ Value: int64(outputAmount), PkScript: p2wkhScript,
+ })
+
+ packet, err := psbt.NewFromUnsignedTx(unsignedTx)
+ require.NoError(t, err)
+
+ // Arrange: Mock data for UTXO and Transaction Details required by
+ // internal calls.
+ credit := &wtxmgr.Credit{
+ OutPoint: outPoint,
+ Amount: inputAmount,
+ PkScript: p2wkhScript,
+ }
+
+ txDetails := &wtxmgr.TxDetails{
+ TxRecord: wtxmgr.TxRecord{
+ MsgTx: wire.MsgTx{
+ TxOut: []*wire.TxOut{
+ {
+ Value: int64(inputAmount),
+ PkScript: p2wkhScript,
+ },
+ },
+ },
+ },
+ Credits: []wtxmgr.CreditRecord{
+ {Index: 0},
+ },
+ }
+
+ // Arrange: Define the FundIntent for the PSBT, including fee rate and
+ // change source.
+ intent := &FundIntent{
+ Packet: packet,
+ FeeRate: btcunit.NewSatPerKVByte(1000),
+ ChangeSource: &ScopedAccount{
+ AccountName: "default",
+ KeyScope: waddrmgr.KeyScopeBIP0084,
+ },
+ }
+
+ w, mocks := createStartedWalletWithMocks(t)
+ mocks.syncer.On("syncState").Return(syncStateSynced).Times(2)
+
+ // Arrange: Mock the internal dependencies for the FundPsbt workflow.
+
+ // --- Mock txStore ---
+ // 1. Mock `txStore.GetUtxo` for `createManualInputSource`:
+ mocks.txStore.On("GetUtxo", mock.Anything, outPoint).
+ Return(credit, nil).
+ Once()
+
+ // 7. Mock `txStore.TxDetails` for `fetchAndValidateUtxo` during
+ // `DecorateInputs`:
+ mocks.txStore.On("TxDetails", mock.Anything,
+ mock.MatchedBy(func(h *chainhash.Hash) bool {
+ return h.IsEqual(&txHash)
+ }),
+ ).Return(txDetails, nil).Once()
+
+ // --- Mock addrStore ---
+ // 2. Mock `addrStore.FetchScopedKeyManager` to retrieve the account
+ // manager:
+ mocks.addrStore.On("FetchScopedKeyManager", waddrmgr.KeyScopeBIP0084).
+ Return(mocks.accountManager, nil).Times(3)
+
+ // 8. Mock `addrStore.Address` for `decorateInput` during
+ // `DecorateInputs`:
+ mocks.addrStore.On("Address", mock.Anything,
+ mock.MatchedBy(func(addr address.Address) bool {
+ return addr.String() == p2wkhAddr.String()
+ }),
+ ).Return(mocks.pubKeyAddr, nil).Times(3)
+
+ // --- Mock accountManager ---
+ // 3. Mock `accountManager.LookupAccount` for the default account:
+ mocks.accountManager.On("LookupAccount", mock.Anything, "default").
+ Return(uint32(waddrmgr.DefaultAccountNum), nil).
+ Once()
+
+ // 4. Mock `accountManager.AccountProperties` to return properties for
+ // the default account:
+ mocks.accountManager.On("AccountProperties", mock.Anything,
+ uint32(waddrmgr.DefaultAccountNum),
+ ).Return(&waddrmgr.AccountProperties{
+ AccountName: "default",
+ KeyScope: waddrmgr.KeyScopeBIP0084,
+ }, nil).Once()
+
+ // 5. Mock `accountManager.NextInternalAddresses` to generate a change
+ // address:
+ changeAddr := p2wkhAddr // Reusing p2wkhAddr for simplicity as change
+ mockManagedAddr := mocks.pubKeyAddr
+ mocks.accountManager.On("NextInternalAddresses", mock.Anything,
+ uint32(waddrmgr.DefaultAccountNum), uint32(1),
+ ).Return([]waddrmgr.ManagedAddress{mockManagedAddr}, nil).Once()
+
+ // --- Mock pubKeyAddr (and ManagedAddress) ---
+ // 6. Mock `mockManagedAddr.Address` to return the change address:
+ mockManagedAddr.On("Address").Return(changeAddr)
+
+ // 9. Mock `ManagedPubKeyAddress` methods for `decorateInput`:
+ mocks.pubKeyAddr.On("Imported").Return(false)
+ mocks.pubKeyAddr.On("DerivationInfo").Return(
+ waddrmgr.KeyScopeBIP0084, waddrmgr.DerivationPath{}, true,
+ )
+ mocks.pubKeyAddr.On("PubKey").Return(pubKey)
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.WitnessPubKey)
+
+ // Act: Execute the FundPsbt workflow with the configured intent.
+ fundedPacket, changeIndex, err := w.FundPsbt(t.Context(), intent)
+
+ // Assert: Verify that no error occurred, a funded PSBT packet is
+ // returned, and a valid change index is provided.
+ require.NoError(t, err)
+ require.NotNil(t, fundedPacket)
+ require.GreaterOrEqual(t, changeIndex, int32(0))
+}
+
+// TestFundPsbtDecorateFailure tests that FundPsbt returns an error if the
+// internal DecorateInputs call fails (e.g. due to database error).
+func TestFundPsbtDecorateFailure(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys/address.
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ make([]byte, 20), &chainParams,
+ )
+ require.NoError(t, err)
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ // Arrange: Create packet with 1 input.
+ txHash := chainhash.Hash{1}
+ outPoint := wire.OutPoint{Hash: txHash, Index: 0}
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{PreviousOutPoint: outPoint})
+ tx.AddTxOut(&wire.TxOut{Value: 90000, PkScript: p2wkhScript})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ // Arrange: Intent with manual inputs (so CreateTransaction uses
+ // GetUtxo).
+ intent := &FundIntent{
+ Packet: packet,
+ FeeRate: btcunit.NewSatPerKVByte(1000),
+ ChangeSource: &ScopedAccount{
+ AccountName: "default",
+ KeyScope: waddrmgr.KeyScopeBIP0084,
+ },
+ }
+
+ w, mocks := createStartedWalletWithMocks(t)
+ mocks.syncer.On("syncState").Return(syncStateSynced).Times(2)
+
+ // Arrange: Mock GetUtxo for CreateTransaction (Success).
+
+ credit := &wtxmgr.Credit{
+ OutPoint: outPoint,
+ Amount: 100000,
+ PkScript: p2wkhScript,
+ }
+ mocks.txStore.On("GetUtxo", mock.Anything, outPoint).Return(credit, nil)
+
+ // Arrange: Mock TxDetails for DecorateInputs (Failure).
+ // This triggers the error in populatePsbtPacket -> DecorateInputs.
+ mocks.txStore.On("TxDetails", mock.Anything,
+ mock.MatchedBy(func(h *chainhash.Hash) bool {
+ return h.IsEqual(&txHash)
+ }),
+ ).Return(nil, errDb)
+
+ // Arrange: Mock account manager for change address generation, which is
+ // required because the input (100k) exceeds the output (90k) + fees.
+ mocks.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(mocks.accountManager, nil)
+
+ mocks.accountManager.On("LookupAccount", mock.Anything, "default").
+ Return(uint32(waddrmgr.DefaultAccountNum), nil)
+ mocks.accountManager.On("AccountProperties", mock.Anything,
+ uint32(waddrmgr.DefaultAccountNum),
+ ).Return(&waddrmgr.AccountProperties{
+ AccountName: "default",
+ KeyScope: waddrmgr.KeyScopeBIP0084,
+ }, nil)
+
+ // Change address generation.
+ mocks.accountManager.On(
+ "NextInternalAddresses", mock.Anything, mock.Anything,
+ mock.Anything,
+ ).Return([]waddrmgr.ManagedAddress{mocks.pubKeyAddr}, nil)
+ mocks.pubKeyAddr.On("Address").Return(p2wkhAddr)
+
+ // Act: FundPsbt.
+ _, _, err = w.FundPsbt(t.Context(), intent)
+
+ // Assert: Should fail due to DecorateInputs error.
+ require.ErrorIs(t, err, errDb)
+}
+
+// TestFundPsbtErrors tests various error conditions in FundPsbt.
+func TestFundPsbtErrors(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Common intent setup (auto coin selection).
+ tx := wire.NewMsgTx(2)
+ tx.AddTxOut(&wire.TxOut{Value: 1000, PkScript: []byte{0x00}})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ intent := &FundIntent{
+ Packet: packet,
+ FeeRate: btcunit.NewSatPerKVByte(1000),
+ // Policy implies auto selection
+ Policy: &InputsPolicy{
+ Source: &ScopedAccount{
+ AccountName: "default",
+ KeyScope: waddrmgr.KeyScopeBIP0084,
+ },
+ },
+ }
+
+ t.Run("validate intent fails", func(t *testing.T) {
+ t.Parallel()
+ w, mocks := createStartedWalletWithMocks(t)
+ mocks.syncer.On("syncState").Return(syncStateSynced).Once()
+ // Invalid intent (nil packet)
+ _, _, err := w.FundPsbt(t.Context(), &FundIntent{})
+ require.ErrorIs(t, err, ErrNilTxIntent)
+ })
+
+ t.Run("CreateTransaction fails", func(t *testing.T) {
+ t.Parallel()
+ w, mocks := createStartedWalletWithMocks(t)
+ mocks.syncer.On("syncState").Return(syncStateSynced).Times(2)
+
+ // Mock CreateTransaction failure via Account lookup failure
+ mocks.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(nil, errDb)
+
+ _, _, err := w.FundPsbt(t.Context(), intent)
+
+ // AccountNumber failure is wrapped in ErrAccountNotFound by
+ // prepareTxAuthSources.
+ require.ErrorIs(t, err, ErrAccountNotFound)
+ })
+}
+
+// TestParseBip32Path tests that parseBip32Path correctly parses valid BIP32
+// paths and returns the appropriate KeyScope and DerivationPath, while also
+// flagging invalid paths.
+func TestParseBip32Path(t *testing.T) {
+ t.Parallel()
+
+ // Use mainnet params for testing (HDCoinType = 0).
+ chainParams := &chaincfg.MainNetParams
+ w := &Wallet{
+ cfg: Config{
+ ChainParams: chainParams,
+ },
+ walletDeprecated: &walletDeprecated{
+ chainParams: chainParams,
+ },
+ }
+
+ hardened := func(i uint32) uint32 {
+ return i + hdkeychain.HardenedKeyStart
+ }
+
+ tests := []struct {
+ name string
+ path []uint32
+ wantPath BIP32Path
+ expectedErr error // Use error type for require.ErrorIs
+ }{
+ {
+ name: "valid BIP44",
+ path: []uint32{
+ hardened(44), hardened(0), hardened(0), 0, 0,
+ },
+ wantPath: BIP32Path{
+ KeyScope: waddrmgr.KeyScopeBIP0044,
+ DerivationPath: waddrmgr.DerivationPath{
+ Account: 0,
+ Branch: 0,
+ Index: 0,
+ },
+ },
+ },
+ {
+ name: "valid BIP84",
+ path: []uint32{
+ hardened(84), hardened(0), hardened(1), 0, 5,
+ },
+ wantPath: BIP32Path{
+ KeyScope: waddrmgr.KeyScopeBIP0084,
+ DerivationPath: waddrmgr.DerivationPath{
+ Account: 1,
+ Branch: 0,
+ Index: 5,
+ },
+ },
+ },
+ {
+ name: "invalid length",
+ path: []uint32{hardened(84)},
+ expectedErr: ErrInvalidBip32Path,
+ },
+ {
+ name: "unhardened purpose",
+ path: []uint32{
+ 84, hardened(0), hardened(0), 0, 0,
+ },
+ expectedErr: ErrInvalidBip32Path,
+ },
+ {
+ name: "unhardened coin type",
+ path: []uint32{
+ hardened(84), 0, hardened(0), 0, 0,
+ },
+ expectedErr: ErrInvalidBip32Path,
+ },
+ {
+ name: "unhardened account",
+ path: []uint32{
+ hardened(84), hardened(0), 0, 0, 0,
+ },
+ expectedErr: ErrInvalidBip32Path,
+ },
+ {
+ name: "coin type mismatch",
+ path: []uint32{
+ hardened(84), hardened(1), hardened(0), 0, 0,
+ },
+ expectedErr: ErrInvalidBip32Path,
+ },
+ {
+ name: "unknown purpose (now allowed in parseBip32Path)",
+ path: []uint32{
+ hardened(999), hardened(0), hardened(0), 0, 0,
+ },
+ wantPath: BIP32Path{
+ KeyScope: waddrmgr.KeyScope{
+ Purpose: 999, Coin: 0,
+ },
+ DerivationPath: waddrmgr.DerivationPath{
+ Account: 0,
+ Branch: 0,
+ Index: 0,
+ },
+ },
+ expectedErr: nil,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Act: Call parseBip32Path with the test path.
+ gotPath, err := w.parseBip32Path(tc.path)
+
+ // Assert: Verify that the function returns the expected
+ // error (if any) or that the parsed path components
+ // (KeyScope, DerivationPath) match the expected
+ // structure.
+ require.ErrorIs(t, err, tc.expectedErr)
+ require.Equal(t, tc.wantPath, gotPath)
+ })
+ }
+}
+
+// TestAddressTypeFromPurpose tests that addressTypeFromPurpose returns the
+// correct AddressType for supported BIP32 purposes and returns an error for
+// unknown purposes.
+func TestAddressTypeFromPurpose(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ purpose uint32
+ want waddrmgr.AddressType
+ expectedErr error
+ }{
+ {
+ name: "BIP44",
+ purpose: waddrmgr.KeyScopeBIP0044.Purpose,
+ want: waddrmgr.PubKeyHash,
+ },
+ {
+ name: "BIP49",
+ purpose: waddrmgr.KeyScopeBIP0049Plus.Purpose,
+ want: waddrmgr.NestedWitnessPubKey,
+ },
+ {
+ name: "BIP84",
+ purpose: waddrmgr.KeyScopeBIP0084.Purpose,
+ want: waddrmgr.WitnessPubKey,
+ },
+ {
+ name: "BIP86",
+ purpose: waddrmgr.KeyScopeBIP0086.Purpose,
+ want: waddrmgr.TaprootPubKey,
+ },
+ {
+ name: "unknown",
+ purpose: 999,
+ expectedErr: ErrUnknownBip32Purpose,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Act: Call addressTypeFromPurpose with the test
+ // purpose.
+ got, err := addressTypeFromPurpose(tc.purpose)
+
+ // Assert: Verify that the returned address type matches
+ // the expected type for the given purpose, or that the
+ // expected error is returned for unknown purposes.
+ require.ErrorIs(t, err, tc.expectedErr)
+ require.Equal(t, tc.want, got)
+ })
+ }
+}
+
+// TestShouldSkipInput tests that shouldSkipInput correctly identifies inputs
+// that should be skipped during signing (e.g., finalized inputs, inputs with
+// no derivation info) and those that should be processed.
+func TestShouldSkipInput(t *testing.T) {
+ t.Parallel()
+
+ // Define shared variables for long literals to satisfy linter.
+ taprootDerivation := []*psbt.TaprootBip32Derivation{{
+ XOnlyPubKey: []byte{0x01},
+ }}
+
+ tests := []struct {
+ name string
+ pInput *psbt.PInput
+ expected bool
+ }{
+ {
+ name: "finalized input should be skipped",
+ pInput: &psbt.PInput{
+ FinalScriptWitness: []byte{1, 2, 3},
+ },
+ expected: true,
+ },
+ {
+ name: "no derivation info should be skipped",
+ pInput: &psbt.PInput{
+ FinalScriptWitness: nil,
+ Bip32Derivation: nil,
+ TaprootBip32Derivation: nil,
+ },
+ expected: true,
+ },
+ {
+ name: "valid BIP32 derivation, not skipped",
+ pInput: &psbt.PInput{
+ Bip32Derivation: []*psbt.Bip32Derivation{{
+ PubKey: []byte{0x01},
+ }},
+ },
+ expected: false,
+ },
+ {
+ name: "valid Taproot derivation, not skipped",
+ pInput: &psbt.PInput{
+ TaprootBip32Derivation: taprootDerivation,
+ },
+ expected: false,
+ },
+ }
+
+ for i, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Act: Call shouldSkipInput with the configured PSBT
+ // input.
+ result := shouldSkipInput(tc.pInput, i)
+
+ // Assert: Verify that the returned boolean matches the
+ // expectation (true for skippable inputs, false
+ // otherwise).
+ require.Equal(t, tc.expected, result)
+ })
+ }
+}
+
+// TestShouldSkipSigningError tests that shouldSkipSigningError correctly
+// determines whether a signing error is non-critical (and thus the input can
+// be skipped) or critical.
+func TestShouldSkipSigningError(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ err error
+ expected bool
+ }{
+ {
+ name: "already signed error should be skipped",
+ err: errAlreadySigned,
+ expected: true,
+ },
+ {
+ name: "compute raw sig error should be skipped",
+ err: fmt.Errorf("wrapped: %w", errComputeRawSig),
+ expected: true,
+ },
+ {
+ name: "unknown BIP32 purpose error should be " +
+ "skipped",
+ err: ErrUnknownBip32Purpose,
+ expected: true,
+ },
+
+ {
+ name: "generic error should not be skipped",
+ err: errDb,
+ expected: false,
+ },
+ }
+
+ for i, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Act: Call shouldSkipSigningError with the test error.
+ result := shouldSkipSigningError(tc.err, i)
+
+ // Assert: Verify that the function correctly identifies
+ // whether the error should cause the input to be
+ // skipped (true) or treated as a failure (false).
+ require.Equal(t, tc.expected, result)
+ })
+ }
+}
+
+// TestValidateDerivation tests that validateDerivation correctly identifies the
+// derivation type (Taproot vs. BIP32) and validates that there are no
+// conflicting or multiple derivation paths.
+func TestValidateDerivation(t *testing.T) {
+ t.Parallel()
+
+ // Define shared variables for long literals.
+ taprootDerivation := []*psbt.TaprootBip32Derivation{{}}
+ multiTapDerivation := []*psbt.TaprootBip32Derivation{{}, {}}
+ multiBip32Derivation := []*psbt.Bip32Derivation{{}, {}}
+ singleBip32Derivation := []*psbt.Bip32Derivation{{}}
+
+ // Arrange: Define test cases for derivation validation.
+ tests := []struct {
+ name string
+ pInput *psbt.PInput
+ isTaproot bool
+ err error
+ }{
+ {
+ name: "single BIP32 derivation",
+ pInput: &psbt.PInput{
+ Bip32Derivation: []*psbt.Bip32Derivation{{}},
+ },
+ isTaproot: false,
+ err: nil,
+ },
+ {
+ name: "single Taproot derivation",
+ pInput: &psbt.PInput{
+ TaprootBip32Derivation: taprootDerivation,
+ },
+ isTaproot: true,
+ err: nil,
+ },
+ {
+ name: "multiple BIP32 derivations error",
+ pInput: &psbt.PInput{
+ Bip32Derivation: multiBip32Derivation,
+ },
+ isTaproot: false,
+ err: ErrUnsupportedMultipleBip32Derivation,
+ },
+ {
+ name: "multiple Taproot derivations error",
+ pInput: &psbt.PInput{
+ TaprootBip32Derivation: multiTapDerivation,
+ },
+ isTaproot: false,
+ err: ErrUnsupportedMultipleTaprootDerivation,
+ },
+ {
+ name: "ambiguous derivation",
+ pInput: &psbt.PInput{
+ Bip32Derivation: singleBip32Derivation,
+ TaprootBip32Derivation: taprootDerivation,
+ },
+ isTaproot: false,
+ err: ErrAmbiguousDerivation,
+ },
+ {
+ name: "no derivation info (valid)",
+ pInput: &psbt.PInput{},
+ isTaproot: false,
+ err: nil,
+ },
+ }
+
+ for i, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ // Act: Call validateDerivation with the configured
+ // input to check for validity and determine if it's a
+ // Taproot input.
+ isTaproot, err := validateDerivation(tc.pInput, i)
+
+ // Assert: Verify that the returned error matches the
+ // expected error (e.g., for ambiguous or multiple
+ // derivations) and that the Taproot flag is set
+ // correctly for valid inputs.
+ require.ErrorIs(t, err, tc.err)
+ require.Equal(t, tc.isTaproot, isTaproot)
+ })
+ }
+}
+
+// TestFetchPsbtUtxo tests that fetchPsbtUtxo correctly prioritizes WitnessUtxo
+// over NonWitnessUtxo when retrieving the UTXO for a PSBT input, and safely
+// handles missing data.
+func TestFetchPsbtUtxo(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create dummy UTXOs for testing.
+ witnessUtxo := &wire.TxOut{Value: 1000, PkScript: []byte{0x00}}
+ nonWitnessUtxo := &wire.TxOut{Value: 2000, PkScript: []byte{0x01}}
+
+ // Arrange: Create a transaction that has the nonWitnessUtxo at index 0
+ tx := wire.NewMsgTx(2)
+ tx.AddTxOut(nonWitnessUtxo)
+
+ // Arrange: Define test cases for fetching UTXOs.
+ tests := []struct {
+ name string
+ packet *psbt.Packet
+ inputIdx int
+ expected *wire.TxOut
+ expectedErr error
+ }{
+ {
+ name: "prioritize WitnessUtxo",
+ // Arrange: PSBT input with both WitnessUtxo and
+ // NonWitnessUtxo.
+ packet: &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{
+ TxIn: []*wire.TxIn{{}},
+ },
+ Inputs: []psbt.PInput{{
+ WitnessUtxo: witnessUtxo,
+ NonWitnessUtxo: tx,
+ }},
+ },
+ inputIdx: 0,
+ // Assert: Expect WitnessUtxo to be returned.
+ expected: witnessUtxo,
+ expectedErr: nil,
+ },
+ {
+ name: "fallback to NonWitnessUtxo",
+ // Arrange: PSBT input with only NonWitnessUtxo.
+ packet: &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{
+ TxIn: []*wire.TxIn{{
+ PreviousOutPoint: wire.OutPoint{
+ Index: 0,
+ },
+ }},
+ },
+ Inputs: []psbt.PInput{{
+ WitnessUtxo: nil,
+ NonWitnessUtxo: tx,
+ }},
+ },
+ inputIdx: 0,
+ // Assert: Expect NonWitnessUtxo to be returned.
+ expected: nonWitnessUtxo,
+ expectedErr: nil,
+ },
+ {
+ name: "missing all utxo info",
+ packet: &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{
+ TxIn: []*wire.TxIn{{}},
+ },
+ Inputs: []psbt.PInput{{
+ WitnessUtxo: nil,
+ NonWitnessUtxo: nil,
+ }},
+ },
+ inputIdx: 0,
+ expected: nil,
+ expectedErr: ErrInputMissingUtxoInfo,
+ },
+ {
+ name: "input index out of bounds",
+ packet: &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{
+ TxIn: []*wire.TxIn{},
+ },
+ Inputs: []psbt.PInput{},
+ },
+ inputIdx: 0,
+ expected: nil,
+ expectedErr: ErrIndexOutOfBounds,
+ },
+ {
+ name: "prevout index out of bounds",
+ packet: &psbt.Packet{
+ UnsignedTx: &wire.MsgTx{
+ TxIn: []*wire.TxIn{{
+ PreviousOutPoint: wire.OutPoint{
+ Index: 99,
+ },
+ }},
+ },
+ Inputs: []psbt.PInput{{
+ NonWitnessUtxo: tx,
+ }},
+ },
+ inputIdx: 0,
+ expected: nil,
+ expectedErr: ErrIndexOutOfBounds,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ // Act: Call fetchPsbtUtxo with the configured packet
+ // and input index.
+ got, err := fetchPsbtUtxo(tc.packet, tc.inputIdx)
+
+ // Assert: Verify that the function returns the correct
+ // UTXO (or nil on error) and the expected error
+ // status.
+ require.ErrorIs(t, err, tc.expectedErr)
+ require.Equal(t, tc.expected, got)
+ })
+ }
+}
+
+// TestCheckTaprootScriptSpendSig tests that checkTaprootScriptSpendSig
+// correctly detects if a Taproot script spend signature already exists for the
+// given key and leaf.
+func TestCheckTaprootScriptSpendSig(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create dummy public key and leaf hash for Taproot
+ // signatures.
+ xOnlyPubKey := bytes.Repeat([]byte{0x01}, 32)
+ leafHash := bytes.Repeat([]byte{0x02}, 32)
+
+ // Pre-define complex slice literals.
+ diffKeySig := []*psbt.TaprootScriptSpendSig{
+ {
+ XOnlyPubKey: bytes.Repeat(
+ []byte{0x03}, 32,
+ ),
+ LeafHash: leafHash,
+ },
+ }
+
+ sameKeySig := []*psbt.TaprootScriptSpendSig{
+ {
+ XOnlyPubKey: xOnlyPubKey,
+ LeafHash: leafHash,
+ },
+ }
+
+ // Arrange: Define test cases for checking existing Taproot script
+ // spend signatures.
+ tests := []struct {
+ name string
+ pInput *psbt.PInput
+ tapDerivation *psbt.TaprootBip32Derivation
+ err error
+ }{
+ {
+ name: "no existing signature",
+ // Arrange: No TaprootScriptSpendSig in the input.
+ pInput: &psbt.PInput{
+ TaprootScriptSpendSig: nil,
+ },
+ tapDerivation: &psbt.TaprootBip32Derivation{
+ XOnlyPubKey: xOnlyPubKey,
+ LeafHashes: [][]byte{leafHash},
+ },
+ err: nil,
+ },
+ {
+ name: "existing signature for different key",
+ // Arrange: A TaprootScriptSpendSig exists, but for a
+ // different XOnlyPubKey.
+ pInput: &psbt.PInput{
+ TaprootScriptSpendSig: diffKeySig,
+ },
+ tapDerivation: &psbt.TaprootBip32Derivation{
+ XOnlyPubKey: xOnlyPubKey,
+ LeafHashes: [][]byte{leafHash},
+ },
+ err: nil,
+ },
+ {
+ name: "existing signature for same key and leaf",
+ // Arrange: A matching TaprootScriptSpendSig already
+ // exists.
+ pInput: &psbt.PInput{
+ TaprootScriptSpendSig: sameKeySig,
+ },
+ tapDerivation: &psbt.TaprootBip32Derivation{
+ XOnlyPubKey: xOnlyPubKey,
+ LeafHashes: [][]byte{leafHash},
+ },
+ // Assert: Expect errAlreadySigned.
+ err: errAlreadySigned,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ // Act: Call checkTaprootScriptSpendSig with the
+ // configured input and derivation info.
+ err := checkTaprootScriptSpendSig(
+ tc.pInput, tc.tapDerivation,
+ )
+
+ // Assert: Verify that the function returns an error
+ // only if a valid signature for the same key and leaf
+ // hash already exists.
+ require.ErrorIs(t, err, tc.err)
+ })
+ }
+}
+
+// TestAddTaprootSigToPInput tests that addTaprootSigToPInput correctly adds a
+// generated Taproot signature to the PSBT input, handling both Key Spend and
+// Script Spend paths.
+func TestAddTaprootSigToPInput(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Define dummy signature, public key, and leaf hash.
+ sig := []byte{0x01, 0x02}
+ xOnlyPubKey := bytes.Repeat([]byte{0x03}, 32)
+ leafHash := bytes.Repeat([]byte{0x04}, 32)
+
+ // Helper to create signature with appended sighash
+ sigWithHash := append(slices.Clone(sig), byte(txscript.SigHashAll))
+
+ // Arrange: Define test cases for adding Taproot signatures.
+ tests := []struct {
+ name string
+ initialPInput *psbt.PInput
+ sighashType txscript.SigHashType
+ details TaprootSpendDetails
+ tapDerivation *psbt.TaprootBip32Derivation
+ expectedPInput *psbt.PInput
+ }{
+ {
+ name: "key path spend default sighash",
+ // Arrange: Initial empty PSBT input.
+ initialPInput: &psbt.PInput{},
+ sighashType: txscript.SigHashDefault,
+ // Arrange: Key path spend details.
+ details: TaprootSpendDetails{
+ SpendPath: KeyPathSpend,
+ },
+ tapDerivation: nil, // Not used for key path spend
+ // Assert: Expect TaprootKeySpendSig to be set.
+ expectedPInput: &psbt.PInput{
+ TaprootKeySpendSig: sig,
+ },
+ },
+ {
+ name: "key path spend non-default sighash",
+ // Arrange: Initial empty PSBT input.
+ initialPInput: &psbt.PInput{},
+ sighashType: txscript.SigHashAll,
+ // Arrange: Key path spend details.
+ details: TaprootSpendDetails{
+ SpendPath: KeyPathSpend,
+ },
+ tapDerivation: nil,
+ // Assert: Expect TaprootKeySpendSig with appended
+ // sighash.
+ expectedPInput: &psbt.PInput{
+ TaprootKeySpendSig: sigWithHash,
+ },
+ },
+ {
+ name: "script path spend",
+ // Arrange: Initial PSBT input with default SighashType.
+ initialPInput: &psbt.PInput{
+ SighashType: txscript.SigHashDefault,
+ },
+ sighashType: txscript.SigHashDefault,
+ // Arrange: Script path spend details.
+ details: TaprootSpendDetails{
+ SpendPath: ScriptPathSpend,
+ },
+ // Arrange: Taproot BIP32 derivation with XOnlyPubKey
+ // and LeafHashes.
+ tapDerivation: &psbt.TaprootBip32Derivation{
+ XOnlyPubKey: xOnlyPubKey,
+ LeafHashes: [][]byte{leafHash},
+ },
+ // Assert: Expect TaprootScriptSpendSig to be appended.
+ expectedPInput: &psbt.PInput{
+ SighashType: txscript.SigHashDefault,
+ TaprootScriptSpendSig: []*psbt.TaprootScriptSpendSig{{
+ XOnlyPubKey: xOnlyPubKey,
+ LeafHash: leafHash,
+ Signature: sig,
+ SigHash: txscript.SigHashDefault,
+ }},
+ },
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ // Arrange: Create a copy of the initial input to ensure
+ // test isolation and avoid side effects.
+ pInput := *tc.initialPInput
+
+ // Act: Call addTaprootSigToPInput to add the signature
+ // to the PSBT input.
+ addTaprootSigToPInput(
+ &pInput, sig, tc.sighashType, tc.details,
+ tc.tapDerivation,
+ )
+
+ // Assert: Verify that the resulting PSBT input matches
+ // the expected state.
+ require.Equal(t, tc.expectedPInput, &pInput)
+ })
+ }
+}
+
+// TestAddBip32SigToPInput tests that addBip32SigToPInput correctly adds a
+// generated BIP32 signature to the PSBT input's partial signatures, appending
+// the sighash type if necessary.
+func TestAddBip32SigToPInput(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Define dummy signature and public key.
+ sig := []byte{0x01, 0x02}
+ pubKey := bytes.Repeat([]byte{0x03}, 33)
+
+ // Helper to create signature with appended sighash
+ sigWithHash := append(slices.Clone(sig), byte(txscript.SigHashAll))
+
+ // Arrange: Define test cases for adding BIP32 signatures.
+ tests := []struct {
+ name string
+ initialPInput *psbt.PInput
+ sighashType txscript.SigHashType
+ addrType waddrmgr.AddressType
+ derivation *psbt.Bip32Derivation
+ expectedPInput *psbt.PInput
+ }{
+ {
+ name: "legacy p2pkh (no sighash append)",
+ // Arrange: Initial empty PSBT input.
+ initialPInput: &psbt.PInput{},
+ sighashType: txscript.SigHashAll,
+ // Arrange: Public Key Hash address type.
+ addrType: waddrmgr.PubKeyHash,
+ // Arrange: BIP32 derivation with PubKey.
+ derivation: &psbt.Bip32Derivation{
+ PubKey: pubKey,
+ },
+ // Assert: Expect PartialSigs to be appended with raw
+ // sig.
+ expectedPInput: &psbt.PInput{
+ PartialSigs: []*psbt.PartialSig{{
+ PubKey: pubKey,
+ Signature: sig,
+ }},
+ },
+ },
+ {
+ name: "segwit p2wkh (append sighash)",
+ // Arrange: Initial empty PSBT input.
+ initialPInput: &psbt.PInput{},
+ sighashType: txscript.SigHashAll,
+ // Arrange: Witness Public Key Hash address type.
+ addrType: waddrmgr.WitnessPubKey,
+ // Arrange: BIP32 derivation with PubKey.
+ derivation: &psbt.Bip32Derivation{
+ PubKey: pubKey,
+ },
+ // Assert: Expect PartialSigs to be appended with sig +
+ // sighash.
+ expectedPInput: &psbt.PInput{
+ PartialSigs: []*psbt.PartialSig{{
+ PubKey: pubKey,
+ Signature: sigWithHash,
+ }},
+ },
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ // Arrange: Create a copy of the initial input to ensure
+ // test isolation.
+ pInput := *tc.initialPInput
+
+ // Act: Call addBip32SigToPInput to add the signature
+ // to the PSBT input.
+ addBip32SigToPInput(
+ &pInput, sig, tc.sighashType, tc.derivation,
+ tc.addrType,
+ )
+
+ // Assert: Verify that the resulting PSBT input matches
+ // the expected state.
+ require.Equal(t, tc.expectedPInput, &pInput)
+ })
+ }
+}
+
+// TestCreateTaprootSpendDetails tests that createTaprootSpendDetails correctly
+// constructs the TaprootSpendDetails required for signing, handling both Key
+// Path and Script Path spends.
+func TestCreateTaprootSpendDetails(t *testing.T) {
+ t.Parallel()
+
+ // Helpers
+ xOnlyPubKey := bytes.Repeat([]byte{0x01}, 32)
+ leafHash := bytes.Repeat([]byte{0x02}, 32)
+ leafScript := []byte{0x51} // OP_TRUE
+ merkleRoot := bytes.Repeat([]byte{0x03}, 32)
+
+ // Calculate expected hash for success case
+ tapLeaf := txscript.NewBaseTapLeaf(leafScript)
+ tapHash := tapLeaf.TapHash()
+ leafHashCalculated := tapHash[:]
+
+ // Define slice literals.
+ tapLeafScriptSuccess := []*psbt.TaprootTapLeafScript{{
+ LeafVersion: txscript.BaseLeafVersion,
+ Script: leafScript,
+ }}
+
+ tests := []struct {
+ name string
+ pInput *psbt.PInput
+ tapDerivation *psbt.TaprootBip32Derivation
+ expected TaprootSpendDetails
+ err error
+ }{
+ {
+ name: "key path spend success",
+ pInput: &psbt.PInput{
+ TaprootMerkleRoot: merkleRoot,
+ },
+ tapDerivation: &psbt.TaprootBip32Derivation{
+ XOnlyPubKey: xOnlyPubKey,
+ LeafHashes: nil, // Empty -> Key Path
+ },
+ expected: TaprootSpendDetails{
+ SpendPath: KeyPathSpend,
+ Tweak: merkleRoot,
+ },
+ err: nil,
+ },
+ {
+ name: "key path spend invalid merkle root length",
+ pInput: &psbt.PInput{
+ // Invalid length
+ TaprootMerkleRoot: []byte{0x01},
+ },
+ tapDerivation: &psbt.TaprootBip32Derivation{
+ XOnlyPubKey: xOnlyPubKey,
+ LeafHashes: nil,
+ },
+ expected: TaprootSpendDetails{},
+ err: ErrInvalidTaprootMerkleRootLength,
+ },
+ {
+ name: "key path spend already signed",
+ pInput: &psbt.PInput{
+ // Already signed
+ TaprootKeySpendSig: []byte{0x01},
+ },
+ tapDerivation: &psbt.TaprootBip32Derivation{
+ XOnlyPubKey: xOnlyPubKey,
+ LeafHashes: nil,
+ },
+ expected: TaprootSpendDetails{
+ SpendPath: KeyPathSpend,
+ Tweak: nil,
+ },
+ err: errAlreadySigned,
+ },
+ {
+ name: "script path spend success",
+ pInput: &psbt.PInput{
+ TaprootLeafScript: tapLeafScriptSuccess,
+ },
+ tapDerivation: &psbt.TaprootBip32Derivation{
+ XOnlyPubKey: xOnlyPubKey,
+ LeafHashes: [][]byte{leafHashCalculated},
+ },
+ expected: TaprootSpendDetails{
+ SpendPath: ScriptPathSpend,
+ WitnessScript: leafScript,
+ },
+ err: nil,
+ },
+ {
+ name: "script path spend mismatch hash",
+ pInput: &psbt.PInput{
+ TaprootLeafScript: tapLeafScriptSuccess,
+ },
+ tapDerivation: &psbt.TaprootBip32Derivation{
+ XOnlyPubKey: xOnlyPubKey,
+ LeafHashes: [][]byte{leafHash}, // Mismatch
+ },
+ expected: TaprootSpendDetails{},
+ err: ErrTaprootLeafHashMismatch,
+ },
+ {
+ name: "script path spend missing script",
+ pInput: &psbt.PInput{
+ TaprootLeafScript: nil, // Missing
+ },
+ tapDerivation: &psbt.TaprootBip32Derivation{
+ XOnlyPubKey: xOnlyPubKey,
+ LeafHashes: [][]byte{leafHash},
+ },
+ expected: TaprootSpendDetails{},
+ err: ErrMissingTaprootLeafScript,
+ },
+ {
+ name: "script path spend multiple leaves unsupported",
+ pInput: &psbt.PInput{},
+ tapDerivation: &psbt.TaprootBip32Derivation{
+ XOnlyPubKey: xOnlyPubKey,
+ LeafHashes: [][]byte{leafHash, leafHash},
+ },
+ expected: TaprootSpendDetails{},
+ err: ErrUnsupportedTaprootLeafCount,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Act: Call createTaprootSpendDetails with the
+ // configured input and derivation.
+ details, err := createTaprootSpendDetails(
+ tc.pInput, tc.tapDerivation,
+ )
+
+ // Assert: Verify that the returned spend details match
+ // the expected values (SpendPath, Tweak, or
+ // WitnessScript) and that any expected errors are
+ // returned.
+ require.ErrorIs(t, err, tc.err)
+ require.Equal(t, tc.expected, details)
+ })
+ }
+}
+
+// TestCreateBip32SpendDetails tests that createBip32SpendDetails correctly
+// constructs the SpendDetails required for signing BIP32 inputs, supporting
+// various address types.
+func TestCreateBip32SpendDetails(t *testing.T) {
+ t.Parallel()
+
+ pubKey := bytes.Repeat([]byte{0x02}, 33)
+ sig := []byte{0x01}
+
+ tests := []struct {
+ name string
+ pInput *psbt.PInput
+ utxo *wire.TxOut
+ addrType waddrmgr.AddressType
+ derivation *psbt.Bip32Derivation
+ expected SpendDetails
+ err error
+ }{
+ {
+ name: "p2wkh success",
+ pInput: &psbt.PInput{
+ WitnessScript: []byte{0x03},
+ },
+ utxo: &wire.TxOut{},
+ addrType: waddrmgr.WitnessPubKey,
+ derivation: &psbt.Bip32Derivation{
+ PubKey: pubKey,
+ },
+ expected: SegwitV0SpendDetails{
+ WitnessScript: []byte{0x03},
+ },
+ err: nil,
+ },
+ {
+ name: "p2pkh success",
+ pInput: &psbt.PInput{},
+ utxo: &wire.TxOut{
+ PkScript: []byte{0x04},
+ },
+ addrType: waddrmgr.PubKeyHash,
+ derivation: &psbt.Bip32Derivation{
+ PubKey: pubKey,
+ },
+ expected: LegacySpendDetails{
+ RedeemScript: []byte{0x04},
+ },
+ err: nil,
+ },
+ {
+ name: "nested p2wkh success",
+ pInput: &psbt.PInput{
+ RedeemScript: []byte{0x05},
+ },
+ utxo: &wire.TxOut{},
+ addrType: waddrmgr.NestedWitnessPubKey,
+ derivation: &psbt.Bip32Derivation{
+ PubKey: pubKey,
+ },
+ expected: SegwitV0SpendDetails{
+ WitnessScript: []byte{0x05},
+ },
+ err: nil,
+ },
+ {
+ name: "unknown address type",
+ pInput: &psbt.PInput{},
+ utxo: &wire.TxOut{},
+ addrType: waddrmgr.Script, // Not supported
+ derivation: &psbt.Bip32Derivation{
+ PubKey: pubKey,
+ },
+ expected: nil,
+ err: ErrUnknownAddressType,
+ },
+ {
+ name: "already signed",
+ pInput: &psbt.PInput{
+ PartialSigs: []*psbt.PartialSig{{
+ PubKey: pubKey,
+ Signature: sig,
+ }},
+ },
+ utxo: &wire.TxOut{},
+ addrType: waddrmgr.WitnessPubKey,
+ derivation: &psbt.Bip32Derivation{
+ PubKey: pubKey,
+ },
+ expected: nil,
+ err: errAlreadySigned,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Act: Call createBip32SpendDetails with the configured
+ // input and UTXO information.
+ details, err := createBip32SpendDetails(
+ tc.pInput, tc.utxo, tc.addrType, tc.derivation,
+ )
+
+ // Assert: Verify that the returned spend details
+ // correctly reflect the address type (Legacy vs Segwit)
+ // and contain the expected scripts, or that an error is
+ // returned for invalid states.
+ require.ErrorIs(t, err, tc.err)
+ require.Equal(t, tc.expected, details)
+ })
+ }
+}
+
+// TestSignTaprootPsbtInput tests that signTaprootPsbtInput successfully
+// generates and appends a signature for a valid Taproot input.
+func TestSignTaprootPsbtInput(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+ xOnlyPubKey := schnorr.SerializePubKey(pubKey)
+
+ // Arrange: Define the BIP32 derivation path and Taproot derivation
+ // information for the input key.
+ derivationPath := []uint32{
+ hdkeychain.HardenedKeyStart + 86,
+ hdkeychain.HardenedKeyStart + 1,
+ hdkeychain.HardenedKeyStart + 0,
+ 0, 0,
+ }
+ tapDerivation := &psbt.TaprootBip32Derivation{
+ XOnlyPubKey: xOnlyPubKey,
+ Bip32Path: derivationPath,
+ }
+
+ // Arrange: Create a dummy UTXO with a Taproot script to be signed.
+ utxo := &wire.TxOut{
+ Value: 1000,
+ // Dummy Taproot script
+ PkScript: bytes.Repeat([]byte{0x51}, 34),
+ }
+
+ // Arrange: Create a PSBT packet containing the transaction with the
+ // Taproot input.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ packet.Inputs[0].WitnessUtxo = utxo
+ tapDerivations := []*psbt.TaprootBip32Derivation{tapDerivation}
+ packet.Inputs[0].TaprootBip32Derivation = tapDerivations
+ packet.Inputs[0].SighashType = txscript.SigHashDefault
+
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ // Arrange: Mock address lookup flow.
+ // 1. FetchScopedKeyManager
+ mocks.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(mocks.accountManager, nil)
+
+ // 2. DeriveFromKeyPath (called inside walletdb.View)
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything, mock.Anything,
+ ).Return(mocks.pubKeyAddr, nil)
+
+ // 3. Address/PrivKey from ManagedAddress
+ mocks.pubKeyAddr.On("PrivKey").Return(privKey, nil)
+
+ // Act: Call signTaprootPsbtInput to sign the input using the mocked
+ // wallet and keys.
+ sigHashes := txscript.NewTxSigHashes(
+ tx, txscript.NewCannedPrevOutputFetcher(
+ packet.Inputs[0].WitnessUtxo.PkScript,
+ packet.Inputs[0].WitnessUtxo.Value,
+ ),
+ )
+ err = w.signTaprootPsbtInput(t.Context(), packet, 0, sigHashes, nil)
+
+ // Assert: Verify that no error occurred and that the TaprootKeySpendSig
+ // field in the PSBT input is now populated with a signature.
+ require.NoError(t, err)
+ require.NotEmpty(t, packet.Inputs[0].TaprootKeySpendSig)
+}
+
+// TestSignBip32PsbtInput tests that signBip32PsbtInput successfully generates
+// and appends a signature for a valid BIP32 (SegWit v0) input.
+func TestSignBip32PsbtInput(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+ pubKeyBytes := pubKey.SerializeCompressed()
+
+ // Arrange: Define the BIP32 derivation path for the input key (BIP-84
+ // P2WKH).
+ derivationPath := []uint32{
+ hdkeychain.HardenedKeyStart + 84,
+ hdkeychain.HardenedKeyStart + 1,
+ hdkeychain.HardenedKeyStart + 0,
+ 0, 0,
+ }
+ derivation := &psbt.Bip32Derivation{
+ PubKey: pubKeyBytes,
+ Bip32Path: derivationPath,
+ }
+
+ // Arrange: Create a P2WKH UTXO using the public key derived from the
+ // path.
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKeyBytes), &chainParams,
+ )
+ require.NoError(t, err)
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ utxo := &wire.TxOut{
+ Value: 1000,
+ PkScript: p2wkhScript,
+ }
+
+ // Arrange: Create a PSBT packet containing the transaction with the
+ // BIP32 input.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ packet.Inputs[0].WitnessUtxo = utxo
+ packet.Inputs[0].Bip32Derivation = []*psbt.Bip32Derivation{derivation}
+ packet.Inputs[0].SighashType = txscript.SigHashAll
+ packet.Inputs[0].WitnessScript = p2wkhScript
+
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ // Arrange: Mock address lookup flow.
+ mocks.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(mocks.accountManager, nil)
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything, mock.Anything,
+ ).Return(mocks.pubKeyAddr, nil)
+ mocks.pubKeyAddr.On("PrivKey").Return(privKey, nil)
+
+ // Act: Call signBip32PsbtInput to sign the input using the mocked
+ // wallet and keys.
+ sigHashes := txscript.NewTxSigHashes(
+ tx, txscript.NewCannedPrevOutputFetcher(
+ packet.Inputs[0].WitnessUtxo.PkScript,
+ packet.Inputs[0].WitnessUtxo.Value,
+ ),
+ )
+ err = w.signBip32PsbtInput(t.Context(), packet, 0, sigHashes, nil)
+
+ // Assert: Verify that no error occurred and that the PartialSigs field
+ // in the PSBT input is populated with a signature from the expected
+ // public key.
+ require.NoError(t, err)
+ require.Len(t, packet.Inputs[0].PartialSigs, 1)
+ require.Equal(t, pubKeyBytes, packet.Inputs[0].PartialSigs[0].PubKey)
+}
+
+// TestSignPsbtFailNilParams tests that SignPsbt returns ErrNilSignPsbtParams
+// when provided with nil parameters.
+func TestSignPsbtFailNilParams(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a mock wallet.
+ w, _ := createUnlockedWalletWithMocks(t)
+
+ // Act: Call SignPsbt with nil params.
+ _, err := w.SignPsbt(t.Context(), nil)
+
+ // Assert: Verify error.
+ require.ErrorIs(t, err, ErrNilArguments)
+}
+
+// TestSignPsbt tests the high-level SignPsbt method, ensuring it correctly
+// orchestrates the signing process for a PSBT packet with valid inputs.
+func TestSignPsbt(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+ pubKeyBytes := pubKey.SerializeCompressed()
+
+ // Arrange: Define the BIP32 derivation path for the input key
+ // (RegressionNet, P2WKH).
+ derivationPath := []uint32{
+ hdkeychain.HardenedKeyStart + 84,
+ // CoinType 1 (RegressionNet)
+ hdkeychain.HardenedKeyStart + 1,
+ hdkeychain.HardenedKeyStart + 0,
+ 0, 0,
+ }
+ derivation := &psbt.Bip32Derivation{
+ PubKey: pubKeyBytes,
+ Bip32Path: derivationPath,
+ }
+
+ // Arrange: Create a P2WKH UTXO that corresponds to the derivation path,
+ // representing the input to be signed.
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKeyBytes), &chainParams,
+ )
+ require.NoError(t, err)
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ utxo := &wire.TxOut{
+ Value: 1000,
+ PkScript: p2wkhScript,
+ }
+
+ // Arrange: Create a PSBT packet containing the transaction to be
+ // signed.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ tx.AddTxOut(&wire.TxOut{Value: 1000, PkScript: []byte{}})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ packet.Inputs[0].WitnessUtxo = utxo
+ packet.Inputs[0].Bip32Derivation = []*psbt.Bip32Derivation{derivation}
+ packet.Inputs[0].SighashType = txscript.SigHashAll
+
+ signParams := &SignPsbtParams{Packet: packet}
+ // Arrange: Wrap the packet in SignPsbtParams.
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ // Arrange: Configure mock expectations for key derivation and private
+ // key retrieval.
+ mocks.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(mocks.accountManager, nil)
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything, mock.Anything,
+ ).Return(mocks.pubKeyAddr, nil)
+ mocks.pubKeyAddr.On("PrivKey").Return(privKey, nil)
+
+ // Act: Call SignPsbt to perform the full signing workflow on the
+ // packet.
+ result, err := w.SignPsbt(t.Context(), signParams)
+
+ // Assert: Verify that the operation succeeded, the input is reported as
+ // signed, and the underlying PSBT packet contains the generated
+ // signature.
+ require.NoError(t, err)
+ require.Len(t, result.SignedInputs, 1)
+ require.Equal(t, uint32(0), result.SignedInputs[0])
+ require.Len(t, packet.Inputs[0].PartialSigs, 1)
+}
+
+// TestSignPsbtInputsNotReady tests that SignPsbt fails if inputs are not ready
+// (missing WitnessUtxo/NonWitnessUtxo).
+func TestSignPsbtInputsNotReady(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Packet with input but no UTXO info.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ signParams := &SignPsbtParams{Packet: packet}
+ w, _ := createUnlockedWalletWithMocks(t)
+
+ // Act.
+ _, err = w.SignPsbt(t.Context(), signParams)
+
+ // Assert.
+ require.ErrorContains(t, err, "psbt inputs not ready")
+}
+
+// TestSignPsbtInvalidDerivationPath tests that SignPsbt returns a fatal error
+// if the derivation path is invalid.
+func TestSignPsbtInvalidDerivationPath(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Packet with valid UTXO but invalid derivation path.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ tx.AddTxOut(&wire.TxOut{Value: 1000, PkScript: []byte{}})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ packet.Inputs[0].WitnessUtxo = &wire.TxOut{
+ Value: 1000,
+ PkScript: []byte{0x00, 0x14}, // P2WKH dummy
+ }
+ // Invalid path (too short).
+ packet.Inputs[0].Bip32Derivation = []*psbt.Bip32Derivation{{
+ Bip32Path: []uint32{1, 2, 3},
+ PubKey: make([]byte, 33),
+ }}
+
+ signParams := &SignPsbtParams{Packet: packet}
+ w, _ := createUnlockedWalletWithMocks(t)
+
+ // Act.
+ _, err = w.SignPsbt(t.Context(), signParams)
+
+ // Assert.
+ require.ErrorIs(t, err, ErrInvalidBip32Path)
+}
+
+// TestSignPsbtSignErrorSkippable tests that SignPsbt skips an input if
+// signing fails with a skippable error (e.g. key not found).
+func TestSignPsbtSignErrorSkippable(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Packet with valid input.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ tx.AddTxOut(&wire.TxOut{Value: 1000, PkScript: []byte{}})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ p2wkhScript, _ := txscript.PayToAddrScript(
+ &address.AddressWitnessPubKeyHash{},
+ ) // Dummy script
+ packet.Inputs[0].WitnessUtxo = &wire.TxOut{
+ Value: 1000,
+ PkScript: p2wkhScript,
+ }
+ // Valid path.
+ packet.Inputs[0].Bip32Derivation = []*psbt.Bip32Derivation{{
+ Bip32Path: []uint32{
+ hdkeychain.HardenedKeyStart + 84,
+ hdkeychain.HardenedKeyStart + 1,
+ hdkeychain.HardenedKeyStart + 0,
+ 0, 0,
+ },
+ PubKey: make([]byte, 33),
+ }}
+ packet.Inputs[0].SighashType = txscript.SigHashAll
+
+ signParams := &SignPsbtParams{Packet: packet}
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ // Arrange: Mocks to simulate signing failure.
+ mocks.addrStore.On("FetchScopedKeyManager", mock.Anything).
+ Return(mocks.accountManager, nil)
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything, mock.Anything,
+ ).Return(mocks.pubKeyAddr, nil)
+
+ // PrivKey returns error!
+ mocks.pubKeyAddr.On("PrivKey").Return(nil, errKeyNotFound)
+
+ // Act.
+ result, err := w.SignPsbt(t.Context(), signParams)
+
+ // Assert: No error, but nothing signed.
+ require.NoError(t, err)
+ require.Empty(t, result.SignedInputs)
+}
+
+// TestSignTaprootPsbtInputErrors tests various error conditions in
+// signTaprootPsbtInput.
+func TestSignTaprootPsbtInputErrors(t *testing.T) {
+ t.Parallel()
+
+ w, _ := createUnlockedWalletWithMocks(t)
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ // Arrange: Add a dummy Witness UTXO to satisfy validity checks.
+ packet.Inputs[0].WitnessUtxo = &wire.TxOut{}
+
+ // Case 1: Invalid Derivation Path.
+ tapDerivation := []*psbt.TaprootBip32Derivation{{
+ Bip32Path: []uint32{1}, // Too short
+ }}
+ packet.Inputs[0].TaprootBip32Derivation = tapDerivation
+ err = w.signTaprootPsbtInput(t.Context(), packet, 0, nil, nil)
+ require.ErrorIs(t, err, ErrInvalidBip32Path)
+
+ // Case 2: CreateTaprootSpendDetails error (e.g. invalid merkle root).
+ packet.Inputs[0].TaprootBip32Derivation[0].Bip32Path = []uint32{
+ hdkeychain.HardenedKeyStart + 86,
+ hdkeychain.HardenedKeyStart + 1,
+ hdkeychain.HardenedKeyStart + 0,
+ 0, 0,
+ }
+ packet.Inputs[0].TaprootMerkleRoot = []byte{0x01} // Invalid length
+ err = w.signTaprootPsbtInput(t.Context(), packet, 0, nil, nil)
+ require.ErrorIs(t, err, ErrInvalidTaprootMerkleRootLength)
+}
+
+// TestSignBip32PsbtInputErrors tests various error conditions in
+// signBip32PsbtInput.
+func TestSignBip32PsbtInputErrors(t *testing.T) {
+ t.Parallel()
+
+ w, _ := createUnlockedWalletWithMocks(t)
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ // Arrange: Add a dummy Witness UTXO to satisfy validity checks.
+ packet.Inputs[0].WitnessUtxo = &wire.TxOut{}
+
+ // Case 1: Invalid Derivation Path.
+ packet.Inputs[0].Bip32Derivation = []*psbt.Bip32Derivation{{
+ Bip32Path: []uint32{1}, // Too short
+ }}
+ err = w.signBip32PsbtInput(t.Context(), packet, 0, nil, nil)
+ require.ErrorIs(t, err, ErrInvalidBip32Path)
+
+ // Case 2: Unknown Purpose.
+ packet.Inputs[0].Bip32Derivation[0].Bip32Path = []uint32{
+ hdkeychain.HardenedKeyStart + 999, // Unknown
+ hdkeychain.HardenedKeyStart + 1,
+ hdkeychain.HardenedKeyStart + 0,
+ 0, 0,
+ }
+ err = w.signBip32PsbtInput(t.Context(), packet, 0, nil, nil)
+ require.ErrorIs(t, err, ErrUnknownBip32Purpose)
+}
+
+// TestAddScriptToPInput tests that addScriptToPInput correctly updates
+// the PSBT input with the provided witness and/or sigScript.
+func TestAddScriptToPInput(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Dummy witness and sigScript.
+ witness := wire.TxWitness{[]byte{0x01}, []byte{0x02}}
+ sigScript := []byte{0x03}
+
+ // Arrange: Expected serialized witness:
+ // - 0x02 (stack items)
+ // - 0x01 (len) + 0x01 (data)
+ // - 0x01 (len) + 0x02 (data)
+ expectedWitness := []byte{0x02, 0x01, 0x01, 0x01, 0x02}
+
+ tests := []struct {
+ name string
+ witness wire.TxWitness
+ sigScript []byte
+ expectedWitness []byte
+ expectedSig []byte
+ }{
+ {
+ name: "witness only",
+ witness: witness,
+ sigScript: nil,
+ expectedWitness: expectedWitness,
+ expectedSig: nil,
+ },
+ {
+ name: "sigScript only",
+ witness: nil,
+ sigScript: sigScript,
+ expectedWitness: nil,
+ expectedSig: sigScript,
+ },
+ {
+ name: "both",
+ witness: witness,
+ sigScript: sigScript,
+ expectedWitness: expectedWitness,
+ expectedSig: sigScript,
+ },
+ {
+ name: "none",
+ witness: nil,
+ sigScript: nil,
+ expectedWitness: nil,
+ expectedSig: nil,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: UnlockingScript and target PInput.
+ script := &UnlockingScript{
+ Witness: tc.witness,
+ SigScript: tc.sigScript,
+ }
+ pInput := &psbt.PInput{}
+
+ // Act: Call addScriptToPInput.
+ err := addScriptToPInput(pInput, script)
+
+ // Assert: Verify no error and fields match
+ // expectations.
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedWitness,
+ pInput.FinalScriptWitness)
+ require.Equal(t, tc.expectedSig,
+ pInput.FinalScriptSig)
+ })
+ }
+}
+
+// TestFinalizeInput tests that finalizeInput correctly processes a single PSBT
+// input, handling success, skips, and errors.
+func TestFinalizeInput(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ t.Run("success", func(t *testing.T) {
+ t.Parallel()
+ // Arrange: Valid PSBT input.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ packet.Inputs[0].WitnessUtxo = &wire.TxOut{
+ Value: 1000,
+ PkScript: p2wkhScript,
+ }
+ packet.Inputs[0].SighashType = txscript.SigHashAll
+
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ // Arrange: Mock dependencies.
+ mocks.addrStore.On(
+ "Address", mock.Anything, mock.Anything,
+ ).Return(mocks.pubKeyAddr, nil)
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.WitnessPubKey)
+ mocks.pubKeyAddr.On("PrivKey").Return(privKey, nil)
+
+ sigHashes := txscript.NewTxSigHashes(
+ tx, txscript.NewCannedPrevOutputFetcher(
+ packet.Inputs[0].WitnessUtxo.PkScript,
+ packet.Inputs[0].WitnessUtxo.Value,
+ ),
+ )
+
+ // Act.
+ err = w.finalizeInput(t.Context(), packet, 0, sigHashes)
+
+ // Assert.
+ require.NoError(t, err)
+ require.NotEmpty(t, packet.Inputs[0].FinalScriptWitness)
+ })
+
+ t.Run("skip finalized", func(t *testing.T) {
+ t.Parallel()
+ // Arrange: Already finalized input.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ packet.Inputs[0].FinalScriptWitness = []byte{0x01}
+
+ w, _ := createUnlockedWalletWithMocks(t)
+
+ // Act.
+ err = w.finalizeInput(t.Context(), packet, 0, nil)
+
+ // Assert: No error, remains unchanged (mock not called).
+ require.NoError(t, err)
+ })
+
+ t.Run("skip missing utxo", func(t *testing.T) {
+ t.Parallel()
+ // Arrange: Input without UTXO.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ w, _ := createUnlockedWalletWithMocks(t)
+
+ // Act.
+ err = w.finalizeInput(t.Context(), packet, 0, nil)
+
+ // Assert: No error (logs error but continues).
+ require.NoError(t, err)
+ })
+
+ t.Run("skip malformed script", func(t *testing.T) {
+ t.Parallel()
+ // Arrange: Input with malformed pkScript.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ // OP_RETURN script cannot be extracted as an address.
+ packet.Inputs[0].WitnessUtxo = &wire.TxOut{
+ Value: 1000,
+ PkScript: []byte{0x6a},
+ }
+
+ w, _ := createUnlockedWalletWithMocks(t)
+
+ // Act.
+ err = w.finalizeInput(t.Context(), packet, 0, nil)
+
+ // Assert: No error (logs error but continues).
+ require.NoError(t, err)
+ })
+}
+
+// TestFinalizePsbtSuccess tests that FinalizePsbt successfully generates
+// witnesses for supported input types (P2WKH, Taproot).
+func TestFinalizePsbtSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup keys.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ // Arrange: Create addresses/scripts.
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+ p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ trAddr, err := address.NewAddressTaproot(
+ schnorr.SerializePubKey(pubKey), &chainParams,
+ )
+ require.NoError(t, err)
+ trScript, err := txscript.PayToAddrScript(trAddr)
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ pkScript []byte
+ addrType waddrmgr.AddressType
+ addr address.Address
+ }{
+ {
+ name: "p2wkh",
+ pkScript: p2wkhScript,
+ addrType: waddrmgr.WitnessPubKey,
+ addr: p2wkhAddr,
+ },
+ {
+ name: "taproot",
+ pkScript: trScript,
+ addrType: waddrmgr.TaprootPubKey,
+ addr: trAddr,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create PSBT.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ tx.AddTxOut(&wire.TxOut{Value: 1000}) // Add output
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ packet.Inputs[0].WitnessUtxo = &wire.TxOut{
+ Value: 1000,
+ PkScript: tc.pkScript,
+ }
+ packet.Inputs[0].SighashType = txscript.SigHashDefault
+
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ // Arrange: Mock address lookup.
+ mocks.addrStore.On(
+ "Address", mock.Anything,
+ mock.MatchedBy(func(a address.Address) bool {
+ return a.String() == tc.addr.String()
+ }),
+ ).Return(mocks.pubKeyAddr, nil)
+
+ // Arrange: Mock ManagedPubKeyAddress.
+ // Note: Address() and PubKey() are not called for
+ // P2WKH/Taproot signing paths in
+ // ComputeUnlockingScript.
+ mocks.pubKeyAddr.On("AddrType").Return(tc.addrType)
+
+ // Create a copy of the private key to avoid data races
+ // when parallel tests call Zero() on it.
+ privKeyCopy := *privKey
+ mocks.pubKeyAddr.On("PrivKey").Return(&privKeyCopy, nil)
+
+ // Act: Call FinalizePsbt.
+ err = w.FinalizePsbt(t.Context(), packet)
+
+ // Assert: Verify success and witness presence.
+ require.NoError(t, err)
+ require.NotEmpty(
+ t, packet.Inputs[0].FinalScriptWitness,
+ )
+ })
+ }
+}
+
+// TestFinalizePsbtErrors tests error conditions for FinalizePsbt.
+func TestFinalizePsbtErrors(t *testing.T) {
+ t.Parallel()
+
+ t.Run("inputs not ready", func(t *testing.T) {
+ t.Parallel()
+ // Arrange: Packet with input but no UTXO info.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ tx.AddTxOut(&wire.TxOut{Value: 1000})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ w, _ := createUnlockedWalletWithMocks(t)
+
+ // Act.
+ err = w.FinalizePsbt(t.Context(), packet)
+
+ // Assert.
+ require.ErrorContains(t, err, "psbt inputs not ready")
+ })
+
+ t.Run("finalization failed", func(t *testing.T) {
+ t.Parallel()
+ // Arrange: Packet with valid UTXO but we can't sign it (watch
+ // only).
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ tx.AddTxOut(&wire.TxOut{Value: 1000})
+ packet, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ // Use a valid P2WKH script so extraction succeeds.
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ make([]byte, 20), &chainParams,
+ )
+ require.NoError(t, err)
+ dummyScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ packet.Inputs[0].WitnessUtxo = &wire.TxOut{
+ Value: 1000,
+ PkScript: dummyScript,
+ }
+
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ // Arrange: Mock Address lookup to return error (or watch only).
+ // Simulating "Address not found" or "Key not found".
+ // ComputeUnlockingScript will fail, log, and continue.
+ // Then MaybeFinalizeAll will fail because no witness.
+ mocks.addrStore.On("Address", mock.Anything, mock.Anything).
+ Return(nil, errAddrNotFound)
+
+ // Act.
+ err = w.FinalizePsbt(t.Context(), packet)
+
+ // Assert: Should return error from MaybeFinalizeAll.
+ require.ErrorContains(t, err, "error finalizing PSBT")
+ })
+}
+
+// TestValidatePsbtMerge tests the validatePsbtMerge helper function.
+func TestValidatePsbtMerge(t *testing.T) {
+ t.Parallel()
+
+ // Helper to create a dummy packet with specific tx hash and IO counts.
+ createPacket := func(txHash byte, inCount, outCount int) *psbt.Packet {
+ tx := wire.NewMsgTx(2)
+ // Add dummy inputs/outputs to affect count and hash.
+ for i := range inCount {
+ tx.AddTxIn(&wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: chainhash.Hash{txHash},
+ Index: uint32(i),
+ },
+ })
+ }
+
+ for i := range outCount {
+ tx.AddTxOut(&wire.TxOut{Value: int64(i)})
+ }
+
+ p, _ := psbt.NewFromUnsignedTx(tx)
+
+ return p
+ }
+
+ base := createPacket(1, 1, 1)
+
+ tests := []struct {
+ name string
+ psbts []*psbt.Packet
+ wantErr error
+ }{
+ {
+ name: "success single",
+ psbts: []*psbt.Packet{base},
+ wantErr: nil,
+ },
+ {
+ name: "success multiple identical",
+ psbts: []*psbt.Packet{base, base},
+ wantErr: nil,
+ },
+ {
+ name: "empty list",
+ psbts: []*psbt.Packet{},
+ wantErr: ErrNoPsbtsToCombine,
+ },
+ {
+ name: "mismatched txid",
+ psbts: []*psbt.Packet{base, createPacket(2, 1, 1)},
+ wantErr: ErrDifferentTransactions,
+ },
+ {
+ name: "mismatched input count",
+ psbts: func() []*psbt.Packet {
+ // Create a packet with same TXID but corrupted
+ // input count.
+ p2 := createPacket(1, 1, 1)
+ p2.Inputs = append(p2.Inputs, psbt.PInput{})
+
+ return []*psbt.Packet{base, p2}
+ }(),
+ wantErr: ErrInputCountMismatch,
+ },
+ {
+ name: "mismatched output count",
+ psbts: func() []*psbt.Packet {
+ // Create a packet with same TXID but corrupted
+ // output count.
+ p2 := createPacket(1, 1, 1)
+ p2.Outputs = append(p2.Outputs, psbt.POutput{})
+
+ return []*psbt.Packet{base, p2}
+ }(),
+ wantErr: ErrOutputCountMismatch,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ got, err := validatePsbtMerge(tc.psbts)
+ if tc.wantErr != nil {
+ require.ErrorIs(t, err, tc.wantErr)
+ require.Nil(t, got)
+ } else {
+ require.NoError(t, err)
+ require.NotNil(t, got)
+
+ // Verify it is a different object (copy).
+ require.NotSame(t, tc.psbts[0], got)
+
+ // Verify structure matches base.
+ require.Equal(t,
+ tc.psbts[0].UnsignedTx.TxHash(),
+ got.UnsignedTx.TxHash(),
+ )
+ require.Len(t, got.Inputs,
+ len(tc.psbts[0].Inputs))
+ require.Len(t, got.Outputs,
+ len(tc.psbts[0].Outputs))
+ }
+ })
+ }
+}
+
+// TestMergePsbtInputs tests that mergePsbtInputs correctly merges and
+// deduplicates input fields.
+func TestMergePsbtInputs(t *testing.T) {
+ t.Parallel()
+
+ t.Run("partial sigs deduplication", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a destination input with one signature, and
+ // a source input with the same signature (duplicate) plus a
+ // new one. This simulates merging updates from multiple
+ // signers where some data overlaps.
+ dest := &psbt.PInput{
+ PartialSigs: []*psbt.PartialSig{
+ {
+ PubKey: []byte{1},
+ Signature: []byte{10},
+ },
+ },
+ }
+ src := &psbt.PInput{
+ PartialSigs: []*psbt.PartialSig{
+ {
+ PubKey: []byte{1},
+ Signature: []byte{10},
+ }, // Duplicate
+ {
+ PubKey: []byte{2},
+ Signature: []byte{20},
+ }, // New
+ },
+ }
+
+ // Act: Merge the source input into the destination.
+ err := mergePsbtInputs(dest, src)
+ require.NoError(t, err)
+
+ // Assert: Verify that the destination now contains exactly two
+ // signatures. The first one should be preserved, and the
+ // second one should be the new signature from the source.
+ require.Len(t, dest.PartialSigs, 2)
+ require.Equal(t, []byte{1}, dest.PartialSigs[0].PubKey)
+ require.Equal(t, []byte{2}, dest.PartialSigs[1].PubKey)
+ })
+
+ t.Run("sighash type adoption", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a destination input with the default sighash
+ // type (0) and a source input with a specific type
+ // (SigHashSingle).
+ dest := &psbt.PInput{SighashType: 0} // Default
+ src := &psbt.PInput{SighashType: txscript.SigHashSingle}
+
+ // Act: Merge the inputs.
+ err := mergePsbtInputs(dest, src)
+
+ // Assert: Verify that the destination adopted the source's
+ // sighash type, as 0 is treated as "unset".
+ require.NoError(t, err)
+ require.Equal(t, txscript.SigHashSingle, dest.SighashType)
+
+ // Arrange: Create a scenario with conflicting sighash types.
+ // Destination has SigHashAll, Source has SigHashSingle.
+ dest.SighashType = txscript.SigHashAll
+ src.SighashType = txscript.SigHashSingle
+
+ // Act: Attempt to merge conflicting inputs.
+ err = mergePsbtInputs(dest, src)
+
+ // Assert: Verify that the merge returns an error indicating
+ // the mismatch.
+ require.ErrorContains(t, err, "sighash type mismatch")
+ })
+
+ t.Run("scripts merging", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a destination input missing script info, and
+ // a source input containing it.
+ dest := &psbt.PInput{}
+ src := &psbt.PInput{
+ RedeemScript: []byte{1, 2, 3},
+ WitnessScript: []byte{4, 5, 6},
+ }
+
+ // Act: Merge the inputs.
+ err := mergePsbtInputs(dest, src)
+
+ // Assert: Verify that the scripts were successfully copied to
+ // the destination.
+ require.NoError(t, err)
+ require.Equal(t, src.RedeemScript, dest.RedeemScript)
+ require.Equal(t, src.WitnessScript, dest.WitnessScript)
+ })
+}
+
+// TestMergePsbtOutputs tests that mergePsbtOutputs correctly merges and
+// deduplicates output fields.
+func TestMergePsbtOutputs(t *testing.T) {
+ t.Parallel()
+
+ t.Run("bip32 derivation deduplication", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create destination and source outputs with
+ // overlapping BIP32 derivation paths.
+ dest := &psbt.POutput{
+ Bip32Derivation: []*psbt.Bip32Derivation{
+ {
+ PubKey: []byte{1},
+ MasterKeyFingerprint: 10,
+ },
+ },
+ }
+ src := &psbt.POutput{
+ Bip32Derivation: []*psbt.Bip32Derivation{
+ {
+ PubKey: []byte{1},
+ MasterKeyFingerprint: 10,
+ }, // Duplicate
+ {
+ PubKey: []byte{2},
+ MasterKeyFingerprint: 20,
+ }, // New
+ },
+ }
+
+ // Act: Merge the outputs.
+ err := mergePsbtOutputs(dest, src)
+ require.NoError(t, err)
+
+ // Assert: Verify that the destination now contains both unique
+ // derivations.
+ require.Len(t, dest.Bip32Derivation, 2)
+ require.Equal(t, []byte{1}, dest.Bip32Derivation[0].PubKey)
+ require.Equal(t, []byte{2}, dest.Bip32Derivation[1].PubKey)
+ })
+
+ t.Run("taproot internal key adoption", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a destination output missing the Taproot
+ // internal key, and a source output that has it.
+ dest := &psbt.POutput{}
+ src := &psbt.POutput{
+ TaprootInternalKey: []byte{1, 2, 3},
+ }
+
+ // Act: Merge the outputs.
+ err := mergePsbtOutputs(dest, src)
+
+ require.NoError(t, err)
+ require.Equal(t, src.TaprootInternalKey,
+ dest.TaprootInternalKey)
+ })
+}
+
+// TestAddInputInfoSegWitV0 tests the legacy helper for adding SegWit v0 input
+// info.
+func TestAddInputInfoSegWitV0(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup input parameters (prevTx, utxo, derivation).
+ in := &psbt.PInput{}
+ prevTx := wire.NewMsgTx(1)
+ utxo := &wire.TxOut{Value: 1000, PkScript: []byte{1}}
+ derivation := &psbt.Bip32Derivation{PubKey: []byte{2}}
+ witnessProgram := []byte{3}
+
+ // Mock address type.
+ mockAddr := &mockManagedAddress{}
+ mockAddr.On("AddrType").Return(waddrmgr.NestedWitnessPubKey)
+
+ // Act: Call the helper.
+ addInputInfoSegWitV0(in, prevTx, utxo, derivation, mockAddr,
+ witnessProgram)
+
+ // Assert: Verify fields are populated correctly.
+ require.Equal(t, prevTx, in.NonWitnessUtxo)
+ require.Equal(t, utxo.Value, in.WitnessUtxo.Value)
+ require.Equal(t, utxo.PkScript, in.WitnessUtxo.PkScript)
+ require.Equal(t, txscript.SigHashAll, in.SighashType)
+ require.Equal(t, derivation, in.Bip32Derivation[0])
+ require.Equal(t, witnessProgram, in.RedeemScript)
+}
+
+// TestAddInputInfoSegWitV1 tests the legacy helper for adding SegWit v1 input
+// info.
+func TestAddInputInfoSegWitV1(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup input parameters.
+ in := &psbt.PInput{}
+ utxo := &wire.TxOut{Value: 1000, PkScript: []byte{1}}
+ // PubKey must be valid length for slicing [1:].
+ pubKey := make([]byte, 33)
+ pubKey[0] = 0x02
+ derivation := &psbt.Bip32Derivation{PubKey: pubKey}
+
+ // Act: Call the helper.
+ addInputInfoSegWitV1(in, utxo, derivation)
+
+ // Assert: Verify fields are populated correctly.
+ require.Equal(t, utxo.Value, in.WitnessUtxo.Value)
+ require.Equal(t, txscript.SigHashDefault, in.SighashType)
+ require.Equal(t, derivation, in.Bip32Derivation[0])
+ require.Equal(t, pubKey[1:], in.TaprootBip32Derivation[0].XOnlyPubKey)
+}
+
+// TestPsbtPrevOutputFetcher tests that the prev output fetcher correctly
+// retrieves UTXOs from the PSBT packet.
+func TestPsbtPrevOutputFetcher(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a PSBT packet with multiple inputs.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{PreviousOutPoint: wire.OutPoint{Index: 0}})
+ tx.AddTxIn(&wire.TxIn{PreviousOutPoint: wire.OutPoint{Index: 1}})
+
+ packet, _ := psbt.NewFromUnsignedTx(tx)
+
+ // Input 0: NonWitnessUtxo.
+ prevTx := wire.NewMsgTx(1)
+ prevTx.AddTxOut(&wire.TxOut{Value: 1000})
+ packet.Inputs[0].NonWitnessUtxo = prevTx
+
+ // Input 1: WitnessUtxo.
+ packet.Inputs[1].WitnessUtxo = &wire.TxOut{Value: 2000}
+
+ // Act: Create the fetcher.
+ fetcher, err := PsbtPrevOutputFetcher(packet)
+ require.NoError(t, err)
+
+ // Assert: Check input 0 (NonWitness).
+ out0 := fetcher.FetchPrevOutput(wire.OutPoint{Index: 0})
+ require.NotNil(t, out0)
+ require.Equal(t, int64(1000), out0.Value)
+
+ // Assert: Check input 1 (Witness).
+ out1 := fetcher.FetchPrevOutput(wire.OutPoint{Index: 1})
+ require.NotNil(t, out1)
+ require.Equal(t, int64(2000), out1.Value)
+}
+
+// TestMergeSighashType tests the mergeSighashType helper function.
+//
+// It verifies two key behaviors:
+// 1. Conflict Detection: It ensures that an error is returned if the
+// destination and source inputs have different, non-zero sighash types.
+// 2. Adoption: It ensures that if the destination has a default (0) sighash
+// type, it correctly adopts the type from the source.
+func TestMergeSighashType(t *testing.T) {
+ t.Parallel()
+
+ t.Run("detect mismatch", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Construct a 'destination' PSBT input that has
+ // already declared a sighash type of SigHashAll.
+ dest := &psbt.PInput{SighashType: txscript.SigHashAll}
+
+ // Arrange: Construct a 'source' PSBT input that declares a
+ // different, conflicting sighash type of SigHashSingle.
+ src := &psbt.PInput{SighashType: txscript.SigHashSingle}
+
+ // Act: Attempt to merge the conflicting inputs.
+ err := mergeSighashType(dest, src)
+
+ // Assert: Verify that the function identified the conflict and
+ // returned the expected error.
+ require.ErrorIs(t, err, ErrPsbtMergeConflict)
+ })
+
+ t.Run("adopt source type", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Construct a 'destination' PSBT input with the
+ // default (zero) sighash type, indicating it hasn't been set
+ // yet.
+ dest := &psbt.PInput{SighashType: 0}
+
+ // Arrange: Construct a 'source' PSBT input with a specific
+ // sighash type (SigHashSingle) that should be propagated.
+ src := &psbt.PInput{SighashType: txscript.SigHashSingle}
+
+ // Act: Merge the source into the destination.
+ err := mergeSighashType(dest, src)
+
+ // Assert: Verify that the operation was successful (no error)
+ // and that the destination input has been updated to match the
+ // source's sighash type.
+ require.NoError(t, err)
+ require.Equal(t, txscript.SigHashSingle, dest.SighashType)
+ })
+}
+
+// TestMergeRedeemScript tests the mergeRedeemScript helper function.
+//
+// It verifies that:
+// 1. Conflicting redeem scripts cause an error.
+// 2. A missing redeem script in the destination is populated from the source.
+func TestMergeRedeemScript(t *testing.T) {
+ t.Parallel()
+
+ t.Run("detect mismatch", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a 'destination' input with a specific redeem
+ // script (byte sequence {1}).
+ dest := &psbt.PInput{RedeemScript: []byte{1}}
+
+ // Arrange: Create a 'source' input with a different redeem
+ // script (byte sequence {2}).
+ src := &psbt.PInput{RedeemScript: []byte{2}}
+
+ // Act: Attempt to merge the source into the destination.
+ err := mergeRedeemScript(dest, src)
+
+ // Assert: Verify that the function returns
+ // ErrPsbtMergeConflict, preventing the corruption of the
+ // redeem script.
+ require.ErrorIs(t, err, ErrPsbtMergeConflict)
+ })
+
+ t.Run("adopt source script", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a 'destination' input with no redeem script
+ // (nil or empty).
+ dest := &psbt.PInput{}
+
+ // Arrange: Create a 'source' input that contains a valid
+ // redeem script.
+ src := &psbt.PInput{RedeemScript: []byte{1, 2, 3}}
+
+ // Act: Merge the source into the destination.
+ err := mergeRedeemScript(dest, src)
+
+ // Assert: Verify that the merge succeeded and the destination
+ // now contains the redeem script from the source.
+ require.NoError(t, err)
+ require.Equal(t, src.RedeemScript, dest.RedeemScript)
+ })
+}
+
+// TestMergeWitnessUtxo tests the mergeWitnessUtxo helper function.
+//
+// It verifies that:
+// 1. Conflicting Witness UTXO values (amount or script) trigger an error.
+// 2. A missing Witness UTXO in the destination is correctly copied from the
+// source.
+func TestMergeWitnessUtxo(t *testing.T) {
+ t.Parallel()
+
+ t.Run("detect value mismatch", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a 'destination' input with a Witness UTXO
+ // valued at 1000 sats.
+ dest := &psbt.PInput{WitnessUtxo: &wire.TxOut{Value: 1000}}
+
+ // Arrange: Create a 'source' input with a Witness UTXO valued
+ // at 2000 sats (conflicting).
+ src := &psbt.PInput{WitnessUtxo: &wire.TxOut{Value: 2000}}
+
+ // Act: Attempt to merge the inputs.
+ err := mergeWitnessUtxo(dest, src)
+
+ // Assert: Verify that the function returns
+ // ErrPsbtMergeConflict.
+ require.ErrorIs(t, err, ErrPsbtMergeConflict)
+ })
+
+ t.Run("detect script mismatch", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a 'destination' input with a Witness UTXO
+ // script {1}.
+ dest := &psbt.PInput{
+ WitnessUtxo: &wire.TxOut{
+ Value: 1000, PkScript: []byte{1},
+ },
+ }
+
+ // Arrange: Create a 'source' input with the same value but a
+ // different script {2}.
+ src := &psbt.PInput{
+ WitnessUtxo: &wire.TxOut{
+ Value: 1000, PkScript: []byte{2},
+ },
+ }
+
+ // Act: Attempt to merge the inputs.
+ err := mergeWitnessUtxo(dest, src)
+
+ // Assert: Verify that the function returns
+ // ErrPsbtMergeConflict due to the script difference.
+ require.ErrorIs(t, err, ErrPsbtMergeConflict)
+ })
+
+ t.Run("adopt source utxo", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a 'destination' input with no Witness UTXO
+ // info.
+ dest := &psbt.PInput{}
+
+ // Arrange: Create a 'source' input with a full Witness UTXO.
+ src := &psbt.PInput{
+ WitnessUtxo: &wire.TxOut{
+ Value: 1000, PkScript: []byte{1},
+ },
+ }
+
+ // Act: Merge the source into the destination.
+ err := mergeWitnessUtxo(dest, src)
+
+ // Assert: Verify that the destination structure now holds the
+ // exact Witness UTXO pointer/value from the source.
+ require.NoError(t, err)
+ require.Equal(t, src.WitnessUtxo, dest.WitnessUtxo)
+ })
+}
+
+// TestMergeNonWitnessUtxo tests the mergeNonWitnessUtxo helper function.
+//
+// It ensures that full transaction data (for legacy/SegWit v0 inputs) is
+// merged safely, rejecting conflicts where the transaction hash differs.
+func TestMergeNonWitnessUtxo(t *testing.T) {
+ t.Parallel()
+
+ t.Run("detect mismatch", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create two distinct wire transactions to serve as
+ // conflicting NonWitnessUtxo data.
+ tx1 := wire.NewMsgTx(1)
+ tx2 := wire.NewMsgTx(2)
+
+ // Arrange: Assign tx1 to destination and tx2 to source.
+ dest := &psbt.PInput{NonWitnessUtxo: tx1}
+ src := &psbt.PInput{NonWitnessUtxo: tx2}
+
+ // Act: Attempt to merge.
+ err := mergeNonWitnessUtxo(dest, src)
+
+ // Assert: Verify that ErrPsbtMergeConflict is returned
+ // because the transactions differ.
+ require.ErrorIs(t, err, ErrPsbtMergeConflict)
+ })
+
+ t.Run("adopt source utxo", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a 'destination' input with no
+ // NonWitnessUtxo.
+ dest := &psbt.PInput{}
+
+ // Arrange: Create a 'source' input with a valid NonWitnessUtxo
+ // transaction.
+ tx := wire.NewMsgTx(1)
+ src := &psbt.PInput{NonWitnessUtxo: tx}
+
+ // Act: Merge the inputs.
+ err := mergeNonWitnessUtxo(dest, src)
+
+ // Assert: Verify that the destination adopted the
+ // NonWitnessUtxo from the source.
+ require.NoError(t, err)
+ require.Equal(t, src.NonWitnessUtxo, dest.NonWitnessUtxo)
+ })
+}
+
+// TestMergeTaprootInternalKeyMismatch verifies that conflicting Taproot
+// internal keys are detected.
+func TestMergeTaprootInternalKeyMismatch(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup conflicting Taproot internal keys (byte {1} vs
+ // byte {2}).
+ dest := &psbt.POutput{TaprootInternalKey: []byte{1}}
+ src := &psbt.POutput{TaprootInternalKey: []byte{2}}
+
+ // Act: Attempt to merge the outputs.
+ err := mergeTaprootInternalKey(dest, src)
+
+ // Assert: Verify that ErrPsbtMergeConflict is returned.
+ require.ErrorIs(t, err, ErrPsbtMergeConflict)
+}
+
+// TestMergeTaprootInternalKeyAdoption verifies that a source key is adopted
+// if the destination key is missing.
+func TestMergeTaprootInternalKeyAdoption(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a destination output with no internal key.
+ dest := &psbt.POutput{}
+
+ // Arrange: Create a source output with a valid internal key.
+ src := &psbt.POutput{TaprootInternalKey: []byte{1, 2, 3}}
+
+ // Act: Merge the outputs.
+ err := mergeTaprootInternalKey(dest, src)
+
+ // Assert: Verify that the internal key was successfully copied to
+ // the destination.
+ require.NoError(t, err)
+ require.Equal(t, src.TaprootInternalKey, dest.TaprootInternalKey)
+}
+
+// TestDeduplicateTaprootBip32Derivations tests the deduplication logic for
+// Taproot BIP32 derivations.
+func TestDeduplicateTaprootBip32Derivations(t *testing.T) {
+ t.Parallel()
+
+ t.Run("deduplicate", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup destination with one derivation.
+ dest := []*psbt.TaprootBip32Derivation{
+ {XOnlyPubKey: []byte{1}},
+ }
+
+ // Arrange: Setup source with duplicate and new derivation.
+ src := []*psbt.TaprootBip32Derivation{
+ {XOnlyPubKey: []byte{1}}, // Duplicate
+ {XOnlyPubKey: []byte{2}}, // New
+ }
+
+ // Act.
+ got := deduplicateTaprootBip32Derivations(dest, src)
+
+ // Assert: Verify deduplication.
+ require.Len(t, got, 2)
+ require.Equal(t, []byte{1}, got[0].XOnlyPubKey)
+ require.Equal(t, []byte{2}, got[1].XOnlyPubKey)
+ })
+}
+
+// TestDeduplicateTaprootScriptSpendSigs tests the deduplication logic for
+// Taproot script spend signatures based on XOnlyPubKey and LeafHash.
+func TestDeduplicateTaprootScriptSpendSigs(t *testing.T) {
+ t.Parallel()
+
+ // Define common elements for signatures.
+ xOnlyPubKey1 := []byte{1}
+ xOnlyPubKey2 := []byte{2}
+ leafHash1 := []byte{10}
+ leafHash2 := []byte{20}
+
+ tests := []struct {
+ name string
+ dest []*psbt.TaprootScriptSpendSig
+ src []*psbt.TaprootScriptSpendSig
+ want []*psbt.TaprootScriptSpendSig
+ }{
+ {
+ name: "empty slices",
+ dest: nil,
+ src: nil,
+ want: nil,
+ },
+ {
+ name: "add unique from src",
+ dest: []*psbt.TaprootScriptSpendSig{
+ {
+ XOnlyPubKey: xOnlyPubKey1,
+ LeafHash: leafHash1,
+ },
+ },
+ src: []*psbt.TaprootScriptSpendSig{
+ {
+ XOnlyPubKey: xOnlyPubKey2,
+ LeafHash: leafHash2,
+ },
+ },
+ want: []*psbt.TaprootScriptSpendSig{
+ {
+ XOnlyPubKey: xOnlyPubKey1,
+ LeafHash: leafHash1,
+ },
+ {
+ XOnlyPubKey: xOnlyPubKey2,
+ LeafHash: leafHash2,
+ },
+ },
+ },
+ {
+ name: "skip duplicate from src",
+ dest: []*psbt.TaprootScriptSpendSig{
+ {
+ XOnlyPubKey: xOnlyPubKey1,
+ LeafHash: leafHash1,
+ },
+ },
+ src: []*psbt.TaprootScriptSpendSig{
+ {
+ XOnlyPubKey: xOnlyPubKey1,
+ LeafHash: leafHash1,
+ },
+ }, // Duplicate
+ want: []*psbt.TaprootScriptSpendSig{
+ {
+ XOnlyPubKey: xOnlyPubKey1,
+ LeafHash: leafHash1,
+ },
+ },
+ },
+ {
+ name: "mix unique and duplicate",
+ dest: []*psbt.TaprootScriptSpendSig{
+ {
+ XOnlyPubKey: xOnlyPubKey1,
+ LeafHash: leafHash1,
+ },
+ },
+ src: []*psbt.TaprootScriptSpendSig{
+ {
+ XOnlyPubKey: xOnlyPubKey1,
+ LeafHash: leafHash1,
+ }, // Duplicate
+ {
+ XOnlyPubKey: xOnlyPubKey2,
+ LeafHash: leafHash2,
+ }, // Unique
+ },
+ want: []*psbt.TaprootScriptSpendSig{
+ {
+ XOnlyPubKey: xOnlyPubKey1,
+ LeafHash: leafHash1,
+ },
+ {
+ XOnlyPubKey: xOnlyPubKey2,
+ LeafHash: leafHash2,
+ },
+ },
+ },
+ {
+ name: "complex mix",
+ dest: []*psbt.TaprootScriptSpendSig{
+ {
+ XOnlyPubKey: xOnlyPubKey1,
+ LeafHash: leafHash1,
+ },
+ {
+ XOnlyPubKey: xOnlyPubKey2,
+ LeafHash: leafHash1,
+ }, // Same LeafHash, diff PubKey
+ },
+ src: []*psbt.TaprootScriptSpendSig{
+ {
+ XOnlyPubKey: xOnlyPubKey1,
+ LeafHash: leafHash1,
+ }, // Duplicate of dest[0]
+ {
+ XOnlyPubKey: xOnlyPubKey1,
+ LeafHash: leafHash2,
+ }, // Unique
+ },
+ want: []*psbt.TaprootScriptSpendSig{
+ {
+ XOnlyPubKey: xOnlyPubKey1,
+ LeafHash: leafHash1,
+ },
+ {
+ XOnlyPubKey: xOnlyPubKey2,
+ LeafHash: leafHash1,
+ },
+ {
+ XOnlyPubKey: xOnlyPubKey1,
+ LeafHash: leafHash2,
+ },
+ },
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Act: Deduplicate the signatures.
+ got := deduplicateTaprootScriptSpendSigs(
+ tc.dest, tc.src,
+ )
+
+ // Assert: Verify the result.
+ require.ElementsMatch(t, tc.want, got)
+ })
+ }
+}
+
+// TestMergeInputScripts tests the aggregate mergeInputScripts function to
+// ensure it propagates errors from all sub-steps.
+func TestMergeInputScripts(t *testing.T) {
+ t.Parallel()
+
+ t.Run("fail on redeem script", func(t *testing.T) {
+ t.Parallel()
+
+ dest := &psbt.PInput{RedeemScript: []byte{1}}
+ src := &psbt.PInput{RedeemScript: []byte{2}}
+ err := mergeInputScripts(dest, src)
+ require.ErrorIs(t, err, ErrPsbtMergeConflict)
+ })
+
+ t.Run("fail on witness script", func(t *testing.T) {
+ t.Parallel()
+
+ dest := &psbt.PInput{WitnessScript: []byte{1}}
+ src := &psbt.PInput{WitnessScript: []byte{2}}
+ err := mergeInputScripts(dest, src)
+ require.ErrorIs(t, err, ErrPsbtMergeConflict)
+ })
+
+ t.Run("fail on final script sig", func(t *testing.T) {
+ t.Parallel()
+
+ dest := &psbt.PInput{FinalScriptSig: []byte{1}}
+ src := &psbt.PInput{FinalScriptSig: []byte{2}}
+ err := mergeInputScripts(dest, src)
+ require.ErrorIs(t, err, ErrPsbtMergeConflict)
+ })
+
+ t.Run("fail on final script witness", func(t *testing.T) {
+ t.Parallel()
+
+ dest := &psbt.PInput{FinalScriptWitness: []byte{1}}
+ src := &psbt.PInput{FinalScriptWitness: []byte{2}}
+ err := mergeInputScripts(dest, src)
+ require.ErrorIs(t, err, ErrPsbtMergeConflict)
+ })
+
+ t.Run("success", func(t *testing.T) {
+ t.Parallel()
+
+ dest := &psbt.PInput{}
+ src := &psbt.PInput{
+ RedeemScript: []byte{1},
+ WitnessScript: []byte{2},
+ FinalScriptSig: []byte{3},
+ FinalScriptWitness: []byte{4},
+ }
+ err := mergeInputScripts(dest, src)
+ require.NoError(t, err)
+ require.Equal(t, src.RedeemScript, dest.RedeemScript)
+ require.Equal(t, src.WitnessScript, dest.WitnessScript)
+ require.Equal(t, src.FinalScriptSig, dest.FinalScriptSig)
+ require.Equal(
+ t, src.FinalScriptWitness, dest.FinalScriptWitness,
+ )
+ })
+}
+
+// TestMergeOutputScripts tests the aggregate mergeOutputScripts function.
+func TestMergeOutputScripts(t *testing.T) {
+ t.Parallel()
+
+ t.Run("fail on redeem script", func(t *testing.T) {
+ t.Parallel()
+
+ dest := &psbt.POutput{RedeemScript: []byte{1}}
+ src := &psbt.POutput{RedeemScript: []byte{2}}
+ err := mergeOutputScripts(dest, src)
+ require.ErrorIs(t, err, ErrPsbtMergeConflict)
+ })
+
+ t.Run("fail on witness script", func(t *testing.T) {
+ t.Parallel()
+
+ dest := &psbt.POutput{WitnessScript: []byte{1}}
+ src := &psbt.POutput{WitnessScript: []byte{2}}
+ err := mergeOutputScripts(dest, src)
+ require.ErrorIs(t, err, ErrPsbtMergeConflict)
+ })
+
+ t.Run("success", func(t *testing.T) {
+ t.Parallel()
+
+ dest := &psbt.POutput{}
+ src := &psbt.POutput{
+ RedeemScript: []byte{1},
+ WitnessScript: []byte{2},
+ }
+ err := mergeOutputScripts(dest, src)
+ require.NoError(t, err)
+ require.Equal(t, src.RedeemScript, dest.RedeemScript)
+ require.Equal(t, src.WitnessScript, dest.WitnessScript)
+ })
+}
+
+// TestCombinePsbt tests that CombinePsbt correctly merges multiple PSBTs.
+func TestCombinePsbt(t *testing.T) {
+ t.Parallel()
+
+ t.Run("success", func(t *testing.T) {
+ t.Parallel()
+ w, _ := createStartedWalletWithMocks(t)
+
+ // Arrange: Create a base transaction with 1 input and 1 output.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ tx.AddTxOut(&wire.TxOut{Value: 1000}) // Add output
+
+ // Arrange: Create two PSBT packets from this transaction.
+ packet1, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ packet2, err := psbt.NewFromUnsignedTx(tx)
+ require.NoError(t, err)
+
+ // Arrange: Add UTXO info to satisfy structural validation
+ // checks.
+ dummyUtxo := &wire.TxOut{Value: 1000, PkScript: []byte{0x00}}
+ packet1.Inputs[0].WitnessUtxo = dummyUtxo
+ packet2.Inputs[0].WitnessUtxo = dummyUtxo
+
+ // Arrange: Add a unique partial signature to the second
+ // packet.
+ packet2.Inputs[0].PartialSigs = []*psbt.PartialSig{{
+ PubKey: []byte{1}, Signature: []byte{1},
+ }}
+
+ // Act: Combine the two packets.
+ combined, err := w.CombinePsbt(t.Context(), packet1, packet2)
+
+ // Assert: Verify the merge was successful and the resulting
+ // packet contains the signature from packet2.
+ require.NoError(t, err)
+ require.Len(t, combined.Inputs[0].PartialSigs, 1)
+ require.Equal(t, []byte{1},
+ combined.Inputs[0].PartialSigs[0].PubKey)
+ })
+
+ t.Run("empty inputs", func(t *testing.T) {
+ t.Parallel()
+ w, _ := createStartedWalletWithMocks(t)
+
+ // Act: Attempt to combine with no packets.
+ _, err := w.CombinePsbt(t.Context())
+
+ // Assert: Verify it returns the expected error.
+ require.ErrorIs(t, err, ErrNoPsbtsToCombine)
+ })
+
+ t.Run("mismatch tx", func(t *testing.T) {
+ t.Parallel()
+ w, _ := createStartedWalletWithMocks(t)
+
+ // Arrange: Create two packets with DIFFERENT transactions.
+ tx1 := wire.NewMsgTx(2)
+ tx1.AddTxIn(&wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: chainhash.Hash{1},
+ },
+ })
+ packet1, err := psbt.NewFromUnsignedTx(tx1)
+ require.NoError(t, err)
+
+ tx2 := wire.NewMsgTx(2)
+ tx2.AddTxIn(&wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: chainhash.Hash{2},
+ },
+ })
+ packet2, err := psbt.NewFromUnsignedTx(tx2)
+ require.NoError(t, err)
+
+ // Act: Attempt to combine conflicting packets.
+ _, err = w.CombinePsbt(t.Context(), packet1, packet2)
+
+ // Assert: Verify it returns the specific mismatch error.
+ require.ErrorIs(t, err, ErrDifferentTransactions)
+ })
+}
+
+// TestDeduplicateUnknowns tests that deduplicateUnknowns correctly adds new
+// unknowns from src to dest while avoiding duplicates based on the key.
+func TestDeduplicateUnknowns(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create some sample unknowns.
+ unknown1 := &psbt.Unknown{Key: []byte{1}, Value: []byte{1}}
+ unknown2 := &psbt.Unknown{Key: []byte{2}, Value: []byte{2}}
+ unknown3 := &psbt.Unknown{Key: []byte{3}, Value: []byte{3}}
+
+ // Arrange: Create a duplicate of unknown1 (same key, different value to
+ // ensure we only check key).
+ unknown1Dup := &psbt.Unknown{Key: []byte{1}, Value: []byte{99}}
+
+ tests := []struct {
+ name string
+ dest []*psbt.Unknown
+ src []*psbt.Unknown
+ expected []*psbt.Unknown
+ }{
+ {
+ name: "no duplicates",
+ dest: []*psbt.Unknown{unknown1},
+ src: []*psbt.Unknown{unknown2, unknown3},
+ expected: []*psbt.Unknown{unknown1, unknown2, unknown3},
+ },
+ {
+ name: "duplicates in src",
+ dest: []*psbt.Unknown{unknown1},
+ src: []*psbt.Unknown{unknown1Dup, unknown2},
+ expected: []*psbt.Unknown{unknown1, unknown2},
+ },
+ {
+ name: "empty dest",
+ dest: []*psbt.Unknown{},
+ src: []*psbt.Unknown{unknown1, unknown2},
+ expected: []*psbt.Unknown{unknown1, unknown2},
+ },
+ {
+ name: "empty src",
+ dest: []*psbt.Unknown{unknown1},
+ src: []*psbt.Unknown{},
+ expected: []*psbt.Unknown{unknown1},
+ },
+ {
+ name: "all duplicates",
+ dest: []*psbt.Unknown{unknown1, unknown2},
+ src: []*psbt.Unknown{unknown1Dup, unknown2},
+ expected: []*psbt.Unknown{unknown1, unknown2},
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Act: Call deduplicateUnknowns.
+ got := deduplicateUnknowns(tc.dest, tc.src)
+
+ // Assert: Verify the result.
+ require.Equal(t, tc.expected, got)
+ })
+ }
+}
+
+// TestSignPsbtLocked tests that SignPsbt fails when the wallet is locked.
+func TestSignPsbtLocked(t *testing.T) {
+ t.Parallel()
+
+ w, _ := createStartedWalletWithMocks(t)
+ // Minimal packet.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ packet, _ := psbt.NewFromUnsignedTx(tx)
+
+ _, err := w.SignPsbt(t.Context(), &SignPsbtParams{Packet: packet})
+ require.ErrorIs(t, err, ErrStateForbidden)
+}
+
+// TestFinalizePsbtLocked tests that FinalizePsbt fails when the wallet is
+// locked.
+func TestFinalizePsbtLocked(t *testing.T) {
+ t.Parallel()
+
+ w, _ := createStartedWalletWithMocks(t)
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{})
+ packet, _ := psbt.NewFromUnsignedTx(tx)
+
+ err := w.FinalizePsbt(t.Context(), packet)
+ require.ErrorIs(t, err, ErrStateForbidden)
+}
diff --git a/wallet/psbt_test.go b/wallet/psbt_test.go
deleted file mode 100644
index 7110b5c2d6..0000000000
--- a/wallet/psbt_test.go
+++ /dev/null
@@ -1,518 +0,0 @@
-// Copyright (c) 2020 The btcsuite developers
-// Use of this source code is governed by an ISC
-// license that can be found in the LICENSE file.
-
-package wallet
-
-import (
- "bytes"
- "encoding/hex"
- "testing"
-
- "github.com/btcsuite/btcd/btcutil/v2"
- "github.com/btcsuite/btcd/psbt/v2"
- "github.com/btcsuite/btcd/txscript/v2"
- "github.com/btcsuite/btcd/wire/v2"
- "github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/btcsuite/btcwallet/wallet/txrules"
- "github.com/btcsuite/btcwallet/wallet/txsizes"
- "github.com/stretchr/testify/require"
-)
-
-var (
- testScriptP2WSH, _ = hex.DecodeString(
- "0020d554616badeb46ccd4ce4b115e1c8d098e942d1387212d0af9ff93a1" +
- "9c8f100e",
- )
- testScriptP2WKH, _ = hex.DecodeString(
- "0014e7a43aa41ef6d72dc6baeeaad8362cedf63b79a3",
- )
-)
-
-// TestFundPsbt tests that a given PSBT packet is funded correctly.
-func TestFundPsbt(t *testing.T) {
- t.Parallel()
-
- w, cleanup := testWallet(t)
- defer cleanup()
-
- // Create a P2WKH address we can use to send some coins to.
- addr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0084)
- require.NoError(t, err)
- p2wkhAddr, err := txscript.PayToAddrScript(addr)
- require.NoError(t, err)
-
- // Also create a nested P2WKH address we can use to send some coins to.
- addr, err = w.CurrentAddress(0, waddrmgr.KeyScopeBIP0049Plus)
- require.NoError(t, err)
- np2wkhAddr, err := txscript.PayToAddrScript(addr)
- require.NoError(t, err)
-
- // Register two big UTXO that will be used when funding the PSBT.
- const utxo1Amount = 1000000
- incomingTx1 := &wire.MsgTx{
- TxIn: []*wire.TxIn{{}},
- TxOut: []*wire.TxOut{wire.NewTxOut(utxo1Amount, p2wkhAddr)},
- }
- addUtxo(t, w, incomingTx1)
- utxo1 := wire.OutPoint{
- Hash: incomingTx1.TxHash(),
- Index: 0,
- }
-
- const utxo2Amount = 900000
- incomingTx2 := &wire.MsgTx{
- TxIn: []*wire.TxIn{{}},
- TxOut: []*wire.TxOut{wire.NewTxOut(utxo2Amount, np2wkhAddr)},
- }
- addUtxo(t, w, incomingTx2)
- utxo2 := wire.OutPoint{
- Hash: incomingTx2.TxHash(),
- Index: 0,
- }
-
- testCases := []struct {
- name string
- packet *psbt.Packet
- feeRateSatPerKB btcutil.Amount
- changeKeyScope *waddrmgr.KeyScope
- expectedErr string
- validatePackage bool
- expectedChangeBeforeFee int64
- expectedInputs []wire.OutPoint
- additionalChecks func(*testing.T, *psbt.Packet, int32)
- }{{
- name: "no outputs provided",
- packet: &psbt.Packet{
- UnsignedTx: &wire.MsgTx{},
- },
- feeRateSatPerKB: 0,
- expectedErr: "PSBT packet must contain at least one " +
- "input or output",
- }, {
- name: "single input, no outputs",
- packet: &psbt.Packet{
- UnsignedTx: &wire.MsgTx{
- TxIn: []*wire.TxIn{{
- PreviousOutPoint: utxo1,
- }},
- },
- Inputs: []psbt.PInput{{}},
- },
- feeRateSatPerKB: 20000,
- validatePackage: true,
- expectedInputs: []wire.OutPoint{utxo1},
- expectedChangeBeforeFee: utxo1Amount,
- }, {
- name: "no dust outputs",
- packet: &psbt.Packet{
- UnsignedTx: &wire.MsgTx{
- TxOut: []*wire.TxOut{{
- PkScript: []byte("foo"),
- Value: 100,
- }},
- },
- Outputs: []psbt.POutput{{}},
- },
- feeRateSatPerKB: 0,
- expectedErr: "transaction output is dust",
- }, {
- name: "two outputs, no inputs",
- packet: &psbt.Packet{
- UnsignedTx: &wire.MsgTx{
- TxOut: []*wire.TxOut{{
- PkScript: testScriptP2WSH,
- Value: 100000,
- }, {
- PkScript: testScriptP2WKH,
- Value: 50000,
- }},
- },
- Outputs: []psbt.POutput{{}, {}},
- },
- feeRateSatPerKB: 2000, // 2 sat/byte
- expectedErr: "",
- validatePackage: true,
- expectedChangeBeforeFee: utxo1Amount - 150000,
- expectedInputs: []wire.OutPoint{utxo1},
- }, {
- name: "large output, no inputs",
- packet: &psbt.Packet{
- UnsignedTx: &wire.MsgTx{
- TxOut: []*wire.TxOut{{
- PkScript: testScriptP2WSH,
- Value: 1500000,
- }},
- },
- Outputs: []psbt.POutput{{}},
- },
- feeRateSatPerKB: 4000, // 4 sat/byte
- expectedErr: "",
- validatePackage: true,
- expectedChangeBeforeFee: (utxo1Amount + utxo2Amount) - 1500000,
- expectedInputs: []wire.OutPoint{utxo1, utxo2},
- }, {
- name: "two outputs, two inputs",
- packet: &psbt.Packet{
- UnsignedTx: &wire.MsgTx{
- TxIn: []*wire.TxIn{{
- PreviousOutPoint: utxo1,
- }, {
- PreviousOutPoint: utxo2,
- }},
- TxOut: []*wire.TxOut{{
- PkScript: testScriptP2WSH,
- Value: 100000,
- }, {
- PkScript: testScriptP2WKH,
- Value: 50000,
- }},
- },
- Inputs: []psbt.PInput{{}, {}},
- Outputs: []psbt.POutput{{}, {}},
- },
- feeRateSatPerKB: 2000, // 2 sat/byte
- expectedErr: "",
- validatePackage: true,
- expectedChangeBeforeFee: (utxo1Amount + utxo2Amount) - 150000,
- expectedInputs: []wire.OutPoint{utxo1, utxo2},
- additionalChecks: func(t *testing.T, packet *psbt.Packet,
- changeIndex int32) {
-
- // Check outputs, find index for each of the 3 expected.
- txOuts := packet.UnsignedTx.TxOut
- require.Len(t, txOuts, 3, "tx outputs")
-
- p2wkhIndex := -1
- p2wshIndex := -1
- totalOut := int64(0)
- for idx, txOut := range txOuts {
- script := txOut.PkScript
- totalOut += txOut.Value
-
- switch {
- case bytes.Equal(script, testScriptP2WKH):
- p2wkhIndex = idx
-
- case bytes.Equal(script, testScriptP2WSH):
- p2wshIndex = idx
-
- }
- }
- totalIn := int64(0)
- for _, txIn := range packet.Inputs {
- totalIn += txIn.WitnessUtxo.Value
- }
-
- // All outputs must be found.
- require.Greater(t, p2wkhIndex, -1)
- require.Greater(t, p2wshIndex, -1)
- require.Greater(t, changeIndex, int32(-1))
-
- // After BIP 69 sorting, the P2WKH output should be
- // before the P2WSH output because the PK script is
- // lexicographically smaller.
- require.Less(
- t, p2wkhIndex, p2wshIndex,
- "index after sorting",
- )
- },
- }, {
- name: "one input and a custom change scope: BIP0084",
- packet: &psbt.Packet{
- UnsignedTx: &wire.MsgTx{
- TxIn: []*wire.TxIn{{
- PreviousOutPoint: utxo1,
- }},
- },
- Inputs: []psbt.PInput{{}},
- },
- feeRateSatPerKB: 20000,
- validatePackage: true,
- changeKeyScope: &waddrmgr.KeyScopeBIP0084,
- expectedInputs: []wire.OutPoint{utxo1},
- expectedChangeBeforeFee: utxo1Amount,
- }, {
- name: "no inputs and a custom change scope: BIP0084",
- packet: &psbt.Packet{
- UnsignedTx: &wire.MsgTx{
- TxOut: []*wire.TxOut{{
- PkScript: testScriptP2WSH,
- Value: 100000,
- }, {
- PkScript: testScriptP2WKH,
- Value: 50000,
- }},
- },
- Outputs: []psbt.POutput{{}, {}},
- },
- feeRateSatPerKB: 2000, // 2 sat/byte
- expectedErr: "",
- validatePackage: true,
- changeKeyScope: &waddrmgr.KeyScopeBIP0084,
- expectedChangeBeforeFee: utxo1Amount - 150000,
- expectedInputs: []wire.OutPoint{utxo1},
- }}
-
- calcFee := func(feeRateSatPerKB btcutil.Amount,
- packet *psbt.Packet) btcutil.Amount {
-
- var numP2WKHInputs, numNP2WKHInputs int
- for _, txin := range packet.UnsignedTx.TxIn {
- if txin.PreviousOutPoint == utxo1 {
- numP2WKHInputs++
- }
- if txin.PreviousOutPoint == utxo2 {
- numNP2WKHInputs++
- }
- }
- estimatedSize := txsizes.EstimateVirtualSize(
- 0, 0, numP2WKHInputs, numNP2WKHInputs,
- packet.UnsignedTx.TxOut, 0,
- )
- return txrules.FeeForSerializeSize(
- feeRateSatPerKB, estimatedSize,
- )
- }
-
- for _, tc := range testCases {
- tc := tc
- t.Run(tc.name, func(t *testing.T) {
- changeIndex, err := w.FundPsbt(
- tc.packet, nil, 1, 0,
- tc.feeRateSatPerKB, CoinSelectionLargest,
- WithCustomChangeScope(tc.changeKeyScope),
- )
-
- // In any case, unlock the UTXO before continuing, we
- // don't want to pollute other test iterations.
- for _, in := range tc.packet.UnsignedTx.TxIn {
- w.UnlockOutpoint(in.PreviousOutPoint)
- }
-
- // Make sure the error is what we expected.
- if tc.expectedErr != "" {
- require.ErrorContains(t, err, tc.expectedErr)
- return
- }
-
- require.NoError(t, err)
-
- if !tc.validatePackage {
- return
- }
-
- // Check wire inputs.
- packet := tc.packet
- assertTxInputs(t, packet, tc.expectedInputs)
-
- // Run any additional tests if available.
- if tc.additionalChecks != nil {
- tc.additionalChecks(t, packet, changeIndex)
- }
-
- // Finally, check the change output size and fee.
- txOuts := packet.UnsignedTx.TxOut
- totalOut := int64(0)
- for _, txOut := range txOuts {
- totalOut += txOut.Value
- }
- totalIn := int64(0)
- for _, txIn := range packet.Inputs {
- totalIn += txIn.WitnessUtxo.Value
- }
- fee := totalIn - totalOut
-
- expectedFee := calcFee(tc.feeRateSatPerKB, packet)
- require.EqualValues(t, expectedFee, fee, "fee")
- require.EqualValues(
- t, tc.expectedChangeBeforeFee,
- txOuts[changeIndex].Value+int64(expectedFee),
- )
-
- changeTxOut := txOuts[changeIndex]
- changeOutput := packet.Outputs[changeIndex]
-
- require.NotEmpty(t, changeOutput.Bip32Derivation)
- b32d := changeOutput.Bip32Derivation[0]
- require.Len(t, b32d.Bip32Path, 5, "derivation path len")
- require.Len(t, b32d.PubKey, 33, "pubkey len")
-
- // The third item should be the branch and should belong
- // to a change output.
- require.EqualValues(t, 1, b32d.Bip32Path[3])
-
- assertChangeOutputScope(
- t, changeTxOut.PkScript, tc.changeKeyScope,
- )
-
- if txscript.IsPayToTaproot(changeTxOut.PkScript) {
- require.NotEmpty(
- t, changeOutput.TaprootInternalKey,
- )
- require.Len(
- t, changeOutput.TaprootInternalKey, 32,
- "internal key len",
- )
- require.NotEmpty(
- t, changeOutput.TaprootBip32Derivation,
- )
-
- trb32d := changeOutput.TaprootBip32Derivation[0]
- require.Equal(
- t, b32d.Bip32Path, trb32d.Bip32Path,
- )
- require.Len(
- t, trb32d.XOnlyPubKey, 32,
- "schnorr pubkey len",
- )
- require.Equal(
- t, changeOutput.TaprootInternalKey,
- trb32d.XOnlyPubKey,
- )
- }
- })
- }
-}
-
-func assertTxInputs(t *testing.T, packet *psbt.Packet,
- expected []wire.OutPoint) {
-
- require.Len(t, packet.UnsignedTx.TxIn, len(expected))
-
- // The order of the UTXOs is random, we need to loop through each of
- // them to make sure they're found. We also check that no signature data
- // was added yet.
- for _, txIn := range packet.UnsignedTx.TxIn {
- if !containsUtxo(expected, txIn.PreviousOutPoint) {
- t.Fatalf("outpoint %v not found in list of expected "+
- "UTXOs", txIn.PreviousOutPoint)
- }
-
- require.Empty(t, txIn.SignatureScript)
- require.Empty(t, txIn.Witness)
- }
-}
-
-// assertChangeOutputScope checks if the pkScript has the right type.
-func assertChangeOutputScope(t *testing.T, pkScript []byte,
- changeScope *waddrmgr.KeyScope) {
-
- // By default (changeScope == nil), the script should
- // be a pay-to-taproot one.
- switch changeScope {
- case nil, &waddrmgr.KeyScopeBIP0086:
- require.True(t, txscript.IsPayToTaproot(pkScript))
-
- case &waddrmgr.KeyScopeBIP0049Plus, &waddrmgr.KeyScopeBIP0084:
- require.True(t, txscript.IsPayToWitnessPubKeyHash(pkScript))
-
- case &waddrmgr.KeyScopeBIP0044:
- require.True(t, txscript.IsPayToPubKeyHash(pkScript))
-
- default:
- require.Fail(t, "assertChangeOutputScope error",
- "change scope: %s", changeScope.String())
- }
-}
-
-func containsUtxo(list []wire.OutPoint, candidate wire.OutPoint) bool {
- for _, utxo := range list {
- if utxo == candidate {
- return true
- }
- }
-
- return false
-}
-
-// TestFinalizePsbt tests that a given PSBT packet can be finalized.
-func TestFinalizePsbt(t *testing.T) {
- t.Parallel()
-
- w, cleanup := testWallet(t)
- defer cleanup()
-
- // Create a P2WKH address we can use to send some coins to.
- addr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0084)
- if err != nil {
- t.Fatalf("unable to get current address: %v", addr)
- }
- p2wkhAddr, err := txscript.PayToAddrScript(addr)
- if err != nil {
- t.Fatalf("unable to convert wallet address to p2wkh: %v", err)
- }
-
- // Also create a nested P2WKH address we can send coins to.
- addr, err = w.CurrentAddress(0, waddrmgr.KeyScopeBIP0049Plus)
- if err != nil {
- t.Fatalf("unable to get current address: %v", addr)
- }
- np2wkhAddr, err := txscript.PayToAddrScript(addr)
- if err != nil {
- t.Fatalf("unable to convert wallet address to np2wkh: %v", err)
- }
-
- // Register two big UTXO that will be used when funding the PSBT.
- utxOutP2WKH := wire.NewTxOut(1000000, p2wkhAddr)
- utxOutNP2WKH := wire.NewTxOut(1000000, np2wkhAddr)
- incomingTx := &wire.MsgTx{
- TxIn: []*wire.TxIn{{}},
- TxOut: []*wire.TxOut{utxOutP2WKH, utxOutNP2WKH},
- }
- addUtxo(t, w, incomingTx)
-
- // Create the packet that we want to sign.
- packet := &psbt.Packet{
- UnsignedTx: &wire.MsgTx{
- TxIn: []*wire.TxIn{{
- PreviousOutPoint: wire.OutPoint{
- Hash: incomingTx.TxHash(),
- Index: 0,
- },
- }, {
- PreviousOutPoint: wire.OutPoint{
- Hash: incomingTx.TxHash(),
- Index: 1,
- },
- }},
- TxOut: []*wire.TxOut{{
- PkScript: testScriptP2WKH,
- Value: 50000,
- }, {
- PkScript: testScriptP2WSH,
- Value: 100000,
- }, {
- PkScript: testScriptP2WKH,
- Value: 849632,
- }},
- },
- Inputs: []psbt.PInput{{
- WitnessUtxo: utxOutP2WKH,
- SighashType: txscript.SigHashAll,
- }, {
- NonWitnessUtxo: incomingTx,
- SighashType: txscript.SigHashAll,
- }},
- Outputs: []psbt.POutput{{}, {}, {}},
- }
-
- // Finalize it to add all witness data then extract the final TX.
- err = w.FinalizePsbt(nil, 0, packet)
- if err != nil {
- t.Fatalf("error finalizing PSBT packet: %v", err)
- }
- finalTx, err := psbt.Extract(packet)
- if err != nil {
- t.Fatalf("error extracting final TX from PSBT: %v", err)
- }
-
- // Finally verify that the created witness is valid.
- err = validateMsgTx(
- finalTx, [][]byte{utxOutP2WKH.PkScript, utxOutNP2WKH.PkScript},
- []btcutil.Amount{1000000, 1000000},
- )
- if err != nil {
- t.Fatalf("error validating tx: %v", err)
- }
-}
diff --git a/wallet/recovery.go b/wallet/recovery.go
index 11c2a7890d..b19a492eea 100644
--- a/wallet/recovery.go
+++ b/wallet/recovery.go
@@ -1,185 +1,20 @@
package wallet
import (
+ "errors"
+ "fmt"
"time"
"github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcutil/v2"
"github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
"github.com/btcsuite/btcd/chaincfg/v2"
- "github.com/btcsuite/btcd/chainhash/v2"
"github.com/btcsuite/btcd/txscript/v2"
"github.com/btcsuite/btcd/wire/v2"
"github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/btcsuite/btcwallet/walletdb"
"github.com/btcsuite/btcwallet/wtxmgr"
)
-// RecoveryManager maintains the state required to recover previously used
-// addresses, and coordinates batched processing of the blocks to search.
-type RecoveryManager struct {
- // recoveryWindow defines the key-derivation lookahead used when
- // attempting to recover the set of used addresses.
- recoveryWindow uint32
-
- // started is true after the first block has been added to the batch.
- started bool
-
- // blockBatch contains a list of blocks that have not yet been searched
- // for recovered addresses.
- blockBatch []wtxmgr.BlockMeta
-
- // state encapsulates and allocates the necessary recovery state for all
- // key scopes and subsidiary derivation paths.
- state *RecoveryState
-
- // chainParams are the parameters that describe the chain we're trying
- // to recover funds on.
- chainParams *chaincfg.Params
-}
-
-// NewRecoveryManager initializes a new RecoveryManager with a derivation
-// look-ahead of `recoveryWindow` child indexes, and pre-allocates a backing
-// array for `batchSize` blocks to scan at once.
-func NewRecoveryManager(recoveryWindow, batchSize uint32,
- chainParams *chaincfg.Params) *RecoveryManager {
-
- return &RecoveryManager{
- recoveryWindow: recoveryWindow,
- blockBatch: make([]wtxmgr.BlockMeta, 0, batchSize),
- chainParams: chainParams,
- state: NewRecoveryState(recoveryWindow),
- }
-}
-
-// Resurrect restores all known addresses for the provided scopes that can be
-// found in the walletdb namespace, in addition to restoring all outpoints that
-// have been previously found. This method ensures that the recovery state's
-// horizons properly start from the last found address of a prior recovery
-// attempt.
-func (rm *RecoveryManager) Resurrect(ns walletdb.ReadBucket,
- scopedMgrs map[waddrmgr.KeyScope]*waddrmgr.ScopedKeyManager,
- credits []wtxmgr.Credit) error {
-
- // First, for each scope that we are recovering, rederive all of the
- // addresses up to the last found address known to each branch.
- for keyScope, scopedMgr := range scopedMgrs {
- // Load the current account properties for this scope, using the
- // the default account number.
- // TODO(conner): rescan for all created accounts if we allow
- // users to use non-default address
- scopeState := rm.state.StateForScope(keyScope)
- acctProperties, err := scopedMgr.AccountProperties(
- ns, waddrmgr.DefaultAccountNum,
- )
- if err != nil {
- return err
- }
-
- // Fetch the external key count, which bounds the indexes we
- // will need to rederive.
- externalCount := acctProperties.ExternalKeyCount
-
- // Walk through all indexes through the last external key,
- // deriving each address and adding it to the external branch
- // recovery state's set of addresses to look for.
- for i := uint32(0); i < externalCount; i++ {
- keyPath := externalKeyPath(i)
- addr, err := scopedMgr.DeriveFromKeyPath(ns, keyPath)
- if err != nil && err != hdkeychain.ErrInvalidChild {
- return err
- } else if err == hdkeychain.ErrInvalidChild {
- scopeState.ExternalBranch.MarkInvalidChild(i)
- continue
- }
-
- scopeState.ExternalBranch.AddAddr(i, addr.Address())
- }
-
- // Fetch the internal key count, which bounds the indexes we
- // will need to rederive.
- internalCount := acctProperties.InternalKeyCount
-
- // Walk through all indexes through the last internal key,
- // deriving each address and adding it to the internal branch
- // recovery state's set of addresses to look for.
- for i := uint32(0); i < internalCount; i++ {
- keyPath := internalKeyPath(i)
- addr, err := scopedMgr.DeriveFromKeyPath(ns, keyPath)
- if err != nil && err != hdkeychain.ErrInvalidChild {
- return err
- } else if err == hdkeychain.ErrInvalidChild {
- scopeState.InternalBranch.MarkInvalidChild(i)
- continue
- }
-
- scopeState.InternalBranch.AddAddr(i, addr.Address())
- }
-
- // The key counts will point to the next key that can be
- // derived, so we subtract one to point to last known key. If
- // the key count is zero, then no addresses have been found.
- if externalCount > 0 {
- scopeState.ExternalBranch.ReportFound(externalCount - 1)
- }
- if internalCount > 0 {
- scopeState.InternalBranch.ReportFound(internalCount - 1)
- }
- }
-
- // In addition, we will re-add any outpoints that are known the wallet
- // to our global set of watched outpoints, so that we can watch them for
- // spends.
- for _, credit := range credits {
- _, addrs, _, err := txscript.ExtractPkScriptAddrs(
- credit.PkScript, rm.chainParams,
- )
- if err != nil {
- return err
- }
-
- rm.state.AddWatchedOutPoint(&credit.OutPoint, addrs[0])
- }
-
- return nil
-}
-
-// AddToBlockBatch appends the block information, consisting of hash and height,
-// to the batch of blocks to be searched.
-func (rm *RecoveryManager) AddToBlockBatch(hash *chainhash.Hash, height int32,
- timestamp time.Time) {
-
- if !rm.started {
- log.Infof("Seed birthday surpassed, starting recovery "+
- "of wallet from height=%d hash=%v with "+
- "recovery-window=%d", height, *hash, rm.recoveryWindow)
- rm.started = true
- }
-
- block := wtxmgr.BlockMeta{
- Block: wtxmgr.Block{
- Hash: *hash,
- Height: height,
- },
- Time: timestamp,
- }
- rm.blockBatch = append(rm.blockBatch, block)
-}
-
-// BlockBatch returns a buffer of blocks that have not yet been searched.
-func (rm *RecoveryManager) BlockBatch() []wtxmgr.BlockMeta {
- return rm.blockBatch
-}
-
-// ResetBlockBatch resets the internal block buffer to conserve memory.
-func (rm *RecoveryManager) ResetBlockBatch() {
- rm.blockBatch = rm.blockBatch[:0]
-}
-
-// State returns the current RecoveryState.
-func (rm *RecoveryManager) State() *RecoveryState {
- return rm.state
-}
-
// RecoveryState manages the initialization and lookup of ScopeRecoveryStates
// for any actively used key scopes.
//
@@ -201,31 +36,70 @@ type RecoveryState struct {
recoveryWindow uint32
// scopes maintains a map of each requested key scope to its active
- // RecoveryState.
+ // RecoveryState. Used for legacy compatibility.
+ //
+ // TODO(yy): Deprecated, remove.
scopes map[waddrmgr.KeyScope]*ScopeRecoveryState
+ // branchStates maintains the recovery state for every branch (scope +
+ // account + branch). This is the source of truth.
+ branchStates map[waddrmgr.BranchScope]*BranchRecoveryState
+
// watchedOutPoints contains the set of all outpoints known to the
// wallet. This is updated iteratively as new outpoints are found during
// a rescan.
+ //
+ // TODO(yy): Deprecated, remove.
watchedOutPoints map[wire.OutPoint]address.Address
+
+ // chainParams are the parameters that describe the chain we're trying
+ // to recover funds on. These are set at initialization and remain
+ // constant.
+ chainParams *chaincfg.Params
+
+ // addrMgr is the address manager used to derive new keys and manage
+ // account state.
+ addrMgr waddrmgr.AddrStore
+
+ // outpoints tracks unspent outpoints to detect spends. The value is
+ // the PkScript of the outpoint. This map is transient, initialized by
+ // InitScanState at the beginning of a batch scan and pruned by Prune()
+ // at the end to manage memory.
+ outpoints map[wire.OutPoint][]byte
+
+ // addrFilters maps encoded addresses to their derivation info for
+ // identifying incoming payments. This map is transient, initialized by
+ // InitScanState at the beginning of a batch scan and pruned by Prune()
+ // at the end to manage memory.
+ addrFilters map[string]AddrEntry
}
// NewRecoveryState creates a new RecoveryState using the provided
// recoveryWindow. Each RecoveryState that is subsequently initialized for a
// particular key scope will receive the same recoveryWindow.
-func NewRecoveryState(recoveryWindow uint32) *RecoveryState {
- scopes := make(map[waddrmgr.KeyScope]*ScopeRecoveryState)
+func NewRecoveryState(recoveryWindow uint32,
+ chainParams *chaincfg.Params,
+ addrMgr waddrmgr.AddrStore) *RecoveryState {
return &RecoveryState{
- recoveryWindow: recoveryWindow,
- scopes: scopes,
+ recoveryWindow: recoveryWindow,
+ scopes: make(
+ map[waddrmgr.KeyScope]*ScopeRecoveryState,
+ ),
+ branchStates: make(
+ map[waddrmgr.BranchScope]*BranchRecoveryState,
+ ),
watchedOutPoints: make(map[wire.OutPoint]address.Address),
+ chainParams: chainParams,
+ addrMgr: addrMgr,
}
}
-// StateForScope returns a ScopeRecoveryState for the provided key scope. If one
-// does not already exist, a new one will be generated with the RecoveryState's
-// recoveryWindow.
+// StateForScope returns the recovery state for the default account of the
+// provided key scope. This exists for backward compatibility with legacy
+// recovery logic which only supports the default account.
+//
+// TODO(yy): Deprecated, remove.
func (rs *RecoveryState) StateForScope(
keyScope waddrmgr.KeyScope) *ScopeRecoveryState {
@@ -243,38 +117,464 @@ func (rs *RecoveryState) StateForScope(
// WatchedOutPoints returns the global set of outpoints that are known to belong
// to the wallet during recovery.
+//
+// TODO(yy): Deprecated, remove.
func (rs *RecoveryState) WatchedOutPoints() map[wire.OutPoint]address.Address {
return rs.watchedOutPoints
}
// AddWatchedOutPoint updates the recovery state's set of known outpoints that
// we will monitor for spends during recovery.
+//
+// TODO(yy): Deprecated, remove.
func (rs *RecoveryState) AddWatchedOutPoint(outPoint *wire.OutPoint,
addr address.Address) {
rs.watchedOutPoints[*outPoint] = addr
}
-// ScopeRecoveryState is used to manage the recovery of addresses generated
-// under a particular BIP32 account. Each account tracks both an external and
-// internal branch recovery state, both of which use the same recovery window.
-type ScopeRecoveryState struct {
- // ExternalBranch is the recovery state of addresses generated for
- // external use, i.e. receiving addresses.
- ExternalBranch *BranchRecoveryState
-
- // InternalBranch is the recovery state of addresses generated for
- // internal use, i.e. change addresses.
- InternalBranch *BranchRecoveryState
+// String returns a summary of the recovery state.
+func (rs *RecoveryState) String() string {
+ return fmt.Sprintf("RecoveryState(addrs=%d, outpoints=%d)",
+ len(rs.addrFilters), len(rs.outpoints))
+}
+
+// Empty returns true if there are no addresses or outpoints to watch.
+func (rs *RecoveryState) Empty() bool {
+ return len(rs.addrFilters) == 0 && len(rs.outpoints) == 0
+}
+
+// WatchListSize returns the total number of items (addresses + outpoints)
+// in the current watchlist.
+func (rs *RecoveryState) WatchListSize() int {
+ return len(rs.addrFilters) + len(rs.outpoints)
+}
+
+// GetBranchState returns the recovery state for the provided branch scope.
+// It acts as the source of truth for branch states by either retrieving an
+// existing in-memory BranchRecoveryState for the given `bs` (branch scope)
+// or creating a new one if it doesn't already exist.
+//
+// When a new state is created, it fetches the appropriate AccountStore (key
+// manager) from the Address Manager. This ensures that the BranchRecoveryState
+// is correctly linked to its derivation logic and maintains a consistent,
+// up-to-date view of the branch's lookahead horizon and derived addresses
+// throughout the recovery process. This centralization prevents redundant
+// state creation and ensures all recovery operations for a specific branch
+// operate on the same instance.
+func (rs *RecoveryState) GetBranchState(bs waddrmgr.BranchScope) (
+ *BranchRecoveryState, error) {
+
+ if s, ok := rs.branchStates[bs]; ok {
+ return s, nil
+ }
+
+ // We assume the scope is valid and active if we are requesting state
+ // for it.
+ var mgr waddrmgr.AccountStore
+ if rs.addrMgr != nil {
+ var err error
+
+ mgr, err = rs.addrMgr.FetchScopedKeyManager(bs.Scope)
+ if err != nil {
+ return nil, fmt.Errorf("failed to fetch manager for "+
+ "scope %v: %w", bs.Scope, err)
+ }
+ }
+
+ s := NewBranchRecoveryState(rs.recoveryWindow, mgr)
+ rs.branchStates[bs] = s
+
+ return s, nil
+}
+
+// AddrEntry holds the derivation info for an address to support
+// reverse lookups during filtering.
+type AddrEntry struct {
+ // Address is the cached address for script generation.
+ Address address.Address
+
+ // Credit records the transaction credit metadata (index, change)
+ // when this address matches a transaction output.
+ Credit wtxmgr.CreditEntry
+
+ // IsLookahead indicates whether this address is part of the current
+ // lookahead window. If true, finding this address *in the block*
+ // triggers horizon expansion.
+ IsLookahead bool
+
+ // addrScope identifies the specific address derivation path.
+ addrScope waddrmgr.AddrScope
+}
+
+// Initialize prepares the recovery state for a new batch scan by syncing
+// horizons, populating history/UTXOs, and expanding the lookahead window.
+//
+// TODO(yy): Once RecoveryManager is removed, privatize this method and call
+// it directly from NewRecoveryState to simplify the initialization flow.
+func (rs *RecoveryState) Initialize(accounts []*waddrmgr.AccountProperties,
+ initialAddrs []address.Address, initialUnspent []wtxmgr.Credit) error {
+
+ rs.outpoints = make(map[wire.OutPoint][]byte)
+ rs.addrFilters = make(map[string]AddrEntry)
+
+ // 1. Sync Horizons & Derive Lookahead.
+ //
+ // We iterate over all accounts loaded from the database (horizonData)
+ // to sync the recovery horizons. This loop will also populate the
+ // rs.branchStates map with all active branches. For each branch, it
+ // will derive addresses up to the recovery window size and add them to
+ // rs.addrFilters.
+ for _, props := range accounts {
+ err := rs.initAccountState(props)
+ if err != nil {
+ return err
+ }
+ }
+
+ // 2. Populate the filter with "History" - addresses that are already
+ // active/used in the wallet database. We monitor these to detect any
+ // new payments to existing keys.
+ for _, addr := range initialAddrs {
+ addrStr := addr.EncodeAddress()
+
+ entry := AddrEntry{
+ Address: addr,
+ IsLookahead: false,
+ }
+ rs.addrFilters[addrStr] = entry
+ }
+
+ // 3. Populate the set of unspent outputs (UTXOs) to watch. We monitor
+ // these outpoints to detect when they are spent by a transaction in a
+ // block.
+ for _, u := range initialUnspent {
+ rs.outpoints[u.OutPoint] = u.PkScript
+ }
+
+ return nil
}
-// NewScopeRecoveryState initializes an ScopeRecoveryState with the chosen
-// recovery window.
-func NewScopeRecoveryState(recoveryWindow uint32) *ScopeRecoveryState {
- return &ScopeRecoveryState{
- ExternalBranch: NewBranchRecoveryState(recoveryWindow),
- InternalBranch: NewBranchRecoveryState(recoveryWindow),
+// BuildCFilterData constructs the list of scripts (addresses + outpoints) used
+// for CFilter matching. This is an expensive operation (script derivation) and
+// should only be called when filters are actually used.
+func (rs *RecoveryState) BuildCFilterData() ([][]byte, error) {
+ // Calculate size: addrFilters (Addrs) + outpoints (UTXOs).
+ size := len(rs.addrFilters) + len(rs.outpoints)
+ watchList := make([][]byte, 0, size)
+
+ for _, entry := range rs.addrFilters {
+ script, err := txscript.PayToAddrScript(entry.Address)
+ if err != nil {
+ return nil, fmt.Errorf("failed to gen script for %s: "+
+ "%w", entry.Address, err)
+ }
+
+ watchList = append(watchList, script)
}
+
+ for _, script := range rs.outpoints {
+ watchList = append(watchList, script)
+ }
+
+ return watchList, nil
+}
+
+// TxEntry pairs a transaction record with its extracted address entries.
+type TxEntry struct {
+ Rec *wtxmgr.TxRecord
+ Entries []AddrEntry
+}
+
+// TxEntries is a list of matched transaction entries, preserving the order of
+// transactions.
+type TxEntries []TxEntry
+
+// BlockProcessResult contains the results of processing a block for recovery.
+type BlockProcessResult struct {
+ // RelevantTxs is a slice of transactions within the block that are
+ // relevant to the wallet (i.e., they spend one of our watched
+ // outpoints or send funds to one of our addresses).
+ RelevantTxs []*btcutil.Tx
+
+ // FoundHorizons maps the BranchScope to the highest child index found
+ // in this block. This is used for persistent horizon expansion.
+ FoundHorizons map[waddrmgr.BranchScope]uint32
+
+ // RelevantOutputs holds the details of transaction outputs that
+ // matched the wallet's filters. This allows efficient access to
+ // derivation information without re-parsing scripts or re-fetching
+ // addresses.
+ RelevantOutputs TxEntries
+
+ // Expanded indicates whether any new addresses were derived and added
+ // to the address filters as a result of processing this block (i.e., a
+ // lookahead horizon expansion was triggered).
+ Expanded bool
+}
+
+// ProcessBlock filters a block for relevant transactions and expands the
+// recovery horizons if new addresses are found. It handles the "Filter ->
+// Expand -> Retry" loop internally and returns the relevant transactions,
+// found horizons (for state update), relevant matches (for efficient
+// ingestion), and a boolean indicating if any expansion occurred.
+func (rs *RecoveryState) ProcessBlock(block *wire.MsgBlock) (
+ *BlockProcessResult, error) {
+
+ var (
+ expanded bool
+ relevantTxs []*btcutil.Tx
+ foundScopes map[waddrmgr.AddrScope]struct{}
+ relevantOutputs TxEntries
+ foundHorizons map[waddrmgr.BranchScope]uint32
+ )
+
+ for {
+ relevantTxs, foundScopes, relevantOutputs = rs.filterBlock(
+ block,
+ )
+
+ foundHorizons = rs.reportFound(foundScopes)
+ if len(foundHorizons) == 0 {
+ break
+ }
+
+ expandedNow, err := rs.expandHorizons()
+ if err != nil {
+ return nil, fmt.Errorf("expand horizons: %w", err)
+ }
+
+ if !expandedNow {
+ break
+ }
+
+ expanded = true
+ }
+
+ return &BlockProcessResult{
+ RelevantTxs: relevantTxs,
+ FoundHorizons: foundHorizons,
+ RelevantOutputs: relevantOutputs,
+ Expanded: expanded,
+ }, nil
+}
+
+// initAccountState initializes the recovery state for a specific account by
+// setting up branch recovery states for both external and internal branches.
+// It iterates through the known address counts (from the provided account
+// properties) to sync the horizons and populate the address filters with
+// derived addresses up to the recovery window.
+func (rs *RecoveryState) initAccountState(
+ props *waddrmgr.AccountProperties) error {
+
+ initBranch := func(branch uint32, lastKnownIndex uint32) error {
+ bs := waddrmgr.BranchScope{
+ Scope: props.KeyScope,
+ Account: props.AccountNumber,
+ Branch: branch,
+ }
+
+ branchState, err := rs.GetBranchState(bs)
+ if err != nil {
+ return err
+ }
+
+ entries, err := branchState.buildAddrFilters(bs, lastKnownIndex)
+ if err != nil {
+ return err
+ }
+
+ for _, entry := range entries {
+ rs.addrFilters[entry.Address.EncodeAddress()] = entry
+ }
+
+ return nil
+ }
+
+ err := initBranch(waddrmgr.ExternalBranch, props.ExternalKeyCount)
+ if err != nil {
+ return fmt.Errorf("derive external addrs for %s/%d': %w",
+ props.KeyScope, props.AccountNumber, err)
+ }
+
+ err = initBranch(waddrmgr.InternalBranch, props.InternalKeyCount)
+ if err != nil {
+ return fmt.Errorf("derive internal addrs for %s/%d': %w",
+ props.KeyScope, props.AccountNumber, err)
+ }
+
+ return nil
+}
+
+// reportFound updates the recovery state with any addresses found in the
+// current block. It returns the set of found horizons (max index per branch).
+func (rs *RecoveryState) reportFound(
+ found map[waddrmgr.AddrScope]struct{}) map[waddrmgr.BranchScope]uint32 {
+
+ foundHorizons := make(map[waddrmgr.BranchScope]uint32)
+
+ // Group by branch and find max index.
+ for addrScope := range found {
+ bs := addrScope.BranchScope
+
+ idx := addrScope.Index
+ if currentMax, ok := foundHorizons[bs]; !ok ||
+ idx > currentMax {
+
+ foundHorizons[bs] = idx
+ }
+ }
+
+ // Update memory state.
+ for bs, maxIdx := range foundHorizons {
+ state, err := rs.GetBranchState(bs)
+ if err != nil {
+ // This should theoretically not happen if the found
+ // map was populated correctly from filters that
+ // correspond to valid branch states. Log this as an
+ // error for debugging.
+ log.Errorf("Failed to get branch state for %v: %v", bs,
+ err)
+
+ continue
+ }
+
+ state.ReportFound(maxIdx)
+ }
+
+ return foundHorizons
+}
+
+// filterBlock checks a block for any transactions relevant to the wallet.
+// It returns the relevant transactions and the set of found addresses (by
+// branch scope and index).
+//
+// NOTE: This method mutates the recovery state's outpoints in-place by
+// removing spent inputs and adding new relevant outputs. This handles
+// intra-block chains correctly.
+func (rs *RecoveryState) filterBlock(block *wire.MsgBlock) ([]*btcutil.Tx,
+ map[waddrmgr.AddrScope]struct{}, TxEntries) {
+
+ var relevant []*btcutil.Tx
+
+ foundScopes := make(map[waddrmgr.AddrScope]struct{})
+
+ var relevantOutputs TxEntries
+ for _, tx := range block.Transactions {
+ isRelevant, entries := rs.filterTx(tx, foundScopes)
+ if isRelevant {
+ relevant = append(relevant, btcutil.NewTx(tx))
+
+ // We create a temporary record here. The timestamp
+ // will be updated during commitment.
+ rec, _ := wtxmgr.NewTxRecordFromMsgTx(
+ tx, time.Time{},
+ )
+
+ relevantOutputs = append(relevantOutputs, TxEntry{
+ Rec: rec,
+ Entries: entries,
+ })
+ }
+ }
+
+ return relevant, foundScopes, relevantOutputs
+}
+
+// filterTx checks a single transaction for relevance and returns any matching
+// address entries.
+func (rs *RecoveryState) filterTx(tx *wire.MsgTx,
+ foundScopes map[waddrmgr.AddrScope]struct{}) (bool, []AddrEntry) {
+
+ var (
+ isRelevant bool
+ entries []AddrEntry
+ )
+
+ // Check if the transaction spends any of our watched outpoints. If so,
+ // it's relevant (a debit).
+ for _, txIn := range tx.TxIn {
+ if _, ok := rs.outpoints[txIn.PreviousOutPoint]; ok {
+ isRelevant = true
+
+ delete(rs.outpoints, txIn.PreviousOutPoint)
+ }
+ }
+
+ // Check if the transaction pays to any of our watched addresses. If
+ // so, it's relevant (a credit).
+ for i, txOut := range tx.TxOut {
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(
+ txOut.PkScript, rs.chainParams,
+ )
+ if err != nil {
+ log.Debugf("Could not extract addresses from script "+
+ "%x: %v", txOut.PkScript, err)
+
+ continue
+ }
+
+ for _, a := range addrs {
+ entry, ok := rs.addrFilters[a.EncodeAddress()]
+ if !ok {
+ continue
+ }
+
+ isRelevant = true
+
+ if entry.IsLookahead {
+ foundScopes[entry.addrScope] = struct{}{}
+ }
+
+ //nolint:gosec // Output index fits in uint32.
+ idx := uint32(i)
+
+ // Create result entry with Credit populated.
+ entry.Credit.Index = idx
+ entries = append(entries, entry)
+
+ // Add new output to map immediately to catch
+ // intra-block spends.
+ op := wire.OutPoint{Hash: tx.TxHash(), Index: idx}
+ rs.outpoints[op] = txOut.PkScript
+ }
+ }
+
+ return isRelevant, entries
+}
+
+// expandHorizons ensures that the recovery state's lookahead horizon is
+// sufficient by deriving new addresses if needed, and then updates the
+// internal batch artifacts (addrFilters) with the lookahead addresses.
+func (rs *RecoveryState) expandHorizons() (bool, error) {
+ // We iterate over all active branch states and ensure their lookahead
+ // windows are sufficiently expanded.
+ //
+ // NOTE: rs.branchStates contains the set of all active branches
+ // determined at initialization. This set remains static for the
+ // duration of the batch scan, even as the internal state of each
+ // branch (horizon) evolves.
+ var expanded bool
+ for bs, branchState := range rs.branchStates {
+ // Passing 0 for lastKnownIndex means we don't want to update
+ // the found status based on historical data, just ensure the
+ // lookahead is sufficient based on the current state.
+ newEntries, err := branchState.buildAddrFilters(bs, 0)
+ if err != nil {
+ return false, err
+ }
+
+ if len(newEntries) > 0 {
+ expanded = true
+
+ for _, entry := range newEntries {
+ rs.addrFilters[entry.Address.EncodeAddress()] =
+ entry
+ }
+ }
+ }
+
+ return expanded, nil
}
// BranchRecoveryState maintains the required state in-order to properly
@@ -288,6 +588,8 @@ func NewScopeRecoveryState(recoveryWindow uint32) *ScopeRecoveryState {
// - Reporting that an address has been found.
// - Retrieving all currently derived addresses for the branch.
// - Looking up a particular address by its child index.
+//
+// TODO(yy): Privatize this struct and all its methods.
type BranchRecoveryState struct {
// recoveryWindow defines the key-derivation lookahead used when
// attempting to recover the set of addresses on this branch.
@@ -307,15 +609,22 @@ type BranchRecoveryState struct {
// invalidChildren records the set of child indexes that derive to
// invalid keys.
invalidChildren map[uint32]struct{}
+
+ // manager is the scoped key manager used to derive addresses for this
+ // branch.
+ manager waddrmgr.AccountStore
}
// NewBranchRecoveryState creates a new BranchRecoveryState that can be used to
// track either the external or internal branch of an account's derivation path.
-func NewBranchRecoveryState(recoveryWindow uint32) *BranchRecoveryState {
+func NewBranchRecoveryState(recoveryWindow uint32,
+ manager waddrmgr.AccountStore) *BranchRecoveryState {
+
return &BranchRecoveryState{
recoveryWindow: recoveryWindow,
addresses: make(map[uint32]address.Address),
invalidChildren: make(map[uint32]struct{}),
+ manager: manager,
}
}
@@ -410,3 +719,74 @@ func (brs *BranchRecoveryState) NumInvalidInHorizon() uint32 {
return nInvalid
}
+
+// buildAddrFilters is a helper method that maintains the address lookahead
+// window for this branch. It performs two main tasks:
+// 1. Syncs the branch state to the provided `lastKnownIndex` (if non-zero),
+// ensuring the state reflects what is known from disk or previous scans.
+// 2. Extends the lookahead window if necessary, deriving new addresses and
+// creating filter entries for them.
+//
+// The returned entries are used to populate the batch-wide address filter.
+func (brs *BranchRecoveryState) buildAddrFilters(bs waddrmgr.BranchScope,
+ lastKnownIndex uint32) ([]AddrEntry, error) {
+
+ // 1. Sync State.
+ // If a last known index is provided (e.g., from DB during
+ // initialization), we update our state to reflect that we've found
+ // addresses up to this point.
+ if lastKnownIndex > 0 {
+ brs.ReportFound(lastKnownIndex - 1)
+ }
+
+ // 2. Compute Extension.
+ // Determine the current horizon and how many new addresses are needed
+ // to maintain the required lookahead window (recoveryWindow) beyond
+ // the last found address.
+ curHorizon, windowToDerive := brs.ExtendHorizon()
+ count, childIndex := uint32(0), curHorizon
+
+ var newEntries []AddrEntry
+
+ // 3. Derive & Cache.
+ // Iterate to derive the required number of new addresses.
+ for count < windowToDerive {
+ addr, _, err := brs.manager.DeriveAddr(
+ bs.Account, bs.Branch, childIndex,
+ )
+ if err != nil {
+ // Handle invalid children (rare in HD, but possible).
+ // We skip the invalid index, mark it, and continue to
+ // ensure we still generate the full window of *valid*
+ // addresses.
+ if errors.Is(err, hdkeychain.ErrInvalidChild) {
+ brs.MarkInvalidChild(childIndex)
+ childIndex++
+
+ continue
+ }
+
+ return nil, fmt.Errorf("derive addr: %w", err)
+ }
+
+ // Cache the valid address in the branch state for future
+ // lookups.
+ brs.AddAddr(childIndex, addr)
+
+ // Create a filter entry for the new address. This entry
+ // contains the metadata (Scope, Account, Branch, Index) needed
+ // to map a future hit back to this specific derivation path.
+ as := waddrmgr.AddrScope{BranchScope: bs, Index: childIndex}
+ entry := AddrEntry{
+ Address: addr,
+ addrScope: as,
+ IsLookahead: true,
+ }
+ newEntries = append(newEntries, entry)
+
+ childIndex++
+ count++
+ }
+
+ return newEntries, nil
+}
diff --git a/wallet/recovery_test.go b/wallet/recovery_test.go
index cc65d74f76..d26affa2da 100644
--- a/wallet/recovery_test.go
+++ b/wallet/recovery_test.go
@@ -1,10 +1,19 @@
-package wallet_test
+package wallet
import (
- "runtime"
+ "errors"
+ "fmt"
"testing"
+ "time"
- "github.com/btcsuite/btcwallet/wallet"
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/stretchr/testify/require"
)
// Harness holds the BranchRecoveryState being tested, the recovery window being
@@ -12,7 +21,7 @@ import (
// and next unfound values.
type Harness struct {
t *testing.T
- brs *wallet.BranchRecoveryState
+ brs *BranchRecoveryState
recoveryWindow uint32
expHorizon uint32
expNextUnfound uint32
@@ -209,7 +218,7 @@ func TestBranchRecoveryState(t *testing.T) {
// Expected horizon: 30.
}
- brs := wallet.NewBranchRecoveryState(recoveryWindow)
+ brs := NewBranchRecoveryState(recoveryWindow, nil)
harness := &Harness{
t: t,
brs: brs,
@@ -238,9 +247,835 @@ func assertNumInvalid(t *testing.T, i int, have, want uint32) {
}
func assertHaveWant(t *testing.T, i int, msg string, have, want uint32) {
- _, _, line, _ := runtime.Caller(2)
- if want != have {
- t.Fatalf("[line: %d, step: %d] %s: got %d, want %d",
- line, i, msg, have, want)
+ t.Helper()
+ require.Equal(t, want, have, "[step: %d] %s", i, msg)
+}
+
+// TestRecoveryManagerBatch verifies that the RecoveryManager correctly tracks
+// and resets its internal batch of processed blocks.
+func TestRecoveryManagerBatch(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a new recovery manager with a recovery window of 10
+ // and a lookahead distance of 5.
+ rm := NewRecoveryManager(10, 5, &chainParams)
+
+ // Act: Add a block to the current batch.
+ hash := chainhash.Hash{0x01}
+ rm.AddToBlockBatch(&hash, 100, time.Now())
+
+ // Assert: Verify that the block was correctly added to the batch.
+ batch := rm.BlockBatch()
+ require.Len(t, batch, 1)
+ require.Equal(t, int32(100), batch[0].Height)
+
+ // Act: Clear the current batch.
+ rm.ResetBlockBatch()
+
+ // Assert: Verify that the batch is now empty.
+ require.Empty(t, rm.BlockBatch())
+}
+
+// TestBranchRecoveryStateHorizon verifies horizon expansion logic.
+func TestBranchRecoveryStateHorizon(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Window 10.
+ brs := NewBranchRecoveryState(10, nil)
+
+ // Act: Initial horizon extend.
+ // Horizon is 0. NextUnfound is 0. MinValid = 0 + 10 = 10.
+ // Delta = 10 - 0 = 10.
+ // Returns current horizon (start index) and delta.
+ horizon, delta := brs.ExtendHorizon()
+ require.Equal(t, uint32(0), horizon)
+ require.Equal(t, uint32(10), delta)
+
+ // Act: Report found at 5.
+ brs.ReportFound(5)
+
+ // NextUnfound becomes 6.
+ require.Equal(t, uint32(6), brs.NextUnfound())
+
+ // Act: Extend again.
+ // MinValid = 6 + 10 = 16.
+ // Current Horizon = 10.
+ // Delta = 16 - 10 = 6.
+ horizon, delta = brs.ExtendHorizon()
+ require.Equal(t, uint32(10), horizon)
+ require.Equal(t, uint32(6), delta)
+}
+
+// TestBranchRecoveryStateInvalidChild verifies handling of invalid keys.
+func TestBranchRecoveryStateInvalidChild(t *testing.T) {
+ t.Parallel()
+
+ brs := NewBranchRecoveryState(5, nil)
+ // Initial: Horizon 5.
+ brs.ExtendHorizon()
+
+ // Act: Mark index 2 as invalid.
+ brs.MarkInvalidChild(2)
+
+ // Assert: Horizon incremented to 6.
+ require.Equal(t, uint32(1), brs.NumInvalidInHorizon())
+
+ // Act: Extend.
+ // NextUnfound = 0. Window = 5. Invalid = 1.
+ // MinValid = 0 + 5 + 1 = 6.
+ // Current Horizon = 6.
+ // Delta = 0.
+ horizon, delta := brs.ExtendHorizon()
+ require.Equal(t, uint32(6), horizon)
+ require.Equal(t, uint32(0), delta)
+
+ // Act: Found 3.
+ brs.ReportFound(3)
+
+ // Invalid child 2 is < 3, so it should be pruned.
+ require.Equal(t, uint32(0), brs.NumInvalidInHorizon())
+}
+
+// TestGetBranchState verifies the GetBranchState method of RecoveryState.
+// It ensures that the method correctly fetches and caches BranchRecoveryState
+// instances based on the provided BranchScope, optimizing for subsequent
+// lookups by returning the cached state instead of re-creating it.
+func TestGetBranchState(t *testing.T) {
+ t.Parallel()
+
+ addrMgr := &mockAddrStore{}
+ defer addrMgr.AssertExpectations(t)
+
+ scope := waddrmgr.KeyScope{
+ Purpose: waddrmgr.KeyScopeBIP0084.Purpose,
+ Coin: waddrmgr.KeyScopeBIP0084.Coin,
+ }
+ bs := waddrmgr.BranchScope{
+ Scope: scope,
+ Account: 0,
+ Branch: 0,
+ }
+
+ // Expect FetchScopedKeyManager to be called only once for a given
+ // scope.
+ addrMgr.On("FetchScopedKeyManager", scope).Return(
+ &mockAccountStore{}, nil,
+ ).Once()
+
+ rs := NewRecoveryState(10, &chainParams, addrMgr)
+
+ // First call should fetch the manager and create a new state.
+ state1, err := rs.GetBranchState(bs)
+ require.NoError(t, err)
+ require.NotNil(t, state1)
+
+ // Second call with the same scope should return the cached state.
+ state2, err := rs.GetBranchState(bs)
+ require.NoError(t, err)
+ require.Equal(t, state1, state2)
+}
+
+// TestInitialize verifies the public Initialize method of RecoveryState.
+// It ensures the method correctly sets up transient address filters and
+// outpoints based on account properties and mock address derivations.
+// This includes verifying that the lookahead horizons are properly synced
+// and that the address filters are populated with the expected number of
+// derived addresses for both external and internal branches.
+func TestInitialize(t *testing.T) {
+ t.Parallel()
+
+ addrMgr := &mockAddrStore{}
+ defer addrMgr.AssertExpectations(t)
+
+ accountStore := &mockAccountStore{}
+ defer accountStore.AssertExpectations(t)
+
+ scope := waddrmgr.KeyScope{Purpose: 84, Coin: 0}
+
+ props := &waddrmgr.AccountProperties{
+ KeyScope: scope,
+ AccountNumber: 0,
+ ExternalKeyCount: 5, // 5 found addresses
+ InternalKeyCount: 3, // 3 found addresses
+ }
+
+ // FetchScopedKeyManager is called twice (once for external, once for
+ // internal branch)
+ addrMgr.On("FetchScopedKeyManager", scope).Return(
+ accountStore, nil,
+ ).Times(2)
+
+ // Helper to mock DeriveAddr calls for a given branch and
+ // range of indices.
+ mockDerive := func(branch, count uint32) {
+ for i := range count {
+ id := int(branch)*1000 + int(i)
+ addr := &mockAddress{}
+ addrStr := fmt.Sprintf("addr-%d", id)
+ script := fmt.Appendf(nil, "script-%d", id)
+
+ // Configure mockAddress expectations.
+ addr.On("EncodeAddress").Return(addrStr)
+ addr.On("ScriptAddress").Return(script)
+
+ accountStore.On(
+ "DeriveAddr", uint32(0), branch, i,
+ ).Return(addr, script, nil).Once()
+ }
+ }
+
+ // External branch: 5 found, recovery window 10. Total 15 derivations
+ // (0-14).
+ mockDerive(0, 15)
+
+ // Internal branch: 3 found, recovery window 10. Total 13 derivations
+ // (0-12).
+ mockDerive(1, 13)
+
+ rs := NewRecoveryState(10, &chainParams, addrMgr)
+
+ err := rs.Initialize([]*waddrmgr.AccountProperties{props}, nil, nil)
+ require.NoError(t, err)
+
+ // Verify that the address filters are populated with the expected
+ // number of addresses.
+ require.Len(t, rs.addrFilters, 15+13)
+ require.Equal(t, 15+13, rs.WatchListSize())
+}
+
+// TestProcessBlock verifies the core private filterTx and expandHorizons
+// methods through the public ProcessBlock entry point. It simulates a
+// block containing transactions that trigger address discovery and horizon
+// expansion. This test ensures the method correctly identifies relevant
+// transactions, updates outpoints, and manages the lookahead window by
+// repeatedly filtering the block and expanding horizons until convergence,
+// correctly handling intra-block chains and lookahead expansions.
+func TestProcessBlock(t *testing.T) {
+ t.Parallel()
+
+ addrMgr := &mockAddrStore{}
+ defer addrMgr.AssertExpectations(t)
+
+ accountStore := &mockAccountStore{}
+ defer accountStore.AssertExpectations(t)
+
+ scope := waddrmgr.KeyScope{Purpose: 84, Coin: 0}
+ props := &waddrmgr.AccountProperties{
+ KeyScope: scope,
+ AccountNumber: 0,
+
+ // Start fresh for easier expansion testing.
+ ExternalKeyCount: 0,
+ InternalKeyCount: 0,
+ }
+
+ addrMgr.On("FetchScopedKeyManager", scope).Return(
+ accountStore, nil,
+ ).Maybe() // Called by GetBranchState within Initialize/ProcessBlock
+
+ // Store generated addresses to construct block data.
+ addrs := make(map[int]address.Address)
+
+ // Helper to mock DeriveAddr and store the generated address.
+ setupDerive := func(branch, idx uint32) {
+ id := int(branch)*1000 + int(idx)
+
+ // Create a valid P2PKH address for deterministic scripts.
+ hash := make([]byte, 20)
+ hash[0] = byte(id >> 8)
+ hash[1] = byte(id)
+ addr, _ := address.NewAddressPubKeyHash(
+ hash, &chainParams,
+ )
+ addrs[id] = addr
+
+ // Set up mock expectation for DeriveAddr.
+ script, _ := txscript.PayToAddrScript(addr)
+ accountStore.On(
+ "DeriveAddr", uint32(0), branch, idx,
+ ).Return(addr, script, nil).Maybe()
+ }
+
+ // 1. Initialize RecoveryState (derives initial lookahead 0-9 for both
+ // branches).
+ for i := range uint32(10) {
+ setupDerive(0, i) // External addresses
+ setupDerive(1, i) // Internal addresses
+ }
+
+ rs := NewRecoveryState(10, &chainParams, addrMgr)
+ err := rs.Initialize([]*waddrmgr.AccountProperties{props}, nil, nil)
+ require.NoError(t, err)
+
+ // 2. Setup expectations for subsequent horizon expansions:
+ //
+ // Finding address at index 5 (External) should make next unfound 6.
+ // Horizon expands to 6 + window (10) = 16. New addresses derived from
+ // 10 to 15.
+ for i := uint32(10); i < 16; i++ {
+ setupDerive(0, i)
+ }
+
+ // Finding address at index 12 (External) should make next unfound 13.
+ // Horizon expands to 13 + window (10) = 23. New addresses derived from
+ // 16 to 22.
+ for i := uint32(16); i < 23; i++ {
+ setupDerive(0, i)
+ }
+
+ // 3. Construct a mock block with transactions.
+ block := wire.NewMsgBlock(wire.NewBlockHeader(
+ 0, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0,
+ ))
+
+ // Tx1: Pays to Addr 5 (External Branch) - an address within initial
+ // lookahead.
+ addr5 := addrs[5] // Corresponds to id 5 (branch 0, index 5)
+ script5, _ := txscript.PayToAddrScript(addr5)
+ tx1 := wire.NewMsgTx(2)
+ tx1.AddTxOut(wire.NewTxOut(1000, script5))
+ _ = block.AddTransaction(tx1)
+
+ // Tx2: Pays to Addr 12 (External Branch) - an address initially
+ // outside the lookahead, but becomes visible after the first
+ // expansion.
+ addr12 := addrs[12] // Corresponds to id 12 (branch 0, index 12)
+ script12, _ := txscript.PayToAddrScript(addr12)
+ tx2 := wire.NewMsgTx(2)
+ tx2.AddTxOut(wire.NewTxOut(2000, script12))
+ _ = block.AddTransaction(tx2)
+
+ // 4. Process the block and verify results.
+ res, err := rs.ProcessBlock(block)
+ require.NoError(t, err)
+
+ // Expect expansion occurred due to finding index 12.
+ require.True(t, res.Expanded)
+
+ // Expect both transactions to be identified as relevant.
+ require.Len(t, res.RelevantTxs, 2)
+
+ // Verify the maximum index found for Branch 0 (External) is 12.
+ bs := waddrmgr.BranchScope{Scope: scope, Account: 0, Branch: 0}
+ require.Contains(t, res.FoundHorizons, bs)
+ require.Equal(t, uint32(12), res.FoundHorizons[bs])
+}
+
+// TestBuildCFilterData verifies the BuildCFilterData method of RecoveryState.
+// It ensures that the method correctly aggregates all relevant scripts from
+// both the transient address filters and the watched outpoints into a single
+// list, which is then used for CFilter construction. This tests the data
+// aggregation logic, independent of address derivation or block processing.
+func TestBuildCFilterData(t *testing.T) {
+ t.Parallel()
+
+ addrMgr := &mockAddrStore{}
+ defer addrMgr.AssertExpectations(t)
+
+ rs := NewRecoveryState(10, &chainParams, addrMgr)
+
+ // Initially, the recovery state should be empty.
+ require.True(t, rs.Empty())
+
+ // Manually initialize maps as Initialize is not called for this test.
+ rs.addrFilters = make(map[string]AddrEntry)
+ rs.outpoints = make(map[wire.OutPoint][]byte)
+
+ // Add a sample address filter entry.
+ addr1, _ := address.DecodeAddress(
+ "mrCDrCybB6J1vRfbwM5hemdJz73FwDBC8r", &chainParams,
+ )
+ rs.addrFilters[addr1.EncodeAddress()] = AddrEntry{
+ Address: addr1,
+ }
+
+ // Add a sample watched outpoint.
+ op := wire.OutPoint{Hash: chainhash.Hash{1}, Index: 0}
+ pkScript := []byte{0x00, 0x14, 0x01, 0x02} // Dummy P2WPKH script
+ rs.outpoints[op] = pkScript
+
+ // Verify WatchListSize reflects the manually added entries.
+ require.Equal(t, 2, rs.WatchListSize())
+
+ // Build the CFilter data.
+ data, err := rs.BuildCFilterData()
+ require.NoError(t, err)
+
+ // Construct the expected script for addr1.
+ script1, _ := txscript.PayToAddrScript(addr1)
+
+ // Verify the returned data contains both scripts.
+ require.Len(t, data, 2)
+ require.Contains(t, data, script1)
+ require.Contains(t, data, pkScript)
+}
+
+// TestInitAccountState verifies the private initAccountState method.
+// It focuses on ensuring that when an account's properties are processed,
+// the method correctly initializes branch recovery states for both external
+// and internal branches. This involves verifying that the address manager's
+// FetchScopedKeyManager is called appropriately and that the recovery
+// state's addrFilters are populated with the initial set of derived
+// addresses based on the configured recovery window.
+func TestInitAccountState(t *testing.T) {
+ t.Parallel()
+
+ addrMgr := &mockAddrStore{}
+ defer addrMgr.AssertExpectations(t)
+
+ accountStore := &mockAccountStore{}
+ defer accountStore.AssertExpectations(t)
+
+ scope := waddrmgr.KeyScope{Purpose: 84, Coin: 0}
+ props := &waddrmgr.AccountProperties{
+ KeyScope: scope,
+ AccountNumber: 0,
+ ExternalKeyCount: 0,
+ InternalKeyCount: 0,
+ }
+
+ // RecoveryWindow 2.
+ rs := NewRecoveryState(2, &chainParams, addrMgr)
+ rs.addrFilters = make(map[string]AddrEntry)
+
+ // FetchScopedKeyManager is called twice (once for external, once for
+ // internal branch).
+ addrMgr.On("FetchScopedKeyManager", scope).Return(
+ accountStore, nil,
+ ).Times(2)
+
+ // Helper to mock DeriveAddr calls for a given branch and range of
+ // indices.
+ mockDerive := func(branch uint32) {
+ for i := range uint32(2) {
+ id := int(branch)*1000 + int(i)
+ addr := &mockAddress{}
+ addrStr := fmt.Sprintf("b%d-%d", branch, i)
+ script := fmt.Appendf(nil, "script-%d", id)
+
+ // Configure mockAddress expectations.
+ addr.On("EncodeAddress").Return(addrStr)
+ addr.On("ScriptAddress").Return(script)
+
+ accountStore.On(
+ "DeriveAddr", uint32(0), branch, i,
+ ).Return(addr, script, nil).Once()
+ }
+ }
+
+ mockDerive(0) // External
+ mockDerive(1) // Internal
+
+ err := rs.initAccountState(props)
+ require.NoError(t, err)
+ require.Len(t, rs.addrFilters, 4)
+}
+
+// TestExpandHorizons verifies the private expandHorizons method.
+// It ensures that the lookahead horizon is correctly expanded when new
+// addresses are reported as found. This test specifically checks that new
+// addresses are derived (via the mocked AccountStore) to maintain the recovery
+// window, and that these newly derived addresses are subsequently added to the
+// RecoveryState's transient addrFilters.
+func TestExpandHorizons(t *testing.T) {
+ t.Parallel()
+
+ store := &mockAccountStore{}
+ defer store.AssertExpectations(t)
+
+ rs := NewRecoveryState(2, nil, nil)
+ rs.addrFilters = make(map[string]AddrEntry)
+
+ // Setup a branch state manually.
+ bs := waddrmgr.BranchScope{Branch: 0}
+ brs := NewBranchRecoveryState(2, store)
+ rs.branchStates[bs] = brs
+
+ // Simulate finding index 0, which requires derivation of 0, 1, 2
+ // because NextUnfound becomes 1, MinHorizon = 1+2=3.
+ brs.ReportFound(0)
+
+ // Expect derivation of 0, 1, 2.
+ for i := range uint32(3) {
+ addr := &mockAddress{}
+ addrStr := fmt.Sprintf("addr-%d", i)
+ script := fmt.Appendf(nil, "script-%d", i)
+
+ addr.On("EncodeAddress").Return(addrStr)
+ addr.On("ScriptAddress").Return(script)
+
+ store.On("DeriveAddr", uint32(0), uint32(0), i).Return(
+ addr, script, nil,
+ ).Once()
+ }
+
+ expanded, err := rs.expandHorizons()
+ require.NoError(t, err)
+ require.True(t, expanded)
+ require.Len(t, rs.addrFilters, 3)
+}
+
+// TestReportFound verifies the private reportFound method.
+// It ensures that the method correctly processes a map of found AddrScopes,
+// identifying the maximum index found for each BranchScope. It then verifies
+// that the corresponding BranchRecoveryState's internal `nextUnfound` value is
+// updated appropriately based on these findings, triggering potential future
+// horizon expansions.
+func TestReportFound(t *testing.T) {
+ t.Parallel()
+
+ rs := NewRecoveryState(10, nil, nil)
+ bs := waddrmgr.BranchScope{Branch: 0}
+ brs := NewBranchRecoveryState(10, nil)
+ rs.branchStates[bs] = brs
+
+ // Simulate finding index 5 on this branch.
+ found := map[waddrmgr.AddrScope]struct{}{
+ {BranchScope: bs, Index: 5}: {},
+ }
+
+ horizons := rs.reportFound(found)
+
+ require.Contains(t, horizons, bs)
+ require.Equal(t, uint32(5), horizons[bs])
+ require.Equal(t, uint32(6), brs.NextUnfound())
+}
+
+// TestFilterTx verifies the private filterTx method.
+// It simulates a single transaction and checks its relevance against the
+// RecoveryState's configured address filters and watched outpoints. This test
+// ensures that the method correctly identifies credits (payments to our
+// addresses) and debits (spends from our outpoints), updates the transient
+// outpoints map (removing spent inputs, adding new relevant outputs), and
+// populates the foundScopes and relevantOutputs maps for subsequent
+// processing.
+func TestFilterTx(t *testing.T) {
+ t.Parallel()
+
+ rs := NewRecoveryState(10, &chainParams, nil)
+ rs.addrFilters = make(map[string]AddrEntry)
+ rs.outpoints = make(map[wire.OutPoint][]byte)
+
+ // 1. Setup Watched Address.
+ // Use real address for Script parsing interaction with txscript.
+ addr, _ := address.DecodeAddress(
+ "mrCDrCybB6J1vRfbwM5hemdJz73FwDBC8r", &chainParams,
+ )
+
+ rs.addrFilters[addr.EncodeAddress()] = AddrEntry{
+ Address: addr,
+ addrScope: waddrmgr.AddrScope{
+ BranchScope: waddrmgr.BranchScope{Branch: 0},
+ Index: 10,
+ },
+ IsLookahead: true,
+ }
+
+ // 2. Setup Watched Outpoint.
+ opHash := chainhash.Hash{0x01}
+ op := wire.OutPoint{Hash: opHash, Index: 0}
+ rs.outpoints[op] = []byte{0x00} // Dummy script
+
+ // 3. Construct Tx.
+ // Input spending 'op'.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(wire.NewTxIn(&op, nil, nil))
+
+ // Output paying to 'addr'.
+ pkScript, _ := txscript.PayToAddrScript(addr)
+ tx.AddTxOut(wire.NewTxOut(1000, pkScript))
+
+ // 4. Filter.
+ foundScopes := make(map[waddrmgr.AddrScope]struct{})
+ isRelevant, entries := rs.filterTx(tx, foundScopes)
+ require.True(t, isRelevant)
+
+ // Verify Outpoint spent (removed).
+ _, ok := rs.outpoints[op]
+ require.False(t, ok, "outpoint should be removed")
+
+ // Verify Output matched.
+ txHash := tx.TxHash()
+
+ require.Len(t, entries, 1)
+
+ // Verify Scope found.
+ expectedScope := waddrmgr.AddrScope{
+ BranchScope: waddrmgr.BranchScope{Branch: 0},
+ Index: 10,
+ }
+ require.Contains(t, foundScopes, expectedScope)
+
+ // Verify new outpoint added.
+ newOp := wire.OutPoint{Hash: txHash, Index: 0}
+ _, ok = rs.outpoints[newOp]
+ require.True(t, ok, "new outpoint should be watched")
+}
+
+// TestRecoveryStateWatchedOutPoints verifies the management of persistent
+// watched outpoints through AddWatchedOutPoint and WatchedOutPoints.
+func TestRecoveryStateWatchedOutPoints(t *testing.T) {
+ t.Parallel()
+
+ rs := NewRecoveryState(10, nil, nil)
+
+ // Initially, no watched outpoints.
+ require.Empty(t, rs.WatchedOutPoints())
+
+ op1 := wire.OutPoint{Hash: chainhash.Hash{1}, Index: 0}
+ addr1 := &mockAddress{}
+ op2 := wire.OutPoint{Hash: chainhash.Hash{2}, Index: 1}
+ addr2 := &mockAddress{}
+
+ rs.AddWatchedOutPoint(&op1, addr1)
+ rs.AddWatchedOutPoint(&op2, addr2)
+
+ watched := rs.WatchedOutPoints()
+ require.Len(t, watched, 2)
+ require.Equal(t, addr1, watched[op1])
+ require.Equal(t, addr2, watched[op2])
+}
+
+// TestRecoveryStateStringAndEmpty verifies the String and Empty methods of
+// RecoveryState. It ensures that the String method produces a non-empty
+// summary and that the Empty method accurately reflects the presence or
+// absence of filters and outpoints.
+func TestRecoveryStateStringAndEmpty(t *testing.T) {
+ t.Parallel()
+
+ rs := NewRecoveryState(10, nil, nil)
+
+ // Initially, state should be empty.
+ require.True(t, rs.Empty())
+ require.Contains(t, rs.String(), "RecoveryState(addrs=0, outpoints=0)")
+
+ // Add an address filter entry.
+ rs.addrFilters = make(map[string]AddrEntry)
+ rs.addrFilters["a"] = AddrEntry{}
+
+ require.False(t, rs.Empty())
+ require.Contains(t, rs.String(), "RecoveryState(addrs=1, outpoints=0)")
+
+ // Add an outpoint.
+ rs.outpoints = make(map[wire.OutPoint][]byte)
+ op := wire.OutPoint{Hash: chainhash.Hash{1}, Index: 0}
+ rs.outpoints[op] = []byte{}
+
+ require.False(t, rs.Empty())
+ require.Contains(t, rs.String(), "RecoveryState(addrs=1, outpoints=1)")
+}
+
+// Define a static error for testing FetchScopedKeyManager failure.
+var errFetch = errors.New("fetch error")
+
+// TestExpandHorizonsWithInvalidChild verifies that expandHorizons correctly
+// handles hdkeychain.ErrInvalidChild by skipping the invalid index and
+// continuing derivation until the window is full.
+func TestExpandHorizonsWithInvalidChild(t *testing.T) {
+ t.Parallel()
+
+ store := &mockAccountStore{}
+ defer store.AssertExpectations(t)
+
+ rs := NewRecoveryState(2, nil, nil)
+ rs.addrFilters = make(map[string]AddrEntry)
+
+ // Setup a branch state manually.
+ bs := waddrmgr.BranchScope{Branch: 0}
+ brs := NewBranchRecoveryState(2, store)
+ rs.branchStates[bs] = brs
+
+ // Simulate finding index 0. This triggers expansion.
+ brs.ReportFound(0)
+
+ // Expect derivation of 0 -> Success
+ addr0 := &mockAddress{}
+ addr0.On("EncodeAddress").Return("addr-0")
+ addr0.On("ScriptAddress").Return([]byte("script-0"))
+ store.On("DeriveAddr", uint32(0), uint32(0), uint32(0)).Return(
+ addr0, []byte("script-0"), nil,
+ ).Once()
+
+ // Expect derivation of 1 -> ErrInvalidChild
+ store.On("DeriveAddr", uint32(0), uint32(0), uint32(1)).Return(
+ nil, nil, hdkeychain.ErrInvalidChild,
+ ).Once()
+
+ // Expect derivation of 2 -> Success
+ addr2 := &mockAddress{}
+ addr2.On("EncodeAddress").Return("addr-2")
+ addr2.On("ScriptAddress").Return([]byte("script-2"))
+ store.On("DeriveAddr", uint32(0), uint32(0), uint32(2)).Return(
+ addr2, []byte("script-2"), nil,
+ ).Once()
+
+ // Expect derivation of 3 -> Success (to fill window)
+ addr3 := &mockAddress{}
+ addr3.On("EncodeAddress").Return("addr-3")
+ addr3.On("ScriptAddress").Return([]byte("script-3"))
+ store.On("DeriveAddr", uint32(0), uint32(0), uint32(3)).Return(
+ addr3, []byte("script-3"), nil,
+ ).Once()
+
+ expanded, err := rs.expandHorizons()
+ require.NoError(t, err)
+ require.True(t, expanded)
+
+ // Verify filters contain 0, 2, 3 (3 valid addresses).
+ require.Len(t, rs.addrFilters, 3)
+ require.Contains(t, rs.addrFilters, "addr-0")
+ require.Contains(t, rs.addrFilters, "addr-2")
+ require.Contains(t, rs.addrFilters, "addr-3")
+ require.NotContains(t, rs.addrFilters, "addr-1")
+}
+
+// TestInitializeError verifies that Initialize propagates errors from
+// initAccountState (e.g. FetchScopedKeyManager failures).
+func TestInitializeError(t *testing.T) {
+ t.Parallel()
+
+ addrMgr := &mockAddrStore{}
+ defer addrMgr.AssertExpectations(t)
+
+ rs := NewRecoveryState(10, nil, addrMgr)
+ scope := waddrmgr.KeyScope{Purpose: 84, Coin: 0}
+ props := &waddrmgr.AccountProperties{KeyScope: scope}
+
+ // Mock failure.
+ addrMgr.On("FetchScopedKeyManager", scope).Return(nil, errFetch).Once()
+
+ err := rs.Initialize([]*waddrmgr.AccountProperties{props}, nil, nil)
+ require.ErrorIs(t, err, errFetch)
+}
+
+// TestInitAccountStateDeriveError verifies that initAccountState propagates
+// errors from DeriveAddr.
+func TestInitAccountStateDeriveError(t *testing.T) {
+ t.Parallel()
+
+ addrMgr := &mockAddrStore{}
+ accountStore := &mockAccountStore{}
+
+ defer addrMgr.AssertExpectations(t)
+ defer accountStore.AssertExpectations(t)
+
+ rs := NewRecoveryState(10, nil, addrMgr)
+ rs.addrFilters = make(map[string]AddrEntry)
+ scope := waddrmgr.KeyScope{Purpose: 84, Coin: 0}
+ props := &waddrmgr.AccountProperties{KeyScope: scope}
+
+ // First call succeeds (External).
+ addrMgr.On("FetchScopedKeyManager", scope).Return(
+ accountStore, nil,
+ ).Once()
+
+ // Derive fails immediately.
+ accountStore.On(
+ "DeriveAddr", uint32(0), uint32(0), uint32(0),
+ ).Return(nil, nil, errFetch).Once()
+
+ err := rs.initAccountState(props)
+ require.ErrorIs(t, err, errFetch)
+}
+
+// TestExpandHorizonsError verifies that expandHorizons propagates errors from
+// DeriveAddr when attempting to extend the lookahead window.
+func TestExpandHorizonsError(t *testing.T) {
+ t.Parallel()
+
+ accountStore := &mockAccountStore{}
+ defer accountStore.AssertExpectations(t)
+
+ rs := NewRecoveryState(2, nil, nil)
+ rs.addrFilters = make(map[string]AddrEntry)
+ bs := waddrmgr.BranchScope{Branch: 0}
+ brs := NewBranchRecoveryState(2, accountStore)
+ rs.branchStates[bs] = brs
+
+ // Trigger expansion requirement.
+ brs.ReportFound(0)
+
+ // Mock DeriveAddr error.
+ accountStore.On(
+ "DeriveAddr", uint32(0), uint32(0), uint32(0),
+ ).Return(nil, nil, errFetch).Once()
+
+ _, err := rs.expandHorizons()
+ require.ErrorIs(t, err, errFetch)
+}
+
+// TestInitializeWithState verifies Initialize with existing state.
+func TestInitializeWithState(t *testing.T) {
+ t.Parallel()
+
+ addrMgr := &mockAddrStore{}
+ defer addrMgr.AssertExpectations(t)
+
+ rs := NewRecoveryState(10, nil, addrMgr)
+
+ // Mock address and outpoint.
+ addr := &mockAddress{}
+ addr.On("EncodeAddress").Return("addr1")
+
+ outpoint := wtxmgr.Credit{
+ OutPoint: wire.OutPoint{Hash: chainhash.Hash{1}, Index: 0},
+ PkScript: []byte{1},
}
+
+ err := rs.Initialize(
+ nil, []address.Address{addr}, []wtxmgr.Credit{outpoint},
+ )
+ require.NoError(t, err)
+ require.Len(t, rs.addrFilters, 1)
+ require.Len(t, rs.outpoints, 1)
+}
+
+// TestProcessBlockError verifies that ProcessBlock propagates errors from
+// expandHorizons.
+func TestProcessBlockError(t *testing.T) {
+ t.Parallel()
+
+ store := &mockAccountStore{}
+ defer store.AssertExpectations(t)
+
+ rs := NewRecoveryState(10, &chainParams, nil)
+ rs.addrFilters = make(map[string]AddrEntry)
+ rs.outpoints = make(map[wire.OutPoint][]byte)
+
+ // Setup branch.
+ bs := waddrmgr.BranchScope{Branch: 0}
+ brs := NewBranchRecoveryState(10, store)
+ rs.branchStates[bs] = brs
+
+ // Add filter entry that triggers expansion (using real address for
+ // txscript compatibility).
+ realAddr, _ := address.NewAddressPubKeyHash(
+ make([]byte, 20), &chainParams,
+ )
+ rs.addrFilters[realAddr.EncodeAddress()] = AddrEntry{
+ Address: realAddr,
+ addrScope: waddrmgr.AddrScope{
+ BranchScope: bs, Index: 0,
+ },
+ IsLookahead: true,
+ }
+
+ // Block with tx paying to realAddr.
+ block := wire.NewMsgBlock(&wire.BlockHeader{})
+ tx := wire.NewMsgTx(2)
+ txOut := wire.NewTxOut(1000, nil)
+
+ var err error
+
+ txOut.PkScript, err = txscript.PayToAddrScript(realAddr)
+ require.NoError(t, err)
+ tx.AddTxOut(txOut)
+ _ = block.AddTransaction(tx)
+
+ // Mock failure.
+ store.On("DeriveAddr", uint32(0), uint32(0), uint32(0)).Return(
+ nil, nil, errFetch).Once()
+
+ _, err = rs.ProcessBlock(block)
+ require.ErrorIs(t, err, errFetch)
}
diff --git a/wallet/rescan.go b/wallet/rescan.go
deleted file mode 100644
index 4f98841bf5..0000000000
--- a/wallet/rescan.go
+++ /dev/null
@@ -1,323 +0,0 @@
-// Copyright (c) 2013-2017 The btcsuite developers
-// Use of this source code is governed by an ISC
-// license that can be found in the LICENSE file.
-
-package wallet
-
-import (
- "github.com/btcsuite/btcd/address/v2"
- "github.com/btcsuite/btcd/txscript/v2"
- "github.com/btcsuite/btcd/wire/v2"
- "github.com/btcsuite/btcwallet/chain"
- "github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/btcsuite/btcwallet/wtxmgr"
-)
-
-// RescanProgressMsg reports the current progress made by a rescan for a
-// set of wallet addresses.
-type RescanProgressMsg struct {
- Addresses []address.Address
- Notification chain.RescanProgress
-}
-
-// RescanFinishedMsg reports the addresses that were rescanned when a
-// rescanfinished message was received rescanning a batch of addresses.
-type RescanFinishedMsg struct {
- Addresses []address.Address
- Notification *chain.RescanFinished
-}
-
-// RescanJob is a job to be processed by the RescanManager. The job includes
-// a set of wallet addresses, a starting height to begin the rescan, and
-// outpoints spendable by the addresses thought to be unspent. After the
-// rescan completes, the error result of the rescan RPC is sent on the Err
-// channel.
-type RescanJob struct {
- InitialSync bool
- Addrs []address.Address
- OutPoints map[wire.OutPoint]address.Address
- BlockStamp waddrmgr.BlockStamp
- err chan error
-}
-
-// rescanBatch is a collection of one or more RescanJobs that were merged
-// together before a rescan is performed.
-type rescanBatch struct {
- initialSync bool
- addrs []address.Address
- outpoints map[wire.OutPoint]address.Address
- bs waddrmgr.BlockStamp
- errChans []chan error
-}
-
-// SubmitRescan submits a RescanJob to the RescanManager. A channel is
-// returned with the final error of the rescan. The channel is buffered
-// and does not need to be read to prevent a deadlock.
-func (w *Wallet) SubmitRescan(job *RescanJob) <-chan error {
- errChan := make(chan error, 1)
- job.err = errChan
- select {
- case w.rescanAddJob <- job:
- case <-w.quitChan():
- errChan <- ErrWalletShuttingDown
- }
- return errChan
-}
-
-// batch creates the rescanBatch for a single rescan job.
-func (job *RescanJob) batch() *rescanBatch {
- return &rescanBatch{
- initialSync: job.InitialSync,
- addrs: job.Addrs,
- outpoints: job.OutPoints,
- bs: job.BlockStamp,
- errChans: []chan error{job.err},
- }
-}
-
-// merge merges the work from k into j, setting the starting height to
-// the minimum of the two jobs. This method does not check for
-// duplicate addresses or outpoints.
-func (b *rescanBatch) merge(job *RescanJob) {
- if job.InitialSync {
- b.initialSync = true
- }
- b.addrs = append(b.addrs, job.Addrs...)
-
- for op, addr := range job.OutPoints {
- b.outpoints[op] = addr
- }
-
- if job.BlockStamp.Height < b.bs.Height {
- b.bs = job.BlockStamp
- }
- b.errChans = append(b.errChans, job.err)
-}
-
-// done iterates through all error channels, duplicating sending the error
-// to inform callers that the rescan finished (or could not complete due
-// to an error).
-func (b *rescanBatch) done(err error) {
- for _, c := range b.errChans {
- c <- err
- }
-}
-
-// rescanBatchHandler handles incoming rescan request, serializing rescan
-// submissions, and possibly batching many waiting requests together so they
-// can be handled by a single rescan after the current one completes.
-func (w *Wallet) rescanBatchHandler() {
- defer w.wg.Done()
-
- var curBatch, nextBatch *rescanBatch
- quit := w.quitChan()
-
- for {
- select {
- case job := <-w.rescanAddJob:
- if curBatch == nil {
- // Set current batch as this job and send
- // request.
- curBatch = job.batch()
- select {
- case w.rescanBatch <- curBatch:
- case <-quit:
- job.err <- ErrWalletShuttingDown
- return
- }
- } else {
- // Create next batch if it doesn't exist, or
- // merge the job.
- if nextBatch == nil {
- nextBatch = job.batch()
- } else {
- nextBatch.merge(job)
- }
- }
-
- case n := <-w.rescanNotifications:
- switch n := n.(type) {
- case *chain.RescanProgress:
- if curBatch == nil {
- log.Warnf("Received rescan progress " +
- "notification but no rescan " +
- "currently running")
- continue
- }
- select {
- case w.rescanProgress <- &RescanProgressMsg{
- Addresses: curBatch.addrs,
- Notification: *n,
- }:
- case <-quit:
- for _, errChan := range curBatch.errChans {
- errChan <- ErrWalletShuttingDown
- }
- return
- }
-
- case *chain.RescanFinished:
- if curBatch == nil {
- log.Warnf("Received rescan finished " +
- "notification but no rescan " +
- "currently running")
- continue
- }
- select {
- case w.rescanFinished <- &RescanFinishedMsg{
- Addresses: curBatch.addrs,
- Notification: n,
- }:
- case <-quit:
- for _, errChan := range curBatch.errChans {
- errChan <- ErrWalletShuttingDown
- }
- return
- }
-
- curBatch, nextBatch = nextBatch, nil
-
- if curBatch != nil {
- select {
- case w.rescanBatch <- curBatch:
- case <-quit:
- for _, errChan := range curBatch.errChans {
- errChan <- ErrWalletShuttingDown
- }
- return
- }
- }
-
- default:
- // Unexpected message
- panic(n)
- }
-
- case <-quit:
- return
- }
- }
-}
-
-// rescanProgressHandler handles notifications for partially and fully completed
-// rescans by marking each rescanned address as partially or fully synced.
-func (w *Wallet) rescanProgressHandler() {
- quit := w.quitChan()
-out:
- for {
- // These can't be processed out of order since both chans are
- // unbuffured and are sent from same context (the batch
- // handler).
- select {
- case msg := <-w.rescanProgress:
- n := msg.Notification
- log.Infof("Rescanned through block %v (height %d)",
- n.Hash, n.Height)
-
- case msg := <-w.rescanFinished:
- n := msg.Notification
- addrs := msg.Addresses
- noun := pickNoun(len(addrs), "address", "addresses")
- log.Infof("Finished rescan for %d %s (synced to block "+
- "%s, height %d)", len(addrs), noun, n.Hash,
- n.Height)
-
- go w.resendUnminedTxs()
-
- case <-quit:
- break out
- }
- }
- w.wg.Done()
-}
-
-// rescanRPCHandler reads batch jobs sent by rescanBatchHandler and sends the
-// RPC requests to perform a rescan. New jobs are not read until a rescan
-// finishes.
-func (w *Wallet) rescanRPCHandler() {
- chainClient, err := w.requireChainClient()
- if err != nil {
- log.Errorf("rescanRPCHandler called without an RPC client")
- w.wg.Done()
- return
- }
-
- quit := w.quitChan()
-
-out:
- for {
- select {
- case batch := <-w.rescanBatch:
- // Log the newly-started rescan.
- numAddrs := len(batch.addrs)
- numOps := len(batch.outpoints)
-
- log.Infof("Started rescan from block %v (height %d) "+
- "for %d addrs, %d outpoints", batch.bs.Hash,
- batch.bs.Height, numAddrs, numOps)
-
- err := chainClient.Rescan(
- &batch.bs.Hash, batch.addrs, batch.outpoints,
- )
- if err != nil {
- log.Errorf("Rescan for %d addrs, %d outpoints "+
- "failed: %v", numAddrs, numOps, err)
- }
- batch.done(err)
- case <-quit:
- break out
- }
- }
-
- w.wg.Done()
-}
-
-// Rescan begins a rescan for all active addresses and unspent outputs of
-// a wallet. This is intended to be used to sync a wallet back up to the
-// current best block in the main chain, and is considered an initial sync
-// rescan.
-func (w *Wallet) Rescan(addrs []address.Address,
- unspent []wtxmgr.Credit) error {
-
- return w.rescanWithTarget(addrs, unspent, nil)
-}
-
-// rescanWithTarget performs a rescan starting at the optional startStamp. If
-// none is provided, the rescan will begin from the manager's sync tip.
-func (w *Wallet) rescanWithTarget(addrs []address.Address,
- unspent []wtxmgr.Credit, startStamp *waddrmgr.BlockStamp) error {
-
- outpoints := make(map[wire.OutPoint]address.Address, len(unspent))
- for _, output := range unspent {
- _, outputAddrs, _, err := txscript.ExtractPkScriptAddrs(
- output.PkScript, w.chainParams,
- )
- if err != nil {
- return err
- }
-
- outpoints[output.OutPoint] = outputAddrs[0]
- }
-
- // If a start block stamp was provided, we will use that as the initial
- // starting point for the rescan.
- if startStamp == nil {
- startStamp = &waddrmgr.BlockStamp{}
- *startStamp = w.Manager.SyncedTo()
- }
-
- job := &RescanJob{
- InitialSync: true,
- Addrs: addrs,
- OutPoints: outpoints,
- BlockStamp: *startStamp,
- }
-
- // Submit merged job and block until rescan completes.
- select {
- case err := <-w.SubmitRescan(job):
- return err
- case <-w.quitChan():
- return ErrWalletShuttingDown
- }
-}
diff --git a/wallet/signer.go b/wallet/signer.go
index b16d047453..71509b891c 100644
--- a/wallet/signer.go
+++ b/wallet/signer.go
@@ -1,138 +1,912 @@
-// Copyright (c) 2020 The btcsuite developers
-// Use of this source code is governed by an ISC
-// license that can be found in the LICENSE file.
-
package wallet
import (
+ "context"
+ "errors"
"fmt"
"github.com/btcsuite/btcd/address/v2"
"github.com/btcsuite/btcd/btcec/v2"
+ "github.com/btcsuite/btcd/btcec/v2/ecdsa"
+ "github.com/btcsuite/btcd/btcec/v2/schnorr"
+ "github.com/btcsuite/btcd/chainhash/v2"
"github.com/btcsuite/btcd/txscript/v2"
"github.com/btcsuite/btcd/wire/v2"
"github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+)
+
+var (
+ // ErrUnknownSignMethod is returned when a transaction is signed with an
+ // unknown sign method.
+ ErrUnknownSignMethod = errors.New("unknown sign method")
+
+ // ErrUnsupportedAddressType is returned when a transaction is signed
+ // for an unsupported address type.
+ ErrUnsupportedAddressType = errors.New("unsupported address type")
+
+ // ErrInvalidDigestSize is returned when a signature digest is not 32
+ // bytes.
+ ErrInvalidDigestSize = errors.New("digest must be 32 bytes")
+
+ // ErrInvalidSignParam is returned when the parameters for the signing
+ // operation are invalid.
+ ErrInvalidSignParam = errors.New("invalid signing parameters")
+)
+
+// Signer provides an interface for common, safe cryptographic operations,
+// including signing and key derivation.
+type Signer interface {
+ // DerivePubKey derives a public key from a full BIP-32 derivation
+ // path.
+ DerivePubKey(ctx context.Context, path BIP32Path) (
+ *btcec.PublicKey, error)
+
+ // ECDH performs a scalar multiplication (ECDH-like operation) between
+ // a key from the wallet and a remote public key. The output returned
+ // will be the raw 32-byte shared secret (the X-coordinate of the
+ // result point).
+ ECDH(ctx context.Context, path BIP32Path, pub *btcec.PublicKey) (
+ [32]byte, error)
+
+ // SignDigest signs a message digest based on the provided intent. The
+ // returned Signature is a marker interface that can be asserted to the
+ // concrete signature types, ECDSASignature or SchnorrSignature.
+ SignDigest(ctx context.Context, path BIP32Path,
+ intent *SignDigestIntent) (Signature, error)
+
+ // ComputeUnlockingScript generates the full sigScript and witness
+ // required to spend a UTXO. The resulting UnlockingScript struct
+ // contains the raw witness and/or sigScript, which can be used to
+ // populate the final transaction input.
+ //
+ // This method is designed for spending single-signature outputs, which
+ // are outputs that can be spent with a single signature from a single
+ // private key. This includes P2PKH, P2WKH, NP2WKH, and P2TR key-path
+ // spends. For more complex script-based spends, such as P2SH or P2WSH
+ // multisig, the ComputeRawSig method should be used to generate the raw
+ // signature, which can then be manually assembled into the final
+ // witness.
+ ComputeUnlockingScript(ctx context.Context,
+ params *UnlockingScriptParams) (*UnlockingScript, error)
+
+ // ComputeRawSig generates a raw signature for a single transaction
+ // input. The caller is responsible for assembling the final witness.
+ //
+ // This method is a low-level specialist function that should only be
+ // used when the caller needs to generate a raw signature for a
+ // specific key, without the wallet assembling the final witness. This
+ // is useful for multi-party protocols like multisig or Lightning,
+ // where signatures may need to be exchanged and combined before the
+ // final witness is created. For most common, single-signature spends,
+ // ComputeUnlockingScript should be used instead.
+ ComputeRawSig(ctx context.Context, params *RawSigParams) (
+ RawSignature, error)
+}
+
+// UnsafeSigner provides an interface for security-sensitive cryptographic
+// operations that export raw private key material. This interface should be
+// used with extreme care and only when absolutely necessary.
+type UnsafeSigner interface {
+ Signer
+
+ // DerivePrivKey derives a private key from a full BIP-32 derivation
+ // path.
+ //
+ // DANGER: This method exports sensitive key material.
+ DerivePrivKey(ctx context.Context, path BIP32Path) (
+ *btcec.PrivateKey, error)
+
+ // GetPrivKeyForAddress returns the private key for a given address.
+ //
+ // DANGER: This method exports sensitive key material.
+ GetPrivKeyForAddress(ctx context.Context, a address.Address) (
+ *btcec.PrivateKey, error)
+}
+
+// A compile-time check to ensure that Wallet implements the Signer and
+// UnsafeSigner interfaces.
+var _ Signer = (*Wallet)(nil)
+var _ UnsafeSigner = (*Wallet)(nil)
+
+// BIP32Path contains the full information needed to derive a key from the
+// wallet's master seed, as defined by BIP-32. It combines the high-level key
+// scope with the specific derivation path.
+type BIP32Path struct {
+ // KeyScope specifies the key scope (e.g., P2WKH, P2TR, or lnd's custom
+ // scope).
+ KeyScope waddrmgr.KeyScope
+
+ // DerivationPath specifies the full derivation path within the scope.
+ DerivationPath waddrmgr.DerivationPath
+}
+
+// SignatureType represents the type of signature to produce.
+type SignatureType uint8
+
+const (
+ // SigTypeECDSA represents an ECDSA signature.
+ SigTypeECDSA SignatureType = iota
+
+ // SigTypeSchnorr represents a Schnorr signature.
+ SigTypeSchnorr
+)
+
+// SignDigestIntent represents the user's intent to sign a message digest. It
+// serves as a blueprint for the Signer, bundling all the parameters
+// required to produce a signature into a single, coherent structure.
+//
+// # Usage Examples
+//
+// ## Standard ECDSA Signature (DER Encoded)
+// To produce a standard ECDSA signature, set SigType to SigTypeECDSA.
+//
+// intent := &wallet.SignDigestIntent{
+// Digest: chainhash.HashB([]byte("a message")),
+// SigType: wallet.SigTypeECDSA,
+// }
+// rawSig, err := signer.SignDigest(ctx, path, intent)
+// // Type-assert the result to ECDSASignature.
+// ecdsaSig := rawSig.(wallet.ECDSASignature)
+//
+// ## Compact, Recoverable ECDSA Signature
+// To produce a compact, recoverable signature, set CompactSig to true.
+//
+// intent := &wallet.SignDigestIntent{
+// Digest: chainhash.DoubleHashB([]byte("a message")),
+// SigType: wallet.SigTypeECDSA,
+// CompactSig: true,
+// }
+// rawSig, err := signer.SignDigest(ctx, path, intent)
+// // Type-assert the result to CompactSignature.
+// compactSig := rawSig.(wallet.CompactSignature)
+//
+// ## Schnorr Signature
+// To produce a Schnorr signature, set SigType to SigTypeSchnorr.
+//
+// intent := &wallet.SignDigestIntent{
+// Digest: chainhash.TaggedHash(
+// []byte("my_protocol_tag"), []byte("a message"),
+// ),
+// SigType: wallet.SigTypeSchnorr,
+// }
+// rawSig, err := signer.SignDigest(ctx, path, intent)
+// // Type-assert the result to SchnorrSignature.
+// schnorrSig := rawSig.(wallet.SchnorrSignature)
+type SignDigestIntent struct {
+ // Digest is the 32-byte hash digest to be signed.
+ Digest []byte
+
+ // SigType specifies the type of signature to generate.
+ SigType SignatureType
+
+ // CompactSig specifies whether the signature should be returned in the
+ // compact, recoverable format. This is only valid for ECDSA signatures.
+ CompactSig bool
+
+ // TaprootTweak is an optional private key tweak to be applied before
+ // signing. This is only valid for Schnorr signatures.
+ TaprootTweak []byte
+}
+
+// Signature is an interface that represents a cryptographic signature.
+// It is a marker interface to allow returning different signature types.
+type Signature interface {
+ // isSignature is a marker method to ensure that only the types defined
+ // in this package can implement this interface.
+ isSignature()
+}
+
+// ECDSASignature wraps an ecdsa.Signature to implement the Signature interface.
+type ECDSASignature struct {
+ *ecdsa.Signature
+}
+
+// CompactSignature wraps a compact signature byte slice to implement the
+// Signature interface.
+type CompactSignature []byte
+
+// SchnorrSignature wraps a schnorr.Signature to implement the Signature
+// interface.
+type SchnorrSignature struct {
+ *schnorr.Signature
+}
+
+// isSignature implements the Signature marker interface.
+func (ECDSASignature) isSignature() {}
+
+// isSignature implements the Signature marker interface.
+func (CompactSignature) isSignature() {}
+
+// isSignature implements the Signature marker interface.
+func (SchnorrSignature) isSignature() {}
+
+// UnlockingScript is a struct that contains the witness and sigScript for a
+// transaction input.
+type UnlockingScript struct {
+ // Witness is the witness stack for the input. For non-SegWit inputs,
+ // this will be nil.
+ Witness wire.TxWitness
+
+ // SigScript is the signature script for the input. For native SegWit
+ // inputs, this will be nil.
+ SigScript []byte
+}
+
+// PrivKeyTweaker is a function type that can be used to pass in a callback for
+// tweaking a private key before it's used to sign an input.
+type PrivKeyTweaker func(*btcec.PrivateKey) (*btcec.PrivateKey, error)
+
+// UnlockingScriptParams provides all the necessary parameters to generate an
+// unlocking script (witness and sigScript) for a transaction input.
+type UnlockingScriptParams struct {
+ // Tx is the transaction containing the input to be signed.
+ Tx *wire.MsgTx
+
+ // InputIndex is the index of the input to be signed.
+ InputIndex int
+
+ // Output is the previous output that is being spent.
+ Output *wire.TxOut
+
+ // SigHashes is the sighash cache for the transaction.
+ SigHashes *txscript.TxSigHashes
+
+ // HashType is the signature hash type to use.
+ HashType txscript.SigHashType
+
+ // Tweaker is an optional function that can be used to tweak the
+ // private key before signing.
+ Tweaker PrivKeyTweaker
+}
+
+// RawSigParams provides all the necessary parameters to generate a raw
+// signature for a transaction input.
+type RawSigParams struct {
+ // Tx is the transaction containing the input to be signed.
+ Tx *wire.MsgTx
+
+ // InputIndex is the index of the input to be signed.
+ InputIndex int
+
+ // Output is the previous output that is being spent.
+ Output *wire.TxOut
+
+ // SigHashes is the sighash cache for the transaction.
+ SigHashes *txscript.TxSigHashes
+
+ // HashType is the signature hash type to use.
+ HashType txscript.SigHashType
+
+ // Path is the BIP-32 derivation path of the key to be used for
+ // signing.
+ Path BIP32Path
+
+ // Tweaker is an optional function that can be used to tweak the
+ // private key before signing.
+ Tweaker PrivKeyTweaker
+
+ // Details specifies the version-specific information for signing.
+ // This field MUST be set to either LegacySpendDetails,
+ // SegwitV0SpendDetails or TaprootSpendDetails.
+ Details SpendDetails
+}
+
+// RawSignature is a raw signature.
+type RawSignature []byte
+
+// TaprootSpendPath is an enum that specifies the spending path to be used for a
+// Taproot input.
+type TaprootSpendPath uint8
+
+const (
+ // KeyPathSpend indicates that the output should be spent using the key
+ // path.
+ KeyPathSpend TaprootSpendPath = iota
+
+ // ScriptPathSpend indicates that the output should be spent using the
+ // script path.
+ ScriptPathSpend
)
-// ScriptForOutput returns the address, witness program and redeem script for a
-// given UTXO. An error is returned if the UTXO does not belong to our wallet or
-// it is not a managed pubKey address.
-func (w *Wallet) ScriptForOutput(output *wire.TxOut) (
- waddrmgr.ManagedPubKeyAddress, []byte, []byte, error) {
+// SpendDetails is a sealed interface that provides the version-specific
+// details required to generate a raw signature.
+type SpendDetails interface {
+ // isSpendDetails is a marker method to ensure that only the types
+ // defined in this package can implement this interface.
+ isSpendDetails()
- // First make sure we can sign for the input by making sure the script
- // in the UTXO belongs to our wallet and we have the private key for it.
- walletAddr, err := w.fetchOutputAddr(output.PkScript)
+ // Sign performs the version-specific signing operation.
+ Sign(params *RawSigParams, privKey *btcec.PrivateKey) (
+ RawSignature, error)
+}
+
+// LegacySpendDetails provides the details for signing a legacy P2PKH input.
+type LegacySpendDetails struct {
+ // RedeemScript is the redeem script for P2SH spends.
+ RedeemScript []byte
+}
+
+// Sign performs the version-specific signing operation for a legacy input.
+func (l LegacySpendDetails) Sign(params *RawSigParams,
+ privKey *btcec.PrivateKey) (RawSignature, error) {
+
+ // For P2SH, the redeem script must be provided. For P2PKH, the pkscript
+ // of the output is used.
+ script := l.RedeemScript
+ if script == nil {
+ script = params.Output.PkScript
+ }
+
+ rawSig, err := txscript.RawTxInSignature(
+ params.Tx, params.InputIndex, script,
+ params.HashType, privKey,
+ )
if err != nil {
- return nil, nil, nil, err
+ return nil, fmt.Errorf("cannot create raw signature: %w", err)
}
- pubKeyAddr, ok := walletAddr.(waddrmgr.ManagedPubKeyAddress)
- if !ok {
- return nil, nil, nil, fmt.Errorf("address %s is not a "+
- "p2wkh or np2wkh address", walletAddr.Address())
+ return rawSig, nil
+}
+
+// isSpendDetails implements the sealed interface.
+func (l LegacySpendDetails) isSpendDetails() {}
+
+// SegwitV0SpendDetails provides the details for signing a SegWit v0 input.
+type SegwitV0SpendDetails struct {
+ // WitnessScript is the witness script for P2WSH spends. For P2WKH,
+ // this should be the P2PKH script of the key.
+ WitnessScript []byte
+}
+
+// Sign performs the version-specific signing operation for a SegWit v0 input.
+func (s SegwitV0SpendDetails) Sign(params *RawSigParams,
+ privKey *btcec.PrivateKey) (RawSignature, error) {
+
+ sig, err := txscript.RawTxInWitnessSignature(
+ params.Tx, params.SigHashes, params.InputIndex,
+ params.Output.Value, s.WitnessScript,
+ params.HashType, privKey,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("cannot create witness sig: %w", err)
}
+ // Validate the signature by parsing it. This serves as a sanity check
+ // to ensure the generated signature is valid.
+ _, err = ecdsa.ParseDERSignature(sig[:len(sig)-1])
+ if err != nil {
+ return nil, fmt.Errorf("generated invalid witness sig: %w", err)
+ }
+
+ return sig[:len(sig)-1], nil
+}
+
+// isSpendDetails implements the sealed interface.
+func (s SegwitV0SpendDetails) isSpendDetails() {}
+
+// TaprootSpendDetails provides the details for signing a Taproot input.
+type TaprootSpendDetails struct {
+ // SpendPath specifies which spending path to use.
+ SpendPath TaprootSpendPath
+
+ // Tweak is the tweak to apply to the internal key. For a key-path
+ // spend, this is typically the merkle root of the script tree.
+ Tweak []byte
+
+ // WitnessScript is the specific script leaf being spent. This is
+ // only used for ScriptPathSpend.
+ WitnessScript []byte
+}
+
+// Sign performs the version-specific signing operation for a Taproot input.
+func (t TaprootSpendDetails) Sign(params *RawSigParams,
+ privKey *btcec.PrivateKey) (RawSignature, error) {
+
var (
- witnessProgram []byte
- sigScript []byte
+ rawSig []byte
+ err error
)
-
- switch {
- // If we're spending p2wkh output nested within a p2sh output, then
- // we'll need to attach a sigScript in addition to witness data.
- case walletAddr.AddrType() == waddrmgr.NestedWitnessPubKey:
- pubKey := pubKeyAddr.PubKey()
- pubKeyHash := address.Hash160(pubKey.SerializeCompressed())
-
- // Next, we'll generate a valid sigScript that will allow us to
- // spend the p2sh output. The sigScript will contain only a
- // single push of the p2wkh witness program corresponding to
- // the matching public key of this address.
- p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
- pubKeyHash, w.chainParams,
+ switch t.SpendPath {
+ case KeyPathSpend:
+ rawSig, err = txscript.RawTxInTaprootSignature(
+ params.Tx, params.SigHashes,
+ params.InputIndex, params.Output.Value,
+ params.Output.PkScript, t.Tweak,
+ params.HashType, privKey,
)
if err != nil {
- return nil, nil, nil, err
+ return nil, fmt.Errorf("taproot sig error: %w", err)
}
- witnessProgram, err = txscript.PayToAddrScript(p2wkhAddr)
+ case ScriptPathSpend:
+ leaf := txscript.TapLeaf{
+ LeafVersion: txscript.BaseLeafVersion,
+ Script: t.WitnessScript,
+ }
+
+ rawSig, err = txscript.RawTxInTapscriptSignature(
+ params.Tx, params.SigHashes,
+ params.InputIndex, params.Output.Value,
+ params.Output.PkScript, leaf,
+ params.HashType, privKey,
+ )
if err != nil {
- return nil, nil, nil, err
+ return nil, fmt.Errorf("tapscript sig error: %w", err)
}
+ default:
+ return nil, fmt.Errorf("%w: %v", ErrUnknownSignMethod,
+ t.SpendPath)
+ }
+
+ // Validate the signature by parsing it. This serves as a sanity check
+ // to ensure the generated signature is valid.
+ _, err = schnorr.ParseSignature(rawSig[:schnorr.SignatureSize])
+ if err != nil {
+ return nil, fmt.Errorf("generated invalid taproot sig: %w", err)
+ }
+
+ return rawSig, nil
+}
+
+// isSpendDetails implements the sealed interface.
+func (t TaprootSpendDetails) isSpendDetails() {}
+
+// A compile-time assertion to ensure that all SpendDetails implementations
+// adhere to the interface.
+var _ SpendDetails = (*LegacySpendDetails)(nil)
+var _ SpendDetails = (*SegwitV0SpendDetails)(nil)
+var _ SpendDetails = (*TaprootSpendDetails)(nil)
+
+// DerivePubKey derives a public key from a full BIP-32 derivation path.
+func (w *Wallet) DerivePubKey(_ context.Context, path BIP32Path) (
+ *btcec.PublicKey, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ managedPubKeyAddr, err := w.fetchManagedPubKeyAddress(path)
+ if err != nil {
+ return nil, err
+ }
+
+ return managedPubKeyAddr.PubKey(), nil
+}
- bldr := txscript.NewScriptBuilder()
- bldr.AddData(witnessProgram)
- sigScript, err = bldr.Script()
+// fetchManagedPubKeyAddress is a helper function that encapsulates the common
+// logic of fetching a scoped key manager, deriving a managed address from a
+// BIP32 path, and ensuring it is a public key address.
+//
+// Time Complexity:
+// - Average Case: O(1) - This is the common case where the account
+// information is already cached in memory. The function performs a few
+// map lookups and constant-time cryptographic operations.
+// - Worst Case: O(log N) - This occurs on a cache miss (e.g., the first
+// time an account is used). The function must perform a single, indexed
+// database lookup to fetch the account's master key. N is the number of
+// accounts in the wallet.
+//
+// Database Actions:
+// - This method performs a single read-only database transaction
+// (`walletdb.View`).
+// - The transaction's only purpose is to call `DeriveFromKeyPath`, which
+// performs at most one indexed database lookup for account information if
+// that information is not already in the in-memory cache.
+func (w *Wallet) fetchManagedPubKeyAddress(path BIP32Path) (
+ waddrmgr.ManagedPubKeyAddress, error) {
+
+ // Fetch the scoped key manager for the given key scope. This can be
+ // done outside of the database transaction as it only deals with
+ // in-memory state.
+ manager, err := w.addrStore.FetchScopedKeyManager(path.KeyScope)
+ if err != nil {
+ return nil, fmt.Errorf("cannot fetch scoped key manager: %w",
+ err)
+ }
+
+ // The derivation of the address is the only part that requires a
+ // database transaction.
+ var addr waddrmgr.ManagedAddress
+
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ // Derive the managed address from the derivation path.
+ derivedAddr, err := manager.DeriveFromKeyPath(
+ addrmgrNs, path.DerivationPath,
+ )
if err != nil {
- return nil, nil, nil, err
+ return fmt.Errorf("cannot derive from key path: %w",
+ err)
}
- // Otherwise, this is a regular p2wkh or p2tr output, so we include the
- // witness program itself as the subscript to generate the proper
- // sighash digest. As part of the new sighash digest algorithm, the
- // p2wkh witness program will be expanded into a regular p2kh
- // script.
- default:
- witnessProgram = output.PkScript
+ addr = derivedAddr
+
+ return nil
+ })
+ if err != nil {
+ return nil, fmt.Errorf("cannot view wallet database: %w", err)
+ }
+
+ // The post-processing of the address can be done outside of the
+ // database transaction as it only deals with the in-memory struct.
+ managedPubKeyAddr, ok := addr.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ return nil, fmt.Errorf("%w: addr %s", ErrNotPubKeyAddress,
+ addr.Address())
}
- return pubKeyAddr, witnessProgram, sigScript, nil
+ return managedPubKeyAddr, nil
}
-// PrivKeyTweaker is a function type that can be used to pass in a callback for
-// tweaking a private key before it's used to sign an input.
-type PrivKeyTweaker func(*btcec.PrivateKey) (*btcec.PrivateKey, error)
+// ECDH performs a scalar multiplication (ECDH-like operation) between a key
+// from the wallet and a remote public key. The output returned will be the
+// sha256 of the resulting shared point serialized in compressed format.
+func (w *Wallet) ECDH(_ context.Context, path BIP32Path,
+ pub *btcec.PublicKey) ([32]byte, error) {
-// ComputeInputScript generates a complete InputScript for the passed
-// transaction with the signature as defined within the passed SignDescriptor.
-// This method is capable of generating the proper input script for both
-// regular p2wkh output and p2wkh outputs nested within a regular p2sh output.
-func (w *Wallet) ComputeInputScript(tx *wire.MsgTx, output *wire.TxOut,
- inputIndex int, sigHashes *txscript.TxSigHashes,
- hashType txscript.SigHashType, tweaker PrivKeyTweaker) (wire.TxWitness,
- []byte, error) {
+ err := w.state.canSign()
+ if err != nil {
+ return [32]byte{}, err
+ }
+
+ managedPubKeyAddr, err := w.fetchManagedPubKeyAddress(path)
+ if err != nil {
+ return [32]byte{}, err
+ }
+
+ // Get the private key for the derived address.
+ privKey, err := managedPubKeyAddr.PrivKey()
+ if err != nil {
+ return [32]byte{}, fmt.Errorf("cannot get private key: %w",
+ err)
+ }
+ defer privKey.Zero()
+
+ // Perform the scalar multiplication and hash the result.
+ secret := btcec.GenerateSharedSecret(privKey, pub)
+
+ var sharedSecret [32]byte
+ copy(sharedSecret[:], secret)
+
+ return sharedSecret, nil
+}
+
+// validateSignDigestIntent validates the parameters of a SignDigestIntent.
+func validateSignDigestIntent(intent *SignDigestIntent) error {
+ // The digest must be exactly 32 bytes.
+ if len(intent.Digest) != chainhash.HashSize {
+ return ErrInvalidDigestSize
+ }
+
+ // Validate parameters based on signature type.
+ switch intent.SigType {
+ case SigTypeECDSA:
+ if intent.TaprootTweak != nil {
+ return fmt.Errorf("%w: taproot tweak cannot be used "+
+ "with ECDSA", ErrInvalidSignParam)
+ }
+
+ case SigTypeSchnorr:
+ if intent.CompactSig {
+ return fmt.Errorf("%w: compact signature cannot be "+
+ "used with Schnorr", ErrInvalidSignParam)
+ }
+ }
+
+ return nil
+}
+
+// SignDigest signs a message digest based on the provided intent.
+func (w *Wallet) SignDigest(_ context.Context, path BIP32Path,
+ intent *SignDigestIntent) (Signature, error) {
+
+ err := w.state.canSign()
+ if err != nil {
+ return nil, err
+ }
+
+ err = validateSignDigestIntent(intent)
+ if err != nil {
+ return nil, err
+ }
+
+ managedPubKeyAddr, err := w.fetchManagedPubKeyAddress(path)
+ if err != nil {
+ return nil, err
+ }
+
+ // Get the private key for the derived address.
+ privKey, err := managedPubKeyAddr.PrivKey()
+ if err != nil {
+ return nil, fmt.Errorf("cannot get private key: %w", err)
+ }
+ defer privKey.Zero()
+
+ // Now, sign the message using the derived private key. This is all
+ // pure computation, so it can be done outside the DB transaction.
+ return signDigestWithPrivKey(privKey, intent)
+}
+
+// signDigestWithPrivKey performs the actual signing of a digest with a given
+// private key, based on the options specified in the SignDigestIntent. It
+// acts as a dispatcher to the appropriate signing algorithm.
+func signDigestWithPrivKey(privKey *btcec.PrivateKey,
+ intent *SignDigestIntent) (Signature, error) {
+
+ // If Schnorr is specified, we'll generate a Schnorr signature.
+ if intent.SigType == SigTypeSchnorr {
+ return signDigestSchnorr(privKey, intent)
+ }
+
+ // Otherwise, we'll generate an ECDSA signature.
+ return signDigestECDSA(privKey, intent)
+}
+
+// signDigestSchnorr performs the actual signing of a digest with a given
+// private key, using the Schnorr signature algorithm.
+func signDigestSchnorr(privKey *btcec.PrivateKey,
+ intent *SignDigestIntent) (Signature, error) {
- walletAddr, witnessProgram, sigScript, err := w.ScriptForOutput(output)
+ if intent.TaprootTweak != nil {
+ privKey = txscript.TweakTaprootPrivKey(
+ *privKey, intent.TaprootTweak,
+ )
+ }
+
+ sig, err := schnorr.Sign(privKey, intent.Digest)
if err != nil {
- return nil, nil, err
+ return nil, fmt.Errorf("cannot create schnorr sig: %w", err)
+ }
+
+ return SchnorrSignature{sig}, nil
+}
+
+// signDigestECDSA performs the actual signing of a digest with a given
+// private key, using the ECDSA signature algorithm.
+func signDigestECDSA(privKey *btcec.PrivateKey,
+ intent *SignDigestIntent) (Signature, error) {
+
+ if intent.CompactSig {
+ sig := ecdsa.SignCompact(privKey, intent.Digest, true)
+ return CompactSignature(sig), nil
}
- privKey, err := walletAddr.PrivKey()
+ sig := ecdsa.Sign(privKey, intent.Digest)
+
+ return ECDSASignature{sig}, nil
+}
+
+// ComputeUnlockingScript generates the full sigScript and witness required to
+// spend a UTXO.
+func (w *Wallet) ComputeUnlockingScript(ctx context.Context,
+ params *UnlockingScriptParams) (*UnlockingScript, error) {
+
+ err := w.state.canSign()
+ if err != nil {
+ return nil, err
+ }
+
+ // First, we'll fetch the managed address that corresponds to the
+ // output being spent. This will be used to look up the private key
+ // required for signing.
+ scriptInfo, err := w.ScriptForOutput(ctx, *params.Output)
+ if err != nil {
+ return nil, err
+ }
+
+ // The address must be a public key address.
+ pubKeyAddr, ok := scriptInfo.Addr.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ return nil, fmt.Errorf("%w: addr %s",
+ ErrNotPubKeyAddress, scriptInfo.Addr.Address())
+ }
+
+ // Get the private key for the derived address.
+ privKey, err := pubKeyAddr.PrivKey()
if err != nil {
- return nil, nil, err
+ return nil, fmt.Errorf("cannot get private key: %w", err)
}
+ defer privKey.Zero()
- // If we need to maybe tweak our private key, do it now.
- if tweaker != nil {
- privKey, err = tweaker(privKey)
+ // If a tweaker is provided, we'll use it to tweak the private key.
+ if params.Tweaker != nil {
+ privKey, err = params.Tweaker(privKey)
if err != nil {
- return nil, nil, err
+ return nil, fmt.Errorf("error tweaking private key: %w",
+ err)
}
}
- // We need to produce a Schnorr signature for p2tr key spend addresses.
- if txscript.IsPayToTaproot(output.PkScript) {
- // We can now generate a valid witness which will allow us to
- // spend this output.
- witnessScript, err := txscript.TaprootWitnessSignature(
- tx, sigHashes, inputIndex, output.Value,
- output.PkScript, hashType, privKey,
+ // With the private key retrieved and tweaked, we can now generate the
+ // unlocking script.
+ return signAndAssembleScript(params, privKey, &scriptInfo)
+}
+
+// signAndAssembleScript is a helper function that performs the final signing
+// and script assembly for a given set of parameters and a private key.
+func signAndAssembleScript(params *UnlockingScriptParams,
+ privKey *btcec.PrivateKey,
+ scriptInfo *Script) (*UnlockingScript, error) {
+
+ // Dispatch to the correct signing logic based on the address type of
+ // the output.
+ switch scriptInfo.Addr.AddrType() {
+ // For Taproot key-path spends, we produce a Schnorr signature.
+ case waddrmgr.TaprootPubKey:
+ witness, err := txscript.TaprootWitnessSignature(
+ params.Tx, params.SigHashes, params.InputIndex,
+ params.Output.Value, params.Output.PkScript,
+ params.HashType, privKey,
)
if err != nil {
- return nil, nil, err
+ return nil, fmt.Errorf("taproot witness error: %w", err)
}
- return witnessScript, nil, nil
+ return &UnlockingScript{
+ Witness: witness,
+ }, nil
+
+ // For SegWit v0 outputs, we'll generate a standard ECDSA signature.
+ case waddrmgr.WitnessPubKey, waddrmgr.NestedWitnessPubKey:
+ witness, err := txscript.WitnessSignature(
+ params.Tx, params.SigHashes, params.InputIndex,
+ params.Output.Value, scriptInfo.WitnessProgram,
+ params.HashType, privKey, true,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("witness sig error: %w", err)
+ }
+
+ return &UnlockingScript{
+ Witness: witness,
+ SigScript: scriptInfo.RedeemScript,
+ }, nil
+
+ // For legacy P2PKH outputs, we'll generate a signature script.
+ case waddrmgr.PubKeyHash:
+ sigScript, err := txscript.SignatureScript(
+ params.Tx, params.InputIndex, params.Output.PkScript,
+ params.HashType, privKey, true,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("sig script error: %w", err)
+ }
+
+ return &UnlockingScript{
+ SigScript: sigScript,
+ }, nil
+
+ // The following address types are not supported by this function.
+ case waddrmgr.Script, waddrmgr.RawPubKey, waddrmgr.WitnessScript,
+ waddrmgr.TaprootScript:
+ return nil, fmt.Errorf("%w: %v", ErrUnsupportedAddressType,
+ scriptInfo.Addr.AddrType())
+
+ default:
+ return nil, fmt.Errorf("%w: %v", ErrUnsupportedAddressType,
+ scriptInfo.Addr.AddrType())
}
+}
- // Generate a valid witness stack for the input.
- witnessScript, err := txscript.WitnessSignature(
- tx, sigHashes, inputIndex, output.Value, witnessProgram,
- hashType, privKey, true,
- )
+// ComputeRawSig generates a raw signature for a single transaction input. The
+// caller is responsible for assembling the final witness.
+func (w *Wallet) ComputeRawSig(_ context.Context, params *RawSigParams) (
+ RawSignature, error) {
+
+ err := w.state.canSign()
+ if err != nil {
+ return nil, err
+ }
+
+ // Get the managed address for the specified derivation path. This will
+ // be used to retrieve the private key.
+ managedAddr, err := w.fetchManagedPubKeyAddress(params.Path)
+ if err != nil {
+ return nil, err
+ }
+
+ // Get the private key for the address.
+ privKey, err := managedAddr.PrivKey()
+ if err != nil {
+ return nil, fmt.Errorf("cannot get private key: %w", err)
+ }
+ defer privKey.Zero()
+
+ // If a tweaker is provided, we'll use it to tweak the private key.
+ if params.Tweaker != nil {
+ privKey, err = params.Tweaker(privKey)
+ if err != nil {
+ return nil, fmt.Errorf("error tweaking private key: %w",
+ err)
+ }
+ }
+
+ // With the private key retrieved and tweaked, we can now delegate the
+ // actual signing to the version-specific details object.
+ rawSig, err := params.Details.Sign(params, privKey)
+ if err != nil {
+ return nil, fmt.Errorf("cannot sign transaction: %w", err)
+ }
+
+ return rawSig, nil
+}
+
+// DerivePrivKey derives a private key from a full BIP-32 derivation
+// path.
+//
+// DANGER: This method exports sensitive key material.
+func (w *Wallet) DerivePrivKey(_ context.Context, path BIP32Path) (
+ *btcec.PrivateKey, error) {
+
+ err := w.state.canSign()
+ if err != nil {
+ return nil, err
+ }
+
+ managedPubKeyAddr, err := w.fetchManagedPubKeyAddress(path)
+ if err != nil {
+ return nil, err
+ }
+
+ privKey, err := managedPubKeyAddr.PrivKey()
+ if err != nil {
+ return nil, fmt.Errorf("cannot get private key: %w", err)
+ }
+
+ return privKey, nil
+}
+
+// GetPrivKeyForAddress returns the private key for a given address.
+//
+// DANGER: This method exports sensitive key material.
+func (w *Wallet) GetPrivKeyForAddress(_ context.Context, a address.Address) (
+ *btcec.PrivateKey, error) {
+
+ err := w.state.canSign()
+ if err != nil {
+ return nil, err
+ }
+
+ return w.PrivKeyForAddress(a)
+}
+
+// PrivKeyForAddress looks up the associated private key for a P2PKH or P2PK
+// address.
+func (w *Wallet) PrivKeyForAddress(a address.Address) (
+ *btcec.PrivateKey, error) {
+
+ err := w.state.canSign()
+ if err != nil {
+ return nil, err
+ }
+
+ var privKey *btcec.PrivateKey
+
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+
+ addr, err := w.addrStore.Address(addrmgrNs, a)
+ if err != nil {
+ return fmt.Errorf("failed to get address: %w", err)
+ }
+
+ managedPubKeyAddr, ok := addr.(waddrmgr.ManagedPubKeyAddress)
+ if !ok {
+ return ErrNoAssocPrivateKey
+ }
+
+ privKey, err = managedPubKeyAddr.PrivKey()
+ if err != nil {
+ return fmt.Errorf("failed to get private key: %w", err)
+ }
+
+ return nil
+ })
if err != nil {
- return nil, nil, err
+ return nil, fmt.Errorf("failed to view database: %w", err)
}
- return witnessScript, sigScript, nil
+ return privKey, nil
}
diff --git a/wallet/signer_benchmark_test.go b/wallet/signer_benchmark_test.go
new file mode 100644
index 0000000000..b5e7f272c4
--- /dev/null
+++ b/wallet/signer_benchmark_test.go
@@ -0,0 +1,1460 @@
+package wallet
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/btcsuite/btcd/btcec/v2"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/stretchr/testify/require"
+)
+
+// BenchmarkDerivePubKey benchmarks the DerivePubKey method across different
+// wallet sizes. The benchmark measures the performance of deriving a public
+// key from a BIP-32 path, which involves database lookups and cryptographic
+// operations.
+func BenchmarkDerivePubKey(b *testing.B) {
+ const (
+ startGrowthIteration = 0
+ maxGrowthIteration = 5
+ )
+
+ var (
+ // accountGrowth uses linearGrowth to test how performance
+ // scales with the number of accounts in the wallet. Key
+ // derivation uses the account index in the BIP-32 path, so
+ // database lookup time should remain constant due to indexed
+ // lookups.
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ // addressGrowth uses constantGrowth since address count doesn't
+ // affect the key derivation's time complexity - it derives from
+ // an explicit path without address search.
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ // utxoGrowth uses constantGrowth since UTXO count doesn't
+ // affect the key derivation's time complexity.
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{
+ waddrmgr.KeyScopeBIP0084,
+ waddrmgr.KeyScopeBIP0086,
+ }
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("Accounts-%0*d", accountGrowthPadding,
+ accountGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ // Use a path from the middle of the account range
+ // for representative performance.
+ accountIndex := uint32(accountGrowth[i] / 2)
+ path := BIP32Path{
+ KeyScope: scopes[0],
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: accountIndex,
+ Branch: 0,
+ Index: 0,
+ },
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := w.DerivePubKey(b.Context(), path)
+ require.NoError(b, err)
+ }
+ })
+ }
+}
+
+// BenchmarkECDH benchmarks the ECDH method across different wallet sizes.
+// The benchmark measures the performance of performing an ECDH operation
+// between a wallet key and a remote public key.
+func BenchmarkECDH(b *testing.B) {
+ const (
+ startGrowthIteration = 0
+ maxGrowthIteration = 5
+ )
+
+ var (
+ // accountGrowth uses linearGrowth to test scaling with wallet
+ // size. ECDH derives the wallet's private key using the account
+ // index in the BIP-32 path for the scalar multiplication.
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ // addressGrowth uses constantGrowth since address count doesn't
+ // the ECDH operation's time complexity. It uses an explicit
+ // path.
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ // utxoGrowth uses constantGrowth since UTXO count doesn't
+ // affect the cryptographic operation's time complexity.
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ // Generate a remote public key for ECDH.
+ remotePrivKey, err := btcec.NewPrivateKey()
+ require.NoError(b, err)
+
+ remotePubKey := remotePrivKey.PubKey()
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("Accounts-%0*d", accountGrowthPadding,
+ accountGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ accountIndex := uint32(accountGrowth[i] / 2)
+ path := BIP32Path{
+ KeyScope: scopes[0],
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: accountIndex,
+ Branch: 0,
+ Index: 0,
+ },
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := w.ECDH(
+ b.Context(), path, remotePubKey,
+ )
+ require.NoError(b, err)
+ }
+ })
+ }
+}
+
+// BenchmarkSignDigestECDSA benchmarks the SignDigest method for ECDSA
+// signatures across different wallet sizes. The benchmark measures the
+// performance of signing a digest with ECDSA.
+func BenchmarkSignDigestECDSA(b *testing.B) {
+ const (
+ startGrowthIteration = 0
+ maxGrowthIteration = 5
+ )
+
+ var (
+ // accountGrowth uses linearGrowth to test scaling with wallet
+ // size. Signature operations derive keys using the account
+ // index in the BIP-32 path.
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ // addressGrowth uses constantGrowth since address count doesn't
+ // affect the signature generation's time complexity when using
+ // an explicit path.
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ // utxoGrowth uses constantGrowth since UTXO count doesn't
+ // affect the signature generation's time complexity.
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ // Create a test digest to sign.
+ digest := chainhash.HashB([]byte("test message"))
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("Accounts-%0*d", accountGrowthPadding,
+ accountGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ accountIndex := uint32(accountGrowth[i] / 2)
+ path := BIP32Path{
+ KeyScope: scopes[0],
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: accountIndex,
+ Branch: 0,
+ Index: 0,
+ },
+ }
+
+ intent := &SignDigestIntent{
+ Digest: digest,
+ SigType: SigTypeECDSA,
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ sig, err := w.SignDigest(
+ b.Context(), path, intent,
+ )
+ require.NoError(b, err)
+ require.NotNil(b, sig)
+ }
+ })
+ }
+}
+
+// BenchmarkSignDigestECDSACompact benchmarks the SignDigest method for
+// compact ECDSA signatures across different wallet sizes.
+func BenchmarkSignDigestECDSACompact(b *testing.B) {
+ const (
+ startGrowthIteration = 0
+ maxGrowthIteration = 5
+ )
+
+ var (
+ // accountGrowth uses linearGrowth to test scaling with wallet
+ // size. Signature operations derive keys using the account
+ // index in the BIP-32 path.
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ // addressGrowth uses constantGrowth since address count doesn't
+ // affect the signature generation's time complexity when using
+ // an explicit path.
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ // utxoGrowth uses constantGrowth since UTXO count doesn't
+ // affect the signature generation's time complexity.
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ digest := chainhash.DoubleHashB([]byte("test message"))
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("Accounts-%0*d", accountGrowthPadding,
+ accountGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ accountIndex := uint32(accountGrowth[i] / 2)
+
+ path := BIP32Path{
+ KeyScope: scopes[0],
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: accountIndex,
+ Branch: 0,
+ Index: 0,
+ },
+ }
+
+ intent := &SignDigestIntent{
+ Digest: digest,
+ SigType: SigTypeECDSA,
+ CompactSig: true,
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ sig, err := w.SignDigest(
+ b.Context(), path, intent,
+ )
+ require.NoError(b, err)
+ require.NotNil(b, sig)
+ }
+ })
+ }
+}
+
+// BenchmarkSignDigestSchnorr benchmarks the SignDigest method for Schnorr
+// signatures across different wallet sizes. The benchmark measures the
+// performance of signing a digest with Schnorr signatures.
+func BenchmarkSignDigestSchnorr(b *testing.B) {
+ const (
+ startGrowthIteration = 0
+ maxGrowthIteration = 5
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0086}
+ )
+
+ digest := chainhash.TaggedHash(
+ []byte("BIP0340/challenge"), []byte("test message"),
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("Accounts-%0*d", accountGrowthPadding,
+ accountGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ accountIndx := uint32(accountGrowth[i] / 2)
+ path := BIP32Path{
+ KeyScope: scopes[0],
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: accountIndx,
+ Branch: 0,
+ Index: 0,
+ },
+ }
+
+ intent := &SignDigestIntent{
+ Digest: digest[:],
+ SigType: SigTypeSchnorr,
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ sig, err := w.SignDigest(
+ b.Context(), path, intent,
+ )
+ require.NoError(b, err)
+ require.NotNil(b, sig)
+ }
+ })
+ }
+}
+
+// BenchmarkComputeUnlockingScriptP2WKH benchmarks the ComputeUnlockingScript
+// method for P2WKH outputs across different wallet sizes and UTXO counts.
+func BenchmarkComputeUnlockingScriptP2WKH(b *testing.B) {
+ const (
+ startGrowthIteration = 0
+ maxGrowthIteration = 5
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ utxoGrowthPadding = decimalWidth(
+ utxoGrowth[len(utxoGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("Accounts-%0*d/UTXOs-%0*d",
+ accountGrowthPadding, accountGrowth[i],
+ utxoGrowthPadding, utxoGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ // Get a test address and create a P2WKH output.
+ testAddr := getTestAddress(
+ b, bw.Wallet, accountGrowth[i],
+ )
+ pkScript, err := txscript.PayToAddrScript(testAddr)
+ require.NoError(b, err)
+
+ prevOut := &wire.TxOut{
+ Value: 100000,
+ PkScript: pkScript,
+ }
+
+ // Create a spending transaction.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: chainhash.Hash{},
+ Index: 0,
+ },
+ })
+ tx.AddTxOut(&wire.TxOut{
+ Value: 50000,
+ PkScript: pkScript,
+ })
+
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ params := &UnlockingScriptParams{
+ Tx: tx,
+ InputIndex: 0,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ unlockScript, err := bw.ComputeUnlockingScript(
+ b.Context(), params,
+ )
+ require.NoError(b, err)
+ require.NotNil(b, unlockScript)
+ }
+ })
+ }
+}
+
+// BenchmarkComputeUnlockingScriptP2TR benchmarks the ComputeUnlockingScript
+// method for P2TR (Taproot) key-path spends across different wallet sizes.
+func BenchmarkComputeUnlockingScriptP2TR(b *testing.B) {
+ const (
+ startGrowthIteration = 0
+ maxGrowthIteration = 5
+ )
+
+ var (
+ // accountGrowth uses linearGrowth to test scaling with wallet
+ // size. ComputeUnlockingScript derives the signing key from the
+ // output's address, which requires account lookup.
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ // addressGrowth uses constantGrowth since we're testing with a
+ // single specific output address.
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ // utxoGrowth uses constantGrowth since UTXO count doesn't
+ // affect signing a single input.
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0086}
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("Accounts-%0*d", accountGrowthPadding,
+ accountGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ // Get a test Taproot address.
+ testAddr := getTestAddress(
+ b, bw.Wallet, accountGrowth[i],
+ )
+ pkScript, err := txscript.PayToAddrScript(testAddr)
+ require.NoError(b, err)
+
+ prevOut := &wire.TxOut{
+ Value: 100000,
+ PkScript: pkScript,
+ }
+
+ // Create a spending transaction.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: chainhash.Hash{},
+ Index: 0,
+ },
+ })
+ tx.AddTxOut(&wire.TxOut{
+ Value: 50000,
+ PkScript: pkScript,
+ })
+
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ params := &UnlockingScriptParams{
+ Tx: tx,
+ InputIndex: 0,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashDefault,
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ unlockScript, err := bw.ComputeUnlockingScript(
+ b.Context(), params,
+ )
+ require.NoError(b, err)
+ require.NotNil(b, unlockScript)
+ }
+ })
+ }
+}
+
+// BenchmarkComputeRawSigSegwitV0 benchmarks the ComputeRawSig method for
+// SegWit v0 inputs.
+func BenchmarkComputeRawSigSegwitV0(b *testing.B) {
+ const (
+ startGrowthIteration = 0
+ maxGrowthIteration = 5
+ )
+
+ var (
+ // accountGrowth uses constantGrowth to verify that wallet size
+ // doesn't affect performance. ComputeRawSig uses an explicit
+ // BIP-32 path, so performance should be constant regardless of
+ // account count.
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ // addressGrowth uses constantGrowth since address count doesn't
+ // affect the ComputeRawSig operation's time complexity - it
+ // uses an explicit BIP-32 path.
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ // utxoGrowth uses constantGrowth since UTXO count doesn't
+ // affect the ComputeRawSig operation's time complexity.
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("Accounts-%0*d", accountGrowthPadding,
+ accountGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ // Get a test address and create witness script.
+ testAddr := getTestAddress(
+ b, bw.Wallet, accountGrowth[i],
+ )
+ witnessPubKeyHash := testAddr.ScriptAddress()
+ witnessScript, err := txscript.NewScriptBuilder().
+ AddOp(txscript.OP_DUP).
+ AddOp(txscript.OP_HASH160).
+ AddData(witnessPubKeyHash).
+ AddOp(txscript.OP_EQUALVERIFY).
+ AddOp(txscript.OP_CHECKSIG).
+ Script()
+ require.NoError(b, err)
+
+ pkScript, err := txscript.PayToAddrScript(testAddr)
+ require.NoError(b, err)
+
+ prevOut := &wire.TxOut{
+ Value: 100000,
+ PkScript: pkScript,
+ }
+
+ // Create a spending transaction.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: chainhash.Hash{},
+ Index: 0,
+ },
+ })
+ tx.AddTxOut(&wire.TxOut{
+ Value: 50000,
+ PkScript: pkScript,
+ })
+
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ accountIndex := uint32(accountGrowth[i] / 2)
+ addressIndex := uint32(addressGrowth[i] / 2)
+ path := BIP32Path{
+ KeyScope: scopes[0],
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: accountIndex,
+ Branch: 0,
+ Index: addressIndex,
+ },
+ }
+
+ params := &RawSigParams{
+ Tx: tx,
+ InputIndex: 0,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
+ Path: path,
+ Details: SegwitV0SpendDetails{
+ WitnessScript: witnessScript,
+ },
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ sig, err := bw.ComputeRawSig(
+ b.Context(), params,
+ )
+ require.NoError(b, err)
+ require.NotNil(b, sig)
+ }
+ })
+ }
+}
+
+// BenchmarkComputeRawSigTaproot benchmarks the ComputeRawSig method for
+// Taproot key-path spends. Since ComputeRawSig uses an explicit path and
+// doesn't search through accounts, wallet size shouldn't affect performance -
+// we use constantGrowth to verify performance remains constant regardless of
+// wallet size.
+func BenchmarkComputeRawSigTaproot(b *testing.B) {
+ const (
+ startGrowthIteration = 0
+ maxGrowthIteration = 10
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0086}
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("Accounts-%0*d", accountGrowthPadding,
+ accountGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ // Get a test Taproot address.
+ testAddr := getTestAddress(
+ b, bw.Wallet, accountGrowth[i],
+ )
+ pkScript, err := txscript.PayToAddrScript(testAddr)
+ require.NoError(b, err)
+
+ prevOut := &wire.TxOut{
+ Value: 100000,
+ PkScript: pkScript,
+ }
+
+ // Create a spending transaction.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: chainhash.Hash{},
+ Index: 0,
+ },
+ })
+ tx.AddTxOut(&wire.TxOut{
+ Value: 50000,
+ PkScript: pkScript,
+ })
+
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ accountIndex := uint32(accountGrowth[i] / 2)
+ addressIndex := uint32(addressGrowth[i] / 2)
+ path := BIP32Path{
+ KeyScope: scopes[0],
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: accountIndex,
+ Branch: 0,
+ Index: addressIndex,
+ },
+ }
+
+ params := &RawSigParams{
+ Tx: tx,
+ InputIndex: 0,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashDefault,
+ Path: path,
+ Details: TaprootSpendDetails{
+ SpendPath: KeyPathSpend,
+ },
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ sig, err := bw.ComputeRawSig(
+ b.Context(), params,
+ )
+ require.NoError(b, err)
+ require.NotNil(b, sig)
+ }
+ })
+ }
+}
+
+// BenchmarkDerivePrivKey benchmarks the DerivePrivKey method (UnsafeSigner)
+// across different wallet sizes. This benchmark measures the performance of
+// deriving a private key from a BIP-32 path.
+func BenchmarkDerivePrivKey(b *testing.B) {
+ const (
+ startGrowthIteration = 0
+ maxGrowthIteration = 10
+ )
+
+ var (
+ // accountGrowth uses linearGrowth to test scaling with wallet
+ // size. Signature operations derive keys using the account
+ // index in the BIP-32 path.
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ // addressGrowth uses constantGrowth since address count doesn't
+ // affect the signature generation's time complexity when using
+ // an explicit path.
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ // utxoGrowth uses constantGrowth since UTXO count doesn't
+ // affect the signature generation's time complexity.
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("Accounts-%0*d", accountGrowthPadding,
+ accountGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ accountIndex := uint32(accountGrowth[i] / 2)
+
+ path := BIP32Path{
+ KeyScope: scopes[0],
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: accountIndex,
+ Branch: 0,
+ Index: 0,
+ },
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ privKey, err := w.DerivePrivKey(
+ b.Context(), path,
+ )
+ require.NoError(b, err)
+ require.NotNil(b, privKey)
+ privKey.Zero()
+ }
+ })
+ }
+}
+
+// BenchmarkGetPrivKeyForAddress benchmarks the GetPrivKeyForAddress method
+// (UnsafeSigner) across different wallet sizes and address counts.
+func BenchmarkGetPrivKeyForAddress(b *testing.B) {
+ const (
+ startGrowthIteration = 0
+ maxGrowthIteration = 10
+ )
+
+ var (
+ // accountGrowth uses linearGrowth since the account is part of
+ // the address search space.
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ // addressGrowth uses linearGrowth to test how address lookup
+ // performance scales. GetPrivKeyForAddress searches through
+ // the address manager to find the matching address, so
+ // performance scales with total address count.
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ // utxoGrowth uses constantGrowth since UTXO count doesn't
+ // affect the address lookup's time complexity.
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("Accounts-%0*d/Addresses-%0*d",
+ accountGrowthPadding, accountGrowth[i],
+ addressGrowthPadding, addressGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ testAddr := getTestAddress(
+ b, bw.Wallet, accountGrowth[i],
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ privKey, err := bw.GetPrivKeyForAddress(
+ b.Context(), testAddr,
+ )
+ require.NoError(b, err)
+ require.NotNil(b, privKey)
+ privKey.Zero()
+ }
+ })
+ }
+}
+
+// BenchmarkSignDigestComparisonECDSAvsSchnorr compares ECDSA and Schnorr
+// signature performance side-by-side. This benchmark helps understand the
+// performance characteristics of different signature algorithms.
+func BenchmarkSignDigestComparisonECDSAvsSchnorr(b *testing.B) {
+ const numAccounts = 5
+
+ scopes := []waddrmgr.KeyScope{
+ waddrmgr.KeyScopeBIP0084, // For ECDSA
+ waddrmgr.KeyScopeBIP0086, // For Schnorr
+ }
+
+ ecdsaDigest := chainhash.HashB([]byte("test message"))
+ schnorrDigest := chainhash.TaggedHash(
+ []byte("BIP0340/challenge"), []byte("test message"),
+ )
+
+ b.Run("ECDSA", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: []waddrmgr.KeyScope{scopes[0]},
+ numAccounts: numAccounts,
+ numAddresses: 5,
+ numWalletTxs: 0,
+ },
+ )
+
+ path := BIP32Path{
+ KeyScope: scopes[0],
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: 0,
+ Branch: 0,
+ Index: 0,
+ },
+ }
+
+ intent := &SignDigestIntent{
+ Digest: ecdsaDigest,
+ SigType: SigTypeECDSA,
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ sig, err := w.SignDigest(b.Context(), path, intent)
+ require.NoError(b, err)
+ require.NotNil(b, sig)
+ }
+ })
+
+ b.Run("Schnorr", func(b *testing.B) {
+ w := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: []waddrmgr.KeyScope{scopes[1]},
+ numAccounts: numAccounts,
+ numAddresses: 5,
+ numWalletTxs: 0,
+ },
+ )
+
+ path := BIP32Path{
+ KeyScope: scopes[1],
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: 0,
+ Branch: 0,
+ Index: 0,
+ },
+ }
+
+ intent := &SignDigestIntent{
+ Digest: schnorrDigest[:],
+ SigType: SigTypeSchnorr,
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ sig, err := w.SignDigest(b.Context(), path, intent)
+ require.NoError(b, err)
+ require.NotNil(b, sig)
+ }
+ })
+}
+
+// BenchmarkMultiInputTransaction benchmarks signing a transaction with
+// multiple inputs.
+func BenchmarkMultiInputTransaction(b *testing.B) {
+ const (
+ startGrowthIteration = 0
+ maxGrowthIteration = 5
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ // Test with growing number of inputs using linear growth.
+ inputGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ inputGrowthPadding = decimalWidth(
+ inputGrowth[len(inputGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ numInputs := inputGrowth[i]
+
+ name := fmt.Sprintf("Inputs-%0*d", inputGrowthPadding,
+ numInputs)
+
+ b.Run(name, func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ // Get test addresses for outputs.
+ testAddr := getTestAddress(
+ b, bw.Wallet, accountGrowth[i],
+ )
+ pkScript, err := txscript.PayToAddrScript(testAddr)
+ require.NoError(b, err)
+
+ tx := wire.NewMsgTx(2)
+
+ // Create previous outputs for each input.
+ prevOuts := make([]*wire.TxOut, numInputs)
+ for j := range numInputs {
+ prevOuts[j] = &wire.TxOut{
+ Value: 100000,
+ PkScript: pkScript,
+ }
+
+ tx.AddTxIn(&wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: chainhash.Hash{byte(j)},
+ Index: 0,
+ },
+ })
+ }
+
+ // Add a single output.
+ tx.AddTxOut(&wire.TxOut{
+ Value: int64(numInputs) * 100000,
+ PkScript: pkScript,
+ })
+
+ // Pre-compute sigHashes.
+ fetcher := txscript.NewMultiPrevOutFetcher(nil)
+ for j, prevOut := range prevOuts {
+ fetcher.AddPrevOut(wire.OutPoint{
+ Hash: chainhash.Hash{byte(j)},
+ Index: 0,
+ }, prevOut)
+ }
+
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ signMultipleInputs(
+ b, bw.Wallet, tx, prevOuts, sigHashes,
+ txscript.SigHashAll,
+ )
+ }
+ })
+ }
+}
+
+// BenchmarkComputeUnlockingScriptWithTweaker benchmarks the
+// ComputeUnlockingScript method with a custom private key tweaker function
+// across different numbers of inputs.
+func BenchmarkComputeUnlockingScriptWithTweaker(b *testing.B) {
+ const (
+ startGrowthIteration = 0
+ maxGrowthIteration = 5
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ // Test with growing number of inputs to see tweaker overhead
+ // with multiple inputs.
+ inputGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ inputGrowthPadding = decimalWidth(
+ inputGrowth[len(inputGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ // Create a tweaker function that adds a scalar to the key.
+ tweakScalar := new(btcec.ModNScalar)
+ tweakScalar.SetByteSlice([]byte{0x01, 0x02, 0x03})
+
+ // Without that ignore linting directive getting result 1 (error) is
+ // always nil. This is acceptable since it is used for convenient
+ // testing purposes.
+ //
+ //nolint:unparam
+ tweaker := func(privKey *btcec.PrivateKey) (*btcec.PrivateKey, error) {
+ // Add the tweak to the private key scalar.
+ var privKeyScalar btcec.ModNScalar
+ privKeyScalar.Set(&privKey.Key)
+ privKeyScalar.Add(tweakScalar)
+
+ return btcec.PrivKeyFromScalar(&privKeyScalar), nil
+ }
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ numInputs := inputGrowth[i]
+
+ name := fmt.Sprintf("Inputs-%0*d", inputGrowthPadding,
+ numInputs)
+
+ b.Run(name, func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ // Get a test address and create P2WKH outputs.
+ testAddr := getTestAddress(
+ b, bw.Wallet, accountGrowth[i],
+ )
+ pkScript, err := txscript.PayToAddrScript(testAddr)
+ require.NoError(b, err)
+
+ // Create multiple previous outputs.
+ prevOuts := make([]*wire.TxOut, numInputs)
+ for j := range numInputs {
+ prevOuts[j] = &wire.TxOut{
+ Value: 100000,
+ PkScript: pkScript,
+ }
+ }
+
+ // Create a spending transaction with multiple inputs.
+ tx := wire.NewMsgTx(2)
+ for j := range numInputs {
+ tx.AddTxIn(&wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: chainhash.Hash{byte(j)},
+ Index: 0,
+ },
+ })
+ }
+
+ tx.AddTxOut(&wire.TxOut{
+ Value: int64(numInputs) * 50000,
+ PkScript: pkScript,
+ })
+
+ // Pre-compute sigHashes.
+ fetcher := txscript.NewMultiPrevOutFetcher(nil)
+ for j, prevOut := range prevOuts {
+ fetcher.AddPrevOut(wire.OutPoint{
+ Hash: chainhash.Hash{byte(j)},
+ Index: 0,
+ }, prevOut)
+ }
+
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ signMultipleInputsWithTweaker(
+ b, bw.Wallet, tx, prevOuts, sigHashes,
+ txscript.SigHashAll, tweaker,
+ )
+ }
+ })
+ }
+}
+
+// BenchmarkDifferentAddressTypes compares signing performance across
+// different address types (P2PKH, P2WKH, P2TR).
+func BenchmarkDifferentAddressTypes(b *testing.B) {
+ const (
+ numAccounts = 5
+ numAddresses = 10
+ )
+
+ testCases := []struct {
+ name string
+ scope waddrmgr.KeyScope
+ hashType txscript.SigHashType
+ }{
+ {
+ name: "P2PKH-Legacy",
+ scope: waddrmgr.KeyScopeBIP0044,
+ hashType: txscript.SigHashAll,
+ },
+ {
+ name: "P2WKH-SegWit",
+ scope: waddrmgr.KeyScopeBIP0084,
+ hashType: txscript.SigHashAll,
+ },
+ {
+ name: "P2TR-Taproot",
+ scope: waddrmgr.KeyScopeBIP0086,
+ hashType: txscript.SigHashDefault,
+ },
+ }
+
+ for _, tc := range testCases {
+ b.Run(tc.name, func(b *testing.B) {
+ scopes := []waddrmgr.KeyScope{tc.scope}
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: numAccounts,
+ numAddresses: numAddresses,
+ numWalletTxs: 0,
+ },
+ )
+
+ // Get a test address for this scope.
+ testAddr, err := bw.CurrentAddress(0, tc.scope)
+ if err != nil {
+ // Fallback to getting any address.
+ testAddr = getTestAddress(
+ b, bw.Wallet, numAccounts,
+ )
+ }
+
+ pkScript, err := txscript.PayToAddrScript(testAddr)
+ require.NoError(b, err)
+
+ prevOut := &wire.TxOut{
+ Value: 100000,
+ PkScript: pkScript,
+ }
+
+ // Create a spending transaction.
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(&wire.TxIn{
+ PreviousOutPoint: wire.OutPoint{
+ Hash: chainhash.Hash{},
+ Index: 0,
+ },
+ })
+ tx.AddTxOut(&wire.TxOut{
+ Value: 50000,
+ PkScript: pkScript,
+ })
+
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ params := &UnlockingScriptParams{
+ Tx: tx,
+ InputIndex: 0,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: tc.hashType,
+ }
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ unlockScript, err := bw.ComputeUnlockingScript(
+ b.Context(), params,
+ )
+ require.NoError(b, err)
+ require.NotNil(b, unlockScript)
+ }
+ })
+ }
+}
diff --git a/wallet/signer_test.go b/wallet/signer_test.go
index e55dc18add..8ed482daa0 100644
--- a/wallet/signer_test.go
+++ b/wallet/signer_test.go
@@ -5,105 +5,1935 @@
package wallet
import (
+ "encoding/hex"
+ "errors"
"testing"
- "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcec/v2"
+ "github.com/btcsuite/btcd/btcec/v2/ecdsa"
+ "github.com/btcsuite/btcd/btcec/v2/schnorr"
+ "github.com/btcsuite/btcd/chainhash/v2"
"github.com/btcsuite/btcd/txscript/v2"
"github.com/btcsuite/btcd/wire/v2"
"github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
)
-// TestComputeInputScript checks that the wallet can create the full
-// witness script for a witness output.
-func TestComputeInputScript(t *testing.T) {
+var (
+ // errManagerNotFound is returned when a scoped manager cannot be found.
+ errManagerNotFound = errors.New("manager not found")
+
+ // errDerivationFailed is returned when a key derivation fails.
+ errDerivationFailed = errors.New("derivation failed")
+
+ // errPrivKeyMock is a mock error for private key retrieval.
+ errPrivKeyMock = errors.New("privkey error")
+
+ // errTweakMock is a mock error for private key tweaking.
+ errTweakMock = errors.New("tweak error")
+
+ // errSignMock is a mock error for signing operations.
+ errSignMock = errors.New("sign error")
+)
+
+// TestDerivePubKeySuccess tests the successful derivation of a public key.
+func TestDerivePubKeySuccess(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet with mocks, a test key, and a
+ // derivation path.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey := privKey.PubKey()
+
+ path := BIP32Path{
+ KeyScope: waddrmgr.KeyScopeBIP0084,
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: 0,
+ Branch: 0,
+ Index: 0,
+ },
+ }
+
+ // Set up the mock account manager and the mock address that will be
+ // returned by the derivation call.
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything, path.DerivationPath,
+ ).Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("PubKey").Return(pubKey).Once()
+
+ // Act: Derive the public key.
+ derivedKey, err := w.DerivePubKey(t.Context(), path)
+
+ // Assert: Check that the correct key is returned without error.
+ require.NoError(t, err)
+ require.True(t, pubKey.IsEqual(derivedKey))
+}
+
+// TestDerivePubKeyFetchManagerFails tests the failure case where the scoped
+// key manager cannot be fetched.
+func TestDerivePubKeyFetchManagerFails(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet and a test path. Configure the mock
+ // addrStore to return an error when fetching the key manager.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0084}
+
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return((*mockAccountStore)(nil), errManagerNotFound).Once()
+
+ // Act: Attempt to derive the public key.
+ _, err := w.DerivePubKey(t.Context(), path)
+
+ // Assert: Check that the error is propagated correctly.
+ require.ErrorIs(t, err, errManagerNotFound)
+ mocks.addrStore.AssertExpectations(t)
+}
+
+// TestDerivePubKeyDeriveFails tests the failure case where the key derivation
+// from the path fails.
+func TestDerivePubKeyDeriveFails(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet, mocks, and a test path. Configure the
+ // mock account manager to return an error on derivation.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ path := BIP32Path{
+ KeyScope: waddrmgr.KeyScopeBIP0084,
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: 0,
+ Branch: 0,
+ Index: 0,
+ },
+ }
+
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything, path.DerivationPath,
+ ).Return((*mockManagedPubKeyAddr)(nil), errDerivationFailed).Once()
+
+ // Act: Attempt to derive the public key.
+ _, err := w.DerivePubKey(t.Context(), path)
+
+ // Assert: Check that the error is propagated correctly.
+ require.ErrorIs(t, err, errDerivationFailed)
+}
+
+// TestDerivePubKeyNotPubKeyAddr tests the failure case where the derived
+// address is not a public key address.
+func TestDerivePubKeyNotPubKeyAddr(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet and mocks. Configure the mock derivation
+ // to return a managed address that is NOT a ManagedPubKeyAddress.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0084}
+
+ // We need a valid address for the error message.
+ addr, err := address.NewAddressWitnessPubKeyHash(
+ make([]byte, 20), w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On("DeriveFromKeyPath",
+ mock.Anything, mock.Anything,
+ ).Return(mocks.addr, nil).Once()
+ mocks.addr.On("Address").Return(addr).Once()
+
+ // Act: Attempt to derive the public key.
+ _, err = w.DerivePubKey(t.Context(), path)
+
+ // Assert: Check that the specific ErrNotPubKeyAddress is returned.
+ require.ErrorIs(t, err, ErrNotPubKeyAddress)
+ require.ErrorContains(t, err, "addr "+addr.String())
+}
+
+// TestECDHSuccess tests the successful ECDH key exchange.
+func TestECDHSuccess(t *testing.T) {
t.Parallel()
+ // Arrange: Set up the wallet, mocks, and test keys.
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ // Use a hardcoded private key for deterministic test results.
+ privKey, _ := deterministicPrivKey(t)
+
+ remoteKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ remotePubKey := remoteKey.PubKey()
+
+ path := BIP32Path{
+ KeyScope: waddrmgr.KeyScopeBIP0084,
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: 0,
+ Branch: 0,
+ Index: 0,
+ },
+ }
+
+ // Configure the full mock chain to return the test private key.
+ //
+ // NOTE: We must use a copy since the ECDH method will zero out the key.
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything, path.DerivationPath,
+ ).Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ // Act: Perform the ECDH operation.
+ sharedSecret, err := w.ECDH(t.Context(), path, remotePubKey)
+
+ // Assert: Check that the correct shared secret is returned.
+ require.NoError(t, err)
+
+ // Calculate the expected secret independently to verify.
+ expectedSecret := btcec.GenerateSharedSecret(privKey, remotePubKey)
+
+ var expectedSecretArray [32]byte
+ copy(expectedSecretArray[:], expectedSecret)
+
+ require.Equal(t, expectedSecretArray, sharedSecret)
+
+ // Finally, assert that the private key is zeroed out.
+ require.Equal(t, byte(0), privKeyCopy.Serialize()[0])
+}
+
+// TestECDHFails tests the failure case where the key derivation fails during
+// an ECDH operation.
+func TestECDHFails(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet and configure the mock addrStore to return
+ // an error.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0084}
+
+ remoteKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ remotePubKey := remoteKey.PubKey()
+
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return((*mockAccountStore)(nil), errDerivationFailed).Once()
+
+ // Act: Attempt to perform the ECDH operation.
+ _, err = w.ECDH(t.Context(), path, remotePubKey)
+
+ // Assert: Check that the error is propagated correctly.
+ require.ErrorIs(t, err, errDerivationFailed)
+}
+
+// deterministicPrivKey is a helper function that returns a deterministic
+// private and public key pair for testing purposes.
+func deterministicPrivKey(t *testing.T) (*btcec.PrivateKey, *btcec.PublicKey) {
+ t.Helper()
+
+ pkBytes, err := hex.DecodeString("22a47fa09a223f2aa079edf85a7c2d4f87" +
+ "20ee63e502ee2869afab7de234b80c")
+ require.NoError(t, err)
+
+ privKey, pubKey := btcec.PrivKeyFromBytes(pkBytes)
+
+ return privKey, pubKey
+}
+
+// TestSignDigest tests the signing of a message digest with different signature
+// types.
+func TestSignDigest(t *testing.T) {
+ t.Parallel()
+
+ // We'll use a common set of parameters for all signing test cases to
+ // ensure the only variable is the signing intent itself.
+ privKey, pubKey := deterministicPrivKey(t)
+ path := BIP32Path{
+ KeyScope: waddrmgr.KeyScopeBIP0084,
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: 0,
+ Branch: 0,
+ Index: 0,
+ },
+ }
+ msg := []byte("test message")
+ msgHash := chainhash.HashB(msg)
+ msgDoubleHash := chainhash.DoubleHashB(msg)
+ tag := []byte("test tag")
+ taggedHash := chainhash.TaggedHash(tag, msg)
+
testCases := []struct {
- name string
- scope waddrmgr.KeyScope
- expectedScriptLen int
- }{{
- name: "BIP084 P2WKH",
- scope: waddrmgr.KeyScopeBIP0084,
- expectedScriptLen: 0,
- }, {
- name: "BIP049 nested P2WKH",
- scope: waddrmgr.KeyScopeBIP0049Plus,
- expectedScriptLen: 23,
- }}
-
- w, cleanup := testWallet(t)
- defer cleanup()
+ // name is the name of the test case.
+ name string
+
+ // intent is the signing intent to use for the test.
+ intent *SignDigestIntent
+
+ // verify is a function that verifies the signature produced by
+ // the signing intent.
+ verify func(t *testing.T, sig Signature,
+ pubKey *btcec.PublicKey)
+ }{
+ {
+ name: "ECDSA success",
+ intent: &SignDigestIntent{
+ Digest: msgHash,
+ SigType: SigTypeECDSA,
+ CompactSig: false,
+ },
+ verify: func(t *testing.T, sig Signature,
+ pubKey *btcec.PublicKey) {
+
+ t.Helper()
+
+ ecdsaSig, ok := sig.(ECDSASignature)
+ require.True(t, ok, "expected ECDSASignature")
+ require.True(
+ t, ecdsaSig.Verify(msgHash, pubKey),
+ "signature invalid",
+ )
+ },
+ },
+ {
+ name: "ECDSA compact success",
+ intent: &SignDigestIntent{
+ Digest: msgDoubleHash,
+ SigType: SigTypeECDSA,
+ CompactSig: true,
+ },
+ verify: func(t *testing.T, sig Signature,
+ pubKey *btcec.PublicKey) {
+
+ t.Helper()
+
+ compactSig, ok := sig.(CompactSignature)
+ require.True(t, ok, "expected CompactSignature")
+ recoveredKey, _, err := ecdsa.RecoverCompact(
+ compactSig, msgDoubleHash,
+ )
+ require.NoError(t, err)
+ require.True(
+ t, recoveredKey.IsEqual(pubKey),
+ "recovered key mismatch",
+ )
+ },
+ },
+ {
+ name: "Schnorr success",
+ intent: &SignDigestIntent{
+ Digest: taggedHash[:],
+ SigType: SigTypeSchnorr,
+ },
+ verify: func(t *testing.T, sig Signature,
+ pubKey *btcec.PublicKey) {
+
+ t.Helper()
+
+ schnorrSig, ok := sig.(SchnorrSignature)
+ require.True(t, ok, "expected SchnorrSignature")
+
+ require.True(t,
+ schnorrSig.Verify(
+ taggedHash[:], pubKey,
+ ),
+ "signature invalid",
+ )
+ },
+ },
+ {
+ name: "Schnorr success with tweak",
+ intent: &SignDigestIntent{
+ Digest: msgHash,
+ SigType: SigTypeSchnorr,
+ TaprootTweak: []byte("test tweak"),
+ },
+ verify: func(t *testing.T, sig Signature,
+ pubKey *btcec.PublicKey) {
+
+ t.Helper()
+
+ schnorrSig, ok := sig.(SchnorrSignature)
+ require.True(t, ok, "expected SchnorrSignature")
+
+ // Calculate expected tweaked key and hash
+ tweak := []byte("test tweak")
+ tweakedKey := txscript.TweakTaprootPrivKey(
+ *privKey, tweak,
+ )
+ tweakedPub := tweakedKey.PubKey()
+
+ require.True(t,
+ schnorrSig.Verify(msgHash, tweakedPub),
+ "signature invalid for tweaked key",
+ )
+ },
+ },
+ }
for _, tc := range testCases {
- tc := tc
t.Run(tc.name, func(t *testing.T) {
- runTestCase(t, w, tc.scope, tc.expectedScriptLen)
+ t.Parallel()
+
+ // Arrange: Set up a mock wallet that will return our
+ // deterministic private key for the specified
+ // derivation path. This allows us to test the signing
+ // logic in isolation.
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ // Configure the full mock chain to return the test
+ // private key.
+ //
+ // NOTE: We must use a copy since the ECDH method will
+ // zero out the key.
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(
+ privKey.Serialize(),
+ )
+
+ mocks.addrStore.On(
+ "FetchScopedKeyManager", path.KeyScope,
+ ).Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything,
+ path.DerivationPath,
+ ).Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("PrivKey").Return(
+ privKeyCopy, nil,
+ ).Once()
+
+ // Act: Attempt to sign the message with the wallet.
+ sig, err := w.SignDigest(t.Context(), path, tc.intent)
+
+ // Assert: Verify that the signature was created
+ // successfully and is valid for the given public key.
+ // We also assert that the private key was cleared from
+ // memory after the operation.
+ require.NoError(t, err)
+ tc.verify(t, sig, pubKey)
+ require.Equal(t, byte(0), privKeyCopy.Serialize()[0])
})
}
}
-func runTestCase(t *testing.T, w *Wallet, scope waddrmgr.KeyScope,
- scriptLen int) {
+// TestSignDigestFail tests failure modes of SignDigest.
+func TestSignDigestFail(t *testing.T) {
+ t.Parallel()
+
+ w, mocks := createUnlockedWalletWithMocks(t)
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0084}
+
+ digest := make([]byte, 32)
+ intent := &SignDigestIntent{Digest: digest}
+
+ // Test Case 1: Fetching the key manager fails.
+ // We expect an `errManagerNotFound` error to be returned.
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return((*mockAccountStore)(nil), errManagerNotFound).Once()
+
+ _, err := w.SignDigest(t.Context(), path, intent)
+ require.ErrorIs(t, err, errManagerNotFound)
+
+ // Test Case 2: Obtaining the private key for signing fails.
+ // We expect a `privkey error` to be returned.
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On("DeriveFromKeyPath",
+ mock.Anything, mock.Anything).
+ Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("PrivKey").Return((*btcec.PrivateKey)(nil),
+ errPrivKeyMock).Once()
- // Create an address we can use to send some coins to.
- addr, err := w.CurrentAddress(0, scope)
- if err != nil {
- t.Fatalf("unable to get current address: %v", addr)
+ _, err = w.SignDigest(t.Context(), path, intent)
+ require.ErrorContains(t, err, "privkey error")
+}
+
+// TestValidateSignDigestIntent tests the validation logic for SignDigestIntent.
+func TestValidateSignDigestIntent(t *testing.T) {
+ t.Parallel()
+
+ validDigest := make([]byte, 32)
+ invalidDigest := make([]byte, 31)
+
+ testCases := []struct {
+ name string
+ intent *SignDigestIntent
+ wantErr error
+ }{
+ {
+ // A valid ECDSA intent with a 32-byte digest and no
+ // restricted fields should pass validation.
+ name: "valid ECDSA",
+ intent: &SignDigestIntent{
+ Digest: validDigest,
+ SigType: SigTypeECDSA,
+ },
+ wantErr: nil,
+ },
+ {
+ // A valid Schnorr intent with a 32-byte digest and no
+ // restricted fields should pass validation.
+ name: "valid Schnorr",
+ intent: &SignDigestIntent{
+ Digest: validDigest,
+ SigType: SigTypeSchnorr,
+ },
+ wantErr: nil,
+ },
+ {
+ // If the digest length is not 32 bytes, we expect an
+ // ErrInvalidDigestSize error.
+ name: "invalid digest length",
+ intent: &SignDigestIntent{
+ Digest: invalidDigest,
+ SigType: SigTypeECDSA,
+ },
+ wantErr: ErrInvalidDigestSize,
+ },
+ {
+ // If an ECDSA intent provides a Taproot Tweak, we
+ // expect an ErrInvalidSignParam error as tweaks are
+ // Schnorr-specific.
+ name: "ECDSA with Taproot Tweak",
+ intent: &SignDigestIntent{
+ Digest: validDigest,
+ SigType: SigTypeECDSA,
+ TaprootTweak: []byte("tweak"),
+ },
+ wantErr: ErrInvalidSignParam,
+ },
+ {
+ // If a Schnorr intent requests a Compact Signature, we
+ // expect an ErrInvalidSignParam error as compact sigs
+ // are ECDSA-specific.
+ name: "Schnorr with CompactSig",
+ intent: &SignDigestIntent{
+ Digest: validDigest,
+ SigType: SigTypeSchnorr,
+ CompactSig: true,
+ },
+ wantErr: ErrInvalidSignParam,
+ },
}
- p2shAddr, err := txscript.PayToAddrScript(addr)
- if err != nil {
- t.Fatalf("unable to convert wallet address to p2sh: %v", err)
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ err := validateSignDigestIntent(tc.intent)
+ require.ErrorIs(t, err, tc.wantErr)
+ })
+ }
+}
+
+// TestComputeUnlockingScriptP2PKH tests that the wallet can generate a valid
+// unlocking script for a P2PKH output.
+func TestComputeUnlockingScriptP2PKH(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet, keys, and a dummy transaction that will
+ // be used to spend the P2PKH output.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ privKey, pubKey := deterministicPrivKey(t)
+
+ // Create a P2PKH address and the corresponding previous output script.
+ // This is the output we want to create an unlocking script for.
+ addr, err := address.NewAddressPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()),
+ w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ prevOut, tx := createDummyTestTx(pkScript)
+
+ // The wallet needs to be able to find the private key for the given
+ // address. We mock the address store to return a mock address that,
+ // when queried, will provide the private key for signing. This
+ // simulates a real scenario where the wallet's address manager would
+ // fetch the key from the database.
+ mocks.addrStore.On("Address",
+ mock.Anything, addr,
+ ).Return(mocks.pubKeyAddr, nil)
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.PubKeyHash).Twice()
+
+ // Configure the full mock chain to return the test private key.
+ //
+ // NOTE: We must use a copy since the ECDH method will zero out the key.
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil)
+
+ // Act: With the setup complete, we can now ask the wallet to compute
+ // the unlocking script.
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+ params := &UnlockingScriptParams{
+ Tx: tx,
+ InputIndex: 0,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
}
+ script, err := w.ComputeUnlockingScript(t.Context(), params)
+ require.NoError(t, err)
+
+ // Assert: The computed script should be a valid unlocking script for
+ // the P2PKH output. We verify this by creating a new script engine
+ // and executing it with the generated script. A successful execution
+ // proves the script is correct.
+ require.NotNil(t, script.SigScript)
+ require.Nil(t, script.Witness)
+ tx.TxIn[0].SignatureScript = script.SigScript
+
+ vm, err := txscript.NewEngine(
+ prevOut.PkScript, tx, 0, txscript.StandardVerifyFlags, nil,
+ sigHashes, prevOut.Value, fetcher,
+ )
+ require.NoError(t, err)
+ require.NoError(t, vm.Execute(), "script execution failed")
- // Add an output paying to the wallet's address to the database.
- utxOut := wire.NewTxOut(100000, p2shAddr)
- incomingTx := &wire.MsgTx{
- TxIn: []*wire.TxIn{{}},
- TxOut: []*wire.TxOut{utxOut},
+ // Finally, we ensure that the private key was not mutated during the
+ // signing process.
+ require.Equal(t, byte(0), privKeyCopy.Serialize()[0])
+}
+
+// TestComputeUnlockingScriptP2WKH tests that the wallet can generate a valid
+// unlocking script for a P2WKH output.
+func TestComputeUnlockingScriptP2WKH(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet, keys, and a dummy transaction that will
+ // be used to spend the P2WKH output.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ privKey, pubKey := deterministicPrivKey(t)
+
+ // Create a P2WKH address and the corresponding previous output script.
+ addr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()),
+ w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ prevOut, tx := createDummyTestTx(pkScript)
+
+ // The wallet needs to be able to find the private key for the given
+ // address. We mock the address store to return a mock address that,
+ // when queried, will provide the private key for signing.
+ mocks.addrStore.On("Address",
+ mock.Anything, addr,
+ ).Return(mocks.pubKeyAddr, nil)
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.WitnessPubKey).Twice()
+
+ // Configure the full mock chain to return the test private key.
+ //
+ // NOTE: We must use a copy since the ECDH method will zero out the key.
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil)
+
+ // Act: With the setup complete, we can now ask the wallet to compute
+ // the unlocking script.
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+ params := &UnlockingScriptParams{
+ Tx: tx,
+ InputIndex: 0,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
}
- addUtxo(t, w, incomingTx)
+ script, err := w.ComputeUnlockingScript(t.Context(), params)
+ require.NoError(t, err)
+
+ // Assert: The computed script should be a valid unlocking script. For
+ // P2WKH, this means a nil SigScript and a non-nil Witness. We verify
+ // this by creating a new script engine and executing it.
+ require.Nil(t, script.SigScript)
+ require.NotNil(t, script.Witness)
+ tx.TxIn[0].Witness = script.Witness
+
+ vm, err := txscript.NewEngine(
+ prevOut.PkScript, tx, 0, txscript.StandardVerifyFlags, nil,
+ sigHashes, prevOut.Value, fetcher,
+ )
+ require.NoError(t, err)
+ require.NoError(t, vm.Execute(), "script execution failed")
+
+ // Finally, we ensure that the private key was not mutated during the
+ // signing process.
+ require.Equal(t, byte(0), privKeyCopy.Serialize()[0])
+}
+
+// TestComputeUnlockingScriptNP2WKH tests that the wallet can generate a valid
+// unlocking script for a nested P2WKH output.
+func TestComputeUnlockingScriptNP2WKH(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet, keys, and a dummy transaction.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ privKey, pubKey := deterministicPrivKey(t)
+
+ // Create a NP2WKH address. This is a P2WKH output nested within a
+ // P2SH output. This is done by creating the witness program first,
+ // and then using its hash in a P2SH script.
+ p2sh, err := txscript.NewScriptBuilder().
+ AddOp(txscript.OP_0).
+ AddData(address.Hash160(pubKey.SerializeCompressed())).
+ Script()
+ require.NoError(t, err)
+ addr, err := address.NewAddressScriptHash(p2sh, w.cfg.ChainParams)
+ require.NoError(t, err)
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
- // Create a transaction that spends the UTXO created above and spends to
- // the same address again.
- prevOut := wire.OutPoint{
- Hash: incomingTx.TxHash(),
- Index: 0,
+ prevOut, tx := createDummyTestTx(pkScript)
+
+ // The wallet needs to be able to find the private key for the given
+ // address. We mock the address store to return a mock address that,
+ // when queried, will provide the private key for signing. For NP2WKH,
+ // the wallet also needs the public key to reconstruct the witness
+ // program, so we mock that as well.
+ mocks.addrStore.On("Address",
+ mock.Anything, addr,
+ ).Return(mocks.pubKeyAddr, nil)
+ mocks.pubKeyAddr.On("AddrType").Return(
+ waddrmgr.NestedWitnessPubKey).Twice()
+ mocks.pubKeyAddr.On("PubKey").Return(pubKey)
+
+ // Configure the full mock chain to return the test private key.
+ //
+ // NOTE: We must use a copy since the ECDH method will zero out the key.
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil)
+
+ // Act: With the setup complete, we can now ask the wallet to compute
+ // the unlocking script.
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+ params := &UnlockingScriptParams{
+ Tx: tx,
+ InputIndex: 0,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
}
- outgoingTx := &wire.MsgTx{
- TxIn: []*wire.TxIn{{
- PreviousOutPoint: prevOut,
- }},
- TxOut: []*wire.TxOut{utxOut},
+ script, err := w.ComputeUnlockingScript(t.Context(), params)
+ require.NoError(t, err)
+
+ // Assert: The computed script should be a valid unlocking script. For
+ // NP2WKH, this means both a non-nil SigScript (containing the redeem
+ // script) and a non-nil Witness. We verify this by creating a new
+ // script engine and executing it.
+ require.NotNil(t, script.SigScript)
+ require.NotNil(t, script.Witness)
+ tx.TxIn[0].SignatureScript = script.SigScript
+ tx.TxIn[0].Witness = script.Witness
+
+ vm, err := txscript.NewEngine(
+ prevOut.PkScript, tx, 0, txscript.StandardVerifyFlags, nil,
+ sigHashes, prevOut.Value, fetcher,
+ )
+ require.NoError(t, err)
+ require.NoError(t, vm.Execute(), "script execution failed")
+
+ // Finally, we ensure that the private key was not mutated during the
+ // signing process.
+ require.Equal(t, byte(0), privKeyCopy.Serialize()[0])
+}
+
+// TestComputeUnlockingScriptP2TR tests that the wallet can generate a valid
+// unlocking script for a P2TR key-path spend.
+func TestComputeUnlockingScriptP2TR(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet, keys, and a dummy transaction.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ privKey, pubKey := deterministicPrivKey(t)
+
+ // Create a P2TR address for a key-path spend. This involves computing
+ // the taproot output key from the internal public key.
+ addr, err := address.NewAddressTaproot(
+ schnorr.SerializePubKey(
+ txscript.ComputeTaprootOutputKey(pubKey, nil),
+ ), w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ prevOut, tx := createDummyTestTx(pkScript)
+
+ // The wallet needs to be able to find the private key for the given
+ // address. We mock the address store to return a mock address that,
+ // when queried, will provide the private key for signing.
+ mocks.addrStore.On("Address",
+ mock.Anything, addr,
+ ).Return(mocks.pubKeyAddr, nil)
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.TaprootPubKey).Twice()
+
+ // Configure the full mock chain to return the test private key.
+ //
+ // NOTE: We must use a copy since the ECDH method will zero out the key.
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil)
+
+ // Act: With the setup complete, we can now ask the wallet to compute
+ // the unlocking script. For Taproot, we must use a multi-output
+ // fetcher, as the sighash calculation (specifically with
+ // SigHashDefault) requires access to all previous outputs being spent
+ // in the transaction.
+ fetcher := txscript.NewMultiPrevOutFetcher(
+ map[wire.OutPoint]*wire.TxOut{{Index: 0}: prevOut},
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+ params := &UnlockingScriptParams{
+ Tx: tx,
+ InputIndex: 0,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashDefault,
}
+ script, err := w.ComputeUnlockingScript(t.Context(), params)
+ require.NoError(t, err)
+
+ // Assert: The computed script should be a valid unlocking script. For a
+ // P2TR key-path spend, this means a nil SigScript and a non-nil
+ // Witness containing just the Schnorr signature. We verify this by
+ // creating a new script engine and executing it.
+ require.Nil(t, script.SigScript)
+ require.NotNil(t, script.Witness)
+ tx.TxIn[0].Witness = script.Witness
+
+ vm, err := txscript.NewEngine(
+ prevOut.PkScript, tx, 0, txscript.StandardVerifyFlags, nil,
+ sigHashes, prevOut.Value, fetcher,
+ )
+ require.NoError(t, err)
+ require.NoError(t, vm.Execute(), "script execution failed")
+
+ // Finally, we ensure that the private key was not mutated during the
+ // signing process.
+ require.Equal(t, byte(0), privKeyCopy.Serialize()[0])
+}
+
+// TestComputeUnlockingScriptFail_ScriptForOutput tests failure when
+// ScriptForOutput returns an error.
+func TestComputeUnlockingScriptFail_ScriptForOutput(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up keys, address, and transaction.
+ _, pubKey := deterministicPrivKey(t)
+ addr, err := address.NewAddressPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ // Create fresh mutable state.
+ prevOut, tx := createDummyTestTx(pkScript)
fetcher := txscript.NewCannedPrevOutputFetcher(
- utxOut.PkScript, utxOut.Value,
+ prevOut.PkScript, prevOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ // Arrange: Set up the wallet and mocks.
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ // Mock the address store to return an error.
+ mocks.addrStore.On("Address", mock.Anything, addr).
+ Return((*mockManagedAddress)(nil), errManagerNotFound).Once()
+
+ params := &UnlockingScriptParams{
+ Tx: tx,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
+ }
+
+ // Act: Attempt to compute the unlocking script.
+ _, err = w.ComputeUnlockingScript(t.Context(), params)
+
+ // Assert: Verify error.
+ require.ErrorContains(t, err, "unable to get address info")
+}
+
+// TestComputeUnlockingScriptFail_PrivKey tests failure when private key
+// retrieval fails.
+func TestComputeUnlockingScriptFail_PrivKey(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up keys, address, and transaction.
+ _, pubKey := deterministicPrivKey(t)
+ addr, err := address.NewAddressPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
)
- sigHashes := txscript.NewTxSigHashes(outgoingTx, fetcher)
+ require.NoError(t, err)
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
- // Compute the input script to spend the UTXO now.
- witness, script, err := w.ComputeInputScript(
- outgoingTx, utxOut, 0, sigHashes, txscript.SigHashAll, nil,
+ // Create fresh mutable state.
+ prevOut, tx := createDummyTestTx(pkScript)
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
)
- if err != nil {
- t.Fatalf("error computing input script: %v", err)
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ // Arrange: Set up the wallet and mocks.
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ // Mock address store and managed address.
+ mocks.addrStore.On("Address", mock.Anything, addr).
+ Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.PubKeyHash)
+
+ // Mock private key retrieval failure.
+ mocks.pubKeyAddr.On("PrivKey").Return((*btcec.PrivateKey)(nil),
+ errPrivKeyMock).Once()
+
+ params := &UnlockingScriptParams{
+ Tx: tx,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
}
- if len(script) != scriptLen {
- t.Fatalf("unexpected script length, got %d wanted %d",
- len(script), scriptLen)
+
+ // Act: Attempt to compute the unlocking script.
+ _, err = w.ComputeUnlockingScript(t.Context(), params)
+
+ // Assert: Verify error.
+ require.ErrorContains(t, err, "privkey error")
+}
+
+// TestComputeUnlockingScriptFail_Tweak tests failure when the tweaker fails.
+func TestComputeUnlockingScriptFail_Tweak(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up keys, address, and transaction.
+ privKey, pubKey := deterministicPrivKey(t)
+ addr, err := address.NewAddressPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ // Create fresh mutable state.
+ prevOut, tx := createDummyTestTx(pkScript)
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ // Arrange: Set up the wallet and mocks.
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ // Mock address store and managed address.
+ mocks.addrStore.On("Address", mock.Anything, addr).
+ Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.PubKeyHash)
+
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ // Define failing tweaker.
+ params := &UnlockingScriptParams{
+ Tx: tx,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
+ Tweaker: func(*btcec.PrivateKey) (*btcec.PrivateKey, error) {
+ return nil, errTweakMock
+ },
}
- if len(witness) != 2 {
- t.Fatalf("unexpected witness stack length, got %d, wanted %d",
- len(witness), 2)
+
+ // Act: Attempt to compute the unlocking script.
+ _, err = w.ComputeUnlockingScript(t.Context(), params)
+
+ // Assert: Verify error.
+ require.ErrorContains(t, err, "tweak error")
+}
+
+// TestComputeUnlockingScriptFail_UnsupportedAddr tests failure when the
+// address type is unsupported.
+func TestComputeUnlockingScriptFail_UnsupportedAddr(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up keys, address, and transaction.
+ privKey, pubKey := deterministicPrivKey(t)
+ addr, err := address.NewAddressPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()), &chainParams,
+ )
+ require.NoError(t, err)
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ // Create fresh mutable state.
+ prevOut, tx := createDummyTestTx(pkScript)
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ // Arrange: Set up the wallet and mocks.
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ // Mock address store and managed address.
+ mocks.addrStore.On("Address", mock.Anything, addr).
+ Return(mocks.pubKeyAddr, nil).Once()
+
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ // Mock unsupported address type.
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.RawPubKey)
+
+ params := &UnlockingScriptParams{
+ Tx: tx,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
}
- // Finally verify that the created witness is valid.
- outgoingTx.TxIn[0].Witness = witness
- outgoingTx.TxIn[0].SignatureScript = script
- err = validateMsgTx(
- outgoingTx, [][]byte{utxOut.PkScript}, []btcutil.Amount{100000},
+ // Act: Attempt to compute the unlocking script.
+ _, err = w.ComputeUnlockingScript(t.Context(), params)
+
+ // Assert: Verify error.
+ require.ErrorIs(t, err, ErrUnsupportedAddressType)
+}
+
+// TestComputeUnlockingScriptUnknownAddrType tests the default case in
+// signAndAssembleScript by using an address with an unknown type.
+func TestComputeUnlockingScriptUnknownAddrType(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet, mocks, keys, and transaction.
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ privKey, pubKey := deterministicPrivKey(t)
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+ addr, err := address.NewAddressPubKeyHash(
+ address.Hash160(pubKey.SerializeCompressed()),
+ w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ prevOut, tx := createDummyTestTx(pkScript)
+
+ // Mock address lookup to return a valid managed address.
+ mocks.addrStore.On("Address", mock.Anything, addr).
+ Return(mocks.pubKeyAddr, nil).Once()
+
+ // Mock private key retrieval to succeed.
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ // Mock the address type to return an unknown type (e.g. 99) that falls
+ // through the switch statement in signAndAssembleScript.
+ mocks.pubKeyAddr.On("AddrType").Return(waddrmgr.AddressType(99))
+
+ fetcher := txscript.NewCannedPrevOutputFetcher(pkScript, 10000)
+
+ params := &UnlockingScriptParams{
+ Tx: tx,
+ Output: prevOut,
+ SigHashes: txscript.NewTxSigHashes(tx, fetcher),
+ HashType: txscript.SigHashAll,
+ }
+
+ // Act: Attempt to compute the unlocking script.
+ _, err = w.ComputeUnlockingScript(t.Context(), params)
+
+ // Assert: Verify that the unsupported address type error is returned.
+ require.ErrorIs(t, err, ErrUnsupportedAddressType)
+}
+
+// createDummyTestTx creates a dummy transaction for testing purposes.
+func createDummyTestTx(pkScript []byte) (*wire.TxOut, *wire.MsgTx) {
+ prevOut := wire.NewTxOut(100000, pkScript)
+ tx := wire.NewMsgTx(2)
+ tx.AddTxIn(wire.NewTxIn(&wire.OutPoint{Index: 0}, nil, nil))
+ tx.AddTxOut(wire.NewTxOut(90000, nil))
+
+ return prevOut, tx
+}
+
+// TestComputeRawSigLegacyP2PKH tests the successful signing of a legacy P2PKH
+// input.
+func TestComputeRawSigLegacyP2PKH(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet, mocks, and a deterministic private key.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ privKey, pubKey := deterministicPrivKey(t)
+
+ // Create a P2PKH address from the public key.
+ pubKeyHash := address.Hash160(pubKey.SerializeCompressed())
+ addr, err := address.NewAddressPubKeyHash(
+ pubKeyHash, w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ // Create a previous output and a transaction to spend it.
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ prevOut, tx := createDummyTestTx(pkScript)
+
+ // Configure the full mock chain to return the test private key.
+ //
+ // NOTE: We must use a copy since the ECDH method will zero out the key.
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0084}
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything, mock.Anything,
+ ).Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ // Create the raw signature parameters.
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ params := &RawSigParams{
+ Tx: tx,
+ InputIndex: 0,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
+ Path: path,
+ Details: LegacySpendDetails{},
+ }
+
+ // Act: Compute the raw signature.
+ rawSig, err := w.ComputeRawSig(t.Context(), params)
+ require.NoError(t, err)
+
+ // Assert: Verify that the signature is valid.
+ sigScript, err := txscript.NewScriptBuilder().
+ AddData(rawSig).
+ AddData(pubKey.SerializeCompressed()).
+ Script()
+ require.NoError(t, err)
+
+ tx.TxIn[0].SignatureScript = sigScript
+
+ // The signature is valid if the script engine executes without error.
+ vm, err := txscript.NewEngine(
+ prevOut.PkScript, tx, 0, txscript.StandardVerifyFlags, nil,
+ sigHashes, prevOut.Value, txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ ),
+ )
+ require.NoError(t, err)
+ require.NoError(t, vm.Execute(), "signature verification failed")
+
+ // Finally, assert that the private key is zeroed out.
+ require.Equal(t, byte(0), privKeyCopy.Serialize()[0])
+}
+
+// TestComputeRawSigLegacyP2SH tests the signing of a legacy P2SH input.
+func TestComputeRawSigLegacyP2SH(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet with mocks and a deterministic private
+ // key for testing.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ privKey, pubKey := deterministicPrivKey(t)
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+
+ // Create a P2SH redeem script. This involves pushing the public key
+ // and the CHECKSIG opcode.
+ redeemScript, err := txscript.NewScriptBuilder().
+ AddOp(txscript.OP_DATA_33).
+ AddData(pubKey.SerializeCompressed()).
+ AddOp(txscript.OP_CHECKSIG).
+ Script()
+ require.NoError(t, err)
+
+ // Create the P2SH address corresponding to the redeem script hash.
+ addr, err := address.NewAddressScriptHash(
+ redeemScript, w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ // Create the Pay-To-Addr script (P2SH script) which will be the
+ // pkScript of the previous output.
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ // Create a dummy transaction and a previous output to spend.
+ prevOut, tx := createDummyTestTx(pkScript)
+
+ // Configure the address manager mock to return the correct key manager
+ // and address information. P2SH addresses use BIP0049 derivation scope.
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0049Plus}
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On("DeriveFromKeyPath",
+ mock.Anything, mock.Anything).
+ Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ // Prepare the inputs for the signing operation.
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
)
- if err != nil {
- t.Fatalf("error validating tx: %v", err)
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ params := &RawSigParams{
+ Tx: tx,
+ InputIndex: 0,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
+ Path: path,
+ Details: LegacySpendDetails{
+ RedeemScript: redeemScript,
+ },
}
+
+ // Act: Compute the raw signature using the wallet.
+ rawSig, err := w.ComputeRawSig(t.Context(), params)
+
+ // Assert: Verify that no error occurred and a signature was generated.
+ require.NoError(t, err)
+ require.NotEmpty(t, rawSig)
+}
+
+// TestComputeRawSigSegwitV0 tests the successful signing of a SegWit v0 P2WKH
+// input.
+func TestComputeRawSigSegwitV0(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet, mocks, and a deterministic private key.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ privKey, pubKey := deterministicPrivKey(t)
+
+ // Create a P2WKH address from the public key.
+ pubKeyHash := address.Hash160(pubKey.SerializeCompressed())
+ addr, err := address.NewAddressWitnessPubKeyHash(
+ pubKeyHash, w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ // Create a previous output and a transaction to spend it.
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ prevOut, tx := createDummyTestTx(pkScript)
+
+ // Configure the full mock chain to return the test private key.
+ //
+ // NOTE: We must use a copy since the ECDH method will zero out the key.
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0084}
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything, mock.Anything,
+ ).Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ // Create the raw signature parameters.
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+ witnessScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ params := &RawSigParams{
+ Tx: tx,
+ InputIndex: 0,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
+ Path: path,
+ Details: SegwitV0SpendDetails{
+ WitnessScript: witnessScript,
+ },
+ }
+
+ // Act: Compute the raw signature.
+ rawSig, err := w.ComputeRawSig(t.Context(), params)
+ require.NoError(t, err)
+
+ // Assert: Verify that the signature is valid.
+ // We need to append the sighash type to the raw signature.
+ rawSig = append(rawSig, byte(txscript.SigHashAll))
+ tx.TxIn[0].Witness = wire.TxWitness{
+ rawSig, pubKey.SerializeCompressed(),
+ }
+
+ // The signature is valid if the script engine executes without error.
+ vm, err := txscript.NewEngine(
+ prevOut.PkScript, tx, 0, txscript.StandardVerifyFlags, nil,
+ sigHashes, prevOut.Value, txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ ),
+ )
+ require.NoError(t, err)
+ require.NoError(t, vm.Execute(), "signature verification failed")
+
+ // Finally, assert that the private key is zeroed out.
+ require.Equal(t, byte(0), privKeyCopy.Serialize()[0])
+}
+
+// TestComputeRawSigTaprootKeySpendPath tests the successful signing of a
+// Taproot P2TR input using the key-path spend.
+func TestComputeRawSigTaprootKeySpendPath(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet, mocks, and a deterministic private key.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ privKey, internalKey := deterministicPrivKey(t)
+
+ // Create a P2TR address from the public key.
+ addr, err := address.NewAddressTaproot(
+ schnorr.SerializePubKey(
+ txscript.ComputeTaprootOutputKey(internalKey, nil),
+ ), w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ // Create a previous output and a transaction to spend it.
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ prevOut, tx := createDummyTestTx(pkScript)
+
+ // Configure the full mock chain to return the test private key.
+ //
+ // NOTE: We must use a copy since the ECDH method will zero out the key.
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0086}
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything, mock.Anything,
+ ).Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ // Create the raw signature parameters.
+ fetcher := txscript.NewMultiPrevOutFetcher(
+ map[wire.OutPoint]*wire.TxOut{
+ {Index: 0}: prevOut,
+ },
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ params := &RawSigParams{
+ Tx: tx,
+ InputIndex: 0,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashDefault,
+ Path: path,
+ Details: TaprootSpendDetails{
+ SpendPath: KeyPathSpend,
+ },
+ }
+
+ // Act: Compute the raw signature.
+ rawSig, err := w.ComputeRawSig(t.Context(), params)
+ require.NoError(t, err)
+
+ // Assert: Verify that the signature is valid.
+ tx.TxIn[0].Witness = wire.TxWitness{rawSig}
+ vm, err := txscript.NewEngine(
+ pkScript, tx, 0, txscript.StandardVerifyFlags, nil, sigHashes,
+ prevOut.Value, txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ ),
+ )
+ require.NoError(t, err)
+ require.NoError(t, vm.Execute(), "signature verification failed")
+
+ // Finally, assert that the private key is zeroed out.
+ require.Equal(t, byte(0), privKeyCopy.Serialize()[0])
+}
+
+// TestComputeRawSigTaprootScriptPath tests the successful signing of a Taproot
+// P2TR input using the script-path spend.
+func TestComputeRawSigTaprootScriptPath(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet, mocks, and a deterministic private key.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ privKey, internalKey := deterministicPrivKey(t)
+
+ // Create a script to spend.
+ script, err := txscript.NewScriptBuilder().
+ AddData(schnorr.SerializePubKey(internalKey)).
+ AddOp(txscript.OP_CHECKSIG).
+ Script()
+ require.NoError(t, err)
+
+ leaf := txscript.NewBaseTapLeaf(script)
+ tapScriptTree := txscript.AssembleTaprootScriptTree(leaf)
+ rootHash := tapScriptTree.RootNode.TapHash()
+ outputKey := txscript.ComputeTaprootOutputKey(internalKey, rootHash[:])
+
+ // Create a P2TR address from the output key.
+ addr, err := address.NewAddressTaproot(
+ schnorr.SerializePubKey(outputKey), w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ // Create a previous output and a transaction to spend it.
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ prevOut, tx := createDummyTestTx(pkScript)
+
+ // Configure the full mock chain to return the test private key.
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0086}
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything, mock.Anything,
+ ).Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ // Create the raw signature parameters.
+ fetcher := txscript.NewMultiPrevOutFetcher(
+ map[wire.OutPoint]*wire.TxOut{
+ {Index: 0}: prevOut,
+ },
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ params := &RawSigParams{
+ Tx: tx,
+ InputIndex: 0,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashDefault,
+ Path: path,
+ Details: TaprootSpendDetails{
+ SpendPath: ScriptPathSpend,
+ WitnessScript: script,
+ },
+ }
+
+ // Act: Compute the raw signature.
+ rawSig, err := w.ComputeRawSig(t.Context(), params)
+ require.NoError(t, err)
+
+ // Assert: Verify that the signature is valid.
+ // For script path, we need the control block.
+ ctrlBlock := tapScriptTree.LeafMerkleProofs[0].ToControlBlock(
+ internalKey,
+ )
+ ctrlBlockBytes, err := ctrlBlock.ToBytes()
+ require.NoError(t, err)
+
+ tx.TxIn[0].Witness = wire.TxWitness{
+ rawSig, script, ctrlBlockBytes,
+ }
+ vm, err := txscript.NewEngine(
+ pkScript, tx, 0, txscript.StandardVerifyFlags, nil, sigHashes,
+ prevOut.Value, txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ ),
+ )
+ require.NoError(t, err)
+ require.NoError(t, vm.Execute(), "signature verification failed")
+
+ // Finally, assert that the private key is zeroed out.
+ require.Equal(t, byte(0), privKeyCopy.Serialize()[0])
+}
+
+// TestComputeRawSigFail tests various failure modes of ComputeRawSig.
+func TestComputeRawSigFail(t *testing.T) {
+ t.Parallel()
+
+ privKey, _ := deterministicPrivKey(t)
+
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0084}
+ prevOut := &wire.TxOut{PkScript: []byte{0x00}}
+ tx := wire.NewMsgTx(2)
+
+ fetcher := txscript.NewCannedPrevOutputFetcher(
+ prevOut.PkScript, prevOut.Value,
+ )
+ sigHashes := txscript.NewTxSigHashes(tx, fetcher)
+
+ // This subtest ensures that if fetching the key manager fails during
+ // the raw signature computation, the error is correctly propagated.
+ t.Run("Fetch Address Fail", func(t *testing.T) {
+ t.Parallel()
+ w, mocks := createUnlockedWalletWithMocks(t)
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return((*mockAccountStore)(nil),
+ errManagerNotFound).Once()
+
+ params := &RawSigParams{
+ Tx: tx,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
+ Path: path,
+ Details: LegacySpendDetails{},
+ }
+
+ _, err := w.ComputeRawSig(t.Context(), params)
+ require.ErrorIs(t, err, errManagerNotFound)
+ })
+
+ // This subtest ensures that if obtaining the private key from the
+ // managed address fails during raw signature computation, the error is
+ // correctly propagated.
+ t.Run("PrivKey Fail", func(t *testing.T) {
+ t.Parallel()
+ w, mocks := createUnlockedWalletWithMocks(t)
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+
+ mocks.accountManager.On("DeriveFromKeyPath",
+ mock.Anything, mock.Anything).
+ Return(mocks.pubKeyAddr, nil).Once()
+
+ mocks.pubKeyAddr.On("PrivKey").Return((*btcec.PrivateKey)(nil),
+ errPrivKeyMock).Once()
+
+ params := &RawSigParams{
+ Tx: tx,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
+ Path: path,
+ Details: LegacySpendDetails{},
+ }
+
+ _, err := w.ComputeRawSig(t.Context(), params)
+ require.ErrorContains(t, err, "privkey error")
+ })
+
+ // This subtest verifies that if the private key tweaking function
+ // returns an error, the raw signature computation correctly propagates
+ // that error.
+ t.Run("Tweak Fail", func(t *testing.T) {
+ t.Parallel()
+ w, mocks := createUnlockedWalletWithMocks(t)
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+
+ mocks.accountManager.On("DeriveFromKeyPath",
+ mock.Anything, mock.Anything).
+ Return(mocks.pubKeyAddr, nil).Once()
+
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ params := &RawSigParams{
+ Tx: tx,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
+ Path: path,
+ Details: LegacySpendDetails{},
+ Tweaker: func(*btcec.PrivateKey) (
+ *btcec.PrivateKey, error) {
+
+ return nil, errTweakMock
+ },
+ }
+
+ _, err := w.ComputeRawSig(t.Context(), params)
+ require.ErrorContains(t, err, "tweak error")
+ })
+
+ // This subtest ensures that if the underlying `Sign` method of the
+ // spend details returns an error, the raw signature computation
+ // correctly propagates that error.
+ t.Run("Sign Fail", func(t *testing.T) {
+ t.Parallel()
+ w, mocks := createUnlockedWalletWithMocks(t)
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+
+ mocks.accountManager.On("DeriveFromKeyPath",
+ mock.Anything, mock.Anything).
+ Return(mocks.pubKeyAddr, nil).Once()
+
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ params := &RawSigParams{
+ Tx: tx,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: txscript.SigHashAll,
+ Path: path,
+ Details: LegacySpendDetails{},
+ }
+ mockDetails := &mockSpendDetails{}
+ params.Details = mockDetails
+
+ mockDetails.On("Sign", params, privKeyCopy).
+ Return((RawSignature)(nil),
+ errSignMock)
+ _, err := w.ComputeRawSig(t.Context(), params)
+ require.ErrorContains(t, err, "sign error")
+ mockDetails.AssertExpectations(t)
+ })
+
+ // This subtest verifies that an error is returned when an unsupported
+ // Taproot spend path is provided, ensuring robust error handling for
+ // invalid configurations.
+ t.Run("Invalid Taproot Path", func(t *testing.T) {
+ t.Parallel()
+ w, mocks := createUnlockedWalletWithMocks(t)
+
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0086}
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+
+ mocks.accountManager.On("DeriveFromKeyPath",
+ mock.Anything, mock.Anything).
+ Return(mocks.pubKeyAddr, nil).Once()
+
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ params := &RawSigParams{
+ Tx: wire.NewMsgTx(2),
+ Path: path,
+ Details: TaprootSpendDetails{
+ SpendPath: TaprootSpendPath(99), // Invalid path
+ },
+ }
+
+ _, err := w.ComputeRawSig(t.Context(), params)
+ require.ErrorIs(t, err, ErrUnknownSignMethod)
+ })
+
+ // This subtest verifies that if the SegWit v0 signing process fails
+ // (e.g., due to invalid parameters like an invalid hash type), the
+ // error is correctly propagated.
+ t.Run("Segwit Sign Fail", func(t *testing.T) {
+ t.Parallel()
+ w, mocks := createUnlockedWalletWithMocks(t)
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On("DeriveFromKeyPath",
+ mock.Anything, mock.Anything).
+ Return(mocks.pubKeyAddr, nil).Once()
+
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ params := &RawSigParams{
+ Tx: tx,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: 0xff,
+ Path: path,
+ Details: SegwitV0SpendDetails{
+ WitnessScript: []byte{},
+ },
+ }
+
+ _, err := w.ComputeRawSig(t.Context(), params)
+ require.Error(t, err)
+ })
+
+ // This subtest verifies that if the Taproot KeyPath signing process
+ // fails (e.g., due to invalid parameters), the error is correctly
+ // propagated.
+ t.Run("Taproot KeyPath Sign Fail", func(t *testing.T) {
+ t.Parallel()
+ w, mocks := createUnlockedWalletWithMocks(t)
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On("DeriveFromKeyPath",
+ mock.Anything, mock.Anything).
+ Return(mocks.pubKeyAddr, nil).Once()
+
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ params := &RawSigParams{
+ Tx: tx,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: 0xff,
+ Path: path,
+ Details: TaprootSpendDetails{SpendPath: KeyPathSpend},
+ }
+
+ _, err := w.ComputeRawSig(t.Context(), params)
+ require.Error(t, err)
+ })
+
+ // This subtest verifies that if the Taproot ScriptPath signing process
+ // fails (e.g., due to invalid parameters), the error is correctly
+ // propagated.
+ t.Run("Taproot ScriptPath Sign Fail", func(t *testing.T) {
+ t.Parallel()
+ w, mocks := createUnlockedWalletWithMocks(t)
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On("DeriveFromKeyPath",
+ mock.Anything, mock.Anything).
+ Return(mocks.pubKeyAddr, nil).Once()
+
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ params := &RawSigParams{
+ Tx: tx,
+ Output: prevOut,
+ SigHashes: sigHashes,
+ HashType: 0xff,
+ Path: path,
+ Details: TaprootSpendDetails{
+ SpendPath: ScriptPathSpend,
+ WitnessScript: []byte{0x51},
+ },
+ }
+
+ _, err := w.ComputeRawSig(t.Context(), params)
+ require.Error(t, err)
+ })
+}
+
+// TestDerivePrivKeySuccess tests the successful derivation of a private key.
+func TestDerivePrivKeySuccess(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet with mocks, a test key, and a
+ // derivation path.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ path := BIP32Path{
+ KeyScope: waddrmgr.KeyScopeBIP0084,
+ DerivationPath: waddrmgr.DerivationPath{
+ InternalAccount: 0,
+ Branch: 0,
+ Index: 0,
+ },
+ }
+
+ // Set up the mock account manager and the mock address that will be
+ // returned by the derivation call.
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything, path.DerivationPath,
+ ).Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ // Act: Derive the private key.
+ derivedKey, err := w.DerivePrivKey(t.Context(), path)
+
+ // Assert: Check that the correct key is returned without error.
+ require.NoError(t, err)
+ require.Equal(t, privKey.Serialize(), derivedKey.Serialize())
+}
+
+// TestDerivePrivKeyFails tests the failure case where the key derivation fails.
+func TestDerivePrivKeyFails(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet and a test path. Configure the mock
+ // addrStore to return an error when fetching the key manager.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0084}
+
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return((*mockAccountStore)(nil), errManagerNotFound).Once()
+
+ // Act: Attempt to derive the private key.
+ _, err := w.DerivePrivKey(t.Context(), path)
+
+ // Assert: Check that the error is propagated correctly.
+ require.ErrorIs(t, err, errManagerNotFound)
+}
+
+// TestGetPrivKeyForAddressSuccess tests the successful retrieval of a private
+// key by address.
+func TestGetPrivKeyForAddressSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet, mocks, and a deterministic private key.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ privKey, pubKey := deterministicPrivKey(t)
+
+ // Create a P2PKH address from the public key.
+ pubKeyHash := address.Hash160(pubKey.SerializeCompressed())
+ addr, err := address.NewAddressPubKeyHash(
+ pubKeyHash, w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ // Configure the mock chain to return the test private key.
+ //
+ // NOTE: We must use a copy since the method will zero out the key.
+ privKeyCopy, _ := btcec.PrivKeyFromBytes(privKey.Serialize())
+
+ mocks.addrStore.On("Address", mock.Anything, addr).
+ Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("PrivKey").Return(privKeyCopy, nil).Once()
+
+ // Act: Get the private key for the address.
+ retrievedKey, err := w.GetPrivKeyForAddress(t.Context(), addr)
+
+ // Assert: Check that the correct key is returned.
+ require.NoError(t, err)
+ require.Equal(t, privKey.Serialize(), retrievedKey.Serialize())
+}
+
+// TestGetPrivKeyForAddressFail tests the failure cases for retrieval of a
+// private key by address.
+func TestGetPrivKeyForAddressFail(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Set up the wallet and mocks.
+ w, mocks := createUnlockedWalletWithMocks(t)
+ addr, err := address.NewAddressPubKeyHash(
+ make([]byte, 20), w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ // Case 1: Address lookup fails.
+ mocks.addrStore.On("Address", mock.Anything, addr).
+ Return((*mockManagedAddress)(nil), errManagerNotFound).Once()
+
+ _, err = w.GetPrivKeyForAddress(t.Context(), addr)
+ require.ErrorIs(t, err, errManagerNotFound)
+
+ // Case 2: Address is not a pubkey address.
+ // We need a separate mock for this to ensure clean separation.
+ //
+ // NOTE: We can reuse the existing mocks but need to reset expectations
+ // or ensure ordering. Since we are in a single test function, we can
+ // just sequence them.
+ mockScriptAddr := &mockManagedAddress{}
+ mocks.addrStore.On("Address", mock.Anything, addr).
+ Return(mockScriptAddr, nil).Once()
+
+ _, err = w.GetPrivKeyForAddress(t.Context(), addr)
+ require.ErrorIs(t, err, ErrNoAssocPrivateKey)
+}
+
+// TestDerivePrivKeyFail tests failure modes of DerivePrivKey.
+func TestDerivePrivKeyFail(t *testing.T) {
+ t.Parallel()
+
+ w, mocks := createUnlockedWalletWithMocks(t)
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0084}
+
+ // Test Case 1: Fetching key manager fails.
+ //
+ // We mock the address store to return an error when fetching the key
+ // manager. We expect this error to be propagated.
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return((*mockAccountStore)(nil), errManagerNotFound).Once()
+
+ _, err := w.DerivePrivKey(t.Context(), path)
+ require.ErrorIs(t, err, errManagerNotFound)
+
+ // Test Case 2: PrivKey retrieval fails.
+ //
+ // We mock the key manager to return a valid address, but mock the
+ // address to return an error when fetching the private key. We expect
+ // a wrapped error indicating the failure.
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything, mock.Anything,
+ ).Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("PrivKey").
+ Return((*btcec.PrivateKey)(nil), errPrivKeyMock).Once()
+
+ _, err = w.DerivePrivKey(t.Context(), path)
+ require.ErrorContains(t, err, "cannot get private key")
+}
+
+// TestECDHFail tests failure modes of ECDH.
+func TestECDHFail(t *testing.T) {
+ t.Parallel()
+
+ w, mocks := createUnlockedWalletWithMocks(t)
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0084}
+ privKey, _ := btcec.NewPrivateKey()
+
+ // Test Case 1: Fetching key manager fails.
+ //
+ // We mock the address store to return an error when fetching the key
+ // manager. We expect this error to be propagated.
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return((*mockAccountStore)(nil), errManagerNotFound).Once()
+
+ _, err := w.ECDH(t.Context(), path, privKey.PubKey())
+ require.ErrorIs(t, err, errManagerNotFound)
+
+ // Test Case 2: PrivKey retrieval fails.
+ //
+ // We mock the key manager to return a valid address, but mock the
+ // address to return an error when fetching the private key. We expect
+ // a wrapped error indicating the failure.
+ mocks.addrStore.On("FetchScopedKeyManager", path.KeyScope).
+ Return(mocks.accountManager, nil).Once()
+ mocks.accountManager.On(
+ "DeriveFromKeyPath", mock.Anything, mock.Anything,
+ ).Return(mocks.pubKeyAddr, nil).Once()
+ mocks.pubKeyAddr.On("PrivKey").
+ Return((*btcec.PrivateKey)(nil), errPrivKeyMock).Once()
+
+ _, err = w.ECDH(t.Context(), path, privKey.PubKey())
+ require.ErrorContains(t, err, "cannot get private key")
+}
+
+// TestSignDigestLocked tests that SignDigest fails when the wallet is locked.
+func TestSignDigestLocked(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a locked wallet.
+ w, _ := createStartedWalletWithMocks(t)
+ path := BIP32Path{KeyScope: waddrmgr.KeyScopeBIP0084}
+ intent := &SignDigestIntent{
+ Digest: make([]byte, 32),
+ SigType: SigTypeECDSA,
+ }
+
+ // Act: Call SignDigest.
+ _, err := w.SignDigest(t.Context(), path, intent)
+
+ // Assert: Check for forbidden/locked error.
+ require.ErrorIs(t, err, ErrStateForbidden)
}
diff --git a/wallet/state.go b/wallet/state.go
new file mode 100644
index 0000000000..ae86c83ad8
--- /dev/null
+++ b/wallet/state.go
@@ -0,0 +1,310 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wallet
+
+import (
+ "errors"
+ "fmt"
+ "sync/atomic"
+)
+
+var (
+ // ErrStateForbidden is returned when an operation cannot be performed
+ // due to the current state of the wallet (e.g., locked, not started,
+ // not synced).
+ ErrStateForbidden = errors.New("operation forbidden in current state")
+)
+
+// lifecycle represents the lifecycle state of the wallet's main event loop.
+type lifecycle uint32
+
+const (
+ // lifecycleStopped indicates the wallet is stopped.
+ lifecycleStopped lifecycle = iota
+
+ // lifecycleStarting indicates the wallet is starting up.
+ lifecycleStarting
+
+ // lifecycleStarted indicates the wallet is started.
+ lifecycleStarted
+
+ // lifecycleStopping indicates the wallet is currently stopping.
+ lifecycleStopping
+)
+
+// String returns the string representation of a lifecycle.
+func (l lifecycle) String() string {
+ switch l {
+ case lifecycleStopped:
+ return "stopped"
+
+ case lifecycleStarting:
+ return "starting"
+
+ case lifecycleStarted:
+ return "started"
+
+ case lifecycleStopping:
+ return "stopping"
+
+ default:
+ return "unknown lifecycle state"
+ }
+}
+
+// walletState is a thread-safe wrapper that manages the state of the wallet
+// across three orthogonal dimensions. These dimensions are independent of each
+// other, allowing for a precise representation of the wallet's condition at any
+// given moment.
+//
+// The three dimensions are:
+// 1. Lifecycle (System State): Tracks whether the wallet is running, stopped,
+// or in transition. This dictates whether background processes are active.
+// 2. Synchronization (Chain State): Tracks the wallet's progress in syncing
+// with the blockchain (e.g., syncing, synced, scanning). This dictates
+// data freshness and availability.
+// 3. Authentication (Security State): Tracks whether the wallet is locked or
+// unlocked. This dictates the ability to perform sensitive operations like
+// signing.
+type walletState struct {
+ // lifecycle tracks the start/stop state of the wallet.
+ lifecycle atomic.Uint32
+
+ // syncer is the interface used to retrieve the current chain
+ // synchronization status from the synchronization component.
+ //
+ // This approach is chosen to enforce a strict separation of concerns
+ // and ownership:
+ // 1. Ownership: The syncer exclusively owns and manages the writes to
+ // the sync state as it is the only component driving the sync.
+ // 2. Decoupling: walletState provides a unified view of the wallet's
+ // atomic conditions without needing to know the implementation
+ // details of the synchronization subsystem.
+ // 3. Consistency: By reading directly from the syncer's internal
+ // state (via this interface), we ensure that the wallet always
+ // reports a real-time, consistent view of its data freshness.
+ syncer chainSyncer
+
+ // unlocked tracks whether the wallet is unlocked (true) or locked
+ // (false). The zero value is false (Locked), which is secure by
+ // default.
+ unlocked atomic.Bool
+}
+
+// newWalletState creates a new walletState initialized with the provided
+// syncer and secure defaults:
+// - Lifecycle: Stopped (awaiting Start() call).
+// - Synchronization: BackendSyncing (until syncer is running and connected).
+// - Authentication: Locked (secure by default).
+func newWalletState(syncer chainSyncer) walletState {
+ return walletState{
+ syncer: syncer,
+ }
+}
+
+// String returns a summary of the wallet's state.
+func (s *walletState) String() string {
+ lc := lifecycle(s.lifecycle.Load())
+ sync := s.syncState()
+ unlocked := s.unlocked.Load()
+
+ return fmt.Sprintf("status=%v, sync=%v, locked=%v", lc, sync, !unlocked)
+}
+
+// toStarting transitions the wallet state from Stopped to Starting.
+// It initializes the synchronization and authentication states to their
+// secure defaults. It returns an error if the wallet is already started or
+// not in the Stopped state.
+func (s *walletState) toStarting() error {
+ // 1. Lifecycle (System State): Atomic transition from Stopped to
+ // Starting.
+ if !s.lifecycle.CompareAndSwap(
+ uint32(lifecycleStopped), uint32(lifecycleStarting)) {
+
+ return fmt.Errorf("%w: current state is %v",
+ ErrWalletAlreadyStarted, lifecycle(s.lifecycle.Load()))
+ }
+
+ // 2. Authentication (Security State): Reset to Locked. This ensures
+ // the wallet always starts in a secure state.
+ s.unlocked.Store(false)
+
+ return nil
+}
+
+// toStarted marks the wallet as fully started. This should be called only
+// after all resource initialization is complete.
+func (s *walletState) toStarted() error {
+ if !s.lifecycle.CompareAndSwap(
+ uint32(lifecycleStarting), uint32(lifecycleStarted)) {
+
+ return fmt.Errorf("%w: cannot transition to started from %v",
+ ErrStateForbidden, lifecycle(s.lifecycle.Load()))
+ }
+
+ return nil
+}
+
+// toStopping transitions the wallet from Started to Stopping.
+// It returns an error if the wallet is not running.
+func (s *walletState) toStopping() error {
+ // Atomic transition from Started to Stopping.
+ if !s.lifecycle.CompareAndSwap(
+ uint32(lifecycleStarted), uint32(lifecycleStopping)) {
+
+ // If we are not Started, we cannot Stop.
+ // This covers Stopped, Starting, and Stopping.
+ return ErrStateForbidden
+ }
+
+ // Lock the wallet during shutdown to prevent any further signing
+ // operations.
+ s.unlocked.Store(false)
+
+ return nil
+}
+
+// toStopped marks the wallet as fully stopped.
+func (s *walletState) toStopped() error {
+ // We allow transition from Stopping (normal shutdown) or Starting
+ // (failure during startup).
+ //
+ // We use a CAS loop here to handle potential races where the state
+ // might change between Load and CompareAndSwap.
+ //
+ // This loop is guaranteed to terminate because:
+ // 1. If CAS succeeds, we break.
+ // 2. If CAS fails, it means the state changed. We reload the new state.
+ // 3. If the new state is not Stopping or Starting (e.g. it became
+ // Started or already Stopped), the validation check fails and we
+ // return an error.
+ for {
+ current := s.lifecycle.Load()
+ lc := lifecycle(current)
+
+ if lc != lifecycleStopping && lc != lifecycleStarting {
+ return fmt.Errorf("%w: cannot transition to stopped "+
+ "from %v", ErrStateForbidden, lc)
+ }
+
+ if s.lifecycle.CompareAndSwap(
+ current, uint32(lifecycleStopped),
+ ) {
+
+ break
+ }
+ }
+
+ // Force lock the wallet on shutdown for security.
+ s.unlocked.Store(false)
+
+ return nil
+}
+
+// toUnlocked marks the wallet as unlocked.
+func (s *walletState) toUnlocked() {
+ s.unlocked.Store(true)
+}
+
+// toLocked marks the wallet as locked.
+func (s *walletState) toLocked() {
+ s.unlocked.Store(false)
+}
+
+// syncState returns the current synchronization state.
+func (s *walletState) syncState() syncState {
+ if s.syncer == nil {
+ return syncStateBackendSyncing
+ }
+
+ return s.syncer.syncState()
+}
+
+// isSynced returns true if the wallet is fully synchronized with the
+// blockchain.
+func (s *walletState) isSynced() bool {
+ return s.syncState() == syncStateSynced
+}
+
+// isUnlocked returns true if the wallet is currently unlocked.
+func (s *walletState) isUnlocked() bool {
+ return s.unlocked.Load()
+}
+
+// isStarted returns true if the wallet is in the Started state.
+func (s *walletState) isStarted() bool {
+ return lifecycle(s.lifecycle.Load()) == lifecycleStarted
+}
+
+// isRunning returns true if the wallet is in any active state (not stopped
+// or stopping).
+func (s *walletState) isRunning() bool {
+ lc := lifecycle(s.lifecycle.Load())
+ return lc != lifecycleStopped && lc != lifecycleStopping
+}
+
+// canSign checks if the wallet is in a state allowing message/transaction
+// signing. The wallet must be Started and Unlocked.
+func (s *walletState) canSign() error {
+ if !s.isStarted() {
+ return fmt.Errorf("%w: wallet not started", ErrStateForbidden)
+ }
+
+ if !s.isUnlocked() {
+ return fmt.Errorf("%w: wallet locked", ErrStateForbidden)
+ }
+
+ return nil
+}
+
+// validateSynced checks if the wallet is running and fully synchronized.
+// It returns an error if the wallet is not started or if it is currently
+// syncing/rescanning.
+func (s *walletState) validateSynced() error {
+ if !s.isStarted() {
+ return fmt.Errorf("%w: wallet not started", ErrStateForbidden)
+ }
+
+ // TODO(yy): Should we allow creating txs while syncing?
+ // Currently we enforce sync to ensure accurate coin selection.
+ sync := s.syncState()
+ if sync != syncStateSynced {
+ return fmt.Errorf("%w: wallet is currently %s",
+ ErrStateForbidden, sync)
+ }
+
+ return nil
+}
+
+// validateStarted checks if the wallet is currently running.
+func (s *walletState) validateStarted() error {
+ if !s.isStarted() {
+ return fmt.Errorf("%w: wallet not started", ErrStateForbidden)
+ }
+
+ return nil
+}
+
+// canUnlock checks if the wallet is in a state that allows unlocking.
+func (s *walletState) canUnlock() error {
+ return s.validateStarted()
+}
+
+// canLock checks if the wallet is in a state that allows locking.
+func (s *walletState) canLock() error {
+ return s.validateStarted()
+}
+
+// canChangePassphrase checks if the wallet is in a state that allows changing
+// the passphrase.
+func (s *walletState) canChangePassphrase() error {
+ return s.validateStarted()
+}
+
+// isRecoveryMode returns true if the wallet is currently syncing or rescanning.
+func (s *walletState) isRecoveryMode() bool {
+ sync := s.syncState()
+ return sync == syncStateSyncing || sync == syncStateRescanning
+}
diff --git a/wallet/state_test.go b/wallet/state_test.go
new file mode 100644
index 0000000000..a48f8ef208
--- /dev/null
+++ b/wallet/state_test.go
@@ -0,0 +1,439 @@
+package wallet
+
+import (
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// TestStateSecureByDefault verifies that the zero-value of walletState
+// represents a safe, locked condition.
+func TestStateSecureByDefault(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a new state in Stopped (default) mode.
+ syncer := &mockChainSyncer{}
+ s := newWalletState(syncer)
+
+ // Act & Assert: Verify initial state.
+ require.False(t, s.isStarted())
+ require.False(t, s.isRunning())
+
+ // Act: Transition to Starting.
+ err := s.toStarting()
+ require.NoError(t, err)
+
+ // Act: Transition to Started.
+ err = s.toStarted()
+ require.NoError(t, err)
+ require.True(t, s.isStarted())
+ require.True(t, s.isRunning())
+
+ // Act: Transition to Stopping.
+ err = s.toStopping()
+ require.NoError(t, err)
+ require.False(t, s.isStarted())
+
+ // Stopping is NOT running.
+ require.False(t, s.isRunning())
+
+ // Act: Transition to Stopped.
+ err = s.toStopped()
+ require.NoError(t, err)
+ require.False(t, s.isRunning())
+
+ // Assert: Invalid transition (Stop when already Stopped).
+ err = s.toStopping()
+ require.ErrorIs(t, err, ErrStateForbidden)
+}
+
+// TestStateAuthentication verifies locking and unlocking logic.
+func TestStateAuthentication(t *testing.T) {
+ t.Parallel()
+
+ syncer := &mockChainSyncer{}
+ s := newWalletState(syncer)
+
+ // Arrange: Start the wallet (must be started to be useful).
+ require.NoError(t, s.toStarting())
+ err := s.toStarted()
+ require.NoError(t, err)
+
+ // Assert: Default is Locked.
+ require.False(t, s.isUnlocked())
+
+ // Act: Unlock.
+ s.toUnlocked()
+ require.NoError(t, err)
+ require.True(t, s.isUnlocked())
+
+ // Act: Lock.
+ s.toLocked()
+ require.NoError(t, err)
+ require.False(t, s.isUnlocked())
+
+ // Act: Verify canSign checks.
+ // Case 1: Locked -> Error.
+ err = s.canSign()
+ require.ErrorIs(t, err, ErrStateForbidden)
+ require.ErrorContains(t, err, "wallet locked")
+
+ // Case 2: Unlocked -> Success.
+ s.toUnlocked()
+ err = s.canSign()
+ require.NoError(t, err)
+
+ // Case 3: Stopped -> Error (even if unlocked, though stopped forces
+ // lock).
+ require.NoError(t, s.toStopping())
+ err = s.toStopped()
+ require.NoError(t, err)
+ // Note: toStopped forces lock, so we must check that logic too.
+ require.False(t, s.isUnlocked())
+
+ // Manually unlock while stopped to test canSign check.
+ s.toUnlocked()
+ err = s.canSign()
+ require.ErrorIs(t, err, ErrStateForbidden)
+ require.ErrorContains(t, err, "wallet not started")
+}
+
+// TestStateSynchronization verifies that the wallet state correctly reflects
+// the syncer's status.
+func TestStateSynchronization(t *testing.T) {
+ t.Parallel()
+
+ syncer := &mockChainSyncer{}
+ s := newWalletState(syncer)
+ require.NoError(t, s.toStarting())
+ require.NoError(t, s.toStarted())
+
+ // Arrange: Mock syncer to return Synced.
+ syncer.On("syncState").Return(syncStateSynced)
+
+ // Act & Assert.
+ require.Equal(t, syncStateSynced, s.syncState())
+ require.True(t, s.isSynced())
+ require.False(t, s.isRecoveryMode())
+
+ // Arrange: Mock syncer to return Syncing.
+ // Note: We need to reset expectations or use a new mock/state if rigid.
+ // testify/mock allows updating expectations usually.
+ syncer.ExpectedCalls = nil
+ syncer.On("syncState").Return(syncStateSyncing)
+
+ // Act & Assert.
+ require.Equal(t, syncStateSyncing, s.syncState())
+ require.False(t, s.isSynced())
+ require.True(t, s.isRecoveryMode())
+}
+
+// TestStateNilSyncer verifies behavior when syncer is nil (defensive check).
+func TestStateNilSyncer(t *testing.T) {
+ t.Parallel()
+
+ s := newWalletState(nil)
+
+ // Act & Assert: Should default to BackendSyncing safely.
+ require.Equal(t, syncStateBackendSyncing, s.syncState())
+}
+
+// TestStateThreadSafety verifies that state transitions are safe under
+// concurrent access.
+func TestStateThreadSafety(t *testing.T) {
+ t.Parallel()
+
+ syncer := &mockChainSyncer{}
+ s := newWalletState(syncer)
+
+ // Arrange: Hammer the start/stop transitions.
+ var wg sync.WaitGroup
+
+ start := make(chan struct{})
+
+ for range 100 {
+ wg.Add(1)
+
+ go func() {
+ defer wg.Done()
+
+ <-start
+
+ // NOTE: We ignore errors here because we are
+ // purposefully hammering the state machine from
+ // multiple goroutines. Many of these transitions will
+ // fail (e.g., trying to start an already starting
+ // wallet), which is expected behavior. We are
+ // primarily verifying that no data races or panics
+ // occur.
+ //
+ // Try to start.
+ _ = s.toStarting()
+
+ // Try to stop.
+ _ = s.toStopping()
+ }()
+ }
+
+ close(start)
+ wg.Wait()
+
+ // Assert: State should be valid (either stopped, starting, or
+ // stopping).
+ // Just ensure no panics occurred.
+}
+
+// TestValidateSynced verifies the validation logic for operations requiring
+// synchronization.
+func TestValidateSynced(t *testing.T) {
+ t.Parallel()
+
+ syncer := &mockChainSyncer{}
+ s := newWalletState(syncer)
+
+ // Case 1: Not started.
+ err := s.validateSynced()
+ require.ErrorIs(t, err, ErrStateForbidden)
+
+ // Case 2: Started but not synced.
+ require.NoError(t, s.toStarting())
+ require.NoError(t, s.toStarted())
+ syncer.On("syncState").Return(syncStateSyncing)
+
+ err = s.validateSynced()
+ require.ErrorIs(t, err, ErrStateForbidden)
+
+ // Case 3: Started and synced.
+ syncer.ExpectedCalls = nil
+ syncer.On("syncState").Return(syncStateSynced)
+
+ err = s.validateSynced()
+ require.NoError(t, err)
+}
+
+// TestStateLifecycleTransitions verifies valid and invalid lifecycle
+// state transitions.
+func TestStateLifecycleTransitions(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ lifecycle lifecycle
+ running bool
+ }{
+ {
+ name: "started is running",
+ lifecycle: lifecycleStarted,
+ running: true,
+ },
+ {
+ name: "stopped is not running",
+ lifecycle: lifecycleStopped,
+ running: false,
+ },
+ {
+ name: "stopping is not running",
+ lifecycle: lifecycleStopping,
+ running: false,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup state.
+ state := newWalletState(nil)
+ state.lifecycle.Store(uint32(tc.lifecycle))
+
+ // Act & Assert: Verify isRunning result.
+ require.Equal(t, tc.running, state.isRunning())
+ })
+ }
+}
+
+// TestStateString verifies the summary string format.
+func TestStateString(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a specific state.
+ ms := &mockChainSyncer{}
+ ms.On("syncState").Return(syncStateSyncing)
+
+ state := newWalletState(ms)
+ state.lifecycle.Store(uint32(lifecycleStarted))
+ state.unlocked.Store(true)
+
+ // Act: Get the summary string.
+ got := state.String()
+
+ // Assert: Verify exact format and values.
+ // Note: String uses !unlocked for "locked" boolean value.
+ expected := "status=started, sync=syncing, locked=false"
+ require.Equal(t, expected, got)
+}
+
+// TestStateStartStop verifies the transition logic for start and stop.
+func TestStateStartStop(t *testing.T) {
+ t.Parallel()
+
+ t.Run("start success", func(t *testing.T) {
+ t.Parallel()
+
+ state := newWalletState(nil)
+
+ // Set initial random state to verify reset.
+ state.unlocked.Store(true)
+
+ err := state.toStarting()
+ require.NoError(t, err)
+ require.Equal(t, uint32(lifecycleStarting),
+ state.lifecycle.Load())
+ require.False(t, state.unlocked.Load())
+
+ // Now mark as started.
+ err = state.toStarted()
+ require.NoError(t, err)
+ require.Equal(t, uint32(lifecycleStarted),
+ state.lifecycle.Load())
+ })
+
+ t.Run("start fail already started", func(t *testing.T) {
+ t.Parallel()
+
+ state := newWalletState(nil)
+ state.lifecycle.Store(uint32(lifecycleStarted))
+
+ err := state.toStarting()
+ require.ErrorIs(t, err, ErrWalletAlreadyStarted)
+ })
+
+ t.Run("stop success", func(t *testing.T) {
+ t.Parallel()
+
+ state := newWalletState(nil)
+ state.lifecycle.Store(uint32(lifecycleStarted))
+ state.unlocked.Store(true)
+
+ err := state.toStopping()
+ require.NoError(t, err)
+
+ require.Equal(t, uint32(lifecycleStopping),
+ state.lifecycle.Load())
+ require.False(t, state.unlocked.Load())
+ })
+
+ t.Run("stop fail not started", func(t *testing.T) {
+ t.Parallel()
+
+ state := newWalletState(nil)
+ state.lifecycle.Store(uint32(lifecycleStopped))
+
+ err := state.toStopping()
+ require.ErrorIs(t, err, ErrStateForbidden)
+ })
+}
+
+// TestStateValidateStarted verifies the validateStarted check.
+func TestStateValidateStarted(t *testing.T) {
+ t.Parallel()
+
+ t.Run("success started", func(t *testing.T) {
+ t.Parallel()
+
+ state := newWalletState(nil)
+ state.lifecycle.Store(uint32(lifecycleStarted))
+ require.NoError(t, state.validateStarted())
+ })
+
+ t.Run("fail stopped", func(t *testing.T) {
+ t.Parallel()
+
+ state := newWalletState(nil)
+ state.lifecycle.Store(uint32(lifecycleStopped))
+ require.ErrorIs(t, state.validateStarted(), ErrStateForbidden)
+ })
+}
+
+// TestStateAuthChecks verifies the semantic auth check methods.
+func TestStateAuthChecks(t *testing.T) {
+ t.Parallel()
+
+ // Helper to set state
+ setState := func(s *walletState, lc lifecycle) {
+ s.lifecycle.Store(uint32(lc))
+ }
+
+ t.Run("started allowed", func(t *testing.T) {
+ t.Parallel()
+
+ state := newWalletState(nil)
+
+ setState(&state, lifecycleStarted)
+ require.NoError(t, state.canUnlock())
+ require.NoError(t, state.canLock())
+ require.NoError(t, state.canChangePassphrase())
+ })
+
+ t.Run("stopped forbidden", func(t *testing.T) {
+ t.Parallel()
+
+ state := newWalletState(nil)
+
+ setState(&state, lifecycleStopped)
+ require.ErrorIs(t, state.canUnlock(), ErrStateForbidden)
+ require.ErrorIs(t, state.canLock(), ErrStateForbidden)
+ require.ErrorIs(t, state.canChangePassphrase(),
+ ErrStateForbidden)
+ })
+}
+
+// TestStateIsRecoveryMode verifies the recovery mode check.
+func TestStateIsRecoveryMode(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ sync syncState
+ isRecovery bool
+ }{
+ {"backend syncing", syncStateBackendSyncing, false},
+ {"syncing", syncStateSyncing, true},
+ {"synced", syncStateSynced, false},
+ {"rescanning", syncStateRescanning, true},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ ms := &mockChainSyncer{}
+ ms.On("syncState").Return(tc.sync)
+
+ state := newWalletState(ms)
+ require.Equal(t, tc.isRecovery, state.isRecoveryMode())
+ })
+ }
+}
+
+// TestStateAuxiliaryMethods verifies helper methods like canUnlock, canLock,
+// and canChangePassphrase.
+func TestStateAuxiliaryMethods(t *testing.T) {
+ t.Parallel()
+
+ syncer := &mockChainSyncer{}
+ s := newWalletState(syncer)
+
+ // Case 1: Stopped -> All forbidden.
+ require.ErrorIs(t, s.canUnlock(), ErrStateForbidden)
+ require.ErrorIs(t, s.canLock(), ErrStateForbidden)
+ require.ErrorIs(t, s.canChangePassphrase(), ErrStateForbidden)
+
+ // Case 2: Started -> All allowed.
+ require.NoError(t, s.toStarting())
+ require.NoError(t, s.toStarted())
+ require.NoError(t, s.canUnlock())
+ require.NoError(t, s.canLock())
+ require.NoError(t, s.canChangePassphrase())
+}
diff --git a/wallet/syncer.go b/wallet/syncer.go
new file mode 100644
index 0000000000..59a06ef27b
--- /dev/null
+++ b/wallet/syncer.go
@@ -0,0 +1,1351 @@
+package wallet
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync/atomic"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/gcs"
+ "github.com/btcsuite/btcd/btcutil/v2/gcs/builder"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/chain"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+)
+
+var (
+ // ErrCFiltersUnavailable is returned when the chain backend cannot
+ // serve compact filters.
+ ErrCFiltersUnavailable = errors.New("cfilters unavailable")
+
+ // ErrUnknownSyncMethod is returned when an unknown synchronization
+ // method is specified.
+ ErrUnknownSyncMethod = errors.New("unknown sync method")
+
+ // ErrScanBatchEmpty is returned when a scan batch contains no blocks.
+ ErrScanBatchEmpty = errors.New("scan batch empty")
+
+ // ErrUnknownRescanJobType is returned when an unknown rescan job type
+ // is encountered.
+ ErrUnknownRescanJobType = errors.New("unknown rescan job type")
+
+ // ErrInvalidStartHeight is returned when a resync or rescan is
+ // requested with an invalid start height (e.g., zero if not allowed).
+ ErrInvalidStartHeight = errors.New("invalid start height")
+
+ // ErrStartHeightTooHigh is returned when a resync or rescan is
+ // requested with a start height that is greater than the current
+ // chain tip.
+ ErrStartHeightTooHigh = errors.New("start height is greater than " +
+ "current chain tip")
+
+ // ErrStartHeightTooLarge is returned when a resync or rescan is
+ // requested with a start height that exceeds the maximum value of
+ // an int32, which is the underlying type for block heights.
+ ErrStartHeightTooLarge = errors.New("start height too large, exceeds " +
+ "maximum int32 value")
+
+ // ErrNoScanTargets is returned when a targeted rescan is requested with
+ // no targets.
+ ErrNoScanTargets = errors.New("at least one target must be specified")
+)
+
+const (
+ // syncStateSwitchThreshold is the number of blocks behind the chain
+ // tip at which the wallet switches to the "Syncing" state. Gaps
+ // smaller than this are handled silently (blocking DB lock) to avoid
+ // disrupting UX with "Wallet Busy" errors for minor lags.
+ //
+ // Value 6 (approx 1 hour) is chosen based on a balance of two factors:
+ //
+ // 1. Database Contention: Synchronization requires a database write
+ // lock. Processing 6 blocks typically takes less than 1 second,
+ // which is an acceptable duration for other operations (like
+ // CreateTx) to block (wait) on the database lock. Gaps larger than
+ // this would result in noticeable UI hangs, so we switch to the
+ // explicit "Syncing" state which allows the wallet to return an
+ // immediate error instead of blocking.
+ //
+ // 2. Data Integrity vs. UX: While in a silent sync, the wallet might
+ // allow the user to initiate actions based on slightly outdated
+ // data (e.g., spending an output that was actually spent in one of
+ // the missing blocks). For a 6-block gap, the risk is minimal, and
+ // such transactions would be rejected by the network mempool or
+ // miners. However, for large gaps, the risk of false "Insufficient
+ // Funds" errors or extremely inaccurate fee estimates increases,
+ // making the explicit "Syncing" state a necessary safeguard.
+ syncStateSwitchThreshold = 6
+)
+
+// syncState represents the synchronization status of the wallet with the
+// blockchain.
+type syncState uint32
+
+const (
+ // syncStateBackendSyncing indicates the wallet is waiting for the
+ // chain backend to finish syncing.
+ syncStateBackendSyncing syncState = iota
+
+ // syncStateSyncing indicates the wallet is running but catching up to
+ // the chain tip (or rewinding).
+ syncStateSyncing
+
+ // syncStateSynced indicates the wallet is running and synced to the
+ // chain tip.
+ syncStateSynced
+
+ // syncStateRescanning indicates the wallet is running a historical
+ // scan for specific user-provided targets, such as accounts or
+ // addresses, without rewinding the global synchronization state.
+ syncStateRescanning
+)
+
+// String returns the string representation of a syncState.
+func (s syncState) String() string {
+ switch s {
+ case syncStateBackendSyncing:
+ return "backend-syncing"
+
+ case syncStateSyncing:
+ return "syncing"
+
+ case syncStateSynced:
+ return "synced"
+
+ case syncStateRescanning:
+ return "rescanning"
+
+ default:
+ return "unknown sync state"
+ }
+}
+
+// scanType represents the type of rescan being requested.
+type scanType uint8
+
+const (
+ // scanTypeRewind represents a full rescan which rewinds the wallet's
+ // state to a specific point and scans forward.
+ scanTypeRewind scanType = iota
+
+ // scanTypeTargeted represents a targeted rescan for specific addresses
+ // or accounts without altering the global sync state.
+ scanTypeTargeted
+)
+
+// scanReq is an internal request to perform a rescan.
+type scanReq struct {
+ // typ specifies the type of rescan to perform.
+ typ scanType
+
+ // startBlock specifies the block height and hash to start the rescan
+ // from.
+ startBlock waddrmgr.BlockStamp
+
+ // targets specifies the accounts to scan for. This is only used for
+ // targeted rescans.
+ targets []waddrmgr.AccountScope
+}
+
+// scanResult holds the result of processing a single block during a batch
+// scan.
+type scanResult struct {
+ // BlockProcessResult embeds the results of filtering the block.
+ *BlockProcessResult
+
+ // meta contains block metadata (hash, height, time).
+ meta *wtxmgr.BlockMeta
+}
+
+// chainSyncer is a private interface that abstracts the chain synchronization
+// logic, allowing it to be mocked for testing the wallet and controller.
+type chainSyncer interface {
+ // run executes the main synchronization loop.
+ run(ctx context.Context) error
+
+ // requestScan submits a rescan job to the syncer.
+ requestScan(ctx context.Context, req *scanReq) error
+
+ // syncState returns the current synchronization state.
+ syncState() syncState
+}
+
+// syncer is a stateless blocking worker responsible for synchronizing the
+// wallet with the blockchain. It operates within the lifecycle provided by the
+// caller via context and manages the chain loop, scanning, and reorg handling.
+type syncer struct {
+ // cfg holds the configuration parameters for the syncer.
+ cfg Config
+
+ // addrStore is the address and key manager.
+ addrStore waddrmgr.AddrStore
+
+ // txStore is the transaction manager.
+ txStore wtxmgr.TxStore
+
+ // state tracks the chain synchronization status.
+ state atomic.Uint32
+
+ // scanReqChan is the internal mailbox used to receive scan requests
+ // from the controller. It is buffered to ensure that submitting a
+ // request does not unnecessarily block the calling goroutine.
+ scanReqChan chan *scanReq
+
+ // publisher is the component responsible for broadcasting transactions
+ // to the network. It is primarily used during the maintenance phase to
+ // ensure unmined transactions remain in the mempool.
+ publisher TxPublisher
+}
+
+// newSyncer creates a new syncer instance.
+func newSyncer(cfg Config, addrStore waddrmgr.AddrStore,
+ txStore wtxmgr.TxStore, publisher TxPublisher) *syncer {
+
+ return &syncer{
+ cfg: cfg,
+ addrStore: addrStore,
+ txStore: txStore,
+ scanReqChan: make(chan *scanReq, 1),
+ publisher: publisher,
+ }
+}
+
+// syncState returns the current synchronization state of the wallet.
+func (s *syncer) syncState() syncState {
+ return syncState(s.state.Load())
+}
+
+// isRecoveryMode returns true if the wallet is currently syncing or
+// rescanning.
+func (s *syncer) isRecoveryMode() bool {
+ status := s.syncState()
+ return status == syncStateSyncing || status == syncStateRescanning
+}
+
+// initChainSync performs the initial setup for the chain synchronization loop.
+// This includes waiting for the backend to sync, checking for rollbacks, and
+// enabling block notifications. It returns an error if any of these setup
+// steps fail.
+func (s *syncer) initChainSync(ctx context.Context) error {
+ var err error
+
+ // Inform the backend about our birthday for optimization. For backends
+ // like Neutrino (SPV), this provides a starting point for the internal
+ // synchronization of block headers and compact filters. Without this
+ // hint, the backend might attempt to sync from genesis or its latest
+ // checkpoint, leading to unnecessary network I/O and delayed wallet
+ // readiness.
+ if cc, ok := s.cfg.Chain.(*chain.NeutrinoClient); ok {
+ cc.SetStartTime(s.addrStore.Birthday())
+ }
+
+ // Wait for the backend to be synced to the network. We require the
+ // backend to be synced before we start scanning to ensure we have a
+ // consistent view of the chain and can perform recovery correctly.
+ s.state.Store(uint32(syncStateBackendSyncing))
+
+ err = s.waitUntilBackendSynced(ctx)
+ if err != nil {
+ return fmt.Errorf("unable to wait for backend sync: %w", err)
+ }
+
+ // Check for any reorgs that happened while we were down.
+ err = s.checkRollback(ctx)
+ if err != nil {
+ return fmt.Errorf("unable to check for rollback: %w", err)
+ }
+
+ // Enable block notifications from the chain backend.
+ err = s.cfg.Chain.NotifyBlocks()
+ if err != nil {
+ return fmt.Errorf("unable to start block notifications: %w",
+ err)
+ }
+
+ return nil
+}
+
+// waitUntilBackendSynced blocks until the chain backend considers itself
+// "current".
+func (s *syncer) waitUntilBackendSynced(ctx context.Context) error {
+ // Check immediately if the backend is already synced.
+ if s.cfg.Chain.IsCurrent() {
+ return nil
+ }
+
+ // We'll poll every second to determine if our chain considers itself
+ // "current".
+ t := time.NewTicker(time.Second)
+ defer t.Stop()
+
+ for {
+ select {
+ case <-t.C:
+ if s.cfg.Chain.IsCurrent() {
+ return nil
+ }
+
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+ }
+}
+
+// checkRollback ensures the wallet is synchronized with the current chain tip.
+// It checks if the wallet's synced tip is still on the main chain, and if not,
+// rewinds the wallet state to the common ancestor.
+func (s *syncer) checkRollback(ctx context.Context) error {
+ var err error
+
+ // batchSize is the number of blocks to fetch from the chain backend in
+ // a single batch when checking for a rollback. A value of 10 is chosen
+ // as a conservative default that covers the vast majority of reorg
+ // scenarios (typically 1-3 blocks) while keeping individual batch
+ // requests lightweight.
+ const batchSize = 10
+
+ syncedTo := s.addrStore.SyncedTo()
+ syncedHeight := syncedTo.Height
+
+ var (
+ localHashes []*chainhash.Hash
+ remoteHashes []chainhash.Hash
+ header *wire.BlockHeader
+ )
+
+ for syncedHeight > 0 {
+ // Calculate the range for this batch. We scan backwards:
+ // [startHeight, endHeight] where endHeight is syncedHeight.
+ endHeight := syncedHeight
+ startHeight := max(0, endHeight-batchSize+1)
+
+ // Fetch Local Batch (from wallet's database).
+ localHashes, err = s.DBGetSyncedBlocks(
+ ctx, startHeight, endHeight,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Fetch Remote Batch - Fetch corresponding hashes from the
+ // chain backend.
+ remoteHashes, err = s.cfg.Chain.GetBlockHashes(
+ int64(startHeight), int64(endHeight),
+ )
+ if err != nil {
+ return fmt.Errorf("remote get block hashes: %w", err)
+ }
+
+ // Compare Batches. Iterate backwards to find the last matching
+ // block (the fork point).
+ matchIndex := s.findForkPoint(localHashes, remoteHashes)
+
+ // Case A: Tip matches. No rollback needed (if we are at the
+ // tip). If syncedHeight == syncedTo.Height and matchIndex is
+ // the last element, then we are fully synced on the main
+ // chain.
+ if syncedHeight == syncedTo.Height &&
+ matchIndex == len(localHashes)-1 {
+
+ return nil
+ }
+
+ // Case B: Mismatch found within this batch. A fork point has
+ // been detected. This indicates a blockchain reorganization
+ // where the wallet's local chain history diverges from the
+ // chain backend's view within the current batch of blocks.
+ if matchIndex != -1 {
+ //nolint:gosec // matchIndex < batchSize (10).
+ forkHeight := startHeight + int32(matchIndex)
+ forkHash := localHashes[matchIndex]
+
+ log.Infof("Rollback detected! Rewinding to height %d "+
+ "(%v)", forkHeight, forkHash)
+
+ // Fetch the block header outside the DB transaction to
+ // avoid holding the lock during an RPC call.
+ header, err = s.cfg.Chain.GetBlockHeader(forkHash)
+ if err != nil {
+ return fmt.Errorf("get fork header: %w", err)
+ }
+
+ // Perform the rollback.
+ return s.DBPutRewind(ctx, waddrmgr.BlockStamp{
+ Height: forkHeight,
+ Hash: *forkHash,
+ Timestamp: header.Timestamp,
+ })
+ }
+
+ // Case C: No match in this batch. The fork point is deeper.
+ // Move syncedHeight back and continue loop.
+ syncedHeight = startHeight - 1
+ }
+
+ return nil
+}
+
+// findForkPoint compares local and remote block hashes to find the last
+// matching block (fork point). It returns the index of the last match in the
+// slices, or -1 if no match is found.
+func (s *syncer) findForkPoint(localHashes []*chainhash.Hash,
+ remoteHashes []chainhash.Hash) int {
+
+ // Compare up to the length of the shortest slice to avoid
+ // out-of-bounds panics if the chain backend returns fewer hashes than
+ // expected.
+ minLen := min(len(localHashes), len(remoteHashes))
+
+ for i := minLen - 1; i >= 0; i-- {
+ if localHashes[i].IsEqual(&remoteHashes[i]) {
+ return i
+ }
+ }
+
+ return -1
+}
+
+// run executes the main synchronization loop.
+func (s *syncer) run(ctx context.Context) error {
+ // Initialize the chain sync state.
+ err := s.initChainSync(ctx)
+ if err != nil {
+ if errors.Is(err, context.Canceled) ||
+ errors.Is(err, ErrWalletShuttingDown) {
+
+ return nil
+ }
+
+ return fmt.Errorf("initialize chain sync: %w", err)
+ }
+
+ for {
+ err := s.runSyncStep(ctx)
+ if err != nil {
+ if errors.Is(err, context.Canceled) ||
+ errors.Is(err, ErrWalletShuttingDown) {
+
+ return nil
+ }
+
+ return err
+ }
+ }
+}
+
+// runSyncStep performs a single iteration of the synchronization loop. It
+// advances the chain sync state, broadcasts unmined transactions, and then
+// waits for the next event (notification or job).
+func (s *syncer) runSyncStep(ctx context.Context) error {
+ // Attempt to advance the wallet's sync state.
+ syncFinished, err := s.advanceChainSync(ctx)
+ if err != nil {
+ return fmt.Errorf("advance chain sync: %w", err)
+ }
+
+ if !syncFinished {
+ return nil
+ }
+
+ // Rebroadcast unmined transactions.
+ err = s.broadcastUnminedTxns(ctx)
+ if err != nil {
+ return fmt.Errorf("broadcast unmined txns: %w", err)
+ }
+
+ // Proceed to idle mode, waiting for notifications or jobs.
+ err = s.waitForEvent(ctx)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// requestScan submits a rescan job to the syncer.
+func (s *syncer) requestScan(ctx context.Context, req *scanReq) error {
+ select {
+ case s.scanReqChan <- req:
+ return nil
+
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
+// broadcastUnminedTxns retrieves all unmined transactions from the wallet and
+// attempts to re-broadcast them to the network.
+func (s *syncer) broadcastUnminedTxns(ctx context.Context) error {
+ txs, err := s.DBGetUnminedTxns(ctx)
+ if err != nil {
+ log.Errorf("Unable to retrieve unconfirmed transactions to "+
+ "resend: %v", err)
+
+ return fmt.Errorf("failed to retrieve unconfirmed txs: %w", err)
+ }
+
+ for _, tx := range txs {
+ err := s.publisher.Broadcast(ctx, tx, "")
+ if err != nil {
+ log.Warnf("Unable to rebroadcast tx %v: %v",
+ tx.TxHash(), err)
+ }
+ }
+
+ return nil
+}
+
+// scanBatchHeadersOnly performs a lightweight scan by only fetching block
+// headers. This is used when the wallet has no addresses or outpoints to
+// watch, allowing it to fast-forward its sync state.
+func (s *syncer) scanBatchHeadersOnly(_ context.Context,
+ startHeight, endHeight int32) ([]scanResult, error) {
+
+ // Batch 1: Fetch Block Hashes.
+ hashes, err := s.cfg.Chain.GetBlockHashes(
+ int64(startHeight), int64(endHeight),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("batch get block hashes: %w", err)
+ }
+
+ // Batch 2: Fetch Block Headers (for timestamps).
+ headers, err := s.cfg.Chain.GetBlockHeaders(hashes)
+ if err != nil {
+ return nil, fmt.Errorf("batch get block headers: %w", err)
+ }
+
+ results := make([]scanResult, 0, len(hashes))
+ for i := range hashes {
+ hash := hashes[i]
+ header := headers[i]
+
+ //nolint:gosec // i is bounded by batch size (2000), so
+ // addition to startHeight won't overflow int32.
+ height := startHeight + int32(i)
+
+ meta := &wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{Hash: hash, Height: height},
+ Time: header.Timestamp,
+ }
+
+ results = append(results, scanResult{
+ meta: meta,
+ // We provide an empty BlockProcessResult to avoid nil
+ // pointer dereferences when accessing embedded fields
+ // (like RelevantTxs) in commitSyncBatch. This
+ // effectively acts as a "no-op" result.
+ BlockProcessResult: &BlockProcessResult{},
+ })
+ }
+
+ return results, nil
+}
+
+// loadFullScanState initializes a fresh recovery state for a new batch scan.
+// It loads active data, syncs horizons from DB, and prepares the initial
+// lookahead window.
+func (s *syncer) loadFullScanState(
+ ctx context.Context) (*RecoveryState, error) {
+
+ horizonData, initialAddrs, initialUnspent, err := s.loadWalletScanData(
+ ctx,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Initialize a fresh recovery state for this batch to ensure no stale
+ // state leaks between batches.
+ scanState := NewRecoveryState(
+ s.cfg.RecoveryWindow, s.cfg.ChainParams, s.addrStore,
+ )
+
+ // Initialize Batch State (History + Lookahead)
+ err = scanState.Initialize(horizonData, initialAddrs, initialUnspent)
+ if err != nil {
+ return nil, fmt.Errorf("init scan state: %w", err)
+ }
+
+ return scanState, nil
+}
+
+// scanBatchWithFullBlocks implements the fallback scanning by downloading and
+// checking every block in the batch.
+func (s *syncer) scanBatchWithFullBlocks(_ context.Context,
+ scanState *RecoveryState, startHeight int32,
+ hashes []chainhash.Hash) ([]scanResult, error) {
+
+ results := make([]scanResult, 0, len(hashes))
+
+ // 1. Fetch ALL Blocks.
+ blocks, err := s.cfg.Chain.GetBlocks(hashes)
+ if err != nil {
+ return nil, fmt.Errorf("batch get blocks (fallback): %w", err)
+ }
+
+ // Iterate and Process Blocks. Now that all blocks in the batch have
+ // been fetched, process each block individually. This involves
+ // creating the necessary block metadata and then feeding the full
+ // block into the recovery state for filtering and horizon expansion.
+ for i := range hashes {
+ hash := hashes[i]
+ block := blocks[i]
+
+ //nolint:gosec // i is bounded by batch size (2000), so
+ // addition to startHeight won't overflow int32.
+ height := startHeight + int32(i)
+
+ meta := &wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{Hash: hash, Height: height},
+ }
+
+ // Process the block using the recovery state. This involves:
+ // 1. Filtering the block for relevant transactions.
+ // 2. Expanding the address lookahead horizons if new addresses
+ // are found.
+ // 3. Re-filtering if horizons were expanded to ensure we catch
+ // all transactions relevant to the newly derived addresses.
+ res, err := scanState.ProcessBlock(block)
+ if err != nil {
+ return nil, fmt.Errorf("process block %d (%s): %w",
+ height, hash, err)
+ }
+
+ results = append(results, scanResult{
+ meta: meta,
+ BlockProcessResult: res,
+ })
+ }
+
+ return results, nil
+}
+
+// initResultsForCFilterScan fetches block headers for the given hashes and
+// initializes a slice of scanResult with basic metadata (hash, height, time).
+// This is a preparatory step specifically for CFilter-based scans.
+func (s *syncer) initResultsForCFilterScan(_ context.Context,
+ startHeight int32, hashes []chainhash.Hash) ([]scanResult, error) {
+
+ headers, err := s.cfg.Chain.GetBlockHeaders(hashes)
+ if err != nil {
+ return nil, fmt.Errorf("batch get block headers: %w", err)
+ }
+
+ results := make([]scanResult, len(hashes))
+ for i := range hashes {
+ results[i] = scanResult{
+ meta: &wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ Hash: hashes[i],
+
+ //nolint:gosec // i is bounded by batch
+ // size (2000), so addition to
+ // startHeight won't overflow int32.
+ Height: startHeight + int32(i),
+ },
+ Time: headers[i].Timestamp,
+ },
+ // Initialize with empty result to avoid nil
+ // dereference if block is not processed.
+ BlockProcessResult: &BlockProcessResult{},
+ }
+ }
+
+ return results, nil
+}
+
+// filterBatch iterates over the scan results and matches them against the
+// provided filters using the watchlist. It returns a list of block hashes that
+// matched the filter.
+func (s *syncer) filterBatch(ctx context.Context, results []scanResult,
+ filters []*gcs.Filter,
+ blockMap map[chainhash.Hash]*wire.MsgBlock,
+ watchList [][]byte) ([]chainhash.Hash, error) {
+
+ var matchedHashes []chainhash.Hash
+ for i := range results {
+ // Check context cancellation.
+ select {
+ case <-ctx.Done():
+ return nil, fmt.Errorf("context done: %w", ctx.Err())
+ default:
+ }
+
+ // Skip if we already fetched this block.
+ if _, ok := blockMap[results[i].meta.Hash]; ok {
+ continue
+ }
+
+ filter := filters[i]
+
+ // If the filter is nil or has no elements (N=0), it indicates
+ // a potential issue with the chain backend (e.g., filter not
+ // available, corrupted, or for an invalid block). While N=0 is
+ // theoretically impossible for valid Bitcoin blocks with
+ // regular filters (due to coinbase transactions), we
+ // conservatively treat both cases as a match to ensure no
+ // relevant transactions are missed. This prioritizes safety
+ // over strict filter efficiency, forcing the download of the
+ // full block for later processing.
+ if filter == nil || filter.N() == 0 {
+ var n uint32
+ if filter != nil {
+ n = filter.N()
+ }
+
+ log.Errorf("Filter missing or empty for block %v "+
+ "(nil=%v, N=%d), forcing download",
+ results[i].meta.Hash, filter == nil, n)
+
+ matchedHashes = append(
+ matchedHashes, results[i].meta.Hash,
+ )
+
+ continue
+ }
+
+ key := builder.DeriveKey(&results[i].meta.Hash)
+
+ matched, err := filter.MatchAny(key, watchList)
+ if err != nil {
+ return nil, fmt.Errorf("filter match failed: %w", err)
+ }
+
+ if matched {
+ matchedHashes = append(
+ matchedHashes, results[i].meta.Hash,
+ )
+ }
+ }
+
+ return matchedHashes, nil
+}
+
+// matchAndFetchBatch performs the core logic of matching CFilters against the
+// wallet's watchlist and fetching the corresponding blocks. It iterates over
+// the provided `results`, checking filters for each. Blocks that match (and
+// haven't been fetched yet) are downloaded and added to the `blockMap`.
+//
+// NOTE: This method mutates the provided `blockMap` parameter by adding new
+// blocks to it.
+func (s *syncer) matchAndFetchBatch(ctx context.Context, state *RecoveryState,
+ results []scanResult,
+ filters []*gcs.Filter,
+ blockMap map[chainhash.Hash]*wire.MsgBlock) error {
+
+ // Generate the watchlist for CFilter matching.
+ watchList, err := state.BuildCFilterData()
+ if err != nil {
+ return fmt.Errorf("build cfilter data: %w", err)
+ }
+
+ matchedHashes, err := s.filterBatch(
+ ctx, results, filters, blockMap, watchList,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Fetch Matched Blocks.
+ if len(matchedHashes) > 0 {
+ blocks, err := s.cfg.Chain.GetBlocks(matchedHashes)
+ if err != nil {
+ return fmt.Errorf("batch get blocks: %w", err)
+ }
+
+ for i, block := range blocks {
+ blockMap[matchedHashes[i]] = block
+ }
+ }
+
+ return nil
+}
+
+// scanBatchWithCFilters implements the fast-path scanning using Compact
+// Filters. It fetches filters, matches them locally, fetches only matched
+// blocks, and handles horizon expansion with an in-place resume logic.
+func (s *syncer) scanBatchWithCFilters(ctx context.Context,
+ scanState *RecoveryState, startHeight int32,
+ hashes []chainhash.Hash) ([]scanResult, error) {
+
+ // Fetch CFilters for the batch.
+ filters, err := s.cfg.Chain.GetCFilters(
+ hashes, wire.GCSFilterRegular,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrCFiltersUnavailable, err)
+ }
+
+ // Fetch headers and initialize results with metadata.
+ results, err := s.initResultsForCFilterScan(ctx, startHeight, hashes)
+ if err != nil {
+ return nil, err
+ }
+
+ // blockMap serves as a cache for full block data that has been
+ // fetched. It is populated by `matchAndFetchBatch` during both the
+ // initial matching phase and any subsequent re-matching due to horizon
+ // expansion. This map ensures that once a block is identified as
+ // relevant and downloaded, it's available for processing without
+ // redundant network requests, maintaining I/O efficiency across the
+ // processing loops.
+ blockMap := make(map[chainhash.Hash]*wire.MsgBlock, len(hashes))
+
+ // Initial Match: Optimistically match the entire batch of filters
+ // against the current watchlist. This allows us to fetch all likely
+ // relevant blocks in a single batch operation, maximizing I/O
+ // parallelism.
+ err = s.matchAndFetchBatch(ctx, scanState, results, filters, blockMap)
+ if err != nil {
+ return nil, err
+ }
+
+ // Process Blocks: Iterate through the results and process any blocks
+ // that were matched and fetched.
+ for i := range results {
+ res := &results[i]
+ block := blockMap[res.meta.Hash]
+
+ // If block was not matched/fetched, skip processing.
+ if block == nil {
+ continue
+ }
+
+ processRes, err := scanState.ProcessBlock(block)
+ if err != nil {
+ return nil, fmt.Errorf("process block %d (%s): %w",
+ res.meta.Height, res.meta.Hash, err)
+ }
+
+ // Attach the real result to the pre-allocated scanResult.
+ res.BlockProcessResult = processRes
+
+ // Move to the next if the horizon is not expanded.
+ if !processRes.Expanded {
+ continue
+ }
+
+ log.Debugf("Horizon expanded at height %d, updating filters",
+ res.meta.Height)
+
+ // If the horizon expanded, our watchlist has changed. We must
+ // re-evaluate the remaining filters in the batch (i+1 onwards)
+ // against the new addresses or outpoints to ensure we don't
+ // miss any relevant transactions that were previously skipped.
+ // This "in-place resume" logic ensures correctness despite the
+ // batch pre-fetching optimization.
+ err = s.matchAndFetchBatch(
+ ctx, scanState, results[i+1:], filters[i+1:], blockMap,
+ )
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return results, nil
+}
+
+// fetchAndFilterBlocks retrieves and processes a batch of blocks from the
+// chain backend. It handles CFilter matching, block fetching, local filtering,
+// and dynamic address discovery (expanding horizons in memory/read-only DB).
+func (s *syncer) fetchAndFilterBlocks(ctx context.Context,
+ scanState *RecoveryState, startHeight, chainTip int32) (
+ []scanResult, error) {
+
+ // Cap the batch size to recoveryBatchSize to manage memory usage.
+ endHeight := min(startHeight+int32(recoveryBatchSize)-1, chainTip)
+
+ // Optimization: If we have nothing to watch, performing a
+ // "header-only" scan to advance the wallet's sync state without
+ // downloading full blocks or filters.
+ //
+ // NOTE: For targeted rescans, the state will never be empty as it is
+ // initialized with specific targets.
+ if scanState.Empty() {
+ log.Debugf("Performing header-only scan for %d blocks",
+ endHeight-startHeight+1)
+
+ return s.scanBatchHeadersOnly(ctx, startHeight, endHeight)
+ }
+
+ log.Debugf("Scanning %d blocks (height %d to %d) with %s",
+ endHeight-startHeight+1, startHeight, endHeight, scanState)
+
+ // Batch 1: Fetch all Block Hashes.
+ // TODO: Pass ctx when chainClient supports it.
+ hashes, err := s.cfg.Chain.GetBlockHashes(
+ int64(startHeight), int64(endHeight),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("batch get block hashes: %w", err)
+ }
+
+ return s.dispatchScanStrategy(ctx, scanState, startHeight, hashes)
+}
+
+// defaultMaxCFilterItems is the heuristic threshold for the number of items
+// (addresses + outpoints) in the watchlist at which the cost of client-side
+// GCS filter matching exceeds the cost of downloading and parsing full blocks.
+//
+// Calculation:
+// - CFilter Match: ~50ns per item (SIP hash). 100k items = 5ms per block.
+// - Full Block: ~10ms transfer (local) + ~10ms parse. Total ~20ms per block.
+//
+// While 100k items suggests ~5ms matching time, this is for a single filter.
+// In a batch of 200 blocks, total matching time is 1 second. However, if the
+// match rate is non-zero, we incur additional block download costs.
+//
+// At >100k items, the CPU load of matching becomes significant enough that
+// bypassing filters and streaming full blocks (especially from a local node)
+// is often more performant and uses less CPU time overall. This threshold is
+// conservative to favor CFilters for typical wallet sizes (<10k items).
+const defaultMaxCFilterItems = 100000
+
+// dispatchScanStrategy chooses and executes the appropriate scanning strategy
+// based on the wallet's configuration and heuristics.
+func (s *syncer) dispatchScanStrategy(ctx context.Context,
+ scanState *RecoveryState, startHeight int32,
+ hashes []chainhash.Hash) ([]scanResult, error) {
+
+ switch s.cfg.SyncMethod {
+ case SyncMethodFullBlocks:
+ return s.scanBatchWithFullBlocks(
+ ctx, scanState, startHeight, hashes,
+ )
+
+ // Attempt to use CFilters. If this fails (e.g. not supported by
+ // backend), we return the error directly as the user explicitly
+ // requested this method.
+ case SyncMethodCFilters:
+ return s.scanBatchWithCFilters(
+ ctx, scanState, startHeight, hashes,
+ )
+
+ case SyncMethodAuto:
+ // Check address/UTXO count heuristic. If we have > 100k items
+ // to watch, full block scanning is likely faster due to
+ // client-side filter matching CPU bottleneck.
+ threshold := s.cfg.MaxCFilterItems
+ if threshold == 0 {
+ threshold = defaultMaxCFilterItems
+ }
+
+ if scanState.WatchListSize() > threshold {
+ log.Infof("Auto sync: Watchlist size %d > %d, "+
+ "switching to full blocks for performance",
+ scanState.WatchListSize(), threshold)
+
+ return s.scanBatchWithFullBlocks(
+ ctx, scanState, startHeight, hashes,
+ )
+ }
+
+ // Try CFilters (Fast Path).
+ results, err := s.scanBatchWithCFilters(
+ ctx, scanState, startHeight, hashes,
+ )
+ if err == nil {
+ return results, nil
+ }
+
+ // If CFilters are unavailable (e.g. backend doesn't support
+ // them), fall back to full block scanning.
+ if errors.Is(err, ErrCFiltersUnavailable) {
+ log.Warnf("Batch GetCFilters unavailable: %v. "+
+ "Falling back to full block download.", err)
+
+ return s.scanBatchWithFullBlocks(
+ ctx, scanState, startHeight, hashes,
+ )
+ }
+
+ // If scanBatchWithCFilters failed for another reason, return
+ // the error.
+ return nil, err
+
+ default:
+ return nil, fmt.Errorf("%w: %v", ErrUnknownSyncMethod,
+ s.cfg.SyncMethod)
+ }
+}
+
+// advanceChainSync checks if the wallet is behind the chain tip and processes
+// a batch of blocks if necessary. It returns (syncFinished, error) where
+// syncFinished is true if the wallet is caught up to the best known tip, and
+// false if a sync operation was performed (or attempted) indicating that the
+// caller should continue polling.
+func (s *syncer) advanceChainSync(ctx context.Context) (bool, error) {
+ // Check the chain tip.
+ _, bestHeight, err := s.cfg.Chain.GetBestBlock()
+ if err != nil {
+ // An error getting best block height means we couldn't
+ // determine sync status. We are NOT finished, and an error
+ // occurred. Caller should retry.
+ return false, fmt.Errorf("unable to get best block height: %w",
+ err)
+ }
+
+ // Determine our current sync state.
+ syncedTo := s.addrStore.SyncedTo()
+
+ // If the wallet is caught up to the best known tip, log this and
+ // return.
+ if syncedTo.Height >= bestHeight {
+ s.state.Store(uint32(syncStateSynced))
+ log.Infof("Wallet is synced to chain tip: height=%d",
+ syncedTo.Height)
+
+ return true, nil
+ }
+
+ // Calculate the gap.
+ gap := bestHeight - syncedTo.Height
+
+ // If the gap is large (> 6 blocks), we treat it as a major event
+ // requiring Syncing state protection. Smaller gaps are handled
+ // silently to avoid disrupting user operations like CreateTx.
+ isLargeGap := gap > syncStateSwitchThreshold
+
+ if isLargeGap {
+ s.state.Store(uint32(syncStateSyncing))
+ }
+
+ // Wallet is behind, log the sync range and attempt to scan a batch.
+ log.Infof("Wallet is in syncing mode: from height %d to %d (gap=%d)",
+ syncedTo.Height+1, bestHeight, gap)
+
+ err = s.scanBatch(ctx, syncedTo, bestHeight)
+ if err != nil {
+ // Scan failed. Sync operation was attempted but not finished
+ // due to error.
+ return false, fmt.Errorf("failed to process batch: %w", err)
+ }
+
+ // Scan successful, but wallet might still be behind. Synchronization
+ // is NOT finished. Caller should continue looping to process the next
+ // batch.
+ return false, nil
+}
+
+// scanBatch fetches and processes a batch of blocks from the chain backend. It
+// handles fetching, CFilter matching, and DB updates.
+func (s *syncer) scanBatch(ctx context.Context, syncedTo waddrmgr.BlockStamp,
+ bestHeight int32) error {
+
+ // Prepare the full recovery state for syncing.
+ scanState, err := s.loadFullScanState(ctx)
+ if err != nil {
+ return err
+ }
+
+ // Fetch and Filter Blocks. The `fetchAndFilterBlocks` method is
+ // responsible for fetching a batch of blocks from the chain backend,
+ // filtering them for relevant transactions, and expanding address
+ // horizons. This phase primarily involves network I/O and in-memory
+ // processing. While it internally performs brief read-only database
+ // accesses (e.g., in `loadFullScanState`), it avoids holding
+ // long-lived write locks during potentially extensive network
+ // operations.
+ results, err := s.fetchAndFilterBlocks(
+ ctx, scanState, syncedTo.Height+1, bestHeight,
+ )
+ if err != nil {
+ return err
+ }
+ // Batch might be empty if:
+ // 1. We were interrupted by a quit signal or rescan job (handled
+ // above).
+ // 2. We encountered a backend error fetching the first block
+ // hash or filter (loop broke early).
+ // In either case, we return an error to let the chain loop sleep and
+ // retry.
+ if len(results) == 0 {
+ return fmt.Errorf("%w: scan batch empty", ErrScanBatchEmpty)
+ }
+ // Process Batch (Update). We do this in a single DB transaction.
+ return s.DBPutSyncBatch(ctx, results)
+}
+
+// handleChainUpdate processes a notification immediately.
+// It returns an error if processing fails or if the wallet is shutting down.
+func (s *syncer) handleChainUpdate(ctx context.Context, n any) error {
+ // For a single update, we process it and commit immediately.
+ err := s.processChainUpdate(ctx, n)
+ if err != nil {
+ return fmt.Errorf("failed to process chain update: %w", err)
+ }
+
+ switch msg := n.(type) {
+ case *chain.RescanProgress:
+ log.Debugf("Rescanned through block %v (height %d)",
+ msg.Hash, msg.Height)
+
+ // Consume and log the legacy RescanFinished notification. We no longer
+ // perform state updates here as the new controller- driven sync loop
+ // manages wallet synchronization.
+ case *chain.RescanFinished:
+ log.Debugf("Received legacy RescanFinished notification for "+
+ "block %v (height %d). No wallet state updates "+
+ "performed.", msg.Hash, msg.Height)
+ }
+
+ return nil
+}
+
+// processChainUpdate writes a single chain update to the database.
+func (s *syncer) processChainUpdate(ctx context.Context, update any) error {
+ switch n := update.(type) {
+ case chain.BlockConnected:
+ return s.DBPutSyncTip(ctx, wtxmgr.BlockMeta(n))
+
+ // A block was disconnected. We use checkRollback to safely verify our
+ // chain state against the backend and rewind if necessary. This
+ // handles both single block disconnects and deeper reorgs robustly.
+ case chain.BlockDisconnected:
+ return s.checkRollback(ctx)
+
+ // We only expect individual transaction notifications for unconfirmed
+ // transactions as they enter the mempool. Confirmed transactions are
+ // handled atomically via FilteredBlockConnected.
+ case chain.RelevantTx:
+ matches := s.prepareTxMatches([]*wtxmgr.TxRecord{n.TxRecord})
+ return s.DBPutTxns(ctx, matches, n.Block)
+
+ case chain.FilteredBlockConnected:
+ matches := s.prepareTxMatches(n.RelevantTxs)
+ return s.DBPutBlocks(ctx, matches, n.Block)
+ }
+
+ return nil
+}
+
+// prepareTxMatches extracts address entries from a batch of transactions and
+// groups them by transaction hash.
+func (s *syncer) prepareTxMatches(recs []*wtxmgr.TxRecord) TxEntries {
+ matches := make(TxEntries, 0, len(recs))
+ for _, rec := range recs {
+ entries := s.extractAddrEntries(rec.MsgTx.TxOut)
+ matches = append(matches, TxEntry{
+ Rec: rec,
+ Entries: entries,
+ })
+ }
+
+ return matches
+}
+
+// extractAddrEntries collects all addresses from transaction outputs and
+// creates initial AddrEntry objects with output indices.
+func (s *syncer) extractAddrEntries(txOuts []*wire.TxOut) []AddrEntry {
+ var entries []AddrEntry
+ for i, output := range txOuts {
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(
+ output.PkScript, s.cfg.ChainParams,
+ )
+ if err != nil {
+ log.Warnf("Cannot extract non-std pkScript=%x",
+ output.PkScript)
+
+ continue
+ }
+
+ for _, addr := range addrs {
+ entries = append(entries, AddrEntry{
+ Address: addr,
+ Credit: wtxmgr.CreditEntry{
+ //nolint:gosec // bounded.
+ Index: uint32(i),
+ },
+ })
+ }
+ }
+
+ return entries
+}
+
+// handleScanReq processes a user-initiated rescan request.
+func (s *syncer) handleScanReq(ctx context.Context,
+ req *scanReq) error {
+
+ // If the wallet is already syncing or rescanning, we can't accept a
+ // full resync request. This prevents conflicting rescan operations.
+ if s.isRecoveryMode() {
+ return fmt.Errorf("%w: wallet is currently %s",
+ ErrStateForbidden, s.syncState())
+ }
+
+ if req.typ == scanTypeTargeted {
+ return s.scanWithTargets(ctx, req)
+ }
+
+ return s.scanWithRewind(ctx, req)
+}
+
+// waitForEvent blocks until a notification, rescan job, or context
+// cancellation occurs, processing the event accordingly.
+func (s *syncer) waitForEvent(ctx context.Context) error {
+ select {
+ // Process asynchronous notifications from the chain backend, such as
+ // new blocks or transactions.
+ case n, ok := <-s.cfg.Chain.Notifications():
+ if !ok {
+ return ErrWalletShuttingDown
+ }
+
+ return s.handleChainUpdate(ctx, n)
+
+ // Handle synchronous rescan or resync requests submitted via the
+ // controller.
+ case job := <-s.scanReqChan:
+ return s.handleScanReq(ctx, job)
+
+ // Exit gracefully if the context is canceled or the wallet is shutting
+ // down.
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
+// scanWithRewind rewinds the wallet's sync status to the requested start
+// block.
+func (s *syncer) scanWithRewind(ctx context.Context, req *scanReq) error {
+ current := s.addrStore.SyncedTo()
+
+ if req.startBlock.Height >= current.Height {
+ // Requested start is ahead of or equal to current sync.
+ // Nothing to do (we are already synced past it).
+ return nil
+ }
+
+ log.Infof("Rewinding sync status from %d to %d for rescan",
+ current.Height, req.startBlock.Height)
+
+ // Rewind the database status.
+ err := s.DBPutRewind(ctx, req.startBlock)
+ if err != nil {
+ log.Errorf("Failed to rewind sync status: %v", err)
+
+ return err
+ }
+
+ return nil
+}
+
+// scanWithTargets performs a targeted rescan for specific accounts without
+// rewinding the global sync state.
+func (s *syncer) scanWithTargets(ctx context.Context, req *scanReq) error {
+ scanState, err := s.loadTargetedScanState(ctx, req.targets)
+ if err != nil {
+ return err
+ }
+
+ s.state.Store(uint32(syncStateRescanning))
+ defer s.state.Store(uint32(syncStateSynced))
+
+ startHeight := req.startBlock.Height
+
+ _, bestHeight, err := s.cfg.Chain.GetBestBlock()
+ if err != nil {
+ return fmt.Errorf("get best block: %w", err)
+ }
+
+ log.Infof("Starting targeted rescan from height %d to %d for %d "+
+ "accounts", startHeight, bestHeight, len(req.targets))
+
+ // Loop until caught up. We use an inclusive condition (<=) because
+ // startHeight represents the first block of the missing range and
+ // bestHeight is the last block (the chain tip). If we used a strict
+ // inequality (<), the tip would be skipped when the wallet is only one
+ // block behind.
+ for startHeight <= bestHeight {
+ // Cap end height.
+ endHeight := min(
+ startHeight+int32(recoveryBatchSize)-1, bestHeight,
+ )
+
+ // Use fetchAndFilterBlocks directly.
+ results, err := s.fetchAndFilterBlocks(
+ ctx, scanState, startHeight, endHeight,
+ )
+ if err != nil {
+ return err
+ }
+
+ if len(results) == 0 {
+ return fmt.Errorf("%w: fetchAndFilterBlocks returned "+
+ "0 results", ErrScanBatchEmpty)
+ }
+
+ // Process results (update DB).
+ err = s.DBPutTargetedBatch(ctx, results)
+ if err != nil {
+ return err
+ }
+
+ // Advance startHeight.
+ //nolint:gosec // batch size is bounded.
+ startHeight += int32(len(results))
+ }
+
+ log.Infof("Targeted rescan complete")
+
+ return nil
+}
+
+// loadTargetedScanState initializes a recovery state for a targeted rescan of
+// specific accounts.
+func (s *syncer) loadTargetedScanState(ctx context.Context,
+ targets []waddrmgr.AccountScope) (*RecoveryState, error) {
+
+ horizonData, initialAddrs, initialUnspent, err :=
+ s.loadTargetedScanData(ctx, targets)
+ if err != nil {
+ return nil, err
+ }
+
+ state := NewRecoveryState(
+ s.cfg.RecoveryWindow, s.cfg.ChainParams, s.addrStore,
+ )
+
+ err = state.Initialize(horizonData, initialAddrs, initialUnspent)
+ if err != nil {
+ return nil, fmt.Errorf("init scan state: %w", err)
+ }
+
+ return state, nil
+}
+
+// loadTargetedScanData retrieves all necessary data from the database to
+// initialize the recovery state for a targeted rescan.
+func (s *syncer) loadTargetedScanData(ctx context.Context,
+ targets []waddrmgr.AccountScope) ([]*waddrmgr.AccountProperties,
+ []address.Address, []wtxmgr.Credit, error) {
+
+ return s.DBGetScanData(ctx, targets)
+}
+
+// loadWalletScanData retrieves all necessary data from the database to
+// initialize the recovery state. This includes account horizons, active
+// addresses, and unspent outputs to watch.
+func (s *syncer) loadWalletScanData(ctx context.Context) (
+ []*waddrmgr.AccountProperties, []address.Address,
+ []wtxmgr.Credit, error) {
+
+ var targets []waddrmgr.AccountScope
+ for _, scopedMgr := range s.addrStore.ActiveScopedKeyManagers() {
+ for _, accNum := range scopedMgr.ActiveAccounts() {
+ targets = append(targets, waddrmgr.AccountScope{
+ Scope: scopedMgr.Scope(),
+ Account: accNum,
+ })
+ }
+ }
+
+ return s.DBGetScanData(ctx, targets)
+}
diff --git a/wallet/syncer_test.go b/wallet/syncer_test.go
new file mode 100644
index 0000000000..513bca9067
--- /dev/null
+++ b/wallet/syncer_test.go
@@ -0,0 +1,3528 @@
+package wallet
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcutil/v2/gcs"
+ "github.com/btcsuite/btcd/btcutil/v2/gcs/builder"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/chain"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+// TestSyncerInitialization verifies that a new syncer is created with the
+// correct default state.
+func TestSyncerInitialization(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize mock dependencies for the syncer.
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ // Act: Create a new syncer instance with a recovery window of 1.
+ s := newSyncer(
+ Config{RecoveryWindow: 1}, mockAddrStore, mockTxStore,
+ mockPublisher,
+ )
+
+ // Assert: Verify that the syncer is correctly initialized in the
+ // backend syncing state and is not in recovery mode.
+ require.NotNil(t, s)
+ require.Equal(t, syncStateBackendSyncing, s.syncState())
+ require.False(t, s.isRecoveryMode())
+}
+
+// TestSyncerRequestScan verifies that scan requests are correctly accepted
+// by the syncer's buffered channel.
+func TestSyncerRequestScan(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer and a rewind scan request.
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(Config{}, mockAddrStore, mockTxStore, mockPublisher)
+
+ req := &scanReq{
+ typ: scanTypeRewind,
+ startBlock: waddrmgr.BlockStamp{
+ Height: 100,
+ },
+ }
+
+ // Act: Submit the rewind request to the syncer.
+ err := s.requestScan(t.Context(), req)
+
+ // Assert: Ensure the request is accepted without error and is
+ // correctly placed in the scan request channel.
+ require.NoError(t, err)
+
+ select {
+ case received := <-s.scanReqChan:
+ require.Equal(t, req, received)
+ default:
+ require.Fail(t, "request not received")
+ }
+}
+
+// TestSyncerRequestScanBlocked verifies behavior when the channel is full.
+func TestSyncerRequestScanBlocked(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and fill its scan request buffer.
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(Config{}, mockAddrStore, mockTxStore, mockPublisher)
+
+ // Fill the buffer (size 1).
+ s.scanReqChan <- &scanReq{}
+
+ // Act: Attempt to submit another request with a context that is
+ // already canceled.
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel()
+
+ err := s.requestScan(ctx, &scanReq{})
+
+ // Assert: Verify that the request fails as expected due to the
+ // context cancellation.
+ require.Error(t, err)
+ require.ErrorIs(t, err, context.Canceled)
+}
+
+// TestSyncerRun verifies the run implementation.
+func TestSyncerRun(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and mock its chain and address store.
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(
+ Config{Chain: mockChain}, mockAddrStore, nil, mockPublisher,
+ )
+
+ // context cancellation.
+ mockAddrStore.On("Birthday").Return(time.Now()).Maybe()
+ mockChain.On("IsCurrent").Return(false).Maybe()
+ mockAddrStore.On("SyncedTo").Return(waddrmgr.BlockStamp{}).Maybe()
+ mockChain.On("NotifyBlocks").Return(nil).Maybe()
+
+ // Act: Execute the syncer's run loop with a context that is canceled
+ // immediately to stop the loop.
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel()
+
+ // Assert: The run loop should exit without error.
+ err := s.run(ctx)
+ require.NoError(t, err)
+}
+
+// TestWaitUntilBackendSynced verifies polling logic.
+func TestWaitUntilBackendSynced(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and mock its chain to simulate a
+ // delayed synchronization.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+
+ // Simulate the backend not being current on the first check, but
+ // becoming current on the second check.
+ mockChain.On("IsCurrent").Return(false).Once()
+ mockChain.On("IsCurrent").Return(true).Once()
+
+ // Act & Assert: Call waitUntilBackendSynced and verify it waits for
+ // the backend to sync before returning successfully.
+ err := s.waitUntilBackendSynced(t.Context())
+ require.NoError(t, err)
+ mockChain.AssertExpectations(t)
+}
+
+// TestCheckRollbackNoReorg verifies checkRollback when tips match.
+func TestCheckRollbackNoReorg(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer with a test database and mock chain.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ mockChain := &mockChain{}
+
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db}, mockAddrStore, nil, nil,
+ )
+
+ tip := waddrmgr.BlockStamp{Height: 100, Hash: chainhash.Hash{0x01}}
+ mockAddrStore.On("SyncedTo").Return(tip)
+
+ // Mock retrieval of synced block hashes from the database for the
+ // last 10 blocks.
+ for i := int32(91); i <= 100; i++ {
+ hash := chainhash.Hash{byte(i)}
+ mockAddrStore.On(
+ "BlockHash", mock.Anything, i,
+ ).Return(&hash, nil)
+ }
+
+ // Mock retrieval of matching block hashes from the remote chain.
+ remoteHashes := make([]chainhash.Hash, 10)
+ for i := range 10 {
+ remoteHashes[i] = chainhash.Hash{byte(91 + i)}
+ }
+
+ mockChain.On(
+ "GetBlockHashes", int64(91), int64(100),
+ ).Return(remoteHashes, nil).Once()
+
+ // Act & Assert: Verify that checkRollback completes without error
+ // and no rollback is triggered when hashes match.
+ err := s.checkRollback(t.Context())
+ require.NoError(t, err)
+}
+
+// TestCheckRollbackDetected verifies checkRollback when reorg is detected.
+func TestCheckRollbackDetected(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer with a test database and mocks to
+ // simulate a chain reorganization.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ mockChain := &mockChain{}
+ mockTxStore := &mockTxStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db}, mockAddrStore, mockTxStore,
+ mockPublisher,
+ )
+
+ tip := waddrmgr.BlockStamp{Height: 100, Hash: chainhash.Hash{0x01}}
+ mockAddrStore.On("SyncedTo").Return(tip)
+
+ // Mock retrieval of synced block hashes from the database for blocks
+ // 91 to 100.
+ for i := int32(91); i <= 100; i++ {
+ hash := chainhash.Hash{byte(i)}
+ mockAddrStore.On(
+ "BlockHash", mock.Anything, i,
+ ).Return(&hash, nil)
+ }
+
+ // Mock retrieval of remote block hashes where a fork occurs at
+ // height 95.
+ remoteHashes := make([]chainhash.Hash, 10)
+ for i := range 10 {
+ h := 91 + i
+ if h > 95 {
+ remoteHashes[i] = chainhash.Hash{0xff} // Mismatch
+ } else {
+ remoteHashes[i] = chainhash.Hash{byte(h)} // Match
+ }
+ }
+
+ mockChain.On(
+ "GetBlockHashes", int64(91), int64(100),
+ ).Return(remoteHashes, nil).Once()
+
+ // Mock header retrieval for the detected fork point at height 95.
+ forkHash := chainhash.Hash{byte(95)}
+ header := &wire.BlockHeader{Timestamp: time.Now()}
+ mockChain.On("GetBlockHeader", &forkHash).Return(header, nil).Once()
+
+ // Expect a rollback to the common ancestor at height 95 and a
+ // corresponding transaction store rollback.
+ mockAddrStore.On(
+ "SetSyncedTo", mock.Anything, mock.Anything,
+ ).Return(nil).Once()
+ mockTxStore.On("Rollback", mock.Anything, int32(96)).Return(nil).Once()
+
+ // Act & Assert: Verify that checkRollback correctly identifies the
+ // fork and performs the rollback.
+ err := s.checkRollback(t.Context())
+ require.NoError(t, err)
+}
+
+// TestInitChainSync verifies the initial synchronization sequence.
+func TestInitChainSync(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and mock its dependencies for the
+ // initial synchronization sequence.
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(
+ Config{Chain: mockChain}, mockAddrStore, nil, mockPublisher,
+ )
+
+ // Mock backend synchronization check.
+ mockChain.On("IsCurrent").Return(true).Once()
+
+ // Mock block notification registration.
+ mockChain.On("NotifyBlocks").Return(nil).Once()
+
+ // Mock rollback check at the start of synchronization.
+ tip := waddrmgr.BlockStamp{Height: 0}
+ mockAddrStore.On("SyncedTo").Return(tip)
+
+ // Act & Assert: Verify that the initial chain synchronization
+ // sequence completes successfully.
+ err := s.initChainSync(t.Context())
+ require.NoError(t, err)
+}
+
+// TestScanBatchHeadersOnly verifies header-only scan logic.
+func TestScanBatchHeadersOnly(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and mock block and header retrieval.
+ mockChain := &mockChain{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, mockPublisher)
+
+ hashes := []chainhash.Hash{{0x01}, {0x02}}
+ mockChain.On(
+ "GetBlockHashes", int64(10), int64(11),
+ ).Return(hashes, nil).Once()
+
+ headers := []*wire.BlockHeader{
+ {Timestamp: time.Unix(100, 0)},
+ {Timestamp: time.Unix(200, 0)},
+ }
+ mockChain.On("GetBlockHeaders", hashes).Return(headers, nil).Once()
+
+ // Act: Perform a header-only scan for blocks 10 and 11.
+ results, err := s.scanBatchHeadersOnly(t.Context(), 10, 11)
+
+ // Assert: Verify that the correct block results are returned with
+ // expected heights.
+ require.NoError(t, err)
+ require.Len(t, results, 2)
+ require.Equal(t, int32(10), results[0].meta.Height)
+ require.Equal(t, int32(11), results[1].meta.Height)
+}
+
+// TestSyncerLoadScanState verifies full scan state loading.
+func TestSyncerLoadScanState(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer with a test database and set up complex
+ // mock expectations for loading wallet scan data.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(
+ Config{
+ DB: db,
+ RecoveryWindow: 10,
+ ChainParams: &chainParams,
+ },
+ mockAddrStore, mockTxStore, mockPublisher,
+ )
+
+ // Mock active scoped key managers.
+ scopedMgr := &mockAccountStore{}
+ mockAddrStore.On(
+ "ActiveScopedKeyManagers",
+ ).Return([]waddrmgr.AccountStore{scopedMgr}).Once()
+
+ // Mock active accounts for the key manager scope.
+ scopedMgr.On("ActiveAccounts").Return([]uint32{0}).Once()
+ scopedMgr.On("Scope").Return(waddrmgr.KeyScopeBIP0084).Once()
+
+ // Mock database operations to fetch scan data, including key managers,
+ // account properties, active addresses, and outputs to watch.
+ mockAddrStore.On(
+ "FetchScopedKeyManager", mock.Anything,
+ ).Return(scopedMgr, nil).Times(3)
+
+ props := &waddrmgr.AccountProperties{
+ AccountNumber: 0,
+ KeyScope: waddrmgr.KeyScopeBIP0084,
+ }
+ scopedMgr.On(
+ "AccountProperties", mock.Anything, uint32(0),
+ ).Return(props, nil).Twice()
+
+ mockAddrStore.On(
+ "ForEachRelevantActiveAddress", mock.Anything, mock.Anything,
+ ).Return(nil).Once()
+
+ mockTxStore.On(
+ "OutputsToWatch", mock.Anything,
+ ).Return([]wtxmgr.Credit(nil), nil).Once()
+
+ // Mock address derivation for the lookahead window (10 addresses for
+ // each branch).
+ mockAddr := &mockAddress{}
+ mockAddr.On("EncodeAddress").Return("addr")
+ mockAddr.On("ScriptAddress").Return([]byte{0x00})
+ scopedMgr.On(
+ "DeriveAddr", mock.Anything, mock.Anything, mock.Anything,
+ ).Return(
+ mockAddr, []byte{0x00}, nil,
+ ).Maybe()
+
+ // Act: Load the full scan state from the database.
+ state, err := s.loadFullScanState(t.Context())
+
+ // Assert: Verify that the scan state is correctly loaded and not nil.
+ require.NoError(t, err)
+ require.NotNil(t, state)
+}
+
+// TestScanBatchWithFullBlocks verifies fallback scan logic.
+func TestScanBatchWithFullBlocks(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and a recovery state for scanning.
+ mockChain := &mockChain{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, mockPublisher)
+
+ mockAddrStore := &mockAddrStore{}
+ scanState := NewRecoveryState(
+ 10, &chainParams, mockAddrStore,
+ )
+
+ hashes := []chainhash.Hash{{0x01}}
+
+ // Create a mock block message for testing.
+ msgBlock := wire.NewMsgBlock(wire.NewBlockHeader(
+ 1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0,
+ ))
+ blocks := []*wire.MsgBlock{msgBlock}
+ mockChain.On(
+ "GetBlocks", hashes,
+ ).Return(blocks, nil).Once()
+
+ // Act: Perform a batch scan using full blocks.
+ results, err := s.scanBatchWithFullBlocks(
+ t.Context(), scanState, 10, hashes,
+ )
+
+ // Assert: Verify that the scan returned the expected block result.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+ require.Equal(t, int32(10), results[0].meta.Height)
+}
+
+// TestScanBatchWithCFilters verifies CFilter-based scan logic.
+func TestScanBatchWithCFilters(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and set up a recovery state.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockChain := &mockChain{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db}, nil, nil, mockPublisher,
+ )
+
+ mockAddrStore := &mockAddrStore{}
+ scanState := NewRecoveryState(
+ 10, &chainParams, mockAddrStore,
+ )
+
+ hashes := []chainhash.Hash{{0x01}}
+
+ // Mock retrieval of compact filters for the block batch.
+ filter, err := gcs.BuildGCSFilter(
+ builder.DefaultP, builder.DefaultM, [16]byte{}, nil,
+ )
+ require.NoError(t, err)
+ mockChain.On(
+ "GetCFilters", hashes, wire.GCSFilterRegular,
+ ).Return([]*gcs.Filter{filter}, nil).Once()
+
+ // Mock retrieval of block headers for the batch.
+ headers := []*wire.BlockHeader{{Timestamp: time.Unix(100, 0)}}
+ mockChain.On("GetBlockHeaders", hashes).Return(headers, nil).Once()
+
+ // Mock retrieval of full blocks for the batch (simulating a filter
+ // match).
+ msgBlock := wire.NewMsgBlock(wire.NewBlockHeader(
+ 1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0,
+ ))
+ mockChain.On("GetBlocks", hashes).Return(
+ []*wire.MsgBlock{msgBlock}, nil,
+ ).Once()
+
+ // Mock address store failures to simplify the test path and avoid
+ // deep derivation logic.
+ mockAddrStore.On(
+ "Address", mock.Anything, mock.Anything,
+ ).Return(nil, waddrmgr.ErrAddressNotFound).Maybe()
+ mockAddrStore.On(
+ "FetchScopedKeyManager", mock.Anything,
+ ).Return(nil, waddrmgr.ErrAddressNotFound).Maybe()
+
+ // Act: Perform a batch scan using CFilters.
+ results, err := s.scanBatchWithCFilters(
+ t.Context(), scanState, 10, hashes,
+ )
+
+ // Assert: Verify that the scan results are correct.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+ require.Equal(t, int32(10), results[0].meta.Height)
+}
+
+// TestDispatchScanStrategy verifies strategy selection.
+func TestDispatchScanStrategy(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and mock dependencies.
+ mockChain := &mockChain{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, mockPublisher)
+
+ scanState := NewRecoveryState(10, &chainParams, nil)
+ hashes := []chainhash.Hash{{0x01}}
+
+ // 1. Test the SyncMethodFullBlocks strategy.
+ s.cfg.SyncMethod = SyncMethodFullBlocks
+ msgBlock := wire.NewMsgBlock(wire.NewBlockHeader(
+ 1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0,
+ ))
+ mockChain.On(
+ "GetBlocks", hashes,
+ ).Return([]*wire.MsgBlock{msgBlock}, nil).Once()
+
+ // Act: Dispatch the scan strategy for full blocks.
+ results, err := s.dispatchScanStrategy(
+ t.Context(), scanState, 10, hashes,
+ )
+
+ // Assert: Verify that full blocks strategy was used.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+
+ // 2. Test the SyncMethodCFilters strategy.
+ s.cfg.SyncMethod = SyncMethodCFilters
+ filter, err := gcs.BuildGCSFilter(
+ builder.DefaultP, builder.DefaultM, [16]byte{}, nil,
+ )
+ require.NoError(t, err)
+
+ mockChain.On(
+ "GetCFilters", hashes, wire.GCSFilterRegular,
+ ).Return([]*gcs.Filter{filter}, nil).Once()
+ mockChain.On(
+ "GetBlockHeaders", hashes,
+ ).Return([]*wire.BlockHeader{{}}, nil).Once()
+
+ // Simulate a filter match (N=0) to force a full block download.
+ mockChain.On(
+ "GetBlocks", hashes,
+ ).Return([]*wire.MsgBlock{msgBlock}, nil).Once()
+
+ // Act: Dispatch the scan strategy for CFilters.
+ results, err = s.dispatchScanStrategy(
+ t.Context(), scanState, 10, hashes,
+ )
+
+ // Assert: Verify that CFilters strategy was used.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+}
+
+// TestScanBatch verifies the batch scanning entry point.
+func TestScanBatch(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer with a test database and set up mocks
+ // for the batch scan.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ mockChain := &mockChain{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db}, mockAddrStore, nil,
+ mockPublisher,
+ )
+
+ // Mock loading of the full scan state required by the batch scan.
+ scopedMgr := &mockAccountStore{}
+ scopedMgr.On("ActiveAccounts").Return([]uint32{0}).Once()
+ scopedMgr.On("Scope").Return(waddrmgr.KeyScopeBIP0084).Once()
+ scopedMgr.On(
+ "AccountProperties", mock.Anything, uint32(0),
+ ).Return(&waddrmgr.AccountProperties{}, nil).Twice()
+ mockAddrStore.On(
+ "ActiveScopedKeyManagers",
+ ).Return([]waddrmgr.AccountStore{scopedMgr}).Once()
+ mockAddrStore.On(
+ "FetchScopedKeyManager", mock.Anything,
+ ).Return(scopedMgr, nil).Times(3)
+ mockAddrStore.On(
+ "ForEachRelevantActiveAddress", mock.Anything, mock.Anything,
+ ).Return(nil).Once()
+
+ mockTxStore := &mockTxStore{}
+ s.txStore = mockTxStore
+ mockTxStore.On(
+ "OutputsToWatch", mock.Anything,
+ ).Return([]wtxmgr.Credit(nil), nil).Once()
+
+ // Mock expectations for header-only scanning when no targets are
+ // present.
+ hashes := []chainhash.Hash{{0x01}}
+ mockChain.On(
+ "GetBlockHashes", int64(11), int64(11),
+ ).Return(hashes, nil).Once()
+ mockChain.On(
+ "GetBlockHeaders", hashes,
+ ).Return([]*wire.BlockHeader{{}}, nil).Once()
+
+ // Expect the sync progress to be updated in the database.
+ mockAddrStore.On(
+ "SetSyncedTo", mock.Anything, mock.Anything,
+ ).Return(nil).Once()
+
+ // Act: Perform a batch scan from height 10 to 11.
+ err := s.scanBatch(t.Context(), waddrmgr.BlockStamp{Height: 10}, 11)
+
+ // Assert: Verify that the batch scan completed successfully.
+ require.NoError(t, err)
+}
+
+// TestFetchAndFilterBlocks verifies the block fetching and filtering helper.
+func TestFetchAndFilterBlocks(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and mock chain for block fetching.
+ mockChain := &mockChain{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, mockPublisher)
+
+ // Create an empty recovery state for testing.
+ scanState := NewRecoveryState(10, &chainParams, nil)
+ hashes := []chainhash.Hash{{0x01}}
+
+ // Mock expectations for header-only scanning when the recovery state
+ // is empty.
+ mockChain.On(
+ "GetBlockHashes", int64(10), int64(11),
+ ).Return(hashes, nil).Once()
+ mockChain.On(
+ "GetBlockHeaders", hashes,
+ ).Return([]*wire.BlockHeader{{}}, nil).Once()
+
+ // Act: Fetch and filter blocks for heights 10 to 11.
+ results, err := s.fetchAndFilterBlocks(t.Context(), scanState, 10, 11)
+
+ // Assert: Verify that the block results are correct.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+}
+
+// TestAdvanceChainSync verifies advancement logic.
+func TestAdvanceChainSync(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer with a test database and mocks to
+ // test the chain synchronization advancement logic.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db}, mockAddrStore, mockTxStore,
+ mockPublisher,
+ )
+
+ // Case 1: Test advancement when the wallet is already synced to the
+ // best block.
+ mockChain.On(
+ "GetBestBlock",
+ ).Return(&chainhash.Hash{}, int32(100), nil).Once()
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100},
+ ).Once()
+
+ // Act & Assert: Advance the chain sync and verify that it correctly
+ // identifies the synced state.
+ finished, err := s.advanceChainSync(t.Context())
+ require.NoError(t, err)
+ require.True(t, finished)
+ require.Equal(t, syncStateSynced, s.syncState())
+
+ // Case 2: Test advancement when the wallet is behind and needs to
+ // trigger a scan.
+ mockChain.On("GetBestBlock").Return(
+ &chainhash.Hash{}, int32(105), nil,
+ ).Once()
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100},
+ ).Once()
+
+ // Set up mocks for the batch scan triggered by advancement.
+ // Mock loading of the full scan state.
+ scopedMgr := &mockAccountStore{}
+ mockAddrStore.On(
+ "ActiveScopedKeyManagers",
+ ).Return([]waddrmgr.AccountStore{scopedMgr}).Once()
+ scopedMgr.On("ActiveAccounts").Return([]uint32{0}).Once()
+ scopedMgr.On("Scope").Return(waddrmgr.KeyScopeBIP0084).Once()
+ mockAddrStore.On(
+ "FetchScopedKeyManager", mock.Anything,
+ ).Return(scopedMgr, nil).Times(3)
+
+ props := &waddrmgr.AccountProperties{
+ AccountNumber: 0,
+ KeyScope: waddrmgr.KeyScopeBIP0084,
+ }
+ scopedMgr.On(
+ "AccountProperties", mock.Anything, uint32(0),
+ ).Return(props, nil).Twice()
+ mockAddrStore.On(
+ "ForEachRelevantActiveAddress", mock.Anything, mock.Anything,
+ ).Return(nil).Once()
+
+ mockTxStore.On(
+ "OutputsToWatch", mock.Anything,
+ ).Return([]wtxmgr.Credit(nil), nil).Once()
+
+ scopedMgr.On(
+ "DeriveAddr", mock.Anything, mock.Anything, mock.Anything,
+ ).Return(
+ &mockAddress{}, []byte{}, nil,
+ ).Maybe()
+
+ // Mock fetching and filtering of blocks for the missing height range.
+ // Mock retrieval of block hashes when scan targets are present.
+ hashes := []chainhash.Hash{{0x01}, {0x02}, {0x03}, {0x04}, {0x05}}
+ mockChain.On(
+ "GetBlockHashes", int64(101), int64(105),
+ ).Return(hashes, nil).Once()
+
+ // Mock the scan strategy dispatch for the block batch.
+ filter, err := gcs.BuildGCSFilter(
+ builder.DefaultP, builder.DefaultM, [16]byte{}, nil,
+ )
+ require.NoError(t, err)
+
+ filters := make([]*gcs.Filter, 5)
+ for i := range 5 {
+ filters[i] = filter
+ }
+
+ mockChain.On(
+ "GetCFilters", hashes, wire.GCSFilterRegular,
+ ).Return(filters, nil).Once()
+
+ headers := make([]*wire.BlockHeader, 5)
+ for i := range 5 {
+ headers[i] = &wire.BlockHeader{}
+ }
+
+ mockChain.On("GetBlockHeaders", hashes).Return(headers, nil).Once()
+
+ // Simulate filter matches for all blocks to force full block downloads.
+ msgBlock := wire.NewMsgBlock(wire.NewBlockHeader(
+ 1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0,
+ ))
+
+ blocks := make([]*wire.MsgBlock, 5)
+ for i := range 5 {
+ blocks[i] = msgBlock
+ }
+
+ mockChain.On("GetBlocks", hashes).Return(blocks, nil).Once()
+
+ // Expect the sync progress to be updated for each block in the batch.
+ mockAddrStore.On(
+ "SetSyncedTo", mock.Anything, mock.Anything,
+ ).Return(nil).Times(5)
+
+ // Act & Assert: Advance the chain sync and verify that it triggers
+ // the expected batch scan.
+ finished, err = s.advanceChainSync(t.Context())
+ require.NoError(t, err)
+ require.False(t, finished)
+}
+
+// TestHandleChainUpdate verifies notification handling.
+func TestHandleChainUpdate(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and mock its dependencies for
+ // handling chain updates.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db}, mockAddrStore, mockTxStore,
+ mockPublisher,
+ )
+
+ // Case 1: Test handling of a BlockConnected notification.
+ meta := wtxmgr.BlockMeta{Block: wtxmgr.Block{Height: 100}}
+
+ mockAddrStore.On(
+ "SetSyncedTo", mock.Anything, mock.Anything,
+ ).Return(nil).Once()
+
+ // Act & Assert: Verify that a BlockConnected notification is
+ // correctly processed.
+ err := s.handleChainUpdate(t.Context(), chain.BlockConnected(meta))
+ require.NoError(t, err)
+
+ // Case 2: Test handling of a RelevantTx notification.
+ tx := wire.NewMsgTx(1)
+ rec, err := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
+ require.NoError(t, err)
+ mockTxStore.On(
+ "InsertUnconfirmedTx", mock.Anything, mock.Anything,
+ mock.Anything,
+ ).Return(nil).Once()
+
+ // Act & Assert: Verify that a RelevantTx notification is correctly
+ // processed.
+ err = s.handleChainUpdate(t.Context(), chain.RelevantTx{TxRecord: rec})
+ require.NoError(t, err)
+}
+
+// TestExtractAddrEntries verifies address extraction from outputs.
+func TestExtractAddrEntries(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and create a P2PKH output for address
+ // extraction.
+ mockPublisher := &mockTxPublisher{}
+ s := newSyncer(
+ Config{ChainParams: &chainParams}, nil, nil,
+ mockPublisher,
+ )
+
+ addr, err := address.NewAddressPubKeyHash(
+ make([]byte, 20), &chainParams,
+ )
+ require.NoError(t, err)
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ txOut := &wire.TxOut{Value: 1000, PkScript: pkScript}
+
+ // Act: Extract address entries from the output.
+ entries := s.extractAddrEntries([]*wire.TxOut{txOut})
+
+ // Assert: Verify that the correct address was extracted.
+ require.Len(t, entries, 1)
+ require.Equal(t, addr.String(), entries[0].Address.String())
+ require.Equal(t, uint32(0), entries[0].Credit.Index)
+}
+
+// TestHandleScanReq verifies scan request handling.
+func TestHandleScanReq(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer with a test database and mocks to
+ // test handling of different scan request types.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(
+ Config{DB: db}, mockAddrStore, nil, mockPublisher,
+ )
+
+ // Case 1: Test handling of a rewind scan request.
+ req := &scanReq{
+ typ: scanTypeRewind,
+ startBlock: waddrmgr.BlockStamp{Height: 50},
+ }
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100},
+ ).Once()
+
+ // Expect sync state update and transaction rollback for the rewind.
+ mockAddrStore.On(
+ "SetSyncedTo", mock.Anything, mock.Anything,
+ ).Return(nil).Once()
+
+ mockTxStore := &mockTxStore{}
+ s.txStore = mockTxStore
+ mockTxStore.On("Rollback", mock.Anything, int32(51)).Return(nil).Once()
+
+ // Act & Assert: Verify that a rewind scan request is correctly handled.
+ err := s.handleScanReq(t.Context(), req)
+ require.NoError(t, err)
+
+ // Case 2: Test handling of a targeted scan request.
+ req = &scanReq{
+ typ: scanTypeTargeted,
+ startBlock: waddrmgr.BlockStamp{Height: 100},
+ targets: []waddrmgr.AccountScope{{Account: 1}},
+ }
+ mockChain := &mockChain{}
+ s.cfg.Chain = mockChain
+ mockChain.On("GetBestBlock").Return(
+ &chainhash.Hash{}, int32(101), nil,
+ ).Once()
+
+ // Mock loading of targeted scan data.
+ scopedMgr := &mockAccountStore{}
+ mockAddrStore.On(
+ "FetchScopedKeyManager", mock.Anything,
+ ).Return(scopedMgr, nil).Times(3)
+
+ // Set up mocks for initializing targeted scan state.
+ props := &waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ KeyScope: waddrmgr.KeyScopeBIP0084,
+ }
+ scopedMgr.On(
+ "AccountProperties", mock.Anything, uint32(1),
+ ).Return(props, nil).Twice()
+ // ActiveAccounts might not be called in targeted scan flow.
+ scopedMgr.On("ActiveAccounts").Return([]uint32{1}).Maybe()
+ mockAddrStore.On(
+ "ForEachRelevantActiveAddress", mock.Anything, mock.Anything,
+ ).Return(nil).Once()
+ mockTxStore.On(
+ "OutputsToWatch", mock.Anything,
+ ).Return([]wtxmgr.Credit(nil), nil).Once()
+
+ // DeriveAddr is called multiple times during state initialization.
+ // Use Maybe() to avoid assertions on specific iteration counts.
+ scopedMgr.On(
+ "DeriveAddr", mock.Anything, mock.Anything, mock.Anything,
+ ).Return(&mockAddress{}, []byte{}, nil).Maybe()
+
+ // Mock block hash retrieval for the targeted scan range.
+ mockChain.On(
+ "GetBlockHashes", int64(100), int64(101),
+ ).Return([]chainhash.Hash{{0x01}, {0x02}}, nil).Once()
+
+ // Mock CFilter-based scanning for the targeted scan.
+ mockChain.On(
+ "GetCFilters", mock.Anything, mock.Anything,
+ ).Return([]*gcs.Filter{nil, nil}, nil).Once()
+ mockChain.On(
+ "GetBlockHeaders", mock.Anything,
+ ).Return([]*wire.BlockHeader{{}, {}}, nil).Once()
+
+ msgBlock := wire.NewMsgBlock(wire.NewBlockHeader(
+ 1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0,
+ ))
+
+ blocks := make([]*wire.MsgBlock, 2)
+ for i := range 2 {
+ blocks[i] = msgBlock
+ }
+
+ mockChain.On("GetBlocks", mock.Anything).Return(blocks, nil).Once()
+
+ // Act & Assert: Verify that a targeted scan request is correctly
+ // handled.
+ err = s.handleScanReq(t.Context(), req)
+ require.NoError(t, err)
+}
+
+// TestWaitForEvent verifies event loop idling and dispatch.
+func TestWaitForEvent(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and mock its dependencies for testing
+ // the event loop.
+ mockChain := &mockChain{}
+ mockPublisher := &mockTxPublisher{}
+ mockAddrStore := &mockAddrStore{}
+
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ s := newSyncer(
+ Config{
+ Chain: mockChain,
+ DB: db,
+ },
+ mockAddrStore, nil, mockPublisher,
+ )
+
+ // Mock chain notifications channel.
+ notificationChan := make(chan any, 1)
+ mockChain.On("Notifications").Return((<-chan any)(notificationChan))
+
+ // Case 1: Test event handling when a chain notification arrives.
+ notificationChan <- chain.BlockConnected{}
+
+ // Mock sync progress update resulting from the chain notification.
+ mockAddrStore.On(
+ "SetSyncedTo", mock.Anything, mock.Anything,
+ ).Return(nil).Once()
+
+ // Act & Assert: Call waitForEvent and verify it correctly processes
+ // the arriving notification.
+ err := s.waitForEvent(t.Context())
+ require.NoError(t, err)
+
+ // Case 2: Test event handling when a scan request arrives.
+ s.scanReqChan <- &scanReq{typ: scanTypeRewind}
+
+ mockAddrStore.On("SyncedTo").Return(waddrmgr.BlockStamp{}).Once()
+
+ // Act & Assert: Call waitForEvent and verify it correctly processes
+ // the arriving scan request.
+ err = s.waitForEvent(t.Context())
+ require.NoError(t, err)
+}
+
+// TestSyncerFullRun verifies the full run loop coordination.
+func TestSyncerFullRun(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer with a test database and set up
+ // extensive mocks to simulate a full run loop execution.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db}, mockAddrStore, nil,
+ mockPublisher,
+ )
+
+ // Mock initial chain sync sequence.
+ mockAddrStore.On("Birthday").Return(time.Now()).Once()
+ mockChain.On("IsCurrent").Return(true).Once()
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100},
+ ).Once()
+
+ // Mock rollback check dependencies.
+ mockAddrStore.On(
+ "BlockHash", mock.Anything, mock.Anything,
+ ).Return(&chainhash.Hash{}, nil).Maybe()
+
+ // Mock remote hashes for rollback check (batch size 10).
+ remoteHashes := make([]chainhash.Hash, 10)
+ mockChain.On(
+ "GetBlockHashes", mock.Anything, mock.Anything,
+ ).Return(remoteHashes, nil).Maybe()
+ mockChain.On("NotifyBlocks").Return(nil).Once()
+
+ // Mock advancement to the current best block.
+ mockChain.On(
+ "GetBestBlock",
+ ).Return(&chainhash.Hash{}, int32(100), nil).Once()
+
+ // Mock synced state retrieval.
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100},
+ ).Once()
+
+ // Mock retrieval of unmined transactions from the store.
+ mockTxStore := &mockTxStore{}
+ s.txStore = mockTxStore
+ mockTxStore.On("UnminedTxs", mock.Anything).Return(
+ []*wire.MsgTx(nil), nil,
+ ).Once()
+
+ // Set up for the event waiting phase of the run loop.
+ ctx, cancel := context.WithCancel(t.Context())
+
+ // Use a goroutine to cancel the context after a delay to allow the
+ // syncer to enter its event loop.
+ go func() {
+ time.Sleep(1500 * time.Millisecond)
+ cancel()
+ }()
+
+ notificationChan := make(chan any)
+ mockChain.On("Notifications").Return((<-chan any)(notificationChan))
+
+ // Act & Assert: Execute the syncer's run loop and verify that it
+ // completes all initial sync steps and enters the idle loop.
+ err := s.run(ctx)
+ require.NoError(t, err)
+}
+
+var (
+ errDBMockSync = errors.New("db error")
+ errCFilter = errors.New("not supported")
+)
+
+// TestProcessChainUpdate_Disconnect verifies rollback on block disconnect.
+func TestProcessChainUpdate_Disconnect(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and mock its dependencies for handling
+ // a block disconnect.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db}, mockAddrStore, mockTxStore,
+ mockPublisher,
+ )
+
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100},
+ ).Once()
+
+ mockAddrStore.On(
+ "BlockHash", mock.Anything, mock.Anything,
+ ).Return(&chainhash.Hash{}, nil).Maybe()
+
+ remoteHashes := make([]chainhash.Hash, 10)
+ mockChain.On("GetBlockHashes", mock.Anything, mock.Anything).Return(
+ remoteHashes, nil,
+ ).Once()
+
+ // Act & Assert: Process a BlockDisconnected notification and verify
+ // that it triggers a rollback check.
+ err := s.processChainUpdate(t.Context(), chain.BlockDisconnected{})
+ require.NoError(t, err)
+}
+
+// TestBroadcastUnminedTxns_Error verifies error handling.
+func TestBroadcastUnminedTxns_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and mock an error during unmined
+ // transactions retrieval.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockTxStore := &mockTxStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(Config{DB: db}, nil, mockTxStore, mockPublisher)
+
+ mockTxStore.On("UnminedTxs", mock.Anything).Return(
+ ([]*wire.MsgTx)(nil), errDBMockSync,
+ ).Once()
+
+ // Act & Assert: Verify that broadcasting unmined transactions
+ // returns the expected database error.
+ err := s.broadcastUnminedTxns(t.Context())
+ require.Error(t, err)
+}
+
+// TestInitChainSync_BackendNotSynced verifies it waits/errors.
+func TestInitChainSync_BackendNotSynced(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and mock the backend as not being
+ // current to test initialization timeout.
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(
+ Config{Chain: mockChain}, mockAddrStore, nil, mockPublisher,
+ )
+
+ mockAddrStore.On("Birthday").Return(time.Now()).Once()
+ mockChain.On("IsCurrent").Return(false)
+
+ ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond)
+ defer cancel()
+
+ // Act & Assert: Verify that initialization fails due to timeout
+ // when the backend never becomes current.
+ err := s.initChainSync(ctx)
+ require.Error(t, err)
+}
+
+// TestDispatchScanStrategy_CFilterFail verifies fallback.
+func TestDispatchScanStrategy_CFilterFail(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and mock a CFilter retrieval failure
+ // to test fallback to full block scanning.
+ mockChain := &mockChain{}
+ mockPublisher := &mockTxPublisher{}
+ s := newSyncer(
+ Config{Chain: mockChain, SyncMethod: SyncMethodAuto}, nil, nil,
+ mockPublisher,
+ )
+ mockAddrStore := &mockAddrStore{}
+ scanState := NewRecoveryState(
+ 10, &chainParams, mockAddrStore,
+ )
+ hashes := []chainhash.Hash{{0x01}}
+
+ mockChain.On(
+ "GetCFilters", hashes, wire.GCSFilterRegular,
+ ).Return(([]*gcs.Filter)(nil), errCFilter).Once()
+
+ msgBlock := wire.NewMsgBlock(wire.NewBlockHeader(
+ 1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0,
+ ))
+ mockChain.On(
+ "GetBlocks", hashes,
+ ).Return([]*wire.MsgBlock{msgBlock}, nil).Once()
+
+ // Act: Dispatch the scan strategy when CFilters are not supported.
+ results, err := s.dispatchScanStrategy(
+ t.Context(), scanState, 10, hashes,
+ )
+
+ // Assert: Verify that the scan fell back to full blocks successfully.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+}
+
+// TestFilterBatch_MatchFound verifies logic when CFilter matches.
+func TestFilterBatch_MatchFound(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer configured for CFilter scanning.
+ mockChain := &mockChain{}
+ s := newSyncer(
+ Config{Chain: mockChain, SyncMethod: SyncMethodCFilters},
+ nil, nil, nil,
+ )
+
+ // Create a filter that matches "data".
+ data := []byte("match_me")
+ filter, err := gcs.BuildGCSFilter(
+ builder.DefaultP, builder.DefaultM, [16]byte{}, [][]byte{data},
+ )
+ require.NoError(t, err)
+
+ // Setup scan state watching the data.
+ scanState := NewRecoveryState(10, &chainParams, nil)
+
+ mockAddr := &mockAddress{}
+ mockAddr.On("ScriptAddress").Return(data)
+ mockAddr.On("String").Return("addr")
+
+ scopeState := scanState.StateForScope(waddrmgr.KeyScopeBIP0084)
+ scopeState.ExternalBranch.AddAddr(0, mockAddr)
+
+ hashes := []chainhash.Hash{{0x01}}
+ mockChain.On(
+ "GetCFilters", hashes, wire.GCSFilterRegular,
+ ).Return([]*gcs.Filter{filter}, nil).Once()
+
+ // Expect full block fetch due to filter match.
+ msgBlock := wire.NewMsgBlock(wire.NewBlockHeader(
+ 1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0,
+ ))
+ mockChain.On("GetBlocks", hashes).Return(
+ []*wire.MsgBlock{msgBlock}, nil,
+ ).Once()
+
+ mockChain.On("GetBlockHeaders", hashes).Return(
+ []*wire.BlockHeader{{}}, nil,
+ ).Once()
+
+ // Act: Perform the scan.
+ results, err := s.scanBatchWithCFilters(
+ t.Context(), scanState, 10, hashes,
+ )
+
+ // Assert: Verify results.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+}
+
+// TestScanBatchWithCFilters_GetHeadersFail verifies error handling.
+func TestScanBatchWithCFilters_GetHeadersFail(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer and mock CFilter success but header retrieval
+ // failure.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+ scanState := NewRecoveryState(10, &chainParams, nil)
+ hashes := []chainhash.Hash{{0x01}}
+
+ filter, err := gcs.BuildGCSFilter(
+ builder.DefaultP, builder.DefaultM, [16]byte{}, nil,
+ )
+ require.NoError(t, err)
+
+ mockChain.On(
+ "GetCFilters", hashes, wire.GCSFilterRegular,
+ ).Return([]*gcs.Filter{filter}, nil).Once()
+
+ mockChain.On(
+ "GetBlockHeaders", hashes,
+ ).Return(([]*wire.BlockHeader)(nil), errHeaders).Once()
+
+ // Act: Attempt to scan the batch.
+ results, err := s.scanBatchWithCFilters(
+ t.Context(), scanState, 10, hashes,
+ )
+
+ // Assert: Verify error propagation.
+ require.Nil(t, results)
+ require.ErrorContains(t, err, "headers fail")
+}
+
+// TestFetchAndFilterBlocks_NonEmpty verifies block fetching and filtering
+// when the scan state is NOT empty.
+func TestFetchAndFilterBlocks_NonEmpty(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer with a non-empty scan state.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+
+ scanState := NewRecoveryState(10, &chainParams, nil)
+ scanState.AddWatchedOutPoint(&wire.OutPoint{Index: 0}, nil)
+
+ hashes := []chainhash.Hash{{0x01}}
+ mockChain.On(
+ "GetBlockHashes", int64(10), int64(11),
+ ).Return(hashes, nil).Once()
+
+ filter, err := gcs.BuildGCSFilter(
+ builder.DefaultP, builder.DefaultM, [16]byte{}, nil,
+ )
+ require.NoError(t, err)
+ mockChain.On(
+ "GetCFilters", hashes, wire.GCSFilterRegular,
+ ).Return([]*gcs.Filter{filter}, nil).Once()
+ mockChain.On("GetBlockHeaders", hashes).Return(
+ []*wire.BlockHeader{{}}, nil).Once()
+
+ // Act: Fetch and filter blocks.
+ results, err := s.fetchAndFilterBlocks(
+ t.Context(), scanState, 10, 11,
+ )
+
+ // Assert: Verify results.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+}
+
+// TestFetchAndFilterBlocks_Errors verifies error paths.
+func TestFetchAndFilterBlocks_Errors(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer with a non-empty scan state and mock a hash
+ // fetch failure.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+ scanState := NewRecoveryState(10, &chainParams, nil)
+ scanState.AddWatchedOutPoint(&wire.OutPoint{Index: 0}, nil)
+
+ mockChain.On(
+ "GetBlockHashes", int64(10), int64(11),
+ ).Return([]chainhash.Hash(nil), errChainMock).Once()
+
+ // Act: Attempt to fetch and filter blocks.
+ results, err := s.fetchAndFilterBlocks(
+ t.Context(), scanState, 10, 11,
+ )
+
+ // Assert: Verify error propagation.
+ require.Nil(t, results)
+ require.ErrorContains(t, err, "chain error")
+}
+
+// TestScanBatch_Empty verifies error when fetchAndFilterBlocks returns 0.
+func TestScanBatch_Empty(t *testing.T) {
+ t.Parallel()
+
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ // Arrange: Setup a syncer that returns empty blocks during a batch
+ // scan.
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db},
+ mockAddrStore, mockTxStore, nil,
+ )
+
+ mockAddrStore.On("ActiveScopedKeyManagers").Return(
+ []waddrmgr.AccountStore{}).Once()
+
+ mockTxStore.On("OutputsToWatch", mock.Anything).Return(
+ []wtxmgr.Credit(nil), nil).Once()
+ mockAddrStore.On("ForEachRelevantActiveAddress", mock.Anything,
+ mock.Anything).Return(nil).Once()
+
+ mockChain.On("GetBlockHashes", mock.Anything, mock.Anything).Return(
+ []chainhash.Hash{}, nil).Once()
+ mockChain.On("GetBlockHeaders", []chainhash.Hash{}).Return(
+ []*wire.BlockHeader{}, nil).Once()
+
+ // Act: Attempt to scan the batch.
+ err := s.scanBatch(
+ t.Context(), waddrmgr.BlockStamp{Height: 10}, 11,
+ )
+
+ // Assert: Verify that the empty batch error is returned.
+ require.ErrorIs(t, err, ErrScanBatchEmpty)
+}
+
+// TestInitChainSync_Errors verifies initChainSync error paths.
+func TestInitChainSync_Errors(t *testing.T) {
+ t.Parallel()
+
+ t.Run("CheckRollback_Failure", func(t *testing.T) {
+ t.Parallel()
+
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ // Arrange: Setup a syncer where DB operations fail during
+ // rollback check.
+ mockChain := &mockChain{}
+ addrStore := &mockAddrStore{}
+
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db}, addrStore, nil, nil,
+ )
+
+ mockChain.On("IsCurrent").Return(true).Maybe()
+ addrStore.On("Birthday").Return(time.Now()).Maybe()
+ addrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100},
+ )
+ addrStore.On("BlockHash", mock.Anything, mock.Anything).Return(
+ &chainhash.Hash{}, errDBMock).Once()
+
+ // Act: Attempt initialization.
+ err := s.initChainSync(t.Context())
+
+ // Assert: Verify error.
+ require.ErrorContains(t, err, "db error")
+ })
+
+ t.Run("NotifyBlocks_Failure", func(t *testing.T) {
+ t.Parallel()
+
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ // Arrange: Setup a syncer where block notifications fail.
+ mockChain := &mockChain{}
+ addrStore := &mockAddrStore{}
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db}, addrStore, nil, nil,
+ )
+
+ mockChain.On("IsCurrent").Return(true).Maybe()
+ addrStore.On("Birthday").Return(time.Now()).Maybe()
+ addrStore.On("SyncedTo").Return(waddrmgr.BlockStamp{Height: 0})
+ mockChain.On("NotifyBlocks").Return(errNotify).Once()
+
+ // Act: Attempt initialization.
+ err := s.initChainSync(t.Context())
+
+ // Assert: Verify error.
+ require.ErrorContains(t, err, "notify fail")
+ })
+}
+
+// TestHandleScanReq_Errors verifies handleScanReq error paths.
+func TestHandleScanReq_Errors(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer already in syncing state.
+ s := newSyncer(Config{}, nil, nil, nil)
+ s.state.Store(uint32(syncStateSyncing))
+
+ // Act: Attempt to handle a scan request.
+ err := s.handleScanReq(t.Context(), &scanReq{})
+
+ // Assert: Verify state forbidden error.
+ require.ErrorIs(t, err, ErrStateForbidden)
+}
+
+// TestSyncerRun_InitError verifies run failure when initChainSync fails.
+func TestSyncerRun_InitError(t *testing.T) {
+ t.Parallel()
+
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ // Arrange: Setup a syncer where initialization fails.
+ mockChain := &mockChain{}
+ addrStore := &mockAddrStore{}
+
+ s := newSyncer(Config{Chain: mockChain, DB: db}, addrStore, nil, nil)
+
+ addrStore.On("Birthday").Return(time.Now()).Once()
+ mockChain.On("IsCurrent").Return(true).Once()
+
+ addrStore.On("SyncedTo").Return(waddrmgr.BlockStamp{Height: 100})
+ addrStore.On("BlockHash", mock.Anything, mock.Anything).Return(
+ &chainhash.Hash{}, errDBMock).Once()
+
+ // Act: Run the syncer.
+ err := s.run(t.Context())
+
+ // Assert: Verify error propagation.
+ require.ErrorContains(t, err, "db error")
+}
+
+// TestHandleChainUpdate_BlockDisconnected verifies handleChainUpdate for
+// BlockDisconnected.
+func TestHandleChainUpdate_BlockDisconnected(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer and dependencies for handling updates.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ mockChain := &mockChain{}
+ s := newSyncer(
+ Config{
+ Chain: mockChain,
+ ChainParams: &chainParams,
+ DB: db,
+ },
+ mockAddrStore, mockTxStore, nil,
+ )
+
+ // 1. BlockDisconnected.
+ mockTxStore.On("Rollback", mock.Anything, int32(100)).Return(nil).Once()
+ mockAddrStore.On("SetSyncedTo", mock.Anything, mock.Anything).Return(
+ nil).Once()
+ mockAddrStore.On("SyncedTo").Return(waddrmgr.BlockStamp{Height: 100})
+
+ for i := int32(91); i <= 100; i++ {
+ hash := chainhash.Hash{byte(i)}
+ mockAddrStore.On("BlockHash", mock.Anything, i).Return(
+ &hash, nil).Maybe()
+ }
+
+ remoteHashes := make([]chainhash.Hash, 10)
+ for i := range 10 {
+ remoteHashes[i] = chainhash.Hash{byte(91 + i)}
+ }
+
+ mockChain.On("GetBlockHashes", int64(91), int64(100)).Return(
+ remoteHashes, nil).Once()
+
+ // Act: Handle BlockDisconnected.
+ err := s.handleChainUpdate(
+ t.Context(), chain.BlockDisconnected{
+ Block: wtxmgr.Block{Height: 100},
+ },
+ )
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+}
+
+// TestDispatchScanStrategy_AutoFallback verifies fallback to full blocks
+// when watchlist is too large.
+func TestDispatchScanStrategy_AutoFallback(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer with a low filter item threshold to force
+ // fallback.
+ mockChain := &mockChain{}
+ s := newSyncer(
+ Config{
+ Chain: mockChain,
+ SyncMethod: SyncMethodAuto,
+ MaxCFilterItems: 1,
+ }, nil, nil, nil,
+ )
+ scanState := NewRecoveryState(10, &chainParams, nil)
+
+ // Add 2 items (threshold 1).
+ credits := make([]wtxmgr.Credit, 2)
+ for i := range credits {
+ credits[i] = wtxmgr.Credit{
+ OutPoint: wire.OutPoint{Index: uint32(i)},
+ PkScript: []byte{0x00},
+ }
+ }
+
+ err := scanState.Initialize(nil, nil, credits)
+ require.NoError(t, err)
+
+ hashes := []chainhash.Hash{{0x01}}
+ msgBlock := wire.NewMsgBlock(wire.NewBlockHeader(
+ 1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0,
+ ))
+ mockChain.On(
+ "GetBlocks", hashes,
+ ).Return([]*wire.MsgBlock{msgBlock}, nil).Once()
+
+ // Act: Dispatch the scan strategy.
+ results, err := s.dispatchScanStrategy(
+ t.Context(), scanState, 10, hashes,
+ )
+
+ // Assert: Verify results.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+}
+
+// TestBroadcastUnminedTxns_Success verifies successful broadcast.
+func TestBroadcastUnminedTxns_Success(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer and mock successful transaction retrieval
+ // and broadcast.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockTxStore := &mockTxStore{}
+ mockPublisher := &mockTxPublisher{}
+
+ s := newSyncer(Config{DB: db}, nil, mockTxStore, mockPublisher)
+
+ tx := wire.NewMsgTx(1)
+ mockTxStore.On("UnminedTxs", mock.Anything).Return(
+ []*wire.MsgTx{tx}, nil,
+ ).Once()
+ mockPublisher.On("Broadcast", mock.Anything, tx, "").Return(nil).Once()
+
+ // Act: Broadcast unmined transactions.
+ err := s.broadcastUnminedTxns(t.Context())
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+}
+
+// TestFilterBatch_EmptyFilter verifies that empty filters force download.
+func TestFilterBatch_EmptyFilter(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer and mock an empty filter response.
+ mockChain := &mockChain{}
+ s := newSyncer(
+ Config{Chain: mockChain, SyncMethod: SyncMethodCFilters},
+ nil, nil, nil,
+ )
+
+ emptyFilter, err := gcs.BuildGCSFilter(
+ builder.DefaultP, builder.DefaultM, [16]byte{}, nil,
+ )
+ require.NoError(t, err)
+
+ scanState := NewRecoveryState(10, &chainParams, nil)
+ hashes := []chainhash.Hash{{0x01}}
+ mockChain.On(
+ "GetCFilters", hashes, wire.GCSFilterRegular,
+ ).Return([]*gcs.Filter{emptyFilter}, nil).Once()
+
+ msgBlock := wire.NewMsgBlock(wire.NewBlockHeader(
+ 1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0,
+ ))
+ mockChain.On("GetBlocks", hashes).Return(
+ []*wire.MsgBlock{msgBlock}, nil,
+ ).Once()
+ mockChain.On("GetBlockHeaders", hashes).Return(
+ []*wire.BlockHeader{{}}, nil,
+ ).Once()
+
+ // Act: Scan the batch with CFilters.
+ results, err := s.scanBatchWithCFilters(
+ t.Context(), scanState, 10, hashes,
+ )
+
+ // Assert: Verify that the block was fetched.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+}
+
+// TestWaitForEvent_NotificationsClosed verifies that the loop exits when the
+// notifications channel is closed.
+func TestWaitForEvent_NotificationsClosed(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer with a closed notification channel.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+
+ closedChan := make(chan any)
+ close(closedChan)
+
+ mockChain.On("Notifications").Return((<-chan any)(closedChan)).Once()
+
+ // Act: Start waiting for events.
+ err := s.waitForEvent(t.Context())
+
+ // Assert: Verify that the loop exits with the expected error.
+ require.ErrorIs(t, err, ErrWalletShuttingDown)
+}
+
+// TestWaitForEvent_ContextCancelled verifies exit on context cancellation.
+func TestWaitForEvent_ContextCancelled(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer with a blocking notification channel and a
+ // cancelled context.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+
+ blockChan := make(chan any)
+ mockChain.On("Notifications").Return((<-chan any)(blockChan)).Once()
+
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel()
+
+ // Act: Attempt to wait for events.
+ err := s.waitForEvent(ctx)
+
+ // Assert: Verify cancellation error.
+ require.ErrorIs(t, err, context.Canceled)
+}
+
+// TestMatchAndFetchBatch_GetBlocksError verifies error propagation.
+func TestMatchAndFetchBatch_GetBlocksError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer and setup a recovery state.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+
+ state := NewRecoveryState(1, nil, nil)
+
+ // Setup results and mock filters such that a match is forced, then
+ // mock a block fetch failure.
+ results := []scanResult{
+ {meta: &wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{Hash: chainhash.Hash{0x01}},
+ }},
+ }
+
+ filters := []*gcs.Filter{nil}
+
+ blockMap := make(map[chainhash.Hash]*wire.MsgBlock)
+ mockChain.On("GetBlocks", mock.Anything).Return(
+ ([]*wire.MsgBlock)(nil), errGetBlocks).Once()
+
+ // Act: Attempt to match and fetch the batch.
+ err := s.matchAndFetchBatch(
+ t.Context(), state, results, filters, blockMap,
+ )
+
+ // Assert: Verify error propagation.
+ require.ErrorIs(t, err, errGetBlocks)
+}
+
+// TestFilterBatch_ContextCancelled verifies early exit.
+func TestFilterBatch_ContextCancelled(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer and a cancelled context.
+ s := newSyncer(Config{}, nil, nil, nil)
+
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel()
+
+ // Act: Attempt to filter a batch.
+ results := []scanResult{{}}
+ matched, err := s.filterBatch(ctx, results, nil, nil, nil)
+
+ // Assert: Verify failure.
+ require.Nil(t, matched)
+ require.ErrorIs(t, err, context.Canceled)
+}
+
+// TestFilterBatch_BlockAlreadyFetched verifies skipping.
+func TestFilterBatch_BlockAlreadyFetched(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer where the target block has already been
+ // fetched.
+ s := newSyncer(Config{}, nil, nil, nil)
+
+ hash := chainhash.Hash{0x01}
+ results := []scanResult{
+ {meta: &wtxmgr.BlockMeta{Block: wtxmgr.Block{Hash: hash}}},
+ }
+ blockMap := map[chainhash.Hash]*wire.MsgBlock{
+ hash: {},
+ }
+
+ // Act: Filter the batch.
+ matched, err := s.filterBatch(t.Context(), results, nil, blockMap, nil)
+
+ // Assert: Verify that the block was skipped.
+ require.NoError(t, err)
+ require.Empty(t, matched)
+}
+
+// TestInitChainSync_WaitUntilSyncedError verifies error propagation.
+func TestInitChainSync_WaitUntilSyncedError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where the backend is not current,
+ // then cancel the context.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+
+ mockChain.On("IsCurrent").Return(false).Maybe()
+
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel()
+
+ // Act: Attempt chain sync initialization.
+ err := s.initChainSync(ctx)
+
+ // Assert: Verify failure.
+ require.ErrorContains(t, err, "unable to wait for backend sync")
+}
+
+// TestScanBatchHeadersOnly_ContextCancelled verifies early exit.
+func TestScanBatchHeadersOnly_ContextCancelled(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations and a cancelled context.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel()
+
+ mockChain.On("GetBlockHashes", mock.Anything, mock.Anything).Return(
+ []chainhash.Hash{}, context.Canceled).Maybe()
+
+ // Act: Attempt header-only scan.
+ results, err := s.scanBatchHeadersOnly(ctx, 0, 0)
+
+ // Assert: Verify failure.
+ require.Nil(t, results)
+ require.ErrorIs(t, err, context.Canceled)
+}
+
+// TestBroadcastUnminedTxns_BroadcastError verifies warning log (no error
+// returned).
+func TestBroadcastUnminedTxns_BroadcastError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where a transaction broadcast fails.
+ mockPublisher := &mockTxPublisher{}
+ mockTxStore := &mockTxStore{}
+
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ s := newSyncer(Config{DB: db}, nil, mockTxStore, mockPublisher)
+
+ tx := wire.NewMsgTx(1)
+ mockTxStore.On("UnminedTxs", mock.Anything).Return(
+ []*wire.MsgTx{tx}, nil).Once()
+ mockPublisher.On("Broadcast", mock.Anything, tx, "").Return(
+ errBroadcast).Once()
+
+ // Act: Broadcast unmined transactions.
+ err := s.broadcastUnminedTxns(t.Context())
+
+ // Assert: Verify that the error is not propagated (it's only logged).
+ require.NoError(t, err)
+}
+
+// TestCheckRollback_DBError verifies error propagation.
+func TestCheckRollback_DBError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where local block hash lookup fails
+ // during a rollback check.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, nil, nil)
+
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100}).Once()
+ mockAddrStore.On("BlockHash", mock.Anything, mock.Anything).Return(
+ (*chainhash.Hash)(nil), errBlockHash).Once()
+
+ // Act: Perform a rollback check.
+ err := s.checkRollback(t.Context())
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errBlockHash)
+}
+
+// TestCheckRollback_RemoteError verifies error propagation from
+// GetBlockHashes.
+func TestCheckRollback_RemoteError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where remote hash lookup fails
+ // during a rollback check.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db},
+ mockAddrStore, nil, nil,
+ )
+
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100}).Once()
+ mockAddrStore.On("BlockHash", mock.Anything, mock.Anything).Return(
+ &chainhash.Hash{}, nil).Maybe()
+ mockChain.On("GetBlockHashes", mock.Anything, mock.Anything).Return(
+ ([]chainhash.Hash)(nil), errRemote).Once()
+
+ // Act: Perform a rollback check.
+ err := s.checkRollback(t.Context())
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errRemote)
+}
+
+// TestFilterBatch_NilFilter verifies logging and forcing download.
+func TestFilterBatch_NilFilter(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a batch with a nil filter.
+ s := newSyncer(Config{}, nil, nil, nil)
+
+ hash := chainhash.Hash{0x01}
+ results := []scanResult{
+ {meta: &wtxmgr.BlockMeta{Block: wtxmgr.Block{Hash: hash}}},
+ }
+ filters := []*gcs.Filter{nil}
+ blockMap := make(map[chainhash.Hash]*wire.MsgBlock)
+
+ // Act: Filter the batch.
+ matched, err := s.filterBatch(
+ t.Context(), results, filters, blockMap, nil,
+ )
+
+ // Assert: Verify that the block is matched to force download.
+ require.NoError(t, err)
+ require.Len(t, matched, 1)
+ require.Equal(t, hash, matched[0])
+}
+
+// TestInitChainSync_NotifyBlocksError verifies error propagation.
+func TestInitChainSync_NotifyBlocksError(t *testing.T) {
+ t.Parallel()
+
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ // Arrange: Setup mock expectations where block notification fails.
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db},
+ mockAddrStore, nil, nil,
+ )
+
+ mockChain.On("IsCurrent").Return(true).Once()
+ mockChain.On("GetBlockHashes", mock.Anything, mock.Anything).Return(
+ []chainhash.Hash{}, nil).Once()
+ mockChain.On("NotifyBlocks").Return(errNotify).Once()
+
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 0}).Once()
+ mockAddrStore.On("Birthday").Return(time.Time{}).Maybe()
+
+ // Act: Attempt chain sync initialization.
+ err := s.initChainSync(t.Context())
+
+ // Assert: Verify failure.
+ require.ErrorContains(t, err, "unable to start block notifications")
+}
+
+// TestScanBatchHeadersOnly_Errors verifies error paths.
+func TestScanBatchHeadersOnly_Errors(t *testing.T) {
+ t.Parallel()
+
+ t.Run("GetBlockHashes_Failure", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where GetBlockHashes fails.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+
+ mockChain.On("GetBlockHashes", mock.Anything,
+ mock.Anything).Return(([]chainhash.Hash)(nil),
+ errHashes).Once()
+
+ // Act: Perform header-only scan.
+ results, err := s.scanBatchHeadersOnly(t.Context(), 0, 0)
+
+ // Assert: Verify failure.
+ require.Nil(t, results)
+ require.ErrorIs(t, err, errHashes)
+ })
+
+ t.Run("GetBlockHeaders_Failure", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where GetBlockHeaders fails.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+
+ mockChain.On("GetBlockHashes", mock.Anything,
+ mock.Anything).Return([]chainhash.Hash{{}}, nil).Once()
+ mockChain.On("GetBlockHeaders", mock.Anything).Return(
+ ([]*wire.BlockHeader)(nil), errHeaders).Once()
+
+ // Act: Perform header-only scan again.
+ results, err := s.scanBatchHeadersOnly(t.Context(), 0, 0)
+
+ // Assert: Verify failure.
+ require.Nil(t, results)
+ require.ErrorIs(t, err, errHeaders)
+ })
+}
+
+// TestCheckRollback_HeaderError verifies error when fetching fork header.
+func TestCheckRollback_HeaderError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations for a rollback check where a
+ // header fetch failure occurs at the fork point.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db},
+ mockAddrStore, nil, nil,
+ )
+
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 101}).Once()
+
+ hashA := &chainhash.Hash{0x0A}
+ hashB := &chainhash.Hash{0x0B}
+ hashC := chainhash.Hash{0x0C}
+
+ mockAddrStore.On("BlockHash", mock.Anything, int32(101)).Return(hashB,
+ nil).Once()
+ mockAddrStore.On("BlockHash", mock.Anything, int32(100)).Return(hashA,
+ nil).Once()
+ mockAddrStore.On("BlockHash", mock.Anything, mock.Anything).Return(
+ &chainhash.Hash{}, nil).Maybe()
+
+ remoteHashes := make([]chainhash.Hash, 10)
+ remoteHashes[8] = *hashA
+ remoteHashes[9] = hashC
+ mockChain.On("GetBlockHashes", int64(92), int64(101)).Return(
+ remoteHashes, nil).Once()
+ mockChain.On("GetBlockHeader", hashA).Return(
+ (*wire.BlockHeader)(nil), errHeader).Once()
+
+ // Act: Perform the rollback check.
+ err := s.checkRollback(t.Context())
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errHeader)
+}
+
+// TestFilterBatch_Match verifies positive match logic.
+func TestFilterBatch_Match(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a batch with a matching filter.
+ s := newSyncer(Config{}, nil, nil, nil)
+
+ hash := chainhash.Hash{0x01}
+ results := []scanResult{
+ {meta: &wtxmgr.BlockMeta{Block: wtxmgr.Block{Hash: hash}}},
+ }
+ blockMap := make(map[chainhash.Hash]*wire.MsgBlock)
+
+ key := builder.DeriveKey(&hash)
+ filter, err := gcs.BuildGCSFilter(
+ builder.DefaultP, builder.DefaultM, key, [][]byte{{0x01}},
+ )
+ require.NoError(t, err)
+
+ filters := []*gcs.Filter{filter}
+ watchList := [][]byte{{0x01}}
+
+ // Act: Filter the batch.
+ matched, err := s.filterBatch(
+ t.Context(), results, filters, blockMap, watchList,
+ )
+
+ // Assert: Verify the match.
+ require.NoError(t, err)
+ require.Len(t, matched, 1)
+ require.Equal(t, hash, matched[0])
+}
+
+// TestScanWithTargets_Empty verifies handling of empty batch results.
+func TestScanWithTargets_Empty(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a targeted scan where the resulting block batch is
+ // empty.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+
+ defer mockChain.AssertExpectations(t)
+ defer mockAddrStore.AssertExpectations(t)
+ defer mockTxStore.AssertExpectations(t)
+
+ s := newSyncer(Config{
+ DB: db,
+ Chain: mockChain,
+ SyncMethod: SyncMethodAuto,
+ MaxCFilterItems: 100,
+ }, mockAddrStore, mockTxStore, nil)
+
+ req := &scanReq{
+ startBlock: waddrmgr.BlockStamp{Height: 100},
+ targets: []waddrmgr.AccountScope{
+ {Scope: waddrmgr.KeyScopeBIP0084, Account: 0}},
+ }
+
+ mockTxStore.On("OutputsToWatch", mock.Anything).Return(
+ []wtxmgr.Credit{{PkScript: []byte{0x01}}}, nil).Once()
+
+ mgr := &mockAccountStore{}
+ mockAddrStore.On("FetchScopedKeyManager", mock.Anything).Return(mgr,
+ nil).Times(3)
+ mgr.On("AccountProperties", mock.Anything, mock.Anything).Return(
+ &waddrmgr.AccountProperties{}, nil).Once()
+ mockAddrStore.On("ForEachRelevantActiveAddress", mock.Anything,
+ mock.AnythingOfType("func(address.Address) error")).Return(
+ nil).Once()
+ // SyncedTo is not called in the targeted scan path.
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100}).Maybe()
+
+ mockChain.On("GetBestBlock").Return(&chainhash.Hash{}, int32(100),
+ nil).Once()
+ mockChain.On("GetBlockHashes", int64(100), int64(100)).Return(
+ []chainhash.Hash{}, nil).Once()
+ mockChain.On("GetCFilters", []chainhash.Hash{},
+ wire.GCSFilterRegular).Return([]*gcs.Filter{}, nil).Once()
+ mockChain.On("GetBlockHeaders", []chainhash.Hash{}).Return(
+ []*wire.BlockHeader{}, nil).Once()
+
+ // Act: Perform the scan.
+ err := s.scanWithTargets(t.Context(), req)
+
+ // Assert: Verify that an empty batch error is returned.
+ require.ErrorIs(t, err, ErrScanBatchEmpty)
+}
+
+// TestInitChainSync_Neutrino verifies the type switch case for NeutrinoClient.
+func TestInitChainSync_Neutrino(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock neutrino chain service and a syncer with a
+ // NeutrinoClient.
+ mockCS := &mockNeutrinoChain{}
+ // IsCurrent called by waitUntilBackendSynced.
+ // Return false to keep polling until context cancel.
+ mockCS.On("IsCurrent").Return(false).Maybe()
+
+ nc := &chain.NeutrinoClient{
+ CS: mockCS,
+ }
+ mockAddrStore := &mockAddrStore{}
+ // Birthday called by SetStartTime.
+ mockAddrStore.On("Birthday").Return(time.Time{}).Once()
+
+ s := newSyncer(Config{Chain: nc}, mockAddrStore, nil, nil)
+
+ // Cancel context immediately to abort waitUntilBackendSynced.
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel()
+
+ // Act: Attempt chain sync initialization.
+ err := s.initChainSync(ctx)
+
+ // Assert: Verify cancellation error and that Birthday was accessed.
+ require.Error(t, err)
+ mockAddrStore.AssertExpectations(t)
+}
+
+// TestFetchAndFilterBlocks_HeaderScan verifies the optimization for empty scan
+// state.
+func TestFetchAndFilterBlocks_HeaderScan(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer with an empty scan state.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+
+ scanState := NewRecoveryState(10, nil, nil)
+
+ // Expect hash and header fetching for the empty scan state.
+ mockChain.On("GetBlockHashes", int64(100), int64(100)).Return(
+ []chainhash.Hash{{0x01}}, nil,
+ ).Once()
+ mockChain.On("GetBlockHeaders", mock.Anything).Return(
+ []*wire.BlockHeader{{Timestamp: time.Unix(12345, 0)}}, nil,
+ ).Once()
+
+ // Act: Perform the fetch and filter operation.
+ results, err := s.fetchAndFilterBlocks(
+ t.Context(), scanState, 100, 100,
+ )
+
+ // Assert: Verify results.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+ require.Equal(t, int32(100), results[0].meta.Height)
+}
+
+// TestScanBatchWithFullBlocks_ProcessError verifies error from ProcessBlock.
+func TestScanBatchWithFullBlocks_ProcessError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations to simulate an expansion failure
+ // during full block scanning.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain, DB: db}, nil, nil, nil)
+
+ addrStore := &mockAccountStore{}
+ rs := NewRecoveryState(10, &chainParams, nil)
+ rs.addrFilters = make(map[string]AddrEntry)
+ rs.outpoints = make(map[wire.OutPoint][]byte)
+ rs.branchStates[waddrmgr.BranchScope{}] = NewBranchRecoveryState(
+ 10, addrStore,
+ )
+
+ // Force expansion by finding an address.
+ addr, err := address.NewAddressPubKeyHash(
+ make([]byte, 20), &chainParams,
+ )
+ require.NoError(t, err)
+
+ rs.addrFilters[addr.EncodeAddress()] = AddrEntry{
+ Address: addr,
+ IsLookahead: true,
+ addrScope: waddrmgr.AddrScope{Index: 0},
+ }
+ block := wire.NewMsgBlock(&wire.BlockHeader{})
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ tx := wire.NewMsgTx(1)
+ tx.AddTxOut(wire.NewTxOut(100, pkScript))
+ require.NoError(t, block.AddTransaction(tx))
+
+ hashes := []chainhash.Hash{{0x01}}
+ mockChain.On("GetBlocks", hashes).Return([]*wire.MsgBlock{block},
+ nil).Once()
+ addrStore.On("DeriveAddr", mock.Anything, mock.Anything,
+ mock.Anything).Return(nil, nil, errDeriveFail).Once()
+
+ // Act: Execute the scan.
+ results, err := s.scanBatchWithFullBlocks(
+ t.Context(), rs, 100, hashes,
+ )
+
+ // Assert: Verify derivation failure.
+ require.Nil(t, results)
+ require.ErrorContains(t, err, "derive fail")
+}
+
+// TestDispatchScanStrategy_Auto verifies heuristics.
+func TestDispatchScanStrategy_Auto(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations for the auto dispatch strategy.
+ mockChain := &mockChain{}
+ s := newSyncer(
+ Config{
+ Chain: mockChain,
+ SyncMethod: SyncMethodAuto,
+ MaxCFilterItems: 1,
+ }, nil, nil, nil,
+ )
+ scanState := NewRecoveryState(10, nil, nil)
+
+ scanState.outpoints = make(map[wire.OutPoint][]byte)
+ for i := range 5 {
+ scanState.outpoints[wire.OutPoint{Index: uint32(i)}] = []byte{}
+ }
+
+ hashes := []chainhash.Hash{{0x01}}
+ mockChain.On("GetBlocks", hashes).Return(
+ []*wire.MsgBlock{wire.NewMsgBlock(&wire.BlockHeader{})}, nil,
+ ).Once()
+
+ // Act: Dispatch the scan strategy.
+ results, err := s.dispatchScanStrategy(
+ t.Context(), scanState, 100, hashes,
+ )
+
+ // Assert: Verify successful dispatch.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+}
+
+// TestDispatchScanStrategy_AutoFallback_Final verifies fallback on
+// ErrCFiltersUnavailable.
+func TestDispatchScanStrategy_AutoFallback_Final(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where CFilters are unavailable,
+ // triggering a fallback to full blocks.
+ mockChain := &mockChain{}
+ s := newSyncer(
+ Config{
+ Chain: mockChain,
+ SyncMethod: SyncMethodAuto,
+ }, nil, nil, nil,
+ )
+
+ scanState := NewRecoveryState(10, nil, nil)
+ scanState.outpoints = make(map[wire.OutPoint][]byte)
+ scanState.outpoints[wire.OutPoint{}] = []byte{}
+ hashes := []chainhash.Hash{{0x01}}
+
+ mockChain.On("GetCFilters", hashes, mock.Anything).Return(
+ []*gcs.Filter(nil), ErrCFiltersUnavailable).Once()
+ mockChain.On("GetBlocks", hashes).Return(
+ []*wire.MsgBlock{wire.NewMsgBlock(&wire.BlockHeader{})}, nil,
+ ).Once()
+
+ // Act: Dispatch the strategy.
+ results, err := s.dispatchScanStrategy(
+ t.Context(), scanState, 100, hashes,
+ )
+
+ // Assert: Verify successful fallback.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+}
+
+// TestProcessChainUpdate_Disconnected verifies rollback on disconnect.
+func TestProcessChainUpdate_Disconnected(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a syncer with a database and verify initial sync
+ // state.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, nil, nil)
+
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 0}).Once()
+
+ // Act: Process a BlockDisconnected update.
+ err := s.processChainUpdate(
+ t.Context(), chain.BlockDisconnected{},
+ )
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+}
+
+// TestScanWithTargets_Errors verifies error paths in scanWithTargets.
+func TestScanWithTargets_Errors(t *testing.T) {
+ t.Parallel()
+
+ t.Run("GetBestBlock_Failure", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where GetBestBlock fails
+ // during targeted scan initialization.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+
+ s := newSyncer(
+ Config{
+ Chain: mockChain,
+ DB: db,
+ }, mockAddrStore, mockTxStore, nil,
+ )
+
+ req := &scanReq{
+ startBlock: waddrmgr.BlockStamp{Height: 100},
+ targets: []waddrmgr.AccountScope{{
+ Scope: waddrmgr.KeyScopeBIP0084, Account: 0,
+ }},
+ }
+
+ mgr := &mockAccountStore{}
+ mockAddrStore.On("FetchScopedKeyManager",
+ mock.Anything).Return(mgr, nil)
+ mgr.On("AccountProperties", mock.Anything, mock.Anything).Return(
+ &waddrmgr.AccountProperties{}, nil)
+ mockAddrStore.On("ForEachRelevantActiveAddress", mock.Anything,
+ mock.Anything).Return(nil)
+ mockTxStore.On("OutputsToWatch", mock.Anything).Return(
+ []wtxmgr.Credit(nil), nil)
+ mockChain.On("GetBestBlock").Return(nil, int32(0),
+ errBestBlock).Once()
+
+ // Act: Attempt targeted scan.
+ err := s.scanWithTargets(t.Context(), req)
+
+ // Assert: Verify failure.
+ require.ErrorContains(t, err, "best block fail")
+ })
+
+ t.Run("GetBlockHashes_Failure", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where GetBlockHashes fails.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+
+ s := newSyncer(
+ Config{
+ Chain: mockChain,
+ DB: db,
+ }, mockAddrStore, mockTxStore, nil,
+ )
+
+ req := &scanReq{
+ startBlock: waddrmgr.BlockStamp{Height: 100},
+ targets: []waddrmgr.AccountScope{{
+ Scope: waddrmgr.KeyScopeBIP0084, Account: 0,
+ }},
+ }
+
+ mgr := &mockAccountStore{}
+ mockAddrStore.On("FetchScopedKeyManager",
+ mock.Anything).Return(mgr, nil)
+ mgr.On("AccountProperties", mock.Anything, mock.Anything).Return(
+ &waddrmgr.AccountProperties{}, nil).Once()
+ mockAddrStore.On("ForEachRelevantActiveAddress", mock.Anything,
+ mock.Anything).Return(nil).Once()
+ mockTxStore.On("OutputsToWatch", mock.Anything).Return(
+ []wtxmgr.Credit(nil), nil).Once()
+ mockChain.On("GetBestBlock").Return(&chainhash.Hash{},
+ int32(200), nil).Once()
+ mockChain.On("GetBlockHashes", mock.Anything,
+ mock.Anything).Return([]chainhash.Hash(nil),
+ errHashes).Once()
+
+ // Act: Attempt targeted scan.
+ err := s.scanWithTargets(t.Context(), req)
+
+ // Assert: Verify failure.
+ require.ErrorContains(t, err, "hashes fail")
+ })
+
+ t.Run("FetchScopedKeyManager_Failure", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations to simulate a fetch failure during
+ // targeted scan initialization.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, nil, nil)
+
+ mockAddrStore.On("FetchScopedKeyManager", mock.Anything).Return(
+ nil, errFetchFail).Once()
+
+ targets := []waddrmgr.AccountScope{{
+ Scope: waddrmgr.KeyScopeBIP0084, Account: 0,
+ }}
+
+ // Act: Attempt a targeted scan.
+ err := s.scanWithTargets(
+ t.Context(), &scanReq{
+ targets: targets,
+ startBlock: waddrmgr.BlockStamp{Height: 100},
+ },
+ )
+
+ // Assert: Verify propagation.
+ require.ErrorContains(t, err, "fetch fail")
+ })
+}
+
+// TestScanBatchWithCFilters_InitResultsError verifies error propagation.
+func TestScanBatchWithCFilters_InitResultsError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where header retrieval fails during
+ // initialization for a CFilter scan.
+ mockChain := &mockChain{}
+ s := newSyncer(
+ Config{
+ Chain: mockChain,
+ SyncMethod: SyncMethodCFilters,
+ }, nil, nil, nil,
+ )
+
+ hashes := []chainhash.Hash{{0x01}}
+ mockChain.On("GetCFilters", hashes, mock.Anything).Return(
+ []*gcs.Filter{{}}, nil).Once()
+ mockChain.On("GetBlockHeaders", hashes).Return(
+ []*wire.BlockHeader(nil), errHeaders).Once()
+
+ // Act: Attempt batch scan with CFilters.
+ results, err := s.scanBatchWithCFilters(
+ t.Context(), nil, 100, hashes,
+ )
+
+ // Assert: Verify failure.
+ require.Nil(t, results)
+ require.ErrorContains(t, err, "headers fail")
+}
+
+// TestProcessChainUpdate verifies processChainUpdate for all update types.
+func TestProcessChainUpdate(t *testing.T) {
+ t.Parallel()
+
+ db, cleanup := setupTestDB(t)
+ t.Cleanup(cleanup)
+
+ tests := []struct {
+ name string
+ update interface{}
+ setup func(*mockAddrStore, *mockTxStore, *mockChain)
+ }{
+ {
+ name: "BlockConnected",
+ update: chain.BlockConnected{
+ Block: wtxmgr.Block{Height: 100},
+ },
+ setup: func(as *mockAddrStore, ts *mockTxStore, c *mockChain) {
+ as.On("SetSyncedTo", mock.Anything, mock.MatchedBy(
+ func(bs *waddrmgr.BlockStamp) bool {
+ return bs.Height == 100
+ })).Return(nil).Once()
+ },
+ },
+ {
+ name: "RelevantTx",
+ update: chain.RelevantTx{
+ TxRecord: &wtxmgr.TxRecord{MsgTx: *wire.NewMsgTx(1)},
+ },
+ setup: func(as *mockAddrStore, ts *mockTxStore, c *mockChain) {
+ ts.On("InsertUnconfirmedTx", mock.Anything, mock.Anything,
+ mock.Anything).Return(nil).Once()
+ },
+ },
+ {
+ name: "FilteredBlockConnected",
+ update: chain.FilteredBlockConnected{
+ Block: &wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{Height: 102},
+ },
+ },
+ setup: func(as *mockAddrStore, ts *mockTxStore, c *mockChain) {
+ as.On("SetSyncedTo", mock.Anything, mock.MatchedBy(
+ func(bs *waddrmgr.BlockStamp) bool {
+ return bs.Height == 102
+ })).Return(nil).Once()
+ },
+ },
+ {
+ name: "BlockDisconnected",
+ update: chain.BlockDisconnected{
+ Block: wtxmgr.Block{Height: 100, Hash: chainhash.Hash{0x01}},
+ },
+ setup: func(as *mockAddrStore, ts *mockTxStore, c *mockChain) {
+ as.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100},
+ ).Once()
+ as.On(
+ "BlockHash", mock.Anything, mock.Anything,
+ ).Return(&chainhash.Hash{}, nil).Maybe()
+
+ remoteHashes := make([]chainhash.Hash, 10)
+ c.On(
+ "GetBlockHashes", mock.Anything, mock.Anything,
+ ).Return(remoteHashes, nil).Once()
+ },
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ mockChain := &mockChain{}
+ s := newSyncer(
+ Config{
+ Chain: mockChain,
+ ChainParams: &chainParams,
+ DB: db,
+ },
+ mockAddrStore, mockTxStore, nil,
+ )
+
+ tc.setup(mockAddrStore, mockTxStore, mockChain)
+
+ // Act
+ err := s.processChainUpdate(t.Context(), tc.update)
+
+ // Assert
+ require.NoError(t, err)
+ })
+ }
+}
+
+// TestHandleChainUpdate_SpecialNotifs verifies RescanProgress and
+// RescanFinished.
+func TestHandleChainUpdate_SpecialNotifs(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer for special notification handling.
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{}, mockAddrStore, nil, nil)
+
+ // 1. RescanProgress
+ // Act: Handle RescanProgress.
+ err := s.handleChainUpdate(
+ t.Context(), &chain.RescanProgress{
+ Height: 100, Hash: chainhash.Hash{0x01},
+ },
+ )
+ require.NoError(t, err)
+
+ // 2. RescanFinished
+ // Act: Handle RescanFinished.
+ err = s.handleChainUpdate(
+ t.Context(), &chain.RescanFinished{
+ Height: 100, Hash: &chainhash.Hash{0x01},
+ },
+ )
+ require.NoError(t, err)
+}
+
+// TestSyncStateString verifies String representations.
+func TestSyncStateString(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Define test cases for syncState string conversion.
+ tests := []struct {
+ state syncState
+ want string
+ }{
+ {syncStateBackendSyncing, "backend-syncing"},
+ {syncStateSyncing, "syncing"},
+ {syncStateSynced, "synced"},
+ {syncStateRescanning, "rescanning"},
+ {syncState(99), "unknown sync state"},
+ }
+
+ // Act & Assert: Execute test cases.
+ for _, tt := range tests {
+ require.Equal(t, tt.want, tt.state.String())
+ }
+}
+
+// TestFetchAndFilterBlocks_BatchCapping verifies endHeight calculation.
+func TestFetchAndFilterBlocks_BatchCapping(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer with expectations for batch capping.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+ scanState := NewRecoveryState(10, nil, nil)
+
+ // Expect GetBlockHashes with a capped range based on recoveryBatchSize.
+ mockChain.On("GetBlockHashes", int64(100), int64(2099)).Return(
+ []chainhash.Hash{{0x01}}, nil,
+ ).Once()
+ mockChain.On("GetBlockHeaders", mock.Anything).Return(
+ []*wire.BlockHeader{{}}, nil,
+ ).Once()
+
+ // Act: Perform the fetch.
+ results, err := s.fetchAndFilterBlocks(
+ t.Context(), scanState, 100, 5000,
+ )
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+}
+
+// TestRunSyncStep_Unfinished verifies the early return if sync not finished.
+func TestRunSyncStep_Unfinished(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer and mock an incomplete sync state.
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ s := newSyncer(
+ Config{
+ Chain: mockChain,
+ DB: db,
+ }, mockAddrStore, mockTxStore, nil,
+ )
+
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 90}).Maybe()
+ mockChain.On("GetBestBlock").Return(&chainhash.Hash{}, int32(100),
+ nil).Once()
+
+ mockAddrStore.On("ActiveScopedKeyManagers").Return(
+ []waddrmgr.AccountStore(nil)).Maybe()
+ mockAddrStore.On("ForEachRelevantActiveAddress", mock.Anything,
+ mock.Anything).Return(nil).Maybe()
+
+ mockTxStore.On("OutputsToWatch", mock.Anything).Return(
+ []wtxmgr.Credit(nil), nil).Maybe()
+ mockChain.On("GetBlockHashes", int64(91), int64(100)).Return(
+ []chainhash.Hash{{0x01}}, nil).Once()
+ mockChain.On("GetBlockHeaders", mock.Anything).Return(
+ []*wire.BlockHeader{{}}, nil).Once()
+ mockAddrStore.On("SetSyncedTo", mock.Anything,
+ mock.Anything).Return(nil).Maybe()
+
+ // Act: Execute a sync step.
+ err := s.runSyncStep(t.Context())
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+}
+
+// TestDispatchScanStrategy_OtherMethods verifies FullBlocks, CFilters and
+// Default.
+func TestDispatchScanStrategy_OtherMethods(t *testing.T) {
+ t.Parallel()
+
+ hashes := []chainhash.Hash{{0x01}}
+ scanState := NewRecoveryState(10, nil, nil)
+
+ t.Run("FullBlocks", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer for FullBlocks strategy.
+ mockChain := &mockChain{}
+ s := newSyncer(
+ Config{
+ Chain: mockChain,
+ SyncMethod: SyncMethodFullBlocks,
+ }, nil, nil, nil,
+ )
+ mockChain.On("GetBlocks", hashes).Return([]*wire.MsgBlock{
+ wire.NewMsgBlock(&wire.BlockHeader{})}, nil).Once()
+
+ // Act: Dispatch the strategy.
+ results, err := s.dispatchScanStrategy(
+ t.Context(), scanState, 100, hashes,
+ )
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+ })
+
+ t.Run("CFilters", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer for CFilters strategy.
+ mockChain := &mockChain{}
+ s := newSyncer(
+ Config{
+ Chain: mockChain,
+ SyncMethod: SyncMethodCFilters,
+ }, nil, nil, nil,
+ )
+ mockChain.On("GetCFilters", hashes, mock.Anything).Return(
+ []*gcs.Filter{{}}, nil).Once()
+ mockChain.On("GetBlockHeaders", hashes).Return(
+ []*wire.BlockHeader{{}}, nil).Once()
+ mockChain.On("GetBlocks", mock.Anything).Return(
+ []*wire.MsgBlock{wire.NewMsgBlock(&wire.BlockHeader{})},
+ nil).Once()
+
+ // Act: Dispatch the strategy.
+ results, err := s.dispatchScanStrategy(
+ t.Context(), scanState, 100, hashes,
+ )
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+ })
+
+ t.Run("Default_Unknown", func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer with an unknown method.
+ mockChain := &mockChain{}
+ s := newSyncer(
+ Config{
+ Chain: mockChain,
+ SyncMethod: 99,
+ }, nil, nil, nil,
+ )
+
+ // Act: Dispatch the strategy.
+ results, err := s.dispatchScanStrategy(
+ t.Context(), scanState, 100, hashes,
+ )
+
+ // Assert: Verify failure for unknown method.
+ require.Nil(t, results)
+ require.ErrorContains(t, err, "unknown sync method")
+ })
+}
+
+// TestHandleChainUpdate_Error verifies that handleChainUpdate returns error if
+// processChainUpdate fails.
+func TestHandleChainUpdate_Error(t *testing.T) {
+ t.Parallel()
+
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ // Arrange: Setup a syncer where chain update processing will fail due
+ // to a database error.
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, nil, nil)
+
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100}).Maybe()
+ mockAddrStore.On("BlockHash", mock.Anything, mock.Anything).Return(
+ (*chainhash.Hash)(nil), errDBFail).Once()
+
+ // Act: Attempt to handle a BlockDisconnected update.
+ err := s.handleChainUpdate(
+ t.Context(), chain.BlockDisconnected{},
+ )
+
+ // Assert: Verify failure.
+ require.ErrorContains(t, err, "failed to process chain update")
+}
+
+// TestRunSyncStep_Success verifies the idle path in runSyncStep.
+func TestRunSyncStep_Success(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a syncer and mock a notification arrival to trigger
+ // the idle processing path.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+ s := newSyncer(
+ Config{
+ Chain: mockChain,
+ DB: db,
+ }, mockAddrStore, mockTxStore, nil,
+ )
+
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100}).Maybe()
+ mockChain.On("GetBestBlock").Return(&chainhash.Hash{}, int32(100),
+ nil).Once()
+ mockTxStore.On("UnminedTxs", mock.Anything).Return([]*wire.MsgTx{},
+ nil).Once()
+
+ notifChan := make(chan any, 1)
+ mockChain.On("Notifications").Return((<-chan any)(notifChan)).Maybe()
+
+ notifChan <- chain.BlockConnected{Block: wtxmgr.Block{Height: 101}}
+
+ mockAddrStore.On("SetSyncedTo", mock.Anything,
+ mock.Anything).Return(nil).Once()
+
+ // Act: Execute a sync step.
+ err := s.runSyncStep(t.Context())
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+}
+
+// TestScanBatchWithCFilters_HorizonExpansion verifies the re-matching logic
+// when a horizon is expanded.
+func TestScanBatchWithCFilters_HorizonExpansion(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup a complex mock scenario where finding an address
+ // triggers a horizon expansion, requiring a re-match of the block
+ // batch.
+ mockChain := &mockChain{}
+ addrStore := &mockAddrStore{}
+ accountStore := &mockAccountStore{}
+
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ s := newSyncer(Config{Chain: mockChain, DB: db}, addrStore, nil, nil)
+
+ hashes := []chainhash.Hash{{0x01}, {0x02}}
+
+ mockChain.On("GetCFilters", hashes, wire.GCSFilterRegular).Return(
+ []*gcs.Filter{nil, nil}, nil).Once()
+ mockChain.On("GetBlockHeaders", hashes).Return(
+ []*wire.BlockHeader{{}, {}}, nil).Once()
+
+ block1 := wire.NewMsgBlock(&wire.BlockHeader{})
+ block2 := wire.NewMsgBlock(&wire.BlockHeader{})
+ mockChain.On("GetBlocks", mock.MatchedBy(func(h []chainhash.Hash) bool {
+ return len(h) == 2
+ })).Return([]*wire.MsgBlock{block1, block2}, nil).Once()
+
+ scanState := NewRecoveryState(1, &chainParams, addrStore)
+
+ scanState.addrFilters = make(map[string]AddrEntry)
+ scanState.outpoints = make(map[wire.OutPoint][]byte)
+
+ bs := waddrmgr.BranchScope{}
+ scanState.branchStates[bs] = NewBranchRecoveryState(1, accountStore)
+
+ // Found address in block 1.
+ addr, err := address.NewAddressPubKeyHash(
+ make([]byte, 20), &chainParams,
+ )
+ require.NoError(t, err)
+
+ scanState.addrFilters[addr.EncodeAddress()] = AddrEntry{
+ Address: addr,
+ IsLookahead: true,
+ addrScope: waddrmgr.AddrScope{Index: 0},
+ }
+ pkScript, err := txscript.PayToAddrScript(addr)
+ require.NoError(t, err)
+
+ tx1 := wire.NewMsgTx(1)
+ tx1.AddTxOut(wire.NewTxOut(100, pkScript))
+ require.NoError(t, block1.AddTransaction(tx1))
+
+ // Mock DeriveAddr and ExtendAddresses for expansion.
+ expAddr, err := address.NewAddressPubKeyHash(
+ append([]byte{1}, make([]byte, 19)...), &chainParams,
+ )
+ require.NoError(t, err)
+
+ accountStore.On("DeriveAddr", mock.Anything, mock.Anything,
+ mock.Anything).Return(expAddr, []byte{}, nil).Maybe()
+ accountStore.On("ExtendAddresses", mock.Anything, mock.Anything,
+ mock.Anything, mock.Anything).Return(
+ []address.Address{expAddr}, [][]byte{make([]byte, 20)}, nil,
+ ).Once()
+
+ // Act: Perform a batch scan with CFilters.
+ results, err := s.scanBatchWithCFilters(
+ t.Context(), scanState, 100, hashes,
+ )
+
+ // Assert: Verify that both blocks were returned after expansion.
+ require.NoError(t, err)
+ require.Len(t, results, 2)
+}
+
+// TestRunSyncStep_AdvanceError verifies that runSyncStep returns errors
+// from advanceChainSync.
+func TestRunSyncStep_AdvanceError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations to simulate a failure during
+ // loadFullScanState within runSyncStep.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db},
+ mockAddrStore, nil, nil,
+ )
+
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100}).Maybe()
+
+ mockChain.On("GetBestBlock").Return(
+ &chainhash.Hash{}, int32(101), nil).Once()
+
+ mgr := &mockAccountStore{}
+ mockAddrStore.On("ActiveScopedKeyManagers").Return(
+ []waddrmgr.AccountStore{mgr}).Once()
+ mgr.On("ActiveAccounts").Return([]uint32{0}).Once()
+ mgr.On("Scope").Return(waddrmgr.KeyScopeBIP0084).Once()
+
+ mockAddrStore.On("FetchScopedKeyManager",
+ waddrmgr.KeyScopeBIP0084).Return(nil, errLoadStateFail).Once()
+
+ // Act: Execute a single sync step.
+ err := s.runSyncStep(t.Context())
+
+ // Assert: Verify error propagation.
+ require.ErrorContains(t, err, "load state fail")
+}
+
+// TestLoadFullScanState_Error verifies error propagation.
+func TestLoadFullScanState_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations to simulate a database failure
+ // when loading scan state.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, nil, nil)
+
+ mgr := &mockAccountStore{}
+ mgr.On("ActiveAccounts").Return([]uint32{0}).Once()
+ mgr.On("Scope").Return(waddrmgr.KeyScopeBIP0084).Once()
+
+ mockAddrStore.On("ActiveScopedKeyManagers").Return(
+ []waddrmgr.AccountStore{mgr}).Once()
+ mockAddrStore.On("FetchScopedKeyManager",
+ waddrmgr.KeyScopeBIP0084).Return(nil, errDBMock).Once()
+
+ // Act: Attempt to load the full scan state.
+ state, err := s.loadFullScanState(t.Context())
+
+ // Assert: Verify failure.
+ require.Nil(t, state)
+ require.ErrorContains(t, err, "db error")
+}
+
+// TestScanWithRewind_Error verifies error propagation from DBPutRewind.
+func TestScanWithRewind_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations for a rewind scan where a database
+ // rollback failure occurs.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockTxStore := &mockTxStore{}
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, mockTxStore, nil)
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100}).Maybe()
+
+ mockAddrStore.On("SetSyncedTo", mock.Anything,
+ mock.Anything).Return(nil).Maybe()
+ mockTxStore.On("Rollback", mock.Anything, mock.Anything).Return(
+ errRollbackFail).Once()
+
+ // Act: Attempt to perform a scan with rewind.
+ err := s.scanWithRewind(
+ t.Context(), &scanReq{
+ startBlock: waddrmgr.BlockStamp{Height: 90},
+ },
+ )
+
+ // Assert: Verify rollback failure is propagated.
+ require.ErrorContains(t, err, "rollback fail")
+}
+
+// TestMatchAndFetchBatch_GetBlockHeadersError verifies error handling.
+func TestMatchAndFetchBatch_GetBlockHeadersError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a nil filter to force a match, bypassing complex
+ // filter logic, then mock a block fetch failure.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+
+ filters := []*gcs.Filter{nil}
+ results := []scanResult{{
+ meta: &wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{Hash: chainhash.Hash{0x01}},
+ },
+ }}
+
+ blockMap := make(map[chainhash.Hash]*wire.MsgBlock)
+
+ state := NewRecoveryState(10, nil, nil)
+
+ mockChain.On("GetBlocks", mock.Anything).Return(
+ []*wire.MsgBlock(nil), errGetBlocks).Once()
+
+ // Act: Attempt to match and fetch a batch.
+ err := s.matchAndFetchBatch(
+ t.Context(), state, results, filters, blockMap,
+ )
+
+ // Assert: Verify failure.
+ require.ErrorContains(t, err, "get blocks fail")
+}
+
+// TestScanBatchWithCFilters_FilterBatchError verifies error propagation.
+func TestScanBatchWithCFilters_FilterBatchError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where CFilter retrieval fails.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+
+ hashes := []chainhash.Hash{{0x01}}
+
+ mockChain.On("GetCFilters", hashes, wire.GCSFilterRegular).Return(
+ []*gcs.Filter(nil), errCFilterFail).Once()
+
+ // Act: Attempt a batch scan using CFilters.
+ results, err := s.scanBatchWithCFilters(
+ t.Context(), nil, 100, hashes,
+ )
+
+ // Assert: Verify failure.
+ require.Nil(t, results)
+ require.ErrorContains(t, err, "cfilter fail")
+}
+
+// TestScanBatch_GetScanDataError verifies scanBatch failure.
+func TestScanBatch_GetScanDataError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where scan data loading fails
+ // during a batch scan.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockAddrStore := &mockAddrStore{}
+ s := newSyncer(Config{DB: db}, mockAddrStore, nil, nil)
+
+ mgr := &mockAccountStore{}
+ mockAddrStore.On("ActiveScopedKeyManagers").Return(
+ []waddrmgr.AccountStore{mgr}).Once()
+ mgr.On("ActiveAccounts").Return([]uint32{0}).Once()
+ mgr.On("Scope").Return(waddrmgr.KeyScopeBIP0084).Once()
+ mockAddrStore.On("FetchScopedKeyManager",
+ waddrmgr.KeyScopeBIP0084).Return(nil, errActiveMgrsFail).Once()
+
+ // Act: Attempt to execute scanBatch.
+ err := s.scanBatch(
+ t.Context(), waddrmgr.BlockStamp{Height: 100}, 105,
+ )
+
+ // Assert: Verify error propagation.
+ require.ErrorContains(t, err, "active managers fail")
+}
+
+// TestInitResultsForCFilterScan_Error verifies basic error propagation (e.g.
+// GetBlockHeader).
+func TestInitResultsForCFilterScan_Error(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where header retrieval fails during
+ // initialization for a CFilter scan.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+
+ hashes := []chainhash.Hash{{0x01}}
+
+ mockChain.On("GetBlockHeaders", hashes).Return(
+ []*wire.BlockHeader(nil), errHeaders).Once()
+
+ // Act: Initialize results for a CFilter scan.
+ results, err := s.initResultsForCFilterScan(t.Context(), 100, hashes)
+
+ // Assert: Verify failure.
+ require.Nil(t, results)
+ require.ErrorContains(t, err, "headers fail")
+}
+
+// TestDispatchScanStrategy_AutoError verifies error return in Auto mode.
+func TestDispatchScanStrategy_AutoError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where header retrieval fails during
+ // an auto-dispatch scan.
+ mockChain := &mockChain{}
+ s := newSyncer(
+ Config{Chain: mockChain, SyncMethod: SyncMethodAuto},
+ nil, nil, nil,
+ )
+
+ hashes := []chainhash.Hash{{0x01}}
+ scanState := NewRecoveryState(1, nil, nil)
+
+ mockChain.On("GetCFilters", hashes, mock.Anything).Return(
+ []*gcs.Filter{{}}, nil).Once()
+ mockChain.On("GetBlockHeaders", hashes).Return(
+ ([]*wire.BlockHeader)(nil), errOther).Once()
+
+ // Act: Dispatch the scan strategy.
+ results, err := s.dispatchScanStrategy(
+ t.Context(), scanState, 100, hashes,
+ )
+
+ // Assert: Verify failure.
+ require.Nil(t, results)
+ require.ErrorIs(t, err, errOther)
+}
+
+// TestAdvanceChainSync_SmallGap verifies the silent sync path.
+func TestAdvanceChainSync_SmallGap(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations for a small gap where silent sync
+ // is preferred.
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db},
+ mockAddrStore, mockTxStore, nil,
+ )
+
+ mockChain.On("GetBestBlock").Return(&chainhash.Hash{}, int32(105),
+ nil).Once()
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100}).Once()
+ mockAddrStore.On("ActiveScopedKeyManagers").Return(
+ []waddrmgr.AccountStore(nil)).Once()
+ mockAddrStore.On("ForEachRelevantActiveAddress", mock.Anything,
+ mock.Anything).Return(nil).Once()
+
+ mockTxStore.On("OutputsToWatch", mock.Anything).Return(
+ []wtxmgr.Credit(nil), nil).Once()
+ mockChain.On("GetBlockHashes", int64(101), int64(105)).Return(
+ []chainhash.Hash{{0x01}}, nil).Once()
+ mockChain.On("GetBlockHeaders", mock.Anything).Return(
+ []*wire.BlockHeader{{}}, nil).Once()
+ mockAddrStore.On("SetSyncedTo", mock.Anything,
+ mock.Anything).Return(nil).Once()
+
+ // Act: Advance chain sync.
+ finished, err := s.advanceChainSync(t.Context())
+
+ // Assert: Verify state transition to backend-syncing.
+ require.NoError(t, err)
+ require.False(t, finished)
+ require.Equal(t, uint32(syncStateBackendSyncing), s.state.Load())
+}
+
+// TestRunSyncStep_BroadcastError verifies error propagation.
+func TestRunSyncStep_BroadcastError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where a broadcast-related failure
+ // occurs during a sync step.
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db},
+ mockAddrStore, mockTxStore, nil,
+ )
+
+ mockChain.On("GetBestBlock").Return(&chainhash.Hash{}, int32(100),
+ nil).Once()
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100}).Maybe()
+ mockTxStore.On("UnminedTxs", mock.Anything).Return([]*wire.MsgTx(nil),
+ errBroadcast).Once()
+
+ // Act: Execute a sync step.
+ err := s.runSyncStep(t.Context())
+
+ // Assert: Verify failure.
+ require.ErrorIs(t, err, errBroadcast)
+}
+
+// TestFetchAndFilterBlocks_DispatchError verifies error from
+// dispatchScanStrategy.
+func TestFetchAndFilterBlocks_DispatchError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where an invalid sync method is
+ // encountered during block filtering.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain, SyncMethod: 99}, nil, nil, nil)
+
+ hashes := []chainhash.Hash{{0x01}}
+ mockChain.On("GetBlockHashes", mock.Anything, mock.Anything).Return(
+ hashes, nil).Once()
+
+ scanState := NewRecoveryState(1, nil, nil)
+ scanState.outpoints = make(map[wire.OutPoint][]byte)
+ scanState.outpoints[wire.OutPoint{}] = []byte{}
+
+ // Act: Attempt to fetch and filter blocks.
+ results, err := s.fetchAndFilterBlocks(t.Context(), scanState, 100, 100)
+
+ // Assert: Verify unknown sync method error.
+ require.Nil(t, results)
+ require.ErrorContains(t, err, "unknown sync method")
+}
+
+// TestAdvanceChainSync_ScanBatchError verifies error propagation.
+func TestAdvanceChainSync_ScanBatchError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where address iteration fails
+ // during chain sync advancement.
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db},
+ mockAddrStore, nil, nil,
+ )
+
+ mockChain.On("GetBestBlock").Return(&chainhash.Hash{}, int32(105),
+ nil).Once()
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100}).Once()
+ mockAddrStore.On("ActiveScopedKeyManagers").Return(
+ []waddrmgr.AccountStore(nil)).Once()
+ mockAddrStore.On("ForEachRelevantActiveAddress", mock.Anything,
+ mock.Anything).Return(errScan).Once()
+
+ // Act: Advance chain sync.
+ finished, err := s.advanceChainSync(t.Context())
+
+ // Assert: Verify failure.
+ require.False(t, finished)
+ require.ErrorIs(t, err, errScan)
+}
+
+// TestDispatchScanStrategy_FullBlocksError verifies error propagation.
+func TestDispatchScanStrategy_FullBlocksError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where block retrieval fails during
+ // a full-block scan.
+ mockChain := &mockChain{}
+ s := newSyncer(
+ Config{Chain: mockChain, SyncMethod: SyncMethodFullBlocks},
+ nil, nil, nil,
+ )
+
+ hashes := []chainhash.Hash{{0x01}}
+ scanState := NewRecoveryState(1, nil, nil)
+
+ mockChain.On("GetBlocks", hashes).Return([]*wire.MsgBlock(nil),
+ errBlocks).Once()
+
+ // Act: Dispatch the strategy.
+ results, err := s.dispatchScanStrategy(
+ t.Context(), scanState, 100, hashes,
+ )
+
+ // Assert: Verify failure.
+ require.Nil(t, results)
+ require.ErrorIs(t, err, errBlocks)
+}
+
+// TestExtractAddrEntries_NonStd verifies non-standard script handling.
+func TestExtractAddrEntries_NonStd(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Initialize a syncer and create various output scripts,
+ // including non-standard and OP_RETURN scripts.
+ s := newSyncer(
+ Config{ChainParams: &chainParams},
+ nil, nil, nil,
+ )
+
+ pkh, err := address.NewAddressPubKeyHash(
+ make([]byte, 20), &chainParams,
+ )
+ require.NoError(t, err)
+ pkScript, err := txscript.PayToAddrScript(pkh)
+ require.NoError(t, err)
+
+ txOuts := []*wire.TxOut{
+ {
+ // OP_DATA_1 but no data byte follows. (Error path)
+ PkScript: []byte{0x01},
+ },
+ {
+ // OP_RETURN (Empty addrs, no error)
+ PkScript: []byte{0x6a, 0x04, 0xde, 0xad, 0xbe, 0xef},
+ },
+ {
+ // Standard P2PKH (Success path)
+ PkScript: pkScript,
+ },
+ }
+
+ // Act: Extract address entries.
+ entries := s.extractAddrEntries(txOuts)
+
+ // Assert: Verify that only the standard P2PKH output was extracted.
+ require.Len(t, entries, 1)
+}
+
+// TestAdvanceChainSync_GetBestBlockError verifies error propagation.
+func TestAdvanceChainSync_GetBestBlockError(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations where GetBestBlock fails during
+ // chain sync advancement.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{Chain: mockChain}, nil, nil, nil)
+
+ mockChain.On("GetBestBlock").Return((*chainhash.Hash)(nil), int32(0),
+ errBestBlock).Once()
+
+ // Act: Advance chain sync.
+ finished, err := s.advanceChainSync(t.Context())
+
+ // Assert: Verify failure.
+ require.False(t, finished)
+ require.ErrorIs(t, err, errBestBlock)
+}
+
+// TestDispatchScanStrategy_AutoDefaultThreshold verifies threshold=0 branch.
+func TestDispatchScanStrategy_AutoDefaultThreshold(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations for auto strategy with a zero
+ // threshold for compact filters.
+ mockChain := &mockChain{}
+ s := newSyncer(Config{
+ Chain: mockChain,
+ SyncMethod: SyncMethodAuto,
+ MaxCFilterItems: 0,
+ }, nil, nil, nil)
+
+ hashes := []chainhash.Hash{{0x01}}
+ scanState := NewRecoveryState(1, nil, nil)
+
+ mockChain.On("GetCFilters", hashes, mock.Anything).Return(
+ []*gcs.Filter{{}}, nil).Once()
+ mockChain.On("GetBlockHeaders", hashes).Return(
+ []*wire.BlockHeader{{}}, nil).Once()
+ mockChain.On("GetBlocks", mock.Anything).Return(
+ []*wire.MsgBlock{
+ wire.NewMsgBlock(&wire.BlockHeader{})}, nil).Once()
+
+ // Act: Dispatch the strategy.
+ results, err := s.dispatchScanStrategy(
+ t.Context(), scanState, 100, hashes,
+ )
+
+ // Assert: Verify success.
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+}
+
+// TestAdvanceChainSync_LargeGap verifies the explicit syncing state.
+func TestAdvanceChainSync_LargeGap(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Setup mock expectations for a large sync gap where explicit
+ // scanning is triggered.
+ mockChain := &mockChain{}
+ mockAddrStore := &mockAddrStore{}
+ mockTxStore := &mockTxStore{}
+
+ db, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ s := newSyncer(
+ Config{Chain: mockChain, DB: db},
+ mockAddrStore, mockTxStore, nil,
+ )
+
+ mockChain.On("GetBestBlock").Return(&chainhash.Hash{}, int32(110),
+ nil).Once()
+ mockAddrStore.On("SyncedTo").Return(
+ waddrmgr.BlockStamp{Height: 100}).Once()
+
+ // The following mocks use Maybe() because for a large gap, the syncer
+ // transitions to SyncStateSyncing and returns early, skipping these
+ // calls.
+ mockAddrStore.On("ActiveScopedKeyManagers").Return(
+ []waddrmgr.AccountStore(nil)).Maybe()
+ mockAddrStore.On("ForEachRelevantActiveAddress", mock.Anything,
+ mock.Anything).Return(nil).Maybe()
+
+ mockTxStore.On("OutputsToWatch", mock.Anything).Return(
+ []wtxmgr.Credit(nil), nil).Maybe()
+ mockChain.On("GetBlockHashes", mock.Anything, mock.Anything).Return(
+ []chainhash.Hash{{0x01}}, nil).Maybe()
+ mockChain.On("GetBlockHeaders", mock.Anything).Return(
+ []*wire.BlockHeader{{}}, nil).Maybe()
+ mockAddrStore.On("SetSyncedTo", mock.Anything,
+ mock.Anything).Return(nil).Maybe()
+
+ // Act: Advance chain sync.
+ finished, err := s.advanceChainSync(t.Context())
+
+ // Assert: Verify state transition to syncing.
+ require.NoError(t, err)
+ require.False(t, finished)
+ require.Equal(t, uint32(syncStateSyncing), s.state.Load())
+}
diff --git a/wallet/tx_creator.go b/wallet/tx_creator.go
new file mode 100644
index 0000000000..fc32b2e44b
--- /dev/null
+++ b/wallet/tx_creator.go
@@ -0,0 +1,1157 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+// Package wallet provides a bitcoin wallet implementation that is ready for
+// use.
+//
+// TODO(yy): bring wrapcheck back when implementing the `Store` interface.
+//
+//nolint:wrapcheck
+package wallet
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "math/rand"
+ "sort"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/pkg/btcunit"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/wallet/txauthor"
+ "github.com/btcsuite/btcwallet/wallet/txrules"
+ "github.com/btcsuite/btcwallet/wallet/txsizes"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+)
+
+var (
+ // ErrManualInputsEmpty is returned when manual inputs are specified but
+ // the list is empty.
+ ErrManualInputsEmpty = errors.New("manual inputs cannot be empty")
+
+ // ErrDuplicatedUtxo is returned when a UTXO is specified multiple
+ // times.
+ ErrDuplicatedUtxo = errors.New("duplicated utxo")
+
+ // ErrUnsupportedTxInputs is returned when the `Inputs` field of a
+ // TxIntent is not of a supported type.
+ ErrUnsupportedTxInputs = errors.New("unsupported tx inputs type")
+
+ // ErrUtxoNotEligible is returned when a UTXO is not eligible to be
+ // spent.
+ ErrUtxoNotEligible = errors.New("utxo not eligible to spend")
+
+ // ErrAccountNotFound is returned when an account is not found.
+ ErrAccountNotFound = errors.New("account not found")
+
+ // ErrNoTxOutputs is returned when a transaction is created without any
+ // outputs.
+ ErrNoTxOutputs = errors.New("tx has no outputs")
+
+ // ErrFeeRateTooLarge is returned when a transaction is created with a
+ // fee rate that is larger than the configured max allowed fee rate.
+ // The default max fee rate is 1000 sat/vb.
+ ErrFeeRateTooLarge = errors.New("fee rate too large")
+
+ // ErrMissingFeeRate is returned when a transaction is created without
+ // a fee rate.
+ ErrMissingFeeRate = errors.New("missing fee rate")
+
+ // ErrMissingAccountName is returned when an account name is required
+ // but not provided.
+ ErrMissingAccountName = errors.New("account name cannot be empty")
+
+ // ErrUnsupportedCoinSource is returned when the `Source` field of a
+ // CoinSelectionPolicy is not of a supported type.
+ ErrUnsupportedCoinSource = errors.New("unsupported coin source type")
+
+ // ErrMissingInputs is returned when a transaction is created without
+ // any inputs.
+ ErrMissingInputs = errors.New("tx has no inputs")
+
+ // ErrNilTxIntent is returned when a nil `TxIntent` is provided.
+ ErrNilTxIntent = errors.New("nil TxIntent")
+)
+
+var (
+ // DefaultMaxFeeRate is the default maximum fee rate in sat/kvb that
+ // the wallet will consider sane. This is currently set to 1000 sat/vb
+ // (1,000,000 sat/kvb).
+ //
+ // TODO(yy): The max fee rate should be made configurable as part of
+ // the WalletController interface implementation.
+ //
+ //nolint:mnd // 1M sat/kvb default max fee.
+ DefaultMaxFeeRate = btcunit.NewSatPerKVByte(1_000_000)
+)
+
+// Coin represents a spendable UTXO which is available for coin selection.
+type Coin struct {
+ wire.TxOut
+ wire.OutPoint
+}
+
+// CoinSelectionStrategy is an interface that represents a coin selection
+// strategy. A coin selection strategy is responsible for ordering, shuffling or
+// filtering a list of coins before they are passed to the coin selection
+// algorithm.
+type CoinSelectionStrategy interface {
+ // ArrangeCoins takes a list of coins and arranges them according to the
+ // specified coin selection strategy and fee rate.
+ ArrangeCoins(eligible []Coin, feeSatPerKb btcutil.Amount) ([]Coin,
+ error)
+}
+
+var (
+ // CoinSelectionLargest always picks the largest available utxo to add
+ // to the transaction next.
+ CoinSelectionLargest CoinSelectionStrategy = &LargestFirstCoinSelector{}
+
+ // CoinSelectionRandom randomly selects the next utxo to add to the
+ // transaction. This strategy prevents the creation of ever smaller
+ // utxos over time.
+ CoinSelectionRandom CoinSelectionStrategy = &RandomCoinSelector{}
+)
+
+// TxCreator provides an interface for creating transactions. Its primary
+// role is to produce a fully-formed, unsigned transaction that can be passed
+// to the Signer interface.
+type TxCreator interface {
+ // CreateTransaction creates a new, unsigned transaction based on the
+ // provided intent. The resulting AuthoredTx will contain the unsigned
+ // transaction and all the necessary metadata to sign it.
+ CreateTransaction(ctx context.Context, intent *TxIntent) (
+ *txauthor.AuthoredTx, error)
+}
+
+// A compile time check to ensure that Wallet implements the interface.
+var _ TxCreator = (*Wallet)(nil)
+
+// TxIntent represents the user's intent to create a transaction. It serves as
+// a blueprint for the TxCreator, bundling all the parameters required to
+// construct a transaction into a single, coherent structure.
+//
+// A TxIntent can be used to create a transaction in four main ways:
+//
+// 1. Automatic Coin Selection from the Default Account:
+// The simplest way to create a transaction is to specify only the outputs
+// and the fee rate. By leaving the `Inputs` field as nil, the wallet will
+// automatically select coins from the default account to fund the
+// transaction.
+//
+// Example:
+//
+// intent := &TxIntent{
+// Outputs: outputs,
+// FeeRate: feeRate,
+// }
+//
+// 2. Manual Input Selection:
+// To have direct control over the inputs used, the caller can specify the
+// exact UTXOs to spend. This is achieved by setting the `Inputs` field to
+// an `InputsManual` struct, which contains a slice of the desired
+// `wire.OutPoint`s. In this mode, all coin selection logic is bypassed; the
+// wallet simply uses the provided inputs.
+//
+// Example:
+//
+// intent := &TxIntent{
+// Outputs: outputs,
+// Inputs: &InputsManual{UTXOs: []wire.OutPoint{...}},
+// FeeRate: feeRate,
+// ChangeSource: changeSource,
+// }
+//
+// 3. Policy-Based Coin Selection from an Account:
+// To have the wallet select inputs from a specific account, the caller can
+// specify a policy. This is achieved by setting the `Inputs` field to an
+// `InputsPolicy` struct. This struct defines the strategy (e.g.,
+// largest-first), the minimum number of confirmations, and the source of
+// the coins. If the `Source` is a `ScopedAccount`, the wallet will select
+// coins from that account. If the `Source` field is nil, the wallet will
+// use a default source, typically the default account.
+//
+// Example:
+//
+// intent := &TxIntent{
+// Outputs: outputs,
+// Inputs: &InputsPolicy{
+// Strategy: CoinSelectionLargest,
+// MinConfs: 1,
+// Source: &ScopedAccount{AccountName: "default", ...},
+// },
+// FeeRate: feeRate,
+// ChangeSource: changeSource,
+// }
+//
+// 4. Policy-Based Coin Selection from a specific set of UTXOs:
+// For more advanced control, the caller can provide a specific list of UTXOs
+// and have the coin selection algorithm choose the best subset from that
+// list. This is useful for scenarios like coin control where the user wants
+// to limit the potential inputs for a transaction. This is achieved by
+// setting the `Source` of an `InputsPolicy` to a `CoinSourceUTXOs` struct.
+//
+// Example:
+//
+// intent := &TxIntent{
+// Outputs: outputs,
+// Inputs: &InputsPolicy{
+// Strategy: CoinSelectionLargest,
+// MinConfs: 1,
+// Source: &CoinSourceUTXOs{
+// UTXOs: []wire.OutPoint{...},
+// },
+// },
+// FeeRate: feeRate,
+// ChangeSource: changeSource,
+// }
+type TxIntent struct {
+ // Outputs specifies the recipients and amounts for the transaction.
+ // This field is required.
+ Outputs []wire.TxOut
+
+ // Inputs defines the source of the inputs for the transaction. This
+ // must be one of the Inputs implementations (InputsManual or
+ // InputsPolicy). This field is required.
+ Inputs Inputs
+
+ // ChangeSource specifies the destination for the transaction's change
+ // output. If this field is nil, the wallet will use a default change
+ // source based on the account and scope of the inputs.
+ ChangeSource *ScopedAccount
+
+ // FeeRate specifies the desired fee rate for the transaction,
+ // expressed in satoshis per kilo-virtual-byte (sat/kvb). This field is
+ // required.
+ FeeRate btcunit.SatPerKVByte
+
+ // Label is an optional, human-readable label for the transaction. This
+ // can be used to associate a memo with the transaction for later
+ // reference.
+ Label string
+}
+
+// Inputs is a sealed interface that defines the source of inputs for a
+// transaction. It can either be a manually specified set of UTXOs or a policy
+// for coin selection. The sealed interface pattern is used here to
+// provide compile-time safety, ensuring that only the intended implementations
+// can be used.
+type Inputs interface {
+ // isInputs is a marker method that is part of the sealed interface
+ // pattern. It is unexported, so it can only be implemented by types
+ // within this package. This ensures that only the intended types
+ // can be used as an Inputs implementation.
+ isInputs()
+
+ // validate performs a series of checks on the input source to ensure
+ // it is well-formed. This method is called before any database
+ // transactions are opened, allowing for early, efficient validation.
+ validate() error
+}
+
+// InputsManual implements the Inputs interface and specifies the exact UTXOs
+// to be used as transaction inputs. When this is used, all automatic coin
+// selection logic is bypassed.
+type InputsManual struct {
+ // UTXOs is a slice of outpoints to be used as the exact inputs for the
+ // transaction. The wallet will validate that these UTXOs are known and
+ // spendable but will not perform any further coin selection.
+ UTXOs []wire.OutPoint
+}
+
+// InputsPolicy implements the Inputs interface and specifies the policy
+// for coin selection by the wallet.
+type InputsPolicy struct {
+ // Strategy is the algorithm to use for selecting coins (e.g., largest
+ // first, random). If this is nil, the wallet's default coin selection
+ // strategy will be used.
+ Strategy CoinSelectionStrategy
+
+ // MinConfs is the minimum number of confirmations a UTXO must have to
+ // be considered eligible for coin selection.
+ MinConfs uint32
+
+ // Source specifies the pool of UTXOs to select from. If this is nil,
+ // the wallet will use a default source (e.g., the default account).
+ // Otherwise, this must be one of the CoinSource implementations.
+ Source CoinSource
+}
+
+// isInputs marks InputsManual as an implementation of the Inputs interface.
+func (*InputsManual) isInputs() {}
+
+// validate performs validation on the manual inputs.
+func (i *InputsManual) validate() error {
+ return validateOutPoints(i.UTXOs)
+}
+
+// isInputs marks InputsPolicy as an implementation of the Inputs
+// interface.
+func (*InputsPolicy) isInputs() {}
+
+// validate performs validation on the input policy.
+func (i *InputsPolicy) validate() error {
+ if i.Source == nil {
+ return nil
+ }
+
+ switch source := i.Source.(type) {
+ // If the source is a scoped account, it must have a non-empty account
+ // name.
+ case *ScopedAccount:
+ if source.AccountName == "" {
+ return ErrMissingAccountName
+ }
+
+ // If the source is a list of UTXOs, it must not be empty and must not
+ // contain duplicates.
+ case *CoinSourceUTXOs:
+ return validateOutPoints(source.UTXOs)
+
+ // Any other source type is unsupported.
+ default:
+ return fmt.Errorf("%w: %T", ErrUnsupportedCoinSource, source)
+ }
+
+ return nil
+}
+
+// A compile-time assertion to ensure that all types implementing the Inputs
+// interface adhere to it.
+var _ Inputs = (*InputsManual)(nil)
+var _ Inputs = (*InputsPolicy)(nil)
+
+// CoinSource is a sealed interface that defines the pool of UTXOs available
+// for coin selection. The sealed interface pattern ensures that only
+// the intended implementations can be used.
+type CoinSource interface {
+ // isCoinSource is a marker method that is part of the sealed interface
+ // pattern. It is unexported, so it can only be implemented by types
+ // within this package. This ensures that only the intended types
+ // can be used as a CoinSource implementation.
+ isCoinSource()
+}
+
+// ScopedAccount defines a wallet account within a particular key scope. It is
+// used to specify the source of funds for coin selection and the
+// destination for change outputs.
+type ScopedAccount struct {
+ // AccountName specifies the name of the account. This must be a
+ // non-empty string.
+ AccountName string
+
+ // KeyScope specifies the key scope (e.g., P2WKH, P2TR).
+ KeyScope waddrmgr.KeyScope
+}
+
+// CoinSourceUTXOs specifies that the wallet should select coins from a
+// specific, predefined list of candidate UTXOs.
+type CoinSourceUTXOs struct {
+ // UTXOs is a slice of outpoints from which the coin selection
+ // algorithm will choose. This list must not be empty.
+ UTXOs []wire.OutPoint
+}
+
+// isCoinSource marks ScopedAccount as an implementation of the CoinSource
+// interface.
+func (ScopedAccount) isCoinSource() {}
+
+// isCoinSource marks CoinSourceUTXOs as an implementation of the CoinSource
+// interface.
+func (CoinSourceUTXOs) isCoinSource() {}
+
+// validateOutPoints checks a slice of `wire.OutPoint`s for emptiness and
+// duplicate entries. It returns `ErrManualInputsEmpty` if the slice is empty
+// and `ErrDuplicatedUtxo` if any duplicates are found.
+func validateOutPoints(outpoints []wire.OutPoint) error {
+ if len(outpoints) == 0 {
+ return ErrManualInputsEmpty
+ }
+
+ seenUTXOs := make(map[wire.OutPoint]struct{})
+ for _, utxo := range outpoints {
+ if _, ok := seenUTXOs[utxo]; ok {
+ return ErrDuplicatedUtxo
+ }
+
+ seenUTXOs[utxo] = struct{}{}
+ }
+
+ return nil
+}
+
+// A compile-time assertion to ensure that all types implementing the CoinSource
+// interface adhere to it.
+var _ CoinSource = (*ScopedAccount)(nil)
+var _ CoinSource = (*CoinSourceUTXOs)(nil)
+
+// validateTxIntent performs a series of checks on a TxIntent to ensure it is
+// well-formed. This function is called before any transaction creation logic
+// to ensure that the caller has provided a valid intent. This function is for
+// validation only and does not modify the TxIntent.
+//
+// The following checks are performed:
+// - The intent must have at least one output.
+// - Each output must not be a dust output.
+// - If a change source is specified, it must have a non-empty account name.
+// - The intent must have a valid, non-nil input source.
+// - The input source itself is validated via the `validate` method.
+func validateTxIntent(intent *TxIntent) error {
+ // The intent must have at least one output.
+ if len(intent.Outputs) == 0 {
+ return ErrNoTxOutputs
+ }
+
+ // Each output must not be a dust output according to the default relay
+ // fee policy.
+ for _, output := range intent.Outputs {
+ err := txrules.CheckOutput(
+ &output, txrules.DefaultRelayFeePerKb,
+ )
+ if err != nil {
+ return err
+ }
+ }
+
+ // If a change source is specified, it must have a non-empty account
+ // name.
+ if intent.ChangeSource != nil && intent.ChangeSource.AccountName == "" {
+ return ErrMissingAccountName
+ }
+
+ // If no input source is specified, an error is returned.
+ if intent.Inputs == nil {
+ return ErrMissingInputs
+ }
+
+ // Validate the inputs.
+ err := intent.Inputs.validate()
+ if err != nil {
+ return err
+ }
+
+ // The intent must have a non-zero fee rate.
+ if intent.FeeRate.LessThanOrEqual(btcunit.ZeroSatPerKVByte) {
+ return ErrMissingFeeRate
+ }
+
+ // Ensure the fee rate is not "insane". This prevents users from
+ // accidentally paying exorbitant fees.
+ if intent.FeeRate.GreaterThan(DefaultMaxFeeRate) {
+ return fmt.Errorf("%w: fee rate of %s is too high, "+
+ "max sane fee rate is %s", ErrFeeRateTooLarge,
+ intent.FeeRate, DefaultMaxFeeRate)
+ }
+
+ return nil
+}
+
+// prepareTxAuthSources creates the input and change sources required to
+// author a transaction.
+func (w *Wallet) prepareTxAuthSources(intent *TxIntent) (
+ txauthor.InputSource, *txauthor.ChangeSource, error) {
+ // Determine the change source. If not specified, a default will be
+ // used.
+ changeAccount := w.determineChangeSource(intent)
+
+ manager, err := w.addrStore.FetchScopedKeyManager(
+ changeAccount.KeyScope,
+ )
+ if err != nil {
+ return nil, nil, fmt.Errorf("%w: %s", ErrAccountNotFound,
+ changeAccount.AccountName)
+ }
+
+ var (
+ changeSource *txauthor.ChangeSource
+ inputSource txauthor.InputSource
+ )
+ // We perform the core logic of creating the input and change sources
+ // within a single database transaction to ensure atomicity.
+ err = walletdb.Update(w.cfg.DB, func(dbtx walletdb.ReadWriteTx) error {
+ changeKeyScope := &changeAccount.KeyScope
+ accountName := changeAccount.AccountName
+
+ addrmgrNs := dbtx.ReadBucket(waddrmgrNamespaceKey)
+
+ // Query the account's number using the account name.
+ //
+ // TODO(yy): Remove this query in upcoming SQL.
+ account, err := manager.LookupAccount(addrmgrNs, accountName)
+ if err != nil {
+ return fmt.Errorf("%w: %s", ErrAccountNotFound,
+ accountName)
+ }
+
+ // Create the change source, which is a closure that the
+ // txauthor package will use to generate a new change address
+ // when needed.
+ //
+ // TODO(yy): Refactor to ensure atomicity. The underlying
+ // `GetUnusedAddress` call creates its own database
+ // transaction, breaking the atomicity of this
+ // `walletdb.Update` block. A new method should be added to
+ // `AccountStore` that accepts an active database transaction
+ // and returns an unused address. This will allow the address
+ // derivation to occur within the same atomic transaction as
+ // the rest of the tx creation logic.
+ _, changeSource, err = w.addrMgrWithChangeSource(
+ dbtx, changeKeyScope, account,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Create the input source, which is a closure that the
+ // txauthor package will use to select coins.
+ inputSource, err = w.createInputSource(dbtx, intent)
+ if err != nil {
+ return err
+ }
+
+ return nil
+ })
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return inputSource, changeSource, nil
+}
+
+// CreateTransaction creates a new unsigned transaction spending unspent outputs
+// to the given outputs. It is the main implementation of the TxCreator
+// interface. The method will produce a valid, unsigned transaction, which can
+// then be passed to the Signer interface to be signed.
+func (w *Wallet) CreateTransaction(_ context.Context, intent *TxIntent) (
+ *txauthor.AuthoredTx, error) {
+
+ err := w.state.validateSynced()
+ if err != nil {
+ return nil, err
+ }
+
+ // Check that the intent is not nil.
+ if intent == nil {
+ return nil, ErrNilTxIntent
+ }
+
+ // If no input source is specified, an auto coin selection with the
+ // default account will be used.
+ if intent.Inputs == nil {
+ log.Debug("No input source specified, using default policy " +
+ "for automatic coin selection")
+
+ intent.Inputs = &InputsPolicy{}
+ }
+
+ err = validateTxIntent(intent)
+ if err != nil {
+ return nil, err
+ }
+
+ inputSource, changeSource, err := w.prepareTxAuthSources(intent)
+ if err != nil {
+ return nil, err
+ }
+
+ // The txauthor.NewUnsignedTransaction function expects a slice of
+ // *wire.TxOut, but our intent has a slice of wire.TxOut. We perform
+ // the conversion here.
+ //
+ // TODO(yy): change the signature of `NewUnsignedTransaction` to take a
+ // list of `wire.TxOut`.
+ outputs := make([]*wire.TxOut, 0, len(intent.Outputs))
+ for _, output := range intent.Outputs {
+ outputs = append(outputs, &output)
+ }
+
+ // With the input source and change source prepared, we can now call the
+ // txauthor package to perform the actual coin selection and create the
+ // unsigned transaction.
+ feeSatPerKb := intent.FeeRate.Val()
+
+ tx, err := txauthor.NewUnsignedTransaction(
+ outputs, feeSatPerKb, inputSource, changeSource,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Randomize the position of the change output, if one was created. This
+ // helps to improve privacy by making it harder to distinguish change
+ // outputs from other outputs.
+ if tx.ChangeIndex >= 0 {
+ tx.RandomizeChangePosition()
+ }
+
+ return tx, nil
+}
+
+// determineChangeSource determines the source for the transaction's change
+// output. If a source is specified in the intent, it is used. Otherwise, a
+// default is determined based on the input source or the wallet's default
+// account. When falling back to the default account, the P2TR (Taproot) key
+// scope is used.
+func (w *Wallet) determineChangeSource(intent *TxIntent) *ScopedAccount {
+ // If a change source is specified in the intent, use it.
+ if intent.ChangeSource != nil {
+ return intent.ChangeSource
+ }
+
+ // If the inputs are from a specific account, use that for change.
+ if policy, ok := intent.Inputs.(*InputsPolicy); ok {
+ if account, ok := policy.Source.(*ScopedAccount); ok {
+ return account
+ }
+ }
+
+ // Otherwise, use the default account.
+ // TODO(yy): The default key scope is currently hardcoded to P2TR
+ // (Taproot). This should be made configurable.
+ return &ScopedAccount{
+ AccountName: waddrmgr.DefaultAccountName,
+ KeyScope: waddrmgr.KeyScopeBIP0086,
+ }
+}
+
+// createInputSource creates a txauthor.InputSource that will be used to select
+// inputs for a transaction. It acts as a dispatcher, delegating to either the
+// manual or policy-based input source creator based on the type of the intent's
+// Inputs field.
+//
+// TODO(yy): We use customized queries here to make the utxo lookups atomic
+// inside a big tx that's created in `CreateTransaction`, however, we should
+// instead have methods made on the `txStore`, which takes a db tx and use them
+// here, as the logic will be largely overlapped with the interface methods used
+// in `wallet/utxo_manager.go`.
+func (w *Wallet) createInputSource(dbtx walletdb.ReadTx, intent *TxIntent) (
+ txauthor.InputSource, error) {
+
+ switch inputs := intent.Inputs.(type) {
+ // If the inputs are manually specified, we create a "constant" input
+ // source that will only ever return the specified UTXOs.
+ case *InputsManual:
+ return w.createManualInputSource(dbtx, inputs)
+
+ // If the inputs are policy-based, we create an input source that will
+ // perform coin selection.
+ case *InputsPolicy:
+ return w.createPolicyInputSource(dbtx, inputs, intent.FeeRate)
+
+ // Any other type is unsupported.
+ default:
+ return nil, ErrUnsupportedTxInputs
+ }
+}
+
+// createManualInputSource creates an input source from a list of manually
+// specified UTXOs. It fetches the UTXOs directly from the database and ensures
+// that they are eligible for spending.
+func (w *Wallet) createManualInputSource(dbtx walletdb.ReadTx,
+ inputs *InputsManual) (
+ txauthor.InputSource, error) {
+
+ txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
+
+ // Create a slice to hold the eligible UTXOs.
+ eligibleSelectedUtxo := make(
+ []wtxmgr.Credit, 0, len(inputs.UTXOs),
+ )
+
+ // Iterate through the manually specified UTXOs and ensure that each
+ // one is eligible for spending.
+ for _, outpoint := range inputs.UTXOs {
+ // Fetch the UTXO from the database.
+ credit, err := w.txStore.GetUtxo(txmgrNs, outpoint)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %v", ErrUtxoNotEligible,
+ outpoint)
+ }
+
+ // TODO(yy): check for locked utxos and log a warning.
+ eligibleSelectedUtxo = append(eligibleSelectedUtxo, *credit)
+ }
+
+ // Return a constant input source that will only provide the selected
+ // UTXOs.
+ return constantInputSource(eligibleSelectedUtxo), nil
+}
+
+// createPolicyInputSource creates an input source that will perform automatic
+// coin selection based on the provided policy.
+func (w *Wallet) createPolicyInputSource(dbtx walletdb.ReadTx,
+ policy *InputsPolicy, feeRate btcunit.SatPerKVByte) (
+ txauthor.InputSource, error) {
+
+ // Fall back to the default coin selection strategy if none is supplied.
+ strategy := policy.Strategy
+ if strategy == nil {
+ strategy = CoinSelectionLargest
+ }
+
+ // Get the full set of eligible UTXOs based on the policy's source
+ // and confirmation requirements.
+ eligible, err := w.getEligibleUTXOs(
+ dbtx, policy.Source, policy.MinConfs,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Wrap our wtxmgr.Credit coins in a `Coin` type that implements the
+ // SelectableCoin interface. This allows the coin selection strategy
+ // to operate on them.
+ //
+ // TODO(yy): unify the types here - we should use `Utxo` instead of
+ // `Credit` or `Coin`.
+ wrappedEligible := make([]Coin, len(eligible))
+ for i := range eligible {
+ wrappedEligible[i] = Coin{
+ TxOut: wire.TxOut{
+ Value: int64(
+ eligible[i].Amount,
+ ),
+ PkScript: eligible[i].PkScript,
+ },
+ OutPoint: eligible[i].OutPoint,
+ }
+ }
+
+ // Arrange the eligible coins according to the chosen strategy (e.g.,
+ // sort by largest first, or shuffle for random selection).
+ feeSatPerKb := feeRate.Val()
+
+ arrangedCoins, err := strategy.ArrangeCoins(
+ wrappedEligible, feeSatPerKb,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Return an input source that will dispense the arranged coins one by
+ // one as requested by the txauthor.
+ return makeInputSource(arrangedCoins), nil
+}
+
+// getEligibleUTXOs returns a slice of eligible UTXOs that can be used as
+// inputs for a transaction, based on the specified source and confirmation
+// requirements. A UTXO is considered ineligible if it is not found in the
+// wallet's transaction store or if it does not meet the minimum confirmation
+// requirements.
+func (w *Wallet) getEligibleUTXOs(dbtx walletdb.ReadTx,
+ source CoinSource, minconf uint32) ([]wtxmgr.Credit, error) {
+
+ // TODO(yy): remove this block stamp check. The block stamp should be
+ // passed in as a parameter.
+ bs, err := w.cfg.Chain.BlockStamp()
+ if err != nil {
+ return nil, err
+ }
+
+ // Dispatch based on the type of the coin source.
+ switch source := source.(type) {
+ // If the source is nil, we'll use the default account.
+ case nil:
+ return w.filterEligibleOutputs(
+ dbtx, &waddrmgr.KeyScopeBIP0086,
+ waddrmgr.DefaultAccountNum, minconf, bs,
+ )
+
+ // If the source is a scoped account, we find all eligible outputs for
+ // that specific account and key scope.
+ case *ScopedAccount:
+ return w.getEligibleUTXOsFromAccount(dbtx, source, minconf, bs)
+
+ // If the source is a list of UTXOs, we validate and fetch each UTXO
+ // from the provided list.
+ case *CoinSourceUTXOs:
+ return w.getEligibleUTXOsFromList(dbtx, source, minconf, bs)
+
+ // Any other source type is unsupported.
+ default:
+ return nil, ErrUnsupportedCoinSource
+ }
+}
+
+// getEligibleUTXOsFromAccount returns a slice of eligible UTXOs for a specific
+// account and key scope.
+func (w *Wallet) getEligibleUTXOsFromAccount(dbtx walletdb.ReadTx,
+ source *ScopedAccount, minconf uint32, bs *waddrmgr.BlockStamp) (
+ []wtxmgr.Credit, error) {
+
+ keyScope := &source.KeyScope
+
+ manager, err := w.addrStore.FetchScopedKeyManager(*keyScope)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %s", ErrAccountNotFound,
+ source.AccountName)
+ }
+
+ addrmgrNs := dbtx.ReadBucket(waddrmgrNamespaceKey)
+
+ account, err := manager.LookupAccount(addrmgrNs, source.AccountName)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %s", ErrAccountNotFound,
+ source.AccountName)
+ }
+
+ return w.filterEligibleOutputs(dbtx, keyScope, account, minconf, bs)
+}
+
+// getEligibleUTXOsFromList returns a slice of eligible UTXOs from a specified
+// list of outpoints.
+func (w *Wallet) getEligibleUTXOsFromList(dbtx walletdb.ReadTx,
+ source *CoinSourceUTXOs, minconf uint32, bs *waddrmgr.BlockStamp) (
+ []wtxmgr.Credit, error) {
+
+ // Create a slice to hold the eligible UTXOs.
+ eligible := make([]wtxmgr.Credit, 0, len(source.UTXOs))
+
+ // Get the transaction manager's namespace.
+ txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
+
+ // Iterate through the manually specified UTXOs and ensure that each
+ // one is eligible for spending.
+ for _, outpoint := range source.UTXOs {
+ // Fetch the UTXO from the database.
+ credit, err := w.txStore.GetUtxo(txmgrNs, outpoint)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %v",
+ ErrUtxoNotEligible, outpoint)
+ }
+
+ // A UTXO is only eligible if it has reached the required
+ // number of confirmations.
+ if !hasMinConfs(minconf, credit.Height, bs.Height) {
+ // Calculate the number of confirmations for the
+ // warning message.
+ confs := calcConf(credit.Height, bs.Height)
+
+ log.Warnf("Skipping user-specified UTXO %v "+
+ "because it has %d confs but needs %d",
+ credit.OutPoint, confs, minconf)
+
+ continue
+ }
+
+ // If the UTXO is eligible, add it to the list.
+ eligible = append(eligible, *credit)
+ }
+
+ return eligible, nil
+}
+
+func makeInputSource(eligible []Coin) txauthor.InputSource {
+ // Current inputs and their total value. These are closed over by the
+ // returned input source and reused across multiple calls.
+ currentTotal := btcutil.Amount(0)
+ currentInputs := make([]*wire.TxIn, 0, len(eligible))
+ currentScripts := make([][]byte, 0, len(eligible))
+ currentInputValues := make([]btcutil.Amount, 0, len(eligible))
+
+ return func(target btcutil.Amount) (btcutil.Amount, []*wire.TxIn,
+ []btcutil.Amount, [][]byte, error) {
+
+ for currentTotal < target && len(eligible) != 0 {
+ nextCredit := eligible[0]
+ prevOut := nextCredit.TxOut
+ outpoint := nextCredit.OutPoint
+ eligible = eligible[1:]
+
+ nextInput := wire.NewTxIn(&outpoint, nil, nil)
+ currentTotal += btcutil.Amount(prevOut.Value)
+
+ currentInputs = append(currentInputs, nextInput)
+ currentScripts = append(
+ currentScripts, prevOut.PkScript,
+ )
+ currentInputValues = append(
+ currentInputValues,
+ btcutil.Amount(prevOut.Value),
+ )
+ }
+
+ return currentTotal, currentInputs, currentInputValues,
+ currentScripts, nil
+ }
+}
+
+// constantInputSource creates an input source function that always returns the
+// static set of user-selected UTXOs.
+func constantInputSource(eligible []wtxmgr.Credit) txauthor.InputSource {
+ // Current inputs and their total value. These won't change over
+ // different invocations as we want our inputs to remain static since
+ // they're selected by the user.
+ currentTotal := btcutil.Amount(0)
+ currentInputs := make([]*wire.TxIn, 0, len(eligible))
+ currentScripts := make([][]byte, 0, len(eligible))
+ currentInputValues := make([]btcutil.Amount, 0, len(eligible))
+
+ for _, credit := range eligible {
+ nextInput := wire.NewTxIn(&credit.OutPoint, nil, nil)
+ currentTotal += credit.Amount
+
+ currentInputs = append(currentInputs, nextInput)
+ currentScripts = append(currentScripts, credit.PkScript)
+ currentInputValues = append(currentInputValues, credit.Amount)
+ }
+
+ return func(_ btcutil.Amount) (btcutil.Amount, []*wire.TxIn,
+ []btcutil.Amount, [][]byte, error) {
+
+ return currentTotal, currentInputs, currentInputValues,
+ currentScripts, nil
+ }
+}
+
+// filterEligibleOutputs finds eligible outputs for the given key scope and
+// account.
+//
+// We will build a single query for this operation, so skip the linter for now.
+//
+//nolint:cyclop
+func (w *Wallet) filterEligibleOutputs(dbtx walletdb.ReadTx,
+ keyScope *waddrmgr.KeyScope, account uint32, minconf uint32,
+ bs *waddrmgr.BlockStamp) ([]wtxmgr.Credit, error) {
+
+ addrmgrNs := dbtx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
+
+ unspent, err := w.txStore.UnspentOutputs(txmgrNs)
+ if err != nil {
+ return nil, err
+ }
+
+ // TODO: Eventually all of these filters (except perhaps output locking)
+ // should be handled by the call to UnspentOutputs (or similar).
+ // Because one of these filters requires matching the output script to
+ // the desired account, this change depends on making wtxmgr a waddrmgr
+ // dependency and requesting unspent outputs for a single account.
+ eligible := make([]wtxmgr.Credit, 0, len(unspent))
+ for i := range unspent {
+ output := &unspent[i]
+
+ // Only include this output if it meets the required number of
+ // confirmations. Coinbase transactions must have reached
+ // maturity before their outputs may be spent.
+ if !hasMinConfs(minconf, output.Height, bs.Height) {
+ continue
+ }
+
+ if output.FromCoinBase {
+ target := w.cfg.ChainParams.CoinbaseMaturity
+ if !hasMinConfs(
+ uint32(target), output.Height, bs.Height,
+ ) {
+
+ continue
+ }
+ }
+
+ // Locked unspent outputs are skipped.
+ if output.Locked {
+ continue
+ }
+
+ // Only include the output if it is associated with the passed
+ // account.
+ //
+ // TODO: Handle multisig outputs by determining if enough of the
+ // addresses are controlled.
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(
+ output.PkScript, w.cfg.ChainParams,
+ )
+ if err != nil || len(addrs) != 1 {
+ continue
+ }
+
+ scopedMgr, addrAcct, err := w.addrStore.AddrAccount(
+ addrmgrNs, addrs[0],
+ )
+ if err != nil {
+ continue
+ }
+
+ if keyScope != nil && scopedMgr.Scope() != *keyScope {
+ continue
+ }
+
+ if addrAcct != account {
+ continue
+ }
+
+ eligible = append(eligible, *output)
+ }
+
+ return eligible, nil
+}
+
+// inputYieldsPositively returns a boolean indicating whether this input yields
+// positively if added to a transaction. This determination is based on the
+// best-case added virtual size. For edge cases this function can return true
+// while the input is yielding slightly negative as part of the final
+// transaction.
+func inputYieldsPositively(credit *wire.TxOut,
+ feeRatePerKb btcutil.Amount) bool {
+
+ inputSize := txsizes.GetMinInputVirtualSize(credit.PkScript)
+ feeRate := btcunit.NewSatPerKVByte(feeRatePerKb)
+ inputFee := feeRate.FeeForVByte(btcunit.NewVByte(inputSize))
+
+ return inputFee < btcutil.Amount(credit.Value)
+}
+
+func getScriptSize(addrType waddrmgr.AddressType) (int, error) {
+ switch addrType {
+ case waddrmgr.PubKeyHash:
+ return txsizes.P2PKHPkScriptSize, nil
+
+ case waddrmgr.NestedWitnessPubKey:
+ return txsizes.NestedP2WPKHPkScriptSize, nil
+
+ case waddrmgr.WitnessPubKey:
+ return txsizes.P2WPKHPkScriptSize, nil
+
+ case waddrmgr.TaprootPubKey:
+ return txsizes.P2TRPkScriptSize, nil
+
+ case waddrmgr.Script, waddrmgr.RawPubKey, waddrmgr.WitnessScript,
+ waddrmgr.TaprootScript:
+ return 0, fmt.Errorf("%w: %v", ErrUnsupportedAddressType,
+ addrType)
+
+ default:
+ return 0, fmt.Errorf("%w: %v", ErrUnsupportedAddressType,
+ addrType)
+ }
+}
+
+// addrMgrWithChangeSource returns the address manager bucket and a change
+// source that returns change addresses from said address manager. The change
+// addresses will come from the specified key scope and account, unless a key
+// scope is not specified. In that case, change addresses will always come from
+// the P2WKH key scope.
+func (w *Wallet) addrMgrWithChangeSource(dbtx walletdb.ReadWriteTx,
+ changeKeyScope *waddrmgr.KeyScope, account uint32) (
+ walletdb.ReadWriteBucket, *txauthor.ChangeSource, error) {
+
+ // Determine the address type for change addresses of the given
+ // account.
+ if changeKeyScope == nil {
+ changeKeyScope = &waddrmgr.KeyScopeBIP0086
+ }
+
+ addrType := waddrmgr.ScopeAddrMap[*changeKeyScope].InternalAddrType
+
+ // It's possible for the account to have an address schema override, so
+ // prefer that if it exists.
+ addrmgrNs := dbtx.ReadWriteBucket(waddrmgrNamespaceKey)
+
+ scopeMgr, err := w.addrStore.FetchScopedKeyManager(*changeKeyScope)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ accountInfo, err := scopeMgr.AccountProperties(addrmgrNs, account)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ if accountInfo.AddrSchema != nil {
+ addrType = accountInfo.AddrSchema.InternalAddrType
+ }
+
+ // Compute the expected size of the script for the change address type.
+ scriptSize, err := getScriptSize(addrType)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ newChangeScript := func() ([]byte, error) {
+ // Derive the change output script. As a hack to allow spending
+ // from the imported account, change addresses are created from
+ // account 0.
+ var (
+ changeAddr address.Address
+ err error
+ )
+ if account == waddrmgr.ImportedAddrAccount {
+ changeAddr, err = w.newChangeAddress(
+ addrmgrNs, 0, *changeKeyScope,
+ )
+ } else {
+ changeAddr, err = w.newChangeAddress(
+ addrmgrNs, account, *changeKeyScope,
+ )
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ return txscript.PayToAddrScript(changeAddr)
+ }
+
+ return addrmgrNs, &txauthor.ChangeSource{
+ ScriptSize: scriptSize,
+ NewScript: newChangeScript,
+ }, nil
+}
+
+// sortByAmount is a generic sortable type for sorting coins by their amount.
+type sortByAmount []Coin
+
+func (s sortByAmount) Len() int { return len(s) }
+func (s sortByAmount) Less(i, j int) bool {
+ return s[i].Value < s[j].Value
+}
+func (s sortByAmount) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+// LargestFirstCoinSelector is an implementation of the CoinSelectionStrategy
+// that always selects the largest coins first.
+type LargestFirstCoinSelector struct{}
+
+// ArrangeCoins takes a list of coins and arranges them according to the
+// specified coin selection strategy and fee rate.
+func (*LargestFirstCoinSelector) ArrangeCoins(eligible []Coin,
+ _ btcutil.Amount) ([]Coin, error) {
+
+ sort.Sort(sort.Reverse(sortByAmount(eligible)))
+
+ return eligible, nil
+}
+
+// RandomCoinSelector is an implementation of the CoinSelectionStrategy that
+// selects coins at random. This prevents the creation of ever smaller UTXOs
+// over time that may never become economical to spend.
+type RandomCoinSelector struct{}
+
+// ArrangeCoins takes a list of coins and arranges them according to the
+// specified coin selection strategy and fee rate.
+func (*RandomCoinSelector) ArrangeCoins(eligible []Coin,
+ feeSatPerKb btcutil.Amount) ([]Coin, error) {
+
+ // Skip inputs that do not raise the total transaction output
+ // value at the requested fee rate.
+ positivelyYielding := make([]Coin, 0, len(eligible))
+ for _, output := range eligible {
+ if !inputYieldsPositively(&output.TxOut, feeSatPerKb) {
+ continue
+ }
+
+ positivelyYielding = append(positivelyYielding, output)
+ }
+
+ rand.Shuffle(len(positivelyYielding), func(i, j int) {
+ positivelyYielding[i], positivelyYielding[j] =
+ positivelyYielding[j], positivelyYielding[i]
+ })
+
+ return positivelyYielding, nil
+}
diff --git a/wallet/tx_creator_test.go b/wallet/tx_creator_test.go
new file mode 100644
index 0000000000..2f41dac38f
--- /dev/null
+++ b/wallet/tx_creator_test.go
@@ -0,0 +1,1362 @@
+package wallet
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcec/v2"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/pkg/btcunit"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/wallet/txrules"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ // errStrategy is used to simulate failures in coin selection
+ // strategies within tests.
+ errStrategy = errors.New("strategy error")
+
+ // errDB is used to simulate database operation failures within tests.
+ errDB = errors.New("db error")
+
+ // defaultAccountName is the name of the default account.
+ defaultAccountName = "default"
+)
+
+// TestValidateTxIntent ensures that the validateTxIntent function returns
+// errors for all expected invalid transaction intents, and that it returns nil
+// for valid intents. The test covers a range of scenarios, including missing
+// inputs or outputs, dust outputs, duplicate UTXOs, and invalid account or
+// change source configurations.
+func TestValidateTxIntent(t *testing.T) {
+ t.Parallel()
+
+ const defaultAccountName = "default"
+
+ // Define a set of valid outputs and inputs to be reused across test
+ // cases.
+ validOutput := wire.TxOut{Value: 10000, PkScript: []byte{}}
+ validUTXO := wire.OutPoint{Hash: [32]byte{1}, Index: 0}
+ validAccountName := defaultAccountName
+ validScopedAccount := &ScopedAccount{
+ AccountName: validAccountName,
+ KeyScope: waddrmgr.KeyScopeBIP0086,
+ }
+ defaultFeeRate := btcunit.NewSatPerKVByte(1000)
+
+ // Define the test cases, each representing a different scenario for
+ // validating a TxIntent.
+ testCases := []struct {
+ name string
+ intent *TxIntent
+ expectedErr error
+ }{
+ {
+ name: "valid intent with manual inputs",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsManual{
+ UTXOs: []wire.OutPoint{validUTXO},
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "valid intent with policy inputs " +
+ "(scoped account)",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsPolicy{
+ Source: validScopedAccount,
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "valid intent with policy inputs (utxo source)",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsPolicy{
+ Source: &CoinSourceUTXOs{
+ UTXOs: []wire.OutPoint{
+ validUTXO,
+ },
+ },
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "valid intent with nil source in policy",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsPolicy{Source: nil},
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "invalid intent - nil inputs",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: nil,
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: ErrMissingInputs,
+ },
+ {
+ name: "invalid intent - no outputs",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{},
+ Inputs: &InputsManual{
+ UTXOs: []wire.OutPoint{validUTXO},
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: ErrNoTxOutputs,
+ },
+ {
+ name: "invalid intent - dust output",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{{Value: 1}},
+ Inputs: &InputsManual{
+ UTXOs: []wire.OutPoint{validUTXO},
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: txrules.ErrOutputIsDust,
+ },
+ {
+ name: "invalid intent - empty manual inputs",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsManual{
+ UTXOs: []wire.OutPoint{},
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: ErrManualInputsEmpty,
+ },
+ {
+ name: "invalid intent - duplicate manual inputs",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsManual{
+ UTXOs: []wire.OutPoint{
+ validUTXO, validUTXO,
+ },
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: ErrDuplicatedUtxo,
+ },
+ {
+ name: "invalid intent - empty account name in source",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsPolicy{
+ Source: &ScopedAccount{AccountName: ""},
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: ErrMissingAccountName,
+ },
+ {
+ name: "invalid intent - empty utxo list in source",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsPolicy{
+ Source: &CoinSourceUTXOs{
+ UTXOs: []wire.OutPoint{},
+ },
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: ErrManualInputsEmpty,
+ },
+ {
+ name: "invalid intent - duplicate utxos in policy",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsPolicy{
+ Source: &CoinSourceUTXOs{
+ UTXOs: []wire.OutPoint{
+ validUTXO, validUTXO,
+ },
+ },
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: ErrDuplicatedUtxo,
+ },
+ {
+ name: "invalid intent - unsupported coin source",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsPolicy{
+ Source: &unsupportedCoinSource{},
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: ErrUnsupportedCoinSource,
+ },
+ {
+ name: "invalid intent - unsupported inputs type",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &unsupportedInputs{},
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "invalid intent - empty account name in change",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsManual{
+ UTXOs: []wire.OutPoint{validUTXO},
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: "",
+ KeyScope: waddrmgr.KeyScopeBIP0086,
+ },
+ FeeRate: defaultFeeRate,
+ },
+ expectedErr: ErrMissingAccountName,
+ },
+ {
+ name: "invalid intent - zero fee rate",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsManual{
+ UTXOs: []wire.OutPoint{validUTXO},
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ KeyScope: waddrmgr.KeyScopeBIP0086,
+ },
+ FeeRate: btcunit.ZeroSatPerKVByte,
+ },
+ expectedErr: ErrMissingFeeRate,
+ },
+ {
+ name: "invalid intent - insane fee rate",
+ intent: &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsManual{
+ UTXOs: []wire.OutPoint{validUTXO},
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ KeyScope: waddrmgr.KeyScopeBIP0086,
+ },
+ FeeRate: btcunit.NewSatPerKVByte(2_000_000),
+ },
+ expectedErr: ErrFeeRateTooLarge,
+ },
+ }
+
+ // Iterate through all test cases and run them.
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Call the validate function and check that the error
+ // matches the expected error.
+ err := validateTxIntent(tc.intent)
+ require.ErrorIs(t, err, tc.expectedErr)
+ })
+ }
+}
+
+// unsupportedInputs is a mock implementation of the Inputs interface used for
+// testing purposes.
+type unsupportedInputs struct{}
+
+func (u *unsupportedInputs) isInputs() {}
+func (u *unsupportedInputs) validate() error { return nil }
+
+// unsupportedCoinSource is a mock implementation of the CoinSource interface
+// used for testing purposes.
+type unsupportedCoinSource struct{}
+
+func (u *unsupportedCoinSource) isCoinSource() {}
+
+// TestDetermineChangeSource tests the behavior of the determineChangeSource
+// method, ensuring that it correctly selects a change source based on the
+// transaction intent. It covers scenarios where the change source is
+// explicitly provided, derived from the input policy, or falls back to the
+// default P2TR account.
+func TestDetermineChangeSource(t *testing.T) {
+ t.Parallel()
+
+ w, _ := createStartedWalletWithMocks(t)
+
+ // Define a set of accounts to be reused across test cases.
+ explicitChangeSource := &ScopedAccount{
+ AccountName: "explicit",
+ KeyScope: waddrmgr.KeyScopeBIP0044,
+ }
+ policyAccountSource := &ScopedAccount{
+ AccountName: "policy",
+ KeyScope: waddrmgr.KeyScopeBIP0049Plus,
+ }
+ defaultAccountSource := &ScopedAccount{
+ AccountName: waddrmgr.DefaultAccountName,
+ KeyScope: waddrmgr.KeyScopeBIP0086,
+ }
+
+ testCases := []struct {
+ name string
+ intent *TxIntent
+ expectedSource *ScopedAccount
+ }{
+ {
+ name: "explicit change source",
+ intent: &TxIntent{
+ ChangeSource: explicitChangeSource,
+ },
+ expectedSource: explicitChangeSource,
+ },
+ {
+ name: "nil change source with policy account",
+ intent: &TxIntent{
+ Inputs: &InputsPolicy{
+ Source: policyAccountSource,
+ },
+ ChangeSource: nil,
+ },
+ expectedSource: policyAccountSource,
+ },
+ {
+ name: "nil change source with manual inputs",
+ intent: &TxIntent{
+ Inputs: &InputsManual{},
+ ChangeSource: nil,
+ },
+ expectedSource: defaultAccountSource,
+ },
+ {
+ name: "nil change source with non-account policy",
+ intent: &TxIntent{
+ Inputs: &InputsPolicy{
+ Source: &CoinSourceUTXOs{},
+ },
+ ChangeSource: nil,
+ },
+ expectedSource: defaultAccountSource,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ source := w.determineChangeSource(tc.intent)
+ require.Equal(t, tc.expectedSource, source)
+ })
+ }
+}
+
+type mockReadBucket struct {
+ walletdb.ReadBucket
+}
+
+type mockReadTx struct {
+ walletdb.ReadTx
+}
+
+func (m *mockReadTx) ReadBucket(key []byte) walletdb.ReadBucket {
+ return &mockReadBucket{}
+}
+
+// TestGetEligibleUTXOsFromList tests that the getEligibleUTXOsFromList method
+// correctly filters a list of UTXOs based on their confirmation status. It
+// ensures that UTXOs with sufficient confirmations are included, while those
+// that are unconfirmed or do not meet the minimum confirmation requirement are
+// excluded. The test also verifies that an error is returned if a specified
+// UTXO is not found in the wallet.
+func TestGetEligibleUTXOsFromList(t *testing.T) {
+ t.Parallel()
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Define a block stamp for the current chain height.
+ currentHeight := int32(100)
+ blockStamp := &waddrmgr.BlockStamp{
+ Height: currentHeight,
+ }
+
+ // Define some UTXOs.
+ // This UTXO has 1 confirmation.
+ utxo1 := wire.OutPoint{Hash: [32]byte{1}, Index: 0}
+
+ // This UTXO has 6 confirmations.
+ utxo2 := wire.OutPoint{Hash: [32]byte{2}, Index: 0}
+
+ // This UTXO is unconfirmed.
+ utxo3 := wire.OutPoint{Hash: [32]byte{3}, Index: 0}
+
+ // This UTXO is not found.
+ utxo4 := wire.OutPoint{Hash: [32]byte{4}, Index: 0}
+
+ // Define the corresponding credits.
+ credit1 := &wtxmgr.Credit{
+ OutPoint: utxo1,
+ BlockMeta: wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ // 1 conf = 100 - 100 + 1.
+ Height: currentHeight,
+ },
+ },
+ }
+ credit2 := &wtxmgr.Credit{
+ OutPoint: utxo2,
+ BlockMeta: wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ // 6 confs = 100 - 95 + 1.
+ Height: currentHeight - 5,
+ },
+ },
+ }
+ credit3 := &wtxmgr.Credit{
+ OutPoint: utxo3,
+ BlockMeta: wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ // Unconfirmed.
+ Height: -1,
+ },
+ },
+ }
+
+ // Set up mock calls for txStore.GetUtxo.
+ mocks.txStore.On("GetUtxo", mock.Anything, utxo1).Return(credit1, nil)
+ mocks.txStore.On("GetUtxo", mock.Anything, utxo2).Return(credit2, nil)
+ mocks.txStore.On("GetUtxo", mock.Anything, utxo3).Return(credit3, nil)
+ mocks.txStore.On("GetUtxo", mock.Anything, utxo4).Return(
+ nil, wtxmgr.ErrUtxoNotFound,
+ )
+
+ testCases := []struct {
+ name string
+ source *CoinSourceUTXOs
+ minconf uint32
+ expectedUtxos []wtxmgr.Credit
+ expectedErr error
+ }{
+ {
+ name: "all utxos with minconf 0",
+ source: &CoinSourceUTXOs{
+ UTXOs: []wire.OutPoint{utxo1, utxo2, utxo3},
+ },
+ minconf: 0,
+ expectedUtxos: []wtxmgr.Credit{
+ *credit1, *credit2, *credit3,
+ },
+ },
+ {
+ name: "1 conf required",
+ source: &CoinSourceUTXOs{
+ UTXOs: []wire.OutPoint{utxo1, utxo2, utxo3},
+ },
+ minconf: 1,
+ expectedUtxos: []wtxmgr.Credit{*credit1, *credit2},
+ },
+ {
+ name: "6 confs required",
+ source: &CoinSourceUTXOs{
+ UTXOs: []wire.OutPoint{utxo1, utxo2, utxo3},
+ },
+ minconf: 6,
+ expectedUtxos: []wtxmgr.Credit{*credit2},
+ },
+ {
+ name: "7 confs required",
+ source: &CoinSourceUTXOs{
+ UTXOs: []wire.OutPoint{utxo1, utxo2, utxo3},
+ },
+ minconf: 7,
+ expectedUtxos: []wtxmgr.Credit{},
+ },
+ {
+ name: "utxo not found",
+ source: &CoinSourceUTXOs{
+ UTXOs: []wire.OutPoint{utxo1, utxo4},
+ },
+ minconf: 1,
+ expectedErr: ErrUtxoNotEligible,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ dbtx := &mockReadTx{}
+ utxos, err := w.getEligibleUTXOsFromList(
+ dbtx, tc.source, tc.minconf, blockStamp,
+ )
+
+ require.ErrorIs(t, err, tc.expectedErr)
+
+ if err == nil {
+ require.ElementsMatch(
+ t, tc.expectedUtxos, utxos,
+ )
+ }
+ })
+ }
+}
+
+// TestGetEligibleUTXOsFromAccount tests that the getEligibleUTXOsFromAccount
+// method correctly returns an ErrAccountNotFound when the specified account
+// does not exist. This ensures that the function properly handles cases where
+// UTXOs are requested from a non-existent account.
+func TestGetEligibleUTXOsFromAccount(t *testing.T) {
+ t.Parallel()
+
+ // Define a block stamp for the current chain height.
+ blockStamp := &waddrmgr.BlockStamp{
+ Height: 100,
+ }
+
+ keyScope := waddrmgr.KeyScopeBIP0086
+ minconf := uint32(1)
+
+ w, mocks := createStartedWalletWithMocks(t)
+ accountStore := &mockAccountStore{}
+ mocks.addrStore.On("FetchScopedKeyManager", keyScope).
+ Return(accountStore, nil)
+
+ // We need to define the error type explicitly to avoid mock panics.
+ errNotFound := waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrAccountNotFound,
+ }
+ accountStore.On("LookupAccount", mock.Anything, "unknown").
+ Return(uint32(0), errNotFound)
+
+ _, err := w.getEligibleUTXOsFromAccount(
+ &mockReadTx{},
+ &ScopedAccount{
+ AccountName: "unknown",
+ KeyScope: keyScope,
+ },
+ minconf, blockStamp,
+ )
+ require.ErrorIs(t, err, ErrAccountNotFound)
+}
+
+// TestGetEligibleUTXOs serves as a comprehensive test suite for the
+// getEligibleUTXOs method, which acts as a dispatcher based on the provided
+// CoinSource type. This test ensures that the method correctly delegates to the
+// appropriate sub-handler for each source type (scoped account, UTXO list, or
+// nil for default) and that it properly returns an error for unsupported
+// source types.
+func TestGetEligibleUTXOs(t *testing.T) {
+ t.Parallel()
+
+ minconf := uint32(1)
+ utxo := wire.OutPoint{}
+ credit := &wtxmgr.Credit{}
+ scopedAccount := &ScopedAccount{
+ AccountName: defaultAccountName,
+ KeyScope: waddrmgr.KeyScopeBIP0086,
+ }
+
+ testCases := []struct {
+ name string
+ source CoinSource
+ setupMocks func(m *mockWalletDeps, source CoinSource)
+ expectedErr error
+ }{
+ {
+ name: "scoped account",
+ source: scopedAccount,
+ setupMocks: func(
+ m *mockWalletDeps, source CoinSource,
+ ) {
+
+ m.chain.On("BlockStamp").Return(
+ &waddrmgr.BlockStamp{}, nil,
+ )
+ scopedSrc, ok := source.(*ScopedAccount)
+ require.True(t, ok)
+ accountStore := &mockAccountStore{}
+
+ m.addrStore.On("FetchScopedKeyManager",
+ scopedSrc.KeyScope,
+ ).Return(accountStore, nil)
+
+ accountStore.On("LookupAccount",
+ mock.Anything, scopedSrc.AccountName,
+ ).Return(uint32(0), nil)
+
+ m.txStore.On("UnspentOutputs",
+ mock.Anything,
+ ).Return([]wtxmgr.Credit{}, nil)
+ },
+ },
+ {
+ name: "utxo source",
+ source: &CoinSourceUTXOs{
+ UTXOs: []wire.OutPoint{utxo},
+ },
+ setupMocks: func(m *mockWalletDeps, source CoinSource) {
+ m.chain.On("BlockStamp").Return(
+ &waddrmgr.BlockStamp{}, nil,
+ )
+ m.txStore.On("GetUtxo", mock.Anything, utxo).
+ Return(credit, nil)
+ },
+ },
+ {
+ name: "nil source",
+ source: nil,
+ setupMocks: func(m *mockWalletDeps, source CoinSource) {
+ m.chain.On("BlockStamp").Return(
+ &waddrmgr.BlockStamp{}, nil,
+ )
+ m.txStore.On("UnspentOutputs",
+ mock.Anything,
+ ).Return([]wtxmgr.Credit{}, nil)
+ },
+ },
+ {
+ name: "unsupported source",
+ source: &unsupportedCoinSource{},
+ setupMocks: func(m *mockWalletDeps, source CoinSource) {
+ m.chain.On("BlockStamp").Return(
+ &waddrmgr.BlockStamp{}, nil,
+ )
+ },
+ expectedErr: ErrUnsupportedCoinSource,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ w, mocks := createStartedWalletWithMocks(t)
+ tc.setupMocks(mocks, tc.source)
+
+ _, err := w.getEligibleUTXOs(
+ &mockReadTx{}, tc.source, minconf,
+ )
+
+ require.ErrorIs(t, err, tc.expectedErr)
+ })
+ }
+}
+
+// TestCreateManualInputSource verifies that the createManualInputSource
+// function correctly creates an input source from a manually specified list of
+// UTXOs. It tests the success path, where all UTXOs are valid and spendable,
+// and the failure path, where a UTXO is not found in the wallet, ensuring that
+// the function returns the expected error in that case.
+func TestCreateManualInputSource(t *testing.T) {
+ t.Parallel()
+
+ w, mocks := createStartedWalletWithMocks(t)
+ dbtx := &mockReadTx{}
+
+ utxo1 := wire.OutPoint{Hash: [32]byte{1}, Index: 0}
+ credit1 := &wtxmgr.Credit{OutPoint: utxo1}
+
+ utxo2 := wire.OutPoint{Hash: [32]byte{2}, Index: 0}
+
+ testCases := []struct {
+ name string
+ inputs *InputsManual
+ setupMocks func()
+ expectedErr error
+ }{
+ {
+ name: "success",
+ inputs: &InputsManual{
+ UTXOs: []wire.OutPoint{utxo1},
+ },
+ setupMocks: func() {
+ mocks.txStore.On("GetUtxo",
+ mock.Anything, utxo1,
+ ).Return(credit1, nil).Once()
+ },
+ },
+ {
+ name: "utxo not found",
+ inputs: &InputsManual{
+ UTXOs: []wire.OutPoint{utxo2},
+ },
+ setupMocks: func() {
+ mocks.txStore.On("GetUtxo",
+ mock.Anything, utxo2,
+ ).Return(nil, wtxmgr.ErrUtxoNotFound).Once()
+ },
+ expectedErr: ErrUtxoNotEligible,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ tc.setupMocks()
+
+ source, err := w.createManualInputSource(
+ dbtx, tc.inputs,
+ )
+
+ require.ErrorIs(t, err, tc.expectedErr)
+
+ if err == nil {
+ require.NotNil(t, source)
+ } else {
+ require.Nil(t, source)
+ }
+ })
+ }
+}
+
+// TestCreatePolicyInputSource tests the functionality of the
+// createPolicyInputSource method. It ensures that the method correctly creates
+// an input source for coin selection based on a given policy. The test covers
+// scenarios where a default coin selection strategy is used, as well as cases
+// with a custom strategy. It also verifies that errors from underlying
+// dependencies, such as the database or the coin selection strategy itself, are
+// properly propagated.
+func TestCreatePolicyInputSource(t *testing.T) {
+ t.Parallel()
+
+ dbtx := &mockReadTx{}
+ feeRate := btcunit.NewSatPerKVByte(1000)
+
+ utxo1 := wtxmgr.Credit{
+ OutPoint: wire.OutPoint{Hash: [32]byte{1}, Index: 0},
+ }
+ utxo2 := wtxmgr.Credit{
+ OutPoint: wire.OutPoint{Hash: [32]byte{2}, Index: 0},
+ }
+ eligibleUtxos := []wtxmgr.Credit{utxo1, utxo2}
+
+ // A mock strategy that just returns the coins as is.
+ mockStrategy := &mockCoinSelectionStrategy{}
+ mockStrategy.On("ArrangeCoins", mock.Anything, mock.Anything).
+ Return(make([]Coin, 0), nil)
+
+ // A mock strategy that returns an error.
+ errCoinSelection := &mockCoinSelectionStrategy{}
+ errCoinSelection.On("ArrangeCoins", mock.Anything, mock.Anything).
+ Return(([]Coin)(nil), errStrategy)
+
+ testCases := []struct {
+ name string
+ policy *InputsPolicy
+ setupMocks func(m *mockWalletDeps)
+ expectedErr error
+ }{
+ {
+ name: "success with default strategy",
+ policy: &InputsPolicy{
+ // Should default to default account
+ Source: nil,
+ MinConfs: 1,
+ },
+ setupMocks: func(m *mockWalletDeps) {
+ m.chain.On("BlockStamp").Return(
+ &waddrmgr.BlockStamp{}, nil,
+ ).Once()
+ m.txStore.On("UnspentOutputs", mock.Anything).
+ Return(eligibleUtxos, nil).Once()
+ },
+ },
+ {
+ name: "success with custom strategy",
+ policy: &InputsPolicy{
+ Strategy: mockStrategy,
+ Source: nil,
+ MinConfs: 1,
+ },
+ setupMocks: func(m *mockWalletDeps) {
+ m.chain.On("BlockStamp").Return(
+ &waddrmgr.BlockStamp{}, nil,
+ ).Once()
+ m.txStore.On("UnspentOutputs", mock.Anything).
+ Return(eligibleUtxos, nil).Once()
+ },
+ },
+ {
+ name: "getEligibleUTXOs fails on UnspentOutputs",
+ policy: &InputsPolicy{
+ Source: nil,
+ MinConfs: 1,
+ },
+ setupMocks: func(m *mockWalletDeps) {
+ m.chain.On("BlockStamp").Return(
+ &waddrmgr.BlockStamp{}, nil,
+ ).Once()
+ m.txStore.On("UnspentOutputs",
+ mock.Anything,
+ ).Return(nil, errDB).Once()
+ },
+ expectedErr: errDB,
+ },
+ {
+ name: "strategy ArrangeCoins fails",
+ policy: &InputsPolicy{
+ Strategy: errCoinSelection,
+ Source: nil,
+ MinConfs: 1,
+ },
+ setupMocks: func(m *mockWalletDeps) {
+ m.chain.On("BlockStamp").Return(
+ &waddrmgr.BlockStamp{}, nil,
+ ).Once()
+ m.txStore.On("UnspentOutputs", mock.Anything).
+ Return(eligibleUtxos, nil).Once()
+ },
+ expectedErr: errStrategy,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ w, mocks := createStartedWalletWithMocks(t)
+ tc.setupMocks(mocks)
+
+ source, err := w.createPolicyInputSource(
+ dbtx, tc.policy, feeRate,
+ )
+
+ if tc.expectedErr != nil {
+ require.Error(t, err)
+ require.Contains(t, err.Error(),
+ tc.expectedErr.Error())
+ require.Nil(t, source)
+ } else {
+ require.NoError(t, err)
+ require.NotNil(t, source)
+ }
+ })
+ }
+}
+
+// TestCreateInputSource serves as a dispatcher test for the createInputSource
+// method. It verifies that the method correctly delegates to the appropriate
+// specialized input source creator—either for manual or policy-based coin
+// selection—based on the type of the `Inputs` field in the `TxIntent`. The test
+// also ensures that an `ErrUnsupportedTxInputs` error is returned if an
+// unknown input type is provided.
+func TestCreateInputSource(t *testing.T) {
+ t.Parallel()
+
+ dbtx := &mockReadTx{}
+
+ utxo := wire.OutPoint{Hash: [32]byte{1}, Index: 0}
+ credit := &wtxmgr.Credit{OutPoint: utxo}
+
+ manualInputs := &InputsManual{UTXOs: []wire.OutPoint{utxo}}
+ policyInputs := &InputsPolicy{}
+ unsupported := &unsupportedInputs{}
+
+ intentManual := &TxIntent{Inputs: manualInputs}
+ intentPolicy := &TxIntent{
+ Inputs: policyInputs,
+ FeeRate: btcunit.NewSatPerKVByte(1000),
+ }
+ intentUnsupported := &TxIntent{Inputs: unsupported}
+
+ testCases := []struct {
+ name string
+ intent *TxIntent
+ setupMocks func(m *mockWalletDeps)
+ expectedErr error
+ }{
+ {
+ name: "manual inputs",
+ intent: intentManual,
+ setupMocks: func(m *mockWalletDeps) {
+ m.txStore.On("GetUtxo", mock.Anything, utxo).
+ Return(credit, nil).Once()
+ },
+ },
+ {
+ name: "policy inputs",
+ intent: intentPolicy,
+ setupMocks: func(m *mockWalletDeps) {
+ m.chain.On("BlockStamp").Return(
+ &waddrmgr.BlockStamp{}, nil,
+ ).Once()
+ m.txStore.On("UnspentOutputs",
+ mock.Anything,
+ ).Return([]wtxmgr.Credit{*credit}, nil).Once()
+ },
+ },
+ {
+ name: "unsupported inputs",
+ intent: intentUnsupported,
+ setupMocks: func(m *mockWalletDeps) {},
+ expectedErr: ErrUnsupportedTxInputs,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ w, mocks := createStartedWalletWithMocks(t)
+ tc.setupMocks(mocks)
+
+ source, err := w.createInputSource(dbtx, tc.intent)
+
+ require.ErrorIs(t, err, tc.expectedErr)
+
+ if err == nil {
+ require.NotNil(t, source)
+ } else {
+ require.Nil(t, source)
+ }
+ })
+ }
+}
+
+// TestCreateTransactionSuccessManualInputs tests the success path for creating
+// a transaction with manually specified inputs.
+func TestCreateTransactionSuccessManualInputs(t *testing.T) {
+ t.Parallel()
+
+ // Arrange.
+ w, mocks := createStartedWalletWithMocks(t)
+ mocks.syncer.On("syncState").Return(syncStateSynced).Once()
+
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(privKey.PubKey().SerializeCompressed()),
+ &chainParams,
+ )
+ require.NoError(t, err)
+ validPkScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ validOutput := wire.TxOut{Value: 10000, PkScript: validPkScript}
+ validUTXO := wire.OutPoint{Hash: [32]byte{1}, Index: 0}
+
+ changeKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ changeAddr, err := address.NewAddressPubKey(
+ changeKey.PubKey().SerializeCompressed(), &chainParams,
+ )
+ require.NoError(t, err)
+
+ mockChangeAddr := &mockManagedAddress{}
+ mockChangeAddr.On("Address").Return(changeAddr)
+ mockChangeAddr.On("Internal").Return(true)
+ mockChangeAddr.On("Compressed").Return(true)
+ mockChangeAddr.On("AddrType").Return(waddrmgr.WitnessPubKey)
+ mockChangeAddr.On("InternalAccount").Return(uint32(0))
+ mockChangeAddr.On("DerivationInfo").Return(
+ waddrmgr.KeyScopeBIP0086, waddrmgr.DerivationPath{}, true,
+ )
+
+ credit := &wtxmgr.Credit{
+ OutPoint: validUTXO,
+ Amount: btcutil.Amount(50000), // Generous amount
+ PkScript: []byte{4, 5, 6},
+ }
+
+ intent := &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsManual{
+ UTXOs: []wire.OutPoint{validUTXO},
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: defaultAccountName,
+ KeyScope: waddrmgr.KeyScopeBIP0086,
+ },
+ FeeRate: btcunit.NewSatPerKVByte(1000),
+ }
+
+ accountStore := &mockAccountStore{}
+ mocks.addrStore.On("FetchScopedKeyManager",
+ waddrmgr.KeyScopeBIP0086).Return(accountStore, nil)
+
+ accountStore.On("LookupAccount",
+ mock.Anything, "default",
+ ).Return(uint32(0), nil)
+
+ accountProps := &waddrmgr.AccountProperties{
+ AccountNumber: 0,
+ AccountName: "default",
+ }
+ accountStore.On("AccountProperties",
+ mock.Anything, uint32(0),
+ ).Return(accountProps, nil)
+
+ accountStore.On("NextInternalAddresses",
+ mock.Anything, uint32(0), uint32(1),
+ ).Return(
+ []waddrmgr.ManagedAddress{
+ mockChangeAddr,
+ }, nil,
+ )
+
+ mocks.txStore.On("GetUtxo",
+ mock.Anything, validUTXO,
+ ).Return(credit, nil)
+
+ // Act.
+ tx, err := w.CreateTransaction(t.Context(), intent)
+
+ // Assert.
+ require.NoError(t, err)
+ require.NotNil(t, tx)
+}
+
+// TestCreateTransactionSuccessNilChangeSourceManualInputs tests the success
+// path for creating a transaction with manually specified inputs and a nil
+// change source.
+func TestCreateTransactionSuccessNilChangeSourceManualInputs(t *testing.T) {
+ t.Parallel()
+
+ // Arrange.
+ w, mocks := createStartedWalletWithMocks(t)
+ mocks.syncer.On("syncState").Return(syncStateSynced).Once()
+
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(privKey.PubKey().SerializeCompressed()),
+ &chainParams,
+ )
+ require.NoError(t, err)
+ validPkScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ validOutput := wire.TxOut{Value: 10000, PkScript: validPkScript}
+ validUTXO := wire.OutPoint{Hash: [32]byte{1}, Index: 0}
+
+ changeKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ changeAddr, err := address.NewAddressPubKey(
+ changeKey.PubKey().SerializeCompressed(), &chainParams,
+ )
+ require.NoError(t, err)
+
+ mockChangeAddr := &mockManagedAddress{}
+ mockChangeAddr.On("Address").Return(changeAddr)
+ mockChangeAddr.On("Internal").Return(true)
+ mockChangeAddr.On("Compressed").Return(true)
+ mockChangeAddr.On("AddrType").Return(waddrmgr.WitnessPubKey)
+ mockChangeAddr.On("InternalAccount").Return(uint32(0))
+ mockChangeAddr.On("DerivationInfo").Return(
+ waddrmgr.KeyScopeBIP0086, waddrmgr.DerivationPath{}, true,
+ )
+
+ credit := &wtxmgr.Credit{
+ OutPoint: validUTXO,
+ Amount: btcutil.Amount(50000), // Generous amount
+ PkScript: []byte{4, 5, 6},
+ }
+
+ intent := &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsManual{
+ UTXOs: []wire.OutPoint{validUTXO},
+ },
+ ChangeSource: nil,
+ FeeRate: btcunit.NewSatPerKVByte(1000),
+ }
+
+ accountStore := &mockAccountStore{}
+ mocks.addrStore.On("FetchScopedKeyManager",
+ waddrmgr.KeyScopeBIP0086,
+ ).Return(accountStore, nil)
+
+ // Should look up the default account
+ accountStore.On("LookupAccount",
+ mock.Anything, "default",
+ ).Return(uint32(0), nil)
+
+ accountProps := &waddrmgr.AccountProperties{
+ AccountNumber: 0,
+ AccountName: "default",
+ }
+ accountStore.On("AccountProperties",
+ mock.Anything, uint32(0),
+ ).Return(accountProps, nil)
+
+ accountStore.On("NextInternalAddresses",
+ mock.Anything, uint32(0), uint32(1),
+ ).Return(
+ []waddrmgr.ManagedAddress{
+ mockChangeAddr,
+ }, nil,
+ )
+
+ mocks.txStore.On("GetUtxo",
+ mock.Anything, validUTXO,
+ ).Return(credit, nil)
+
+ // Act.
+ tx, err := w.CreateTransaction(t.Context(), intent)
+
+ // Assert.
+ require.NoError(t, err)
+ require.NotNil(t, tx)
+}
+
+// TestCreateTransactionSuccessNilChangeSourcePolicyInputs tests the success
+// path for creating a transaction with policy-based inputs and a nil change
+// source.
+func TestCreateTransactionSuccessNilChangeSourcePolicyInputs(t *testing.T) {
+ t.Parallel()
+
+ // Arrange.
+ w, mocks := createStartedWalletWithMocks(t)
+ mocks.syncer.On("syncState").Return(syncStateSynced).Once()
+
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(privKey.PubKey().SerializeCompressed()),
+ &chainParams,
+ )
+ require.NoError(t, err)
+ validPkScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ validOutput := wire.TxOut{Value: 10000, PkScript: validPkScript}
+ validUTXO := wire.OutPoint{Hash: [32]byte{1}, Index: 0}
+
+ changeKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ changeAddr, err := address.NewAddressPubKey(
+ changeKey.PubKey().SerializeCompressed(), &chainParams,
+ )
+ require.NoError(t, err)
+
+ mockChangeAddr := &mockManagedAddress{}
+ mockChangeAddr.On("Address").Return(changeAddr)
+ mockChangeAddr.On("Internal").Return(true)
+ mockChangeAddr.On("Compressed").Return(true)
+ mockChangeAddr.On("AddrType").Return(waddrmgr.WitnessPubKey)
+ mockChangeAddr.On("InternalAccount").Return(uint32(0))
+ mockChangeAddr.On("DerivationInfo").Return(
+ waddrmgr.KeyScopeBIP0086, waddrmgr.DerivationPath{}, true,
+ )
+
+ credit := &wtxmgr.Credit{
+ OutPoint: validUTXO,
+ Amount: btcutil.Amount(50000), // Generous amount
+ PkScript: []byte{4, 5, 6},
+ }
+
+ intent := &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsPolicy{
+ Source: &ScopedAccount{
+ AccountName: "test-account",
+ KeyScope: waddrmgr.KeyScopeBIP0086,
+ },
+ },
+ ChangeSource: nil,
+ FeeRate: btcunit.NewSatPerKVByte(1000),
+ }
+
+ accountStore := &mockAccountStore{}
+ mocks.addrStore.On("FetchScopedKeyManager",
+ waddrmgr.KeyScopeBIP0086,
+ ).Return(accountStore, nil)
+
+ // Should look up the "test-account" for the change source.
+ accountStore.On("LookupAccount",
+ mock.Anything, "test-account",
+ ).Return(uint32(1), nil)
+
+ accountProps := &waddrmgr.AccountProperties{
+ AccountNumber: 1,
+ AccountName: "test-account",
+ }
+ accountStore.On("AccountProperties",
+ mock.Anything, uint32(1),
+ ).Return(accountProps, nil)
+
+ accountStore.On(
+ "NextInternalAddresses", mock.Anything,
+ uint32(1), uint32(1),
+ ).Return(
+ []waddrmgr.ManagedAddress{
+ mockChangeAddr,
+ }, nil,
+ )
+
+ // Mocks for createPolicyInputSource.
+ mocks.chain.On("BlockStamp").Return(
+ &waddrmgr.BlockStamp{}, nil,
+ )
+
+ // We need to return the credit for the test-account.
+ testAddr, err := address.NewAddressPubKey(
+ changeKey.PubKey().SerializeCompressed(),
+ &chainParams,
+ )
+ require.NoError(t, err)
+ testPkScript, err := txscript.PayToAddrScript(
+ testAddr,
+ )
+ require.NoError(t, err)
+
+ credit.PkScript = testPkScript
+
+ // We'll also need to set up the address store to know about the test
+ // account.
+ mockAddr := &mockManagedAddress{}
+ mockAddr.On("Account").Return(uint32(1))
+ accountStore.On("Address",
+ mock.Anything, testAddr,
+ ).Return(mockAddr, nil)
+ mocks.addrStore.On("AddrAccount",
+ mock.Anything, mock.Anything,
+ ).Return(accountStore, uint32(1), nil)
+ accountStore.On("Scope").Return(waddrmgr.KeyScopeBIP0086)
+
+ mocks.txStore.On("UnspentOutputs",
+ mock.Anything,
+ ).Return([]wtxmgr.Credit{*credit}, nil)
+
+ // Act.
+ tx, err := w.CreateTransaction(t.Context(), intent)
+
+ // Assert.
+ require.NoError(t, err)
+ require.NotNil(t, tx)
+}
+
+// TestCreateTransactionInvalidIntent tests that an error is returned when an
+// invalid transaction intent is provided.
+func TestCreateTransactionInvalidIntent(t *testing.T) {
+ t.Parallel()
+
+ // Arrange.
+ w, mocks := createStartedWalletWithMocks(t)
+ mocks.syncer.On("syncState").Return(syncStateSynced).Once()
+
+ intent := &TxIntent{
+ Outputs: []wire.TxOut{}, // No outputs
+ }
+
+ // Act.
+ tx, err := w.CreateTransaction(t.Context(), intent)
+
+ // Assert.
+ require.ErrorIs(t, err, ErrNoTxOutputs)
+ require.Nil(t, tx)
+}
+
+// TestCreateTransactionAccountNotFound tests that an error is returned when
+// the specified account is not found.
+func TestCreateTransactionAccountNotFound(t *testing.T) {
+ t.Parallel()
+
+ // Arrange.
+ w, mocks := createStartedWalletWithMocks(t)
+ mocks.syncer.On("syncState").Return(syncStateSynced).Once()
+
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ p2wkhAddr, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(privKey.PubKey().SerializeCompressed()),
+ &chainParams,
+ )
+ require.NoError(t, err)
+ validPkScript, err := txscript.PayToAddrScript(p2wkhAddr)
+ require.NoError(t, err)
+
+ validOutput := wire.TxOut{Value: 10000, PkScript: validPkScript}
+ validUTXO := wire.OutPoint{Hash: [32]byte{1}, Index: 0}
+
+ intent := &TxIntent{
+ Outputs: []wire.TxOut{validOutput},
+ Inputs: &InputsManual{
+ UTXOs: []wire.OutPoint{validUTXO},
+ },
+ ChangeSource: &ScopedAccount{
+ AccountName: "unknown",
+ KeyScope: waddrmgr.KeyScopeBIP0086,
+ },
+ FeeRate: btcunit.NewSatPerKVByte(1000),
+ }
+
+ accountStore := &mockAccountStore{}
+ mocks.addrStore.On("FetchScopedKeyManager",
+ waddrmgr.KeyScopeBIP0086).Return(
+ accountStore, nil,
+ )
+ errNotFound := waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrAccountNotFound,
+ }
+ accountStore.On("LookupAccount",
+ mock.Anything, "unknown",
+ ).Return(uint32(0), errNotFound)
+
+ // Act.
+ tx, err := w.CreateTransaction(t.Context(), intent)
+
+ // Assert.
+ require.ErrorIs(t, err, ErrAccountNotFound)
+ require.Nil(t, tx)
+}
diff --git a/wallet/tx_publisher.go b/wallet/tx_publisher.go
new file mode 100644
index 0000000000..2be2c7b386
--- /dev/null
+++ b/wallet/tx_publisher.go
@@ -0,0 +1,556 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+// Package wallet provides a bitcoin wallet implementation that is ready for
+// use.
+//
+// TODO(yy): bring wrapcheck back when implementing the `Store` interface.
+//
+//nolint:wrapcheck
+package wallet
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/rpcclient"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/chain"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/davecgh/go-spew/spew"
+)
+
+var (
+ // ErrMempoolAccept is a sentinel error used to indicate that the
+ // mempool acceptance test returned an unexpected number of results.
+ ErrMempoolAccept = errors.New(
+ "expected 1 result from TestMempoolAccept",
+ )
+)
+
+// TxPublisher provides an interface for publishing transactions.
+type TxPublisher interface {
+ // CheckMempoolAcceptance checks if a transaction would be accepted by
+ // the mempool without broadcasting.
+ CheckMempoolAcceptance(ctx context.Context, tx *wire.MsgTx) error
+
+ // Broadcast broadcasts a transaction to the network.
+ Broadcast(ctx context.Context, tx *wire.MsgTx, label string) error
+}
+
+// A compile time check to ensure that Wallet implements the interface.
+var _ TxPublisher = (*Wallet)(nil)
+
+// CheckMempoolAcceptance checks if a transaction would be accepted by the
+// mempool without broadcasting.
+func (w *Wallet) CheckMempoolAcceptance(_ context.Context,
+ tx *wire.MsgTx) error {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return err
+ }
+
+ if tx == nil {
+ return ErrTxCannotBeNil
+ }
+
+ // TODO(yy): thread context through.
+ // The TestMempoolAccept rpc expects a slice of transactions.
+ txns := []*wire.MsgTx{tx}
+
+ // Use a max feerate of 0 means the default value will be used when
+ // testing mempool acceptance. The default max feerate is 0.10 BTC/kvb,
+ // or 10,000 sat/vb.
+ maxFeeRate := float64(0)
+
+ results, err := w.cfg.Chain.TestMempoolAccept(txns, maxFeeRate)
+ if err != nil {
+ return err
+ }
+
+ // Sanity check that the expected single result is returned.
+ if len(results) != 1 {
+ return ErrMempoolAccept
+ }
+
+ result := results[0]
+
+ // If the transaction is allowed, we can return early.
+ if result.Allowed {
+ return nil
+ }
+
+ // Otherwise, we'll map the reason to a concrete error type and return
+ // it.
+ //
+ //nolint:err113
+ err = errors.New(result.RejectReason)
+
+ return w.cfg.Chain.MapRPCErr(err)
+}
+
+// Broadcast broadcasts a tx to the network. It is the main implementation of
+// the TxPublisher interface.
+func (w *Wallet) Broadcast(ctx context.Context, tx *wire.MsgTx,
+ label string) error {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return err
+ }
+
+ if tx == nil {
+ return ErrTxCannotBeNil
+ }
+
+ // We'll start by checking if the tx is acceptable to the mempool.
+ err = w.checkMempool(ctx, tx)
+ if errors.Is(err, errAlreadyBroadcasted) {
+ return nil
+ }
+
+ if err != nil {
+ return err
+ }
+
+ // First, we'll attempt to add the tx to our wallet's DB. This will
+ // allow us to track the tx's confirmation status, and also
+ // re-broadcast it upon startup. If any of the subsequent steps fail,
+ // this tx must be removed.
+ ourAddrs, err := w.addTxToWallet(tx, label)
+ if err != nil {
+ return err
+ }
+
+ // Now, we'll attempt to publish the tx. On successful attempt, we
+ // return immediately. On any failures, we remove it from the tx store
+ // to prevent subsequent attempts with stale transaction data.
+ err = w.publishTx(tx, ourAddrs)
+ if err == nil {
+ return nil
+ }
+
+ txid := tx.TxHash()
+ log.Errorf("%v: broadcast failed: %v", txid, err)
+
+ // If the tx was rejected for any other reason, then we'll remove it
+ // from the tx store, as otherwise, we'll attempt to continually
+ // re-broadcast it, and the UTXO state of the wallet won't be accurate.
+ removeErr := w.removeUnminedTx(tx)
+ if removeErr != nil {
+ log.Warnf("Unable to remove tx %v after broadcast failed: %v",
+ txid, removeErr)
+
+ // Return a wrapped error to give the caller full context.
+ return fmt.Errorf("broadcast failed: %w; and failed to "+
+ "remove from wallet: %v", err, removeErr)
+ }
+
+ return err
+}
+
+var (
+ // errAlreadyBroadcasted is a sentinel error used to indicate that a tx
+ // has already been broadcast.
+ errAlreadyBroadcasted = errors.New("tx already broadcasted")
+
+ // ErrTxCannotBeNil is returned when a nil transaction is passed to a
+ // function.
+ ErrTxCannotBeNil = errors.New("tx cannot be nil")
+)
+
+// checkMempool is a helper function that checks if a tx is acceptable to the
+// mempool before broadcasting.
+func (w *Wallet) checkMempool(ctx context.Context,
+ tx *wire.MsgTx) error {
+
+ // We'll start by checking if the tx is acceptable to the mempool.
+ err := w.CheckMempoolAcceptance(ctx, tx)
+
+ switch {
+ // If the tx is already in the mempool or confirmed, we can return
+ // early.
+ case errors.Is(err, chain.ErrTxAlreadyInMempool),
+ errors.Is(err, chain.ErrTxAlreadyKnown),
+ errors.Is(err, chain.ErrTxAlreadyConfirmed):
+
+ log.Infof("Tx %v already broadcasted", tx.TxHash())
+
+ // TODO(yy): Add a new method UpdateTxLabel to allow updating
+ // the label of a tx. With this change, the label passed in
+ // will be ignored if the tx is already known.
+ return errAlreadyBroadcasted
+
+ // If the backend does not support the mempool acceptance test, we'll
+ // just attempt to publish the tx.
+ case errors.Is(err, rpcclient.ErrBackendVersion),
+ errors.Is(err, chain.ErrUnimplemented):
+
+ log.Warnf("Backend does not support mempool acceptance test, "+
+ "broadcasting directly: %v", err)
+
+ return nil
+
+ // If the tx was rejected for any other reason, we'll return the error
+ // directly.
+ case err != nil:
+ return fmt.Errorf("tx rejected by mempool: %w", err)
+
+ // Otherwise, the tx is valid and we can publish it.
+ default:
+ return nil
+ }
+}
+
+// creditInfo is a struct that holds all the information needed to atomically
+// record a transaction credit.
+type creditInfo struct {
+ // index is the output index of the credit.
+ index uint32
+
+ // ma is the managed address of the credit.
+ ma waddrmgr.ManagedAddress
+
+ // addr is the address of the credit.
+ addr address.Address
+}
+
+// ownedAddrInfo holds information about a wallet-owned address and the
+// transaction output indices that pay to it.
+type ownedAddrInfo struct {
+ // managedAddr represents the managed address.
+ managedAddr waddrmgr.ManagedAddress
+
+ // outputIndices contains the transaction output indices that contain
+ // this address. The indices are not guaranteed to be sorted in any
+ // order.
+ outputIndices []uint32
+}
+
+// addTxToWallet adds a tx to the wallet's database. This function is a critical
+// part of the wallet's transaction processing pipeline and is designed for high
+// performance and atomicity. It follows a four-stage process:
+//
+// 1. Extract: First, it performs a CPU-intensive, in-memory pre-processing
+// step to parse all transaction outputs and extract all potential addresses.
+// This is done outside of any database transaction to avoid holding locks
+// during computationally expensive work.
+//
+// 2. Filter: Second, it uses a fast, read-only database transaction to
+// filter the large list of potential addresses down to the small set that is
+// actually owned by the wallet. This minimizes the time spent in the final,
+// more expensive write transaction.
+//
+// 3. Plan: Third, it prepares a definitive "write plan" in memory. This plan
+// is a simple slice of structs that contains all the information needed to
+// atomically update the database. This step ensures that transactions with
+// multiple outputs to the same address are handled correctly.
+//
+// 4. Execute: Finally, it executes this plan within a minimal, atomic write
+// transaction. This transaction contains no business logic and only performs
+// the necessary database writes, ensuring that the exclusive database lock is
+// held for the shortest possible time.
+func (w *Wallet) addTxToWallet(tx *wire.MsgTx,
+ label string) ([]address.Address, error) {
+
+ txRec, err := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
+ if err != nil {
+ return nil, err
+ }
+
+ // Stage 1: Extract potential addresses from all transaction outputs.
+ // This is a CPU-intensive operation that is performed entirely in
+ // memory, without holding any database locks.
+ txOutAddrs := w.extractTxAddrs(tx)
+
+ // Stage 2: Filter the extracted addresses to find which ones are owned
+ // by the wallet. This is done in a fast, read-only database
+ // transaction to minimize contention.
+ ownedAddrs, err := w.filterOwnedAddresses(txOutAddrs)
+ if err != nil {
+ return nil, err
+ }
+
+ // If the transaction has no outputs relevant to us, we can exit early.
+ if len(ownedAddrs) == 0 {
+ return nil, nil
+ }
+
+ // Stage 3: Prepare a definitive "write plan". This plan is created in
+ // memory and contains all the information needed for the final atomic
+ // database update.
+ //
+ // Pre-allocate slices with exact capacity to avoid reallocations.
+ // We know the exact number of credits from the total output indices
+ // across all owned addresses.
+ var totalCredits int
+ for _, info := range ownedAddrs {
+ totalCredits += len(info.outputIndices)
+ }
+
+ creditsToWrite := make([]creditInfo, 0, totalCredits)
+ ourAddrs := make([]address.Address, 0, len(ownedAddrs))
+
+ // Iterate directly over owned addresses and their pre-computed output
+ // indices. This correctly handles the edge case where a single
+ // transaction has multiple outputs paying to the same address.
+ for addr, info := range ownedAddrs {
+ for _, index := range info.outputIndices {
+ creditsToWrite = append(creditsToWrite, creditInfo{
+ index: index,
+ ma: info.managedAddr,
+ addr: addr,
+ })
+ }
+
+ ourAddrs = append(ourAddrs, addr)
+ }
+
+ // Stage 4: Atomically execute the write plan. This is the only stage
+ // that takes an exclusive database lock, and it is designed to be as
+ // fast as possible, containing no business logic.
+ err = w.recordTxAndCredits(txRec, label, creditsToWrite)
+ if err != nil {
+ return nil, err
+ }
+
+ return ourAddrs, nil
+}
+
+// recordTxAndCredits performs a single atomic database transaction to execute a
+// pre-computed "write plan" for a transaction.
+func (w *Wallet) recordTxAndCredits(txRec *wtxmgr.TxRecord, label string,
+ creditsToWrite []creditInfo) error {
+
+ return walletdb.Update(w.cfg.DB, func(dbTx walletdb.ReadWriteTx) error {
+ addrmgrNs := dbTx.ReadWriteBucket(waddrmgrNamespaceKey)
+ txmgrNs := dbTx.ReadWriteBucket(wtxmgrNamespaceKey)
+
+ // If there is a label we should write, get the namespace key
+ // and record it in the tx store.
+ if len(label) != 0 {
+ txHash := txRec.MsgTx.TxHash()
+
+ err := w.txStore.PutTxLabel(txmgrNs, txHash, label)
+ if err != nil {
+ return err
+ }
+ }
+
+ // At the moment all notified txs are assumed to actually be
+ // relevant. This assumption will not hold true when SPV
+ // support is added, but until then, simply insert the tx
+ // because there should either be one or more relevant inputs
+ // or outputs.
+ exists, err := w.txStore.InsertTxCheckIfExists(
+ txmgrNs, txRec, nil,
+ )
+ if err != nil {
+ return err
+ }
+
+ // If the tx has already been recorded, we can return early.
+ if exists {
+ return nil
+ }
+
+ // Now, execute the write plan.
+ for _, credit := range creditsToWrite {
+ err := w.txStore.AddCredit(
+ txmgrNs, txRec, nil, credit.index,
+ credit.ma.Internal(),
+ )
+ if err != nil {
+ return err
+ }
+
+ err = w.addrStore.MarkUsed(addrmgrNs, credit.addr)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+}
+
+// extractTxAddrs extracts all potential addresses from a transaction's outputs.
+// This is a CPU-intensive function that should be run outside of a database
+// transaction.
+func (w *Wallet) extractTxAddrs(tx *wire.MsgTx) map[uint32][]address.Address {
+ txOutAddrs := make(map[uint32][]address.Address)
+ for i, output := range tx.TxOut {
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(
+ output.PkScript, w.cfg.ChainParams,
+ )
+ // Ignore non-standard scripts.
+ if err != nil {
+ log.Warnf("Cannot extract non-std pkScript=%x",
+ output.PkScript)
+
+ continue
+ }
+
+ // It's not possible for a transaction to have this many
+ // outputs, so we can ignore the gosec error.
+ //
+ //nolint:gosec
+ txOutAddrs[uint32(i)] = addrs
+ }
+
+ return txOutAddrs
+}
+
+// filterOwnedAddresses takes a map of output indexes to addresses and returns a
+// new map containing only the addresses that are owned by the wallet. This
+// function is a key part of the wallet's performance strategy. It efficiently
+// filters a potentially large set of addresses down to the small subset that
+// the wallet needs to act on.
+//
+// The function is optimized to handle transactions with multiple outputs
+// paying to the same address. It internally de-duplicates the addresses to
+// ensure that the expensive database lookup (`w.addrStore.Address`) is
+// performed only once for each unique address.
+func (w *Wallet) filterOwnedAddresses(
+ txOutAddrs map[uint32][]address.Address) (
+ map[address.Address]ownedAddrInfo, error) {
+
+ ownedAddrs := make(map[address.Address]ownedAddrInfo)
+
+ // Pre-deduplicate addresses outside the DB transaction.
+ uniqueAddrs := make(map[address.Address][]uint32)
+ for index, addrs := range txOutAddrs {
+ for _, addr := range addrs {
+ uniqueAddrs[addr] = append(uniqueAddrs[addr], index)
+ }
+ }
+
+ err := walletdb.View(w.cfg.DB, func(dbTx walletdb.ReadTx) error {
+ addrmgrNs := dbTx.ReadBucket(waddrmgrNamespaceKey)
+
+ for addr, indices := range uniqueAddrs {
+ ma, err := w.addrStore.Address(addrmgrNs, addr)
+
+ // If the address is not found, it simply means
+ // it does not belong to the wallet. This is
+ // the expected case for most addresses, so we
+ // can safely continue to the next one.
+ if waddrmgr.IsError(
+ err, waddrmgr.ErrAddressNotFound) {
+
+ continue
+ }
+
+ if err != nil {
+ return err
+ }
+
+ ownedAddrs[addr] = ownedAddrInfo{
+ managedAddr: ma,
+ outputIndices: indices,
+ }
+ }
+
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return ownedAddrs, nil
+}
+
+// publishTx is a helper function that handles the process of broadcasting a
+// transaction to the network. This includes getting a chain client,
+// registering for notifications, and sending the raw transaction.
+func (w *Wallet) publishTx(tx *wire.MsgTx, ourAddrs []address.Address) error {
+ // We'll also ask to be notified of the tx once it confirms on-chain.
+ // This is done outside of the database tx to prevent backend
+ // interaction within it.
+ err := w.cfg.Chain.NotifyReceived(ourAddrs)
+ if err != nil {
+ return err
+ }
+
+ txid := tx.TxHash()
+
+ // allowHighFees is always false such that the max fee rate allowed is
+ // capped at 10,000 sat/vb for bitcoind. Note that this flag is only
+ // used in bitcoind chain backend. See,
+ // - https://github.com/btcsuite/btcd/blob/442ef28bcf03797e845c8e957e5cd6d4bffb5764/rpcclient/rawtransactions.go#L22
+ //
+ //nolint:lll
+ allowHighFees := false
+
+ _, rpcErr := w.cfg.Chain.SendRawTransaction(tx, allowHighFees)
+ if rpcErr == nil {
+ return nil
+ }
+
+ // If the tx was rejected, we need to determine why and act
+ // accordingly.
+ //
+ // NOTE: This check for ErrTxAlreadyInMempool should only be triggered
+ // if the wallet is running without mempool acceptance checks (e.g.,
+ // with an older version of the chain backend or with Neutrino).
+ // Otherwise, this condition should have been caught earlier by the
+ // `checkMempool` function.
+ if errors.Is(rpcErr, chain.ErrTxAlreadyInMempool) {
+ log.Infof("%v: tx already in mempool", txid)
+ return nil
+ }
+
+ // If the tx was rejected for any other reason, then we'll return the
+ // error and let the caller handle the cleanup.
+ return rpcErr
+}
+
+// removeUnminedTx removes a tx from the unconfirmed store.
+func (w *Wallet) removeUnminedTx(tx *wire.MsgTx) error {
+ txHash := tx.TxHash()
+
+ dbErr := walletdb.Update(w.cfg.DB, func(dbTx walletdb.ReadWriteTx) error {
+ txmgrNs := dbTx.ReadWriteBucket(wtxmgrNamespaceKey)
+
+ txRec, err := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
+ if err != nil {
+ return err
+ }
+
+ return w.txStore.RemoveUnminedTx(txmgrNs, txRec)
+ })
+ if dbErr != nil {
+ log.Warnf("Unable to remove invalid tx %v: %v", txHash, dbErr)
+ return dbErr
+ }
+
+ log.Infof("Removed invalid tx: %v", txHash)
+
+ // The serialized tx is for logging only, don't fail on the error.
+ var txRaw bytes.Buffer
+
+ _ = tx.Serialize(&txRaw)
+
+ // Optionally log the tx in debug when the size is manageable.
+ const maxTxSizeForLog = 1_000_000
+ if txRaw.Len() < maxTxSizeForLog {
+ log.Debugf("Removed invalid tx: %v \n hex=%x",
+ newLogClosure(func() string {
+ return spew.Sdump(tx)
+ }), txRaw.Bytes())
+ } else {
+ log.Debugf("Removed invalid tx %v due to its size "+
+ "being too large", txHash)
+ }
+
+ return nil
+}
diff --git a/wallet/tx_publisher_benchmark_test.go b/wallet/tx_publisher_benchmark_test.go
new file mode 100644
index 0000000000..d49d36bfd3
--- /dev/null
+++ b/wallet/tx_publisher_benchmark_test.go
@@ -0,0 +1,429 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wallet
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/stretchr/testify/require"
+)
+
+// BenchmarkBroadcastAPI benchmarks the Broadcast API against the legacy
+// PublishTransaction API using identical test data under sequential load.
+// Test names start with transaction pool size to group API comparisons for
+// benchstat analysis.
+//
+// Time Complexity Analysis:
+// Broadcast is a write operation with amortized cost. The time complexity is
+// O(n + m·log(k)) where:
+// - n: number of transaction outputs (address extraction)
+// - m: number of unique addresses extracted from outputs
+// - k: total number of addresses in the wallet (B-tree lookup)
+//
+// The API is optimized with a 4-stage pipeline:
+// 1. Extract: O(n) - CPU-intensive address extraction (no DB locks)
+// 2. Filter: O(m·log(k)) - Read-only DB transaction to filter owned addresses
+// 3. Plan: O(n·m) - In-memory write plan preparation (typically O(n) as m≈1-2)
+// 4. Execute: O(c) - Atomic write transaction
+// (c = owned outputs, typically c << n)
+//
+// This design ensures DB locks are held only during minimal read/write
+// operations, maximizing throughput under concurrent load.
+func BenchmarkBroadcastAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // endGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ endGrowthIteration = 5
+ )
+
+ var (
+ // accountGrowth uses constantGrowth since account count doesn't
+ // affect the Broadcast API's time complexity.
+ accountGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ // addressGrowth uses linearGrowth to test O(log k) wallet
+ // address lookup scaling. As the address count grows linearly,
+ // the filterOwnedAddresses lookup time should grow
+ // logarithmically due to B-tree indexing.
+ addressGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ linearGrowth,
+ )
+
+ // txPoolGrowth uses linearGrowth to establish baseline
+ // transaction pool size. This represents the number of
+ // unconfirmed transactions being broadcast, stressing the
+ // idempotency checks and mempool state management.
+ txPoolGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ linearGrowth,
+ )
+
+ // txIOGrowth uses linearGrowth for both inputs and outputs to
+ // test the O(n) address extraction and O(n·m) write plan
+ // preparation costs. As transaction complexity grows linearly,
+ // processing time should scale linearly with output count.
+ txIOGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ linearGrowth,
+ )
+
+ txPoolGrowthPadding = decimalWidth(
+ txPoolGrowth[len(txPoolGrowth)-1],
+ )
+
+ txIOGrowthPadding = decimalWidth(
+ txIOGrowth[len(txIOGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+
+ chainBackend = &mockChainClient{}
+ )
+
+ err := chainBackend.Start(b.Context())
+ require.NoError(b, err)
+ b.Cleanup(chainBackend.Stop)
+
+ for i := 0; i <= endGrowthIteration; i++ {
+ name := fmt.Sprintf("TxPool-%0*d-Addrs-%0*d-Ins-%0*d-Outs-%0*d",
+ txPoolGrowthPadding, txPoolGrowth[i],
+ addressGrowthPadding, addressGrowth[i],
+ txIOGrowthPadding, txIOGrowth[i],
+ txIOGrowthPadding, txIOGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: txPoolGrowth[i],
+ numTxInputs: txIOGrowth[i],
+ numTxOutputs: txIOGrowth[i],
+ },
+ )
+ bw.cfg.Chain = chainBackend
+
+ var (
+ beforeResult map[chainhash.Hash]*wire.MsgTx
+ afterResult map[chainhash.Hash]*wire.MsgTx
+ )
+
+ b.Run("0-Before", func(b *testing.B) {
+ result := make(map[chainhash.Hash]*wire.MsgTx)
+ baselineResult := make(
+ map[chainhash.Hash]*wire.MsgTx,
+ )
+
+ broadcastLabel := "sequential-before"
+
+ // Clear mempool to ensure clean state for
+ // benchmark baseline.
+ chainBackend.ResetMempool()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; b.Loop(); i++ {
+ index := i % len(bw.unconfirmedTxs)
+ tx := bw.unconfirmedTxs[index]
+
+ err := bw.PublishTransaction(
+ tx, broadcastLabel,
+ )
+ require.NoError(b, err)
+
+ result, err = chainBackend.GetMempool()
+ require.NoError(b, err)
+
+ // Capture baseline after each
+ // transaction in the first cycle. This
+ // ensures we get the complete mempool
+ // state after all transactions are
+ // published, since benchmark iteration
+ // count varies based on runtime
+ // performance.
+ if i < len(bw.unconfirmedTxs) {
+ baselineResult = result
+ }
+ }
+
+ require.Equal(
+ b, baselineResult, result,
+ "PublishTransaction API should be "+
+ "idempotent",
+ )
+
+ beforeResult = result
+ })
+
+ b.Run("1-After", func(b *testing.B) {
+ result := make(map[chainhash.Hash]*wire.MsgTx)
+ baselineResult := make(
+ map[chainhash.Hash]*wire.MsgTx,
+ )
+
+ broadcastLabel := "sequential-after"
+
+ // Clear mempool to ensure clean state for
+ // benchmark baseline.
+ chainBackend.ResetMempool()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; b.Loop(); i++ {
+ index := i % len(bw.unconfirmedTxs)
+ tx := bw.unconfirmedTxs[index]
+
+ err := bw.Broadcast(
+ b.Context(), tx, broadcastLabel,
+ )
+ require.NoError(b, err)
+
+ result, err = chainBackend.GetMempool()
+ require.NoError(b, err)
+
+ // Capture baseline after each
+ // transaction in the first cycle. This
+ // ensures we get the complete mempool
+ // state after all transactions are
+ // published, since benchmark iteration
+ // count varies based on runtime
+ // performance.
+ if i < len(bw.unconfirmedTxs) {
+ baselineResult = result
+ }
+ }
+
+ require.Equal(
+ b, baselineResult, result,
+ "PublishTransaction API should be "+
+ "idempotent",
+ )
+
+ afterResult = result
+ })
+
+ assertBroadcastAPIsEquivalent(
+ b, beforeResult, afterResult,
+ )
+ })
+ }
+}
+
+// BenchmarkBroadcastAPIConcurrently benchmarks the Broadcast API against the
+// legacy PublishTransaction API using identical test data under concurrent
+// load. Test names start with transaction pool size to group API comparisons
+// for benchstat analysis.
+//
+// Time Complexity Analysis:
+// Under concurrent load, the API maintains the same per-transaction complexity
+// of O(n + m·log(k)) as the sequential benchmark, where:
+// - n: number of transaction outputs (address extraction)
+// - m: number of unique addresses extracted from outputs
+// - k: total number of addresses in the wallet (B-tree lookup)
+//
+// The 4-stage pipeline design provides excellent concurrent performance:
+// 1. Extract: O(n) - Parallel CPU work, no contention
+// 2. Filter: O(m·log(k)) - Read-only transactions, minimal lock contention
+// 3. Plan: O(n·m) - Parallel in-memory work, no contention
+// 4. Execute: O(c) - Short write transactions reduce lock contention
+//
+// This benchmark stresses the lock contention characteristics during Stage 2
+// (read locks) and Stage 4 (write locks), demonstrating scalability under
+// concurrent broadcast operations.
+func BenchmarkBroadcastAPIConcurrently(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // endGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ endGrowthIteration = 5
+ )
+
+ var (
+ // accountGrowth uses constantGrowth since account count doesn't
+ // affect the Broadcast API's time complexity.
+ accountGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ // addressGrowth uses linearGrowth to test O(log k) wallet
+ // address lookup scaling. As the address count grows linearly,
+ // the filterOwnedAddresses lookup time should grow
+ // logarithmically due to B-tree indexing.
+ addressGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ linearGrowth,
+ )
+
+ // txPoolGrowth uses linearGrowth to establish baseline
+ // transaction pool size. This represents the number of
+ // unconfirmed transactions being broadcast, stressing the
+ // idempotency checks and mempool state management.
+ txPoolGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ linearGrowth,
+ )
+
+ // txIOGrowth uses linearGrowth for both inputs and outputs to
+ // test the O(n) address extraction and O(n·m) write plan
+ // preparation costs. As transaction complexity grows linearly,
+ // processing time should scale linearly with output count.
+ txIOGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ linearGrowth,
+ )
+
+ txPoolGrowthPadding = decimalWidth(
+ txPoolGrowth[len(txPoolGrowth)-1],
+ )
+
+ txIOGrowthPadding = decimalWidth(
+ txIOGrowth[len(txIOGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+
+ chainBackend = &mockChainClient{}
+ )
+
+ err := chainBackend.Start(b.Context())
+ require.NoError(b, err)
+ b.Cleanup(chainBackend.Stop)
+
+ for i := 0; i <= endGrowthIteration; i++ {
+ name := fmt.Sprintf("TxPool-%0*d-Addrs-%0*d-Ins-%0*d-Outs-%0*d",
+ txPoolGrowthPadding, txPoolGrowth[i],
+ addressGrowthPadding, addressGrowth[i],
+ txIOGrowthPadding, txIOGrowth[i],
+ txIOGrowthPadding, txIOGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: txPoolGrowth[i],
+ numTxInputs: txIOGrowth[i],
+ numTxOutputs: txIOGrowth[i],
+ },
+ )
+ bw.cfg.Chain = chainBackend
+
+ var (
+ beforeResult map[chainhash.Hash]*wire.MsgTx
+ afterResult map[chainhash.Hash]*wire.MsgTx
+ )
+
+ b.Run("0-Before", func(b *testing.B) {
+ broadcastLabel := "concurrent-before"
+
+ // Clear mempool to ensure clean state for
+ // benchmark baseline.
+ chainBackend.ResetMempool()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ b.RunParallel(func(pb *testing.PB) {
+ j := len(bw.unconfirmedTxs)
+ for i := 0; pb.Next(); i++ {
+ k := i % j
+ tx := bw.unconfirmedTxs[k]
+ err := bw.PublishTransaction(
+ tx, broadcastLabel,
+ )
+ require.NoError(b, err)
+ }
+ })
+
+ var err error
+
+ beforeResult, err = chainBackend.GetMempool()
+ require.NoError(b, err)
+ })
+
+ b.Run("1-After", func(b *testing.B) {
+ broadcastAfter := "concurrent-after"
+
+ // Clear mempool to ensure clean state for
+ // benchmark baseline.
+ chainBackend.ResetMempool()
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ b.RunParallel(func(pb *testing.PB) {
+ j := len(bw.unconfirmedTxs)
+ for i := 0; pb.Next(); i++ {
+ k := i % j
+ tx := bw.unconfirmedTxs[k]
+ err := bw.Broadcast(
+ b.Context(), tx,
+ broadcastAfter,
+ )
+ require.NoError(b, err)
+ }
+ })
+
+ var err error
+
+ afterResult, err = chainBackend.GetMempool()
+ require.NoError(b, err)
+ })
+
+ assertBroadcastAPIsEquivalent(
+ b, beforeResult, afterResult,
+ )
+ })
+ }
+}
+
+// assertBroadcastAPIsEquivalent verifies that PublishTransaction (legacy) and
+// Broadcast (new) produce equivalent results by comparing the transactions
+// that ended up in the mock mempool.
+func assertBroadcastAPIsEquivalent(b *testing.B,
+ before, after map[chainhash.Hash]*wire.MsgTx) {
+
+ b.Helper()
+
+ require.NotNil(b, before)
+ require.NotNil(b, after)
+
+ // require.Equal uses reflect.DeepEqual internally which compares maps
+ // by matching corresponding keys to deeply equal values, regardless of
+ // iteration order as stated in the official go package dev docs.
+ require.Equal(
+ b, before, after,
+ "PublishTransaction and Broadcast APIs should produce "+
+ "equivalent mempool state",
+ )
+}
diff --git a/wallet/tx_publisher_test.go b/wallet/tx_publisher_test.go
new file mode 100644
index 0000000000..3eab15bcc8
--- /dev/null
+++ b/wallet/tx_publisher_test.go
@@ -0,0 +1,926 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wallet
+
+import (
+ "crypto/sha256"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcec/v2"
+ "github.com/btcsuite/btcd/btcjson"
+ "github.com/btcsuite/btcd/rpcclient"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/chain"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ errDummy = errors.New("dummy")
+ errInsufficientFee = errors.New("insufficient fee")
+ errRpc = errors.New("rpc error")
+ errPublish = errors.New("publish error")
+ errRemove = errors.New("remove error")
+)
+
+const testTxLabel = "test-tx"
+
+// TestCheckMempoolAcceptance tests the CheckMempoolAcceptance method.
+func TestCheckMempoolAcceptance(t *testing.T) {
+ t.Parallel()
+
+ tx := &wire.MsgTx{}
+
+ mempoolAcceptResultAllowed := []*btcjson.TestMempoolAcceptResult{
+ {Allowed: true},
+ }
+ mempoolAcceptResultRejected := []*btcjson.TestMempoolAcceptResult{
+ {
+ Allowed: false,
+ RejectReason: errInsufficientFee.Error(),
+ },
+ }
+
+ testCases := []struct {
+ name string
+ tx *wire.MsgTx
+ rpcResult []*btcjson.TestMempoolAcceptResult
+ rpcErr error
+ expectedErr error
+ }{
+ {
+ name: "nil tx",
+ tx: nil,
+ expectedErr: ErrTxCannotBeNil,
+ },
+ {
+ name: "accepted",
+ tx: tx,
+ rpcResult: mempoolAcceptResultAllowed,
+ rpcErr: nil,
+ expectedErr: nil,
+ },
+ {
+ name: "rejected",
+ tx: tx,
+ rpcResult: mempoolAcceptResultRejected,
+ rpcErr: nil,
+ expectedErr: errInsufficientFee,
+ },
+ {
+ name: "rpc error",
+ tx: tx,
+ rpcResult: nil,
+ rpcErr: errRpc,
+ expectedErr: errRpc,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ w, m := createStartedWalletWithMocks(t)
+
+ if tc.tx != nil {
+ m.chain.On("TestMempoolAccept",
+ mock.Anything, mock.Anything,
+ ).Return(tc.rpcResult, tc.rpcErr)
+ }
+
+ // We only need to mock the MapRPCErr function if the
+ // RPC call is expected to succeed but the tx is
+ // rejected.
+ if tc.rpcErr == nil && tc.rpcResult != nil &&
+ !tc.rpcResult[0].Allowed {
+
+ m.chain.On("MapRPCErr",
+ mock.Anything,
+ ).Return(errInsufficientFee)
+ }
+
+ err := w.CheckMempoolAcceptance(t.Context(), tc.tx)
+ require.ErrorIs(t, err, tc.expectedErr)
+ })
+ }
+}
+
+// testTxData is a helper struct to hold the results of createTestTx.
+type testTxData struct {
+ // tx is the generated transaction.
+ tx *wire.MsgTx
+
+ // addr1 is the P2WKH address used in the transaction.
+ addr1 address.Address
+
+ // addr2 is the P2SH address used in the transaction.
+ addr2 address.Address
+
+ // addr3 is the P2WSH address used in the transaction.
+ addr3 address.Address
+}
+
+// createTestTx is a helper function to create a transaction with various
+// output types for testing. The created transaction has a single placeholder
+// input and four outputs:
+// - Output 0: A P2WKH (Pay-to-Witness-Key-Hash) output.
+// - Output 1: A P2SH (Pay-to-Script-Hash) output.
+// - Output 2: An OP_RETURN output for data embedding.
+// - Output 3: A 2-of-2 multi-sig P2WSH (Pay-to-Witness-Script-Hash) output.
+func createTestTx(t *testing.T, w *Wallet) *testTxData {
+ t.Helper()
+
+ // Create some keys and addresses for testing.
+ privKey1, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey1 := privKey1.PubKey()
+ addr1, err := address.NewAddressWitnessPubKeyHash(
+ address.Hash160(pubKey1.SerializeCompressed()),
+ w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ privKey2, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+
+ pubKey2 := privKey2.PubKey()
+
+ // Create a transaction with various output types.
+ tx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{
+ {},
+ },
+ TxOut: []*wire.TxOut{},
+ }
+
+ // Output 0: P2WKH
+ pkScript1, err := txscript.PayToAddrScript(addr1)
+ require.NoError(t, err)
+
+ tx.TxOut = append(tx.TxOut, &wire.TxOut{PkScript: pkScript1, Value: 1})
+
+ // Output 1: P2SH
+ script2 := []byte{txscript.OP_1}
+ addr2, err := address.NewAddressScriptHash(
+ script2, w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+ pkScript2, err := txscript.PayToAddrScript(addr2)
+ require.NoError(t, err)
+
+ tx.TxOut = append(tx.TxOut, &wire.TxOut{PkScript: pkScript2, Value: 1})
+
+ // Output 2: OP_RETURN
+ opReturnBuilder := txscript.NewScriptBuilder()
+ opReturnBuilder.AddOp(txscript.OP_RETURN).AddData([]byte("test"))
+ pkScript3, err := opReturnBuilder.Script()
+ require.NoError(t, err)
+
+ tx.TxOut = append(tx.TxOut, &wire.TxOut{PkScript: pkScript3, Value: 0})
+
+ // Output 3: Multi-sig P2WSH
+ builder := txscript.NewScriptBuilder()
+ builder.AddInt64(2)
+ builder.AddData(pubKey1.SerializeCompressed())
+ builder.AddData(pubKey2.SerializeCompressed())
+ builder.AddInt64(2)
+ builder.AddOp(txscript.OP_CHECKMULTISIG)
+ multiSigScript, err := builder.Script()
+ require.NoError(t, err)
+
+ scriptHash := sha256.Sum256(multiSigScript)
+ addr3, err := address.NewAddressWitnessScriptHash(
+ scriptHash[:], w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+ pkScript4, err := txscript.PayToAddrScript(addr3)
+ require.NoError(t, err)
+
+ tx.TxOut = append(tx.TxOut, &wire.TxOut{PkScript: pkScript4, Value: 1})
+
+ return &testTxData{
+ tx: tx,
+ addr1: addr1,
+ addr2: addr2,
+ addr3: addr3,
+ }
+}
+
+// TestExtractTxAddrs tests the extractTxAddrs method to ensure it correctly
+// extracts all potential addresses from a transaction's outputs.
+func TestExtractTxAddrs(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet.
+ w, _ := createStartedWalletWithMocks(t)
+
+ // Create the test transaction.
+ testData := createTestTx(t, w)
+
+ // Extract addresses.
+ extractedAddrs := w.extractTxAddrs(testData.tx)
+
+ // Check the results.
+ // We expect 4 entries in the map, one for each output.
+ require.Len(t, extractedAddrs, 4, "expected 4 outputs")
+
+ // Output 0 should have one address.
+ require.Len(t, extractedAddrs[0], 1)
+ require.Equal(t, testData.addr1.String(), extractedAddrs[0][0].String())
+
+ // Output 1 should have one address.
+ require.Len(t, extractedAddrs[1], 1)
+ require.Equal(t, testData.addr2.String(), extractedAddrs[1][0].String())
+
+ // Output 2 (OP_RETURN) should have no addresses.
+ require.Empty(t, extractedAddrs[2], "OP_RETURN output should have "+
+ "no addresses")
+
+ // Output 3 should have one address (the script hash address).
+ require.Len(t, extractedAddrs[3], 1)
+ require.Equal(t, testData.addr3.String(), extractedAddrs[3][0].String())
+}
+
+// TestFilterOwnedAddresses tests the filterOwnedAddresses method to ensure it
+// correctly identifies owned addresses and handles de-duplication.
+func TestFilterOwnedAddresses(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet with mocks.
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Create two addresses, one owned and one not.
+ ownedPrivKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ ownedAddr, err := address.NewAddressPubKey(
+ ownedPrivKey.PubKey().SerializeCompressed(), w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ unownedPrivKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ unownedAddr, err := address.NewAddressPubKey(
+ unownedPrivKey.PubKey().SerializeCompressed(),
+ w.cfg.ChainParams,
+ )
+ require.NoError(t, err)
+
+ // Create an input map with both addresses, with the owned address
+ // appearing twice.
+ txOutAddrs := map[uint32][]address.Address{
+ 0: {ownedAddr},
+ 1: {unownedAddr},
+ 2: {ownedAddr}, // Duplicate
+ }
+
+ // Set up the mock for the address store.
+ mockManagedAddr := &mockManagedAddress{}
+ errAddrNotFound := waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrAddressNotFound,
+ }
+
+ mocks.addrStore.On("Address", mock.Anything, ownedAddr).
+ Return(mockManagedAddr, nil).Once()
+ mocks.addrStore.On("Address", mock.Anything, unownedAddr).
+ Return(nil, errAddrNotFound).Once()
+
+ // Filter the addresses.
+ ownedAddrs, err := w.filterOwnedAddresses(txOutAddrs)
+ require.NoError(t, err)
+
+ // Check that the result contains only the owned address.
+ require.Len(t, ownedAddrs, 1)
+ _, ok := ownedAddrs[ownedAddr]
+ require.True(t, ok)
+}
+
+// TestRecordTxAndCredits tests the recordTxAndCredits method to ensure it
+// correctly records transactions and credits in the database.
+func TestRecordTxAndCredits(t *testing.T) {
+ t.Parallel()
+
+ // Create a sample TxRecord from a transaction with one input and one
+ // output with a value of 10000.
+ tx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{}},
+ TxOut: []*wire.TxOut{{Value: 10000}},
+ }
+ txRec, err := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
+ require.NoError(t, err)
+
+ // Create a sample credit for a P2PK address.
+ privKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ addr, err := address.NewAddressPubKey(
+ privKey.PubKey().SerializeCompressed(), &chainParams,
+ )
+ require.NoError(t, err)
+
+ mockManagedAddr := &mockManagedAddress{}
+ mockManagedAddr.On("Internal").Return(false)
+ credits := []creditInfo{{
+ index: 0,
+ ma: mockManagedAddr,
+ addr: addr,
+ }}
+
+ testCases := []struct {
+ name string
+ withLabel bool
+ txExists bool
+ }{
+ {
+ name: "new tx with label",
+ withLabel: true,
+ txExists: false,
+ },
+ {
+ name: "existing tx",
+ withLabel: true,
+ txExists: true,
+ },
+ {
+ name: "no label",
+ withLabel: false,
+ txExists: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ w, mocks := createStartedWalletWithMocks(t)
+ txid := tx.TxHash()
+
+ label := ""
+ if tc.withLabel {
+ label = testTxLabel
+ }
+
+ mocks.txStore.On("InsertTxCheckIfExists",
+ mock.Anything, txRec, mock.Anything,
+ ).Return(tc.txExists, nil).Once()
+
+ if tc.withLabel {
+ mocks.txStore.On("PutTxLabel",
+ mock.Anything, txid, label,
+ ).Return(nil).Once()
+ }
+
+ if !tc.txExists {
+ mocks.txStore.On("AddCredit",
+ mock.Anything, txRec, mock.Anything,
+ uint32(0), false,
+ ).Return(nil).Once()
+ mocks.addrStore.On("MarkUsed",
+ mock.Anything, addr,
+ ).Return(nil).Once()
+ }
+
+ err := w.recordTxAndCredits(txRec, label, credits)
+ require.NoError(t, err)
+ })
+ }
+}
+
+// TestAddTxToWallet tests the addTxToWallet method, which serves as an
+// integration test for the transaction extraction, filtering, and recording
+// process.
+func TestAddTxToWallet(t *testing.T) {
+ t.Parallel()
+
+ // Create some addresses for testing.
+ ownedPrivKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ ownedAddr, err := address.NewAddressPubKey(
+ ownedPrivKey.PubKey().SerializeCompressed(), &chainParams,
+ )
+ require.NoError(t, err)
+
+ unownedPrivKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ unownedAddr, err := address.NewAddressPubKey(
+ unownedPrivKey.PubKey().SerializeCompressed(), &chainParams,
+ )
+ require.NoError(t, err)
+
+ // Create a transaction with outputs to both owned and unowned
+ // addresses.
+ tx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{}},
+ TxOut: []*wire.TxOut{
+ {
+ Value: 10000,
+ PkScript: mustPayToAddrScript(ownedAddr),
+ },
+ {
+ Value: 20000,
+ PkScript: mustPayToAddrScript(unownedAddr),
+ },
+ {
+ Value: 30000,
+ PkScript: mustPayToAddrScript(ownedAddr),
+ },
+ },
+ }
+ txid := tx.TxHash()
+ label := testTxLabel
+
+ t.Run("tx with owned outputs", func(t *testing.T) {
+ t.Parallel()
+
+ w, m := createStartedWalletWithMocks(t)
+
+ // This test case simulates the scenario where the
+ // transaction has outputs owned by the wallet. We expect
+ // the wallet to identify these outputs, record the
+ // transaction, and credit the wallet with the new UTXOs.
+ //
+ // Set up the mock for the address store.
+ mockManagedAddr := &mockManagedAddress{}
+ mockManagedAddr.On("Internal").Return(false)
+
+ errAddrNotFound := waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrAddressNotFound,
+ }
+
+ m.addrStore.On("Address",
+ mock.Anything, ownedAddr,
+ ).Return(mockManagedAddr, nil)
+ m.addrStore.On("Address",
+ mock.Anything, unownedAddr,
+ ).Return(nil, errAddrNotFound)
+
+ // Set up the mocks for the transaction store.
+ m.txStore.On("PutTxLabel",
+ mock.Anything, txid, label,
+ ).Return(nil).Once()
+ m.txStore.On("InsertTxCheckIfExists",
+ mock.Anything, mock.Anything,
+ mock.Anything,
+ ).Return(false, nil).Once()
+
+ // We expect two credits to be added for the two owned
+ // outputs.
+ m.txStore.On("AddCredit",
+ mock.Anything, mock.Anything,
+ mock.Anything, uint32(0), false,
+ ).Return(nil).Once()
+ m.txStore.On("AddCredit",
+ mock.Anything, mock.Anything,
+ mock.Anything, uint32(2), false,
+ ).Return(nil).Once()
+ m.addrStore.On("MarkUsed",
+ mock.Anything, ownedAddr,
+ ).Return(nil).Twice()
+
+ // Add the transaction to the wallet.
+ ourAddrs, err := w.addTxToWallet(tx, label)
+ require.NoError(t, err)
+
+ // Check that the returned addresses are correct.
+ require.Len(t, ourAddrs, 2)
+ require.Equal(
+ t, ownedAddr.String(),
+ ourAddrs[0].String(),
+ )
+ require.Equal(
+ t, ownedAddr.String(),
+ ourAddrs[1].String(),
+ )
+ })
+
+ t.Run("tx with no owned outputs", func(t *testing.T) {
+ t.Parallel()
+
+ w, m := createStartedWalletWithMocks(t)
+
+ // This test case simulates the scenario where the
+ // transaction has no outputs owned by the wallet. We
+ // expect the wallet to identify this and exit early
+ // without recording the transaction.
+ //
+ // Set up the mock for the address store to own no
+ // addresses.
+ errAddrNotFound := waddrmgr.ManagerError{
+ ErrorCode: waddrmgr.ErrAddressNotFound,
+ }
+ m.addrStore.On("Address",
+ mock.Anything, ownedAddr,
+ ).Return(nil, errAddrNotFound)
+ m.addrStore.On("Address",
+ mock.Anything, unownedAddr,
+ ).Return(nil, errAddrNotFound)
+
+ // Add the transaction to the wallet.
+ ourAddrs, err := w.addTxToWallet(tx, label)
+ require.NoError(t, err)
+
+ // We expect no addresses to be returned and no calls to the
+ // transaction store.
+ require.Nil(t, ourAddrs)
+ })
+}
+
+// mustPayToAddrScript is a helper function to create a PkScript for a given
+// address. It panics on error.
+func mustPayToAddrScript(addr address.Address) []byte {
+ pkScript, err := txscript.PayToAddrScript(addr)
+ if err != nil {
+ panic(err)
+ }
+
+ return pkScript
+}
+
+// TestRemoveUnminedTx tests the removeUnminedTx method to ensure it correctly
+// removes a transaction from the unconfirmed store.
+func TestRemoveUnminedTx(t *testing.T) {
+ t.Parallel()
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Create a sample transaction with one input and one output.
+ tx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{}},
+ TxOut: []*wire.TxOut{{
+ Value: 10000,
+ }},
+ }
+
+ // Set up the mock for the transaction store.
+ mocks.txStore.On(
+ "RemoveUnminedTx", mock.Anything, mock.Anything,
+ ).Return(nil).Once()
+
+ // Call the method under test.
+ err := w.removeUnminedTx(tx)
+ require.NoError(t, err)
+}
+
+// TestCheckMempool tests the checkMempool helper function.
+func TestCheckMempool(t *testing.T) {
+ t.Parallel()
+
+ tx := &wire.MsgTx{}
+
+ testCases := []struct {
+ name string
+ mempoolAcceptErr error
+ expectedErr error
+ expectWrappedErr bool
+ rejectionReason string
+ mapRPCErr func(error) error
+ }{
+ {
+ name: "accepted",
+ mempoolAcceptErr: nil,
+ expectedErr: nil,
+ },
+ {
+ name: "already in mempool",
+ mempoolAcceptErr: chain.ErrTxAlreadyInMempool,
+ expectedErr: errAlreadyBroadcasted,
+ },
+ {
+ name: "already known",
+ mempoolAcceptErr: chain.ErrTxAlreadyKnown,
+ expectedErr: errAlreadyBroadcasted,
+ },
+ {
+ name: "already confirmed",
+ mempoolAcceptErr: chain.ErrTxAlreadyConfirmed,
+ expectedErr: errAlreadyBroadcasted,
+ },
+ {
+ name: "backend version",
+ mempoolAcceptErr: rpcclient.ErrBackendVersion,
+ expectedErr: nil,
+ },
+ {
+ name: "unimplemented",
+ mempoolAcceptErr: chain.ErrUnimplemented,
+ expectedErr: nil,
+ },
+ {
+ name: "rejected",
+ mempoolAcceptErr: errDummy,
+ expectedErr: errDummy,
+ expectWrappedErr: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ w, m := createStartedWalletWithMocks(t)
+
+ // Setup the mock for TestMempoolAccept.
+ if tc.mempoolAcceptErr == nil {
+ m.chain.On("TestMempoolAccept",
+ mock.Anything, mock.Anything,
+ ).Return([]*btcjson.TestMempoolAcceptResult{
+ {Allowed: true},
+ }, nil)
+ } else {
+ m.chain.On("TestMempoolAccept",
+ mock.Anything, mock.Anything,
+ ).Return(nil, tc.mempoolAcceptErr)
+ }
+
+ err := w.checkMempool(t.Context(), tx)
+ require.ErrorIs(t, err, tc.expectedErr)
+ })
+ }
+}
+
+// TestPublishTx tests the publishTx helper function.
+func TestPublishTx(t *testing.T) {
+ t.Parallel()
+
+ tx := &wire.MsgTx{}
+ addrs := []address.Address{&address.AddressPubKey{}}
+
+ testCases := []struct {
+ name string
+ notifyErr error
+ sendErr error
+ expectedErr error
+ }{
+ {
+ name: "success",
+ notifyErr: nil,
+ sendErr: nil,
+ expectedErr: nil,
+ },
+ {
+ name: "notify received fails",
+ notifyErr: errDummy,
+ sendErr: nil,
+ expectedErr: errDummy,
+ },
+ {
+ name: "send raw transaction fails",
+ notifyErr: nil,
+ sendErr: errDummy,
+ expectedErr: errDummy,
+ },
+ {
+ name: "already in mempool",
+ notifyErr: nil,
+ sendErr: chain.ErrTxAlreadyInMempool,
+ expectedErr: nil,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ w, m := createStartedWalletWithMocks(t)
+
+ m.chain.On("NotifyReceived",
+ mock.Anything, mock.Anything,
+ mock.Anything).Return(tc.notifyErr)
+
+ // We only expect SendRawTransaction to be called if
+ // NotifyReceived succeeds.
+ if tc.notifyErr == nil {
+ m.chain.On("SendRawTransaction",
+ mock.Anything, mock.Anything,
+ ).Return(nil, tc.sendErr)
+ }
+
+ err := w.publishTx(tx, addrs)
+ require.ErrorIs(t, err, tc.expectedErr)
+ })
+ }
+}
+
+// TestBroadcastSuccess tests the Broadcast method for a successful broadcast.
+func TestBroadcastSuccess(t *testing.T) {
+ t.Parallel()
+
+ label := testTxLabel
+ w, m := createStartedWalletWithMocks(t)
+
+ // Create a transaction with an owned output.
+ ownedPrivKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ ownedAddr, err := address.NewAddressPubKey(
+ ownedPrivKey.PubKey().SerializeCompressed(), &chainParams,
+ )
+ require.NoError(t, err)
+ pkScript, err := txscript.PayToAddrScript(ownedAddr)
+ require.NoError(t, err)
+
+ tx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{}},
+ TxOut: []*wire.TxOut{{Value: 10000, PkScript: pkScript}},
+ }
+
+ // Mock checkMempool to succeed.
+ m.chain.On("TestMempoolAccept",
+ mock.Anything, mock.Anything,
+ ).Return([]*btcjson.TestMempoolAcceptResult{{Allowed: true}}, nil)
+
+ // Mock addTxToWallet to succeed.
+ mockManagedAddr := &mockManagedAddress{}
+ mockManagedAddr.On("Internal").Return(false)
+ m.addrStore.On("Address",
+ mock.Anything, ownedAddr,
+ ).Return(mockManagedAddr, nil).Once()
+ m.txStore.On("PutTxLabel",
+ mock.Anything, tx.TxHash(), label,
+ ).Return(nil).Once()
+ m.txStore.On("InsertTxCheckIfExists",
+ mock.Anything, mock.Anything, mock.Anything,
+ ).Return(false, nil).Once()
+ m.txStore.On("AddCredit",
+ mock.Anything, mock.Anything, mock.Anything, uint32(0), false,
+ ).Return(nil).Once()
+ m.addrStore.On("MarkUsed",
+ mock.Anything, ownedAddr,
+ ).Return(nil).Once()
+
+ // Mock publishTx to succeed.
+ m.chain.On("NotifyReceived", mock.Anything).Return(nil)
+ m.chain.On("SendRawTransaction",
+ mock.Anything, mock.Anything,
+ ).Return(nil, nil)
+
+ err = w.Broadcast(t.Context(), tx, label)
+ require.NoError(t, err)
+}
+
+// TestBroadcastAlreadyBroadcasted tests the Broadcast method when the
+// transaction has already been broadcasted.
+func TestBroadcastAlreadyBroadcasted(t *testing.T) {
+ t.Parallel()
+
+ label := testTxLabel
+ w, m := createStartedWalletWithMocks(t)
+
+ tx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{}},
+ TxOut: []*wire.TxOut{{Value: 10000}},
+ }
+
+ // Mock checkMempool to return already broadcasted.
+ m.chain.On("TestMempoolAccept", mock.Anything, mock.Anything).
+ Return(nil, chain.ErrTxAlreadyInMempool)
+
+ err := w.Broadcast(t.Context(), tx, label)
+ require.NoError(t, err)
+}
+
+// TestBroadcastPublishFailsRemoveSucceeds tests the Broadcast method when
+// publishing fails but removing the transaction from the wallet succeeds.
+func TestBroadcastPublishFailsRemoveSucceeds(t *testing.T) {
+ t.Parallel()
+
+ label := testTxLabel
+ w, m := createStartedWalletWithMocks(t)
+
+ // Create a transaction with an owned output.
+ ownedPrivKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ ownedAddr, err := address.NewAddressPubKey(
+ ownedPrivKey.PubKey().SerializeCompressed(), &chainParams,
+ )
+ require.NoError(t, err)
+ pkScript, err := txscript.PayToAddrScript(ownedAddr)
+ require.NoError(t, err)
+
+ tx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{}},
+ TxOut: []*wire.TxOut{{Value: 10000, PkScript: pkScript}},
+ }
+
+ // Mock checkMempool to succeed.
+ m.chain.On("TestMempoolAccept",
+ mock.Anything, mock.Anything,
+ ).Return([]*btcjson.TestMempoolAcceptResult{{Allowed: true}}, nil)
+
+ // Mock addTxToWallet to succeed.
+ mockManagedAddr := &mockManagedAddress{}
+ mockManagedAddr.On("Internal").Return(false)
+ m.addrStore.On("Address",
+ mock.Anything, ownedAddr,
+ ).Return(mockManagedAddr, nil).Once()
+ m.txStore.On("PutTxLabel",
+ mock.Anything, tx.TxHash(), label,
+ ).Return(nil).Once()
+ m.txStore.On("InsertTxCheckIfExists",
+ mock.Anything, mock.Anything, mock.Anything,
+ ).Return(false, nil).Once()
+ m.txStore.On("AddCredit",
+ mock.Anything, mock.Anything, mock.Anything, uint32(0), false,
+ ).Return(nil).Once()
+ m.addrStore.On("MarkUsed",
+ mock.Anything, ownedAddr,
+ ).Return(nil).Once()
+
+ // Mock publishTx to fail.
+ m.chain.On("NotifyReceived", mock.Anything).Return(nil)
+ m.chain.On("SendRawTransaction",
+ mock.Anything, mock.Anything,
+ ).Return(nil, errPublish)
+
+ // Mock removeUnminedTx to succeed.
+ m.txStore.On("RemoveUnminedTx",
+ mock.Anything, mock.Anything,
+ ).Return(nil).Once()
+
+ err = w.Broadcast(t.Context(), tx, label)
+ require.ErrorIs(t, err, errPublish)
+}
+
+// TestBroadcastPublishFailsRemoveFails tests the Broadcast method when both
+// publishing and removing the transaction from the wallet fail.
+func TestBroadcastPublishFailsRemoveFails(t *testing.T) {
+ t.Parallel()
+
+ label := testTxLabel
+ w, m := createStartedWalletWithMocks(t)
+
+ // Create a transaction with an owned output.
+ ownedPrivKey, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ ownedAddr, err := address.NewAddressPubKey(
+ ownedPrivKey.PubKey().SerializeCompressed(), &chainParams,
+ )
+ require.NoError(t, err)
+ pkScript, err := txscript.PayToAddrScript(ownedAddr)
+ require.NoError(t, err)
+
+ tx := &wire.MsgTx{
+ TxIn: []*wire.TxIn{{}},
+ TxOut: []*wire.TxOut{{Value: 10000, PkScript: pkScript}},
+ }
+
+ // Mock checkMempool to succeed.
+ m.chain.On("TestMempoolAccept",
+ mock.Anything, mock.Anything,
+ ).Return([]*btcjson.TestMempoolAcceptResult{{Allowed: true}}, nil)
+
+ // Mock addTxToWallet to succeed.
+ mockManagedAddr := &mockManagedAddress{}
+ mockManagedAddr.On("Internal").Return(false)
+
+ // Mock addrStore to succeed.
+ m.addrStore.On("Address",
+ mock.Anything, ownedAddr,
+ ).Return(mockManagedAddr, nil).Once()
+ m.addrStore.On("MarkUsed",
+ mock.Anything, ownedAddr,
+ ).Return(nil).Once()
+
+ // Mock txStore to succeed.
+ m.txStore.On("PutTxLabel",
+ mock.Anything, tx.TxHash(), label,
+ ).Return(nil).Once()
+ m.txStore.On("InsertTxCheckIfExists",
+ mock.Anything, mock.Anything, mock.Anything,
+ ).Return(false, nil).Once()
+ m.txStore.On("AddCredit",
+ mock.Anything, mock.Anything, mock.Anything, uint32(0), false,
+ ).Return(nil).Once()
+
+ // Mock publishTx to fail.
+ m.chain.On("NotifyReceived", mock.Anything).Return(nil)
+ m.chain.On("SendRawTransaction",
+ mock.Anything, mock.Anything,
+ ).Return(nil, errPublish)
+
+ // Mock removeUnminedTx to fail.
+ m.txStore.On("RemoveUnminedTx",
+ mock.Anything, mock.Anything,
+ ).Return(errRemove).Once()
+
+ err = w.Broadcast(t.Context(), tx, label)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), errPublish.Error())
+ require.Contains(t, err.Error(), errRemove.Error())
+}
+
+// TestBroadcastNilTx tests that the Broadcast method returns an error when a
+// nil transaction is passed.
+func TestBroadcastNilTx(t *testing.T) {
+ t.Parallel()
+
+ label := testTxLabel
+ w, _ := createStartedWalletWithMocks(t)
+
+ err := w.Broadcast(t.Context(), nil, label)
+ require.ErrorIs(t, err, ErrTxCannotBeNil)
+}
diff --git a/wallet/tx_reader.go b/wallet/tx_reader.go
new file mode 100644
index 0000000000..30e2875e9a
--- /dev/null
+++ b/wallet/tx_reader.go
@@ -0,0 +1,442 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wallet
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "math"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/blockchain"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/pkg/btcunit"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+)
+
+var (
+ // ErrTxNotFound is returned when a transaction is not found in the
+ // store.
+ ErrTxNotFound = errors.New("tx not found")
+)
+
+// TxReader provides an interface for querying tx history.
+type TxReader interface {
+ // GetTx returns a detailed description of a tx given its tx hash.
+ GetTx(ctx context.Context, txHash chainhash.Hash) (*TxDetail, error)
+
+ // ListTxns returns a list of all txns which are relevant to the wallet
+ // over a given block range.
+ ListTxns(ctx context.Context, startHeight, endHeight int32) (
+ []*TxDetail, error)
+}
+
+// A compile-time assertion to ensure that Wallet implements the TxReader
+// interface.
+var _ TxReader = (*Wallet)(nil)
+
+// Output contains details for a tx output.
+type Output struct {
+ // Addresses are the addresses associated with the output script.
+ Addresses []address.Address
+
+ // PkScript is the raw output script.
+ PkScript []byte
+
+ // Index is the index of the output in the tx.
+ Index int
+
+ // Amount is the value of the output.
+ Amount btcutil.Amount
+
+ // Type is the script class of the output.
+ Type txscript.ScriptClass
+
+ // IsOurs is true if the output is controlled by the wallet.
+ IsOurs bool
+}
+
+// PrevOut describes a tx input.
+type PrevOut struct {
+ // OutPoint is the unique reference to the output being spent.
+ OutPoint wire.OutPoint
+
+ // IsOurs is true if the input spends an output controlled by the
+ // wallet.
+ IsOurs bool
+}
+
+// BlockDetails contains details about the block that includes a tx.
+type BlockDetails struct {
+ // Hash is the hash of the block.
+ Hash chainhash.Hash
+
+ // Height is the height of the block.
+ Height int32
+
+ // Timestamp is the unix timestamp of the block.
+ Timestamp int64
+}
+
+// TxDetail describes a tx relevant to a wallet. This is a flattened
+// and information-dense structure designed to be returned by the TxReader
+// interface.
+type TxDetail struct {
+ // Hash is the tx hash.
+ Hash chainhash.Hash
+
+ // RawTx is the serialized tx.
+ RawTx []byte
+
+ // Value is the net value of this tx (in satoshis) from the
+ // POV of the wallet.
+ Value btcutil.Amount
+
+ // Fee is the total fee in satoshis paid by this tx.
+ //
+ // NOTE: This is only calculated if all inputs are known to the wallet.
+ // Otherwise, it will be zero.
+ //
+ // TODO(yy): This should also be calculated for txns with external
+ // inputs. This requires adding a `GetRawTransaction` method to the
+ // `chain.Interface`.
+ Fee btcutil.Amount
+
+ // FeeRate is the fee rate of the tx in sat/vbyte.
+ //
+ // NOTE: This is only calculated if all inputs are known to the wallet.
+ // Otherwise, it will be zero.
+ FeeRate btcunit.SatPerVByte
+
+ // Weight is the tx's weight.
+ Weight btcunit.WeightUnit
+
+ // Confirmations is the number of confirmations this tx has.
+ // This will be 0 for unconfirmed txns.
+ Confirmations int32
+
+ // Block contains details of the block that includes this tx.
+ Block *BlockDetails
+
+ // ReceivedTime is the time the tx was received by the wallet.
+ ReceivedTime time.Time
+
+ // Outputs contains data for each tx output.
+ Outputs []Output
+
+ // PrevOuts are the inputs for the tx.
+ PrevOuts []PrevOut
+
+ // Label is an optional tx label.
+ Label string
+}
+
+// GetTx returns a detailed description of a tx given its tx hash.
+//
+// NOTE: This method is part of the TxReader interface.
+//
+// Time complexity: O(log n + I + O), where n is the number of
+// transactions in the database, I is the number of inputs, and O is the
+// number of outputs. The lookup is dominated by a key-based B-tree lookup
+// in the database and the processing of the transaction's inputs and
+// outputs.
+func (w *Wallet) GetTx(_ context.Context, txHash chainhash.Hash) (
+ *TxDetail, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ txDetails, err := w.fetchTxDetails(&txHash)
+ if err != nil {
+ return nil, err
+ }
+
+ bestBlock := w.SyncedTo()
+ currentHeight := bestBlock.Height
+
+ return w.buildTxDetail(txDetails, currentHeight), nil
+}
+
+// ListTxns returns a list of all txns which are relevant to the
+// wallet over a given block range. The block range is inclusive of the
+// start and end heights.
+//
+// The underlying transaction store allows for reverse iteration, so if
+// startHeight > endHeight, the transactions will be returned in reverse
+// order.
+//
+// The special height -1 may be used to include unmined transactions. For
+// example, to get all transactions from block 100 to the current tip including
+// unmined, use a startHeight of 100 and an endHeight of -1. To get all
+// transactions in the wallet, use a startHeight of 0 and an endHeight of -1.
+//
+// NOTE: This method is part of the TxReader interface.
+//
+// Time complexity: O(B + N), where B is the number of blocks in the
+// range and N is the total number of inputs and outputs across all
+// transactions in the range.
+func (w *Wallet) ListTxns(_ context.Context, startHeight,
+ endHeight int32) ([]*TxDetail, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ bestBlock := w.SyncedTo()
+ currentHeight := bestBlock.Height
+
+ // We'll first fetch all the transaction records from the database
+ // within a single database transaction. This is done to minimize the
+ // time we hold the database lock.
+ var records []wtxmgr.TxDetails
+
+ err = walletdb.View(w.cfg.DB, func(dbtx walletdb.ReadTx) error {
+ txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
+
+ err := w.txStore.RangeTransactions(
+ txmgrNs, startHeight, endHeight,
+ func(d []wtxmgr.TxDetails) (bool, error) {
+ records = append(records, d...)
+
+ return false, nil
+ },
+ )
+ if err != nil {
+ return fmt.Errorf("tx range failed: %w", err)
+ }
+
+ return nil
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to view wallet db: %w", err)
+ }
+
+ // Now that we have all the records, we can build the detailed
+ // response without holding the database lock.
+ details := make([]*TxDetail, 0, len(records))
+ for _, detail := range records {
+ txDetail := w.buildTxDetail(&detail, currentHeight)
+ details = append(details, txDetail)
+ }
+
+ return details, nil
+}
+
+// fetchTxDetails fetches the tx details for the given tx hash
+// from the wallet's tx store.
+func (w *Wallet) fetchTxDetails(txHash *chainhash.Hash) (
+ *wtxmgr.TxDetails, error) {
+
+ var txDetails *wtxmgr.TxDetails
+
+ err := walletdb.View(w.cfg.DB, func(dbtx walletdb.ReadTx) error {
+ txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
+
+ var err error
+
+ txDetails, err = w.txStore.TxDetails(txmgrNs, txHash)
+ if err != nil {
+ return fmt.Errorf("failed to fetch tx details: %w", err)
+ }
+
+ return nil
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to view wallet db: %w", err)
+ }
+
+ // TxDetails will return nil when the tx is not found.
+ //
+ // TODO(yy): We should instead return an error when the tx cannot be
+ // found in the db.
+ if txDetails == nil {
+ return nil, ErrTxNotFound
+ }
+
+ return txDetails, nil
+}
+
+// buildTxDetail builds a TxDetail from the given wtxmgr.TxDetails.
+func (w *Wallet) buildTxDetail(txDetails *wtxmgr.TxDetails,
+ currentHeight int32) *TxDetail {
+
+ details := w.buildBasicTxDetail(txDetails)
+
+ w.populateBlockDetails(details, txDetails, currentHeight)
+ w.calculateValueAndFee(details, txDetails)
+ w.populateOutputs(details, txDetails)
+ w.populatePrevOuts(details, txDetails)
+
+ return details
+}
+
+// buildBasicTxDetail builds the basic TxDetail from the given wtxmgr.TxDetails.
+func (w *Wallet) buildBasicTxDetail(txDetails *wtxmgr.TxDetails) *TxDetail {
+ txWeight := blockchain.GetTransactionWeight(
+ btcutil.NewTx(&txDetails.MsgTx),
+ )
+
+ return &TxDetail{
+ Hash: txDetails.Hash,
+ RawTx: txDetails.SerializedTx,
+ Label: txDetails.Label,
+ ReceivedTime: txDetails.Received,
+ Weight: safeInt64ToWeightUnit(txWeight),
+ FeeRate: btcunit.ZeroSatPerVByte,
+ }
+}
+
+// populateBlockDetails populates the block details for the given TxDetail.
+func (w *Wallet) populateBlockDetails(details *TxDetail,
+ txDetails *wtxmgr.TxDetails, currentHeight int32) {
+
+ height := txDetails.Block.Height
+ if height == -1 {
+ return
+ }
+
+ details.Block = &BlockDetails{
+ Hash: txDetails.Block.Hash,
+ Height: txDetails.Block.Height,
+ Timestamp: txDetails.Block.Time.Unix(),
+ }
+
+ details.Confirmations = calcConf(height, currentHeight)
+}
+
+// calculateValueAndFee calculates the value and fee for the given TxDetail.
+func (w *Wallet) calculateValueAndFee(details *TxDetail,
+ txDetails *wtxmgr.TxDetails) {
+
+ var balanceDelta btcutil.Amount
+ for _, debit := range txDetails.Debits {
+ balanceDelta -= debit.Amount
+ }
+
+ for _, credit := range txDetails.Credits {
+ balanceDelta += credit.Amount
+ }
+
+ details.Value = balanceDelta
+
+ // If not all inputs are ours, we can't calculate the total fee.
+ // txDetails.Debits contains only our inputs, while
+ // txDetails.MsgTx.TxIn contains all inputs. If they differ, some
+ // inputs belong to external wallets and we don't know their input
+ // values.
+ if len(txDetails.Debits) != len(txDetails.MsgTx.TxIn) {
+ return
+ }
+
+ var totalInput btcutil.Amount
+ for _, debit := range txDetails.Debits {
+ totalInput += debit.Amount
+ }
+
+ var totalOutput btcutil.Amount
+ for _, txOut := range txDetails.MsgTx.TxOut {
+ totalOutput += btcutil.Amount(txOut.Value)
+ }
+
+ details.Fee = totalInput - totalOutput
+ details.FeeRate = btcunit.CalcSatPerVByte(
+ details.Fee, details.Weight.ToVB(),
+ )
+}
+
+// populateOutputs populates the outputs for the given TxDetail.
+func (w *Wallet) populateOutputs(details *TxDetail,
+ txDetails *wtxmgr.TxDetails) {
+
+ isOurAddress := make(map[uint32]bool)
+ for _, credit := range txDetails.Credits {
+ isOurAddress[credit.Index] = true
+ }
+
+ for i, txOut := range txDetails.MsgTx.TxOut {
+ sc, outAddresses, _, err := txscript.ExtractPkScriptAddrs(
+ txOut.PkScript, w.cfg.ChainParams,
+ )
+
+ var addresses []address.Address
+ if err != nil {
+ log.Warnf("Cannot extract addresses from pkScript for "+
+ "tx %v, output %d: %v", details.Hash, i, err)
+ } else {
+ addresses = outAddresses
+ }
+
+ idx, ok := safeIntToUint32(i)
+ if !ok {
+ log.Warnf("Output index %d out of uint32 range", i)
+ continue
+ }
+
+ details.Outputs = append(
+ details.Outputs, Output{
+ Type: sc,
+ Addresses: addresses,
+ PkScript: txOut.PkScript,
+ Index: i,
+ Amount: btcutil.Amount(txOut.Value),
+ IsOurs: isOurAddress[idx],
+ },
+ )
+ }
+}
+
+// populatePrevOuts populates the previous outputs for the given TxDetail.
+func (w *Wallet) populatePrevOuts(details *TxDetail,
+ txDetails *wtxmgr.TxDetails) {
+
+ isOurOutput := make(map[uint32]bool)
+ for _, debit := range txDetails.Debits {
+ isOurOutput[debit.Index] = true
+ }
+
+ for i, txIn := range txDetails.MsgTx.TxIn {
+ idx, ok := safeIntToUint32(i)
+ if !ok {
+ log.Warnf("Input index %d out of uint32 range", i)
+ continue
+ }
+
+ details.PrevOuts = append(
+ details.PrevOuts, PrevOut{
+ OutPoint: txIn.PreviousOutPoint,
+ IsOurs: isOurOutput[idx],
+ },
+ )
+ }
+}
+
+// safeInt64ToWeightUnit converts an int64 to a unit.WeightUnit, ensuring the
+// value is non-negative.
+func safeInt64ToWeightUnit(w int64) btcunit.WeightUnit {
+ if w < 0 {
+ return btcunit.NewWeightUnit(0)
+ }
+
+ return btcunit.NewWeightUnit(uint64(w))
+}
+
+// safeIntToUint32 converts an int to a uint32, returning false if the
+// conversion would overflow.
+func safeIntToUint32(i int) (uint32, bool) {
+ if i < 0 || i > math.MaxUint32 {
+ return 0, false
+ }
+
+ return uint32(i), true
+}
diff --git a/wallet/tx_reader_benchmark_test.go b/wallet/tx_reader_benchmark_test.go
new file mode 100644
index 0000000000..461fea0655
--- /dev/null
+++ b/wallet/tx_reader_benchmark_test.go
@@ -0,0 +1,677 @@
+package wallet
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/stretchr/testify/require"
+)
+
+// BenchmarkGetTxAPI benchmarks GetTx API and its deprecated variant
+// GetTransaction using identical test data across transactions with varying
+// complexity (input/output counts). Test names start with transaction
+// complexity to group API comparisons for benchstat analysis.
+//
+// Time Complexity Analysis:
+// GetTx has no amortization - it's a read operation with consistent upper/tight
+// bound cost every time. The time complexity is O(log n + I + O) where:
+// - n: number of transactions in the database (B-tree lookup)
+// - I: number of inputs in the transaction
+// - O: number of outputs in the transaction
+func BenchmarkGetTxAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // endGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ endGrowthIteration = 5
+ )
+
+ var (
+ // accountGrowth uses constantGrowth since account count doesn't
+ // affect the API's time complexity.
+ accountGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ // addressGrowth uses constantGrowth since address count doesn't
+ // affect the API's time complexity.
+ addressGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ // txPoolGrowth uses linearGrowth to test O(log n) B-tree lookup
+ // scaling. As database size grows linearly, lookup time should
+ // grow logarithmically, demonstrating sublinear scaling.
+ txPoolGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration, linearGrowth,
+ )
+
+ // txIOGrowth uses symmetric linearGrowth for both inputs
+ // and outputs to stress test the O(I + O) processing cost with
+ // rapidly growing transaction complexity, exposing potential
+ // performance bottlenecks in input/output iteration and address
+ // extraction.
+ txIOGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration, linearGrowth,
+ )
+
+ txPoolGrowthPadding = decimalWidth(
+ txPoolGrowth[len(txPoolGrowth)-1],
+ )
+
+ txIOGrowthPadding = decimalWidth(
+ txIOGrowth[len(txIOGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ for i := 0; i <= endGrowthIteration; i++ {
+ name := fmt.Sprintf("TxPool-%0*d-Ins-%0*d-Outs-%0*d",
+ txPoolGrowthPadding, txPoolGrowth[i], txIOGrowthPadding,
+ txIOGrowth[i], txIOGrowthPadding, txIOGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: txPoolGrowth[i],
+ numTxInputs: txIOGrowth[i],
+ numTxOutputs: txIOGrowth[i],
+ },
+ )
+
+ // Get a transaction hash from the middle of the dataset
+ // for representative benchmarking.
+ medianIndex := len(bw.allTxs) / 2
+ testTxHash := bw.allTxs[medianIndex].TxHash()
+
+ var (
+ beforeResult *GetTransactionResult
+ afterResult *TxDetail
+ )
+
+ b.Run("0-Before", func(b *testing.B) {
+ var (
+ result *GetTransactionResult
+ baselineResult *GetTransactionResult
+ err error
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; b.Loop(); i++ {
+ result, err = bw.GetTransaction(
+ testTxHash,
+ )
+ require.NoError(b, err)
+
+ // Capture first result only.
+ if i == 0 {
+ baselineResult = result
+ }
+ }
+
+ require.Equal(
+ b, baselineResult, result,
+ "GetTransaction API should be "+
+ "idempotent",
+ )
+
+ beforeResult = result
+ })
+
+ b.Run("1-After", func(b *testing.B) {
+ var (
+ result *TxDetail
+ baselineResult *TxDetail
+ err error
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; b.Loop(); i++ {
+ result, err = bw.GetTx(
+ b.Context(), testTxHash,
+ )
+ require.NoError(b, err)
+
+ // Capture first baseline result only.
+ if i == 0 {
+ baselineResult = result
+ }
+ }
+
+ require.Equal(
+ b, baselineResult, result,
+ "GetTx API should be idempotent",
+ )
+
+ afterResult = result
+ })
+
+ // Verify API equivalence after benchmarks complete.
+ // This ensures:
+ // - Both APIs return consistent results for the same
+ // transaction
+ // - The new API maintains compatibility with the
+ // legacy API
+ // - Regression prevention for future changes
+ assertGetTxAPIsEquivalent(
+ b, bw.Wallet, beforeResult, afterResult,
+ )
+ })
+ }
+}
+
+// BenchmarkGetTxAPIConcurrently benchmarks GetTx API and its deprecated
+// variant GetTransaction using identical test data under concurrent load.
+// Test names start with transaction pool size to group API comparisons for
+// benchstat analysis.
+//
+// Time Complexity Analysis:
+// Under concurrent load, the API maintains the same per-transaction complexity
+// of O(log n + I + O) as the sequential benchmark, where:
+// - n: number of transactions in the database (B-tree lookup)
+// - I: number of inputs in the transaction
+// - O: number of outputs in the transaction
+//
+// This benchmark stresses the lock contention characteristics during database
+// reads, demonstrating scalability under concurrent read operations.
+func BenchmarkGetTxAPIConcurrently(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // endGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ endGrowthIteration = 5
+ )
+
+ var (
+ // accountGrowth uses constantGrowth since account count doesn't
+ // affect the API's time complexity.
+ accountGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ // addressGrowth uses constantGrowth since address count doesn't
+ // affect the API's time complexity.
+ addressGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ // txPoolGrowth uses linearGrowth to test O(log n) B-tree lookup
+ // scaling. As database size grows linearly, lookup time should
+ // grow logarithmically, demonstrating sublinear scaling.
+ txPoolGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration, linearGrowth,
+ )
+
+ // txIOGrowth uses symmetric linearGrowth for both inputs
+ // and outputs to stress test the O(I + O) processing cost with
+ // rapidly growing transaction complexity, exposing potential
+ // performance bottlenecks in input/output iteration and address
+ // extraction.
+ txIOGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration, linearGrowth,
+ )
+
+ txPoolGrowthPadding = decimalWidth(
+ txPoolGrowth[len(txPoolGrowth)-1],
+ )
+
+ txIOGrowthPadding = decimalWidth(
+ txIOGrowth[len(txIOGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ for i := 0; i <= endGrowthIteration; i++ {
+ name := fmt.Sprintf("TxPool-%0*d-Ins-%0*d-Outs-%0*d",
+ txPoolGrowthPadding, txPoolGrowth[i], txIOGrowthPadding,
+ txIOGrowth[i], txIOGrowthPadding, txIOGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: txPoolGrowth[i],
+ numTxInputs: txIOGrowth[i],
+ numTxOutputs: txIOGrowth[i],
+ },
+ )
+
+ // Get a transaction hash from the middle of the dataset
+ // for representative benchmarking.
+ medianIndex := len(bw.allTxs) / 2
+ testTxHash := bw.allTxs[medianIndex].TxHash()
+
+ var (
+ before *GetTransactionResult
+ after *TxDetail
+ )
+
+ b.Run("0-Before", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ res, err := bw.GetTransaction(
+ testTxHash,
+ )
+ before = res
+
+ require.NoError(b, err)
+ require.NotNil(b, before)
+ }
+ })
+ })
+
+ b.Run("1-After", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ res, err := bw.GetTx(
+ b.Context(), testTxHash,
+ )
+ after = res
+
+ require.NoError(b, err)
+ require.NotNil(b, after)
+ }
+ })
+ })
+
+ assertGetTxAPIsEquivalent(b, bw.Wallet, before, after)
+ })
+ }
+}
+
+// BenchmarkListTxnsAPI benchmarks ListTxns API and its deprecated variant
+// GetTransactions using identical test data across varying block ranges and
+// transaction densities. Test names start with complexity metrics to group API
+// comparisons for benchstat analysis.
+//
+// Time Complexity Analysis:
+// ListTxns has no amortization - it's a read operation with consistent
+// upper/tight bound cost. The time complexity is O(B * T * (I + O)) where:
+// - B: number of blocks in the range [startHeight, endHeight]
+// - T: average transactions per block
+// - I: average inputs per transaction
+// - O: average outputs per transaction
+//
+// This simplifies to O(N) where N = total inputs + outputs across all
+// transactions in the block range.
+func BenchmarkListTxnsAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // endGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ endGrowthIteration = 10
+ )
+
+ var (
+ // accountGrowth uses constantGrowth since account count doesn't
+ // affect the API's time complexity.
+ accountGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ // addressGrowth uses constantGrowth since address count doesn't
+ // affect the API's time complexity.
+ addressGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ // txPoolGrowth uses exponentialGrowth to stress test the
+ // O(B * T) component - total transactions across blocks.
+ txPoolGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration, linearGrowth,
+ )
+
+ // txIOGrowth uses exponentialGrowth for both inputs and outputs
+ // to stress test the O(I + O) per-transaction processing cost
+ // with rapidly growing transaction complexity.
+ txIOGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration, linearGrowth,
+ )
+
+ txPoolGrowthPadding = decimalWidth(
+ txPoolGrowth[len(txPoolGrowth)-1],
+ )
+
+ txIOGrowthPadding = decimalWidth(
+ txIOGrowth[len(txIOGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ for i := 0; i <= endGrowthIteration; i++ {
+ name := fmt.Sprintf("TxPool-%0*d-Ins-%0*d-Outs-%0*d",
+ txPoolGrowthPadding, txPoolGrowth[i], txIOGrowthPadding,
+ txIOGrowth[i], txIOGrowthPadding, txIOGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ // Setup wallet once for both API benchmarks.
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: txPoolGrowth[i],
+ numTxInputs: txIOGrowth[i],
+ numTxOutputs: txIOGrowth[i],
+ },
+ )
+
+ // List all transactions (no height filter).
+ // For GetTransactions (old): nil, nil means all blocks
+ // For ListTxns (new): 0, -1 means all blocks
+ // (0=genesis, -1=unlimited).
+ var (
+ startBlock *BlockIdentifier
+ endBlock *BlockIdentifier
+ startHeight int32 = 0
+ endHeight int32 = -1
+ )
+
+ var (
+ beforeResult *GetTransactionsResult
+ afterResult []*TxDetail
+ )
+
+ b.Run("0-Before", func(b *testing.B) {
+ var (
+ result *GetTransactionsResult
+ firstResult *GetTransactionsResult
+ err error
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; b.Loop(); i++ {
+ result, err = bw.GetTransactions(
+ startBlock, endBlock, "", nil,
+ )
+ require.NoError(b, err)
+
+ // Capture first result only.
+ if i == 0 {
+ firstResult = result
+ }
+ }
+
+ require.Equal(
+ b, firstResult, result,
+ "GetTransactions API should be "+
+ "idempotent",
+ )
+
+ beforeResult = result
+ })
+
+ b.Run("1-After", func(b *testing.B) {
+ var (
+ result []*TxDetail
+ firstResult []*TxDetail
+ err error
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; b.Loop(); i++ {
+ result, err = bw.ListTxns(
+ b.Context(), startHeight,
+ endHeight,
+ )
+ require.NoError(b, err)
+
+ // Capture first result only.
+ if i == 0 {
+ firstResult = result
+ }
+ }
+
+ require.Equal(
+ b, firstResult, result,
+ "ListTxns API should be idempotent ",
+ )
+
+ afterResult = result
+ })
+
+ // Verify API equivalence after benchmarks complete.
+ // This ensures:
+ // - Both APIs return consistent results for the same
+ // block range
+ // - The new API maintains compatibility with the
+ // legacy API
+ // - Regression prevention for future changes
+ assertListTxnsAPIsEquivalent(
+ b, bw.Wallet, beforeResult, afterResult,
+ )
+ })
+ }
+}
+
+// BenchmarkListTxnsAPIConcurrently benchmarks ListTxns API and its deprecated
+// variant GetTransactions using identical test data under concurrent load.
+// Test names start with complexity metrics to group API comparisons for
+// benchstat analysis.
+//
+// Time Complexity Analysis:
+// Under concurrent load, the API maintains the same per-request complexity
+// of O(B * T * (I + O)) as the sequential benchmark, where:
+// - B: number of blocks in the range [startHeight, endHeight]
+// - T: average transactions per block
+// - I: average inputs per transaction
+// - O: average outputs per transaction
+//
+// This benchmark stresses lock contention characteristics during database
+// reads, demonstrating scalability under concurrent read operations.
+func BenchmarkListTxnsAPIConcurrently(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // endGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ endGrowthIteration = 5
+ )
+
+ var (
+ // accountGrowth uses constantGrowth since account count doesn't
+ // affect the API's time complexity.
+ accountGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ // addressGrowth uses constantGrowth since address count doesn't
+ // affect the API's time complexity.
+ addressGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ // txPoolGrowth uses linearGrowth for CI-friendly execution
+ // while still testing scaling behavior.
+ txPoolGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration, linearGrowth,
+ )
+
+ // txIOGrowth uses linearGrowth for CI-friendly execution
+ // while still testing scaling behavior.
+ txIOGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration, linearGrowth,
+ )
+
+ txPoolGrowthPadding = decimalWidth(
+ txPoolGrowth[len(txPoolGrowth)-1],
+ )
+
+ txIOGrowthPadding = decimalWidth(
+ txIOGrowth[len(txIOGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ for i := 0; i <= endGrowthIteration; i++ {
+ name := fmt.Sprintf("TxPool-%0*d-Ins-%0*d-Outs-%0*d",
+ txPoolGrowthPadding, txPoolGrowth[i], txIOGrowthPadding,
+ txIOGrowth[i], txIOGrowthPadding, txIOGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: txPoolGrowth[i],
+ numTxInputs: txIOGrowth[i],
+ numTxOutputs: txIOGrowth[i],
+ },
+ )
+
+ // List all transactions (no height filter).
+ var (
+ startBlock *BlockIdentifier
+ endBlock *BlockIdentifier
+ startHeight int32 = 0
+ endHeight int32 = -1
+ )
+
+ var (
+ beforeResult *GetTransactionsResult
+ afterResult []*TxDetail
+ )
+
+ b.Run("0-Before", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ res, err := bw.GetTransactions(
+ startBlock, endBlock,
+ "", nil,
+ )
+ beforeResult = res
+
+ require.NoError(b, err)
+ require.NotNil(b, res)
+ }
+ })
+ })
+
+ b.Run("1-After", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ res, err := bw.ListTxns(
+ b.Context(),
+ startHeight, endHeight,
+ )
+ afterResult = res
+
+ require.NoError(b, err)
+ require.NotNil(b, res)
+ }
+ })
+ })
+
+ assertListTxnsAPIsEquivalent(
+ b, bw.Wallet, beforeResult, afterResult,
+ )
+ })
+ }
+}
+
+// assertGetTxAPIsEquivalent verifies that GetTransaction (legacy) and GetTx
+// (new) return equivalent data for the same transaction.
+func assertGetTxAPIsEquivalent(b *testing.B, w *Wallet,
+ before *GetTransactionResult, after *TxDetail) {
+
+ b.Helper()
+
+ require.NotNil(b, before)
+ require.NotNil(b, after)
+
+ afterConverted, err := w.GetTransaction(after.Hash)
+ require.NoError(b, err)
+
+ require.GreaterOrEqual(b, afterConverted.Confirmations, int32(0))
+
+ require.Equal(b, before, afterConverted)
+}
+
+// assertListTxnsAPIsEquivalent verifies that GetTransactions (legacy) and
+// ListTxns (new) return equivalent data for the same block range.
+func assertListTxnsAPIsEquivalent(b *testing.B, w *Wallet,
+ before *GetTransactionsResult, after []*TxDetail) {
+
+ b.Helper()
+
+ require.NotNil(b, before)
+ require.NotNil(b, after)
+
+ // Use GetTransactions API to fetch all transactions (both confirmed
+ // and unconfirmed) for comparison. Parameters match the benchmark's
+ // "before" case:
+ // - startBlock: nil (from genesis)
+ // - endBlock: nil (to current tip)
+ // - accountName: "" (all accounts)
+ // - cancel: nil (no cancellation)
+ var (
+ startBlock *BlockIdentifier
+ endBlock *BlockIdentifier
+ accountName string
+ cancel <-chan struct{}
+ )
+
+ afterConverted, err := w.GetTransactions(
+ startBlock, endBlock, accountName, cancel,
+ )
+ require.NoError(b, err)
+
+ require.NotEmpty(b, before.MinedTransactions)
+ require.NotEmpty(b, before.UnminedTransactions)
+
+ require.Equal(
+ b, before, afterConverted,
+ "GetTransactions and ListTxns APIs should return equivalent "+
+ "data",
+ )
+}
diff --git a/wallet/tx_reader_test.go b/wallet/tx_reader_test.go
new file mode 100644
index 0000000000..23d3430538
--- /dev/null
+++ b/wallet/tx_reader_test.go
@@ -0,0 +1,304 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wallet
+
+import (
+ "testing"
+ "time"
+
+ "github.com/btcsuite/btcd/blockchain"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/chaincfg/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcwallet/pkg/btcunit"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+// TestBuildTxDetail tests the buildTxDetail function.
+func TestBuildTxDetail(t *testing.T) {
+ t.Parallel()
+
+ // Create the various test cases.
+ minedDetails, minedTxDetail := createMinedTxDetail(t)
+ unminedDetails, unminedTxDetail := createUnminedTxDetail(t)
+ unminedNoFeeDetails, unminedNoFeeTxDetail := createUnminedTxDetail(t)
+ unminedNoFeeDetails.Debits = nil
+ unminedNoFeeTxDetail.Fee = 0
+ unminedNoFeeTxDetail.FeeRate = btcunit.ZeroSatPerVByte
+ unminedNoFeeTxDetail.Value = unminedNoFeeDetails.Credits[0].Amount +
+ unminedNoFeeDetails.Credits[1].Amount
+ unminedNoFeeTxDetail.PrevOuts[0].IsOurs = false
+
+ testCases := []struct {
+ name string
+ details *wtxmgr.TxDetails
+ expectedTxDetail *TxDetail
+ }{
+ {
+ name: "mined tx",
+ details: minedDetails,
+ expectedTxDetail: minedTxDetail,
+ },
+ {
+ name: "unmined tx",
+ details: unminedDetails,
+ expectedTxDetail: unminedTxDetail,
+ },
+ {
+ name: "unmined tx no fee",
+ details: unminedNoFeeDetails,
+ expectedTxDetail: unminedNoFeeTxDetail,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ // Arrange: Create a test wallet with mocks.
+ w, _ := createStartedWalletWithMocks(t)
+ currentHeight := int32(1)
+
+ // Act: Build the TxDetail.
+ result := w.buildTxDetail(tc.details, currentHeight)
+
+ // Assert: Check that the correct details are returned.
+ require.Equal(t, tc.expectedTxDetail, result)
+ })
+ }
+}
+
+// TestGetTxSuccess tests the GetTx method of the wallet for success scenarios.
+func TestGetTxSuccess(t *testing.T) {
+ t.Parallel()
+
+ minedDetails, minedTxDetail := createMinedTxDetail(t)
+ unminedDetails, unminedTxDetail := createUnminedTxDetail(t)
+
+ testCases := []struct {
+ name string
+ mockDetails *wtxmgr.TxDetails
+ expectedTxDetail *TxDetail
+ }{
+ {
+ name: "mined tx",
+ mockDetails: minedDetails,
+ expectedTxDetail: minedTxDetail,
+ },
+ {
+ name: "unmined tx",
+ mockDetails: unminedDetails,
+ expectedTxDetail: unminedTxDetail,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ // Arrange: Create a test wallet with mocks.
+ w, mocks := createStartedWalletWithMocks(t)
+ // SyncedTo is mocked in createStartedWalletWithMocks (height 1).
+
+ mocks.txStore.On("TxDetails", mock.Anything, TstTxHash).
+ Return(tc.mockDetails, nil).Once()
+
+ // Act: Get the transaction.
+ details, err := w.GetTx(t.Context(), *TstTxHash)
+
+ // Assert: Check that the correct details are returned.
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedTxDetail, details)
+ })
+ }
+}
+
+// TestGetTxNotFound tests that GetTx returns the correct error when a
+// transaction is not found.
+func TestGetTxNotFound(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet with mocks and mock the TxDetails call
+ // to return nil, simulating a non-existing tx.
+ w, mocks := createStartedWalletWithMocks(t)
+ mocks.txStore.On("TxDetails", mock.Anything, TstTxHash).
+ Return(nil, nil).Once()
+
+ // Act: Attempt to get the transaction.
+ _, err := w.GetTx(t.Context(), *TstTxHash)
+
+ // Assert that the correct error is returned.
+ require.ErrorIs(t, err, ErrTxNotFound)
+}
+
+// TestListTxnsSuccess tests the ListTxns method of the wallet.
+func TestListTxnsSuccess(t *testing.T) {
+ t.Parallel()
+
+ // Arrange: Create a test wallet with mocks and a mock tx record.
+ w, mocks := createStartedWalletWithMocks(t)
+ _, expectedTxDetail := createMinedTxDetail(t)
+
+ // SyncedTo is mocked in createStartedWalletWithMocks (height 1).
+
+ // Set up the mock for the tx store. We use .Run to execute the
+ // callback function that's passed in as an argument to the mock.
+ mocks.txStore.On("RangeTransactions",
+ mock.Anything, mock.Anything, mock.Anything, mock.Anything,
+ ).Run(func(args mock.Arguments) {
+ // Get the callback function from the arguments.
+ f, ok := args.Get(3).(func([]wtxmgr.TxDetails) (bool, error))
+ require.True(t, ok)
+
+ // Create the mock details to pass to the callback.
+ minedDetails, _ := createMinedTxDetail(t)
+ details := []wtxmgr.TxDetails{*minedDetails}
+
+ // Call the callback.
+ _, err := f(details)
+ require.NoError(t, err)
+ }).Return(nil).Once()
+
+ // Act: List txns.
+ details, err := w.ListTxns(t.Context(), 0, 1000)
+
+ // Assert: Check that the correct details are returned.
+ require.NoError(t, err)
+ require.Len(t, details, 1)
+ require.Equal(t, expectedTxDetail, details[0])
+}
+
+// createUnminedTxDetail creates a test transaction that sends funds from the
+// wallet to two of its own addresses. The transaction is unmined and has no
+// confirmations.
+//
+// The transaction details are as follows:
+// - The transaction has one input, which is owned by the wallet.
+// - The transaction has two outputs, both of which are owned by the wallet.
+// - The total value of the outputs (totalCredits) is the sum of the two
+// output amounts.
+// - The total value of the inputs (debitAmt) is the sum of the credits plus
+// a fee.
+// - The net value of the transaction from the wallet's perspective (Value) is
+// totalCredits - debitAmt, which is equal to -fee.
+func createUnminedTxDetail(t *testing.T) (*wtxmgr.TxDetails, *TxDetail) {
+ t.Helper()
+
+ // Create a deterministic timestamp for the test tx record.
+ txTime := time.Unix(1616161616, 0)
+ rec, err := wtxmgr.NewTxRecord(TstSerializedTx, txTime)
+ require.NoError(t, err)
+
+ // Deserialize the test transaction to avoid using a global variable.
+ tx, err := btcutil.NewTxFromBytes(TstSerializedTx)
+ require.NoError(t, err)
+
+ msgTx := tx.MsgTx()
+
+ // The credits are the sum of all outputs of the test tx.
+ var totalCredits btcutil.Amount
+ for _, txOut := range msgTx.TxOut {
+ totalCredits += btcutil.Amount(txOut.Value)
+ }
+
+ // The debit amount is the total credit amount plus a fee.
+ fee := btcutil.Amount(1000)
+ debitAmt := totalCredits + fee
+ testLabel := "test"
+
+ out0Amt := btcutil.Amount(msgTx.TxOut[0].Value)
+ out1Amt := btcutil.Amount(msgTx.TxOut[1].Value)
+
+ // Create a fully populated TxDetail for the unmined case.
+ unminedDetails := &wtxmgr.TxDetails{
+ TxRecord: *rec,
+ Block: wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ Height: -1,
+ },
+ },
+ Credits: []wtxmgr.CreditRecord{
+ {
+ Index: 0,
+ Amount: out0Amt,
+ },
+ {
+ Index: 1,
+ Amount: out1Amt,
+ },
+ },
+ Debits: []wtxmgr.DebitRecord{
+ {
+ Index: 0,
+ Amount: debitAmt,
+ },
+ },
+ Label: testLabel,
+ }
+
+ // Manually build the expected outputs for the test tx.
+ expectedOutputs := make([]Output, len(msgTx.TxOut))
+ for i, txOut := range msgTx.TxOut {
+ _, addrs, _, err := txscript.ExtractPkScriptAddrs(
+ txOut.PkScript, &chaincfg.RegressionNetParams,
+ )
+ require.NoError(t, err)
+
+ expectedOutputs[i] = Output{
+ Type: 2,
+ Addresses: addrs,
+ PkScript: txOut.PkScript,
+ Index: i,
+ Amount: btcutil.Amount(txOut.Value),
+ IsOurs: true,
+ }
+ }
+
+ // Manually build the expected previous outputs for the test tx.
+ expectedPrevOuts := []PrevOut{
+ {
+ OutPoint: msgTx.TxIn[0].PreviousOutPoint,
+ IsOurs: true,
+ },
+ }
+
+ // Define the expected TxDetail for the unmined case.
+ weight := btcunit.NewWeightUnit(uint64(blockchain.GetTransactionWeight(
+ btcutil.NewTx(&rec.MsgTx),
+ )))
+ unminedTxDetail := &TxDetail{
+ Hash: *TstTxHash,
+ RawTx: TstSerializedTx,
+ Label: testLabel,
+ Value: totalCredits - debitAmt,
+ Fee: fee,
+ FeeRate: btcunit.CalcSatPerVByte(fee, weight.ToVB()),
+ Confirmations: 0,
+ Weight: weight,
+ ReceivedTime: txTime,
+ Outputs: expectedOutputs,
+ PrevOuts: expectedPrevOuts,
+ }
+
+ return unminedDetails, unminedTxDetail
+}
+
+// createMinedTxDetail builds on createUnminedTxDetail to create a mined
+// transaction. The transaction has one confirmation.
+func createMinedTxDetail(t *testing.T) (*wtxmgr.TxDetails, *TxDetail) {
+ t.Helper()
+
+ minedDetails, minedTxDetail := createUnminedTxDetail(t)
+ // Set height to 1 to match the default SyncedTo mock (height 1).
+ minedDetails.Block.Height = 1
+ minedDetails.Block.Time = time.Unix(1616161617, 0)
+ minedTxDetail.Confirmations = 1
+ minedTxDetail.Block = &BlockDetails{
+ Height: 1,
+ Timestamp: minedDetails.Block.Time.Unix(),
+ }
+
+ return minedDetails, minedTxDetail
+}
diff --git a/wallet/tx_writer.go b/wallet/tx_writer.go
new file mode 100644
index 0000000000..d9abfd7f37
--- /dev/null
+++ b/wallet/tx_writer.go
@@ -0,0 +1,60 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wallet
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcwallet/walletdb"
+)
+
+// TxWriter provides an interface for updating wallet txns.
+type TxWriter interface {
+ // LabelTx adds a label to a tx. If a label already exists, it will be
+ // overwritten.
+ LabelTx(ctx context.Context, hash chainhash.Hash, label string) error
+}
+
+// A compile time check to ensure that Wallet implements the interface.
+var _ TxWriter = (*Wallet)(nil)
+
+// LabelTx adds a label to a tx. If a label already exists, it will be
+// overwritten.
+func (w *Wallet) LabelTx(_ context.Context,
+ hash chainhash.Hash, label string) error {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return err
+ }
+
+ err = walletdb.Update(w.cfg.DB, func(dbtx walletdb.ReadWriteTx) error {
+ txmgrNs := dbtx.ReadWriteBucket(wtxmgrNamespaceKey)
+
+ // Check that the transaction is known to the wallet.
+ details, err := w.txStore.TxDetails(txmgrNs, &hash)
+ if err != nil {
+ return fmt.Errorf("failed to get tx details: %w", err)
+ }
+
+ if details == nil {
+ return ErrTxNotFound
+ }
+
+ err = w.txStore.PutTxLabel(txmgrNs, hash, label)
+ if err != nil {
+ return fmt.Errorf("failed to put tx label: %w", err)
+ }
+
+ return nil
+ })
+ if err != nil {
+ return fmt.Errorf("failed to update wallet db: %w", err)
+ }
+
+ return nil
+}
diff --git a/wallet/tx_writer_benchmark_test.go b/wallet/tx_writer_benchmark_test.go
new file mode 100644
index 0000000000..0069f6ebea
--- /dev/null
+++ b/wallet/tx_writer_benchmark_test.go
@@ -0,0 +1,302 @@
+package wallet
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/stretchr/testify/require"
+)
+
+// BenchmarkLabelTxAPI benchmarks LabelTx API against its deprecated variant
+// LabelTransaction (when overwrite is true) using identical test data.
+// Test names use the wallet size metric to group API comparisons for benchstat
+// analysis.
+//
+// Time Complexity Analysis:
+// Both APIs are dominated by a single key-value write (PutTxLabel) operation,
+// which is typically O(log n) on a B-tree where n is the number of transactions
+// (keys). Since the new API eliminates an initial read operation
+// (FetchTxLabel), it should show better performance.
+func BenchmarkLabelTxAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // endGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ endGrowthIteration = 5
+ )
+
+ var (
+ // accountGrowth and addressGrowth use constantGrowth since
+ // ccount/address count doesn't directly affect the API's time
+ // complexity for a single tx label.
+ accountGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ // txPoolGrowth uses linearGrowth to test O(log n) B-tree write
+ // scaling. As database size grows linearly, write time should
+ // grow logarithmically.
+ txPoolGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ linearGrowth,
+ )
+
+ // txIOGrowth uses constantGrowth since I/O count doesn't affect
+ // the LabelTx API's time complexity.
+ txIOGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ txPoolGrowthPadding = decimalWidth(
+ txPoolGrowth[len(txPoolGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+
+ // testLabel is the string used for labeling the transaction.
+ testLabel = "bench_label"
+ )
+
+ for i := 0; i <= endGrowthIteration; i++ {
+ name := fmt.Sprintf("TxPool-%0*d", txPoolGrowthPadding,
+ txPoolGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: txPoolGrowth[i],
+ numTxInputs: txIOGrowth[i],
+ numTxOutputs: txIOGrowth[i],
+ },
+ )
+
+ // Get a transaction hash from the middle of the dataset
+ // for representative benchmarking.
+ medianIndex := len(bw.unconfirmedTxs) / 2
+ testTxHash := bw.unconfirmedTxs[medianIndex].TxHash()
+
+ // Initial write of the label to ensure both APIs are
+ // testing the _overwrite_ case, which aligns the
+ // functional behavior of LabelTransaction
+ // (overwrite=true) with LabelTx.
+ err := bw.LabelTx(b.Context(), testTxHash, testLabel)
+ require.NoError(b, err)
+
+ b.Run("0-Before", func(b *testing.B) {
+ const overwrite = true
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; b.Loop(); i++ {
+ err = bw.LabelTransaction(
+ testTxHash, testLabel,
+ overwrite,
+ )
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run("1-After", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for i := 0; b.Loop(); i++ {
+ err = bw.LabelTx(
+ b.Context(), testTxHash,
+ testLabel,
+ )
+ require.NoError(b, err)
+ }
+ })
+
+ // Verification: Ensure the label was successfully
+ // written and is identical after both benchmarks. Since
+ // we are testing the overwrite case repeatedly, we only
+ // need to check the final state. That way we are sure
+ // that we are benchmarking the thing right.
+ assertLabelTxAPIsEquivalent(
+ b, bw.Wallet, testTxHash, testLabel,
+ )
+ })
+ }
+}
+
+// BenchmarkLabelTxAPIConcurrently benchmarks LabelTx API and its deprecated
+// variant LabelTransaction (when overwrite is true) using identical test data
+// under concurrent load. Test names use the wallet size metric to group API
+// comparisons for benchstat analysis.
+//
+// Time Complexity Analysis:
+// Under concurrent load, the API maintains the same per-transaction complexity
+// of O(log n) as the sequential benchmark, where n is the number of
+// transactions in the database (B-tree write operation). Since the new API
+// eliminates an initial read operation (FetchTxLabel), it should show better
+// performance even under concurrent load.
+//
+// This benchmark stresses the lock contention characteristics during database
+// writes, demonstrating scalability under concurrent write operations.
+func BenchmarkLabelTxAPIConcurrently(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // endGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ endGrowthIteration = 5
+ )
+
+ var (
+ // accountGrowth and addressGrowth use constantGrowth since
+ // account/address count doesn't directly affect the API's time
+ // complexity for a single tx label.
+ accountGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ // txPoolGrowth uses linearGrowth to test O(log n) B-tree write
+ // scaling. As database size grows linearly, write time should
+ // grow logarithmically.
+ txPoolGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ linearGrowth,
+ )
+
+ // txIOGrowth uses constantGrowth since I/O count doesn't affect
+ // the LabelTx API's time complexity.
+ txIOGrowth = mapRange(
+ startGrowthIteration, endGrowthIteration,
+ constantGrowth,
+ )
+
+ txPoolGrowthPadding = decimalWidth(
+ txPoolGrowth[len(txPoolGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+
+ // testLabel is the string used for labeling the transaction.
+ testLabel = "bench_label"
+ )
+
+ for i := 0; i <= endGrowthIteration; i++ {
+ name := fmt.Sprintf("TxPool-%0*d", txPoolGrowthPadding,
+ txPoolGrowth[i])
+
+ b.Run(name, func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: txPoolGrowth[i],
+ numTxInputs: txIOGrowth[i],
+ numTxOutputs: txIOGrowth[i],
+ },
+ )
+
+ // Get a transaction hash from the middle of the dataset
+ // for representative benchmarking.
+ medianIndex := len(bw.unconfirmedTxs) / 2
+ testTxHash := bw.unconfirmedTxs[medianIndex].TxHash()
+
+ // Initial write of the label to ensure both APIs are
+ // testing the _overwrite_ case, which aligns the
+ // functional behavior of LabelTransaction
+ // (overwrite=true) with LabelTx.
+ err := bw.LabelTx(b.Context(), testTxHash, testLabel)
+ require.NoError(b, err)
+
+ b.Run("0-Before", func(b *testing.B) {
+ const overwrite = true
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ err := bw.LabelTransaction(
+ testTxHash, testLabel,
+ overwrite,
+ )
+ require.NoError(b, err)
+ }
+ })
+ })
+
+ b.Run("1-After", func(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ err := bw.LabelTx(
+ b.Context(), testTxHash,
+ testLabel,
+ )
+ require.NoError(b, err)
+ }
+ })
+ })
+
+ // Verification: Ensure the label was successfully
+ // written and is identical after both benchmarks. Since
+ // we are testing the overwrite case repeatedly, we only
+ // need to check the final state. That way we are sure
+ // that we are benchmarking the thing right.
+ assertLabelTxAPIsEquivalent(
+ b, bw.Wallet, testTxHash, testLabel,
+ )
+ })
+ }
+}
+
+// assertLabelTxAPIsEquivalent verifies that the transaction label is correctly
+// set after the benchmark run.
+func assertLabelTxAPIsEquivalent(b *testing.B, w *Wallet, hash chainhash.Hash,
+ expectedLabel string) {
+
+ b.Helper()
+
+ var actualLabel string
+
+ err := walletdb.View(w.cfg.DB, func(dbtx walletdb.ReadTx) error {
+ txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
+
+ var err error
+
+ actualLabel, err = w.txStore.FetchTxLabel(txmgrNs, hash)
+
+ return err
+ })
+
+ require.NoError(b, err)
+ require.Equal(
+ b, expectedLabel, actualLabel,
+ "LabelTx and LabelTransaction should result in the same label "+
+ "value",
+ )
+}
diff --git a/wallet/tx_writer_test.go b/wallet/tx_writer_test.go
new file mode 100644
index 0000000000..8545a82ae8
--- /dev/null
+++ b/wallet/tx_writer_test.go
@@ -0,0 +1,59 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wallet
+
+import (
+ "testing"
+
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+// TestLabelTxSuccess tests that we can successfully label a transaction.
+func TestLabelTxSuccess(t *testing.T) {
+ t.Parallel()
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock the TxDetails call to simulate a known transaction.
+ // We return a non-nil TxDetails to pass the check.
+ mocks.txStore.On("TxDetails", mock.Anything, TstTxHash).
+ Return(&wtxmgr.TxDetails{}, nil).Once()
+
+ // Arrange: Mock the PutTxLabel call. We expect it to be called with
+ // the new label.
+ newLabel := "new label"
+ mocks.txStore.On("PutTxLabel", mock.Anything, *TstTxHash, newLabel).
+ Return(nil).Once()
+
+ // Act: Call the LabelTx function.
+ err := w.LabelTx(t.Context(), *TstTxHash, newLabel)
+
+ // Assert: Check that there was no error and that the mocks were called
+ // as expected.
+ require.NoError(t, err)
+ mocks.txStore.AssertExpectations(t)
+}
+
+// TestLabelTxNotFound tests that we get an error when we try to label a tx
+// that is not known to the wallet.
+func TestLabelTxNotFound(t *testing.T) {
+ t.Parallel()
+
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Arrange: Mock the TxDetails call to return nil, simulating a tx
+ // that is not known to the wallet.
+ mocks.txStore.On("TxDetails", mock.Anything, TstTxHash).
+ Return(nil, nil).Once()
+
+ // Act: Attempt to label a tx that is not known to the wallet.
+ err := w.LabelTx(t.Context(), *TstTxHash, "some label")
+
+ // Assert: Check that the correct error is returned.
+ require.ErrorIs(t, err, ErrTxNotFound)
+ mocks.txStore.AssertExpectations(t)
+}
diff --git a/wallet/txsizes/size.go b/wallet/txsizes/size.go
index 4bf8e1fa4a..3fa44b36ad 100644
--- a/wallet/txsizes/size.go
+++ b/wallet/txsizes/size.go
@@ -246,8 +246,8 @@ func EstimateVirtualSize(numP2PKHIns, numP2TRIns, numP2WPKHIns, numNestedP2WPKHI
// GetMinInputVirtualSize returns the minimum number of vbytes that this input
// adds to a transaction.
-func GetMinInputVirtualSize(pkScript []byte) int {
- var baseSize, witnessWeight int
+func GetMinInputVirtualSize(pkScript []byte) uint64 {
+ var baseSize, witnessWeight uint64
switch {
// If this is a p2sh output, we assume this is a
// nested P2WKH.
@@ -267,7 +267,6 @@ func GetMinInputVirtualSize(pkScript []byte) int {
baseSize = RedeemP2PKHInputSize
}
- return baseSize +
- (witnessWeight+blockchain.WitnessScaleFactor-1)/
- blockchain.WitnessScaleFactor
+ return baseSize + (witnessWeight+blockchain.WitnessScaleFactor-1)/
+ blockchain.WitnessScaleFactor
}
diff --git a/wallet/unstable.go b/wallet/unstable.go
index f75667b7ed..8f455b6ce8 100644
--- a/wallet/unstable.go
+++ b/wallet/unstable.go
@@ -28,7 +28,8 @@ func (u unstableAPI) TxDetails(txHash *chainhash.Hash) (*wtxmgr.TxDetails, error
err := walletdb.View(u.w.db, func(dbtx walletdb.ReadTx) error {
txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
var err error
- details, err = u.w.TxStore.TxDetails(txmgrNs, txHash)
+
+ details, err = u.w.txStore.TxDetails(txmgrNs, txHash)
return err
})
return details, err
@@ -39,6 +40,6 @@ func (u unstableAPI) TxDetails(txHash *chainhash.Hash) (*wtxmgr.TxDetails, error
func (u unstableAPI) RangeTransactions(begin, end int32, f func([]wtxmgr.TxDetails) (bool, error)) error {
return walletdb.View(u.w.db, func(dbtx walletdb.ReadTx) error {
txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
- return u.w.TxStore.RangeTransactions(txmgrNs, begin, end, f)
+ return u.w.txStore.RangeTransactions(txmgrNs, begin, end, f)
})
}
diff --git a/wallet/utxo_manager.go b/wallet/utxo_manager.go
new file mode 100644
index 0000000000..9fea82e751
--- /dev/null
+++ b/wallet/utxo_manager.go
@@ -0,0 +1,552 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+// Package wallet provides a bitcoin wallet implementation that is centered
+// around the concept of a UtxoManager, which is responsible for managing the
+// wallet's UTXO set.
+//
+// TODO(yy): bring wrapcheck back when implementing the `Store` interface.
+//
+//nolint:wrapcheck
+package wallet
+
+import (
+ "context"
+ "sort"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/walletdb"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+)
+
+// Utxo provides a detailed overview of an unspent transaction output.
+type Utxo struct {
+ // OutPoint is the transaction output identifier.
+ OutPoint wire.OutPoint
+
+ // Amount is the value of the output.
+ Amount btcutil.Amount
+
+ // PkScript is the public key script for the output.
+ PkScript []byte
+
+ // Confirmations is the number of confirmations the output has.
+ Confirmations int32
+
+ // Spendable indicates whether the output is considered spendable.
+ Spendable bool
+
+ // Address is the address associated with the output.
+ Address address.Address
+
+ // Account is the name of the account that owns the output.
+ Account string
+
+ // AddressType is the type of the address.
+ AddressType waddrmgr.AddressType
+
+ // Locked indicates whether the output is locked.
+ Locked bool
+}
+
+// UtxoQuery holds the set of options for a ListUnspent query.
+type UtxoQuery struct {
+ // Account specifies the account to query UTXOs for. If empty,
+ // UTXOs from all accounts are returned.
+ Account string
+
+ // MinConfs is the minimum number of confirmations a UTXO must have.
+ MinConfs int32
+
+ // MaxConfs is the maximum number of confirmations a UTXO can have.
+ MaxConfs int32
+}
+
+// UtxoManager provides an interface for querying and managing the wallet's
+// UTXO set.
+type UtxoManager interface {
+ // ListUnspent returns a slice of all unspent transaction outputs that
+ // match the query. The returned UTXOs are sorted by amount in
+ // ascending order.
+ ListUnspent(ctx context.Context, query UtxoQuery) ([]*Utxo, error)
+
+ // GetUtxo returns the output information for a given outpoint.
+ GetUtxo(ctx context.Context, prevOut wire.OutPoint) (*Utxo, error)
+
+ // LeaseOutput locks an output for a given duration, preventing it from
+ // being used in transactions.
+ LeaseOutput(ctx context.Context, id wtxmgr.LockID,
+ op wire.OutPoint, duration time.Duration) (time.Time, error)
+
+ // ReleaseOutput unlocks a previously leased output, making it available
+ // for use.
+ ReleaseOutput(ctx context.Context, id wtxmgr.LockID,
+ op wire.OutPoint) error
+
+ // ListLeasedOutputs returns a list of all currently leased outputs.
+ ListLeasedOutputs(ctx context.Context) ([]*wtxmgr.LockedOutput, error)
+}
+
+// ListUnspent returns a slice of unspent transaction outputs that match the
+// query.
+//
+// This method provides a comprehensive view of the wallet's UTXO set, allowing
+// for filtering by account and confirmation status. The results are enriched
+// with detailed information about each UTXO, such as its address, account,
+// and spendability.
+//
+// How it works:
+// The method performs a full scan of all UTXOs in the wallet's transaction
+// store (`wtxmgr`). For each UTXO, it applies the specified filters (account,
+// confirmations). If a UTXO matches, the method then performs an additional
+// lookup in the address manager (`waddrmgr`) to enrich the UTXO data with
+// details like the owning account name, address type, and spendability. This
+// process of fetching a list and then performing a lookup for each item is
+// known as the "N+1 query problem" and is a known inefficiency (see TODO).
+//
+// Logical Steps:
+// 1. Initiate a single, read-only database transaction to ensure a
+// consistent view of the data.
+// 2. Fetch all unspent transaction outputs from the `wtxmgr` namespace.
+// 3. Sort the outputs in ascending order of value. This is a convention to
+// make the list more predictable and potentially useful for coin
+// selection algorithms that prefer larger UTXOs.
+// 4. Iterate through each UTXO:
+// a. Calculate its current confirmation status based on the wallet's
+// synced block height.
+// b. Apply the `MinConfs` and `MaxConfs` filters from the query.
+// c. Extract the address from the UTXO's public key script. For
+// multi-address scripts, the first address is used.
+// d. Call `waddrmgr.AddressDetails` to get the spendability status,
+// account name, and address type in a single, efficient lookup.
+// e. Apply the `Account` filter from the query.
+// f. If all filters pass, construct the final `Utxo` struct with all
+// the combined data.
+// 5. Append the `Utxo` to the result slice.
+// 6. After iterating through all UTXOs, return the final slice.
+//
+// Database Actions:
+// - This method performs a single read-only database transaction
+// (`walletdb.View`).
+// - It reads from both the `wtxmgr` (for UTXOs) and `waddrmgr` (for
+// address details) namespaces.
+//
+// Time Complexity:
+// - The complexity is O(U * A_l), where U is the total number of unspent
+// transaction outputs in the wallet and A_l is the average cost of the
+// address and account lookups (`AddressDetails`). This is due to the N+1
+// query problem where each UTXO requires additional lookups.
+//
+// TODO(yy): The current implementation of ListUnspent fetches all UTXOs
+// from the database and then filters them in memory. This is inefficient for
+// wallets with a large number of UTXOs. The upcoming SQL schema redesign should
+// address the following issues:
+//
+// 1. **N+1 Query Problem:** The function iterates through all unspent outputs
+// and performs separate database lookups for each one to retrieve its full
+// details. The database schema should be denormalized to include this data
+// directly in the `unspent` value, which would turn the N+1 query into a
+// single, efficient bucket scan.
+//
+// 2. **Lack of Pagination:** The function loads all results into a single
+// in-memory slice, which can be memory-intensive for wallets with a large
+// UTXO set. A more scalable approach would use an iterator pattern.
+//
+// NOTE: This is part of the UtxoManager interface implementation.
+func (w *Wallet) ListUnspent(_ context.Context,
+ query UtxoQuery) ([]*Utxo, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ log.Debugf("ListUnspent using query: %v", query)
+
+ syncBlock := w.addrStore.SyncedTo()
+ currentHeight := syncBlock.Height
+
+ var utxos []*Utxo
+
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ // First, fetch all unspent transaction outputs from the UTXO
+ // set.
+ unspent, err := w.txStore.UnspentOutputs(txmgrNs)
+ if err != nil {
+ return err
+ }
+
+ // Iterate through each UTXO to apply filters and enrich it with
+ // address-specific details.
+ for _, output := range unspent {
+ utxo := w.processUnspentOutput(
+ addrmgrNs, output, currentHeight, query,
+ )
+ if utxo != nil {
+ utxos = append(utxos, utxo)
+ }
+ }
+
+ return nil
+ })
+
+ // Sort the outputs in ascending order of value. This is a convention
+ // to make the list more predictable and potentially useful for coin
+ // selection algorithms that prefer smaller UTXOs.
+ sort.Slice(utxos, func(i, j int) bool {
+ return utxos[i].Amount < utxos[j].Amount
+ })
+
+ return utxos, err
+}
+
+// processUnspentOutput processes a single unspent output, applying filters and
+// enriching it with address details. Returns nil if the output should be
+// skipped.
+func (w *Wallet) processUnspentOutput(addrmgrNs walletdb.ReadBucket,
+ output wtxmgr.Credit, currentHeight int32, query UtxoQuery) *Utxo {
+
+ confs := calcConf(output.Height, currentHeight)
+
+ log.Tracef("Checking utxo[%v]: current height=%v, "+
+ "confirm height=%v, conf=%v", output.OutPoint,
+ currentHeight, output.Height, confs)
+
+ // Apply the MinConfs and MaxConfs filters from the query.
+ if confs < query.MinConfs || confs > query.MaxConfs {
+ return nil
+ }
+
+ // Extract the address from the UTXO's public key script.
+ // For multi-address scripts, the first address is used.
+ addr := extractAddrFromPKScript(
+ output.PkScript, w.cfg.ChainParams,
+ )
+ if addr == nil {
+ return nil
+ }
+
+ // Get all the required address-related details.
+ //
+ // NOTE: This lookup is the source of the N+1 query problem.
+ spendable, account, addrType := w.addrStore.AddressDetails(
+ addrmgrNs, addr,
+ )
+
+ log.Debugf("Found address: %s from account: %s",
+ addr.String(), account)
+
+ // Apply the Account filter from the query.
+ if query.Account != "" && account != query.Account {
+ return nil
+ }
+
+ // A UTXO is also unspendable if it is an immature coinbase output.
+ if output.FromCoinBase {
+ maturity := w.cfg.ChainParams.CoinbaseMaturity
+ if confs < int32(maturity) {
+ spendable = false
+ }
+ }
+
+ // TODO(yy): This should be a column in the new utxo SQL table. Note
+ // that currently UnspentOutputs only returns unlocked outputs, so this
+ // field will always be false. This will be fixed in the upcoming
+ // sqlization PRs.
+ locked := output.Locked
+
+ // If all filters pass, construct the final Utxo struct with all the
+ // combined data.
+ return &Utxo{
+ OutPoint: output.OutPoint,
+ Amount: output.Amount,
+ PkScript: output.PkScript,
+ Confirmations: confs,
+ Spendable: spendable,
+ Address: addr,
+ Account: account,
+ AddressType: addrType,
+ Locked: locked,
+ }
+}
+
+// GetUtxo returns the output information for a given outpoint.
+//
+// This method provides a detailed view of a single UTXO, identified by its
+// outpoint. The result is enriched with detailed information about the UTXO,
+// such as its address, account, and spendability.
+//
+// How it works:
+// The method performs a direct lookup of the UTXO in the wallet's transaction
+// store (`wtxmgr`). If the UTXO is found, it then performs an additional
+// lookup in the address manager (`waddrmgr`) to enrich the UTXO data with
+// details like the owning account name, address type, and spendability.
+//
+// Logical Steps:
+// 1. Initiate a single, read-only database transaction to ensure a
+// consistent view of the data.
+// 2. Fetch the unspent transaction output from the `wtxmgr` namespace using
+// the provided outpoint.
+// 3. If the UTXO is not found, return a `wtxmgr.ErrUtxoNotFound` error.
+// 4. Calculate its current confirmation status based on the wallet's
+// synced block height.
+// 5. Extract the address from the UTXO's public key script. For
+// multi-address scripts, the first address is used.
+// 6. Call `waddrmgr.AddressDetails` to get the spendability status,
+// account name, and address type in a single, efficient lookup.
+// 7. Construct the final `Utxo` struct with all the combined data.
+// 8. Return the final `Utxo` struct.
+//
+// Database Actions:
+// - This method performs a single read-only database transaction
+// (`walletdb.View`).
+// - It reads from both the `wtxmgr` (for the UTXO) and `waddrmgr` (for
+// address details) namespaces.
+//
+// Time Complexity:
+// - The complexity is O(A_l), where A_l is the average cost of the
+// address and account lookups (`AddressDetails`).
+//
+// TODO(yy): The current implementation of GetUtxo performs separate database
+// lookups for the UTXO and its details. The upcoming SQL schema redesign should
+// address this issue by denormalizing the data, which would turn the multiple
+// lookups into a single, efficient query.
+//
+// NOTE: This is part of the UtxoManager interface implementation.
+func (w *Wallet) GetUtxo(_ context.Context,
+ prevOut wire.OutPoint) (*Utxo, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ // Calculate the current confirmation status based on the wallet's
+ // synced block height.
+ syncBlock := w.addrStore.SyncedTo()
+ currentHeight := syncBlock.Height
+
+ var utxo *Utxo
+
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+
+ // First, fetch the unspent transaction output from the UTXO
+ // set.
+ output, err := w.txStore.GetUtxo(txmgrNs, prevOut)
+ if err != nil {
+ return err
+ }
+
+ // If the output is not found, return an error.
+ if output == nil {
+ return wtxmgr.ErrUtxoNotFound
+ }
+
+ confs := calcConf(output.Height, currentHeight)
+
+ // Extract the address from the UTXO's public key script.
+ // For multi-address scripts, the first address is used.
+ addr := extractAddrFromPKScript(
+ output.PkScript, w.cfg.ChainParams,
+ )
+ if addr == nil {
+ return wtxmgr.ErrUtxoNotFound
+ }
+
+ // In a single lookup, get all the required
+ // address-related details: spendability, account name,
+ // and address type. This avoids the N+1 query problem.
+ spendable, account, addrType := w.addrStore.AddressDetails(
+ addrmgrNs, addr,
+ )
+
+ // If all filters pass, construct the final Utxo struct
+ // with all the combined data.
+ utxo = &Utxo{
+ OutPoint: output.OutPoint,
+ Amount: output.Amount,
+ PkScript: output.PkScript,
+ Confirmations: confs,
+ Spendable: spendable,
+ Address: addr,
+ Account: account,
+ AddressType: addrType,
+ Locked: output.Locked,
+ }
+
+ return nil
+ })
+
+ return utxo, err
+}
+
+// LeaseOutput locks an output for a given duration, preventing it from being
+// used in transactions.
+//
+// This method allows a caller to reserve a specific UTXO for a certain period,
+// making it unavailable for other operations like coin selection. This is
+// useful in scenarios where a transaction is being built and its inputs need to
+// be protected from being used by other concurrent operations.
+//
+// How it works:
+// The method delegates the locking operation to the underlying transaction
+// store (`wtxmgr`), which maintains a record of all leased outputs. The lease
+// is identified by a unique `LockID` and has a specific `duration`.
+//
+// Logical Steps:
+// 1. Initiate a read-write database transaction.
+// 2. Call the `wtxmgr.LockOutput` method with the provided `LockID`,
+// outpoint, and `duration`.
+// 3. The `wtxmgr` checks if the output is known and not already locked by a
+// different ID.
+// 4. If the checks pass, it records the lock with an expiration time.
+// 5. The expiration time is returned to the caller.
+//
+// Database Actions:
+// - This method performs a single read-write database transaction
+// (`walletdb.Update`).
+// - It writes to the `wtxmgr` namespace to record the output lock.
+//
+// Time Complexity:
+// - The complexity is O(1) as it involves a direct lookup and write in the
+// database.
+//
+// TODO(yy): The current `wtxmgr.LockOutput` implementation does not check if
+// the output is already spent by an unmined transaction. This could lead to a
+// scenario where a spent output is leased. The implementation should be
+// improved to perform this check.
+//
+// NOTE: This is part of the UtxoManager interface implementation.
+func (w *Wallet) LeaseOutput(_ context.Context, id wtxmgr.LockID,
+ op wire.OutPoint, duration time.Duration) (time.Time, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return time.Time{}, err
+ }
+
+ var expiration time.Time
+
+ err = walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ txmgrNs := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+
+ expiration, err = w.txStore.LockOutput(
+ txmgrNs, id, op, duration,
+ )
+
+ return err
+ })
+
+ return expiration, err
+}
+
+// ReleaseOutput unlocks a previously leased output, making it available for
+// use.
+//
+// This method allows a caller to manually release a lock on a UTXO before its
+// expiration time. This is useful when a transaction-building process is
+// aborted and the reserved inputs need to be returned to the pool of available
+// UTXOs.
+//
+// How it works:
+// The method delegates the unlocking operation to the underlying transaction
+// store (`wtxmgr`), which removes the lock record for the specified outpoint.
+//
+// Logical Steps:
+// 1. Initiate a read-write database transaction.
+// 2. Call the `wtxmgr.UnlockOutput` method with the provided `LockID` and
+// outpoint.
+// 3. The `wtxmgr` verifies that the output is indeed locked by the same
+// `LockID` before removing the lock.
+//
+// Database Actions:
+// - This method performs a single read-write database transaction
+// (`walletdb.Update`).
+// - It deletes from the `wtxmgr` namespace to remove the output lock.
+//
+// Time Complexity:
+// - The complexity is O(1) as it involves a direct lookup and delete in the
+// database.
+//
+// TODO(yy): The current `wtxmgr.UnlockOutput` implementation does not validate
+// that the `LockID` matches the one that currently holds the lock. This could
+// allow any caller to unlock an output, which could be a potential security
+// risk in a multi-user environment. The implementation should be improved to
+// perform this check.
+//
+// NOTE: This is part of the UtxoManager interface implementation.
+func (w *Wallet) ReleaseOutput(_ context.Context, id wtxmgr.LockID,
+ op wire.OutPoint) error {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return err
+ }
+
+ return walletdb.Update(w.cfg.DB, func(tx walletdb.ReadWriteTx) error {
+ txmgrNs := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+ return w.txStore.UnlockOutput(txmgrNs, id, op)
+ })
+}
+
+// ListLeasedOutputs returns a list of all currently leased outputs.
+//
+// This method provides a way to inspect which UTXOs are currently locked and
+// when their leases expire. This can be useful for debugging and for managing
+// long-lived locks.
+//
+// How it works:
+// The method delegates the listing operation to the underlying transaction
+// store (`wtxmgr`), which scans its record of all leased outputs.
+//
+// Logical Steps:
+// 1. Initiate a read-only database transaction.
+// 2. Call the `wtxmgr.ListLeasedOutputs` method.
+// 3. The `wtxmgr` iterates through all the recorded locks and returns them
+// as a slice.
+//
+// Database Actions:
+// - This method performs a single read-only database transaction
+// (`walletdb.View`).
+// - It reads from the `wtxmgr` namespace to get the list of leased
+// outputs.
+//
+// Time Complexity:
+// - The complexity is O(L), where L is the number of leased outputs, as it
+// involves a full scan of the leased outputs bucket.
+//
+// TODO(yy): The current `wtxmgr.ListLeasedOutputs` implementation returns a
+// struct from the `wtxmgr` package. This is a leaky abstraction. The method
+// should return a struct defined in the `wallet` package to maintain a clean
+// separation of concerns.
+//
+// NOTE: This is part of the UtxoManager interface implementation.
+func (w *Wallet) ListLeasedOutputs(
+ _ context.Context) ([]*wtxmgr.LockedOutput, error) {
+
+ err := w.state.validateStarted()
+ if err != nil {
+ return nil, err
+ }
+
+ var leasedOutputs []*wtxmgr.LockedOutput
+
+ err = walletdb.View(w.cfg.DB, func(tx walletdb.ReadTx) error {
+ txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
+ leasedOutputs, err = w.txStore.ListLockedOutputs(txmgrNs)
+
+ return err
+ })
+
+ return leasedOutputs, err
+}
diff --git a/wallet/utxo_manager_benchmark_test.go b/wallet/utxo_manager_benchmark_test.go
new file mode 100644
index 0000000000..74fd093a0a
--- /dev/null
+++ b/wallet/utxo_manager_benchmark_test.go
@@ -0,0 +1,553 @@
+package wallet
+
+import (
+ "fmt"
+ "math"
+ "testing"
+ "time"
+
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/stretchr/testify/require"
+)
+
+// BenchmarkGetUtxoAPI benchmarks GetUtxo API and its deprecated variant
+// FetchOutpointInfo using same key scope and identical test data across
+// multiple dataset sizes. Test names start with dataset size to group API
+// comparisons for benchstat analysis.
+func BenchmarkGetUtxoAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ utxoGrowthPadding = decimalWidth(
+ utxoGrowth[len(utxoGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-Addresses-%0*d-UTXOs",
+ accountGrowthPadding, accountGrowth[i],
+ addressGrowthPadding, addressGrowth[i],
+ utxoGrowthPadding, utxoGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ outpoints := txsToOutpoints(bw.confirmedTxs)
+ testOutpoint := getTestUtxoOutpoint(outpoints)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := getUtxoDeprecated(
+ bw.Wallet, testOutpoint,
+ )
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ outpoints := txsToOutpoints(bw.confirmedTxs)
+ testOutpoint := getTestUtxoOutpoint(outpoints)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := bw.GetUtxo(
+ b.Context(), testOutpoint,
+ )
+ require.NoError(b, err)
+ }
+ })
+ }
+}
+
+// BenchmarkListUnspentAPI benchmarks ListUnspent API and its deprecated
+// variant ListUnspentDeprecated using same key scope and identical test data
+// across multiple dataset sizes. Test names start with dataset size to group
+// API comparisons for benchstat analysis.
+func BenchmarkListUnspentAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ utxoGrowthPadding = decimalWidth(
+ utxoGrowth[len(utxoGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+
+ minConfs = 0
+
+ maxConfs = math.MaxInt32
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ accountName, _ := generateAccountName(accountGrowth[i], scopes)
+
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-Addresses-%0*d-UTXOs",
+ accountGrowthPadding, accountGrowth[i],
+ addressGrowthPadding, addressGrowth[i],
+ utxoGrowthPadding, utxoGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := bw.ListUnspentDeprecated(
+ int32(minConfs), int32(maxConfs),
+ accountName,
+ )
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := bw.ListUnspent(
+ b.Context(), UtxoQuery{
+ Account: accountName,
+ MinConfs: int32(minConfs),
+ MaxConfs: int32(maxConfs),
+ },
+ )
+ require.NoError(b, err)
+ }
+ })
+ }
+}
+
+// BenchmarkLeaseOutputAPI benchmarks LeaseOutput API and its deprecated
+// variant LeaseOutputDeprecated. Although LeaseOutput is an O(1) operation,
+// testing across different dataset sizes helps identify any database bucket
+// depth effects or positional bias as the UTXO set grows.
+func BenchmarkLeaseOutputAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ utxoGrowthPadding = decimalWidth(
+ utxoGrowth[len(utxoGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+
+ lockID = wtxmgr.LockID{0x01, 0x02, 0x03, 0x04}
+
+ duration = time.Hour
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-Addresses-%0*d-UTXOs",
+ accountGrowthPadding, accountGrowth[i],
+ addressGrowthPadding, addressGrowth[i],
+ utxoGrowthPadding, utxoGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ outpoints := txsToOutpoints(bw.confirmedTxs)
+ testOutpoint := getTestUtxoOutpoint(outpoints)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := bw.LeaseOutputDeprecated(
+ lockID, testOutpoint, duration,
+ )
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ outpoints := txsToOutpoints(bw.confirmedTxs)
+ testOutpoint := getTestUtxoOutpoint(outpoints)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := bw.LeaseOutput(
+ b.Context(), lockID, testOutpoint,
+ duration,
+ )
+ require.NoError(b, err)
+ }
+ })
+ }
+}
+
+// BenchmarkReleaseOutputAPI benchmarks ReleaseOutput API and its deprecated
+// variant ReleaseOutputDeprecated. Although ReleaseOutput is an O(1) operation,
+// testing across different dataset sizes helps identify any database bucket
+// depth effects or positional bias as the UTXO set grows. Outputs must be
+// leased before they can be released.
+func BenchmarkReleaseOutputAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ linearGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ utxoGrowthPadding = decimalWidth(
+ utxoGrowth[len(utxoGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+
+ lockID = wtxmgr.LockID{0x01, 0x02, 0x03, 0x04}
+
+ duration = time.Hour
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-Addresses-%0*d-UTXOs",
+ accountGrowthPadding, accountGrowth[i],
+ addressGrowthPadding, addressGrowth[i],
+ utxoGrowthPadding, utxoGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ outpoints := txsToOutpoints(bw.confirmedTxs)
+ testOutpoint := getTestUtxoOutpoint(outpoints)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := bw.LeaseOutputDeprecated(
+ lockID, testOutpoint, duration,
+ )
+ require.NoError(b, err)
+
+ err = bw.ReleaseOutputDeprecated(
+ lockID, testOutpoint,
+ )
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ outpoints := txsToOutpoints(bw.confirmedTxs)
+ testOutpoint := getTestUtxoOutpoint(outpoints)
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := bw.LeaseOutput(
+ b.Context(), lockID, testOutpoint,
+ duration,
+ )
+ require.NoError(b, err)
+
+ err = bw.ReleaseOutput(
+ b.Context(), lockID, testOutpoint,
+ )
+ require.NoError(b, err)
+ }
+ })
+ }
+}
+
+// BenchmarkListLeasedOutputsAPI benchmarks ListLeasedOutputs API and its
+// deprecated variant ListLeasedOutputsDeprecated. The deprecated API performs
+// N+1 transaction lookups to enrich each leased output with value and pkScript,
+// while the new API returns minimal lock metadata in a single scan. Performance
+// difference scales with the number of leased outputs.
+func BenchmarkListLeasedOutputsAPI(b *testing.B) {
+ const (
+ // startGrowthIteration is the starting iteration index for the
+ // growth sequence.
+ startGrowthIteration = 0
+
+ // maxGrowthIteration is the maximum iteration index for the
+ // growth sequence.
+ maxGrowthIteration = 14
+ )
+
+ var (
+ accountGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ addressGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ constantGrowth,
+ )
+
+ utxoGrowth = mapRange(
+ startGrowthIteration, maxGrowthIteration,
+ exponentialGrowth,
+ )
+
+ accountGrowthPadding = decimalWidth(
+ accountGrowth[len(accountGrowth)-1],
+ )
+
+ addressGrowthPadding = decimalWidth(
+ addressGrowth[len(addressGrowth)-1],
+ )
+
+ utxoGrowthPadding = decimalWidth(
+ utxoGrowth[len(utxoGrowth)-1],
+ )
+
+ scopes = []waddrmgr.KeyScope{waddrmgr.KeyScopeBIP0084}
+
+ duration = time.Hour
+ )
+
+ for i := 0; i <= maxGrowthIteration; i++ {
+ name := fmt.Sprintf("%0*d-Accounts-%0*d-Addresses-%0*d-UTXOs",
+ accountGrowthPadding, accountGrowth[i],
+ addressGrowthPadding, addressGrowth[i],
+ utxoGrowthPadding, utxoGrowth[i])
+
+ b.Run(name+"/0-Before", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ // Lease all outputs to maximize the N+1 query impact.
+ leaseAllOutputs(
+ b, bw.Wallet, txsToOutpoints(bw.confirmedTxs),
+ duration,
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := bw.ListLeasedOutputsDeprecated()
+ require.NoError(b, err)
+ }
+ })
+
+ b.Run(name+"/1-After", func(b *testing.B) {
+ bw := setupBenchmarkWallet(
+ b, benchmarkWalletConfig{
+ scopes: scopes,
+ numAccounts: accountGrowth[i],
+ numAddresses: addressGrowth[i],
+ numWalletTxs: utxoGrowth[i],
+ },
+ )
+
+ // Lease all outputs to maximize the N+1 query impact.
+ leaseAllOutputs(
+ b, bw.Wallet, txsToOutpoints(bw.confirmedTxs),
+ duration,
+ )
+
+ b.ReportAllocs()
+ b.ResetTimer()
+
+ for b.Loop() {
+ _, err := bw.ListLeasedOutputs(b.Context())
+ require.NoError(b, err)
+ }
+ })
+ }
+}
diff --git a/wallet/utxo_manager_test.go b/wallet/utxo_manager_test.go
new file mode 100644
index 0000000000..d1a23cd17a
--- /dev/null
+++ b/wallet/utxo_manager_test.go
@@ -0,0 +1,360 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wallet
+
+import (
+ "testing"
+ "time"
+
+ "github.com/btcsuite/btcd/address/v2"
+ "github.com/btcsuite/btcd/btcec/v2"
+ "github.com/btcsuite/btcd/txscript/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/waddrmgr"
+ "github.com/btcsuite/btcwallet/wtxmgr"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+// TestListUnspent tests the ListUnspent method with various filters.
+func TestListUnspent(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet with mocks.
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Define account names.
+ account1 := defaultAccountName
+ account2 := "test"
+
+ // Create the addresses that our mocks will return.
+ privKeyDefault, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ addrDefault, err := address.NewAddressPubKey(
+ privKeyDefault.PubKey().SerializeCompressed(), &chainParams,
+ )
+ require.NoError(t, err)
+
+ privKeyTest, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ addrTest, err := address.NewAddressPubKey(
+ privKeyTest.PubKey().SerializeCompressed(), &chainParams,
+ )
+ require.NoError(t, err)
+
+ // Set the current block height to match the default mock (1).
+ currentHeight := int32(1)
+
+ mocks.addrStore.On("AddressDetails", mock.Anything, addrDefault).Return(
+ false, account1, waddrmgr.WitnessPubKey,
+ )
+ mocks.addrStore.On("AddressDetails", mock.Anything, addrTest).Return(
+ false, account2, waddrmgr.NestedWitnessPubKey,
+ )
+
+ // Now that the mocks are set up, we can create the pkScripts.
+ pkScriptDefault, err := txscript.PayToAddrScript(addrDefault)
+ require.NoError(t, err)
+ pkScriptTest, err := txscript.PayToAddrScript(addrTest)
+ require.NoError(t, err)
+
+ const (
+ minConf = 2
+ maxConf = 6
+ )
+
+ // Create two UTXOs, one for each address.
+ utxo1 := wtxmgr.Credit{
+ OutPoint: wire.OutPoint{
+ Hash: [32]byte{1},
+ Index: 0,
+ },
+ Amount: 100000,
+ PkScript: pkScriptDefault,
+ BlockMeta: wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ Height: currentHeight - minConf + 1,
+ },
+ },
+ }
+ utxo2 := wtxmgr.Credit{
+ OutPoint: wire.OutPoint{
+ Hash: [32]byte{2},
+ Index: 0,
+ },
+ Amount: 200000,
+ PkScript: pkScriptTest,
+ BlockMeta: wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ Height: currentHeight - maxConf + 1,
+ },
+ },
+ }
+
+ // Mock the UnspentOutputs method to return the two UTXOs.
+ mocks.txStore.On("UnspentOutputs", mock.Anything).Return(
+ []wtxmgr.Credit{utxo1, utxo2}, nil,
+ )
+
+ testCases := []struct {
+ name string
+ query UtxoQuery
+ expectedCount int
+ expectedAddrs map[string]bool
+ }{
+ {
+ name: "no filter",
+ query: UtxoQuery{MinConfs: 0, MaxConfs: 999999},
+ expectedCount: 2,
+ expectedAddrs: map[string]bool{
+ addrDefault.String(): true,
+ addrTest.String(): true,
+ },
+ },
+ {
+ name: "filter by default account",
+ query: UtxoQuery{
+ Account: account1,
+ MinConfs: 0,
+ MaxConfs: 999999,
+ },
+ expectedCount: 1,
+ expectedAddrs: map[string]bool{
+ addrDefault.String(): true,
+ },
+ },
+ {
+ name: "filter by test account",
+ query: UtxoQuery{
+ Account: account2,
+ MinConfs: 0,
+ MaxConfs: 999999,
+ },
+ expectedCount: 1,
+ expectedAddrs: map[string]bool{
+ addrTest.String(): true,
+ },
+ },
+ {
+ name: "filter by min confs",
+ query: UtxoQuery{
+ MinConfs: minConf + 1,
+ MaxConfs: 999999,
+ },
+ expectedCount: 1,
+ expectedAddrs: map[string]bool{
+ addrTest.String(): true,
+ },
+ },
+ {
+ name: "filter by max confs",
+ query: UtxoQuery{
+ MinConfs: 0,
+ MaxConfs: maxConf - 1,
+ },
+ expectedCount: 1,
+ expectedAddrs: map[string]bool{
+ addrDefault.String(): true,
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ utxos, err := w.ListUnspent(t.Context(), tc.query)
+ require.NoError(t, err)
+ require.Len(t, utxos, tc.expectedCount)
+
+ // Check that the correct addresses are returned. We do
+ // this by creating a map of the returned addresses and
+ // comparing it to the expected map. This ensures that
+ // all expected addresses are present and there are no
+ // duplicates.
+ returnedAddrs := make(map[string]bool)
+ for _, utxo := range utxos {
+ returnedAddrs[utxo.Address.String()] = true
+ }
+
+ require.Equal(t, tc.expectedAddrs, returnedAddrs)
+
+ // Check that the UTXOs are sorted by amount in
+ // ascending order.
+ for i := range len(utxos) - 1 {
+ require.LessOrEqual(
+ t, utxos[i].Amount, utxos[i+1].Amount,
+ )
+ }
+ })
+ }
+}
+
+// TestGetUtxo tests that the GetUtxo method can successfully retrieve a UTXO.
+func TestGetUtxo(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet with mocks.
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Define account names.
+ account1 := "default"
+
+ // Create the addresses that our mocks will return.
+ privKeyDefault, err := btcec.NewPrivateKey()
+ require.NoError(t, err)
+ addrDefault, err := address.NewAddressPubKey(
+ privKeyDefault.PubKey().SerializeCompressed(), &chainParams,
+ )
+ require.NoError(t, err)
+
+ // Set the current block height to match the default mock (1).
+ currentHeight := int32(1)
+
+ mocks.addrStore.On("AddressDetails", mock.Anything, addrDefault).Return(
+ false, account1, waddrmgr.WitnessPubKey,
+ )
+
+ // Now that the mocks are set up, we can create the pkScripts.
+ pkScriptDefault, err := txscript.PayToAddrScript(addrDefault)
+ require.NoError(t, err)
+
+ // Create a UTXO.
+ utxo1 := wtxmgr.Credit{
+ OutPoint: wire.OutPoint{
+ Hash: [32]byte{1},
+ Index: 0,
+ },
+ Amount: 100000,
+ PkScript: pkScriptDefault,
+ BlockMeta: wtxmgr.BlockMeta{
+ Block: wtxmgr.Block{
+ Height: currentHeight,
+ },
+ },
+ }
+
+ // Mock the GetUtxo method to return the UTXO.
+ mocks.txStore.On("GetUtxo", mock.Anything, utxo1.OutPoint).Return(
+ &utxo1, nil,
+ )
+
+ // Construct the expected Utxo.
+ expectedUtxo := &Utxo{
+ OutPoint: utxo1.OutPoint,
+ Amount: utxo1.Amount,
+ PkScript: utxo1.PkScript,
+ Confirmations: 1,
+ Spendable: false,
+ Address: addrDefault,
+ Account: account1,
+ AddressType: waddrmgr.WitnessPubKey,
+ }
+
+ // Now, try to get the UTXO and compare it to our expected result.
+ utxo, err := w.GetUtxo(t.Context(), utxo1.OutPoint)
+ require.NoError(t, err)
+ require.Equal(t, expectedUtxo, utxo)
+}
+
+// TestGetUtxo_Err tests the error conditions of the GetUtxo method.
+func TestGetUtxo_Err(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet with mocks.
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Test the case where the UTXO is not found.
+ utxoNotFound := wire.OutPoint{
+ Hash: [32]byte{2},
+ Index: 0,
+ }
+ mocks.txStore.On("GetUtxo", mock.Anything, utxoNotFound).Return(
+ nil, wtxmgr.ErrUtxoNotFound,
+ )
+ utxo, err := w.GetUtxo(t.Context(), utxoNotFound)
+ require.ErrorIs(t, err, wtxmgr.ErrUtxoNotFound)
+ require.Nil(t, utxo)
+}
+
+// TestLeaseOutput tests the LeaseOutput method.
+func TestLeaseOutput(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet with mocks.
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Create a UTXO.
+ utxo := wire.OutPoint{
+ Hash: [32]byte{1},
+ Index: 0,
+ }
+
+ // Mock the LockOutput method to return a fixed expiration time.
+ expiration := time.Now().Add(time.Hour)
+ mocks.txStore.On("LockOutput", mock.Anything, mock.Anything, utxo,
+ mock.Anything).Return(expiration, nil)
+
+ // Now, try to lease the output.
+ leaseID := wtxmgr.LockID{1}
+ leaseDuration := time.Hour
+ actualExpiration, err := w.LeaseOutput(
+ t.Context(), leaseID, utxo, leaseDuration,
+ )
+ require.NoError(t, err)
+ require.Equal(t, expiration, actualExpiration)
+}
+
+// TestReleaseOutput tests the ReleaseOutput method.
+func TestReleaseOutput(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet with mocks.
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Create a UTXO.
+ utxo := wire.OutPoint{
+ Hash: [32]byte{1},
+ Index: 0,
+ }
+
+ // Mock the UnlockOutput method to return nil.
+ mocks.txStore.On("UnlockOutput",
+ mock.Anything, mock.Anything, utxo,
+ ).Return(nil)
+
+ // Now, try to release the output.
+ leaseID := wtxmgr.LockID{1}
+ err := w.ReleaseOutput(t.Context(), leaseID, utxo)
+ require.NoError(t, err)
+}
+
+// TestListLeasedOutputs tests the ListLeasedOutputs method.
+func TestListLeasedOutputs(t *testing.T) {
+ t.Parallel()
+
+ // Create a new test wallet with mocks.
+ w, mocks := createStartedWalletWithMocks(t)
+
+ // Create a leased output.
+ leasedOutput := &wtxmgr.LockedOutput{
+ Outpoint: wire.OutPoint{
+ Hash: [32]byte{1},
+ Index: 0,
+ },
+ LockID: wtxmgr.LockID{1},
+ Expiration: time.Now().Add(time.Hour),
+ }
+
+ // Mock the ListLockedOutputs method to return the leased output.
+ mocks.txStore.On("ListLockedOutputs", mock.Anything).Return(
+ []*wtxmgr.LockedOutput{leasedOutput}, nil,
+ )
+
+ // Now, try to list the leased outputs.
+ leasedOutputs, err := w.ListLeasedOutputs(t.Context())
+ require.NoError(t, err)
+ require.Len(t, leasedOutputs, 1)
+ require.Equal(t, leasedOutput, leasedOutputs[0])
+}
diff --git a/wallet/utxos.go b/wallet/utxos.go
deleted file mode 100644
index 1c82b8f7eb..0000000000
--- a/wallet/utxos.go
+++ /dev/null
@@ -1,251 +0,0 @@
-// Copyright (c) 2016 The Decred developers
-// Copyright (c) 2017 The btcsuite developers
-// Use of this source code is governed by an ISC
-// license that can be found in the LICENSE file.
-
-package wallet
-
-import (
- "errors"
- "fmt"
-
- "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
- "github.com/btcsuite/btcd/psbt/v2"
- "github.com/btcsuite/btcd/txscript/v2"
- "github.com/btcsuite/btcd/wire/v2"
- "github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/btcsuite/btcwallet/walletdb"
- "github.com/btcsuite/btcwallet/wtxmgr"
-)
-
-var (
- // ErrNotMine is an error denoting that a Wallet instance is unable to
- // spend a specified output.
- ErrNotMine = errors.New("the passed output does not belong to the " +
- "wallet")
-)
-
-// OutputSelectionPolicy describes the rules for selecting an output from the
-// wallet.
-type OutputSelectionPolicy struct {
- Account uint32
- RequiredConfirmations int32
-}
-
-func (p *OutputSelectionPolicy) meetsRequiredConfs(txHeight,
- curHeight int32) bool {
-
- return hasMinConfs(p.RequiredConfirmations, txHeight, curHeight)
-}
-
-// UnspentOutputs fetches all unspent outputs from the wallet that match rules
-// described in the passed policy.
-func (w *Wallet) UnspentOutputs(policy OutputSelectionPolicy) ([]*TransactionOutput, error) {
- var outputResults []*TransactionOutput
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
-
- syncBlock := w.Manager.SyncedTo()
-
- // TODO: actually stream outputs from the db instead of fetching
- // all of them at once.
- outputs, err := w.TxStore.UnspentOutputs(txmgrNs)
- if err != nil {
- return err
- }
-
- for _, output := range outputs {
- // Ignore outputs that haven't reached the required
- // number of confirmations.
- if !policy.meetsRequiredConfs(output.Height, syncBlock.Height) {
- continue
- }
-
- // Ignore outputs that are not controlled by the account.
- _, addrs, _, err := txscript.ExtractPkScriptAddrs(output.PkScript,
- w.chainParams)
- if err != nil || len(addrs) == 0 {
- // Cannot determine which account this belongs
- // to without a valid address. TODO: Fix this
- // by saving outputs per account, or accounts
- // per output.
- continue
- }
- _, outputAcct, err := w.Manager.AddrAccount(addrmgrNs, addrs[0])
- if err != nil {
- return err
- }
- if outputAcct != policy.Account {
- continue
- }
-
- // Stakebase isn't exposed by wtxmgr so those will be
- // OutputKindNormal for now.
- outputSource := OutputKindNormal
- if output.FromCoinBase {
- outputSource = OutputKindCoinbase
- }
-
- result := &TransactionOutput{
- OutPoint: output.OutPoint,
- Output: wire.TxOut{
- Value: int64(output.Amount),
- PkScript: output.PkScript,
- },
- OutputKind: outputSource,
- ContainingBlock: BlockIdentity(output.Block),
- ReceiveTime: output.Received,
- }
- outputResults = append(outputResults, result)
- }
-
- return nil
- })
- return outputResults, err
-}
-
-// FetchInputInfo queries for the wallet's knowledge of the passed outpoint. If
-// the wallet determines this output is under its control, then the original
-// full transaction, the target txout, the derivation info and the number of
-// confirmations are returned. Otherwise, a non-nil error value of ErrNotMine
-// is returned instead.
-//
-// NOTE: This method is kept for compatibility.
-func (w *Wallet) FetchInputInfo(prevOut *wire.OutPoint) (*wire.MsgTx,
- *wire.TxOut, *psbt.Bip32Derivation, int64, error) {
-
- tx, txOut, confs, err := w.FetchOutpointInfo(prevOut)
- if err != nil {
- return nil, nil, nil, 0, err
- }
-
- derivation, err := w.FetchDerivationInfo(txOut.PkScript)
- if err != nil {
- return nil, nil, nil, 0, err
- }
-
- return tx, txOut, derivation, confs, nil
-}
-
-// fetchOutputAddr attempts to fetch the managed address corresponding to the
-// passed output script. This function is used to look up the proper key which
-// should be used to sign a specified input.
-func (w *Wallet) fetchOutputAddr(script []byte) (waddrmgr.ManagedAddress, error) {
- _, addrs, _, err := txscript.ExtractPkScriptAddrs(script, w.chainParams)
- if err != nil {
- return nil, err
- }
-
- // If the case of a multi-sig output, several address may be extracted.
- // Therefore, we simply select the key for the first address we know
- // of.
- for _, addr := range addrs {
- addr, err := w.AddressInfo(addr)
- if err == nil {
- return addr, nil
- }
- }
-
- return nil, ErrNotMine
-}
-
-// FetchOutpointInfo queries for the wallet's knowledge of the passed outpoint.
-// If the wallet determines this output is under its control, the original full
-// transaction, the target txout and the number of confirmations are returned.
-// Otherwise, a non-nil error value of ErrNotMine is returned instead.
-func (w *Wallet) FetchOutpointInfo(prevOut *wire.OutPoint) (*wire.MsgTx,
- *wire.TxOut, int64, error) {
-
- // We manually look up the output within the tx store.
- txid := &prevOut.Hash
- txDetail, err := UnstableAPI(w).TxDetails(txid)
- if err != nil {
- return nil, nil, 0, err
- } else if txDetail == nil {
- return nil, nil, 0, ErrNotMine
- }
-
- // With the output retrieved, we'll make an additional check to ensure
- // we actually have control of this output. We do this because the
- // check above only guarantees that the transaction is somehow relevant
- // to us, like in the event of us being the sender of the transaction.
- numOutputs := uint32(len(txDetail.TxRecord.MsgTx.TxOut))
- if prevOut.Index >= numOutputs {
- return nil, nil, 0, fmt.Errorf("invalid output index %v for "+
- "transaction with %v outputs", prevOut.Index,
- numOutputs)
- }
-
- // Exit early if the output doesn't belong to our wallet. We know it's
- // our UTXO iff the `TxDetails` has a credit record on this output.
- if !hasOutput(txDetail, prevOut.Index) {
- return nil, nil, 0, ErrNotMine
- }
-
- pkScript := txDetail.TxRecord.MsgTx.TxOut[prevOut.Index].PkScript
-
- // Determine the number of confirmations the output currently has.
- _, currentHeight, err := w.chainClient.GetBestBlock()
- if err != nil {
- return nil, nil, 0, fmt.Errorf("unable to retrieve current "+
- "height: %w", err)
- }
-
- confs := int64(0)
- if txDetail.Block.Height != -1 {
- confs = int64(currentHeight - txDetail.Block.Height)
- }
-
- return &txDetail.TxRecord.MsgTx, &wire.TxOut{
- Value: txDetail.TxRecord.MsgTx.TxOut[prevOut.Index].Value,
- PkScript: pkScript,
- }, confs, nil
-}
-
-// FetchDerivationInfo queries for the wallet's knowledge of the passed
-// pkScript and constructs the derivation info and returns it.
-func (w *Wallet) FetchDerivationInfo(pkScript []byte) (*psbt.Bip32Derivation,
- error) {
-
- addr, err := w.fetchOutputAddr(pkScript)
- if err != nil {
- return nil, err
- }
-
- pubKeyAddr, ok := addr.(waddrmgr.ManagedPubKeyAddress)
- if !ok {
- return nil, ErrNotMine
- }
- keyScope, derivationPath, _ := pubKeyAddr.DerivationInfo()
-
- derivation := &psbt.Bip32Derivation{
- PubKey: pubKeyAddr.PubKey().SerializeCompressed(),
- MasterKeyFingerprint: derivationPath.MasterKeyFingerprint,
- Bip32Path: []uint32{
- keyScope.Purpose + hdkeychain.HardenedKeyStart,
- keyScope.Coin + hdkeychain.HardenedKeyStart,
- derivationPath.Account,
- derivationPath.Branch,
- derivationPath.Index,
- },
- }
-
- return derivation, nil
-}
-
-// hasOutpoint takes an output identified by its output index and determines
-// whether the TxDetails contains this output. If the TxDetails doesn't have
-// this output, it means this output doesn't belong to our wallet.
-//
-// TODO(yy): implement this method on `TxDetails` and update the package
-// `wtxmgr` instead.
-func hasOutput(t *wtxmgr.TxDetails, outputIndex uint32) bool {
- for _, cred := range t.Credits {
- if outputIndex == cred.Index {
- return true
- }
- }
-
- return false
-}
diff --git a/wallet/utxos_test.go b/wallet/utxos_test.go
deleted file mode 100644
index 43ac1675b3..0000000000
--- a/wallet/utxos_test.go
+++ /dev/null
@@ -1,244 +0,0 @@
-// Copyright (c) 2020 The btcsuite developers
-// Use of this source code is governed by an ISC
-// license that can be found in the LICENSE file.
-
-package wallet
-
-import (
- "bytes"
- "testing"
-
- "github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
- "github.com/btcsuite/btcd/chainhash/v2"
- "github.com/btcsuite/btcd/txscript/v2"
- "github.com/btcsuite/btcd/wire/v2"
- "github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/stretchr/testify/require"
-)
-
-// TestFetchInputInfo checks that the wallet can gather information about an
-// output based on the address.
-func TestFetchInputInfo(t *testing.T) {
- t.Parallel()
-
- w, cleanup := testWallet(t)
- defer cleanup()
-
- // Create an address we can use to send some coins to.
- addr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0084)
- if err != nil {
- t.Fatalf("unable to get current address: %v", addr)
- }
- p2shAddr, err := txscript.PayToAddrScript(addr)
- if err != nil {
- t.Fatalf("unable to convert wallet address to p2sh: %v", err)
- }
-
- // Add an output paying to the wallet's address to the database.
- utxOut := wire.NewTxOut(100000, p2shAddr)
- incomingTx := &wire.MsgTx{
- TxIn: []*wire.TxIn{{}},
- TxOut: []*wire.TxOut{utxOut},
- }
- addUtxo(t, w, incomingTx)
-
- // Look up the UTXO for the outpoint now and compare it to our
- // expectations.
- prevOut := &wire.OutPoint{
- Hash: incomingTx.TxHash(),
- Index: 0,
- }
- tx, out, derivationPath, confirmations, err := w.FetchInputInfo(prevOut)
- if err != nil {
- t.Fatalf("error fetching input info: %v", err)
- }
- if !bytes.Equal(out.PkScript, utxOut.PkScript) || out.Value != utxOut.Value {
- t.Fatalf("unexpected TX out, got %v wanted %v", out, utxOut)
- }
- if !bytes.Equal(tx.TxOut[prevOut.Index].PkScript, utxOut.PkScript) {
- t.Fatalf("unexpected TX out, got %v wanted %v",
- tx.TxOut[prevOut.Index].PkScript, utxOut)
- }
- if len(derivationPath.Bip32Path) != 5 {
- t.Fatalf("expected derivation path of length %v, got %v", 3,
- len(derivationPath.Bip32Path))
- }
- if derivationPath.Bip32Path[0] !=
- waddrmgr.KeyScopeBIP0084.Purpose+hdkeychain.HardenedKeyStart {
- t.Fatalf("expected purpose %v, got %v",
- waddrmgr.KeyScopeBIP0084.Purpose,
- derivationPath.Bip32Path[0])
- }
- if derivationPath.Bip32Path[1] !=
- waddrmgr.KeyScopeBIP0084.Coin+hdkeychain.HardenedKeyStart {
- t.Fatalf("expected coin type %v, got %v",
- waddrmgr.KeyScopeBIP0084.Coin,
- derivationPath.Bip32Path[1])
- }
- if derivationPath.Bip32Path[2] != hdkeychain.HardenedKeyStart {
- t.Fatalf("expected account %v, got %v",
- hdkeychain.HardenedKeyStart, derivationPath.Bip32Path[2])
- }
- if derivationPath.Bip32Path[3] != 0 {
- t.Fatalf("expected branch %v, got %v", 0,
- derivationPath.Bip32Path[3])
- }
- if derivationPath.Bip32Path[4] != 0 {
- t.Fatalf("expected index %v, got %v", 0,
- derivationPath.Bip32Path[4])
- }
- if confirmations != int64(0-testBlockHeight) {
- t.Fatalf("unexpected number of confirmations, got %d wanted %d",
- confirmations, 0-testBlockHeight)
- }
-}
-
-// TestFetchOutpointInfo checks that the wallet can gather information about an
-// output based on the outpoint.
-func TestFetchOutpointInfo(t *testing.T) {
- t.Parallel()
-
- w, cleanup := testWallet(t)
- defer cleanup()
-
- // Create an address we can use to send some coins to.
- addr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0084)
- require.NoError(t, err)
- p2shAddr, err := txscript.PayToAddrScript(addr)
- require.NoError(t, err)
-
- // Add an output paying to the wallet's address to the database.
- utxOut := wire.NewTxOut(100000, p2shAddr)
- incomingTx := &wire.MsgTx{
- TxIn: []*wire.TxIn{{}},
- TxOut: []*wire.TxOut{utxOut},
- }
- addUtxo(t, w, incomingTx)
-
- // Look up the UTXO for the outpoint now and compare it to our
- // expectations.
- prevOut := &wire.OutPoint{
- Hash: incomingTx.TxHash(),
- Index: 0,
- }
- tx, out, confirmations, err := w.FetchOutpointInfo(prevOut)
- require.NoError(t, err)
-
- require.Equal(t, utxOut.PkScript, out.PkScript)
- require.Equal(t, utxOut.Value, out.Value)
- require.Equal(t, utxOut.PkScript, tx.TxOut[prevOut.Index].PkScript)
- require.Equal(t, int64(0-testBlockHeight), confirmations)
-}
-
-// TestFetchOutpointInfoErr checks when the wallet cannot find an output, a
-// proper error is returned.
-func TestFetchOutpointInfoErr(t *testing.T) {
- t.Parallel()
-
- w, cleanup := testWallet(t)
- defer cleanup()
-
- // Create an address we can use to send some coins to.
- addr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0084)
- require.NoError(t, err)
- p2shAddr, err := txscript.PayToAddrScript(addr)
- require.NoError(t, err)
-
- // Create a tx that has two outputs - output1 belongs to the wallet,
- // output2 is external.
- output1 := wire.NewTxOut(100000, p2shAddr)
- output2 := wire.NewTxOut(100000, p2shAddr)
- tx := &wire.MsgTx{
- TxIn: []*wire.TxIn{{}},
- TxOut: []*wire.TxOut{
- output1,
- output2,
- },
- }
-
- // Add the tx and its first output as the credit.
- addTxAndCredit(t, w, tx, 0)
-
- testCases := []struct {
- name string
- prevOut *wire.OutPoint
-
- // TODO(yy): refator `FetchOutpointInfo` to return wrapped
- // errors.
- errExpected string
- }{
- {
- name: "no tx details",
- prevOut: &wire.OutPoint{
- Hash: chainhash.Hash{1, 2, 3},
- Index: 0,
- },
- errExpected: "does not belong to the wallet",
- },
- {
- name: "invalid output index",
- prevOut: &wire.OutPoint{
- Hash: tx.TxHash(),
- Index: 1000,
- },
- errExpected: "invalid output index",
- },
- {
- name: "no credit found",
- prevOut: &wire.OutPoint{
- Hash: tx.TxHash(),
- Index: 1,
- },
- errExpected: "does not belong to the wallet",
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- t.Parallel()
-
- // Look up the UTXO for the outpoint now and compare it
- // to the expected error.
- tx, out, conf, err := w.FetchOutpointInfo(tc.prevOut)
- require.ErrorContains(t, err, tc.errExpected)
- require.Nil(t, tx)
- require.Nil(t, out)
- require.Zero(t, conf)
- })
- }
-}
-
-// TestFetchDerivationInfo checks that the wallet can gather the derivation
-// info about an output based on the pkScript.
-func TestFetchDerivationInfo(t *testing.T) {
- t.Parallel()
-
- w, cleanup := testWallet(t)
- defer cleanup()
-
- // Create an address we can use to send some coins to.
- addr, err := w.CurrentAddress(0, waddrmgr.KeyScopeBIP0084)
- require.NoError(t, err)
- p2shAddr, err := txscript.PayToAddrScript(addr)
- require.NoError(t, err)
-
- // Add an output paying to the wallet's address to the database.
- utxOut := wire.NewTxOut(100000, p2shAddr)
- incomingTx := &wire.MsgTx{
- TxIn: []*wire.TxIn{{}},
- TxOut: []*wire.TxOut{utxOut},
- }
- addUtxo(t, w, incomingTx)
-
- info, err := w.FetchDerivationInfo(utxOut.PkScript)
- require.NoError(t, err)
-
- require.Len(t, info.Bip32Path, 5)
- require.Equal(t, waddrmgr.KeyScopeBIP0084.Purpose+
- hdkeychain.HardenedKeyStart, info.Bip32Path[0])
- require.Equal(t, waddrmgr.KeyScopeBIP0084.Coin+
- hdkeychain.HardenedKeyStart, info.Bip32Path[1])
- require.EqualValues(t, hdkeychain.HardenedKeyStart, info.Bip32Path[2])
- require.Equal(t, uint32(0), info.Bip32Path[3])
- require.Equal(t, uint32(0), info.Bip32Path[4])
-}
diff --git a/wallet/wallet.go b/wallet/wallet.go
index e27211753a..ecfa297091 100644
--- a/wallet/wallet.go
+++ b/wallet/wallet.go
@@ -3,36 +3,29 @@
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
+// Package wallet provides a bitcoin wallet that is capable of fulfilling all
+// the duties of a typical bitcoin wallet such as creating and managing keys,
+// creating and signing transactions, and customizing of transaction fees.
+//
+// TODO(yy): bring wrapcheck back when implementing the `Store` interface.
+//
+//nolint:wrapcheck,cyclop
package wallet
import (
- "bytes"
- "encoding/hex"
+ "context"
"errors"
"fmt"
- "sort"
"sync"
- "sync/atomic"
"time"
- "github.com/btcsuite/btcd/address/v2"
- "github.com/btcsuite/btcd/blockchain"
- "github.com/btcsuite/btcd/btcec/v2"
- "github.com/btcsuite/btcd/btcjson"
- "github.com/btcsuite/btcd/btcutil/v2"
"github.com/btcsuite/btcd/btcutil/v2/hdkeychain"
"github.com/btcsuite/btcd/chaincfg/v2"
- "github.com/btcsuite/btcd/chainhash/v2"
- "github.com/btcsuite/btcd/txscript/v2"
"github.com/btcsuite/btcd/wire/v2"
"github.com/btcsuite/btcwallet/chain"
"github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/btcsuite/btcwallet/wallet/txauthor"
- "github.com/btcsuite/btcwallet/wallet/txrules"
"github.com/btcsuite/btcwallet/walletdb"
- "github.com/btcsuite/btcwallet/walletdb/migration"
"github.com/btcsuite/btcwallet/wtxmgr"
- "github.com/davecgh/go-spew/spew"
)
const (
@@ -54,14 +47,24 @@ const (
// defaultSyncRetryInterval is the default amount of time to wait
// between re-tries on errors during initial sync.
defaultSyncRetryInterval = 5 * time.Second
+
+ // birthdayBlockDelta is the maximum time delta allowed between our
+ // birthday timestamp and our birthday block's timestamp when searching
+ // for a better birthday block candidate (if possible).
+ birthdayBlockDelta = 2 * time.Hour
+
+ // defaultLockDuration is the default duration for automatic wallet
+ // locking.
+ defaultLockDuration = 10 * time.Minute
+
+ // MinRecoveryWindow is the minimum allowed value for the RecoveryWindow
+ // configuration parameter. This value ensures that a sufficient number
+ // of addresses are scanned during wallet recovery to avoid missing
+ // funds due to gaps in the address chain.
+ MinRecoveryWindow = 20
)
var (
- // ErrNotSynced describes an error where an operation cannot complete
- // due wallet being out of sync (and perhaps currently syncing with)
- // the remote chain server.
- ErrNotSynced = errors.New("wallet is not synchronized with the chain server")
-
// ErrWalletShuttingDown is an error returned when we attempt to make a
// request to the wallet but it is in the process of or has already shut
// down.
@@ -84,527 +87,155 @@ var (
// watch-only mode where we can select coins but not sign any inputs.
ErrTxUnsigned = errors.New("watch-only wallet, transaction not signed")
- // Namespace bucket keys.
- waddrmgrNamespaceKey = []byte("waddrmgr")
- wtxmgrNamespaceKey = []byte("wtxmgr")
-)
-
-// Coin represents a spendable UTXO which is available for coin selection.
-type Coin struct {
- wire.TxOut
+ // ErrNoAssocPrivateKey is returned when a private key is requested for
+ // an address that has no associated private key.
+ ErrNoAssocPrivateKey = errors.New("address does not have an " +
+ "associated private key")
- wire.OutPoint
-}
+ // ErrInvalidAccountKey is returned when the provided extended public key
+ // does not meet the requirements for an account key (e.g., wrong depth
+ // or not hardened).
+ ErrInvalidAccountKey = errors.New("invalid account key")
-// CoinSelectionStrategy is an interface that represents a coin selection
-// strategy. A coin selection strategy is responsible for ordering, shuffling or
-// filtering a list of coins before they are passed to the coin selection
-// algorithm.
-type CoinSelectionStrategy interface {
- // ArrangeCoins takes a list of coins and arranges them according to the
- // specified coin selection strategy and fee rate.
- ArrangeCoins(eligible []Coin, feeSatPerKb btcutil.Amount) ([]Coin,
- error)
-}
+ // ErrMissingParam is returned when a required parameter is missing from
+ // the configuration.
+ ErrMissingParam = errors.New("missing config parameter")
-var (
- // CoinSelectionLargest always picks the largest available utxo to add
- // to the transaction next.
- CoinSelectionLargest CoinSelectionStrategy = &LargestFirstCoinSelector{}
+ // ErrInvalidParam is returned when a parameter is invalid.
+ ErrInvalidParam = errors.New("invalid config parameter")
- // CoinSelectionRandom randomly selects the next utxo to add to the
- // transaction. This strategy prevents the creation of ever smaller
- // utxos over time.
- CoinSelectionRandom CoinSelectionStrategy = &RandomCoinSelector{}
+ // Namespace bucket keys.
+ waddrmgrNamespaceKey = []byte("waddrmgr")
+ wtxmgrNamespaceKey = []byte("wtxmgr")
)
-// Wallet is a structure containing all the components for a
-// complete wallet. It contains the Armory-style key store
-// addresses and keys),
-type Wallet struct {
- publicPassphrase []byte
-
- // Data stores
- db walletdb.DB
- Manager *waddrmgr.Manager
- TxStore *wtxmgr.Store
-
- chainClient chain.Interface
- chainClientLock sync.Mutex
- chainClientSynced bool
- chainClientSyncMtx sync.Mutex
-
- newAddrMtx sync.Mutex
-
- lockedOutpoints map[wire.OutPoint]struct{}
- lockedOutpointsMtx sync.Mutex
-
- recovering atomic.Value
- recoveryWindow uint32
-
- // Channels for rescan processing. Requests are added and merged with
- // any waiting requests, before being sent to another goroutine to
- // call the rescan RPC.
- rescanAddJob chan *RescanJob
- rescanBatch chan *rescanBatch
- rescanNotifications chan interface{} // From chain server
- rescanProgress chan *RescanProgressMsg
- rescanFinished chan *RescanFinishedMsg
-
- // Channel for transaction creation requests.
- createTxRequests chan createTxRequest
-
- // Channels for the manager locker.
- unlockRequests chan unlockRequest
- lockRequests chan struct{}
- holdUnlockRequests chan chan heldUnlock
- lockState chan bool
- changePassphrase chan changePassphraseRequest
- changePassphrases chan changePassphrasesRequest
-
- NtfnServer *NotificationServer
-
- chainParams *chaincfg.Params
- wg sync.WaitGroup
-
- started bool
- quit chan struct{}
- quitMu sync.Mutex
-
- // syncRetryInterval is the amount of time to wait between re-tries on
- // errors during initial sync.
- syncRetryInterval time.Duration
-}
-
-// Start starts the goroutines necessary to manage a wallet.
-func (w *Wallet) Start() {
- w.quitMu.Lock()
- select {
- case <-w.quit:
- // Restart the wallet goroutines after shutdown finishes.
- w.WaitForShutdown()
- w.quit = make(chan struct{})
- default:
- // Ignore when the wallet is still running.
- if w.started {
- w.quitMu.Unlock()
- return
- }
- w.started = true
- }
- w.quitMu.Unlock()
-
- w.wg.Add(2)
- go w.txCreator()
- go w.walletLocker()
-}
-
-// SynchronizeRPC associates the wallet with the consensus RPC client,
-// synchronizes the wallet with the latest changes to the blockchain, and
-// continuously updates the wallet through RPC notifications.
-//
-// This method is unstable and will be removed when all syncing logic is moved
-// outside of the wallet package.
-func (w *Wallet) SynchronizeRPC(chainClient chain.Interface) {
- w.quitMu.Lock()
- select {
- case <-w.quit:
- w.quitMu.Unlock()
- return
- default:
- }
- w.quitMu.Unlock()
-
- // TODO: Ignoring the new client when one is already set breaks callers
- // who are replacing the client, perhaps after a disconnect.
- w.chainClientLock.Lock()
- if w.chainClient != nil {
- w.chainClientLock.Unlock()
- return
- }
- w.chainClient = chainClient
-
- // If the chain client is a NeutrinoClient instance, set a birthday so
- // we don't download all the filters as we go.
- switch cc := chainClient.(type) {
- case *chain.NeutrinoClient:
- cc.SetStartTime(w.Manager.Birthday())
- case *chain.BitcoindClient:
- cc.SetBirthday(w.Manager.Birthday())
- }
- w.chainClientLock.Unlock()
-
- // TODO: It would be preferable to either run these goroutines
- // separately from the wallet (use wallet mutator functions to
- // make changes from the RPC client) and not have to stop and
- // restart them each time the client disconnects and reconnets.
- w.wg.Add(4)
- go w.handleChainNotifications()
- go w.rescanBatchHandler()
- go w.rescanProgressHandler()
- go w.rescanRPCHandler()
-}
-
-// requireChainClient marks that a wallet method can only be completed when the
-// consensus RPC server is set. This function and all functions that call it
-// are unstable and will need to be moved when the syncing code is moved out of
-// the wallet.
-func (w *Wallet) requireChainClient() (chain.Interface, error) {
- w.chainClientLock.Lock()
- chainClient := w.chainClient
- w.chainClientLock.Unlock()
- if chainClient == nil {
- return nil, errors.New("blockchain RPC is inactive")
- }
- return chainClient, nil
-}
-
-// ChainClient returns the optional consensus RPC client associated with the
-// wallet.
-//
-// This function is unstable and will be removed once sync logic is moved out of
-// the wallet.
-func (w *Wallet) ChainClient() chain.Interface {
- w.chainClientLock.Lock()
- chainClient := w.chainClient
- w.chainClientLock.Unlock()
- return chainClient
-}
-
-// quitChan atomically reads the quit channel.
-func (w *Wallet) quitChan() <-chan struct{} {
- w.quitMu.Lock()
- c := w.quit
- w.quitMu.Unlock()
- return c
-}
-
-// Stop signals all wallet goroutines to shutdown.
-func (w *Wallet) Stop() {
- <-w.endRecovery()
-
- w.quitMu.Lock()
- quit := w.quit
- w.quitMu.Unlock()
+// SyncMethod determines the strategy used to synchronize the wallet with the
+// blockchain.
+type SyncMethod uint8
- select {
- case <-quit:
- default:
- close(quit)
- w.chainClientLock.Lock()
- if w.chainClient != nil {
- w.chainClient.Stop()
- w.chainClient = nil
- }
- w.chainClientLock.Unlock()
- }
-}
+const (
+ // SyncMethodAuto defaults to CFilters if available (Neutrino/Bitcoind),
+ // falling back to Full Block scan if not.
+ //
+ // Use Case: Default for most users.
+ //
+ // Logic:
+ // 1. Checks if the number of watched items (Addresses + UTXOs) exceeds
+ // a heuristic threshold (100,000). If so, switches to Full Block
+ // scanning to avoid the CPU bottleneck of client-side filter
+ // matching.
+ // 2. Attempts to fetch CFilters. If successful, uses CFilters.
+ // 3. If CFilters are unavailable, falls back to Full Block scanning.
+ SyncMethodAuto SyncMethod = iota
+
+ // SyncMethodCFilters forces the use of Compact Filters (BIP 157/158).
+ // The sync process will fail if the backend does not support filters.
+ //
+ // Use Case: Bandwidth-constrained environments (mobile) or when privacy
+ // is paramount (Neutrino P2P).
+ //
+ // Pros:
+ // - Minimal Bandwidth: Only downloads headers and filters (approx 4MB
+ // per 200 blocks) plus relevant blocks. Ideal for sparse wallets.
+ //
+ // Cons:
+ // - CPU Intensive: Client-side matching is O(N*M) where N=Blocks,
+ // M=Addresses. Can be slow for massive wallets (>100k addresses).
+ // - Slower if Match Rate is High: If the wallet has transactions in
+ // nearly every block, it downloads filters AND blocks, resulting in
+ // higher overhead than full block scanning.
+ SyncMethodCFilters
+
+ // SyncMethodFullBlocks forces the use of full block downloading and
+ // scanning, bypassing filters entirely.
+ //
+ // Use Case: High-bandwidth/Local environments (Bitcoind on localhost)
+ // or massive wallets (exchanges, heavy users).
+ //
+ // Pros:
+ // - Low CPU: Block parsing and map lookup is extremely fast compared
+ // to filter matching. Scaling is O(1) or O(TxOutputs) for address
+ // lookups, independent of watchlist size.
+ // - Faster for High Match Rates: Avoids the overhead of
+ // fetching/matching filters when most blocks are going to be
+ // downloaded anyway.
+ //
+ // Cons:
+ // - High Bandwidth: Downloads all block data (approx 200MB per 200
+ // blocks). Slow on limited connections.
+ SyncMethodFullBlocks
+)
-// ShuttingDown returns whether the wallet is currently in the process of
-// shutting down or not.
-func (w *Wallet) ShuttingDown() bool {
- select {
- case <-w.quitChan():
- return true
- default:
- return false
- }
-}
+// Config holds the configuration options for creating a new
+// WalletController.
+type Config struct {
+ // DB is the underlying database for the wallet.
+ DB walletdb.DB
-// WaitForShutdown blocks until all wallet goroutines have finished executing.
-func (w *Wallet) WaitForShutdown() {
- w.chainClientLock.Lock()
- if w.chainClient != nil {
- w.chainClient.WaitForShutdown()
- }
- w.chainClientLock.Unlock()
- w.wg.Wait()
-}
+ // Chain is the interface to the blockchain (e.g. bitcoind,
+ // neutrino). If set, the wallet will automatically synchronize with
+ // the chain upon Start.
+ Chain chain.Interface
-// SynchronizingToNetwork returns whether the wallet is currently synchronizing
-// with the Bitcoin network.
-func (w *Wallet) SynchronizingToNetwork() bool {
- // At the moment, RPC is the only synchronization method. In the
- // future, when SPV is added, a separate check will also be needed, or
- // SPV could always be enabled if RPC was not explicitly specified when
- // creating the wallet.
- w.chainClientSyncMtx.Lock()
- syncing := w.chainClient != nil
- w.chainClientSyncMtx.Unlock()
- return syncing
-}
+ // ChainParams defines the network parameters (e.g. mainnet, testnet).
+ ChainParams *chaincfg.Params
-// ChainSynced returns whether the wallet has been attached to a chain server
-// and synced up to the best block on the main chain.
-func (w *Wallet) ChainSynced() bool {
- w.chainClientSyncMtx.Lock()
- synced := w.chainClientSynced
- w.chainClientSyncMtx.Unlock()
- return synced
-}
+ // RecoveryWindow specifies the address lookahead for recovery.
+ RecoveryWindow uint32
-// SetChainSynced marks whether the wallet is connected to and currently in sync
-// with the latest block notified by the chain server.
-//
-// NOTE: Due to an API limitation with rpcclient, this may return true after
-// the client disconnected (and is attempting a reconnect). This will be unknown
-// until the reconnect notification is received, at which point the wallet can be
-// marked out of sync again until after the next rescan completes.
-func (w *Wallet) SetChainSynced(synced bool) {
- w.chainClientSyncMtx.Lock()
- w.chainClientSynced = synced
- w.chainClientSyncMtx.Unlock()
-}
+ // WalletSyncRetryInterval is the interval at which the wallet should
+ // retry syncing to the chain if it encounters an error.
+ WalletSyncRetryInterval time.Duration
-// activeData returns the currently-active receiving addresses and all unspent
-// outputs. This is primarely intended to provide the parameters for a
-// rescan request.
-func (w *Wallet) activeData(
- dbtx walletdb.ReadWriteTx) ([]address.Address, []wtxmgr.Credit, error) {
+ // SyncMethod specifies the synchronization strategy to use.
+ SyncMethod SyncMethod
- addrmgrNs := dbtx.ReadBucket(waddrmgrNamespaceKey)
- txmgrNs := dbtx.ReadWriteBucket(wtxmgrNamespaceKey)
+ // AutoLockDuration is the default duration after which the wallet will
+ // automatically lock itself if no specific duration is provided during
+ // unlock. If zero or negative, the wallet will default to a hardcoded
+ // safe duration (e.g. 10m) unless explicitly overridden by the unlock
+ // request.
+ AutoLockDuration time.Duration
- var addrs []address.Address
- err := w.Manager.ForEachRelevantActiveAddress(
- addrmgrNs, func(addr address.Address) error {
- addrs = append(addrs, addr)
- return nil
- },
- )
- if err != nil {
- return nil, nil, err
- }
+ // Name is the unique identifier for the wallet. It is used to track
+ // active wallet instances within the Manager.
+ Name string
- // Before requesting the list of spendable UTXOs, we'll delete any
- // expired output locks.
- err = w.TxStore.DeleteExpiredLockedOutputs(
- dbtx.ReadWriteBucket(wtxmgrNamespaceKey),
- )
- if err != nil {
- return nil, nil, err
- }
+ // PubPassphrase is the public passphrase for the wallet.
+ PubPassphrase []byte
- unspent, err := w.TxStore.OutputsToWatch(txmgrNs)
- return addrs, unspent, err
+ // MaxCFilterItems is the threshold of watched items (addresses +
+ // outpoints) above which the wallet will fallback to full block
+ // scanning when SyncMethodAuto is used. This avoids the CPU bottleneck
+ // of client-side filter matching for large watchlists. If 0, a default
+ // of 100,000 is used.
+ MaxCFilterItems int
}
-// syncWithChain brings the wallet up to date with the current chain server
-// connection. It creates a rescan request and blocks until the rescan has
-// finished. The birthday block can be passed in, if set, to ensure we can
-// properly detect if it gets rolled back.
-func (w *Wallet) syncWithChain(birthdayStamp *waddrmgr.BlockStamp) error {
- chainClient, err := w.requireChainClient()
- if err != nil {
- return err
- }
-
- // Neutrino relies on the information given to it by the cfheader server
- // so it knows exactly whether it's synced up to the server's state or
- // not, even on dev chains. To recover a Neutrino wallet, we need to
- // make sure it's synced before we start scanning for addresses,
- // otherwise we might miss some if we only scan up to its current sync
- // point.
- neutrinoRecovery := chainClient.BackEnd() == "neutrino" &&
- w.recoveryWindow > 0
-
- // We'll wait until the backend is synced to ensure we get the latest
- // MaxReorgDepth blocks to store. We don't do this for development
- // environments as we can't guarantee a lively chain, except for
- // Neutrino, where the cfheader server tells us what it believes the
- // chain tip is.
- if !w.isDevEnv() || neutrinoRecovery {
- log.Debug("Waiting for chain backend to sync to tip")
- if err := w.waitUntilBackendSynced(chainClient); err != nil {
- return err
- }
- log.Debug("Chain backend synced to tip!")
- }
-
- // If we've yet to find our birthday block, we'll do so now.
- if birthdayStamp == nil {
- var err error
- birthdayStamp, err = locateBirthdayBlock(
- chainClient, w.Manager.Birthday(),
- )
- if err != nil {
- return fmt.Errorf("unable to locate birthday block: %w",
- err)
- }
-
- // We'll also determine our initial sync starting height. This
- // is needed as the wallet can now begin storing blocks from an
- // arbitrary height, rather than all the blocks from genesis, so
- // we persist this height to ensure we don't store any blocks
- // before it.
- startHeight := birthdayStamp.Height
-
- // With the starting height obtained, get the remaining block
- // details required by the wallet.
- startHash, err := chainClient.GetBlockHash(int64(startHeight))
- if err != nil {
- return err
- }
- startHeader, err := chainClient.GetBlockHeader(startHash)
- if err != nil {
- return err
- }
-
- err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- err := w.Manager.SetSyncedTo(ns, &waddrmgr.BlockStamp{
- Hash: *startHash,
- Height: startHeight,
- Timestamp: startHeader.Timestamp,
- })
- if err != nil {
- return err
- }
- return w.Manager.SetBirthdayBlock(ns, *birthdayStamp, true)
- })
- if err != nil {
- return fmt.Errorf("unable to persist initial sync "+
- "data: %w", err)
- }
- }
-
- // If the wallet requested an on-chain recovery of its funds, we'll do
- // so now.
- if w.recoveryWindow > 0 {
- if err := w.recovery(chainClient, birthdayStamp); err != nil {
- return fmt.Errorf("unable to perform wallet recovery: "+
- "%w", err)
- }
+// validate checks the configuration for consistency and completeness.
+func (c *Config) validate() error {
+ if c.DB == nil {
+ return fmt.Errorf("%w: DB", ErrMissingParam)
}
- // Compare previously-seen blocks against the current chain. If any of
- // these blocks no longer exist, rollback all of the missing blocks
- // before catching up with the rescan.
- rollback := false
- rollbackStamp := w.Manager.SyncedTo()
- err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- txmgrNs := tx.ReadWriteBucket(wtxmgrNamespaceKey)
-
- for height := rollbackStamp.Height; true; height-- {
- hash, err := w.Manager.BlockHash(addrmgrNs, height)
- if err != nil {
- return err
- }
- chainHash, err := chainClient.GetBlockHash(int64(height))
- if err != nil {
- return err
- }
- header, err := chainClient.GetBlockHeader(chainHash)
- if err != nil {
- return err
- }
-
- rollbackStamp.Hash = *chainHash
- rollbackStamp.Height = height
- rollbackStamp.Timestamp = header.Timestamp
-
- if bytes.Equal(hash[:], chainHash[:]) {
- break
- }
- rollback = true
- }
-
- // If a rollback did not happen, we can proceed safely.
- if !rollback {
- return nil
- }
-
- // Otherwise, we'll mark this as our new synced height.
- err := w.Manager.SetSyncedTo(addrmgrNs, &rollbackStamp)
- if err != nil {
- return err
- }
-
- // If the rollback happened to go beyond our birthday stamp,
- // we'll need to find a new one by syncing with the chain again
- // until finding one.
- if rollbackStamp.Height <= birthdayStamp.Height &&
- rollbackStamp.Hash != birthdayStamp.Hash {
-
- err := w.Manager.SetBirthdayBlock(
- addrmgrNs, rollbackStamp, true,
- )
- if err != nil {
- return err
- }
- }
-
- // Finally, we'll roll back our transaction store to reflect the
- // stale state. `Rollback` unconfirms transactions at and beyond
- // the passed height, so add one to the new synced-to height to
- // prevent unconfirming transactions in the synced-to block.
- return w.TxStore.Rollback(txmgrNs, rollbackStamp.Height+1)
- })
- if err != nil {
- return err
+ if c.Chain == nil {
+ return fmt.Errorf("%w: Chain", ErrMissingParam)
}
- // Request notifications for connected and disconnected blocks.
- //
- // TODO(jrick): Either request this notification only once, or when
- // rpcclient is modified to allow some notification request to not
- // automatically resent on reconnect, include the notifyblocks request
- // as well. I am leaning towards allowing off all rpcclient
- // notification re-registrations, in which case the code here should be
- // left as is.
- if err := chainClient.NotifyBlocks(); err != nil {
- return err
+ if c.ChainParams == nil {
+ return fmt.Errorf("%w: ChainParams", ErrMissingParam)
}
- // Finally, we'll trigger a wallet rescan and request notifications for
- // transactions sending to all wallet addresses and spending all wallet
- // UTXOs.
- var (
- addrs []address.Address
- unspent []wtxmgr.Credit
- )
- err = walletdb.Update(w.db, func(dbtx walletdb.ReadWriteTx) error {
- addrs, unspent, err = w.activeData(dbtx)
- return err
- })
- if err != nil {
- return err
+ if c.Name == "" {
+ return fmt.Errorf("%w: Name", ErrMissingParam)
}
- return w.rescanWithTarget(addrs, unspent, nil)
-}
-
-// isDevEnv determines whether the wallet is currently under a local developer
-// environment, e.g. simnet or regtest.
-func (w *Wallet) isDevEnv() bool {
- switch uint32(w.ChainParams().Net) {
- case uint32(chaincfg.RegressionNetParams.Net):
- case uint32(chaincfg.SimNetParams.Net):
- default:
- return false
+ if c.RecoveryWindow < MinRecoveryWindow {
+ return fmt.Errorf("%w: RecoveryWindow must be at least %d",
+ ErrInvalidParam, MinRecoveryWindow)
}
- return true
-}
-// waitUntilBackendSynced blocks until the chain backend considers itself
-// "current".
-func (w *Wallet) waitUntilBackendSynced(chainClient chain.Interface) error {
- // We'll poll every second to determine if our chain considers itself
- // "current".
- t := time.NewTicker(time.Second)
- defer t.Stop()
-
- for {
- select {
- case <-t.C:
- if chainClient.IsCurrent() {
- return nil
- }
- case <-w.quitChan():
- return ErrWalletShuttingDown
- }
- }
+ return nil
}
// locateBirthdayBlock returns a block that meets the given birthday timestamp
@@ -615,6 +246,7 @@ func locateBirthdayBlock(chainClient chainConn,
// Retrieve the lookup range for our block.
startHeight := int32(0)
+
_, bestHeight, err := chainClient.GetBestBlock()
if err != nil {
return nil, err
@@ -633,11 +265,15 @@ func locateBirthdayBlock(chainClient chainConn,
for {
// Retrieve the timestamp for the block halfway through our
// range.
+ //
+ //nolint:mnd // Division by 2 is standard for binary search.
mid := left + (right-left)/2
+
hash, err := chainClient.GetBlockHash(int64(mid))
if err != nil {
return nil, err
}
+
header, err := chainClient.GetBlockHeader(hash)
if err != nil {
return nil, err
@@ -654,6 +290,7 @@ func locateBirthdayBlock(chainClient chainConn,
Height: mid,
Timestamp: header.Timestamp,
}
+
break
}
@@ -676,6 +313,7 @@ func locateBirthdayBlock(chainClient chainConn,
Height: mid,
Timestamp: header.Timestamp,
}
+
break
}
@@ -686,3625 +324,156 @@ func locateBirthdayBlock(chainClient chainConn,
return birthdayBlock, nil
}
-// recoverySyncer is used to synchronize wallet and address manager locking
-// with the end of recovery. (*Wallet).recovery will store a recoverySyncer
-// when invoked, and will close the done chan upon exit. Setting the quit flag
-// will cause recovery to end after the current batch of blocks.
-type recoverySyncer struct {
- done chan struct{}
- quit uint32 // atomic
-}
-
-// recovery attempts to recover any unspent outputs that pay to any of our
-// addresses starting from our birthday, or the wallet's tip (if higher), which
-// would indicate resuming a recovery after a restart.
-func (w *Wallet) recovery(chainClient chain.Interface,
- birthdayBlock *waddrmgr.BlockStamp) error {
+// Wallet is a structure containing all the components for a complete wallet.
+// It manages the cryptographic keys, transaction history, and synchronization
+// with the blockchain.
+type Wallet struct {
+ // walletDeprecated embeds the legacy state and channels. Access to
+ // these should be phased out as refactoring progresses.
+ *walletDeprecated
- log.Infof("RECOVERY MODE ENABLED -- rescanning for used addresses "+
- "with recovery_window=%d", w.recoveryWindow)
+ // addrStore is the address and key manager responsible for hierarchical
+ // deterministic (HD) derivation and storage of cryptographic keys.
+ addrStore waddrmgr.AddrStore
- // Wallet locking must synchronize with the end of recovery, since use of
- // keys in recovery is racy with manager IsLocked checks, which could
- // result in enrypting data with a zeroed key.
- syncer := &recoverySyncer{done: make(chan struct{})}
- w.recovering.Store(syncer)
- defer close(syncer.done)
+ // txStore is the transaction manager responsible for storing and
+ // querying the wallet's transaction history and unspent outputs.
+ txStore wtxmgr.TxStore
- // We'll initialize the recovery manager with a default batch size of
- // 2000.
- recoveryMgr := NewRecoveryManager(
- w.recoveryWindow, recoveryBatchSize, w.chainParams,
- )
+ // NtfnServer handles the delivery of wallet-related events (e.g., new
+ // transactions, block connections) to connected clients.
+ //
+ // TODO(yy): Deprecate.
+ NtfnServer *NotificationServer
- // In the event that this recovery is being resumed, we will need to
- // repopulate all found addresses from the database. Ideally, for basic
- // recovery, we would only do so for the default scopes, but due to a
- // bug in which the wallet would create change addresses outside of the
- // default scopes, it's necessary to attempt all registered key scopes.
- scopedMgrs := make(map[waddrmgr.KeyScope]*waddrmgr.ScopedKeyManager)
- for _, scopedMgr := range w.Manager.ActiveScopedKeyManagers() {
- scopedMgrs[scopedMgr.Scope()] = scopedMgr
- }
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- txMgrNS := tx.ReadBucket(wtxmgrNamespaceKey)
- credits, err := w.TxStore.UnspentOutputs(txMgrNS)
- if err != nil {
- return err
- }
- addrMgrNS := tx.ReadBucket(waddrmgrNamespaceKey)
- return recoveryMgr.Resurrect(addrMgrNS, scopedMgrs, credits)
- })
- if err != nil {
- return err
- }
+ // wg is a wait group used to track and wait for all long-running
+ // background goroutines to finish during a graceful shutdown.
+ wg sync.WaitGroup
- // Fetch the best height from the backend to determine when we should
- // stop.
- _, bestHeight, err := chainClient.GetBestBlock()
- if err != nil {
- return err
- }
+ // cfg holds the static configuration parameters provided when the
+ // wallet was created or loaded.
+ cfg Config
- // Now we can begin scanning the chain from the wallet's current tip to
- // ensure we properly handle restarts. Since the recovery process itself
- // acts as rescan, we'll also update our wallet's synced state along the
- // way to reflect the blocks we process and prevent rescanning them
- // later on.
- //
- // NOTE: We purposefully don't update our best height since we assume
- // that a wallet rescan will be performed from the wallet's tip, which
- // will be of bestHeight after completing the recovery process.
- var blocks []*waddrmgr.BlockStamp
- startHeight := w.Manager.SyncedTo().Height + 1
- for height := startHeight; height <= bestHeight; height++ {
- if atomic.LoadUint32(&syncer.quit) == 1 {
- return errors.New("recovery: forced shutdown")
- }
+ // sync is the dedicated synchronization component that manages the
+ // chain loop, scanning, and reorganization handling.
+ sync chainSyncer
- hash, err := chainClient.GetBlockHash(int64(height))
- if err != nil {
- return err
- }
- header, err := chainClient.GetBlockHeader(hash)
- if err != nil {
- return err
- }
- blocks = append(blocks, &waddrmgr.BlockStamp{
- Hash: *hash,
- Height: height,
- Timestamp: header.Timestamp,
- })
+ // state maintains the wallet's atomic, three-dimensional status:
+ // Lifecycle (System), Synchronization (Chain), and Authentication
+ // (Security).
+ state walletState
- // It's possible for us to run into blocks before our birthday
- // if our birthday is after our reorg safe height, so we'll make
- // sure to not add those to the batch.
- if height >= birthdayBlock.Height {
- recoveryMgr.AddToBlockBatch(
- hash, height, header.Timestamp,
- )
- }
+ // lifetimeCtx defines the runtime scope of the wallet. It is created
+ // when the wallet starts and canceled when it stops, providing a
+ // standard way to signal shutdown to all context-aware background
+ // routines.
+ //
+ // Storing a context in a struct is generally considered an
+ // anti-pattern because contexts are usually request-scoped. However,
+ // for long-lived service objects that manage their own background
+ // goroutines, maintaining a parent context for those routines is a
+ // valid exception.
+ //
+ //nolint:containedctx
+ lifetimeCtx context.Context
- // We'll perform our recovery in batches of 2000 blocks. It's
- // possible for us to reach our best height without exceeding
- // the recovery batch size, so we can proceed to commit our
- // state to disk.
- recoveryBatch := recoveryMgr.BlockBatch()
- if len(recoveryBatch) == recoveryBatchSize || height == bestHeight {
- err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- if err := w.recoverScopedAddresses(
- chainClient, tx, ns, recoveryBatch,
- recoveryMgr.State(), scopedMgrs,
- ); err != nil {
- return err
- }
+ // cancel is the cancellation function for lifetimeCtx.
+ cancel context.CancelFunc
- // TODO: Any error here will roll back this
- // entire tx. This may cause the in memory sync
- // point to become desyncronized. Refactor so
- // that this cannot happen.
- for _, block := range blocks {
- err := w.Manager.SetSyncedTo(ns, block)
- if err != nil {
- return err
- }
- }
+ // requestChan is the central communication channel for incoming
+ // lifecycle and authentication requests.
+ requestChan chan any
- return nil
- })
- if err != nil {
- return err
- }
+ // lockTimer is the timer used to automatically lock the wallet after a
+ // timeout.
+ lockTimer *time.Timer
- if len(recoveryBatch) > 0 {
- log.Infof("Recovered addresses from blocks "+
- "%d-%d", recoveryBatch[0].Height,
- recoveryBatch[len(recoveryBatch)-1].Height)
- }
+ // birthdayBlock is the block from which the wallet started scanning.
+ // It is loaded on startup and cached to avoid database lookups.
+ birthdayBlock waddrmgr.BlockStamp
+}
- // Clear the batch of all processed blocks to reuse the
- // same memory for future batches.
- blocks = blocks[:0]
- recoveryMgr.ResetBlockBatch()
- }
+// hasMinConfs checks whether a transaction at height txHeight has met minconf
+// confirmations for a blockchain at height curHeight.
+func hasMinConfs(minconf uint32, txHeight, curHeight int32) bool {
+ confs := calcConf(txHeight, curHeight)
+ if confs < 0 {
+ return false
}
- return nil
+ return uint32(confs) >= minconf
}
-// recoverScopedAddresses scans a range of blocks in attempts to recover any
-// previously used addresses for a particular account derivation path. At a high
-// level, the algorithm works as follows:
-//
-// 1. Ensure internal and external branch horizons are fully expanded.
-// 2. Filter the entire range of blocks, stopping if a non-zero number of
-// address are contained in a particular block.
-// 3. Record all internal and external addresses found in the block.
-// 4. Record any outpoints found in the block that should be watched for spends
-// 5. Trim the range of blocks up to and including the one reporting the addrs.
-// 6. Repeat from (1) if there are still more blocks in the range.
-//
-// TODO(conner): parallelize/pipeline/cache intermediate network requests
-func (w *Wallet) recoverScopedAddresses(
- chainClient chain.Interface,
- tx walletdb.ReadWriteTx,
- ns walletdb.ReadWriteBucket,
- batch []wtxmgr.BlockMeta,
- recoveryState *RecoveryState,
- scopedMgrs map[waddrmgr.KeyScope]*waddrmgr.ScopedKeyManager) error {
-
- // If there are no blocks in the batch, we are done.
- if len(batch) == 0 {
- return nil
- }
+// calcConf returns the number of confirmations for a transaction given its
+// containing block height and the current best block height. Unconfirmed
+// transactions have a height of -1 and are considered to have 0 confirmations.
+func calcConf(txHeight, curHeight int32) int32 {
+ switch {
+ // Unconfirmed transactions have 0 confirmations.
+ case txHeight == -1:
+ return 0
- log.Infof("Scanning %d blocks for recoverable addresses", len(batch))
+ // A transaction in a block after the current best block is considered
+ // unconfirmed. This can happen during a chain reorg.
+ case txHeight > curHeight:
+ return 0
-expandHorizons:
- for scope, scopedMgr := range scopedMgrs {
- scopeState := recoveryState.StateForScope(scope)
- err := expandScopeHorizons(ns, scopedMgr, scopeState)
- if err != nil {
- return err
- }
+ // Confirmed transactions have at least one confirmation.
+ default:
+ return curHeight - txHeight + 1
}
+}
- // With the internal and external horizons properly expanded, we now
- // construct the filter blocks request. The request includes the range
- // of blocks we intend to scan, in addition to the scope-index -> addr
- // map for all internal and external branches.
- filterReq := newFilterBlocksRequest(batch, scopedMgrs, recoveryState)
-
- // Initiate the filter blocks request using our chain backend. If an
- // error occurs, we are unable to proceed with the recovery.
- filterResp, err := chainClient.FilterBlocks(filterReq)
+// RemoveDescendants attempts to remove any transaction from the wallet's tx
+// store (that may be unconfirmed) that spends outputs created by the passed
+// transaction. This remove propagates recursively down the chain of descendent
+// transactions.
+func (w *Wallet) RemoveDescendants(tx *wire.MsgTx) error {
+ txRecord, err := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
if err != nil {
return err
}
- // If the filter response is empty, this signals that the rest of the
- // batch was completed, and no other addresses were discovered. As a
- // result, no further modifications to our recovery state are required
- // and we can proceed to the next batch.
- if filterResp == nil {
- return nil
- }
+ return walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
+ wtxmgrNs := tx.ReadWriteBucket(wtxmgrNamespaceKey)
+
+ return w.txStore.RemoveUnminedTx(wtxmgrNs, txRecord)
+ })
+}
+
+// BirthdayBlock returns the birthday block of the wallet.
+//
+// NOTE: The wallet won't start until the backend is synced, thus the birthday
+// block won't be set and `ErrBirthdayBlockNotSet` will be returned.
+func (w *Wallet) BirthdayBlock() (*waddrmgr.BlockStamp, error) {
+ var birthdayBlock waddrmgr.BlockStamp
- // Otherwise, retrieve the block info for the block that detected a
- // non-zero number of address matches.
- block := batch[filterResp.BatchIndex]
+ // Query the wallet's birthday block height from db.
+ err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
+ addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- // Log any non-trivial findings of addresses or outpoints.
- logFilterBlocksResp(block, filterResp)
+ bb, _, err := w.addrStore.BirthdayBlock(addrmgrNs)
+ birthdayBlock = bb
- // Report any external or internal addresses found as a result of the
- // appropriate branch recovery state. Adding indexes above the
- // last-found index of either will result in the horizons being expanded
- // upon the next iteration. Any found addresses are also marked used
- // using the scoped key manager.
- err = extendFoundAddresses(ns, filterResp, scopedMgrs, recoveryState)
- if err != nil {
return err
- }
-
- // Update the global set of watched outpoints with any that were found
- // in the block.
- for outPoint, addr := range filterResp.FoundOutPoints {
- outPoint := outPoint
- recoveryState.AddWatchedOutPoint(&outPoint, addr)
- }
-
- // Finally, record all of the relevant transactions that were returned
- // in the filter blocks response. This ensures that these transactions
- // and their outputs are tracked when the final rescan is performed.
- for _, txn := range filterResp.RelevantTxns {
- txRecord, err := wtxmgr.NewTxRecordFromMsgTx(
- txn, filterResp.BlockMeta.Time,
- )
- if err != nil {
- return err
- }
-
- err = w.addRelevantTx(tx, txRecord, &filterResp.BlockMeta)
- if err != nil {
- return err
- }
- }
-
- // Update the batch to indicate that we've processed all block through
- // the one that returned found addresses.
- batch = batch[filterResp.BatchIndex+1:]
-
- // If this was not the last block in the batch, we will repeat the
- // filtering process again after expanding our horizons.
- if len(batch) > 0 {
- goto expandHorizons
- }
-
- return nil
-}
-
-// expandScopeHorizons ensures that the ScopeRecoveryState has an adequately
-// sized look ahead for both its internal and external branches. The keys
-// derived here are added to the scope's recovery state, but do not affect the
-// persistent state of the wallet. If any invalid child keys are detected, the
-// horizon will be properly extended such that our lookahead always includes the
-// proper number of valid child keys.
-func expandScopeHorizons(ns walletdb.ReadWriteBucket,
- scopedMgr *waddrmgr.ScopedKeyManager,
- scopeState *ScopeRecoveryState) error {
-
- // Compute the current external horizon and the number of addresses we
- // must derive to ensure we maintain a sufficient recovery window for
- // the external branch.
- exHorizon, exWindow := scopeState.ExternalBranch.ExtendHorizon()
- count, childIndex := uint32(0), exHorizon
- for count < exWindow {
- keyPath := externalKeyPath(childIndex)
- addr, err := scopedMgr.DeriveFromKeyPath(ns, keyPath)
- switch {
- case err == hdkeychain.ErrInvalidChild:
- // Record the existence of an invalid child with the
- // external branch's recovery state. This also
- // increments the branch's horizon so that it accounts
- // for this skipped child index.
- scopeState.ExternalBranch.MarkInvalidChild(childIndex)
- childIndex++
- continue
-
- case err != nil:
- return err
- }
-
- // Register the newly generated external address and child index
- // with the external branch recovery state.
- scopeState.ExternalBranch.AddAddr(childIndex, addr.Address())
-
- childIndex++
- count++
- }
-
- // Compute the current internal horizon and the number of addresses we
- // must derive to ensure we maintain a sufficient recovery window for
- // the internal branch.
- inHorizon, inWindow := scopeState.InternalBranch.ExtendHorizon()
- count, childIndex = 0, inHorizon
- for count < inWindow {
- keyPath := internalKeyPath(childIndex)
- addr, err := scopedMgr.DeriveFromKeyPath(ns, keyPath)
- switch {
- case err == hdkeychain.ErrInvalidChild:
- // Record the existence of an invalid child with the
- // internal branch's recovery state. This also
- // increments the branch's horizon so that it accounts
- // for this skipped child index.
- scopeState.InternalBranch.MarkInvalidChild(childIndex)
- childIndex++
- continue
-
- case err != nil:
- return err
- }
-
- // Register the newly generated internal address and child index
- // with the internal branch recovery state.
- scopeState.InternalBranch.AddAddr(childIndex, addr.Address())
-
- childIndex++
- count++
- }
-
- return nil
-}
-
-// externalKeyPath returns the relative external derivation path /0/0/index.
-func externalKeyPath(index uint32) waddrmgr.DerivationPath {
- return waddrmgr.DerivationPath{
- InternalAccount: waddrmgr.DefaultAccountNum,
- Account: waddrmgr.DefaultAccountNum,
- Branch: waddrmgr.ExternalBranch,
- Index: index,
- }
-}
-
-// internalKeyPath returns the relative internal derivation path /0/1/index.
-func internalKeyPath(index uint32) waddrmgr.DerivationPath {
- return waddrmgr.DerivationPath{
- InternalAccount: waddrmgr.DefaultAccountNum,
- Account: waddrmgr.DefaultAccountNum,
- Branch: waddrmgr.InternalBranch,
- Index: index,
- }
-}
-
-// newFilterBlocksRequest constructs FilterBlocksRequests using our current
-// block range, scoped managers, and recovery state.
-func newFilterBlocksRequest(batch []wtxmgr.BlockMeta,
- scopedMgrs map[waddrmgr.KeyScope]*waddrmgr.ScopedKeyManager,
- recoveryState *RecoveryState) *chain.FilterBlocksRequest {
-
- filterReq := &chain.FilterBlocksRequest{
- Blocks: batch,
- ExternalAddrs: make(map[waddrmgr.ScopedIndex]address.Address),
- InternalAddrs: make(map[waddrmgr.ScopedIndex]address.Address),
- WatchedOutPoints: recoveryState.WatchedOutPoints(),
- }
-
- // Populate the external and internal addresses by merging the addresses
- // sets belong to all currently tracked scopes.
- for scope := range scopedMgrs {
- scopeState := recoveryState.StateForScope(scope)
- for index, addr := range scopeState.ExternalBranch.Addrs() {
- scopedIndex := waddrmgr.ScopedIndex{
- Scope: scope,
- Index: index,
- }
- filterReq.ExternalAddrs[scopedIndex] = addr
- }
- for index, addr := range scopeState.InternalBranch.Addrs() {
- scopedIndex := waddrmgr.ScopedIndex{
- Scope: scope,
- Index: index,
- }
- filterReq.InternalAddrs[scopedIndex] = addr
- }
- }
-
- return filterReq
-}
-
-// extendFoundAddresses accepts a filter blocks response that contains addresses
-// found on chain, and advances the state of all relevant derivation paths to
-// match the highest found child index for each branch.
-func extendFoundAddresses(ns walletdb.ReadWriteBucket,
- filterResp *chain.FilterBlocksResponse,
- scopedMgrs map[waddrmgr.KeyScope]*waddrmgr.ScopedKeyManager,
- recoveryState *RecoveryState) error {
-
- // Mark all recovered external addresses as used. This will be done only
- // for scopes that reported a non-zero number of external addresses in
- // this block.
- for scope, indexes := range filterResp.FoundExternalAddrs {
- // First, report all external child indexes found for this
- // scope. This ensures that the external last-found index will
- // be updated to include the maximum child index seen thus far.
- scopeState := recoveryState.StateForScope(scope)
- for index := range indexes {
- scopeState.ExternalBranch.ReportFound(index)
- }
-
- scopedMgr := scopedMgrs[scope]
-
- // Now, with all found addresses reported, derive and extend all
- // external addresses up to and including the current last found
- // index for this scope.
- exNextUnfound := scopeState.ExternalBranch.NextUnfound()
-
- exLastFound := exNextUnfound
- if exLastFound > 0 {
- exLastFound--
- }
-
- err := scopedMgr.ExtendExternalAddresses(
- ns, waddrmgr.DefaultAccountNum, exLastFound,
- )
- if err != nil {
- return err
- }
-
- // Finally, with the scope's addresses extended, we mark used
- // the external addresses that were found in the block and
- // belong to this scope.
- for index := range indexes {
- addr := scopeState.ExternalBranch.GetAddr(index)
- err := scopedMgr.MarkUsed(ns, addr)
- if err != nil {
- return err
- }
- }
- }
-
- // Mark all recovered internal addresses as used. This will be done only
- // for scopes that reported a non-zero number of internal addresses in
- // this block.
- for scope, indexes := range filterResp.FoundInternalAddrs {
- // First, report all internal child indexes found for this
- // scope. This ensures that the internal last-found index will
- // be updated to include the maximum child index seen thus far.
- scopeState := recoveryState.StateForScope(scope)
- for index := range indexes {
- scopeState.InternalBranch.ReportFound(index)
- }
-
- scopedMgr := scopedMgrs[scope]
-
- // Now, with all found addresses reported, derive and extend all
- // internal addresses up to and including the current last found
- // index for this scope.
- inNextUnfound := scopeState.InternalBranch.NextUnfound()
-
- inLastFound := inNextUnfound
- if inLastFound > 0 {
- inLastFound--
- }
- err := scopedMgr.ExtendInternalAddresses(
- ns, waddrmgr.DefaultAccountNum, inLastFound,
- )
- if err != nil {
- return err
- }
-
- // Finally, with the scope's addresses extended, we mark used
- // the internal addresses that were found in the blockand belong
- // to this scope.
- for index := range indexes {
- addr := scopeState.InternalBranch.GetAddr(index)
- err := scopedMgr.MarkUsed(ns, addr)
- if err != nil {
- return err
- }
- }
- }
-
- return nil
-}
-
-// logFilterBlocksResp provides useful logging information when filtering
-// succeeded in finding relevant transactions.
-func logFilterBlocksResp(block wtxmgr.BlockMeta,
- resp *chain.FilterBlocksResponse) {
-
- // Log the number of external addresses found in this block.
- var nFoundExternal int
- for _, indexes := range resp.FoundExternalAddrs {
- nFoundExternal += len(indexes)
- }
- if nFoundExternal > 0 {
- log.Infof("Recovered %d external addrs at height=%d hash=%v",
- nFoundExternal, block.Height, block.Hash)
- }
-
- // Log the number of internal addresses found in this block.
- var nFoundInternal int
- for _, indexes := range resp.FoundInternalAddrs {
- nFoundInternal += len(indexes)
- }
- if nFoundInternal > 0 {
- log.Infof("Recovered %d internal addrs at height=%d hash=%v",
- nFoundInternal, block.Height, block.Hash)
- }
-
- // Log the number of outpoints found in this block.
- nFoundOutPoints := len(resp.FoundOutPoints)
- if nFoundOutPoints > 0 {
- log.Infof("Found %d spends from watched outpoints at "+
- "height=%d hash=%v",
- nFoundOutPoints, block.Height, block.Hash)
- }
-}
-
-type (
- createTxRequest struct {
- coinSelectKeyScope *waddrmgr.KeyScope
- changeKeyScope *waddrmgr.KeyScope
- account uint32
- outputs []*wire.TxOut
- minconf int32
- feeSatPerKB btcutil.Amount
- coinSelectionStrategy CoinSelectionStrategy
- dryRun bool
- resp chan createTxResponse
- selectUtxos []wire.OutPoint
- allowUtxo func(wtxmgr.Credit) bool
- }
- createTxResponse struct {
- tx *txauthor.AuthoredTx
- err error
- }
-)
-
-// txCreator is responsible for the input selection and creation of
-// transactions. These functions are the responsibility of this method
-// (designed to be run as its own goroutine) since input selection must be
-// serialized, or else it is possible to create double spends by choosing the
-// same inputs for multiple transactions. Along with input selection, this
-// method is also responsible for the signing of transactions, since we don't
-// want to end up in a situation where we run out of inputs as multiple
-// transactions are being created. In this situation, it would then be possible
-// for both requests, rather than just one, to fail due to not enough available
-// inputs.
-func (w *Wallet) txCreator() {
- quit := w.quitChan()
-out:
- for {
- select {
- case txr := <-w.createTxRequests:
- // If the wallet can be locked because it contains
- // private key material, we need to prevent it from
- // doing so while we are assembling the transaction.
- release := func() {}
- if !w.Manager.WatchOnly() {
- heldUnlock, err := w.holdUnlock()
- if err != nil {
- txr.resp <- createTxResponse{nil, err}
- continue
- }
-
- release = heldUnlock.release
- }
-
- tx, err := w.txToOutputs(
- txr.outputs, txr.coinSelectKeyScope,
- txr.changeKeyScope, txr.account, txr.minconf,
- txr.feeSatPerKB, txr.coinSelectionStrategy,
- txr.dryRun, txr.selectUtxos, txr.allowUtxo,
- )
-
- release()
- txr.resp <- createTxResponse{tx, err}
- case <-quit:
- break out
- }
- }
- w.wg.Done()
-}
-
-// txCreateOptions is a set of optional arguments to modify the tx creation
-// process. This can be used to do things like use a custom coin selection
-// scope, which otherwise will default to the specified coin selection scope.
-type txCreateOptions struct {
- changeKeyScope *waddrmgr.KeyScope
- selectUtxos []wire.OutPoint
- allowUtxo func(wtxmgr.Credit) bool
-}
-
-// TxCreateOption is a set of optional arguments to modify the tx creation
-// process. This can be used to do things like use a custom coin selection
-// scope, which otherwise will default to the specified coin selection scope.
-type TxCreateOption func(*txCreateOptions)
-
-// defaultTxCreateOptions is the default set of options.
-func defaultTxCreateOptions() *txCreateOptions {
- return &txCreateOptions{}
-}
-
-// WithCustomChangeScope can be used to specify a change scope for the change
-// address. If unspecified, then the same scope will be used for both inputs
-// and the change addr. Not specifying any scope at all (nil) will use all
-// available coins and the default change scope (P2TR).
-func WithCustomChangeScope(changeScope *waddrmgr.KeyScope) TxCreateOption {
- return func(opts *txCreateOptions) {
- opts.changeKeyScope = changeScope
- }
-}
-
-// WithCustomSelectUtxos is used to specify the inputs to be used while
-// creating txns.
-func WithCustomSelectUtxos(utxos []wire.OutPoint) TxCreateOption {
- return func(opts *txCreateOptions) {
- opts.selectUtxos = utxos
- }
-}
-
-// WithUtxoFilter is used to restrict the selection of the internal wallet
-// inputs by further external conditions. Utxos which pass the filter are
-// considered when creating the transaction.
-func WithUtxoFilter(allowUtxo func(utxo wtxmgr.Credit) bool) TxCreateOption {
- return func(opts *txCreateOptions) {
- opts.allowUtxo = allowUtxo
- }
-}
-
-// CreateSimpleTx creates a new signed transaction spending unspent outputs with
-// at least minconf confirmations spending to any number of address/amount
-// pairs. Only unspent outputs belonging to the given key scope and account will
-// be selected, unless a key scope is not specified. In that case, inputs from all
-// accounts may be selected, no matter what key scope they belong to. This is
-// done to handle the default account case, where a user wants to fund a PSBT
-// with inputs regardless of their type (NP2WKH, P2WKH, etc.). Change and an
-// appropriate transaction fee are automatically included, if necessary. All
-// transaction creation through this function is serialized to prevent the
-// creation of many transactions which spend the same outputs.
-//
-// A set of functional options can be passed in to apply modifications to the
-// tx creation process such as using a custom change scope, which otherwise
-// defaults to the same as the specified coin selection scope.
-//
-// NOTE: The dryRun argument can be set true to create a tx that doesn't alter
-// the database. A tx created with this set to true SHOULD NOT be broadcast.
-func (w *Wallet) CreateSimpleTx(coinSelectKeyScope *waddrmgr.KeyScope,
- account uint32, outputs []*wire.TxOut, minconf int32,
- satPerKb btcutil.Amount, coinSelectionStrategy CoinSelectionStrategy,
- dryRun bool, optFuncs ...TxCreateOption) (*txauthor.AuthoredTx, error) {
-
- opts := defaultTxCreateOptions()
- for _, optFunc := range optFuncs {
- optFunc(opts)
- }
-
- // If the change scope isn't set, then it should be the same as the
- // coin selection scope in order to match existing behavior.
- if opts.changeKeyScope == nil {
- opts.changeKeyScope = coinSelectKeyScope
- }
-
- req := createTxRequest{
- coinSelectKeyScope: coinSelectKeyScope,
- changeKeyScope: opts.changeKeyScope,
- account: account,
- outputs: outputs,
- minconf: minconf,
- feeSatPerKB: satPerKb,
- coinSelectionStrategy: coinSelectionStrategy,
- dryRun: dryRun,
- resp: make(chan createTxResponse),
- selectUtxos: opts.selectUtxos,
- allowUtxo: opts.allowUtxo,
- }
- w.createTxRequests <- req
- resp := <-req.resp
- return resp.tx, resp.err
-}
-
-type (
- unlockRequest struct {
- passphrase []byte
- lockAfter <-chan time.Time // nil prevents the timeout.
- err chan error
- }
-
- changePassphraseRequest struct {
- old, new []byte
- private bool
- err chan error
- }
-
- changePassphrasesRequest struct {
- publicOld, publicNew []byte
- privateOld, privateNew []byte
- err chan error
- }
-
- // heldUnlock is a tool to prevent the wallet from automatically
- // locking after some timeout before an operation which needed
- // the unlocked wallet has finished. Any acquired heldUnlock
- // *must* be released (preferably with a defer) or the wallet
- // will forever remain unlocked.
- heldUnlock chan struct{}
-)
-
-// endRecovery tells (*Wallet).recovery to stop, if running, and returns a
-// channel that will be closed when the recovery routine exits.
-func (w *Wallet) endRecovery() <-chan struct{} {
- if recoverySyncI := w.recovering.Load(); recoverySyncI != nil {
- recoverySync := recoverySyncI.(*recoverySyncer)
-
- // If recovery is still running, it will end early with an error
- // once we set the quit flag.
- atomic.StoreUint32(&recoverySync.quit, 1)
-
- return recoverySync.done
- }
- c := make(chan struct{})
- close(c)
- return c
-}
-
-// walletLocker manages the locked/unlocked state of a wallet.
-func (w *Wallet) walletLocker() {
- var timeout <-chan time.Time
- holdChan := make(heldUnlock)
- quit := w.quitChan()
-out:
- for {
- select {
- case req := <-w.unlockRequests:
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- return w.Manager.Unlock(addrmgrNs, req.passphrase)
- })
- if err != nil {
- req.err <- err
- continue
- }
- timeout = req.lockAfter
- if timeout == nil {
- log.Info("The wallet has been unlocked without a time limit")
- } else {
- log.Info("The wallet has been temporarily unlocked")
- }
- req.err <- nil
- continue
-
- case req := <-w.changePassphrase:
- err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- return w.Manager.ChangePassphrase(
- addrmgrNs, req.old, req.new, req.private,
- &waddrmgr.DefaultScryptOptions,
- )
- })
- req.err <- err
- continue
-
- case req := <-w.changePassphrases:
- err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- err := w.Manager.ChangePassphrase(
- addrmgrNs, req.publicOld, req.publicNew,
- false, &waddrmgr.DefaultScryptOptions,
- )
- if err != nil {
- return err
- }
-
- return w.Manager.ChangePassphrase(
- addrmgrNs, req.privateOld, req.privateNew,
- true, &waddrmgr.DefaultScryptOptions,
- )
- })
- req.err <- err
- continue
-
- case req := <-w.holdUnlockRequests:
- if w.Manager.IsLocked() {
- close(req)
- continue
- }
-
- req <- holdChan
- <-holdChan // Block until the lock is released.
-
- // If, after holding onto the unlocked wallet for some
- // time, the timeout has expired, lock it now instead
- // of hoping it gets unlocked next time the top level
- // select runs.
- select {
- case <-timeout:
- // Let the top level select fallthrough so the
- // wallet is locked.
- default:
- continue
- }
-
- case w.lockState <- w.Manager.IsLocked():
- continue
-
- case <-quit:
- break out
-
- case <-w.lockRequests:
- case <-timeout:
- }
-
- // Select statement fell through by an explicit lock or the
- // timer expiring. Lock the manager here.
-
- // We can't lock the manager if recovery is active because we use
- // cryptoKeyPriv and cryptoKeyScript in recovery.
- <-w.endRecovery()
-
- timeout = nil
- err := w.Manager.Lock()
- if err != nil && !waddrmgr.IsError(err, waddrmgr.ErrLocked) {
- log.Errorf("Could not lock wallet: %v", err)
- } else {
- log.Info("The wallet has been locked")
- }
- }
- w.wg.Done()
-}
-
-// Unlock unlocks the wallet's address manager and relocks it after timeout has
-// expired. If the wallet is already unlocked and the new passphrase is
-// correct, the current timeout is replaced with the new one. The wallet will
-// be locked if the passphrase is incorrect or any other error occurs during the
-// unlock.
-func (w *Wallet) Unlock(passphrase []byte, lock <-chan time.Time) error {
- err := make(chan error, 1)
- w.unlockRequests <- unlockRequest{
- passphrase: passphrase,
- lockAfter: lock,
- err: err,
- }
- return <-err
-}
-
-// Lock locks the wallet's address manager.
-func (w *Wallet) Lock() {
- w.lockRequests <- struct{}{}
-}
-
-// Locked returns whether the account manager for a wallet is locked.
-func (w *Wallet) Locked() bool {
- return <-w.lockState
-}
-
-// holdUnlock prevents the wallet from being locked. The heldUnlock object
-// *must* be released, or the wallet will forever remain unlocked.
-//
-// TODO: To prevent the above scenario, perhaps closures should be passed
-// to the walletLocker goroutine and disallow callers from explicitly
-// handling the locking mechanism.
-func (w *Wallet) holdUnlock() (heldUnlock, error) {
- req := make(chan heldUnlock)
- w.holdUnlockRequests <- req
- hl, ok := <-req
- if !ok {
- // TODO(davec): This should be defined and exported from
- // waddrmgr.
- return nil, waddrmgr.ManagerError{
- ErrorCode: waddrmgr.ErrLocked,
- Description: "address manager is locked",
- }
- }
- return hl, nil
-}
-
-// release releases the hold on the unlocked-state of the wallet and allows the
-// wallet to be locked again. If a lock timeout has already expired, the
-// wallet is locked again as soon as release is called.
-func (c heldUnlock) release() {
- c <- struct{}{}
-}
-
-// ChangePrivatePassphrase attempts to change the passphrase for a wallet from
-// old to new. Changing the passphrase is synchronized with all other address
-// manager locking and unlocking. The lock state will be the same as it was
-// before the password change.
-func (w *Wallet) ChangePrivatePassphrase(old, new []byte) error {
- err := make(chan error, 1)
- w.changePassphrase <- changePassphraseRequest{
- old: old,
- new: new,
- private: true,
- err: err,
- }
- return <-err
-}
-
-// ChangePublicPassphrase modifies the public passphrase of the wallet.
-func (w *Wallet) ChangePublicPassphrase(old, new []byte) error {
- err := make(chan error, 1)
- w.changePassphrase <- changePassphraseRequest{
- old: old,
- new: new,
- private: false,
- err: err,
- }
- return <-err
-}
-
-// ChangePassphrases modifies the public and private passphrase of the wallet
-// atomically.
-func (w *Wallet) ChangePassphrases(publicOld, publicNew, privateOld,
- privateNew []byte) error {
-
- err := make(chan error, 1)
- w.changePassphrases <- changePassphrasesRequest{
- publicOld: publicOld,
- publicNew: publicNew,
- privateOld: privateOld,
- privateNew: privateNew,
- err: err,
- }
- return <-err
-}
-
-// AccountAddresses returns the addresses for every created address for an
-// account.
-func (w *Wallet) AccountAddresses(account uint32) ([]address.Address, error) {
- var addrs []address.Address
-
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
-
- return w.Manager.ForEachAccountAddress(
- addrmgrNs, account, func(maddr waddrmgr.ManagedAddress) error {
- addrs = append(addrs, maddr.Address())
- return nil
- },
- )
- })
- if err != nil {
- return nil, err
- }
-
- return addrs, nil
-}
-
-// AccountManagedAddresses returns the managed addresses for every created
-// address for an account.
-func (w *Wallet) AccountManagedAddresses(scope waddrmgr.KeyScope,
- accountNum uint32) ([]waddrmgr.ManagedAddress, error) {
-
- scopedMgr, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return nil, err
- }
-
- addrs := make([]waddrmgr.ManagedAddress, 0)
-
- err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
-
- return scopedMgr.ForEachAccountAddress(
- addrmgrNs, accountNum,
- func(a waddrmgr.ManagedAddress) error {
- addrs = append(addrs, a)
-
- return nil
- },
- )
- },
- )
- if err != nil {
- return nil, err
- }
-
- return addrs, nil
-}
-
-// CalculateBalance sums the amounts of all unspent transaction
-// outputs to addresses of a wallet and returns the balance.
-//
-// If confirmations is 0, all UTXOs, even those not present in a
-// block (height -1), will be used to get the balance. Otherwise,
-// a UTXO must be in a block. If confirmations is 1 or greater,
-// the balance will be calculated based on how many how many blocks
-// include a UTXO.
-func (w *Wallet) CalculateBalance(confirms int32) (btcutil.Amount, error) {
- var balance btcutil.Amount
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
- var err error
- blk := w.Manager.SyncedTo()
- balance, err = w.TxStore.Balance(txmgrNs, confirms, blk.Height)
- return err
- })
- return balance, err
-}
-
-// Balances records total, spendable (by policy), and immature coinbase
-// reward balance amounts.
-type Balances struct {
- Total btcutil.Amount
- Spendable btcutil.Amount
- ImmatureReward btcutil.Amount
-}
-
-// CalculateAccountBalances sums the amounts of all unspent transaction
-// outputs to the given account of a wallet and returns the balance.
-//
-// This function is much slower than it needs to be since transactions outputs
-// are not indexed by the accounts they credit to, and all unspent transaction
-// outputs must be iterated.
-func (w *Wallet) CalculateAccountBalances(account uint32,
- confirms int32) (Balances, error) {
-
- var bals Balances
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
-
- // Get current block. The block height used for calculating
- // the number of tx confirmations.
- syncBlock := w.Manager.SyncedTo()
-
- unspent, err := w.TxStore.UnspentOutputs(txmgrNs)
- if err != nil {
- return err
- }
- for i := range unspent {
- output := &unspent[i]
-
- var outputAcct uint32
- _, addrs, _, err := txscript.ExtractPkScriptAddrs(
- output.PkScript, w.chainParams)
- if err == nil && len(addrs) > 0 {
- _, outputAcct, err = w.Manager.AddrAccount(addrmgrNs, addrs[0])
- }
- if err != nil || outputAcct != account {
- continue
- }
-
- bals.Total += output.Amount
- if output.FromCoinBase && !hasMinConfs(
- int32(w.chainParams.CoinbaseMaturity),
- output.Height, syncBlock.Height,
- ) {
-
- bals.ImmatureReward += output.Amount
- } else if hasMinConfs(
- confirms, output.Height, syncBlock.Height,
- ) {
-
- bals.Spendable += output.Amount
- }
- }
- return nil
- })
- return bals, err
-}
-
-// CurrentAddress gets the most recently requested Bitcoin payment address
-// from a wallet for a particular key-chain scope. If the address has already
-// been used (there is at least one transaction spending to it in the
-// blockchain or btcd mempool), the next chained address is returned.
-func (w *Wallet) CurrentAddress(account uint32,
- scope waddrmgr.KeyScope) (address.Address, error) {
-
- chainClient, err := w.requireChainClient()
- if err != nil {
- return nil, err
- }
-
- manager, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return nil, err
- }
-
- // The address manager uses OnCommit on the walletdb tx to update the
- // in-memory state of the account state. But because the commit happens
- // _after_ the account manager internal lock has been released, there
- // is a chance for the address index to be accessed concurrently, even
- // though the closure in OnCommit re-acquires the lock. To avoid this
- // issue, we surround the whole address creation process with a lock.
- w.newAddrMtx.Lock()
- defer w.newAddrMtx.Unlock()
-
- var (
- addr address.Address
- props *waddrmgr.AccountProperties
- )
- err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- maddr, err := manager.LastExternalAddress(addrmgrNs, account)
- if err != nil {
- // If no address exists yet, create the first external
- // address.
- if waddrmgr.IsError(err, waddrmgr.ErrAddressNotFound) {
- addr, props, err = w.newAddress(
- addrmgrNs, account, scope,
- )
- }
- return err
- }
-
- // Get next chained address if the last one has already been
- // used.
- if maddr.Used(addrmgrNs) {
- addr, props, err = w.newAddress(
- addrmgrNs, account, scope,
- )
- return err
- }
-
- addr = maddr.Address()
- return nil
- })
- if err != nil {
- return nil, err
- }
-
- // If the props have been initially, then we had to create a new address
- // to satisfy the query. Notify the rpc server about the new address.
- if props != nil {
- err = chainClient.NotifyReceived([]address.Address{addr})
- if err != nil {
- return nil, err
- }
-
- w.NtfnServer.notifyAccountProperties(props)
- }
-
- return addr, nil
-}
-
-// PubKeyForAddress looks up the associated public key for a P2PKH address.
-func (w *Wallet) PubKeyForAddress(a address.Address) (*btcec.PublicKey, error) {
- var pubKey *btcec.PublicKey
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- managedAddr, err := w.Manager.Address(addrmgrNs, a)
- if err != nil {
- return err
- }
- managedPubKeyAddr, ok := managedAddr.(waddrmgr.ManagedPubKeyAddress)
- if !ok {
- return errors.New("address does not have an associated public key")
- }
- pubKey = managedPubKeyAddr.PubKey()
- return nil
- })
- return pubKey, err
-}
-
-// LabelTransaction adds a label to the transaction with the hash provided. The
-// call will fail if the label is too long, or if the transaction already has
-// a label and the overwrite boolean is not set.
-func (w *Wallet) LabelTransaction(hash chainhash.Hash, label string,
- overwrite bool) error {
-
- // Check that the transaction is known to the wallet, and fail if it is
- // unknown. If the transaction is known, check whether it already has
- // a label.
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
-
- dbTx, err := w.TxStore.TxDetails(txmgrNs, &hash)
- if err != nil {
- return err
- }
-
- // If the transaction looked up is nil, it was not found. We
- // do not allow labelling of unknown transactions so we fail.
- if dbTx == nil {
- return ErrUnknownTransaction
- }
-
- _, err = wtxmgr.FetchTxLabel(txmgrNs, hash)
- return err
- })
-
- switch err {
- // If no labels have been written yet, we can silence the error.
- // Likewise if there is no label, we do not need to do any overwrite
- // checks.
- case wtxmgr.ErrNoLabelBucket:
- case wtxmgr.ErrTxLabelNotFound:
-
- // If we successfully looked up a label, fail if the overwrite param
- // is not set.
- case nil:
- if !overwrite {
- return ErrTxLabelExists
- }
-
- // In another unrelated error occurred, return it.
- default:
- return err
- }
-
- return walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- txmgrNs := tx.ReadWriteBucket(wtxmgrNamespaceKey)
- return w.TxStore.PutTxLabel(txmgrNs, hash, label)
- })
-}
-
-// PrivKeyForAddress looks up the associated private key for a P2PKH or P2PK
-// address.
-func (w *Wallet) PrivKeyForAddress(
- a address.Address) (*btcec.PrivateKey, error) {
-
- var privKey *btcec.PrivateKey
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- managedAddr, err := w.Manager.Address(addrmgrNs, a)
- if err != nil {
- return err
- }
- managedPubKeyAddr, ok := managedAddr.(waddrmgr.ManagedPubKeyAddress)
- if !ok {
- return errors.New("address does not have an associated private key")
- }
- privKey, err = managedPubKeyAddr.PrivKey()
- return err
- })
- return privKey, err
-}
-
-// HaveAddress returns whether the wallet is the owner of the address a.
-func (w *Wallet) HaveAddress(a address.Address) (bool, error) {
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- _, err := w.Manager.Address(addrmgrNs, a)
- return err
- })
- if err == nil {
- return true, nil
- }
- if waddrmgr.IsError(err, waddrmgr.ErrAddressNotFound) {
- return false, nil
- }
- return false, err
-}
-
-// AccountOfAddress finds the account that an address is associated with.
-func (w *Wallet) AccountOfAddress(a address.Address) (uint32, error) {
- var account uint32
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- var err error
- _, account, err = w.Manager.AddrAccount(addrmgrNs, a)
- return err
- })
- return account, err
-}
-
-// AddressInfo returns detailed information regarding a wallet address.
-func (w *Wallet) AddressInfo(
- a address.Address) (waddrmgr.ManagedAddress, error) {
-
- var managedAddress waddrmgr.ManagedAddress
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- var err error
- managedAddress, err = w.Manager.Address(addrmgrNs, a)
- return err
- })
- return managedAddress, err
-}
-
-// AccountNumber returns the account number for an account name under a
-// particular key scope.
-func (w *Wallet) AccountNumber(scope waddrmgr.KeyScope,
- accountName string) (uint32, error) {
-
- manager, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return 0, err
- }
-
- var account uint32
- err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- var err error
- account, err = manager.LookupAccount(addrmgrNs, accountName)
- return err
- })
- return account, err
-}
-
-// AccountName returns the name of an account.
-func (w *Wallet) AccountName(scope waddrmgr.KeyScope,
- accountNumber uint32) (string, error) {
-
- manager, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return "", err
- }
-
- var accountName string
- err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- var err error
- accountName, err = manager.AccountName(addrmgrNs, accountNumber)
- return err
- })
- return accountName, err
-}
-
-// AccountProperties returns the properties of an account, including address
-// indexes and name. It first fetches the desynced information from the address
-// manager, then updates the indexes based on the address pools.
-func (w *Wallet) AccountProperties(scope waddrmgr.KeyScope,
- acct uint32) (*waddrmgr.AccountProperties, error) {
-
- manager, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return nil, err
- }
-
- var props *waddrmgr.AccountProperties
- err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- waddrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- var err error
- props, err = manager.AccountProperties(waddrmgrNs, acct)
- return err
- })
- return props, err
-}
-
-// AccountPropertiesByName returns the properties of an account by its name. It
-// first fetches the desynced information from the address manager, then updates
-// the indexes based on the address pools.
-func (w *Wallet) AccountPropertiesByName(scope waddrmgr.KeyScope,
- name string) (*waddrmgr.AccountProperties, error) {
-
- manager, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return nil, err
- }
-
- var props *waddrmgr.AccountProperties
- err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- waddrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- acct, err := manager.LookupAccount(waddrmgrNs, name)
- if err != nil {
- return err
- }
- props, err = manager.AccountProperties(waddrmgrNs, acct)
- return err
- })
- return props, err
-}
-
-// LookupAccount returns the corresponding key scope and account number for the
-// account with the given name.
-func (w *Wallet) LookupAccount(name string) (waddrmgr.KeyScope, uint32, error) {
- var (
- keyScope waddrmgr.KeyScope
- account uint32
- )
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- ns := tx.ReadBucket(waddrmgrNamespaceKey)
- var err error
- keyScope, account, err = w.Manager.LookupAccount(ns, name)
- return err
- })
- return keyScope, account, err
-}
-
-// RenameAccount sets the name for an account number to newName.
-func (w *Wallet) RenameAccount(scope waddrmgr.KeyScope, account uint32,
- newName string) error {
-
- manager, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return err
- }
-
- var props *waddrmgr.AccountProperties
- err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- err := manager.RenameAccount(addrmgrNs, account, newName)
- if err != nil {
- return err
- }
- props, err = manager.AccountProperties(addrmgrNs, account)
- return err
- })
- if err == nil {
- w.NtfnServer.notifyAccountProperties(props)
- }
- return err
-}
-
-// NextAccount creates the next account and returns its account number. The
-// name must be unique to the account. In order to support automatic seed
-// restoring, new accounts may not be created when all of the previous 100
-// accounts have no transaction history (this is a deviation from the BIP0044
-// spec, which allows no unused account gaps).
-func (w *Wallet) NextAccount(scope waddrmgr.KeyScope,
- name string) (uint32, error) {
-
- manager, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return 0, err
- }
-
- var (
- account uint32
- props *waddrmgr.AccountProperties
- )
- err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- var err error
- account, err = manager.NewAccount(addrmgrNs, name)
- if err != nil {
- return err
- }
- props, err = manager.AccountProperties(addrmgrNs, account)
- return err
- })
- if err != nil {
- log.Errorf("Cannot fetch new account properties for notification "+
- "after account creation: %v", err)
- } else {
- w.NtfnServer.notifyAccountProperties(props)
- }
- return account, err
-}
-
-// CreditCategory describes the type of wallet transaction output. The category
-// of "sent transactions" (debits) is always "send", and is not expressed by
-// this type.
-//
-// TODO: This is a requirement of the RPC server and should be moved.
-type CreditCategory byte
-
-// These constants define the possible credit categories.
-const (
- CreditReceive CreditCategory = iota
- CreditGenerate
- CreditImmature
-)
-
-// String returns the category as a string. This string may be used as the
-// JSON string for categories as part of listtransactions and gettransaction
-// RPC responses.
-func (c CreditCategory) String() string {
- switch c {
- case CreditReceive:
- return "receive"
- case CreditGenerate:
- return "generate"
- case CreditImmature:
- return "immature"
- default:
- return "unknown"
- }
-}
-
-// RecvCategory returns the category of received credit outputs from a
-// transaction record. The passed block chain height is used to distinguish
-// immature from mature coinbase outputs.
-//
-// TODO: This is intended for use by the RPC server and should be moved out of
-// this package at a later time.
-func RecvCategory(details *wtxmgr.TxDetails, syncHeight int32,
- net *chaincfg.Params) CreditCategory {
-
- if blockchain.IsCoinBaseTx(&details.MsgTx) {
- if hasMinConfs(
- int32(net.CoinbaseMaturity), details.Block.Height,
- syncHeight,
- ) {
-
- return CreditGenerate
- }
- return CreditImmature
- }
- return CreditReceive
-}
-
-// listTransactions creates a object that may be marshalled to a response result
-// for a listtransactions RPC.
-//
-// TODO: This should be moved to the legacyrpc package.
-//
-//nolint:cyclop,gocognit
-func listTransactions(tx walletdb.ReadTx, details *wtxmgr.TxDetails,
- addrMgr *waddrmgr.Manager, syncHeight int32,
- net *chaincfg.Params) []btcjson.ListTransactionsResult {
-
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
-
- var (
- blockHashStr string
- blockTime int64
- confirmations int64
- )
- if details.Block.Height != -1 {
- blockHashStr = details.Block.Hash.String()
- blockTime = details.Block.Time.Unix()
- confirmations = int64(
- calcConf(details.Block.Height, syncHeight),
- )
- }
-
- results := []btcjson.ListTransactionsResult{}
- txHashStr := details.Hash.String()
- received := details.Received.Unix()
- generated := blockchain.IsCoinBaseTx(&details.MsgTx)
- recvCat := RecvCategory(details, syncHeight, net).String()
-
- send := len(details.Debits) != 0
-
- // Fee can only be determined if every input is a debit.
- var feeF64 float64
- if len(details.Debits) == len(details.MsgTx.TxIn) {
- var debitTotal btcutil.Amount
- for _, deb := range details.Debits {
- debitTotal += deb.Amount
- }
- var outputTotal btcutil.Amount
- for _, output := range details.MsgTx.TxOut {
- outputTotal += btcutil.Amount(output.Value)
- }
- // Note: The actual fee is debitTotal - outputTotal. However,
- // this RPC reports negative numbers for fees, so the inverse
- // is calculated.
- feeF64 = (outputTotal - debitTotal).ToBTC()
- }
-
-outputs:
- for i, output := range details.MsgTx.TxOut {
- // Determine if this output is a credit, and if so, determine
- // its spentness.
- var isCredit bool
- var spentCredit bool
- for _, cred := range details.Credits {
- if cred.Index == uint32(i) {
- // Change outputs are ignored.
- if cred.Change {
- continue outputs
- }
-
- isCredit = true
- spentCredit = cred.Spent
- break
- }
- }
-
- var address string
- var accountName string
- _, addrs, _, _ := txscript.ExtractPkScriptAddrs(output.PkScript, net)
- if len(addrs) == 1 {
- addr := addrs[0]
- address = addr.EncodeAddress()
- mgr, account, err := addrMgr.AddrAccount(addrmgrNs, addrs[0])
- if err == nil {
- accountName, err = mgr.AccountName(addrmgrNs, account)
- if err != nil {
- accountName = ""
- }
- }
- }
-
- amountF64 := btcutil.Amount(output.Value).ToBTC()
- result := btcjson.ListTransactionsResult{
- // Fields left zeroed:
- // InvolvesWatchOnly
- // BlockIndex
- //
- // Fields set below:
- // Account (only for non-"send" categories)
- // Category
- // Amount
- // Fee
- Address: address,
- Vout: uint32(i),
- Confirmations: confirmations,
- Generated: generated,
- BlockHash: blockHashStr,
- BlockTime: blockTime,
- TxID: txHashStr,
- WalletConflicts: []string{},
- Time: received,
- TimeReceived: received,
- }
-
- // Add a received/generated/immature result if this is a credit.
- // If the output was spent, create a second result under the
- // send category with the inverse of the output amount. It is
- // therefore possible that a single output may be included in
- // the results set zero, one, or two times.
- //
- // Since credits are not saved for outputs that are not
- // controlled by this wallet, all non-credits from transactions
- // with debits are grouped under the send category.
-
- if send || spentCredit {
- result.Category = "send"
- result.Amount = -amountF64
- result.Fee = &feeF64
- results = append(results, result)
- }
- if isCredit {
- result.Account = accountName
- result.Category = recvCat
- result.Amount = amountF64
- result.Fee = nil
- results = append(results, result)
- }
- }
- return results
-}
-
-// ListSinceBlock returns a slice of objects with details about transactions
-// since the given block. If the block is -1 then all transactions are included.
-// This is intended to be used for listsinceblock RPC replies.
-func (w *Wallet) ListSinceBlock(start, end,
- syncHeight int32) ([]btcjson.ListTransactionsResult, error) {
-
- txList := []btcjson.ListTransactionsResult{}
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
-
- rangeFn := func(details []wtxmgr.TxDetails) (bool, error) {
- for _, detail := range details {
- detail := detail
-
- jsonResults := listTransactions(
- tx, &detail, w.Manager, syncHeight,
- w.chainParams,
- )
- txList = append(txList, jsonResults...)
- }
- return false, nil
- }
-
- return w.TxStore.RangeTransactions(txmgrNs, start, end, rangeFn)
- })
- return txList, err
-}
-
-// ListTransactions returns a slice of objects with details about a recorded
-// transaction. This is intended to be used for listtransactions RPC
-// replies.
-func (w *Wallet) ListTransactions(from,
- count int) ([]btcjson.ListTransactionsResult, error) {
-
- txList := []btcjson.ListTransactionsResult{}
-
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
-
- // Get current block. The block height used for calculating
- // the number of tx confirmations.
- syncBlock := w.Manager.SyncedTo()
-
- // Need to skip the first from transactions, and after those, only
- // include the next count transactions.
- skipped := 0
- n := 0
-
- rangeFn := func(details []wtxmgr.TxDetails) (bool, error) {
- // Iterate over transactions at this height in reverse order.
- // This does nothing for unmined transactions, which are
- // unsorted, but it will process mined transactions in the
- // reverse order they were marked mined.
- for i := len(details) - 1; i >= 0; i-- {
- if from > skipped {
- skipped++
- continue
- }
-
- n++
- if n > count {
- return true, nil
- }
-
- jsonResults := listTransactions(tx, &details[i],
- w.Manager, syncBlock.Height, w.chainParams)
- txList = append(txList, jsonResults...)
-
- if len(jsonResults) > 0 {
- n++
- }
- }
-
- return false, nil
- }
-
- // Return newer results first by starting at mempool height and working
- // down to the genesis block.
- return w.TxStore.RangeTransactions(txmgrNs, -1, 0, rangeFn)
- })
- return txList, err
-}
-
-// ListAddressTransactions returns a slice of objects with details about
-// recorded transactions to or from any address belonging to a set. This is
-// intended to be used for listaddresstransactions RPC replies.
-func (w *Wallet) ListAddressTransactions(
- pkHashes map[string]struct{}) ([]btcjson.ListTransactionsResult, error) {
-
- txList := []btcjson.ListTransactionsResult{}
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
-
- // Get current block. The block height used for calculating
- // the number of tx confirmations.
- syncBlock := w.Manager.SyncedTo()
- rangeFn := func(details []wtxmgr.TxDetails) (bool, error) {
- loopDetails:
- for i := range details {
- detail := &details[i]
-
- for _, cred := range detail.Credits {
- pkScript := detail.MsgTx.TxOut[cred.Index].PkScript
- _, addrs, _, err := txscript.ExtractPkScriptAddrs(
- pkScript, w.chainParams)
- if err != nil || len(addrs) != 1 {
- continue
- }
-
- apkh, ok := addrs[0].(*address.AddressPubKeyHash)
- if !ok {
- continue
- }
- _, ok = pkHashes[string(apkh.ScriptAddress())]
- if !ok {
- continue
- }
-
- jsonResults := listTransactions(tx, detail,
- w.Manager, syncBlock.Height, w.chainParams)
- txList = append(txList, jsonResults...)
- continue loopDetails
- }
- }
- return false, nil
- }
-
- return w.TxStore.RangeTransactions(txmgrNs, 0, -1, rangeFn)
- })
- return txList, err
-}
-
-// ListAllTransactions returns a slice of objects with details about a recorded
-// transaction. This is intended to be used for listalltransactions RPC
-// replies.
-func (w *Wallet) ListAllTransactions() ([]btcjson.ListTransactionsResult,
- error) {
-
- txList := []btcjson.ListTransactionsResult{}
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
-
- // Get current block. The block height used for calculating
- // the number of tx confirmations.
- syncBlock := w.Manager.SyncedTo()
-
- rangeFn := func(details []wtxmgr.TxDetails) (bool, error) {
- // Iterate over transactions at this height in reverse order.
- // This does nothing for unmined transactions, which are
- // unsorted, but it will process mined transactions in the
- // reverse order they were marked mined.
- for i := len(details) - 1; i >= 0; i-- {
- jsonResults := listTransactions(tx, &details[i], w.Manager,
- syncBlock.Height, w.chainParams)
- txList = append(txList, jsonResults...)
- }
- return false, nil
- }
-
- // Return newer results first by starting at mempool height and
- // working down to the genesis block.
- return w.TxStore.RangeTransactions(txmgrNs, -1, 0, rangeFn)
- })
- return txList, err
-}
-
-// BlockIdentifier identifies a block by either a height or a hash.
-type BlockIdentifier struct {
- height int32
- hash *chainhash.Hash
-}
-
-// NewBlockIdentifierFromHeight constructs a BlockIdentifier for a block height.
-func NewBlockIdentifierFromHeight(height int32) *BlockIdentifier {
- return &BlockIdentifier{height: height}
-}
-
-// NewBlockIdentifierFromHash constructs a BlockIdentifier for a block hash.
-func NewBlockIdentifierFromHash(hash *chainhash.Hash) *BlockIdentifier {
- return &BlockIdentifier{hash: hash}
-}
-
-// GetTransactionsResult is the result of the wallet's GetTransactions method.
-// See GetTransactions for more details.
-type GetTransactionsResult struct {
- MinedTransactions []Block
- UnminedTransactions []TransactionSummary
-}
-
-// GetTransactions returns transaction results between a starting and ending
-// block. Blocks in the block range may be specified by either a height or a
-// hash.
-//
-// Because this is a possibly lenghtly operation, a cancel channel is provided
-// to cancel the task. If this channel unblocks, the results created thus far
-// will be returned.
-//
-// Transaction results are organized by blocks in ascending order and unmined
-// transactions in an unspecified order. Mined transactions are saved in a
-// Block structure which records properties about the block.
-func (w *Wallet) GetTransactions(startBlock, endBlock *BlockIdentifier,
- _ string, cancel <-chan struct{}) (*GetTransactionsResult, error) {
-
- var start, end int32 = 0, -1
-
- w.chainClientLock.Lock()
- chainClient := w.chainClient
- w.chainClientLock.Unlock()
-
- // TODO: Fetching block heights by their hashes is inherently racy
- // because not all block headers are saved but when they are for SPV the
- // db can be queried directly without this.
- if startBlock != nil {
- if startBlock.hash == nil {
- start = startBlock.height
- } else {
- if chainClient == nil {
- return nil, errors.New("no chain server client")
- }
- switch client := chainClient.(type) {
- case *chain.RPCClient:
- startHeader, err := client.GetBlockHeaderVerbose(
- startBlock.hash,
- )
- if err != nil {
- return nil, err
- }
- start = startHeader.Height
- case *chain.BitcoindClient:
- var err error
- start, err = client.GetBlockHeight(startBlock.hash)
- if err != nil {
- return nil, err
- }
- case *chain.NeutrinoClient:
- var err error
- start, err = client.GetBlockHeight(startBlock.hash)
- if err != nil {
- return nil, err
- }
- }
- }
- }
- if endBlock != nil {
- if endBlock.hash == nil {
- end = endBlock.height
- } else {
- if chainClient == nil {
- return nil, errors.New("no chain server client")
- }
- switch client := chainClient.(type) {
- case *chain.RPCClient:
- endHeader, err := client.GetBlockHeaderVerbose(
- endBlock.hash,
- )
- if err != nil {
- return nil, err
- }
- end = endHeader.Height
- case *chain.BitcoindClient:
- var err error
- start, err = client.GetBlockHeight(endBlock.hash)
- if err != nil {
- return nil, err
- }
- case *chain.NeutrinoClient:
- var err error
- end, err = client.GetBlockHeight(endBlock.hash)
- if err != nil {
- return nil, err
- }
- }
- }
- }
-
- var res GetTransactionsResult
- err := walletdb.View(w.db, func(dbtx walletdb.ReadTx) error {
- txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
-
- rangeFn := func(details []wtxmgr.TxDetails) (bool, error) {
- // TODO: probably should make RangeTransactions not reuse the
- // details backing array memory.
- dets := make([]wtxmgr.TxDetails, len(details))
- copy(dets, details)
- details = dets
-
- txs := make([]TransactionSummary, 0, len(details))
- for i := range details {
- txs = append(txs, makeTxSummary(dbtx, w, &details[i]))
- }
-
- if details[0].Block.Height != -1 {
- blockHash := details[0].Block.Hash
- res.MinedTransactions = append(res.MinedTransactions, Block{
- Hash: &blockHash,
- Height: details[0].Block.Height,
- Timestamp: details[0].Block.Time.Unix(),
- Transactions: txs,
- })
- } else {
- res.UnminedTransactions = txs
- }
-
- select {
- case <-cancel:
- return true, nil
- default:
- return false, nil
- }
- }
-
- return w.TxStore.RangeTransactions(txmgrNs, start, end, rangeFn)
- })
- return &res, err
-}
-
-// GetTransactionResult returns a summary of the transaction along with
-// other block properties.
-type GetTransactionResult struct {
- Summary TransactionSummary
- Height int32
- BlockHash *chainhash.Hash
- Confirmations int32
- Timestamp int64
-}
-
-// GetTransaction returns detailed data of a transaction given its id. In
-// addition it returns properties about its block.
-func (w *Wallet) GetTransaction(txHash chainhash.Hash) (*GetTransactionResult,
- error) {
-
- var res GetTransactionResult
- err := walletdb.View(w.db, func(dbtx walletdb.ReadTx) error {
- txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
-
- txDetail, err := w.TxStore.TxDetails(txmgrNs, &txHash)
- if err != nil {
- return err
- }
-
- // If the transaction was not found we return an error.
- if txDetail == nil {
- return fmt.Errorf("%w: txid %v", ErrNoTx, txHash)
- }
-
- res = GetTransactionResult{
- Summary: makeTxSummary(dbtx, w, txDetail),
- BlockHash: nil,
- Height: -1,
- Confirmations: 0,
- Timestamp: 0,
- }
-
- // If it is a confirmed transaction we set the corresponding
- // block height, timestamp, hash, and confirmations.
- if txDetail.Block.Height != -1 {
- res.Height = txDetail.Block.Height
- res.Timestamp = txDetail.Block.Time.Unix()
- res.BlockHash = &txDetail.Block.Hash
-
- bestBlock := w.SyncedTo()
- blockHeight := txDetail.Block.Height
- res.Confirmations = calcConf(
- blockHeight, bestBlock.Height,
- )
- }
-
- return nil
- })
- if err != nil {
- return nil, err
- }
- return &res, nil
-}
-
-// AccountResult is a single account result for the AccountsResult type.
-type AccountResult struct {
- waddrmgr.AccountProperties
- TotalBalance btcutil.Amount
-}
-
-// AccountsResult is the result of the wallet's Accounts method. See that
-// method for more details.
-type AccountsResult struct {
- Accounts []AccountResult
- CurrentBlockHash *chainhash.Hash
- CurrentBlockHeight int32
-}
-
-// Accounts returns the current names, numbers, and total balances of all
-// accounts in the wallet restricted to a particular key scope. The current
-// chain tip is included in the result for atomicity reasons.
-//
-// TODO(jrick): Is the chain tip really needed, since only the total balances
-// are included?
-func (w *Wallet) Accounts(scope waddrmgr.KeyScope) (*AccountsResult, error) {
- manager, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return nil, err
- }
-
- var (
- accounts []AccountResult
- syncBlockHash *chainhash.Hash
- syncBlockHeight int32
- )
- err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
-
- syncBlock := w.Manager.SyncedTo()
- syncBlockHash = &syncBlock.Hash
- syncBlockHeight = syncBlock.Height
- unspent, err := w.TxStore.UnspentOutputs(txmgrNs)
- if err != nil {
- return err
- }
- err = manager.ForEachAccount(addrmgrNs, func(acct uint32) error {
- props, err := manager.AccountProperties(addrmgrNs, acct)
- if err != nil {
- return err
- }
- accounts = append(accounts, AccountResult{
- AccountProperties: *props,
- // TotalBalance set below
- })
- return nil
- })
- if err != nil {
- return err
- }
- m := make(map[uint32]*btcutil.Amount)
- for i := range accounts {
- a := &accounts[i]
- m[a.AccountNumber] = &a.TotalBalance
- }
- for i := range unspent {
- output := unspent[i]
- var outputAcct uint32
-
- _, addrs, _, err := txscript.ExtractPkScriptAddrs(
- output.PkScript, w.chainParams,
- )
- if err == nil && len(addrs) > 0 {
- _, outputAcct, err = w.Manager.AddrAccount(addrmgrNs, addrs[0])
- }
- if err == nil {
- amt, ok := m[outputAcct]
- if ok {
- *amt += output.Amount
- }
- }
- }
- return nil
- })
- return &AccountsResult{
- Accounts: accounts,
- CurrentBlockHash: syncBlockHash,
- CurrentBlockHeight: syncBlockHeight,
- }, err
-}
-
-// AccountBalanceResult is a single result for the Wallet.AccountBalances method.
-type AccountBalanceResult struct {
- AccountNumber uint32
- AccountName string
- AccountBalance btcutil.Amount
-}
-
-// AccountBalances returns all accounts in the wallet and their balances.
-// Balances are determined by excluding transactions that have not met
-// requiredConfs confirmations.
-func (w *Wallet) AccountBalances(scope waddrmgr.KeyScope,
- requiredConfs int32) ([]AccountBalanceResult, error) {
-
- manager, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return nil, err
- }
-
- var results []AccountBalanceResult
- err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
-
- syncBlock := w.Manager.SyncedTo()
-
- // Fill out all account info except for the balances.
- lastAcct, err := manager.LastAccount(addrmgrNs)
- if err != nil {
- return err
- }
- results = make([]AccountBalanceResult, lastAcct+2)
- for i := range results[:len(results)-1] {
- accountName, err := manager.AccountName(addrmgrNs, uint32(i))
- if err != nil {
- return err
- }
- results[i].AccountNumber = uint32(i)
- results[i].AccountName = accountName
- }
- results[len(results)-1].AccountNumber = waddrmgr.ImportedAddrAccount
- results[len(results)-1].AccountName = waddrmgr.ImportedAddrAccountName
-
- // Fetch all unspent outputs, and iterate over them tallying each
- // account's balance where the output script pays to an account address
- // and the required number of confirmations is met.
- unspentOutputs, err := w.TxStore.UnspentOutputs(txmgrNs)
- if err != nil {
- return err
- }
- for i := range unspentOutputs {
- output := &unspentOutputs[i]
- if !hasMinConfs(
- requiredConfs, output.Height, syncBlock.Height,
- ) {
-
- continue
- }
-
- if output.FromCoinBase && !hasMinConfs(
- int32(w.ChainParams().CoinbaseMaturity),
- output.Height, syncBlock.Height,
- ) {
-
- continue
- }
-
- _, addrs, _, err := txscript.ExtractPkScriptAddrs(
- output.PkScript, w.chainParams,
- )
- if err != nil || len(addrs) == 0 {
- continue
- }
- outputAcct, err := manager.AddrAccount(addrmgrNs, addrs[0])
- if err != nil {
- continue
- }
- switch {
- case outputAcct == waddrmgr.ImportedAddrAccount:
- results[len(results)-1].AccountBalance += output.Amount
- case outputAcct > lastAcct:
- return errors.New("waddrmgr.Manager.AddrAccount returned " +
- "account beyond recorded last account")
- default:
- results[outputAcct].AccountBalance += output.Amount
- }
- }
- return nil
- })
- return results, err
-}
-
-// creditSlice satisifies the sort.Interface interface to provide sorting
-// transaction credits from oldest to newest. Credits with the same receive
-// time and mined in the same block are not guaranteed to be sorted by the order
-// they appear in the block. Credits from the same transaction are sorted by
-// output index.
-type creditSlice []wtxmgr.Credit
-
-func (s creditSlice) Len() int {
- return len(s)
-}
-
-func (s creditSlice) Less(i, j int) bool {
- switch {
- // If both credits are from the same tx, sort by output index.
- case s[i].OutPoint.Hash == s[j].OutPoint.Hash:
- return s[i].OutPoint.Index < s[j].OutPoint.Index
-
- // If both transactions are unmined, sort by their received date.
- case s[i].Height == -1 && s[j].Height == -1:
- return s[i].Received.Before(s[j].Received)
-
- // Unmined (newer) txs always come last.
- case s[i].Height == -1:
- return false
- case s[j].Height == -1:
- return true
-
- // If both txs are mined in different blocks, sort by block height.
- default:
- return s[i].Height < s[j].Height
- }
-}
-
-func (s creditSlice) Swap(i, j int) {
- s[i], s[j] = s[j], s[i]
-}
-
-// ListUnspent returns a slice of objects representing the unspent wallet
-// transactions fitting the given criteria. The confirmations will be more than
-// minconf, less than maxconf and if addresses is populated only the addresses
-// contained within it will be considered. If we know nothing about a
-// transaction an empty array will be returned.
-func (w *Wallet) ListUnspent(minconf, maxconf int32,
- accountName string) ([]*btcjson.ListUnspentResult, error) {
-
- var results []*btcjson.ListUnspentResult
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
-
- syncBlock := w.Manager.SyncedTo()
-
- filter := accountName != ""
- unspent, err := w.TxStore.UnspentOutputs(txmgrNs)
- if err != nil {
- return err
- }
- sort.Sort(sort.Reverse(creditSlice(unspent)))
-
- defaultAccountName := "default"
-
- results = make([]*btcjson.ListUnspentResult, 0, len(unspent))
- for i := range unspent {
- output := unspent[i]
-
- // Outputs with fewer confirmations than the minimum or
- // more confs than the maximum are excluded.
- confs := calcConf(output.Height, syncBlock.Height)
- if confs < minconf || confs > maxconf {
- continue
- }
-
- // Only mature coinbase outputs are included.
- if output.FromCoinBase {
- target := int32(w.ChainParams().CoinbaseMaturity)
- if !hasMinConfs(
- target, output.Height, syncBlock.Height,
- ) {
-
- continue
- }
- }
-
- // Exclude locked outputs from the result set.
- if w.LockedOutpoint(output.OutPoint) {
- continue
- }
-
- // Lookup the associated account for the output. Use the
- // default account name in case there is no associated account
- // for some reason, although this should never happen.
- //
- // This will be unnecessary once transactions and outputs are
- // grouped under the associated account in the db.
- outputAcctName := defaultAccountName
- sc, addrs, _, err := txscript.ExtractPkScriptAddrs(
- output.PkScript, w.chainParams)
- if err != nil {
- continue
- }
- if len(addrs) > 0 {
- smgr, acct, err := w.Manager.AddrAccount(addrmgrNs, addrs[0])
- if err == nil {
- s, err := smgr.AccountName(addrmgrNs, acct)
- if err == nil {
- outputAcctName = s
- }
- }
- }
-
- if filter && outputAcctName != accountName {
- continue
- }
-
- // At the moment watch-only addresses are not supported, so all
- // recorded outputs that are not multisig are "spendable".
- // Multisig outputs are only "spendable" if all keys are
- // controlled by this wallet.
- //
- // TODO: Each case will need updates when watch-only addrs
- // is added. For P2PK, P2PKH, and P2SH, the address must be
- // looked up and not be watching-only. For multisig, all
- // pubkeys must belong to the manager with the associated
- // private key (currently it only checks whether the pubkey
- // exists, since the private key is required at the moment).
- var spendable bool
- scSwitch:
- switch sc {
- case txscript.PubKeyHashTy:
- spendable = true
- case txscript.PubKeyTy:
- spendable = true
- case txscript.WitnessV0ScriptHashTy:
- spendable = true
- case txscript.WitnessV0PubKeyHashTy:
- spendable = true
- case txscript.MultiSigTy:
- for _, a := range addrs {
- _, err := w.Manager.Address(addrmgrNs, a)
- if err == nil {
- continue
- }
- if waddrmgr.IsError(err, waddrmgr.ErrAddressNotFound) {
- break scSwitch
- }
- return err
- }
- spendable = true
- }
-
- result := &btcjson.ListUnspentResult{
- TxID: output.OutPoint.Hash.String(),
- Vout: output.OutPoint.Index,
- Account: outputAcctName,
- ScriptPubKey: hex.EncodeToString(output.PkScript),
- Amount: output.Amount.ToBTC(),
- Confirmations: int64(confs),
- Spendable: spendable,
- }
-
- // BUG: this should be a JSON array so that all
- // addresses can be included, or removed (and the
- // caller extracts addresses from the pkScript).
- if len(addrs) > 0 {
- result.Address = addrs[0].EncodeAddress()
- }
-
- results = append(results, result)
- }
- return nil
- })
- return results, err
-}
-
-// ListLeasedOutputResult is a single result for the Wallet.ListLeasedOutputs
-// method. See that method for more details.
-type ListLeasedOutputResult struct {
- *wtxmgr.LockedOutput
- Value int64
- PkScript []byte
-}
-
-// ListLeasedOutputs returns a list of objects representing the currently locked
-// utxos.
-func (w *Wallet) ListLeasedOutputs() ([]*ListLeasedOutputResult, error) {
- var results []*ListLeasedOutputResult
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- ns := tx.ReadBucket(wtxmgrNamespaceKey)
- outputs, err := w.TxStore.ListLockedOutputs(ns)
- if err != nil {
- return err
- }
-
- for _, output := range outputs {
- details, err := w.TxStore.TxDetails(ns, &output.Outpoint.Hash)
- if err != nil {
- return err
- }
-
- if details == nil {
- log.Infof("unable to find tx details for "+
- "%v:%v", output.Outpoint.Hash,
- output.Outpoint.Index)
- continue
- }
-
- txOut := details.MsgTx.TxOut[output.Outpoint.Index]
-
- result := &ListLeasedOutputResult{
- LockedOutput: output,
- Value: txOut.Value,
- PkScript: txOut.PkScript,
- }
-
- results = append(results, result)
- }
-
- return nil
- })
- return results, err
-}
-
-// DumpPrivKeys returns the WIF-encoded private keys for all addresses with
-// private keys in a wallet.
-func (w *Wallet) DumpPrivKeys() ([]string, error) {
- var privkeys []string
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- // Iterate over each active address, appending the private key to
- // privkeys.
- return w.Manager.ForEachActiveAddress(
- addrmgrNs, func(addr address.Address) error {
- ma, err := w.Manager.Address(addrmgrNs, addr)
- if err != nil {
- return err
- }
-
- // Only those addresses with keys needed.
- pka, ok := ma.(waddrmgr.ManagedPubKeyAddress)
- if !ok {
- return nil
- }
-
- wif, err := pka.ExportPrivKey()
- if err != nil {
- // It would be nice to zero out the array here. However,
- // since strings in go are immutable, and we have no
- // control over the caller I don't think we can. :(
- return err
- }
-
- privkeys = append(privkeys, wif.String())
-
- return nil
- },
- )
- })
- return privkeys, err
-}
-
-// DumpWIFPrivateKey returns the WIF encoded private key for a
-// single wallet address.
-func (w *Wallet) DumpWIFPrivateKey(addr address.Address) (string, error) {
- var maddr waddrmgr.ManagedAddress
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- waddrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- // Get private key from wallet if it exists.
- var err error
- maddr, err = w.Manager.Address(waddrmgrNs, addr)
- return err
- })
- if err != nil {
- return "", err
- }
-
- pka, ok := maddr.(waddrmgr.ManagedPubKeyAddress)
- if !ok {
- return "", fmt.Errorf("address %s is not a key type", addr)
- }
-
- wif, err := pka.ExportPrivKey()
- if err != nil {
- return "", err
- }
- return wif.String(), nil
-}
-
-// LockedOutpoint returns whether an outpoint has been marked as locked and
-// should not be used as an input for created transactions.
-func (w *Wallet) LockedOutpoint(op wire.OutPoint) bool {
- w.lockedOutpointsMtx.Lock()
- defer w.lockedOutpointsMtx.Unlock()
-
- _, locked := w.lockedOutpoints[op]
- return locked
-}
-
-// LockOutpoint marks an outpoint as locked, that is, it should not be used as
-// an input for newly created transactions.
-func (w *Wallet) LockOutpoint(op wire.OutPoint) {
- w.lockedOutpointsMtx.Lock()
- defer w.lockedOutpointsMtx.Unlock()
-
- w.lockedOutpoints[op] = struct{}{}
-}
-
-// UnlockOutpoint marks an outpoint as unlocked, that is, it may be used as an
-// input for newly created transactions.
-func (w *Wallet) UnlockOutpoint(op wire.OutPoint) {
- w.lockedOutpointsMtx.Lock()
- defer w.lockedOutpointsMtx.Unlock()
-
- delete(w.lockedOutpoints, op)
-}
-
-// ResetLockedOutpoints resets the set of locked outpoints so all may be used
-// as inputs for new transactions.
-func (w *Wallet) ResetLockedOutpoints() {
- w.lockedOutpointsMtx.Lock()
- defer w.lockedOutpointsMtx.Unlock()
-
- w.lockedOutpoints = map[wire.OutPoint]struct{}{}
-}
-
-// LockedOutpoints returns a slice of currently locked outpoints. This is
-// intended to be used by marshaling the result as a JSON array for
-// listlockunspent RPC results.
-func (w *Wallet) LockedOutpoints() []btcjson.TransactionInput {
- w.lockedOutpointsMtx.Lock()
- defer w.lockedOutpointsMtx.Unlock()
-
- locked := make([]btcjson.TransactionInput, len(w.lockedOutpoints))
- i := 0
- for op := range w.lockedOutpoints {
- locked[i] = btcjson.TransactionInput{
- Txid: op.Hash.String(),
- Vout: op.Index,
- }
- i++
- }
- return locked
-}
-
-// LeaseOutput locks an output to the given ID, preventing it from being
-// available for coin selection. The absolute time of the lock's expiration is
-// returned. The expiration of the lock can be extended by successive
-// invocations of this call.
-//
-// Outputs can be unlocked before their expiration through `UnlockOutput`.
-// Otherwise, they are unlocked lazily through calls which iterate through all
-// known outputs, e.g., `CalculateBalance`, `ListUnspent`.
-//
-// If the output is not known, ErrUnknownOutput is returned. If the output has
-// already been locked to a different ID, then ErrOutputAlreadyLocked is
-// returned.
-//
-// NOTE: This differs from LockOutpoint in that outputs are locked for a limited
-// amount of time and their locks are persisted to disk.
-func (w *Wallet) LeaseOutput(id wtxmgr.LockID, op wire.OutPoint,
- duration time.Duration) (time.Time, error) {
-
- var expiry time.Time
- err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(wtxmgrNamespaceKey)
- var err error
- expiry, err = w.TxStore.LockOutput(ns, id, op, duration)
- return err
- })
- return expiry, err
-}
-
-// ReleaseOutput unlocks an output, allowing it to be available for coin
-// selection if it remains unspent. The ID should match the one used to
-// originally lock the output.
-func (w *Wallet) ReleaseOutput(id wtxmgr.LockID, op wire.OutPoint) error {
- return walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(wtxmgrNamespaceKey)
- return w.TxStore.UnlockOutput(ns, id, op)
- })
-}
-
-// resendUnminedTxs iterates through all transactions that spend from wallet
-// credits that are not known to have been mined into a block, and attempts
-// to send each to the chain server for relay.
-func (w *Wallet) resendUnminedTxs() {
- var txs []*wire.MsgTx
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
- var err error
- txs, err = w.TxStore.UnminedTxs(txmgrNs)
- return err
- })
- if err != nil {
- log.Errorf("Unable to retrieve unconfirmed transactions to "+
- "resend: %v", err)
- return
- }
-
- for _, tx := range txs {
- txHash, err := w.publishTransaction(tx)
- if err != nil {
- log.Debugf("Unable to rebroadcast transaction %v: %v",
- tx.TxHash(), err)
- continue
- }
-
- log.Debugf("Successfully rebroadcast unconfirmed transaction %v",
- txHash)
- }
-}
-
-// SortedActivePaymentAddresses returns a slice of all active payment
-// addresses in a wallet.
-func (w *Wallet) SortedActivePaymentAddresses() ([]string, error) {
- var addrStrs []string
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
-
- return w.Manager.ForEachActiveAddress(
- addrmgrNs, func(addr address.Address) error {
- addrStrs = append(addrStrs, addr.EncodeAddress())
- return nil
- },
- )
- })
- if err != nil {
- return nil, err
- }
-
- sort.Strings(addrStrs)
- return addrStrs, nil
-}
-
-// NewAddress returns the next external chained address for a wallet.
-func (w *Wallet) NewAddress(account uint32,
- scope waddrmgr.KeyScope) (address.Address, error) {
-
- chainClient, err := w.requireChainClient()
- if err != nil {
- return nil, err
- }
-
- // The address manager uses OnCommit on the walletdb tx to update the
- // in-memory state of the account state. But because the commit happens
- // _after_ the account manager internal lock has been released, there
- // is a chance for the address index to be accessed concurrently, even
- // though the closure in OnCommit re-acquires the lock. To avoid this
- // issue, we surround the whole address creation process with a lock.
- w.newAddrMtx.Lock()
- defer w.newAddrMtx.Unlock()
-
- var (
- addr address.Address
- props *waddrmgr.AccountProperties
- )
- err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- var err error
- addr, props, err = w.newAddress(addrmgrNs, account, scope)
- return err
- })
- if err != nil {
- return nil, err
- }
-
- // Notify the rpc server about the newly created address.
- err = chainClient.NotifyReceived([]address.Address{addr})
- if err != nil {
- return nil, err
- }
-
- w.NtfnServer.notifyAccountProperties(props)
-
- return addr, nil
-}
-
-func (w *Wallet) newAddress(addrmgrNs walletdb.ReadWriteBucket, account uint32,
- scope waddrmgr.KeyScope) (address.Address, *waddrmgr.AccountProperties,
- error) {
-
- manager, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return nil, nil, err
- }
-
- // Get next address from wallet.
- addrs, err := manager.NextExternalAddresses(addrmgrNs, account, 1)
- if err != nil {
- return nil, nil, err
- }
-
- props, err := manager.AccountProperties(addrmgrNs, account)
- if err != nil {
- log.Errorf("Cannot fetch account properties for notification "+
- "after deriving next external address: %v", err)
- return nil, nil, err
- }
-
- return addrs[0].Address(), props, nil
-}
-
-// NewChangeAddress returns a new change address for a wallet.
-func (w *Wallet) NewChangeAddress(account uint32,
- scope waddrmgr.KeyScope) (address.Address, error) {
-
- chainClient, err := w.requireChainClient()
- if err != nil {
- return nil, err
- }
-
- // The address manager uses OnCommit on the walletdb tx to update the
- // in-memory state of the account state. But because the commit happens
- // _after_ the account manager internal lock has been released, there
- // is a chance for the address index to be accessed concurrently, even
- // though the closure in OnCommit re-acquires the lock. To avoid this
- // issue, we surround the whole address creation process with a lock.
- w.newAddrMtx.Lock()
- defer w.newAddrMtx.Unlock()
-
- var addr address.Address
- err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- var err error
- addr, err = w.newChangeAddress(addrmgrNs, account, scope)
- return err
- })
- if err != nil {
- return nil, err
- }
-
- // Notify the rpc server about the newly created address.
- err = chainClient.NotifyReceived([]address.Address{addr})
- if err != nil {
- return nil, err
- }
-
- return addr, nil
-}
-
-// newChangeAddress returns a new change address for the wallet.
-//
-// NOTE: This method requires the caller to use the backend's NotifyReceived
-// method in order to detect when an on-chain transaction pays to the address
-// being created.
-func (w *Wallet) newChangeAddress(addrmgrNs walletdb.ReadWriteBucket,
- account uint32, scope waddrmgr.KeyScope) (address.Address, error) {
-
- manager, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return nil, err
- }
-
- // Get next chained change address from wallet for account.
- addrs, err := manager.NextInternalAddresses(addrmgrNs, account, 1)
- if err != nil {
- return nil, err
- }
-
- return addrs[0].Address(), nil
-}
-
-// hasMinConfs returns whether a transaction has met at least minConf
-// confirmations at the current block height.
-func hasMinConfs(minConf, txHeight, curHeight int32) bool {
- return calcConf(txHeight, curHeight) >= minConf
-}
-
-// calcConf returns the number of confirmations for a transaction given its
-// containing block height and the current best block height. Unconfirmed
-// transactions have a height of -1 and are considered to have 0 confirmations.
-func calcConf(txHeight, curHeight int32) int32 {
- switch {
- // Unconfirmed transactions have 0 confirmations.
- case txHeight == -1:
- return 0
-
- // A transaction in a block after the current best block is considered
- // unconfirmed. This can happen during a chain reorg.
- case txHeight > curHeight:
- return 0
-
- // Confirmed transactions have at least one confirmation.
- default:
- return curHeight - txHeight + 1
- }
-}
-
-// AccountTotalReceivedResult is a single result for the
-// Wallet.TotalReceivedForAccounts method.
-type AccountTotalReceivedResult struct {
- AccountNumber uint32
- AccountName string
- TotalReceived btcutil.Amount
- LastConfirmation int32
-}
-
-// TotalReceivedForAccounts iterates through a wallet's transaction history,
-// returning the total amount of Bitcoin received for all accounts.
-func (w *Wallet) TotalReceivedForAccounts(scope waddrmgr.KeyScope,
- minConf int32) ([]AccountTotalReceivedResult, error) {
-
- manager, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return nil, err
- }
-
- var results []AccountTotalReceivedResult
- err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
- txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
-
- syncBlock := w.Manager.SyncedTo()
-
- err := manager.ForEachAccount(addrmgrNs, func(account uint32) error {
- accountName, err := manager.AccountName(addrmgrNs, account)
- if err != nil {
- return err
- }
- results = append(results, AccountTotalReceivedResult{
- AccountNumber: account,
- AccountName: accountName,
- })
- return nil
- })
- if err != nil {
- return err
- }
-
- var stopHeight int32
-
- if minConf > 0 {
- stopHeight = syncBlock.Height - minConf + 1
- } else {
- stopHeight = -1
- }
-
- rangeFn := func(details []wtxmgr.TxDetails) (bool, error) {
- for i := range details {
- detail := &details[i]
- for _, cred := range detail.Credits {
- pkScript := detail.MsgTx.TxOut[cred.Index].PkScript
- var outputAcct uint32
-
- _, addrs, _, err := txscript.ExtractPkScriptAddrs(
- pkScript, w.chainParams,
- )
- if err == nil && len(addrs) > 0 {
- _, outputAcct, err = w.Manager.AddrAccount(
- addrmgrNs, addrs[0],
- )
- }
- if err == nil {
- acctIndex := int(outputAcct)
- if outputAcct == waddrmgr.ImportedAddrAccount {
- acctIndex = len(results) - 1
- }
- res := &results[acctIndex]
- res.TotalReceived += cred.Amount
-
- confs := calcConf(
- detail.Block.Height,
- syncBlock.Height,
- )
- res.LastConfirmation = confs
- }
- }
- }
- return false, nil
- }
- return w.TxStore.RangeTransactions(txmgrNs, 0, stopHeight, rangeFn)
- })
- return results, err
-}
-
-// TotalReceivedForAddr iterates through a wallet's transaction history,
-// returning the total amount of bitcoins received for a single wallet
-// address.
-func (w *Wallet) TotalReceivedForAddr(addr address.Address,
- minConf int32) (btcutil.Amount, error) {
-
- var amount btcutil.Amount
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- txmgrNs := tx.ReadBucket(wtxmgrNamespaceKey)
-
- syncBlock := w.Manager.SyncedTo()
-
- var (
- addrStr = addr.EncodeAddress()
- stopHeight int32
- )
-
- if minConf > 0 {
- stopHeight = syncBlock.Height - minConf + 1
- } else {
- stopHeight = -1
- }
- rangeFn := func(details []wtxmgr.TxDetails) (bool, error) {
- for i := range details {
- detail := &details[i]
- for _, cred := range detail.Credits {
- pkScript := detail.MsgTx.TxOut[cred.Index].PkScript
- _, addrs, _, err := txscript.ExtractPkScriptAddrs(pkScript,
- w.chainParams)
- // An error creating addresses from the output script only
- // indicates a non-standard script, so ignore this credit.
- if err != nil {
- continue
- }
- for _, a := range addrs {
- if addrStr == a.EncodeAddress() {
- amount += cred.Amount
- break
- }
- }
- }
- }
- return false, nil
- }
- return w.TxStore.RangeTransactions(txmgrNs, 0, stopHeight, rangeFn)
- })
- return amount, err
-}
-
-// SendOutputs creates and sends payment transactions. Coin selection is
-// performed by the wallet, choosing inputs that belong to the given key scope
-// and account, unless a key scope is not specified. In that case, inputs from
-// accounts matching the account number provided across all key scopes may be
-// selected. This is done to handle the default account case, where a user wants
-// to fund a PSBT with inputs regardless of their type (NP2WKH, P2WKH, etc.). It
-// returns the transaction upon success.
-func (w *Wallet) SendOutputs(outputs []*wire.TxOut, keyScope *waddrmgr.KeyScope,
- account uint32, minconf int32, satPerKb btcutil.Amount,
- coinSelectionStrategy CoinSelectionStrategy, label string) (*wire.MsgTx,
- error) {
-
- return w.sendOutputs(
- outputs, keyScope, account, minconf, satPerKb,
- coinSelectionStrategy, label,
- )
-}
-
-// SendOutputsWithInput creates and sends payment transactions using the
-// provided selected utxos. It returns the transaction upon success.
-func (w *Wallet) SendOutputsWithInput(outputs []*wire.TxOut,
- keyScope *waddrmgr.KeyScope,
- account uint32, minconf int32, satPerKb btcutil.Amount,
- coinSelectionStrategy CoinSelectionStrategy, label string,
- selectedUtxos []wire.OutPoint) (*wire.MsgTx, error) {
-
- return w.sendOutputs(outputs, keyScope, account, minconf, satPerKb,
- coinSelectionStrategy, label, selectedUtxos...)
-}
-
-// sendOutputs creates and sends payment transactions. It returns the
-// transaction upon success.
-func (w *Wallet) sendOutputs(outputs []*wire.TxOut, keyScope *waddrmgr.KeyScope,
- account uint32, minconf int32, satPerKb btcutil.Amount,
- coinSelectionStrategy CoinSelectionStrategy, label string,
- selectedUtxos ...wire.OutPoint) (*wire.MsgTx, error) {
-
- // Ensure the outputs to be created adhere to the network's consensus
- // rules.
- for _, output := range outputs {
- err := txrules.CheckOutput(
- output, txrules.DefaultRelayFeePerKb,
- )
- if err != nil {
- return nil, err
- }
- }
-
- // Create the transaction and broadcast it to the network. The
- // transaction will be added to the database in order to ensure that we
- // continue to re-broadcast the transaction upon restarts until it has
- // been confirmed.
- createdTx, err := w.CreateSimpleTx(
- keyScope, account, outputs, minconf, satPerKb,
- coinSelectionStrategy, false, WithCustomSelectUtxos(
- selectedUtxos,
- ),
- )
- if err != nil {
- return nil, err
- }
-
- // If our wallet is read-only, we'll get a transaction with coins
- // selected but no witness data. In such a case we need to inform our
- // caller that they'll actually need to go ahead and sign the TX.
- if w.Manager.WatchOnly() {
- return createdTx.Tx, ErrTxUnsigned
- }
-
- txHash, err := w.reliablyPublishTransaction(createdTx.Tx, label)
- if err != nil {
- return nil, err
- }
-
- // Sanity check on the returned tx hash.
- if *txHash != createdTx.Tx.TxHash() {
- return nil, errors.New("tx hash mismatch")
- }
-
- return createdTx.Tx, nil
-}
-
-// SignatureError records the underlying error when validating a transaction
-// input signature.
-type SignatureError struct {
- InputIndex uint32
- Error error
-}
-
-// SignTransaction uses secrets of the wallet, as well as additional secrets
-// passed in by the caller, to create and add input signatures to a transaction.
-//
-// Transaction input script validation is used to confirm that all signatures
-// are valid. For any invalid input, a SignatureError is added to the returns.
-// The final error return is reserved for unexpected or fatal errors, such as
-// being unable to determine a previous output script to redeem.
-//
-// The transaction pointed to by tx is modified by this function.
-func (w *Wallet) SignTransaction(tx *wire.MsgTx, hashType txscript.SigHashType,
- additionalPrevScripts map[wire.OutPoint][]byte,
- additionalKeysByAddress map[string]*btcutil.WIF,
- p2shRedeemScriptsByAddress map[string][]byte) ([]SignatureError, error) {
-
- var signErrors []SignatureError
- err := walletdb.View(w.db, func(dbtx walletdb.ReadTx) error {
- addrmgrNs := dbtx.ReadBucket(waddrmgrNamespaceKey)
- txmgrNs := dbtx.ReadBucket(wtxmgrNamespaceKey)
-
- inputFetcher := txscript.NewMultiPrevOutFetcher(nil)
- for i, txIn := range tx.TxIn {
- prevOutScript, ok := additionalPrevScripts[txIn.PreviousOutPoint]
- if !ok {
- prevHash := &txIn.PreviousOutPoint.Hash
- prevIndex := txIn.PreviousOutPoint.Index
- txDetails, err := w.TxStore.TxDetails(txmgrNs, prevHash)
- if err != nil {
- return fmt.Errorf("cannot query previous transaction "+
- "details for %v: %w", txIn.PreviousOutPoint, err)
- }
- if txDetails == nil {
- return fmt.Errorf("%v not found",
- txIn.PreviousOutPoint)
- }
- prevOutScript = txDetails.MsgTx.TxOut[prevIndex].PkScript
- }
- inputFetcher.AddPrevOut(txIn.PreviousOutPoint, &wire.TxOut{
- PkScript: prevOutScript,
- })
-
- // Set up our callbacks that we pass to txscript so it can
- // look up the appropriate keys and scripts by address.
- getKey := txscript.KeyClosure(func(
- addr address.Address) (*btcec.PrivateKey, bool, error) {
-
- if len(additionalKeysByAddress) != 0 {
- addrStr := addr.EncodeAddress()
- wif, ok := additionalKeysByAddress[addrStr]
- if !ok {
- return nil, false,
- errors.New("no key for address")
- }
- return wif.PrivKey, wif.CompressPubKey, nil
- }
- address, err := w.Manager.Address(addrmgrNs, addr)
- if err != nil {
- return nil, false, err
- }
-
- pka, ok := address.(waddrmgr.ManagedPubKeyAddress)
- if !ok {
- return nil, false, fmt.Errorf("address %v is not "+
- "a pubkey address", address.Address().EncodeAddress())
- }
-
- key, err := pka.PrivKey()
- if err != nil {
- return nil, false, err
- }
-
- return key, pka.Compressed(), nil
- })
- getScript := txscript.ScriptClosure(func(
- addr address.Address) ([]byte, error) {
-
- // If keys were provided then we can only use the
- // redeem scripts provided with our inputs, too.
- if len(additionalKeysByAddress) != 0 {
- addrStr := addr.EncodeAddress()
- script, ok := p2shRedeemScriptsByAddress[addrStr]
- if !ok {
- return nil, errors.New("no script for address")
- }
- return script, nil
- }
- address, err := w.Manager.Address(addrmgrNs, addr)
- if err != nil {
- return nil, err
- }
- sa, ok := address.(waddrmgr.ManagedScriptAddress)
- if !ok {
- return nil, errors.New("address is not a script" +
- " address")
- }
-
- return sa.Script()
- })
-
- // SigHashSingle inputs can only be signed if there's a
- // corresponding output. However this could be already signed,
- // so we always verify the output.
- if (hashType&txscript.SigHashSingle) !=
- txscript.SigHashSingle || i < len(tx.TxOut) {
-
- script, err := txscript.SignTxOutput(w.ChainParams(),
- tx, i, prevOutScript, hashType, getKey,
- getScript, txIn.SignatureScript)
- // Failure to sign isn't an error, it just means that
- // the tx isn't complete.
- if err != nil {
- signErrors = append(signErrors, SignatureError{
- InputIndex: uint32(i),
- Error: err,
- })
- continue
- }
- txIn.SignatureScript = script
- }
-
- // Either it was already signed or we just signed it.
- // Find out if it is completely satisfied or still needs more.
- vm, err := txscript.NewEngine(
- prevOutScript, tx, i,
- txscript.StandardVerifyFlags, nil, nil, 0,
- inputFetcher,
- )
- if err == nil {
- err = vm.Execute()
- }
- if err != nil {
- signErrors = append(signErrors, SignatureError{
- InputIndex: uint32(i),
- Error: err,
- })
- }
- }
- return nil
- })
- return signErrors, err
-}
-
-// ErrDoubleSpend is an error returned from PublishTransaction in case the
-// published transaction failed to propagate since it was double spending a
-// confirmed transaction or a transaction in the mempool.
-type ErrDoubleSpend struct {
- backendError error
-}
-
-// Error returns the string representation of ErrDoubleSpend.
-//
-// NOTE: Satisfies the error interface.
-func (e *ErrDoubleSpend) Error() string {
- return fmt.Sprintf("double spend: %v", e.backendError)
-}
-
-// Unwrap returns the underlying error returned from the backend.
-func (e *ErrDoubleSpend) Unwrap() error {
- return e.backendError
-}
-
-// ErrMempoolFee is an error returned from PublishTransaction in case the
-// published transaction failed to propagate since it did not match the
-// current mempool fee requirement.
-type ErrMempoolFee struct {
- backendError error
-}
-
-// Error returns the string representation of ErrMempoolFee.
-//
-// NOTE: Satisfies the error interface.
-func (e *ErrMempoolFee) Error() string {
- return fmt.Sprintf("mempool fee not met: %v", e.backendError)
-}
-
-// Unwrap returns the underlying error returned from the backend.
-func (e *ErrMempoolFee) Unwrap() error {
- return e.backendError
-}
-
-// ErrAlreadyConfirmed is an error returned from PublishTransaction in case
-// a transaction is already confirmed in the blockchain.
-type ErrAlreadyConfirmed struct {
- backendError error
-}
-
-// Error returns the string representation of ErrAlreadyConfirmed.
-//
-// NOTE: Satisfies the error interface.
-func (e *ErrAlreadyConfirmed) Error() string {
- return fmt.Sprintf("tx already confirmed: %v", e.backendError)
-}
-
-// Unwrap returns the underlying error returned from the backend.
-func (e *ErrAlreadyConfirmed) Unwrap() error {
- return e.backendError
-}
-
-// ErrInMempool is an error returned from PublishTransaction in case a
-// transaction is already in the mempool.
-type ErrInMempool struct {
- backendError error
-}
-
-// Error returns the string representation of ErrInMempool.
-//
-// NOTE: Satisfies the error interface.
-func (e *ErrInMempool) Error() string {
- return fmt.Sprintf("tx already in mempool: %v", e.backendError)
-}
-
-// Unwrap returns the underlying error returned from the backend.
-func (e *ErrInMempool) Unwrap() error {
- return e.backendError
-}
-
-// PublishTransaction sends the transaction to the consensus RPC server so it
-// can be propagated to other nodes and eventually mined.
-//
-// This function is unstable and will be removed once syncing code is moved out
-// of the wallet.
-func (w *Wallet) PublishTransaction(tx *wire.MsgTx, label string) error {
- _, err := w.reliablyPublishTransaction(tx, label)
- return err
-}
-
-// reliablyPublishTransaction is a superset of publishTransaction which contains
-// the primary logic required for publishing a transaction, updating the
-// relevant database state, and finally possible removing the transaction from
-// the database (along with cleaning up all inputs used, and outputs created) if
-// the transaction is rejected by the backend.
-func (w *Wallet) reliablyPublishTransaction(tx *wire.MsgTx,
- label string) (*chainhash.Hash, error) {
-
- chainClient, err := w.requireChainClient()
- if err != nil {
- return nil, err
- }
-
- // As we aim for this to be general reliable transaction broadcast API,
- // we'll write this tx to disk as an unconfirmed transaction. This way,
- // upon restarts, we'll always rebroadcast it, and also add it to our
- // set of records.
- txRec, err := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
- if err != nil {
- return nil, err
- }
-
- // Along the way, we'll extract our relevant destination addresses from
- // the transaction.
- var ourAddrs []address.Address
- err = walletdb.Update(w.db, func(dbTx walletdb.ReadWriteTx) error {
- addrmgrNs := dbTx.ReadWriteBucket(waddrmgrNamespaceKey)
- for _, txOut := range tx.TxOut {
- _, addrs, _, err := txscript.ExtractPkScriptAddrs(
- txOut.PkScript, w.chainParams,
- )
- if err != nil {
- // Non-standard outputs can safely be skipped
- // because they're not supported by the wallet.
- log.Warnf("Non-standard pkScript=%x in tx=%v",
- txOut.PkScript, tx.TxHash())
-
- continue
- }
- for _, addr := range addrs {
- // Skip any addresses which are not relevant to
- // us.
- _, err := w.Manager.Address(addrmgrNs, addr)
- if waddrmgr.IsError(err, waddrmgr.ErrAddressNotFound) {
- continue
- }
- if err != nil {
- return err
- }
- ourAddrs = append(ourAddrs, addr)
- }
- }
-
- // If there is a label we should write, get the namespace key
- // and record it in the tx store.
- if len(label) != 0 {
- txmgrNs := dbTx.ReadWriteBucket(wtxmgrNamespaceKey)
-
- err = w.TxStore.PutTxLabel(
- txmgrNs, tx.TxHash(), label,
- )
- if err != nil {
- return err
- }
- }
-
- return w.addRelevantTx(dbTx, txRec, nil)
- })
- if err != nil {
- return nil, err
- }
-
- // We'll also ask to be notified of the transaction once it confirms
- // on-chain. This is done outside of the database transaction to prevent
- // backend interaction within it.
- if err := chainClient.NotifyReceived(ourAddrs); err != nil {
- return nil, err
- }
-
- return w.publishTransaction(tx)
-}
-
-// publishTransaction attempts to send an unconfirmed transaction to the
-// wallet's current backend. In the event that sending the transaction fails for
-// whatever reason, it will be removed from the wallet's unconfirmed transaction
-// store.
-func (w *Wallet) publishTransaction(tx *wire.MsgTx) (*chainhash.Hash, error) {
- chainClient, err := w.requireChainClient()
- if err != nil {
- return nil, err
- }
-
- txid := tx.TxHash()
- _, rpcErr := chainClient.SendRawTransaction(tx, false)
- if rpcErr == nil {
- return &txid, nil
- }
-
- switch {
- case errors.Is(rpcErr, chain.ErrTxAlreadyInMempool):
- log.Infof("%v: tx already in mempool", txid)
- return &txid, nil
-
- case errors.Is(rpcErr, chain.ErrTxAlreadyKnown),
- errors.Is(rpcErr, chain.ErrTxAlreadyConfirmed):
-
- dbErr := walletdb.Update(w.db, func(dbTx walletdb.ReadWriteTx) error {
- txmgrNs := dbTx.ReadWriteBucket(wtxmgrNamespaceKey)
- txRec, err := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
- if err != nil {
- return err
- }
- return w.TxStore.RemoveUnminedTx(txmgrNs, txRec)
- })
- if dbErr != nil {
- log.Warnf("Unable to remove confirmed transaction %v "+
- "from unconfirmed store: %v", tx.TxHash(), dbErr)
- }
-
- log.Infof("%v: tx already confirmed", txid)
-
- return &txid, nil
-
- }
-
- // Log the causing error, even if we know how to handle it.
- log.Infof("%v: broadcast failed because of: %v", txid, rpcErr)
-
- // If the transaction was rejected for whatever other reason, then
- // we'll remove it from the transaction store, as otherwise, we'll
- // attempt to continually re-broadcast it, and the UTXO state of the
- // wallet won't be accurate.
- dbErr := walletdb.Update(w.db, func(dbTx walletdb.ReadWriteTx) error {
- txmgrNs := dbTx.ReadWriteBucket(wtxmgrNamespaceKey)
- txRec, err := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
- if err != nil {
- return err
- }
- return w.TxStore.RemoveUnminedTx(txmgrNs, txRec)
- })
- if dbErr != nil {
- log.Warnf("Unable to remove invalid transaction %v: %v",
- tx.TxHash(), dbErr)
- } else {
- log.Infof("Removed invalid transaction: %v", tx.TxHash())
-
- // The serialized transaction is for logging only, don't fail
- // on the error.
- var txRaw bytes.Buffer
- _ = tx.Serialize(&txRaw)
-
- // Optionally log the tx in debug when the size is manageable.
- if txRaw.Len() < 1_000_000 {
- log.Debugf("Removed invalid transaction: %v \n hex=%x",
- newLogClosure(func() string {
- return spew.Sdump(tx)
- }), txRaw.Bytes())
- } else {
- log.Debug("Removed invalid transaction due to size " +
- "too large")
- }
- }
-
- return nil, rpcErr
-}
-
-// ChainParams returns the network parameters for the blockchain the wallet
-// belongs to.
-func (w *Wallet) ChainParams() *chaincfg.Params {
- return w.chainParams
-}
-
-// Database returns the underlying walletdb database. This method is provided
-// in order to allow applications wrapping btcwallet to store app-specific data
-// with the wallet's database.
-func (w *Wallet) Database() walletdb.DB {
- return w.db
-}
-
-// RemoveDescendants attempts to remove any transaction from the wallet's tx
-// store (that may be unconfirmed) that spends outputs created by the passed
-// transaction. This remove propagates recursively down the chain of descendent
-// transactions.
-func (w *Wallet) RemoveDescendants(tx *wire.MsgTx) error {
- txRecord, err := wtxmgr.NewTxRecordFromMsgTx(tx, time.Now())
- if err != nil {
- return err
- }
-
- return walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- wtxmgrNs := tx.ReadWriteBucket(wtxmgrNamespaceKey)
-
- return w.TxStore.RemoveUnminedTx(wtxmgrNs, txRecord)
- })
-}
-
-// BirthdayBlock returns the birthday block of the wallet.
-//
-// NOTE: The wallet won't start until the backend is synced, thus the birthday
-// block won't be set and `ErrBirthdayBlockNotSet` will be returned.
-func (w *Wallet) BirthdayBlock() (*waddrmgr.BlockStamp, error) {
- var birthdayBlock waddrmgr.BlockStamp
-
- // Query the wallet's birthday block height from db.
- err := walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
-
- bb, _, err := w.Manager.BirthdayBlock(addrmgrNs)
- birthdayBlock = bb
-
- return err
- })
- if err != nil {
- return nil, err
+ })
+ if err != nil {
+ return nil, err
}
return &birthdayBlock, nil
}
-// AddScopeManager creates a new scoped key manager from the root manager.
-func (w *Wallet) AddScopeManager(scope waddrmgr.KeyScope,
- addrSchema waddrmgr.ScopeAddrSchema) (
- *waddrmgr.ScopedKeyManager, error) {
-
- var scopedManager *waddrmgr.ScopedKeyManager
-
- err := walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
-
- manager, err := w.Manager.NewScopedKeyManager(
- addrmgrNs, scope, addrSchema,
- )
- scopedManager = manager
-
- return err
- })
- if err != nil {
- return nil, err
- }
-
- return scopedManager, nil
-}
-
-// InitAccounts creates a number of accounts specified by `num`, with account
-// number ranges from 1 to `num`.
-func (w *Wallet) InitAccounts(scope *waddrmgr.ScopedKeyManager,
- watchOnly bool, num uint32) error {
-
- return walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
-
- // Generate all accounts that we could ever need. This includes
- // all key families.
- for account := uint32(1); account <= num; account++ {
- // Otherwise, we'll check if the account already exists,
- // if so, we can once again bail early.
- _, err := scope.AccountName(addrmgrNs, account)
- if err == nil {
- continue
- }
-
- // If we reach this point, then the account hasn't yet
- // been created, so we'll need to create it before we
- // can proceed.
- err = scope.NewRawAccount(addrmgrNs, account)
- if err != nil {
- return err
- }
- }
-
- // If this is the first startup with remote signing and wallet
- // migration turned on and the wallet wasn't previously
- // migrated, we can do that now that we made sure all accounts
- // that we need were derived correctly.
- if watchOnly {
- log.Infof("Migrating wallet to watch-only mode, " +
- "purging all private key material")
-
- ns := tx.ReadWriteBucket(waddrmgrNamespaceKey)
-
- return w.Manager.ConvertToWatchingOnly(ns)
- }
-
- return nil
- })
-}
-
-// DeriveFromKeyPath derives a private key using the given derivation path.
-func (w *Wallet) DeriveFromKeyPath(scope waddrmgr.KeyScope,
- path waddrmgr.DerivationPath) (*btcec.PrivateKey, error) {
-
- scopedMgr, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return nil, fmt.Errorf("error fetching manager for scope %v: "+
- "%w", scope, err)
- }
-
- // Let's see if we can hit the private key cache.
- privKey, err := scopedMgr.DeriveFromKeyPathCache(path)
- if err == nil {
- return privKey, nil
- }
-
- // The key wasn't in the cache, let's fully derive it now.
- err = walletdb.View(w.db, func(tx walletdb.ReadTx) error {
- addrmgrNs := tx.ReadBucket(waddrmgrNamespaceKey)
-
- addr, err := scopedMgr.DeriveFromKeyPath(addrmgrNs, path)
- if err != nil {
- return fmt.Errorf("error deriving private key: %w", err)
- }
-
- mpka, ok := addr.(waddrmgr.ManagedPubKeyAddress)
- if !ok {
- err := fmt.Errorf("managed address type for %v is "+
- "`%T` but want waddrmgr.ManagedPubKeyAddress",
- addr, addr)
-
- return err
- }
- privKey, err = mpka.PrivKey()
-
- return err
- })
- if err != nil {
- return nil, err
- }
-
- return privKey, nil
-}
-
-// DeriveFromKeyPathAddAccount derives a private key using the given derivation
-// path. The account will be created if it doesn't exist.
-func (w *Wallet) DeriveFromKeyPathAddAccount(scope waddrmgr.KeyScope,
- path waddrmgr.DerivationPath) (*btcec.PrivateKey, error) {
-
- scopedMgr, err := w.Manager.FetchScopedKeyManager(scope)
- if err != nil {
- return nil, fmt.Errorf("error fetching manager for scope %v: "+
- "%w", scope, err)
- }
-
- // Let's see if we can hit the private key cache.
- privKey, err := scopedMgr.DeriveFromKeyPathCache(path)
- if err == nil {
- return privKey, nil
- }
-
- derivePrivKey := func(addrmgrNs walletdb.ReadWriteBucket) error {
- addr, err := scopedMgr.DeriveFromKeyPath(addrmgrNs, path)
-
- // Exit early if there's no error.
- if err == nil {
- key, ok := addr.(waddrmgr.ManagedPubKeyAddress)
- if !ok {
- return nil
- }
-
- // Overwrite the returned private key variable.
- privKey, err = key.PrivKey()
-
- return err
- }
-
- return err
- }
-
- // The key wasn't in the cache, let's fully derive it now.
- err = walletdb.Update(w.db, func(tx walletdb.ReadWriteTx) error {
- addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
-
- err := derivePrivKey(addrmgrNs)
-
- // Exit early if there's no error.
- if err == nil {
- return nil
- }
-
- // Exit with the error if it's not account not found.
- if !waddrmgr.IsError(err, waddrmgr.ErrAccountNotFound) {
- return fmt.Errorf("error deriving private key: %w", err)
- }
-
- // If we've reached this point, then the account doesn't yet
- // exist, so we'll create it now to ensure we can sign.
- err = scopedMgr.NewRawAccount(addrmgrNs, path.Account)
- if err != nil {
- return err
- }
-
- // Now that we know the account exists, we'll attempt to
- // re-derive the private key.
- return derivePrivKey(addrmgrNs)
- })
- if err != nil {
- return nil, err
- }
-
- return privKey, nil
-}
-
// SyncedTo calls the `SyncedTo` method on the wallet's manager.
func (w *Wallet) SyncedTo() waddrmgr.BlockStamp {
- return w.Manager.SyncedTo()
+ return w.addrStore.SyncedTo()
}
// AddrManager returns the internal address manager.
//
// TODO(yy): Refactor it in lnd and remove the method.
-func (w *Wallet) AddrManager() *waddrmgr.Manager {
- return w.Manager
+func (w *Wallet) AddrManager() waddrmgr.AddrStore {
+ return w.addrStore
}
// NotificationServer returns the internal NotificationServer.
@@ -4337,18 +506,6 @@ func CreateWatchingOnlyWithCallback(db walletdb.DB, pubPass []byte,
)
}
-// Create creates an new wallet, writing it to an empty database. If the passed
-// root key is non-nil, it is used. Otherwise, a secure random seed of the
-// recommended length is generated.
-func Create(db walletdb.DB, pubPass, privPass []byte,
- rootKey *hdkeychain.ExtendedKey, params *chaincfg.Params,
- birthday time.Time) error {
-
- return create(
- db, pubPass, privPass, rootKey, params, birthday, false, nil,
- )
-}
-
// CreateWatchingOnly creates an new watch-only wallet, writing it to
// an empty database. No root key can be provided as this wallet will be
// watching only. Likewise no private passphrase may be provided
@@ -4421,94 +578,3 @@ func create(db walletdb.DB, pubPass, privPass []byte,
return nil
})
}
-
-// Open loads an already-created wallet from the passed database and namespaces.
-func Open(db walletdb.DB, pubPass []byte, cbs *waddrmgr.OpenCallbacks,
- params *chaincfg.Params, recoveryWindow uint32) (*Wallet, error) {
-
- return OpenWithRetry(
- db, pubPass, cbs, params, recoveryWindow,
- defaultSyncRetryInterval,
- )
-}
-
-// OpenWithRetry loads an already-created wallet from the passed database and
-// namespaces and re-tries on errors during initial sync.
-func OpenWithRetry(db walletdb.DB, pubPass []byte, cbs *waddrmgr.OpenCallbacks,
- params *chaincfg.Params, recoveryWindow uint32,
- syncRetryInterval time.Duration) (*Wallet, error) {
-
- var (
- addrMgr *waddrmgr.Manager
- txMgr *wtxmgr.Store
- )
-
- // Before attempting to open the wallet, we'll check if there are any
- // database upgrades for us to proceed. We'll also create our references
- // to the address and transaction managers, as they are backed by the
- // database.
- err := walletdb.Update(db, func(tx walletdb.ReadWriteTx) error {
- addrMgrBucket := tx.ReadWriteBucket(waddrmgrNamespaceKey)
- if addrMgrBucket == nil {
- return errors.New("missing address manager namespace")
- }
- txMgrBucket := tx.ReadWriteBucket(wtxmgrNamespaceKey)
- if txMgrBucket == nil {
- return errors.New("missing transaction manager namespace")
- }
-
- addrMgrUpgrader := waddrmgr.NewMigrationManager(addrMgrBucket)
- txMgrUpgrader := wtxmgr.NewMigrationManager(txMgrBucket)
- err := migration.Upgrade(txMgrUpgrader, addrMgrUpgrader)
- if err != nil {
- return err
- }
-
- addrMgr, err = waddrmgr.Open(addrMgrBucket, pubPass, params)
- if err != nil {
- return err
- }
- txMgr, err = wtxmgr.Open(txMgrBucket, params)
- if err != nil {
- return err
- }
-
- return nil
- })
- if err != nil {
- return nil, err
- }
-
- log.Infof("Opened wallet") // TODO: log balance? last sync height?
-
- w := &Wallet{
- publicPassphrase: pubPass,
- db: db,
- Manager: addrMgr,
- TxStore: txMgr,
- lockedOutpoints: map[wire.OutPoint]struct{}{},
- recoveryWindow: recoveryWindow,
- rescanAddJob: make(chan *RescanJob),
- rescanBatch: make(chan *rescanBatch),
- rescanNotifications: make(chan interface{}),
- rescanProgress: make(chan *RescanProgressMsg),
- rescanFinished: make(chan *RescanFinishedMsg),
- createTxRequests: make(chan createTxRequest),
- unlockRequests: make(chan unlockRequest),
- lockRequests: make(chan struct{}),
- holdUnlockRequests: make(chan chan heldUnlock),
- lockState: make(chan bool),
- changePassphrase: make(chan changePassphraseRequest),
- changePassphrases: make(chan changePassphrasesRequest),
- chainParams: params,
- quit: make(chan struct{}),
- syncRetryInterval: syncRetryInterval,
- }
-
- w.NtfnServer = newNotificationServer(w)
- w.TxStore.NotifyUnspent = func(hash *chainhash.Hash, index uint32) {
- w.NtfnServer.notifyUnspentOutput(0, hash, index)
- }
-
- return w, nil
-}
diff --git a/wallet/wallet_test.go b/wallet/wallet_test.go
index a12527b794..ebfa96434d 100644
--- a/wallet/wallet_test.go
+++ b/wallet/wallet_test.go
@@ -2,23 +2,24 @@ package wallet
import (
"encoding/hex"
+ "errors"
"fmt"
- "math"
- "strings"
- "sync"
- "sync/atomic"
"testing"
"time"
- "github.com/btcsuite/btcd/address/v2"
"github.com/btcsuite/btcd/btcutil/v2"
"github.com/btcsuite/btcd/chainhash/v2"
"github.com/btcsuite/btcd/wire/v2"
- "github.com/btcsuite/btcwallet/waddrmgr"
- "github.com/btcsuite/btcwallet/walletdb"
"github.com/btcsuite/btcwallet/wtxmgr"
"github.com/stretchr/testify/require"
- "golang.org/x/sync/errgroup"
+)
+
+var (
+ // errBlockNotFound is an error returned when a block is not found.
+ errBlockNotFound = errors.New("block not found")
+
+ // errHeaderNotFound is an error returned when a header is not found.
+ errHeaderNotFound = errors.New("header not found")
)
var (
@@ -36,6 +37,174 @@ var (
}
)
+// TestConfigValidate ensures that the Config.validate method correctly
+// identifies missing required parameters.
+func TestConfigValidate(t *testing.T) {
+ t.Parallel()
+
+ db, cleanup := setupTestDB(t)
+ t.Cleanup(cleanup)
+
+ testCases := []struct {
+ name string
+ config Config
+ expectedErr string
+ }{
+ {
+ name: "valid config",
+ config: Config{
+ DB: db,
+ Chain: &mockChain{},
+ ChainParams: &chainParams,
+ Name: "test-wallet",
+ RecoveryWindow: MinRecoveryWindow,
+ },
+ },
+ {
+ name: "invalid RecoveryWindow",
+ config: Config{
+ DB: db,
+ Chain: &mockChain{},
+ ChainParams: &chainParams,
+ Name: "test-wallet",
+ RecoveryWindow: MinRecoveryWindow - 1,
+ },
+ expectedErr: "RecoveryWindow",
+ },
+ {
+ name: "missing DB",
+ config: Config{
+ Chain: &mockChain{},
+ ChainParams: &chainParams,
+ Name: "test-wallet",
+ RecoveryWindow: MinRecoveryWindow,
+ },
+ expectedErr: "DB",
+ },
+ {
+ name: "missing Chain",
+ config: Config{
+ DB: db,
+ ChainParams: &chainParams,
+ Name: "test-wallet",
+ RecoveryWindow: MinRecoveryWindow,
+ },
+ expectedErr: "Chain",
+ },
+ {
+ name: "missing ChainParams",
+ config: Config{
+ DB: db,
+ Chain: &mockChain{},
+ Name: "test-wallet",
+ RecoveryWindow: MinRecoveryWindow,
+ },
+ expectedErr: "ChainParams",
+ },
+ {
+ name: "missing Name",
+ config: Config{
+ DB: db,
+ Chain: &mockChain{},
+ ChainParams: &chainParams,
+ RecoveryWindow: MinRecoveryWindow,
+ },
+ expectedErr: "Name",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ err := tc.config.validate()
+ if tc.expectedErr == "" {
+ require.NoError(t, err)
+ } else {
+ require.ErrorContains(t, err, tc.expectedErr)
+ }
+ })
+ }
+}
+
+// mockChainConn is a mock in-memory implementation of the chainConn interface
+// that will be used for the birthday block sanity check tests. The struct is
+// capable of being backed by a chain in order to reproduce real-world
+// scenarios.
+type mockChainConn struct {
+ chainTip uint32
+ blockHashes map[uint32]chainhash.Hash
+ blocks map[chainhash.Hash]*wire.MsgBlock
+}
+
+var _ chainConn = (*mockChainConn)(nil)
+
+// createMockChainConn creates a new mock chain connection backed by a chain
+// with N blocks. Each block has a timestamp that is exactly blockInterval after
+// the previous block's timestamp.
+func createMockChainConn(genesis *wire.MsgBlock, n uint32,
+ blockInterval time.Duration) *mockChainConn {
+
+ c := &mockChainConn{
+ chainTip: n,
+ blockHashes: make(map[uint32]chainhash.Hash),
+ blocks: make(map[chainhash.Hash]*wire.MsgBlock),
+ }
+
+ genesisHash := genesis.BlockHash()
+ c.blockHashes[0] = genesisHash
+ c.blocks[genesisHash] = genesis
+
+ for i := uint32(1); i <= n; i++ {
+ prevTimestamp := c.blocks[c.blockHashes[i-1]].Header.Timestamp
+ block := &wire.MsgBlock{
+ Header: wire.BlockHeader{
+ Timestamp: prevTimestamp.Add(blockInterval),
+ },
+ }
+
+ blockHash := block.BlockHash()
+ c.blockHashes[i] = blockHash
+ c.blocks[blockHash] = block
+ }
+
+ return c
+}
+
+// GetBestBlock returns the hash and height of the best block known to the
+// backend.
+func (c *mockChainConn) GetBestBlock() (*chainhash.Hash, int32, error) {
+ bestHash, ok := c.blockHashes[c.chainTip]
+ if !ok {
+ return nil, 0, fmt.Errorf("%w: height %d",
+ errBlockNotFound, c.chainTip)
+ }
+
+ return &bestHash, int32(c.chainTip), nil
+}
+
+// GetBlockHash returns the hash of the block with the given height.
+func (c *mockChainConn) GetBlockHash(height int64) (*chainhash.Hash, error) {
+ hash, ok := c.blockHashes[uint32(height)]
+ if !ok {
+ return nil, fmt.Errorf("%w: height %d", errBlockNotFound, height)
+ }
+
+ return &hash, nil
+}
+
+// GetBlockHeader returns the header for the block with the given hash.
+func (c *mockChainConn) GetBlockHeader(
+ hash *chainhash.Hash) (*wire.BlockHeader, error) {
+
+ block, ok := c.blocks[*hash]
+ if !ok {
+ return nil, fmt.Errorf("%w: hash %v", errHeaderNotFound, hash)
+ }
+
+ return &block.Header, nil
+}
+
// TestLocateBirthdayBlock ensures we can properly map a block in the chain to a
// timestamp.
func TestLocateBirthdayBlock(t *testing.T) {
@@ -115,550 +284,3 @@ func TestLocateBirthdayBlock(t *testing.T) {
}
}
}
-
-// TestLabelTransaction tests labelling of transactions with invalid labels,
-// and failure to label a transaction when it already has a label.
-func TestLabelTransaction(t *testing.T) {
- t.Parallel()
-
- tests := []struct {
- name string
-
- // Whether the transaction should be known to the wallet.
- txKnown bool
-
- // Whether the test should write an existing label to disk.
- existingLabel bool
-
- // The overwrite parameter to call label transaction with.
- overwrite bool
-
- // The error we expect to be returned.
- expectedErr error
- }{
- {
- name: "existing label, not overwrite",
- txKnown: true,
- existingLabel: true,
- overwrite: false,
- expectedErr: ErrTxLabelExists,
- },
- {
- name: "existing label, overwritten",
- txKnown: true,
- existingLabel: true,
- overwrite: true,
- expectedErr: nil,
- },
- {
- name: "no prexisting label, ok",
- txKnown: true,
- existingLabel: false,
- overwrite: false,
- expectedErr: nil,
- },
- {
- name: "transaction unknown",
- txKnown: false,
- existingLabel: false,
- overwrite: false,
- expectedErr: ErrUnknownTransaction,
- },
- }
-
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- t.Parallel()
-
- w, cleanup := testWallet(t)
- defer cleanup()
-
- // If the transaction should be known to the store, we
- // write txdetail to disk.
- if test.txKnown {
- rec, err := wtxmgr.NewTxRecord(
- TstSerializedTx, time.Now(),
- )
- if err != nil {
- t.Fatal(err)
- }
-
- err = walletdb.Update(w.db,
- func(tx walletdb.ReadWriteTx) error {
-
- ns := tx.ReadWriteBucket(
- wtxmgrNamespaceKey,
- )
-
- return w.TxStore.InsertTx(
- ns, rec, nil,
- )
- })
- if err != nil {
- t.Fatalf("could not insert tx: %v", err)
- }
- }
-
- // If we want to setup an existing label for the purpose
- // of the test, write one to disk.
- if test.existingLabel {
- err := w.LabelTransaction(
- *TstTxHash, "existing label", false,
- )
- if err != nil {
- t.Fatalf("could not write label: %v",
- err)
- }
- }
-
- newLabel := "new label"
- err := w.LabelTransaction(
- *TstTxHash, newLabel, test.overwrite,
- )
- if err != test.expectedErr {
- t.Fatalf("expected: %v, got: %v",
- test.expectedErr, err)
- }
- })
- }
-}
-
-// TestGetTransaction tests if we can fetch a mined, an existing
-// and a non-existing transaction from the wallet like we expect.
-func TestGetTransaction(t *testing.T) {
- t.Parallel()
- rec, err := wtxmgr.NewTxRecord(TstSerializedTx, time.Now())
- require.NoError(t, err)
-
- tests := []struct {
- name string
-
- // Transaction id.
- txid chainhash.Hash
-
- // Expected height.
- expectedHeight int32
-
- // Store function.
- f func(*wtxmgr.Store, walletdb.ReadWriteBucket) (*wtxmgr.Store, error)
-
- // The error we expect to be returned.
- expectedErr error
- }{
- {
- name: "existing unmined transaction",
- txid: *TstTxHash,
- expectedHeight: -1,
- // We write txdetail for the tx to disk.
- f: func(s *wtxmgr.Store, ns walletdb.ReadWriteBucket) (
- *wtxmgr.Store, error) {
-
- err = s.InsertTx(ns, rec, nil)
- return s, err
- },
- expectedErr: nil,
- },
- {
- name: "existing mined transaction",
- txid: *TstTxHash,
- // We write txdetail for the tx to disk.
- f: func(s *wtxmgr.Store, ns walletdb.ReadWriteBucket) (
- *wtxmgr.Store, error) {
-
- err = s.InsertTx(ns, rec, TstMinedSignedTxBlockDetails)
- return s, err
- },
- expectedHeight: TstMinedTxBlockHeight,
- expectedErr: nil,
- },
- {
- name: "non-existing transaction",
- txid: *TstTxHash,
- // Write no txdetail to disk.
- f: func(s *wtxmgr.Store, ns walletdb.ReadWriteBucket) (
- *wtxmgr.Store, error) {
-
- return s, nil
- },
- expectedErr: ErrNoTx,
- },
- }
- for _, test := range tests {
- test := test
-
- t.Run(test.name, func(t *testing.T) {
- w, cleanup := testWallet(t)
- defer cleanup()
-
- err := walletdb.Update(w.db, func(rw walletdb.ReadWriteTx) error {
- ns := rw.ReadWriteBucket(wtxmgrNamespaceKey)
- _, err := test.f(w.TxStore, ns)
- return err
- })
- require.NoError(t, err)
- tx, err := w.GetTransaction(test.txid)
- require.ErrorIs(t, err, test.expectedErr)
-
- // Discontinue if no transaction were found.
- if err != nil {
- return
- }
-
- // Check if we get the expected hash.
- require.Equal(t, &test.txid, tx.Summary.Hash)
-
- // Check the block height.
- require.Equal(t, test.expectedHeight, tx.Height)
- })
- }
-}
-
-// TestGetTransactionConfirmations tests that GetTransaction correctly
-// calculates confirmations for both confirmed and unconfirmed transactions.
-// This is a regression test for a bug where confirmations were set to the
-// block height instead of being calculated as currentHeight - blockHeight + 1.
-//
-// The bug had several negative impacts:
-// - Unconfirmed transactions showed -1 confirmations instead of 0, breaking
-// zero-conf (accepting transactions before block inclusion)
-// - Confirmed transactions showed block height instead of actual confirmation
-// count
-// - LND and other consumers would make incorrect decisions based on wrong
-// counts
-func TestGetTransactionConfirmations(t *testing.T) {
- t.Parallel()
-
- rec, err := wtxmgr.NewTxRecord(TstSerializedTx, time.Now())
- require.NoError(t, err)
-
- tests := []struct {
- name string
-
- // Block height where transaction is mined (-1 for unmined).
- txBlockHeight int32
-
- // Current wallet sync height.
- currentHeight int32
-
- // Expected confirmations.
- expectedConfirmations int32
-
- // Expected height in result.
- expectedHeight int32
-
- // Whether to check for non-zero timestamp.
- expectTimestamp bool
- }{
- {
- name: "unconfirmed tx",
- txBlockHeight: -1,
- currentHeight: 100,
- expectedConfirmations: 0,
- expectedHeight: -1,
- expectTimestamp: false,
- },
- {
- name: "tx with 1 confirmation",
- txBlockHeight: 100,
- currentHeight: 100,
- expectedConfirmations: 1,
- expectedHeight: 100,
- expectTimestamp: true,
- },
- {
- name: "tx with 3 confirmations",
- txBlockHeight: 8,
- currentHeight: 10,
- expectedConfirmations: 3,
- expectedHeight: 8,
- expectTimestamp: true,
- },
- {
- name: "old tx with many confirmations",
- txBlockHeight: 1,
- currentHeight: 1000,
- expectedConfirmations: 1000,
- expectedHeight: 1,
- expectTimestamp: true,
- },
- {
- name: "tx in future block",
- txBlockHeight: 105,
- currentHeight: 100,
- expectedConfirmations: 0,
- expectedHeight: 105,
- expectTimestamp: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- t.Parallel()
- w, cleanup := testWallet(t)
- t.Cleanup(cleanup)
-
- // Set the wallet's synced height.
- err := walletdb.Update(
- w.db, func(tx walletdb.ReadWriteTx) error {
- addrmgrNs := tx.ReadWriteBucket(
- waddrmgrNamespaceKey,
- )
- bs := &waddrmgr.BlockStamp{
- Height: tt.currentHeight,
- Hash: chainhash.Hash{},
- }
-
- return w.Manager.SetSyncedTo(
- addrmgrNs, bs,
- )
- },
- )
- require.NoError(t, err)
-
- // Insert transaction into wallet.
- err = walletdb.Update(
- w.db, func(tx walletdb.ReadWriteTx) error {
- ns := tx.ReadWriteBucket(
- wtxmgrNamespaceKey,
- )
-
- // Create block metadata if transaction
- // is mined.
- var blockMeta *wtxmgr.BlockMeta
- if tt.txBlockHeight != -1 {
- hash := chainhash.Hash{}
- height := tt.txBlockHeight
- block := wtxmgr.Block{
- Hash: hash,
- Height: height,
- }
- blockMeta = &wtxmgr.BlockMeta{
- Block: block,
- Time: time.Now(),
- }
- }
-
- return w.TxStore.InsertTx(
- ns, rec, blockMeta,
- )
- },
- )
- require.NoError(t, err)
-
- result, err := w.GetTransaction(*TstTxHash)
- require.NoError(t, err)
-
- require.Equal(
- t, tt.expectedConfirmations,
- result.Confirmations,
- )
-
- require.Equal(t, tt.expectedHeight, result.Height)
-
- if tt.expectTimestamp {
- require.NotZero(t, result.Timestamp)
- } else {
- require.Zero(t, result.Timestamp)
- }
-
- // Additional checks for unconfirmed transactions.
- if tt.txBlockHeight == -1 {
- require.Nil(t, result.BlockHash)
- require.Equal(t, int32(0), result.Confirmations)
- } else {
- require.NotNil(t, result.BlockHash)
- // Only expect positive confirmations when tx is
- // not in a future block.
- if tt.txBlockHeight <= tt.currentHeight {
- require.Positive(
- t, result.Confirmations,
- )
- } else {
- // Confirmed txns in future blocks for
- // example due to reorg should be
- // treated as unconfirmed and have 0
- // confirmations.
- require.Equal(
- t, int32(0),
- result.Confirmations,
- )
- }
- }
- })
- }
-}
-
-// TestDuplicateAddressDerivation tests that duplicate addresses are not
-// derived when multiple goroutines are concurrently requesting new addresses.
-func TestDuplicateAddressDerivation(t *testing.T) {
- w, cleanup := testWallet(t)
- defer cleanup()
-
- var (
- m sync.Mutex
- globalAddrs = make(map[string]address.Address)
- )
-
- for o := 0; o < 10; o++ {
- var eg errgroup.Group
-
- for n := 0; n < 10; n++ {
- eg.Go(func() error {
- addrs := make([]address.Address, 10)
- for i := 0; i < 10; i++ {
- addr, err := w.NewAddress(
- 0, waddrmgr.KeyScopeBIP0084,
- )
- if err != nil {
- return err
- }
-
- addrs[i] = addr
- }
-
- m.Lock()
- defer m.Unlock()
-
- for idx := range addrs {
- addrStr := addrs[idx].String()
- if a, ok := globalAddrs[addrStr]; ok {
- return fmt.Errorf("duplicate "+
- "address! already "+
- "have %v, want to "+
- "add %v", a, addrs[idx])
- }
-
- globalAddrs[addrStr] = addrs[idx]
- }
-
- return nil
- })
- }
-
- require.NoError(t, eg.Wait())
- }
-}
-
-func TestEndRecovery(t *testing.T) {
- // This is an unconventional unit test, but I'm trying to keep things as
- // succint as possible so that this test is readable without having to mock
- // up literally everything.
- // The unmonitored goroutine we're looking at is pretty deep:
- // SynchronizeRPC -> handleChainNotifications -> syncWithChain -> recovery
- // The "deadlock" we're addressing isn't actually a deadlock, but the wallet
- // will hang on Stop() -> WaitForShutdown() until (*Wallet).recovery gets
- // every single block, which could be hours depending on hardware and
- // network factors. The WaitGroup is incremented in SynchronizeRPC, and
- // WaitForShutdown will not return until handleChainNotifications returns,
- // which is blocked by a running (*Wallet).recovery loop.
- // It is noted that the conditions for long recovery are difficult to hit
- // when using btcwallet with a fresh seed, because it requires an early
- // birthday to be set or established.
-
- w, cleanup := testWallet(t)
-
- blockHashCalled := make(chan struct{})
-
- chainClient := &mockChainClient{
- // Force the loop to iterate about forever.
- getBestBlockHeight: math.MaxInt32,
- // Get control of when the loop iterates.
- getBlockHashFunc: func() (*chainhash.Hash, error) {
- blockHashCalled <- struct{}{}
- return &chainhash.Hash{}, nil
- },
- // Avoid a panic.
- getBlockHeader: &wire.BlockHeader{},
- }
-
- recoveryDone := make(chan struct{})
- go func() {
- defer close(recoveryDone)
- w.recovery(chainClient, &waddrmgr.BlockStamp{})
- }()
-
- getBlockHashCalls := func(expCalls int) {
- var i int
- for {
- select {
- case <-blockHashCalled:
- i++
- case <-time.After(time.Second):
- t.Fatal("expected BlockHash to be called")
- }
- if i == expCalls {
- break
- }
- }
- }
-
- // Recovery is running.
- getBlockHashCalls(3)
-
- // Closing the quit channel, e.g. Stop() without endRecovery, alone will not
- // end the recovery loop.
- w.quitMu.Lock()
- close(w.quit)
- w.quitMu.Unlock()
- // Continues scanning.
- getBlockHashCalls(3)
-
- // We're done with this one
- atomic.StoreUint32(&w.recovering.Load().(*recoverySyncer).quit, 1)
- select {
- case <-blockHashCalled:
- case <-recoveryDone:
- }
- cleanup()
-
- // Try again.
- w, cleanup = testWallet(t)
- defer cleanup()
-
- // We'll catch the error to make sure we're hitting our desired path. The
- // WaitGroup isn't required for the test, but does show how it completes
- // shutdown at a higher level.
- var err error
- w.wg.Add(1)
- recoveryDone = make(chan struct{})
- go func() {
- defer w.wg.Done()
- defer close(recoveryDone)
- err = w.recovery(chainClient, &waddrmgr.BlockStamp{})
- }()
-
- waitedForShutdown := make(chan struct{})
- go func() {
- w.WaitForShutdown()
- close(waitedForShutdown)
- }()
-
- // Recovery is running.
- getBlockHashCalls(3)
-
- // endRecovery is required to exit the unmonitored goroutine.
- end := w.endRecovery()
- select {
- case <-blockHashCalled:
- case <-recoveryDone:
- }
- <-end
-
- // testWallet starts a couple of other unrelated goroutines that need to be
- // killed, so we still need to close the quit channel.
- w.quitMu.Lock()
- close(w.quit)
- w.quitMu.Unlock()
-
- select {
- case <-waitedForShutdown:
- case <-time.After(time.Second):
- t.Fatal("WaitForShutdown never returned")
- }
-
- if !strings.EqualFold(err.Error(), "recovery: forced shutdown") {
- t.Fatal("wrong error")
- }
-}
diff --git a/wallet/watchingonly_test.go b/wallet/watchingonly_test.go
deleted file mode 100644
index 73d43b3e1b..0000000000
--- a/wallet/watchingonly_test.go
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright (c) 2018 The btcsuite developers
-// Use of this source code is governed by an ISC
-// license that can be found in the LICENSE file.
-
-package wallet
-
-import (
- "testing"
- "time"
-
- "github.com/btcsuite/btcd/chaincfg/v2"
- _ "github.com/btcsuite/btcwallet/walletdb/bdb"
-)
-
-// TestCreateWatchingOnly checks that we can construct a watching-only
-// wallet.
-func TestCreateWatchingOnly(t *testing.T) {
- // Set up a wallet.
- dir := t.TempDir()
-
- pubPass := []byte("hello")
-
- loader := NewLoader(
- &chaincfg.TestNet3Params, dir, true, defaultDBTimeout, 250,
- WithWalletSyncRetryInterval(10*time.Millisecond),
- )
- _, err := loader.CreateNewWatchingOnlyWallet(pubPass, time.Now())
- if err != nil {
- t.Fatalf("unable to create wallet: %v", err)
- }
-}
diff --git a/walletsetup.go b/walletsetup.go
index 9eae9c05fc..6578efa168 100644
--- a/walletsetup.go
+++ b/walletsetup.go
@@ -152,7 +152,14 @@ func createWallet(cfg *config) error {
defer func() {
lockChan <- time.Time{}
}()
- err := w.Unlock(privPass, lockChan)
+
+ //nolint:staticcheck // This should be fixed once
+ // the interface refactor is finished, and new wallet
+ // RPC is built.
+ err := w.UnlockDeprecated(
+ privPass,
+ lockChan,
+ )
if err != nil {
fmt.Printf("ERR: Failed to unlock new wallet "+
"during old wallet key import: %v", err)
@@ -193,7 +200,7 @@ func createWallet(cfg *config) error {
return err
}
- w.Manager.Close()
+ w.AddrManager().Close()
fmt.Println("The wallet has been created successfully.")
return nil
}
@@ -221,7 +228,13 @@ func createSimulationWallet(cfg *config) error {
defer db.Close()
// Create the wallet.
- err = wallet.Create(db, pubPass, privPass, nil, activeNet.Params, time.Now())
+ //
+ //nolint:staticcheck // This should be fixed once the interface
+ // refactor is finished, and new wallet RPC is built.
+ err = wallet.CreateDeprecated(
+ db, pubPass, privPass, nil, activeNet.Params,
+ time.Now(),
+ )
if err != nil {
return err
}
diff --git a/wtxmgr/db.go b/wtxmgr/db.go
index d3509a6193..aa41ec1a5c 100644
--- a/wtxmgr/db.go
+++ b/wtxmgr/db.go
@@ -434,6 +434,10 @@ func readRawTxRecord(txHash *chainhash.Hash, v []byte, rec *TxRecord) error {
bucketTxRecords, txHash)
return storeError(ErrData, str, err)
}
+
+ // Cache the raw bytes when the above deserialization succeeded.
+ rec.SerializedTx = v[8:]
+
return nil
}
diff --git a/wtxmgr/error.go b/wtxmgr/error.go
index e1cd1d41e6..5069dd1e28 100644
--- a/wtxmgr/error.go
+++ b/wtxmgr/error.go
@@ -5,7 +5,10 @@
package wtxmgr
-import "fmt"
+import (
+ "errors"
+ "fmt"
+)
// ErrorCode identifies a category of error.
type ErrorCode uint8
@@ -51,6 +54,11 @@ const (
ErrUnknownVersion
)
+var (
+ // ErrUtxoNotFound is returned when a UTXO is not found in the store.
+ ErrUtxoNotFound = errors.New("utxo not found")
+)
+
var errStrs = [...]string{
ErrDatabase: "ErrDatabase",
ErrData: "ErrData",
diff --git a/wtxmgr/interface.go b/wtxmgr/interface.go
new file mode 100644
index 0000000000..2b55fbf44a
--- /dev/null
+++ b/wtxmgr/interface.go
@@ -0,0 +1,181 @@
+// Copyright (c) 2025 The btcsuite developers
+// Use of this source code is governed by an ISC
+// license that can be found in the LICENSE file.
+
+package wtxmgr
+
+import (
+ "time"
+
+ "github.com/btcsuite/btcd/btcutil/v2"
+ "github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/wire/v2"
+ "github.com/btcsuite/btcwallet/walletdb"
+)
+
+// TODO(yy): The TxStore interface is a temporary solution to decouple the
+// wallet from the wtxmgr. It is not a good example of a well-designed
+// interface. It has the following issues:
+//
+// 1. Violation of the Interface Segregation Principle (ISP):
+// The current TxStore interface is a "fat" interface, containing over 15
+// methods that span a wide range of responsibilities, from simple balance
+// lookups to administrative tasks like database rollbacks. A component that
+// only needs to read transaction details is forced to depend on the entire
+// interface, including methods for writing data and performing
+// administrative actions. This creates an unnecessarily large dependency
+// surface.
+//
+// 2. Lack of Cohesion and CRUD-like Grouping:
+// The methods in TxStore are not grouped by the domain entity they operate
+// on. A more intuitive design would follow a classic Create, Read, Update,
+// Delete (CRUD) pattern for each major entity (transactions, UTXOs,
+// labels). The flat structure of the interface makes it harder to
+// understand the available operations for a specific entity. For example,
+// PutTxLabel, FetchTxLabel, and TxDetails are all at the same level, despite
+// operating on different aspects of a transaction.
+//
+// 3. Leaky Abstractions:
+// The interface methods currently require the caller (the wallet package)
+// to pass in walletdb.ReadWriteBucket or walletdb.ReadBucket handles. This
+// leaks the implementation detail that the store is built on walletdb. The
+// wallet should not need to know about the underlying database technology
+// or manage database transactions for the wtxmgr. This also violates the
+// "Pull Complexity Downwards" principle, as the TxStore should be
+// responsible for its own data access logic.
+//
+// 4. Missing context.Context Propagation:
+// None of the interface methods accept a context.Context. This is a
+// critical omission. Without a context, we cannot enforce timeouts,
+// propagate cancellation signals, or ensure the graceful shutdown of
+// long-running database queries.
+//
+// TxStore is an interface that describes a transaction store.
+type TxStore interface {
+ // Balance returns the spendable wallet balance (total value of all
+ // unspent transaction outputs) given a minimum of minConf confirmations,
+ // calculated at a current chain height of curHeight. Coinbase outputs
+ // are only included in the balance if maturity has been reached.
+ Balance(ns walletdb.ReadBucket, minConf int32,
+ syncHeight int32) (btcutil.Amount, error)
+
+ // DeleteExpiredLockedOutputs iterates through all existing locked
+ // outputs and deletes those which have already expired.
+ DeleteExpiredLockedOutputs(ns walletdb.ReadWriteBucket) error
+
+ // InsertTx records a transaction as belonging to a wallet's transaction
+ // history. If block is nil, the transaction is considered unspent, and
+ // the transaction's index must be unset.
+ InsertTx(ns walletdb.ReadWriteBucket, rec *TxRecord,
+ block *BlockMeta) error
+
+ // InsertTxCheckIfExists records a transaction as belonging to a wallet's
+ // transaction history. If block is nil, the transaction is considered
+ // unspent, and the transaction's index must be unset. It will return
+ // true if the transaction was already recorded prior to the call.
+ InsertTxCheckIfExists(ns walletdb.ReadWriteBucket, rec *TxRecord,
+ block *BlockMeta) (bool, error)
+
+ // InsertConfirmedTx records a mined transaction and its associated
+ // credits in a single operation.
+ InsertConfirmedTx(ns walletdb.ReadWriteBucket, rec *TxRecord,
+ block *BlockMeta, credits []CreditEntry) error
+
+ // InsertUnconfirmedTx records an unmined transaction and its associated
+ // credits in a single operation.
+ InsertUnconfirmedTx(ns walletdb.ReadWriteBucket, rec *TxRecord,
+ credits []CreditEntry) error
+
+ // AddCredit marks a transaction record as containing a transaction
+ // output spendable by wallet. The output is added unspent, and is
+ // marked spent when a new transaction spending the output is inserted
+ // into the store.
+ AddCredit(ns walletdb.ReadWriteBucket, rec *TxRecord,
+ block *BlockMeta, index uint32, change bool) error
+
+ // ListLockedOutputs returns a list of objects representing the currently
+ // locked utxos.
+ ListLockedOutputs(ns walletdb.ReadBucket) ([]*LockedOutput, error)
+
+ // LockOutput locks an output to the given ID, preventing it from being
+ // available for coin selection. The absolute time of the lock's
+ // expiration is returned. The expiration of the lock can be extended by
+ // successive invocations of this call.
+ LockOutput(ns walletdb.ReadWriteBucket, id LockID, op wire.OutPoint,
+ duration time.Duration) (time.Time, error)
+
+ // OutputsToWatch returns a list of outputs to monitor during the
+ // wallet's startup. The returned items are similar to UnspentOutputs,
+ // exccept the locked outputs and unmined credits are also returned
+ // here. In addition, we only set the field `OutPoint` and `PkScript`
+ // for the `Credit`, as these are the only fields used during the
+ // rescan.
+ OutputsToWatch(ns walletdb.ReadBucket) ([]Credit, error)
+
+ // PutTxLabel validates transaction labels and writes them to disk if
+ // they are non-zero and within the label length limit.
+ PutTxLabel(ns walletdb.ReadWriteBucket, txid chainhash.Hash,
+ label string) error
+
+ // RangeTransactions runs the function f on all transaction details
+ // between blocks on the best chain over the height range [begin,end].
+ // The special height -1 may be used to also include unmined
+ // transactions. If the end height comes before the begin height, blocks
+ // are iterated in reverse order and unmined transactions (if any) are
+ // processed first.
+ RangeTransactions(ns walletdb.ReadBucket, begin, end int32,
+ f func([]TxDetails) (bool, error)) error
+
+ // Rollback removes all blocks at height onwards, moving any transactions
+ // within each block to the unconfirmed pool.
+ Rollback(ns walletdb.ReadWriteBucket, height int32) error
+
+ // TxDetails looks up all recorded details regarding a transaction with
+ // some hash. In case of a hash collision, the most recent transaction
+ // with a matching hash is returned.
+ TxDetails(ns walletdb.ReadBucket,
+ txHash *chainhash.Hash) (*TxDetails, error)
+
+ // UniqueTxDetails looks up all recorded details for a transaction
+ // recorded mined in some particular block, or an unmined transaction if
+ // block is nil.
+ UniqueTxDetails(ns walletdb.ReadBucket, txHash *chainhash.Hash,
+ block *Block) (*TxDetails, error)
+
+ // UnlockOutput unlocks an output, allowing it to be available for coin
+ // selection if it remains unspent. The ID should match the one used to
+ // originally lock the output.
+ UnlockOutput(ns walletdb.ReadWriteBucket, id LockID,
+ op wire.OutPoint) error
+
+ // UnspentOutputs returns all unspent received transaction outputs.
+ // The order is undefined.
+ UnspentOutputs(ns walletdb.ReadBucket) ([]Credit, error)
+
+ // FetchTxLabel reads a transaction label from the tx labels bucket. If
+ // a label with 0 length was written, we return an error, since this is
+ // unexpected.
+ FetchTxLabel(ns walletdb.ReadBucket,
+ txid chainhash.Hash) (string, error)
+
+ // GetUtxo returns the credit for a given outpoint, if it is known to
+ // the store as a UTXO. It checks for mined (confirmed) UTXOs first,
+ // and then unmined (unconfirmed) credits. If the UTXO is not found,
+ // ErrUtxoNotFound is returned. This function does not determine if the
+ // UTXO is spent by an unmined transaction or locked.
+ GetUtxo(ns walletdb.ReadBucket,
+ outpoint wire.OutPoint) (*Credit, error)
+
+ // UnminedTxs returns the underlying transactions for all unmined
+ // transactions which are not known to have been mined in a block.
+ // Transactions are guaranteed to be sorted by their dependency order.
+ UnminedTxs(ns walletdb.ReadBucket) ([]*wire.MsgTx, error)
+
+ // UnminedTxHashes returns the hashes of all transactions not known to
+ // have been mined in a block.
+ UnminedTxHashes(ns walletdb.ReadBucket) ([]*chainhash.Hash, error)
+
+ // RemoveUnminedTx attempts to remove an unmined transaction from the
+ // transaction store.
+ RemoveUnminedTx(ns walletdb.ReadWriteBucket, rec *TxRecord) error
+}
diff --git a/wtxmgr/query.go b/wtxmgr/query.go
index c2c9cfce03..c61079f220 100644
--- a/wtxmgr/query.go
+++ b/wtxmgr/query.go
@@ -8,8 +8,10 @@ package wtxmgr
import (
"fmt"
+ "github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/btcutil/v2"
"github.com/btcsuite/btcd/chainhash/v2"
+ "github.com/btcsuite/btcd/wire/v2"
"github.com/btcsuite/btcwallet/walletdb"
)
@@ -21,6 +23,9 @@ type CreditRecord struct {
Index uint32
Spent bool
Change bool
+
+ // Locked indicates whether the output is locked by the wallet.
+ Locked bool
}
// DebitRecord contains metadata regarding a transaction debit for a known
@@ -76,6 +81,12 @@ func (s *Store) minedTxDetails(ns walletdb.ReadBucket, txHash *chainhash.Hash, r
spent := existsRawUnminedInput(ns, k) != nil
credIter.elem.Spent = spent
}
+
+ // Check if locked.
+ op := wire.OutPoint{Hash: *txHash, Index: credIter.elem.Index}
+ _, _, locked := isLockedOutput(ns, op, s.clock.Now())
+ credIter.elem.Locked = locked
+
details.Credits = append(details.Credits, credIter.elem)
}
if credIter.err != nil {
@@ -124,6 +135,12 @@ func (s *Store) unminedTxDetails(ns walletdb.ReadBucket, txHash *chainhash.Hash,
// Set the Spent field since this is not done by the iterator.
it.elem.Spent = existsRawUnminedInput(ns, it.ck) != nil
+
+ // Check if locked.
+ op := wire.OutPoint{Hash: *txHash, Index: it.elem.Index}
+ _, _, locked := isLockedOutput(ns, op, s.clock.Now())
+ it.elem.Locked = locked
+
details.Credits = append(details.Credits, it.elem)
}
if it.err != nil {
@@ -184,7 +201,7 @@ func (s *Store) unminedTxDetails(ns walletdb.ReadBucket, txHash *chainhash.Hash,
func (s *Store) TxLabel(ns walletdb.ReadBucket, txHash chainhash.Hash) (string,
error) {
- label, err := FetchTxLabel(ns, txHash)
+ label, err := s.FetchTxLabel(ns, txHash)
switch err {
// If there are no saved labels yet (the bucket has not been created) or
// there is not a label for this particular tx, we ignore the error.
@@ -456,3 +473,129 @@ func (s *Store) PreviousPkScripts(ns walletdb.ReadBucket, rec *TxRecord, block *
return pkScripts, nil
}
+
+// getMinedUtxo constructs a Credit for a mined UTXO from the raw database
+// value.
+func (s *Store) getMinedUtxo(ns walletdb.ReadBucket, outpoint wire.OutPoint,
+ unspentVal []byte) (*Credit, error) {
+
+ var block Block
+ err := readUnspentBlock(unspentVal, &block)
+ if err != nil {
+ return nil, err
+ }
+
+ // We have the block, now fetch the full transaction record.
+ rec, err := fetchTxRecord(ns, &outpoint.Hash, &block)
+ if err != nil {
+ return nil, err
+ }
+
+ // Ensure the output index is valid for the transaction.
+ if int(outpoint.Index) >= len(rec.MsgTx.TxOut) {
+ str := fmt.Sprintf("mined credit %v references "+
+ "non-existent output index", outpoint)
+ return nil, storeError(ErrData, str, nil)
+ }
+ txOut := rec.MsgTx.TxOut[outpoint.Index]
+
+ blockTime, err := fetchBlockTime(ns, block.Height)
+ if err != nil {
+ return nil, err
+ }
+
+ _, _, locked := isLockedOutput(ns, outpoint, s.clock.Now())
+
+ credit := &Credit{
+ OutPoint: outpoint,
+ BlockMeta: BlockMeta{
+ Block: block,
+ Time: blockTime,
+ },
+ Amount: btcutil.Amount(txOut.Value),
+ PkScript: txOut.PkScript,
+ Received: rec.Received,
+ FromCoinBase: blockchain.IsCoinBaseTx(&rec.MsgTx),
+ Locked: locked,
+ }
+ return credit, nil
+}
+
+// getUnminedUtxo constructs a Credit for an unmined UTXO.
+func (s *Store) getUnminedUtxo(ns walletdb.ReadBucket,
+ outpoint wire.OutPoint) (*Credit, error) {
+
+ // The outpoint is an unmined credit. We need to fetch the
+ // full transaction to get all the credit details.
+ recVal := existsRawUnmined(ns, outpoint.Hash[:])
+ if recVal == nil {
+ // This would indicate a store inconsistency.
+ str := fmt.Sprintf("unmined credit %v has no matching "+
+ "unmined tx record", outpoint)
+ return nil, storeError(ErrData, str, nil)
+ }
+
+ var rec TxRecord
+ err := readRawTxRecord(&outpoint.Hash, recVal, &rec)
+ if err != nil {
+ return nil, err
+ }
+
+ // Ensure the output index is valid for the transaction.
+ if int(outpoint.Index) >= len(rec.MsgTx.TxOut) {
+ str := fmt.Sprintf("unmined credit %v references "+
+ "non-existent output index", outpoint)
+ return nil, storeError(ErrData, str, nil)
+ }
+ txOut := rec.MsgTx.TxOut[outpoint.Index]
+
+ _, _, locked := isLockedOutput(ns, outpoint, s.clock.Now())
+
+ credit := &Credit{
+ OutPoint: outpoint,
+ BlockMeta: BlockMeta{
+ Block: Block{Height: -1},
+ },
+ Amount: btcutil.Amount(txOut.Value),
+ PkScript: txOut.PkScript,
+ Received: rec.Received,
+ FromCoinBase: false, // Unmined can't be coinbase.
+ Locked: locked,
+ }
+ return credit, nil
+}
+
+// TODO(yy): This method is inefficient as it requires multiple database
+// lookups (unspent, tx records) to construct the full credit. This
+// should be optimized by denormalizing the unspent bucket to include
+// the amount and pkScript directly.
+//
+// TODO(yy): This function only confirms the existence of a UTXO but does
+// not guarantee its spendability. A more comprehensive version should
+// also check the 'unminedInputs' and 'lockedOutputs' buckets.
+//
+// GetUtxo returns the credit for a given outpoint, if it is known to the
+// store as a UTXO. It checks for mined (confirmed) UTXOs first, and then
+// unmined (unconfirmed) credits. If the UTXO is not found, ErrUtxoNotFound is
+// returned. This function does not determine if the UTXO is spent by an
+// unmined transaction or locked.
+func (s *Store) GetUtxo(ns walletdb.ReadBucket,
+ outpoint wire.OutPoint) (*Credit, error) {
+
+ k := canonicalOutPoint(&outpoint.Hash, outpoint.Index)
+
+ // First, check if the UTXO is a mined and unspent credit.
+ unspentVal := ns.NestedReadBucket(bucketUnspent).Get(k)
+ if unspentVal != nil {
+ return s.getMinedUtxo(ns, outpoint, unspentVal)
+ }
+
+ // If not found in mined, check if it's an unconfirmed credit.
+ v := existsRawUnminedCredit(ns, k)
+ if v != nil {
+ return s.getUnminedUtxo(ns, outpoint)
+ }
+
+ // If not found in either bucket, it's not a known UTXO.
+ return nil, ErrUtxoNotFound
+}
diff --git a/wtxmgr/query_test.go b/wtxmgr/query_test.go
index 4d40cc1c2b..3ac42d787e 100644
--- a/wtxmgr/query_test.go
+++ b/wtxmgr/query_test.go
@@ -16,6 +16,7 @@ import (
"github.com/btcsuite/btcd/chainhash/v2"
"github.com/btcsuite/btcd/wire/v2"
"github.com/btcsuite/btcwallet/walletdb"
+ "github.com/stretchr/testify/require"
)
type queryState struct {
@@ -214,7 +215,8 @@ func equalTxs(got, exp *wire.MsgTx) error {
// Returns time.Now() with seconds resolution, this is what Store saves.
func timeNow() time.Time {
- return time.Unix(time.Now().Unix(), 0)
+ // Truncate to the second to match the precision of the database.
+ return time.Now().Truncate(time.Second)
}
// Returns a copy of a TxRecord without the serialized tx.
@@ -275,11 +277,12 @@ func TestStoreQueries(t *testing.T) {
newState.blocks = [][]TxDetails{
{
{
- TxRecord: *stripSerializedTx(recA),
+ TxRecord: *recA,
Block: BlockMeta{Block: Block{Height: -1}},
},
},
}
+
newState.txDetails[recA.Hash] = []TxDetails{
newState.blocks[0][0],
}
@@ -322,7 +325,7 @@ func TestStoreQueries(t *testing.T) {
newState = lastState.deepCopy()
newState.blocks[0][0].Credits[0].Spent = true
newState.blocks[0] = append(newState.blocks[0], TxDetails{
- TxRecord: *stripSerializedTx(recB),
+ TxRecord: *recB,
Block: BlockMeta{Block: Block{Height: -1}},
Debits: []DebitRecord{
{
@@ -741,3 +744,65 @@ func TestPreviousPkScripts(t *testing.T) {
t.Fatal("Failed after inserting tx D")
}
}
+
+// TestGetUtxo tests the GetUtxo method to ensure it correctly retrieves both
+// mined and unmined UTXOs, and that it returns the expected error when a UTXO
+// cannot be found.
+func TestGetUtxo(t *testing.T) {
+ t.Parallel()
+
+ s, db, err := testStore(t)
+ require.NoError(t, err)
+ defer db.Close()
+
+ dbtx, err := db.BeginReadWriteTx()
+ require.NoError(t, err)
+ defer dbtx.Commit()
+ ns := dbtx.ReadWriteBucket(namespaceKey)
+
+ // We'll start by querying for a UTXO that does not exist in the
+ // store. This should result in a ErrUtxoNotFound error.
+ op := wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}
+ cred, err := s.GetUtxo(ns, op)
+ require.ErrorIs(t, err, ErrUtxoNotFound)
+ require.Nil(t, cred)
+
+ // Now, we'll add a mined transaction and its credit to the store. This
+ // will serve as our confirmed UTXO.
+ b100 := makeBlockMeta(100)
+ txA := spendOutput(&chainhash.Hash{}, 0, 100e8)
+ recA, err := NewTxRecordFromMsgTx(txA, timeNow())
+ require.NoError(t, err)
+
+ err = s.InsertTx(ns, recA, &b100)
+ require.NoError(t, err)
+ err = s.AddCredit(ns, recA, &b100, 0, false)
+ require.NoError(t, err)
+
+ // We should now be able to query for the mined UTXO and get back the
+ // correct credit details.
+ op = wire.OutPoint{Hash: recA.Hash, Index: 0}
+ cred, err = s.GetUtxo(ns, op)
+ require.NoError(t, err)
+ require.NotNil(t, cred)
+ require.Equal(t, op, cred.OutPoint)
+
+ // We'll do the same for an unmined transaction and its credit. This
+ // will serve as our unconfirmed UTXO.
+ txB := spendOutput(&recA.Hash, 0, 50e8)
+ recB, err := NewTxRecordFromMsgTx(txB, timeNow())
+ require.NoError(t, err)
+
+ err = s.InsertTx(ns, recB, nil)
+ require.NoError(t, err)
+ err = s.AddCredit(ns, recB, nil, 0, false)
+ require.NoError(t, err)
+
+ // We should now be able to query for the unmined UTXO and get back
+ // the correct credit details.
+ op = wire.OutPoint{Hash: recB.Hash, Index: 0}
+ cred, err = s.GetUtxo(ns, op)
+ require.NoError(t, err)
+ require.NotNil(t, cred)
+ require.Equal(t, op, cred.OutPoint)
+}
diff --git a/wtxmgr/tx.go b/wtxmgr/tx.go
index c7787c4d7f..253827f016 100644
--- a/wtxmgr/tx.go
+++ b/wtxmgr/tx.go
@@ -134,6 +134,13 @@ type LockedOutput struct {
Expiration time.Time
}
+// CreditEntry specifies a transaction output that should be recorded as a
+// credit (spendable output) for the wallet.
+type CreditEntry struct {
+ Index uint32
+ Change bool
+}
+
// NewTxRecord creates a new transaction record that may be inserted into the
// store. It uses memoization to save the transaction hash and the serialized
// transaction.
@@ -180,6 +187,9 @@ type Credit struct {
PkScript []byte
Received time.Time
FromCoinBase bool
+
+ // Locked indicates whether the output is locked by the wallet.
+ Locked bool
}
// LockID represents a unique context-specific ID assigned to an output lock.
@@ -198,6 +208,10 @@ type Store struct {
NotifyUnspent func(hash *chainhash.Hash, index uint32)
}
+// A compile-time assertion to ensure that Store implements the TxStore
+// interface.
+var _ TxStore = (*Store)(nil)
+
// Open opens the wallet transaction store from a walletdb namespace. If the
// store does not exist, ErrNoExist is returned. `lockDuration` represents how
// long outputs are locked for.
@@ -385,6 +399,50 @@ func (s *Store) InsertTxCheckIfExists(ns walletdb.ReadWriteBucket,
return false, err
}
+// InsertConfirmedTx records a mined transaction and its associated credits in
+// a single operation. This is more efficient than calling InsertTx followed by
+// AddCredit for each output.
+func (s *Store) InsertConfirmedTx(ns walletdb.ReadWriteBucket, rec *TxRecord,
+ block *BlockMeta, credits []CreditEntry) error {
+
+ if err := s.insertMinedTx(ns, rec, block); err != nil && err != ErrDuplicateTx {
+ return err
+ }
+
+ for _, c := range credits {
+ isNew, err := s.addCredit(ns, rec, block, c.Index, c.Change)
+ if err != nil {
+ return err
+ }
+ if isNew && s.NotifyUnspent != nil {
+ s.NotifyUnspent(&rec.Hash, c.Index)
+ }
+ }
+ return nil
+}
+
+// InsertUnconfirmedTx records an unmined transaction and its associated credits
+// in a single operation. This is more efficient than calling InsertTx followed
+// by AddCredit for each output.
+func (s *Store) InsertUnconfirmedTx(ns walletdb.ReadWriteBucket, rec *TxRecord,
+ credits []CreditEntry) error {
+
+ if err := s.insertMemPoolTx(ns, rec); err != nil && err != ErrDuplicateTx {
+ return err
+ }
+
+ for _, c := range credits {
+ isNew, err := s.addCredit(ns, rec, nil, c.Index, c.Change)
+ if err != nil {
+ return err
+ }
+ if isNew && s.NotifyUnspent != nil {
+ s.NotifyUnspent(&rec.Hash, c.Index)
+ }
+ }
+ return nil
+}
+
// RemoveUnminedTx attempts to remove an unmined transaction from the
// transaction store. This is to be used in the scenario that a transaction
// that we attempt to rebroadcast, turns out to double spend one of our
@@ -805,6 +863,34 @@ func (s *Store) rollback(ns walletdb.ReadWriteBucket, height int32) error {
return putMinedBalance(ns, minedBalance)
}
+// TODO(yy): The fetchCredits method suffers from several architectural and
+// performance issues that should be addressed in a future refactoring:
+//
+// 1. **N+1 Query Problem:** The function iterates through all unspent outputs
+// and performs a separate database lookup (`fetchTxRecord`) for each one to
+// retrieve its full details. For a wallet with a large number of UTXOs,
+// this results in an excessive number of database reads, leading to poor
+// performance.
+//
+// 2. **Inefficient Data Storage:** The root cause of the N+1 problem is that
+// the `unspent` bucket only stores a reference to the transaction, not the
+// critical data (Amount, PkScript) itself. The schema should be
+// denormalized to include this data directly in the `unspent` value, which
+// would turn the N+1 query into a single, efficient bucket scan.
+//
+// 3. **Code Duplication:** The logic for iterating over mined and unmined
+// credits is nearly identical, leading to significant code duplication. This
+// should be consolidated into a more generic helper function.
+//
+// 4. **Leaky Abstraction:** The use of multiple boolean flags
+// (`includeLocked`, `populateFullDetails`) to control behavior is a sign of
+// a leaky abstraction. A better API would provide more specific query
+// functions rather than a single, complex function with many toggles.
+//
+// 5. **Lack of Pagination:** The function loads all results into a single
+// in-memory slice, which can be memory-intensive for wallets with a large
+// UTXO set. A more scalable approach would use an iterator pattern.
+//
// fetchCredits retrieves credits from the store based on the provided filters.
// It iterates over both mined (unspent) and unmined credits.
//
@@ -832,12 +918,13 @@ func (s *Store) fetchCredits(ns walletdb.ReadBucket, includeLocked bool,
return err
}
+ // We check if this output is actually locked and set
+ // the Locked field.
+ _, _, isLocked := isLockedOutput(ns, op, now)
+
// Check if locked, skip if necessary.
- if !includeLocked {
- _, _, isLocked := isLockedOutput(ns, op, now)
- if isLocked {
- return nil
- }
+ if isLocked && !includeLocked {
+ return nil
}
// Check if spent by unmined, skip if necessary.
@@ -869,6 +956,7 @@ func (s *Store) fetchCredits(ns walletdb.ReadBucket, includeLocked bool,
cred := Credit{
OutPoint: op,
PkScript: txOut.PkScript,
+ Locked: isLocked,
}
// Populate full details if requested.
@@ -918,12 +1006,13 @@ func (s *Store) fetchCredits(ns walletdb.ReadBucket, includeLocked bool,
return err
}
+ // We check if this output is actually locked and set
+ // the Locked field.
+ _, _, isLocked := isLockedOutput(ns, op, now)
+
// Check if locked, skip if necessary.
- if !includeLocked {
- _, _, isLocked := isLockedOutput(ns, op, now)
- if isLocked {
- return nil
- }
+ if isLocked && !includeLocked {
+ return nil
}
// Check if spent by unmined, skip if necessary.
@@ -961,6 +1050,7 @@ func (s *Store) fetchCredits(ns walletdb.ReadBucket, includeLocked bool,
cred := Credit{
OutPoint: op,
PkScript: txOut.PkScript,
+ Locked: isLocked,
}
// Populate full details if requested.
@@ -1232,7 +1322,7 @@ func PutTxLabel(labelBucket walletdb.ReadWriteBucket, txid chainhash.Hash,
// FetchTxLabel reads a transaction label from the tx labels bucket. If a label
// with 0 length was written, we return an error, since this is unexpected.
-func FetchTxLabel(ns walletdb.ReadBucket, txid chainhash.Hash) (string, error) {
+func (s *Store) FetchTxLabel(ns walletdb.ReadBucket, txid chainhash.Hash) (string, error) {
labelBucket := ns.NestedReadBucket(bucketTxLabels)
if labelBucket == nil {
return "", ErrNoLabelBucket
diff --git a/wtxmgr/tx_test.go b/wtxmgr/tx_test.go
index 45257491c2..8ff4b482c2 100644
--- a/wtxmgr/tx_test.go
+++ b/wtxmgr/tx_test.go
@@ -2383,7 +2383,7 @@ func TestTxLabel(t *testing.T) {
err := walletdb.Update(db, func(tx walletdb.ReadWriteTx) error {
var err error
- label, err = FetchTxLabel(getBucket(tx), labelTx)
+ label, err = store.FetchTxLabel(getBucket(tx), labelTx)
return err
})