Skip to content

Commit d92d23d

Browse files
committed
feat: Add support for tools from github enterprise.
1 parent 418a00a commit d92d23d

File tree

5 files changed

+155
-42
lines changed

5 files changed

+155
-42
lines changed

pkg/cli/gptscript.go

+24-18
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"github.com/gptscript-ai/gptscript/pkg/gptscript"
2323
"github.com/gptscript-ai/gptscript/pkg/input"
2424
"github.com/gptscript-ai/gptscript/pkg/loader"
25+
"github.com/gptscript-ai/gptscript/pkg/loader/github"
2526
"github.com/gptscript-ai/gptscript/pkg/monitor"
2627
"github.com/gptscript-ai/gptscript/pkg/mvl"
2728
"github.com/gptscript-ai/gptscript/pkg/openai"
@@ -53,24 +54,25 @@ type GPTScript struct {
5354
Output string `usage:"Save output to a file, or - for stdout" short:"o"`
5455
EventsStreamTo string `usage:"Stream events to this location, could be a file descriptor/handle (e.g. fd://2), filename, or named pipe (e.g. \\\\.\\pipe\\my-pipe)" name:"events-stream-to"`
5556
// Input should not be using GPTSCRIPT_INPUT env var because that is the same value that is set in tool executions
56-
Input string `usage:"Read input from a file (\"-\" for stdin)" short:"f" env:"GPTSCRIPT_INPUT_FILE"`
57-
SubTool string `usage:"Use tool of this name, not the first tool in file" local:"true"`
58-
Assemble bool `usage:"Assemble tool to a single artifact, saved to --output" hidden:"true" local:"true"`
59-
ListModels bool `usage:"List the models available and exit" local:"true"`
60-
ListTools bool `usage:"List built-in tools and exit" local:"true"`
61-
ListenAddress string `usage:"Server listen address" default:"127.0.0.1:0" hidden:"true"`
62-
Chdir string `usage:"Change current working directory" short:"C"`
63-
Daemon bool `usage:"Run tool as a daemon" local:"true" hidden:"true"`
64-
Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"`
65-
CredentialContext string `usage:"Context name in which to store credentials" default:"default"`
66-
CredentialOverride string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"`
67-
ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state" local:"true"`
68-
ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool" local:"true"`
69-
ForceSequential bool `usage:"Force parallel calls to run sequentially" local:"true"`
70-
Workspace string `usage:"Directory to use for the workspace, if specified it will not be deleted on exit"`
71-
UI bool `usage:"Launch the UI" local:"true" name:"ui"`
72-
DisableTUI bool `usage:"Don't use chat TUI but instead verbose output" local:"true" name:"disable-tui"`
73-
SaveChatStateFile string `usage:"A file to save the chat state to so that a conversation can be resumed with --chat-state" local:"true"`
57+
Input string `usage:"Read input from a file (\"-\" for stdin)" short:"f" env:"GPTSCRIPT_INPUT_FILE"`
58+
SubTool string `usage:"Use tool of this name, not the first tool in file" local:"true"`
59+
Assemble bool `usage:"Assemble tool to a single artifact, saved to --output" hidden:"true" local:"true"`
60+
ListModels bool `usage:"List the models available and exit" local:"true"`
61+
ListTools bool `usage:"List built-in tools and exit" local:"true"`
62+
ListenAddress string `usage:"Server listen address" default:"127.0.0.1:0" hidden:"true"`
63+
Chdir string `usage:"Change current working directory" short:"C"`
64+
Daemon bool `usage:"Run tool as a daemon" local:"true" hidden:"true"`
65+
Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"`
66+
CredentialContext string `usage:"Context name in which to store credentials" default:"default"`
67+
CredentialOverride string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"`
68+
ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state" local:"true"`
69+
ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool" local:"true"`
70+
ForceSequential bool `usage:"Force parallel calls to run sequentially" local:"true"`
71+
Workspace string `usage:"Directory to use for the workspace, if specified it will not be deleted on exit"`
72+
UI bool `usage:"Launch the UI" local:"true" name:"ui"`
73+
DisableTUI bool `usage:"Don't use chat TUI but instead verbose output" local:"true" name:"disable-tui"`
74+
SaveChatStateFile string `usage:"A file to save the chat state to so that a conversation can be resumed with --chat-state" local:"true"`
75+
EnableGithubEnterprise string `usage:"The host name for a Github Enterprise instance to enable for remote loading" local:"true"`
7476

7577
readData []byte
7678
}
@@ -328,6 +330,10 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
328330
return err
329331
}
330332

