mirror of https://github.com/synctv-org/synctv
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
542 lines
14 KiB
Go
542 lines
14 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/synctv-org/synctv/utils"
|
|
)
|
|
|
|
var (
|
|
_ io.ReadSeekCloser = (*HttpReadSeekCloser)(nil)
|
|
_ Proxy = (*HttpReadSeekCloser)(nil)
|
|
)
|
|
|
|
type HttpReadSeekCloser struct {
|
|
ctx context.Context
|
|
headHeaders http.Header
|
|
currentResp *http.Response
|
|
headers http.Header
|
|
client *http.Client
|
|
contentType string
|
|
method string
|
|
headMethod string
|
|
url string
|
|
allowedContentTypes []string
|
|
notAllowedStatusCodes []int
|
|
allowedStatusCodes []int
|
|
offset int64
|
|
contentLength int64
|
|
length int64
|
|
currentRespMaxOffset int64
|
|
notSupportRange bool
|
|
}
|
|
|
|
type HttpReadSeekerConf func(h *HttpReadSeekCloser)
|
|
|
|
func WithHeaders(headers http.Header) HttpReadSeekerConf {
|
|
return func(h *HttpReadSeekCloser) {
|
|
if headers != nil {
|
|
h.headers = headers.Clone()
|
|
}
|
|
}
|
|
}
|
|
|
|
func WithHeadersMap(headers map[string]string) HttpReadSeekerConf {
|
|
return func(h *HttpReadSeekCloser) {
|
|
for k, v := range headers {
|
|
h.headers.Set(k, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
func WithClient(client *http.Client) HttpReadSeekerConf {
|
|
return func(h *HttpReadSeekCloser) {
|
|
if client != nil {
|
|
h.client = client
|
|
}
|
|
}
|
|
}
|
|
|
|
func WithMethod(method string) HttpReadSeekerConf {
|
|
return func(h *HttpReadSeekCloser) {
|
|
if method != "" {
|
|
h.method = method
|
|
}
|
|
}
|
|
}
|
|
|
|
func WithHeadMethod(method string) HttpReadSeekerConf {
|
|
return func(h *HttpReadSeekCloser) {
|
|
if method != "" {
|
|
h.headMethod = method
|
|
}
|
|
}
|
|
}
|
|
|
|
func WithContext(ctx context.Context) HttpReadSeekerConf {
|
|
return func(h *HttpReadSeekCloser) {
|
|
if ctx != nil {
|
|
h.ctx = ctx
|
|
}
|
|
}
|
|
}
|
|
|
|
func WithContentLength(contentLength int64) HttpReadSeekerConf {
|
|
return func(h *HttpReadSeekCloser) {
|
|
if contentLength >= 0 {
|
|
h.contentLength = contentLength
|
|
}
|
|
}
|
|
}
|
|
|
|
func AllowedContentTypes(types ...string) HttpReadSeekerConf {
|
|
return func(h *HttpReadSeekCloser) {
|
|
if len(types) > 0 {
|
|
h.allowedContentTypes = slices.Clone(types)
|
|
}
|
|
}
|
|
}
|
|
|
|
func AllowedStatusCodes(codes ...int) HttpReadSeekerConf {
|
|
return func(h *HttpReadSeekCloser) {
|
|
if len(codes) > 0 {
|
|
h.allowedStatusCodes = slices.Clone(codes)
|
|
}
|
|
}
|
|
}
|
|
|
|
func NotAllowedStatusCodes(codes ...int) HttpReadSeekerConf {
|
|
return func(h *HttpReadSeekCloser) {
|
|
if len(codes) > 0 {
|
|
h.notAllowedStatusCodes = slices.Clone(codes)
|
|
}
|
|
}
|
|
}
|
|
|
|
func WithLength(length int64) HttpReadSeekerConf {
|
|
return func(h *HttpReadSeekCloser) {
|
|
if length > 0 {
|
|
h.length = length
|
|
}
|
|
}
|
|
}
|
|
|
|
func WithNotSupportRange(notSupportRange bool) HttpReadSeekerConf {
|
|
return func(h *HttpReadSeekCloser) {
|
|
h.notSupportRange = notSupportRange
|
|
}
|
|
}
|
|
|
|
func NewHttpReadSeekCloser(url string, conf ...HttpReadSeekerConf) *HttpReadSeekCloser {
|
|
rs := &HttpReadSeekCloser{
|
|
url: url,
|
|
contentLength: -1,
|
|
method: http.MethodGet,
|
|
headMethod: http.MethodHead,
|
|
length: 1024 * 1024 * 16,
|
|
headers: make(http.Header),
|
|
ctx: context.Background(),
|
|
client: http.DefaultClient,
|
|
}
|
|
|
|
for _, c := range conf {
|
|
if c != nil {
|
|
c(rs)
|
|
}
|
|
}
|
|
|
|
rs.fix()
|
|
|
|
return rs
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) fix() *HttpReadSeekCloser {
|
|
if h.method == "" {
|
|
h.method = http.MethodGet
|
|
}
|
|
if h.headMethod == "" {
|
|
h.headMethod = http.MethodHead
|
|
}
|
|
if h.ctx == nil {
|
|
h.ctx = context.Background()
|
|
}
|
|
if h.client == nil {
|
|
h.client = http.DefaultClient
|
|
}
|
|
if len(h.notAllowedStatusCodes) == 0 {
|
|
h.notAllowedStatusCodes = []int{http.StatusNotFound}
|
|
}
|
|
if h.length <= 0 {
|
|
h.length = 64 * 1024
|
|
}
|
|
if h.headers == nil {
|
|
h.headers = make(http.Header)
|
|
}
|
|
return h
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) Read(p []byte) (n int, err error) {
|
|
for n < len(p) {
|
|
if h.currentResp == nil || h.offset > h.currentRespMaxOffset {
|
|
if err := h.FetchNextChunk(); err != nil {
|
|
if err == io.EOF {
|
|
return n, io.EOF
|
|
}
|
|
return 0, fmt.Errorf("failed to fetch next chunk: %w", err)
|
|
}
|
|
}
|
|
|
|
readN, err := h.currentResp.Body.Read(p[n:])
|
|
if readN > 0 {
|
|
n += readN
|
|
h.offset += int64(readN)
|
|
}
|
|
|
|
if err == io.EOF {
|
|
h.closeCurrentResp()
|
|
if n < len(p) {
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
if err != nil {
|
|
if n > 0 {
|
|
return n, nil
|
|
}
|
|
return 0, fmt.Errorf("error reading response body: %w", err)
|
|
}
|
|
}
|
|
|
|
return n, nil
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) FetchNextChunk() error {
|
|
h.closeCurrentResp()
|
|
|
|
if h.contentLength > 0 && h.offset >= h.contentLength {
|
|
return io.EOF
|
|
}
|
|
|
|
req, err := h.createRequest()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
resp, err := h.client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to execute HTTP request: %w", err)
|
|
}
|
|
|
|
h.contentType = resp.Header.Get("Content-Type")
|
|
|
|
if resp.StatusCode == http.StatusOK {
|
|
// if the maximum offset of the current response is less than the content length minus one, it means that the server does not support range requests
|
|
if h.currentRespMaxOffset < resp.ContentLength-1 || h.notSupportRange {
|
|
// if the offset is not 0, it means that the seek method is incorrectly used
|
|
if h.offset != 0 {
|
|
resp.Body.Close()
|
|
return fmt.Errorf("server does not support range requests, cannot seek to non-zero offset")
|
|
}
|
|
h.notSupportRange = true
|
|
h.contentLength = resp.ContentLength
|
|
h.currentRespMaxOffset = h.contentLength - 1
|
|
h.currentResp = resp
|
|
return nil
|
|
}
|
|
// if the content length is not known, it may be because the requested length is too long, and a new request is needed
|
|
if h.contentLength < 0 {
|
|
h.contentLength = resp.ContentLength
|
|
resp.Body.Close()
|
|
return h.FetchNextChunk()
|
|
}
|
|
// if the offset is greater than 0, it means that the seek method is incorrectly used
|
|
if h.offset > 0 {
|
|
resp.Body.Close()
|
|
return fmt.Errorf("server does not support range requests, cannot seek to offset %d", h.offset)
|
|
}
|
|
h.notSupportRange = true
|
|
h.currentRespMaxOffset = h.contentLength - 1
|
|
h.currentResp = resp
|
|
return nil
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusPartialContent {
|
|
resp.Body.Close()
|
|
return fmt.Errorf("unexpected HTTP status code: %d (expected 206 Partial Content)", resp.StatusCode)
|
|
}
|
|
|
|
if err := h.checkResponse(resp); err != nil {
|
|
resp.Body.Close()
|
|
return fmt.Errorf("response validation failed: %w", err)
|
|
}
|
|
|
|
contentTotalLength, err := ParseContentRangeTotalLength(resp.Header.Get("Content-Range"))
|
|
if err == nil && contentTotalLength > 0 {
|
|
h.contentLength = contentTotalLength
|
|
}
|
|
_, end, err := ParseContentRangeStartAndEnd(resp.Header.Get("Content-Range"))
|
|
if err == nil && end != -1 {
|
|
h.currentRespMaxOffset = end
|
|
}
|
|
|
|
h.currentResp = resp
|
|
return nil
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) createRequest() (*http.Request, error) {
|
|
if h.notSupportRange {
|
|
if h.contentLength != -1 {
|
|
h.currentRespMaxOffset = h.contentLength - 1
|
|
}
|
|
return h.createRequestWithoutRange()
|
|
}
|
|
|
|
req, err := h.createRequestWithoutRange()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
end := h.offset + h.length - 1
|
|
if h.contentLength > 0 && end > h.contentLength-1 {
|
|
end = h.contentLength - 1
|
|
}
|
|
|
|
h.currentRespMaxOffset = end
|
|
|
|
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", h.offset, end))
|
|
return req, nil
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) createRequestWithoutRange() (*http.Request, error) {
|
|
req, err := http.NewRequestWithContext(h.ctx, h.method, h.url, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
|
}
|
|
req.Header = h.headers.Clone()
|
|
req.Header.Del("Range")
|
|
if req.Header.Get("User-Agent") == "" {
|
|
req.Header.Set("User-Agent", utils.UA)
|
|
}
|
|
return req, nil
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) checkResponse(resp *http.Response) error {
|
|
if err := h.checkStatusCode(resp.StatusCode); err != nil {
|
|
return err
|
|
}
|
|
return h.checkContentType(resp.Header.Get("Content-Type"))
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) closeCurrentResp() {
|
|
if h.currentResp != nil {
|
|
h.currentResp.Body.Close()
|
|
h.currentResp = nil
|
|
}
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) checkContentType(ct string) error {
|
|
if len(h.allowedContentTypes) != 0 {
|
|
if ct == "" || slices.Index(h.allowedContentTypes, ct) == -1 {
|
|
return fmt.Errorf("content type '%s' is not in the list of allowed content types: %v", ct, h.allowedContentTypes)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) checkStatusCode(code int) error {
|
|
if len(h.allowedStatusCodes) != 0 {
|
|
if slices.Index(h.allowedStatusCodes, code) == -1 {
|
|
return fmt.Errorf("HTTP status code %d is not in the list of allowed status codes: %v", code, h.allowedStatusCodes)
|
|
}
|
|
return nil
|
|
}
|
|
if len(h.notAllowedStatusCodes) != 0 {
|
|
if slices.Index(h.notAllowedStatusCodes, code) != -1 {
|
|
return fmt.Errorf("HTTP status code %d is in the list of not allowed status codes: %v", code, h.notAllowedStatusCodes)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) Seek(offset int64, whence int) (int64, error) {
|
|
newOffset, err := h.calculateNewOffset(offset, whence)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("failed to calculate new offset: %w", err)
|
|
}
|
|
|
|
if newOffset < 0 {
|
|
return 0, fmt.Errorf("cannot seek to negative offset: %d", newOffset)
|
|
}
|
|
|
|
if newOffset != h.offset {
|
|
h.closeCurrentResp()
|
|
h.offset = newOffset
|
|
}
|
|
|
|
return h.offset, nil
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) calculateNewOffset(offset int64, whence int) (int64, error) {
|
|
switch whence {
|
|
case io.SeekStart:
|
|
if h.notSupportRange && offset != 0 && offset != h.offset {
|
|
return 0, fmt.Errorf("server does not support range requests, cannot seek to non-zero offset")
|
|
}
|
|
return offset, nil
|
|
case io.SeekCurrent:
|
|
if h.notSupportRange && offset != 0 {
|
|
return 0, fmt.Errorf("server does not support range requests, cannot seek to non-zero offset")
|
|
}
|
|
return h.offset + offset, nil
|
|
case io.SeekEnd:
|
|
if h.contentLength < 0 {
|
|
if err := h.fetchContentLength(); err != nil {
|
|
return 0, fmt.Errorf("failed to fetch content length: %w", err)
|
|
}
|
|
}
|
|
newOffset := h.contentLength - offset
|
|
if h.notSupportRange && newOffset != h.offset {
|
|
return 0, fmt.Errorf("server does not support range requests, cannot seek to non-zero offset")
|
|
}
|
|
return newOffset, nil
|
|
default:
|
|
return 0, fmt.Errorf("invalid seek whence value: %d (must be 0, 1, or 2)", whence)
|
|
}
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) fetchContentLength() error {
|
|
req, err := h.createRequestWithoutRange()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.Method = h.headMethod
|
|
|
|
resp, err := h.client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to execute HEAD request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("unexpected HTTP status code in HEAD request: %d (expected 200 OK)", resp.StatusCode)
|
|
}
|
|
|
|
if err := h.checkResponse(resp); err != nil {
|
|
return fmt.Errorf("HEAD response validation failed: %w", err)
|
|
}
|
|
|
|
if resp.ContentLength < 0 {
|
|
return fmt.Errorf("server returned invalid content length: %d", resp.ContentLength)
|
|
}
|
|
|
|
h.contentType = resp.Header.Get("Content-Type")
|
|
|
|
h.contentLength = resp.ContentLength
|
|
h.headHeaders = resp.Header.Clone()
|
|
return nil
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) Close() error {
|
|
if h.currentResp != nil {
|
|
return h.currentResp.Body.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) Offset() int64 {
|
|
return h.offset
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) ContentLength() int64 {
|
|
return h.contentLength
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) ContentType() (string, error) {
|
|
if h.contentType != "" {
|
|
return h.contentType, nil
|
|
}
|
|
return "", fmt.Errorf("content type is not available - no successful response received yet")
|
|
}
|
|
|
|
func (h *HttpReadSeekCloser) ContentTotalLength() (int64, error) {
|
|
if h.contentLength > 0 {
|
|
return h.contentLength, nil
|
|
}
|
|
return 0, fmt.Errorf("content total length is not available - no successful response received yet")
|
|
}
|
|
|
|
func ParseContentRangeStartAndEnd(contentRange string) (int64, int64, error) {
|
|
if contentRange == "" {
|
|
return 0, 0, fmt.Errorf("Content-Range header is empty")
|
|
}
|
|
|
|
if !strings.HasPrefix(contentRange, "bytes ") {
|
|
return 0, 0, fmt.Errorf("invalid Content-Range header format (expected 'bytes ' prefix): %s", contentRange)
|
|
}
|
|
|
|
parts := strings.Split(strings.TrimPrefix(contentRange, "bytes "), "/")
|
|
if len(parts) != 2 {
|
|
return 0, 0, fmt.Errorf("invalid Content-Range header format (expected 2 parts separated by '/'): %s", contentRange)
|
|
}
|
|
|
|
rangeParts := strings.Split(strings.TrimSpace(parts[0]), "-")
|
|
if len(rangeParts) != 2 {
|
|
return 0, 0, fmt.Errorf("invalid Content-Range range format (expected start-end): %s", contentRange)
|
|
}
|
|
|
|
start, err := strconv.ParseInt(strings.TrimSpace(rangeParts[0]), 10, 64)
|
|
if err != nil {
|
|
return 0, 0, fmt.Errorf("invalid Content-Range start value '%s': %w", rangeParts[0], err)
|
|
}
|
|
|
|
rangeParts[1] = strings.TrimSpace(rangeParts[1])
|
|
var end int64
|
|
if rangeParts[1] == "" || rangeParts[1] == "*" {
|
|
end = -1
|
|
} else {
|
|
end, err = strconv.ParseInt(rangeParts[1], 10, 64)
|
|
if err != nil {
|
|
return 0, 0, fmt.Errorf("invalid Content-Range end value '%s': %w", rangeParts[1], err)
|
|
}
|
|
}
|
|
|
|
return start, end, nil
|
|
}
|
|
|
|
// ParseContentRangeTotalLength parses a Content-Range header value and returns the total length
|
|
func ParseContentRangeTotalLength(contentRange string) (int64, error) {
|
|
if contentRange == "" {
|
|
return 0, fmt.Errorf("Content-Range header is empty")
|
|
}
|
|
|
|
if !strings.HasPrefix(contentRange, "bytes ") {
|
|
return 0, fmt.Errorf("invalid Content-Range header format (expected 'bytes ' prefix): %s", contentRange)
|
|
}
|
|
|
|
parts := strings.Split(strings.TrimPrefix(contentRange, "bytes "), "/")
|
|
if len(parts) != 2 {
|
|
return 0, fmt.Errorf("invalid Content-Range header format (expected 2 parts separated by '/'): %s", contentRange)
|
|
}
|
|
|
|
if parts[1] == "" || parts[1] == "*" {
|
|
return -1, nil
|
|
}
|
|
|
|
length, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("invalid Content-Range total length value '%s': %w", parts[1], err)
|
|
}
|
|
|
|
if length < 0 {
|
|
return 0, fmt.Errorf("Content-Range total length cannot be negative: %d", length)
|
|
}
|
|
|
|
return length, nil
|
|
}
|