diff --git a/modal-go/client.go b/modal-go/client.go index 3a18c5ac..11b96b37 100644 --- a/modal-go/client.go +++ b/modal-go/client.go @@ -59,6 +59,7 @@ type Client struct { Apps AppService CloudBucketMounts CloudBucketMountService Cls ClsService + Dicts DictService Functions FunctionService FunctionCalls FunctionCallService Images ImageService @@ -187,6 +188,7 @@ func NewClientWithOptions(params *ClientParams) (*Client, error) { c.Apps = &appServiceImpl{client: c} c.CloudBucketMounts = &cloudBucketMountServiceImpl{client: c} c.Cls = &clsServiceImpl{client: c} + c.Dicts = &dictServiceImpl{client: c} c.Functions = &functionServiceImpl{client: c} c.FunctionCalls = &functionCallServiceImpl{client: c} c.Images = &imageServiceImpl{client: c} diff --git a/modal-go/dict.go b/modal-go/dict.go new file mode 100644 index 00000000..9863bb1d --- /dev/null +++ b/modal-go/dict.go @@ -0,0 +1,784 @@ +package modal + +// Dict implements the Modal Dict distributed key-value store. +// +// Dict keys are matched on the server by byte-equality of their serialized +// pickle representation. The Go SDK serializes keys using og-rek (pickle +// protocol 4, StrictUnicode) with post-processing to inject FRAME and MEMOIZE +// opcodes, producing bytes byte-identical to Python's cloudpickle for all +// supported primitive types. This enables cross-language interop: keys written +// by Go can be read by Python and vice versa. +// +// Per the Modal docs: "cloudpickle serialization is not guaranteed to be +// deterministic, so it is generally recommended to use primitive types for keys." +// See https://modal.com/docs/reference/modal.Dict +// +// Values are serialized using og-rek (protocol 4) without post-processing. +// Values only need to be valid pickle for deserialization — they do not require +// byte-equality with cloudpickle. This supports complex types (maps, slices, +// nested structures) as values. +// +// Supported key types (Go → Python): +// +// nil → None +// bool → bool +// int, int8-64 → int +// uint, uint8-64 → int +// float32, float64 → float +// string → str +// []byte → bytes + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "io" + "iter" + "math/big" + "strings" + + pickle "github.com/kisielk/og-rek" + pb "github.com/modal-labs/libmodal/modal-go/proto/modal_proto" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// Pickle protocol 4 opcodes used by ogrekToCloudpickle / cloudpickleToOgRek. +const ( + pickleShortBinBytes = 'C' // bytes len < 256 + pickleBinBytes = 'B' // bytes len < 2^32 + pickleShortBinUnicode = 0x8c // string len < 256 + pickleBinUnicode = 'X' // string len >= 256 + pickleLong1 = 0x8a // variable-length two's complement int + pickleASCIIInt = 'I' // og-rek uses this for ints outside int32 range + pickleMemoize = 0x94 + pickleFrame = 0x95 + pickleStop = '.' +) + +// dictOgRekP4Serialize uses og-rek with Protocol 4 + StrictUnicode. +// This is the base encoder used by both dictSerializeKey (with post-processing) +// and dictSerializeValue (without post-processing). +func dictOgRekP4Serialize(v any) ([]byte, error) { + var buf bytes.Buffer + e := pickle.NewEncoderWithConfig(&buf, &pickle.EncoderConfig{Protocol: 4, StrictUnicode: true}) + if err := e.Encode(v); err != nil { + return nil, fmt.Errorf("og-rek pickle error: %w", err) + } + return buf.Bytes(), nil +} + +// dictSerializeKey serializes a Dict key to pickle bytes that are byte-identical +// to Python's cloudpickle (protocol 4) for all supported primitive types. +// This enables cross-language Dict key lookups: keys written in Go can be +// read by Python and vice versa. +// +// Internally: runs og-rek (protocol 4, StrictUnicode) then post-processes the +// output via ogrekToCloudpickle to inject FRAME and MEMOIZE opcodes. +// +// Per the Modal docs: "cloudpickle serialization is not guaranteed to be +// deterministic, so it is generally recommended to use primitive types for keys." +func dictSerializeKey(v any) ([]byte, error) { + raw, err := dictOgRekP4Serialize(v) + if err != nil { + return nil, err + } + return ogrekToCloudpickle(raw), nil +} + +// og-rek encodes []byte as builtins.bytearray(SHORT_BINBYTES(data)), but +// cloudpickle uses bare SHORT_BINBYTES/BINBYTES. This prefix is the og-rek +// bytearray constructor pattern that we convert. +var ogrekBytearrayPrefix = []byte{ + 0x8c, 0x08, 'b', 'u', 'i', 'l', 't', 'i', 'n', 's', + 0x8c, 0x09, 'b', 'y', 't', 'e', 'a', 'r', 'r', 'a', 'y', + 0x93, // STACK_GLOBAL +} + +// ogrekToCloudpickle post-processes og-rek protocol 4 bytes into cloudpickle- +// compatible bytes by injecting FRAME and MEMOIZE opcodes, converting the +// ASCII 'I' opcode to LONG1, and unwrapping bytearray constructors. +// +// og-rek protocol 4 produces: PROTO(4) + + STOP +// cloudpickle protocol 4 needs: PROTO(4) + [FRAME(len)] + [+ MEMOIZE] + STOP +// +// The post-processing: +// 1. Strips the PROTO header (0x80 0x04) and trailing STOP (0x2e) +// 2. Converts ASCII 'I' opcode (og-rek uses for ints > int32) to LONG1 (cloudpickle) +// 3. Unwraps builtins.bytearray(...) to bare SHORT_BINBYTES/BINBYTES +// 4. If the first opcode is a string or bytes opcode, appends MEMOIZE (0x94) +// 5. Re-appends STOP +// 6. If the content (after PROTO) is >= 4 bytes, wraps it in a FRAME opcode +// 7. Reassembles: PROTO(4) + [FRAME] + content + STOP +func ogrekToCloudpickle(raw []byte) []byte { + // raw layout: [0x80 0x04] [0x2e] + // Minimum valid pickle is 3 bytes: PROTO(2) + STOP(1). + if len(raw) < 3 { + return raw + } + + body := raw[2 : len(raw)-1] // strip PROTO header and STOP + + // og-rek uses the ASCII 'I' opcode (protocol 0) for ints outside int32 range. + // cloudpickle uses LONG1 (0x8a) with binary two's complement encoding. + // Convert: I\n → LONG1 + if len(body) > 0 && body[0] == pickleASCIIInt { + s := strings.TrimSuffix(string(body[1:]), "\n") + n, ok := new(big.Int).SetString(s, 10) + if ok { + body = encodeLong1(n) + } + } + + // og-rek encodes []byte as builtins.bytearray(SHORT_BINBYTES(data)) + // but cloudpickle uses bare SHORT_BINBYTES/BINBYTES. Unwrap. + if bytes.HasPrefix(body, ogrekBytearrayPrefix) { + inner := body[len(ogrekBytearrayPrefix):] + if extracted := extractBytesOpcode(inner); extracted != nil { + body = extracted + } + } + + // Inject MEMOIZE after string/bytes opcodes. + // cloudpickle always memoizes strings and bytes; og-rek never does. + needsMemoize := len(body) > 0 && + (body[0] == pickleShortBinUnicode || body[0] == pickleBinUnicode || + body[0] == pickleShortBinBytes || body[0] == pickleBinBytes) + + var content bytes.Buffer + content.Write(body) + if needsMemoize { + content.WriteByte(pickleMemoize) + } + content.WriteByte(pickleStop) + + // Assemble: PROTO(4) + optional FRAME + content. + // cloudpickle emits a FRAME when content >= 4 bytes. + var buf bytes.Buffer + buf.Write([]byte{0x80, 0x04}) + + data := content.Bytes() + if len(data) >= 4 { + buf.WriteByte(pickleFrame) + binary.Write(&buf, binary.LittleEndian, uint64(len(data))) + } + buf.Write(data) + + return buf.Bytes() +} + +// extractBytesOpcode extracts the bare SHORT_BINBYTES/BINBYTES from an og-rek +// bytearray constructor body: TUPLE1(0x85) REDUCE(0x52). +// Returns nil if the pattern doesn't match. +func extractBytesOpcode(inner []byte) []byte { + if len(inner) < 4 { + return nil + } + switch inner[0] { + case pickleShortBinBytes: + dataLen := int(inner[1]) + end := 2 + dataLen + if end+2 <= len(inner) && inner[end] == 0x85 && inner[end+1] == 0x52 { + return inner[:end] + } + case pickleBinBytes: + if len(inner) < 5 { + return nil + } + dataLen := int(binary.LittleEndian.Uint32(inner[1:5])) + end := 5 + dataLen + if end+2 <= len(inner) && inner[end] == 0x85 && inner[end+1] == 0x52 { + return inner[:end] + } + } + return nil +} + +// encodeLong1 encodes a big.Int as a pickle LONG1 opcode: 0x8a . +// The bytes are minimal-length little-endian two's complement, matching Python's +// pickle LONG1 format. +func encodeLong1(n *big.Int) []byte { + if n.Sign() == 0 { + return []byte{pickleLong1, 0x00} + } + + // big.Int.Bytes() returns unsigned big-endian bytes. We need signed + // little-endian (two's complement). For positive numbers, Bytes() is + // already the unsigned representation. For negative, compute two's + // complement manually. + var data []byte + if n.Sign() > 0 { + be := n.Bytes() // big-endian unsigned + data = make([]byte, len(be)) + for i, b := range be { + data[len(be)-1-i] = b // reverse to little-endian + } + // If high bit is set, append 0x00 to keep it positive. + if data[len(data)-1] >= 0x80 { + data = append(data, 0x00) + } + } else { + // Two's complement for negative: subtract 1 from abs, then invert bits. + abs := new(big.Int).Abs(n) + abs.Sub(abs, big.NewInt(1)) + be := abs.Bytes() + if len(be) == 0 { + // n == -1: abs-1 == 0, Bytes() is empty. Two's complement is 0xff. + data = []byte{0xff} + } else { + data = make([]byte, len(be)) + for i, b := range be { + data[len(be)-1-i] = ^b // reverse and invert + } + // If high bit is not set, append 0xff to keep it negative. + if data[len(data)-1] < 0x80 { + data = append(data, 0xff) + } + } + } + + result := make([]byte, 0, 2+len(data)) + result = append(result, pickleLong1, byte(len(data))) + result = append(result, data...) + return result +} + +// decodeLong1 decodes the body portion of a LONG1 opcode (after the 0x8a byte) +// back to an ASCII 'I' opcode string for cloudpickleToOgRek. The input is: +// . +func decodeLong1(body []byte) (string, int) { + if len(body) < 1 { + return "", 0 + } + length := int(body[0]) + if len(body) < 1+length { + return "", 0 + } + data := body[1 : 1+length] + + if length == 0 { + return "I0\n", 1 + length + } + + // Determine sign from the high bit of the last byte (most significant in LE). + negative := data[length-1] >= 0x80 + + n := new(big.Int) + if !negative { + // Convert LE to BE. + be := make([]byte, length) + for i, b := range data { + be[length-1-i] = b + } + n.SetBytes(be) + } else { + // Two's complement negative: invert bits, convert, then negate and subtract 1. + be := make([]byte, length) + for i, b := range data { + be[length-1-i] = ^b + } + n.SetBytes(be) + n.Add(n, big.NewInt(1)) + n.Neg(n) + } + + return "I" + n.String() + "\n", 1 + length +} + +// cloudpickleToOgRek strips FRAME and MEMOIZE opcodes from cloudpickle protocol 4 +// bytes, producing plain og-rek protocol 4 bytes. This is the inverse of +// ogrekToCloudpickle. +func cloudpickleToOgRek(raw []byte) []byte { + if len(raw) < 3 { + return raw + } + + // Start after PROTO header (0x80 0x04). + pos := 2 + + // Skip FRAME opcode + 8-byte length if present. + if pos < len(raw) && raw[pos] == pickleFrame { + pos += 1 + 8 // opcode + uint64 frame length + } + + body := raw[pos:] + + // Strip trailing MEMOIZE before STOP, but only for string/bytes opcodes + // (mirroring ogrekToCloudpickle which only injects MEMOIZE for these types). + // Without this check, a data byte 0x94 (e.g. in BININT LE payload) would + // be falsely stripped. + hasMemoize := len(body) >= 2 && + body[len(body)-2] == pickleMemoize && body[len(body)-1] == pickleStop + isStringOrBytes := len(body) > 0 && + (body[0] == pickleShortBinUnicode || body[0] == pickleBinUnicode || + body[0] == pickleShortBinBytes || body[0] == pickleBinBytes) + if hasMemoize && isStringOrBytes { + body = append(body[:len(body)-2], pickleStop) + } + + // Convert LONG1 back to ASCII 'I' opcode (inverse of ogrekToCloudpickle). + if len(body) >= 2 && body[0] == pickleLong1 { + asciiInt, consumed := decodeLong1(body[1:]) + if consumed > 0 { + var newBody bytes.Buffer + newBody.WriteString(asciiInt) + newBody.Write(body[1+consumed:]) + body = newBody.Bytes() + } + } + + // Wrap bare SHORT_BINBYTES/BINBYTES in builtins.bytearray() constructor + // (inverse of ogrekToCloudpickle's bytearray unwrapping). + if len(body) >= 2 && (body[0] == pickleShortBinBytes || body[0] == pickleBinBytes) { + // body = STOP + // Wrap: prefix + + TUPLE1 + REDUCE + STOP + stopIdx := len(body) - 1 // last byte should be STOP + if body[stopIdx] == pickleStop { + var newBody bytes.Buffer + newBody.Write(ogrekBytearrayPrefix) + newBody.Write(body[:stopIdx]) // bytes opcode + data (without STOP) + newBody.WriteByte(0x85) // TUPLE1 + newBody.WriteByte(0x52) // REDUCE + newBody.WriteByte(pickleStop) + body = newBody.Bytes() + } + } + + var buf bytes.Buffer + buf.Write([]byte{0x80, 0x04}) + buf.Write(body) + return buf.Bytes() +} + +// DictService provides Dict related operations. +type DictService interface { + Ephemeral(ctx context.Context, params *DictEphemeralParams) (*Dict, error) + FromName(ctx context.Context, name string, params *DictFromNameParams) (*Dict, error) + Delete(ctx context.Context, name string, params *DictDeleteParams) error +} + +type dictServiceImpl struct{ client *Client } + +// Dict is a distributed dictionary for storage in Modal Apps. +// +// Keys should be primitive types (see package doc for the full list). +// cloudpickle serialization is not guaranteed to be deterministic, so +// complex types as keys may produce inconsistent lookups across languages. +type Dict struct { + DictID string + Name string + cancelEphemeral context.CancelFunc + + client *Client +} + +// DictEphemeralParams are options for client.Dicts.Ephemeral. +type DictEphemeralParams struct { + Environment string +} + +// Ephemeral creates a nameless, temporary Dict that persists until CloseEphemeral is called, or the process exits. +func (s *dictServiceImpl) Ephemeral(ctx context.Context, params *DictEphemeralParams) (*Dict, error) { + if params == nil { + params = &DictEphemeralParams{} + } + + resp, err := s.client.cpClient.DictGetOrCreate(ctx, pb.DictGetOrCreateRequest_builder{ + ObjectCreationType: pb.ObjectCreationType_OBJECT_CREATION_TYPE_EPHEMERAL, + EnvironmentName: environmentName(params.Environment, s.client.profile), + }.Build()) + if err != nil { + return nil, err + } + + s.client.logger.DebugContext(ctx, "Created ephemeral Dict", "dict_id", resp.GetDictId()) + + ephemeralCtx, cancel := context.WithCancel(context.Background()) + startEphemeralHeartbeat(ephemeralCtx, func() error { + _, err := s.client.cpClient.DictHeartbeat(ephemeralCtx, pb.DictHeartbeatRequest_builder{ + DictId: resp.GetDictId(), + }.Build()) + return err + }) + + return &Dict{ + DictID: resp.GetDictId(), + cancelEphemeral: cancel, + client: s.client, + }, nil +} + +// CloseEphemeral deletes an ephemeral Dict, only used with DictEphemeral. +func (d *Dict) CloseEphemeral() { + if d.cancelEphemeral != nil { + d.cancelEphemeral() + } else { + panic(fmt.Sprintf("Dict %s is not ephemeral", d.DictID)) + } +} + +// DictFromNameParams are options for client.Dicts.FromName. +type DictFromNameParams struct { + Environment string + CreateIfMissing bool +} + +// FromName references a named Dict, creating if necessary. +func (s *dictServiceImpl) FromName(ctx context.Context, name string, params *DictFromNameParams) (*Dict, error) { + if params == nil { + params = &DictFromNameParams{} + } + + creationType := pb.ObjectCreationType_OBJECT_CREATION_TYPE_UNSPECIFIED + if params.CreateIfMissing { + creationType = pb.ObjectCreationType_OBJECT_CREATION_TYPE_CREATE_IF_MISSING + } + + resp, err := s.client.cpClient.DictGetOrCreate(ctx, pb.DictGetOrCreateRequest_builder{ + DeploymentName: name, + EnvironmentName: environmentName(params.Environment, s.client.profile), + ObjectCreationType: creationType, + }.Build()) + + if status, ok := status.FromError(err); ok && status.Code() == codes.NotFound { + return nil, NotFoundError{fmt.Sprintf("Dict '%s' not found", name)} + } + if err != nil { + return nil, err + } + + s.client.logger.DebugContext(ctx, "Retrieved Dict", "dict_id", resp.GetDictId(), "dict_name", name) + return &Dict{ + DictID: resp.GetDictId(), + Name: name, + cancelEphemeral: nil, + client: s.client, + }, nil +} + +// DictDeleteParams are options for client.Dicts.Delete. +type DictDeleteParams struct { + Environment string + AllowMissing bool +} + +// Delete removes a Dict by name. +// +// Warning: Deletion is irreversible and will affect any Apps currently using the Dict. +func (s *dictServiceImpl) Delete(ctx context.Context, name string, params *DictDeleteParams) error { + if params == nil { + params = &DictDeleteParams{} + } + + d, err := s.FromName(ctx, name, &DictFromNameParams{ + Environment: params.Environment, + CreateIfMissing: false, + }) + + if err != nil { + if _, ok := err.(NotFoundError); ok && params.AllowMissing { + return nil + } + return err + } + + _, err = s.client.cpClient.DictDelete(ctx, pb.DictDeleteRequest_builder{DictId: d.DictID}.Build()) + if err != nil { + if st, ok := status.FromError(err); ok && st.Code() == codes.NotFound && params.AllowMissing { + return nil + } + return err + } + + s.client.logger.DebugContext(ctx, "Deleted Dict", "dict_name", name, "dict_id", d.DictID) + return nil +} + +// Clear removes all items from the Dict. +func (d *Dict) Clear(ctx context.Context) error { + _, err := d.client.cpClient.DictClear(ctx, pb.DictClearRequest_builder{ + DictId: d.DictID, + }.Build()) + return err +} + +// Get returns the value for a key. The second return value indicates whether the key was found. +func (d *Dict) Get(ctx context.Context, key any) (any, bool, error) { + keyBytes, err := dictSerializeKey(key) + if err != nil { + return nil, false, err + } + + resp, err := d.client.cpClient.DictGet(ctx, pb.DictGetRequest_builder{ + DictId: d.DictID, + Key: keyBytes, + }.Build()) + if err != nil { + return nil, false, err + } + + if !resp.GetFound() { + return nil, false, nil + } + + val, err := pickleDeserialize(resp.GetValue()) + if err != nil { + return nil, false, err + } + return val, true, nil +} + +// Contains returns whether a key is present in the Dict. +func (d *Dict) Contains(ctx context.Context, key any) (bool, error) { + keyBytes, err := dictSerializeKey(key) + if err != nil { + return false, err + } + + resp, err := d.client.cpClient.DictContains(ctx, pb.DictContainsRequest_builder{ + DictId: d.DictID, + Key: keyBytes, + }.Build()) + if err != nil { + return false, err + } + return resp.GetFound(), nil +} + +// Len returns the number of items in the Dict. +func (d *Dict) Len(ctx context.Context) (int, error) { + resp, err := d.client.cpClient.DictLen(ctx, pb.DictLenRequest_builder{ + DictId: d.DictID, + }.Build()) + if err != nil { + return 0, err + } + return int(resp.GetLen()), nil +} + +// DictPutParams are options for Dict.Put. +type DictPutParams struct { + SkipIfExists bool +} + +// Put adds a key-value pair to the Dict. +// Returns true if the entry was created, false if the key already existed and SkipIfExists was set. +func (d *Dict) Put(ctx context.Context, key any, value any, params *DictPutParams) (bool, error) { + if params == nil { + params = &DictPutParams{} + } + + entries, err := serializeDictEntries(key, value) + if err != nil { + return false, err + } + + resp, err := d.client.cpClient.DictUpdate(ctx, pb.DictUpdateRequest_builder{ + DictId: d.DictID, + Updates: entries, + IfNotExists: params.SkipIfExists, + }.Build()) + if err != nil { + return false, err + } + return resp.GetCreated(), nil +} + +// Pop removes a key from the Dict and returns its value. +// The second return value indicates whether the key was found. +func (d *Dict) Pop(ctx context.Context, key any) (any, bool, error) { + keyBytes, err := dictSerializeKey(key) + if err != nil { + return nil, false, err + } + + resp, err := d.client.cpClient.DictPop(ctx, pb.DictPopRequest_builder{ + DictId: d.DictID, + Key: keyBytes, + }.Build()) + if err != nil { + return nil, false, err + } + + if !resp.GetFound() { + return nil, false, nil + } + + val, err := pickleDeserialize(resp.GetValue()) + if err != nil { + return nil, false, err + } + return val, true, nil +} + +// Update adds multiple key-value pairs to the Dict. +func (d *Dict) Update(ctx context.Context, data map[any]any) error { + entries, err := serializeDictEntriesMap(data) + if err != nil { + return err + } + + _, err = d.client.cpClient.DictUpdate(ctx, pb.DictUpdateRequest_builder{ + DictId: d.DictID, + Updates: entries, + }.Build()) + return err +} + +// DictItem holds a key-value pair from a Dict iteration. +type DictItem struct { + Key any + Value any +} + +// Keys returns an iterator over the keys in the Dict. +func (d *Dict) Keys(ctx context.Context) iter.Seq2[any, error] { + return func(yield func(any, error) bool) { + stream, err := d.client.cpClient.DictContents(ctx, pb.DictContentsRequest_builder{ + DictId: d.DictID, + Keys: true, + }.Build()) + if err != nil { + yield(nil, err) + return + } + + for { + entry, err := stream.Recv() + if err == io.EOF { + return + } + if err != nil { + yield(nil, err) + return + } + key, err := pickleDeserialize(entry.GetKey()) + if err != nil { + yield(nil, err) + return + } + if !yield(key, nil) { + return + } + } + } +} + +// Values returns an iterator over the values in the Dict. +func (d *Dict) Values(ctx context.Context) iter.Seq2[any, error] { + return func(yield func(any, error) bool) { + stream, err := d.client.cpClient.DictContents(ctx, pb.DictContentsRequest_builder{ + DictId: d.DictID, + Values: true, + }.Build()) + if err != nil { + yield(nil, err) + return + } + + for { + entry, err := stream.Recv() + if err == io.EOF { + return + } + if err != nil { + yield(nil, err) + return + } + val, err := pickleDeserialize(entry.GetValue()) + if err != nil { + yield(nil, err) + return + } + if !yield(val, nil) { + return + } + } + } +} + +// Items returns an iterator over the key-value pairs in the Dict. +func (d *Dict) Items(ctx context.Context) iter.Seq2[DictItem, error] { + return func(yield func(DictItem, error) bool) { + stream, err := d.client.cpClient.DictContents(ctx, pb.DictContentsRequest_builder{ + DictId: d.DictID, + Keys: true, + Values: true, + }.Build()) + if err != nil { + yield(DictItem{}, err) + return + } + + for { + entry, err := stream.Recv() + if err == io.EOF { + return + } + if err != nil { + yield(DictItem{}, err) + return + } + key, err := pickleDeserialize(entry.GetKey()) + if err != nil { + yield(DictItem{}, err) + return + } + val, err := pickleDeserialize(entry.GetValue()) + if err != nil { + yield(DictItem{}, err) + return + } + if !yield(DictItem{Key: key, Value: val}, nil) { + return + } + } + } +} + +// dictSerializeValue serializes a Dict value using og-rek (protocol 4, StrictUnicode). +// Unlike keys, values don't need byte-equality with cloudpickle — they just need to be +// valid pickle that Python can deserialize. og-rek handles primitives and complex types +// (maps, slices, nested structures). +func dictSerializeValue(v any) ([]byte, error) { + return dictOgRekP4Serialize(v) +} + +// serializeDictEntries serializes a single key-value pair into a DictEntry slice. +func serializeDictEntries(key any, value any) ([]*pb.DictEntry, error) { + keyBytes, err := dictSerializeKey(key) + if err != nil { + return nil, err + } + valBytes, err := dictSerializeValue(value) + if err != nil { + return nil, err + } + return []*pb.DictEntry{ + pb.DictEntry_builder{ + Key: keyBytes, + Value: valBytes, + }.Build(), + }, nil +} + +// serializeDictEntriesMap serializes a map of key-value pairs into a DictEntry slice. +func serializeDictEntriesMap(data map[any]any) ([]*pb.DictEntry, error) { + entries := make([]*pb.DictEntry, 0, len(data)) + for k, v := range data { + keyBytes, err := dictSerializeKey(k) + if err != nil { + return nil, err + } + valBytes, err := dictSerializeValue(v) + if err != nil { + return nil, err + } + entries = append(entries, pb.DictEntry_builder{ + Key: keyBytes, + Value: valBytes, + }.Build()) + } + return entries, nil +} diff --git a/modal-go/dict_serialization_test.go b/modal-go/dict_serialization_test.go new file mode 100644 index 00000000..f8453a41 --- /dev/null +++ b/modal-go/dict_serialization_test.go @@ -0,0 +1,189 @@ +package modal + +import ( + "math" + "math/big" + "testing" + + pickle "github.com/kisielk/og-rek" + "github.com/onsi/gomega" +) + +// Golden bytes generated by Python's cloudpickle.dumps(value, protocol=4). +// These are the canonical byte sequences that the Modal Dict server expects. + +var goldenCases = []struct { + name string + value any + bytes []byte +}{ + {"nil", nil, []byte{0x80, 0x04, 0x4e, 0x2e}}, + {"true", true, []byte{0x80, 0x04, 0x88, 0x2e}}, + {"false", false, []byte{0x80, 0x04, 0x89, 0x2e}}, + {"int 0", 0, []byte{0x80, 0x04, 0x4b, 0x00, 0x2e}}, + {"int 1", 1, []byte{0x80, 0x04, 0x4b, 0x01, 0x2e}}, + {"int 42", 42, []byte{0x80, 0x04, 0x4b, 0x2a, 0x2e}}, + {"int 255", 255, []byte{0x80, 0x04, 0x4b, 0xff, 0x2e}}, + {"int 256", 256, []byte{0x80, 0x04, 0x95, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4d, 0x00, 0x01, 0x2e}}, + {"int 65535", 65535, []byte{0x80, 0x04, 0x95, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4d, 0xff, 0xff, 0x2e}}, + {"int 65536", 65536, []byte{0x80, 0x04, 0x95, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4a, 0x00, 0x00, 0x01, 0x00, 0x2e}}, + {"int -1", -1, []byte{0x80, 0x04, 0x95, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4a, 0xff, 0xff, 0xff, 0xff, 0x2e}}, + {"int -100", -100, []byte{0x80, 0x04, 0x95, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4a, 0x9c, 0xff, 0xff, 0xff, 0x2e}}, + {"int MaxInt32", int64(math.MaxInt32), []byte{0x80, 0x04, 0x95, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4a, 0xff, 0xff, 0xff, 0x7f, 0x2e}}, + {"int MinInt32", int64(math.MinInt32), []byte{0x80, 0x04, 0x95, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4a, 0x00, 0x00, 0x00, 0x80, 0x2e}}, + {"int MaxInt32+1", int64(math.MaxInt32) + 1, []byte{0x80, 0x04, 0x95, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x8a, 0x05, 0x00, 0x00, 0x00, 0x80, 0x00, 0x2e}}, + {"int MinInt32-1", int64(math.MinInt32) - 1, []byte{0x80, 0x04, 0x95, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x8a, 0x05, 0xff, 0xff, 0xff, 0x7f, 0xff, 0x2e}}, + {"int MaxInt64", int64(math.MaxInt64), []byte{0x80, 0x04, 0x95, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x8a, 0x08, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, 0x2e}}, + {"int MinInt64", int64(math.MinInt64), []byte{0x80, 0x04, 0x95, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x8a, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x2e}}, + {"uint MaxUint64", uint64(math.MaxUint64), []byte{0x80, 0x04, 0x95, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x8a, 0x09, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x2e}}, + {"float 1.5", 1.5, []byte{0x80, 0x04, 0x95, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x3f, 0xf8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2e}}, + {"float 3.14", 3.14, []byte{0x80, 0x04, 0x95, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x40, 0x09, 0x1e, 0xb8, 0x51, 0xeb, 0x85, 0x1f, 0x2e}}, + {"string hello", "hello", []byte{0x80, 0x04, 0x95, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x8c, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x94, 0x2e}}, + {"string empty", "", []byte{0x80, 0x04, 0x95, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x8c, 0x00, 0x94, 0x2e}}, + {"bytes short", []byte("hello bytes"), []byte{0x80, 0x04, 0x95, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x43, 0x0b, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x62, 0x79, 0x74, 0x65, 0x73, 0x94, 0x2e}}, + {"bytes empty", []byte{}, []byte{0x80, 0x04, 0x95, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x43, 0x00, 0x94, 0x2e}}, +} + +// TestDictSerializeKeyGoldenBytes verifies Go → Python compatibility: +// dictSerializeKey produces bytes identical to Python's cloudpickle.dumps(value, protocol=4). +func TestDictSerializeKeyGoldenBytes(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + + for _, tc := range goldenCases { + t.Run(tc.name, func(t *testing.T) { + got, err := dictSerializeKey(tc.value) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(got).To(gomega.Equal(tc.bytes), "dictSerializeKey(%v) bytes mismatch", tc.value) + }) + } +} + +// TestDictDeserializeCloudpickleBytes verifies Python → Go compatibility: +// cloudpickle bytes can be deserialized correctly by og-rek's decoder. +func TestDictDeserializeCloudpickleBytes(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + + // og-rek decodes Python ints as int64, floats as float64. + // For ints beyond int64 range (e.g. MaxUint64), og-rek returns *big.Int. + // og-rek decodes None as pickle.None{}, not Go nil. + expected := map[string]any{ + "nil": pickle.None{}, + "true": true, + "false": false, + "int 0": int64(0), + "int 1": int64(1), + "int 42": int64(42), + "int 255": int64(255), + "int 256": int64(256), + "int 65535": int64(65535), + "int 65536": int64(65536), + "int -1": int64(-1), + "int -100": int64(-100), + "int MaxInt32": int64(math.MaxInt32), + "int MinInt32": int64(math.MinInt32), + "int MaxInt32+1": new(big.Int).SetInt64(int64(math.MaxInt32) + 1), + "int MinInt32-1": new(big.Int).SetInt64(int64(math.MinInt32) - 1), + "int MaxInt64": new(big.Int).SetInt64(math.MaxInt64), + "int MinInt64": new(big.Int).SetInt64(math.MinInt64), + "uint MaxUint64": new(big.Int).SetUint64(math.MaxUint64), + "float 1.5": float64(1.5), + "float 3.14": float64(3.14), + "string hello": "hello", + "string empty": "", + "bytes short": pickle.Bytes("hello bytes"), + "bytes empty": pickle.Bytes(""), + } + + for _, tc := range goldenCases { + exp, ok := expected[tc.name] + if !ok { + continue + } + t.Run(tc.name, func(t *testing.T) { + got, err := pickleDeserialize(tc.bytes) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(got).To(gomega.Equal(exp), "pickleDeserialize(%s) value mismatch", tc.name) + }) + } +} + +// TestDictSerializeKeyRoundTrip verifies Go → Go compatibility: +// serialize with dictSerializeKey, then deserialize with pickleDeserialize. +func TestDictSerializeKeyRoundTrip(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + + cases := []struct { + name string + input any + want any // expected value after round-trip (may differ due to type widening) + }{ + {"nil", nil, pickle.None{}}, + {"true", true, true}, + {"false", false, false}, + {"int 0", 0, int64(0)}, + {"int 42", 42, int64(42)}, + {"int -100", -100, int64(-100)}, + {"int MaxInt32", int64(math.MaxInt32), int64(math.MaxInt32)}, + {"int MinInt32", int64(math.MinInt32), int64(math.MinInt32)}, + {"int MaxInt32+1", int64(math.MaxInt32) + 1, new(big.Int).SetInt64(int64(math.MaxInt32) + 1)}, + {"int MinInt32-1", int64(math.MinInt32) - 1, new(big.Int).SetInt64(int64(math.MinInt32) - 1)}, + {"int MaxInt64", int64(math.MaxInt64), new(big.Int).SetInt64(math.MaxInt64)}, + {"int MinInt64", int64(math.MinInt64), new(big.Int).SetInt64(math.MinInt64)}, + {"uint MaxUint64", uint64(math.MaxUint64), new(big.Int).SetUint64(math.MaxUint64)}, + {"int8", int8(127), int64(127)}, + {"int16", int16(32000), int64(32000)}, + {"int32", int32(math.MaxInt32), int64(math.MaxInt32)}, + {"uint8", uint8(200), int64(200)}, + {"uint16", uint16(50000), int64(50000)}, + {"uint32", uint32(math.MaxUint32), new(big.Int).SetInt64(int64(math.MaxUint32))}, + {"float64", 3.14, float64(3.14)}, + {"float32", float32(1.5), float64(1.5)}, + {"string", "hello world", "hello world"}, + {"string empty", "", ""}, + {"bytes", []byte{1, 2, 3}, pickle.Bytes(string([]byte{1, 2, 3}))}, + {"bytes empty", []byte{}, pickle.Bytes("")}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + serialized, err := dictSerializeKey(tc.input) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + deserialized, err := pickleDeserialize(serialized) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(deserialized).To(gomega.Equal(tc.want), "round-trip mismatch for %s", tc.name) + }) + } +} + +// TestOgrekCloudpickleTransformRoundTrip verifies that cloudpickleToOgRek +// is the inverse of ogrekToCloudpickle for all supported types. +func TestOgrekCloudpickleTransformRoundTrip(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + + values := []any{ + nil, true, false, + 0, 42, -1, 255, 256, 65535, 65536, -100, + int64(math.MaxInt32), int64(math.MinInt32), + int64(math.MaxInt32) + 1, int64(math.MinInt32) - 1, + int64(math.MaxInt64), int64(math.MinInt64), + uint64(math.MaxUint64), + 1.5, 3.14, + "hello", "", + []byte{1, 2, 3}, []byte{}, + } + + for _, v := range values { + raw, err := dictOgRekP4Serialize(v) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + patched := ogrekToCloudpickle(raw) + stripped := cloudpickleToOgRek(patched) + + g.Expect(stripped).To(gomega.Equal(raw), "round-trip transform mismatch for %v", v) + } +} diff --git a/modal-go/test/dict_test.go b/modal-go/test/dict_test.go new file mode 100644 index 00000000..df413a0d --- /dev/null +++ b/modal-go/test/dict_test.go @@ -0,0 +1,395 @@ +package test + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/modal-labs/libmodal/modal-go" + "github.com/modal-labs/libmodal/modal-go/internal/grpcmock" + pb "github.com/modal-labs/libmodal/modal-go/proto/modal_proto" + "github.com/onsi/gomega" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/emptypb" +) + +func TestDictEphemeral(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + dict, err := tc.Dicts.Ephemeral(ctx, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer dict.CloseEphemeral() + g.Expect(dict.Name).To(gomega.BeEmpty()) + + created, err := dict.Put(ctx, "key1", "value1", nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(created).To(gomega.BeTrue()) + + n, err := dict.Len(ctx) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(n).To(gomega.Equal(1)) + + val, found, err := dict.Get(ctx, "key1") + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(found).To(gomega.BeTrue()) + g.Expect(val).To(gomega.Equal("value1")) +} + +func TestDictPutAndGet(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + dict, err := tc.Dicts.Ephemeral(ctx, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + defer dict.CloseEphemeral() + + _, err = dict.Put(ctx, "hello", "world", nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + val, found, err := dict.Get(ctx, "hello") + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(found).To(gomega.BeTrue()) + g.Expect(val).To(gomega.Equal("world")) + + // missing key + _, found, err = dict.Get(ctx, "missing") + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(found).To(gomega.BeFalse()) +} + +func TestDictContains(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + dict, err := tc.Dicts.Ephemeral(ctx, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + defer dict.CloseEphemeral() + + _, err = dict.Put(ctx, "exists", 123, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + found, err := dict.Contains(ctx, "exists") + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(found).To(gomega.BeTrue()) + + found, err = dict.Contains(ctx, "nope") + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(found).To(gomega.BeFalse()) +} + +func TestDictLen(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + dict, err := tc.Dicts.Ephemeral(ctx, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + defer dict.CloseEphemeral() + + n, err := dict.Len(ctx) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(n).To(gomega.Equal(0)) + + _, err = dict.Put(ctx, "a", 1, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + _, err = dict.Put(ctx, "b", 2, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + _, err = dict.Put(ctx, "c", 3, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + n, err = dict.Len(ctx) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(n).To(gomega.Equal(3)) +} + +func TestDictPop(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + dict, err := tc.Dicts.Ephemeral(ctx, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + defer dict.CloseEphemeral() + + _, err = dict.Put(ctx, "key", "value", nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + val, found, err := dict.Pop(ctx, "key") + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(found).To(gomega.BeTrue()) + g.Expect(val).To(gomega.Equal("value")) + + // key should be gone + _, found, err = dict.Get(ctx, "key") + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(found).To(gomega.BeFalse()) + + // pop missing key + _, found, err = dict.Pop(ctx, "key") + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(found).To(gomega.BeFalse()) +} + +func TestDictClear(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + dict, err := tc.Dicts.Ephemeral(ctx, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + defer dict.CloseEphemeral() + + _, err = dict.Put(ctx, "a", 1, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + _, err = dict.Put(ctx, "b", 2, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + err = dict.Clear(ctx) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + n, err := dict.Len(ctx) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(n).To(gomega.Equal(0)) +} + +func TestDictUpdate(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + dict, err := tc.Dicts.Ephemeral(ctx, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + defer dict.CloseEphemeral() + + err = dict.Update(ctx, map[any]any{ + "x": 10, + "y": 20, + "z": 30, + }) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + n, err := dict.Len(ctx) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(n).To(gomega.Equal(3)) + + val, found, err := dict.Get(ctx, "x") + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(found).To(gomega.BeTrue()) + g.Expect(val).To(gomega.Equal(int64(10))) +} + +func TestDictPutSkipIfExists(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + dict, err := tc.Dicts.Ephemeral(ctx, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + defer dict.CloseEphemeral() + + created, err := dict.Put(ctx, "key", "first", nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(created).To(gomega.BeTrue()) + + created, err = dict.Put(ctx, "key", "second", &modal.DictPutParams{SkipIfExists: true}) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(created).To(gomega.BeFalse()) + + // original value should be preserved + val, found, err := dict.Get(ctx, "key") + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(found).To(gomega.BeTrue()) + g.Expect(val).To(gomega.Equal("first")) +} + +func TestDictNonEphemeral(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + dictName := "test-dict-" + strconv.FormatInt(time.Now().UnixNano(), 10) + dict1, err := tc.Dicts.FromName(ctx, dictName, &modal.DictFromNameParams{CreateIfMissing: true}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(dict1.Name).To(gomega.Equal(dictName)) + + defer func() { + err := tc.Dicts.Delete(ctx, dictName, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + _, err = tc.Dicts.FromName(ctx, dictName, nil) + g.Expect(err).Should(gomega.HaveOccurred()) + }() + + _, err = dict1.Put(ctx, "data-key", "data-value", nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + dict2, err := tc.Dicts.FromName(ctx, dictName, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + val, found, err := dict2.Get(ctx, "data-key") + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(found).To(gomega.BeTrue()) + g.Expect(val).To(gomega.Equal("data-value")) +} + +func TestDictKeysValuesItems(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + dict, err := tc.Dicts.Ephemeral(ctx, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + defer dict.CloseEphemeral() + + err = dict.Update(ctx, map[any]any{ + "a": int64(1), + "b": int64(2), + "c": int64(3), + }) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + // Keys + keys := make(map[any]bool) + for k, err := range dict.Keys(ctx) { + g.Expect(err).ToNot(gomega.HaveOccurred()) + keys[k] = true + } + g.Expect(keys).To(gomega.HaveLen(3)) + g.Expect(keys).To(gomega.HaveKey("a")) + g.Expect(keys).To(gomega.HaveKey("b")) + g.Expect(keys).To(gomega.HaveKey("c")) + + // Values + values := make([]any, 0, 3) + for v, err := range dict.Values(ctx) { + g.Expect(err).ToNot(gomega.HaveOccurred()) + values = append(values, v) + } + g.Expect(values).To(gomega.HaveLen(3)) + g.Expect(values).To(gomega.ContainElements(int64(1), int64(2), int64(3))) + + // Items + items := make(map[any]any) + for item, err := range dict.Items(ctx) { + g.Expect(err).ToNot(gomega.HaveOccurred()) + items[item.Key] = item.Value + } + g.Expect(items).To(gomega.HaveLen(3)) + g.Expect(items["a"]).To(gomega.Equal(int64(1))) + g.Expect(items["b"]).To(gomega.Equal(int64(2))) + g.Expect(items["c"]).To(gomega.Equal(int64(3))) +} + +func TestDictDeleteSuccess(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + + mock := newGRPCMockClient(t) + + grpcmock.HandleUnary( + mock, "/DictGetOrCreate", + func(req *pb.DictGetOrCreateRequest) (*pb.DictGetOrCreateResponse, error) { + return pb.DictGetOrCreateResponse_builder{ + DictId: "di-test-123", + }.Build(), nil + }, + ) + + grpcmock.HandleUnary( + mock, "/DictDelete", + func(req *pb.DictDeleteRequest) (*emptypb.Empty, error) { + g.Expect(req.GetDictId()).To(gomega.Equal("di-test-123")) + return &emptypb.Empty{}, nil + }, + ) + + err := mock.Dicts.Delete(ctx, "test-dict", nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + g.Expect(mock.AssertExhausted()).ShouldNot(gomega.HaveOccurred()) +} + +func TestDictDeleteWithAllowMissing(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + + mock := newGRPCMockClient(t) + + grpcmock.HandleUnary( + mock, "/DictGetOrCreate", + func(req *pb.DictGetOrCreateRequest) (*pb.DictGetOrCreateResponse, error) { + return nil, modal.NotFoundError{Exception: "Dict 'missing' not found"} + }, + ) + + err := mock.Dicts.Delete(ctx, "missing", &modal.DictDeleteParams{ + AllowMissing: true, + }) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + g.Expect(mock.AssertExhausted()).ShouldNot(gomega.HaveOccurred()) +} + +func TestDictDeleteWithAllowMissingDeleteRPCNotFound(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + + mock := newGRPCMockClient(t) + + grpcmock.HandleUnary(mock, "/DictGetOrCreate", + func(req *pb.DictGetOrCreateRequest) (*pb.DictGetOrCreateResponse, error) { + return pb.DictGetOrCreateResponse_builder{DictId: "di-test-123"}.Build(), nil + }, + ) + + grpcmock.HandleUnary(mock, "/DictDelete", + func(req *pb.DictDeleteRequest) (*emptypb.Empty, error) { + return nil, status.Errorf(codes.NotFound, "Dict not found") + }, + ) + + err := mock.Dicts.Delete(ctx, "test-dict", &modal.DictDeleteParams{AllowMissing: true}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(mock.AssertExhausted()).ShouldNot(gomega.HaveOccurred()) +} + +func TestDictDeleteWithAllowMissingFalseThrows(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + + mock := newGRPCMockClient(t) + + grpcmock.HandleUnary( + mock, "/DictGetOrCreate", + func(req *pb.DictGetOrCreateRequest) (*pb.DictGetOrCreateResponse, error) { + return nil, modal.NotFoundError{Exception: "Dict 'missing' not found"} + }, + ) + + err := mock.Dicts.Delete(ctx, "missing", &modal.DictDeleteParams{ + AllowMissing: false, + }) + g.Expect(err).Should(gomega.HaveOccurred()) + + g.Expect(mock.AssertExhausted()).ShouldNot(gomega.HaveOccurred()) +}