diff --git a/pkg/container/client.go b/pkg/container/client.go index 75c0b4d..fa587e2 100644 --- a/pkg/container/client.go +++ b/pkg/container/client.go @@ -8,7 +8,6 @@ import ( "time" "github.com/containrrr/watchtower/pkg/registry" - "github.com/containrrr/watchtower/pkg/registry/digest" t "github.com/containrrr/watchtower/pkg/types" "github.com/docker/docker/api/types" @@ -52,6 +51,7 @@ func NewClient(opts ClientOptions) Client { return dockerClient{ api: cli, ClientOptions: opts, + reg: registry.NewClient(), } } @@ -63,6 +63,7 @@ type ClientOptions struct { ReviveStopped bool IncludeRestarting bool WarnOnHeadFailed WarningStrategy + Timeout time.Duration } // WarningStrategy is a value determining when to show warnings @@ -80,6 +81,16 @@ const ( type dockerClient struct { api sdkClient.CommonAPIClient ClientOptions + reg *registry.Client +} + +func (client *dockerClient) createContext() (context.Context, context.CancelFunc) { + base := context.TODO() + if client.ClientOptions.Timeout == 0 { + // No timeout has been specified, let's not create a context that instantly cancels itself + return base, func() {} + } + return context.WithTimeout(context.Background(), client.ClientOptions.Timeout) } func (client dockerClient) WarnOnHeadPullFailed(container Container) bool { @@ -278,7 +289,8 @@ func (client dockerClient) RenameContainer(c Container, newName string) error { } func (client dockerClient) IsContainerStale(container Container) (stale bool, latestImage t.ImageID, err error) { - ctx := context.Background() + ctx, cancel := client.createContext() + defer cancel() if !client.PullImages { log.Debugf("Skipping image pull.") @@ -335,12 +347,12 @@ func (client dockerClient) PullImage(ctx context.Context, container Container) e log.WithFields(fields).Debugf("Checking if pull is needed") - if match, err := digest.CompareDigest(ctx, container, opts.RegistryAuth); err != nil { + if match, err := client.reg.CompareDigest(ctx, container, opts.RegistryAuth); err != nil { headLevel := log.DebugLevel if client.WarnOnHeadPullFailed(container) { headLevel = log.WarnLevel } - log.WithFields(fields).Logf(headLevel, "Could not do a head request for %q, falling back to regular pull.", imageName) + log.WithFields(fields).Log(headLevel, "Could not do a head request, falling back to regular pull.") log.WithFields(fields).Log(headLevel, "Reason: ", err) } else if match { log.Debug("No pull needed. Skipping image.") diff --git a/pkg/registry/auth.go b/pkg/registry/auth.go new file mode 100644 index 0000000..92f6902 --- /dev/null +++ b/pkg/registry/auth.go @@ -0,0 +1,112 @@ +package registry + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + + "github.com/containrrr/watchtower/pkg/registry/auth" + "github.com/containrrr/watchtower/pkg/types" + "github.com/sirupsen/logrus" +) + +// ChallengeHeader is the HTTP Header containing challenge instructions +const ChallengeHeader = "WWW-Authenticate" + +// GetToken fetches a token for the registry hosting the provided image +func (rc *Client) GetToken(ctx context.Context, container types.Container, registryAuth string) (string, error) { + var err error + var URL url.URL + + if URL, err = auth.GetChallengeURL(container.ImageName()); err != nil { + return "", err + } + logrus.WithField("URL", URL.String()).Debug("Building challenge URL") + + var req *http.Request + if req, err = rc.GetChallengeRequest(ctx, URL); err != nil { + return "", err + } + + var res *http.Response + if res, err = rc.httpClient.Do(req); err != nil { + return "", err + } + defer res.Body.Close() + v := res.Header.Get(ChallengeHeader) + + logrus.WithFields(logrus.Fields{ + "status": res.Status, + "header": v, + }).Debug("Got response to challenge request") + + challenge := strings.ToLower(v) + if strings.HasPrefix(challenge, "basic") { + if registryAuth == "" { + return "", fmt.Errorf("no credentials available") + } + + return fmt.Sprintf("Basic %s", registryAuth), nil + } + if strings.HasPrefix(challenge, "bearer") { + return rc.GetBearerHeader(ctx, challenge, container.ImageName(), registryAuth) + } + + return "", errors.New("unsupported challenge type from registry") +} + +// GetChallengeRequest creates a request for getting challenge instructions +func (rc *Client) GetChallengeRequest(ctx context.Context, URL url.URL) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, "GET", URL.String(), nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "*/*") + req.Header.Set("User-Agent", "Watchtower (Docker)") + return req, nil +} + +// GetBearerHeader tries to fetch a bearer token from the registry based on the challenge instructions +func (rc *Client) GetBearerHeader(ctx context.Context, challenge string, img string, registryAuth string) (string, error) { + if strings.Contains(img, ":") { + img = strings.Split(img, ":")[0] + } + authURL, err := auth.GetAuthURL(challenge, img) + + if err != nil { + return "", err + } + + var r *http.Request + if r, err = http.NewRequestWithContext(ctx, "GET", authURL.String(), nil); err != nil { + return "", err + } + + if registryAuth != "" { + logrus.Debug("Credentials found.") + logrus.Tracef("Credentials: %v", registryAuth) + r.Header.Add("Authorization", fmt.Sprintf("Basic %s", registryAuth)) + } else { + logrus.Debug("No credentials found.") + } + + var authResponse *http.Response + if authResponse, err = rc.httpClient.Do(r); err != nil { + return "", err + } + + body, _ := ioutil.ReadAll(authResponse.Body) + tokenResponse := &types.TokenResponse{} + + err = json.Unmarshal(body, tokenResponse) + if err != nil { + return "", err + } + + return fmt.Sprintf("Bearer %s", tokenResponse.Token), nil +} diff --git a/pkg/registry/auth/auth.go b/pkg/registry/auth/auth.go index d7ed09f..64962a4 100644 --- a/pkg/registry/auth/auth.go +++ b/pkg/registry/auth/auth.go @@ -1,117 +1,15 @@ package auth import ( - "context" - "encoding/json" - "errors" "fmt" - "io/ioutil" - "net/http" "net/url" "strings" "github.com/containrrr/watchtower/pkg/registry/helpers" - "github.com/containrrr/watchtower/pkg/types" "github.com/docker/distribution/reference" "github.com/sirupsen/logrus" ) -// ChallengeHeader is the HTTP Header containing challenge instructions -const ChallengeHeader = "WWW-Authenticate" - -// GetToken fetches a token for the registry hosting the provided image -func GetToken(ctx context.Context, container types.Container, registryAuth string) (string, error) { - var err error - var URL url.URL - - if URL, err = GetChallengeURL(container.ImageName()); err != nil { - return "", err - } - logrus.WithField("URL", URL.String()).Debug("Building challenge URL") - - var req *http.Request - if req, err = GetChallengeRequest(ctx, URL); err != nil { - return "", err - } - - var res *http.Response - if res, err = http.DefaultClient.Do(req); err != nil { - return "", err - } - defer res.Body.Close() - v := res.Header.Get(ChallengeHeader) - - logrus.WithFields(logrus.Fields{ - "status": res.Status, - "header": v, - }).Debug("Got response to challenge request") - - challenge := strings.ToLower(v) - if strings.HasPrefix(challenge, "basic") { - if registryAuth == "" { - return "", fmt.Errorf("no credentials available") - } - - return fmt.Sprintf("Basic %s", registryAuth), nil - } - if strings.HasPrefix(challenge, "bearer") { - return GetBearerHeader(ctx, challenge, container.ImageName(), registryAuth) - } - - return "", errors.New("unsupported challenge type from registry") -} - -// GetChallengeRequest creates a request for getting challenge instructions -func GetChallengeRequest(ctx context.Context, URL url.URL) (*http.Request, error) { - req, err := http.NewRequestWithContext(ctx, "GET", URL.String(), nil) - if err != nil { - return nil, err - } - req.Header.Set("Accept", "*/*") - req.Header.Set("User-Agent", "Watchtower (Docker)") - return req, nil -} - -// GetBearerHeader tries to fetch a bearer token from the registry based on the challenge instructions -func GetBearerHeader(ctx context.Context, challenge string, img string, registryAuth string) (string, error) { - if strings.Contains(img, ":") { - img = strings.Split(img, ":")[0] - } - authURL, err := GetAuthURL(challenge, img) - - if err != nil { - return "", err - } - - var r *http.Request - if r, err = http.NewRequestWithContext(ctx, "GET", authURL.String(), nil); err != nil { - return "", err - } - - if registryAuth != "" { - logrus.Debug("Credentials found.") - logrus.Tracef("Credentials: %v", registryAuth) - r.Header.Add("Authorization", fmt.Sprintf("Basic %s", registryAuth)) - } else { - logrus.Debug("No credentials found.") - } - - var authResponse *http.Response - if authResponse, err = http.DefaultClient.Do(r); err != nil { - return "", err - } - - body, _ := ioutil.ReadAll(authResponse.Body) - tokenResponse := &types.TokenResponse{} - - err = json.Unmarshal(body, tokenResponse) - if err != nil { - return "", err - } - - return fmt.Sprintf("Bearer %s", tokenResponse.Token), nil -} - // GetAuthURL from the instructions in the challenge func GetAuthURL(challenge string, img string) (*url.URL, error) { loweredChallenge := strings.ToLower(challenge) diff --git a/pkg/registry/auth/auth_test.go b/pkg/registry/auth/auth_test.go index 0778110..e11ea7c 100644 --- a/pkg/registry/auth/auth_test.go +++ b/pkg/registry/auth/auth_test.go @@ -1,14 +1,10 @@ package auth_test import ( - "context" - "fmt" "net/url" "os" "testing" - "time" - "github.com/containrrr/watchtower/internal/actions/mocks" "github.com/containrrr/watchtower/pkg/registry/auth" wtTypes "github.com/containrrr/watchtower/pkg/types" @@ -34,37 +30,13 @@ func SkipIfCredentialsEmpty(credentials *wtTypes.RegistryCredentials, fn func()) } } -var ctx = context.Background() - var GHCRCredentials = &wtTypes.RegistryCredentials{ Username: os.Getenv("CI_INTEGRATION_TEST_REGISTRY_GH_USERNAME"), Password: os.Getenv("CI_INTEGRATION_TEST_REGISTRY_GH_PASSWORD"), } var _ = Describe("the auth module", func() { - mockId := "mock-id" - mockName := "mock-container" - mockImage := "ghcr.io/k6io/operator:latest" - mockCreated := time.Now() - mockDigest := "ghcr.io/k6io/operator@sha256:d68e1e532088964195ad3a0a71526bc2f11a78de0def85629beb75e2265f0547" - - mockContainer := mocks.CreateMockContainerWithDigest( - mockId, - mockName, - mockImage, - mockCreated, - mockDigest) - When("getting an auth url", func() { - It("should parse the token from the response", - SkipIfCredentialsEmpty(GHCRCredentials, func() { - creds := fmt.Sprintf("%s:%s", GHCRCredentials.Username, GHCRCredentials.Password) - token, err := auth.GetToken(ctx, mockContainer, creds) - Expect(err).NotTo(HaveOccurred()) - Expect(token).NotTo(Equal("")) - }), - ) - It("should create a valid auth url object based on the challenge header supplied", func() { input := `bearer realm="https://ghcr.io/token",service="ghcr.io",scope="repository:user/image:pull"` expected := &url.URL{ diff --git a/pkg/registry/client.go b/pkg/registry/client.go new file mode 100644 index 0000000..695a8b0 --- /dev/null +++ b/pkg/registry/client.go @@ -0,0 +1,40 @@ +package registry + +import ( + "crypto/tls" + "net" + "net/http" + "time" +) + +type Client struct { + httpClient *http.Client + Timeout time.Duration +} + +// NewClientWithHTTPClient returns a custom registry client useful for testing +func NewClientWithHTTPClient(httpClient *http.Client) *Client { + timeout := 30 * time.Second + return &Client{ + httpClient, + timeout, + } +} + +// NewClient returns a registry client with the default values +func NewClient() *Client { + tr := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + return NewClientWithHTTPClient(&http.Client{Transport: tr}) +} diff --git a/pkg/registry/digest/digest.go b/pkg/registry/digest.go similarity index 74% rename from pkg/registry/digest/digest.go rename to pkg/registry/digest.go index c8bbf07..514d66d 100644 --- a/pkg/registry/digest/digest.go +++ b/pkg/registry/digest.go @@ -1,19 +1,15 @@ -package digest +package registry import ( "context" - "crypto/tls" "encoding/base64" "encoding/json" "errors" "fmt" - "net" "net/http" "strings" - "time" "github.com/containrrr/watchtower/internal/meta" - "github.com/containrrr/watchtower/pkg/registry/auth" "github.com/containrrr/watchtower/pkg/registry/manifest" "github.com/containrrr/watchtower/pkg/types" "github.com/sirupsen/logrus" @@ -22,8 +18,12 @@ import ( // ContentDigestHeader is the key for the key-value pair containing the digest header const ContentDigestHeader = "Docker-Content-Digest" -// CompareDigest ... -func CompareDigest(ctx context.Context, container types.Container, registryAuth string) (bool, error) { +// CompareDigest retrieves the latest digest for the container image from the registry +// and returns whether it matches any of the containers current image's digest +func (rc *Client) CompareDigest(ctx context.Context, container types.Container, registryAuth string) (bool, error) { + ctx, cancel := context.WithTimeout(ctx, rc.Timeout) + defer cancel() + if !container.HasImageInfo() { return false, errors.New("container image info missing") } @@ -31,7 +31,7 @@ func CompareDigest(ctx context.Context, container types.Container, registryAuth var digest string registryAuth = TransformAuth(registryAuth) - token, err := auth.GetToken(ctx, container, registryAuth) + token, err := rc.GetToken(ctx, container, registryAuth) if err != nil { return false, err } @@ -41,7 +41,7 @@ func CompareDigest(ctx context.Context, container types.Container, registryAuth return false, err } - if digest, err = GetDigest(ctx, digestURL, token); err != nil { + if digest, err = rc.GetDigest(ctx, digestURL, token); err != nil { return false, err } @@ -76,21 +76,7 @@ func TransformAuth(registryAuth string) string { } // GetDigest from registry using a HEAD request to prevent rate limiting -func GetDigest(ctx context.Context, url string, token string) (string, error) { - tr := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - client := &http.Client{Transport: tr} +func (rc *Client) GetDigest(ctx context.Context, url string, token string) (string, error) { req, _ := http.NewRequestWithContext(ctx, "HEAD", url, nil) req.Header.Set("User-Agent", meta.UserAgent) @@ -108,7 +94,7 @@ func GetDigest(ctx context.Context, url string, token string) (string, error) { logrus.WithField("url", url).Debug("Doing a HEAD request to fetch a digest") - res, err := client.Do(req) + res, err := rc.httpClient.Do(req) if err != nil { return "", err } diff --git a/pkg/registry/digest/digest_test.go b/pkg/registry/digest_test.go similarity index 73% rename from pkg/registry/digest/digest_test.go rename to pkg/registry/digest_test.go index 8d53cf0..de788fa 100644 --- a/pkg/registry/digest/digest_test.go +++ b/pkg/registry/digest_test.go @@ -1,14 +1,13 @@ -package digest_test +package registry_test import ( "fmt" "net/http" "os" - "testing" "time" "github.com/containrrr/watchtower/internal/actions/mocks" - "github.com/containrrr/watchtower/pkg/registry/digest" + "github.com/containrrr/watchtower/pkg/registry" wtTypes "github.com/containrrr/watchtower/pkg/types" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -16,11 +15,6 @@ import ( "golang.org/x/net/context" ) -func TestDigest(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(GinkgoT(), "Digest Suite") -} - var ( DockerHubCredentials = &wtTypes.RegistryCredentials{ Username: os.Getenv("CI_INTEGRATION_TEST_REGISTRY_DH_USERNAME"), @@ -31,6 +25,7 @@ var ( Password: os.Getenv("CI_INTEGRATION_TEST_REGISTRY_GH_PASSWORD"), } ctx = context.Background() + rc = registry.NewClientWithHTTPClient(http.DefaultClient) ) func SkipIfCredentialsEmpty(credentials *wtTypes.RegistryCredentials, fn func()) func() { @@ -67,7 +62,7 @@ var _ = Describe("Digests", func() { It("should return true if digests match", SkipIfCredentialsEmpty(GHCRCredentials, func() { creds := fmt.Sprintf("%s:%s", GHCRCredentials.Username, GHCRCredentials.Password) - matches, err := digest.CompareDigest(ctx, mockContainer, creds) + matches, err := rc.CompareDigest(ctx, mockContainer, creds) Expect(err).NotTo(HaveOccurred()) Expect(matches).To(Equal(true)) }), @@ -80,7 +75,7 @@ var _ = Describe("Digests", func() { }) It("should return an error when container contains no image info", func() { - matches, err := digest.CompareDigest(ctx, mockContainerNoImage, `user:pass`) + matches, err := rc.CompareDigest(ctx, mockContainerNoImage, `user:pass`) Expect(err).To(HaveOccurred()) Expect(matches).To(Equal(false)) }) @@ -112,16 +107,42 @@ var _ = Describe("Digests", func() { "User-Agent": []string{"Watchtower/v0.0.0-unknown"}, }), ghttp.RespondWith(http.StatusOK, "", http.Header{ - digest.ContentDigestHeader: []string{ + registry.ContentDigestHeader: []string{ mockDigest, }, }), ), ) - dig, err := digest.GetDigest(ctx, server.URL(), "token") + dig, err := rc.GetDigest(ctx, server.URL(), "token") Expect(server.ReceivedRequests()).Should(HaveLen(1)) Expect(err).NotTo(HaveOccurred()) Expect(dig).To(Equal(mockDigest)) }) }) }) + +var _ = Describe("the auth module", func() { + mockId := "mock-id" + mockName := "mock-container" + mockImage := "ghcr.io/k6io/operator:latest" + mockCreated := time.Now() + mockDigest := "ghcr.io/k6io/operator@sha256:d68e1e532088964195ad3a0a71526bc2f11a78de0def85629beb75e2265f0547" + + mockContainer := mocks.CreateMockContainerWithDigest( + mockId, + mockName, + mockImage, + mockCreated, + mockDigest) + + When("getting an auth url", func() { + It("should parse the token from the response", + SkipIfCredentialsEmpty(GHCRCredentials, func() { + creds := fmt.Sprintf("%s:%s", GHCRCredentials.Username, GHCRCredentials.Password) + token, err := rc.GetToken(ctx, mockContainer, creds) + Expect(err).NotTo(HaveOccurred()) + Expect(token).NotTo(Equal("")) + }), + ) + }) +})