From 9220b51665b282be080b6a3769bb3083eee3bc86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?nils=20m=C3=A5s=C3=A9n?= Date: Tue, 6 Dec 2022 20:47:08 +0100 Subject: [PATCH] feat: pass context when fetching digests --- pkg/container/client.go | 2 +- pkg/registry/auth/auth.go | 21 ++++++++++----------- pkg/registry/auth/auth_test.go | 10 +++++++--- pkg/registry/digest/digest.go | 22 ++++++++++++---------- pkg/registry/digest/digest_test.go | 18 ++++++++++-------- 5 files changed, 40 insertions(+), 33 deletions(-) diff --git a/pkg/container/client.go b/pkg/container/client.go index 7447828..75c0b4d 100644 --- a/pkg/container/client.go +++ b/pkg/container/client.go @@ -335,7 +335,7 @@ func (client dockerClient) PullImage(ctx context.Context, container Container) e log.WithFields(fields).Debugf("Checking if pull is needed") - if match, err := digest.CompareDigest(container, opts.RegistryAuth); err != nil { + if match, err := digest.CompareDigest(ctx, container, opts.RegistryAuth); err != nil { headLevel := log.DebugLevel if client.WarnOnHeadPullFailed(container) { headLevel = log.WarnLevel diff --git a/pkg/registry/auth/auth.go b/pkg/registry/auth/auth.go index 23aef60..d7ed09f 100644 --- a/pkg/registry/auth/auth.go +++ b/pkg/registry/auth/auth.go @@ -1,6 +1,7 @@ package auth import ( + "context" "encoding/json" "errors" "fmt" @@ -19,7 +20,7 @@ import ( const ChallengeHeader = "WWW-Authenticate" // GetToken fetches a token for the registry hosting the provided image -func GetToken(container types.Container, registryAuth string) (string, error) { +func GetToken(ctx context.Context, container types.Container, registryAuth string) (string, error) { var err error var URL url.URL @@ -29,13 +30,12 @@ func GetToken(container types.Container, registryAuth string) (string, error) { logrus.WithField("URL", URL.String()).Debug("Building challenge URL") var req *http.Request - if req, err = GetChallengeRequest(URL); err != nil { + if req, err = GetChallengeRequest(ctx, URL); err != nil { return "", err } - client := &http.Client{} var res *http.Response - if res, err = client.Do(req); err != nil { + if res, err = http.DefaultClient.Do(req); err != nil { return "", err } defer res.Body.Close() @@ -55,15 +55,15 @@ func GetToken(container types.Container, registryAuth string) (string, error) { return fmt.Sprintf("Basic %s", registryAuth), nil } if strings.HasPrefix(challenge, "bearer") { - return GetBearerHeader(challenge, container.ImageName(), registryAuth) + return GetBearerHeader(ctx, challenge, container.ImageName(), registryAuth) } return "", errors.New("unsupported challenge type from registry") } // GetChallengeRequest creates a request for getting challenge instructions -func GetChallengeRequest(URL url.URL) (*http.Request, error) { - req, err := http.NewRequest("GET", URL.String(), nil) +func GetChallengeRequest(ctx context.Context, URL url.URL) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, "GET", URL.String(), nil) if err != nil { return nil, err } @@ -73,8 +73,7 @@ func GetChallengeRequest(URL url.URL) (*http.Request, error) { } // GetBearerHeader tries to fetch a bearer token from the registry based on the challenge instructions -func GetBearerHeader(challenge string, img string, registryAuth string) (string, error) { - client := http.Client{} +func GetBearerHeader(ctx context.Context, challenge string, img string, registryAuth string) (string, error) { if strings.Contains(img, ":") { img = strings.Split(img, ":")[0] } @@ -85,7 +84,7 @@ func GetBearerHeader(challenge string, img string, registryAuth string) (string, } var r *http.Request - if r, err = http.NewRequest("GET", authURL.String(), nil); err != nil { + if r, err = http.NewRequestWithContext(ctx, "GET", authURL.String(), nil); err != nil { return "", err } @@ -98,7 +97,7 @@ func GetBearerHeader(challenge string, img string, registryAuth string) (string, } var authResponse *http.Response - if authResponse, err = client.Do(r); err != nil { + if authResponse, err = http.DefaultClient.Do(r); err != nil { return "", err } diff --git a/pkg/registry/auth/auth_test.go b/pkg/registry/auth/auth_test.go index 6ad2307..0778110 100644 --- a/pkg/registry/auth/auth_test.go +++ b/pkg/registry/auth/auth_test.go @@ -1,14 +1,16 @@ package auth_test import ( + "context" "fmt" - "github.com/containrrr/watchtower/internal/actions/mocks" - "github.com/containrrr/watchtower/pkg/registry/auth" "net/url" "os" "testing" "time" + "github.com/containrrr/watchtower/internal/actions/mocks" + "github.com/containrrr/watchtower/pkg/registry/auth" + wtTypes "github.com/containrrr/watchtower/pkg/types" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -32,6 +34,8 @@ func SkipIfCredentialsEmpty(credentials *wtTypes.RegistryCredentials, fn func()) } } +var ctx = context.Background() + var GHCRCredentials = &wtTypes.RegistryCredentials{ Username: os.Getenv("CI_INTEGRATION_TEST_REGISTRY_GH_USERNAME"), Password: os.Getenv("CI_INTEGRATION_TEST_REGISTRY_GH_PASSWORD"), @@ -55,7 +59,7 @@ var _ = Describe("the auth module", func() { It("should parse the token from the response", SkipIfCredentialsEmpty(GHCRCredentials, func() { creds := fmt.Sprintf("%s:%s", GHCRCredentials.Username, GHCRCredentials.Password) - token, err := auth.GetToken(mockContainer, creds) + token, err := auth.GetToken(ctx, mockContainer, creds) Expect(err).NotTo(HaveOccurred()) Expect(token).NotTo(Equal("")) }), diff --git a/pkg/registry/digest/digest.go b/pkg/registry/digest/digest.go index 26fbd8e..c8bbf07 100644 --- a/pkg/registry/digest/digest.go +++ b/pkg/registry/digest/digest.go @@ -1,35 +1,37 @@ package digest import ( + "context" "crypto/tls" "encoding/base64" "encoding/json" "errors" "fmt" + "net" + "net/http" + "strings" + "time" + "github.com/containrrr/watchtower/internal/meta" "github.com/containrrr/watchtower/pkg/registry/auth" "github.com/containrrr/watchtower/pkg/registry/manifest" "github.com/containrrr/watchtower/pkg/types" "github.com/sirupsen/logrus" - "net" - "net/http" - "strings" - "time" ) // ContentDigestHeader is the key for the key-value pair containing the digest header const ContentDigestHeader = "Docker-Content-Digest" // CompareDigest ... -func CompareDigest(container types.Container, registryAuth string) (bool, error) { +func CompareDigest(ctx context.Context, container types.Container, registryAuth string) (bool, error) { if !container.HasImageInfo() { return false, errors.New("container image info missing") } - + var digest string registryAuth = TransformAuth(registryAuth) - token, err := auth.GetToken(container, registryAuth) + token, err := auth.GetToken(ctx, container, registryAuth) if err != nil { return false, err } @@ -39,7 +41,7 @@ func CompareDigest(container types.Container, registryAuth string) (bool, error) return false, err } - if digest, err = GetDigest(digestURL, token); err != nil { + if digest, err = GetDigest(ctx, digestURL, token); err != nil { return false, err } @@ -74,7 +76,7 @@ func TransformAuth(registryAuth string) string { } // GetDigest from registry using a HEAD request to prevent rate limiting -func GetDigest(url string, token string) (string, error) { +func GetDigest(ctx context.Context, url string, token string) (string, error) { tr := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ @@ -90,7 +92,7 @@ func GetDigest(url string, token string) (string, error) { } client := &http.Client{Transport: tr} - req, _ := http.NewRequest("HEAD", url, nil) + req, _ := http.NewRequestWithContext(ctx, "HEAD", url, nil) req.Header.Set("User-Agent", meta.UserAgent) if token != "" { diff --git a/pkg/registry/digest/digest_test.go b/pkg/registry/digest/digest_test.go index a6e6650..8d53cf0 100644 --- a/pkg/registry/digest/digest_test.go +++ b/pkg/registry/digest/digest_test.go @@ -2,20 +2,21 @@ package digest_test import ( "fmt" + "net/http" + "os" + "testing" + "time" + "github.com/containrrr/watchtower/internal/actions/mocks" "github.com/containrrr/watchtower/pkg/registry/digest" wtTypes "github.com/containrrr/watchtower/pkg/types" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" "github.com/onsi/gomega/ghttp" - "net/http" - "os" - "testing" - "time" + "golang.org/x/net/context" ) func TestDigest(t *testing.T) { - RegisterFailHandler(Fail) RunSpecs(GinkgoT(), "Digest Suite") } @@ -29,6 +30,7 @@ var ( Username: os.Getenv("CI_INTEGRATION_TEST_REGISTRY_GH_USERNAME"), Password: os.Getenv("CI_INTEGRATION_TEST_REGISTRY_GH_PASSWORD"), } + ctx = context.Background() ) func SkipIfCredentialsEmpty(credentials *wtTypes.RegistryCredentials, fn func()) func() { @@ -65,7 +67,7 @@ var _ = Describe("Digests", func() { It("should return true if digests match", SkipIfCredentialsEmpty(GHCRCredentials, func() { creds := fmt.Sprintf("%s:%s", GHCRCredentials.Username, GHCRCredentials.Password) - matches, err := digest.CompareDigest(mockContainer, creds) + matches, err := digest.CompareDigest(ctx, mockContainer, creds) Expect(err).NotTo(HaveOccurred()) Expect(matches).To(Equal(true)) }), @@ -78,7 +80,7 @@ var _ = Describe("Digests", func() { }) It("should return an error when container contains no image info", func() { - matches, err := digest.CompareDigest(mockContainerNoImage, `user:pass`) + matches, err := digest.CompareDigest(ctx, mockContainerNoImage, `user:pass`) Expect(err).To(HaveOccurred()) Expect(matches).To(Equal(false)) }) @@ -116,7 +118,7 @@ var _ = Describe("Digests", func() { }), ), ) - dig, err := digest.GetDigest(server.URL(), "token") + dig, err := digest.GetDigest(ctx, server.URL(), "token") Expect(server.ReceivedRequests()).Should(HaveLen(1)) Expect(err).NotTo(HaveOccurred()) Expect(dig).To(Equal(mockDigest))