diff --git a/.golangci.json b/.golangci.json index e81ced3..ec7a586 100644 --- a/.golangci.json +++ b/.golangci.json @@ -46,7 +46,17 @@ "disable": [ "gochecknoglobals", "gocritic", - "godot" + "godot", + "godox", + "testifylint", + "revive", + "gochecknoinits", + "forbidigo", + "err113", + "errorlint", + "gosec", + "unused", + "gomoddirectives" ], "enable": [ "asasalint", diff --git a/coinbase_placeholder.go b/coinbase_placeholder.go new file mode 100644 index 0000000..ff6d17f --- /dev/null +++ b/coinbase_placeholder.go @@ -0,0 +1,46 @@ +package subtree + +import ( + "github.com/libsv/go-bt/v2" + "github.com/libsv/go-bt/v2/chainhash" +) + +var ( + // CoinbasePlaceholder hard code this value to avoid having to calculate it every time + // to help the compiler optimize the code. + CoinbasePlaceholder = [32]byte{ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + } + CoinbasePlaceholderHashValue = chainhash.Hash(CoinbasePlaceholder) + CoinbasePlaceholderHash = &CoinbasePlaceholderHashValue + + FrozenBytes = [36]byte{ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, + } + FrozenBytesTxBytes = FrozenBytes[0:32] + FrozenBytesTxHash = chainhash.Hash(FrozenBytesTxBytes) +) + +var ( + CoinbasePlaceholderTx *bt.Tx + coinbasePlaceholderTxHash *chainhash.Hash +) + +func init() { + CoinbasePlaceholderTx = bt.NewTx() + CoinbasePlaceholderTx.Version = 0xFFFFFFFF + CoinbasePlaceholderTx.LockTime = 0xFFFFFFFF + + coinbasePlaceholderTxHash = CoinbasePlaceholderTx.TxIDChainHash() +} + +func IsCoinbasePlaceHolderTx(tx *bt.Tx) bool { + return tx.TxIDChainHash().IsEqual(coinbasePlaceholderTxHash) +} diff --git a/coinbase_placeholder_test.go b/coinbase_placeholder_test.go new file mode 100644 index 0000000..7c542d4 --- /dev/null +++ b/coinbase_placeholder_test.go @@ -0,0 +1,17 @@ +package subtree + +import ( + "testing" + + "github.com/libsv/go-bt/v2" + "github.com/stretchr/testify/assert" +) + +func TestCoinbasePlaceholderTx(t *testing.T) { + assert.True(t, IsCoinbasePlaceHolderTx(CoinbasePlaceholderTx)) + assert.Equal(t, CoinbasePlaceholderTx.Version, uint32(0xFFFFFFFF)) + assert.Equal(t, CoinbasePlaceholderTx.LockTime, uint32(0xFFFFFFFF)) + assert.Equal(t, CoinbasePlaceholderTx.TxIDChainHash(), coinbasePlaceholderTxHash) + assert.False(t, IsCoinbasePlaceHolderTx(bt.NewTx())) + assert.Equal(t, "a8502e9c08b3c851201a71d25bf29fd38a664baedb777318b12d19242f0e46ab", CoinbasePlaceholderTx.TxIDChainHash().String()) +} diff --git a/compare.go b/compare.go new file mode 100644 index 0000000..05fbfd6 --- /dev/null +++ b/compare.go @@ -0,0 +1,19 @@ +package subtree + +import "golang.org/x/exp/constraints" + +func Min[T constraints.Ordered](a, b T) T { + if a < b { + return a + } + + return b +} + +func Max[T constraints.Ordered](a, b T) T { + if a > b { + return a + } + + return b +} diff --git a/compare_test.go b/compare_test.go new file mode 100644 index 0000000..05a4ded --- /dev/null +++ b/compare_test.go @@ -0,0 +1,43 @@ +package subtree + +import "testing" + +func TestMin(t *testing.T) { + tests := []struct { + a, b, expected int + }{ + {1, 2, 1}, + {2, 1, 1}, + {3, 3, 3}, + {-1, 1, -1}, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + result := Min(tt.a, tt.b) + if result != tt.expected { + t.Errorf("Min(%d, %d) = %d; want %d", tt.a, tt.b, result, tt.expected) + } + }) + } +} + +func TestMax(t *testing.T) { + tests := []struct { + a, b, expected int + }{ + {1, 2, 2}, + {2, 1, 2}, + {3, 3, 3}, + {-1, 1, 1}, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + result := Max(tt.a, tt.b) + if result != tt.expected { + t.Errorf("Max(%d, %d) = %d; want %d", tt.a, tt.b, result, tt.expected) + } + }) + } +} diff --git a/examples/example.go b/examples/example.go deleted file mode 100644 index f7b8277..0000000 --- a/examples/example.go +++ /dev/null @@ -1,15 +0,0 @@ -// Package main is an example of how to use the go-subtree package -package main - -import ( - "log" - - "github.com/bsv-blockchain/go-subtree" -) - -func main() { - // Greet the user with a custom name - name := "Alice" - greeting := template.Greet(name) - log.Println(greeting) -} diff --git a/go.mod b/go.mod index 3ca389d..f171c92 100644 --- a/go.mod +++ b/go.mod @@ -2,10 +2,24 @@ module github.com/bsv-blockchain/go-subtree go 1.24 -require github.com/stretchr/testify v1.10.0 +require ( + github.com/bsv-blockchain/go-safe-conversion v0.0.0-20250701040542-ca4e7b9ca0da + github.com/bsv-blockchain/go-tx-map v1.0.0 + github.com/libsv/go-bt/v2 v2.0.0-00010101000000-000000000000 + github.com/stretchr/testify v1.10.0 + golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b +) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dolthub/maphash v0.1.0 // indirect + github.com/dolthub/swiss v0.2.1 // indirect + github.com/libsv/go-bk v0.1.6 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/crypto v0.14.0 // indirect + google.golang.org/protobuf v1.36.6 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/libsv/go-bt/v2 => github.com/ordishs/go-bt/v2 v2.2.22 diff --git a/go.sum b/go.sum index 713a0b4..a836bba 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,39 @@ +github.com/bsv-blockchain/go-safe-conversion v0.0.0-20250701040542-ca4e7b9ca0da h1:WYjLgrUvgq1HjGcP1mU0aXPrqZ5l0UoDB/b9ssa/nxo= +github.com/bsv-blockchain/go-safe-conversion v0.0.0-20250701040542-ca4e7b9ca0da/go.mod h1:Fmat8fhPfMrdGCGv9PZ+QOkpQutD61hssbaLto4+3ks= +github.com/bsv-blockchain/go-tx-map v1.0.0 h1:oeGI6et039crvzuELKHojYdlZwDNf+UCv9r1+63sZHE= +github.com/bsv-blockchain/go-tx-map v1.0.0/go.mod h1:xCauj1rtF8dxuxP8WusegkFLzmcVCuLzcYZfjwXOi3Y= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dolthub/maphash v0.1.0 h1:bsQ7JsF4FkkWyrP3oCnFJgrCUAFbFf3kOl4L/QxPDyQ= +github.com/dolthub/maphash v0.1.0/go.mod h1:gkg4Ch4CdCDu5h6PMriVLawB7koZ+5ijb9puGMV50a4= +github.com/dolthub/swiss v0.2.1 h1:gs2osYs5SJkAaH5/ggVJqXQxRXtWshF6uE0lgR/Y3Gw= +github.com/dolthub/swiss v0.2.1/go.mod h1:8AhKZZ1HK7g18j7v7k6c5cYIGEZJcPn0ARsai8cUrh0= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/libsv/go-bk v0.1.6 h1:c9CiT5+64HRDbzxPl1v/oiFmbvWZTuUYqywCf+MBs/c= +github.com/libsv/go-bk v0.1.6/go.mod h1:khJboDoH18FPUaZlzRFKzlVN84d4YfdmlDtdX4LAjQA= +github.com/ordishs/go-bt/v2 v2.2.22 h1:5WmTQoX74g9FADM9hpbXZOE34uep4EqeSwpIy4CbWYE= +github.com/ordishs/go-bt/v2 v2.2.22/go.mod h1:bOaZFOoazYognJH/nfcBjuDFud1XmIc05n7bp4Tvvfk= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/inpoints.go b/inpoints.go new file mode 100644 index 0000000..3af63bc --- /dev/null +++ b/inpoints.go @@ -0,0 +1,240 @@ +package subtree + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math" + "slices" + + "github.com/libsv/go-bt/v2" + "github.com/libsv/go-bt/v2/chainhash" +) + +type Inpoint struct { + Hash chainhash.Hash + Index uint32 +} + +type TxInpoints struct { + ParentTxHashes []chainhash.Hash + Idxs [][]uint32 + + // internal variable + nrInpoints int +} + +func NewTxInpoints() TxInpoints { + return TxInpoints{ + ParentTxHashes: make([]chainhash.Hash, 0, 8), // initial capacity of 8, can grow as needed + Idxs: make([][]uint32, 0, 16), // initial capacity of 16, can grow as needed + } +} + +func NewTxInpointsFromTx(tx *bt.Tx) (TxInpoints, error) { + p := NewTxInpoints() + p.addTx(tx) + + return p, nil +} + +func NewTxInpointsFromInputs(inputs []*bt.Input) (TxInpoints, error) { + p := TxInpoints{} + + tx := &bt.Tx{} + tx.Inputs = inputs + + p.addTx(tx) + + return p, nil +} + +func NewTxInpointsFromBytes(data []byte) (TxInpoints, error) { + p := TxInpoints{} + + if err := p.deserializeFromReader(bytes.NewReader(data)); err != nil { + return p, err + } + + return p, nil +} + +func NewTxInpointsFromReader(buf io.Reader) (TxInpoints, error) { + p := TxInpoints{} + + if err := p.deserializeFromReader(buf); err != nil { + return p, err + } + + return p, nil +} + +func (p *TxInpoints) String() string { + return fmt.Sprintf("TxInpoints{ParentTxHashes: %v, Idxs: %v}", p.ParentTxHashes, p.Idxs) +} + +func (p *TxInpoints) addTx(tx *bt.Tx) { + // Do not error out for transactions without inputs, seeded Teranodes will have txs without inputs + + for _, input := range tx.Inputs { + hash := *input.PreviousTxIDChainHash() + + index := slices.Index(p.ParentTxHashes, hash) + if index != -1 { + p.Idxs[index] = append(p.Idxs[index], input.PreviousTxOutIndex) + } else { + p.ParentTxHashes = append(p.ParentTxHashes, hash) + p.Idxs = append(p.Idxs, []uint32{input.PreviousTxOutIndex}) + } + + p.nrInpoints++ + } +} + +// GetParentTxHashes returns the unique parent tx hashes +func (p *TxInpoints) GetParentTxHashes() []chainhash.Hash { + return p.ParentTxHashes +} + +func (p *TxInpoints) GetParentTxHashAtIndex(index int) (chainhash.Hash, error) { + if index >= len(p.ParentTxHashes) { + return chainhash.Hash{}, fmt.Errorf("index out of range") + } + + return p.ParentTxHashes[index], nil +} + +// GetTxInpoints returns the unique parent inpoints for the tx +func (p *TxInpoints) GetTxInpoints() []Inpoint { + inpoints := make([]Inpoint, 0, p.nrInpoints) + + for i, hash := range p.ParentTxHashes { + for _, index := range p.Idxs[i] { + inpoints = append(inpoints, Inpoint{ + Hash: hash, + Index: index, + }) + } + } + + return inpoints +} + +func (p *TxInpoints) GetParentVoutsAtIndex(index int) ([]uint32, error) { + if index >= len(p.ParentTxHashes) { + return nil, fmt.Errorf("index out of range") + } + + return p.Idxs[index], nil +} + +func (p *TxInpoints) Serialize() ([]byte, error) { + if len(p.ParentTxHashes) != len(p.Idxs) { + return nil, fmt.Errorf("parent tx hashes and indexes length mismatch") + } + + bufBytes := make([]byte, 0, 1024) // 1KB (arbitrary size, should be enough for most cases) + buf := bytes.NewBuffer(bufBytes) + + var ( + err error + bytesUint32 [4]byte + ) + + binary.LittleEndian.PutUint32(bytesUint32[:], len32(p.ParentTxHashes)) + + if _, err = buf.Write(bytesUint32[:]); err != nil { + return nil, fmt.Errorf("unable to write number of parent inpoints: %s", err) + } + + // write the parent tx hashes + for _, hash := range p.ParentTxHashes { + if _, err = buf.Write(hash[:]); err != nil { + return nil, fmt.Errorf("unable to write parent tx hash: %s", err) + } + } + + // write the parent indexes + for _, indexes := range p.Idxs { + binary.LittleEndian.PutUint32(bytesUint32[:], len32(indexes)) + + if _, err = buf.Write(bytesUint32[:]); err != nil { + return nil, fmt.Errorf("unable to write number of parent indexes: %s", err) + } + + for _, idx := range indexes { + binary.LittleEndian.PutUint32(bytesUint32[:], idx) + + if _, err = buf.Write(bytesUint32[:]); err != nil { + return nil, fmt.Errorf("unable to write parent index: %s", err) + } + } + } + + return buf.Bytes(), nil +} + +func (p *TxInpoints) deserializeFromReader(buf io.Reader) error { + // read the number of parent inpoints + var bytesUint32 [4]byte + + if _, err := io.ReadFull(buf, bytesUint32[:]); err != nil { + return fmt.Errorf("unable to read number of parent inpoints: %s", err) + } + + totalInpointsLen := binary.LittleEndian.Uint32(bytesUint32[:]) + + if totalInpointsLen == 0 { + return nil + } + + p.nrInpoints = int(totalInpointsLen) + + // read the parent inpoints + p.ParentTxHashes = make([]chainhash.Hash, totalInpointsLen) + p.Idxs = make([][]uint32, totalInpointsLen) + + // read the parent tx hash + for i := uint32(0); i < totalInpointsLen; i++ { + if _, err := io.ReadFull(buf, p.ParentTxHashes[i][:]); err != nil { + return fmt.Errorf("unable to read parent tx hash: %s", err) + } + } + + // read the number of parent indexes + for i := uint32(0); i < totalInpointsLen; i++ { + if _, err := io.ReadFull(buf, bytesUint32[:]); err != nil { + return fmt.Errorf("unable to read number of parent indexes: %s", err) + } + + parentIndexesLen := binary.LittleEndian.Uint32(bytesUint32[:]) + + // read the parent indexes + p.Idxs[i] = make([]uint32, parentIndexesLen) + + for j := uint32(0); j < parentIndexesLen; j++ { + if _, err := io.ReadFull(buf, bytesUint32[:]); err != nil { + return fmt.Errorf("unable to read parent index: %s", err) + } + + p.Idxs[i][j] = binary.LittleEndian.Uint32(bytesUint32[:]) + } + } + + return nil +} + +func len32[V any](b []V) uint32 { + if b == nil { + return 0 + } + + l := len(b) + + if l > math.MaxUint32 { + return math.MaxInt32 + } + + return uint32(l) +} diff --git a/inpoints_test.go b/inpoints_test.go new file mode 100644 index 0000000..07a4afd --- /dev/null +++ b/inpoints_test.go @@ -0,0 +1,129 @@ +package subtree + +import ( + "testing" + + "github.com/libsv/go-bt/v2/chainhash" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTxInpoints(t *testing.T) { + t.Run("TestTxInpoints", func(t *testing.T) { + p, err := NewTxInpointsFromTx(tx) + require.NoError(t, err) + + assert.Equal(t, 1, len(p.ParentTxHashes)) + assert.Equal(t, 1, len(p.Idxs[0])) + }) + + t.Run("serialize", func(t *testing.T) { + p, err := NewTxInpointsFromTx(tx) + require.NoError(t, err) + + b, err := p.Serialize() + require.NoError(t, err) + assert.Equal(t, 44, len(b)) + + p2, err := NewTxInpointsFromBytes(b) + require.NoError(t, err) + + assert.Equal(t, 1, len(p2.ParentTxHashes)) + assert.Equal(t, 1, len(p2.Idxs[0])) + + assert.Equal(t, p.ParentTxHashes[0], p2.ParentTxHashes[0]) + assert.Equal(t, p.Idxs[0][0], p2.Idxs[0][0]) + }) + + t.Run("serialize with error", func(t *testing.T) { + p := NewTxInpoints() + p.ParentTxHashes = []chainhash.Hash{chainhash.HashH([]byte("test"))} + p.Idxs = [][]uint32{} + + _, err := p.Serialize() + require.Error(t, err) + }) + + t.Run("from inputs", func(t *testing.T) { + p, err := NewTxInpointsFromTx(tx) + require.NoError(t, err) + + p2, err := NewTxInpointsFromInputs(tx.Inputs) + require.NoError(t, err) + + // make sure they are the same + assert.Equal(t, len(p.ParentTxHashes), len(p2.ParentTxHashes)) + assert.Equal(t, len(p.Idxs), len(p2.Idxs)) + assert.Equal(t, p.ParentTxHashes[0], p2.ParentTxHashes[0]) + assert.Equal(t, p.Idxs[0][0], p2.Idxs[0][0]) + }) +} + +func TestGetTxInpoints(t *testing.T) { + p, err := NewTxInpointsFromTx(tx) + require.NoError(t, err) + + // Test getting inpoints + inpoints := p.GetTxInpoints() + assert.Equal(t, 1, len(inpoints)) + assert.Equal(t, uint32(5), inpoints[0].Index) + assert.Equal(t, *tx.Inputs[0].PreviousTxIDChainHash(), inpoints[0].Hash) +} + +func TestGetParentTxHashAtIndex(t *testing.T) { + t.Run("TestGetParentTxHashAtIndex", func(t *testing.T) { + p, err := NewTxInpointsFromTx(tx) + require.NoError(t, err) + + // Test getting parent tx hash at index + hash, err := p.GetParentTxHashAtIndex(0) + require.NoError(t, err) + + assert.Equal(t, *tx.Inputs[0].PreviousTxIDChainHash(), hash) + }) + + t.Run("out of range", func(t *testing.T) { + p, err := NewTxInpointsFromTx(tx) + require.NoError(t, err) + + // Test getting parent tx hash at index + hash, err := p.GetParentTxHashAtIndex(1) + require.Error(t, err) + + assert.Equal(t, chainhash.Hash{}, hash) + }) +} + +func TestGetParentVoutsAtIndex(t *testing.T) { + t.Run("TestGetParentVoutsAtIndex", func(t *testing.T) { + p, err := NewTxInpointsFromTx(tx) + require.NoError(t, err) + + // Test getting parent vouts at index + vouts, err := p.GetParentVoutsAtIndex(0) + require.NoError(t, err) + + assert.Equal(t, 1, len(vouts)) + assert.Equal(t, uint32(5), vouts[0]) + }) + + t.Run("out of range", func(t *testing.T) { + p, err := NewTxInpointsFromTx(tx) + require.NoError(t, err) + + // Test getting parent vouts at index + vouts, err := p.GetParentVoutsAtIndex(1) + require.Error(t, err) + + assert.Nil(t, vouts) + }) +} + +func BenchmarkNewTxInpoints(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := NewTxInpointsFromTx(tx) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/merkle_tree.go b/merkle_tree.go new file mode 100644 index 0000000..8843f34 --- /dev/null +++ b/merkle_tree.go @@ -0,0 +1,150 @@ +package subtree + +import ( + "crypto/sha256" + "fmt" + "math" + "sync" + + "github.com/libsv/go-bt/v2/chainhash" +) + +func GetMerkleProofForCoinbase(subtrees []*Subtree) ([]*chainhash.Hash, error) { + if len(subtrees) == 0 { + return nil, fmt.Errorf("no subtrees available") + } + + merkleProof, err := subtrees[0].GetMerkleProof(0) + if err != nil { + return nil, fmt.Errorf("failed creating merkle proof for subtree: %s", err) + } + + // Create a new tree with the subtreeHashes of the subtrees + topTree, err := NewTreeByLeafCount(CeilPowerOfTwo(len(subtrees))) + if err != nil { + return nil, err + } + + for _, subtree := range subtrees { + err = topTree.AddNode(*subtree.RootHash(), subtree.Fees, subtree.SizeInBytes) + if err != nil { + return nil, err + } + } + + topMerkleProof, err := topTree.GetMerkleProof(0) + if err != nil { + return nil, fmt.Errorf("failed creating merkle proofs for top tree: %s", err) + } + + return append(merkleProof, topMerkleProof...), nil +} + +func BuildMerkleTreeStoreFromBytes(nodes []SubtreeNode) (*[]chainhash.Hash, error) { + if len(nodes) == 0 { + return &[]chainhash.Hash{}, nil + } + + // Calculate how many entries are in an array of that size. + length := len(nodes) + nextPoT := NextPowerOfTwo(length) + arraySize := nextPoT*2 - 1 + // we do not include the original nodes in the merkle tree + merkles := make([]chainhash.Hash, nextPoT-1) + // merkles := []byte{MaxSubtreeSize: 0} + + if arraySize == 1 { + // Handle this Bitcoin exception that the merkle root is the same as the transaction hash if there + // is only one transaction. + return &[]chainhash.Hash{nodes[0].Hash}, nil + } + + // Start the array offset after the last transaction and adjusted to the + // next power of two. + height := int(math.Ceil(math.Log2(float64(length)))) + routineSplitSize := 1024 // should be a power of two + + merkleFrom := 0 + for h := 0; h <= height; h++ { + merkleTo := merkleFrom + int(math.Pow(2, float64(height-h))) - 1 + if merkleTo-merkleFrom > routineSplitSize { + var wg sync.WaitGroup + // if we are calculating a large merkle tree, we can do this in parallel + for i := merkleFrom; i < merkleTo; i += routineSplitSize { + wg.Add(1) + + go func(i int) { + defer wg.Done() + calcMerkles(nodes, i, Min(i+routineSplitSize, merkleTo), nextPoT, length, merkles) + }(i) + } + + wg.Wait() + } else { + calcMerkles(nodes, merkleFrom, merkleTo, nextPoT, length, merkles) + } + + merkleFrom = merkleTo + 1 + } + + return &merkles, nil +} + +func calcMerkles(nodes []SubtreeNode, merkleFrom, merkleTo, nextPoT, length int, merkles []chainhash.Hash) { + var offset int + + var currentMerkle chainhash.Hash + + var currentMerkle1 chainhash.Hash + + for i := merkleFrom; i < merkleTo; i += 2 { + offset = i / 2 + + if i < nextPoT { + if i >= length { + currentMerkle = chainhash.Hash{} + } else { + currentMerkle = nodes[i].Hash + } + + if i+1 >= length { + currentMerkle1 = chainhash.Hash{} + } else { + currentMerkle1 = nodes[i+1].Hash + } + } else { + currentMerkle = merkles[i-nextPoT] + currentMerkle1 = merkles[i-nextPoT+1] + } + + merkles[offset] = calcMerkle(currentMerkle, currentMerkle1) + } +} + +func calcMerkle(currentMerkle chainhash.Hash, currentMerkle1 chainhash.Hash) [32]byte { + switch { + // When there is no left child node, the parent is nil ("") too. + case currentMerkle.Equal(chainhash.Hash{}): + return chainhash.Hash{} + + // When there is no right child, the parent is generated by + // hashing the concatenation of the left child with itself. + case currentMerkle1.Equal(chainhash.Hash{}): + shaBytes := [64]byte{} + copy(shaBytes[0:32], currentMerkle[:]) + copy(shaBytes[32:64], currentMerkle[:]) + hash := sha256.Sum256(shaBytes[:]) + + return sha256.Sum256(hash[:]) + + // The normal case sets the parent node to the double sha256 + // of the concatenation of the left and right children. + default: + shaBytes := [64]byte{} + copy(shaBytes[0:32], currentMerkle[:]) + copy(shaBytes[32:64], currentMerkle1[:]) + hash := sha256.Sum256(shaBytes[:]) + + return sha256.Sum256(hash[:]) + } +} diff --git a/merkle_tree_test.go b/merkle_tree_test.go new file mode 100644 index 0000000..974c459 --- /dev/null +++ b/merkle_tree_test.go @@ -0,0 +1,53 @@ +package subtree + +import ( + "testing" + + "github.com/libsv/go-bt/v2/chainhash" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetMerkleProofForCoinbase(t *testing.T) { + hash1, _ := chainhash.NewHashFromStr("97af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115") + hash2, _ := chainhash.NewHashFromStr("7ce05dda56bc523048186c0f0474eb21c92fe35de6d014bd016834637a3ed08d") + hash3, _ := chainhash.NewHashFromStr("3070fb937289e24720c827cbc24f3fce5c361cd7e174392a700a9f42051609e0") + hash4, _ := chainhash.NewHashFromStr("d3cde0ab7142cc99acb31c5b5e1e941faed1c5cf5f8b63ed663972845d663487") + + hash5, _ := chainhash.NewHashFromStr("87af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115") + hash6, _ := chainhash.NewHashFromStr("6ce05dda56bc523048186c0f0474eb21c92fe35de6d014bd016834637a3ed08d") + hash7, _ := chainhash.NewHashFromStr("2070fb937289e24720c827cbc24f3fce5c361cd7e174392a700a9f42051609e0") + hash8, _ := chainhash.NewHashFromStr("c3cde0ab7142cc99acb31c5b5e1e941faed1c5cf5f8b63ed663972845d663487") + + expectedRootHash := "86867b9f3e7dcb4bdf5b5cc99322122fe492bc466621f3709d4e389e7e14c16c" + + t.Run("", func(t *testing.T) { + subtree1, err := NewTree(2) + require.NoError(t, err) + + require.NoError(t, subtree1.AddNode(*hash1, 12, 0)) + require.NoError(t, subtree1.AddNode(*hash2, 13, 0)) + require.NoError(t, subtree1.AddNode(*hash3, 14, 0)) + require.NoError(t, subtree1.AddNode(*hash4, 15, 0)) + + subtree2, err := NewTree(2) + require.NoError(t, err) + + require.NoError(t, subtree2.AddNode(*hash5, 16, 0)) + require.NoError(t, subtree2.AddNode(*hash6, 17, 0)) + require.NoError(t, subtree2.AddNode(*hash7, 18, 0)) + require.NoError(t, subtree2.AddNode(*hash8, 19, 0)) + + merkleProof, err := GetMerkleProofForCoinbase([]*Subtree{subtree1, subtree2}) + require.NoError(t, err) + assert.Equal(t, "7ce05dda56bc523048186c0f0474eb21c92fe35de6d014bd016834637a3ed08d", merkleProof[0].String()) + assert.Equal(t, "c32db78e5f8437648888713982ea3d49628dbde0b4b48857147f793b55d26f09", merkleProof[1].String()) + + topTree, err := NewTreeByLeafCount(2) + require.NoError(t, err) + + require.NoError(t, topTree.AddNode(*subtree1.RootHash(), subtree1.Fees, subtree1.SizeInBytes)) + require.NoError(t, topTree.AddNode(*subtree2.RootHash(), subtree2.Fees, subtree1.SizeInBytes)) + assert.Equal(t, expectedRootHash, topTree.RootHash().String()) + }) +} diff --git a/power_of_two.go b/power_of_two.go new file mode 100644 index 0000000..11db277 --- /dev/null +++ b/power_of_two.go @@ -0,0 +1,53 @@ +package subtree + +import ( + "math" + "math/bits" +) + +func CeilPowerOfTwo(num int) int { + if num <= 0 { + return 1 + } + + // Find the position of the most significant bit + msbPos := uint(math.Ceil(math.Log2(float64(num)))) + + // Calculate the power of 2 with the next higher position + ceilValue := int(math.Pow(2, float64(msbPos))) + + return ceilValue +} + +func IsPowerOfTwo(num int) bool { + if num <= 0 { + return false + } + + return (num & (num - 1)) == 0 +} + +// NextPowerOfTwo returns the next highest power of two from a given number if +// it is not already a power of two. This is a helper function used during the +// calculation of a merkle tree. +func NextPowerOfTwo(n int) int { + // Return the number if it's already a power of 2. + if n&(n-1) == 0 { + return n + } + + // Figure out and return the next power of two. + exponent := uint(math.Log2(float64(n))) + 1 + + return 1 << exponent // 2^exponent +} + +// NextLowerPowerOfTwo finds the next power of 2 that is less than x. +func NextLowerPowerOfTwo(x uint) uint { + if x == 0 { + return 0 + } + + // bit length minus one gives the exponent of the next lower (or equal) power of 2 + return 1 << (bits.Len(x) - 1) +} diff --git a/power_of_two_test.go b/power_of_two_test.go new file mode 100644 index 0000000..9c18bf0 --- /dev/null +++ b/power_of_two_test.go @@ -0,0 +1,31 @@ +package subtree + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsPowerOf2(t *testing.T) { + // Testing the function + numbers := []int{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 1048576, 70368744177664} + for _, num := range numbers { + assert.True(t, IsPowerOfTwo(num), fmt.Sprintf("%d should be a power of 2", num)) + } + + numbers = []int{-1, 0, 41, 13} + for _, num := range numbers { + assert.False(t, IsPowerOfTwo(num), fmt.Sprintf("%d should be a power of 2", num)) + } +} + +func TestNextLowerPowerOf2(t *testing.T) { + // Testing the function + numbers := []uint{17, 32, 120, 128, 0, 231072} + expected := []uint{16, 32, 64, 128, 0, 131072} + + for i, num := range numbers { + assert.Equal(t, expected[i], NextLowerPowerOfTwo(num), fmt.Sprintf("%d should be a power of 2", num)) + } +} diff --git a/subtree.go b/subtree.go new file mode 100644 index 0000000..9961a9a --- /dev/null +++ b/subtree.go @@ -0,0 +1,790 @@ +package subtree + +import ( + "bufio" + "bytes" + "encoding/binary" + "fmt" + "io" + "math" + "sync" + + safe "github.com/bsv-blockchain/go-safe-conversion" + txmap "github.com/bsv-blockchain/go-tx-map" + "github.com/libsv/go-bt/v2/chainhash" +) + +type SubtreeNode struct { + Hash chainhash.Hash `json:"txid"` // This is called txid so that the UI knows to add a link to /tx/ + Fee uint64 `json:"fee"` + SizeInBytes uint64 `json:"size"` +} + +type Subtree struct { + Height int + Fees uint64 + SizeInBytes uint64 + FeeHash chainhash.Hash + Nodes []SubtreeNode + ConflictingNodes []chainhash.Hash // conflicting nodes need to be checked when doing block assembly + + // temporary (calculated) variables + rootHash *chainhash.Hash + treeSize int + feeBytes []byte + feeHashBytes []byte + + mu sync.RWMutex // protects Nodes slice + nodeIndex map[chainhash.Hash]int // maps txid to index in Nodes slice +} + +type TxMap interface { + Put(hash chainhash.Hash, value uint64) error + Get(hash chainhash.Hash) (uint64, bool) + Exists(hash chainhash.Hash) bool + Length() int + Keys() []chainhash.Hash +} + +// NewTree creates a new Subtree with a fixed height +// +// height is the number if levels in a merkle tree of the subtree +func NewTree(height int) (*Subtree, error) { + if height < 0 { + return nil, fmt.Errorf("height must be at least 0") + } + + var treeSize = int(math.Pow(2, float64(height))) + + return &Subtree{ + Nodes: make([]SubtreeNode, 0, treeSize), + Height: height, + FeeHash: chainhash.Hash{}, + treeSize: treeSize, + // feeBytes: make([]byte, 8), + // feeHashBytes: make([]byte, 40), + }, nil +} + +func NewTreeByLeafCount(maxNumberOfLeaves int) (*Subtree, error) { + if !IsPowerOfTwo(maxNumberOfLeaves) { + return nil, fmt.Errorf("numberOfLeaves must be a power of two") + } + + height := math.Ceil(math.Log2(float64(maxNumberOfLeaves))) + + return NewTree(int(height)) +} + +func NewIncompleteTreeByLeafCount(maxNumberOfLeaves int) (*Subtree, error) { + height := math.Ceil(math.Log2(float64(maxNumberOfLeaves))) + + return NewTree(int(height)) +} + +func NewSubtreeFromBytes(b []byte) (*Subtree, error) { + defer func() { + if r := recover(); r != nil { + fmt.Printf("Recovered in NewSubtreeFromBytes: %v\n", r) + } + }() + + subtree := &Subtree{} + + err := subtree.Deserialize(b) + if err != nil { + return nil, err + } + + return subtree, nil +} + +func NewSubtreeFromReader(reader io.Reader) (*Subtree, error) { + defer func() { + if r := recover(); r != nil { + fmt.Printf("Recovered in NewSubtreeFromReader: %v\n", r) + } + }() + + subtree := &Subtree{} + + if err := subtree.DeserializeFromReader(reader); err != nil { + return nil, err + } + + return subtree, nil +} + +func DeserializeNodesFromReader(reader io.Reader) (subtreeBytes []byte, err error) { + buf := bufio.NewReaderSize(reader, 1024*1024*16) // 16MB buffer + + // root len(st.rootHash[:]) bytes + // first 8 bytes, fees + // second 8 bytes, sizeInBytes + // third 8 bytes, number of leaves + // total read at once = len(st.rootHash[:]) + 8 + 8 + 8 + byteBuffer := make([]byte, chainhash.HashSize+24) + if _, err = ReadBytes(buf, byteBuffer); err != nil { + return nil, fmt.Errorf("unable to read subtree root information: %w", err) + } + + numLeaves := binary.LittleEndian.Uint64(byteBuffer[chainhash.HashSize+16 : chainhash.HashSize+24]) + subtreeBytes = make([]byte, chainhash.HashSize*int(numLeaves)) + + byteBuffer = byteBuffer[8:] // reduce read byteBuffer size by 8 + for i := uint64(0); i < numLeaves; i++ { + if _, err = ReadBytes(buf, byteBuffer); err != nil { + return nil, fmt.Errorf("unable to read subtree node information: %w", err) + } + + copy(subtreeBytes[i*chainhash.HashSize:(i+1)*chainhash.HashSize], byteBuffer[:chainhash.HashSize]) + } + + return subtreeBytes, nil +} + +func (st *Subtree) Duplicate() *Subtree { + newSubtree := &Subtree{ + Height: st.Height, + Fees: st.Fees, + SizeInBytes: st.SizeInBytes, + FeeHash: st.FeeHash, + Nodes: make([]SubtreeNode, len(st.Nodes)), + ConflictingNodes: make([]chainhash.Hash, len(st.ConflictingNodes)), + rootHash: st.rootHash, + treeSize: st.treeSize, + // feeBytes: make([]byte, 8), + // feeHashBytes: make([]byte, 40), + } + + copy(newSubtree.Nodes, st.Nodes) + copy(newSubtree.ConflictingNodes, st.ConflictingNodes) + + return newSubtree +} + +// Size returns the capacity of the subtree +func (st *Subtree) Size() int { + st.mu.RLock() + size := cap(st.Nodes) + st.mu.RUnlock() + + return size +} + +// Length returns the number of nodes in the subtree +func (st *Subtree) Length() int { + st.mu.RLock() + length := len(st.Nodes) + st.mu.RUnlock() + + return length +} + +func (st *Subtree) IsComplete() bool { + st.mu.RLock() + isComplete := len(st.Nodes) == cap(st.Nodes) + st.mu.RUnlock() + + return isComplete +} + +func (st *Subtree) ReplaceRootNode(node *chainhash.Hash, fee uint64, sizeInBytes uint64) *chainhash.Hash { + if len(st.Nodes) < 1 { + st.Nodes = append(st.Nodes, SubtreeNode{ + Hash: *node, + Fee: fee, + SizeInBytes: sizeInBytes, + }) + } else { + st.Nodes[0] = SubtreeNode{ + Hash: *node, + Fee: fee, + SizeInBytes: sizeInBytes, + } + } + + st.rootHash = nil // reset rootHash + st.SizeInBytes += sizeInBytes + + return st.RootHash() +} + +func (st *Subtree) AddSubtreeNode(node SubtreeNode) error { + st.mu.Lock() + defer st.mu.Unlock() + + if (len(st.Nodes) + 1) > st.treeSize { + return fmt.Errorf("subtree is full") + } + + if node.Hash.Equal(CoinbasePlaceholder) { + return fmt.Errorf("[AddSubtreeNode] coinbase placeholder node should be added with AddCoinbaseNode, tree length is %d", len(st.Nodes)) + } + + // AddNode is not concurrency safe, so we can reuse the same byte arrays + // binary.LittleEndian.PutUint64(st.feeBytes, fee) + // st.feeHashBytes = append(node[:], st.feeBytes[:]...) + // if len(st.Nodes) == 0 { + // st.FeeHash = chainhash.HashH(st.feeHashBytes) + // } else { + // st.FeeHash = chainhash.HashH(append(st.FeeHash[:], st.feeHashBytes...)) + // } + + st.Nodes = append(st.Nodes, node) + st.rootHash = nil // reset rootHash + st.Fees += node.Fee + st.SizeInBytes += node.SizeInBytes + + if st.nodeIndex != nil { + // node index map exists, add the node to it + st.nodeIndex[node.Hash] = len(st.Nodes) - 1 + } + + return nil +} + +func (st *Subtree) AddCoinbaseNode() error { + if len(st.Nodes) != 0 { + return fmt.Errorf("subtree should be empty before adding a coinbase node") + } + + st.Nodes = append(st.Nodes, SubtreeNode{ + Hash: CoinbasePlaceholder, + Fee: 0, + SizeInBytes: 0, + }) + st.rootHash = nil // reset rootHash + st.Fees = 0 + st.SizeInBytes = 0 + + return nil +} + +func (st *Subtree) AddConflictingNode(newConflictingNode chainhash.Hash) error { + if st.ConflictingNodes == nil { + st.ConflictingNodes = make([]chainhash.Hash, 0, 1) + } + + // check the conflicting node is actually in the subtree + found := false + + for _, n := range st.Nodes { + if n.Hash.Equal(newConflictingNode) { + found = true + break + } + } + + if !found { + return fmt.Errorf("conflicting node is not in the subtree") + } + + // check whether the conflicting node has already been added + for _, conflictingNode := range st.ConflictingNodes { + if conflictingNode.Equal(newConflictingNode) { + return nil + } + } + + st.ConflictingNodes = append(st.ConflictingNodes, newConflictingNode) + + return nil +} + +// AddNode adds a node to the subtree +// NOTE: this function is not concurrency safe, so it should be called from a single goroutine +// +// Parameters: +// - node: the transaction id of the node to add +// - fee: the fee of the node +// - sizeInBytes: the size of the node in bytes +// +// Returns: +// - error: an error if the node could not be added +func (st *Subtree) AddNode(node chainhash.Hash, fee uint64, sizeInBytes uint64) error { + if (len(st.Nodes) + 1) > st.treeSize { + return fmt.Errorf("subtree is full") + } + + if node.Equal(CoinbasePlaceholder) { + return fmt.Errorf("[AddNode] coinbase placeholder node should be added with AddCoinbaseNode") + } + + // AddNode is not concurrency safe, so we can reuse the same byte arrays + // binary.LittleEndian.PutUint64(st.feeBytes, fee) + // st.feeHashBytes = append(node[:], st.feeBytes[:]...) + // if len(st.Nodes) == 0 { + // st.FeeHash = chainhash.HashH(st.feeHashBytes) + // } else { + // st.FeeHash = chainhash.HashH(append(st.FeeHash[:], st.feeHashBytes...)) + // } + + st.Nodes = append(st.Nodes, SubtreeNode{ + Hash: node, + Fee: fee, + SizeInBytes: sizeInBytes, + }) + st.rootHash = nil // reset rootHash + st.Fees += fee + st.SizeInBytes += sizeInBytes + + if st.nodeIndex != nil { + // node index map exists, add the node to it + st.nodeIndex[node] = len(st.Nodes) - 1 + } + + return nil +} + +// RemoveNodeAtIndex removes a node at the given index and makes sure the subtree is still valid +func (st *Subtree) RemoveNodeAtIndex(index int) error { + st.mu.Lock() + defer st.mu.Unlock() + + if index >= len(st.Nodes) { + return fmt.Errorf("index out of range") + } + + st.Fees -= st.Nodes[index].Fee + st.SizeInBytes -= st.Nodes[index].SizeInBytes + + hash := st.Nodes[index].Hash + st.Nodes = append(st.Nodes[:index], st.Nodes[index+1:]...) + st.rootHash = nil // reset rootHash + + if st.nodeIndex != nil { + // remove the node from the node index map + delete(st.nodeIndex, hash) + } + + return nil +} + +func (st *Subtree) RootHash() *chainhash.Hash { + if st == nil { + return nil + } + + if st.rootHash != nil { + return st.rootHash + } + + if st.Length() == 0 { + return nil + } + + // calculate rootHash + store, err := BuildMerkleTreeStoreFromBytes(st.Nodes) + if err != nil { + return nil + } + + st.rootHash, _ = chainhash.NewHash((*store)[len(*store)-1][:]) + + return st.rootHash +} + +func (st *Subtree) RootHashWithReplaceRootNode(node *chainhash.Hash, fee uint64, sizeInBytes uint64) (*chainhash.Hash, error) { + if st == nil { + return nil, fmt.Errorf("subtree is nil") + } + + // clone the subtree, so we do not overwrite anything in it + subtreeClone := st.Duplicate() + subtreeClone.ReplaceRootNode(node, fee, sizeInBytes) + + // calculate rootHash + store, err := BuildMerkleTreeStoreFromBytes(subtreeClone.Nodes) + if err != nil { + return nil, err + } + + rootHash := chainhash.Hash((*store)[len(*store)-1][:]) + + return &rootHash, nil +} + +func (st *Subtree) GetMap() (TxMap, error) { + lengthUint32, err := safe.IntToUint32(len(st.Nodes)) + if err != nil { + return nil, err + } + + m := txmap.NewSwissMapUint64(lengthUint32) + for idx, node := range st.Nodes { + _ = m.Put(node.Hash, uint64(idx)) + } + + return m, nil +} + +func (st *Subtree) NodeIndex(hash chainhash.Hash) int { + if st.nodeIndex == nil { + // create the node index map + st.mu.Lock() + st.nodeIndex = make(map[chainhash.Hash]int, len(st.Nodes)) + + for idx, node := range st.Nodes { + st.nodeIndex[node.Hash] = idx + } + + st.mu.Unlock() + } + + nodeIndex, ok := st.nodeIndex[hash] + if ok { + return nodeIndex + } + + return -1 +} + +func (st *Subtree) HasNode(hash chainhash.Hash) bool { + return st.NodeIndex(hash) != -1 +} + +func (st *Subtree) GetNode(hash chainhash.Hash) (*SubtreeNode, error) { + nodeIndex := st.NodeIndex(hash) + if nodeIndex != -1 { + return &st.Nodes[nodeIndex], nil + } + + return nil, fmt.Errorf("node not found") +} + +func (st *Subtree) Difference(ids TxMap) ([]SubtreeNode, error) { + // return all the ids that are in st.Nodes, but not in ids + diff := make([]SubtreeNode, 0, 1_000) + + for _, node := range st.Nodes { + if !ids.Exists(node.Hash) { + diff = append(diff, node) + } + } + + return diff, nil +} + +// GetMerkleProof returns the merkle proof for the given index +// TODO rewrite this to calculate this from the subtree nodes needed, and not the whole tree +func (st *Subtree) GetMerkleProof(index int) ([]*chainhash.Hash, error) { + if index >= len(st.Nodes) { + return nil, fmt.Errorf("index out of range") + } + + merkleTree, err := BuildMerkleTreeStoreFromBytes(st.Nodes) + if err != nil { + return nil, err + } + + height := math.Ceil(math.Log2(float64(len(st.Nodes)))) + totalLength := int(math.Pow(2, height)) + len(*merkleTree) + + treeIndexPos := 0 + treeIndex := index + nodes := make([]*chainhash.Hash, 0, int(height)) + + for i := height; i > 0; i-- { + if i == height { + // we are at the leaf level and read from the Nodes array + if index%2 == 0 { + nodes = append(nodes, &st.Nodes[index+1].Hash) + } else { + nodes = append(nodes, &st.Nodes[index-1].Hash) + } + } else { + treePos := treeIndexPos + treeIndex + if treePos%2 == 0 { + if totalLength > treePos+1 && !(*merkleTree)[treePos+1].Equal(chainhash.Hash{}) { + treePos++ + } + } else { + if !(*merkleTree)[treePos-1].Equal(chainhash.Hash{}) { + treePos-- + } + } + + nodes = append(nodes, &(*merkleTree)[treePos]) + treeIndexPos += int(math.Pow(2, i)) + } + + treeIndex = int(math.Floor(float64(treeIndex) / 2)) + } + + return nodes, nil +} + +func (st *Subtree) Serialize() ([]byte, error) { + bufBytes := make([]byte, 0, 32+8+8+8+(len(st.Nodes)*32)+8+(len(st.ConflictingNodes)*32)) + buf := bytes.NewBuffer(bufBytes) + + // write root hash - this is only for checking the correctness of the data + _, err := buf.Write(st.RootHash()[:]) + if err != nil { + return nil, fmt.Errorf("unable to write root hash: %w", err) + } + + var b [8]byte + + // write fees + binary.LittleEndian.PutUint64(b[:], st.Fees) + + if _, err = buf.Write(b[:]); err != nil { + return nil, fmt.Errorf("unable to write fees: %w", err) + } + + // write size + binary.LittleEndian.PutUint64(b[:], st.SizeInBytes) + + if _, err = buf.Write(b[:]); err != nil { + return nil, fmt.Errorf("unable to write sizeInBytes: %w", err) + } + + // write number of nodes + binary.LittleEndian.PutUint64(b[:], uint64(len(st.Nodes))) + + if _, err = buf.Write(b[:]); err != nil { + return nil, fmt.Errorf("unable to write number of nodes: %w", err) + } + + // write nodes + feeBytes := make([]byte, 8) + sizeBytes := make([]byte, 8) + + for _, subtreeNode := range st.Nodes { + _, err = buf.Write(subtreeNode.Hash[:]) + if err != nil { + return nil, fmt.Errorf("unable to write node: %w", err) + } + + binary.LittleEndian.PutUint64(feeBytes, subtreeNode.Fee) + + _, err = buf.Write(feeBytes) + if err != nil { + return nil, fmt.Errorf("unable to write fee: %w", err) + } + + binary.LittleEndian.PutUint64(sizeBytes, subtreeNode.SizeInBytes) + + _, err = buf.Write(sizeBytes) + if err != nil { + return nil, fmt.Errorf("unable to write sizeInBytes: %w", err) + } + } + + // write number of conflicting nodes + binary.LittleEndian.PutUint64(b[:], uint64(len(st.ConflictingNodes))) + + if _, err = buf.Write(b[:]); err != nil { + return nil, fmt.Errorf("unable to write number of conflicting nodes: %w", err) + } + + // write conflicting nodes + for _, nodeHash := range st.ConflictingNodes { + _, err = buf.Write(nodeHash[:]) + if err != nil { + return nil, fmt.Errorf("unable to write conflicting node: %w", err) + } + } + + return buf.Bytes(), nil +} + +// SerializeNodes serializes only the nodes (list of transaction ids), not the root hash, fees, etc. +func (st *Subtree) SerializeNodes() ([]byte, error) { + b := make([]byte, 0, len(st.Nodes)*32) + buf := bytes.NewBuffer(b) + + var err error + + // write nodes + for _, subtreeNode := range st.Nodes { + if _, err = buf.Write(subtreeNode.Hash[:]); err != nil { + return nil, fmt.Errorf("unable to write node: %w", err) + } + } + + return buf.Bytes(), nil +} + +func (st *Subtree) Deserialize(b []byte) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("recovered in Deserialize: %s", r) + } + }() + + buf := bytes.NewBuffer(b) + + return st.DeserializeFromReader(buf) +} + +func (st *Subtree) DeserializeFromReader(reader io.Reader) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("recovered in DeserializeFromReader: %s", r) + } + }() + + buf := bufio.NewReaderSize(reader, 1024*1024*16) // 16MB buffer + + var ( + n int + bytes8 = make([]byte, 8) + ) + + // read root hash + st.rootHash = new(chainhash.Hash) + if n, err = buf.Read(st.rootHash[:]); err != nil || n != chainhash.HashSize { + // if _, err = io.ReadFull(buf, st.rootHash[:]); err != nil { + return fmt.Errorf("unable to read root hash: %w", err) + } + + // read fees + if n, err = buf.Read(bytes8); err != nil || n != 8 { + // if _, err = io.ReadFull(buf, bytes8); err != nil { + return fmt.Errorf("unable to read fees: %w", err) + } + + st.Fees = binary.LittleEndian.Uint64(bytes8) + + // read sizeInBytes + if n, err = buf.Read(bytes8); err != nil || n != 8 { + // if _, err = io.ReadFull(buf, bytes8); err != nil { + return fmt.Errorf("unable to read sizeInBytes: %w", err) + } + + st.SizeInBytes = binary.LittleEndian.Uint64(bytes8) + + if err = st.deserializeNodes(buf); err != nil { + return err + } + + if err = st.deserializeConflictingNodes(buf); err != nil { + return err + } + + return nil +} + +func (st *Subtree) deserializeNodes(buf *bufio.Reader) error { + bytes8 := make([]byte, 8) + + // read number of leaves + if n, err := buf.Read(bytes8); err != nil || n != 8 { + // if _, err = io.ReadFull(buf, bytes8); err != nil { + return fmt.Errorf("unable to read number of leaves: %w", err) + } + + numLeaves := binary.LittleEndian.Uint64(bytes8) + + st.treeSize = int(numLeaves) + // the height of a subtree is always a power of two + st.Height = int(math.Ceil(math.Log2(float64(numLeaves)))) + + // read leaves + st.Nodes = make([]SubtreeNode, numLeaves) + + bytes48 := make([]byte, 48) + for i := uint64(0); i < numLeaves; i++ { + // read all the node data in 1 go + if n, err := ReadBytes(buf, bytes48); err != nil || n != 48 { + // if _, err = io.ReadFull(buf, bytes48); err != nil { + return fmt.Errorf("unable to read node: %w", err) + } + + st.Nodes[i].Hash = chainhash.Hash(bytes48[:32]) + st.Nodes[i].Fee = binary.LittleEndian.Uint64(bytes48[32:40]) + st.Nodes[i].SizeInBytes = binary.LittleEndian.Uint64(bytes48[40:48]) + } + + return nil +} + +func (st *Subtree) deserializeConflictingNodes(buf *bufio.Reader) error { + bytes8 := make([]byte, 8) + + // read number of conflicting nodes + if n, err := buf.Read(bytes8); err != nil || n != 8 { + // if _, err = io.ReadFull(buf, bytes8); err != nil { + return fmt.Errorf("unable to read number of conflicting nodes: %w", err) + } + + numConflictingLeaves := binary.LittleEndian.Uint64(bytes8) + + // read conflicting nodes + st.ConflictingNodes = make([]chainhash.Hash, numConflictingLeaves) + + for i := uint64(0); i < numConflictingLeaves; i++ { + if n, err := buf.Read(st.ConflictingNodes[i][:]); err != nil || n != 32 { + return fmt.Errorf("unable to read conflicting node %d: %s", i, err) + } + } + + return nil +} + +func ReadBytes(buf *bufio.Reader, p []byte) (n int, err error) { + minRead := len(p) + for n < minRead && err == nil { + p[n], err = buf.ReadByte() + n++ + } + + if n >= minRead { + err = nil + } else if n > 0 && err == io.EOF { + err = io.ErrUnexpectedEOF + } + + return +} + +func DeserializeSubtreeConflictingFromReader(reader io.Reader) (conflictingNodes []chainhash.Hash, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("recovered in DeserializeFromReader: %s", r) + } + }() + + buf := bufio.NewReaderSize(reader, 1024*1024*16) // 16MB buffer + + // skip root hash 32 bytes + // skip fees, 8 bytes + // skip sizeInBytes, 8 bytes + _, _ = buf.Discard(32 + 8 + 8) + + bytes8 := make([]byte, 8) + + // read number of leaves + if _, err = io.ReadFull(buf, bytes8); err != nil { + return nil, fmt.Errorf("unable to read number of leaves: %w", err) + } + + numLeaves := binary.LittleEndian.Uint64(bytes8) + + numLeavesInt, err := safe.Uint64ToInt(numLeaves) + if err != nil { + return nil, err + } + + _, _ = buf.Discard(48 * numLeavesInt) + + // read number of conflicting nodes + if _, err = io.ReadFull(buf, bytes8); err != nil { + return nil, fmt.Errorf("unable to read number of conflicting nodes: %w", err) + } + + numConflictingLeaves := binary.LittleEndian.Uint64(bytes8) + + // read conflicting nodes + conflictingNodes = make([]chainhash.Hash, numConflictingLeaves) + for i := uint64(0); i < numConflictingLeaves; i++ { + if _, err = io.ReadFull(buf, conflictingNodes[i][:]); err != nil { + return nil, fmt.Errorf("unable to read conflicting node: %w", err) + } + } + + return conflictingNodes, nil +} diff --git a/subtree_benchmark_test.go b/subtree_benchmark_test.go new file mode 100644 index 0000000..22b0286 --- /dev/null +++ b/subtree_benchmark_test.go @@ -0,0 +1,71 @@ +package subtree_test + +import ( + "crypto/rand" + "encoding/binary" + "github.com/bsv-blockchain/go-subtree" + "github.com/libsv/go-bt/v2/chainhash" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func BenchmarkSubtree_AddNode(b *testing.B) { + st, err := subtree.NewIncompleteTreeByLeafCount(b.N) + require.NoError(b, err) + + // create a slice of random hashes + hashes := make([]chainhash.Hash, b.N) + + b32 := make([]byte, 32) + + for i := 0; i < b.N; i++ { + // create random 32 bytes + _, _ = rand.Read(b32) + hashes[i] = chainhash.Hash(b32) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _ = st.AddNode(hashes[i], 111, 0) + } +} + +func BenchmarkSubtree_Serialize(b *testing.B) { + st, err := subtree.NewIncompleteTreeByLeafCount(b.N) + require.NoError(b, err) + + for i := 0; i < b.N; i++ { + // int to bytes + var bb [32]byte + + binary.LittleEndian.PutUint32(bb[:], uint32(i)) + _ = st.AddNode(*(*chainhash.Hash)(&bb), 111, 234) + } + + b.ResetTimer() + + ser, err := st.Serialize() + require.NoError(b, err) + assert.GreaterOrEqual(b, len(ser), 48*b.N) +} + +func BenchmarkSubtree_SerializeNodes(b *testing.B) { + st, err := subtree.NewIncompleteTreeByLeafCount(b.N) + require.NoError(b, err) + + for i := 0; i < b.N; i++ { + // int to bytes + var bb [32]byte + + binary.LittleEndian.PutUint32(bb[:], uint32(i)) + _ = st.AddNode(*(*chainhash.Hash)(&bb), 111, 234) + } + + b.ResetTimer() + + ser, err := st.SerializeNodes() + require.NoError(b, err) + assert.GreaterOrEqual(b, len(ser), 32*b.N) +} diff --git a/subtree_data.go b/subtree_data.go new file mode 100644 index 0000000..0503d74 --- /dev/null +++ b/subtree_data.go @@ -0,0 +1,157 @@ +package subtree + +import ( + "bytes" + "fmt" + "io" + + "github.com/libsv/go-bt/v2" + "github.com/libsv/go-bt/v2/chainhash" +) + +type SubtreeData struct { + Subtree *Subtree + Txs []*bt.Tx +} + +// NewSubtreeData creates a new SubtreeData object +// the size parameter is the number of nodes in the subtree, +// the index in that array should match the index of the node in the subtree +func NewSubtreeData(subtree *Subtree) *SubtreeData { + return &SubtreeData{ + Subtree: subtree, + Txs: make([]*bt.Tx, subtree.Size()), + } +} + +func NewSubtreeDataFromBytes(subtree *Subtree, dataBytes []byte) (*SubtreeData, error) { + s := &SubtreeData{ + Subtree: subtree, + } + if err := s.serializeFromReader(bytes.NewReader(dataBytes)); err != nil { + return nil, fmt.Errorf("unable to create subtree data from bytes: %s", err) + } + + return s, nil +} + +func NewSubtreeDataFromReader(subtree *Subtree, dataReader io.Reader) (*SubtreeData, error) { + s := &SubtreeData{ + Subtree: subtree, + } + if err := s.serializeFromReader(dataReader); err != nil { + return nil, fmt.Errorf("unable to create subtree data from reader: %s", err) + } + + return s, nil +} + +func (s *SubtreeData) RootHash() *chainhash.Hash { + return s.Subtree.RootHash() +} + +func (s *SubtreeData) AddTx(tx *bt.Tx, index int) error { + if index == 0 && tx.IsCoinbase() && s.Subtree.Nodes[index].Hash.Equal(CoinbasePlaceholderHashValue) { + // we got the coinbase tx as the first tx, we need to add it as the first tx and stop further processing + s.Txs[index] = tx + + return nil + } + + // check whether this is set in the main subtree + if !s.Subtree.Nodes[index].Hash.Equal(*tx.TxIDChainHash()) { + return fmt.Errorf("transaction hash does not match subtree node hash") + } + + s.Txs[index] = tx + + return nil +} + +func (s *SubtreeData) serializeFromReader(buf io.Reader) error { + var ( + err error + txIndex int + ) + + if s.Subtree == nil || len(s.Subtree.Nodes) == 0 { + return fmt.Errorf("subtree nodes slice is empty") + } + + if s.Subtree.Nodes[0].Hash.Equal(CoinbasePlaceholderHashValue) { + txIndex = 1 + } + + // initialize the txs array + s.Txs = make([]*bt.Tx, s.Subtree.Length()) + + for { + tx := &bt.Tx{} + + _, err = tx.ReadFrom(buf) + if err != nil { + if err == io.EOF { + break + } + + return fmt.Errorf("error reading transaction: %s", err) + } + + if txIndex == 1 && tx.IsCoinbase() { + // we got the coinbase tx as the first tx, we need to add it as the first tx and continue + s.Txs[0] = tx + + continue + } + + if txIndex >= len(s.Subtree.Nodes) { + return fmt.Errorf("transaction index out of bounds") + } + + if !s.Subtree.Nodes[txIndex].Hash.Equal(*tx.TxIDChainHash()) { + return fmt.Errorf("transaction hash does not match subtree node hash") + } + + s.Txs[txIndex] = tx + txIndex++ + } + + return nil +} + +// Serialize returns the serialized form of the subtree meta +func (s *SubtreeData) Serialize() ([]byte, error) { + var err error + + // only serialize when we have the matching subtree + if s.Subtree == nil { + return nil, fmt.Errorf("cannot serialize, subtree is not set") + } + + var txStartIndex int + if s.Subtree.Nodes[0].Hash.Equal(*CoinbasePlaceholderHash) { + txStartIndex = 1 + } + + // check the data in the subtree matches the data in the tx data + subtreeLen := s.Subtree.Length() + for i := txStartIndex; i < subtreeLen; i++ { + if s.Txs[i] == nil && i != 0 { + return nil, fmt.Errorf("subtree length does not match tx data length") + } + } + + bufBytes := make([]byte, 0, 32*1024) // 16MB (arbitrary size, should be enough for most cases) + buf := bytes.NewBuffer(bufBytes) + + for i := txStartIndex; i < subtreeLen; i++ { + b := s.Txs[i].ExtendedBytes() + + _, err = buf.Write(b) + if err != nil { + return nil, fmt.Errorf("error writing tx data: %s", err) + } + } + + return buf.Bytes(), nil +} diff --git a/subtree_data_test.go b/subtree_data_test.go new file mode 100644 index 0000000..24dca36 --- /dev/null +++ b/subtree_data_test.go @@ -0,0 +1,300 @@ +package subtree + +import ( + "bytes" + "fmt" + "io" + "testing" + + "github.com/libsv/go-bt/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + tx, _ = bt.NewTxFromString("010000000000000000ef0158ef6d539bf88c850103fa127a92775af48dba580c36bbde4dc6d8b9da83256d050000006a47304402200ca69c5672d0e0471cd4ff1f9993f16103fc29b98f71e1a9760c828b22cae61c0220705e14aa6f3149130c3a6aa8387c51e4c80c6ae52297b2dabfd68423d717be4541210286dbe9cd647f83a4a6b29d2a2d3227a897a4904dc31769502cb013cbe5044dddffffffff8c2f6002000000001976a914308254c746057d189221c36418ba93337de33bc988ac03002d3101000000001976a91498cde576de501ceb5bb1962c6e49a4d1af17730788ac80969800000000001976a914eb7772212c334c0bdccee75c0369aa675fc21d2088ac706b9600000000001976a914a32f7eaae3afd5f73a2d6009b93f91aa11d16eef88ac00000000") +) + +func TestNewSubtreeData(t *testing.T) { + tx1 := tx.Clone() + tx1.Version = 1 + + tx2 := tx.Clone() + tx2.Version = 2 + + tx3 := tx.Clone() + tx3.Version = 3 + + tx4 := tx.Clone() + tx4.Version = 4 + + t.Run("create new subtree data", func(t *testing.T) { + subtree, err := NewTree(2) + require.NoError(t, err) + + _ = subtree.AddNode(*tx1.TxIDChainHash(), 111, 0) + _ = subtree.AddNode(*tx2.TxIDChainHash(), 111, 0) + _ = subtree.AddNode(*tx3.TxIDChainHash(), 111, 0) + _ = subtree.AddNode(*tx4.TxIDChainHash(), 111, 0) + + // Test the constructor + subtreeData := NewSubtreeData(subtree) + + // Verify the subtree data + assert.Equal(t, subtree, subtreeData.Subtree) + assert.Equal(t, subtree.Size(), len(subtreeData.Txs)) + assert.Equal(t, 4, len(subtreeData.Txs)) + + // All transactions should be initially nil + for i := 0; i < len(subtreeData.Txs); i++ { + assert.Nil(t, subtreeData.Txs[i]) + } + }) + + t.Run("add transaction successfully", func(t *testing.T) { + subtree, err := NewTree(2) + require.NoError(t, err) + + _ = subtree.AddNode(*tx1.TxIDChainHash(), 111, 0) + + subtreeData := NewSubtreeData(subtree) + + // Add the transaction to the subtree data + err = subtreeData.AddTx(tx1, 0) + require.NoError(t, err) + + // Verify the transaction was added + assert.Equal(t, tx1, subtreeData.Txs[0]) + }) + + t.Run("add with coinbase tx", func(t *testing.T) { + coinbaseTx, _ := bt.NewTxFromString("02000000010000000000000000000000000000000000000000000000000000000000000000ffffffff03510101ffffffff0100f2052a01000000232103656065e6886ca1e947de3471c9e723673ab6ba34724476417fa9fcef8bafa604ac00000000") + + subtree, err := NewTree(2) + require.NoError(t, err) + + require.NoError(t, subtree.AddCoinbaseNode()) + require.NoError(t, subtree.AddNode(*tx1.TxIDChainHash(), 111, 0)) + + subtreeData := NewSubtreeData(subtree) + + // Add the transaction to the subtree data + err = subtreeData.AddTx(coinbaseTx, 0) + require.NoError(t, err) + + err = subtreeData.AddTx(tx1, 1) + require.NoError(t, err) + + // Verify the transaction was added + assert.Equal(t, coinbaseTx, subtreeData.Txs[0]) + assert.Equal(t, tx1, subtreeData.Txs[1]) + }) + + t.Run("add transaction with mismatched hash", func(t *testing.T) { + subtree, err := NewTree(2) + require.NoError(t, err) + + _ = subtree.AddNode(*tx1.TxIDChainHash(), 111, 0) + + subtreeData := NewSubtreeData(subtree) + + // Add the transaction should fail due to hash mismatch + err = subtreeData.AddTx(tx2, 0) + require.Error(t, err) + assert.Contains(t, err.Error(), "transaction hash does not match subtree node hash") + + // Verify the transaction was not added + assert.Nil(t, subtreeData.Txs[0]) + }) +} + +func setupSubtreeData(t *testing.T) (*Subtree, *SubtreeData) { + tx1 := tx.Clone() + tx1.Version = 1 + + tx2 := tx.Clone() + tx2.Version = 2 + + tx3 := tx.Clone() + tx3.Version = 3 + + tx4 := tx.Clone() + tx4.Version = 4 + + subtree, err := NewTree(2) + require.NoError(t, err) + + _ = subtree.AddNode(*tx1.TxIDChainHash(), 111, 1) + _ = subtree.AddNode(*tx2.TxIDChainHash(), 111, 2) + _ = subtree.AddNode(*tx3.TxIDChainHash(), 111, 3) + _ = subtree.AddNode(*tx4.TxIDChainHash(), 111, 4) + + subtreeData := NewSubtreeData(subtree) + + // Add transactions to the subtree data + _ = subtreeData.AddTx(tx1, 0) + _ = subtreeData.AddTx(tx2, 1) + _ = subtreeData.AddTx(tx3, 2) + _ = subtreeData.AddTx(tx4, 3) + + return subtree, subtreeData +} + +func TestSerialize(t *testing.T) { + tx1 := tx.Clone() + tx1.Version = 1 + + tx2 := tx.Clone() + tx2.Version = 2 + + tx3 := tx.Clone() + tx3.Version = 3 + + tx4 := tx.Clone() + tx4.Version = 4 + + t.Run("serialize subtree data", func(t *testing.T) { + _, subtreeData := setupSubtreeData(t) + + // Serialize the subtree data + serializedData, err := subtreeData.Serialize() + require.NoError(t, err) + + // Ensure we have data + assert.NotEmpty(t, serializedData) + }) + + t.Run("serialize with nil subtree", func(t *testing.T) { + subtreeData := &SubtreeData{ + Subtree: nil, + Txs: make([]*bt.Tx, 0), + } + + // Serialize should fail with nil subtree + serializedData, err := subtreeData.Serialize() + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot serialize, subtree is not set") + assert.Nil(t, serializedData) + }) + + t.Run("serialize with missing transactions", func(t *testing.T) { + subtree, err := NewTree(2) + require.NoError(t, err) + + _ = subtree.AddNode(*tx1.TxIDChainHash(), 111, 0) + _ = subtree.AddNode(*tx2.TxIDChainHash(), 111, 0) + + subtreeData := NewSubtreeData(subtree) + + _ = subtreeData.AddTx(tx1, 0) + + // Second transaction is missing, so serialization should fail + serializedData, err := subtreeData.Serialize() + require.Error(t, err) + assert.Contains(t, err.Error(), "subtree length does not match tx data length") + assert.Nil(t, serializedData) + }) +} + +func TestNewSubtreeDataFromBytes(t *testing.T) { + t.Run("create from valid bytes", func(t *testing.T) { + subtree, origData := setupSubtreeData(t) + + // Serialize the original data + serializedData, err := origData.Serialize() + require.NoError(t, err) + + // Create new subtree data from bytes + newData, err := NewSubtreeDataFromBytes(subtree, serializedData) + require.NoError(t, err) + + // Verify the new subtree data + assert.Equal(t, subtree, newData.Subtree) + assert.Equal(t, len(origData.Txs), len(newData.Txs)) + + // Compare transactions (skipping first if it's a coinbase placeholder) + startIdx := 0 + if subtree.Nodes[0].Hash.Equal(*CoinbasePlaceholderHash) { + startIdx = 1 + } + + for i := startIdx; i < len(newData.Txs); i++ { + if origData.Txs[i] != nil && newData.Txs[i] != nil { + assert.Equal(t, origData.Txs[i].TxID(), newData.Txs[i].TxID()) + } + } + }) + + t.Run("create from invalid bytes", func(t *testing.T) { + subtree, _ := setupSubtreeData(t) + + // Create invalid serialized data + invalidData := []byte("invalid data") + + // Create new subtree data from invalid bytes should fail + newData, err := NewSubtreeDataFromBytes(subtree, invalidData) + require.Error(t, err) + assert.Nil(t, newData) + }) +} + +func TestNewSubtreeDataFromReader(t *testing.T) { + t.Run("create from valid reader", func(t *testing.T) { + subtree, origData := setupSubtreeData(t) + + // Serialize the original data + serializedData, err := origData.Serialize() + require.NoError(t, err) + + // Create a reader from the serialized data + reader := bytes.NewReader(serializedData) + + // Create new subtree data from reader + newData, err := NewSubtreeDataFromReader(subtree, reader) + require.NoError(t, err) + + // Verify the new subtree data + assert.Equal(t, subtree, newData.Subtree) + + // Compare transactions (skipping first if it's a coinbase placeholder) + startIdx := 0 + if subtree.Nodes[0].Hash.Equal(*CoinbasePlaceholderHash) { + startIdx = 1 + } + + for i := startIdx; i < len(newData.Txs); i++ { + if origData.Txs[i] != nil && newData.Txs[i] != nil { + assert.Equal(t, origData.Txs[i].TxID(), newData.Txs[i].TxID()) + } + } + }) + + t.Run("create from invalid reader", func(t *testing.T) { + subtree, _ := setupSubtreeData(t) + + // Create invalid reader that returns EOF + reader := &mockReader{err: io.EOF} + + // Create new subtree data from invalid reader + newData, err := NewSubtreeDataFromReader(subtree, reader) + assert.NoError(t, err) // EOF is handled specially and considered normal end of data + assert.NotNil(t, newData) + + // Create invalid reader that returns error other than EOF + reader = &mockReader{err: fmt.Errorf("read error")} + + // Create new subtree data from invalid reader should fail + newData, err = NewSubtreeDataFromReader(subtree, reader) + require.Error(t, err) + assert.Nil(t, newData) + }) +} + +// Mock reader for testing +type mockReader struct { + err error +} + +func (r *mockReader) Read(p []byte) (n int, err error) { + return 0, r.err +} diff --git a/subtree_meta.go b/subtree_meta.go new file mode 100644 index 0000000..c1f9158 --- /dev/null +++ b/subtree_meta.go @@ -0,0 +1,293 @@ +package subtree + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + safe "github.com/bsv-blockchain/go-safe-conversion" + "github.com/libsv/go-bt/v2" + "github.com/libsv/go-bt/v2/chainhash" +) + +type SubtreeMeta struct { + // Subtree is the subtree this meta is for + Subtree *Subtree + // TxInpoints is a lookup of the parent tx inpoints for each node in the subtree + TxInpoints []TxInpoints + + // RootHash is the hash of the root node of the subtree + rootHash chainhash.Hash +} + +// NewSubtreeMeta creates a new SubtreeMeta object +// the size parameter is the number of nodes in the subtree, +// the index in that array should match the index of the node in the subtree +// +// Parameters: +// - subtree: The subtree for which to create the meta +// +// Returns: +// - *SubtreeMeta: A new SubtreeMeta object with the specified subtree and an empty TxInpoints slice +func NewSubtreeMeta(subtree *Subtree) *SubtreeMeta { + return &SubtreeMeta{ + Subtree: subtree, + TxInpoints: make([]TxInpoints, subtree.Size()), + } +} + +// NewSubtreeMetaFromBytes creates a new SubtreeMeta object from the provided byte slice. +// It reads the subtree meta data from the byte slice and populates the SubtreeMeta struct. +// +// Parameters: +// - subtree: The subtree for which to create the meta +// - dataBytes: The byte slice containing the serialized subtree meta data +// +// Returns: +// - *SubtreeMeta: A new SubtreeMeta object populated with data from the byte slice +// - error: An error if the deserialization fails +func NewSubtreeMetaFromBytes(subtree *Subtree, dataBytes []byte) (*SubtreeMeta, error) { + s := &SubtreeMeta{ + Subtree: subtree, + } + if err := s.deserializeFromReader(bytes.NewReader(dataBytes)); err != nil { + return nil, fmt.Errorf("unable to create subtree meta from bytes: %s", err) + } + + return s, nil +} + +// NewSubtreeMetaFromReader creates a new SubtreeMeta object from the provided reader. +// +// Parameters: +// - subtree: The subtree for which to create the meta +// - dataReader: The reader from which to read the subtree meta data +// +// Returns: +// - *SubtreeMeta: A new SubtreeMeta object populated with data from the reader +// - error: An error if the deserialization fails +func NewSubtreeMetaFromReader(subtree *Subtree, dataReader io.Reader) (*SubtreeMeta, error) { + s := &SubtreeMeta{ + Subtree: subtree, + TxInpoints: make([]TxInpoints, subtree.Size()), + } + + if err := s.deserializeFromReader(dataReader); err != nil { + return nil, fmt.Errorf("unable to create subtree meta from reader: %s", err) + } + + return s, nil +} + +// GetParentTxHashes returns the unique parent transaction hashes for the specified index in the subtree meta. +// It returns an error if the index is out of range. +// +// Parameters: +// - index: The index of the subtree node for which to get the parent transaction hashes +// +// Returns: +// - []chainhash.Hash: The unique parent transaction hashes for the specified index +// - error: An error if the index is out of range or if there is an issue retrieving the parent transaction hashes +func (s *SubtreeMeta) GetParentTxHashes(index int) ([]chainhash.Hash, error) { + if index >= len(s.TxInpoints) { + return nil, fmt.Errorf("index out of range") + } + + return s.TxInpoints[index].GetParentTxHashes(), nil +} + +// GetTxInpoints returns the TxInpoints for the specified index in the subtree meta. +// It returns an error if the index is out of range. +// +// Parameters: +// - index: The index of the subtree node for which to get the TxInpoints +// +// Returns: +// - []meta.Inpoint: The TxInpoints for the specified index +// - error: An error if the index is out of range or if there is an issue retrieving the TxInpoints +func (s *SubtreeMeta) GetTxInpoints(index int) ([]Inpoint, error) { + if index >= len(s.TxInpoints) { + return nil, fmt.Errorf("index out of range getting tx inpoints") + } + + return s.TxInpoints[index].GetTxInpoints(), nil +} + +// deserializeFromReader reads the subtree meta from the provided reader +// and populates the SubtreeMeta struct with the data. +// +// Parameters: +// - buf: The reader from which to read the subtree meta data +// +// Returns: +// - error: An error if the deserialization fails +func (s *SubtreeMeta) deserializeFromReader(buf io.Reader) error { + var ( + err error + dataBytes [4]byte + hashBytes [32]byte + ) + + // read the root hash + if _, err = io.ReadFull(buf, hashBytes[:]); err != nil { + return fmt.Errorf("unable to read root hash: %s", err) + } + + s.rootHash = hashBytes + + // read the number of parent tx hashes + if _, err = io.ReadFull(buf, dataBytes[:]); err != nil { + return fmt.Errorf("unable to read number of parent tx hashes: %s", err) + } + + txInpointsLen := binary.LittleEndian.Uint32(dataBytes[:]) + + // read the parent tx hashes + s.TxInpoints = make([]TxInpoints, s.Subtree.Size()) + + return s.deserializeTxInpointsFromReader(buf, txInpointsLen) +} + +// deserializeTxInpointsFromReader reads the TxInpoints from the provided reader +// and populates the TxInpoints slice in the SubtreeMeta. +// +// Parameters: +// - buf: The reader from which to read the TxInpoints +// - txInpointsLen: The number of TxInpoints to read +// +// Returns: +// - error: An error if the deserialization fails +func (s *SubtreeMeta) deserializeTxInpointsFromReader(buf io.Reader, txInpointsLen uint32) error { + var ( + err error + txInpoints TxInpoints + ) + + for i := uint32(0); i < txInpointsLen; i++ { + txInpoints, err = NewTxInpointsFromReader(buf) + if err != nil { + return fmt.Errorf("unable to deserialize parent outpoints: %s", err) + } + + s.TxInpoints[i] = txInpoints + } + + return nil +} + +// SetTxInpointsFromTx sets the TxInpoints for the subtree meta from a transaction. +// It finds the index of the transaction in the subtree and sets the TxInpoints at that index. +// If the transaction is not found in the subtree, it returns an error. +// +// Parameters: +// - tx: The transaction to set the TxInpoints from +// +// Returns: +// - error: An error if the transaction is not found in the subtree or if there is an issue creating the TxInpoints +func (s *SubtreeMeta) SetTxInpointsFromTx(tx *bt.Tx) error { + index := s.Subtree.NodeIndex(*tx.TxIDChainHash()) + if index == -1 { + return fmt.Errorf("[SetParentTxHashesFromTx][%s] node not found in subtree", tx.TxID()) + } + + p, err := NewTxInpointsFromTx(tx) + if err != nil { + return err + } + + s.TxInpoints[index] = p + + return nil +} + +// SetTxInpoints sets the TxInpoints at the specified index in the subtree meta. +// It returns an error if the index is out of range. +// +// Parameters: +// - idx: The index at which to set the TxInpoints +// - txInpoints: The TxInpoints to set at the specified index +// +// Returns: +// - error: An error if the index is out of range +func (s *SubtreeMeta) SetTxInpoints(idx int, txInpoints TxInpoints) error { + if idx >= len(s.TxInpoints) { + return fmt.Errorf("index out of range") + } + + s.TxInpoints[idx] = txInpoints + + return nil +} + +// Serialize returns the serialized form of the subtree meta +func (s *SubtreeMeta) Serialize() ([]byte, error) { + var err error + + // only serialize when we have the matching subtree + if s.Subtree == nil { + return nil, fmt.Errorf("cannot serialize, subtree is not set") + } + + // check the data in the subtree matches the data in the parent tx hashes + subtreeLen := s.Subtree.Length() + for i := 0; i < subtreeLen; i++ { + if i != 0 && s.TxInpoints[i].ParentTxHashes == nil { + return nil, fmt.Errorf("cannot serialize, parent tx hashes are not set for node %d: %s", i, s.Subtree.Nodes[i].Hash.String()) + } + } + + bufBytes := make([]byte, 0, 32*1024) // 32MB (arbitrary size, should be enough for most cases) + buf := bytes.NewBuffer(bufBytes) + + s.rootHash = *s.Subtree.RootHash() + + // write root hash + if _, err = buf.Write(s.rootHash[:]); err != nil { + return nil, fmt.Errorf("cannot serialize, unable to write root hash: %s", err) + } + + if err = s.serializeTxInpoints(buf); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func (s *SubtreeMeta) serializeTxInpoints(buf *bytes.Buffer) error { + var ( + err error + bytesUint32 [4]byte + ) + + parentTxHashesLen32, err := safe.IntToUint32(s.Subtree.Length()) + if err != nil { + return fmt.Errorf("cannot serialize, unable to get safe uint32: %s", err) + } + + // write number of parent tx hashes + binary.LittleEndian.PutUint32(bytesUint32[:], parentTxHashesLen32) + + if _, err = buf.Write(bytesUint32[:]); err != nil { + return fmt.Errorf("cannot serialize, unable to write total number of nodes: %s", err) + } + + var txInPointBytes []byte + + // write parent txInpoints + // for _, txInpoint := range s.TxInpoints { + for i := uint32(0); i < parentTxHashesLen32; i++ { + txInpoint := s.TxInpoints[i] + + txInPointBytes, err = txInpoint.Serialize() + if err != nil { + return fmt.Errorf("cannot serialize, unable to write parent tx hash: %s", err) + } + + if _, err = buf.Write(txInPointBytes); err != nil { + return fmt.Errorf("cannot serialize, unable to write parent tx hash: %s", err) + } + } + + return nil +} diff --git a/subtree_meta_test.go b/subtree_meta_test.go new file mode 100644 index 0000000..c79012f --- /dev/null +++ b/subtree_meta_test.go @@ -0,0 +1,344 @@ +package subtree + +import ( + "bytes" + "testing" + + "github.com/libsv/go-bt/v2" + "github.com/libsv/go-bt/v2/chainhash" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewSubtreeMeta(t *testing.T) { + tx1 := tx.Clone() + tx1.Version = 1 + + tx2 := tx.Clone() + tx2.Version = 2 + + tx3 := tx.Clone() + tx3.Version = 3 + + tx4 := tx.Clone() + tx4.Version = 4 + + t.Run("TestNewSubtreeMeta", func(t *testing.T) { + subtree, _ := NewTreeByLeafCount(4) + _ = subtree.AddNode(*tx1.TxIDChainHash(), 1, 1) + subtreeMeta := NewSubtreeMeta(subtree) + + assert.Equal(t, 4, len(subtreeMeta.TxInpoints)) + + for i := 0; i < 4; i++ { + assert.Nil(t, subtreeMeta.TxInpoints[i].ParentTxHashes) + } + }) + + t.Run("TestNewSubtreeMeta without subtree node", func(t *testing.T) { + subtree, _ := NewTreeByLeafCount(4) + _ = subtree.AddNode(*tx1.TxIDChainHash(), 1, 1) + subtreeMeta := NewSubtreeMeta(subtree) + + err := subtreeMeta.SetTxInpointsFromTx(tx1) + require.NoError(t, err) + + err = subtreeMeta.SetTxInpointsFromTx(tx2) + require.Error(t, err) + }) + + t.Run("TestNewSubtreeMeta with 1 set", func(t *testing.T) { + subtree, _ := NewTreeByLeafCount(4) + _ = subtree.AddNode(*tx1.TxIDChainHash(), 1, 1) + subtreeMeta := NewSubtreeMeta(subtree) + + err := subtreeMeta.SetTxInpointsFromTx(tx1) + require.NoError(t, err) + + assert.Equal(t, 4, len(subtreeMeta.TxInpoints)) + + assert.Equal(t, 1, len(subtreeMeta.TxInpoints[0].GetParentTxHashes())) + + allParentTxHashes, err := subtreeMeta.GetParentTxHashes(0) + require.NoError(t, err) + assert.Equal(t, 1, len(allParentTxHashes)) + + for i := 1; i < 4; i++ { + assert.Nil(t, subtreeMeta.TxInpoints[i].ParentTxHashes) + } + }) + + t.Run("TestNewSubtreeMeta with all set", func(t *testing.T) { + subtree, _ := NewTreeByLeafCount(4) + require.NoError(t, subtree.AddNode(*tx1.TxIDChainHash(), 1, 1)) + require.NoError(t, subtree.AddNode(*tx2.TxIDChainHash(), 2, 2)) + require.NoError(t, subtree.AddNode(*tx3.TxIDChainHash(), 3, 3)) + require.NoError(t, subtree.AddNode(*tx4.TxIDChainHash(), 4, 4)) + + subtreeMeta := NewSubtreeMeta(subtree) + + _ = subtreeMeta.SetTxInpointsFromTx(tx1) + _ = subtreeMeta.SetTxInpointsFromTx(tx2) + _ = subtreeMeta.SetTxInpointsFromTx(tx3) + _ = subtreeMeta.SetTxInpointsFromTx(tx4) + + assert.Equal(t, 4, len(subtreeMeta.TxInpoints)) + + for i := 1; i < 4; i++ { + assert.Equal(t, 1, len(subtreeMeta.TxInpoints[i].GetParentTxHashes())) + } + }) +} + +func TestNewSubtreeMetaFromBytes(t *testing.T) { + tx1 := tx.Clone() + tx1.Version = 1 + + tx2 := tx.Clone() + tx2.Version = 2 + + tx3 := tx.Clone() + tx3.Version = 3 + + tx4 := tx.Clone() + tx4.Version = 4 + + t.Run("TestNewSubtreeMetaFromBytes", func(t *testing.T) { + subtree, _ := NewTreeByLeafCount(4) + require.NoError(t, subtree.AddNode(*tx1.TxIDChainHash(), 1, 1)) + + subtreeMeta := NewSubtreeMeta(subtree) + require.NoError(t, subtreeMeta.SetTxInpointsFromTx(tx1)) + + b, err := subtreeMeta.Serialize() + require.NoError(t, err) + + subtreeMeta2, err := NewSubtreeMetaFromBytes(subtree, b) + require.NoError(t, err) + + assert.Equal(t, subtreeMeta.rootHash, subtreeMeta2.rootHash) + assert.Equal(t, len(subtreeMeta.TxInpoints), len(subtreeMeta2.TxInpoints)) + + for i := 0; i < 4; i++ { + if subtreeMeta.TxInpoints[i].ParentTxHashes == nil { + assert.Nil(t, subtreeMeta2.TxInpoints[i].ParentTxHashes) + continue + } + + assert.Equal(t, len(subtreeMeta.TxInpoints[i].GetParentTxHashes()), len(subtreeMeta2.TxInpoints[i].GetParentTxHashes())) + + for j := 0; j < len(subtreeMeta.TxInpoints[i].GetParentTxHashes()); j++ { + assert.Equal(t, subtreeMeta.TxInpoints[i].GetParentTxHashes()[j], subtreeMeta2.TxInpoints[i].GetParentTxHashes()[j]) + } + } + }) + + t.Run("TestNewSubtreeMetaFromReader", func(t *testing.T) { + subtree, _ := NewTreeByLeafCount(4) + _ = subtree.AddNode(*tx1.TxIDChainHash(), 1, 1) + subtreeMeta := NewSubtreeMeta(subtree) + _ = subtreeMeta.SetTxInpointsFromTx(tx1) + + b, err := subtreeMeta.Serialize() + require.NoError(t, err) + + buf := bytes.NewReader(b) + + subtreeMeta2, err := NewSubtreeMetaFromReader(subtree, buf) + require.NoError(t, err) + + assert.Equal(t, subtreeMeta.rootHash, subtreeMeta2.rootHash) + assert.Equal(t, len(subtreeMeta.TxInpoints), len(subtreeMeta2.TxInpoints)) + + for i := 0; i < 4; i++ { + if subtreeMeta.TxInpoints[i].ParentTxHashes == nil { + assert.Nil(t, subtreeMeta2.TxInpoints[i].ParentTxHashes) + continue + } + + assert.Equal(t, len(subtreeMeta.TxInpoints[i].GetParentTxHashes()), len(subtreeMeta2.TxInpoints[i].GetParentTxHashes())) + + for j := 0; j < len(subtreeMeta.TxInpoints[i].GetParentTxHashes()); j++ { + assert.Equal(t, subtreeMeta.TxInpoints[i].GetParentTxHashes()[j], subtreeMeta2.TxInpoints[i].GetParentTxHashes()[j]) + } + } + }) + + t.Run("TestNewSubtreeMetaFromReader with large cap", func(t *testing.T) { + subtree, _ := NewTreeByLeafCount(32 * 1024) + _ = subtree.AddNode(*tx1.TxIDChainHash(), 1, 1) + + b, err := subtree.Serialize() + require.NoError(t, err) + + subtree2, err := NewSubtreeFromBytes(b) + require.NoError(t, err) + + subtreeMeta := NewSubtreeMeta(subtree2) + _ = subtreeMeta.SetTxInpointsFromTx(tx1) + + b, err = subtreeMeta.Serialize() + require.NoError(t, err) + + buf := bytes.NewReader(b) + + subtreeMeta2, err := NewSubtreeMetaFromReader(subtree2, buf) + require.NoError(t, err) + + assert.Equal(t, subtreeMeta.rootHash, subtreeMeta2.rootHash) + assert.Equal(t, len(subtreeMeta.TxInpoints), len(subtreeMeta2.TxInpoints)) + + assert.Equal(t, subtree2.Size(), cap(subtreeMeta.TxInpoints)) + }) + + t.Run("TestNewSubtreeMetaFromBytes with all set", func(t *testing.T) { + _, subtree, subtreeMeta := initSubtreeMeta(t) + + b, err := subtreeMeta.Serialize() + require.NoError(t, err) + subtreeMeta2, err := NewSubtreeMetaFromBytes(subtree, b) + require.NoError(t, err) + + assert.Equal(t, subtreeMeta.rootHash, subtreeMeta2.rootHash) + assert.Equal(t, len(subtreeMeta.TxInpoints), len(subtreeMeta2.TxInpoints)) + + for i := 0; i < 4; i++ { + assert.Equal(t, len(subtreeMeta.TxInpoints[i].GetParentTxHashes()), len(subtreeMeta2.TxInpoints[i].GetParentTxHashes())) + + for j := 0; j < len(subtreeMeta.TxInpoints[i].GetParentTxHashes()); j++ { + assert.Equal(t, subtreeMeta.TxInpoints[i].GetParentTxHashes()[j], subtreeMeta2.TxInpoints[i].GetParentTxHashes()[j]) + } + } + }) +} + +func TestNewSubtreeMetaGetParentTxHashes(t *testing.T) { + txs, _, subtreeMeta := initSubtreeMeta(t) + + t.Run("TestGetParentTxHashes", func(t *testing.T) { + for i := 0; i < 4; i++ { + allParentTxHashes, err := subtreeMeta.GetParentTxHashes(i) + require.NoError(t, err) + + assert.Equal(t, 1, len(allParentTxHashes)) + assert.Equal(t, *txs[i].Inputs[0].PreviousTxIDChainHash(), allParentTxHashes[0]) + } + }) + + t.Run("TestGetParentTxHashes with out of range index", func(t *testing.T) { + allParentTxHashes, err := subtreeMeta.GetParentTxHashes(5) + require.Error(t, err) + assert.Nil(t, allParentTxHashes) + }) +} + +func TestSubtreeMetaGetTxInpoints(t *testing.T) { + txs, _, subtreeMeta := initSubtreeMeta(t) + + t.Run("empty subtree", func(t *testing.T) { + subtree, _ := NewTreeByLeafCount(4) + emptySubtreeMeta := NewSubtreeMeta(subtree) + + inpoints, err := emptySubtreeMeta.GetTxInpoints(0) + require.NoError(t, err) + + assert.Equal(t, 0, len(inpoints)) + }) + + t.Run("out of range index", func(t *testing.T) { + inpoints, err := subtreeMeta.GetTxInpoints(5) + require.Error(t, err) + + assert.Nil(t, inpoints) + }) + + t.Run("TestGetTxInpoints", func(t *testing.T) { + for i := 0; i < 4; i++ { + inpoints, err := subtreeMeta.GetTxInpoints(0) + require.NoError(t, err) + + assert.Equal(t, 1, len(inpoints)) + assert.Equal(t, *txs[i].Inputs[0].PreviousTxIDChainHash(), inpoints[0].Hash) + assert.Equal(t, txs[i].Inputs[0].PreviousTxOutIndex, inpoints[0].Index) + } + }) +} + +func TestSubtreeMetaSetTxInpoints(t *testing.T) { + t.Run("TestSetTxInpointsFromTx", func(t *testing.T) { + txs, _, subtreeMeta := initSubtreeMeta(t) + + for i := 0; i < 4; i++ { + err := subtreeMeta.SetTxInpointsFromTx(txs[i]) + require.NoError(t, err) + + inpoints, err := subtreeMeta.GetTxInpoints(i) + require.NoError(t, err) + + assert.Equal(t, 1, len(inpoints)) + assert.Equal(t, *txs[i].Inputs[0].PreviousTxIDChainHash(), inpoints[0].Hash) + assert.Equal(t, txs[i].Inputs[0].PreviousTxOutIndex, inpoints[0].Index) + } + }) + + t.Run("TestSetTxInpoints", func(t *testing.T) { + txs, _, subtreeMeta := initSubtreeMeta(t) + + // Test setting inpoints for a subtree node that does not exist + err := subtreeMeta.SetTxInpoints(2, TxInpoints{ + ParentTxHashes: []chainhash.Hash{*txs[0].Inputs[0].PreviousTxIDChainHash()}, + Idxs: [][]uint32{{1, 2, 3}}, + nrInpoints: 3, + }) + require.NoError(t, err) + + inpoints, err := subtreeMeta.GetTxInpoints(2) + require.NoError(t, err) + + assert.Equal(t, 3, len(inpoints)) + assert.Equal(t, *txs[0].Inputs[0].PreviousTxIDChainHash(), inpoints[0].Hash) + assert.Equal(t, *txs[0].Inputs[0].PreviousTxIDChainHash(), inpoints[0].Hash) + assert.Equal(t, *txs[0].Inputs[0].PreviousTxIDChainHash(), inpoints[0].Hash) + assert.Equal(t, uint32(1), inpoints[0].Index) + assert.Equal(t, uint32(2), inpoints[1].Index) + assert.Equal(t, uint32(3), inpoints[2].Index) + + // Test setting inpoints for a subtree node that does not exist + err = subtreeMeta.SetTxInpoints(5, TxInpoints{ + ParentTxHashes: []chainhash.Hash{*txs[0].Inputs[0].PreviousTxIDChainHash()}, + Idxs: [][]uint32{{1, 2, 3}}, + nrInpoints: 3, + }) + require.Error(t, err) + assert.Equal(t, "index out of range", err.Error()) + }) +} + +func initSubtreeMeta(t *testing.T) ([]*bt.Tx, *Subtree, *SubtreeMeta) { + tx1 := tx.Clone() + tx1.Version = 1 + + tx2 := tx.Clone() + tx2.Version = 2 + + tx3 := tx.Clone() + tx3.Version = 3 + + tx4 := tx.Clone() + tx4.Version = 4 + + subtree, _ := NewTreeByLeafCount(4) + require.NoError(t, subtree.AddNode(*tx1.TxIDChainHash(), 1, 1)) + require.NoError(t, subtree.AddNode(*tx2.TxIDChainHash(), 2, 2)) + require.NoError(t, subtree.AddNode(*tx3.TxIDChainHash(), 3, 3)) + require.NoError(t, subtree.AddNode(*tx4.TxIDChainHash(), 4, 4)) + + subtreeMeta := NewSubtreeMeta(subtree) + + require.NoError(t, subtreeMeta.SetTxInpointsFromTx(tx1)) + require.NoError(t, subtreeMeta.SetTxInpointsFromTx(tx2)) + require.NoError(t, subtreeMeta.SetTxInpointsFromTx(tx3)) + require.NoError(t, subtreeMeta.SetTxInpointsFromTx(tx4)) + + return []*bt.Tx{tx1, tx2, tx3, tx4}, subtree, subtreeMeta +} diff --git a/subtree_test.go b/subtree_test.go new file mode 100644 index 0000000..71efb25 --- /dev/null +++ b/subtree_test.go @@ -0,0 +1,1068 @@ +package subtree + +import ( + "bytes" + "fmt" + "testing" + + "github.com/libsv/go-bt/v2/chainhash" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewTree(t *testing.T) { + t.Run("invalid size", func(t *testing.T) { + _, err := NewTreeByLeafCount(123) + require.Error(t, err) + }) + + t.Run("valid size", func(t *testing.T) { + st, err := NewTree(20) + require.NoError(t, err) + + if st.Size() != 1048576 { + t.Errorf("expected size to be 1048576, got %d", st.Size()) + } + }) +} + +func TestNewIncompleteTreeByLeafCount(t *testing.T) { + t.Run("invalid size", func(t *testing.T) { + _, err := NewIncompleteTreeByLeafCount(0) + require.Error(t, err) + }) + + t.Run("valid size", func(t *testing.T) { + st, err := NewIncompleteTreeByLeafCount(20) + require.NoError(t, err) + + // should be the next power of 2 + assert.Equal(t, 32, st.Size()) + }) +} + +func TestRootHash(t *testing.T) { + t.Run("root hash", func(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("97af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115") + hash2, _ := chainhash.NewHashFromStr("7ce05dda56bc523048186c0f0474eb21c92fe35de6d014bd016834637a3ed08d") + hash3, _ := chainhash.NewHashFromStr("3070fb937289e24720c827cbc24f3fce5c361cd7e174392a700a9f42051609e0") + hash4, _ := chainhash.NewHashFromStr("d3cde0ab7142cc99acb31c5b5e1e941faed1c5cf5f8b63ed663972845d663487") + _ = st.AddNode(*hash1, 111, 0) + _ = st.AddNode(*hash2, 111, 0) + _ = st.AddNode(*hash3, 111, 0) + _ = st.AddNode(*hash4, 111, 0) + + rootHash := st.RootHash() + assert.Equal(t, "b47df6aa4fe0a1d3841c635444be4e33eb8cdc2f2e929ced06d0a8454fb28225", rootHash.String()) + }) +} + +func Test_RootHashWithReplaceRootNode(t *testing.T) { + t.Run("empty tree", func(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + rootHash := st.RootHash() + assert.Nil(t, rootHash) + }) + + t.Run("replace with 0 noded", func(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + hash1, _ := chainhash.NewHashFromStr("97af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115") + + rootHash := st.RootHash() + assert.Nil(t, rootHash) + + rootHash2, err := st.RootHashWithReplaceRootNode(hash1, 111, 0) + require.NoError(t, err) + assert.NotEqual(t, rootHash, rootHash2) + assert.Equal(t, "97af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115", rootHash2.String()) + }) + + t.Run("replace with only 1 node", func(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + _ = st.AddCoinbaseNode() + + hash1, _ := chainhash.NewHashFromStr("97af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115") + + rootHash := st.RootHash() + assert.Equal(t, "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", rootHash.String()) + + rootHash2, err := st.RootHashWithReplaceRootNode(hash1, 111, 0) + require.NoError(t, err) + assert.NotEqual(t, rootHash, rootHash2) + assert.Equal(t, "97af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115", rootHash2.String()) + }) + + t.Run("root hash with replace root node", func(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("97af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115") + hash2, _ := chainhash.NewHashFromStr("7ce05dda56bc523048186c0f0474eb21c92fe35de6d014bd016834637a3ed08d") + hash3, _ := chainhash.NewHashFromStr("3070fb937289e24720c827cbc24f3fce5c361cd7e174392a700a9f42051609e0") + hash4, _ := chainhash.NewHashFromStr("d3cde0ab7142cc99acb31c5b5e1e941faed1c5cf5f8b63ed663972845d663487") + _ = st.AddNode(*hash1, 111, 0) + _ = st.AddNode(*hash2, 111, 0) + _ = st.AddNode(*hash3, 111, 0) + _ = st.AddNode(*hash4, 111, 0) + + rootHash := st.RootHash() + assert.Equal(t, "b47df6aa4fe0a1d3841c635444be4e33eb8cdc2f2e929ced06d0a8454fb28225", rootHash.String()) + + rootHash2, err := st.RootHashWithReplaceRootNode(hash4, 111, 0) + require.NoError(t, err) + assert.NotEqual(t, rootHash, rootHash2) + assert.Equal(t, "dfec71cf72403643187e9e02d7c436e87251fa098cffa54d182022153da3d09a", rootHash2.String()) + }) +} + +func TestGetMap(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("8c14f0db3df150123e6f3dbbf30f8b955a8249b62ac1d1ff16284aefa3d06d87") + hash2, _ := chainhash.NewHashFromStr("fff2525b8931402dd09222c50775608f75787bd2b87e56995a7bdd30f79702c4") + hash3, _ := chainhash.NewHashFromStr("6359f0868171b1d194cbee1af2f16ea598ae8fad666d9b012c8ed2b79a236ec4") + hash4, _ := chainhash.NewHashFromStr("e9a66845e05d5abc0ad04ec80f774a7e585c6e8db975962d069a522137b80c1d") + _ = st.AddNode(*hash1, 111, 101) + _ = st.AddNode(*hash2, 112, 102) + _ = st.AddNode(*hash3, 113, 103) + _ = st.AddNode(*hash4, 114, 104) + + txMap, err := st.GetMap() + require.NoError(t, err) + + assert.Equal(t, 4, txMap.Length()) + + for _, node := range st.Nodes { + txIdx, ok := txMap.Get(node.Hash) + require.True(t, ok, fmt.Sprintf("expected to find hash %s in map", node.Hash.String())) + + // find node in original tree + originalIdx := st.NodeIndex(node.Hash) + assert.Equal(t, uint64(originalIdx), txIdx) //nolint:gosec // Ignore for now + } +} + +func TestHasNode(t *testing.T) { + t.Run("exists", func(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("8c14f0db3df150123e6f3dbbf30f8b955a8249b62ac1d1ff16284aefa3d06d87") + _ = st.AddNode(*hash1, 111, 0) + + exists := st.HasNode(*hash1) + assert.True(t, exists) + }) + + t.Run("does not exist", func(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("8c14f0db3df150123e6f3dbbf30f8b955a8249b62ac1d1ff16284aefa3d06d87") + exists := st.HasNode(*hash1) + assert.False(t, exists) + }) +} + +func TestGetNode(t *testing.T) { + t.Run("exists", func(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("8c14f0db3df150123e6f3dbbf30f8b955a8249b62ac1d1ff16284aefa3d06d87") + _ = st.AddNode(*hash1, 111, 0) + + node, err := st.GetNode(*hash1) + assert.NoError(t, err) + assert.Equal(t, *hash1, node.Hash) + assert.Equal(t, uint64(111), node.Fee) + }) + + t.Run("does not exist", func(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("8c14f0db3df150123e6f3dbbf30f8b955a8249b62ac1d1ff16284aefa3d06d87") + node, err := st.GetNode(*hash1) + assert.Error(t, err) + assert.Nil(t, node) + }) +} + +func TestDifference(t *testing.T) { + st1, err := NewTree(2) + require.NoError(t, err) + + hash1, _ := chainhash.NewHashFromStr("8c14f0db3df150123e6f3dbbf30f8b955a8249b62ac1d1ff16284aefa3d06d87") + hash2, _ := chainhash.NewHashFromStr("fff2525b8931402dd09222c50775608f75787bd2b87e56995a7bdd30f79702c4") + hash3, _ := chainhash.NewHashFromStr("6359f0868171b1d194cbee1af2f16ea598ae8fad666d9b012c8ed2b79a236ec4") + hash4, _ := chainhash.NewHashFromStr("e9a66845e05d5abc0ad04ec80f774a7e585c6e8db975962d069a522137b80c1d") + _ = st1.AddNode(*hash1, 111, 0) + _ = st1.AddNode(*hash2, 112, 0) + _ = st1.AddNode(*hash3, 113, 0) + _ = st1.AddNode(*hash4, 114, 0) + + st2, err := NewTree(2) + require.NoError(t, err) + + _ = st2.AddNode(*hash3, 113, 0) + _ = st2.AddNode(*hash4, 114, 0) + + st2Map, err := st2.GetMap() + require.NoError(t, err) + + diff, err := st1.Difference(st2Map) + require.NoError(t, err) + + assert.Equal(t, 2, len(diff)) + assert.Equal(t, *hash1, diff[0].Hash) + assert.Equal(t, *hash2, diff[1].Hash) +} + +func TestRootHashSimon(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("8c14f0db3df150123e6f3dbbf30f8b955a8249b62ac1d1ff16284aefa3d06d87") + hash2, _ := chainhash.NewHashFromStr("fff2525b8931402dd09222c50775608f75787bd2b87e56995a7bdd30f79702c4") + hash3, _ := chainhash.NewHashFromStr("6359f0868171b1d194cbee1af2f16ea598ae8fad666d9b012c8ed2b79a236ec4") + hash4, _ := chainhash.NewHashFromStr("e9a66845e05d5abc0ad04ec80f774a7e585c6e8db975962d069a522137b80c1d") + _ = st.AddNode(*hash1, 111, 0) + _ = st.AddNode(*hash2, 111, 0) + _ = st.AddNode(*hash3, 111, 0) + _ = st.AddNode(*hash4, 111, 0) + + rootHash := st.RootHash() + assert.Equal(t, "f3e94742aca4b5ef85488dc37c06c3282295ffec960994b2c0d5ac2a25a95766", rootHash.String()) +} + +func TestTwoTransactions(t *testing.T) { + st, err := NewTree(1) + require.NoError(t, err) + + if st.Size() != 2 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("de2c2e8628ab837ceff3de0217083d9d5feb71f758a5d083ada0b33a36e1b30e") + hash2, _ := chainhash.NewHashFromStr("89878bfd69fba52876e5217faec126fc6a20b1845865d4038c12f03200793f48") + _ = st.AddNode(*hash1, 111, 0) + _ = st.AddNode(*hash2, 111, 0) + + rootHash := st.RootHash() + assert.Equal(t, "7a059188283323a2ef0e02dd9f8ba1ac550f94646290d0a52a586e5426c956c5", rootHash.String()) +} + +func TestSubtree_GetMerkleProof(t *testing.T) { + st, err := NewTree(3) + require.NoError(t, err) + + if st.Size() != 8 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + txIDS := []string{ + "4634057867994ae379e82b408cc9eb145a6e921b95ca38f2ced7eb880685a290", + "7f87fe1100963977975cef49344e442b4fa3dd9d41de19bc94609c100210ca05", + "a28c1021f07263101f5a5052c6a7bdc970ac1d0ab09d8d20aa7a4a61ad9d6597", + "dcd31c71368f757f65105d68ee1a2e5598db84900e28dabecba23651c5cda468", + "7bac32882547cbb540914f48c6ac99ac682ef001c3aa3d4dcdb5951c8db79678", + "67c0f4eb336057ecdf940497a75fcbd1a131e981edf568b54eed2f944889e441", + } + + var txHash *chainhash.Hash + for _, txID := range txIDS { + txHash, _ = chainhash.NewHashFromStr(txID) + _ = st.AddNode(*txHash, 101, 0) + } + + proof, err := st.GetMerkleProof(1) + require.NoError(t, err) + assert.Equal(t, 3, len(proof)) + assert.Equal(t, "4634057867994ae379e82b408cc9eb145a6e921b95ca38f2ced7eb880685a290", proof[0].String()) + assert.Equal(t, "a9e6413abb02b534ff5250cbabdc673480656d0e053cfd23fd010241d5e045f2", proof[1].String()) + assert.Equal(t, "63fd0f07ff87223f688d0809f46a8118f185bab04d300406513acdc8832bad5e", proof[2].String()) + assert.Equal(t, "68e239fc6684a224142add79ebed60569baedf667c6be03a5f8719aba44a488b", st.RootHash().String()) + + proof, err = st.GetMerkleProof(4) + require.NoError(t, err) + assert.Equal(t, 3, len(proof)) + assert.Equal(t, "67c0f4eb336057ecdf940497a75fcbd1a131e981edf568b54eed2f944889e441", proof[0].String()) + assert.Equal(t, "e2a6065233b307b77a5f73f9f27843d42e48d5e061567416b4508517ef2dd452", proof[1].String()) + assert.Equal(t, "bfd8a13a5cb1ba128319ee95e09a7e2ff67a52d0c9af8485bfffae737e32d6bf", proof[2].String()) + assert.Equal(t, "68e239fc6684a224142add79ebed60569baedf667c6be03a5f8719aba44a488b", st.RootHash().String()) + + proof, err = st.GetMerkleProof(6) + require.Error(t, err) // out of range + assert.Len(t, proof, 0) +} + +func Test_Serialize(t *testing.T) { + t.Run("Serialize", func(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("8c14f0db3df150123e6f3dbbf30f8b955a8249b62ac1d1ff16284aefa3d06d87") + hash2, _ := chainhash.NewHashFromStr("fff2525b8931402dd09222c50775608f75787bd2b87e56995a7bdd30f79702c4") + hash3, _ := chainhash.NewHashFromStr("6359f0868171b1d194cbee1af2f16ea598ae8fad666d9b012c8ed2b79a236ec4") + hash4, _ := chainhash.NewHashFromStr("e9a66845e05d5abc0ad04ec80f774a7e585c6e8db975962d069a522137b80c1d") + _ = st.AddNode(*hash1, 111, 0) + _ = st.AddNode(*hash2, 111, 0) + _ = st.AddNode(*hash3, 111, 0) + _ = st.AddNode(*hash4, 111, 0) + + serializedBytes, err := st.Serialize() + require.NoError(t, err) + + newSubtree, err := NewTree(2) + require.NoError(t, err) + + err = newSubtree.Deserialize(serializedBytes) + require.NoError(t, err) + assert.Equal(t, st.Fees, newSubtree.Fees) + assert.Equal(t, st.Size(), newSubtree.Size()) + assert.Equal(t, st.RootHash(), newSubtree.RootHash()) + + assert.Equal(t, len(st.Nodes), len(newSubtree.Nodes)) + + for i := 0; i < len(st.Nodes); i++ { + assert.Equal(t, st.Nodes[i].Hash.String(), newSubtree.Nodes[i].Hash.String()) + assert.Equal(t, st.Nodes[i].Fee, newSubtree.Nodes[i].Fee) + } + }) + + t.Run("Serialize nodes", func(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("8c14f0db3df150123e6f3dbbf30f8b955a8249b62ac1d1ff16284aefa3d06d87") + hash2, _ := chainhash.NewHashFromStr("fff2525b8931402dd09222c50775608f75787bd2b87e56995a7bdd30f79702c4") + hash3, _ := chainhash.NewHashFromStr("6359f0868171b1d194cbee1af2f16ea598ae8fad666d9b012c8ed2b79a236ec4") + hash4, _ := chainhash.NewHashFromStr("e9a66845e05d5abc0ad04ec80f774a7e585c6e8db975962d069a522137b80c1d") + _ = st.AddNode(*hash1, 111, 0) + _ = st.AddNode(*hash2, 111, 0) + _ = st.AddNode(*hash3, 111, 0) + _ = st.AddNode(*hash4, 111, 0) + + subtreeBytes, err := st.SerializeNodes() + require.NoError(t, err) + + require.Equal(t, chainhash.HashSize*4, len(subtreeBytes)) + + txHashes := make([]chainhash.Hash, len(subtreeBytes)/chainhash.HashSize) + for i := 0; i < len(subtreeBytes); i += chainhash.HashSize { + txHashes[i/chainhash.HashSize] = chainhash.Hash(subtreeBytes[i : i+chainhash.HashSize]) + } + + assert.Equal(t, hash1.String(), txHashes[0].String()) + assert.Equal(t, hash2.String(), txHashes[1].String()) + assert.Equal(t, hash3.String(), txHashes[2].String()) + assert.Equal(t, hash4.String(), txHashes[3].String()) + }) + + t.Run("New subtree from bytes", func(t *testing.T) { + st, serializedBytes := getSubtreeBytes(t) + + newSubtree, err := NewSubtreeFromBytes(serializedBytes) + require.NoError(t, err) + + for i := 0; i < newSubtree.Size(); i += chainhash.HashSize { + assert.Equal(t, st.Nodes[i/chainhash.HashSize].Hash.String(), newSubtree.Nodes[i/chainhash.HashSize].Hash.String()) + } + }) + + t.Run("New subtree from reader", func(t *testing.T) { + st, serializedBytes := getSubtreeBytes(t) + + newSubtree, err := NewSubtreeFromReader(bytes.NewReader(serializedBytes)) + require.NoError(t, err) + + for i := 0; i < newSubtree.Size(); i += chainhash.HashSize { + assert.Equal(t, st.Nodes[i/chainhash.HashSize].Hash.String(), newSubtree.Nodes[i/chainhash.HashSize].Hash.String()) + } + }) + + t.Run("DeserializeNodes with reader", func(t *testing.T) { + st, serializedBytes := getSubtreeBytes(t) + + subtreeBytes, err := DeserializeNodesFromReader(bytes.NewReader(serializedBytes)) + require.NoError(t, err) + + require.Equal(t, chainhash.HashSize*4, len(subtreeBytes)) + + for i := 0; i < len(subtreeBytes); i += chainhash.HashSize { + txHash := chainhash.Hash(subtreeBytes[i : i+chainhash.HashSize]) + assert.Equal(t, st.Nodes[i/chainhash.HashSize].Hash.String(), txHash.String()) + } + }) + + t.Run("Deserialize with reader", func(t *testing.T) { + st, serializedBytes := getSubtreeBytes(t) + + newSubtree, err := NewTree(2) + require.NoError(t, err) + + r := bytes.NewReader(serializedBytes) + + err = newSubtree.DeserializeFromReader(r) + require.NoError(t, err) + assert.Equal(t, st.Fees, newSubtree.Fees) + assert.Equal(t, st.Size(), newSubtree.Size()) + assert.Equal(t, st.RootHash(), newSubtree.RootHash()) + + assert.Equal(t, len(st.Nodes), len(newSubtree.Nodes)) + + for i := 0; i < len(st.Nodes); i++ { + assert.Equal(t, st.Nodes[i].Hash.String(), newSubtree.Nodes[i].Hash.String()) + assert.Equal(t, st.Nodes[i].Fee, newSubtree.Nodes[i].Fee) + } + }) + + t.Run("Serialize with conflicting", func(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("8c14f0db3df150123e6f3dbbf30f8b955a8249b62ac1d1ff16284aefa3d06d87") + hash2, _ := chainhash.NewHashFromStr("fff2525b8931402dd09222c50775608f75787bd2b87e56995a7bdd30f79702c4") + hash3, _ := chainhash.NewHashFromStr("6359f0868171b1d194cbee1af2f16ea598ae8fad666d9b012c8ed2b79a236ec4") + hash4, _ := chainhash.NewHashFromStr("e9a66845e05d5abc0ad04ec80f774a7e585c6e8db975962d069a522137b80c1d") + _ = st.AddNode(*hash1, 111, 0) + _ = st.AddNode(*hash2, 111, 0) + _ = st.AddNode(*hash3, 111, 0) + _ = st.AddNode(*hash4, 111, 0) + + err = st.AddConflictingNode(*hash3) + require.NoError(t, err) + + err = st.AddConflictingNode(*hash4) + require.NoError(t, err) + + serializedBytes, err := st.Serialize() + require.NoError(t, err) + + newSubtree, err := NewTree(2) + require.NoError(t, err) + + err = newSubtree.Deserialize(serializedBytes) + require.NoError(t, err) + assert.Equal(t, st.Fees, newSubtree.Fees) + assert.Equal(t, st.Size(), newSubtree.Size()) + assert.Equal(t, st.RootHash(), newSubtree.RootHash()) + + assert.Equal(t, len(st.Nodes), len(newSubtree.Nodes)) + + for i := 0; i < len(st.Nodes); i++ { + assert.Equal(t, st.Nodes[i].Hash.String(), newSubtree.Nodes[i].Hash.String()) + assert.Equal(t, st.Nodes[i].Fee, newSubtree.Nodes[i].Fee) + } + + assert.Equal(t, len(st.ConflictingNodes), len(newSubtree.ConflictingNodes)) + + for i := 0; i < len(st.ConflictingNodes); i++ { + assert.Equal(t, st.ConflictingNodes[i].String(), newSubtree.ConflictingNodes[i].String()) + } + + conflictingNodes, err := DeserializeSubtreeConflictingFromReader(bytes.NewReader(serializedBytes)) + require.NoError(t, err) + + assert.Equal(t, len(st.ConflictingNodes), len(conflictingNodes)) + + for i := 0; i < len(st.ConflictingNodes); i++ { + assert.Equal(t, st.ConflictingNodes[i].String(), conflictingNodes[i].String()) + } + }) +} + +func Test_Duplicate(t *testing.T) { + t.Run("Duplicate", func(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("8c14f0db3df150123e6f3dbbf30f8b955a8249b62ac1d1ff16284aefa3d06d87") + hash2, _ := chainhash.NewHashFromStr("fff2525b8931402dd09222c50775608f75787bd2b87e56995a7bdd30f79702c4") + hash3, _ := chainhash.NewHashFromStr("6359f0868171b1d194cbee1af2f16ea598ae8fad666d9b012c8ed2b79a236ec4") + hash4, _ := chainhash.NewHashFromStr("e9a66845e05d5abc0ad04ec80f774a7e585c6e8db975962d069a522137b80c1d") + _ = st.AddNode(*hash1, 111, 0) + _ = st.AddNode(*hash2, 111, 0) + _ = st.AddNode(*hash3, 111, 0) + _ = st.AddNode(*hash4, 111, 0) + + newSubtree := st.Duplicate() + + require.NoError(t, err) + assert.Equal(t, st.Fees, newSubtree.Fees) + assert.Equal(t, st.Size(), newSubtree.Size()) + assert.Equal(t, st.RootHash(), newSubtree.RootHash()) + + assert.Equal(t, len(st.Nodes), len(newSubtree.Nodes)) + + for i := 0; i < len(st.Nodes); i++ { + assert.Equal(t, st.Nodes[i].Hash.String(), newSubtree.Nodes[i].Hash.String()) + assert.Equal(t, st.Nodes[i].Fee, newSubtree.Nodes[i].Fee) + } + }) + + t.Run("Clone - not same root hash", func(t *testing.T) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("8c14f0db3df150123e6f3dbbf30f8b955a8249b62ac1d1ff16284aefa3d06d87") + hash2, _ := chainhash.NewHashFromStr("fff2525b8931402dd09222c50775608f75787bd2b87e56995a7bdd30f79702c4") + hash3, _ := chainhash.NewHashFromStr("6359f0868171b1d194cbee1af2f16ea598ae8fad666d9b012c8ed2b79a236ec4") + hash4, _ := chainhash.NewHashFromStr("e9a66845e05d5abc0ad04ec80f774a7e585c6e8db975962d069a522137b80c1d") + _ = st.AddNode(*hash1, 111, 0) + _ = st.AddNode(*hash2, 111, 0) + _ = st.AddNode(*hash3, 111, 0) + _ = st.AddNode(*hash4, 111, 0) + + newSubtree := st.Duplicate() + newSubtree.ReplaceRootNode(hash4, 111, 0) + assert.NotEqual(t, st.RootHash(), newSubtree.RootHash()) + }) +} + +func TestSubtree_NodeIndex(t *testing.T) { + tx1 := tx.Clone() + tx1.Version = 1 + hash1 := *tx1.TxIDChainHash() + + tx2 := tx.Clone() + tx2.Version = 2 + hash2 := *tx2.TxIDChainHash() + + tx3 := tx.Clone() + tx3.Version = 3 + hash3 := *tx3.TxIDChainHash() + + t.Run("existing node", func(t *testing.T) { + st, err := NewTree(4) + require.NoError(t, err) + + _ = st.AddNode(hash1, 111, 1) + _ = st.AddNode(hash2, 112, 2) + + index := st.NodeIndex(hash1) + assert.Equal(t, 0, index) + + index = st.NodeIndex(hash2) + assert.Equal(t, 1, index) + }) + + t.Run("non-existing node", func(t *testing.T) { + st, err := NewTree(4) + require.NoError(t, err) + + _ = st.AddNode(hash1, 111, 1) + _ = st.AddNode(hash2, 112, 2) + + index := st.NodeIndex(hash3) + assert.Equal(t, -1, index) + }) + + t.Run("remove existing node", func(t *testing.T) { + st, err := NewTree(4) + require.NoError(t, err) + + _ = st.AddNode(hash1, 111, 1) + _ = st.AddNode(hash2, 112, 2) + + assert.Equal(t, 2, st.Length()) + + err = st.RemoveNodeAtIndex(0) + require.NoError(t, err) + assert.Equal(t, 1, st.Length()) + + // hash2 should now be at node 0 + assert.Equal(t, hash2, st.Nodes[0].Hash) + }) + + t.Run("remove non-existing node", func(t *testing.T) { + st, err := NewTree(4) + require.NoError(t, err) + + err = st.RemoveNodeAtIndex(0) + assert.Error(t, err, "index out of range") + }) +} + +func getSubtreeBytes(t *testing.T) (*Subtree, []byte) { + st, err := NewTree(2) + require.NoError(t, err) + + if st.Size() != 4 { + t.Errorf("expected size to be 4, got %d", st.Size()) + } + + hash1, _ := chainhash.NewHashFromStr("8c14f0db3df150123e6f3dbbf30f8b955a8249b62ac1d1ff16284aefa3d06d87") + hash2, _ := chainhash.NewHashFromStr("fff2525b8931402dd09222c50775608f75787bd2b87e56995a7bdd30f79702c4") + hash3, _ := chainhash.NewHashFromStr("6359f0868171b1d194cbee1af2f16ea598ae8fad666d9b012c8ed2b79a236ec4") + hash4, _ := chainhash.NewHashFromStr("e9a66845e05d5abc0ad04ec80f774a7e585c6e8db975962d069a522137b80c1d") + _ = st.AddNode(*hash1, 111, 0) + _ = st.AddNode(*hash2, 111, 0) + _ = st.AddNode(*hash3, 111, 0) + _ = st.AddNode(*hash4, 111, 0) + + serializedBytes, err := st.Serialize() + require.NoError(t, err) + + return st, serializedBytes +} + +func Test_BuildMerkleTreeStoreFromBytes(t *testing.T) { + t.Run("complete tree", func(t *testing.T) { + hashes := make([]*chainhash.Hash, 8) + hashes[0], _ = chainhash.NewHashFromStr("97af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115") + hashes[1], _ = chainhash.NewHashFromStr("7ce05dda56bc523048186c0f0474eb21c92fe35de6d014bd016834637a3ed08d") + hashes[2], _ = chainhash.NewHashFromStr("3070fb937289e24720c827cbc24f3fce5c361cd7e174392a700a9f42051609e0") + hashes[3], _ = chainhash.NewHashFromStr("d3cde0ab7142cc99acb31c5b5e1e941faed1c5cf5f8b63ed663972845d663487") + hashes[4], _ = chainhash.NewHashFromStr("87af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115") + hashes[5], _ = chainhash.NewHashFromStr("6ce05dda56bc523048186c0f0474eb21c92fe35de6d014bd016834637a3ed08d") + hashes[6], _ = chainhash.NewHashFromStr("2070fb937289e24720c827cbc24f3fce5c361cd7e174392a700a9f42051609e0") + hashes[7], _ = chainhash.NewHashFromStr("c3cde0ab7142cc99acb31c5b5e1e941faed1c5cf5f8b63ed663972845d663487") + + subtree, err := NewTreeByLeafCount(8) + require.NoError(t, err) + + for _, hash := range hashes { + _ = subtree.AddNode(*hash, 111, 0) + } + + merkleStore, err := BuildMerkleTreeStoreFromBytes(subtree.Nodes) + require.NoError(t, err) + + expectedMerkleStore := []string{ + "2207df31366e6fdd96a7ef3286278422c1c6dd3d74c3f85bbcfee82a8d31da25", + "c32db78e5f8437648888713982ea3d49628dbde0b4b48857147f793b55d26f09", + "4cfd8f882dc64dd7a123d545785bd2670c981493ea85ec058e6428cb95f04fa7", + "0bb2f84f4071e1a04f61bb04a10dc17affcf7fd558945a3a31b1d1f0fb6ec121", + "b47df6aa4fe0a1d3841c635444be4e33eb8cdc2f2e929ced06d0a8454fb28225", + "1e3cfb94c292e8fc2ac692c4c4db4ea73784978ff47424668233a7f491e218a3", + "86867b9f3e7dcb4bdf5b5cc99322122fe492bc466621f3709d4e389e7e14c16c", + } + + actualMerkleStore := make([]string, len(*merkleStore)) + for idx, merkle := range *merkleStore { + actualMerkleStore[idx] = merkle.String() + } + + assert.Equal(t, expectedMerkleStore, actualMerkleStore) + }) + + t.Run("incomplete tree", func(t *testing.T) { + st, err := NewTreeByLeafCount(8) + require.NoError(t, err) + + txIDS := []string{ + "4634057867994ae379e82b408cc9eb145a6e921b95ca38f2ced7eb880685a290", + "7f87fe1100963977975cef49344e442b4fa3dd9d41de19bc94609c100210ca05", + "a28c1021f07263101f5a5052c6a7bdc970ac1d0ab09d8d20aa7a4a61ad9d6597", + "dcd31c71368f757f65105d68ee1a2e5598db84900e28dabecba23651c5cda468", + "7bac32882547cbb540914f48c6ac99ac682ef001c3aa3d4dcdb5951c8db79678", + "67c0f4eb336057ecdf940497a75fcbd1a131e981edf568b54eed2f944889e441", + } + + var txHash *chainhash.Hash + for _, txID := range txIDS { + txHash, _ = chainhash.NewHashFromStr(txID) + _ = st.AddNode(*txHash, 101, 0) + } + + merkleStore, err := BuildMerkleTreeStoreFromBytes(st.Nodes) + require.NoError(t, err) + + expectedMerkleStore := []string{ + "dc9ab938cd3124ad36e90c30bcb02256eb73eb62dc657d93e89a0a29f323c3c7", + "a9e6413abb02b534ff5250cbabdc673480656d0e053cfd23fd010241d5e045f2", + "e2a6065233b307b77a5f73f9f27843d42e48d5e061567416b4508517ef2dd452", + "", + "bfd8a13a5cb1ba128319ee95e09a7e2ff67a52d0c9af8485bfffae737e32d6bf", + "63fd0f07ff87223f688d0809f46a8118f185bab04d300406513acdc8832bad5e", + "68e239fc6684a224142add79ebed60569baedf667c6be03a5f8719aba44a488b", + } + + actualMerkleStore := make([]string, len(*merkleStore)) + + for idx, merkle := range *merkleStore { + if merkle.Equal(chainhash.Hash{}) { + actualMerkleStore[idx] = "" + } else { + actualMerkleStore[idx] = merkle.String() + } + } + + assert.Equal(t, expectedMerkleStore, actualMerkleStore) + }) + + t.Run("incomplete tree 2", func(t *testing.T) { + hashes := make([]*chainhash.Hash, 5) + hashes[0], _ = chainhash.NewHashFromStr("97af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115") + hashes[1], _ = chainhash.NewHashFromStr("7ce05dda56bc523048186c0f0474eb21c92fe35de6d014bd016834637a3ed08d") + hashes[2], _ = chainhash.NewHashFromStr("3070fb937289e24720c827cbc24f3fce5c361cd7e174392a700a9f42051609e0") + hashes[3], _ = chainhash.NewHashFromStr("d3cde0ab7142cc99acb31c5b5e1e941faed1c5cf5f8b63ed663972845d663487") + hashes[4], _ = chainhash.NewHashFromStr("87af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115") + + subtree, err := NewTreeByLeafCount(8) + require.NoError(t, err) + + for _, hash := range hashes { + _ = subtree.AddNode(*hash, 111, 0) + } + + merkleStore, err := BuildMerkleTreeStoreFromBytes(subtree.Nodes) + require.NoError(t, err) + + expectedMerkleStore := []string{ + "2207df31366e6fdd96a7ef3286278422c1c6dd3d74c3f85bbcfee82a8d31da25", + "c32db78e5f8437648888713982ea3d49628dbde0b4b48857147f793b55d26f09", + "61a34fe6c63b5276e042a10a559e9ee9bb785f7b40f753fefdf0fe615d8a6be1", + "", + "b47df6aa4fe0a1d3841c635444be4e33eb8cdc2f2e929ced06d0a8454fb28225", + "95d960d5691c5a92beb94501d0f775dbc161e6fe1c6ca420e158ef22f25320cb", + "e641bf2a1c0a2298d628ad70e25976cbda419e825eeb21d854976d6f93192a24", + } + + actualMerkleStore := make([]string, len(*merkleStore)) + + for idx, merkle := range *merkleStore { + if merkle.Equal(chainhash.Hash{}) { + actualMerkleStore[idx] = "" + } else { + actualMerkleStore[idx] = merkle.String() + } + } + + assert.Equal(t, expectedMerkleStore, actualMerkleStore) + }) +} + +// func TestSubtree_AddNode(t *testing.T) { +// t.Run("fee hash", func(t *testing.T) { +// st := NewTree(1) +// assert.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", st.FeeHash.String()) +// }) +// +// t.Run("fee hash 1", func(t *testing.T) { +// st := NewTree(1) +// hash1, _ := chainhash.NewHashFromStr("de2c2e8628ab837ceff3de0217083d9d5feb71f758a5d083ada0b33a36e1b30e") +// _ = st.AddNode(hash1, 111, 0) +// +// assert.Equal(t, "66e4e66648f366400333d922e2371ad132b37054d53410b2767876089707eb43", st.FeeHash.String()) +// }) +// +// t.Run("fee hash 2", func(t *testing.T) { +// st := NewTree(1) +// hash1, _ := chainhash.NewHashFromStr("de2c2e8628ab837ceff3de0217083d9d5feb71f758a5d083ada0b33a36e1b30e") +// hash2, _ := chainhash.NewHashFromStr("89878bfd69fba52876e5217faec126fc6a20b1845865d4038c12f03200793f48") +// _ = st.AddNode(hash1, 111, 0) +// _ = st.AddNode(hash2, 123, 0) +// +// assert.Equal(t, "e6e65a874a12c4753485b3b42d1c378b36b02196ef2b3461da1d452d7d1434fb", st.FeeHash.String()) +// }) +// } + +func TestAddSubtreeNode(t *testing.T) { + t.Run("successfully add node to empty subtree", func(t *testing.T) { + st, err := NewTree(1) // Creates a subtree that can hold 2 nodes + require.NoError(t, err) + + hash, _ := chainhash.NewHashFromStr("97af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115") + node := SubtreeNode{ + Hash: *hash, + Fee: 1000, + SizeInBytes: 250, + } + + err = st.AddSubtreeNode(node) + require.NoError(t, err) + + require.Equal(t, 1, len(st.Nodes)) + require.Equal(t, *hash, st.Nodes[0].Hash) + require.Equal(t, uint64(1000), st.Fees) + require.Equal(t, uint64(250), st.SizeInBytes) + require.Nil(t, st.rootHash) // Should be reset + }) + + t.Run("successfully add multiple nodes", func(t *testing.T) { + st, err := NewTree(1) // Creates a subtree that can hold 2 nodes + require.NoError(t, err) + + hash1, _ := chainhash.NewHashFromStr("97af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115") + node1 := SubtreeNode{ + Hash: *hash1, + Fee: 1000, + SizeInBytes: 250, + } + + hash2, _ := chainhash.NewHashFromStr("7ce05dda56bc523048186c0f0474eb21c92fe35de6d014bd016834637a3ed08d") + node2 := SubtreeNode{ + Hash: *hash2, + Fee: 2000, + SizeInBytes: 500, + } + + err = st.AddSubtreeNode(node1) + require.NoError(t, err) + err = st.AddSubtreeNode(node2) + require.NoError(t, err) + + require.Equal(t, 2, len(st.Nodes)) + require.Equal(t, *hash1, st.Nodes[0].Hash) + require.Equal(t, *hash2, st.Nodes[1].Hash) + require.Equal(t, uint64(3000), st.Fees) + require.Equal(t, uint64(750), st.SizeInBytes) + }) + + t.Run("error when subtree is full", func(t *testing.T) { + st, err := NewTree(1) // Creates a subtree that can hold 2 nodes + require.NoError(t, err) + + // Add two nodes to fill the subtree + hash1, _ := chainhash.NewHashFromStr("97af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115") + node1 := SubtreeNode{Hash: *hash1, Fee: 1000, SizeInBytes: 250} + hash2, _ := chainhash.NewHashFromStr("7ce05dda56bc523048186c0f0474eb21c92fe35de6d014bd016834637a3ed08d") + node2 := SubtreeNode{Hash: *hash2, Fee: 2000, SizeInBytes: 500} + + err = st.AddSubtreeNode(node1) + require.NoError(t, err) + err = st.AddSubtreeNode(node2) + require.NoError(t, err) + + require.True(t, st.IsComplete()) + + // Try to add a third node + hash3, _ := chainhash.NewHashFromStr("3070fb937289e24720c827cbc24f3fce5c361cd7e174392a700a9f42051609e0") + node3 := SubtreeNode{Hash: *hash3, Fee: 3000, SizeInBytes: 750} + err = st.AddSubtreeNode(node3) + require.Error(t, err) + require.Equal(t, "subtree is full", err.Error()) + require.Equal(t, 2, len(st.Nodes)) + }) + + t.Run("error when adding coinbase placeholder", func(t *testing.T) { + st, err := NewTree(1) + require.NoError(t, err) + + node := SubtreeNode{ + Hash: CoinbasePlaceholder, + Fee: 0, + SizeInBytes: 0, + } + + err = st.AddSubtreeNode(node) + require.Error(t, err) + require.Contains(t, err.Error(), "coinbase placeholder node should be added with AddCoinbaseNode") + require.Equal(t, 0, len(st.Nodes)) + }) + + t.Run("node index is updated when it exists", func(t *testing.T) { + st, err := NewTree(1) + require.NoError(t, err) + + // Initialize node index + st.nodeIndex = make(map[chainhash.Hash]int) + + hash, _ := chainhash.NewHashFromStr("97af9ad3583e2f83fc1e44e475e3a3ee31ec032449cc88b491479ef7d187c115") + node := SubtreeNode{ + Hash: *hash, + Fee: 1000, + SizeInBytes: 250, + } + + err = st.AddSubtreeNode(node) + require.NoError(t, err) + + // Check that the node was added to the index + index, exists := st.nodeIndex[*hash] + require.True(t, exists) + require.Equal(t, 0, index) + }) +} + +func TestSubtree_ConflictingNodes(t *testing.T) { + tx1 := tx.Clone() + tx1.Version = 1 + hash1 := *tx1.TxIDChainHash() + + tx2 := tx.Clone() + tx2.Version = 2 + hash2 := *tx2.TxIDChainHash() + + st, err := NewTree(4) + require.NoError(t, err) + + err = st.AddNode(hash1, 111, 1) + require.NoError(t, err) + + err = st.AddNode(hash2, 112, 2) + require.NoError(t, err) + + assert.Len(t, st.Nodes, 2) + assert.Equal(t, 2, st.Length()) + + err = st.AddConflictingNode(hash1) + require.NoError(t, err) + assert.Len(t, st.Nodes, 2) + assert.Equal(t, 2, st.Length()) + assert.Len(t, st.ConflictingNodes, 1) + + bytes, err := st.Serialize() + require.NoError(t, err) + assert.GreaterOrEqual(t, len(bytes), 48) + + newSt, err := NewSubtreeFromBytes(bytes) + require.NoError(t, err) + assert.Len(t, newSt.Nodes, 2) + assert.Equal(t, 2, newSt.Length()) + assert.Len(t, newSt.ConflictingNodes, 1) +} + +func BenchmarkSubtree_Deserialize(b *testing.B) { + // populate subtree for test + subtree, _ := NewTreeByLeafCount(1024 * 1024) + + for i := uint64(0); i < 1024*1024; i++ { + hash, _ := chainhash.NewHashFromStr(fmt.Sprintf("%x", i)) + _ = subtree.AddNode(*hash, i, i) + } + + subtreeBytes, _ := subtree.Serialize() + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := NewSubtreeFromBytes(subtreeBytes) + require.NoError(b, err) + } +} + +func BenchmarkSubtree_DeserializeNodesFromReader(b *testing.B) { + // populate subtree for test + subtree, _ := NewTreeByLeafCount(1024 * 1024) + + for i := uint64(0); i < 1024*1024; i++ { + hash, _ := chainhash.NewHashFromStr(fmt.Sprintf("%x", i)) + _ = subtree.AddNode(*hash, i, i) + } + + subtreeBytes, _ := subtree.Serialize() + subtreeReader := bytes.NewReader(subtreeBytes) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := DeserializeNodesFromReader(subtreeReader) + require.NoError(b, err) + + // reset the subtree reader for the next loop + _, _ = subtreeReader.Seek(0, 0) + } +} + +func BenchmarkSubtree_DeserializeFromReader(b *testing.B) { + // populate subtree for test + subtree, _ := NewTreeByLeafCount(1024 * 1024) + + for i := uint64(0); i < 1024*1024; i++ { + hash, _ := chainhash.NewHashFromStr(fmt.Sprintf("%x", i)) + _ = subtree.AddNode(*hash, i, i) + } + + subtreeBytes, _ := subtree.Serialize() + subtreeReader := bytes.NewReader(subtreeBytes) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + s := &Subtree{} + err := s.DeserializeFromReader(subtreeReader) + require.NoError(b, err) + + // reset the subtree reader for the next loop + _, _ = subtreeReader.Seek(0, 0) + } +} + +func Benchmark_SubtreeNodeIndex(b *testing.B) { + // populate subtree for test + subtree, _ := NewTreeByLeafCount(1024 * 1024) + + for i := uint64(0); i < 1024*1024; i++ { + hash := chainhash.HashH([]byte(fmt.Sprintf("tx_%x", i))) + _ = subtree.AddNode(hash, i, i) + } + + subtreeLength := subtree.Length() + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + index := subtree.NodeIndex(subtree.Nodes[i%subtreeLength].Hash) + require.GreaterOrEqual(b, index, 0) + } +} diff --git a/template.go b/template.go deleted file mode 100644 index 53978ff..0000000 --- a/template.go +++ /dev/null @@ -1,45 +0,0 @@ -// Package template provides a robust starter template for building new Go libraries. -// -// This package implements foundational patterns and utilities for Go library development and is designed to help developers quickly scaffold, test, and maintain secure, idiomatic Go code. -// -// Key features include: -// - Built-in support for code quality, testing, and CI/CD workflows -// - Example functions and best-practice patterns for Go libraries -// -// The package is structured for modularity and ease of extension, following Go community conventions. It relies on the Go standard library and popular tools for testing and linting. -// -// Usage examples: -// -// msg := template.Greet("Alice") -// fmt.Println(msg) // Output: Hello Alice -// -// Important notes: -// - Assumes Go modules are used for dependency management -// - No external configuration is required for basic usage -// - Designed for use as a starting point for new Go projects -// -// This package is part of the go-subtree project and is intended to be copied or forked for new Go library development. -package template - -import "fmt" - -// Greet returns a greeting message for the given first name. -// -// This function performs the following steps: -// - Formats a greeting string using the provided first name. -// -// Parameters: -// - firstname: The first name to include in the greeting message. -// -// Returns: -// - A string containing the greeting message. -// -// Side Effects: -// - None. -// -// Notes: -// - Assumes firstname is a non-empty string; no validation is performed. -// - This function is standalone and not part of a larger workflow. -func Greet(firstname string) string { - return fmt.Sprintf("Hello %s", firstname) -} diff --git a/template_benchmark_test.go b/template_benchmark_test.go deleted file mode 100644 index d0769c4..0000000 --- a/template_benchmark_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package template_test - -import ( - "testing" - - "github.com/bsv-blockchain/go-subtree" -) - -// BenchmarkGreet benchmarks the Greet function. -func BenchmarkGreet(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = template.Greet("BenchmarkUser") - } -} diff --git a/template_example_test.go b/template_example_test.go deleted file mode 100644 index 226539f..0000000 --- a/template_example_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package template_test - -import ( - "fmt" - - "github.com/bsv-blockchain/go-subtree" -) - -// ExampleGreet demonstrates the usage of the Greet function. -func ExampleGreet() { - msg := template.Greet("Alice") - fmt.Println(msg) - // Output: Hello Alice -} diff --git a/template_fuzz_test.go b/template_fuzz_test.go deleted file mode 100644 index 4b77081..0000000 --- a/template_fuzz_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package template_test - -import ( - "testing" - - "github.com/bsv-blockchain/go-subtree" - "github.com/stretchr/testify/require" -) - -// FuzzGreet validates that Greet always returns a string that starts with "Hello " and is at least 6 characters long. -func FuzzGreet(f *testing.F) { - seed := []string{ - "Alice", - "", - "123", - "!@#", - "世界", //nolint: gosmopolitan // Test with non-ASCII characters - } - for _, tc := range seed { - f.Add(tc) - } - f.Fuzz(func(t *testing.T, input string) { - out := template.Greet(input) - require.GreaterOrEqualf(t, len(out), 6, "output too short: %q", out) - require.Equalf(t, "Hello ", out[:6], "output does not start with 'Hello ': %q", out) - }) -} diff --git a/template_test.go b/template_test.go deleted file mode 100644 index 27b6d06..0000000 --- a/template_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package template_test - -import ( - "testing" - - "github.com/bsv-blockchain/go-subtree" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestGreet tests the Greet function with various input scenarios using table-driven tests. -func TestGreet(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input string - expected string - }{ - { - name: "normal name", - input: "Alice", - expected: "Hello Alice", - }, - { - name: "empty string", - input: "", - expected: "Hello ", - }, - { - name: "whitespace", - input: " ", - expected: "Hello ", - }, - { - name: "special characters", - input: "@!$", - expected: "Hello @!$", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := template.Greet(tt.input) - require.NotNil(t, result) - assert.Equal(t, tt.expected, result) - }) - } -}