mirror of
https://github.com/openimsdk/open-im-server.git
synced 2026-04-28 06:19:20 +08:00
optimization: change the configuration file from being read globally … (#1935)
* optimization: change the configuration file from being read globally to being read independently. * optimization: change the configuration file from being read globally to being read independently. * optimization: change the configuration file from being read globally to being read independently. * optimization: config file changed to dependency injection. * fix: replace global config with dependency injection * fix: replace global config with dependency injection * fix: import the enough param * fix: import the enough param * fix: import the enough param * fix: fix the component check of path * fix: fix the kafka of tls is nil problem * fix: fix the TLS.CACrt is nil error * fix: fix the valiable shadows problem * fix: fix the comflect * optimization: message remove options. * fix: fix the param pass error * fix: find error * fix: find error * fix: find eror * fix: find error * fix: find error * fix: del the undifined func * fix: find error * fix: fix the error * fix: pass config * fix: find error * fix: find error * fix: find error * fix: find error * fix: find error * fix: fix the config * fix: fix the error * fix: fix the config pass error * fix: fix the eror * fix: fix the error * fix: fix the error * fix: fix the error * fix: find error * fix: fix the error * fix: fix the config * fix: add return err * fix: fix the err2 * fix: err * fix: fix the func * fix: del the chinese comment * fix: fix the func * fix: fix the gateway_test logic * fix: s3 * test * test * fix: not found --------- Co-authored-by: luhaoling <2198702716@qq.com> Co-authored-by: withchao <993506633@qq.com>
This commit is contained in:
+21
-22
@@ -26,61 +26,60 @@ import (
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
)
|
||||
|
||||
func Secret() jwt.Keyfunc {
|
||||
func Secret(secret string) jwt.Keyfunc {
|
||||
return func(token *jwt.Token) (any, error) {
|
||||
return []byte(config.Config.Secret), nil
|
||||
return []byte(secret), nil
|
||||
}
|
||||
}
|
||||
|
||||
func CheckAccessV3(ctx context.Context, ownerUserID string) (err error) {
|
||||
func CheckAccessV3(ctx context.Context, ownerUserID string, config *config.GlobalConfig) (err error) {
|
||||
opUserID := mcontext.GetOpUserID(ctx)
|
||||
if len(config.Config.Manager.UserID) > 0 && utils.IsContain(opUserID, config.Config.Manager.UserID) {
|
||||
if len(config.Manager.UserID) > 0 && utils.IsContain(opUserID, config.Manager.UserID) {
|
||||
return nil
|
||||
}
|
||||
if utils.IsContain(opUserID, config.Config.IMAdmin.UserID) {
|
||||
if utils.IsContain(opUserID, config.IMAdmin.UserID) {
|
||||
return nil
|
||||
}
|
||||
if opUserID == ownerUserID {
|
||||
return nil
|
||||
}
|
||||
return errs.Wrap(errs.ErrNoPermission, "CheckAccessV3: no permission for user "+opUserID)
|
||||
return errs.ErrNoPermission.Wrap("ownerUserID", ownerUserID)
|
||||
}
|
||||
|
||||
func IsAppManagerUid(ctx context.Context) bool {
|
||||
return (len(config.Config.Manager.UserID) > 0 && utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.Manager.UserID)) ||
|
||||
utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.IMAdmin.UserID)
|
||||
func IsAppManagerUid(ctx context.Context, config *config.GlobalConfig) bool {
|
||||
return (len(config.Manager.UserID) > 0 && utils.IsContain(mcontext.GetOpUserID(ctx), config.Manager.UserID)) ||
|
||||
utils.IsContain(mcontext.GetOpUserID(ctx), config.IMAdmin.UserID)
|
||||
}
|
||||
|
||||
func CheckAdmin(ctx context.Context) error {
|
||||
if len(config.Config.Manager.UserID) > 0 && utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.Manager.UserID) {
|
||||
func CheckAdmin(ctx context.Context, config *config.GlobalConfig) error {
|
||||
if len(config.Manager.UserID) > 0 && utils.IsContain(mcontext.GetOpUserID(ctx), config.Manager.UserID) {
|
||||
return nil
|
||||
}
|
||||
if utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.IMAdmin.UserID) {
|
||||
if utils.IsContain(mcontext.GetOpUserID(ctx), config.IMAdmin.UserID) {
|
||||
return nil
|
||||
}
|
||||
return errs.ErrNoPermission.Wrap(fmt.Sprintf("user %s is not admin userID", mcontext.GetOpUserID(ctx)))
|
||||
}
|
||||
|
||||
func CheckIMAdmin(ctx context.Context) error {
|
||||
if utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.IMAdmin.UserID) {
|
||||
func CheckIMAdmin(ctx context.Context, config *config.GlobalConfig) error {
|
||||
if utils.IsContain(mcontext.GetOpUserID(ctx), config.IMAdmin.UserID) {
|
||||
return nil
|
||||
}
|
||||
if len(config.Config.Manager.UserID) > 0 && utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.Manager.UserID) {
|
||||
if len(config.Manager.UserID) > 0 && utils.IsContain(mcontext.GetOpUserID(ctx), config.Manager.UserID) {
|
||||
return nil
|
||||
}
|
||||
return errs.ErrNoPermission.Wrap(fmt.Sprintf("user %s is not CheckIMAdmin userID", mcontext.GetOpUserID(ctx)))
|
||||
}
|
||||
|
||||
func ParseRedisInterfaceToken(redisToken any) (*tokenverify.Claims, error) {
|
||||
return tokenverify.GetClaimFromToken(string(redisToken.([]uint8)), Secret())
|
||||
func ParseRedisInterfaceToken(redisToken any, secret string) (*tokenverify.Claims, error) {
|
||||
return tokenverify.GetClaimFromToken(string(redisToken.([]uint8)), Secret(secret))
|
||||
}
|
||||
|
||||
func IsManagerUserID(opUserID string) bool {
|
||||
return (len(config.Config.Manager.UserID) > 0 && utils.IsContain(opUserID, config.Config.Manager.UserID)) || utils.IsContain(opUserID, config.Config.IMAdmin.UserID)
|
||||
func IsManagerUserID(opUserID string, config *config.GlobalConfig) bool {
|
||||
return (len(config.Manager.UserID) > 0 && utils.IsContain(opUserID, config.Manager.UserID)) || utils.IsContain(opUserID, config.IMAdmin.UserID)
|
||||
}
|
||||
|
||||
func WsVerifyToken(token, userID string, platformID int) error {
|
||||
claim, err := tokenverify.GetClaimFromToken(token, Secret())
|
||||
func WsVerifyToken(token, userID, secret string, platformID int) error {
|
||||
claim, err := tokenverify.GetClaimFromToken(token, Secret(secret))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
+20
-30
@@ -15,54 +15,44 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/constant"
|
||||
config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/openimsdk/open-im-server/v3/internal/api"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
)
|
||||
|
||||
type ApiCmd struct {
|
||||
*RootCmd
|
||||
initFunc func(config *config.GlobalConfig, port int, promPort int) error
|
||||
}
|
||||
|
||||
func NewApiCmd() *ApiCmd {
|
||||
ret := &ApiCmd{NewRootCmd("api")}
|
||||
ret := &ApiCmd{RootCmd: NewRootCmd("api"), initFunc: api.Start}
|
||||
ret.SetRootCmdPt(ret)
|
||||
|
||||
ret.addPreRun()
|
||||
ret.addRunE()
|
||||
return ret
|
||||
}
|
||||
|
||||
// AddApi configures the API command to run with specified ports for the API and Prometheus monitoring.
|
||||
// It ensures error handling for port retrieval and only proceeds if both port numbers are successfully obtained.
|
||||
func (a *ApiCmd) AddApi(f func(port int, promPort int) error) {
|
||||
func (a *ApiCmd) addPreRun() {
|
||||
a.Command.PreRun = func(cmd *cobra.Command, args []string) {
|
||||
a.port = a.getPortFlag(cmd)
|
||||
a.prometheusPort = a.getPrometheusPortFlag(cmd)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *ApiCmd) addRunE() {
|
||||
a.Command.RunE = func(cmd *cobra.Command, args []string) error {
|
||||
port, err := a.getPortFlag(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
promPort, err := a.getPrometheusPortFlag(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return f(port, promPort)
|
||||
return a.initFunc(a.config, a.port, a.prometheusPort)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *ApiCmd) GetPortFromConfig(portType string) (int, error) {
|
||||
func (a *ApiCmd) GetPortFromConfig(portType string) int {
|
||||
if portType == constant.FlagPort {
|
||||
if len(config2.Config.Api.OpenImApiPort) > 0 {
|
||||
return config2.Config.Api.OpenImApiPort[0], nil
|
||||
}
|
||||
return 0, errors.New("API port configuration is empty or missing")
|
||||
return a.config.Api.OpenImApiPort[0]
|
||||
} else if portType == constant.FlagPrometheusPort {
|
||||
if len(config2.Config.Prometheus.ApiPrometheusPort) > 0 {
|
||||
return config2.Config.Prometheus.ApiPrometheusPort[0], nil
|
||||
}
|
||||
return 0, errors.New("Prometheus port configuration is empty or missing")
|
||||
return a.config.Prometheus.ApiPrometheusPort[0]
|
||||
}
|
||||
return 0, fmt.Errorf("unknown port type: %s", portType)
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -14,29 +14,35 @@
|
||||
|
||||
package cmd
|
||||
|
||||
import "github.com/spf13/cobra"
|
||||
import (
|
||||
"github.com/openimsdk/open-im-server/v3/internal/tools"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type CronTaskCmd struct {
|
||||
*RootCmd
|
||||
initFunc func(config *config.GlobalConfig) error
|
||||
}
|
||||
|
||||
func NewCronTaskCmd() *CronTaskCmd {
|
||||
ret := &CronTaskCmd{NewRootCmd("cronTask", WithCronTaskLogName())}
|
||||
ret := &CronTaskCmd{RootCmd: NewRootCmd("cronTask", WithCronTaskLogName()),
|
||||
initFunc: tools.StartTask}
|
||||
ret.addRunE()
|
||||
ret.SetRootCmdPt(ret)
|
||||
return ret
|
||||
}
|
||||
|
||||
func (c *CronTaskCmd) addRunE(f func() error) {
|
||||
func (c *CronTaskCmd) addRunE() {
|
||||
c.Command.RunE = func(cmd *cobra.Command, args []string) error {
|
||||
return f()
|
||||
return c.initFunc(c.config)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CronTaskCmd) Exec(f func() error) error {
|
||||
c.addRunE(f)
|
||||
func (c *CronTaskCmd) Exec() error {
|
||||
return c.Execute()
|
||||
}
|
||||
|
||||
func (c *CronTaskCmd) GetPortFromConfig(portType string) (int, error) {
|
||||
return 0, nil
|
||||
func (c *CronTaskCmd) GetPortFromConfig(portType string) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -15,13 +15,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/constant"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
|
||||
"github.com/openimsdk/open-im-server/v3/internal/msggateway"
|
||||
v3config "github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type MsgGatewayCmd struct {
|
||||
@@ -30,6 +30,7 @@ type MsgGatewayCmd struct {
|
||||
|
||||
func NewMsgGatewayCmd() *MsgGatewayCmd {
|
||||
ret := &MsgGatewayCmd{NewRootCmd("msgGateway")}
|
||||
ret.addRunE()
|
||||
ret.SetRootCmdPt(ret)
|
||||
return ret
|
||||
}
|
||||
@@ -38,67 +39,39 @@ func (m *MsgGatewayCmd) AddWsPortFlag() {
|
||||
m.Command.Flags().IntP(constant.FlagWsPort, "w", 0, "ws server listen port")
|
||||
}
|
||||
|
||||
func (m *MsgGatewayCmd) getWsPortFlag(cmd *cobra.Command) (int, error) {
|
||||
func (m *MsgGatewayCmd) getWsPortFlag(cmd *cobra.Command) int {
|
||||
port, err := cmd.Flags().GetInt(constant.FlagWsPort)
|
||||
if err != nil {
|
||||
return 0, errs.Wrap(err, "error getting ws port flag")
|
||||
log.Println("Error getting ws port flag:", err)
|
||||
}
|
||||
if port == 0 {
|
||||
port, _ = m.PortFromConfig(constant.FlagWsPort)
|
||||
port = m.PortFromConfig(constant.FlagWsPort)
|
||||
}
|
||||
return port, nil
|
||||
return port
|
||||
}
|
||||
|
||||
func (m *MsgGatewayCmd) addRunE() {
|
||||
m.Command.RunE = func(cmd *cobra.Command, args []string) error {
|
||||
wsPort, err := m.getWsPortFlag(cmd)
|
||||
if err != nil {
|
||||
return errs.Wrap(err, "failed to get WS port flag")
|
||||
}
|
||||
port, err := m.getPortFlag(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prometheusPort, err := m.getPrometheusPortFlag(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return msggateway.RunWsAndServer(port, wsPort, prometheusPort)
|
||||
return msggateway.RunWsAndServer(m.config, m.getPortFlag(cmd), m.getWsPortFlag(cmd), m.getPrometheusPortFlag(cmd))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MsgGatewayCmd) Exec() error {
|
||||
m.addRunE()
|
||||
return m.Execute()
|
||||
}
|
||||
|
||||
func (m *MsgGatewayCmd) GetPortFromConfig(portType string) (int, error) {
|
||||
var port int
|
||||
var exists bool
|
||||
|
||||
func (m *MsgGatewayCmd) GetPortFromConfig(portType string) int {
|
||||
switch portType {
|
||||
case constant.FlagWsPort:
|
||||
if len(v3config.Config.LongConnSvr.OpenImWsPort) > 0 {
|
||||
port = v3config.Config.LongConnSvr.OpenImWsPort[0]
|
||||
exists = true
|
||||
}
|
||||
return m.config.LongConnSvr.OpenImWsPort[0]
|
||||
|
||||
case constant.FlagPort:
|
||||
if len(v3config.Config.LongConnSvr.OpenImMessageGatewayPort) > 0 {
|
||||
port = v3config.Config.LongConnSvr.OpenImMessageGatewayPort[0]
|
||||
exists = true
|
||||
}
|
||||
return m.config.LongConnSvr.OpenImMessageGatewayPort[0]
|
||||
|
||||
case constant.FlagPrometheusPort:
|
||||
if len(v3config.Config.Prometheus.MessageGatewayPrometheusPort) > 0 {
|
||||
port = v3config.Config.Prometheus.MessageGatewayPrometheusPort[0]
|
||||
exists = true
|
||||
}
|
||||
}
|
||||
return m.config.Prometheus.MessageGatewayPrometheusPort[0]
|
||||
|
||||
if !exists {
|
||||
return 0, errs.Wrap(errors.New("port type '%s' not found in configuration"), portType)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
|
||||
return port, nil
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ func TestMsgGatewayCmd_GetPortFromConfig(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.portType, func(t *testing.T) {
|
||||
got, _ := msgGatewayCmd.GetPortFromConfig(tt.portType)
|
||||
got := msgGatewayCmd.GetPortFromConfig(tt.portType)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -16,11 +16,10 @@ package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/constant"
|
||||
"github.com/openimsdk/open-im-server/v3/internal/msgtransfer"
|
||||
config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/openimsdk/open-im-server/v3/internal/msgtransfer"
|
||||
)
|
||||
|
||||
type MsgTransferCmd struct {
|
||||
@@ -29,37 +28,29 @@ type MsgTransferCmd struct {
|
||||
|
||||
func NewMsgTransferCmd() *MsgTransferCmd {
|
||||
ret := &MsgTransferCmd{NewRootCmd("msgTransfer")}
|
||||
ret.addRunE()
|
||||
ret.SetRootCmdPt(ret)
|
||||
return ret
|
||||
}
|
||||
|
||||
func (m *MsgTransferCmd) addRunE() {
|
||||
m.Command.RunE = func(cmd *cobra.Command, args []string) error {
|
||||
prometheusPort, err := m.getPrometheusPortFlag(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return msgtransfer.StartTransfer(prometheusPort)
|
||||
return msgtransfer.StartTransfer(m.config, m.getPrometheusPortFlag(cmd))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MsgTransferCmd) Exec() error {
|
||||
m.addRunE()
|
||||
return m.Execute()
|
||||
}
|
||||
|
||||
func (m *MsgTransferCmd) GetPortFromConfig(portType string) (int, error) {
|
||||
func (m *MsgTransferCmd) GetPortFromConfig(portType string) int {
|
||||
if portType == constant.FlagPort {
|
||||
return 0, nil
|
||||
return 0
|
||||
} else if portType == constant.FlagPrometheusPort {
|
||||
n := m.getTransferProgressFlagValue()
|
||||
|
||||
if n < len(config2.Config.Prometheus.MessageTransferPrometheusPort) {
|
||||
return config2.Config.Prometheus.MessageTransferPrometheusPort[n], nil
|
||||
}
|
||||
return 0, fmt.Errorf("index out of range for MessageTransferPrometheusPort with index %d", n)
|
||||
return m.config.Prometheus.MessageTransferPrometheusPort[n]
|
||||
}
|
||||
return 0, fmt.Errorf("unknown port type: %s", portType)
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *MsgTransferCmd) AddTransferProgressFlag() {
|
||||
@@ -67,10 +58,10 @@ func (m *MsgTransferCmd) AddTransferProgressFlag() {
|
||||
}
|
||||
|
||||
func (m *MsgTransferCmd) getTransferProgressFlagValue() int {
|
||||
nindex, err := m.Command.Flags().GetInt(constant.FlagTransferProgressIndex)
|
||||
nIndex, err := m.Command.Flags().GetInt(constant.FlagTransferProgressIndex)
|
||||
if err != nil {
|
||||
fmt.Println("get transfercmd error,make sure it is k8s env or not")
|
||||
fmt.Println("get transfer cmd error,make sure it is k8s env or not")
|
||||
return 0
|
||||
}
|
||||
return nindex
|
||||
return nIndex
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
type MsgUtilsCmd struct {
|
||||
cobra.Command
|
||||
MsgTool *tools.MsgTool
|
||||
}
|
||||
|
||||
func (m *MsgUtilsCmd) AddUserIDFlag() {
|
||||
@@ -135,7 +136,7 @@ func NewSeqCmd() *SeqCmd {
|
||||
|
||||
func (s *SeqCmd) GetSeqCmd() *cobra.Command {
|
||||
s.Command.Run = func(cmdLines *cobra.Command, args []string) {
|
||||
_, err := tools.InitMsgTool()
|
||||
_, err := tools.InitMsgTool(s.MsgTool.Config)
|
||||
if err != nil {
|
||||
util.ExitWithError(err)
|
||||
}
|
||||
|
||||
+22
-25
@@ -26,7 +26,7 @@ import (
|
||||
)
|
||||
|
||||
type RootCmdPt interface {
|
||||
GetPortFromConfig(portType string) (int, error)
|
||||
GetPortFromConfig(portType string) int
|
||||
}
|
||||
|
||||
type RootCmd struct {
|
||||
@@ -35,6 +35,11 @@ type RootCmd struct {
|
||||
port int
|
||||
prometheusPort int
|
||||
cmdItf RootCmdPt
|
||||
config *config.GlobalConfig
|
||||
}
|
||||
|
||||
func (rc *RootCmd) Port() int {
|
||||
return rc.port
|
||||
}
|
||||
|
||||
type CmdOpts struct {
|
||||
@@ -54,7 +59,7 @@ func WithLogName(logName string) func(*CmdOpts) {
|
||||
}
|
||||
|
||||
func NewRootCmd(name string, opts ...func(*CmdOpts)) *RootCmd {
|
||||
rootCmd := &RootCmd{Name: name}
|
||||
rootCmd := &RootCmd{Name: name, config: config.NewGlobalConfig()}
|
||||
cmd := cobra.Command{
|
||||
Use: "Start openIM application",
|
||||
Short: fmt.Sprintf(`Start %s `, name),
|
||||
@@ -96,7 +101,7 @@ func (rc *RootCmd) applyOptions(opts ...func(*CmdOpts)) *CmdOpts {
|
||||
}
|
||||
|
||||
func (rc *RootCmd) initializeLogger(cmdOpts *CmdOpts) error {
|
||||
logConfig := config.Config.Log
|
||||
logConfig := rc.config.Log
|
||||
|
||||
return log.InitFromConfig(
|
||||
|
||||
@@ -129,41 +134,36 @@ func (r *RootCmd) AddPortFlag() {
|
||||
r.Command.Flags().IntP(constant.FlagPort, "p", 0, "server listen port")
|
||||
}
|
||||
|
||||
func (r *RootCmd) getPortFlag(cmd *cobra.Command) (int, error) {
|
||||
func (r *RootCmd) getPortFlag(cmd *cobra.Command) int {
|
||||
port, err := cmd.Flags().GetInt(constant.FlagPort)
|
||||
if err != nil {
|
||||
// Wrapping the error with additional context
|
||||
return 0, errs.Wrap(err, "error getting port flag")
|
||||
return 0
|
||||
}
|
||||
if port == 0 {
|
||||
port, _ = r.PortFromConfig(constant.FlagPort)
|
||||
// port, err := r.PortFromConfig(constant.FlagPort)
|
||||
// if err != nil {
|
||||
// // Optionally wrap the error if it's an internal error needing context
|
||||
// return 0, errs.Wrap(err, "error getting port from config")
|
||||
// }
|
||||
port = r.PortFromConfig(constant.FlagPort)
|
||||
}
|
||||
return port, nil
|
||||
return port
|
||||
}
|
||||
|
||||
// // GetPortFlag returns the port flag.
|
||||
func (r *RootCmd) GetPortFlag() (int, error) {
|
||||
return r.port, nil
|
||||
func (r *RootCmd) GetPortFlag() int {
|
||||
return r.port
|
||||
}
|
||||
|
||||
func (r *RootCmd) AddPrometheusPortFlag() {
|
||||
r.Command.Flags().IntP(constant.FlagPrometheusPort, "", 0, "server prometheus listen port")
|
||||
}
|
||||
|
||||
func (r *RootCmd) getPrometheusPortFlag(cmd *cobra.Command) (int, error) {
|
||||
func (r *RootCmd) getPrometheusPortFlag(cmd *cobra.Command) int {
|
||||
port, err := cmd.Flags().GetInt(constant.FlagPrometheusPort)
|
||||
if err != nil || port == 0 {
|
||||
port, err = r.PortFromConfig(constant.FlagPrometheusPort)
|
||||
port = r.PortFromConfig(constant.FlagPrometheusPort)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0
|
||||
}
|
||||
}
|
||||
return port, nil
|
||||
return port
|
||||
}
|
||||
|
||||
func (r *RootCmd) GetPrometheusPortFlag() int {
|
||||
@@ -173,7 +173,7 @@ func (r *RootCmd) GetPrometheusPortFlag() int {
|
||||
func (r *RootCmd) getConfFromCmdAndInit(cmdLines *cobra.Command) error {
|
||||
configFolderPath, _ := cmdLines.Flags().GetString(constant.FlagConf)
|
||||
fmt.Println("The directory of the configuration file to start the process:", configFolderPath)
|
||||
return config2.InitConfig(configFolderPath)
|
||||
return config2.InitConfig(r.config, configFolderPath)
|
||||
}
|
||||
|
||||
func (r *RootCmd) Execute() error {
|
||||
@@ -184,11 +184,8 @@ func (r *RootCmd) AddCommand(cmds ...*cobra.Command) {
|
||||
r.Command.AddCommand(cmds...)
|
||||
}
|
||||
|
||||
func (r *RootCmd) PortFromConfig(portType string) (int, error) {
|
||||
func (r *RootCmd) PortFromConfig(portType string) int {
|
||||
// Retrieve the port and cache it
|
||||
port, err := r.cmdItf.GetPortFromConfig(portType)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return port, nil
|
||||
port := r.cmdItf.GetPortFromConfig(portType)
|
||||
return port
|
||||
}
|
||||
|
||||
+117
-73
@@ -16,100 +16,144 @@ package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/constant"
|
||||
"github.com/OpenIMSDK/tools/discoveryregistry"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/startrpc"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
|
||||
"github.com/OpenIMSDK/tools/discoveryregistry"
|
||||
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/startrpc"
|
||||
)
|
||||
|
||||
type rpcInitFuc func(config *config2.GlobalConfig, disCov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error
|
||||
|
||||
type RpcCmd struct {
|
||||
*RootCmd
|
||||
RpcRegisterName string
|
||||
initFunc rpcInitFuc
|
||||
}
|
||||
|
||||
func NewRpcCmd(name string) *RpcCmd {
|
||||
ret := &RpcCmd{NewRootCmd(name)}
|
||||
func NewRpcCmd(name string, initFunc rpcInitFuc) *RpcCmd {
|
||||
ret := &RpcCmd{RootCmd: NewRootCmd(name), initFunc: initFunc}
|
||||
ret.addPreRun()
|
||||
ret.addRunE()
|
||||
ret.SetRootCmdPt(ret)
|
||||
return ret
|
||||
}
|
||||
|
||||
func (a *RpcCmd) Exec() error {
|
||||
a.Command.RunE = func(cmd *cobra.Command, args []string) error {
|
||||
portFlag, err := a.getPortFlag(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.port = portFlag
|
||||
|
||||
prometheusPort, err := a.getPrometheusPortFlag(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.prometheusPort = prometheusPort
|
||||
|
||||
return nil
|
||||
func (a *RpcCmd) addPreRun() {
|
||||
a.Command.PreRun = func(cmd *cobra.Command, args []string) {
|
||||
a.port = a.getPortFlag(cmd)
|
||||
a.prometheusPort = a.getPrometheusPortFlag(cmd)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *RpcCmd) addRunE() {
|
||||
a.Command.RunE = func(cmd *cobra.Command, args []string) error {
|
||||
rpcRegisterName, err := a.GetRpcRegisterNameFromConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
} else {
|
||||
return a.StartSvr(rpcRegisterName, a.initFunc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *RpcCmd) Exec() error {
|
||||
return a.Execute()
|
||||
}
|
||||
|
||||
func (a *RpcCmd) StartSvr(name string, rpcFn func(discov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error) error {
|
||||
portFlag, err := a.GetPortFlag()
|
||||
if err != nil {
|
||||
return err
|
||||
} else {
|
||||
a.port = portFlag
|
||||
func (a *RpcCmd) StartSvr(name string, rpcFn func(config *config2.GlobalConfig, disCov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error) error {
|
||||
if a.GetPortFlag() == 0 {
|
||||
return errs.Wrap(errors.New("port is required"))
|
||||
}
|
||||
|
||||
return startrpc.Start(portFlag, name, a.GetPrometheusPortFlag(), rpcFn)
|
||||
return startrpc.Start(a.GetPortFlag(), name, a.GetPrometheusPortFlag(), a.config, rpcFn)
|
||||
}
|
||||
|
||||
func (a *RpcCmd) GetPortFromConfig(portType string) (int, error) {
|
||||
portConfigMap := map[string]map[string]int{
|
||||
RpcPushServer: {
|
||||
constant.FlagPort: config2.Config.RpcPort.OpenImPushPort[0],
|
||||
constant.FlagPrometheusPort: config2.Config.Prometheus.PushPrometheusPort[0],
|
||||
},
|
||||
RpcAuthServer: {
|
||||
constant.FlagPort: config2.Config.RpcPort.OpenImAuthPort[0],
|
||||
constant.FlagPrometheusPort: config2.Config.Prometheus.AuthPrometheusPort[0],
|
||||
},
|
||||
RpcConversationServer: {
|
||||
constant.FlagPort: config2.Config.RpcPort.OpenImConversationPort[0],
|
||||
constant.FlagPrometheusPort: config2.Config.Prometheus.ConversationPrometheusPort[0],
|
||||
},
|
||||
RpcFriendServer: {
|
||||
constant.FlagPort: config2.Config.RpcPort.OpenImFriendPort[0],
|
||||
constant.FlagPrometheusPort: config2.Config.Prometheus.FriendPrometheusPort[0],
|
||||
},
|
||||
RpcGroupServer: {
|
||||
constant.FlagPort: config2.Config.RpcPort.OpenImGroupPort[0],
|
||||
constant.FlagPrometheusPort: config2.Config.Prometheus.GroupPrometheusPort[0],
|
||||
},
|
||||
RpcMsgServer: {
|
||||
constant.FlagPort: config2.Config.RpcPort.OpenImMessagePort[0],
|
||||
constant.FlagPrometheusPort: config2.Config.Prometheus.MessagePrometheusPort[0],
|
||||
},
|
||||
RpcThirdServer: {
|
||||
constant.FlagPort: config2.Config.RpcPort.OpenImThirdPort[0],
|
||||
constant.FlagPrometheusPort: config2.Config.Prometheus.ThirdPrometheusPort[0],
|
||||
},
|
||||
RpcUserServer: {
|
||||
constant.FlagPort: config2.Config.RpcPort.OpenImUserPort[0],
|
||||
constant.FlagPrometheusPort: config2.Config.Prometheus.UserPrometheusPort[0],
|
||||
},
|
||||
}
|
||||
|
||||
if portMap, ok := portConfigMap[a.Name]; ok {
|
||||
if port, ok := portMap[portType]; ok {
|
||||
return port, nil
|
||||
} else {
|
||||
return 0, errs.Wrap(errors.New("port type not found"), fmt.Sprintf("Failed to get port for %s", a.Name))
|
||||
func (a *RpcCmd) GetPortFromConfig(portType string) int {
|
||||
switch a.Name {
|
||||
case RpcPushServer:
|
||||
if portType == constant.FlagPort {
|
||||
return a.config.RpcPort.OpenImPushPort[0]
|
||||
}
|
||||
if portType == constant.FlagPrometheusPort {
|
||||
return a.config.Prometheus.PushPrometheusPort[0]
|
||||
}
|
||||
case RpcAuthServer:
|
||||
if portType == constant.FlagPort {
|
||||
return a.config.RpcPort.OpenImAuthPort[0]
|
||||
}
|
||||
if portType == constant.FlagPrometheusPort {
|
||||
return a.config.Prometheus.AuthPrometheusPort[0]
|
||||
}
|
||||
case RpcConversationServer:
|
||||
if portType == constant.FlagPort {
|
||||
return a.config.RpcPort.OpenImConversationPort[0]
|
||||
}
|
||||
if portType == constant.FlagPrometheusPort {
|
||||
return a.config.Prometheus.ConversationPrometheusPort[0]
|
||||
}
|
||||
case RpcFriendServer:
|
||||
if portType == constant.FlagPort {
|
||||
return a.config.RpcPort.OpenImFriendPort[0]
|
||||
}
|
||||
if portType == constant.FlagPrometheusPort {
|
||||
return a.config.Prometheus.FriendPrometheusPort[0]
|
||||
}
|
||||
case RpcGroupServer:
|
||||
if portType == constant.FlagPort {
|
||||
return a.config.RpcPort.OpenImGroupPort[0]
|
||||
}
|
||||
if portType == constant.FlagPrometheusPort {
|
||||
return a.config.Prometheus.GroupPrometheusPort[0]
|
||||
}
|
||||
case RpcMsgServer:
|
||||
if portType == constant.FlagPort {
|
||||
return a.config.RpcPort.OpenImMessagePort[0]
|
||||
}
|
||||
if portType == constant.FlagPrometheusPort {
|
||||
return a.config.Prometheus.MessagePrometheusPort[0]
|
||||
}
|
||||
case RpcThirdServer:
|
||||
if portType == constant.FlagPort {
|
||||
return a.config.RpcPort.OpenImThirdPort[0]
|
||||
}
|
||||
if portType == constant.FlagPrometheusPort {
|
||||
return a.config.Prometheus.ThirdPrometheusPort[0]
|
||||
}
|
||||
case RpcUserServer:
|
||||
if portType == constant.FlagPort {
|
||||
return a.config.RpcPort.OpenImUserPort[0]
|
||||
}
|
||||
if portType == constant.FlagPrometheusPort {
|
||||
return a.config.Prometheus.UserPrometheusPort[0]
|
||||
}
|
||||
}
|
||||
|
||||
return 0, errs.Wrap(fmt.Errorf("server name '%s' not found", a.Name), "Failed to get port configuration")
|
||||
return 0
|
||||
}
|
||||
|
||||
func (a *RpcCmd) GetRpcRegisterNameFromConfig() (string, error) {
|
||||
switch a.Name {
|
||||
case RpcPushServer:
|
||||
return a.config.RpcRegisterName.OpenImPushName, nil
|
||||
case RpcAuthServer:
|
||||
return a.config.RpcRegisterName.OpenImAuthName, nil
|
||||
case RpcConversationServer:
|
||||
return a.config.RpcRegisterName.OpenImConversationName, nil
|
||||
case RpcFriendServer:
|
||||
return a.config.RpcRegisterName.OpenImFriendName, nil
|
||||
case RpcGroupServer:
|
||||
return a.config.RpcRegisterName.OpenImGroupName, nil
|
||||
case RpcMsgServer:
|
||||
return a.config.RpcRegisterName.OpenImMsgName, nil
|
||||
case RpcThirdServer:
|
||||
return a.config.RpcRegisterName.OpenImThirdName, nil
|
||||
case RpcUserServer:
|
||||
return a.config.RpcRegisterName.OpenImUserName, nil
|
||||
}
|
||||
return "", errs.Wrap(errors.New("can not get rpc register name"), a.Name)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var Config configStruct
|
||||
var Config GlobalConfig
|
||||
|
||||
const ConfKey = "conf"
|
||||
|
||||
@@ -57,7 +57,7 @@ type MYSQL struct {
|
||||
SlowThreshold int `yaml:"slowThreshold"`
|
||||
}
|
||||
|
||||
type configStruct struct {
|
||||
type GlobalConfig struct {
|
||||
Envs struct {
|
||||
Discovery string `yaml:"discovery"`
|
||||
}
|
||||
@@ -339,6 +339,10 @@ type configStruct struct {
|
||||
Notification notification `yaml:"notification"`
|
||||
}
|
||||
|
||||
func NewGlobalConfig() *GlobalConfig {
|
||||
return &GlobalConfig{}
|
||||
}
|
||||
|
||||
type notification struct {
|
||||
GroupCreated NotificationConf `yaml:"groupCreated"`
|
||||
GroupInfoSet NotificationConf `yaml:"groupInfoSet"`
|
||||
@@ -378,7 +382,7 @@ type notification struct {
|
||||
ConversationSetPrivate NotificationConf `yaml:"conversationSetPrivate"`
|
||||
}
|
||||
|
||||
func (c *configStruct) GetServiceNames() []string {
|
||||
func (c *GlobalConfig) GetServiceNames() []string {
|
||||
return []string{
|
||||
c.RpcRegisterName.OpenImUserName,
|
||||
c.RpcRegisterName.OpenImFriendName,
|
||||
@@ -392,7 +396,7 @@ func (c *configStruct) GetServiceNames() []string {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *configStruct) RegisterConf2Registry(registry discoveryregistry.SvcDiscoveryRegistry) error {
|
||||
func (c *GlobalConfig) RegisterConf2Registry(registry discoveryregistry.SvcDiscoveryRegistry) error {
|
||||
data, err := yaml.Marshal(c)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -400,11 +404,11 @@ func (c *configStruct) RegisterConf2Registry(registry discoveryregistry.SvcDisco
|
||||
return registry.RegisterConf2Registry(ConfKey, data)
|
||||
}
|
||||
|
||||
func (c *configStruct) GetConfFromRegistry(registry discoveryregistry.SvcDiscoveryRegistry) ([]byte, error) {
|
||||
func (c *GlobalConfig) GetConfFromRegistry(registry discoveryregistry.SvcDiscoveryRegistry) ([]byte, error) {
|
||||
return registry.GetConfFromRegistry(ConfKey)
|
||||
}
|
||||
|
||||
func (c *configStruct) EncodeConfig() []byte {
|
||||
func (c *GlobalConfig) EncodeConfig() []byte {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
if err := yaml.NewEncoder(buf).Encode(c); err != nil {
|
||||
panic(err)
|
||||
|
||||
+33
-60
@@ -21,10 +21,10 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/constant"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/msgprocessor"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
//go:embed version
|
||||
@@ -36,38 +36,32 @@ const (
|
||||
DefaultFolderPath = "../config/"
|
||||
)
|
||||
|
||||
// GetDefaultConfigPath returns the absolute path to the default configuration directory
|
||||
// relative to the executable's location. It is intended for use in Kubernetes container configurations.
|
||||
// Errors are returned to the caller to allow for flexible error handling.
|
||||
func GetDefaultConfigPath() (string, error) {
|
||||
// return absolude path join ../config/, this is k8s container config path.
|
||||
func GetDefaultConfigPath() string {
|
||||
executablePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", errs.Wrap(err, "failed to get executable path")
|
||||
fmt.Println("GetDefaultConfigPath error:", err.Error())
|
||||
return ""
|
||||
}
|
||||
|
||||
// Calculate the config path as a directory relative to the executable's location
|
||||
configPath, err := genutil.OutDir(filepath.Join(filepath.Dir(executablePath), "../config/"))
|
||||
if err != nil {
|
||||
return "", errs.Wrap(err, "failed to get output directory")
|
||||
fmt.Fprintf(os.Stderr, "failed to get output directory: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return configPath, nil
|
||||
return configPath
|
||||
}
|
||||
|
||||
// GetProjectRoot returns the absolute path of the project root directory by navigating up from the directory
|
||||
// containing the executable. It provides a detailed error if the path cannot be determined.
|
||||
func GetProjectRoot() (string, error) {
|
||||
executablePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", errs.Wrap(err, "failed to retrieve executable path")
|
||||
}
|
||||
// getProjectRoot returns the absolute path of the project root directory.
|
||||
func GetProjectRoot() string {
|
||||
executablePath, _ := os.Executable()
|
||||
|
||||
// Attempt to compute the project root by navigating up from the executable's directory
|
||||
projectRoot, err := genutil.OutDir(filepath.Join(filepath.Dir(executablePath), "../../../../.."))
|
||||
if err != nil {
|
||||
return "", err
|
||||
fmt.Fprintf(os.Stderr, "failed to get output directory: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
return projectRoot, nil
|
||||
return projectRoot
|
||||
}
|
||||
|
||||
func GetOptionsByNotification(cfg NotificationConf) msgprocessor.Options {
|
||||
@@ -93,62 +87,41 @@ func GetOptionsByNotification(cfg NotificationConf) msgprocessor.Options {
|
||||
// If the specified config file does not exist, it attempts to load from the project's default "config" directory.
|
||||
// It logs informative messages regarding the configuration path being used.
|
||||
func initConfig(config any, configName, configFolderPath string) error {
|
||||
configFilePath := filepath.Join(configFolderPath, configName)
|
||||
_, err := os.Stat(configFilePath)
|
||||
configFolderPath = filepath.Join(configFolderPath, configName)
|
||||
_, err := os.Stat(configFolderPath)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return errs.Wrap(err, fmt.Sprintf("failed to check existence of config file at path: %s", configFilePath))
|
||||
fmt.Println("stat config path error:", err.Error())
|
||||
return fmt.Errorf("stat config path error: %w", err)
|
||||
}
|
||||
var projectRoot string
|
||||
projectRoot, err = GetProjectRoot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
configFilePath = filepath.Join(projectRoot, "config", configName)
|
||||
fmt.Printf("Configuration file not found at specified path. Falling back to project path: %s\n", configFilePath)
|
||||
configFolderPath = filepath.Join(GetProjectRoot(), "config", configName)
|
||||
fmt.Println("flag's path,enviment's path,default path all is not exist,using project path:", configFolderPath)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configFilePath)
|
||||
data, err := os.ReadFile(configFolderPath)
|
||||
if err != nil {
|
||||
// Wrap and return the error if reading the configuration file fails.
|
||||
return errs.Wrap(err, fmt.Sprintf("failed to read configuration file at path: %s", configFilePath))
|
||||
return fmt.Errorf("read file error: %w", err)
|
||||
}
|
||||
|
||||
if err = yaml.Unmarshal(data, config); err != nil {
|
||||
// Wrap and return the error if unmarshalling the YAML configuration fails.
|
||||
return errs.Wrap(err, "failed to unmarshal YAML configuration")
|
||||
return fmt.Errorf("unmarshal yaml error: %w", err)
|
||||
}
|
||||
fmt.Println("The path of the configuration file to start the process:", configFolderPath)
|
||||
|
||||
fmt.Printf("Configuration file loaded successfully from path: %s\n", configFilePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitConfig initializes the application configuration by loading it from a specified folder path.
|
||||
// If the folder path is not provided, it attempts to use the OPENIMCONFIG environment variable,
|
||||
// and as a fallback, it uses the default configuration path. It loads both the main configuration
|
||||
// and notification configuration, wrapping errors for better context.
|
||||
func InitConfig(configFolderPath string) error {
|
||||
// Use the provided config folder path, or fallback to environment variable or default path
|
||||
func InitConfig(config *GlobalConfig, configFolderPath string) error {
|
||||
if configFolderPath == "" {
|
||||
configFolderPath = os.Getenv("OPENIMCONFIG")
|
||||
if configFolderPath == "" {
|
||||
var err error
|
||||
configFolderPath, err = GetDefaultConfigPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
envConfigPath := os.Getenv("OPENIMCONFIG")
|
||||
if envConfigPath != "" {
|
||||
configFolderPath = envConfigPath
|
||||
} else {
|
||||
configFolderPath = GetDefaultConfigPath()
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the main configuration
|
||||
if err := initConfig(&Config, FileName, configFolderPath); err != nil {
|
||||
if err := initConfig(config, FileName, configFolderPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the notification configuration
|
||||
if err := initConfig(&Config.Notification, NotificationFileName, configFolderPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return initConfig(&config.Notification, NotificationFileName, configFolderPath)
|
||||
}
|
||||
|
||||
@@ -103,13 +103,14 @@ func TestInitConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
config *GlobalConfig
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := InitConfig(tt.args.configFolderPath); (err != nil) != tt.wantErr {
|
||||
if err := InitConfig(tt.config, tt.args.configFolderPath); (err != nil) != tt.wantErr {
|
||||
t.Errorf("InitConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
|
||||
Vendored
+21
-21
@@ -38,34 +38,34 @@ const (
|
||||
)
|
||||
|
||||
// NewRedis Initialize redis connection.
|
||||
func NewRedis() (redis.UniversalClient, error) {
|
||||
func NewRedis(config *config.GlobalConfig) (redis.UniversalClient, error) {
|
||||
if redisClient != nil {
|
||||
return redisClient, nil
|
||||
}
|
||||
|
||||
// Read configuration from environment variables
|
||||
overrideConfigFromEnv()
|
||||
overrideConfigFromEnv(config)
|
||||
|
||||
if len(config.Config.Redis.Address) == 0 {
|
||||
return nil, errs.Wrap(errors.New("redis address is empty"), "Redis configuration error")
|
||||
if len(config.Redis.Address) == 0 {
|
||||
return nil, errs.Wrap(errors.New("redis address is empty"))
|
||||
}
|
||||
specialerror.AddReplace(redis.Nil, errs.ErrRecordNotFound)
|
||||
var rdb redis.UniversalClient
|
||||
if len(config.Config.Redis.Address) > 1 || config.Config.Redis.ClusterMode {
|
||||
if len(config.Redis.Address) > 1 || config.Redis.ClusterMode {
|
||||
rdb = redis.NewClusterClient(&redis.ClusterOptions{
|
||||
Addrs: config.Config.Redis.Address,
|
||||
Username: config.Config.Redis.Username,
|
||||
Password: config.Config.Redis.Password, // no password set
|
||||
Addrs: config.Redis.Address,
|
||||
Username: config.Redis.Username,
|
||||
Password: config.Redis.Password, // no password set
|
||||
PoolSize: 50,
|
||||
MaxRetries: maxRetry,
|
||||
})
|
||||
} else {
|
||||
rdb = redis.NewClient(&redis.Options{
|
||||
Addr: config.Config.Redis.Address[0],
|
||||
Username: config.Config.Redis.Username,
|
||||
Password: config.Config.Redis.Password, // no password set
|
||||
DB: 0, // use default DB
|
||||
PoolSize: 100, // connection pool size
|
||||
Addr: config.Redis.Address[0],
|
||||
Username: config.Redis.Username,
|
||||
Password: config.Redis.Password,
|
||||
DB: 0, // use default DB
|
||||
PoolSize: 100, // connection pool size
|
||||
MaxRetries: maxRetry,
|
||||
})
|
||||
}
|
||||
@@ -75,33 +75,33 @@ func NewRedis() (redis.UniversalClient, error) {
|
||||
defer cancel()
|
||||
err = rdb.Ping(ctx).Err()
|
||||
if err != nil {
|
||||
uriFormat := "address:%v, username:%s, clusterMode:%t, enablePipeline:%t"
|
||||
errMsg := fmt.Sprintf(uriFormat, config.Config.Redis.Address, config.Config.Redis.Username, config.Config.Redis.ClusterMode, config.Config.Redis.EnablePipeline)
|
||||
return nil, errs.Wrap(err, "Redis connection failed: %s", errMsg)
|
||||
errMsg := fmt.Sprintf("address:%s, username:%s, password:%s, clusterMode:%t, enablePipeline:%t", config.Redis.Address, config.Redis.Username,
|
||||
config.Redis.Password, config.Redis.ClusterMode, config.Redis.EnablePipeline)
|
||||
return nil, errs.Wrap(err, errMsg)
|
||||
}
|
||||
redisClient = rdb
|
||||
return rdb, err
|
||||
}
|
||||
|
||||
// overrideConfigFromEnv overrides configuration fields with environment variables if present.
|
||||
func overrideConfigFromEnv() {
|
||||
func overrideConfigFromEnv(config *config.GlobalConfig) {
|
||||
if envAddr := os.Getenv("REDIS_ADDRESS"); envAddr != "" {
|
||||
if envPort := os.Getenv("REDIS_PORT"); envPort != "" {
|
||||
addresses := strings.Split(envAddr, ",")
|
||||
for i, addr := range addresses {
|
||||
addresses[i] = addr + ":" + envPort
|
||||
}
|
||||
config.Config.Redis.Address = addresses
|
||||
config.Redis.Address = addresses
|
||||
} else {
|
||||
config.Config.Redis.Address = strings.Split(envAddr, ",")
|
||||
config.Redis.Address = strings.Split(envAddr, ",")
|
||||
}
|
||||
}
|
||||
|
||||
if envUser := os.Getenv("REDIS_USERNAME"); envUser != "" {
|
||||
config.Config.Redis.Username = envUser
|
||||
config.Redis.Username = envUser
|
||||
}
|
||||
|
||||
if envPass := os.Getenv("REDIS_PASSWORD"); envPass != "" {
|
||||
config.Config.Redis.Password = envPass
|
||||
config.Redis.Password = envPass
|
||||
}
|
||||
}
|
||||
|
||||
Vendored
+11
-8
@@ -18,13 +18,16 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/OpenIMSDK/tools/mw/specialerror"
|
||||
|
||||
"github.com/dtm-labs/rockscache"
|
||||
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
"github.com/OpenIMSDK/tools/log"
|
||||
"github.com/OpenIMSDK/tools/mw/specialerror"
|
||||
"github.com/OpenIMSDK/tools/utils"
|
||||
"github.com/dtm-labs/rockscache"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -128,7 +131,7 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin
|
||||
v, err := rcClient.Fetch2(ctx, key, expire, func() (s string, err error) {
|
||||
t, err = fn(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", errs.Wrap(err)
|
||||
}
|
||||
bs, err := json.Marshal(t)
|
||||
if err != nil {
|
||||
@@ -139,7 +142,7 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin
|
||||
return string(bs), nil
|
||||
})
|
||||
if err != nil {
|
||||
return t, err
|
||||
return t, errs.Wrap(err)
|
||||
}
|
||||
if write {
|
||||
return t, nil
|
||||
@@ -149,8 +152,8 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin
|
||||
}
|
||||
err = json.Unmarshal([]byte(v), &t)
|
||||
if err != nil {
|
||||
log.ZError(ctx, "cache json.Unmarshal failed", err, "key", key, "value", v, "expire", expire)
|
||||
return t, errs.Wrap(err, "unmarshal failed")
|
||||
errInfo := fmt.Sprintf("cache json.Unmarshal failed, key:%s, value:%s, expire:%s", key, v, expire)
|
||||
return t, errs.Wrap(err, errInfo)
|
||||
}
|
||||
|
||||
return t, nil
|
||||
@@ -203,7 +206,7 @@ func batchGetCache2[T any, K comparable](
|
||||
fns func(ctx context.Context, key K) (T, error),
|
||||
) ([]T, error) {
|
||||
if len(keys) == 0 {
|
||||
return nil, nil
|
||||
return nil, errs.ErrArgs.Wrap("groupID is empty")
|
||||
}
|
||||
res := make([]T, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
@@ -214,7 +217,7 @@ func batchGetCache2[T any, K comparable](
|
||||
if errs.ErrRecordNotFound.Is(specialerror.ErrCode(errs.Unwrap(err))) {
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
res = append(res, val)
|
||||
}
|
||||
|
||||
Vendored
+12
-11
@@ -121,13 +121,14 @@ type MsgModel interface {
|
||||
UnLockMessageTypeKey(ctx context.Context, clientMsgID string, TypeKey string) error
|
||||
}
|
||||
|
||||
func NewMsgCacheModel(client redis.UniversalClient) MsgModel {
|
||||
return &msgCache{rdb: client}
|
||||
func NewMsgCacheModel(client redis.UniversalClient, config *config.GlobalConfig) MsgModel {
|
||||
return &msgCache{rdb: client, config: config}
|
||||
}
|
||||
|
||||
type msgCache struct {
|
||||
metaCache
|
||||
rdb redis.UniversalClient
|
||||
rdb redis.UniversalClient
|
||||
config *config.GlobalConfig
|
||||
}
|
||||
|
||||
func (c *msgCache) getMaxSeqKey(conversationID string) string {
|
||||
@@ -315,7 +316,7 @@ func (c *msgCache) allMessageCacheKey(conversationID string) string {
|
||||
}
|
||||
|
||||
func (c *msgCache) GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsgs []*sdkws.MsgData, failedSeqs []int64, err error) {
|
||||
if config.Config.Redis.EnablePipeline {
|
||||
if c.config.Redis.EnablePipeline {
|
||||
return c.PipeGetMessagesBySeq(ctx, conversationID, seqs)
|
||||
}
|
||||
|
||||
@@ -416,7 +417,7 @@ func (c *msgCache) ParallelGetMessagesBySeq(ctx context.Context, conversationID
|
||||
}
|
||||
|
||||
func (c *msgCache) SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) {
|
||||
if config.Config.Redis.EnablePipeline {
|
||||
if c.config.Redis.EnablePipeline {
|
||||
return c.PipeSetMessageToCache(ctx, conversationID, msgs)
|
||||
}
|
||||
return c.ParallelSetMessageToCache(ctx, conversationID, msgs)
|
||||
@@ -431,7 +432,7 @@ func (c *msgCache) PipeSetMessageToCache(ctx context.Context, conversationID str
|
||||
}
|
||||
|
||||
key := c.getMessageCacheKey(conversationID, msg.Seq)
|
||||
_ = pipe.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second)
|
||||
_ = pipe.Set(ctx, key, s, time.Duration(c.config.MsgCacheTimeout)*time.Second)
|
||||
}
|
||||
|
||||
results, err := pipe.Exec(ctx)
|
||||
@@ -461,7 +462,7 @@ func (c *msgCache) ParallelSetMessageToCache(ctx context.Context, conversationID
|
||||
}
|
||||
|
||||
key := c.getMessageCacheKey(conversationID, msg.Seq)
|
||||
if err := c.rdb.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil {
|
||||
if err := c.rdb.Set(ctx, key, s, time.Duration(c.config.MsgCacheTimeout)*time.Second).Err(); err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
return nil
|
||||
@@ -496,10 +497,10 @@ func (c *msgCache) UserDeleteMsgs(ctx context.Context, conversationID string, se
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
if err := c.rdb.Expire(ctx, delUserListKey, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil {
|
||||
if err := c.rdb.Expire(ctx, delUserListKey, time.Duration(c.config.MsgCacheTimeout)*time.Second).Err(); err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
if err := c.rdb.Expire(ctx, userDelListKey, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil {
|
||||
if err := c.rdb.Expire(ctx, userDelListKey, time.Duration(c.config.MsgCacheTimeout)*time.Second).Err(); err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
}
|
||||
@@ -604,7 +605,7 @@ func (c *msgCache) DelUserDeleteMsgsList(ctx context.Context, conversationID str
|
||||
}
|
||||
|
||||
func (c *msgCache) DeleteMessages(ctx context.Context, conversationID string, seqs []int64) error {
|
||||
if config.Config.Redis.EnablePipeline {
|
||||
if c.config.Redis.EnablePipeline {
|
||||
return c.PipeDeleteMessages(ctx, conversationID, seqs)
|
||||
}
|
||||
|
||||
@@ -686,7 +687,7 @@ func (c *msgCache) DelMsgFromCache(ctx context.Context, userID string, seqs []in
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
if err := c.rdb.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil {
|
||||
if err := c.rdb.Set(ctx, key, s, time.Duration(c.config.MsgCacheTimeout)*time.Second).Err(); err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
Vendored
+20
-12
@@ -22,12 +22,16 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
|
||||
|
||||
"github.com/OpenIMSDK/tools/log"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/constant"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/user"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
"github.com/OpenIMSDK/tools/log"
|
||||
|
||||
"github.com/dtm-labs/rockscache"
|
||||
relationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -62,7 +66,11 @@ type UserCacheRedis struct {
|
||||
rcClient *rockscache.Client
|
||||
}
|
||||
|
||||
func NewUserCacheRedis(rdb redis.UniversalClient, userDB relationtb.UserModelInterface, options rockscache.Options) UserCache {
|
||||
func NewUserCacheRedis(
|
||||
rdb redis.UniversalClient,
|
||||
userDB relationtb.UserModelInterface,
|
||||
options rockscache.Options,
|
||||
) UserCache {
|
||||
rcClient := rockscache.NewClient(rdb, options)
|
||||
|
||||
return &UserCacheRedis{
|
||||
@@ -193,13 +201,13 @@ func (u *UserCacheRedis) SetUserStatus(ctx context.Context, userID string, statu
|
||||
Status: constant.Online,
|
||||
PlatformIDs: []int32{platformID},
|
||||
}
|
||||
jsonData, err2 := json.Marshal(&onlineStatus)
|
||||
if err2 != nil {
|
||||
return errs.Wrap(err2)
|
||||
jsonData, err := json.Marshal(&onlineStatus)
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
_, err2 = u.rdb.HSet(ctx, key, userID, string(jsonData)).Result()
|
||||
if err2 != nil {
|
||||
return errs.Wrap(err2)
|
||||
_, err = u.rdb.HSet(ctx, key, userID, string(jsonData)).Result()
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
u.rdb.Expire(ctx, key, userOlineStatusExpireTime)
|
||||
|
||||
@@ -273,9 +281,9 @@ func (u *UserCacheRedis) refreshStatusOffline(ctx context.Context, userID string
|
||||
func (u *UserCacheRedis) refreshStatusOnline(ctx context.Context, userID string, platformID int32, isNil bool, err error, result, key string) error {
|
||||
var onlineStatus user.OnlineStatus
|
||||
if !isNil {
|
||||
err2 := json.Unmarshal([]byte(result), &onlineStatus)
|
||||
if err2 != nil {
|
||||
return errs.Wrap(err, "json.Unmarshal failed")
|
||||
err := json.Unmarshal([]byte(result), &onlineStatus)
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
onlineStatus.PlatformIDs = RemoveRepeatedElementsInList(append(onlineStatus.PlatformIDs, platformID))
|
||||
} else {
|
||||
|
||||
@@ -16,6 +16,7 @@ package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/constant"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
@@ -33,14 +34,14 @@ type AuthDatabase interface {
|
||||
}
|
||||
|
||||
type authDatabase struct {
|
||||
cache cache.MsgModel
|
||||
|
||||
cache cache.MsgModel
|
||||
accessSecret string
|
||||
accessExpire int64
|
||||
config *config.GlobalConfig
|
||||
}
|
||||
|
||||
func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int64) AuthDatabase {
|
||||
return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire}
|
||||
func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int64, config *config.GlobalConfig) AuthDatabase {
|
||||
return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, config: config}
|
||||
}
|
||||
|
||||
// If the result is empty.
|
||||
@@ -56,7 +57,7 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
|
||||
}
|
||||
var deleteTokenKey []string
|
||||
for k, v := range tokens {
|
||||
_, err = tokenverify.GetClaimFromToken(k, authverify.Secret())
|
||||
_, err = tokenverify.GetClaimFromToken(k, authverify.Secret(a.config.Secret))
|
||||
if err != nil || v != constant.NormalToken {
|
||||
deleteTokenKey = append(deleteTokenKey, k)
|
||||
}
|
||||
|
||||
@@ -120,16 +120,33 @@ type CommonMsgDatabase interface {
|
||||
ConvertMsgsDocLen(ctx context.Context, conversationIDs []string)
|
||||
}
|
||||
|
||||
func NewCommonMsgDatabase(msgDocModel unrelationtb.MsgDocModelInterface, cacheModel cache.MsgModel) (CommonMsgDatabase, error) {
|
||||
producerToRedis, err := kafka.NewKafkaProducer(config.Config.Kafka.Addr, config.Config.Kafka.LatestMsgToRedis.Topic)
|
||||
func NewCommonMsgDatabase(msgDocModel unrelationtb.MsgDocModelInterface, cacheModel cache.MsgModel, config *config.GlobalConfig) (CommonMsgDatabase, error) {
|
||||
producerConfig := &kafka.ProducerConfig{
|
||||
ProducerAck: config.Kafka.ProducerAck,
|
||||
CompressType: config.Kafka.CompressType,
|
||||
Username: config.Kafka.Username,
|
||||
Password: config.Kafka.Password,
|
||||
}
|
||||
|
||||
var tlsConfig *kafka.TLSConfig
|
||||
if config.Kafka.TLS != nil {
|
||||
tlsConfig = &kafka.TLSConfig{
|
||||
CACrt: config.Kafka.TLS.CACrt,
|
||||
ClientCrt: config.Kafka.TLS.ClientCrt,
|
||||
ClientKey: config.Kafka.TLS.ClientKey,
|
||||
ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd,
|
||||
InsecureSkipVerify: false,
|
||||
}
|
||||
}
|
||||
producerToRedis, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.LatestMsgToRedis.Topic, producerConfig, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
producerToMongo, err := kafka.NewKafkaProducer(config.Config.Kafka.Addr, config.Config.Kafka.MsgToMongo.Topic)
|
||||
producerToMongo, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.MsgToMongo.Topic, producerConfig, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
producerToPush, err := kafka.NewKafkaProducer(config.Config.Kafka.Addr, config.Config.Kafka.MsgToPush.Topic)
|
||||
producerToPush, err := kafka.NewKafkaProducer(config.Kafka.Addr, config.Kafka.MsgToPush.Topic, producerConfig, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -142,10 +159,10 @@ func NewCommonMsgDatabase(msgDocModel unrelationtb.MsgDocModelInterface, cacheMo
|
||||
}, nil
|
||||
}
|
||||
|
||||
func InitCommonMsgDatabase(rdb redis.UniversalClient, database *mongo.Database) (CommonMsgDatabase, error) {
|
||||
cacheModel := cache.NewMsgCacheModel(rdb)
|
||||
func InitCommonMsgDatabase(rdb redis.UniversalClient, database *mongo.Database, config *config.GlobalConfig) (CommonMsgDatabase, error) {
|
||||
cacheModel := cache.NewMsgCacheModel(rdb, config)
|
||||
msgDocModel := unrelation.NewMsgMongoDriver(database)
|
||||
return NewCommonMsgDatabase(msgDocModel, cacheModel)
|
||||
return NewCommonMsgDatabase(msgDocModel, cacheModel, config)
|
||||
}
|
||||
|
||||
type commonMsgDatabase struct {
|
||||
@@ -397,9 +414,9 @@ func (db *commonMsgDatabase) BatchInsertChat2Cache(ctx context.Context, conversa
|
||||
log.ZError(ctx, "db.cache.SetMaxSeq error", err, "conversationID", conversationID)
|
||||
prommetrics.SeqSetFailedCounter.Inc()
|
||||
}
|
||||
err2 := db.cache.SetHasReadSeqs(ctx, conversationID, userSeqMap)
|
||||
err = db.cache.SetHasReadSeqs(ctx, conversationID, userSeqMap)
|
||||
if err != nil {
|
||||
log.ZError(ctx, "SetHasReadSeqs error", err2, "userSeqMap", userSeqMap, "conversationID", conversationID)
|
||||
log.ZError(ctx, "SetHasReadSeqs error", err, "userSeqMap", userSeqMap, "conversationID", conversationID)
|
||||
prommetrics.SeqSetFailedCounter.Inc()
|
||||
}
|
||||
return lastMaxSeq, isNew, errs.Wrap(err)
|
||||
|
||||
@@ -33,27 +33,28 @@ import (
|
||||
)
|
||||
|
||||
func Test_BatchInsertChat2DB(t *testing.T) {
|
||||
config.Config.Mongo.Address = []string{"192.168.44.128:37017"}
|
||||
// config.Config.Mongo.Timeout = 60
|
||||
config.Config.Mongo.Database = "openIM"
|
||||
// config.Config.Mongo.Source = "admin"
|
||||
config.Config.Mongo.Username = "root"
|
||||
config.Config.Mongo.Password = "openIM123"
|
||||
config.Config.Mongo.MaxPoolSize = 100
|
||||
config.Config.RetainChatRecords = 3650
|
||||
config.Config.ChatRecordsClearTime = "0 2 * * 3"
|
||||
conf := config.NewGlobalConfig()
|
||||
conf.Mongo.Address = []string{"192.168.44.128:37017"}
|
||||
// conf.Mongo.Timeout = 60
|
||||
conf.Mongo.Database = "openIM"
|
||||
// conf.Mongo.Source = "admin"
|
||||
conf.Mongo.Username = "root"
|
||||
conf.Mongo.Password = "openIM123"
|
||||
conf.Mongo.MaxPoolSize = 100
|
||||
conf.RetainChatRecords = 3650
|
||||
conf.ChatRecordsClearTime = "0 2 * * 3"
|
||||
|
||||
mongo, err := unrelation.NewMongo()
|
||||
mongo, err := unrelation.NewMongo(conf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mongo.GetDatabase().Client().Ping(context.Background(), nil)
|
||||
err = mongo.GetDatabase(conf.Mongo.Database).Client().Ping(context.Background(), nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db := &commonMsgDatabase{
|
||||
msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase()),
|
||||
msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase(conf.Mongo.Database)),
|
||||
}
|
||||
|
||||
//ctx := context.Background()
|
||||
@@ -70,7 +71,7 @@ func Test_BatchInsertChat2DB(t *testing.T) {
|
||||
//}
|
||||
|
||||
_ = db.BatchInsertChat2DB
|
||||
c := mongo.GetDatabase().Collection("msg")
|
||||
c := mongo.GetDatabase(conf.Mongo.Database).Collection("msg")
|
||||
|
||||
ch := make(chan int)
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
@@ -144,26 +145,27 @@ func Test_BatchInsertChat2DB(t *testing.T) {
|
||||
}
|
||||
|
||||
func GetDB() *commonMsgDatabase {
|
||||
config.Config.Mongo.Address = []string{"203.56.175.233:37017"}
|
||||
// config.Config.Mongo.Timeout = 60
|
||||
config.Config.Mongo.Database = "openim_v3"
|
||||
// config.Config.Mongo.Source = "admin"
|
||||
config.Config.Mongo.Username = "root"
|
||||
config.Config.Mongo.Password = "openIM123"
|
||||
config.Config.Mongo.MaxPoolSize = 100
|
||||
config.Config.RetainChatRecords = 3650
|
||||
config.Config.ChatRecordsClearTime = "0 2 * * 3"
|
||||
conf := config.NewGlobalConfig()
|
||||
conf.Mongo.Address = []string{"203.56.175.233:37017"}
|
||||
// conf.Mongo.Timeout = 60
|
||||
conf.Mongo.Database = "openim_v3"
|
||||
// conf.Mongo.Source = "admin"
|
||||
conf.Mongo.Username = "root"
|
||||
conf.Mongo.Password = "openIM123"
|
||||
conf.Mongo.MaxPoolSize = 100
|
||||
conf.RetainChatRecords = 3650
|
||||
conf.ChatRecordsClearTime = "0 2 * * 3"
|
||||
|
||||
mongo, err := unrelation.NewMongo()
|
||||
mongo, err := unrelation.NewMongo(conf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = mongo.GetDatabase().Client().Ping(context.Background(), nil)
|
||||
err = mongo.GetDatabase(conf.Mongo.Database).Client().Ping(context.Background(), nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &commonMsgDatabase{
|
||||
msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase()),
|
||||
msgDocDatabase: unrelation.NewMsgMongoDriver(mongo.GetDatabase(conf.Mongo.Database)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,275 +0,0 @@
|
||||
// Copyright © 2023 OpenIM. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// docURL: https://docs.aws.amazon.com/AmazonS3/latest/API/Welcome.html
|
||||
|
||||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/awserr"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
sdk "github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/db/s3"
|
||||
)
|
||||
|
||||
const (
|
||||
minPartSize int64 = 1024 * 1024 * 1 // 1MB
|
||||
maxPartSize int64 = 1024 * 1024 * 1024 * 5 // 5GB
|
||||
maxNumSize int64 = 10000
|
||||
)
|
||||
|
||||
// const (
|
||||
// imagePng = "png"
|
||||
// imageJpg = "jpg"
|
||||
// imageJpeg = "jpeg"
|
||||
// imageGif = "gif"
|
||||
// imageWebp = "webp"
|
||||
// )
|
||||
|
||||
// const successCode = http.StatusOK
|
||||
|
||||
// const (
|
||||
// videoSnapshotImagePng = "png"
|
||||
// videoSnapshotImageJpg = "jpg"
|
||||
// )
|
||||
|
||||
func NewAWS() (s3.Interface, error) {
|
||||
conf := config.Config.Object.Aws
|
||||
credential := credentials.NewStaticCredentials(
|
||||
conf.AccessKeyID, // accessKey
|
||||
conf.AccessKeySecret, // secretKey
|
||||
"") // stoken
|
||||
|
||||
sess, err := session.NewSession(&aws.Config{
|
||||
Region: aws.String(conf.Region), // The area where the bucket is located
|
||||
Credentials: credential,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Aws{
|
||||
bucket: conf.Bucket,
|
||||
client: sdk.New(sess),
|
||||
credential: credential,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Aws struct {
|
||||
bucket string
|
||||
client *sdk.S3
|
||||
credential *credentials.Credentials
|
||||
}
|
||||
|
||||
func (a *Aws) Engine() string {
|
||||
return "aws"
|
||||
}
|
||||
|
||||
func (a *Aws) InitiateMultipartUpload(ctx context.Context, name string) (*s3.InitiateMultipartUploadResult, error) {
|
||||
input := &sdk.CreateMultipartUploadInput{
|
||||
Bucket: aws.String(a.bucket), // TODO: To be verified whether it is required
|
||||
Key: aws.String(name),
|
||||
}
|
||||
result, err := a.client.CreateMultipartUploadWithContext(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s3.InitiateMultipartUploadResult{
|
||||
Bucket: *result.Bucket,
|
||||
Key: *result.Key,
|
||||
UploadID: *result.UploadId,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Aws) CompleteMultipartUpload(ctx context.Context, uploadID string, name string, parts []s3.Part) (*s3.CompleteMultipartUploadResult, error) {
|
||||
sdkParts := make([]*sdk.CompletedPart, len(parts))
|
||||
for i, part := range parts {
|
||||
sdkParts[i] = &sdk.CompletedPart{
|
||||
ETag: aws.String(part.ETag),
|
||||
PartNumber: aws.Int64(int64(part.PartNumber)),
|
||||
}
|
||||
}
|
||||
input := &sdk.CompleteMultipartUploadInput{
|
||||
Bucket: aws.String(a.bucket), // TODO: To be verified whether it is required
|
||||
Key: aws.String(name),
|
||||
UploadId: aws.String(uploadID),
|
||||
MultipartUpload: &sdk.CompletedMultipartUpload{
|
||||
Parts: sdkParts,
|
||||
},
|
||||
}
|
||||
result, err := a.client.CompleteMultipartUploadWithContext(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s3.CompleteMultipartUploadResult{
|
||||
Location: *result.Location,
|
||||
Bucket: *result.Bucket,
|
||||
Key: *result.Key,
|
||||
ETag: *result.ETag,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Aws) PartSize(ctx context.Context, size int64) (int64, error) {
|
||||
if size <= 0 {
|
||||
return 0, errors.New("size must be greater than 0")
|
||||
}
|
||||
if size > maxPartSize*maxNumSize {
|
||||
return 0, fmt.Errorf("AWS size must be less than the maximum allowed limit")
|
||||
}
|
||||
if size <= minPartSize*maxNumSize {
|
||||
return minPartSize, nil
|
||||
}
|
||||
partSize := size / maxNumSize
|
||||
if size%maxNumSize != 0 {
|
||||
partSize++
|
||||
}
|
||||
return partSize, nil
|
||||
}
|
||||
|
||||
func (a *Aws) DeleteObject(ctx context.Context, name string) error {
|
||||
_, err := a.client.DeleteObjectWithContext(ctx, &sdk.DeleteObjectInput{
|
||||
Bucket: aws.String(a.bucket),
|
||||
Key: aws.String(name),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *Aws) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyObjectInfo, error) {
|
||||
result, err := a.client.CopyObjectWithContext(ctx, &sdk.CopyObjectInput{
|
||||
Bucket: aws.String(a.bucket),
|
||||
Key: aws.String(dst),
|
||||
CopySource: aws.String(src),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s3.CopyObjectInfo{
|
||||
ETag: *result.CopyObjectResult.ETag,
|
||||
Key: dst,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Aws) IsNotFound(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if aerr, ok := err.(awserr.Error); ok {
|
||||
switch aerr.Code() {
|
||||
case sdk.ErrCodeNoSuchKey:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Aws) AbortMultipartUpload(ctx context.Context, uploadID string, name string) error {
|
||||
_, err := a.client.AbortMultipartUploadWithContext(ctx, &sdk.AbortMultipartUploadInput{
|
||||
Bucket: aws.String(a.bucket),
|
||||
Key: aws.String(name),
|
||||
UploadId: aws.String(uploadID),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *Aws) ListUploadedParts(ctx context.Context, uploadID string, name string, partNumberMarker int, maxParts int) (*s3.ListUploadedPartsResult, error) {
|
||||
result, err := a.client.ListPartsWithContext(ctx, &sdk.ListPartsInput{
|
||||
Bucket: aws.String(a.bucket),
|
||||
Key: aws.String(name),
|
||||
UploadId: aws.String(uploadID),
|
||||
MaxParts: aws.Int64(int64(maxParts)),
|
||||
PartNumberMarker: aws.Int64(int64(partNumberMarker)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts := make([]s3.UploadedPart, len(result.Parts))
|
||||
for i, part := range result.Parts {
|
||||
parts[i] = s3.UploadedPart{
|
||||
PartNumber: int(*part.PartNumber),
|
||||
LastModified: *part.LastModified,
|
||||
Size: *part.Size,
|
||||
ETag: *part.ETag,
|
||||
}
|
||||
}
|
||||
return &s3.ListUploadedPartsResult{
|
||||
Key: *result.Key,
|
||||
UploadID: *result.UploadId,
|
||||
NextPartNumberMarker: int(*result.NextPartNumberMarker),
|
||||
MaxParts: int(*result.MaxParts),
|
||||
UploadedParts: parts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Aws) PartLimit() *s3.PartLimit {
|
||||
return &s3.PartLimit{
|
||||
MinPartSize: minPartSize,
|
||||
MaxPartSize: maxPartSize,
|
||||
MaxNumSize: maxNumSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Aws) PresignedPutObject(ctx context.Context, name string, expire time.Duration) (string, error) {
|
||||
req, _ := a.client.PutObjectRequest(&sdk.PutObjectInput{
|
||||
Bucket: aws.String(a.bucket),
|
||||
Key: aws.String(name),
|
||||
})
|
||||
url, err := req.Presign(expire)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
|
||||
func (a *Aws) StatObject(ctx context.Context, name string) (*s3.ObjectInfo, error) {
|
||||
result, err := a.client.GetObjectWithContext(ctx, &sdk.GetObjectInput{
|
||||
Bucket: aws.String(a.bucket),
|
||||
Key: aws.String(name),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res := &s3.ObjectInfo{
|
||||
Key: name,
|
||||
ETag: *result.ETag,
|
||||
Size: *result.ContentLength,
|
||||
LastModified: *result.LastModified,
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// AccessURL todo.
|
||||
func (a *Aws) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (string, error) {
|
||||
// todo
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (a *Aws) FormData(ctx context.Context, name string, size int64, contentType string, duration time.Duration) (*s3.FormData, error) {
|
||||
// todo
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Aws) AuthSign(ctx context.Context, uploadID string, name string, expire time.Duration, partNumbers []int) (*s3.AuthSignResult, error) {
|
||||
// todo
|
||||
return nil, nil
|
||||
}
|
||||
@@ -23,13 +23,13 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/db/s3"
|
||||
"github.com/tencentyun/cos-go-sdk-v5"
|
||||
)
|
||||
@@ -50,13 +50,15 @@ const (
|
||||
|
||||
const successCode = http.StatusOK
|
||||
|
||||
const (
|
||||
// videoSnapshotImagePng = "png"
|
||||
// videoSnapshotImageJpg = "jpg"
|
||||
)
|
||||
type Config struct {
|
||||
BucketURL string
|
||||
SecretID string
|
||||
SecretKey string
|
||||
SessionToken string
|
||||
PublicRead bool
|
||||
}
|
||||
|
||||
func NewCos() (s3.Interface, error) {
|
||||
conf := config.Config.Object.Cos
|
||||
func NewCos(conf Config) (s3.Interface, error) {
|
||||
u, err := url.Parse(conf.BucketURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -69,6 +71,7 @@ func NewCos() (s3.Interface, error) {
|
||||
},
|
||||
})
|
||||
return &Cos{
|
||||
publicRead: conf.PublicRead,
|
||||
copyURL: u.Host + "/",
|
||||
client: client,
|
||||
credential: client.GetCredential(),
|
||||
@@ -76,6 +79,7 @@ func NewCos() (s3.Interface, error) {
|
||||
}
|
||||
|
||||
type Cos struct {
|
||||
publicRead bool
|
||||
copyURL string
|
||||
client *cos.Client
|
||||
credential *cos.Credential
|
||||
@@ -226,7 +230,7 @@ func (c *Cos) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyO
|
||||
}
|
||||
|
||||
func (c *Cos) IsNotFound(err error) bool {
|
||||
switch e := err.(type) {
|
||||
switch e := errs.Unwrap(err).(type) {
|
||||
case *cos.ErrorResponse:
|
||||
return e.Response.StatusCode == http.StatusNotFound || e.Code == "NoSuchKey"
|
||||
default:
|
||||
@@ -327,7 +331,7 @@ func (c *Cos) AccessURL(ctx context.Context, name string, expire time.Duration,
|
||||
}
|
||||
|
||||
func (c *Cos) getPresignedURL(ctx context.Context, name string, expire time.Duration, opt *cos.PresignedURLOptions) (*url.URL, error) {
|
||||
if !config.Config.Object.Cos.PublicRead {
|
||||
if !c.publicRead {
|
||||
return c.client.Object.GetPresignedURL(ctx, http.MethodGet, name, c.credential.SecretID, c.credential.SecretKey, expire, opt)
|
||||
}
|
||||
return c.client.Object.GetObjectURL(name), nil
|
||||
|
||||
@@ -42,51 +42,79 @@ func ImageWidthHeight(img image.Image) (int, int) {
|
||||
return bounds.X, bounds.Y
|
||||
}
|
||||
|
||||
// resizeImage resizes an image to a specified maximum width and height, maintaining the aspect ratio.
|
||||
// If both maxWidth and maxHeight are set to 0, the original image is returned.
|
||||
// If both are non-zero, the image is scaled to fit within the constraints while maintaining aspect ratio.
|
||||
// If only one of maxWidth or maxHeight is non-zero, the image is scaled accordingly.
|
||||
func resizeImage(img image.Image, maxWidth, maxHeight int) image.Image {
|
||||
bounds := img.Bounds()
|
||||
imgWidth, imgHeight := bounds.Dx(), bounds.Dy()
|
||||
imgWidth := bounds.Max.X
|
||||
imgHeight := bounds.Max.Y
|
||||
|
||||
// Return original image if no resizing is needed.
|
||||
// 计算缩放比例
|
||||
scaleWidth := float64(maxWidth) / float64(imgWidth)
|
||||
scaleHeight := float64(maxHeight) / float64(imgHeight)
|
||||
|
||||
// 如果都为0,则不缩放,返回原始图片
|
||||
if maxWidth == 0 && maxHeight == 0 {
|
||||
return img
|
||||
}
|
||||
|
||||
var scale float64 = 1
|
||||
// 如果宽度和高度都大于0,则选择较小的缩放比例,以保持宽高比
|
||||
if maxWidth > 0 && maxHeight > 0 {
|
||||
scaleWidth := float64(maxWidth) / float64(imgWidth)
|
||||
scaleHeight := float64(maxHeight) / float64(imgHeight)
|
||||
// Choose the smaller scale to fit both constraints.
|
||||
scale = min(scaleWidth, scaleHeight)
|
||||
} else if maxWidth > 0 {
|
||||
scale = float64(maxWidth) / float64(imgWidth)
|
||||
} else if maxHeight > 0 {
|
||||
scale = float64(maxHeight) / float64(imgHeight)
|
||||
}
|
||||
|
||||
newWidth := int(float64(imgWidth) * scale)
|
||||
newHeight := int(float64(imgHeight) * scale)
|
||||
|
||||
// Resize the image by creating a new image and manually copying pixels.
|
||||
thumbnail := image.NewRGBA(image.Rect(0, 0, newWidth, newHeight))
|
||||
for y := 0; y < newHeight; y++ {
|
||||
for x := 0; x < newWidth; x++ {
|
||||
srcX := int(float64(x) / scale)
|
||||
srcY := int(float64(y) / scale)
|
||||
thumbnail.Set(x, y, img.At(srcX, srcY))
|
||||
scale := scaleWidth
|
||||
if scaleHeight < scaleWidth {
|
||||
scale = scaleHeight
|
||||
}
|
||||
|
||||
// 计算缩略图尺寸
|
||||
thumbnailWidth := int(float64(imgWidth) * scale)
|
||||
thumbnailHeight := int(float64(imgHeight) * scale)
|
||||
|
||||
// 使用"image"库的Resample方法生成缩略图
|
||||
thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight))
|
||||
for y := 0; y < thumbnailHeight; y++ {
|
||||
for x := 0; x < thumbnailWidth; x++ {
|
||||
srcX := int(float64(x) / scale)
|
||||
srcY := int(float64(y) / scale)
|
||||
thumbnail.Set(x, y, img.At(srcX, srcY))
|
||||
}
|
||||
}
|
||||
|
||||
return thumbnail
|
||||
}
|
||||
|
||||
return thumbnail
|
||||
}
|
||||
// 如果只指定了宽度或高度,则根据最大不超过的规则生成缩略图
|
||||
if maxWidth > 0 {
|
||||
thumbnailWidth := maxWidth
|
||||
thumbnailHeight := int(float64(imgHeight) * scaleWidth)
|
||||
|
||||
// min returns the smaller of x or y.
|
||||
func min(x, y float64) float64 {
|
||||
if x < y {
|
||||
return x
|
||||
// 使用"image"库的Resample方法生成缩略图
|
||||
thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight))
|
||||
for y := 0; y < thumbnailHeight; y++ {
|
||||
for x := 0; x < thumbnailWidth; x++ {
|
||||
srcX := int(float64(x) / scaleWidth)
|
||||
srcY := int(float64(y) / scaleWidth)
|
||||
thumbnail.Set(x, y, img.At(srcX, srcY))
|
||||
}
|
||||
}
|
||||
|
||||
return thumbnail
|
||||
}
|
||||
return y
|
||||
|
||||
if maxHeight > 0 {
|
||||
thumbnailWidth := int(float64(imgWidth) * scaleHeight)
|
||||
thumbnailHeight := maxHeight
|
||||
|
||||
// 使用"image"库的Resample方法生成缩略图
|
||||
thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight))
|
||||
for y := 0; y < thumbnailHeight; y++ {
|
||||
for x := 0; x < thumbnailWidth; x++ {
|
||||
srcX := int(float64(x) / scaleHeight)
|
||||
srcY := int(float64(y) / scaleHeight)
|
||||
thumbnail.Set(x, y, img.At(srcX, srcY))
|
||||
}
|
||||
}
|
||||
|
||||
return thumbnail
|
||||
}
|
||||
|
||||
// 默认情况下,返回原始图片
|
||||
return img
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -33,7 +34,6 @@ import (
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
"github.com/minio/minio-go/v7/pkg/signer"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/db/cache"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/db/s3"
|
||||
)
|
||||
@@ -43,7 +43,7 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
minPartSize int64 = 1024 * 1024 * 5 // 1MB
|
||||
minPartSize int64 = 1024 * 1024 * 5 // 5MB
|
||||
maxPartSize int64 = 1024 * 1024 * 1024 * 5 // 5GB
|
||||
maxNumSize int64 = 10000
|
||||
)
|
||||
@@ -57,13 +57,23 @@ const (
|
||||
|
||||
const successCode = http.StatusOK
|
||||
|
||||
func NewMinio(cache cache.MinioCache) (s3.Interface, error) {
|
||||
u, err := url.Parse(config.Config.Object.Minio.Endpoint)
|
||||
type Config struct {
|
||||
Bucket string
|
||||
Endpoint string
|
||||
AccessKeyID string
|
||||
SecretAccessKey string
|
||||
SessionToken string
|
||||
SignEndpoint string
|
||||
PublicRead bool
|
||||
}
|
||||
|
||||
func NewMinio(cache cache.MinioCache, conf Config) (s3.Interface, error) {
|
||||
u, err := url.Parse(conf.Endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts := &minio.Options{
|
||||
Creds: credentials.NewStaticV4(config.Config.Object.Minio.AccessKeyID, config.Config.Object.Minio.SecretAccessKey, config.Config.Object.Minio.SessionToken),
|
||||
Creds: credentials.NewStaticV4(conf.AccessKeyID, conf.SecretAccessKey, conf.SessionToken),
|
||||
Secure: u.Scheme == "https",
|
||||
}
|
||||
client, err := minio.New(u.Host, opts)
|
||||
@@ -71,26 +81,27 @@ func NewMinio(cache cache.MinioCache) (s3.Interface, error) {
|
||||
return nil, err
|
||||
}
|
||||
m := &Minio{
|
||||
bucket: config.Config.Object.Minio.Bucket,
|
||||
conf: conf,
|
||||
bucket: conf.Bucket,
|
||||
core: &minio.Core{Client: client},
|
||||
lock: &sync.Mutex{},
|
||||
init: false,
|
||||
cache: cache,
|
||||
}
|
||||
if config.Config.Object.Minio.SignEndpoint == "" || config.Config.Object.Minio.SignEndpoint == config.Config.Object.Minio.Endpoint {
|
||||
if conf.SignEndpoint == "" || conf.SignEndpoint == conf.Endpoint {
|
||||
m.opts = opts
|
||||
m.sign = m.core.Client
|
||||
m.prefix = u.Path
|
||||
u.Path = ""
|
||||
config.Config.Object.Minio.Endpoint = u.String()
|
||||
m.signEndpoint = config.Config.Object.Minio.Endpoint
|
||||
conf.Endpoint = u.String()
|
||||
m.signEndpoint = conf.Endpoint
|
||||
} else {
|
||||
su, err := url.Parse(config.Config.Object.Minio.SignEndpoint)
|
||||
su, err := url.Parse(conf.SignEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.opts = &minio.Options{
|
||||
Creds: credentials.NewStaticV4(config.Config.Object.Minio.AccessKeyID, config.Config.Object.Minio.SecretAccessKey, config.Config.Object.Minio.SessionToken),
|
||||
Creds: credentials.NewStaticV4(conf.AccessKeyID, conf.SecretAccessKey, conf.SessionToken),
|
||||
Secure: su.Scheme == "https",
|
||||
}
|
||||
m.sign, err = minio.New(su.Host, m.opts)
|
||||
@@ -99,8 +110,8 @@ func NewMinio(cache cache.MinioCache) (s3.Interface, error) {
|
||||
}
|
||||
m.prefix = su.Path
|
||||
su.Path = ""
|
||||
config.Config.Object.Minio.SignEndpoint = su.String()
|
||||
m.signEndpoint = config.Config.Object.Minio.SignEndpoint
|
||||
conf.SignEndpoint = su.String()
|
||||
m.signEndpoint = conf.SignEndpoint
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
@@ -111,6 +122,7 @@ func NewMinio(cache cache.MinioCache) (s3.Interface, error) {
|
||||
}
|
||||
|
||||
type Minio struct {
|
||||
conf Config
|
||||
bucket string
|
||||
signEndpoint string
|
||||
location string
|
||||
@@ -132,31 +144,30 @@ func (m *Minio) initMinio(ctx context.Context) error {
|
||||
if m.init {
|
||||
return nil
|
||||
}
|
||||
conf := config.Config.Object.Minio
|
||||
exists, err := m.core.Client.BucketExists(ctx, conf.Bucket)
|
||||
exists, err := m.core.Client.BucketExists(ctx, m.conf.Bucket)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check bucket exists error: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
if err = m.core.Client.MakeBucket(ctx, conf.Bucket, minio.MakeBucketOptions{}); err != nil {
|
||||
if err = m.core.Client.MakeBucket(ctx, m.conf.Bucket, minio.MakeBucketOptions{}); err != nil {
|
||||
return fmt.Errorf("make bucket error: %w", err)
|
||||
}
|
||||
}
|
||||
if conf.PublicRead {
|
||||
if m.conf.PublicRead {
|
||||
policy := fmt.Sprintf(
|
||||
`{"Version": "2012-10-17","Statement": [{"Action": ["s3:GetObject","s3:PutObject"],"Effect": "Allow","Principal": {"AWS": ["*"]},"Resource": ["arn:aws:s3:::%s/*"],"Sid": ""}]}`,
|
||||
conf.Bucket,
|
||||
m.conf.Bucket,
|
||||
)
|
||||
if err = m.core.Client.SetBucketPolicy(ctx, conf.Bucket, policy); err != nil {
|
||||
if err = m.core.Client.SetBucketPolicy(ctx, m.conf.Bucket, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
m.location, err = m.core.Client.GetBucketLocation(ctx, conf.Bucket)
|
||||
m.location, err = m.core.Client.GetBucketLocation(ctx, m.conf.Bucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
func() {
|
||||
if conf.SignEndpoint == "" || conf.SignEndpoint == conf.Endpoint {
|
||||
if m.conf.SignEndpoint == "" || m.conf.SignEndpoint == m.conf.Endpoint {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -176,7 +187,7 @@ func (m *Minio) initMinio(ctx context.Context) error {
|
||||
blc := reflect.ValueOf(m.sign).Elem().FieldByName("bucketLocCache")
|
||||
vblc := reflect.New(reflect.PtrTo(blc.Type()))
|
||||
*(*unsafe.Pointer)(vblc.UnsafePointer()) = unsafe.Pointer(blc.UnsafeAddr())
|
||||
vblc.Elem().Elem().Interface().(interface{ Set(string, string) }).Set(conf.Bucket, m.location)
|
||||
vblc.Elem().Elem().Interface().(interface{ Set(string, string) }).Set(m.conf.Bucket, m.location)
|
||||
}()
|
||||
m.init = true
|
||||
return nil
|
||||
@@ -341,10 +352,7 @@ func (m *Minio) CopyObject(ctx context.Context, src string, dst string) (*s3.Cop
|
||||
}
|
||||
|
||||
func (m *Minio) IsNotFound(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
switch e := err.(type) {
|
||||
switch e := errs.Unwrap(err).(type) {
|
||||
case minio.ErrorResponse:
|
||||
return e.StatusCode == http.StatusNotFound || e.Code == "NoSuchKey"
|
||||
case *minio.ErrorResponse:
|
||||
@@ -397,7 +405,7 @@ func (m *Minio) PresignedGetObject(ctx context.Context, name string, expire time
|
||||
rawURL *url.URL
|
||||
err error
|
||||
)
|
||||
if config.Config.Object.Minio.PublicRead {
|
||||
if m.conf.PublicRead {
|
||||
rawURL, err = makeTargetURL(m.sign, m.bucket, name, m.location, false, query)
|
||||
} else {
|
||||
rawURL, err = m.sign.PresignedGetObject(ctx, m.bucket, name, expire, query)
|
||||
|
||||
+15
-11
@@ -32,7 +32,6 @@ import (
|
||||
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
"github.com/aliyun/aliyun-oss-go-sdk/oss"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/db/s3"
|
||||
)
|
||||
|
||||
@@ -52,13 +51,17 @@ const (
|
||||
|
||||
const successCode = http.StatusOK
|
||||
|
||||
/* const (
|
||||
videoSnapshotImagePng = "png"
|
||||
videoSnapshotImageJpg = "jpg"
|
||||
) */
|
||||
type Config struct {
|
||||
Endpoint string
|
||||
Bucket string
|
||||
BucketURL string
|
||||
AccessKeyID string
|
||||
AccessKeySecret string
|
||||
SessionToken string
|
||||
PublicRead bool
|
||||
}
|
||||
|
||||
func NewOSS() (s3.Interface, error) {
|
||||
conf := config.Config.Object.Oss
|
||||
func NewOSS(conf Config) (s3.Interface, error) {
|
||||
if conf.BucketURL == "" {
|
||||
return nil, errs.Wrap(errors.New("bucket url is empty"))
|
||||
}
|
||||
@@ -78,6 +81,7 @@ func NewOSS() (s3.Interface, error) {
|
||||
bucket: bucket,
|
||||
credentials: client.Config.GetCredentials(),
|
||||
um: *(*urlMaker)(reflect.ValueOf(bucket.Client.Conn).Elem().FieldByName("url").UnsafePointer()),
|
||||
publicRead: conf.PublicRead,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -86,6 +90,7 @@ type OSS struct {
|
||||
bucket *oss.Bucket
|
||||
credentials oss.Credentials
|
||||
um urlMaker
|
||||
publicRead bool
|
||||
}
|
||||
|
||||
func (o *OSS) Engine() string {
|
||||
@@ -236,7 +241,7 @@ func (o *OSS) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyO
|
||||
}
|
||||
|
||||
func (o *OSS) IsNotFound(err error) bool {
|
||||
switch e := err.(type) {
|
||||
switch e := errs.Unwrap(err).(type) {
|
||||
case oss.ServiceError:
|
||||
return e.StatusCode == http.StatusNotFound || e.Code == "NoSuchKey"
|
||||
case *oss.ServiceError:
|
||||
@@ -282,7 +287,6 @@ func (o *OSS) ListUploadedParts(ctx context.Context, uploadID string, name strin
|
||||
}
|
||||
|
||||
func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (string, error) {
|
||||
publicRead := config.Config.Object.Oss.PublicRead
|
||||
var opts []oss.Option
|
||||
if opt != nil {
|
||||
if opt.Image != nil {
|
||||
@@ -310,7 +314,7 @@ func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration,
|
||||
process += ",format," + format
|
||||
opts = append(opts, oss.Process(process))
|
||||
}
|
||||
if !publicRead {
|
||||
if !o.publicRead {
|
||||
if opt.ContentType != "" {
|
||||
opts = append(opts, oss.ResponseContentType(opt.ContentType))
|
||||
}
|
||||
@@ -324,7 +328,7 @@ func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration,
|
||||
} else if expire < time.Second {
|
||||
expire = time.Second
|
||||
}
|
||||
if !publicRead {
|
||||
if !o.publicRead {
|
||||
return o.bucket.SignURL(name, http.MethodGet, int64(expire/time.Second), opts...)
|
||||
}
|
||||
rawParams, err := oss.GetRawParams(opts)
|
||||
|
||||
@@ -36,13 +36,14 @@ const (
|
||||
)
|
||||
|
||||
type Mongo struct {
|
||||
db *mongo.Client
|
||||
db *mongo.Client
|
||||
config *config.GlobalConfig
|
||||
}
|
||||
|
||||
// NewMongo Initialize MongoDB connection.
|
||||
func NewMongo() (*Mongo, error) {
|
||||
func NewMongo(config *config.GlobalConfig) (*Mongo, error) {
|
||||
specialerror.AddReplace(mongo.ErrNoDocuments, errs.ErrRecordNotFound)
|
||||
uri := buildMongoURI()
|
||||
uri := buildMongoURI(config)
|
||||
|
||||
var mongoClient *mongo.Client
|
||||
var err error
|
||||
@@ -56,7 +57,7 @@ func NewMongo() (*Mongo, error) {
|
||||
if err = mongoClient.Ping(ctx, nil); err != nil {
|
||||
return nil, errs.Wrap(err, uri)
|
||||
}
|
||||
return &Mongo{db: mongoClient}, nil
|
||||
return &Mongo{db: mongoClient, config: config}, nil
|
||||
}
|
||||
if shouldRetry(err) {
|
||||
time.Sleep(time.Second) // exponential backoff could be implemented here
|
||||
@@ -66,14 +67,14 @@ func NewMongo() (*Mongo, error) {
|
||||
return nil, errs.Wrap(err, uri)
|
||||
}
|
||||
|
||||
func buildMongoURI() string {
|
||||
func buildMongoURI(config *config.GlobalConfig) string {
|
||||
uri := os.Getenv("MONGO_URI")
|
||||
if uri != "" {
|
||||
return uri
|
||||
}
|
||||
|
||||
if config.Config.Mongo.Uri != "" {
|
||||
return config.Config.Mongo.Uri
|
||||
if config.Mongo.Uri != "" {
|
||||
return config.Mongo.Uri
|
||||
}
|
||||
|
||||
username := os.Getenv("MONGO_OPENIM_USERNAME")
|
||||
@@ -84,21 +85,21 @@ func buildMongoURI() string {
|
||||
maxPoolSize := os.Getenv("MONGO_MAX_POOL_SIZE")
|
||||
|
||||
if username == "" {
|
||||
username = config.Config.Mongo.Username
|
||||
username = config.Mongo.Username
|
||||
}
|
||||
if password == "" {
|
||||
password = config.Config.Mongo.Password
|
||||
password = config.Mongo.Password
|
||||
}
|
||||
if address == "" {
|
||||
address = strings.Join(config.Config.Mongo.Address, ",")
|
||||
address = strings.Join(config.Mongo.Address, ",")
|
||||
} else if port != "" {
|
||||
address = fmt.Sprintf("%s:%s", address, port)
|
||||
}
|
||||
if database == "" {
|
||||
database = config.Config.Mongo.Database
|
||||
database = config.Mongo.Database
|
||||
}
|
||||
if maxPoolSize == "" {
|
||||
maxPoolSize = fmt.Sprint(config.Config.Mongo.MaxPoolSize)
|
||||
maxPoolSize = fmt.Sprint(config.Mongo.MaxPoolSize)
|
||||
}
|
||||
|
||||
uriFormat := "mongodb://%s/%s?maxPoolSize=%s"
|
||||
@@ -122,8 +123,8 @@ func (m *Mongo) GetClient() *mongo.Client {
|
||||
}
|
||||
|
||||
// GetDatabase returns the specific database from MongoDB.
|
||||
func (m *Mongo) GetDatabase() *mongo.Database {
|
||||
return m.db.Database(config.Config.Mongo.Database)
|
||||
func (m *Mongo) GetDatabase(database string) *mongo.Database {
|
||||
return m.db.Database(database)
|
||||
}
|
||||
|
||||
// CreateMsgIndex creates an index for messages in MongoDB.
|
||||
@@ -133,7 +134,7 @@ func (m *Mongo) CreateMsgIndex() error {
|
||||
|
||||
// createMongoIndex creates an index in a MongoDB collection.
|
||||
func (m *Mongo) createMongoIndex(collection string, isUnique bool, keys ...string) error {
|
||||
db := m.GetDatabase().Collection(collection)
|
||||
db := m.GetDatabase(m.config.Mongo.Database).Collection(collection)
|
||||
opts := options.CreateIndexes().SetMaxTime(10 * time.Second)
|
||||
indexView := db.Indexes()
|
||||
|
||||
|
||||
@@ -27,17 +27,17 @@ import (
|
||||
|
||||
type ServiceAddresses map[string][]int
|
||||
|
||||
func getServiceAddresses() ServiceAddresses {
|
||||
func getServiceAddresses(config *config2.GlobalConfig) ServiceAddresses {
|
||||
return ServiceAddresses{
|
||||
config2.Config.RpcRegisterName.OpenImUserName: config2.Config.RpcPort.OpenImUserPort,
|
||||
config2.Config.RpcRegisterName.OpenImFriendName: config2.Config.RpcPort.OpenImFriendPort,
|
||||
config2.Config.RpcRegisterName.OpenImMsgName: config2.Config.RpcPort.OpenImMessagePort,
|
||||
config2.Config.RpcRegisterName.OpenImMessageGatewayName: config2.Config.LongConnSvr.OpenImMessageGatewayPort,
|
||||
config2.Config.RpcRegisterName.OpenImGroupName: config2.Config.RpcPort.OpenImGroupPort,
|
||||
config2.Config.RpcRegisterName.OpenImAuthName: config2.Config.RpcPort.OpenImAuthPort,
|
||||
config2.Config.RpcRegisterName.OpenImPushName: config2.Config.RpcPort.OpenImPushPort,
|
||||
config2.Config.RpcRegisterName.OpenImConversationName: config2.Config.RpcPort.OpenImConversationPort,
|
||||
config2.Config.RpcRegisterName.OpenImThirdName: config2.Config.RpcPort.OpenImThirdPort,
|
||||
config.RpcRegisterName.OpenImUserName: config.RpcPort.OpenImUserPort,
|
||||
config.RpcRegisterName.OpenImFriendName: config.RpcPort.OpenImFriendPort,
|
||||
config.RpcRegisterName.OpenImMsgName: config.RpcPort.OpenImMessagePort,
|
||||
config.RpcRegisterName.OpenImMessageGatewayName: config.LongConnSvr.OpenImMessageGatewayPort,
|
||||
config.RpcRegisterName.OpenImGroupName: config.RpcPort.OpenImGroupPort,
|
||||
config.RpcRegisterName.OpenImAuthName: config.RpcPort.OpenImAuthPort,
|
||||
config.RpcRegisterName.OpenImPushName: config.RpcPort.OpenImPushPort,
|
||||
config.RpcRegisterName.OpenImConversationName: config.RpcPort.OpenImConversationPort,
|
||||
config.RpcRegisterName.OpenImThirdName: config.RpcPort.OpenImThirdPort,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ type ConnDirect struct {
|
||||
currentServiceAddress string
|
||||
conns map[string][]*grpc.ClientConn
|
||||
resolverDirect *ResolverDirect
|
||||
config *config2.GlobalConfig
|
||||
}
|
||||
|
||||
func (cd *ConnDirect) GetClientLocalConns() map[string][]*grpc.ClientConn {
|
||||
@@ -80,10 +81,11 @@ func (cd *ConnDirect) Close() {
|
||||
|
||||
}
|
||||
|
||||
func NewConnDirect() (*ConnDirect, error) {
|
||||
func NewConnDirect(config *config2.GlobalConfig) (*ConnDirect, error) {
|
||||
return &ConnDirect{
|
||||
conns: make(map[string][]*grpc.ClientConn),
|
||||
resolverDirect: NewResolverDirect(),
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -93,12 +95,12 @@ func (cd *ConnDirect) GetConns(ctx context.Context,
|
||||
if conns, exists := cd.conns[serviceName]; exists {
|
||||
return conns, nil
|
||||
}
|
||||
ports := getServiceAddresses()[serviceName]
|
||||
ports := getServiceAddresses(cd.config)[serviceName]
|
||||
var connections []*grpc.ClientConn
|
||||
for _, port := range ports {
|
||||
conn, err := cd.dialServiceWithoutResolver(ctx, fmt.Sprintf(config2.Config.Rpc.ListenIP+":%d", port), append(cd.additionalOpts, opts...)...)
|
||||
conn, err := cd.dialServiceWithoutResolver(ctx, fmt.Sprintf(cd.config.Rpc.ListenIP+":%d", port), append(cd.additionalOpts, opts...)...)
|
||||
if err != nil {
|
||||
fmt.Printf("connect to port %d failed,serviceName %s, IP %s\n", port, serviceName, config2.Config.Rpc.ListenIP)
|
||||
fmt.Printf("connect to port %d failed,serviceName %s, IP %s\n", port, serviceName, cd.config.Rpc.ListenIP)
|
||||
}
|
||||
connections = append(connections, conn)
|
||||
}
|
||||
@@ -111,7 +113,7 @@ func (cd *ConnDirect) GetConns(ctx context.Context,
|
||||
|
||||
func (cd *ConnDirect) GetConn(ctx context.Context, serviceName string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
|
||||
// Get service addresses
|
||||
addresses := getServiceAddresses()
|
||||
addresses := getServiceAddresses(cd.config)
|
||||
address, ok := addresses[serviceName]
|
||||
if !ok {
|
||||
return nil, errs.Wrap(errors.New("unknown service name"), "serviceName", serviceName)
|
||||
@@ -119,9 +121,9 @@ func (cd *ConnDirect) GetConn(ctx context.Context, serviceName string, opts ...g
|
||||
var result string
|
||||
for _, addr := range address {
|
||||
if result != "" {
|
||||
result = result + "," + fmt.Sprintf(config2.Config.Rpc.ListenIP+":%d", addr)
|
||||
result = result + "," + fmt.Sprintf(cd.config.Rpc.ListenIP+":%d", addr)
|
||||
} else {
|
||||
result = fmt.Sprintf(config2.Config.Rpc.ListenIP+":%d", addr)
|
||||
result = fmt.Sprintf(cd.config.Rpc.ListenIP+":%d", addr)
|
||||
}
|
||||
}
|
||||
// Try to dial a new connection
|
||||
|
||||
@@ -16,6 +16,7 @@ package discoveryregister
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"os"
|
||||
|
||||
"github.com/OpenIMSDK/tools/discoveryregistry"
|
||||
@@ -26,19 +27,19 @@ import (
|
||||
)
|
||||
|
||||
// NewDiscoveryRegister creates a new service discovery and registry client based on the provided environment type.
|
||||
func NewDiscoveryRegister(envType string) (discoveryregistry.SvcDiscoveryRegistry, error) {
|
||||
func NewDiscoveryRegister(config *config.GlobalConfig) (discoveryregistry.SvcDiscoveryRegistry, error) {
|
||||
|
||||
if os.Getenv("ENVS_DISCOVERY") != "" {
|
||||
envType = os.Getenv("ENVS_DISCOVERY")
|
||||
config.Envs.Discovery = os.Getenv("ENVS_DISCOVERY")
|
||||
}
|
||||
|
||||
switch envType {
|
||||
switch config.Envs.Discovery {
|
||||
case "zookeeper":
|
||||
return zookeeper.NewZookeeperDiscoveryRegister()
|
||||
return zookeeper.NewZookeeperDiscoveryRegister(config)
|
||||
case "k8s":
|
||||
return kubernetes.NewK8sDiscoveryRegister()
|
||||
return kubernetes.NewK8sDiscoveryRegister(config.RpcRegisterName.OpenImMessageGatewayName)
|
||||
case "direct":
|
||||
return direct.NewConnDirect()
|
||||
return direct.NewConnDirect(config)
|
||||
default:
|
||||
return nil, errs.Wrap(errors.New("envType not correct"))
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
package discoveryregister
|
||||
|
||||
import (
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
@@ -32,20 +33,23 @@ func setupTestEnvironment() {
|
||||
|
||||
func TestNewDiscoveryRegister(t *testing.T) {
|
||||
setupTestEnvironment()
|
||||
|
||||
conf := config.NewGlobalConfig()
|
||||
tests := []struct {
|
||||
envType string
|
||||
gatewayName string
|
||||
expectedError bool
|
||||
expectedResult bool
|
||||
}{
|
||||
{"zookeeper", false, true},
|
||||
{"k8s", false, true}, // Assume that the k8s configuration is also set up correctly
|
||||
{"direct", false, true},
|
||||
{"invalid", true, false},
|
||||
{"zookeeper", "MessageGateway", false, true},
|
||||
{"k8s", "MessageGateway", false, true},
|
||||
{"direct", "MessageGateway", false, true},
|
||||
{"invalid", "MessageGateway", true, false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
client, err := NewDiscoveryRegister(test.envType)
|
||||
conf.Envs.Discovery = test.envType
|
||||
conf.RpcRegisterName.OpenImMessageGatewayName = test.gatewayName
|
||||
client, err := NewDiscoveryRegister(conf)
|
||||
|
||||
if test.expectedError {
|
||||
assert.Error(t, err)
|
||||
|
||||
@@ -22,11 +22,12 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/stathat/consistent"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/OpenIMSDK/tools/discoveryregistry"
|
||||
"github.com/OpenIMSDK/tools/log"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/stathat/consistent"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// K8sDR represents the Kubernetes service discovery and registration client.
|
||||
@@ -34,11 +35,12 @@ type K8sDR struct {
|
||||
options []grpc.DialOption
|
||||
rpcRegisterAddr string
|
||||
gatewayHostConsistent *consistent.Consistent
|
||||
gatewayName string
|
||||
}
|
||||
|
||||
func NewK8sDiscoveryRegister() (discoveryregistry.SvcDiscoveryRegistry, error) {
|
||||
func NewK8sDiscoveryRegister(gatewayName string) (discoveryregistry.SvcDiscoveryRegistry, error) {
|
||||
gatewayConsistent := consistent.New()
|
||||
gatewayHosts := getMsgGatewayHost(context.Background())
|
||||
gatewayHosts := getMsgGatewayHost(context.Background(), gatewayName)
|
||||
for _, v := range gatewayHosts {
|
||||
gatewayConsistent.Add(v)
|
||||
}
|
||||
@@ -46,10 +48,10 @@ func NewK8sDiscoveryRegister() (discoveryregistry.SvcDiscoveryRegistry, error) {
|
||||
}
|
||||
|
||||
func (cli *K8sDR) Register(serviceName, host string, port int, opts ...grpc.DialOption) error {
|
||||
if serviceName != config.Config.RpcRegisterName.OpenImMessageGatewayName {
|
||||
if serviceName != cli.gatewayName {
|
||||
cli.rpcRegisterAddr = serviceName
|
||||
} else {
|
||||
cli.rpcRegisterAddr = getSelfHost(context.Background())
|
||||
cli.rpcRegisterAddr = getSelfHost(context.Background(), cli.gatewayName)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -81,15 +83,15 @@ func (cli *K8sDR) GetUserIdHashGatewayHost(ctx context.Context, userId string) (
|
||||
}
|
||||
return host, err
|
||||
}
|
||||
func getSelfHost(ctx context.Context) string {
|
||||
func getSelfHost(ctx context.Context, gatewayName string) string {
|
||||
port := 88
|
||||
instance := "openimserver"
|
||||
selfPodName := os.Getenv("MY_POD_NAME")
|
||||
ns := os.Getenv("MY_POD_NAMESPACE")
|
||||
statefuleIndex := 0
|
||||
gatewayEnds := strings.Split(config.Config.RpcRegisterName.OpenImMessageGatewayName, ":")
|
||||
gatewayEnds := strings.Split(gatewayName, ":")
|
||||
if len(gatewayEnds) != 2 {
|
||||
log.ZError(ctx, "msggateway RpcRegisterName is error:config.Config.RpcRegisterName.OpenImMessageGatewayName", errors.New("config error"))
|
||||
log.ZError(ctx, "msggateway RpcRegisterName is error:config.RpcRegisterName.OpenImMessageGatewayName", errors.New("config error"))
|
||||
} else {
|
||||
port, _ = strconv.Atoi(gatewayEnds[1])
|
||||
}
|
||||
@@ -102,15 +104,15 @@ func getSelfHost(ctx context.Context) string {
|
||||
}
|
||||
|
||||
// like openimserver-openim-msggateway-0.openimserver-openim-msggateway-headless.openim-lin.svc.cluster.local:88.
|
||||
func getMsgGatewayHost(ctx context.Context) []string {
|
||||
func getMsgGatewayHost(ctx context.Context, gatewayName string) []string {
|
||||
port := 88
|
||||
instance := "openimserver"
|
||||
selfPodName := os.Getenv("MY_POD_NAME")
|
||||
replicas := os.Getenv("MY_MSGGATEWAY_REPLICACOUNT")
|
||||
ns := os.Getenv("MY_POD_NAMESPACE")
|
||||
gatewayEnds := strings.Split(config.Config.RpcRegisterName.OpenImMessageGatewayName, ":")
|
||||
gatewayEnds := strings.Split(gatewayName, ":")
|
||||
if len(gatewayEnds) != 2 {
|
||||
log.ZError(ctx, "msggateway RpcRegisterName is error:config.Config.RpcRegisterName.OpenImMessageGatewayName", errors.New("config error"))
|
||||
log.ZError(ctx, "msggateway RpcRegisterName is error:config.RpcRegisterName.OpenImMessageGatewayName", errors.New("config error"))
|
||||
} else {
|
||||
port, _ = strconv.Atoi(gatewayEnds[1])
|
||||
}
|
||||
@@ -131,7 +133,7 @@ func (cli *K8sDR) GetConns(ctx context.Context, serviceName string, opts ...grpc
|
||||
|
||||
// This conditional checks if the serviceName is not the OpenImMessageGatewayName.
|
||||
// It seems to handle a special case for the OpenImMessageGateway.
|
||||
if serviceName != config.Config.RpcRegisterName.OpenImMessageGatewayName {
|
||||
if serviceName != cli.gatewayName {
|
||||
// DialContext creates a client connection to the given target (serviceName) using the specified context.
|
||||
// 'cli.options' are likely default or common options for all connections in this struct.
|
||||
// 'opts...' allows for additional gRPC dial options to be passed and used.
|
||||
@@ -146,7 +148,7 @@ func (cli *K8sDR) GetConns(ctx context.Context, serviceName string, opts ...grpc
|
||||
|
||||
// getMsgGatewayHost presumably retrieves hosts for the message gateway service.
|
||||
// The context is passed, likely for cancellation and timeout control.
|
||||
gatewayHosts := getMsgGatewayHost(ctx)
|
||||
gatewayHosts := getMsgGatewayHost(ctx, cli.gatewayName)
|
||||
|
||||
// Iterating over the retrieved gateway hosts.
|
||||
for _, host := range gatewayHosts {
|
||||
|
||||
@@ -28,11 +28,11 @@ import (
|
||||
)
|
||||
|
||||
// NewZookeeperDiscoveryRegister creates a new instance of ZookeeperDR for Zookeeper service discovery and registration.
|
||||
func NewZookeeperDiscoveryRegister() (discoveryregistry.SvcDiscoveryRegistry, error) {
|
||||
schema := getEnv("ZOOKEEPER_SCHEMA", config.Config.Zookeeper.Schema)
|
||||
zkAddr := getZkAddrFromEnv(config.Config.Zookeeper.ZkAddr)
|
||||
username := getEnv("ZOOKEEPER_USERNAME", config.Config.Zookeeper.Username)
|
||||
password := getEnv("ZOOKEEPER_PASSWORD", config.Config.Zookeeper.Password)
|
||||
func NewZookeeperDiscoveryRegister(config *config.GlobalConfig) (discoveryregistry.SvcDiscoveryRegistry, error) {
|
||||
schema := getEnv("ZOOKEEPER_SCHEMA", config.Zookeeper.Schema)
|
||||
zkAddr := getZkAddrFromEnv(config.Zookeeper.ZkAddr)
|
||||
username := getEnv("ZOOKEEPER_USERNAME", config.Zookeeper.Username)
|
||||
password := getEnv("ZOOKEEPER_PASSWORD", config.Zookeeper.Password)
|
||||
|
||||
zk, err := openkeeper.NewClient(
|
||||
zkAddr,
|
||||
@@ -46,10 +46,10 @@ func NewZookeeperDiscoveryRegister() (discoveryregistry.SvcDiscoveryRegistry, er
|
||||
if err != nil {
|
||||
uriFormat := "address:%s, username:%s, password:%s, schema:%s."
|
||||
errInfo := fmt.Sprintf(uriFormat,
|
||||
config.Config.Zookeeper.ZkAddr,
|
||||
config.Config.Zookeeper.Username,
|
||||
config.Config.Zookeeper.Password,
|
||||
config.Config.Zookeeper.Schema)
|
||||
config.Zookeeper.ZkAddr,
|
||||
config.Zookeeper.Username,
|
||||
config.Zookeeper.Password,
|
||||
config.Zookeeper.Schema)
|
||||
return nil, errs.Wrap(err, errInfo)
|
||||
}
|
||||
return zk, nil
|
||||
|
||||
@@ -17,6 +17,8 @@ package kafka
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
|
||||
"github.com/IBM/sarama"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
)
|
||||
@@ -29,22 +31,31 @@ type Consumer struct {
|
||||
Consumer sarama.Consumer
|
||||
}
|
||||
|
||||
func NewKafkaConsumer(addr []string, topic string, kafkaConfig *sarama.Config) (*Consumer, error) {
|
||||
p := Consumer{
|
||||
Topic: topic,
|
||||
addr: addr,
|
||||
func NewKafkaConsumer(addr []string, topic string, config *config.GlobalConfig) (*Consumer,error) {
|
||||
p := Consumer{}
|
||||
p.Topic = topic
|
||||
p.addr = addr
|
||||
consumerConfig := sarama.NewConfig()
|
||||
if config.Kafka.Username != "" && config.Kafka.Password != "" {
|
||||
consumerConfig.Net.SASL.Enable = true
|
||||
consumerConfig.Net.SASL.User = config.Kafka.Username
|
||||
consumerConfig.Net.SASL.Password = config.Kafka.Password
|
||||
}
|
||||
|
||||
if kafkaConfig.Net.SASL.User != "" && kafkaConfig.Net.SASL.Password != "" {
|
||||
kafkaConfig.Net.SASL.Enable = true
|
||||
var tlsConfig *TLSConfig
|
||||
if config.Kafka.TLS != nil {
|
||||
tlsConfig = &TLSConfig{
|
||||
CACrt: config.Kafka.TLS.CACrt,
|
||||
ClientCrt: config.Kafka.TLS.ClientCrt,
|
||||
ClientKey: config.Kafka.TLS.ClientKey,
|
||||
ClientKeyPwd: config.Kafka.TLS.ClientKeyPwd,
|
||||
InsecureSkipVerify: false,
|
||||
}
|
||||
}
|
||||
|
||||
err := SetupTLSConfig(kafkaConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
err:=SetupTLSConfig(consumerConfig, tlsConfig)
|
||||
if err!=nil{
|
||||
return nil,err
|
||||
}
|
||||
|
||||
consumer, err := sarama.NewConsumer(p.addr, kafkaConfig)
|
||||
consumer, err := sarama.NewConsumer(p.addr, consumerConfig)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err, "NewKafkaConsumer: creating consumer failed")
|
||||
}
|
||||
|
||||
@@ -17,12 +17,12 @@ package kafka
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/IBM/sarama"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
"github.com/OpenIMSDK/tools/log"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
|
||||
"strings"
|
||||
)
|
||||
|
||||
type MConsumerGroup struct {
|
||||
@@ -35,22 +35,25 @@ type MConsumerGroupConfig struct {
|
||||
KafkaVersion sarama.KafkaVersion
|
||||
OffsetsInitial int64
|
||||
IsReturnErr bool
|
||||
UserName string
|
||||
Password string
|
||||
}
|
||||
|
||||
func NewMConsumerGroup(consumerConfig *MConsumerGroupConfig, topics, addrs []string, groupID string) (*MConsumerGroup, error) {
|
||||
func NewMConsumerGroup(consumerConfig *MConsumerGroupConfig, topics, addrs []string, groupID string, tlsConfig *TLSConfig) (*MConsumerGroup, error) {
|
||||
consumerGroupConfig := sarama.NewConfig()
|
||||
consumerGroupConfig.Version = consumerConfig.KafkaVersion
|
||||
consumerGroupConfig.Consumer.Offsets.Initial = consumerConfig.OffsetsInitial
|
||||
consumerGroupConfig.Consumer.Return.Errors = consumerConfig.IsReturnErr
|
||||
if config.Config.Kafka.Username != "" && config.Config.Kafka.Password != "" {
|
||||
if consumerConfig.UserName != "" && consumerConfig.Password != "" {
|
||||
consumerGroupConfig.Net.SASL.Enable = true
|
||||
consumerGroupConfig.Net.SASL.User = config.Config.Kafka.Username
|
||||
consumerGroupConfig.Net.SASL.Password = config.Config.Kafka.Password
|
||||
consumerGroupConfig.Net.SASL.User = consumerConfig.UserName
|
||||
consumerGroupConfig.Net.SASL.Password = consumerConfig.Password
|
||||
}
|
||||
SetupTLSConfig(consumerGroupConfig)
|
||||
|
||||
SetupTLSConfig(consumerGroupConfig, tlsConfig)
|
||||
consumerGroup, err := sarama.NewConsumerGroup(addrs, groupID, consumerGroupConfig)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err, strings.Join(topics, ","), strings.Join(addrs, ","), groupID, config.Config.Kafka.Username, config.Config.Kafka.Password)
|
||||
return nil, errs.Wrap(err, strings.Join(topics, ","), strings.Join(addrs, ","), groupID, consumerConfig.UserName, consumerConfig.Password)
|
||||
}
|
||||
|
||||
return &MConsumerGroup{
|
||||
|
||||
@@ -22,12 +22,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
|
||||
"github.com/IBM/sarama"
|
||||
"github.com/OpenIMSDK/protocol/constant"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
"github.com/OpenIMSDK/tools/log"
|
||||
"github.com/OpenIMSDK/tools/mcontext"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
@@ -43,8 +43,15 @@ type Producer struct {
|
||||
producer sarama.SyncProducer
|
||||
}
|
||||
|
||||
type ProducerConfig struct {
|
||||
ProducerAck string
|
||||
CompressType string
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
// NewKafkaProducer initializes a new Kafka producer.
|
||||
func NewKafkaProducer(addr []string, topic string) (*Producer, error) {
|
||||
func NewKafkaProducer(addr []string, topic string, producerConfig *ProducerConfig, tlsConfig *TLSConfig) (*Producer, error) {
|
||||
p := Producer{
|
||||
addr: addr,
|
||||
topic: topic,
|
||||
@@ -59,14 +66,14 @@ func NewKafkaProducer(addr []string, topic string) (*Producer, error) {
|
||||
p.config.Producer.Partitioner = sarama.NewHashPartitioner
|
||||
|
||||
// Configure producer acknowledgement level
|
||||
configureProducerAck(&p, config.Config.Kafka.ProducerAck)
|
||||
configureProducerAck(&p, producerConfig.ProducerAck)
|
||||
|
||||
// Configure message compression
|
||||
configureCompression(&p, config.Config.Kafka.CompressType)
|
||||
configureCompression(&p, producerConfig.CompressType)
|
||||
|
||||
// Get Kafka configuration from environment variables or fallback to config file
|
||||
kafkaUsername := getEnvOrConfig("KAFKA_USERNAME", config.Config.Kafka.Username)
|
||||
kafkaPassword := getEnvOrConfig("KAFKA_PASSWORD", config.Config.Kafka.Password)
|
||||
kafkaUsername := getEnvOrConfig("KAFKA_USERNAME", producerConfig.Username)
|
||||
kafkaPassword := getEnvOrConfig("KAFKA_PASSWORD", producerConfig.Password)
|
||||
kafkaAddr := getKafkaAddrFromEnv(addr) // Updated to use the new function
|
||||
|
||||
// Configure SASL authentication if credentials are provided
|
||||
@@ -80,7 +87,7 @@ func NewKafkaProducer(addr []string, topic string) (*Producer, error) {
|
||||
p.addr = kafkaAddr
|
||||
|
||||
// Set up TLS configuration (if required)
|
||||
SetupTLSConfig(p.config)
|
||||
SetupTLSConfig(p.config, tlsConfig)
|
||||
|
||||
// Create the producer with retries
|
||||
var err error
|
||||
|
||||
@@ -20,19 +20,27 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/IBM/sarama"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/tls"
|
||||
)
|
||||
|
||||
type TLSConfig struct {
|
||||
CACrt string
|
||||
ClientCrt string
|
||||
ClientKey string
|
||||
ClientKeyPwd string
|
||||
InsecureSkipVerify bool
|
||||
}
|
||||
|
||||
// SetupTLSConfig set up the TLS config from config file.
|
||||
func SetupTLSConfig(cfg *sarama.Config) error {
|
||||
if config.Config.Kafka.TLS != nil {
|
||||
func SetupTLSConfig(cfg *sarama.Config, tlsConfig *TLSConfig) error {
|
||||
if tlsConfig != nil {
|
||||
cfg.Net.TLS.Enable = true
|
||||
tlsConfig, err := tls.NewTLSConfig(
|
||||
config.Config.Kafka.TLS.ClientCrt,
|
||||
config.Config.Kafka.TLS.ClientKey,
|
||||
config.Config.Kafka.TLS.CACrt,
|
||||
[]byte(config.Config.Kafka.TLS.ClientKeyPwd),
|
||||
tlsConfig.ClientCrt,
|
||||
tlsConfig.ClientKey,
|
||||
tlsConfig.CACrt,
|
||||
[]byte(tlsConfig.ClientKeyPwd),
|
||||
tlsConfig.InsecureSkipVerify,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -31,17 +31,17 @@ func NewGrpcPromObj(cusMetrics []prometheus.Collector) (*prometheus.Registry, *g
|
||||
return reg, grpcMetrics, nil
|
||||
}
|
||||
|
||||
func GetGrpcCusMetrics(registerName string) []prometheus.Collector {
|
||||
func GetGrpcCusMetrics(registerName string, config *config2.GlobalConfig) []prometheus.Collector {
|
||||
switch registerName {
|
||||
case config2.Config.RpcRegisterName.OpenImMessageGatewayName:
|
||||
case config.RpcRegisterName.OpenImMessageGatewayName:
|
||||
return []prometheus.Collector{OnlineUserGauge}
|
||||
case config2.Config.RpcRegisterName.OpenImMsgName:
|
||||
case config.RpcRegisterName.OpenImMsgName:
|
||||
return []prometheus.Collector{SingleChatMsgProcessSuccessCounter, SingleChatMsgProcessFailedCounter, GroupChatMsgProcessSuccessCounter, GroupChatMsgProcessFailedCounter}
|
||||
case "Transfer":
|
||||
return []prometheus.Collector{MsgInsertRedisSuccessCounter, MsgInsertRedisFailedCounter, MsgInsertMongoSuccessCounter, MsgInsertMongoFailedCounter, SeqSetFailedCounter}
|
||||
case config2.Config.RpcRegisterName.OpenImPushName:
|
||||
case config.RpcRegisterName.OpenImPushName:
|
||||
return []prometheus.Collector{MsgOfflinePushFailedCounter}
|
||||
case config2.Config.RpcRegisterName.OpenImAuthName:
|
||||
case config.RpcRegisterName.OpenImAuthName:
|
||||
return []prometheus.Collector{UserLoginCounter}
|
||||
default:
|
||||
return nil
|
||||
|
||||
@@ -58,17 +58,20 @@ func TestNewGrpcPromObj(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetGrpcCusMetrics(t *testing.T) {
|
||||
conf := config2.NewGlobalConfig()
|
||||
|
||||
config2.InitConfig(conf, "../../config")
|
||||
// Test various cases based on the switch statement in the GetGrpcCusMetrics function.
|
||||
testCases := []struct {
|
||||
name string
|
||||
expected int // The expected number of metrics for each case.
|
||||
}{
|
||||
{config2.Config.RpcRegisterName.OpenImMessageGatewayName, 1},
|
||||
{conf.RpcRegisterName.OpenImMessageGatewayName, 1},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
metrics := GetGrpcCusMetrics(tc.name)
|
||||
metrics := GetGrpcCusMetrics(tc.name, conf)
|
||||
assert.Len(t, metrics, tc.expected)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -27,19 +28,25 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/OpenIMSDK/tools/discoveryregistry"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
"github.com/OpenIMSDK/tools/mw"
|
||||
"github.com/OpenIMSDK/tools/network"
|
||||
grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
kdisc "github.com/openimsdk/open-im-server/v3/pkg/common/discoveryregister"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
|
||||
|
||||
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
|
||||
config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/prommetrics"
|
||||
|
||||
grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
kdisc "github.com/openimsdk/open-im-server/v3/pkg/common/discoveryregister"
|
||||
|
||||
"github.com/OpenIMSDK/tools/discoveryregistry"
|
||||
"github.com/OpenIMSDK/tools/mw"
|
||||
"github.com/OpenIMSDK/tools/network"
|
||||
)
|
||||
|
||||
// Start rpc server.
|
||||
@@ -47,37 +54,38 @@ func Start(
|
||||
rpcPort int,
|
||||
rpcRegisterName string,
|
||||
prometheusPort int,
|
||||
rpcFn func(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error,
|
||||
config *config2.GlobalConfig,
|
||||
rpcFn func(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error,
|
||||
options ...grpc.ServerOption,
|
||||
) error {
|
||||
fmt.Printf("start %s server, port: %d, prometheusPort: %d, OpenIM version: %s\n",
|
||||
rpcRegisterName, rpcPort, prometheusPort, config.Version)
|
||||
rpcTcpAddr := net.JoinHostPort(network.GetListenIP(config.Config.Rpc.ListenIP), strconv.Itoa(rpcPort))
|
||||
rpcRegisterName, rpcPort, prometheusPort, config2.Version)
|
||||
rpcTcpAddr := net.JoinHostPort(network.GetListenIP(config.Rpc.ListenIP), strconv.Itoa(rpcPort))
|
||||
listener, err := net.Listen(
|
||||
"tcp",
|
||||
rpcTcpAddr,
|
||||
)
|
||||
if err != nil {
|
||||
return errs.Wrap(err, "rpc start err", rpcTcpAddr)
|
||||
return errs.Wrap(err, "listen err", rpcTcpAddr)
|
||||
}
|
||||
|
||||
defer listener.Close()
|
||||
client, err := kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery)
|
||||
client, err := kdisc.NewDiscoveryRegister(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer client.Close()
|
||||
client.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin")))
|
||||
registerIP, err := network.GetRpcRegisterIP(config.Config.Rpc.RegisterIP)
|
||||
registerIP, err := network.GetRpcRegisterIP(config.Rpc.RegisterIP)
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
|
||||
var reg *prometheus.Registry
|
||||
var metric *grpcprometheus.ServerMetrics
|
||||
if config.Config.Prometheus.Enable {
|
||||
cusMetrics := prommetrics.GetGrpcCusMetrics(rpcRegisterName)
|
||||
if config.Prometheus.Enable {
|
||||
cusMetrics := prommetrics.GetGrpcCusMetrics(rpcRegisterName, config)
|
||||
reg, metric, _ = prommetrics.NewGrpcPromObj(cusMetrics)
|
||||
options = append(options, mw.GrpcServer(), grpc.StreamInterceptor(metric.StreamServerInterceptor()),
|
||||
grpc.UnaryInterceptor(metric.UnaryServerInterceptor()))
|
||||
@@ -91,7 +99,7 @@ func Start(
|
||||
once.Do(srv.GracefulStop)
|
||||
}()
|
||||
|
||||
err = rpcFn(client, srv)
|
||||
err = rpcFn(config, client, srv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -111,7 +119,7 @@ func Start(
|
||||
httpServer *http.Server
|
||||
)
|
||||
go func() {
|
||||
if config.Config.Prometheus.Enable && prometheusPort != 0 {
|
||||
if config.Prometheus.Enable && prometheusPort != 0 {
|
||||
metric.InitializeMetrics(srv)
|
||||
// Create a HTTP server for prometheus.
|
||||
httpServer = &http.Server{Handler: promhttp.HandlerFor(reg, promhttp.HandlerOpts{}), Addr: fmt.Sprintf("0.0.0.0:%d", prometheusPort)}
|
||||
|
||||
@@ -16,6 +16,7 @@ package startrpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -25,7 +26,7 @@ import (
|
||||
)
|
||||
|
||||
// mockRpcFn is a mock gRPC function for testing.
|
||||
func mockRpcFn(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error {
|
||||
func mockRpcFn(config *config.GlobalConfig, client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error {
|
||||
// Implement a mock gRPC service registration logic if needed
|
||||
return nil
|
||||
}
|
||||
@@ -40,7 +41,8 @@ func TestStart(t *testing.T) {
|
||||
doneChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
err := Start(testRpcPort, testRpcRegisterName, testPrometheusPort, mockRpcFn)
|
||||
err := Start(testRpcPort, testRpcRegisterName, testPrometheusPort,
|
||||
config.NewGlobalConfig(), mockRpcFn)
|
||||
doneChan <- err
|
||||
}()
|
||||
|
||||
|
||||
Regular → Executable
+3
-5
@@ -22,7 +22,6 @@ import (
|
||||
"os"
|
||||
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
)
|
||||
|
||||
// decryptPEM decrypts a PEM block using a password.
|
||||
@@ -50,15 +49,14 @@ func readEncryptablePEMBlock(path string, pwd []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
// NewTLSConfig setup the TLS config from general config file.
|
||||
func NewTLSConfig(clientCertFile, clientKeyFile, caCertFile string, keyPwd []byte) (*tls.Config, error) {
|
||||
var tlsConfig tls.Config
|
||||
func NewTLSConfig(clientCertFile, clientKeyFile, caCertFile string, keyPwd []byte, insecureSkipVerify bool) (*tls.Config,error) {
|
||||
tlsConfig := tls.Config{}
|
||||
|
||||
if clientCertFile != "" && clientKeyFile != "" {
|
||||
certPEMBlock, err := os.ReadFile(clientCertFile)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err, "NewTLSConfig: failed to read client cert file")
|
||||
}
|
||||
|
||||
keyPEMBlock, err := readEncryptablePEMBlock(clientKeyFile, keyPwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -84,7 +82,7 @@ func NewTLSConfig(clientCertFile, clientKeyFile, caCertFile string, keyPwd []byt
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
}
|
||||
|
||||
tlsConfig.InsecureSkipVerify = config.Config.Kafka.TLS.InsecureSkipVerify
|
||||
tlsConfig.InsecureSkipVerify = insecureSkipVerify
|
||||
|
||||
return &tlsConfig, nil
|
||||
}
|
||||
|
||||
@@ -24,17 +24,18 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func NewAuth(discov discoveryregistry.SvcDiscoveryRegistry) *Auth {
|
||||
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImAuthName)
|
||||
func NewAuth(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Auth {
|
||||
conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImAuthName)
|
||||
if err != nil {
|
||||
util.ExitWithError(err)
|
||||
}
|
||||
client := auth.NewAuthClient(conn)
|
||||
return &Auth{discov: discov, conn: conn, Client: client}
|
||||
return &Auth{discov: discov, conn: conn, Client: client, Config: config}
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
conn grpc.ClientConnInterface
|
||||
Client auth.AuthClient
|
||||
discov discoveryregistry.SvcDiscoveryRegistry
|
||||
Config *config.GlobalConfig
|
||||
}
|
||||
|
||||
@@ -30,21 +30,22 @@ type Conversation struct {
|
||||
Client pbconversation.ConversationClient
|
||||
conn grpc.ClientConnInterface
|
||||
discov discoveryregistry.SvcDiscoveryRegistry
|
||||
Config *config.GlobalConfig
|
||||
}
|
||||
|
||||
func NewConversation(discov discoveryregistry.SvcDiscoveryRegistry) *Conversation {
|
||||
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImConversationName)
|
||||
func NewConversation(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Conversation {
|
||||
conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImConversationName)
|
||||
if err != nil {
|
||||
util.ExitWithError(err)
|
||||
}
|
||||
client := pbconversation.NewConversationClient(conn)
|
||||
return &Conversation{discov: discov, conn: conn, Client: client}
|
||||
return &Conversation{discov: discov, conn: conn, Client: client, Config: config}
|
||||
}
|
||||
|
||||
type ConversationRpcClient Conversation
|
||||
|
||||
func NewConversationRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) ConversationRpcClient {
|
||||
return ConversationRpcClient(*NewConversation(discov))
|
||||
func NewConversationRpcClient(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) ConversationRpcClient {
|
||||
return ConversationRpcClient(*NewConversation(discov, config))
|
||||
}
|
||||
|
||||
func (c *ConversationRpcClient) GetSingleConversationRecvMsgOpt(ctx context.Context, userID, conversationID string) (int32, error) {
|
||||
|
||||
@@ -29,21 +29,22 @@ type Friend struct {
|
||||
conn grpc.ClientConnInterface
|
||||
Client friend.FriendClient
|
||||
discov discoveryregistry.SvcDiscoveryRegistry
|
||||
Config *config.GlobalConfig
|
||||
}
|
||||
|
||||
func NewFriend(discov discoveryregistry.SvcDiscoveryRegistry) *Friend {
|
||||
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImFriendName)
|
||||
func NewFriend(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Friend {
|
||||
conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImFriendName)
|
||||
if err != nil {
|
||||
util.ExitWithError(err)
|
||||
}
|
||||
client := friend.NewFriendClient(conn)
|
||||
return &Friend{discov: discov, conn: conn, Client: client}
|
||||
return &Friend{discov: discov, conn: conn, Client: client, Config: config}
|
||||
}
|
||||
|
||||
type FriendRpcClient Friend
|
||||
|
||||
func NewFriendRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) FriendRpcClient {
|
||||
return FriendRpcClient(*NewFriend(discov))
|
||||
func NewFriendRpcClient(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) FriendRpcClient {
|
||||
return FriendRpcClient(*NewFriend(discov, config))
|
||||
}
|
||||
|
||||
func (f *FriendRpcClient) GetFriendsInfo(
|
||||
|
||||
@@ -33,21 +33,22 @@ type Group struct {
|
||||
conn grpc.ClientConnInterface
|
||||
Client group.GroupClient
|
||||
discov discoveryregistry.SvcDiscoveryRegistry
|
||||
Config *config.GlobalConfig
|
||||
}
|
||||
|
||||
func NewGroup(discov discoveryregistry.SvcDiscoveryRegistry) *Group {
|
||||
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImGroupName)
|
||||
func NewGroup(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Group {
|
||||
conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImGroupName)
|
||||
if err != nil {
|
||||
util.ExitWithError(err)
|
||||
}
|
||||
client := group.NewGroupClient(conn)
|
||||
return &Group{discov: discov, conn: conn, Client: client}
|
||||
return &Group{discov: discov, conn: conn, Client: client, Config: config}
|
||||
}
|
||||
|
||||
type GroupRpcClient Group
|
||||
|
||||
func NewGroupRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) GroupRpcClient {
|
||||
return GroupRpcClient(*NewGroup(discov))
|
||||
func NewGroupRpcClient(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) GroupRpcClient {
|
||||
return GroupRpcClient(*NewGroup(discov, config))
|
||||
}
|
||||
|
||||
func (g *GroupRpcClient) GetGroupInfos(
|
||||
|
||||
+53
-50
@@ -17,6 +17,8 @@ package rpcclient
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/constant"
|
||||
"github.com/OpenIMSDK/protocol/msg"
|
||||
@@ -29,47 +31,47 @@ import (
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func newContentTypeConf() map[int32]config.NotificationConf {
|
||||
func newContentTypeConf(conf *config.GlobalConfig) map[int32]config.NotificationConf {
|
||||
return map[int32]config.NotificationConf{
|
||||
// group
|
||||
constant.GroupCreatedNotification: config.Config.Notification.GroupCreated,
|
||||
constant.GroupInfoSetNotification: config.Config.Notification.GroupInfoSet,
|
||||
constant.JoinGroupApplicationNotification: config.Config.Notification.JoinGroupApplication,
|
||||
constant.MemberQuitNotification: config.Config.Notification.MemberQuit,
|
||||
constant.GroupApplicationAcceptedNotification: config.Config.Notification.GroupApplicationAccepted,
|
||||
constant.GroupApplicationRejectedNotification: config.Config.Notification.GroupApplicationRejected,
|
||||
constant.GroupOwnerTransferredNotification: config.Config.Notification.GroupOwnerTransferred,
|
||||
constant.MemberKickedNotification: config.Config.Notification.MemberKicked,
|
||||
constant.MemberInvitedNotification: config.Config.Notification.MemberInvited,
|
||||
constant.MemberEnterNotification: config.Config.Notification.MemberEnter,
|
||||
constant.GroupDismissedNotification: config.Config.Notification.GroupDismissed,
|
||||
constant.GroupMutedNotification: config.Config.Notification.GroupMuted,
|
||||
constant.GroupCancelMutedNotification: config.Config.Notification.GroupCancelMuted,
|
||||
constant.GroupMemberMutedNotification: config.Config.Notification.GroupMemberMuted,
|
||||
constant.GroupMemberCancelMutedNotification: config.Config.Notification.GroupMemberCancelMuted,
|
||||
constant.GroupMemberInfoSetNotification: config.Config.Notification.GroupMemberInfoSet,
|
||||
constant.GroupMemberSetToAdminNotification: config.Config.Notification.GroupMemberSetToAdmin,
|
||||
constant.GroupMemberSetToOrdinaryUserNotification: config.Config.Notification.GroupMemberSetToOrdinary,
|
||||
constant.GroupInfoSetAnnouncementNotification: config.Config.Notification.GroupInfoSetAnnouncement,
|
||||
constant.GroupInfoSetNameNotification: config.Config.Notification.GroupInfoSetName,
|
||||
constant.GroupCreatedNotification: conf.Notification.GroupCreated,
|
||||
constant.GroupInfoSetNotification: conf.Notification.GroupInfoSet,
|
||||
constant.JoinGroupApplicationNotification: conf.Notification.JoinGroupApplication,
|
||||
constant.MemberQuitNotification: conf.Notification.MemberQuit,
|
||||
constant.GroupApplicationAcceptedNotification: conf.Notification.GroupApplicationAccepted,
|
||||
constant.GroupApplicationRejectedNotification: conf.Notification.GroupApplicationRejected,
|
||||
constant.GroupOwnerTransferredNotification: conf.Notification.GroupOwnerTransferred,
|
||||
constant.MemberKickedNotification: conf.Notification.MemberKicked,
|
||||
constant.MemberInvitedNotification: conf.Notification.MemberInvited,
|
||||
constant.MemberEnterNotification: conf.Notification.MemberEnter,
|
||||
constant.GroupDismissedNotification: conf.Notification.GroupDismissed,
|
||||
constant.GroupMutedNotification: conf.Notification.GroupMuted,
|
||||
constant.GroupCancelMutedNotification: conf.Notification.GroupCancelMuted,
|
||||
constant.GroupMemberMutedNotification: conf.Notification.GroupMemberMuted,
|
||||
constant.GroupMemberCancelMutedNotification: conf.Notification.GroupMemberCancelMuted,
|
||||
constant.GroupMemberInfoSetNotification: conf.Notification.GroupMemberInfoSet,
|
||||
constant.GroupMemberSetToAdminNotification: conf.Notification.GroupMemberSetToAdmin,
|
||||
constant.GroupMemberSetToOrdinaryUserNotification: conf.Notification.GroupMemberSetToOrdinary,
|
||||
constant.GroupInfoSetAnnouncementNotification: conf.Notification.GroupInfoSetAnnouncement,
|
||||
constant.GroupInfoSetNameNotification: conf.Notification.GroupInfoSetName,
|
||||
// user
|
||||
constant.UserInfoUpdatedNotification: config.Config.Notification.UserInfoUpdated,
|
||||
constant.UserStatusChangeNotification: config.Config.Notification.UserStatusChanged,
|
||||
constant.UserInfoUpdatedNotification: conf.Notification.UserInfoUpdated,
|
||||
constant.UserStatusChangeNotification: conf.Notification.UserStatusChanged,
|
||||
// friend
|
||||
constant.FriendApplicationNotification: config.Config.Notification.FriendApplicationAdded,
|
||||
constant.FriendApplicationApprovedNotification: config.Config.Notification.FriendApplicationApproved,
|
||||
constant.FriendApplicationRejectedNotification: config.Config.Notification.FriendApplicationRejected,
|
||||
constant.FriendAddedNotification: config.Config.Notification.FriendAdded,
|
||||
constant.FriendDeletedNotification: config.Config.Notification.FriendDeleted,
|
||||
constant.FriendRemarkSetNotification: config.Config.Notification.FriendRemarkSet,
|
||||
constant.BlackAddedNotification: config.Config.Notification.BlackAdded,
|
||||
constant.BlackDeletedNotification: config.Config.Notification.BlackDeleted,
|
||||
constant.FriendInfoUpdatedNotification: config.Config.Notification.FriendInfoUpdated,
|
||||
constant.FriendsInfoUpdateNotification: config.Config.Notification.FriendInfoUpdated, //use the same FriendInfoUpdated
|
||||
constant.FriendApplicationNotification: conf.Notification.FriendApplicationAdded,
|
||||
constant.FriendApplicationApprovedNotification: conf.Notification.FriendApplicationApproved,
|
||||
constant.FriendApplicationRejectedNotification: conf.Notification.FriendApplicationRejected,
|
||||
constant.FriendAddedNotification: conf.Notification.FriendAdded,
|
||||
constant.FriendDeletedNotification: conf.Notification.FriendDeleted,
|
||||
constant.FriendRemarkSetNotification: conf.Notification.FriendRemarkSet,
|
||||
constant.BlackAddedNotification: conf.Notification.BlackAdded,
|
||||
constant.BlackDeletedNotification: conf.Notification.BlackDeleted,
|
||||
constant.FriendInfoUpdatedNotification: conf.Notification.FriendInfoUpdated,
|
||||
constant.FriendsInfoUpdateNotification: conf.Notification.FriendInfoUpdated, //use the same FriendInfoUpdated
|
||||
// conversation
|
||||
constant.ConversationChangeNotification: config.Config.Notification.ConversationChanged,
|
||||
constant.ConversationUnreadNotification: config.Config.Notification.ConversationChanged,
|
||||
constant.ConversationPrivateChatNotification: config.Config.Notification.ConversationSetPrivate,
|
||||
constant.ConversationChangeNotification: conf.Notification.ConversationChanged,
|
||||
constant.ConversationUnreadNotification: conf.Notification.ConversationChanged,
|
||||
constant.ConversationPrivateChatNotification: conf.Notification.ConversationSetPrivate,
|
||||
// msg
|
||||
constant.MsgRevokeNotification: {IsSendMsg: false, ReliabilityLevel: constant.ReliableNotificationNoMsg},
|
||||
constant.HasReadReceipt: {IsSendMsg: false, ReliabilityLevel: constant.ReliableNotificationNoMsg},
|
||||
@@ -127,21 +129,22 @@ type Message struct {
|
||||
conn grpc.ClientConnInterface
|
||||
Client msg.MsgClient
|
||||
discov discoveryregistry.SvcDiscoveryRegistry
|
||||
Config *config.GlobalConfig
|
||||
}
|
||||
|
||||
func NewMessage(discov discoveryregistry.SvcDiscoveryRegistry) *Message {
|
||||
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImMsgName)
|
||||
func NewMessage(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Message {
|
||||
conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImMsgName)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
client := msg.NewMsgClient(conn)
|
||||
return &Message{discov: discov, conn: conn, Client: client}
|
||||
return &Message{discov: discov, conn: conn, Client: client, Config: config}
|
||||
}
|
||||
|
||||
type MessageRpcClient Message
|
||||
|
||||
func NewMessageRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) MessageRpcClient {
|
||||
return MessageRpcClient(*NewMessage(discov))
|
||||
func NewMessageRpcClient(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) MessageRpcClient {
|
||||
return MessageRpcClient(*NewMessage(discov, config))
|
||||
}
|
||||
|
||||
// SendMsg sends a message through the gRPC client and returns the response.
|
||||
@@ -234,8 +237,8 @@ func WithUserRpcClient(userRpcClient *UserRpcClient) NotificationSenderOptions {
|
||||
}
|
||||
}
|
||||
|
||||
func NewNotificationSender(opts ...NotificationSenderOptions) *NotificationSender {
|
||||
notificationSender := &NotificationSender{contentTypeConf: newContentTypeConf(), sessionTypeConf: newSessionTypeConf()}
|
||||
func NewNotificationSender(config *config.GlobalConfig, opts ...NotificationSenderOptions) *NotificationSender {
|
||||
notificationSender := &NotificationSender{contentTypeConf: newContentTypeConf(config), sessionTypeConf: newSessionTypeConf()}
|
||||
for _, opt := range opts {
|
||||
opt(notificationSender)
|
||||
}
|
||||
@@ -258,8 +261,8 @@ func (s *NotificationSender) NotificationWithSesstionType(ctx context.Context, s
|
||||
n := sdkws.NotificationElem{Detail: utils.StructToJsonString(m)}
|
||||
content, err := json.Marshal(&n)
|
||||
if err != nil {
|
||||
log.ZError(ctx, "MsgClient Notification json.Marshal failed", err, "sendID", sendID, "recvID", recvID, "contentType", contentType, "msg", m)
|
||||
return err
|
||||
errInfo := fmt.Sprintf("MsgClient Notification json.Marshal failed, sendID:%s, recvID:%s, contentType:%d, msg:%s", sendID, recvID, contentType, m)
|
||||
return errs.Wrap(err, errInfo)
|
||||
}
|
||||
notificationOpt := ¬ificationOpt{}
|
||||
for _, opt := range opts {
|
||||
@@ -271,7 +274,8 @@ func (s *NotificationSender) NotificationWithSesstionType(ctx context.Context, s
|
||||
if notificationOpt.WithRpcGetUsername && s.getUserInfo != nil {
|
||||
userInfo, err = s.getUserInfo(ctx, sendID)
|
||||
if err != nil {
|
||||
log.ZWarn(ctx, "getUserInfo failed", err, "sendID", sendID)
|
||||
errInfo := fmt.Sprintf("getUserInfo failed, sendID:%s", sendID)
|
||||
return errs.Wrap(err, errInfo)
|
||||
} else {
|
||||
msg.SenderNickname = userInfo.Nickname
|
||||
msg.SenderFaceURL = userInfo.FaceURL
|
||||
@@ -303,10 +307,9 @@ func (s *NotificationSender) NotificationWithSesstionType(ctx context.Context, s
|
||||
msg.OfflinePushInfo = &offlineInfo
|
||||
req.MsgData = &msg
|
||||
_, err = s.sendMsg(ctx, &req)
|
||||
if err == nil {
|
||||
log.ZDebug(ctx, "MsgClient Notification SendMsg success", "req", &req)
|
||||
} else {
|
||||
log.ZError(ctx, "MsgClient Notification SendMsg failed", err, "req", &req)
|
||||
if err != nil {
|
||||
errInfo := fmt.Sprintf("MsgClient Notification SendMsg failed, req:%s", &req)
|
||||
return errs.Wrap(err, errInfo)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/constant"
|
||||
"github.com/OpenIMSDK/protocol/sdkws"
|
||||
@@ -26,8 +27,8 @@ type ConversationNotificationSender struct {
|
||||
*rpcclient.NotificationSender
|
||||
}
|
||||
|
||||
func NewConversationNotificationSender(msgRpcClient *rpcclient.MessageRpcClient) *ConversationNotificationSender {
|
||||
return &ConversationNotificationSender{rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient))}
|
||||
func NewConversationNotificationSender(config *config.GlobalConfig, msgRpcClient *rpcclient.MessageRpcClient) *ConversationNotificationSender {
|
||||
return &ConversationNotificationSender{rpcclient.NewNotificationSender(config, rpcclient.WithRpcClient(msgRpcClient))}
|
||||
}
|
||||
|
||||
// SetPrivate invote.
|
||||
|
||||
@@ -16,6 +16,7 @@ package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/constant"
|
||||
pbfriend "github.com/OpenIMSDK/protocol/friend"
|
||||
@@ -80,11 +81,12 @@ func WithRpcFunc(
|
||||
}
|
||||
|
||||
func NewFriendNotificationSender(
|
||||
config *config.GlobalConfig,
|
||||
msgRpcClient *rpcclient.MessageRpcClient,
|
||||
opts ...friendNotificationSenderOptions,
|
||||
) *FriendNotificationSender {
|
||||
f := &FriendNotificationSender{
|
||||
NotificationSender: rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient)),
|
||||
NotificationSender: rpcclient.NewNotificationSender(config, rpcclient.WithRpcClient(msgRpcClient)),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(f)
|
||||
|
||||
@@ -17,6 +17,7 @@ package notification
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/constant"
|
||||
pbgroup "github.com/OpenIMSDK/protocol/group"
|
||||
@@ -35,12 +36,14 @@ func NewGroupNotificationSender(
|
||||
db controller.GroupDatabase,
|
||||
msgRpcClient *rpcclient.MessageRpcClient,
|
||||
userRpcClient *rpcclient.UserRpcClient,
|
||||
config *config.GlobalConfig,
|
||||
fn func(ctx context.Context, userIDs []string) ([]CommonUser, error),
|
||||
) *GroupNotificationSender {
|
||||
return &GroupNotificationSender{
|
||||
NotificationSender: rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient), rpcclient.WithUserRpcClient(userRpcClient)),
|
||||
NotificationSender: rpcclient.NewNotificationSender(config, rpcclient.WithRpcClient(msgRpcClient), rpcclient.WithUserRpcClient(userRpcClient)),
|
||||
getUsersInfo: fn,
|
||||
db: db,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,6 +51,7 @@ type GroupNotificationSender struct {
|
||||
*rpcclient.NotificationSender
|
||||
getUsersInfo func(ctx context.Context, userIDs []string) ([]CommonUser, error)
|
||||
db controller.GroupDatabase
|
||||
config *config.GlobalConfig
|
||||
}
|
||||
|
||||
func (g *GroupNotificationSender) PopulateGroupMember(ctx context.Context, members ...*relation.GroupMemberModel) error {
|
||||
@@ -243,21 +247,15 @@ func (g *GroupNotificationSender) groupMemberDB2PB(member *relation.GroupMemberM
|
||||
} */
|
||||
|
||||
func (g *GroupNotificationSender) fillOpUser(ctx context.Context, opUser **sdkws.GroupMemberFullInfo, groupID string) (err error) {
|
||||
defer log.ZDebug(ctx, "return")
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.ZError(ctx, utils.GetFuncName(1)+" failed", err)
|
||||
}
|
||||
}()
|
||||
if opUser == nil {
|
||||
return errs.ErrInternalServer.Wrap("**sdkws.GroupMemberFullInfo is nil")
|
||||
}
|
||||
if *opUser != nil {
|
||||
return nil
|
||||
return errs.ErrArgs.Wrap("*opUser is not nil")
|
||||
}
|
||||
userID := mcontext.GetOpUserID(ctx)
|
||||
if groupID != "" {
|
||||
if authverify.IsManagerUserID(userID) {
|
||||
if authverify.IsManagerUserID(userID, g.config) {
|
||||
*opUser = &sdkws.GroupMemberFullInfo{
|
||||
GroupID: groupID,
|
||||
UserID: userID,
|
||||
@@ -265,11 +263,11 @@ func (g *GroupNotificationSender) fillOpUser(ctx context.Context, opUser **sdkws
|
||||
AppMangerLevel: constant.AppAdmin,
|
||||
}
|
||||
} else {
|
||||
member, err2 := g.db.TakeGroupMember(ctx, groupID, userID)
|
||||
if err2 == nil {
|
||||
member, err := g.db.TakeGroupMember(ctx, groupID, userID)
|
||||
if err == nil {
|
||||
*opUser = g.groupMemberDB2PB(member, 0)
|
||||
} else if !errs.ErrRecordNotFound.Is(err2) {
|
||||
return err2
|
||||
} else if !errs.ErrRecordNotFound.Is(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -650,12 +648,6 @@ func (g *GroupNotificationSender) GroupCancelMutedNotification(ctx context.Conte
|
||||
}
|
||||
|
||||
func (g *GroupNotificationSender) GroupMemberInfoSetNotification(ctx context.Context, groupID, groupMemberUserID string) (err error) {
|
||||
defer log.ZDebug(ctx, "return")
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.ZError(ctx, utils.GetFuncName(1)+" failed", err)
|
||||
}
|
||||
}()
|
||||
group, err := g.getGroupInfo(ctx, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -672,12 +664,6 @@ func (g *GroupNotificationSender) GroupMemberInfoSetNotification(ctx context.Con
|
||||
}
|
||||
|
||||
func (g *GroupNotificationSender) GroupMemberSetToAdminNotification(ctx context.Context, groupID, groupMemberUserID string) (err error) {
|
||||
defer log.ZDebug(ctx, "return")
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.ZError(ctx, utils.GetFuncName(1)+" failed", err)
|
||||
}
|
||||
}()
|
||||
group, err := g.getGroupInfo(ctx, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -16,6 +16,7 @@ package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/constant"
|
||||
"github.com/OpenIMSDK/protocol/sdkws"
|
||||
@@ -26,8 +27,8 @@ type MsgNotificationSender struct {
|
||||
*rpcclient.NotificationSender
|
||||
}
|
||||
|
||||
func NewMsgNotificationSender(opts ...rpcclient.NotificationSenderOptions) *MsgNotificationSender {
|
||||
return &MsgNotificationSender{rpcclient.NewNotificationSender(opts...)}
|
||||
func NewMsgNotificationSender(config *config.GlobalConfig, opts ...rpcclient.NotificationSenderOptions) *MsgNotificationSender {
|
||||
return &MsgNotificationSender{rpcclient.NewNotificationSender(config, opts...)}
|
||||
}
|
||||
|
||||
func (m *MsgNotificationSender) UserDeleteMsgsNotification(ctx context.Context, userID, conversationID string, seqs []int64) error {
|
||||
|
||||
@@ -16,6 +16,7 @@ package notification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/constant"
|
||||
"github.com/OpenIMSDK/protocol/sdkws"
|
||||
@@ -58,11 +59,12 @@ func WithUserFunc(
|
||||
}
|
||||
|
||||
func NewUserNotificationSender(
|
||||
config *config.GlobalConfig,
|
||||
msgRpcClient *rpcclient.MessageRpcClient,
|
||||
opts ...userNotificationSenderOptions,
|
||||
) *UserNotificationSender {
|
||||
f := &UserNotificationSender{
|
||||
NotificationSender: rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient)),
|
||||
NotificationSender: rpcclient.NewNotificationSender(config, rpcclient.WithRpcClient(msgRpcClient)),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(f)
|
||||
|
||||
@@ -30,8 +30,8 @@ type Push struct {
|
||||
discov discoveryregistry.SvcDiscoveryRegistry
|
||||
}
|
||||
|
||||
func NewPush(discov discoveryregistry.SvcDiscoveryRegistry) *Push {
|
||||
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImPushName)
|
||||
func NewPush(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Push {
|
||||
conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImPushName)
|
||||
if err != nil {
|
||||
util.ExitWithError(err)
|
||||
}
|
||||
@@ -44,8 +44,8 @@ func NewPush(discov discoveryregistry.SvcDiscoveryRegistry) *Push {
|
||||
|
||||
type PushRpcClient Push
|
||||
|
||||
func NewPushRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) PushRpcClient {
|
||||
return PushRpcClient(*NewPush(discov))
|
||||
func NewPushRpcClient(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) PushRpcClient {
|
||||
return PushRpcClient(*NewPush(discov, config))
|
||||
}
|
||||
|
||||
func (p *PushRpcClient) DelUserPushToken(ctx context.Context, req *push.DelUserPushTokenReq) (*push.DelUserPushTokenResp, error) {
|
||||
|
||||
Regular → Executable
+22
-26
@@ -16,13 +16,15 @@ package rpcclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
"net/url"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
|
||||
"github.com/OpenIMSDK/protocol/third"
|
||||
"github.com/OpenIMSDK/tools/discoveryregistry"
|
||||
"github.com/OpenIMSDK/tools/errs"
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
|
||||
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
|
||||
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
|
||||
"google.golang.org/grpc"
|
||||
@@ -33,47 +35,41 @@ type Third struct {
|
||||
Client third.ThirdClient
|
||||
discov discoveryregistry.SvcDiscoveryRegistry
|
||||
MinioClient *minio.Client
|
||||
Config *config.GlobalConfig
|
||||
}
|
||||
|
||||
func NewThird(discov discoveryregistry.SvcDiscoveryRegistry) *Third {
|
||||
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImThirdName)
|
||||
func NewThird(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *Third {
|
||||
conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImThirdName)
|
||||
if err != nil {
|
||||
util.ExitWithError(err)
|
||||
}
|
||||
client := third.NewThirdClient(conn)
|
||||
minioClient, err := minioInit()
|
||||
minioClient, err := minioInit(config)
|
||||
if err != nil {
|
||||
util.ExitWithError(err)
|
||||
}
|
||||
return &Third{discov: discov, Client: client, conn: conn, MinioClient: minioClient}
|
||||
return &Third{discov: discov, Client: client, conn: conn, MinioClient: minioClient, Config: config}
|
||||
}
|
||||
|
||||
func minioInit() (*minio.Client, error) {
|
||||
// Retrieve MinIO configuration details
|
||||
endpoint := config.Config.Object.Minio.Endpoint
|
||||
accessKeyID := config.Config.Object.Minio.AccessKeyID
|
||||
secretAccessKey := config.Config.Object.Minio.SecretAccessKey
|
||||
|
||||
// Parse the MinIO URL to determine if the connection should be secure
|
||||
minioURL, err := url.Parse(endpoint)
|
||||
func minioInit(config *config.GlobalConfig) (*minio.Client, error) {
|
||||
minioClient := &minio.Client{}
|
||||
initUrl := config.Object.Minio.Endpoint
|
||||
minioUrl, err := url.Parse(initUrl)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err, "minioInit: failed to parse MinIO endpoint URL")
|
||||
}
|
||||
|
||||
// Determine the security of the connection based on the scheme
|
||||
secure := minioURL.Scheme == "https"
|
||||
|
||||
// Setup MinIO client options
|
||||
opts := &minio.Options{
|
||||
Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
|
||||
Secure: secure,
|
||||
Creds: credentials.NewStaticV4(config.Object.Minio.AccessKeyID, config.Object.Minio.SecretAccessKey, ""),
|
||||
// Region: config.Credential.Minio.Location,
|
||||
}
|
||||
|
||||
// Initialize MinIO client
|
||||
minioClient, err := minio.New(minioURL.Host, opts)
|
||||
if minioUrl.Scheme == "http" {
|
||||
opts.Secure = false
|
||||
} else if minioUrl.Scheme == "https" {
|
||||
opts.Secure = true
|
||||
}
|
||||
minioClient, err = minio.New(minioUrl.Host, opts)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err, "minioInit: failed to create MinIO client")
|
||||
}
|
||||
|
||||
return minioClient, nil
|
||||
}
|
||||
|
||||
@@ -34,16 +34,17 @@ type User struct {
|
||||
conn grpc.ClientConnInterface
|
||||
Client user.UserClient
|
||||
Discov discoveryregistry.SvcDiscoveryRegistry
|
||||
Config *config.GlobalConfig
|
||||
}
|
||||
|
||||
// NewUser initializes and returns a User instance based on the provided service discovery registry.
|
||||
func NewUser(discov discoveryregistry.SvcDiscoveryRegistry) *User {
|
||||
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImUserName)
|
||||
func NewUser(discov discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) *User {
|
||||
conn, err := discov.GetConn(context.Background(), config.RpcRegisterName.OpenImUserName)
|
||||
if err != nil {
|
||||
util.ExitWithError(err)
|
||||
}
|
||||
client := user.NewUserClient(conn)
|
||||
return &User{Discov: discov, Client: client, conn: conn}
|
||||
return &User{Discov: discov, Client: client, conn: conn, Config: config}
|
||||
}
|
||||
|
||||
// UserRpcClient represents the structure for a User RPC client.
|
||||
@@ -56,8 +57,8 @@ func NewUserRpcClientByUser(user *User) *UserRpcClient {
|
||||
}
|
||||
|
||||
// NewUserRpcClient initializes a UserRpcClient based on the provided service discovery registry.
|
||||
func NewUserRpcClient(client discoveryregistry.SvcDiscoveryRegistry) UserRpcClient {
|
||||
return UserRpcClient(*NewUser(client))
|
||||
func NewUserRpcClient(client discoveryregistry.SvcDiscoveryRegistry, config *config.GlobalConfig) UserRpcClient {
|
||||
return UserRpcClient(*NewUser(client, config))
|
||||
}
|
||||
|
||||
// GetUsersInfo retrieves information for multiple users based on their user IDs.
|
||||
@@ -160,7 +161,7 @@ func (u *UserRpcClient) Access(ctx context.Context, ownerUserID string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return authverify.CheckAccessV3(ctx, ownerUserID)
|
||||
return authverify.CheckAccessV3(ctx, ownerUserID, u.Config)
|
||||
}
|
||||
|
||||
// GetAllUserIDs retrieves all user IDs with pagination options.
|
||||
|
||||
Reference in New Issue
Block a user