Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 31fbd0b667 |
@@ -96,13 +96,13 @@ jobs:
|
||||
repo,
|
||||
per_page: 100
|
||||
});
|
||||
|
||||
|
||||
release = releases.data.find(r => r.draft && r.tag_name === tagName);
|
||||
if (!release) {
|
||||
throw new Error(`No release found with tag ${tagName}`);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
await github.rest.repos.updateRelease({
|
||||
owner,
|
||||
repo,
|
||||
@@ -110,10 +110,10 @@ jobs:
|
||||
draft: false,
|
||||
prerelease: release.prerelease
|
||||
});
|
||||
|
||||
|
||||
const status = release.draft ? "was draft" : "was already published";
|
||||
core.info(`Release ${tagName} ensured to be published (${status}).`);
|
||||
|
||||
|
||||
} catch (error) {
|
||||
core.warning(`Could not find or update release for tag ${tagName}: ${error.message}`);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
rpc:
|
||||
# The IP address where this RPC service registers itself; if left blank, it defaults to the internal network IP
|
||||
registerIP:
|
||||
registerIP:
|
||||
# IP address that the RPC service listens on; setting to 0.0.0.0 listens on both internal and external IPs. If left blank, it automatically uses the internal network IP
|
||||
listenIP: 0.0.0.0
|
||||
# autoSetPorts indicates whether to automatically set the ports
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
module github.com/openimsdk/open-im-server/v3
|
||||
|
||||
go 1.22.7
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
firebase.google.com/go/v4 v4.14.1
|
||||
@@ -27,7 +27,6 @@ require (
|
||||
require github.com/google/uuid v1.6.0
|
||||
|
||||
require (
|
||||
github.com/IBM/sarama v1.43.0
|
||||
github.com/fatih/color v1.14.1
|
||||
github.com/gin-contrib/gzip v1.0.1
|
||||
github.com/go-redis/redis v6.15.9+incompatible
|
||||
@@ -42,7 +41,7 @@ require (
|
||||
github.com/spf13/viper v1.18.2
|
||||
go.etcd.io/etcd/client/v3 v3.5.13
|
||||
go.uber.org/automaxprocs v1.5.3
|
||||
golang.org/x/sync v0.10.0
|
||||
golang.org/x/sync v0.18.0
|
||||
k8s.io/api v0.31.2
|
||||
k8s.io/apimachinery v0.31.2
|
||||
k8s.io/client-go v0.31.2
|
||||
@@ -55,6 +54,7 @@ require (
|
||||
cloud.google.com/go/iam v1.1.7 // indirect
|
||||
cloud.google.com/go/longrunning v0.5.5 // indirect
|
||||
cloud.google.com/go/storage v1.40.0 // indirect
|
||||
github.com/IBM/sarama v1.43.0 // indirect
|
||||
github.com/MicahParks/keyfunc v1.9.0 // indirect
|
||||
github.com/aliyun/aliyun-oss-go-sdk v3.0.2+incompatible // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.32.5 // indirect
|
||||
@@ -194,11 +194,11 @@ require (
|
||||
golang.org/x/arch v0.7.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||
golang.org/x/image v0.15.0 // indirect
|
||||
golang.org/x/net v0.34.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/oauth2 v0.25.0 // indirect
|
||||
golang.org/x/sys v0.29.0 // indirect
|
||||
golang.org/x/term v0.28.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/term v0.37.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
google.golang.org/appengine/v2 v2.0.2 // indirect
|
||||
google.golang.org/genproto v0.0.0-20240213162025-012b6fc9bca9 // indirect
|
||||
@@ -222,6 +222,6 @@ require (
|
||||
github.com/spf13/cobra v1.8.0
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
go.uber.org/zap v1.24.0 // indirect
|
||||
golang.org/x/crypto v0.32.0 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
|
||||
@@ -510,8 +510,8 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
|
||||
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
|
||||
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
|
||||
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
|
||||
@@ -539,8 +539,8 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug
|
||||
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
|
||||
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70=
|
||||
golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
@@ -550,8 +550,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -567,14 +567,14 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg=
|
||||
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
@@ -583,8 +583,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -596,8 +596,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@@ -290,7 +290,6 @@ func newGinRouter(ctx context.Context, client discovery.SvcDiscoveryRegistry, cf
|
||||
conversationGroup.POST("/get_not_notify_conversation_ids", c.GetNotNotifyConversationIDs)
|
||||
conversationGroup.POST("/get_pinned_conversation_ids", c.GetPinnedConversationIDs)
|
||||
conversationGroup.POST("/delete_conversations", c.DeleteConversations)
|
||||
conversationGroup.POST("/update_conversations_by_user", c.UpdateConversationsByUser)
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
+151
-10
@@ -16,6 +16,7 @@ package msggateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -30,6 +31,7 @@ import (
|
||||
"github.com/openimsdk/tools/errs"
|
||||
"github.com/openimsdk/tools/log"
|
||||
"github.com/openimsdk/tools/mcontext"
|
||||
"github.com/openimsdk/tools/utils/stringutil"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -62,7 +64,7 @@ type PingPongHandler func(string) error
|
||||
|
||||
type Client struct {
|
||||
w *sync.Mutex
|
||||
conn ClientConn
|
||||
conn LongConn
|
||||
PlatformID int `json:"platformID"`
|
||||
IsCompress bool `json:"isCompress"`
|
||||
UserID string `json:"userID"`
|
||||
@@ -82,10 +84,10 @@ type Client struct {
|
||||
}
|
||||
|
||||
// ResetClient updates the client's state with new connection and context information.
|
||||
func (c *Client) ResetClient(ctx *UserConnContext, conn ClientConn, longConnServer LongConnServer) {
|
||||
func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, longConnServer LongConnServer) {
|
||||
c.w = new(sync.Mutex)
|
||||
c.conn = conn
|
||||
c.PlatformID = ctx.GetPlatformID()
|
||||
c.PlatformID = stringutil.StringToInt(ctx.GetPlatformID())
|
||||
c.IsCompress = ctx.GetCompression()
|
||||
c.IsBackground = ctx.GetBackground()
|
||||
c.UserID = ctx.GetUserID()
|
||||
@@ -110,6 +112,22 @@ func (c *Client) ResetClient(ctx *UserConnContext, conn ClientConn, longConnServ
|
||||
c.subUserIDs = make(map[string]struct{})
|
||||
}
|
||||
|
||||
func (c *Client) pingHandler(appData string) error {
|
||||
if err := c.conn.SetReadDeadline(pongWait); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.ZDebug(c.ctx, "ping Handler Success.", "appData", appData)
|
||||
return c.writePongMsg(appData)
|
||||
}
|
||||
|
||||
func (c *Client) pongHandler(_ string) error {
|
||||
if err := c.conn.SetReadDeadline(pongWait); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readMessage continuously reads messages from the connection.
|
||||
func (c *Client) readMessage() {
|
||||
defer func() {
|
||||
@@ -120,25 +138,52 @@ func (c *Client) readMessage() {
|
||||
c.close()
|
||||
}()
|
||||
|
||||
c.conn.SetReadLimit(maxMessageSize)
|
||||
_ = c.conn.SetReadDeadline(pongWait)
|
||||
c.conn.SetPongHandler(c.pongHandler)
|
||||
c.conn.SetPingHandler(c.pingHandler)
|
||||
c.activeHeartbeat(c.hbCtx)
|
||||
|
||||
for {
|
||||
log.ZDebug(c.ctx, "readMessage")
|
||||
message, returnErr := c.conn.ReadMessage()
|
||||
messageType, message, returnErr := c.conn.ReadMessage()
|
||||
if returnErr != nil {
|
||||
log.ZWarn(c.ctx, "readMessage", returnErr)
|
||||
log.ZWarn(c.ctx, "readMessage", returnErr, "messageType", messageType)
|
||||
c.closedErr = returnErr
|
||||
return
|
||||
}
|
||||
|
||||
log.ZDebug(c.ctx, "readMessage", "messageType", messageType)
|
||||
if c.closed.Load() {
|
||||
// The scenario where the connection has just been closed, but the coroutine has not exited
|
||||
c.closedErr = ErrConnClosed
|
||||
return
|
||||
}
|
||||
|
||||
parseDataErr := c.handleMessage(message)
|
||||
if parseDataErr != nil {
|
||||
c.closedErr = parseDataErr
|
||||
switch messageType {
|
||||
case MessageBinary:
|
||||
_ = c.conn.SetReadDeadline(pongWait)
|
||||
parseDataErr := c.handleMessage(message)
|
||||
if parseDataErr != nil {
|
||||
c.closedErr = parseDataErr
|
||||
return
|
||||
}
|
||||
case MessageText:
|
||||
_ = c.conn.SetReadDeadline(pongWait)
|
||||
parseDataErr := c.handlerTextMessage(message)
|
||||
if parseDataErr != nil {
|
||||
c.closedErr = parseDataErr
|
||||
return
|
||||
}
|
||||
case PingMessage:
|
||||
err := c.writePongMsg("")
|
||||
log.ZError(c.ctx, "writePongMsg", err)
|
||||
|
||||
case CloseMessage:
|
||||
c.closedErr = ErrClientClosed
|
||||
return
|
||||
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -313,13 +358,109 @@ func (c *Client) writeBinaryMsg(resp Resp) error {
|
||||
c.w.Lock()
|
||||
defer c.w.Unlock()
|
||||
|
||||
err = c.conn.SetWriteDeadline(writeWait)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.IsCompress {
|
||||
resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf)
|
||||
if compressErr != nil {
|
||||
return compressErr
|
||||
}
|
||||
return c.conn.WriteMessage(resultBuf)
|
||||
return c.conn.WriteMessage(MessageBinary, resultBuf)
|
||||
}
|
||||
|
||||
return c.conn.WriteMessage(encodedBuf)
|
||||
return c.conn.WriteMessage(MessageBinary, encodedBuf)
|
||||
}
|
||||
|
||||
// Actively initiate Heartbeat when platform in Web.
|
||||
func (c *Client) activeHeartbeat(ctx context.Context) {
|
||||
if c.PlatformID == constant.WebPlatformID {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.ZPanic(ctx, "activeHeartbeat Panic", errs.ErrPanic(r))
|
||||
}
|
||||
}()
|
||||
log.ZDebug(ctx, "server initiative send heartbeat start.")
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := c.writePingMsg(); err != nil {
|
||||
log.ZWarn(c.ctx, "send Ping Message error.", err)
|
||||
return
|
||||
}
|
||||
case <-c.hbCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
func (c *Client) writePingMsg() error {
|
||||
if c.closed.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.w.Lock()
|
||||
defer c.w.Unlock()
|
||||
|
||||
err := c.conn.SetWriteDeadline(writeWait)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.conn.WriteMessage(PingMessage, nil)
|
||||
}
|
||||
|
||||
func (c *Client) writePongMsg(appData string) error {
|
||||
log.ZDebug(c.ctx, "write Pong Msg in Server", "appData", appData)
|
||||
if c.closed.Load() {
|
||||
log.ZWarn(c.ctx, "is closed in server", nil, "appdata", appData, "closed err", c.closedErr)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.w.Lock()
|
||||
defer c.w.Unlock()
|
||||
|
||||
err := c.conn.SetWriteDeadline(writeWait)
|
||||
if err != nil {
|
||||
log.ZWarn(c.ctx, "SetWriteDeadline in Server have error", errs.Wrap(err), "writeWait", writeWait, "appData", appData)
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
err = c.conn.WriteMessage(PongMessage, []byte(appData))
|
||||
if err != nil {
|
||||
log.ZWarn(c.ctx, "Write Message have error", errs.Wrap(err), "Pong msg", PongMessage)
|
||||
}
|
||||
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
|
||||
func (c *Client) handlerTextMessage(b []byte) error {
|
||||
var msg TextMessage
|
||||
if err := json.Unmarshal(b, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
switch msg.Type {
|
||||
case TextPong:
|
||||
return nil
|
||||
case TextPing:
|
||||
msg.Type = TextPong
|
||||
msgData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.w.Lock()
|
||||
defer c.w.Unlock()
|
||||
if err := c.conn.SetWriteDeadline(writeWait); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.conn.WriteMessage(MessageText, msgData)
|
||||
default:
|
||||
return fmt.Errorf("not support message type %s", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,229 +0,0 @@
|
||||
package msggateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/openimsdk/tools/log"
|
||||
)
|
||||
|
||||
var ErrWriteFull = fmt.Errorf("websocket write buffer full,close connection")
|
||||
|
||||
type ClientConn interface {
|
||||
ReadMessage() ([]byte, error)
|
||||
WriteMessage(message []byte) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type websocketMessage struct {
|
||||
MessageType int
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func NewWebSocketClientConn(conn *websocket.Conn, readLimit int64, readTimeout time.Duration, pingInterval time.Duration) ClientConn {
|
||||
c := &websocketClientConn{
|
||||
readTimeout: readTimeout,
|
||||
conn: conn,
|
||||
writer: make(chan *websocketMessage, 256),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
if readLimit > 0 {
|
||||
c.conn.SetReadLimit(readLimit)
|
||||
}
|
||||
c.conn.SetPingHandler(c.pingHandler)
|
||||
c.conn.SetPongHandler(c.pongHandler)
|
||||
|
||||
go c.loopSend()
|
||||
if pingInterval > 0 {
|
||||
go c.doPing(pingInterval)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
type websocketClientConn struct {
|
||||
readTimeout time.Duration
|
||||
conn *websocket.Conn
|
||||
writer chan *websocketMessage
|
||||
done chan struct{}
|
||||
err atomic.Pointer[error]
|
||||
}
|
||||
|
||||
func (c *websocketClientConn) ReadMessage() ([]byte, error) {
|
||||
buf, err := c.readMessage()
|
||||
if err != nil {
|
||||
return nil, c.closeBy(fmt.Errorf("read message %w", err))
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (c *websocketClientConn) WriteMessage(message []byte) error {
|
||||
return c.writeMessage(websocket.BinaryMessage, message)
|
||||
}
|
||||
|
||||
func (c *websocketClientConn) Close() error {
|
||||
return c.closeBy(fmt.Errorf("websocket connection closed"))
|
||||
}
|
||||
|
||||
func (c *websocketClientConn) closeBy(err error) error {
|
||||
if !c.err.CompareAndSwap(nil, &err) {
|
||||
return *c.err.Load()
|
||||
}
|
||||
close(c.done)
|
||||
log.ZWarn(context.Background(), "websocket connection closed", err, "remoteAddr", c.conn.RemoteAddr(),
|
||||
"chan length", len(c.writer))
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *websocketClientConn) writeMessage(messageType int, data []byte) error {
|
||||
if errPtr := c.err.Load(); errPtr != nil {
|
||||
return *errPtr
|
||||
}
|
||||
select {
|
||||
case c.writer <- &websocketMessage{MessageType: messageType, Data: data}:
|
||||
return nil
|
||||
default:
|
||||
return c.closeBy(ErrWriteFull)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *websocketClientConn) loopSend() {
|
||||
defer func() {
|
||||
_ = c.conn.Close()
|
||||
}()
|
||||
var err error
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
for {
|
||||
select {
|
||||
case msg := <-c.writer:
|
||||
switch msg.MessageType {
|
||||
case websocket.TextMessage, websocket.BinaryMessage:
|
||||
err = c.conn.WriteMessage(msg.MessageType, msg.Data)
|
||||
default:
|
||||
err = c.conn.WriteControl(msg.MessageType, msg.Data, time.Time{})
|
||||
}
|
||||
if err != nil {
|
||||
_ = c.closeBy(err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
case msg := <-c.writer:
|
||||
switch msg.MessageType {
|
||||
case websocket.TextMessage, websocket.BinaryMessage:
|
||||
err = c.conn.WriteMessage(msg.MessageType, msg.Data)
|
||||
default:
|
||||
err = c.conn.WriteControl(msg.MessageType, msg.Data, time.Time{})
|
||||
}
|
||||
if err != nil {
|
||||
_ = c.closeBy(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *websocketClientConn) setReadDeadline() error {
|
||||
deadline := time.Now().Add(c.readTimeout)
|
||||
return c.conn.SetReadDeadline(deadline)
|
||||
}
|
||||
|
||||
func (c *websocketClientConn) readMessage() ([]byte, error) {
|
||||
for {
|
||||
if err := c.setReadDeadline(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
messageType, buf, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch messageType {
|
||||
case websocket.BinaryMessage:
|
||||
return buf, nil
|
||||
case websocket.TextMessage:
|
||||
if err := c.onReadTextMessage(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case websocket.PingMessage:
|
||||
if err := c.pingHandler(string(buf)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case websocket.PongMessage:
|
||||
if err := c.pongHandler(string(buf)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case websocket.CloseMessage:
|
||||
if len(buf) == 0 {
|
||||
return nil, errors.New("websocket connection closed by peer")
|
||||
}
|
||||
return nil, fmt.Errorf("websocket connection closed by peer, data %s", string(buf))
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown websocket message type %d", messageType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *websocketClientConn) onReadTextMessage(buf []byte) error {
|
||||
var msg struct {
|
||||
Type string `json:"type"`
|
||||
Body json.RawMessage `json:"body"`
|
||||
}
|
||||
if err := json.Unmarshal(buf, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
switch msg.Type {
|
||||
case TextPong:
|
||||
return nil
|
||||
case TextPing:
|
||||
msg.Type = TextPong
|
||||
msgData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.writeMessage(websocket.TextMessage, msgData)
|
||||
default:
|
||||
return fmt.Errorf("not support text message type %s", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *websocketClientConn) pingHandler(appData string) error {
|
||||
log.ZDebug(context.Background(), "ping handler recv ping", "remoteAddr", c.conn.RemoteAddr(), "appData", appData)
|
||||
if err := c.setReadDeadline(); err != nil {
|
||||
return err
|
||||
}
|
||||
err := c.conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(time.Second*1))
|
||||
if err != nil {
|
||||
log.ZWarn(context.Background(), "ping handler write pong error", err, "remoteAddr", c.conn.RemoteAddr(), "appData", appData)
|
||||
}
|
||||
log.ZDebug(context.Background(), "ping handler write pong success", "remoteAddr", c.conn.RemoteAddr(), "appData", appData)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *websocketClientConn) pongHandler(string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *websocketClientConn) doPing(d time.Duration) {
|
||||
ticker := time.NewTicker(d)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := c.writeMessage(websocket.PingMessage, nil); err != nil {
|
||||
_ = c.closeBy(fmt.Errorf("send ping %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
+88
-145
@@ -15,8 +15,6 @@
|
||||
package msggateway
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
@@ -26,21 +24,10 @@ import (
|
||||
|
||||
"github.com/openimsdk/protocol/constant"
|
||||
"github.com/openimsdk/tools/utils/encrypt"
|
||||
"github.com/openimsdk/tools/utils/stringutil"
|
||||
"github.com/openimsdk/tools/utils/timeutil"
|
||||
)
|
||||
|
||||
type UserConnContextInfo struct {
|
||||
Token string `json:"token"`
|
||||
UserID string `json:"userID"`
|
||||
PlatformID int `json:"platformID"`
|
||||
OperationID string `json:"operationID"`
|
||||
Compression string `json:"compression"`
|
||||
SDKType string `json:"sdkType"`
|
||||
SendResponse bool `json:"sendResponse"`
|
||||
Background bool `json:"background"`
|
||||
SDKVersion string `json:"sdkVersion"`
|
||||
}
|
||||
|
||||
type UserConnContext struct {
|
||||
RespWriter http.ResponseWriter
|
||||
Req *http.Request
|
||||
@@ -48,7 +35,6 @@ type UserConnContext struct {
|
||||
Method string
|
||||
RemoteAddr string
|
||||
ConnID string
|
||||
info *UserConnContextInfo
|
||||
}
|
||||
|
||||
func (c *UserConnContext) Deadline() (deadline time.Time, ok bool) {
|
||||
@@ -72,11 +58,9 @@ func (c *UserConnContext) Value(key any) any {
|
||||
case constant.ConnID:
|
||||
return c.GetConnID()
|
||||
case constant.OpUserPlatform:
|
||||
return c.GetPlatformID()
|
||||
return constant.PlatformIDToName(stringutil.StringToInt(c.GetPlatformID()))
|
||||
case constant.RemoteAddr:
|
||||
return c.RemoteAddr
|
||||
case SDKVersion:
|
||||
return c.info.SDKVersion
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
@@ -99,92 +83,30 @@ func newContext(respWriter http.ResponseWriter, req *http.Request) *UserConnCont
|
||||
|
||||
func newTempContext() *UserConnContext {
|
||||
return &UserConnContext{
|
||||
Req: &http.Request{URL: &url.URL{}},
|
||||
info: &UserConnContextInfo{},
|
||||
Req: &http.Request{URL: &url.URL{}},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UserConnContext) ParseEssentialArgs() error {
|
||||
query := c.Req.URL.Query()
|
||||
if data := query.Get("v"); data != "" {
|
||||
return c.parseByJson(data)
|
||||
} else {
|
||||
return c.parseByQuery(query, c.Req.Header)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UserConnContext) parseByQuery(query url.Values, header http.Header) error {
|
||||
info := UserConnContextInfo{
|
||||
Token: query.Get(Token),
|
||||
UserID: query.Get(WsUserID),
|
||||
OperationID: query.Get(OperationID),
|
||||
Compression: query.Get(Compression),
|
||||
SDKType: query.Get(SDKType),
|
||||
SDKVersion: query.Get(SDKVersion),
|
||||
}
|
||||
platformID, err := strconv.Atoi(query.Get(PlatformID))
|
||||
if err != nil {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
|
||||
}
|
||||
info.PlatformID = platformID
|
||||
if val := query.Get(SendResponse); val != "" {
|
||||
ok, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("isMsgResp is not bool")
|
||||
}
|
||||
info.SendResponse = ok
|
||||
}
|
||||
if info.Compression == "" {
|
||||
info.Compression = header.Get(Compression)
|
||||
}
|
||||
background, err := strconv.ParseBool(query.Get(BackgroundStatus))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info.Background = background
|
||||
return c.checkInfo(&info)
|
||||
}
|
||||
|
||||
func (c *UserConnContext) parseByJson(data string) error {
|
||||
reqInfo, err := base64.RawURLEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("data is not base64")
|
||||
}
|
||||
var info UserConnContextInfo
|
||||
if err := json.Unmarshal(reqInfo, &info); err != nil {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("data is not json", "info", err.Error())
|
||||
}
|
||||
return c.checkInfo(&info)
|
||||
}
|
||||
|
||||
func (c *UserConnContext) checkInfo(info *UserConnContextInfo) error {
|
||||
if info.OperationID == "" {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("operationID is empty")
|
||||
}
|
||||
if info.Token == "" {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("token is empty")
|
||||
}
|
||||
if info.UserID == "" {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
|
||||
}
|
||||
if _, ok := constant.PlatformID2Name[info.PlatformID]; !ok {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("platformID is invalid")
|
||||
}
|
||||
switch info.SDKType {
|
||||
case "":
|
||||
info.SDKType = GoSDK
|
||||
case GoSDK, JsSDK:
|
||||
default:
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("sdkType is invalid")
|
||||
}
|
||||
c.info = info
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetRemoteAddr() string {
|
||||
return c.RemoteAddr
|
||||
}
|
||||
|
||||
func (c *UserConnContext) Query(key string) (string, bool) {
|
||||
var value string
|
||||
if value = c.Req.URL.Query().Get(key); value == "" {
|
||||
return value, false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetHeader(key string) (string, bool) {
|
||||
var value string
|
||||
if value = c.Req.Header.Get(key); value == "" {
|
||||
return value, false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
|
||||
func (c *UserConnContext) SetHeader(key, value string) {
|
||||
c.RespWriter.Header().Set(key, value)
|
||||
}
|
||||
@@ -198,76 +120,97 @@ func (c *UserConnContext) GetConnID() string {
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetUserID() string {
|
||||
if c == nil || c.info == nil {
|
||||
return ""
|
||||
}
|
||||
return c.info.UserID
|
||||
return c.Req.URL.Query().Get(WsUserID)
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetPlatformID() int {
|
||||
if c == nil || c.info == nil {
|
||||
return 0
|
||||
}
|
||||
return c.info.PlatformID
|
||||
func (c *UserConnContext) GetPlatformID() string {
|
||||
return c.Req.URL.Query().Get(PlatformID)
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetOperationID() string {
|
||||
if c == nil || c.info == nil {
|
||||
return ""
|
||||
}
|
||||
return c.info.OperationID
|
||||
return c.Req.URL.Query().Get(OperationID)
|
||||
}
|
||||
|
||||
func (c *UserConnContext) SetOperationID(operationID string) {
|
||||
if c.info == nil {
|
||||
c.info = &UserConnContextInfo{}
|
||||
}
|
||||
c.info.OperationID = operationID
|
||||
values := c.Req.URL.Query()
|
||||
values.Set(OperationID, operationID)
|
||||
c.Req.URL.RawQuery = values.Encode()
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetToken() string {
|
||||
if c == nil || c.info == nil {
|
||||
return ""
|
||||
}
|
||||
return c.info.Token
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetCompression() bool {
|
||||
return c != nil && c.info != nil && c.info.Compression == GzipCompressionProtocol
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetSDKType() string {
|
||||
if c == nil || c.info == nil {
|
||||
return GoSDK
|
||||
}
|
||||
switch c.info.SDKType {
|
||||
case "", GoSDK:
|
||||
return GoSDK
|
||||
case JsSDK:
|
||||
return JsSDK
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
return c.Req.URL.Query().Get(Token)
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetSDKVersion() string {
|
||||
if c == nil || c.info == nil {
|
||||
return ""
|
||||
return c.Req.URL.Query().Get(SDKVersion)
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetCompression() bool {
|
||||
compression, exists := c.Query(Compression)
|
||||
if exists && compression == GzipCompressionProtocol {
|
||||
return true
|
||||
} else {
|
||||
compression, exists := c.GetHeader(Compression)
|
||||
if exists && compression == GzipCompressionProtocol {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return c.info.SDKVersion
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetSDKType() string {
|
||||
sdkType := c.Req.URL.Query().Get(SDKType)
|
||||
if sdkType == "" {
|
||||
sdkType = GoSDK
|
||||
}
|
||||
return sdkType
|
||||
}
|
||||
|
||||
func (c *UserConnContext) ShouldSendResp() bool {
|
||||
return c != nil && c.info != nil && c.info.SendResponse
|
||||
errResp, exists := c.Query(SendResponse)
|
||||
if exists {
|
||||
b, err := strconv.ParseBool(errResp)
|
||||
if err != nil {
|
||||
return false
|
||||
} else {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *UserConnContext) SetToken(token string) {
|
||||
if c.info == nil {
|
||||
c.info = &UserConnContextInfo{}
|
||||
}
|
||||
c.info.Token = token
|
||||
c.Req.URL.RawQuery = Token + "=" + token
|
||||
}
|
||||
|
||||
func (c *UserConnContext) GetBackground() bool {
|
||||
return c != nil && c.info != nil && c.info.Background
|
||||
b, err := strconv.ParseBool(c.Req.URL.Query().Get(BackgroundStatus))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return b
|
||||
}
|
||||
func (c *UserConnContext) ParseEssentialArgs() error {
|
||||
_, exists := c.Query(Token)
|
||||
if !exists {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("token is empty")
|
||||
}
|
||||
_, exists = c.Query(WsUserID)
|
||||
if !exists {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("sendID is empty")
|
||||
}
|
||||
platformIDStr, exists := c.Query(PlatformID)
|
||||
if !exists {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("platformID is empty")
|
||||
}
|
||||
_, err := strconv.Atoi(platformIDStr)
|
||||
if err != nil {
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("platformID is not int")
|
||||
}
|
||||
switch sdkType, _ := c.Query(SDKType); sdkType {
|
||||
case "", GoSDK, JsSDK:
|
||||
default:
|
||||
return servererrs.ErrConnArgsErr.WrapMsg("sdkType is not go or js")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
// Copyright © 2023 OpenIM. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msggateway
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/openimsdk/tools/apiresp"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/openimsdk/tools/errs"
|
||||
)
|
||||
|
||||
type LongConn interface {
|
||||
// Close this connection
|
||||
Close() error
|
||||
// WriteMessage Write message to connection,messageType means data type,can be set binary(2) and text(1).
|
||||
WriteMessage(messageType int, message []byte) error
|
||||
// ReadMessage Read message from connection.
|
||||
ReadMessage() (int, []byte, error)
|
||||
// SetReadDeadline sets the read deadline on the underlying network connection,
|
||||
// after a read has timed out, will return an error.
|
||||
SetReadDeadline(timeout time.Duration) error
|
||||
// SetWriteDeadline sets to write deadline when send message,when read has timed out,will return error.
|
||||
SetWriteDeadline(timeout time.Duration) error
|
||||
// Dial Try to dial a connection,url must set auth args,header can control compress data
|
||||
Dial(urlStr string, requestHeader http.Header) (*http.Response, error)
|
||||
// IsNil Whether the connection of the current long connection is nil
|
||||
IsNil() bool
|
||||
// SetConnNil Set the connection of the current long connection to nil
|
||||
SetConnNil()
|
||||
// SetReadLimit sets the maximum size for a message read from the peer.bytes
|
||||
SetReadLimit(limit int64)
|
||||
SetPongHandler(handler PingPongHandler)
|
||||
SetPingHandler(handler PingPongHandler)
|
||||
// GenerateLongConn Check the connection of the current and when it was sent are the same
|
||||
GenerateLongConn(w http.ResponseWriter, r *http.Request) error
|
||||
}
|
||||
type GWebSocket struct {
|
||||
protocolType int
|
||||
conn *websocket.Conn
|
||||
handshakeTimeout time.Duration
|
||||
writeBufferSize int
|
||||
}
|
||||
|
||||
func newGWebSocket(protocolType int, handshakeTimeout time.Duration, wbs int) *GWebSocket {
|
||||
return &GWebSocket{protocolType: protocolType, handshakeTimeout: handshakeTimeout, writeBufferSize: wbs}
|
||||
}
|
||||
|
||||
func (d *GWebSocket) Close() error {
|
||||
return d.conn.Close()
|
||||
}
|
||||
|
||||
func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) error {
|
||||
upgrader := &websocket.Upgrader{
|
||||
HandshakeTimeout: d.handshakeTimeout,
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
if d.writeBufferSize > 0 { // default is 4kb.
|
||||
upgrader.WriteBufferSize = d.writeBufferSize
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
// The upgrader.Upgrade method usually returns enough error messages to diagnose problems that may occur during the upgrade
|
||||
return errs.WrapMsg(err, "GenerateLongConn: WebSocket upgrade failed")
|
||||
}
|
||||
d.conn = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *GWebSocket) WriteMessage(messageType int, message []byte) error {
|
||||
// d.setSendConn(d.conn)
|
||||
return d.conn.WriteMessage(messageType, message)
|
||||
}
|
||||
|
||||
// func (d *GWebSocket) setSendConn(sendConn *websocket.Conn) {
|
||||
// d.sendConn = sendConn
|
||||
//}
|
||||
|
||||
func (d *GWebSocket) ReadMessage() (int, []byte, error) {
|
||||
return d.conn.ReadMessage()
|
||||
}
|
||||
|
||||
func (d *GWebSocket) SetReadDeadline(timeout time.Duration) error {
|
||||
return d.conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
|
||||
func (d *GWebSocket) SetWriteDeadline(timeout time.Duration) error {
|
||||
if timeout <= 0 {
|
||||
return errs.New("timeout must be greater than 0")
|
||||
}
|
||||
|
||||
// TODO SetWriteDeadline Future add error handling
|
||||
if err := d.conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
|
||||
return errs.WrapMsg(err, "GWebSocket.SetWriteDeadline failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) {
|
||||
conn, httpResp, err := websocket.DefaultDialer.Dial(urlStr, requestHeader)
|
||||
if err != nil {
|
||||
return httpResp, errs.WrapMsg(err, "GWebSocket.Dial failed", "url", urlStr)
|
||||
}
|
||||
d.conn = conn
|
||||
return httpResp, nil
|
||||
}
|
||||
|
||||
func (d *GWebSocket) IsNil() bool {
|
||||
return d.conn == nil
|
||||
//
|
||||
// if d.conn != nil {
|
||||
// return false
|
||||
// }
|
||||
// return true
|
||||
}
|
||||
|
||||
func (d *GWebSocket) SetConnNil() {
|
||||
d.conn = nil
|
||||
}
|
||||
|
||||
func (d *GWebSocket) SetReadLimit(limit int64) {
|
||||
d.conn.SetReadLimit(limit)
|
||||
}
|
||||
|
||||
func (d *GWebSocket) SetPongHandler(handler PingPongHandler) {
|
||||
d.conn.SetPongHandler(handler)
|
||||
}
|
||||
|
||||
func (d *GWebSocket) SetPingHandler(handler PingPongHandler) {
|
||||
d.conn.SetPingHandler(handler)
|
||||
}
|
||||
|
||||
func (d *GWebSocket) RespondWithError(err error, w http.ResponseWriter, r *http.Request) error {
|
||||
if err := d.GenerateLongConn(w, r); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.Marshal(apiresp.ParseError(err))
|
||||
if err != nil {
|
||||
_ = d.Close()
|
||||
return errs.WrapMsg(err, "json marshal failed")
|
||||
}
|
||||
|
||||
if err := d.WriteMessage(MessageText, data); err != nil {
|
||||
_ = d.Close()
|
||||
return errs.WrapMsg(err, "WriteMessage failed")
|
||||
}
|
||||
_ = d.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *GWebSocket) RespondWithSuccess() error {
|
||||
data, err := json.Marshal(apiresp.ParseError(nil))
|
||||
if err != nil {
|
||||
_ = d.Close()
|
||||
return errs.WrapMsg(err, "json marshal failed")
|
||||
}
|
||||
|
||||
if err := d.WriteMessage(MessageText, data); err != nil {
|
||||
_ = d.Close()
|
||||
return errs.WrapMsg(err, "WriteMessage failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -2,20 +2,18 @@ package msggateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/rpcli"
|
||||
"github.com/openimsdk/tools/apiresp"
|
||||
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/webhook"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/rpccache"
|
||||
pbAuth "github.com/openimsdk/protocol/auth"
|
||||
"github.com/openimsdk/tools/errs"
|
||||
"github.com/openimsdk/tools/mcontext"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
@@ -25,11 +23,10 @@ import (
|
||||
"github.com/openimsdk/protocol/msggateway"
|
||||
"github.com/openimsdk/tools/discovery"
|
||||
"github.com/openimsdk/tools/log"
|
||||
"github.com/openimsdk/tools/utils/stringutil"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var wsSuccessResponse, _ = json.Marshal(&apiresp.ApiResponse{})
|
||||
|
||||
type LongConnServer interface {
|
||||
Run(ctx context.Context) error
|
||||
wsHandler(w http.ResponseWriter, r *http.Request)
|
||||
@@ -46,7 +43,6 @@ type LongConnServer interface {
|
||||
}
|
||||
|
||||
type WsServer struct {
|
||||
websocket *websocket.Upgrader
|
||||
msgGatewayConfig *Config
|
||||
port int
|
||||
wsMaxConnNum int64
|
||||
@@ -140,13 +136,9 @@ func NewWsServer(msgGatewayConfig *Config, opts ...Option) *WsServer {
|
||||
o(&config)
|
||||
}
|
||||
//userRpcClient := rpcclient.NewUserRpcClient(client, config.Discovery.RpcService.User, config.Share.IMAdminUser)
|
||||
upgrader := &websocket.Upgrader{
|
||||
HandshakeTimeout: config.handshakeTimeout,
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
v := validator.New()
|
||||
return &WsServer{
|
||||
websocket: upgrader,
|
||||
msgGatewayConfig: msgGatewayConfig,
|
||||
port: config.port,
|
||||
wsMaxConnNum: config.maxConnNum,
|
||||
@@ -268,7 +260,8 @@ func (ws *WsServer) registerClient(client *Client) {
|
||||
)
|
||||
oldClients, userOK, clientOK = ws.clients.Get(client.UserID, client.PlatformID)
|
||||
|
||||
log.ZInfo(client.ctx, "registerClient", "userID", client.UserID, "platformID", client.PlatformID)
|
||||
log.ZInfo(client.ctx, "registerClient", "userID", client.UserID, "platformID", client.PlatformID,
|
||||
"sdkVersion", client.SDKVersion)
|
||||
|
||||
if !userOK {
|
||||
ws.clients.Set(client.UserID, client)
|
||||
@@ -455,7 +448,7 @@ func (ws *WsServer) unregisterClient(client *Client) {
|
||||
// validateRespWithRequest checks if the response matches the expected userID and platformID.
|
||||
func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.ParseTokenResp) error {
|
||||
userID := ctx.GetUserID()
|
||||
platformID := int32(ctx.GetPlatformID())
|
||||
platformID := stringutil.StringToInt32(ctx.GetPlatformID())
|
||||
if resp.UserID != userID {
|
||||
return servererrs.ErrTokenInvalid.WrapMsg(fmt.Sprintf("token uid %s != userID %s", resp.UserID, userID))
|
||||
}
|
||||
@@ -465,37 +458,19 @@ func (ws *WsServer) validateRespWithRequest(ctx *UserConnContext, resp *pbAuth.P
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ws *WsServer) handlerError(ctx *UserConnContext, w http.ResponseWriter, r *http.Request, err error) {
|
||||
if !ctx.ShouldSendResp() {
|
||||
httpError(ctx, err)
|
||||
return
|
||||
}
|
||||
// the browser cannot get the response of upgrade failure
|
||||
data, err := json.Marshal(apiresp.ParseError(err))
|
||||
if err != nil {
|
||||
log.ZError(ctx, "json marshal failed", err)
|
||||
return
|
||||
}
|
||||
conn, upgradeErr := ws.websocket.Upgrade(w, r, nil)
|
||||
if upgradeErr != nil {
|
||||
log.ZWarn(ctx, "websocket upgrade failed", upgradeErr, "respErr", err, "resp", string(data))
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
log.ZWarn(ctx, "WriteMessage failed", err, "respErr", err, "resp", string(data))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Create a new connection context
|
||||
connContext := newContext(w, r)
|
||||
|
||||
if !ws.ready.Load() {
|
||||
httpError(connContext, errs.New("ws server not ready"))
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the current number of online user connections exceeds the maximum limit
|
||||
if ws.onlineUserConnNum.Load() >= ws.wsMaxConnNum {
|
||||
// If it exceeds the maximum connection number, return an error via HTTP and stop processing
|
||||
ws.handlerError(connContext, w, r, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit"))
|
||||
httpError(connContext, servererrs.ErrConnOverMaxNumLimit.WrapMsg("over max conn num limit"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -503,14 +478,31 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
err := connContext.ParseEssentialArgs()
|
||||
if err != nil {
|
||||
// If there's an error during parsing, return an error via HTTP and stop processing
|
||||
ws.handlerError(connContext, w, r, err)
|
||||
|
||||
httpError(connContext, err)
|
||||
return
|
||||
}
|
||||
|
||||
if ws.authClient == nil {
|
||||
httpError(connContext, errs.New("auth client is not initialized"))
|
||||
return
|
||||
}
|
||||
|
||||
// Call the authentication client to parse the Token obtained from the context
|
||||
resp, err := ws.authClient.ParseToken(connContext, connContext.GetToken())
|
||||
if err != nil {
|
||||
ws.handlerError(connContext, w, r, err)
|
||||
// If there's an error parsing the Token, decide whether to send the error message via WebSocket based on the context flag
|
||||
shouldSendError := connContext.ShouldSendResp()
|
||||
if shouldSendError {
|
||||
// Create a WebSocket connection object and attempt to send the error message via WebSocket
|
||||
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
|
||||
if err := wsLongConn.RespondWithError(err, w, r); err == nil {
|
||||
// If the error message is successfully sent via WebSocket, stop processing
|
||||
return
|
||||
}
|
||||
}
|
||||
// If sending via WebSocket is not required or fails, return the error via HTTP and stop processing
|
||||
httpError(connContext, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -518,30 +510,32 @@ func (ws *WsServer) wsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
err = ws.validateRespWithRequest(connContext, resp)
|
||||
if err != nil {
|
||||
// If validation fails, return an error via HTTP and stop processing
|
||||
ws.handlerError(connContext, w, r, err)
|
||||
httpError(connContext, err)
|
||||
return
|
||||
}
|
||||
conn, err := ws.websocket.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.ZWarn(connContext, "websocket upgrade failed", err)
|
||||
return
|
||||
}
|
||||
if connContext.ShouldSendResp() {
|
||||
if err := conn.WriteMessage(websocket.TextMessage, wsSuccessResponse); err != nil {
|
||||
log.ZWarn(connContext, "WriteMessage first response", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
log.ZDebug(connContext, "new conn", "token", connContext.GetToken())
|
||||
|
||||
var pingInterval time.Duration
|
||||
if connContext.GetPlatformID() == constant.WebPlatformID {
|
||||
pingInterval = pingPeriod
|
||||
// Create a WebSocket long connection object
|
||||
wsLongConn := newGWebSocket(WebSocket, ws.handshakeTimeout, ws.writeBufferSize)
|
||||
if err := wsLongConn.GenerateLongConn(w, r); err != nil {
|
||||
//If the creation of the long connection fails, the error is handled internally during the handshake process.
|
||||
log.ZWarn(connContext, "long connection fails", err)
|
||||
return
|
||||
} else {
|
||||
// Check if a normal response should be sent via WebSocket
|
||||
shouldSendSuccessResp := connContext.ShouldSendResp()
|
||||
if shouldSendSuccessResp {
|
||||
// Attempt to send a success message through WebSocket
|
||||
if err := wsLongConn.RespondWithSuccess(); err != nil {
|
||||
// If the success message is successfully sent, end further processing
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client := new(Client)
|
||||
client.ResetClient(connContext, NewWebSocketClientConn(conn, maxMessageSize, pongWait, pingInterval), ws)
|
||||
// Retrieve a client object from the client pool, reset its state, and associate it with the current WebSocket long connection
|
||||
client := ws.clientPool.Get().(*Client)
|
||||
client.ResetClient(connContext, wsLongConn, ws)
|
||||
|
||||
// Register the client with the server and start message processing
|
||||
ws.registerChan <- client
|
||||
|
||||
@@ -27,7 +27,7 @@ func (c *conversationServer) GetFullOwnerConversationIDs(ctx context.Context, re
|
||||
conversationIDs = nil
|
||||
}
|
||||
return &conversation.GetFullOwnerConversationIDsResp{
|
||||
Version: uint64(vl.Version),
|
||||
Version: idHash,
|
||||
VersionID: vl.ID.Hex(),
|
||||
Equal: req.IdHash == idHash,
|
||||
ConversationIDs: conversationIDs,
|
||||
|
||||
@@ -34,7 +34,7 @@ func (g *groupServer) GetFullGroupMemberUserIDs(ctx context.Context, req *pbgrou
|
||||
userIDs = nil
|
||||
}
|
||||
return &pbgroup.GetFullGroupMemberUserIDsResp{
|
||||
Version: uint64(vl.Version),
|
||||
Version: idHash,
|
||||
VersionID: vl.ID.Hex(),
|
||||
Equal: req.IdHash == idHash,
|
||||
UserIDs: userIDs,
|
||||
@@ -58,7 +58,7 @@ func (g *groupServer) GetFullJoinGroupIDs(ctx context.Context, req *pbgroup.GetF
|
||||
groupIDs = nil
|
||||
}
|
||||
return &pbgroup.GetFullJoinGroupIDsResp{
|
||||
Version: uint64(vl.Version),
|
||||
Version: idHash,
|
||||
VersionID: vl.ID.Hex(),
|
||||
Equal: req.IdHash == idHash,
|
||||
GroupIDs: groupIDs,
|
||||
|
||||
@@ -56,7 +56,7 @@ func (s *friendServer) GetFullFriendUserIDs(ctx context.Context, req *relation.G
|
||||
userIDs = nil
|
||||
}
|
||||
return &relation.GetFullFriendUserIDsResp{
|
||||
Version: uint64(vl.Version),
|
||||
Version: idHash,
|
||||
VersionID: vl.ID.Hex(),
|
||||
Equal: req.IdHash == idHash,
|
||||
UserIDs: userIDs,
|
||||
|
||||
+1
@@ -220,6 +220,7 @@ func (c *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, p
|
||||
if err := c.rdb.HDel(ctx, key, fields...).Err(); err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
|
||||
if c.localCache != nil {
|
||||
c.removeLocalTokenCache(ctx, key)
|
||||
}
|
||||
|
||||
@@ -57,8 +57,8 @@ func (s *seqConversationMongo) Malloc(ctx context.Context, conversationID string
|
||||
}
|
||||
filter := map[string]any{"conversation_id": conversationID}
|
||||
update := map[string]any{
|
||||
"$inc": map[string]any{"max_seq": size},
|
||||
"$setOnInsert": map[string]any{"min_seq": int64(0)},
|
||||
"$inc": map[string]any{"max_seq": size},
|
||||
"$set": map[string]any{"min_seq": int64(0)},
|
||||
}
|
||||
opt := options.FindOneAndUpdate().SetUpsert(true).SetReturnDocument(options.After).SetProjection(map[string]any{"_id": 0, "max_seq": 1})
|
||||
lastSeq, err := mongoutil.FindOneAndUpdate[int64](ctx, s.coll, filter, update, opt)
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
// Copyright © 2024 OpenIM open source community. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package kafka
|
||||
|
||||
type TLSConfig struct {
|
||||
EnableTLS bool `yaml:"enableTLS"`
|
||||
CACrt string `yaml:"caCrt"`
|
||||
ClientCrt string `yaml:"clientCrt"`
|
||||
ClientKey string `yaml:"clientKey"`
|
||||
ClientKeyPwd string `yaml:"clientKeyPwd"`
|
||||
InsecureSkipVerify bool `yaml:"insecureSkipVerify"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Username string `yaml:"username"`
|
||||
Password string `yaml:"password"`
|
||||
ProducerAck string `yaml:"producerAck"`
|
||||
CompressType string `yaml:"compressType"`
|
||||
Addr []string `yaml:"addr"`
|
||||
TLS TLSConfig `yaml:"tls"`
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
// Copyright © 2023 OpenIM. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package kafka
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/IBM/sarama"
|
||||
"github.com/openimsdk/tools/log"
|
||||
)
|
||||
|
||||
type MConsumerGroup struct {
|
||||
sarama.ConsumerGroup
|
||||
groupID string
|
||||
topics []string
|
||||
}
|
||||
|
||||
func NewMConsumerGroup(conf *Config, groupID string, topics []string, autoCommitEnable bool) (*MConsumerGroup, error) {
|
||||
config, err := BuildConsumerGroupConfig(conf, sarama.OffsetNewest, autoCommitEnable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
group, err := NewConsumerGroup(config, conf.Addr, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &MConsumerGroup{
|
||||
ConsumerGroup: group,
|
||||
groupID: groupID,
|
||||
topics: topics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (mc *MConsumerGroup) GetContextFromMsg(cMsg *sarama.ConsumerMessage) context.Context {
|
||||
return GetContextWithMQHeader(cMsg.Headers)
|
||||
}
|
||||
|
||||
func (mc *MConsumerGroup) RegisterHandleAndConsumer(ctx context.Context, handler sarama.ConsumerGroupHandler) {
|
||||
for {
|
||||
err := mc.ConsumerGroup.Consume(ctx, mc.topics, handler)
|
||||
if errors.Is(err, sarama.ErrClosedConsumerGroup) {
|
||||
return
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
log.ZWarn(ctx, "consume err", err, "topic", mc.topics, "groupID", mc.groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *MConsumerGroup) Close() error {
|
||||
return mc.ConsumerGroup.Close()
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
// Copyright © 2023 OpenIM. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package kafka
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/IBM/sarama"
|
||||
"github.com/openimsdk/tools/errs"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// Producer represents a Kafka producer.
|
||||
type Producer struct {
|
||||
addr []string
|
||||
topic string
|
||||
config *sarama.Config
|
||||
producer sarama.SyncProducer
|
||||
}
|
||||
|
||||
func NewKafkaProducer(config *sarama.Config, addr []string, topic string) (*Producer, error) {
|
||||
producer, err := NewProducer(config, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Producer{
|
||||
addr: addr,
|
||||
topic: topic,
|
||||
config: config,
|
||||
producer: producer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendMessage sends a message to the Kafka topic configured in the Producer.
|
||||
func (p *Producer) SendMessage(ctx context.Context, key string, msg proto.Message) (int32, int64, error) {
|
||||
// Marshal the protobuf message
|
||||
bMsg, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return 0, 0, errs.WrapMsg(err, "kafka proto Marshal err")
|
||||
}
|
||||
if len(bMsg) == 0 {
|
||||
return 0, 0, errs.WrapMsg(errEmptyMsg, "kafka proto Marshal err")
|
||||
}
|
||||
|
||||
// Prepare Kafka message
|
||||
kMsg := &sarama.ProducerMessage{
|
||||
Topic: p.topic,
|
||||
Key: sarama.StringEncoder(key),
|
||||
Value: sarama.ByteEncoder(bMsg),
|
||||
}
|
||||
|
||||
// Validate message key and value
|
||||
if kMsg.Key.Length() == 0 || kMsg.Value.Length() == 0 {
|
||||
return 0, 0, errs.Wrap(errEmptyMsg)
|
||||
}
|
||||
|
||||
// Attach context metadata as headers
|
||||
header, err := GetMQHeaderWithContext(ctx)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
kMsg.Headers = header
|
||||
|
||||
// Send the message
|
||||
partition, offset, err := p.producer.SendMessage(kMsg)
|
||||
if err != nil {
|
||||
return 0, 0, errs.WrapMsg(err, "p.producer.SendMessage error")
|
||||
}
|
||||
|
||||
return partition, offset, nil
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package kafka
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/IBM/sarama"
|
||||
"github.com/openimsdk/tools/errs"
|
||||
)
|
||||
|
||||
func BuildConsumerGroupConfig(conf *Config, initial int64, autoCommitEnable bool) (*sarama.Config, error) {
|
||||
kfk := sarama.NewConfig()
|
||||
kfk.Version = sarama.V2_0_0_0
|
||||
kfk.Consumer.Offsets.Initial = initial
|
||||
kfk.Consumer.Offsets.AutoCommit.Enable = autoCommitEnable
|
||||
kfk.Consumer.Return.Errors = false
|
||||
if conf.Username != "" || conf.Password != "" {
|
||||
kfk.Net.SASL.Enable = true
|
||||
kfk.Net.SASL.User = conf.Username
|
||||
kfk.Net.SASL.Password = conf.Password
|
||||
}
|
||||
if conf.TLS.EnableTLS {
|
||||
tls, err := newTLSConfig(conf.TLS.ClientCrt, conf.TLS.ClientKey, conf.TLS.CACrt, []byte(conf.TLS.ClientKeyPwd), conf.TLS.InsecureSkipVerify)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kfk.Net.TLS.Config = tls
|
||||
kfk.Net.TLS.Enable = true
|
||||
}
|
||||
return kfk, nil
|
||||
}
|
||||
|
||||
func NewConsumerGroup(conf *sarama.Config, addr []string, groupID string) (sarama.ConsumerGroup, error) {
|
||||
cg, err := sarama.NewConsumerGroup(addr, groupID, conf)
|
||||
if err != nil {
|
||||
return nil, errs.WrapMsg(err, "NewConsumerGroup failed", "addr", addr, "groupID", groupID, "conf", *conf)
|
||||
}
|
||||
return cg, nil
|
||||
}
|
||||
|
||||
func BuildProducerConfig(conf Config) (*sarama.Config, error) {
|
||||
kfk := sarama.NewConfig()
|
||||
kfk.Producer.Return.Successes = true
|
||||
kfk.Producer.Return.Errors = true
|
||||
kfk.Producer.Partitioner = sarama.NewHashPartitioner
|
||||
if conf.Username != "" || conf.Password != "" {
|
||||
kfk.Net.SASL.Enable = true
|
||||
kfk.Net.SASL.User = conf.Username
|
||||
kfk.Net.SASL.Password = conf.Password
|
||||
}
|
||||
switch strings.ToLower(conf.ProducerAck) {
|
||||
case "no_response":
|
||||
kfk.Producer.RequiredAcks = sarama.NoResponse
|
||||
case "wait_for_local":
|
||||
kfk.Producer.RequiredAcks = sarama.WaitForLocal
|
||||
case "wait_for_all":
|
||||
kfk.Producer.RequiredAcks = sarama.WaitForAll
|
||||
default:
|
||||
kfk.Producer.RequiredAcks = sarama.WaitForAll
|
||||
}
|
||||
if conf.CompressType == "" {
|
||||
kfk.Producer.Compression = sarama.CompressionNone
|
||||
} else {
|
||||
if err := kfk.Producer.Compression.UnmarshalText(bytes.ToLower([]byte(conf.CompressType))); err != nil {
|
||||
return nil, errs.WrapMsg(err, "UnmarshalText failed", "compressType", conf.CompressType)
|
||||
}
|
||||
}
|
||||
if conf.TLS.EnableTLS {
|
||||
tls, err := newTLSConfig(conf.TLS.ClientCrt, conf.TLS.ClientKey, conf.TLS.CACrt, []byte(conf.TLS.ClientKeyPwd), conf.TLS.InsecureSkipVerify)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kfk.Net.TLS.Config = tls
|
||||
kfk.Net.TLS.Enable = true
|
||||
}
|
||||
return kfk, nil
|
||||
}
|
||||
|
||||
func NewProducer(conf *sarama.Config, addr []string) (sarama.SyncProducer, error) {
|
||||
producer, err := sarama.NewSyncProducer(addr, conf)
|
||||
if err != nil {
|
||||
return nil, errs.WrapMsg(err, "NewSyncProducer failed", "addr", addr, "conf", *conf)
|
||||
}
|
||||
return producer, nil
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
// Copyright © 2024 OpenIM open source community. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package kafka
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"os"
|
||||
|
||||
"github.com/openimsdk/tools/errs"
|
||||
)
|
||||
|
||||
// decryptPEM decrypts a PEM block using a password.
|
||||
func decryptPEM(data []byte, passphrase []byte) ([]byte, error) {
|
||||
if len(passphrase) == 0 {
|
||||
return data, nil
|
||||
}
|
||||
b, _ := pem.Decode(data)
|
||||
d, err := x509.DecryptPEMBlock(b, passphrase)
|
||||
if err != nil {
|
||||
return nil, errs.WrapMsg(err, "DecryptPEMBlock failed")
|
||||
}
|
||||
return pem.EncodeToMemory(&pem.Block{
|
||||
Type: b.Type,
|
||||
Bytes: d,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func readEncryptablePEMBlock(path string, pwd []byte) ([]byte, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, errs.WrapMsg(err, "ReadFile failed", "path", path)
|
||||
}
|
||||
return decryptPEM(data, pwd)
|
||||
}
|
||||
|
||||
// newTLSConfig setup the TLS config from general config file.
|
||||
func newTLSConfig(clientCertFile, clientKeyFile, caCertFile string, keyPwd []byte, insecureSkipVerify bool) (*tls.Config, error) {
|
||||
var tlsConfig tls.Config
|
||||
if clientCertFile != "" && clientKeyFile != "" {
|
||||
certPEMBlock, err := os.ReadFile(clientCertFile)
|
||||
if err != nil {
|
||||
return nil, errs.WrapMsg(err, "ReadFile failed", "clientCertFile", clientCertFile)
|
||||
}
|
||||
keyPEMBlock, err := readEncryptablePEMBlock(clientKeyFile, keyPwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
|
||||
if err != nil {
|
||||
return nil, errs.WrapMsg(err, "X509KeyPair failed")
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
if caCertFile != "" {
|
||||
caCert, err := os.ReadFile(caCertFile)
|
||||
if err != nil {
|
||||
return nil, errs.WrapMsg(err, "ReadFile failed", "caCertFile", caCertFile)
|
||||
}
|
||||
caCertPool := x509.NewCertPool()
|
||||
if ok := caCertPool.AppendCertsFromPEM(caCert); !ok {
|
||||
return nil, errs.New("AppendCertsFromPEM failed")
|
||||
}
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = insecureSkipVerify
|
||||
return &tlsConfig, nil
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
package kafka
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/IBM/sarama"
|
||||
"github.com/openimsdk/protocol/constant"
|
||||
"github.com/openimsdk/tools/mcontext"
|
||||
)
|
||||
|
||||
var errEmptyMsg = errors.New("kafka binary msg is empty")
|
||||
|
||||
// GetMQHeaderWithContext extracts message queue headers from the context.
|
||||
func GetMQHeaderWithContext(ctx context.Context) ([]sarama.RecordHeader, error) {
|
||||
operationID, opUserID, platform, connID, err := mcontext.GetCtxInfos(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []sarama.RecordHeader{
|
||||
{Key: []byte(constant.OperationID), Value: []byte(operationID)},
|
||||
{Key: []byte(constant.OpUserID), Value: []byte(opUserID)},
|
||||
{Key: []byte(constant.OpUserPlatform), Value: []byte(platform)},
|
||||
{Key: []byte(constant.ConnID), Value: []byte(connID)},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetContextWithMQHeader creates a context from message queue headers.
|
||||
func GetContextWithMQHeader(header []*sarama.RecordHeader) context.Context {
|
||||
var values []string
|
||||
for _, recordHeader := range header {
|
||||
values = append(values, string(recordHeader.Value))
|
||||
}
|
||||
return mcontext.WithMustInfoCtx(values) // Attach extracted values to context
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
// Copyright © 2024 OpenIM open source community. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package kafka
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/IBM/sarama"
|
||||
"github.com/openimsdk/tools/errs"
|
||||
)
|
||||
|
||||
func CheckTopics(ctx context.Context, conf *Config, topics []string) error {
|
||||
kfk, err := BuildConsumerGroupConfig(conf, sarama.OffsetNewest, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cli, err := sarama.NewClient(conf.Addr, kfk)
|
||||
if err != nil {
|
||||
return errs.WrapMsg(err, "NewClient failed", "config: ", fmt.Sprintf("%+v", conf))
|
||||
}
|
||||
defer cli.Close()
|
||||
|
||||
existingTopics, err := cli.Topics()
|
||||
if err != nil {
|
||||
return errs.WrapMsg(err, "Failed to list topics")
|
||||
}
|
||||
|
||||
existingTopicsMap := make(map[string]bool)
|
||||
for _, t := range existingTopics {
|
||||
existingTopicsMap[t] = true
|
||||
}
|
||||
|
||||
for _, topic := range topics {
|
||||
if !existingTopicsMap[topic] {
|
||||
return errs.New("topic not exist", "topic", topic).Wrap()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CheckHealth(ctx context.Context, conf *Config) error {
|
||||
kfk, err := BuildConsumerGroupConfig(conf, sarama.OffsetNewest, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cli, err := sarama.NewClient(conf.Addr, kfk)
|
||||
if err != nil {
|
||||
return errs.WrapMsg(err, "NewClient failed", "config: ", fmt.Sprintf("%+v", conf))
|
||||
}
|
||||
defer cli.Close()
|
||||
|
||||
// Get broker list
|
||||
brokers := cli.Brokers()
|
||||
if len(brokers) == 0 {
|
||||
return errs.New("no brokers found").Wrap()
|
||||
}
|
||||
|
||||
// Check if all brokers are reachable
|
||||
for _, broker := range brokers {
|
||||
if err := broker.Open(kfk); err != nil {
|
||||
return errs.WrapMsg(err, "failed to open broker", "broker", broker.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
+11
-16
@@ -47,15 +47,15 @@ func New[V any](opts ...Option) Cache[V] {
|
||||
if opt.localSlotNum > 0 && opt.localSlotSize > 0 {
|
||||
createSimpleLRU := func() lru.LRU[string, V] {
|
||||
if opt.expirationEvict {
|
||||
return lru.NewExpirationLRU[string, V](opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict)
|
||||
return lru.NewExpirationLRU(opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict)
|
||||
} else {
|
||||
return lru.NewLazyLRU[string, V](opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict)
|
||||
return lru.NewLazyLRU(opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict)
|
||||
}
|
||||
}
|
||||
if opt.localSlotNum == 1 {
|
||||
c.local = createSimpleLRU()
|
||||
} else {
|
||||
c.local = lru.NewSlotLRU[string, V](opt.localSlotNum, LRUStringHash, createSimpleLRU)
|
||||
c.local = lru.NewSlotLRU(opt.localSlotNum, LRUStringHash, createSimpleLRU)
|
||||
}
|
||||
if opt.linkSlotNum > 0 {
|
||||
c.link = link.New(opt.linkSlotNum)
|
||||
@@ -71,19 +71,14 @@ type cache[V any] struct {
|
||||
}
|
||||
|
||||
func (c *cache[V]) onEvict(key string, value V) {
|
||||
if c.link != nil {
|
||||
// Do not delete other keys while the underlying LRU still holds its lock;
|
||||
// defer linked deletions to avoid re-entering the same slot and deadlocking.
|
||||
if lks := c.link.Del(key); len(lks) > 0 {
|
||||
go c.delLinked(key, lks)
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = value
|
||||
|
||||
func (c *cache[V]) delLinked(src string, keys map[string]struct{}) {
|
||||
for k := range keys {
|
||||
if src != k {
|
||||
c.local.Del(k)
|
||||
if c.link != nil {
|
||||
lks := c.link.Del(key)
|
||||
for k := range lks {
|
||||
if key != k { // prevent deadlock
|
||||
c.local.Del(k)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -110,7 +105,7 @@ func (c *cache[V]) Get(ctx context.Context, key string, fetch func(ctx context.C
|
||||
func (c *cache[V]) GetLink(ctx context.Context, key string, fetch func(ctx context.Context) (V, error), link ...string) (V, error) {
|
||||
if c.local != nil {
|
||||
return c.local.Get(key, func() (V, error) {
|
||||
if len(link) > 0 && c.link != nil {
|
||||
if len(link) > 0 {
|
||||
c.link.Link(key, link...)
|
||||
}
|
||||
return fetch(ctx)
|
||||
|
||||
@@ -22,8 +22,6 @@ import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/localcache/lru"
|
||||
)
|
||||
|
||||
func TestName(t *testing.T) {
|
||||
@@ -93,68 +91,3 @@ func TestName(t *testing.T) {
|
||||
t.Log("del", del.Load())
|
||||
// 137.35s
|
||||
}
|
||||
|
||||
// Test deadlock scenario when eviction callback deletes a linked key that hashes to the same slot.
|
||||
func TestCacheEvictDeadlock(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
c := New[string](WithLocalSlotNum(1), WithLocalSlotSize(1), WithLazy())
|
||||
|
||||
if _, err := c.GetLink(ctx, "k1", func(ctx context.Context) (string, error) {
|
||||
return "v1", nil
|
||||
}, "k2"); err != nil {
|
||||
t.Fatalf("seed cache failed: %v", err)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
_, _ = c.GetLink(ctx, "k2", func(ctx context.Context) (string, error) {
|
||||
return "v2", nil
|
||||
}, "k1")
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// expected to finish quickly; current implementation deadlocks here.
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("GetLink deadlocked during eviction of linked key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpirationLRUGetBatch(t *testing.T) {
|
||||
l := lru.NewExpirationLRU[string, string](2, time.Minute, time.Second*5, EmptyTarget{}, nil)
|
||||
|
||||
keys := []string{"a", "b"}
|
||||
values, err := l.GetBatch(keys, func(keys []string) (map[string]string, error) {
|
||||
res := make(map[string]string)
|
||||
for _, k := range keys {
|
||||
res[k] = k + "_v"
|
||||
}
|
||||
return res, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(values) != len(keys) {
|
||||
t.Fatalf("expected %d values, got %d", len(keys), len(values))
|
||||
}
|
||||
for _, k := range keys {
|
||||
if v, ok := values[k]; !ok || v != k+"_v" {
|
||||
t.Fatalf("unexpected value for %s: %q, ok=%v", k, v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
// second batch should hit cache
|
||||
values, err = l.GetBatch(keys, func(keys []string) (map[string]string, error) {
|
||||
t.Fatalf("should not fetch on cache hit")
|
||||
return nil, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on cache hit: %v", err)
|
||||
}
|
||||
for _, k := range keys {
|
||||
if v, ok := values[k]; !ok || v != k+"_v" {
|
||||
t.Fatalf("unexpected cached value for %s: %q, ok=%v", k, v, ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,10 +15,11 @@
|
||||
package localcache
|
||||
|
||||
import (
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -32,6 +33,10 @@ func InitLocalCache(localCache *config.LocalCache) {
|
||||
Local config.CacheConfig
|
||||
Keys []string
|
||||
}{
|
||||
{
|
||||
Local: localCache.Auth,
|
||||
Keys: []string{cachekey.UidPidToken},
|
||||
},
|
||||
{
|
||||
Local: localCache.User,
|
||||
Keys: []string{cachekey.UserInfoKey, cachekey.UserGlobalRecvMsgOptKey},
|
||||
|
||||
@@ -52,53 +52,8 @@ type ExpirationLRU[K comparable, V any] struct {
|
||||
}
|
||||
|
||||
func (x *ExpirationLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) (map[K]V, error) {
|
||||
var (
|
||||
err error
|
||||
results = make(map[K]V)
|
||||
misses = make([]K, 0, len(keys))
|
||||
)
|
||||
|
||||
for _, key := range keys {
|
||||
x.lock.Lock()
|
||||
v, ok := x.core.Get(key)
|
||||
x.lock.Unlock()
|
||||
if ok {
|
||||
x.target.IncrGetHit()
|
||||
v.lock.RLock()
|
||||
results[key] = v.value
|
||||
if v.err != nil && err == nil {
|
||||
err = v.err
|
||||
}
|
||||
v.lock.RUnlock()
|
||||
continue
|
||||
}
|
||||
misses = append(misses, key)
|
||||
}
|
||||
|
||||
if len(misses) == 0 {
|
||||
return results, err
|
||||
}
|
||||
|
||||
fetchValues, fetchErr := fetch(misses)
|
||||
if fetchErr != nil && err == nil {
|
||||
err = fetchErr
|
||||
}
|
||||
|
||||
for key, val := range fetchValues {
|
||||
results[key] = val
|
||||
if fetchErr != nil {
|
||||
x.target.IncrGetFailed()
|
||||
continue
|
||||
}
|
||||
x.target.IncrGetSuccess()
|
||||
item := &expirationLruItem[V]{value: val}
|
||||
x.lock.Lock()
|
||||
x.core.Add(key, item)
|
||||
x.lock.Unlock()
|
||||
}
|
||||
|
||||
// any keys not returned from fetch remain absent (no cache write)
|
||||
return results, err
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (x *ExpirationLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) {
|
||||
|
||||
@@ -35,7 +35,7 @@ type slotLRU[K comparable, V any] struct {
|
||||
func (x *slotLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) (map[K]V, error) {
|
||||
var (
|
||||
slotKeys = make(map[uint64][]K)
|
||||
vs = make(map[K]V)
|
||||
kVs = make(map[K]V)
|
||||
)
|
||||
|
||||
for _, k := range keys {
|
||||
@@ -49,10 +49,10 @@ func (x *slotLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)
|
||||
return nil, err
|
||||
}
|
||||
for key, value := range batches {
|
||||
vs[key] = value
|
||||
kVs[key] = value
|
||||
}
|
||||
}
|
||||
return vs, nil
|
||||
return kVs, nil
|
||||
}
|
||||
|
||||
func (x *slotLRU[K, V]) getIndex(k K) uint64 {
|
||||
|
||||
@@ -72,7 +72,7 @@ func Main(conf string, del time.Duration) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
mongodbConfig, err := readConfig[config.Mongo](conf, config.MongodbConfigFileName)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,735 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/apistruct"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/openimsdk/protocol/auth"
|
||||
"github.com/openimsdk/protocol/constant"
|
||||
"github.com/openimsdk/protocol/group"
|
||||
"github.com/openimsdk/protocol/sdkws"
|
||||
pbuser "github.com/openimsdk/protocol/user"
|
||||
"github.com/openimsdk/tools/log"
|
||||
"github.com/openimsdk/tools/system/program"
|
||||
)
|
||||
|
||||
// 1. Create 100K New Users
|
||||
// 2. Create 100 100K Groups
|
||||
// 3. Create 1000 999 Groups
|
||||
// 4. Send message to 100K Groups every second
|
||||
// 5. Send message to 999 Groups every minute
|
||||
|
||||
var (
|
||||
// Use default userIDs List for testing, need to be created.
|
||||
TestTargetUserList = []string{
|
||||
// "<need-update-it>",
|
||||
}
|
||||
// DefaultGroupID = "<need-update-it>" // Use default group ID for testing, need to be created.
|
||||
)
|
||||
|
||||
var (
|
||||
ApiAddress string
|
||||
|
||||
// API method
|
||||
GetAdminToken = "/auth/get_admin_token"
|
||||
UserCheck = "/user/account_check"
|
||||
CreateUser = "/user/user_register"
|
||||
ImportFriend = "/friend/import_friend"
|
||||
InviteToGroup = "/group/invite_user_to_group"
|
||||
GetGroupMemberInfo = "/group/get_group_members_info"
|
||||
SendMsg = "/msg/send_msg"
|
||||
CreateGroup = "/group/create_group"
|
||||
GetUserToken = "/auth/user_token"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxUser = 100000
|
||||
Max100KGroup = 100
|
||||
Max999Group = 1000
|
||||
MaxInviteUserLimit = 999
|
||||
|
||||
CreateUserTicker = 1 * time.Second
|
||||
CreateGroupTicker = 1 * time.Second
|
||||
Create100KGroupTicker = 1 * time.Second
|
||||
Create999GroupTicker = 1 * time.Second
|
||||
SendMsgTo100KGroupTicker = 1 * time.Second
|
||||
SendMsgTo999GroupTicker = 1 * time.Minute
|
||||
)
|
||||
|
||||
type BaseResp struct {
|
||||
ErrCode int `json:"errCode"`
|
||||
ErrMsg string `json:"errMsg"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
type StressTest struct {
|
||||
Conf *conf
|
||||
AdminUserID string
|
||||
AdminToken string
|
||||
DefaultGroupID string
|
||||
DefaultUserID string
|
||||
UserCounter int
|
||||
CreateUserCounter int
|
||||
Create100kGroupCounter int
|
||||
Create999GroupCounter int
|
||||
MsgCounter int
|
||||
CreatedUsers []string
|
||||
CreatedGroups []string
|
||||
Mutex sync.Mutex
|
||||
Ctx context.Context
|
||||
Cancel context.CancelFunc
|
||||
HttpClient *http.Client
|
||||
Wg sync.WaitGroup
|
||||
Once sync.Once
|
||||
}
|
||||
|
||||
type conf struct {
|
||||
Share config.Share
|
||||
Api config.API
|
||||
}
|
||||
|
||||
func initConfig(configDir string) (*config.Share, *config.API, error) {
|
||||
var (
|
||||
share = &config.Share{}
|
||||
apiConfig = &config.API{}
|
||||
)
|
||||
|
||||
err := config.Load(configDir, config.ShareFileName, config.EnvPrefixMap[config.ShareFileName], share)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = config.Load(configDir, config.OpenIMAPICfgFileName, config.EnvPrefixMap[config.OpenIMAPICfgFileName], apiConfig)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return share, apiConfig, nil
|
||||
}
|
||||
|
||||
// Post Request
|
||||
func (st *StressTest) PostRequest(ctx context.Context, url string, reqbody any) ([]byte, error) {
|
||||
// Marshal body
|
||||
jsonBody, err := json.Marshal(reqbody)
|
||||
if err != nil {
|
||||
log.ZError(ctx, "Failed to marshal request body", err, "url", url, "reqbody", reqbody)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("operationID", st.AdminUserID)
|
||||
if st.AdminToken != "" {
|
||||
req.Header.Set("token", st.AdminToken)
|
||||
}
|
||||
|
||||
// log.ZInfo(ctx, "Header info is ", "Content-Type", "application/json", "operationID", st.AdminUserID, "token", st.AdminToken)
|
||||
|
||||
resp, err := st.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody)
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.ZError(ctx, "Failed to read response body", err, "url", url)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var baseResp BaseResp
|
||||
if err := json.Unmarshal(respBody, &baseResp); err != nil {
|
||||
log.ZError(ctx, "Failed to unmarshal response body", err, "url", url, "respBody", string(respBody))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if baseResp.ErrCode != 0 {
|
||||
err = fmt.Errorf(baseResp.ErrMsg)
|
||||
log.ZError(ctx, "Failed to send request", err, "url", url, "reqbody", reqbody, "resp", baseResp)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return baseResp.Data, nil
|
||||
}
|
||||
|
||||
func (st *StressTest) GetAdminToken(ctx context.Context) (string, error) {
|
||||
req := auth.GetAdminTokenReq{
|
||||
Secret: st.Conf.Share.Secret,
|
||||
UserID: st.AdminUserID,
|
||||
}
|
||||
|
||||
resp, err := st.PostRequest(ctx, ApiAddress+GetAdminToken, &req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
data := &auth.GetAdminTokenResp{}
|
||||
if err := json.Unmarshal(resp, &data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return data.Token, nil
|
||||
}
|
||||
|
||||
func (st *StressTest) CheckUser(ctx context.Context, userIDs []string) ([]string, error) {
|
||||
req := pbuser.AccountCheckReq{
|
||||
CheckUserIDs: userIDs,
|
||||
}
|
||||
|
||||
resp, err := st.PostRequest(ctx, ApiAddress+UserCheck, &req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data := &pbuser.AccountCheckResp{}
|
||||
if err := json.Unmarshal(resp, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
unRegisteredUserIDs := make([]string, 0)
|
||||
|
||||
for _, res := range data.Results {
|
||||
if res.AccountStatus == constant.UnRegistered {
|
||||
unRegisteredUserIDs = append(unRegisteredUserIDs, res.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
return unRegisteredUserIDs, nil
|
||||
}
|
||||
|
||||
func (st *StressTest) CreateUser(ctx context.Context, userID string) (string, error) {
|
||||
user := &sdkws.UserInfo{
|
||||
UserID: userID,
|
||||
Nickname: userID,
|
||||
}
|
||||
|
||||
req := pbuser.UserRegisterReq{
|
||||
Users: []*sdkws.UserInfo{user},
|
||||
}
|
||||
|
||||
_, err := st.PostRequest(ctx, ApiAddress+CreateUser, &req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
st.UserCounter++
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func (st *StressTest) CreateUserBatch(ctx context.Context, userIDs []string) error {
|
||||
// The method can import a large number of users at once.
|
||||
var userList []*sdkws.UserInfo
|
||||
|
||||
defer st.Once.Do(
|
||||
func() {
|
||||
st.DefaultUserID = userIDs[0]
|
||||
fmt.Println("Default Send User Created ID:", st.DefaultUserID)
|
||||
})
|
||||
|
||||
needUserIDs, err := st.CheckUser(ctx, userIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, userID := range needUserIDs {
|
||||
user := &sdkws.UserInfo{
|
||||
UserID: userID,
|
||||
Nickname: userID,
|
||||
}
|
||||
userList = append(userList, user)
|
||||
}
|
||||
|
||||
req := pbuser.UserRegisterReq{
|
||||
Users: userList,
|
||||
}
|
||||
|
||||
_, err = st.PostRequest(ctx, ApiAddress+CreateUser, &req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
st.UserCounter += len(userList)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *StressTest) GetGroupMembersInfo(ctx context.Context, groupID string, userIDs []string) ([]string, error) {
|
||||
needInviteUserIDs := make([]string, 0)
|
||||
|
||||
const maxBatchSize = 500
|
||||
if len(userIDs) > maxBatchSize {
|
||||
for i := 0; i < len(userIDs); i += maxBatchSize {
|
||||
end := min(i+maxBatchSize, len(userIDs))
|
||||
batchUserIDs := userIDs[i:end]
|
||||
|
||||
// log.ZInfo(ctx, "Processing group members batch", "groupID", groupID, "batch", i/maxBatchSize+1,
|
||||
// "batchUserCount", len(batchUserIDs))
|
||||
|
||||
// Process a single batch
|
||||
batchReq := group.GetGroupMembersInfoReq{
|
||||
GroupID: groupID,
|
||||
UserIDs: batchUserIDs,
|
||||
}
|
||||
|
||||
resp, err := st.PostRequest(ctx, ApiAddress+GetGroupMemberInfo, &batchReq)
|
||||
if err != nil {
|
||||
log.ZError(ctx, "Batch query failed", err, "batch", i/maxBatchSize+1)
|
||||
continue
|
||||
}
|
||||
|
||||
data := &group.GetGroupMembersInfoResp{}
|
||||
if err := json.Unmarshal(resp, &data); err != nil {
|
||||
log.ZError(ctx, "Failed to parse batch response", err, "batch", i/maxBatchSize+1)
|
||||
continue
|
||||
}
|
||||
|
||||
// Process the batch results
|
||||
existingMembers := make(map[string]bool)
|
||||
for _, member := range data.Members {
|
||||
existingMembers[member.UserID] = true
|
||||
}
|
||||
|
||||
for _, userID := range batchUserIDs {
|
||||
if !existingMembers[userID] {
|
||||
needInviteUserIDs = append(needInviteUserIDs, userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return needInviteUserIDs, nil
|
||||
}
|
||||
|
||||
req := group.GetGroupMembersInfoReq{
|
||||
GroupID: groupID,
|
||||
UserIDs: userIDs,
|
||||
}
|
||||
|
||||
resp, err := st.PostRequest(ctx, ApiAddress+GetGroupMemberInfo, &req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data := &group.GetGroupMembersInfoResp{}
|
||||
if err := json.Unmarshal(resp, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingMembers := make(map[string]bool)
|
||||
for _, member := range data.Members {
|
||||
existingMembers[member.UserID] = true
|
||||
}
|
||||
|
||||
for _, userID := range userIDs {
|
||||
if !existingMembers[userID] {
|
||||
needInviteUserIDs = append(needInviteUserIDs, userID)
|
||||
}
|
||||
}
|
||||
|
||||
return needInviteUserIDs, nil
|
||||
}
|
||||
|
||||
func (st *StressTest) InviteToGroup(ctx context.Context, groupID string, userIDs []string) error {
|
||||
req := group.InviteUserToGroupReq{
|
||||
GroupID: groupID,
|
||||
InvitedUserIDs: userIDs,
|
||||
}
|
||||
_, err := st.PostRequest(ctx, ApiAddress+InviteToGroup, &req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *StressTest) SendMsg(ctx context.Context, userID string, groupID string) error {
|
||||
contentObj := map[string]any{
|
||||
// "content": fmt.Sprintf("index %d. The current time is %s", st.MsgCounter, time.Now().Format("2006-01-02 15:04:05.000")),
|
||||
"content": fmt.Sprintf("The current time is %s", time.Now().Format("2006-01-02 15:04:05.000")),
|
||||
}
|
||||
|
||||
req := &apistruct.SendMsgReq{
|
||||
SendMsg: apistruct.SendMsg{
|
||||
SendID: userID,
|
||||
SenderNickname: userID,
|
||||
GroupID: groupID,
|
||||
ContentType: constant.Text,
|
||||
SessionType: constant.ReadGroupChatType,
|
||||
Content: contentObj,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := st.PostRequest(ctx, ApiAddress+SendMsg, &req)
|
||||
if err != nil {
|
||||
log.ZError(ctx, "Failed to send message", err, "userID", userID, "req", &req)
|
||||
return err
|
||||
}
|
||||
|
||||
st.MsgCounter++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Max userIDs number is 1000
|
||||
func (st *StressTest) CreateGroup(ctx context.Context, groupID string, userID string, userIDsList []string) (string, error) {
|
||||
groupInfo := &sdkws.GroupInfo{
|
||||
GroupID: groupID,
|
||||
GroupName: groupID,
|
||||
GroupType: constant.WorkingGroup,
|
||||
}
|
||||
|
||||
req := group.CreateGroupReq{
|
||||
OwnerUserID: userID,
|
||||
MemberUserIDs: userIDsList,
|
||||
GroupInfo: groupInfo,
|
||||
}
|
||||
|
||||
resp := group.CreateGroupResp{}
|
||||
|
||||
response, err := st.PostRequest(ctx, ApiAddress+CreateGroup, &req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(response, &resp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// st.GroupCounter++
|
||||
|
||||
return resp.GroupInfo.GroupID, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
var configPath string
|
||||
// defaultConfigDir := filepath.Join("..", "..", "..", "..", "..", "config")
|
||||
// flag.StringVar(&configPath, "c", defaultConfigDir, "config path")
|
||||
flag.StringVar(&configPath, "c", "", "config path")
|
||||
flag.Parse()
|
||||
|
||||
if configPath == "" {
|
||||
_, _ = fmt.Fprintln(os.Stderr, "config path is empty")
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf(" Config Path: %s\n", configPath)
|
||||
|
||||
share, apiConfig, err := initConfig(configPath)
|
||||
if err != nil {
|
||||
program.ExitWithError(err)
|
||||
return
|
||||
}
|
||||
|
||||
ApiAddress = fmt.Sprintf("http://%s:%s", "127.0.0.1", fmt.Sprint(apiConfig.Api.Ports[0]))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// ch := make(chan struct{})
|
||||
|
||||
st := &StressTest{
|
||||
Conf: &conf{
|
||||
Share: *share,
|
||||
Api: *apiConfig,
|
||||
},
|
||||
AdminUserID: share.IMAdminUser.UserIDs[0],
|
||||
Ctx: ctx,
|
||||
Cancel: cancel,
|
||||
HttpClient: &http.Client{
|
||||
Timeout: 50 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-c
|
||||
fmt.Println("\nReceived stop signal, stopping...")
|
||||
|
||||
go func() {
|
||||
// time.Sleep(5 * time.Second)
|
||||
fmt.Println("Force exit")
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
st.Cancel()
|
||||
}()
|
||||
|
||||
token, err := st.GetAdminToken(st.Ctx)
|
||||
if err != nil {
|
||||
log.ZError(ctx, "Get Admin Token failed.", err, "AdminUserID", st.AdminUserID)
|
||||
}
|
||||
|
||||
st.AdminToken = token
|
||||
fmt.Println("Admin Token:", st.AdminToken)
|
||||
fmt.Println("ApiAddress:", ApiAddress)
|
||||
for i := 0; i < MaxUser; i++ {
|
||||
userID := fmt.Sprintf("v2_StressTest_User_%d", i)
|
||||
st.CreatedUsers = append(st.CreatedUsers, userID)
|
||||
st.CreateUserCounter++
|
||||
}
|
||||
|
||||
// err = st.CreateUserBatch(st.Ctx, st.CreatedUsers)
|
||||
// if err != nil {
|
||||
// log.ZError(ctx, "Create user failed.", err)
|
||||
// }
|
||||
|
||||
const batchSize = 1000
|
||||
totalUsers := len(st.CreatedUsers)
|
||||
successCount := 0
|
||||
|
||||
if st.DefaultUserID == "" && len(st.CreatedUsers) > 0 {
|
||||
st.DefaultUserID = st.CreatedUsers[0]
|
||||
}
|
||||
|
||||
for i := 0; i < totalUsers; i += batchSize {
|
||||
end := min(i+batchSize, totalUsers)
|
||||
|
||||
userBatch := st.CreatedUsers[i:end]
|
||||
log.ZInfo(st.Ctx, "Creating user batch", "batch", i/batchSize+1, "count", len(userBatch))
|
||||
|
||||
err = st.CreateUserBatch(st.Ctx, userBatch)
|
||||
if err != nil {
|
||||
log.ZError(st.Ctx, "Batch user creation failed", err, "batch", i/batchSize+1)
|
||||
} else {
|
||||
successCount += len(userBatch)
|
||||
log.ZInfo(st.Ctx, "Batch user creation succeeded", "batch", i/batchSize+1,
|
||||
"progress", fmt.Sprintf("%d/%d", successCount, totalUsers))
|
||||
}
|
||||
}
|
||||
|
||||
// Execute create 100k group
|
||||
st.Wg.Add(1)
|
||||
go func() {
|
||||
defer st.Wg.Done()
|
||||
|
||||
create100kGroupTicker := time.NewTicker(Create100KGroupTicker)
|
||||
defer create100kGroupTicker.Stop()
|
||||
|
||||
for i := 0; i < Max100KGroup; i++ {
|
||||
select {
|
||||
case <-st.Ctx.Done():
|
||||
log.ZInfo(st.Ctx, "Stop Create 100K Group")
|
||||
return
|
||||
|
||||
case <-create100kGroupTicker.C:
|
||||
// Create 100K groups
|
||||
st.Wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer st.Wg.Done()
|
||||
defer func() {
|
||||
st.Create100kGroupCounter++
|
||||
}()
|
||||
|
||||
groupID := fmt.Sprintf("v2_StressTest_Group_100K_%d", idx)
|
||||
|
||||
if _, err = st.CreateGroup(st.Ctx, groupID, st.DefaultUserID, TestTargetUserList); err != nil {
|
||||
log.ZError(st.Ctx, "Create group failed.", err)
|
||||
// continue
|
||||
}
|
||||
|
||||
for i := 0; i < MaxUser/MaxInviteUserLimit; i++ {
|
||||
InviteUserIDs := make([]string, 0)
|
||||
// ensure TargetUserList is in group
|
||||
InviteUserIDs = append(InviteUserIDs, TestTargetUserList...)
|
||||
|
||||
startIdx := max(i*MaxInviteUserLimit, 1)
|
||||
endIdx := min((i+1)*MaxInviteUserLimit, MaxUser)
|
||||
|
||||
for j := startIdx; j < endIdx; j++ {
|
||||
userCreatedID := fmt.Sprintf("v2_StressTest_User_%d", j)
|
||||
InviteUserIDs = append(InviteUserIDs, userCreatedID)
|
||||
}
|
||||
|
||||
if len(InviteUserIDs) == 0 {
|
||||
log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
InviteUserIDs, err := st.GetGroupMembersInfo(ctx, groupID, InviteUserIDs)
|
||||
if err != nil {
|
||||
log.ZError(st.Ctx, "GetGroupMembersInfo failed.", err, "groupID", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(InviteUserIDs) == 0 {
|
||||
log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
// Invite To Group
|
||||
if err = st.InviteToGroup(st.Ctx, groupID, InviteUserIDs); err != nil {
|
||||
log.ZError(st.Ctx, "Invite To Group failed.", err, "UserID", InviteUserIDs)
|
||||
continue
|
||||
// os.Exit(1)
|
||||
// return
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// create 999 groups
|
||||
st.Wg.Add(1)
|
||||
go func() {
|
||||
defer st.Wg.Done()
|
||||
|
||||
create999GroupTicker := time.NewTicker(Create999GroupTicker)
|
||||
defer create999GroupTicker.Stop()
|
||||
|
||||
for i := 0; i < Max999Group; i++ {
|
||||
select {
|
||||
case <-st.Ctx.Done():
|
||||
log.ZInfo(st.Ctx, "Stop Create 999 Group")
|
||||
return
|
||||
|
||||
case <-create999GroupTicker.C:
|
||||
// Create 999 groups
|
||||
st.Wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer st.Wg.Done()
|
||||
defer func() {
|
||||
st.Create999GroupCounter++
|
||||
}()
|
||||
|
||||
groupID := fmt.Sprintf("v2_StressTest_Group_1K_%d", idx)
|
||||
|
||||
if _, err = st.CreateGroup(st.Ctx, groupID, st.DefaultUserID, TestTargetUserList); err != nil {
|
||||
log.ZError(st.Ctx, "Create group failed.", err)
|
||||
// continue
|
||||
}
|
||||
for i := 0; i < MaxUser/MaxInviteUserLimit; i++ {
|
||||
InviteUserIDs := make([]string, 0)
|
||||
// ensure TargetUserList is in group
|
||||
InviteUserIDs = append(InviteUserIDs, TestTargetUserList...)
|
||||
|
||||
startIdx := max(i*MaxInviteUserLimit, 1)
|
||||
endIdx := min((i+1)*MaxInviteUserLimit, MaxUser)
|
||||
|
||||
for j := startIdx; j < endIdx; j++ {
|
||||
userCreatedID := fmt.Sprintf("v2_StressTest_User_%d", j)
|
||||
InviteUserIDs = append(InviteUserIDs, userCreatedID)
|
||||
}
|
||||
|
||||
if len(InviteUserIDs) == 0 {
|
||||
log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
InviteUserIDs, err := st.GetGroupMembersInfo(ctx, groupID, InviteUserIDs)
|
||||
if err != nil {
|
||||
log.ZError(st.Ctx, "GetGroupMembersInfo failed.", err, "groupID", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(InviteUserIDs) == 0 {
|
||||
log.ZWarn(st.Ctx, "InviteUserIDs is empty", nil, "groupID", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
// Invite To Group
|
||||
if err = st.InviteToGroup(st.Ctx, groupID, InviteUserIDs); err != nil {
|
||||
log.ZError(st.Ctx, "Invite To Group failed.", err, "UserID", InviteUserIDs)
|
||||
continue
|
||||
// os.Exit(1)
|
||||
// return
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Send message to 100K groups
|
||||
st.Wg.Wait()
|
||||
fmt.Println("All groups created successfully, starting to send messages...")
|
||||
log.ZInfo(ctx, "All groups created successfully, starting to send messages...")
|
||||
|
||||
var groups100K []string
|
||||
var groups999 []string
|
||||
|
||||
for i := 0; i < Max100KGroup; i++ {
|
||||
groupID := fmt.Sprintf("v2_StressTest_Group_100K_%d", i)
|
||||
groups100K = append(groups100K, groupID)
|
||||
}
|
||||
|
||||
for i := 0; i < Max999Group; i++ {
|
||||
groupID := fmt.Sprintf("v2_StressTest_Group_1K_%d", i)
|
||||
groups999 = append(groups999, groupID)
|
||||
}
|
||||
|
||||
send100kGroupLimiter := make(chan struct{}, 20)
|
||||
send999GroupLimiter := make(chan struct{}, 100)
|
||||
|
||||
// execute Send message to 100K groups
|
||||
go func() {
|
||||
ticker := time.NewTicker(SendMsgTo100KGroupTicker)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-st.Ctx.Done():
|
||||
log.ZInfo(st.Ctx, "Stop Send Message to 100K Group")
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
// Send message to 100K groups
|
||||
for _, groupID := range groups100K {
|
||||
send100kGroupLimiter <- struct{}{}
|
||||
go func(groupID string) {
|
||||
defer func() { <-send100kGroupLimiter }()
|
||||
if err := st.SendMsg(st.Ctx, st.DefaultUserID, groupID); err != nil {
|
||||
log.ZError(st.Ctx, "Send message to 100K group failed.", err)
|
||||
}
|
||||
}(groupID)
|
||||
}
|
||||
// log.ZInfo(st.Ctx, "Send message to 100K groups successfully.")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// execute Send message to 999 groups
|
||||
go func() {
|
||||
ticker := time.NewTicker(SendMsgTo999GroupTicker)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-st.Ctx.Done():
|
||||
log.ZInfo(st.Ctx, "Stop Send Message to 999 Group")
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
// Send message to 999 groups
|
||||
for _, groupID := range groups999 {
|
||||
send999GroupLimiter <- struct{}{}
|
||||
go func(groupID string) {
|
||||
defer func() { <-send999GroupLimiter }()
|
||||
|
||||
if err := st.SendMsg(st.Ctx, st.DefaultUserID, groupID); err != nil {
|
||||
log.ZError(st.Ctx, "Send message to 999 group failed.", err)
|
||||
}
|
||||
}(groupID)
|
||||
}
|
||||
// log.ZInfo(st.Ctx, "Send message to 999 groups successfully.")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
<-st.Ctx.Done()
|
||||
fmt.Println("Received signal to exit, shutting down...")
|
||||
}
|
||||
+1
-1
@@ -1 +1 @@
|
||||
v3.9.0
|
||||
main
|
||||
Reference in New Issue
Block a user