diff --git a/pkg/engine/daemon.go b/pkg/engine/daemon.go index 113aa1ba..f0a1c10c 100644 --- a/pkg/engine/daemon.go +++ b/pkg/engine/daemon.go @@ -18,8 +18,9 @@ import ( var ports Ports type Ports struct { - daemonPorts map[string]int64 - daemonLock sync.Mutex + daemonPorts map[string]int64 + daemonsRunning map[string]struct{} + daemonLock sync.Mutex startPort, endPort int64 usedPorts map[int64]struct{} @@ -28,6 +29,13 @@ type Ports struct { daemonWG sync.WaitGroup } +func IsDaemonRunning(url string) bool { + ports.daemonLock.Lock() + defer ports.daemonLock.Unlock() + _, ok := ports.daemonsRunning[url] + return ok +} + func SetPorts(start, end int64) { ports.daemonLock.Lock() defer ports.daemonLock.Unlock() @@ -164,8 +172,10 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { if ports.daemonPorts == nil { ports.daemonPorts = map[string]int64{} + ports.daemonsRunning = map[string]struct{}{} } ports.daemonPorts[tool.ID] = port + ports.daemonsRunning[url] = struct{}{} killedCtx, cancel := context.WithCancelCause(ctx) defer cancel(nil) @@ -185,6 +195,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { defer ports.daemonLock.Unlock() delete(ports.daemonPorts, tool.ID) + delete(ports.daemonsRunning, url) ports.daemonWG.Done() }() diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 89863529..6a21413b 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -22,8 +22,9 @@ import ( ) type Client struct { - modelsLock sync.Mutex + clientsLock sync.Mutex cache *cache.Client + clients map[string]clientInfo modelToProvider map[string]string runner *runner.Runner envs []string @@ -38,13 +39,15 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent envs: envs, credStore: credStore, defaultProvider: defaultProvider, + modelToProvider: make(map[string]string), + clients: make(map[string]clientInfo), } } func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { - c.modelsLock.Lock() + c.clientsLock.Lock() provider, ok := c.modelToProvider[messageRequest.Model] - c.modelsLock.Unlock() + c.clientsLock.Unlock() if !ok { return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model) @@ -105,12 +108,8 @@ func (c *Client) Supports(ctx context.Context, modelString string) (bool, error) return false, err } - c.modelsLock.Lock() - defer c.modelsLock.Unlock() - - if c.modelToProvider == nil { - c.modelToProvider = map[string]string{} - } + c.clientsLock.Lock() + defer c.clientsLock.Unlock() c.modelToProvider[modelString] = providerName return true, nil @@ -145,11 +144,23 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie } func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, error) { + c.clientsLock.Lock() + defer c.clientsLock.Unlock() + + client, ok := c.clients[toolName] + if ok && !isHTTPURL(toolName) && engine.IsDaemonRunning(client.url) { + return client.client, nil + } + if isHTTPURL(toolName) { remoteClient, err := c.clientFromURL(ctx, toolName) if err != nil { return nil, err } + c.clients[toolName] = clientInfo{ + client: remoteClient, + url: toolName, + } return remoteClient, nil } @@ -165,7 +176,7 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err return nil, err } - client, err := openai.NewClient(ctx, c.credStore, openai.Options{ + oClient, err := openai.NewClient(ctx, c.credStore, openai.Options{ BaseURL: strings.TrimSuffix(url, "/") + "/v1", Cache: c.cache, CacheKey: prg.EntryToolID, @@ -174,7 +185,11 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err return nil, err } - return client, nil + c.clients[toolName] = clientInfo{ + client: oClient, + url: url, + } + return client.client, nil } func (c *Client) retrieveAPIKey(ctx context.Context, env, url string) (string, error) { @@ -185,3 +200,8 @@ func isLocalhost(url string) bool { return strings.HasPrefix(url, "http://localhost") || strings.HasPrefix(url, "http://127.0.0.1") || strings.HasPrefix(url, "https://localhost") || strings.HasPrefix(url, "https://127.0.0.1") } + +type clientInfo struct { + client *openai.Client + url string +}