333+
if r.EnableGithubEnterprise != "" {
334+
loader.AddVSC(github.LoaderForPrefix(r.EnableGithubEnterprise))
335+
}
336+
331337
// If the user is trying to launch the chat-builder UI, then set up the tool and options here.
332338
if r.UI {
333339
args = append([]string{env.VarOrDefault("GPTSCRIPT_CHAT_UI_TOOL", "github.com/gptscript-ai/ui@v2")}, args...)

pkg/loader/github/github.go

+65-24
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package github
22

33
import (
44
"context"
5+
"crypto/tls"
56
"encoding/json"
67
"fmt"
78
"io"
@@ -18,52 +19,65 @@ import (
1819
"github.com/gptscript-ai/gptscript/pkg/types"
1920
)
2021

21-
const (
22-
GithubPrefix = "github.com/"
23-
githubRepoURL = "https://github.com/%s/%s.git"
24-
githubDownloadURL = "https://raw.githubusercontent.com/%s/%s/%s/%s"
25-
githubCommitURL = "https://api.github.com/repos/%s/%s/commits/%s"
26-
)
22+
type GithubConfig struct {
23+
Prefix string
24+
RepoURL string
25+
DownloadURL string
26+
CommitURL string
27+
AuthToken string
28+
Enterprise bool
29+
}
2730

2831
var (
29-
githubAuthToken = os.Getenv("GITHUB_AUTH_TOKEN")
30-
log = mvl.Package()
32+
log = mvl.Package()
33+
DEFAULT_GITHUB_CONFIG = &GithubConfig{
34+
Prefix: "github.com/",
35+
RepoURL: "https://github.com/%s/%s.git",
36+
DownloadURL: "https://raw.githubusercontent.com/%s/%s/%s/%s",
37+
CommitURL: "https://api.github.com/repos/%s/%s/commits/%s",
38+
AuthToken: os.Getenv("GITHUB_AUTH_TOKEN"),
39+
Enterprise: false,
40+
}
3141
)
3242

3343
func init() {
3444
loader.AddVSC(Load)
3545
}
3646

37-
func getCommitLsRemote(ctx context.Context, account, repo, ref string) (string, error) {
38-
url := fmt.Sprintf(githubRepoURL, account, repo)
47+
func getCommitLsRemote(ctx context.Context, account, repo, ref string, config *GithubConfig) (string, error) {
48+
url := fmt.Sprintf(config.RepoURL, account, repo)
3949
return git.LsRemote(ctx, url, ref)
4050
}
4151

4252
// regexp to match a git commit id
4353
var commitRegexp = regexp.MustCompile("^[a-f0-9]{40}$")
4454

45-
func getCommit(ctx context.Context, account, repo, ref string) (string, error) {
55+
func getCommit(ctx context.Context, account, repo, ref string, config *GithubConfig) (string, error) {
4656
if commitRegexp.MatchString(ref) {
4757
return ref, nil
4858
}
4959

50-
url := fmt.Sprintf(githubCommitURL, account, repo, ref)
60+
url := fmt.Sprintf(config.CommitURL, account, repo, ref)
5161
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
5262
if err != nil {
5363
return "", fmt.Errorf("failed to create request of %s/%s at %s: %w", account, repo, url, err)
5464
}
5565

56-
if githubAuthToken != "" {
57-
req.Header.Add("Authorization", "Bearer "+githubAuthToken)
66+
if config.AuthToken != "" {
67+
req.Header.Add("Authorization", "Bearer "+config.AuthToken)
5868
}
5969

60-
resp, err := http.DefaultClient.Do(req)
70+
client := http.DefaultClient
71+
if req.Host == config.Prefix && strings.ToLower(os.Getenv("GH_ENTERPRISE_SKIP_VERIFY")) == "true" {
72+
client = &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}}
73+
}
74+
resp, err := client.Do(req)
6175
if err != nil {
6276
return "", err
6377
} else if resp.StatusCode != http.StatusOK {
6478
c, _ := io.ReadAll(resp.Body)
6579
resp.Body.Close()
66-
commit, fallBackErr := getCommitLsRemote(ctx, account, repo, ref)
80+
commit, fallBackErr := getCommitLsRemote(ctx, account, repo, ref, config)
6781
if fallBackErr == nil {
6882
return commit, nil
6983
}
@@ -88,8 +102,29 @@ func getCommit(ctx context.Context, account, repo, ref string) (string, error) {
88102
return commit.SHA, nil
89103
}
90104

91-
func Load(ctx context.Context, _ *cache.Client, urlName string) (string, *types.Repo, bool, error) {
92-
if !strings.HasPrefix(urlName, GithubPrefix) {
105+
func LoaderForPrefix(prefix string) func(context.Context, *cache.Client, string) (string, *types.Repo, bool, error) {
106+
return func(ctx context.Context, c *cache.Client, urlName string) (string, *types.Repo, bool, error) {
107+
return LoadWithConfig(ctx, c, urlName, NewGithubEnterpriseConfig(prefix))
108+
}
109+
}
110+
111+
func Load(ctx context.Context, c *cache.Client, urlName string) (string, *types.Repo, bool, error) {
112+
return LoadWithConfig(ctx, c, urlName, DEFAULT_GITHUB_CONFIG)
113+
}
114+
115+
func NewGithubEnterpriseConfig(prefix string) *GithubConfig {
116+
return &GithubConfig{
117+
Prefix: prefix,
118+
RepoURL: fmt.Sprintf("https://%s/%%s/%%s.git", prefix),
119+
DownloadURL: fmt.Sprintf("https://raw.%s/%%s/%%s/%%s/%%s", prefix),
120+
CommitURL: fmt.Sprintf("https://%s/api/v3/repos/%%s/%%s/commits/%%s", prefix),
121+
AuthToken: os.Getenv("GH_ENTERPRISE_TOKEN"),
122+
Enterprise: true,
123+
}
124+
}
125+
126+
func LoadWithConfig(ctx context.Context, _ *cache.Client, urlName string, config *GithubConfig) (string, *types.Repo, bool, error) {
127+
if !strings.HasPrefix(urlName, config.Prefix) {
93128
return "", nil, false, nil
94129
}
95130

@@ -107,12 +142,12 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, *types.
107142
account, repo := parts[1], parts[2]
108143
path := strings.Join(parts[3:], "/")
109144

110-
ref, err := getCommit(ctx, account, repo, ref)
145+
ref, err := getCommit(ctx, account, repo, ref, config)
111146
if err != nil {
112147
return "", nil, false, err
113148
}
114149

115-
downloadURL := fmt.Sprintf(githubDownloadURL, account, repo, ref, path)
150+
downloadURL := fmt.Sprintf(config.DownloadURL, account, repo, ref, path)
116151
if path == "" || path == "/" || !strings.Contains(parts[len(parts)-1], ".") {
117152
var (
118153
testPath string
@@ -124,7 +159,7 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, *types.
124159
} else {
125160
testPath = path + "/" + ext
126161
}
127-
testURL = fmt.Sprintf(githubDownloadURL, account, repo, ref, testPath)
162+
testURL = fmt.Sprintf(config.DownloadURL, account, repo, ref, testPath)
128163
if i == len(types.DefaultFiles)-1 {
129164
// no reason to test the last one, we are just going to use it. Being that the default list is only
130165
// two elements this loop could have been one check, but hey over-engineered code ftw.
@@ -141,11 +176,17 @@ func Load(ctx context.Context, _ *cache.Client, urlName string) (string, *types.
141176
path = testPath
142177
}
143178

144-
return downloadURL, &types.Repo{
179+
repoConfig := &types.Repo{
145180
VCS: "git",
146-
Root: fmt.Sprintf(githubRepoURL, account, repo),
181+
Root: fmt.Sprintf(config.RepoURL, account, repo),
147182
Path: gpath.Dir(path),
148183
Name: gpath.Base(path),
149184
Revision: ref,
150-
}, true, nil
185+
}
186+
if config.Enterprise {
187+
repoConfig.Headers = map[string]string{
188+
"Authorization": fmt.Sprintf("bearer %s", config.AuthToken),
189+
}
190+
}
191+
return downloadURL, repoConfig, true, nil
151192
}

pkg/loader/github/github_test.go

+57
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ package github
22

33
import (
44
"context"
5+
"fmt"
6+
"net/http"
7+
"net/http/httptest"
8+
"os"
59
"testing"
610

711
"github.com/gptscript-ai/gptscript/pkg/types"
@@ -44,3 +48,56 @@ func TestLoad(t *testing.T) {
4448
Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91",
4549
}).Equal(t, repo)
4650
}
51+
52+
func TestLoad_GithubEnterprise(t *testing.T) {
53+
os.Setenv("GH_ENTERPRISE_SKIP_VERIFY", "true")
54+
os.Setenv("GH_ENTERPRISE_TOKEN", "mytoken")
55+
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
56+
fmt.Printf("Request for %s\n", r.URL.Path)
57+
switch r.URL.Path {
58+
case "/api/v3/repos/gptscript-ai/gptscript/commits/172dfb0":
59+
w.Write([]byte(`{"sha": "172dfb00b48c6adbbaa7e99270933f95887d1b91"}`))
60+
default:
61+
w.WriteHeader(404)
62+
}
63+
}))
64+
defer s.Close()
65+
66+
serverAddr := s.Listener.Addr().String()
67+
68+
url, repo, ok, err := LoadWithConfig(context.Background(), nil, fmt.Sprintf("%s/gptscript-ai/gptscript/pkg/loader/testdata/tool@172dfb0", serverAddr), NewGithubEnterpriseConfig(serverAddr))
69+
require.NoError(t, err)
70+
assert.True(t, ok)
71+
autogold.Expect(fmt.Sprintf("https://raw.%s/gptscript-ai/gptscript/172dfb00b48c6adbbaa7e99270933f95887d1b91/pkg/loader/testdata/tool/tool.gpt", serverAddr)).Equal(t, url)
72+
autogold.Expect(&types.Repo{
73+
VCS: "git", Root: fmt.Sprintf("https://%s/gptscript-ai/gptscript.git", serverAddr),
74+
Path: "pkg/loader/testdata/tool",
75+
Name: "tool.gpt",
76+
Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91",
77+
Header: map[string]string{
78+
"Authorization": "bearer mytoken",
79+
},
80+
}).Equal(t, repo)
81+
82+
url, repo, ok, err = Load(context.Background(), nil, "github.com/gptscript-ai/gptscript/pkg/loader/testdata/agent@172dfb0")
83+
require.NoError(t, err)
84+
assert.True(t, ok)
85+
autogold.Expect("https://raw.githubusercontent.com/gptscript-ai/gptscript/172dfb00b48c6adbbaa7e99270933f95887d1b91/pkg/loader/testdata/agent/agent.gpt").Equal(t, url)
86+
autogold.Expect(&types.Repo{
87+
VCS: "git", Root: "https://github.com/gptscript-ai/gptscript.git",
88+
Path: "pkg/loader/testdata/agent",
89+
Name: "agent.gpt",
90+
Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91",
91+
}).Equal(t, repo)
92+
93+
url, repo, ok, err = Load(context.Background(), nil, "github.com/gptscript-ai/gptscript/pkg/loader/testdata/bothtoolagent@172dfb0")
94+
require.NoError(t, err)
95+
assert.True(t, ok)
96+
autogold.Expect("https://raw.githubusercontent.com/gptscript-ai/gptscript/172dfb00b48c6adbbaa7e99270933f95887d1b91/pkg/loader/testdata/bothtoolagent/agent.gpt").Equal(t, url)
97+
autogold.Expect(&types.Repo{
98+
VCS: "git", Root: "https://github.com/gptscript-ai/gptscript.git",
99+
Path: "pkg/loader/testdata/bothtoolagent",
100+
Name: "agent.gpt",
101+
Revision: "172dfb00b48c6adbbaa7e99270933f95887d1b91",
102+
}).Equal(t, repo)
103+
}

pkg/loader/url.go

+6
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ func loadURL(ctx context.Context, cache *cache.Client, base *source, name string
105105
return nil, false, err
106106
}
107107

108+
if repo != nil {
109+
for key, value := range repo.Headers {
110+
req.Header.Add(key, value)
111+
}
112+
}
113+
108114
data, err := getWithDefaults(req)
109115
if err != nil {
110116
return nil, false, fmt.Errorf("error loading %s: %v", url, err)

pkg/types/tool.go

+3
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,9 @@ type Repo struct {
721721
Name string
722722
// The revision of this source
723723
Revision string
724+
725+
// Additional headers to pass when making requests for this repo
726+
Headers map[string]string
724727
}
725728

726729
type ToolSource struct {

0 commit comments

Comments
 (0)