Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions core/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ func ParSignedDataFromProto(typ DutyType, data *pbv1.ParSignedData) (_ ParSigned
// This is to respect the technical possibility of unmarshalling to panic.
// However, our protobuf generated types do not have custom marshallers that may panic.
if r := recover(); r != nil {
rowStr := fmt.Sprintf("%v", r)
oerr = errors.Wrap(errors.New(rowStr), "panic recovered")
oerr = recoverPanicErr(r)
}
}()

Expand Down Expand Up @@ -242,7 +241,13 @@ func UnsignedDataSetToProto(set UnsignedDataSet) (*pbv1.UnsignedDataSet, error)
}

// UnsignedDataSetFromProto returns the set from a protobuf.
func UnsignedDataSetFromProto(typ DutyType, set *pbv1.UnsignedDataSet) (UnsignedDataSet, error) {
func UnsignedDataSetFromProto(typ DutyType, set *pbv1.UnsignedDataSet) (_ UnsignedDataSet, oerr error) {
defer func() {
if r := recover(); r != nil {
oerr = recoverPanicErr(r)
}
}()

if set == nil || len(set.GetSet()) == 0 {
return nil, errors.New("invalid unsigned data set fields", z.Any("set", set))
}
Expand All @@ -261,6 +266,17 @@ func UnsignedDataSetFromProto(typ DutyType, set *pbv1.UnsignedDataSet) (Unsigned
return resp, nil
}

func recoverPanicErr(r any) error {
var err error
if recoveredErr, ok := r.(error); ok {
err = recoveredErr
} else {
err = errors.New(fmt.Sprint(r))
}

return errors.Wrap(err, "panic recovered")
}

// marshal marshals the given value into bytes, either as SSZ if supported by the type (and if enabled) or as json.
func marshal(v any) ([]byte, error) {
// First try SSZ
Expand Down
111 changes: 111 additions & 0 deletions core/proto_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@ package core
import (
"encoding/hex"
"encoding/json"
stderrors "errors"
"testing"

eth2v1 "github.com/attestantio/go-eth2-client/api/v1"
eth2p0 "github.com/attestantio/go-eth2-client/spec/phase0"
ssz "github.com/ferranbt/fastssz"
"github.com/stretchr/testify/require"

"github.com/obolnetwork/charon/app/errors"
pbv1 "github.com/obolnetwork/charon/core/corepb/v1"
"github.com/obolnetwork/charon/eth2util"
)

// TestMarshal tests the marshal() internal function.
Expand Down Expand Up @@ -94,6 +100,111 @@ func TestMarshal(t *testing.T) {
}
}

// TestUnsignedDataSetFromProtoMalformedSSZOffset verifies that malformed SSZ bytes with an
// out-of-bounds offset field return an error instead of panicking.
func TestUnsignedDataSetFromProtoMalformedSSZOffset(t *testing.T) {
// versionedBlindedOffset = 13: 8 (version uint64) + 1 (blinded uint8) + 4 (offset uint32).
// A 13-byte buffer with the offset field encoding 14 caused slice bounds [14:13] before the fix.
proposerBuf := []byte{
0, 0, 0, 0, 0, 0, 0, 0, // version = Phase0 (0)
0, // blinded = false
14, 0, 0, 0, // offset = 14, but len(buf) = 13
}

// versionedOffset = 12: 8 (version uint64) + 4 (offset uint32).
// A 12-byte buffer with the offset field encoding 13 caused slice bounds [13:12] before the fix.
aggregatorBuf := []byte{
0, 0, 0, 0, 0, 0, 0, 0, // version = Phase0 (0)
13, 0, 0, 0, // offset = 13, but len(buf) = 12
}

t.Run("versioned_blinded_helper_returns_offset_error", func(t *testing.T) {
_, _, err := unmarshalSSZVersionedBlinded(proposerBuf, func(eth2util.DataVersion, bool) (sszType, error) {
t.Fatal("valFunc must not be called for an out-of-bounds offset")

return nil, stderrors.New("unexpected valFunc call")
})
require.Error(t, err)
require.True(t, errors.Is(err, ssz.ErrOffset), "error must wrap ssz.ErrOffset: %v", err)
require.NotContains(t, err.Error(), "panic recovered")
})

t.Run("versioned_helper_returns_offset_error", func(t *testing.T) {
_, err := unmarshalSSZVersioned(aggregatorBuf, func(eth2util.DataVersion) (sszType, error) {
t.Fatal("valFunc must not be called for an out-of-bounds offset")

return nil, stderrors.New("unexpected valFunc call")
})
require.Error(t, err)
require.True(t, errors.Is(err, ssz.ErrOffset), "error must wrap ssz.ErrOffset: %v", err)
require.NotContains(t, err.Error(), "panic recovered")
})

tests := []struct {
name string
duty DutyType
buf []byte
}{
{
name: "proposer/versioned_blinded_offset_oob",
duty: DutyProposer,
buf: proposerBuf,
},
{
name: "aggregator/versioned_offset_oob",
duty: DutyAggregator,
buf: aggregatorBuf,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
set := &pbv1.UnsignedDataSet{
Set: map[string][]byte{"pk": tt.buf},
}

_, err := UnsignedDataSetFromProto(tt.duty, set)
require.Error(t, err)
require.NotContains(t, err.Error(), "panic recovered")
})
}
}

func TestRecoverPanicErr(t *testing.T) {
sentinel := stderrors.New("sentinel")

tests := []struct {
name string
recovered any
is error
contains string
}{
{
name: "error",
recovered: sentinel,
is: sentinel,
contains: "sentinel",
},
{
name: "string",
recovered: "plain panic",
contains: "plain panic",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := recoverPanicErr(tt.recovered)
require.ErrorContains(t, err, "panic recovered")
require.ErrorContains(t, err, tt.contains)

if tt.is != nil {
require.ErrorIs(t, err, tt.is)
}
})
}
}

// TestUnmarshal tests the unmarshal() internal function.
// unmarshal() tries SSZ first (if the type implements ssz.Unmarshaler),
// falling back to JSON when SSZ fails and the data starts with '{'.
Expand Down
4 changes: 2 additions & 2 deletions core/ssz.go
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ func unmarshalSSZVersionedBlinded(buf []byte, valFunc func(eth2util.DataVersion,

// Offset (2) 'Value'
o1 := ssz.ReadOffset(buf[9:13])
if versionedBlindedOffset > o1 {
if versionedBlindedOffset > o1 || o1 > uint64(len(buf)) {
return "", false, errors.Wrap(ssz.ErrOffset, "sszValFromVersion offset", z.Any("version", version), z.Bool("blinded", blinded))
}

Expand Down Expand Up @@ -801,7 +801,7 @@ func unmarshalSSZVersioned(buf []byte, valFunc func(eth2util.DataVersion) (sszTy

// Offset (1) 'Value'
o1 := ssz.ReadOffset(buf[8:12])
if versionedOffset > o1 {
if versionedOffset > o1 || o1 > uint64(len(buf)) {
return "", errors.Wrap(ssz.ErrOffset, "sszValFromVersion offset", z.Any("version", version))
}

Expand Down
Loading